Home ACM & 2023蓝桥杯省赛CB组讲解
Post
Cancel

ACM & 2023蓝桥杯省赛CB组讲解

视频讲解:Click here

Solution pdf:Click here

题目pdf下载:Click here

补题OJ:

  1. http://oj.ecustacm.cn/
  2. https://www.dotcpp.com/oj/train/

A:日期统计

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#include<bits/stdc++.h>
using namespace std;

const int N = 110;

int a[N];
// int a[100] = {5, 6, 8, 6, 9, 1, 6, 1, 2, 4, 9, 1, 9, 8, 2, 3, 6, 4, 7, 7, 5, 9, 5, 0, 3, 8, 7, 5, 8, 1, 5, 8, 6, 1, 8, 3, 0, 3, 7, 9, 2, 7, 0, 5, 8, 8, 5, 7, 0, 9, 9, 1, 9, 4, 4, 6, 8, 6, 3, 3, 8, 5, 1, 6, 3, 4, 6, 7, 0, 7, 8, 2, 7, 6, 8, 9, 5, 6, 5, 6, 1, 4, 0, 1, 0, 0, 9, 4, 8, 0, 9, 1, 2, 8, 5, 0, 2, 5, 3, 3,};

void solve() {
	
    for (int i = 0; i < 100; i++) cin >> a[i];
    
    int year[4] = {2, 0, 2, 3}; //先找到2023
    int p = 0; //表示当前配对到第几个数
    int st; //记录下标,然后月份从这个下标开始配对
    for (int i = 0; i < 100; i++) {
        if (a[i] == year[p]) p++; //配对到一个,就准备配对下一个
        if (p == 4) { //说明此时年份的四个数都匹配到了
            st = i + 1; //实际算出st=59
            break; //也就是后面的月份和天数只用跑一半。
        }
    }
    
    int day[13] = {0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31};
	//每个月的天数,注意下标从1开始而不是0。
    int ans = 0; //符合条件的日期数量
    for (int i = 1; i <= 12; i++) { //月份
        for (int j = 1; j <= day[i]; j++) { //天数
            string s; //用字符串去存每天的日期,如1月1日对应的就是s=0101;
            s += i / 10 + '0'; //月份的第一个数
            s += i % 10 + '0'; //月份的第二个数
            s += j / 10 + '0'; //天数的第一个数
            s += j % 10 + '0'; //天数的第二个数
            
            int p = 0;
            for (int k = st; k < 100; k++) { //后续的日期从下标为st的数开始找
                if (a[k] == s[p] - '0') p++;
                if (p == 4) { //这个日期匹配成功
                    ans++; //符合条件的日期数加一。
                    break;
                }
            }
        }
    }
    cout << ans << endl; //最终输出答案为235
}

int main() {
    solve();
    return 0;
}

B:01串的熵

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#include<bits/stdc++.h>
using namespace std;
double n = 23333333;
double H(double a) {
	return - a/n * log2(a/n) * a - (n-a)/n * log2((n-a)/n) * (n-a);
}
void solve() {
	for(double a = 1; a < n/2; a += 1.0) {
		if(fabs(H(a) - 11625907.5798) <= 5e-5) {
			cout << (int)a << endl;
			return;
		}
	}
}
int main() {
	solve();
	return 0;
}

C:冶炼金属

二分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e4+10;
int n;
 
int find_l(int x,int y){
	int l=0,r=1e9+1;
	while(l+1<r){
		int mid=l+r>>1;
		if(x/mid>y) l=mid;
		else r=mid;
	}
	return r;
}

int find_r(int x,int y){
	int l=0,r=1e9+1;
	while(l+1<r){
		int mid=l+r>>1;
		if(x/mid>=y) l=mid;
		else r=mid;
	}
	return l;
}

int main(){
	cin>>n;
	int l=0,r=0x3f3f3f3f;
	while(n--){
		int x,y;
		cin>>x>>y;
		//左端点找最大值 
		l=max(l,find_l(x,y));
		//右端点找最小值 
		r=min(r,find_r(x,y));
	}
	cout<<l<<" "<<r;
	return 0;
}

公式+部分枚举

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e4+10;
int n;

