1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
| #include<bits/stdc++.h> #define int long long using namespace std; typedef long long ll;
const int N = 200000 + 10; ll n, q, f[N][20], col[N], len[N], sum[N]; ll d[N], sz[N], son[N], pos[N]; ll top[N], id[N], df; vector <ll> g[N]; vector <ll> bels[2][N];
void link(ll u, ll v) { g[u].push_back(v); g[v].push_back(u); }
void dfs1(ll u, ll pre) { f[u][0] = pre; d[u] = d[pre] + 1; sum[u] = sum[pre] + len[u]; sz[u] = 1; for(int i = 1; i <= 18; i++) { if(f[f[u][i - 1]][i - 1]) f[u][i] = f[f[u][i - 1]][i - 1]; else break; } for(int i = 0; i < g[u].size(); i++) { int v = g[u][i]; if(v == pre) continue; dfs1(v, u); sz[u] += sz[v]; if(!son[u] || sz[v] > sz[son[u]]) son[u] = v; } } int lca(int u, int v) { if(d[u] < d[v]) swap(u, v); for(int i = 19; i >= 0; i--) if(d[u] - (1 << i) >= d[v]) u = f[u][i]; if(u == v) return u; for(int i = 19; i >= 0; i--) if(f[u][i] != f[v][i]) u = f[u][i], v = f[v][i]; return f[u][0]; } void dfs2(ll u, ll t) { top[u] = t, id[u] = ++df; pos[df] = u; if(son[u]) dfs2(son[u], t); for(int i = 0; i < g[u].size(); i++) { int v = g[u][i]; if(v == f[u][0] || v == son[u]) continue; dfs2(v, v); } }
ll Bs1(int x, int t) { int l = 0, r = bels[0][x].size() - 1; ll ans = -1; while(l <= r) { int mid = (l + r) >> 1; if(bels[0][x][mid] <= t) ans = mid, l = mid + 1; else r = mid - 1; } return ans; } ll BinarySearch(int x, int t) { int l = 0, r = bels[1][x].size() - 1; ll ans = -1; while(l <= r) { int mid = (l + r) >> 1; if(bels[0][x][mid] <= t) ans = mid, l = mid + 1; else r = mid - 1; } if(ans == -1) return 0; return bels[1][x][ans]; } pair <ll, ll> bsearch(int u, int v, int po) { if(bels[0][po].empty()) return make_pair(0, 0); ll A = Bs1(po, v) - Bs1(po, u - 1); ll C = BinarySearch(po, v), D = BinarySearch(po, u - 1); return make_pair(A, C - D); } pair <ll, ll> Query(int u, int v, int c) { pair <ll, ll> t; t.first = 0, t.second = 0; while(top[u] != top[v]) { if(d[top[u]] < d[top[v]]) swap(u, v); pair <ll, ll> te = bsearch(id[top[u]], id[u], c); t.first += te.first, t.second += te.second; u = f[top[u]][0]; } if(id[u] > id[v]) swap(u, v); pair <ll, ll> te = bsearch(id[u], id[v], c); t.first += te.first, t.second += te.second; return t; } signed main(){ scanf("%lld%lld", &n, &q); for(ll i = 1, u, v, c, d; i < n; i++) { scanf("%lld%lld%lld%lld", &u, &v, &c, &d); link(u, n + i); link(v, n + i); col[n + i] = c, len[n + i] = d; } dfs1(1, 0); dfs2(1, 1); for(int i = 1; i <= 2 * n - 1; i++) bels[0][col[pos[i]]].push_back(i),bels[1][col[pos[i]]].push_back(len[pos[i]]); for(int i = 1; i <= n - 1; i++) for(int j = 1; j < bels[1][i].size(); j++) bels[1][i][j] += bels[1][i][j - 1];
while(q--) { int x, y, u, v; scanf("%lld%lld%lld%lld", &x, &y, &u, &v); pair <ll, ll> ans = Query(u, v, x); ll ans1 = sum[u] + sum[v] - 2 * sum[lca(u, v)]; ll ans2 = - ans.second + 1ll * y * ans.first; printf("%lld\n", ans1 + ans2); } return 0; }
|