树分治-点分治

看到akb在冬令营就切了好多点分治就好慌啊qaq,然后机房里也有一堆人写点分治qaq,然后我就被点分治虐了啊qaq……

树上的点分治是一种针对树上路径统计的问题的算法,通过找到重心使得子问题规模快速下降,再将所有子问题答案合并起来,通常能够将\(O(n ^ 2)\)规模的问题优化为\(O(n \log n)\)或者\(O(n \log ^ 2 n)\)。

先看一道题:
给定一颗\(n \leq 10 ^ 5\)个节点的带权无根树,设\((i, j)\)之间的简单路径边数为\(dis(i, j)\),求所有满足\(dis(i, j) \leq K\)且\(i \lt j\)的数对\((i, j)\)的数量。(CWOJ1201

暴力做法显然,从每个节点开始dfs (bfs是等价的,后文dfs均可以用bfs解决),记录答案即可。时间复杂度为\(O(n ^ 2)\)

但是这样太慢了!我们想要更快的做法。

如果我们任意选择树上的一个节点来观察(不妨暂时把它看作树根),那么所有可以作为答案的\((i, j)\)可以分成两类:

  1. 经过根节点的(两端点在不同子树内,或者其中一个端点是根节点)
  2. 不经过根节点的(两端点在同一子树内)

对于第二种问题,我们可以把子树当成一颗完整的树来看待,就可以按照同样的方法解决并求出;所以我们暂时只需考虑第一个问题即可。

对于经过根节点的路径,我们可以先做一次暴力把所有节点的深度搞出来,然后深度不超过\(K\)显然可以和根构成答案;同时,对于在两个不同子树内的节点\(i, j\),如果有\(dep(i) + dep(j) \leq k\),那么也可以构成答案。

有很多做法可以统计这个问题,时间复杂度均为\(O(n \log n)\):

  1. 将整课树里的存在的点的深度排序,先不管\(i, j\)是不是在一棵子树里,通过双指针法在线性的时间内统计出所有\(dep(i) + dep(j) \leq k\)的点;然后对于每棵子树再这样做一次,可以算出\(i, j\)在同一棵子树内被统计的总和,用第一次的答案减去第二次答案即可解决。
  2. 按子树顺序,依次加入答案,并与之前的子树统计答案,这个过程可以通过平衡树或者按照子树大小启发式合并有序数组来实现。

整个过程的时间复杂度为\(T(n) = \sum T(siz) + O(n \log n)\),其中\(\sum siz = n \ – 1\)。如果每一个\(siz\)均不超过\(\frac n 2\),那么最坏情况下的时间复杂度为\[
\begin{aligned}
T(n) = &\ 2 T(\frac n 2) + O(n \log n) \\
=&\ O(n \log ^ 2 n)
\end{aligned}
\]而为了使得每一个子树的大小都不超过\(\frac n 2\),我们每次选取的节点就不能是随机的了;我们需要选取一个比较优越的点,这样的点被称为重心。每棵树必然存在至少一个,至多两个重心,可以用反证法非常简易地证明。这里我们只需找出一个重心即可,需要先选取任意节点作为根开始进行dfs,算出每个子树的大小,然后通过树型dp或者再dfs一次得到重心。

这样,我们就在\(O(n \log ^ 2 n)\)的时间复杂度内解决了这道题。

/* Never say die. */
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <ctime>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <string>
#include <map>
#include <set>
#include <bitset>
using namespace std;
int _c = 0;
template <class T> inline void read(T &_a) {
	bool f = 0;
	while (_c < '0' || _c>'9') {if (_c == '-') f = 1; _c=getchar();}
	_a = 0;
	while (_c >= '0' && _c <= '9') {_a = _a * 10 + _c - '0'; _c = getchar();}
	if (f) _a = -_a;
}

typedef long long LL;

#define MAXN 100319

LL ans = 0;

struct edge {
	int next, to;
} e[MAXN * 2];

int ecnt = 0, head[MAXN];

void add(int u, int v) {
	e[++ecnt].to = v; e[ecnt].next = head[u]; head[u] = ecnt;
	e[++ecnt].to = u; e[ecnt].next = head[v]; head[v] = ecnt;
}

int n, k;
int siz[MAXN];
bool vis[MAXN];

int ds(int x) {
	siz[x] = 1; vis[x] = 1;	
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to]) siz[x] += ds(e[now].to);
	}
	vis[x] = 0;
	return siz[x];
}

