武汉工程大学第三届ACM程序设计选拔赛(正式赛)题解

题目链接:https://ac.nowcoder.com/acm/contest/16172

感谢牛客网为此次比赛提供在线评测环境

A. 疯狂动物城

此题改至ACWing 240 食物链
此题知识点:带权并查集

#include 

using namespace std;

const int N = 5e4 + 5, mod = 4;
int n, m, cnt;
int d[N], pre[N];

int find(int x)
{
     
    if (x != pre[x]) {
     
        int root = find(pre[x]);
        d[x] = (d[x] + d[pre[x]]) % mod;
        pre[x] = root;
    }
    return pre[x];
}

int main()
{
     
    scanf("%d %d", &n, &m);
    for (int i = 1; i <= n; i++)
        pre[i] = i;
    while (m--) {
     
        int t, x, y;
        scanf("%d %d %d", &t, &x, &y);
        if (x < 1 || x > n || y < 1 || y > n || t != 1 && x == y) {
     
            cnt++;
            continue;
        }
        int k = t - 1;
        int px = find(x), py = find(y);
        if (px == py && ((d[x] - d[y]) % mod + mod) % mod != k) {
     
            cnt++;
            continue;
        }
        if (px != py) {
     
            pre[px] = py;
            d[px] = ((k - d[x] + d[y]) % mod + mod) % mod;
        }
    }
    printf("%d\n", cnt);
    return 0;
}

方法二:扩展域并查集
作者不是很懂,只能把大佬的代码搬过来
https://ac.nowcoder.com/acm/contest/view-submission?submissionId=47751034

B. 密室逃脱

此题比较简单的做法是直接Dijkstra,如果传送门的两端都不是陷阱,并且不是相邻的点,那就将两端的边权设置为3。如果不习惯从下标0开始计数,可以先将所有的点的横纵坐标都+1,最后如果有输出点再将横纵坐标都-1

#include 
#include 
#include 
#include 
#include 

using namespace std;

typedef pair<int, int> PII;
typedef pair<int, PII> PIII;
const int N = 310;
const int dx[4] = {
     -1, 1, 0, 0}, dy[4] = {
     0, 0, -1, 1};
int n, m, q;
int stx, sty, edx, edy;
char mp[N][N];
vector<PII> edges[N][N];
int dis[N][N];
bool vis[N][N];

int Dijkstra()
{
     
    memset(dis, 0x3f, sizeof dis);
    priority_queue<PIII, vector<PIII>, greater<PIII>> heap;
    heap.push({
     0, {
     stx, sty}});
    dis[stx][sty] = 0;
    while (!heap.empty()) {
     
        int distance = heap.top().first;
        int a = heap.top().second.first, b = heap.top().second.second;
        heap.pop();
        if (vis[a][b])
            continue;
        vis[a][b] = true;
        if (a == edx && b == edy)
            return distance;
        for (int i = 0; i < 4; i++) {
     
            int x = a + dx[i], y = b + dy[i];
            if (!(x >= 1 && x <= n && y >= 1 && y <= m && mp[x][y] != '#' && dis[x][y] > distance + 1))
                continue;
            dis[x][y] = distance + 1;
            heap.push({
     dis[x][y], {
     x, y}});
        }
        for (auto t : edges[a][b]) {
     
            int x = t.first, y = t.second;
            if (!(mp[x][y] != '#' && dis[x][y] > distance + 3))
                continue;
            dis[x][y] = distance + 3;
            heap.push({
     dis[x][y], {
     x, y}});
        }
    }
    return -1;
}

int main()
{
     
    cin >> n >> m >> q;
    for (int i = 1; i <= n; i++) {
     
        cin >> mp[i] + 1;
        for (int j = 1; j <= m; j++)
            if (mp[i][j] == 'S')
                stx = i, sty = j;
            else if (mp[i][j] == 'T')
                edx = i, edy = j;
    }
    while (q--) {
     
        int a, b, c, d;
        cin >> a >> b >> c >> d;
        a++, b++, c++, d++;
        edges[a][b].push_back({
     c, d});
        edges[c][d].push_back({
     a, b});
    }
    cout << Dijkstra() << endl;
    return 0;
}

C. 露营?料理!

此题的纸质版的地方由于工作人员疏忽, ∑ \sum 的下标 i = 0 i=0 i=0写成了 i = 1 i=1 i=1,并且没有交代 w 0 w_0 w0恒等于0,对选手造成了干扰,在此感到非常抱歉。
此题知识点:前缀和+二分+双关键字排序。
前缀和求完之后需要把 w 0 w_0 w0的值放入到前缀和中,该元素的第二个关键字也要设置为0。
二分不是只有一种写法,在已经升序排列的数组 s s s中,求 s i ≤ x s_i \le x six s i < x , s i > x s_i < x,s_i > x si<xsi>x s i ≥ x s_i \ge x six 的下标 i i i,这四种二分写法是有区别的。进阶指南上面说只有10%的程序员会写二分

≥ x \ge x x的下界

    int l = 1, r = n + 1;
    a[n + 1] = 0x3f3f3f3f;
    while (l < r) {
     
        int mid = (l + r) >> 1;
        if (a[mid] >= x) {
     
            r = mid;
        } else {
     
            l = mid + 1;
        }
    }

    if (l == n + 1) {
     
        puts("Not Found");
    } else {
     
        printf("%d\n", a[l]);
    }

