题目链接
膜拜ydc大神的题解。
对于1000000以内的互质勾股数两两建边。
即在这个森林中选择不相邻的点,可以树形DP解决。
但是并不是树,可能有回边。
那么暴力枚举回边相连的点选还是不选,然后在这个基础上跑树形DP。
并不清楚bzoj上跑得最快的那些人是怎么做的。。但网上只能找到几篇题解就这么写了。
#include
using namespace std;
inline int read(){
int x = 0, f = 1; char c = getchar();
while(!isdigit(c)) { if(c == '-') f = -1; c = getchar(); }
while(isdigit(c)) { x = x * 10 + c - '0'; c = getchar(); }
return x * f;
}
typedef long long ll;
const int N = 1000000 + 10, H = 1000000, mod = 1000000000 + 7;
int a[N], t[N], n;
int tot = 1, to[N], nxt[N], hd[N], fa[N];
int sp[N], s, se[N], wt[N];
ll sq2[N], f[N][2], ans = 1;
bool vis[N], mark[N], must[N], out[N];
inline ll gcd(ll x, ll y){ return x % y ? gcd(y, x % y) : y; }
inline int insert(ll u, ll v){
to[++tot] = v; nxt[tot] = hd[u]; hd[u] = tot;
to[++tot] = u; nxt[tot] = hd[v]; hd[v] = tot;
}
void init(){
for(ll i = 2; (i << 1) <= H; i++)
for(ll j = 1; i * j * 2 <= H && j < i; j++)
if((i & 1) != (j & 1) && gcd(i, j) == 1 && i * i - j * j <= H)
insert(i * i - j * j, 2 * i * j);
sq2[0] = 1;
n = read();
for(int i = 1; i <= n; i++){
a[i] = read();
t[a[i]]++;
sq2[i] = (sq2[i-1]<<1) % mod;
}
sort(a + 1, a + n + 1);
}
void treedp(int u){
f[u][0] = 1, f[u][1] = sq2[t[u]]-1;
if(mark[u] && must[u]) f[u][0] = 0;
if(mark[u] && !must[u]) f[u][1] = 0;
for(int i = hd[u]; i; i = nxt[i]){
if(!out[i] && to[i] != fa[u] && t[to[i]]){
int v = to[i];
treedp(v);
f[u][0] = f[u][0] * (f[v][0] + f[v][1]) % mod;
f[u][1] = f[u][1] * f[v][0] % mod;
}
}
}
void dfs1(int u){
vis[u] = true, sp[++s] = u;
for(int i = hd[u]; i; i = nxt[i])
if(t[to[i]] && to[i] != fa[u]){
if(!vis[to[i]]) fa[to[i]] = u, dfs1(to[i]);
else out[i] = true, se[++se[0]] = i, mark[to[i]] = mark[u] = true;
}
}
void dfs2(int p, int n, ll &ans){
if(p > n){
for(int i = 1; i <= se[0]; i++){
int e = se[i], x = to[e], y = to[e^1];
if(must[x] && must[y]) return;
}
treedp(sp[1]);
ans = (ans + f[sp[1]][0] + f[sp[1]][1]) % mod;
return;
}
must[wt[p]] = true, dfs2(p+1, n, ans);
must[wt[p]] = false, dfs2(p+1, n, ans);
}
ll solve(int u){
s = 0, se[0] = 0, wt[0] = 0;
dfs1(u);
ll ret = 0;
for(int i = 1; i <= s; i++)
if(mark[sp[i]])
wt[++wt[0]] = sp[i];
dfs2(1, wt[0], ret);
return ret;
}
void work(){
for(int i = 1; i <= n; i++)
if(!vis[a[i]])
ans = ans * solve(a[i]) % mod;
printf("%lld\n", ans-1);
}
int main(){
init();
work();
return 0;
}