int main(){
	cin>>n;
	for(int i=1;i<=n;i++){
		cin>>num[i][1]>>num[i][2];
	}
	int ans=0x3f3f3f3f;
	for(int i=1;i<=n;i++){
		ans=min(ans,num[i][1]/num[i][2]);
	} 
	ll tmp=ans;
	while(tmp>0){
		bool flag=true;
		for(int i=1;i<=n;i++){
			if(num[i][1]/tmp!=num[i][2]){
				flag=false;
				break;
			}
		}
		if(!flag) break;
		else tmp--;
	}
	cout<<tmp+1<<" "<<ans;
	return 0;
}

完全公式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e4+10;
int n;

int main(){
	cin>>n;
	for(int i=1;i<=n;i++){
		cin>>num[i][1]>>num[i][2];
	}
	int ans=0;
	for(int i=1;i<=n;i++){
		//求左端点找最大值 
		ans=max(ans,num[i][1]/(num[i][2]+1)+1);
	}
	cout<<ans<<" ";
	ans=0x3f3f3f3f3f;
	for(int i=1;i<=n;i++){
		//求右端点找最小值 
		ans=min(ans,num[i][1]/num[i][2]);
	} 
	cout<<ans;
	return 0;
}

D:飞机降落

next_permutation

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include <iostream>
#include <algorithm>
using namespace std;
int testCase, n, a[100];
struct NODE {
    int T, D, L;
} node[11];
bool check(const int& n) {
	int nowTime = 0; // 当前时间 
	for(int i=1; i<=n; i++) {
		int nowID = a[i]; // 根据排列获得当前点飞机是哪一架 
		// 如果当前时间已经超过当前飞机的最晚降落时间,那就不可行 
		if(nowTime>node[nowID].T+node[nowID].D) return false; 
		// 更新当前时间,注意要取max,因为又可能当前时间还没到飞机的可降落时间
		nowTime = max(nowTime, node[nowID].T) + node[nowID].L;  
	}
	return true;
}
int main() {
    cin >> testCase;
    while(testCase--) {
        bool flag = false;
        cin >> n;
        for(int i=1; i<=n; i++) {
        	cin >> node[i].T >> node[i].D >> node[i].L; 
        	a[i] = i; // 生成初始排列,[1, 2, ..., n] 
		}
        do {
        	if(check(n)) {
        		flag = true; // 标记为有解 
        		break;
			}
		} while(next_permutation(a+1, a+n+1)); // 生成下一个排列 
        puts(flag?"YES":"NO");
    }
	return 0;
}

DFS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
#include <iostream>
using namespace std;
int testCase, n;
bool used[11];
struct NODE {
    int T, D, L;
} node[11];
bool dfs(int nowTime, int cnt) {
    if(cnt == n) return true;
    bool flag;
    for(int i=1; i<=n; i++) {
        if(used[i]) continue;
        if(nowTime>node[i].T+node[i].D) return false;
        used[i] = true;
        flag = dfs(max(nowTime, node[i].T)+node[i].L, cnt+1);
        used[i] = false;
        if(flag) return true;
    }
    return false;
}
int main() {
    cin >> testCase;
    while(testCase--) {
        cin >> n;
        for(int i=1; i<=n; i++) cin >> node[i].T >> node[i].D >> node[i].L;
        puts(dfs(0, 0)?"YES":"NO");
    }
    return 0;
}

状压dp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include<iostream>
#include<cstring>
using namespace std;
int T, n, dp[1<<10];
struct NODE {
    int T, D, L;
} node[10];
int main() {
    cin >> T;
    while(T--) {
        cin >> n;
        for(int i=0; i<n; i++) cin >> node[i].T >> node[i].D >> node[i].L;
        memset(dp, 0x3f, sizeof(dp));
        dp[0] = 0;
        for(int nowState=1; nowState<(1<<n); nowState++) {
            for(int i=0; i<n; i++) {
                if(nowState>>i&1) {
                    int lastState = nowState^(1<<i);
                    if(dp[lastState]>node[i].T+node[i].D) continue; 
                    dp[nowState] = min(dp[nowState], max(node[i].T, dp[lastState])+node[i].L);
                }
            }
        }
        puts(dp[(1<<n)-1]!=0x3f3f3f3f?"YES":"NO");
    }
    return 0;
}

