Doubeecat's Blog

“即便前路混沌,同她走过,才算人间。”

0%

CF gym 102201F Fruit tree 解题报告

CF gym 102201F Fruit tree

有一棵 $n$ 个节点的树,每个节点上有一个颜色,有 $q$ 次询问,每次询问给定两个点 $u,v$,要求你求出是否有一种颜色在 $u,v$ 的简单路径上出现超过一半次数。

$n,q \leq 2.5 \times 10^5$

Solution:

首先先考虑一下序列上的版本怎么做:

给你一个序列,每个位置有一个元素,多次询问,每次求出 $[l,r]$ 中是否有出现次数大于 $\frac{r-l+1}{2}$ 的颜色。

这个问题可以使用莫队在 $O(n\sqrt n)$ 的时间内解决,也可以通过主席树在 $O(n\log n)$ 的时间内解决。但是我们这里引入一种更加简单,并且时间复杂度仍然是 $O(n \log n)$的算法。

前置知识:摩尔投票法

在你知道了摩尔投票法之后,我们来理解一下以下推论:

将二元组 $(x_1,c_1),(x_2,c_2)$ (保证 $c_1 > c_2$) 合并的做法为:

若 $x_1 \not = {x_2}$,则合并后的二元组为 $(x_1,c_1-c_2)$

若 $x_1 = {x_2}$,则合并后的二元组为 $(x_1,c_1+c_2)$

证明十分显然,代码这么写:

1
2
3
4
5
6
7
8
9
10
11
pii merge(pii a,pii b) {
if (a.first != b.first) {
if (a.second >= b.second) {
return make_pair(a.first,a.second - b.second);
} else {
return make_pair(b.first,b.second - a.second);
}
} else {
return make_pair(a.first,a.second + b.second);
}
}

基于这个,我们可以使用倍增在 $O(n \log n)$ 的时间内预处理出第 $i$ 个元素到第 $i + 2 ^ j$ 个元素的代表颜色,然后对于每个区间,我们再用二进制拼凑法算出每个区间的 代表颜色

那么接下来的问题是,我们并不能保证每种代表颜色就一定是正确(即出现次数大于一半)的颜色。所以我们将每种颜色的点和询问单独提出来,然后按颜色处理。

对于每种颜色,我们先在树状数组上将该种颜色的点的前缀和全部 +1,然后直接查询对应区间内的颜色是否满足条件即可。时间复杂度 $O(n\log n)$

那么这个问题实际上就是把序列上问题上树了,把倍增换成树上倍增,在 DFS 序上建树状数组,然后按照序列上方法处理即可。

Code:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cctype>
#include <vector>
#include <cmath>
using namespace std;
#define ll long long
#define ri register int
#define pii pair<int,int>

char buf[1 << 20], *p1, *p2;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 20, stdin), p1 == p2)?EOF: *p1++)
template <typename T> inline void read(T &t) {
ri v = getchar();T f = 1;t = 0;
while (!isdigit(v)) {if (v == '-')f = -1;v = getchar();}
while (isdigit(v)) {t = t * 10 + v - 48;v = getchar();}
t *= f;
}
template <typename T,typename... Args> inline void read(T &t,Args&... args) {
read(t);read(args...);
}

const int N = 250001;
const int K = 21;

vector <int> G[N],c[N],c1[N];

int cnt,in[N],out[N],t,f[N][K],dep[N],a[N],n,q,st[N],ed[N],lca[N],ans[N],co[N],len[N];
pii col[N][K];

pii merge(pii a,pii b) {
if (a.first != b.first) {
if (a.second >= b.second) {
return make_pair(a.first,a.second - b.second);
} else {
return make_pair(b.first,b.second - a.second);
}
} else {
return make_pair(a.first,a.second + b.second);
}
}

struct BIT{
int qwq[N*2];

int lowbit(int x) {return (x & (-x));}

void modify(int x,int v) {
for (;x <= cnt;x += lowbit(x)) qwq[x] += v;
}

int query(int x) {
int ans = 0;
for (;x;x -= lowbit(x)) ans += qwq[x];
return ans;
}
}T;

void dfs(int x,int fa) {
col[x][0] = make_pair(a[x],1);
in[x] = ++cnt;
for (int i = 1;i <= t;++i) {
f[x][i] = f[f[x][i-1]][i-1];
}
for (int i = 1;i <= t;++i) {
col[x][i] = merge(col[x][i-1],col[f[x][i-1]][i-1]);
}
for (auto y:G[x]) {
if (y != fa) {
f[y][0] = x;
dep[y] = dep[x] + 1;
dfs(y,x);
}
}
out[x] = ++cnt;
}

int LCA(int x,int y) {
if (dep[x] < dep[y]) swap(x,y);
for (int d = dep[x] - dep[y],i = 0;d;d >>= 1,++i) if (d & 1) x = f[x][i];
if (x == y) return x;
for (int i = t;i >= 0;--i) {
if (f[x][i] != f[y][i]) {
x = f[x][i],y = f[y][i];
}
}
return f[x][0];
}

int work(int x,int y) {
pii now = make_pair(1,0);
if (dep[x] < dep[y]) swap(x,y);
for (int d = dep[x] - dep[y],i = 0;d;d >>= 1,++i) if (d & 1) now = merge(now,col[x][i]),x = f[x][i];
if (x == y) return merge(now,col[x][0]).first;
for (int i = t;i >= 0;--i) {
if (f[x][i] != f[y][i]) {
now = merge(now,col[x][i]),now = merge(now,col[y][i]);
x = f[x][i],y = f[y][i];
}
}
now = merge(now,col[x][0]);
now = merge(now,col[y][1]);
return now.first;
}

signed main() {
read(n,q);
t = 20;
for (int i = 1;i <= n;++i) {
read(a[i]);c[a[i]].push_back(i);
}
for (int i = 1;i < n;++i) {
int x,y;read(x,y);G[x].push_back(y),G[y].push_back(x);
}
dfs(1,0);

for (int i = 1;i <= q;++i) {
ans[i] = -1;
read(st[i],ed[i]);
lca[i] = LCA(st[i],ed[i]);
co[i] = work(st[i],ed[i]);
c1[co[i]].push_back(i);
len[i] = dep[st[i]] + dep[ed[i]] - 2 * dep[lca[i]] + 1;

}
for (int i = 1;i <= n;++i) {
for (auto y:c[i]) {
T.modify(in[y],1),T.modify(out[y],-1);
}
for (auto y:c1[i]) {
//printf("%d\n",y);
int u = st[y],v = ed[y],p = lca[y];
int tmp = T.query(in[u]) + T.query(in[v]) - 2 * T.query(in[p]-1) - (a[p] == i);
//printf("%d %d %d %d %d %d\n",i,y,tmp,len[y],u,v);
if (tmp * 2 > len[y]) ans[y] = i;
}
for (auto y:c[i]) {
T.modify(in[y],-1),T.modify(out[y],1);
}
}
for (int i = 1;i <= q;++i) printf("%d\n",ans[i]);
return 0;
}