原题地址
当时比赛的时候,在观战席看出来了应该要用启发式合并,后来自己敲的时候想当然地使用了STL。由于不仅要记录结点的值,还要记录结点的id,所以必须定义一个结构体,但是重载结构体<时出现了许多问题。首先是set.insert()不了,后来查了博客发现重载函数必须写成这样
bool operator <(const nod& u)const
{
return a < u.a;
}
才能insert()进去,而且set的迭代器无法直接相减(set内部存储空间并非相邻的)。用set的另一个问题是,如果我的重载函数中不比较id的话,所有a相同而id不同的对象会被认为是相同的,所以最后只能用vector来维护,代码如下:
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define maxn 100010
#define mkp make_pair
#define inf 1e6
typedef long long ll;
const ll mod = 1e9 + 7;
const double pi = acos(-1.0);
int n, son[maxn], sz[maxn], Son, Fa, qq[maxn];
vector<int>gg[maxn];
ll ans[maxn];
struct nod
{
ll a;
int id;
nod(int r, int v)
{
a = r;
id = v;
}
bool operator <(nod u)
{
return a < u.a;
}
};
vector<nod>s;
vector<nod>ww;
void addedge(int u, int v)
{
gg[u].push_back(v);
gg[v].push_back(u);
}
void dfs1(int x, int f)//树链剖分找重儿子
{
sz[x] = 1;
for (int i = 0; i < gg[x].size(); i++)
{
int tx = gg[x][i];
if (tx == f) continue;
dfs1(tx, x);
sz[x] += sz[tx];
if (sz[tx] > sz[son[x]]) son[x] = tx;
}
}
void add(int x, int f)
{
int i, tx, j;
for (i = 0; i < gg[x].size(); i++)
{
tx = gg[x][i];
if (tx == Son || tx == f) continue;
ww.push_back(nod(qq[tx], tx));
add(tx, x);
if (x == Fa)
{
for (j = 0; j < ww.size(); j++)
{
int opt = qq[Fa] ^ ww[j].a;
int lt = lower_bound(s.begin(), s.end(), nod(opt, ww[j].id)) - s.begin();
if (lt == s.size()) continue;
while (s[lt].a == opt)
{
ans[x] = ans[x] + (ll)(s[lt].id ^ ww[j].id);
++lt;
if (lt == s.size()) break;
}
}
for (j = 0; j < ww.size(); j++)
{
s.push_back(ww[j]);
}
sort(s.begin(), s.end());
ww.clear();
}
}
}
void dfs2(int x, int f, bool keep)//统计
{
int i, tx;
for (i = 0; i < gg[x].size(); i++)
{
tx = gg[x][i];
if (tx == f) continue;
if (tx != son[x]) dfs2(tx, x, false);
}
if (son[x])
{
dfs2(son[x], x, true);
}
Son = son[x];
Fa = x;
add(x, f);
s.push_back(nod(qq[x], x));
sort(s.begin(), s.end());
if (!keep) s.clear();
}
int main()
{
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int i, j, r, l;
ll sum = 0;
cin >> n;
for (i = 1; i <= n; i++) cin >> qq[i];
for (i = 1; i < n; i++)
{
cin >> l >> r;
addedge(l, r);
}
dfs1(1, 0);
dfs2(1, 0, true);
for (i = 1; i <= n; i++) sum = sum + ans[i];
cout << sum << endl;
return 0;
}
个人认为这个时间复杂度时候可以被接受的,DSU ON TREE的时间复杂度本质是O(lg(N)O(处理每颗子树的时间)),那么在理想情况下,sort的复杂度lg(N),那么总体的复杂度为O(Nlg(N)^2),但显然这个复杂度过于理想化了,最终该代码TLE在第7个测试点。
后来看了眼网上的题解,大体思路我一直是对的,但确实没想到去分解存储id的贡献值。而且之前自己的DSU ON TREE板子不够优化,一定程度上影响了思考方向。
AC代码:
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define maxn 100010
#define mkp make_pair
#define inf 1e6
typedef long long ll;
const ll mod = 1e9 + 7;
const double pi = acos(-1.0);
int n, son[maxn], sz[maxn], Son, qq[maxn], Fa;
vector<int>gg[maxn];
ll ans[maxn];
struct nod
{
ll a;
int id;
nod(int r, int v)
{
a = r;
id = v;
}
bool operator <(const nod& u)const
{
return a < u.a;
}
};
int cnt[10 * maxn][20][2];
vector<nod>ww;
void addedge(int u, int v)
{
gg[u].push_back(v);
gg[v].push_back(u);
}
void dfs1(int x, int f)//树链剖分找重儿子
{
sz[x] = 1;
for (int i = 0; i < gg[x].size(); i++)
{
int tx = gg[x][i];
if (tx == f) continue;
dfs1(tx, x);
sz[x] += sz[tx];
if (sz[tx] > sz[son[x]]) son[x] = tx;
}
}
void add(int x, int f)
{
int i, tx, j, r;
int opt = qq[Fa] ^ qq[x];
if (opt <= 1e6)
{
for (j = 0; (1 << j) <= n; j++)
{
int ct = (1 << j);
if (ct & x) ans[x] = ans[x] + 1ll * ct * cnt[opt][j][0];
else ans[x] = ans[x] + 1ll * ct * cnt[opt][j][1];
}
}
for (i = 0; i < gg[x].size(); i++)
{
tx = gg[x][i];
if (tx == f) continue;
add(tx, x);
}
}
void update(int x, int f, int k)
{
int i, j, r;
for (r = 0; (1 << r) <= n; r++)//这里一开始写成了<=x,WA在了第3个点
{
int ct = (1 << r);
if (ct & x) cnt[qq[x]][r][1] += k;
else cnt[qq[x]][r][0] += k;
}
for (i = 0; i < gg[x].size(); i++)
{
int tx = gg[x][i];
if (tx == f) continue;
update(tx, x, k);
}
}
void dfs2(int x, int f, bool keep)//统计
{
int i, tx;
for (i = 0; i < gg[x].size(); i++)
{
tx = gg[x][i];
if (tx == f) continue;
if (tx != son[x]) dfs2(tx, x, false);
}
if (son[x])
{
dfs2(son[x], x, true);
}
Son = son[x];
Fa = x;
for (i = 0; i < gg[x].size(); i++)
{
int tx = gg[x][i];
if (tx == f || tx == Son) continue;
add(tx, x);
update(tx, x, 1);
}
for (int r = 0; (1 << r) <= n; r++)
{
int ct = (1 << r);
if (ct & x) cnt[qq[x]][r][1]++;
else cnt[qq[x]][r][0]++;
}
if (!keep)
{
update(x, f, -1);
}
}
int main()
{
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int i, j, r, l;
ll sum = 0;
cin >> n;
for (i = 1; i <= n; i++) cin >> qq[i];
for (i = 1; i < n; i++)
{
cin >> l >> r;
addedge(l, r);
}
dfs1(1, 0);
dfs2(1, 0, true);
for (i = 1; i <= n; i++) sum = sum + ans[i];
cout << sum << endl;
return 0;
}
原题地址
和上题有些相似,但是这题可以直接利用dfs序来做。dfs序的一个优点在于每个点的子树是连续的,这可以解决部分DSU ON TREE的问题。如果上题统计的不是(i^j)之和而仅仅统计(i,j)对数,我认为是可以仿照本题使用dfs序做的。
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std;
#define maxn 100010
#define mkp make_pair
#define inf 1e6
typedef long long ll;
const ll mod = 1e9 + 7;
const double pi = acos(-1.0);
int n, m;
vector<int>gg[maxn];
int dep[maxn], fa[maxn][20], lm[maxn], rm[maxn], cnt = 0, dfn[maxn], ans[maxn];
vector<int>d[maxn];
void dfs(int x, int de)
{
int i;
lm[x] = ++cnt;
dep[x] = de;
dfn[cnt] = x;
d[de].push_back(cnt);
for (i = 1; (1 << i) <= n; i++)
{
fa[x][i] = fa[fa[x][i - 1]][i - 1];
}
for (i = 0; i < gg[x].size(); i++)
{
int tx = gg[x][i];
dfs(tx, de + 1);
}
rm[x] = cnt;
}
int solve(int x, int p)
{
int i, tx = x;
for (i = 0; (1 << i) <= n; i++)
{
int ct = (1 << i);
if (ct & p)
{
tx = fa[tx][i];
}
}
int deep = dep[x];
int lb = lower_bound(d[deep].begin(), d[deep].end(), lm[tx]) - d[deep].begin();
int rb = upper_bound(d[deep].begin(), d[deep].end(), rm[tx]) - d[deep].begin();
return rb - lb;
}
int main()
{
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int i, j, r, l;
cin >> n;
for (i = 1; i <= n; i++)
{
cin >> fa[i][0];
gg[fa[i][0]].push_back(i);
}
dfs(0, 0);
for (i = 1; i <= n; i++)
{
if (d[i].size() == 0) break;
sort(d[i].begin(), d[i].end());
}
cin >> m;
for (i = 1; i <= m; i++)
{
cin >> r >> l;
if (l >= dep[r])
{
ans[i] = 0;
continue;
}
ans[i] = solve(r, l) - 1;
}
for (i = 1; i < m; i++) cout << ans[i] << ' ';
cout << ans[m] << endl;
return 0;
}