世界树

 

如果不会建立虚树的话可以参考这个

题意

我们得到一棵树。这些节点中会定期选出几个关键点。对于所有的点他们受到距离他们最小的点且编号最小的点的控制,求最后每个关键点都会控制多少个点。

思路

我们不难想到一种树上 DP 的方式,首先先从子树向上传递,然后再由父节点向子树传递,就可以传递成功了。

但由于我们的询问轮数太多了,我们根本来不及一个一个传递。那么对于一些点我们建立一颗虚树。

我们记该点受 g[u] 点控制,其距离为 d[u]

我们拷贝t=size[x] 并遍历点 x 的所有子树。

  1. xy 表示虚树上的两个节点,且 xy 两个节点受到的控制点不一样。那么我们计算得出两个控制点之间的分界点 z(z 受到 x 控制)。
    alt

然后我们将这个分为三个部分
alt
接下来我们将 t -= size[z] , 然后将第二部分的值加上去f[g[y]] += size[z] - size[y] ,最后在计算节点 y 的值。
我们在遍历完这个点之后,f[g[x]] += t

  1. xy 受到同一个节点控制,情况较为简单不在赘述。

那么最后体现在代码中就是。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
#include <algorithm>
#include <cstring>
#include <iostream>
#include <vector>
using namespace std;

const int N = 3e5 + 10;

int he[N], nxt[N * 2], v[N * 2], id = 0;
void add(int x, int y) {
    v[id] = y;
    nxt[id] = he[x];
    he[x] = id++;
}

int dpn[N], sz[N], deps[N], pre[N][20], dfn[N];
void get_dpn(int u, int fa = -1, int deep = 1) {
    dfn[u] = id++;
    dpn[u] = deep;
    deps[deep] = u;
    sz[u] = 1;
    for (int i = 0; i < 20; i++) {
        if (deep - (1 << i) < 0) break;
        pre[u][i] = deps[deep - (1 << i)];
    }

    for (int i = he[u]; ~i; i = nxt[i]) {
        int x = v[i];
        if (x == fa) continue;
        get_dpn(x, u, deep + 1);
        sz[u] += sz[x];
    }
}

int lca(int x, int y) {
    if (dpn[x] > dpn[y]) swap(x, y);
    int kki = dpn[y] - dpn[x];
    for (int i = 0; kki; kki >>= 1, i++) {
        if (kki & 1) y = pre[y][i];
    }
    if (x == y) return x;
    for (int i = 19; i >= 0; i--) {
        if (pre[x][i] != pre[y][i]) {
            x = pre[x][i];
            y = pre[y][i];
        }
    }
    return pre[x][0];
}

vector<int> E[N];
bool is[N];

int stk[N], top;
void insert(int kki) {
    int x = stk[top];
    int f = lca(x, kki);
    if (f == x) {
        stk[++top] = kki;
        return;
    }
    int y = stk[top - 1];

    while (dfn[y] > dfn[f]) {
        E[y].push_back(x);
        top--;
        x = stk[top];
        y = stk[top - 1];
    }
    if (dfn[y] == dfn[f]) {
        E[y].push_back(x);
        stk[top] = kki;
    } else if (dfn[y] < dfn[f]) {
        E[f].push_back(x);
        stk[top] = f;
        stk[++top] = kki;
    }
}

int seq[N];
int org[N];

int g[N], mn[N];
void dfs1(int u) {
    if (is[u]) {
        mn[u] = 0;
        g[u] = u;
    } else
        mn[u] = N;
    for (auto x : E[u]) {
        dfs1(x);
        if (mn[x] + dpn[x] - dpn[u] < mn[u]) {
            mn[u] = mn[x] + dpn[x] - dpn[u];
            g[u] = g[x];
        } else if (mn[x] + dpn[x] - dpn[u] == mn[u]) {
            g[u] = min(g[u], g[x]);
        }
    }
}

int ans[N];

void dfs2(int u) {
    for (auto x : E[u]) {
        if (mn[u] + dpn[x] - dpn[u] < mn[x]) {
            mn[x] = mn[u] + dpn[x] - dpn[u];
            g[x] = g[u];
        } else if (mn[u] + dpn[x] - dpn[u] == mn[x]) {
            g[x] = min(g[x], g[u]);
        }
        dfs2(x);
    }
}

int calc(int x, int y) {
    if (g[x] == g[y]) return sz[y];
    int yyyy = y;
    int deep = dpn[y] - dpn[x];
    int kki = deep + mn[x] - mn[y];
    if (((kki & 1) == 0) && g[x] < g[y]) {
        kki >>= 1;
        kki -= 1;
    } else
        kki >>= 1;
    for (int i = 0; kki; i++, kki >>= 1) {
        if (kki & 1) y = pre[y][i];
    }
    ans[g[yyyy]] += sz[y] - sz[yyyy];
    return sz[y];
}

void dfs3(int u) {
    int t = sz[u];
    for (auto x : E[u]) {
        t -= calc(u, x);
        dfs3(x);
    }
    ans[g[u]] += t;
    E[u].clear();
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    memset(he, 0xff, sizeof(he));
    int n;
    cin >> n;

    for (int i = 1; i < n; i++) {
        int x, y;
        cin >> x >> y;
        add(x, y);
        add(y, x);
    }

    id = 1;
    get_dpn(1);
    stk[0] = 1;

    int m;
    cin >> m;

    while (m--) {
        int time;
        cin >> time;
        for (int i = 0; i < time; i++) cin >> seq[i], org[i] = seq[i];
        sort(seq, seq + time, [](int x, int y) { return dfn[x] < dfn[y]; });
        stk[0] = 1;
        top = 0;
        for (int i = (seq[0] == 1) ? 1 : 0; i < time; i++) insert(seq[i]);
        while (top) {
            E[stk[top - 1]].push_back(stk[top]);
            top--;
        }

        for (int i = 0; i < time; i++) is[seq[i]] = true;
        dfs1(1);
        dfs2(1);
        dfs3(1);
        for (int i = 0; i < time; i++) {
            cout << ans[org[i]] << ' ';
            is[org[i]] = false;
            ans[org[i]] = 0;
        }
        cout << '\n';
    }

    return 0;
}
algorithm