> x > x >x的下界

    int l = 1, r = n + 1;
    a[n + 1] = 0x3f3f3f3f;
    while (l < r) {
     
        int mid = (l + r) >> 1;
        if (a[mid] > x) {
     
            r = mid;
        } else {
     
            l = mid + 1;
        }
    }

    if (l == n + 1) {
     
        puts("Not Found");
    } else {
     
        printf("%d\n", a[l]);
    }

≤ x \le x x的上界

    int l = 0, r = n;
    a[0] = 0xc0c0c0c0;
    while (l < r) {
     
        int mid = (l + r + 1) >> 1;
        if (a[mid] <= x) {
     
            l = mid;
        } else {
     
            r = mid - 1;
        }
    }

    if (l == 0) {
     
        puts("Not Found");
    } else {
     
        printf("%d\n", a[l]);
    }

< x < x <x的上界

    int l = 0, r = n;
    a[0] = 0xc0c0c0c0;
    while (l < r) {
     
        int mid = (l + r + 1) >> 1;
        if (a[mid] < x) {
     
            l = mid;
        } else {
     
            r = mid - 1;
        }
    }

    if (l == 0) {
     
        puts("Not Found");
    } else {
     
        printf("%d\n", a[l]);
    }

此题标程

# include 
# include 
# include 

typedef long long ll;

const int N = 1e5 + 5;

int n;

struct Sum {
     
    int ans;
    int id;

    const bool operator < (const Sum& rhs) const {
     
        return ans < rhs.ans || (ans == rhs.ans && id < rhs.id);
    }
};

int w[N];
Sum sum[N];

int main() {
     
    int m;
    std::cin >> n >> m;

    for (int i = 1; i <= n; i++) {
     
        scanf("%d", &w[i]);
    }

    for (int i = 1; i <= n; i++) {
     
        sum[i].ans = sum[i-1].ans + w[i];
        sum[i].id = i;
    }

    sum[n + 1].ans = sum[n + 1].id = 0;
    std::sort(sum + 1, sum + 2 + n);

    sum[0].ans = sum[0].id = 0xc0c0c0c0;

    while (m--) {
     
        int k;
        scanf("%d", &k);      

        int l = 0; 
        int r = n + 1;

        while (l < r) {
     
            int mid = (l + r + 1) >> 1;
            if (sum[mid].ans <= k) {
     
                l = mid;
            } else {
     
                r = mid - 1;
            }
        }

        if (l == 0) {
     
            puts("-1");
        } else {
     
            printf("%d\n", sum[l].id);
        }
    }

    return 0;
}

STL:
lower_bound(iter.begin(), iter.end(), x)寻找 ∗ i t ≥ x *it\ge x itx的下界,如果返回 i t e r . e n d ( ) iter.end() iter.end()说明无解
upper_bound(iter.begin(), iter.end(), x)寻找 ∗ i t > x *it> x it>x的下界,如果返回 i t e r . e n d ( ) iter.end() iter.end()说明无解
lower_bound(iter.begin(), iter.end(), x) - 1寻找 ∗ i t < x *it< x it<x的上界,如果返回 i t e r . b e g i n ( ) − 1 iter.begin()-1 iter.begin()1说明无解
upper_bound(iter.begin(), iter.end(), x) - 1寻找 ∗ i t ≤ x *it\le x itx的上界,如果返回 i t e r . b e g i n ( ) − 1 iter.begin()-1 iter.begin()1说明无解

# include 
# include 
# include 

typedef long long ll;

const int N = 1e5 + 5;

int n;

struct Sum {
     
    ll ans;
    int id;

    const bool operator < (const Sum& rhs) const {
     
        return ans < rhs.ans || (ans == rhs.ans && id < rhs.id);
    }
};

ll w[N];
Sum sum[N];

int main() {
     
    int m;
    std::cin >> n >> m;

    for (int i = 1; i <= n; i++) {
     
        scanf("%lld", &w[i]);
    }

    for (int i = 1; i <= n; i++) {
     
        sum[i].ans = sum[i-1].ans + w[i];
        sum[i].id = i;
    }

    sum[n + 1].ans = 0;
    sum[n + 1].id = 0;
    std::sort(sum + 1, sum + 2 + n);

    while (m--) {
     
        int k;
        scanf("%d", &k);

        Sum t;
        t.ans = k;
        t.id = 0x3f3f3f3f;

        auto it = std::upper_bound(sum + 1, sum + 2 + n, t) - 1;

        if (it == sum) {
     
            puts("-1");
            continue;
        }

        printf("%d\n", it->id);
    }

    return 0;
}

D. 命运之轮

此题肯定不能暴力。关于时间复杂度与超时,可以查看去年新生赛I题,密码acmwitedu2020
标程为线段树的RMQ
出题人出这题的时候没有发现这就是前年新生赛K题的简化版。等到验题人验的时候发现出新题太麻烦,所以就用了这题。再次论往届题目的重要性。

# include 
# include 
# include 
# include 
# include 

# define l(x) tree[x].l
# define r(x) tree[x].r
# define res(x) tree[x].res

typedef long long ll;

const int N = 1e5 + 5;

int n;

int a[N];

struct SegTree {
     
    int l;
    int r;
    int res;
};

SegTree tree[N << 2];