E:接龙数列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include<bits/stdc++.h>
using namespace std;
#define N 10 + 20
int n,ans;
int f[N];//dp数组,f_x表示以数字x结尾的最长接龙序列长度
int main()
{
    scanf("%d",&n);
    for(int i = 1;i <= n;i ++)
    {
        //输入整数
        string s;
        cin>>s;
        //提取当前整数的首位数字和末尾数字,st为首位数字,ed为末尾数字
        int st = s[0] - '0', ed = s[(int)s.size() - 1] - '0';
        //计算当前最长以数字ed结尾的接龙数列
        f[ed] = max(f[ed],f[st] + 1);
        //计算当前最长的接龙数列
        ans = max(ans,f[ed]);
    }
    //用数列总长减去最长的接龙序列,则为我们最少需要删除的个数
    ans = n - ans;
    printf("%d\n",ans);
}

F:岛屿个数

BFS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <bits/stdc++.h>
using namespace std;

int main() {
    cin.tie(0)->sync_with_stdio(false);

    // 多测测试样例 读入数量
    int T;
    cin >> T;
    while (T--) {
        // 读入矩阵大小
        int n, m;
        cin >> n >> m;

        // 读入 01 矩阵,这里下标从 1 开始
        vector<vector<int>> g(n + 2, vector<int>(m + 2));
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= m; j++) {
                char c;
                cin >> c;
                g[i][j] = c - '0';
            }
        }

        // 定义偏移数组,[0, 4) 表示上下左右
        // [4, 8) 表示斜方向
        vector<int> dx{-1, 0, 1, 0, 1, 1, -1, -1};
        vector<int> dy{0, 1, 0, -1, 1, -1, -1, 1};

        // 这个 Dfs 函数是将最外层所在的连通块给染色成 2 
        function<void(int, int)> Dfs = [&](int u, int v) {
            // 先染色成 2
            g[u][v] = 2;

            // 然后将 8 个方向的节点都染一染
            for (int i = 0; i < 8; i++) {
                int nu = u + dx[i], nv = v + dy[i];
                // 如果当前这个节点没有越界 并且 值等于 0,等于 0 表示是与最外层同一个连通块
                if (0 <= nu && nu <= n + 1 && 0 <= nv && nv <= m + 1 && !g[nu][nv]) {
                    // Dfs 下去染色
                    Dfs(nu, nv);
                }
            }
        };
        Dfs(0, 0);

        // 这个 dfs 函数是将非 2 的元素所在的连通块染色成 2 
        function<void(int, int)> dfs = [&](int u, int v) {
            // 染成 2 
            g[u][v] = 2;

            // 枚举上下左右 4 个方向,跟上面 8 个方向不一样
            for (int i = 0; i < 4; i++) {
                int nu = u + dx[i], nv = v + dy[i];

                // 如果合法,并且不是 2,那么就同在一个连通块
                if (0 <= nu && nu <= n + 1 && 0 <= nv && nv <= m + 1 && g[nu][nv] != 2) {
                    dfs(nu, nv);
                }
            }
        };

        // 定义一个答案
        int ans = 0;

        // 枚举每个元素,如果不是 2 那么就将当前连通块染成 2 
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= m; j++) {
                if (g[i][j] != 2) {
                    // 先贡献答案,在染色
                    ans += 1;
                    dfs(i, j);
                }
            }
        }
        
        cout << ans << "\n";
    }
    return 0;
}

dijkstra

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#include <bits/stdc++.h>
using namespace std;

const vector<int> dx{-1, 0, 1, 0, 1, 1, -1, -1};
const vector<int> dy{0, 1, 0, -1, 1, -1, -1, 1};
const int inf = 0x3f3f3f3f;

