문제

https://www.acmicpc.net/problem/13511


알고리즘

HLD


풀이

첫번째 쿼리는 두 정점 사이의 간선가중치의 합, 두번째 쿼리는 k번째 정점을 구하는 문제입니다.

트리에서는 루트정점을 제외하고는 모든 정점의 부모정점을 유일합니다. 이를 이용하면 트리의 리프노드마다 간선의 가중치를 넣을 수 있습니다.

u와 v의 lca를 a라고 가정하겠습니다. k번째 정점을 구하기 위해서는 그 정점이 u와 a사이에 있는지 a와 v사이에 있는지를 일차적으로 확인해야합니다. u와 a사이의 정점의 갯수가 cnt개라고 할 때, $k\leq cnt$ 라면 u로 부터 k번 위로 올라가면 되고 a와 v사이에 k번째 정점이 존재한다면 (전체 정점의 갯수 -k+1)번 위로 올라가면 k번째 정점을 구할 수 있습니다.

코드

#include <bits/stdc++.h>
#define rep(i, n) for (int i = 0; i < n; ++i)
#define REP(i, n) for (int i = 1; i <= n; ++i)
using namespace std;

#define int long long

const int MAXN = 1e5;
#define left (i << 1)
#define right (i << 1 | 1)
int N, M, cnt;

vector<int> adj[MAXN];
vector<tuple<int, int, int>> edge;

int p[MAXN], idx[MAXN], st[MAXN], top[MAXN], dep[MAXN], tree[MAXN<<1], rev[MAXN];

void dfs(int here) {
    for (auto& there : adj[here]) {
        if (there ^ p[here]) {
            p[there] = here;
            dep[there] = dep[here] + 1;
            dfs(there);
            st[here] += st[there];
            int& he = adj[here][0];
            if (he == p[here] || st[there] > st[he]) {
                swap(there, he);
            }
        }
    }
}

void hld(int here) {
    idx[here] = cnt;
    rev[cnt++] = here;
    for (auto there : adj[here]) {
        if (there ^ p[here]) {
            top[there] = there == adj[here][0] ? top[here] : there;
            hld(there);
        }
    }
}

void init() {
    for (auto [u, v, c] : edge) {
        if (p[u] == v)
            swap(u, v);
        tree[N + idx[v]] = c;
    }
    for (int i = N - 1; i > 0; --i) {
        tree[i] = tree[left] + tree[right];
    }
}

int query(int l, int r) {
    int ret = 0;
    l += N;
    r += N;
    for (; l <= r; l >>= 1, r >>= 1) {
        if (l & 1)
            ret += tree[l++];
        if (!(r & 1))
            ret += tree[r--];
    }

    return ret;
}

int pathSum(int u, int v) {
    int ret = 0;
    while (top[u] ^ top[v]) {
        if (st[top[u]] < st[top[v]])
            swap(u, v);
        ret += query(idx[top[v]], idx[v]);
        v = p[top[v]];
    }
    if (idx[u] < idx[v])
        swap(u, v);
    ret += query(idx[v] + 1, idx[u]);
    return ret;
}

int lca(int u, int v) {
    while (top[u] ^ top[v]) {
        if (st[top[u]] < st[top[v]])
            swap(u, v);
        v = p[top[v]];
    }
    return (idx[u] < idx[v] ? u : v);
}

int upkth(int u, int k) {
    if (k == 1) {
        return u;
    }
    int jump = idx[u] - idx[top[u]];
    if (jump == 0) {
        return upkth(p[u], k - 1);
    } else if (k >= jump + 1) {
        return upkth(top[u], k - jump);
    } else {
        return rev[idx[u] - k + 1];
    }
}

int kth(int u, int v, int k) {
    int LCA = lca(u, v);
    int cnt = dep[u] - dep[LCA] + 1;
    if (k <= cnt) {
        return upkth(u, k);
    } else {
        cnt += dep[v] - dep[LCA];

        return upkth(v, cnt - k + 1);
    }
}

signed main() {
#ifndef ONLINE_JUDGE
    freopen("in", "r", stdin);
    freopen("out", "w", stdout);
#endif
    cin.tie(NULL);
    cout.tie(NULL);
    ios::sync_with_stdio(false);

    cin >> N;
    fill(st, st + N, 1);
    rep(i, N - 1) {
        int u, v, c;
        cin >> u >> v >> c;
        --u, --v;
        adj[u].emplace_back(v);
        adj[v].emplace_back(u);
        edge.emplace_back(u, v, c);
    }
    dfs(0);
    hld(0);
    init();

    cin >> M;
    while (M--) {
        int a, u, v;
        cin >> a >> u >> v;
        --u, --v;
        if (a == 1) {
            cout << pathSum(u, v) << '\n';
        } else {
            int k;
            cin >> k;
            cout << kth(u, v, k) + 1 << '\n';
        }
    }
    return 0;
}