void pushup(int p) {
     
    res(p) = std::max(res(p * 2), res(p * 2 + 1));
}

void build(int p, int l, int r) {
     
    l(p) = l;
    r(p) = r;
    if (l == r) {
     
        res(p) = a[l];
        return;
    }

    int mid = l + (r - l) / 2;
    build(p * 2, l, mid);
    build(p * 2 + 1, mid + 1, r);
    pushup(p);
}

void update(int p, int x, int v) {
     
    if (l(p) == r(p)) {
     
        res(p) = v;
        return;
    }

    int mid = l(p) + (r(p) - l(p)) / 2;
    if (x <= mid) {
     
        update(p * 2, x, v);
    }

    if (x > mid) {
     
        update(p * 2 + 1, x, v);
    }

    pushup(p);
}

int get(int p, int v) {
     
    if (l(p) == r(p)) {
     
        return l(p);
    }

    return res(p * 2) > v ? get(p * 2, v) : get(p * 2 + 1, v);
}

int flag = 0;

int query2(int p, int l, int r, int max) {
      
    if (l <= l(p) && r(p) <= r) {
     
        int ans = res(p);

        if (ans > max && !flag) {
     
            flag = 1;
            return get(p, max);
        }

        return -1;
    }

    int mid = l(p) + (r(p) - l(p)) / 2;
    int t = -1;
    if (l <= mid && !flag) {
     
        t = query2(p * 2, l, r, max);
    }

    if (r > mid && !flag) {
     
        t = query2(p * 2 + 1, l, r, max);
    }

    return t;
}

int main() {
     
    int m;
    std::cin >> n >> m;

    for (int i = 1; i <= n; i++) {
     
        scanf("%d", &a[i]);
    }

    build(1, 1, n);

    for (int i = 1; i <= m; i++) {
     
        int p;
        scanf("%d", &p);

        flag = 0;

    //    if (p == n) {   
    //        puts("-1");
    //    } else {
     
            int pos = query2(1, p + 1, n, a[p]);

            printf("%d\n", pos);
    //    }
         
        if (i != m) {
     
            int t, v;
            scanf("%d%d", &t, &v);

            update(1, t, v);
            a[t] = v;
        } 
    }

	return 0;
}

E. 找规律

签到题。太多种写法了,如果不会请自行看答案正确的同学的代码。
其实是作者认为自己的标程太复杂

F. 小布丁的电影

KMP+逆波兰
好像不用KMP直接暴力匹配也可以过
注意%0的时候也是输出"Error!"

#include 
#include 
#include 
#include 
#include 
using namespace std;
typedef long long ll;
const ll N=1e4+10;
const ll mod=1e9+7;
stack<ll> num;
stack<char> op;
int ne[N];
char s[N], p[N],temp[N];
int n;
int m;
bool flag=true;
ll kmp(char p[],int n){
     
	
	for (int i = 2, j = 0; i <= n; i ++ )
    {
     
        while (j && p[i] != p[j + 1]) j = ne[j];
        if (p[i] == p[j + 1]) j ++ ;
        ne[i] = j;
    }

    for (int i = 1, j = 0; i <= m; i ++ )
    {
     
        while (j && s[i] != p[j + 1]) j = ne[j];
        if (s[i] == p[j + 1]) j ++ ;
        if (j == n)
		{
     
            return i-n+1;
        }
    }
    return 0;
} 

void eval()
{
     
	if(num.empty()){
     
		cout<<"Error!";
		flag=false;
		return;
	}
    ll b = num.top(); num.pop();
    if(num.empty()){
     
		cout<<"Error!";
		flag=false;
		return;
	}
    ll a = num.top(); num.pop();
    if(op.empty()){
     
		cout<<"Error!";
		flag=false;
		return;
	}
    char c = op.top(); op.pop();
    ll x;
    if (c == '+') x = ((a + b)%mod+mod)%mod;
    else if (c == '-') x = ((a - b)%mod+mod)%mod;
    else if (c == '*'){
     
    	x = ((a %mod* b%mod)%mod+mod)%mod;
	}
    else {
     
    	if(b==0){
     
    		cout<<"Error!";
			flag=false;
			return;
		}else{
     
			if(b<0) b=-b;
			x = (a % b + b ) % b;	
		}
	}
	//cout<
    num.push(x);
}

int main()
{
     
	
	cin>>m>>s+1;
    unordered_map<char, int> pr{
     {
     '+', 1}, {
     '-', 1}, {
     '*', 2}, {
     '%', 2}};
    string str;
    cin >> str;
    int len=str.size();
    for (int i = 0; i < len; i ++ )
    {
     
        char c = str[i];
        if (c>='a'&&c<='z')
        {
     
            ll x = 0, j = i;
            int ans=1;
            while (j < str.size() && str[j]>='a'&& str[j]<='z')
                temp[ans++]=str[j++];
            x = kmp(temp , ans-1);
            i = j - 1;
            num.push(x);
        }
        else if (c == '(') op.push(c);
        else if (c == ')')
        {
     
            while (op.top() != '(') eval();
            op.pop();
        }
        else
        {
     
            while (op.size() && op.top() != '(' && pr[op.top()] >= pr[c])
			{
     
    			eval();
    			if(!flag) return 0;
			}
            op.push(c);
        }
        if(!flag) return 0;
    }
    while (op.size()) {
     
    	eval();
    	if(!flag) return 0;
	}
    if(flag) cout << (num.top() % mod + mod) % mod;
    return 0;
}

