轮舞前夕

君と彼女の恋这套题的t1,这是t2,经典的树形dp,多了一个统计方案数而已,写的常熟还不如JZ初中的人小QAQQQQQQQQQQQ

#include<cctype>
#include<cstdio>
#include<map>
#include<cmath>
#include<queue>
#include<vector>
#include<cstring>
#include<cstdlib>
#include<iostream>
#include<iomanip>
#include<algorithm>
#define inf 1000000007
#define LL long long
#define pb push_back
using namespace std;
LL read()
{
        LL x=0,f=1;char ch=getchar();
        while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
        while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
        return x*f;
}
vector<int> e[100005];
int n,root;
LL f[100005][3],sum[100005][3],g[3],gg[3];
int calc (int i,int j)
{
        if(i==2)return 2;
        if(i==1||j==2)return 1;
        return 0;
}
void dp(int r,int fa)
{
        for (int i=0;i<=2;i++)
                sum[r][i]=1;
        f[r][2]=1;
        f[r][0]=0;
        f[r][1]=inf;
        for(int p=0,i;p<e[r].size();p++)
        {
                i=e[r][p];
                if(i==fa) continue;
                dp(i,r);
                for (int j=0;j<=2;j++)
                {
                        g[j]=inf;
                        gg[j]=0;
                }
                for (int j=0;j<=2;j++)
                        for (int k=0;k<=2;k++)
                        {
                                if(j!=2&&k==0) continue;
                                int l=calc(j,k);
                                g[l]=min(g[l],f[r][j]+f[i][k]);
                        }

                for (int j=0;j<=2;j++)
                        for (int k=0;k<=2;k++)
                        {
                                if (j!=2&&k==0) continue;
                                int l=calc(j,k);
                                if(f[r][j]+f[i][k]==g[l]) gg[l]=(gg[l]+(LL)sum[r][j]*sum[i][k])%inf;
                        }
                for (int j=0;j<=2;j++)
                {
                        f[r][j]=g[j];
                        sum[r][j]=gg[j];
                }
        }
}
int main()
{
        n=read();root=(n+1)/2;
        if(n==1)
        {
                printf("1\n1\n");
                return 0;
        }
        int x,y;
        for (int i=1;i<=n-1;i++)
        {
                scanf("%d%d",&x,&y);
                e[x].pb(y);
                e[y].pb(x);
        }
        dp(root,0);
        int anss=min(f[root][1],f[root][2]),ans=0;
        if(f[root][1]==anss) ans=sum[root][1];
        if(f[root][2]==anss) ans=(ans+sum[root][2])%inf;
        cout<<anss<<endl;
        cout<<ans<<endl;
        return 0;
}


你可能感兴趣的:(dp)