int find(int x) {
	ds(x);
	for (int rt = x; ; ) {
		int maxs = 0, to = -1;
		for (int now = head[rt]; now; now = e[now].next) {
			if (!vis[e[now].to] && siz[e[now].to] < siz[rt] && siz[e[now].to] > maxs)
					maxs = siz[e[now].to], to = e[now].to;
			}
		if (maxs > siz[x] / 2) rt = to;
		else return rt;
	}
}

int l[MAXN], lcnt = 0, dep[MAXN];

void dd(int x) {
	l[++lcnt] = dep[x]; vis[x] = 1;
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to]) {
			dep[e[now].to] = dep[x] + 1;
			dd(e[now].to);
		}
	}
	vis[x] = 0;
}

LL work(int rt, int base) {
	dep[rt] = base; lcnt = 0; dd(rt);
	sort(l + 1, l + lcnt + 1);
	LL ret = 0; int i = 1, j = lcnt;
	while (i < j) {
		if (l[i] + l[j] <= k) ret += j - i, i++;
		else j--;
	}
	return ret;
}

void solve(int x) {
	int rt = find(x);
	ans += work(rt, 0); vis[rt] = 1;
	for (int now = head[rt]; now; now = e[now].next) {
		if (!vis[e[now].to]) {
			ans -= work(e[now].to, 1);
			solve(e[now].to);
		}
	}
}

int main() {
	read(n); read(k);
	for (int i = 1, u, v; i < n; i++) {
		read(u); read(v); add(u, v);
	}
	solve(1);
	printf("%lld\n", ans);
	return 0;
}

BZOJ1468 BZOJ3365 都是这道题的带边权写法,多倍经验啊。


给你一棵\(n \leq 10^4\)个点的带权有根树,有\(p \leq 100\)个询问,每次询问树中是否存在一条长度为\(len \leq 10 ^ 6\)(注意原题的数据范围写错了)的路径,输出YesNo .(BZOJ1460)

如果我们套用上一题的方法,时间复杂度为\(O(pn \log ^ 2 n)\),最坏情况下时间复杂度为\(2 \times 10 ^ 8\),如果没有十分高超的卡常技巧是过不了这道题的(事实上因为树的dfs自带至少为4的常数,是几乎不可能卡进时限的)。所以我们需要优化。

事实上由于我们只需判是否存在就可以了,而且良心的出题人给了我们边权小于等于\(10^6\)这个良心条件。那我们可以用栈记录每一颗子树内的深度值,并且在子树dfs完之后和之前的比较深度是否可以恰好达成\(len\)即可。时间复杂度优化为\(O(pn \log n)\)。

注意\(len = 0\)的情况啊qaq神坑……

/* Never say die. */
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <ctime>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <string>
#include <map>
#include <set>
#include <bitset>
using namespace std;
int _c = 0;
template <class T> inline void read(T &_a) {
	bool f = 0;
	while (_c < '0' || _c>'9') {if (_c == '-') f = 1; _c=getchar();}
	_a = 0;
	while (_c >= '0' && _c <= '9') {_a = _a * 10 + _c - '0'; _c = getchar();}
	if (f) _a = -_a;
}

#define MAXN 10319
typedef long long LL;

struct edge {
	int next, to, w;
} e[MAXN * 2];

int head[MAXN], ecnt = 0;

void add(int u, int v, int w) {
	e[++ecnt].to = v; e[ecnt].next = head[u]; e[ecnt].w = w; head[u] = ecnt;
	e[++ecnt].to = u; e[ecnt].next = head[v]; e[ecnt].w = w; head[v] = ecnt;
}