G. 厨房

这题真心不难,就是题目有点长而已。

# include 
# include 
# include 
# include 
# include 

typedef long long ll;

int n, m, k;

const int N = 1e2 + 5;

int a[N];
int b[N];
int c[N];
int d[N];
int v[N];

struct Custom {
     
    int time;
    int id;

    bool operator < (Custom rhs) {
     
        if (time == rhs.time) {
     
            return time < rhs.time;
        } else {
     
            return id < rhs.id;
        }
    }
};

Custom custom[N];

void solve() {
     
    scanf("%d%d%d", &n, &m, &k);

    for (int i = 1; i <= N - 2; i++) {
     
        custom[i].id = i;
    }

    for (int i = 1; i <= n; i++) {
     
        scanf("%d", &custom[i].time);
    }

    for (int i = 1; i <= n; i++) {
     
        scanf("%d", &v[i]);
    }

    for (int i = 1; i <= n; i++) {
     
        scanf("%d", &a[i]);
    }

    for (int i = 1; i <= n; i++) {
     
        scanf("%d", &b[i]);
    }

    for (int i = 1; i <= n; i++) {
     
        scanf("%d", &c[i]);
    }

    for (int i = 1; i <= n; i++) {
     
        scanf("%d", &d[i]);
    }

    ll sum = 0;
    int inda = 1;
    int indb = 1;
    int indc = 1;
    int indd = 1;

    //std::sort(custom + 1, custom + n + 1);

    for (int i = 1; i <= n; i++) {
     
        if (inda > n || indb > n) {
     
            sum -= k / 2;
            continue;
        }

        if (v[i] == 1) {
     
            if (a[inda] <= custom[i].time + m && b[indb] <= custom[i].time + m && c[indc] <= custom[i].time + m) {
     
                sum += k;
                inda++;
                indb++;
                indc++;
            } else {
     
                sum -= k / 2;
            }
        } else if (v[i] == 2) {
     
            if (a[inda] <= custom[i].time + m && b[indb] <= custom[i].time + m && d[indd] <= custom[i].time + m) {
     
                sum += k;
                inda++;
                indb++;
                indd++;
            } else {
     
                sum -= k / 2;
            }
        } else {
     
            if (a[inda] <= custom[i].time + m && b[indb] <= custom[i].time + m && c[indc] <= custom[i].time + m && d[indd] <= custom[i].time + m) {
     
                sum += k;
                inda++;
                indb++;
                indc++;
                indd++;
            } else {
     
                sum -= k / 2;
            }
        }
    }

    printf("%lld\n", sum);
}

int main() {
     
    int T;
    std::cin >> T;

    while (T--) {
     
        solve();
    }

	return 0;
}

H. 量子计算机

此题改编自Codeforces Round #674 (Div. 3)的F
本来一开始想用英文出这题的,但是后来感觉这题有点难所以把英文砍了
此题涉及的物理知识可能存在问题,如果感兴趣请查阅专业资料

d p [ i ] [ j ] dp[i][j] dp[i][j] i i i从下标 1 1 1开始计数, 1 ≤ j ≤ ∣ t ∣ 1\le j \le |t| 1jt)表示前 i i i个字符串中,匹配了 t t t的前 j j j位。状态转移方程为( t t t从下标 1 1 1开始计数):

d p [ 0 ] [ 0 ] = 1 dp[0][0]=1 dp[0][0]=1

每一次循环, d p [ i ] dp[i] dp[i]中的每一个元素值等于 d p [ i − 1 ] dp[i-1] dp[i1]中每一个元素的值

d p [ i ] [ 1 ] = ( d p [ i − 1 ] [ 1 ] + ( s [ i ] = = t [ 1 ] ) ) % M o d dp[i][1] = (dp[i-1][1] + (s[i]==t[1])) \% Mod dp[i][1]=(dp[i1][1]+(s[i]==t[1]))%Mod
d p [ i ] [ 2 ] = ( d p [ i − 1 ] [ 2 ] + ( s [ i ] = = t [ 2 ] ) ∗ d p [ i − 1 ] [ 1 ] ) % M o d dp[i][2] = (dp[i-1][2] + (s[i]==t[2]) * dp[i-1][1]) \% Mod dp[i][2]=(dp[i1][2]+(s[i]==t[2])dp[i1][1])%Mod
d p [ i ] [ 3 ] = ( d p [ i − 1 ] [ 3 ] + ( s [ i ] = = t [ 3 ] ) ∗ d p [ i − 1 ] [ 2 ] ) % M o d dp[i][3] = (dp[i-1][3] + (s[i]==t[3]) * dp[i-1][2]) \% Mod dp[i][3]=(dp[i1][3]+(s[i]==t[3])dp[i1][2])%Mod

d p [ i ] [ m ] = ( d p [ i − 1 ] [ 3 ] + ( s [ i ] = = t [ m ] ) ∗ d p [ i − 1 ] [ m − 1 ] ) % M o d dp[i][m] = (dp[i-1][3] + (s[i]==t[m]) * dp[i-1][m-1]) \% Mod dp[i][m]=(dp[i1][3]+(s[i]==t[m])dp[i1][m1])%Mod

