首先求最小生成树,我们可以发现,严格次小生成树只会在MST上改变一条边,不可能改变更多的边。所以我们可以枚举不在MST里的边,比如枚举到(u,v),这条边一定是大于MST中(u,v)链上的每一条边的,所以我们找出MST中(u,v)链上与这条边的差最小的那条(注意不能为0)更换就好了。
至于链上最值的维护上倍增就好了,同noip2013货车运输。
(看起来可以拿来给noip小朋友做。)
#include<iostream>
#include<cstdio>
#include<algorithm>
#define N 100005
#define M 300005
#define inf 2147483647
#define ll long long
using namespace std;
int n,m,cnt,tot,mn=inf;
ll ans;
int f[N],head[N],deep[N],fa[N][18],d1[N][18],d2[N][18];
struct node {int x,y,v;} a[M];
bool mark[M];
int next[N<<1],list[N<<1],key[N<<1];
inline int read()
{
int a=0,f=1; char c=getchar();
while (c<'0'||c>'9') {if (c=='-') f=-1; c=getchar();}
while (c>='0'&&c<='9') {a=a*10+c-'0'; c=getchar();}
return a*f;
}
inline void insert(int x,int y,int z)
{
next[++cnt]=head[x];
head[x]=cnt;
list[cnt]=y;
key[cnt]=z;
}
int find(int i)
{
return f[i]==i?i:f[i]=find(f[i]);
}
inline bool cmp(node a,node b)
{
return a.v<b.v;
}
void dfs(int x)
{
for (int i=1;(1<<i)<=deep[x];i++)
{
fa[x][i]=fa[fa[x][i-1]][i-1];
d1[x][i]=max(d1[x][i-1],d1[fa[x][i-1]][i-1]);
if (d1[x][i-1]<d1[fa[x][i-1]][i-1])
d2[x][i]=max(d1[x][i-1],d2[fa[x][i-1]][i-1]);
if (d1[x][i-1]==d1[fa[x][i-1]][i-1])
d2[x][i]=max(d2[x][i-1],d2[fa[x][i-1]][i-1]);
if (d1[x][i-1]>d1[fa[x][i-1]][i-1])
d2[x][i]=max(d2[x][i-1],d1[fa[x][i-1]][i-1]);
}
for (int i=head[x];i;i=next[i])
if (list[i]!=fa[x][0])
{
fa[list[i]][0]=x;
d1[list[i]][0]=key[i];
d2[list[i]][0]=0;
deep[list[i]]=deep[x]+1;
dfs(list[i]);
}
}
inline int lca(int x,int y)
{
if (deep[x]<deep[y]) swap(x,y);
int t=deep[x]-deep[y];
for (int i=0;(1<<i)<=t;i++)
if ((1<<i)&t) x=fa[x][i];
for (int i=17;i>=0;i--)
if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return x==y?x:fa[x][0];
}
inline void calc(int x,int f,int v)
{
int m1=0,m2=0,t=deep[x]-deep[f];
for (int i=0;(1<<i)<=t;i++)
if ((1<<i)&t)
{
if (d1[x][i]>m1) m2=m1,m1=d1[x][i];
m2=max(m2,d2[x][i]);
x=fa[x][i];
}
mn=min(mn,m1==v?v-m2:v-m1);
}
inline void solve(int i)
{
int t=lca(a[i].x,a[i].y);
calc(a[i].x,t,a[i].v); calc(a[i].y,t,a[i].v);
}
int main()
{
n=read(); m=read();
for (int i=1;i<=n;i++) f[i]=i;
for (int i=1;i<=m;i++)
a[i].x=read(),a[i].y=read(),a[i].v=read();
sort(a+1,a+m+1,cmp);
for (int i=1;i<=m;i++)
{
int p=find(a[i].x),q=find(a[i].y);
if (p!=q)
{
f[p]=q;
ans+=a[i].v;
mark[i]=1;
insert(a[i].x,a[i].y,a[i].v);
insert(a[i].y,a[i].x,a[i].v);
tot++;
if (tot==n-1) break;
}
}
dfs(1);
for (int i=1;i<=m;i++)
if (!mark[i]) solve(i);
cout << ans+mn;
return 0;
}