int n, p, k;
bool vis[MAXN];
int siz[MAXN];

int ds(int x, int fa) {
	siz[x] = 1;	
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to] && e[now].to != fa) siz[x] += ds(e[now].to, x);
	}
	return siz[x];
}

int find(int x) {
	ds(x, 0);
	for (int rt = x; ; ) {
		int maxs = 0, to = -1;
		for (int now = head[rt]; now; now = e[now].next) {
			if (!vis[e[now].to] && siz[e[now].to] < siz[rt]) {
				if (siz[e[now].to] > maxs) maxs = siz[e[now].to], to = e[now].to;
			}
		}
		if (maxs * 2 >= siz[x]) rt = to;
		else return rt;
	}
}


int dep[MAXN], s[MAXN], top;
int ext[1000319], tag = 0;


void dd(int x, int fa) {
	if (dep[x] > k) return;
	s[++top] = dep[x];
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to] && e[now].to != fa) {
			dep[e[now].to] = dep[x] + e[now].w;
			dd(e[now].to, x);
		}
	}
}

bool work(int x) {
	ext[0] = ++tag;
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to]) {
			dep[e[now].to] = e[now].w; top = 0; dd(e[now].to, x);
			for (int i = 1; i <= top; i++) if (ext[k - s[i]] == tag) return 1;
			for (int i = 1; i <= top; i++) ext[s[i]] = tag;
		}
	}
	return 0;
}

bool solve(int x) {
	x = find(x);
	vis[x] = 1;
	if (work(x)) return 1;
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to]) if (solve(e[now].to)) return 1;
	}
	return 0;
}

int main() {
	read(n); read(p);
	for (int i = 1, u, v, w; i < n; i++) read(u), read(v), read(w), add(u, v, w);
	for (int i = 1; i <= p; i++) {
		read(k);
		if (k == 0) {puts("Yes"); continue;} 
		memset(vis, 0, sizeof(vis));
		if (solve(1)) puts("Yes");
		else puts("No");
	}
	return 0;
}


一颗\(n \leq 2 \times 10 ^ 4\)个节点的树,求独立随机选取地两个点(可以相同)路径长度是3的倍数的期望概率,输出既约分数。 (BZOJ2152)

按照长度\(\bmod \ 3\)分组统计合并答案算算就可以了……时间复杂度为\(O(3 n \log n)\)。

/* Never say die. */
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <ctime>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <string>
#include <map>
#include <set>
#include <bitset>
using namespace std;
int _c = 0;
template <class T> inline void read(T &_a) {
	bool f = 0;
	while (_c < '0' || _c>'9') {if (_c == '-') f = 1; _c=getchar();}
	_a = 0;
	while (_c >= '0' && _c <= '9') {_a = _a * 10 + _c - '0'; _c = getchar();}
	if (f) _a = -_a;
}
typedef long long LL;
#define MAXN 100319

LL gcd(LL a, LL b) {
	while (b) {
		LL c = a;
		a = b;
		b = c % b;
	}
	return a;
}

struct edge {
	int next, to, w;
} e[MAXN * 2];

int head[MAXN], ecnt;
void add(int u, int v, int w) {
	e[++ecnt] = (edge){head[u], v, w}; head[u] = ecnt;
	e[++ecnt] = (edge){head[v], u, w}; head[v] = ecnt;
}

int siz[MAXN];
bool vis[MAXN];

int ds(int x) {
	vis[x] = 1; siz[x] = 1;
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to]) siz[x] += ds(e[now].to);
	}
	vis[x] = 0;
	return siz[x];
}

int find(int x) {
	ds(x);
	for (int rt = x; ; ) {
		int maxs = -1, to;
		for (int now = head[rt]; now; now = e[now].next) {
			if (!vis[e[now].to] && siz[e[now].to] < siz[rt] && siz[e[now].to] > maxs) {
				maxs = siz[e[now].to];
				to = e[now].to;
			}
		}
		if (maxs * 2 >= siz[x]) rt = to;
		else return rt;
	}
}

