算法竞赛进阶指南,380页,树上差分
本题要点:
1、 附加边 (x, y) 把 x, y之间的路径上的每条边都"覆盖了一次", 需要统计每条主要边被覆盖了几次;
a) 主要边被覆盖0次,那么将该主要边打断,然后随意打断一条附加边即可;
b) 主要边被覆盖1次,那么将该主要边打断,只有唯一的 附加边被打断才能得到不连通的两部分;
c) 主要边被覆盖2次及以上,无论如何操作都不能打败 Dark
2、 这里涉及的是树上边的差分:
给每个节点初始为0的权值,然后对每条附加边(x, y), 使x点的权值加1, y节点的权值加1, LCA(x, y) 的权值减1;
深度递归, F[x] 表示以x为根的子树的各个几点的权值之和,那么 F[x] 就是x和它的父节点之间的树边的权值。
树上差分参考: https://www.cnblogs.com/TEoS/p/11376676.html
#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
#include <cmath>
using namespace std;
const int MaxN = 100010;
int ver[MaxN * 2], head[MaxN * 2], Next[MaxN * 2];
long long edge[MaxN * 2];
int f[MaxN][20]; //f[i][k] 表示 点i 向上走 2^k 步到达的节点号(n个节点,节点编号从1到n)
int d[MaxN]; //d[i] 点i的深度
long long val[MaxN]; // val[i] 表示第i点的权值
bool vis[MaxN];
int n, m, depth; //depth 表示树的深度
int tot;void add(int x, int y, int z)
{
ver[++tot] = y, edge[tot] = z, Next[tot] = head[x], head[x] = tot;
}void bfs()
{
d[1] = 1;queue<int> q;q.push(1);while(!q.empty()){
int x = q.front();q.pop();for(int i = head[x]; i; i = Next[i]){
int y = ver[i];if(d[y]){
continue;}d[y] = d[x] + 1;f[y][0] = x; //点y的父节点是xfor(int t = 1; t <= depth; ++t){
f[y][t] = f[f[y][t - 1]][t - 1]; }q.push(y);}}
}int lca(int x, int y) //节点x和y的最近公共祖先
{
if(d[x] > d[y]){
swap(x, y); //使得 d[x] <= d[y], 然后调整 y}for(int t = depth; t >= 0; --t)//这个循环的目的是,使得y和x处于同一高度{
if(d[f[y][t]] >= d[x]){
y = f[y][t];}}if(x == y){
return x;}for(int t = depth; t >= 0; --t){
if(f[x][t] != f[y][t]){
x = f[x][t], y = f[y][t];}}return f[x][0];
}void dfs(int x)
{
vis[x] = true;int sum = 0;for(int i = head[x]; i; i = Next[i]){
int y = ver[i];if(vis[y]){
continue;}dfs(y);sum += val[y];edge[i] = val[y];}val[x] += sum;
}int main()
{
int x, y;scanf("%d%d", &n, &m);depth = (int)(log(n) / log(2)) + 1;tot = 0;for(int i = 1; i <= n; ++i){
vis[i] = val[i] = head[i] = d[i] = 0;}for(int i = 1; i < n; ++i){
scanf("%d%d", &x, &y); add(x, y, 0), add(y, x, 0);}bfs();for(int i = 0; i < m; ++i){
scanf("%d%d", &x, &y);val[x]++, val[y]++;val[lca(x, y)] -= 2;}dfs(1);long long ans = 0;for(int i = 2; i <= n; ++i){
if(0 == val[i]){
ans += m; }else if(1 == val[i]){
++ans;}}printf("%lld\n", ans);return 0;
}/* 4 1 1 2 2 3 1 4 3 4 *//* 3 */