用于解决树上路径长问题的算法,复杂度是比较优秀的 $O(nlogn)$

直接按照题目来讲。点分治1

我们要维护树上路径长度为 $k$ 的路径是否存在(当然视题目而定,点分治的操作比较灵活)。这类问题我们选用点分治。

具体怎么操作呢?

一般就是,每次用vis标记删除根节点,然后对于所有没被删除的(除了它自己)计算到当前根节点的距离,然后对于不同子树中的点进行组合,看是否出现我们所要的答案即可。

但是显然这样并不是优秀的。我们发现还有很多地方可以优化。

根节点

很显然我们刚才的做法会被链这类的东西卡掉,于是考虑如何选择更优的节点。

然后我们发现如果每次选择树的重心,这样就可以保证复杂度。而重心我们直接套路的求就可以了。

void findRoot(int u,int f){
	sz[u]=1;
	weight[u]=0;
	for(int i=head[u];~i;i=nxt[i]){
		if(vis[to[i]]||to[i]==f) continue;
		findRoot(to[i],u);
		sz[u]+=sz[to[i]];
		weight[u]=Max(weight[u],sz[to[i]]);
	}
	weight[u]=Max(weight[u],sum-sz[u]);
	if(!rt||weight[u]<weight[rt]){
		rt=u;
	}
}

点组合

如果暴力组合显然是十分暴力的(听君一席话),考虑优化这个组合。不难发现如果对于我们要求的答案进行枚举,每次用这个要求的答案去减去一些到根的路径长,看这个差值是否出现,于是可以维护 $judge$ 数组,存储一个距离是否出现。

这样我们就可以把 $O(n^2)$ 降到 $O(nm)$。

标记删除

在我们遍历完之后需要清空 $judge$ 数组,为了保证复杂度正确,我们将要删除的部分加入清扫队列。最后删除。

那么展示完整代码:

#include<bits/stdc++.h>
using namespace std;
#define file(a) freopen(#a".in","r",stdin),freopen(#a".out","w",stdout)
#define LL long long
#define N 10000+3
#define M 110 
#define MAXN 10000000+3
#define INF (int)(1e9)
inline int read(){
	int s=0,f=1;
	char ch=getchar();
	while(ch<'0'||'9'<ch) {if(ch=='-') f=-1;ch=getchar();}
	while('0'<=ch&&ch<='9') {s=s*10+(ch^48);ch=getchar();}
	return s*f;
}
inline int Max(int x,int y){
	return x>y?x:y;
}
int n,m; 
vector<int>head,to,nxt,val;
void join(int u,int v,int w){
	nxt.push_back(head[u]);
	head[u]=to.size();
	to.push_back(v);
	val.push_back(w);
}
bool vis[N];
int rt,sum;
int sz[N],weight[N];
int ques[M],ans[M];
void findRoot(int u,int f){
	sz[u]=1;
	weight[u]=0;
	for(int i=head[u];~i;i=nxt[i]){
		if(vis[to[i]]||to[i]==f) continue;
		findRoot(to[i],u);
		sz[u]+=sz[to[i]];
		weight[u]=Max(weight[u],sz[to[i]]);
	}
	weight[u]=Max(weight[u],sum-sz[u]);
	if(!rt||weight[u]<weight[rt]){
		rt=u;
	}
}
bool judge[MAXN];
int dis[N],fd[N],tot;
void findDis(int u,int f){
	fd[++tot]=dis[u];
	for(int i=head[u];~i;i=nxt[i]){
		if(vis[to[i]]||to[i]==f) continue;
		dis[to[i]]=dis[u]+val[i];
		findDis(to[i],u);
	}
}
int clr[N],top;
void calc(int u){
	top=0;
	for(int i=head[u];~i;i=nxt[i]){
		if(vis[to[i]]) continue;
		tot=0;dis[to[i]]=val[i];findDis(to[i],u);
		for(int i=1;i<=tot;++i){
			for(int j=1;j<=m;++j){
				if(ques[j]>=fd[i]){
					ans[j]=judge[ques[j]-fd[i]];
				}
			}
		}
		for(int i=1;i<=tot;++i){
			clr[++top]=fd[i];
			judge[fd[i]]=1;
		}
	}
	for(int i=1;i<=top;++i){
		judge[clr[i]]=0;
	}
}
void dfs(int u){
	vis[u]=judge[0]=1;
	calc(u);
	for(int i=head[u];~i;i=nxt[i]){
		if(vis[to[i]]) continue;
		rt=0;sum=sz[to[i]];
		findRoot(to[i],u);
		dfs(to[i]);
	}
}
int main(){
	n=read();m=read();
	head.resize(n+1,-1); 
	for(int i=1;i<n;++i){
		int u=read(),v=read(),w=read();
		join(u,v,w);join(v,u,w);
	} 
	for(int i=1;i<=m;++i){
		ques[i]=read();
	}
	sum=n;
	findRoot(1,1);
	dfs(rt);
	for(int i=1;i<=m;++i){
		if(ans[i]){
			printf("AYE\n");
		}else printf("NAY\n");
	}
	return 0;
}