特判 ‘?’
i f   ( s [ i ] = = ′ ? ′ ) if\ (s[i]=='?') if (s[i]==?)
d p [ i ] [ 0 ] = 2 ∗ d p [ i − 1 ] [ 0 ] % M o d dp[i][0] = 2 * dp[i - 1][0] \% Mod dp[i][0]=2dp[i1][0]%Mod
d p [ i ] [ 1 ] = ( 2 ∗ d p [ i − 1 ] [ 1 ] + d p [ i − 1 ] [ 0 ] ) % M o d dp[i][1] = (2 * dp[i-1][1] + dp[i-1][0]) \% Mod dp[i][1]=(2dp[i1][1]+dp[i1][0])%Mod
d p [ i ] [ 2 ] = ( 2 ∗ d p [ i − 1 ] [ 2 ] + d p [ i − 1 ] [ 1 ] ) % M o d dp[i][2] = (2 * dp[i-1][2] + dp[i-1][1]) \% Mod dp[i][2]=(2dp[i1][2]+dp[i1][1])%Mod
d p [ i ] [ 3 ] = ( 2 ∗ d p [ i − 1 ] [ 3 ] + d p [ i − 1 ] [ 2 ] ) % M o d dp[i][3] = (2 * dp[i-1][3] + dp[i-1][2]) \% Mod dp[i][3]=(2dp[i1][3]+dp[i1][2])%Mod

d p [ i ] [ m ] = ( 2 ∗ d p [ i − 1 ] [ m ] + d p [ i − 1 ] [ m − 1 ] ) % M o d dp[i][m] = (2 * dp[i-1][m] + dp[i-1][m-1]) \% Mod dp[i][m]=(2dp[i1][m]+dp[i1][m1])%Mod

d p [ n ] [ m ] dp[n][m] dp[n][m]( n n n s s s的长度, m m m t t t的长度)。

#include 
#include 
#include 
#include 

const int mod = 998244353;

const int N = 1e5 + 5;
char s[N];
char t[N];

int n;
int m;

std::vector<std::vector<int> > dp(N, std::vector<int>(8));

int main() {
     
    int T;
    std::cin >> T;

    while (T--) {
     
        scanf("%s%s", s + 1, t + 1);
        n = strlen(s + 1);
        m = strlen(t + 1);

        //auto <--> std::vector >::iterator
        for (auto it = dp.begin(); it != dp.begin() + n + 2; it++) {
     
            std::fill(it->begin(), it->end(), 0);
        }

        dp[0][0] = 1;

        for (int i = 1; i <= n; i++)
        {
     
            dp[i] = dp[i - 1]; 

            if (s[i] == '?')
            {
     
                dp[i][0] = 2ll * dp[i - 1][0] % mod;
                for (int j = 1;  j <= m; j++) {
     
                    dp[i][j] = (2ll * dp[i - 1][j] + dp[i - 1][j-1]) % mod;
                }
            }
            else
            {
     
                for (int j = 1; j <= m; j++) {
     
                    dp[i][j] = (1ll * dp[i - 1][j] + (s[i] == t[j]) * dp[i - 1][j-1]) % mod;
                }
            }
        }

        printf("%d\n", dp[n][m]);
    }

    return 0;
}

也可以用滚动数组降维

# include 
# include 
# include 

const int mod = 998244353;

const int N = 1e5 + 5;
char s[N];
char t[N];

int dp[10];
int temp[10];

int n;
int m;

int main() {
     
    int T;
    std::cin >> T;

    while (T--) {
     
        scanf("%s%s", s + 1, t + 1);
        n = strlen(s + 1);
        m = strlen(t + 1);

        memset(dp, 0, sizeof(int) * (m + 2));
        dp[0] = 1;

        for (int i = 1; i <= n; i++) {
     
            if (s[i] == '?') {
     
                for (int j = 0; j < m; j++) {
     
                    temp[j] = dp[j];
                }

                dp[0] = 2ll * dp[0] % mod;
                for (int j = 1; j <= m; j++) {
     
                    dp[j] = (2ll * dp[j] + temp[j-1]) % mod;
                }
            } else {
     
                for (int j = 0; j < m; j++) {
     
                    temp[j] = dp[j];
                }

                for (int j = 1; j <= m; j++) {
     
                    dp[j] = (1ll * dp[j] + (s[i] == t[j]) * temp[j-1]) % mod;
                }
            }
        }
        
        printf("%d\n", dp[m]);
    }
	
	return 0;
}

不要直接memset/清零整个二维dp数组,这样会超时。

时间复杂度 O ( T n m ) O(Tnm) O(Tnm)

I. 我们是冠军

签到题。就是数 2 2 2的多少次方。
老生长谈的问题,当输入的数据过多的时候输入输出的卡常不能忽略。

#include 

using namespace std;

typedef long long LL;

int main()
{
     
    int T;
    scanf("%d", &T);
    while (T--) {
     
        LL n;
        scanf("%lld", &n);
        int cnt = 0;
        while (n > 4) {
     
            cnt++;
            n /= 2;
        }
        printf("%d %d %d\n", cnt + 3, cnt + 4, 1);
    }
    return 0;
}
import java.util.*;
import java.io.*;

public class Main {
     
    public static void solve() {
     
        int T = nextInt();

        while (T-- > 0) {
     
            long n = nextLong();

            int cnt = 3;
            while (n > 4) {
     
                cnt++;
                n >>= 1;
            }
            
            int cnt2 = cnt + 1;
            out.println(cnt + " " + cnt2 + " 1");
        }
    }  