int n, dep[MAXN];
int s[3], cnt[3];
LL ans = 0;

void dd(int x) {
	dep[x] += dep[x] > 2 ? -3 : 0;
	s[dep[x]]++; vis[x] = 1;
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to]) {
			dep[e[now].to] = dep[x] + e[now].w;
			dd(e[now].to);
		}
	}
	vis[x] = 0;
}

void work(int x) {
	cnt[0] = 1, cnt[1] = cnt[2] = 0;
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to]) {
			dep[e[now].to] = e[now].w; dd(e[now].to);
			ans += (LL)s[0] * cnt[0] + (LL)s[1] * cnt[2] + (LL)s[2] * cnt[1];
			for (int i = 0; i < 3; i++) {
				cnt[i] += s[i], s[i] = 0;
			}
		}
	}
}

void solve(int x) {
	x = find(x);
	vis[x] = 1;
	work(x);
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to]) solve(e[now].to);
	}
}


int main() {
	read(n);
	for (int i = 1, u, v, w; i < n; i++) read(u), read(v), read(w), add(u, v, w % 3);
	solve(1); ans = ans * 2 + n;
	LL tot = (LL) n * n, g = gcd(tot, ans);
	printf("%lld/%lld\n", ans/g, tot/g);
	return 0;
}


给定一棵\(n \leq 5 \times 10 ^ 4\)个节点的带边权树,\(w_i \leq 10 ^ 4\),求前\(m \leq 3 \times 10 ^ 5 \)大的路径长度。(BZOJ3784)

如果使用类似于NOI2010:超级钢琴的做法,那么可以在套一个点分治在\(O(n \log ^ 2 n)\)时间复杂度内做出这道题。(然而我太蠢了并没有理解到该怎么搞qaq

如果我们知道了第\(m\)条边的长度,那么我们可以利用点分治在\(O(n \log ^ 2 n + m \log n)\)的时间复杂度内求出答案。现在主要问题转化为怎么求第\(m\)条边的长度。

一个比较好想到的思路是二分答案,每次点分治验证答案。然而这样子的话时间复杂度为\(O(n \log ^ 2 n \log \sum w_i)\),显然是不能接受的。

但是我们发现,其实每次分治时,找到的值是完全相同的!由于点分治保证了每次找到的值数量总和是在\(O(n \log n)\)级别的,瓶颈在于 sort ,那么我们完全可以将每次找到的深度值直接存下来,空间完全够用。这样,验证答案的复杂度就从\(O(n \log ^ 2 n \log \sum w_i)\)下降为\(O(n \log n \log \sum w_i)\),可以接受了。总的时间复杂度为\(O(2 n \log ^ 2 n + n \log n \log \sum w_i + m \log n)\),(相比起普通点分治)额外的空间复杂度为\(O(2 n \log n + 2 n)\)。

注意细节!注意细节!注意细节!我因为一些很奇怪的地方被坑了好多次……

/* Never say die. */
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <ctime>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <string>
#include <map>
#include <set>
#include <bitset>
#include <cassert>
using namespace std;
int _c = 0;
template <class T> inline void read(T &_a) {
	bool f = 0;
	while (_c < '0' || _c>'9') {if (_c == '-') f = 1; _c=getchar();}
	_a = 0;
	while (_c >= '0' && _c <= '9') {_a = _a * 10 + _c - '0'; _c = getchar();}
	if (f) _a = -_a;
}
typedef long long LL;
#define MAXN 50319

struct edge {
	int next, to, w;
} e[MAXN * 2];

int head[MAXN], ecnt = 0, blcnt = 0;
void add(int u, int v, int w) {
	e[++ecnt] = (edge) {head[u], v, w}; head[u] = ecnt;
	e[++ecnt] = (edge) {head[v], u, w}; head[v] = ecnt;
}

int siz[MAXN];
bool vis[MAXN];
int ds(int x) {
	siz[x] = 1; vis[x] = 1;
	for (int now = head[x]; now; now = e[now].next) 
		if (!vis[e[now].to]) siz[x] += ds(e[now].to);
	vis[x] = 0;
	return siz[x];
}

int find(int x) {
	ds(x);
	for (int rt = x; ; ) {
		int maxs = 0, to;
		for (int now = head[rt]; now; now = e[now].next) {
			if (!vis[e[now].to] && siz[rt] > siz[e[now].to] && siz[e[now].to] > maxs) {
				maxs = siz[e[now].to];
				to = e[now].to;
			}
		}
		if (maxs * 2 >= siz[x]) rt = to;
		else return rt;
	}
}

vector <int> l[MAXN], bl[MAXN];

int dep[MAXN];
int n, m;

void dd(int x, vector<int> &r) {
	r.push_back(dep[x]); vis[x] = 1;
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to]) {
			dep[e[now].to] = dep[x] + e[now].w;
			dd(e[now].to, r);
		}
	}
	vis[x] = 0;
}

