点分治
解决树上点对问题
用于解决树上路径长问题的算法,复杂度是比较优秀的 $O(nlogn)$
直接按照题目来讲。点分治1
我们要维护树上路径长度为 $k$ 的路径是否存在(当然视题目而定,点分治的操作比较灵活)。这类问题我们选用点分治。
具体怎么操作呢?
一般就是,每次用vis标记删除根节点,然后对于所有没被删除的(除了它自己)计算到当前根节点的距离,然后对于不同子树中的点进行组合,看是否出现我们所要的答案即可。
但是显然这样并不是优秀的。我们发现还有很多地方可以优化。
根节点
很显然我们刚才的做法会被链这类的东西卡掉,于是考虑如何选择更优的节点。
然后我们发现如果每次选择树的重心,这样就可以保证复杂度。而重心我们直接套路的求就可以了。
void findRoot(int u,int f){
sz[u]=1;
weight[u]=0;
for(int i=head[u];~i;i=nxt[i]){
if(vis[to[i]]||to[i]==f) continue;
findRoot(to[i],u);
sz[u]+=sz[to[i]];
weight[u]=Max(weight[u],sz[to[i]]);
}
weight[u]=Max(weight[u],sum-sz[u]);
if(!rt||weight[u]<weight[rt]){
rt=u;
}
}
点组合
如果暴力组合显然是十分暴力的(听君一席话),考虑优化这个组合。不难发现如果对于我们要求的答案进行枚举,每次用这个要求的答案去减去一些到根的路径长,看这个差值是否出现,于是可以维护 $judge$ 数组,存储一个距离是否出现。
这样我们就可以把 $O(n^2)$ 降到 $O(nm)$。
标记删除
在我们遍历完之后需要清空 $judge$ 数组,为了保证复杂度正确,我们将要删除的部分加入清扫队列。最后删除。
那么展示完整代码:
#include<bits/stdc++.h>
using namespace std;
#define file(a) freopen(#a".in","r",stdin),freopen(#a".out","w",stdout)
#define LL long long
#define N 10000+3
#define M 110
#define MAXN 10000000+3
#define INF (int)(1e9)
inline int read(){
int s=0,f=1;
char ch=getchar();
while(ch<'0'||'9'<ch) {if(ch=='-') f=-1;ch=getchar();}
while('0'<=ch&&ch<='9') {s=s*10+(ch^48);ch=getchar();}
return s*f;
}
inline int Max(int x,int y){
return x>y?x:y;
}
int n,m;
vector<int>head,to,nxt,val;
void join(int u,int v,int w){
nxt.push_back(head[u]);
head[u]=to.size();
to.push_back(v);
val.push_back(w);
}
bool vis[N];
int rt,sum;
int sz[N],weight[N];
int ques[M],ans[M];
void findRoot(int u,int f){
sz[u]=1;
weight[u]=0;
for(int i=head[u];~i;i=nxt[i]){
if(vis[to[i]]||to[i]==f) continue;
findRoot(to[i],u);
sz[u]+=sz[to[i]];
weight[u]=Max(weight[u],sz[to[i]]);
}
weight[u]=Max(weight[u],sum-sz[u]);
if(!rt||weight[u]<weight[rt]){
rt=u;
}
}
bool judge[MAXN];
int dis[N],fd[N],tot;
void findDis(int u,int f){
fd[++tot]=dis[u];
for(int i=head[u];~i;i=nxt[i]){
if(vis[to[i]]||to[i]==f) continue;
dis[to[i]]=dis[u]+val[i];
findDis(to[i],u);
}
}
int clr[N],top;
void calc(int u){
top=0;
for(int i=head[u];~i;i=nxt[i]){
if(vis[to[i]]) continue;
tot=0;dis[to[i]]=val[i];findDis(to[i],u);
for(int i=1;i<=tot;++i){
for(int j=1;j<=m;++j){
if(ques[j]>=fd[i]){
ans[j]=judge[ques[j]-fd[i]];
}
}
}
for(int i=1;i<=tot;++i){
clr[++top]=fd[i];
judge[fd[i]]=1;
}
}
for(int i=1;i<=top;++i){
judge[clr[i]]=0;
}
}
void dfs(int u){
vis[u]=judge[0]=1;
calc(u);
for(int i=head[u];~i;i=nxt[i]){
if(vis[to[i]]) continue;
rt=0;sum=sz[to[i]];
findRoot(to[i],u);
dfs(to[i]);
}
}
int main(){
n=read();m=read();
head.resize(n+1,-1);
for(int i=1;i<n;++i){
int u=read(),v=read(),w=read();
join(u,v,w);join(v,u,w);
}
for(int i=1;i<=m;++i){
ques[i]=read();
}
sum=n;
findRoot(1,1);
dfs(rt);
for(int i=1;i<=m;++i){
if(ans[i]){
printf("AYE\n");
}else printf("NAY\n");
}
return 0;
}
这个题提供点分治的变换思路。用 $judge$ 维护一个长度的路径的出现次数即可。
代码:
#include<bits/stdc++.h>
using namespace std;
#define file(a) freopen(#a".in","r",stdin),freopen(#a".out","w",stdout)
#define LL long long
#define N 20000+3
inline int read(){
int s=0,f=1;
char ch=getchar();
while(ch<'0'||'9'<ch) {if(ch=='-') f=-1;ch=getchar();}
while('0'<=ch&&ch<='9') {s=s*10+(ch^48);ch=getchar();}
return s*f;
}
inline int Max(int x,int y){
return x>y?x:y;
}
vector<int>head,to,nxt,val;
void join(int u,int v,int w){
nxt.push_back(head[u]);
head[u]=to.size();
to.push_back(v);
val.push_back(w);
}
inline int inc(int x,int y){
return (x+=y)>=3?x-3:x;
}
inline int dec(int x,int y){
return (x-=y)<0?x+3:x;
}
int n;
int rt,sum;
int sz[N],weight[N];
int ans1,ans2;
bool vis[N];
void findRoot(int u,int f){
sz[u]=1;
weight[u]=0;
for(int i=head[u];~i;i=nxt[i]){
if(to[i]==f||vis[to[i]]) continue;
findRoot(to[i],u);
sz[u]+=sz[to[i]];
weight[u]=Max(weight[u],sz[to[i]]);
}
weight[u]=Max(weight[u],sum-sz[u]);
if(!rt||weight[u]<weight[rt]){
rt=u;
}
}
int judge[N],fd[N],tot;
int dis[N];
void findDis(int u,int f){
fd[++tot]=dis[u];
for(int i=head[u];~i;i=nxt[i]){
if(vis[to[i]]||to[i]==f) continue;
dis[to[i]]=inc(dis[u],val[i]);
findDis(to[i],u);
}
}
int clr[N],top;
void calc(int u){
top=0;
int res1=0,res2=0;
for(int i=head[u];~i;i=nxt[i]){
if(vis[to[i]]) continue;
tot=0;dis[to[i]]=val[i];findDis(to[i],u);
for(int j=1;j<=tot;++j){
int tmp=dec(0,fd[j]);
if(judge[tmp]){
res1+=judge[tmp];
}
for(int k=0;k<=2;++k){
res2+=judge[k];
}
}
/*
for(int j=0;j<=2;++j){
printf("%d",judge[j]);
putchar(j==2?'\n':' ');
}
*/
for(int j=1;j<=tot;++j){
clr[++top]=fd[j];
++judge[fd[j]];
}
/*
putchar(' ');
for(int j=0;j<=2;++j){
printf("%d",judge[j]);
putchar(j==2?'\n':' ');
}
*/
}
for(int i=1;i<=top;++i){
--judge[clr[i]];
}
res1=(res1)*2;res2=(res2)*2;
ans1+=res1+1;ans2+=res2+1;
}
void dfs(int u){
vis[u]=judge[0]=1;
calc(u);
for(int i=head[u];~i;i=nxt[i]){
if(vis[to[i]]) continue;
sum=sz[u];rt=0;
findRoot(to[i],u);
dfs(rt);
}
}
int gcd(int a,int b){
if(b==0) return a;
return gcd(b,a%b);
}
int main(){
n=read();
head.resize(n+1,-1);
for(int i=1;i<n;++i){
int u=read(),v=read(),w=read()%3;
join(u,v,w);join(v,u,w);
}
sum=n;
findRoot(1,1);
dfs(rt);
int GCD=gcd(ans1,ans2);
printf("%d/%d",ans1/GCD,ans2/GCD);
return 0;
}
需要做题的话,其他题目可以在题目的题目推荐里找到。