    public static void main(String[] args) {
     
            reader = new BufferedReader(new InputStreamReader(System.in));
            tokenizer = null;
            out = new PrintWriter(System.out);
            solve();
            out.close();
    }
  
    static BufferedReader reader;
    static StringTokenizer tokenizer;
    static PrintWriter out;
  
    static int nextInt(){
     
        return Integer.parseInt(next());
    }
  
    static long nextLong(){
     
        return Long.parseLong(next());
    }
  
    static double nextDouble(){
     
        return Double.parseDouble(next());
    }
  
    static String next(){
     
        while (tokenizer == null || !tokenizer.hasMoreTokens()) {
     
            try {
     
                tokenizer = new StringTokenizer(reader.readLine());
            } catch (IOException e) {
     
                throw new RuntimeException(e);
            }
        }
        return tokenizer.nextToken();
    }
}

更多写法请参考去年新生赛预选赛的A

J. 又是斐波那契

这题真的是数论的基础题目,公式全都摆出来了,按照公式打就行

# include 
# include 
# include 

typedef long long ll;

const int mod = 1e9 + 9;

const int Sq5 = 383008016;    // sqrt(5)
const int A = 691504013;	  // (1 + sqrt(5)) / 2
const int B = 308495997;	  // (1 - sqrt(5)) / 2
const int C = 276601605;	  // 1 / sqrt(5)
const int Inv2 = 500000005;   // 1 / 2
const int Inv10 = 100000001;  // 1 / 10

int qpow(int a, ll b) {
     
    int ans = 1 % mod;

    while (b) {
     
        if (b & 1) {
     
            ans = 1ll * ans * a % mod;
        } 

        a = 1ll * a * a % mod;
        b >>= 1;
    }

    return ans;
}

int main() {
     
    ll n;
    int a, b, c;

    while (~scanf("%lld%d%d%d", &n, &a, &b, &c)) {
     
        int qan = qpow(A, n);
        int qbn = qpow(B, n);
        int a1 = 1ll * qan * C % mod;
        int a2 = 1ll * qbn * C % mod;
        int ans1 = (1ll * a1 - a2 + mod) % mod * b % mod;

        int a3 = 1ll * (Sq5 - 1) * qan % mod * Inv2 % mod * C % mod;
        int a4 = 1ll * (Sq5 + 1) * qbn % mod * Inv2 % mod * C % mod;

        int ans2 = (1ll * a3 + a4) % mod * a % mod;

        int a5 = 1ll * (5 + Sq5) * Inv10 % mod * qan % mod;
        int a6 = 1ll * (5 - Sq5) * Inv10 % mod * qbn % mod;
        int ans3 = (1ll * a5 + a6 - 1 + mod) % mod * c % mod;

        int ans = (1ll * ans1 + ans2 + ans3 + mod) % mod;
        printf("%d\n", ans);
    } 

    return 0;
}

乘法的时候注意先将数据变成long long型,不然会超出int的范围。
在模数下 ( ϕ 1 ) n (\phi_1)^n (ϕ1)n不一定大于 ( ϕ 2 ) n (\phi_2)^n (ϕ2)n,模数下用减法需要先加上mod在取模。例如:(a-b+mod)%mod。全场唯一开了此题的队伍就是因为这里没有注意所以没有过。

还可以用广义欧拉降幂将 n n n模一个 ϕ ( m o d ) \phi(mod) ϕ(mod)。这里的 ϕ \phi ϕ为欧拉函数。

# include 
# include 
# include 

typedef long long ll;

const double phi1 = (1 + sqrt(5)) * 0.5;
const double phi2 = (1 - sqrt(5)) * 0.5;

const int mod = 1e9 + 9;

const int Sq5 = 383008016;    // sqrt(5)
const int A = 691504013;	  // (1 + sqrt(5) / 2
const int B = 308495997;	  // (1 - sqrt(5) / 2
const int C = 276601605;	  // 1 / sqrt(5)
const int Inv2 = 500000005;   // 1 / 2
const int Inv10 = 100000001;  // 1 / 10

int qpow(int a, ll b) {
     
    int ans = 1 % mod;

    while (b) {
     
        if (b & 1) {
     
            ans = 1ll * ans * a % mod;
        } 

        a = 1ll * a * a % mod;
        b >>= 1;
    }

    return ans;
}

int main() {
     
    ll n;
    int a, b, c;

    while (~scanf("%lld%d%d%d", &n, &a, &b, &c)) {
     
        n = n % (mod - 1);

        int qan = qpow(A, n);
        int qbn = qpow(B, n);
        int a1 = 1ll * qan * C % mod;
        int a2 = 1ll * qbn * C % mod;
        int ans1 = (1ll * a1 - a2 + mod) % mod * b % mod;

        int a3 = 1ll * (Sq5 - 1) * qan % mod * Inv2 % mod * C % mod;
        int a4 = 1ll * (Sq5 + 1) * qbn % mod * Inv2 % mod * C % mod;

        int ans2 = (1ll * a3 + a4) % mod * a % mod;

        int a5 = 1ll * (5 + Sq5) * Inv10 % mod * qan % mod;
        int a6 = 1ll * (5 - Sq5) * Inv10 % mod * qbn % mod;
        int ans3 = (1ll * a5 + a6 - 1 + mod) % mod * c % mod;

        int ans = (1ll * ans1 + ans2 + ans3 + mod) % mod;
        printf("%d\n", ans);
    } 

    return 0;
}

