Time Limit: 2000/2000 MS (Java/Others) Memory Limit: 524288/524288 K (Java/Others)
Total Submission(s): 1385 Accepted Submission(s): 478
Statistic | Submit | Discuss | Note
有n个人,每个人是2跟1,问你选三个人总和超过5的方案数,但是要求选的三个人都不认识,
输入n - 1次,表示u跟v会认识,认识是相互的且会传递。
并查集来维护关系与数目、每次考虑将两堆人合并之后,对总的方案数会减少多少。
#include
#include
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
#ifdef LOCAL
#define debug(x) cout << "[" __FUNCTION__ ": " #x " = " << (x) << "]\n"
#define TIME cout << "RuningTime: " << clock() << "ms\n", 0
#else
#define TIME 0
#endif
#define hash_ 1000000009
#define Continue(x) { x; continue; }
#define Break(x) { x; break; }
const int mod = 1e9 + 7;
const int N = 2e5 + 10;
const int INF = 0x3f3f3f3f;
const ll LINF = 0x3f3f3f3f3f3f3f3f;
#define gc p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1000000, stdin), p1 == p2) ? EOF : *p1++;
inline int read(){ static char buf[1000000], *p1 = buf, *p2 = buf; register int x = false; register char ch = gc; register bool sgn = false; while (ch != '-' && (ch < '0' || ch > '9')) ch = gc; if (ch == '-') sgn = true, ch = gc; while (ch >= '0'&& ch <= '9') x = (x << 1) + (x << 3) + (ch ^ 48), ch = gc; return sgn ? -x : x; }
ll fpow(ll a, int b, int mod) { ll res = 1; for (; b > 0; b >>= 1) { if (b & 1) res = res * a % mod; a = a * a % mod; } return res; }
const ll INV2 = 500000004;
const ll INV6 = 166666668;
int a[N];
int b[N];
int pre[N];
int find(int x)
{
return x == pre[x] ? x : pre[x] = find(pre[x]);
}
ll two(ll x)
{
return x * (x - 1) % mod * INV2 % mod;
}
ll three(ll x)
{
return x * (x - 1) * (x - 2) % mod * INV6 % mod;
}
int main()
{
#ifdef LOCAL
// freopen("C:/input.txt", "r", stdin);
#endif
int t;
cin >> t;
while (t--)
{
int n;
cin >> n;
ll cnt1 = 0, cnt2 = 0;
for (int i = 1; i <= n; i++)
a[i] = b[i] = 0;
for (int i = 1; i <= n; i++)
{
int num;
pre[i] = i;
scanf("%d", &num);
num == 2 ? ++a[i], ++cnt2 : (++b[i], ++cnt1);
}
ll ans = (cnt2 * (cnt2 - 1) / 2 % mod * cnt1 % mod + cnt2 * (cnt2 - 1) * (cnt2 - 2) / 6 % mod) % mod;
printf("%lld\n", ans);
for (int i = 1; i <= n - 1; i++)
{
int u, v;
scanf("%d%d", &u, &v);
u = find(u);
v = find(v);
if (u == v || !ans)
Continue(printf("%lld\n", ans))
ll a1 = a[u], b1 = b[u], a2 = a[v], b2 = b[v];
ll c1 = cnt1 - b1 - b2;
ll c2 = cnt2 - a1 - a2;
ans = (ans - c2 * a1 * a2 % mod + mod) % mod;//2 2 2
ans = (ans - c2 * b1 * a2 % mod + mod) % mod;//2 1 2
ans = (ans - c2 * a1 * b2 % mod + mod) % mod;//2 2 1
ans = (ans - c1 * a1 * a2 % mod + mod) % mod;//1 2 2
printf("%lld\n", ans);
pre[u] = v;
a[v] += a[u];
b[v] += b[u];
}
}
return TIME;
}