int main() {
    cin.tie(0)->sync_with_stdio(false);
    int T;
    cin >> T;
    while (T--) {
        int n, m;
        cin >> n >> m;

        // 读入矩阵
        vector<vector<int>> g(n + 2, vector<int>(m + 2));
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= m; j++) {
                char c;
                cin >> c;
                g[i][j] = c - '0';
            }
        }

        using T = array<int, 3>;

        // 因为边权不相同,所以我们可以跑一个 Dijkstra 或者 01-BFS
        // 定义距离矩阵
        vector<vector<int>> dis(n + 2, vector<int>(m + 2, inf));

        // 跑一个 dijkstra 源点为 (0, 0)
        // 将 (0, 0) 的距离赋成 0
        dis[0][0] = 0;

        // 定义一个小根堆
        priority_queue<T, vector<T>, greater<T>> q;

        // 将源点入队
        q.push({0, 0, 0});

        while (!q.empty()) {
            // 取出堆顶
            auto t = q.top();
            q.pop();

            int d = t[0], u = t[1], v = t[2];

            // 建图
            for (int i = 0; i < 8; i++) {
                int nu = u + dx[i], nv = v + dy[i];

                // 如果合法的话
                if (0 <= nu && nu <= n + 1 && 0 <= nv && nv <= m + 1) {
                    int w = g[u][v] ^ g[nu][nv];

                    // 如果是上下左右四方向直接连边,然后是最短路的松弛操作
                    if (i < 4 && dis[nu][nv] > d + w) {
                        dis[nu][nv] = d + w;
                        q.push({d + w, nu, nv});
                    }

                    // 如果是斜着的方向,此时就特殊处理 0 -> 0 的边
                    // 同样是松弛操作
                    if (i >= 4 && !g[u][v] && !g[nu][nv] && dis[nu][nv] > d) {
                        dis[nu][nv] = d + w;
                        q.push({d + w, nu, nv});
                    }
                }
            }
        }

        // 将当前连通块全部染成 0
        function<void(int, int)> dfs = [&](int u, int v) {
            g[u][v] = 0;
            for (int i = 0; i < 4; i++) {
                int nu = u + dx[i], nv = v + dy[i];

                // 如果合法并且非 0,那就染色
                if (1 <= nu && nu <= n && 1 <= nv && nv <= m && g[nu][nv]) {
                    dfs(nu, nv);
                }
            }
        };

        // 定义答案
        int ans = 0;

        // 如果距离恰好为 1,并且没有被染色,那就染色成 0 
        for (int i = 1; i <= n; i++) {
            for (int j = 1; j <= m; j++) {
                if (dis[i][j] == 1 && g[i][j]) {
                    // 计数并且染色
                    ans += 1;
                    dfs(i, j);
                }
            }
        }

        cout << ans << "\n";
    }
    return 0;
}

G:子串简写

二分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include <bits/stdc++.h>
#define ll long long

using namespace std;
string s;
char c1, c2;
int k, n;
ll ans;

int main () {
    cin >> k >> s >> c1 >> c2;
    n = s.size ();
    vector<int> v;
    for (int i = 0; i < n; i++) {
        if (s[i] == c2)     v.push_back (i); 
    }
    v.push_back (n); 
    for (int i = 0; i < n; i++) {
        if (s[i] != c1)     continue;
        int cnt = v.end () - lower_bound (v.begin (), v.end (), max (k + i - 1, i + 1)) - 1;
        //cout << cnt << endl;
        ans += max (0, cnt);
    }
    cout << ans << endl;
}

//j-i+1 >= k && j-i >= 1
//j >= k+i-1 && j >= i+1

后缀和

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#include <bits/stdc++.h>
#define ll long long

using namespace std;
const int N = 5e5 + 5;
string s;
char c1, c2;
int k, n, sum[N]; //sum[i]:[i,n-1]中有多少个c2
ll ans;

int main () {
    cin >> k >> s >> c1 >> c2;
    n = s.size ();
    for (int i = n - 1; i >= 0; i--) {
        sum[i] = sum[i+1];
        if (s[i] == c2)     sum[i] ++;
    }
    for (int i = 0; i < n; i++) {
        if (s[i] == c1) {
            int r = min(n, max (k + i - 1, i + 1));
            ans += sum[r];
        }
    }
    cout << ans << endl;
}

H:整数删除

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#include<iostream>
#include<stdio.h>
#include<algorithm>
#include<string.h>
#include<math.h>
#include<time.h>
#include<vector>
#include<queue>
#include <map>
#include <set>
using namespace std;
#define ll long long
const int NUM=1e6+10;

struct QNode
{
  ll pos,num;//pos为数组下标,num为存储数字
};
bool operator <(QNode const &a,QNode const &b)
{
  if(a.num==b.num)
    return a.pos > b.pos;
  return a.num > b.num;
}