我说的是不要轻易用新生赛的代码,没有说矩阵快速幂就不行,题解中的第一种方法确实会超时,但是题解中的第二种方法改改还是可以过的

# include 
# include 
# include 
# include 

typedef long long ll;

const ll mod = 1e9 + 9;;

struct Node {
     
	ll m[3][3];

	Node operator * (const Node& rhs) {
     
		Node t = {
     0};
		for (int i = 0; i < 3; i++) {
     
			for (int j = 0; j < 3; j++) {
     
				for (int k = 0; k < 3; k++) {
     
					t.m[i][j] += m[i][k] * rhs.m[k][j] % mod;

					if (t.m[i][j] >= mod) {
     
						t.m[i][j] -= mod;
					}
				}
			}
		}

		return t;
	}

};

const Node I = {
     1, 0, 0, 0, 1, 0, 0, 0, 1};

Node qpow(Node x, ll p) {
     
	Node ans = I;
	while (p) {
     
		if (p & 1) {
     
			ans = ans * x;
		}

		x = x * x;
		p >>= 1;
	}

	return ans;
}

int main() {
     
	int a = 1;
	int b = 1;
	int f0;
	int f1;
    int c;

    ll n;
    while (~scanf("%lld%d%d%d", &n, &f0, &f1, &c)) {
     
        Node T = {
     a, b, 1, 1, 0, 0, 0, 0, 1};
        Node ans = qpow(T, n);

        int res = (ans.m[1][0] * f1 % mod + ans.m[1][1] * f0 % mod + ans.m[1][2] * c % mod) % mod;

        printf("%d\n", res);
    }

	return 0;
}
# include 
# include 
# include 

typedef long long ll;

const ll mod = 1e9 + 9;;

ll qpow(ll x, ll p, int Mod = mod) {
     
	ll ans = 1 % Mod;
	x %= Mod;
	while (p) {
     
		if (p & 1) {
     
			ans = ans * x % Mod;
		}

		x = x * x % Mod;
		p >>= 1;
	}

	return ans;
}

ll lcm(ll a, ll b) {
     
    return a / std::__gcd(a, b) * b;
}
 
ll pFac[105][2];
int getFactors(ll n) {
     
    int pCnt = 0;
    for (ll i = 2; i * i <= n; ++i) {
     
        if (n % i) {
     
            continue;
        }
 
        pFac[pCnt][0] = i;
        pFac[pCnt][1] = 0;
        while (n % i == 0) {
     
            n /= i;
            pFac[pCnt][1]++;
        }
 
        pCnt++;
    }
 
    if (n > 1) {
     
        pFac[pCnt][0] = n;
        pFac[pCnt++][1] = 1;
    }
 
    return pCnt;
}

int Legendre(ll a, ll p) {
     
    if (qpow(a, (p - 1) >> 1, p) == 1) {
     
        return 1;
    }
 
    return -1;
}
 
ll find_loop(ll n, ll a = 1, ll b = 1) {
     
    int cnt = getFactors(n);
    ll c = a * a + b * 4;
    ll ans = 1, record;
 
    for (int i = 0; i < cnt; ++i) {
     
        if (pFac[i][0] == 2) {
     
            record = 3 * 2 * 2;
        } else if (c % pFac[i][0] == 0) {
     
            record = pFac[i][0] * (pFac[i][0] - 1);
        } else if (Legendre(c, pFac[i][0]) == 1) {
     
            record = pFac[i][0] - 1;
        } else {
     
            record = (pFac[i][0] - 1) * (pFac[i][0] + 1);
        }
 
        for (int j = 1; j < pFac[i][1]; ++j) {
     
            record *= pFac[i][0];
        }
 
        ans = lcm(ans, record);
    }
 
    return ans;
}

struct Node {
     
	ll m[3][3];

	Node operator * (const Node& rhs) {
     
		Node t = {
     0};
		for (int i = 0; i < 3; i++) {
     
			for (int j = 0; j < 3; j++) {
     
				for (int k = 0; k < 3; k++) {
     
					t.m[i][j] += m[i][k] * rhs.m[k][j] % mod;

					if (t.m[i][j] >= mod) {
     
						t.m[i][j] -= mod;
					}
				}
			}
		}

		return t;
	}

};

const Node I = {
     1, 0, 0, 0, 1, 0, 0, 0, 1};

Node qpow(Node x, ll p) {
     
	Node ans = I;
	while (p) {
     
		if (p & 1) {
     
			ans = ans * x;
		}

		x = x * x;
		p >>= 1;
	}

	return ans;
}

int main() {
     
	int a = 1;
	int b = 1;
	int f0;
	int f1;
    int c;
    
    ll n;
    ll loop = find_loop(mod);
    while (~scanf("%lld%d%d%d", &n, &f0, &f1, &c)) {
     
        Node T = {
     a, b, 1, 1, 0, 0, 0, 0, 1};

        n = n % loop;

        Node ans = qpow(T, n);

        int res = (ans.m[1][0] * f1 % mod + ans.m[1][1] * f0 % mod + ans.m[1][2] * c % mod) % mod;

        printf("%d\n", res);
    }

	return 0;
}