聪聪可可

这个题提供点分治的变换思路。用 $judge$ 维护一个长度的路径的出现次数即可。

代码:

#include<bits/stdc++.h>
using namespace std;
#define file(a) freopen(#a".in","r",stdin),freopen(#a".out","w",stdout)
#define LL long long
#define N 20000+3
inline int read(){
	int s=0,f=1;
	char ch=getchar();
	while(ch<'0'||'9'<ch) {if(ch=='-') f=-1;ch=getchar();}
	while('0'<=ch&&ch<='9') {s=s*10+(ch^48);ch=getchar();}
	return s*f;
}
inline int Max(int x,int y){
	return x>y?x:y;
}
vector<int>head,to,nxt,val;
void join(int u,int v,int w){
	nxt.push_back(head[u]);
	head[u]=to.size();
	to.push_back(v);
	val.push_back(w);
}
inline int inc(int x,int y){
	return (x+=y)>=3?x-3:x;
}
inline int dec(int x,int y){
	return (x-=y)<0?x+3:x;
}
int n;
int rt,sum;
int sz[N],weight[N];
int ans1,ans2;
bool vis[N];
void findRoot(int u,int f){
	sz[u]=1;
	weight[u]=0;
	for(int i=head[u];~i;i=nxt[i]){
		if(to[i]==f||vis[to[i]]) continue;
		findRoot(to[i],u);
		sz[u]+=sz[to[i]];
		weight[u]=Max(weight[u],sz[to[i]]);
	} 
	weight[u]=Max(weight[u],sum-sz[u]);
	if(!rt||weight[u]<weight[rt]){
		rt=u;
	}
} 
int judge[N],fd[N],tot;
int dis[N];
void findDis(int u,int f){
	fd[++tot]=dis[u];
	for(int i=head[u];~i;i=nxt[i]){
		if(vis[to[i]]||to[i]==f) continue;
		dis[to[i]]=inc(dis[u],val[i]);
		findDis(to[i],u);
	}
}
int clr[N],top;
void calc(int u){
	top=0;
	int res1=0,res2=0;
	for(int i=head[u];~i;i=nxt[i]){
		if(vis[to[i]]) continue; 
		tot=0;dis[to[i]]=val[i];findDis(to[i],u);
		for(int j=1;j<=tot;++j){
			int tmp=dec(0,fd[j]);
			if(judge[tmp]){
				res1+=judge[tmp]; 
			}
			for(int k=0;k<=2;++k){
				res2+=judge[k];
			}
		}
		/*
		for(int j=0;j<=2;++j){
			printf("%d",judge[j]);
			putchar(j==2?'\n':' '); 
		}
		*/
		for(int j=1;j<=tot;++j){
			clr[++top]=fd[j];
			++judge[fd[j]];
		}
		/*
		putchar('	');
		for(int j=0;j<=2;++j){
			printf("%d",judge[j]);
			putchar(j==2?'\n':' '); 
		}
		*/
	}
	for(int i=1;i<=top;++i){
		--judge[clr[i]];
	}
	res1=(res1)*2;res2=(res2)*2;
	ans1+=res1+1;ans2+=res2+1; 
}
void dfs(int u){
	vis[u]=judge[0]=1;
	calc(u);
	for(int i=head[u];~i;i=nxt[i]){
		if(vis[to[i]]) continue;
		sum=sz[u];rt=0;
		findRoot(to[i],u);
		dfs(rt);
	}
}
int gcd(int a,int b){
	if(b==0) return a;
	return gcd(b,a%b);
}
int main(){
	n=read();
	head.resize(n+1,-1);
	for(int i=1;i<n;++i){
		int u=read(),v=read(),w=read()%3;
		join(u,v,w);join(v,u,w);
	}
	sum=n;
	findRoot(1,1);
	dfs(rt);
	int GCD=gcd(ans1,ans2);
	printf("%d/%d",ans1/GCD,ans2/GCD);
	return 0;
}

需要做题的话,其他题目可以在题目的题目推荐里找到。