priority_queue<QNode>q;
int n,k;
int vis[NUM] = {0};
struct Num
{
  ll front,back,num;
}num[NUM];//数组摸拟链表

inline void init()
{
  cin>>n>>k;
  for(int i = 1; i <= n; i++)
  {
    cin>>num[i].num;
    num[i].front = i-1;
    num[i].back = i + 1;
    q.push({i,num[i].num});//先把每一个元素都塞进队列
  }
}

inline void solve()
{
  int op = 0;
  while(!q.empty()&&op<k)//判断删除次数是否足够
  {
    QNode tmp = q.top();//取出头部元素
    q.pop();
    if(vis[tmp.pos]) continue;//如果取出的节点已经被删了,则出错
    if(tmp.num!=num[tmp.pos].num) continue;//存储的节点数值和当前节点不符合
    vis[tmp.pos]  = 1;//删除节点
    //修改前后节点
    num[num[tmp.pos].front].num += tmp.num;
    num[num[tmp.pos].back].num += tmp.num;
    num[num[tmp.pos].front].back = num[tmp.pos].back;
    num[num[tmp.pos].back].front = num[tmp.pos].front;
    if(num[tmp.pos].front > 0)//判断是否为头节点
      q.push({num[tmp.pos].front,num[num[tmp.pos].front].num});
    if(num[tmp.pos].back <= n)//判断是否是尾节点
      q.push({num[tmp.pos].back,num[num[tmp.pos].back].num});
    op++;
  }
  //输出答案部分
  int begin = 1;
  while(vis[begin])//找到头节点
    begin ++;
  //输出
  while(begin<= n)
  {
    cout<<num[begin].num<<" ";
    begin = num[begin].back;
    
  }
}


int main()
{
  int T = 1;
  while(T--)
  {
    init();
    solve();
  }
}

I:景区导游

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <bits/stdc++.h>
#define int long long
#define N 520203
using namespace std;
struct edge
{
    int dis, to, next;
} e[N << 1];
inline int read()
{
    int x = 0, w = 0;
    char ch = getchar();
    while (!isdigit(ch))
        w |= ch == '-', ch = getchar();
    while (isdigit(ch))
        x = (x << 1) + (x << 3) + ch - '0', ch = getchar();
    return w ? -x : x;
}
int cnt, sum1[N], sum2[N], a[N], len[N], head[N], fa[N], dep[N], size[N], son[N], top[N];
inline void add(int u, int v, int d)
{
    e[++cnt].to = v;
    e[cnt].dis = d;
    e[cnt].next = head[u];
    head[u] = cnt;
}
void dfs(int u)
{
    size[u] = 1;
    for (int i = head[u]; i; i = e[i].next)
    {
        int v = e[i].to;
        if (v == fa[u])
            continue;
        fa[v] = u;
        dep[v] = dep[u] + 1;
        len[v] = len[u] + e[i].dis;
        dfs(v);
        size[u] += size[v];
        if (size[v] > size[son[u]])
            son[u] = v;
    }
}
void dfs1(int u, int tp)
{
    top[u] = tp;
    if (son[u])
        dfs1(son[u], tp);
    for (int i = head[u]; i; i = e[i].next)
    {
        int v = e[i].to;
        if (v == fa[u] || v == son[u])
            continue;
        dfs1(v, v);
    }
}
int lca(int x, int y)
{
    while (top[x] != top[y])
    {
        if (dep[top[x]] < dep[top[y]])
            swap(x, y);
        x = fa[top[x]];
    }
    return dep[x] < dep[y] ? x : y;
}
int dis(int x, int y)
{
    if (!x || !y)
        return 0;
    return len[x] + len[y] - 2 * len[lca(x, y)];
}
signed main()
{
    int n = read(), k = read();
    for (int i = 1; i < n; ++i)
    {
        int u = read(), v = read(), d = read();
        add(u, v, d);
        add(v, u, d);
    }
    dfs(1);
    dfs1(1, 1);
    int ans = 0;
    for (int i = 1; i <= k; ++i)
        a[i] = read(), ans += dis(a[i - 1], a[i]);
    for (int i = 1; i <= k; ++i)
        printf("%lld ", ans - dis(a[i - 1], a[i]) - dis(a[i], a[i + 1]) + dis(a[i - 1], a[i + 1]));
    return 0;
}

