읽는데 약 3분
[백준 13511] 트리와 쿼리 2
문제
https://www.acmicpc.net/problem/13511
알고리즘
HLD
풀이
첫번째 쿼리는 두 정점 사이의 간선가중치의 합, 두번째 쿼리는 k번째 정점을 구하는 문제입니다.
트리에서는 루트정점을 제외하고는 모든 정점의 부모정점을 유일합니다. 이를 이용하면 트리의 리프노드마다 간선의 가중치를 넣을 수 있습니다.
u와 v의 lca를 a라고 가정하겠습니다. k번째 정점을 구하기 위해서는 그 정점이 u와 a사이에 있는지 a와 v사이에 있는지를 일차적으로 확인해야합니다. u와 a사이의 정점의 갯수가 cnt개라고 할 때, $k\leq cnt$ 라면 u로 부터 k번 위로 올라가면 되고 a와 v사이에 k번째 정점이 존재한다면 (전체 정점의 갯수 -k+1)번 위로 올라가면 k번째 정점을 구할 수 있습니다.
코드
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 0; i < n; ++i)
#define REP(i, n) for (int i = 1; i <= n; ++i)
using namespace std;
#define int long long
const int MAXN = 1e5;
#define left (i << 1)
#define right (i << 1 | 1)
int N, M, cnt;
vector<int> adj[MAXN];
vector<tuple<int, int, int>> edge;
int p[MAXN], idx[MAXN], st[MAXN], top[MAXN], dep[MAXN], tree[MAXN<<1], rev[MAXN];
void dfs(int here) {
for (auto& there : adj[here]) {
if (there ^ p[here]) {
p[there] = here;
dep[there] = dep[here] + 1;
dfs(there);
st[here] += st[there];
int& he = adj[here][0];
if (he == p[here] || st[there] > st[he]) {
swap(there, he);
}
}
}
}
void hld(int here) {
idx[here] = cnt;
rev[cnt++] = here;
for (auto there : adj[here]) {
if (there ^ p[here]) {
top[there] = there == adj[here][0] ? top[here] : there;
hld(there);
}
}
}
void init() {
for (auto [u, v, c] : edge) {
if (p[u] == v)
swap(u, v);
tree[N + idx[v]] = c;
}
for (int i = N - 1; i > 0; --i) {
tree[i] = tree[left] + tree[right];
}
}
int query(int l, int r) {
int ret = 0;
l += N;
r += N;
for (; l <= r; l >>= 1, r >>= 1) {
if (l & 1)
ret += tree[l++];
if (!(r & 1))
ret += tree[r--];
}
return ret;
}
int pathSum(int u, int v) {
int ret = 0;
while (top[u] ^ top[v]) {
if (st[top[u]] < st[top[v]])
swap(u, v);
ret += query(idx[top[v]], idx[v]);
v = p[top[v]];
}
if (idx[u] < idx[v])
swap(u, v);
ret += query(idx[v] + 1, idx[u]);
return ret;
}
int lca(int u, int v) {
while (top[u] ^ top[v]) {
if (st[top[u]] < st[top[v]])
swap(u, v);
v = p[top[v]];
}
return (idx[u] < idx[v] ? u : v);
}
int upkth(int u, int k) {
if (k == 1) {
return u;
}
int jump = idx[u] - idx[top[u]];
if (jump == 0) {
return upkth(p[u], k - 1);
} else if (k >= jump + 1) {
return upkth(top[u], k - jump);
} else {
return rev[idx[u] - k + 1];
}
}
int kth(int u, int v, int k) {
int LCA = lca(u, v);
int cnt = dep[u] - dep[LCA] + 1;
if (k <= cnt) {
return upkth(u, k);
} else {
cnt += dep[v] - dep[LCA];
return upkth(v, cnt - k + 1);
}
}
signed main() {
#ifndef ONLINE_JUDGE
freopen("in", "r", stdin);
freopen("out", "w", stdout);
#endif
cin.tie(NULL);
cout.tie(NULL);
ios::sync_with_stdio(false);
cin >> N;
fill(st, st + N, 1);
rep(i, N - 1) {
int u, v, c;
cin >> u >> v >> c;
--u, --v;
adj[u].emplace_back(v);
adj[v].emplace_back(u);
edge.emplace_back(u, v, c);
}
dfs(0);
hld(0);
init();
cin >> M;
while (M--) {
int a, u, v;
cin >> a >> u >> v;
--u, --v;
if (a == 1) {
cout << pathSum(u, v) << '\n';
} else {
int k;
cin >> k;
cout << kth(u, v, k) + 1 << '\n';
}
}
return 0;
}
Read other posts