题目传送门
You are given a tree (a connected undirected graph without cycles) of n vertices. Each of the n−1 edges of the tree is colored in either black or red.
You are also given an integer k. Consider sequences of k vertices. Let’s call a sequence [a1,a2,…,ak] good if it satisfies the following criterion:
We will walk a path (possibly visiting same edge/vertex multiple times) on the tree, starting from a1 and ending at ak.
Start at a1, then go to a2 using the shortest path between a1 and a2, then go to a3 in a similar way, and so on, until you travel the shortest path between ak−1 and ak.
If you walked over at least one black edge during this process, then the sequence is good.
Consider the tree on the picture. If k=3 then the following sequences are good: [1,4,7], [5,5,3] and [2,3,7]. The following sequences are not good: [1,4,6], [5,5,5], [3,7,3].
There are nk sequences of vertices, count how many of them are good. Since this number can be quite large, print it modulo 109+7.
Input
The first line contains two integers n and k (2≤n≤105, 2≤k≤100), the size of the tree and the length of the vertex sequence.
Each of the next n−1 lines contains three integers ui, vi and xi (1≤ui,vi≤n, xi∈{0,1}), where ui and vi denote the endpoints of the corresponding edge and xi is the color of this edge (0 denotes red edge and 1 denotes black edge).
Output
Print the number of good sequences modulo 109+7.
Examples
input:
4 4
1 2 1
2 3 1
3 4 1
output
252
input
4 6
1 2 0
1 3 0
1 4 0
output
0
input
3 5
1 2 1
2 3 0
output
210
Note
In the first example, all sequences (44) of length 4 except the following are good:
[1,1,1,1]
[2,2,2,2]
[3,3,3,3]
[4,4,4,4]
In the second example, all edges are red, hence there aren’t any good sequences.
题目大意: 给定一个具有n个点n-1条边的图,所有边由黑、红两种颜色构成,给定一个数K,要求计算出至少经过1条黑色边且通过K个点的路径总数(mod1e9+7)。
思路: 可以用所有路径可能总数减去未通过黑色边的路径总数,所有路径总数为pow(n,k),未通过黑色边的路径总数为每一个红色边连通块的总结点个数m的k次方:∑pow(m,k)。
AC代码如下:
#include
using namespace std;
typedef long long ll;
#define Max 200005
#define mod 1000000007
int n, ans, k, f, mmp[Max], u, sum[Max];
int Find(int x)
{
return x == mmp[x] ? x : mmp[x] = Find(mmp[x]);
}
int poww(int a, int b)
{
ll ans = 1, base = a;
while (b != 0)
{
if (b & 1 != 0)
ans *= base, ans %= mod;
base *= base, base %= mod;
b >>= 1;
}
return ans % mod;
}
int main()
{
cin >> n >> k;
for(int i = 1; i <= n; i++)
mmp[i] = i, sum[i] = 1;
for(int a, b, c, i = 1; i < n; i++)
{
cin >> a >> b >> c;
if(!c)
{
int x = Find(a), y = Find(b);
mmp[x] = y, sum[y] += sum[x];
}
}
for(int i = 0; i <= n; i++)
if(mmp[i] == i)
(ans += poww(sum[i], k)) %= mod;
printf("%d", (poww(n, k) - ans + mod) % mod);
return 0;
}