J:砍树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int INF = 0x3f3f3f3f;

struct RMQLCA {
    int MAXJ;
    int n;
    vector<vector<int>> edges;
    vector<int> depth;
    vector<vector<int>> fa;

    RMQLCA() {}
    RMQLCA(int _n, const vector<vector<int>>& _edges) 
        :n(_n), edges(_edges), MAXJ(log2(_n) + 1) {
        depth = vector<int>(n + 1, INF);
        fa = vector<vector<int>>(n + 1, vector<int>(MAXJ + 1));
    }

    void bfs(int root) {  // 预处理depth[], fa[][]: 根节点为root
        depth[0] = 0, depth[root] = 1;

        queue<int> que;
        que.push(root);

        while (que.size()) {
            auto u = que.front(); que.pop();
            for (auto v : edges[u]) {
                if (depth[v] > depth[u] + 1) {
                    depth[v] = depth[u] + 1;
                    que.push(v);

                    fa[v][0] = u;
                    for (int k = 1; k <= MAXJ; k++)
                        fa[v][k] = fa[fa[v][k - 1]][k - 1];
                }
            }
        }
    }

    int lca(int u, int v) {
        if (depth[u] < depth[v]) swap(u, v);  // 保证节点u深度大

        // 将节点u与节点v跳到同一深度
        for (int k = MAXJ; k >= 0; k--)
            if (depth[fa[u][k]] >= depth[v]) u = fa[u][k];

        if (u == v) return u;  // u与v原本在一条链上

        // u与v同时跳到LCA的下一层
        for (int k = MAXJ; k >= 0; k--) {
            if (fa[u][k] != fa[v][k])
                u = fa[u][k], v = fa[v][k];
        }
        return fa[u][0];  // 节点u向上跳1步到LCA
    }
};

struct Solver {
    int n;
    vector<vector<int>> edges;
    vector<int> w;  // 节点的点权
    int root = 1;  // 任取
    RMQLCA sol;
    
    Solver(int _n, const vector<vector<int>>& _edges) :n(_n), edges(_edges) {
        w.resize(n + 1);
        sol = RMQLCA(n, edges);
        sol.bfs(root);
    }

    void modify(int u, int v, int val) {  // 给节点u到节点v的简单路径上的点权 + val
        int lca = sol.lca(u, v);
        w[u] += val, w[v] += val, w[lca] -= 2 * val;
    }

    void dfs(int u, int pre) {  // 对差分数组求前缀和得到原数组: 当前节点、前驱节点
        for (auto v : edges[u]) {
            if (v != pre) {
                dfs(v, u);
                w[u] += w[v];
            }
        }
    }

    int get(int u, int v) {  // 返回节点u和节点v中深度较大的节点的权值
        return sol.depth[u] > sol.depth[v] ? w[u] : w[v];
    }
};

void solve() {
    int n, m; cin >> n >> m;
    vector<tuple<int, int, int>> eds;  // id, u, v
    vector<vector<int>> edges(n + 1);
    for (int i = 1; i < n; i++) {
        int u, v; cin >> u >> v;

        eds.push_back({ i, u, v });
        edges[u].push_back(v), edges[v].push_back(u);
    }

    Solver solver(n, edges);
    for (int i = 0; i < m; i++) {
        int u, v; cin >> u >> v;
        solver.modify(u, v, 1);
    }

    solver.dfs(solver.root, 0);
    for (int i = n - 2; i >= 0; i--) {  // 倒序枚举边
        // C++17支持结构化绑定
        auto [id, u, v] = eds[i];

        // C++11(蓝桥杯)写法
        // auto it = eds[i];
        // int id = get<0>(it), u = get<1>(it), v = get<2>(it);

        if (solver.get(u, v) == m) {  // 边被覆盖m次
            cout << id << endl;
            return;
        }
    }
    cout << -1 << endl;
}

int main() {
    solve();
    return 0;
}
This post is licensed under CC BY 4.0 by the author.

鱼书 —— 《深度学习入门》读书笔记

-

Comments powered by Disqus.