洛谷 P3304 [SDOI2013] 直径 题解
题目分析
第一部分好说,求直径,dfs或者DP都可以。
第二部分,有一个定理,就是所有直径中点重叠。
那么有两种情况
-
一种是中点在一个节点上,那么显然这个点是每条直径的终点,也就是说直径的一半相等。从这个点出发dfs,找出所有最远点。如果只有两条,输出depth之和。否则求lca,lca的depth就是重叠的数量。
-
另一种,中点在一条边上。从这个边出发,两侧分别dfs找最远,再分类讨论,有的求lca,有的输出。具体见代码即可。
题解中还有其他思路:比如从一条直径上开始dfs(利用直径同侧长度一定相等的性质),还有两次dp求出总cnt和边cnt进行统计的,还有去掉中点所在边之后进行RootLeafPath的。这些思路都很有用。
顺便提一句,在题解里面看到了专门用结构体Data来统计数据(维护最大值和出现次数)以及通过左右分开上升最大值来排除自身这两个trick,觉得十分好用。
代码
#include <bits/stdc++.h>
#define N (long long)(200000 + 5)
#define LOGN (long long)(25)
using namespace std;
typedef long long LL;
struct node
{
LL v, w, next;
}e[N * 2];
LL n;
LL head[N];
LL cnt;
LL d1[N],d2[N];
LL k1[N],k2[N];
LL d,o;
bool vis[N];//
void adde(int u, int v, int w)
{
cnt++;
e[cnt].v = v;
e[cnt].w = w;
e[cnt].next = head[u];
head[u] = cnt;
}
void dpdfs(int p)
{
vis[p] = true;
d1[p] = d2[p] = 0;
k1[p] = k2[p] = 0;
for(int i = head[p];i != 0;i = e[i].next)
{
if(!vis[e[i].v])
{
vis[e[i].v] = true;
dpdfs(e[i].v);
if(d1[e[i].v] + e[i].w > d1[p])
{
d2[p] = d1[p];
k2[p] = k1[p];
d1[p] = d1[e[i].v] + e[i].w;
k1[p] = e[i].v;
}
else if(d1[e[i].v] + e[i].w > d2[p])//
{
d2[p] = d1[e[i].v] + e[i].w;
k2[p] = e[i].v;
}
}
}
if(d2[p] + d1[p] > d)
{
d = d2[p] + d1[p];
o = p;
}
vis[p] = false;
}
LL depth[N];
LL len[N];
LL maxlendep1;
LL maxlendep2;
LL maxlen1;
LL maxlen2;
LL maxlenidx1[N];
LL maxlenidx2[N];
LL maxlencnt1;
LL maxlencnt2;
void deepdfs(int p,LL *mxi,LL &mxc,LL &mxl, LL &mxd)
{
vis[p] = true;
for(int i = head[p];i != 0;i = e[i].next)
{
if(!vis[e[i].v])
{
vis[e[i].v] = true;
depth[e[i].v] = depth[p] + 1;
len[e[i].v] = len[p] + e[i].w;
if(mxl < len[e[i].v])
{
mxl = len[e[i].v];
mxd = depth[e[i].v];
mxc = 1;
mxi[mxc] = e[i].v;
}
else if(mxl == len[e[i].v])
{
mxi[++mxc] = e[i].v;
}
deepdfs(e[i].v,mxi,mxc,mxl,mxd);
}
}
vis[p] = false;
}
int lcadepth[N];
int fa[N][LOGN];
void lcadfs(int p)
{
vis[p] = true;
for(int i = head[p];i != 0;i = e[i].next)
{
if(!vis[e[i].v])
{
vis[e[i].v] = true;
lcadepth[e[i].v] = lcadepth[p] + 1;
fa[e[i].v][0] = p;
for(int j = 1;j <= log2(n);j++)
fa[e[i].v][j] = fa[fa[e[i].v][j - 1]][j - 1];
lcadfs(e[i].v);
}
}
vis[p] = false;
}
int lca(int x, int y)
{
if(lcadepth[y] > lcadepth[x])
swap(x,y);
for(int i = log2(n); i >= 0;i--)
if(lcadepth[fa[x][i]] >= lcadepth[y])
{
x = fa[x][i];
}
if(x == y) return x;
for(int i = log2(n);i >= 0;i--)
if(fa[x][i] != fa[y][i])
{
x = fa[x][i];
y = fa[y][i];
}
return fa[x][0];
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
cin >> n;
for(int i = 1;i <= n - 1;i++)
{
int u,v,w;
cin >> u >> v >> w;
adde(u,v,w);
adde(v,u,w);
}
dpdfs(1);
cout << d << "n";
LL idx = o, lidx = idx;
LL deep = 0;
while(idx != 0)
{
if(deep + d2[o] >= d / 2)
{
if(((d % 2) == 0) && ((deep + d2[o]) == (d / 2)))
{
lidx = idx;
}
break;
}
lidx = idx;
idx = k1[idx];
deep += d1[lidx] - d1[idx];
}
if(idx != lidx)
{
vis[lidx] = true;
depth[idx] = 0;
len[idx] = 0;
deepdfs(idx,maxlenidx1,maxlencnt1,maxlen1,maxlendep1);
vis[lidx] = false;
vis[idx] = true;
depth[lidx] = 0;
len[lidx] = 0;
deepdfs(lidx,maxlenidx2,maxlencnt2,maxlen2,maxlendep2);
vis[idx] = false;
lcadepth[0] = -1;
lcadfs(idx);
int lca1 = 0, lca2 = 0;
if(maxlencnt1 >= 1)
{
lca1 = maxlenidx1[1];
for(int i = 2;i <= maxlencnt1;i++)
{
lca1 = lca(lca1,maxlenidx1[i]);
}
}
if(maxlencnt2 >= 1)
{
lca2 = maxlenidx2[1];
for(int i = 2;i <= maxlencnt2;i++)
{
lca2 = lca(lca2,maxlenidx2[i]);
}
}
int cnt1 = maxlencnt1, cnt2 = maxlencnt2;
if(cnt1 == 1 && cnt2 == 0) cout << maxlendep1 + 1 << "n";
if(cnt1 == 0 && cnt2 == 1) cout << maxlendep2 + 1 << "n";
if(cnt1 == 2 && cnt2 == 0) cout << depth[lca1] + 1 << "n";
if(cnt1 == 1 && cnt2 == 1) cout << maxlendep1 + maxlendep2 << "n";
if(cnt1 == 0 && cnt2 == 2) cout << depth[lca2] + 1 << "n";
if(cnt1 >= 2 && cnt2 == 1) cout << depth[lca1] + maxlendep2 + 1 << "n";
if(cnt1 == 1 && cnt2 >= 2) cout << maxlendep1 + depth[lca2] + 1 << "n";
if(cnt1 >= 2 && cnt2 >= 2) cout << depth[lca1] + 1 + depth[lca2] << "n";
}
else
{
vis[idx] = true;
depth[idx] = 0;
len[idx] = 0;
deepdfs(idx,maxlenidx1,maxlencnt1,maxlen1,maxlendep1);
vis[idx] = false;
lcadepth[0] = -1;
lcadfs(idx);
int lca1 = 0;
if(maxlencnt1 >= 1)
{
lca1 = maxlenidx1[1];
for(int i = 2;i <= maxlencnt1;i++)
{
lca1 = lca(lca1,maxlenidx1[i]);
}
}
if(maxlencnt1 == 2) cout << depth[maxlenidx1[1]] + depth[maxlenidx1[2]] << "n";
else cout << depth[lca1] << "n";
}
return 0;
}
/*
6
3 1 80
1 4 10
4 2 70
4 5 50
4 6 90
*/
/*
6
1 2 1
2 3 4
2 4 3
1 6 2
1 5 3
*/
内容来源于网络如有侵权请私信删除
文章来源: 博客园
- 还没有人评论,欢迎说说您的想法!