void pre_work(int x, vector <int> &a, int base = 0) {
	dep[x] = base; dd(x, a);
	sort(a.begin(), a.end());
}

void pre_solve(int x) {
	x = find(x);
	pre_work(x, l[x]); vis[x] = 1;
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to]) {
			pre_work(e[now].to, bl[++blcnt], e[now].w);
			pre_solve(e[now].to);
		}
	}
}

int mid;

LL work(vector<int> &a) {
	int i = 0, j = a.size() - 1;
	LL ret = 0;
	while (i < j) {
		if (a[i] + a[j] >= mid) ret += j-- - i;
		else i++;
	}
	return ret;
}

bool check() {
	LL tot = 0;
	for (int i = 1; i <= n; i++) if (l[i].size() > 1) tot += work(l[i]);
	for (int i = 1; i <= n; i++) if (bl[i].size() > 1) tot -= work(bl[i]);
	return tot <= m;
}

vector <int> ans;

multiset <int, greater<int> > s;
multiset <int, greater<int> >::iterator it;


int ls[MAXN], lcnt = 0;

void ddd(int x) {
	ls[++lcnt] = dep[x]; vis[x] = 1;
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to]) {
			dep[e[now].to] = dep[x] + e[now].w;
			ddd(e[now].to);
		}
	}
	vis[x] = 0;
}

void getans(int x) {
	s.clear(); s.insert(0);
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to]) {
			dep[e[now].to] = e[now].w;
			lcnt = 0;
			ddd(e[now].to);
			for (int i = 1; i <= lcnt; i++) {
				for (it = s.begin(); it != s.end() && ls[i] + *it > mid; it++) {
					ans.push_back(ls[i] + *it);
				}
			}
			for (int i = 1; i <= lcnt; i++) s.insert(ls[i]); 
		}
	}
}

void solve(int x) {
	x = find(x);
	vis[x] = 1; getans(x);
	for (int now = head[x]; now; now = e[now].next) {
		if (!vis[e[now].to]) {
			solve(e[now].to);
		}
	}
}

int main() {
	read(n); read(m);
	int L = 1, R = 0;
	for (int i = 1, u, v, w; i < n; i++) {
		read(u); read(v); read(w); add(u, v, w); R += w;
	}
	pre_solve(1);
	while (L < R) {
		mid = (L + R) >> 1;
		if (check()) R = mid;
		else L = mid + 1;
	}
	memset(vis, 0, sizeof vis);
	solve(1); sort(ans.begin(), ans.end());
	for (int i = ans.size() - 1; i >= 0 && m > 0; i--) printf("%d\n", ans[i]), m--;
	while (m--) printf("%d\n", mid);
	return 0;
}


树分治思路很妙但是不直观,而且实现比较绕……调试累死人。哎什么时候码力强了才能够熟练啊。

发表评论

电子邮件地址不会被公开。 必填项已用*标注