通项公式证明过程:(有很多种证明过程,我只会这一种)
( f n + 1 f n 0 f n f n − 1 0 c c 0 ) = ( 1 1 1 1 0 0 0 0 1 ) ∗ ( f n f n − 1 0 f n − 1 f n − 2 0 c c 0 ) = . . . \begin{pmatrix} f_{n+1} & f_{n} & 0 \\ f_{n} & f_{n-1} & 0 \\ c & c & 0 \end{pmatrix} = \begin{pmatrix} 1 & 1 & 1 \\ 1 & 0 & 0 \\ 0 & 0 & 1 \end{pmatrix} * \begin{pmatrix} f_{n} & f_{n-1} & 0 \\ f_{n-1} & f_{n-2} & 0 \\ c & c & 0 \end{pmatrix} =... fn+1fncfnfn1c000=110100101fnfn1cfn1fn2c000=...

= ( 1 1 1 1 0 0 0 0 1 ) n ∗ ( f 1 f 0 0 f 0 f − 1 0 c c 0 ) = { \begin{pmatrix} 1 & 1 & 1 \\ 1 & 0 & 0 \\ 0 & 0 & 1 \end{pmatrix}}^n * \begin{pmatrix} f_{1} & f_{0} & 0 \\ f_{0} & f_{-1} & 0 \\ c & c & 0 \end{pmatrix} =110100101nf1f0cf0f1c000

A = A= A= ( 1 1 1 1 0 0 0 0 1 ) \begin{pmatrix} 1 & 1 & 1 \\ 1 & 0 & 0 \\ 0 & 0 & 1 \end{pmatrix} 110100101,对转移矩阵 A A A进行相似对角化。
特征值 λ 1 = 1 \lambda_1 = 1 λ1=1,特征向量 ξ 1 \xi_1 ξ1 = = = ( − 1 , − 1 , 1 ) T \begin{pmatrix} -1 ,& -1, & 1 \end{pmatrix}^T (1,1,1)T
特征值 λ 2 = 1 + 5 2 \lambda_2 = \frac{1+\sqrt5}{2} λ2=21+5 ,特征向量 ξ 2 \xi_2 ξ2 = = = ( 1 + 5 , 2 , 0 ) T \begin{pmatrix} 1+\sqrt5, & 2, & 0 \end{pmatrix}^T (1+5 ,2,0)T
特征值 λ 3 = 1 − 5 2 \lambda_3 = \frac{1-\sqrt5}{2} λ3=215 ,特征向量 ξ 3 \xi_3 ξ3 = = = ( 1 − 5 , 2 , 0 ) T \begin{pmatrix} 1-\sqrt5, & 2, & 0 \end{pmatrix}^T (15 ,2,0)T

P − 1 A P = Λ P^{-1}AP=\Lambda P1AP=Λ

P = ( ξ 1 , ξ 2 , ξ 3 ) = ( − 1 1 + 5 1 − 5 − 1 2 2 1 0 0 ) P=\begin{pmatrix} \xi_1 ,& \xi_2, & \xi_3 \end{pmatrix}=\begin{pmatrix} -1 & 1+\sqrt5 & 1-\sqrt5 \\ -1 & 2 & 2 \\ 1 & 0 & 0 \end{pmatrix} P=(ξ1,ξ2,ξ3)=1111+5 2015 20

P − 1 = ( 0 0 1 5 10 5 − 5 20 5 + 5 20 − 5 10 5 + 5 20 5 − 5 20 ) P^{-1}=\begin{pmatrix} 0 & 0 & 1 \\ \frac{\sqrt5}{10} & \frac{5-\sqrt5}{20} & \frac{5+\sqrt5}{20} \\ -\frac{\sqrt5}{10} & \frac{5+\sqrt5}{20} & \frac{5-\sqrt5}{20} \end{pmatrix} P1=0105 105 02055 205+5 1205+5 2055

Λ = ( λ 1 0 0 0 λ 2 0 0 0 λ 3 ) = ( 1 0 0 0 1 + 5 2 0 0 0 1 − 5 2 ) \Lambda=\begin{pmatrix} \lambda_1 & 0 & 0 \\ 0 & \lambda_2 & 0 \\ 0 & 0 & \lambda_3 \end{pmatrix}=\begin{pmatrix} 1 & 0 & 0 \\ 0 & \frac{1+\sqrt5}{2} & 0 \\ 0 & 0 & \frac{1-\sqrt5}{2} \end{pmatrix} Λ=λ1000λ2000λ3=100021+5 000215

A = P Λ P − 1 A=P\Lambda P^{-1} A=PΛP1 A n = P Λ n P − 1 A^n=P\Lambda^n P^{-1} An=PΛnP1

f n = ( A n ) 21 ∗ f 1 + ( A n ) 22 ∗ f 0 + ( A n ) 23 ∗ c f_n=(A^n)_{21}*f_1+(A^n)_{22}*f_0+(A^n)_{23}*c fn=(An)21f1+(An)22f0+(An)23c
这里的 ( A n ) i j (A^n)_{ij} (An)ij代表的是 A n A^n An的第 i i i行第 j j j列的元素。化简即可得到通项公式。
你可以在Symnolab Math Solver在线计算矩阵的逆、特征值、特征向量。

你可能感兴趣的:(笔记)