题目链接:https://ac.nowcoder.com/acm/contest/16172
此题改至ACWing 240 食物链
此题知识点:带权并查集
#include
using namespace std;
const int N = 5e4 + 5, mod = 4;
int n, m, cnt;
int d[N], pre[N];
int find(int x)
{
if (x != pre[x]) {
int root = find(pre[x]);
d[x] = (d[x] + d[pre[x]]) % mod;
pre[x] = root;
}
return pre[x];
}
int main()
{
scanf("%d %d", &n, &m);
for (int i = 1; i <= n; i++)
pre[i] = i;
while (m--) {
int t, x, y;
scanf("%d %d %d", &t, &x, &y);
if (x < 1 || x > n || y < 1 || y > n || t != 1 && x == y) {
cnt++;
continue;
}
int k = t - 1;
int px = find(x), py = find(y);
if (px == py && ((d[x] - d[y]) % mod + mod) % mod != k) {
cnt++;
continue;
}
if (px != py) {
pre[px] = py;
d[px] = ((k - d[x] + d[y]) % mod + mod) % mod;
}
}
printf("%d\n", cnt);
return 0;
}
方法二:扩展域并查集
作者不是很懂,只能把大佬的代码搬过来
https://ac.nowcoder.com/acm/contest/view-submission?submissionId=47751034
此题比较简单的做法是直接Dijkstra,如果传送门的两端都不是陷阱,并且不是相邻的点,那就将两端的边权设置为3。如果不习惯从下标0开始计数,可以先将所有的点的横纵坐标都+1,最后如果有输出点再将横纵坐标都-1
#include
#include
#include
#include
#include
using namespace std;
typedef pair<int, int> PII;
typedef pair<int, PII> PIII;
const int N = 310;
const int dx[4] = {
-1, 1, 0, 0}, dy[4] = {
0, 0, -1, 1};
int n, m, q;
int stx, sty, edx, edy;
char mp[N][N];
vector<PII> edges[N][N];
int dis[N][N];
bool vis[N][N];
int Dijkstra()
{
memset(dis, 0x3f, sizeof dis);
priority_queue<PIII, vector<PIII>, greater<PIII>> heap;
heap.push({
0, {
stx, sty}});
dis[stx][sty] = 0;
while (!heap.empty()) {
int distance = heap.top().first;
int a = heap.top().second.first, b = heap.top().second.second;
heap.pop();
if (vis[a][b])
continue;
vis[a][b] = true;
if (a == edx && b == edy)
return distance;
for (int i = 0; i < 4; i++) {
int x = a + dx[i], y = b + dy[i];
if (!(x >= 1 && x <= n && y >= 1 && y <= m && mp[x][y] != '#' && dis[x][y] > distance + 1))
continue;
dis[x][y] = distance + 1;
heap.push({
dis[x][y], {
x, y}});
}
for (auto t : edges[a][b]) {
int x = t.first, y = t.second;
if (!(mp[x][y] != '#' && dis[x][y] > distance + 3))
continue;
dis[x][y] = distance + 3;
heap.push({
dis[x][y], {
x, y}});
}
}
return -1;
}
int main()
{
cin >> n >> m >> q;
for (int i = 1; i <= n; i++) {
cin >> mp[i] + 1;
for (int j = 1; j <= m; j++)
if (mp[i][j] == 'S')
stx = i, sty = j;
else if (mp[i][j] == 'T')
edx = i, edy = j;
}
while (q--) {
int a, b, c, d;
cin >> a >> b >> c >> d;
a++, b++, c++, d++;
edges[a][b].push_back({
c, d});
edges[c][d].push_back({
a, b});
}
cout << Dijkstra() << endl;
return 0;
}
此题的纸质版的地方由于工作人员疏忽, ∑ \sum ∑的下标 i = 0 i=0 i=0写成了 i = 1 i=1 i=1,并且没有交代 w 0 w_0 w0恒等于0,对选手造成了干扰,在此感到非常抱歉。
此题知识点:前缀和+二分+双关键字排序。
前缀和求完之后需要把 w 0 w_0 w0的值放入到前缀和中,该元素的第二个关键字也要设置为0。
二分不是只有一种写法,在已经升序排列的数组 s s s中,求 s i ≤ x s_i \le x si≤x, s i < x , s i > x s_i < x,s_i > x si<x,si>x, s i ≥ x s_i \ge x si≥x 的下标 i i i,这四种二分写法是有区别的。进阶指南上面说只有10%的程序员会写二分
≥ x \ge x ≥x的下界
int l = 1, r = n + 1;
a[n + 1] = 0x3f3f3f3f;
while (l < r) {
int mid = (l + r) >> 1;
if (a[mid] >= x) {
r = mid;
} else {
l = mid + 1;
}
}
if (l == n + 1) {
puts("Not Found");
} else {
printf("%d\n", a[l]);
}
> x > x >x的下界
int l = 1, r = n + 1;
a[n + 1] = 0x3f3f3f3f;
while (l < r) {
int mid = (l + r) >> 1;
if (a[mid] > x) {
r = mid;
} else {
l = mid + 1;
}
}
if (l == n + 1) {
puts("Not Found");
} else {
printf("%d\n", a[l]);
}
≤ x \le x ≤x的上界
int l = 0, r = n;
a[0] = 0xc0c0c0c0;
while (l < r) {
int mid = (l + r + 1) >> 1;
if (a[mid] <= x) {
l = mid;
} else {
r = mid - 1;
}
}
if (l == 0) {
puts("Not Found");
} else {
printf("%d\n", a[l]);
}
< x < x <x的上界
int l = 0, r = n;
a[0] = 0xc0c0c0c0;
while (l < r) {
int mid = (l + r + 1) >> 1;
if (a[mid] < x) {
l = mid;
} else {
r = mid - 1;
}
}
if (l == 0) {
puts("Not Found");
} else {
printf("%d\n", a[l]);
}
此题标程
# include
# include
# include
typedef long long ll;
const int N = 1e5 + 5;
int n;
struct Sum {
int ans;
int id;
const bool operator < (const Sum& rhs) const {
return ans < rhs.ans || (ans == rhs.ans && id < rhs.id);
}
};
int w[N];
Sum sum[N];
int main() {
int m;
std::cin >> n >> m;
for (int i = 1; i <= n; i++) {
scanf("%d", &w[i]);
}
for (int i = 1; i <= n; i++) {
sum[i].ans = sum[i-1].ans + w[i];
sum[i].id = i;
}
sum[n + 1].ans = sum[n + 1].id = 0;
std::sort(sum + 1, sum + 2 + n);
sum[0].ans = sum[0].id = 0xc0c0c0c0;
while (m--) {
int k;
scanf("%d", &k);
int l = 0;
int r = n + 1;
while (l < r) {
int mid = (l + r + 1) >> 1;
if (sum[mid].ans <= k) {
l = mid;
} else {
r = mid - 1;
}
}
if (l == 0) {
puts("-1");
} else {
printf("%d\n", sum[l].id);
}
}
return 0;
}
STL:
lower_bound(iter.begin(), iter.end(), x)
寻找 ∗ i t ≥ x *it\ge x ∗it≥x的下界,如果返回 i t e r . e n d ( ) iter.end() iter.end()说明无解
upper_bound(iter.begin(), iter.end(), x)
寻找 ∗ i t > x *it> x ∗it>x的下界,如果返回 i t e r . e n d ( ) iter.end() iter.end()说明无解
lower_bound(iter.begin(), iter.end(), x) - 1
寻找 ∗ i t < x *it< x ∗it<x的上界,如果返回 i t e r . b e g i n ( ) − 1 iter.begin()-1 iter.begin()−1说明无解
upper_bound(iter.begin(), iter.end(), x) - 1
寻找 ∗ i t ≤ x *it\le x ∗it≤x的上界,如果返回 i t e r . b e g i n ( ) − 1 iter.begin()-1 iter.begin()−1说明无解
# include
# include
# include
typedef long long ll;
const int N = 1e5 + 5;
int n;
struct Sum {
ll ans;
int id;
const bool operator < (const Sum& rhs) const {
return ans < rhs.ans || (ans == rhs.ans && id < rhs.id);
}
};
ll w[N];
Sum sum[N];
int main() {
int m;
std::cin >> n >> m;
for (int i = 1; i <= n; i++) {
scanf("%lld", &w[i]);
}
for (int i = 1; i <= n; i++) {
sum[i].ans = sum[i-1].ans + w[i];
sum[i].id = i;
}
sum[n + 1].ans = 0;
sum[n + 1].id = 0;
std::sort(sum + 1, sum + 2 + n);
while (m--) {
int k;
scanf("%d", &k);
Sum t;
t.ans = k;
t.id = 0x3f3f3f3f;
auto it = std::upper_bound(sum + 1, sum + 2 + n, t) - 1;
if (it == sum) {
puts("-1");
continue;
}
printf("%d\n", it->id);
}
return 0;
}
此题肯定不能暴力。关于时间复杂度与超时,可以查看去年新生赛I题,密码acmwitedu2020
标程为线段树的RMQ
出题人出这题的时候没有发现这就是前年新生赛K题的简化版。等到验题人验的时候发现出新题太麻烦,所以就用了这题。再次论往届题目的重要性。
# include
# include
# include
# include
# include
# define l(x) tree[x].l
# define r(x) tree[x].r
# define res(x) tree[x].res
typedef long long ll;
const int N = 1e5 + 5;
int n;
int a[N];
struct SegTree {
int l;
int r;
int res;
};
SegTree tree[N << 2];
void pushup(int p) {
res(p) = std::max(res(p * 2), res(p * 2 + 1));
}
void build(int p, int l, int r) {
l(p) = l;
r(p) = r;
if (l == r) {
res(p) = a[l];
return;
}
int mid = l + (r - l) / 2;
build(p * 2, l, mid);
build(p * 2 + 1, mid + 1, r);
pushup(p);
}
void update(int p, int x, int v) {
if (l(p) == r(p)) {
res(p) = v;
return;
}
int mid = l(p) + (r(p) - l(p)) / 2;
if (x <= mid) {
update(p * 2, x, v);
}
if (x > mid) {
update(p * 2 + 1, x, v);
}
pushup(p);
}
int get(int p, int v) {
if (l(p) == r(p)) {
return l(p);
}
return res(p * 2) > v ? get(p * 2, v) : get(p * 2 + 1, v);
}
int flag = 0;
int query2(int p, int l, int r, int max) {
if (l <= l(p) && r(p) <= r) {
int ans = res(p);
if (ans > max && !flag) {
flag = 1;
return get(p, max);
}
return -1;
}
int mid = l(p) + (r(p) - l(p)) / 2;
int t = -1;
if (l <= mid && !flag) {
t = query2(p * 2, l, r, max);
}
if (r > mid && !flag) {
t = query2(p * 2 + 1, l, r, max);
}
return t;
}
int main() {
int m;
std::cin >> n >> m;
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
build(1, 1, n);
for (int i = 1; i <= m; i++) {
int p;
scanf("%d", &p);
flag = 0;
// if (p == n) {
// puts("-1");
// } else {
int pos = query2(1, p + 1, n, a[p]);
printf("%d\n", pos);
// }
if (i != m) {
int t, v;
scanf("%d%d", &t, &v);
update(1, t, v);
a[t] = v;
}
}
return 0;
}
签到题。太多种写法了,如果不会请自行看答案正确的同学的代码。
其实是作者认为自己的标程太复杂
KMP+逆波兰
好像不用KMP直接暴力匹配也可以过
注意%0的时候也是输出"Error!"
#include
#include
#include
#include
#include
using namespace std;
typedef long long ll;
const ll N=1e4+10;
const ll mod=1e9+7;
stack<ll> num;
stack<char> op;
int ne[N];
char s[N], p[N],temp[N];
int n;
int m;
bool flag=true;
ll kmp(char p[],int n){
for (int i = 2, j = 0; i <= n; i ++ )
{
while (j && p[i] != p[j + 1]) j = ne[j];
if (p[i] == p[j + 1]) j ++ ;
ne[i] = j;
}
for (int i = 1, j = 0; i <= m; i ++ )
{
while (j && s[i] != p[j + 1]) j = ne[j];
if (s[i] == p[j + 1]) j ++ ;
if (j == n)
{
return i-n+1;
}
}
return 0;
}
void eval()
{
if(num.empty()){
cout<<"Error!";
flag=false;
return;
}
ll b = num.top(); num.pop();
if(num.empty()){
cout<<"Error!";
flag=false;
return;
}
ll a = num.top(); num.pop();
if(op.empty()){
cout<<"Error!";
flag=false;
return;
}
char c = op.top(); op.pop();
ll x;
if (c == '+') x = ((a + b)%mod+mod)%mod;
else if (c == '-') x = ((a - b)%mod+mod)%mod;
else if (c == '*'){
x = ((a %mod* b%mod)%mod+mod)%mod;
}
else {
if(b==0){
cout<<"Error!";
flag=false;
return;
}else{
if(b<0) b=-b;
x = (a % b + b ) % b;
}
}
//cout<
num.push(x);
}
int main()
{
cin>>m>>s+1;
unordered_map<char, int> pr{
{
'+', 1}, {
'-', 1}, {
'*', 2}, {
'%', 2}};
string str;
cin >> str;
int len=str.size();
for (int i = 0; i < len; i ++ )
{
char c = str[i];
if (c>='a'&&c<='z')
{
ll x = 0, j = i;
int ans=1;
while (j < str.size() && str[j]>='a'&& str[j]<='z')
temp[ans++]=str[j++];
x = kmp(temp , ans-1);
i = j - 1;
num.push(x);
}
else if (c == '(') op.push(c);
else if (c == ')')
{
while (op.top() != '(') eval();
op.pop();
}
else
{
while (op.size() && op.top() != '(' && pr[op.top()] >= pr[c])
{
eval();
if(!flag) return 0;
}
op.push(c);
}
if(!flag) return 0;
}
while (op.size()) {
eval();
if(!flag) return 0;
}
if(flag) cout << (num.top() % mod + mod) % mod;
return 0;
}
这题真心不难,就是题目有点长而已。
# include
# include
# include
# include
# include
typedef long long ll;
int n, m, k;
const int N = 1e2 + 5;
int a[N];
int b[N];
int c[N];
int d[N];
int v[N];
struct Custom {
int time;
int id;
bool operator < (Custom rhs) {
if (time == rhs.time) {
return time < rhs.time;
} else {
return id < rhs.id;
}
}
};
Custom custom[N];
void solve() {
scanf("%d%d%d", &n, &m, &k);
for (int i = 1; i <= N - 2; i++) {
custom[i].id = i;
}
for (int i = 1; i <= n; i++) {
scanf("%d", &custom[i].time);
}
for (int i = 1; i <= n; i++) {
scanf("%d", &v[i]);
}
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
}
for (int i = 1; i <= n; i++) {
scanf("%d", &b[i]);
}
for (int i = 1; i <= n; i++) {
scanf("%d", &c[i]);
}
for (int i = 1; i <= n; i++) {
scanf("%d", &d[i]);
}
ll sum = 0;
int inda = 1;
int indb = 1;
int indc = 1;
int indd = 1;
//std::sort(custom + 1, custom + n + 1);
for (int i = 1; i <= n; i++) {
if (inda > n || indb > n) {
sum -= k / 2;
continue;
}
if (v[i] == 1) {
if (a[inda] <= custom[i].time + m && b[indb] <= custom[i].time + m && c[indc] <= custom[i].time + m) {
sum += k;
inda++;
indb++;
indc++;
} else {
sum -= k / 2;
}
} else if (v[i] == 2) {
if (a[inda] <= custom[i].time + m && b[indb] <= custom[i].time + m && d[indd] <= custom[i].time + m) {
sum += k;
inda++;
indb++;
indd++;
} else {
sum -= k / 2;
}
} else {
if (a[inda] <= custom[i].time + m && b[indb] <= custom[i].time + m && c[indc] <= custom[i].time + m && d[indd] <= custom[i].time + m) {
sum += k;
inda++;
indb++;
indc++;
indd++;
} else {
sum -= k / 2;
}
}
}
printf("%lld\n", sum);
}
int main() {
int T;
std::cin >> T;
while (T--) {
solve();
}
return 0;
}
此题改编自Codeforces Round #674 (Div. 3)的F
本来一开始想用英文出这题的,但是后来感觉这题有点难所以把英文砍了
此题涉及的物理知识可能存在问题,如果感兴趣请查阅专业资料
d p [ i ] [ j ] dp[i][j] dp[i][j]( i i i从下标 1 1 1开始计数, 1 ≤ j ≤ ∣ t ∣ 1\le j \le |t| 1≤j≤∣t∣)表示前 i i i个字符串中,匹配了 t t t的前 j j j位。状态转移方程为( t t t从下标 1 1 1开始计数):
d p [ 0 ] [ 0 ] = 1 dp[0][0]=1 dp[0][0]=1
每一次循环, d p [ i ] dp[i] dp[i]中的每一个元素值等于 d p [ i − 1 ] dp[i-1] dp[i−1]中每一个元素的值
d p [ i ] [ 1 ] = ( d p [ i − 1 ] [ 1 ] + ( s [ i ] = = t [ 1 ] ) ) % M o d dp[i][1] = (dp[i-1][1] + (s[i]==t[1])) \% Mod dp[i][1]=(dp[i−1][1]+(s[i]==t[1]))%Mod
d p [ i ] [ 2 ] = ( d p [ i − 1 ] [ 2 ] + ( s [ i ] = = t [ 2 ] ) ∗ d p [ i − 1 ] [ 1 ] ) % M o d dp[i][2] = (dp[i-1][2] + (s[i]==t[2]) * dp[i-1][1]) \% Mod dp[i][2]=(dp[i−1][2]+(s[i]==t[2])∗dp[i−1][1])%Mod
d p [ i ] [ 3 ] = ( d p [ i − 1 ] [ 3 ] + ( s [ i ] = = t [ 3 ] ) ∗ d p [ i − 1 ] [ 2 ] ) % M o d dp[i][3] = (dp[i-1][3] + (s[i]==t[3]) * dp[i-1][2]) \% Mod dp[i][3]=(dp[i−1][3]+(s[i]==t[3])∗dp[i−1][2])%Mod
…
d p [ i ] [ m ] = ( d p [ i − 1 ] [ 3 ] + ( s [ i ] = = t [ m ] ) ∗ d p [ i − 1 ] [ m − 1 ] ) % M o d dp[i][m] = (dp[i-1][3] + (s[i]==t[m]) * dp[i-1][m-1]) \% Mod dp[i][m]=(dp[i−1][3]+(s[i]==t[m])∗dp[i−1][m−1])%Mod
特判 ‘?’
i f ( s [ i ] = = ′ ? ′ ) if\ (s[i]=='?') if (s[i]==′?′)
d p [ i ] [ 0 ] = 2 ∗ d p [ i − 1 ] [ 0 ] % M o d dp[i][0] = 2 * dp[i - 1][0] \% Mod dp[i][0]=2∗dp[i−1][0]%Mod
d p [ i ] [ 1 ] = ( 2 ∗ d p [ i − 1 ] [ 1 ] + d p [ i − 1 ] [ 0 ] ) % M o d dp[i][1] = (2 * dp[i-1][1] + dp[i-1][0]) \% Mod dp[i][1]=(2∗dp[i−1][1]+dp[i−1][0])%Mod
d p [ i ] [ 2 ] = ( 2 ∗ d p [ i − 1 ] [ 2 ] + d p [ i − 1 ] [ 1 ] ) % M o d dp[i][2] = (2 * dp[i-1][2] + dp[i-1][1]) \% Mod dp[i][2]=(2∗dp[i−1][2]+dp[i−1][1])%Mod
d p [ i ] [ 3 ] = ( 2 ∗ d p [ i − 1 ] [ 3 ] + d p [ i − 1 ] [ 2 ] ) % M o d dp[i][3] = (2 * dp[i-1][3] + dp[i-1][2]) \% Mod dp[i][3]=(2∗dp[i−1][3]+dp[i−1][2])%Mod
…
d p [ i ] [ m ] = ( 2 ∗ d p [ i − 1 ] [ m ] + d p [ i − 1 ] [ m − 1 ] ) % M o d dp[i][m] = (2 * dp[i-1][m] + dp[i-1][m-1]) \% Mod dp[i][m]=(2∗dp[i−1][m]+dp[i−1][m−1])%Mod
求 d p [ n ] [ m ] dp[n][m] dp[n][m]( n n n是 s s s的长度, m m m是 t t t的长度)。
#include
#include
#include
#include
const int mod = 998244353;
const int N = 1e5 + 5;
char s[N];
char t[N];
int n;
int m;
std::vector<std::vector<int> > dp(N, std::vector<int>(8));
int main() {
int T;
std::cin >> T;
while (T--) {
scanf("%s%s", s + 1, t + 1);
n = strlen(s + 1);
m = strlen(t + 1);
//auto <--> std::vector >::iterator
for (auto it = dp.begin(); it != dp.begin() + n + 2; it++) {
std::fill(it->begin(), it->end(), 0);
}
dp[0][0] = 1;
for (int i = 1; i <= n; i++)
{
dp[i] = dp[i - 1];
if (s[i] == '?')
{
dp[i][0] = 2ll * dp[i - 1][0] % mod;
for (int j = 1; j <= m; j++) {
dp[i][j] = (2ll * dp[i - 1][j] + dp[i - 1][j-1]) % mod;
}
}
else
{
for (int j = 1; j <= m; j++) {
dp[i][j] = (1ll * dp[i - 1][j] + (s[i] == t[j]) * dp[i - 1][j-1]) % mod;
}
}
}
printf("%d\n", dp[n][m]);
}
return 0;
}
也可以用滚动数组降维
# include
# include
# include
const int mod = 998244353;
const int N = 1e5 + 5;
char s[N];
char t[N];
int dp[10];
int temp[10];
int n;
int m;
int main() {
int T;
std::cin >> T;
while (T--) {
scanf("%s%s", s + 1, t + 1);
n = strlen(s + 1);
m = strlen(t + 1);
memset(dp, 0, sizeof(int) * (m + 2));
dp[0] = 1;
for (int i = 1; i <= n; i++) {
if (s[i] == '?') {
for (int j = 0; j < m; j++) {
temp[j] = dp[j];
}
dp[0] = 2ll * dp[0] % mod;
for (int j = 1; j <= m; j++) {
dp[j] = (2ll * dp[j] + temp[j-1]) % mod;
}
} else {
for (int j = 0; j < m; j++) {
temp[j] = dp[j];
}
for (int j = 1; j <= m; j++) {
dp[j] = (1ll * dp[j] + (s[i] == t[j]) * temp[j-1]) % mod;
}
}
}
printf("%d\n", dp[m]);
}
return 0;
}
不要直接memset/清零整个二维dp数组,这样会超时。
时间复杂度 O ( T n m ) O(Tnm) O(Tnm)
签到题。就是数 2 2 2的多少次方。
老生长谈的问题,当输入的数据过多的时候输入输出的卡常不能忽略。
#include
using namespace std;
typedef long long LL;
int main()
{
int T;
scanf("%d", &T);
while (T--) {
LL n;
scanf("%lld", &n);
int cnt = 0;
while (n > 4) {
cnt++;
n /= 2;
}
printf("%d %d %d\n", cnt + 3, cnt + 4, 1);
}
return 0;
}
import java.util.*;
import java.io.*;
public class Main {
public static void solve() {
int T = nextInt();
while (T-- > 0) {
long n = nextLong();
int cnt = 3;
while (n > 4) {
cnt++;
n >>= 1;
}
int cnt2 = cnt + 1;
out.println(cnt + " " + cnt2 + " 1");
}
}
public static void main(String[] args) {
reader = new BufferedReader(new InputStreamReader(System.in));
tokenizer = null;
out = new PrintWriter(System.out);
solve();
out.close();
}
static BufferedReader reader;
static StringTokenizer tokenizer;
static PrintWriter out;
static int nextInt(){
return Integer.parseInt(next());
}
static long nextLong(){
return Long.parseLong(next());
}
static double nextDouble(){
return Double.parseDouble(next());
}
static String next(){
while (tokenizer == null || !tokenizer.hasMoreTokens()) {
try {
tokenizer = new StringTokenizer(reader.readLine());
} catch (IOException e) {
throw new RuntimeException(e);
}
}
return tokenizer.nextToken();
}
}
更多写法请参考去年新生赛预选赛的A
这题真的是数论的基础题目,公式全都摆出来了,按照公式打就行
# include
# include
# include
typedef long long ll;
const int mod = 1e9 + 9;
const int Sq5 = 383008016; // sqrt(5)
const int A = 691504013; // (1 + sqrt(5)) / 2
const int B = 308495997; // (1 - sqrt(5)) / 2
const int C = 276601605; // 1 / sqrt(5)
const int Inv2 = 500000005; // 1 / 2
const int Inv10 = 100000001; // 1 / 10
int qpow(int a, ll b) {
int ans = 1 % mod;
while (b) {
if (b & 1) {
ans = 1ll * ans * a % mod;
}
a = 1ll * a * a % mod;
b >>= 1;
}
return ans;
}
int main() {
ll n;
int a, b, c;
while (~scanf("%lld%d%d%d", &n, &a, &b, &c)) {
int qan = qpow(A, n);
int qbn = qpow(B, n);
int a1 = 1ll * qan * C % mod;
int a2 = 1ll * qbn * C % mod;
int ans1 = (1ll * a1 - a2 + mod) % mod * b % mod;
int a3 = 1ll * (Sq5 - 1) * qan % mod * Inv2 % mod * C % mod;
int a4 = 1ll * (Sq5 + 1) * qbn % mod * Inv2 % mod * C % mod;
int ans2 = (1ll * a3 + a4) % mod * a % mod;
int a5 = 1ll * (5 + Sq5) * Inv10 % mod * qan % mod;
int a6 = 1ll * (5 - Sq5) * Inv10 % mod * qbn % mod;
int ans3 = (1ll * a5 + a6 - 1 + mod) % mod * c % mod;
int ans = (1ll * ans1 + ans2 + ans3 + mod) % mod;
printf("%d\n", ans);
}
return 0;
}
乘法的时候注意先将数据变成long long型,不然会超出int的范围。
在模数下 ( ϕ 1 ) n (\phi_1)^n (ϕ1)n不一定大于 ( ϕ 2 ) n (\phi_2)^n (ϕ2)n,模数下用减法需要先加上mod在取模。例如:(a-b+mod)%mod。全场唯一开了此题的队伍就是因为这里没有注意所以没有过。
还可以用广义欧拉降幂将 n n n模一个 ϕ ( m o d ) \phi(mod) ϕ(mod)。这里的 ϕ \phi ϕ为欧拉函数。
# include
# include
# include
typedef long long ll;
const double phi1 = (1 + sqrt(5)) * 0.5;
const double phi2 = (1 - sqrt(5)) * 0.5;
const int mod = 1e9 + 9;
const int Sq5 = 383008016; // sqrt(5)
const int A = 691504013; // (1 + sqrt(5) / 2
const int B = 308495997; // (1 - sqrt(5) / 2
const int C = 276601605; // 1 / sqrt(5)
const int Inv2 = 500000005; // 1 / 2
const int Inv10 = 100000001; // 1 / 10
int qpow(int a, ll b) {
int ans = 1 % mod;
while (b) {
if (b & 1) {
ans = 1ll * ans * a % mod;
}
a = 1ll * a * a % mod;
b >>= 1;
}
return ans;
}
int main() {
ll n;
int a, b, c;
while (~scanf("%lld%d%d%d", &n, &a, &b, &c)) {
n = n % (mod - 1);
int qan = qpow(A, n);
int qbn = qpow(B, n);
int a1 = 1ll * qan * C % mod;
int a2 = 1ll * qbn * C % mod;
int ans1 = (1ll * a1 - a2 + mod) % mod * b % mod;
int a3 = 1ll * (Sq5 - 1) * qan % mod * Inv2 % mod * C % mod;
int a4 = 1ll * (Sq5 + 1) * qbn % mod * Inv2 % mod * C % mod;
int ans2 = (1ll * a3 + a4) % mod * a % mod;
int a5 = 1ll * (5 + Sq5) * Inv10 % mod * qan % mod;
int a6 = 1ll * (5 - Sq5) * Inv10 % mod * qbn % mod;
int ans3 = (1ll * a5 + a6 - 1 + mod) % mod * c % mod;
int ans = (1ll * ans1 + ans2 + ans3 + mod) % mod;
printf("%d\n", ans);
}
return 0;
}
我说的是不要轻易用新生赛的代码,没有说矩阵快速幂就不行,题解中的第一种方法确实会超时,但是题解中的第二种方法改改还是可以过的
# include
# include
# include
# include
typedef long long ll;
const ll mod = 1e9 + 9;;
struct Node {
ll m[3][3];
Node operator * (const Node& rhs) {
Node t = {
0};
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
for (int k = 0; k < 3; k++) {
t.m[i][j] += m[i][k] * rhs.m[k][j] % mod;
if (t.m[i][j] >= mod) {
t.m[i][j] -= mod;
}
}
}
}
return t;
}
};
const Node I = {
1, 0, 0, 0, 1, 0, 0, 0, 1};
Node qpow(Node x, ll p) {
Node ans = I;
while (p) {
if (p & 1) {
ans = ans * x;
}
x = x * x;
p >>= 1;
}
return ans;
}
int main() {
int a = 1;
int b = 1;
int f0;
int f1;
int c;
ll n;
while (~scanf("%lld%d%d%d", &n, &f0, &f1, &c)) {
Node T = {
a, b, 1, 1, 0, 0, 0, 0, 1};
Node ans = qpow(T, n);
int res = (ans.m[1][0] * f1 % mod + ans.m[1][1] * f0 % mod + ans.m[1][2] * c % mod) % mod;
printf("%d\n", res);
}
return 0;
}
# include
# include
# include
typedef long long ll;
const ll mod = 1e9 + 9;;
ll qpow(ll x, ll p, int Mod = mod) {
ll ans = 1 % Mod;
x %= Mod;
while (p) {
if (p & 1) {
ans = ans * x % Mod;
}
x = x * x % Mod;
p >>= 1;
}
return ans;
}
ll lcm(ll a, ll b) {
return a / std::__gcd(a, b) * b;
}
ll pFac[105][2];
int getFactors(ll n) {
int pCnt = 0;
for (ll i = 2; i * i <= n; ++i) {
if (n % i) {
continue;
}
pFac[pCnt][0] = i;
pFac[pCnt][1] = 0;
while (n % i == 0) {
n /= i;
pFac[pCnt][1]++;
}
pCnt++;
}
if (n > 1) {
pFac[pCnt][0] = n;
pFac[pCnt++][1] = 1;
}
return pCnt;
}
int Legendre(ll a, ll p) {
if (qpow(a, (p - 1) >> 1, p) == 1) {
return 1;
}
return -1;
}
ll find_loop(ll n, ll a = 1, ll b = 1) {
int cnt = getFactors(n);
ll c = a * a + b * 4;
ll ans = 1, record;
for (int i = 0; i < cnt; ++i) {
if (pFac[i][0] == 2) {
record = 3 * 2 * 2;
} else if (c % pFac[i][0] == 0) {
record = pFac[i][0] * (pFac[i][0] - 1);
} else if (Legendre(c, pFac[i][0]) == 1) {
record = pFac[i][0] - 1;
} else {
record = (pFac[i][0] - 1) * (pFac[i][0] + 1);
}
for (int j = 1; j < pFac[i][1]; ++j) {
record *= pFac[i][0];
}
ans = lcm(ans, record);
}
return ans;
}
struct Node {
ll m[3][3];
Node operator * (const Node& rhs) {
Node t = {
0};
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
for (int k = 0; k < 3; k++) {
t.m[i][j] += m[i][k] * rhs.m[k][j] % mod;
if (t.m[i][j] >= mod) {
t.m[i][j] -= mod;
}
}
}
}
return t;
}
};
const Node I = {
1, 0, 0, 0, 1, 0, 0, 0, 1};
Node qpow(Node x, ll p) {
Node ans = I;
while (p) {
if (p & 1) {
ans = ans * x;
}
x = x * x;
p >>= 1;
}
return ans;
}
int main() {
int a = 1;
int b = 1;
int f0;
int f1;
int c;
ll n;
ll loop = find_loop(mod);
while (~scanf("%lld%d%d%d", &n, &f0, &f1, &c)) {
Node T = {
a, b, 1, 1, 0, 0, 0, 0, 1};
n = n % loop;
Node ans = qpow(T, n);
int res = (ans.m[1][0] * f1 % mod + ans.m[1][1] * f0 % mod + ans.m[1][2] * c % mod) % mod;
printf("%d\n", res);
}
return 0;
}
通项公式证明过程:(有很多种证明过程,我只会这一种)
( f n + 1 f n 0 f n f n − 1 0 c c 0 ) = ( 1 1 1 1 0 0 0 0 1 ) ∗ ( f n f n − 1 0 f n − 1 f n − 2 0 c c 0 ) = . . . \begin{pmatrix} f_{n+1} & f_{n} & 0 \\ f_{n} & f_{n-1} & 0 \\ c & c & 0 \end{pmatrix} = \begin{pmatrix} 1 & 1 & 1 \\ 1 & 0 & 0 \\ 0 & 0 & 1 \end{pmatrix} * \begin{pmatrix} f_{n} & f_{n-1} & 0 \\ f_{n-1} & f_{n-2} & 0 \\ c & c & 0 \end{pmatrix} =... ⎝⎛fn+1fncfnfn−1c000⎠⎞=⎝⎛110100101⎠⎞∗⎝⎛fnfn−1cfn−1fn−2c000⎠⎞=...
= ( 1 1 1 1 0 0 0 0 1 ) n ∗ ( f 1 f 0 0 f 0 f − 1 0 c c 0 ) = { \begin{pmatrix} 1 & 1 & 1 \\ 1 & 0 & 0 \\ 0 & 0 & 1 \end{pmatrix}}^n * \begin{pmatrix} f_{1} & f_{0} & 0 \\ f_{0} & f_{-1} & 0 \\ c & c & 0 \end{pmatrix} =⎝⎛110100101⎠⎞n∗⎝⎛f1f0cf0f−1c000⎠⎞
令 A = A= A= ( 1 1 1 1 0 0 0 0 1 ) \begin{pmatrix} 1 & 1 & 1 \\ 1 & 0 & 0 \\ 0 & 0 & 1 \end{pmatrix} ⎝⎛110100101⎠⎞,对转移矩阵 A A A进行相似对角化。
特征值 λ 1 = 1 \lambda_1 = 1 λ1=1,特征向量 ξ 1 \xi_1 ξ1 = = = ( − 1 , − 1 , 1 ) T \begin{pmatrix} -1 ,& -1, & 1 \end{pmatrix}^T (−1,−1,1)T
特征值 λ 2 = 1 + 5 2 \lambda_2 = \frac{1+\sqrt5}{2} λ2=21+5,特征向量 ξ 2 \xi_2 ξ2 = = = ( 1 + 5 , 2 , 0 ) T \begin{pmatrix} 1+\sqrt5, & 2, & 0 \end{pmatrix}^T (1+5,2,0)T
特征值 λ 3 = 1 − 5 2 \lambda_3 = \frac{1-\sqrt5}{2} λ3=21−5,特征向量 ξ 3 \xi_3 ξ3 = = = ( 1 − 5 , 2 , 0 ) T \begin{pmatrix} 1-\sqrt5, & 2, & 0 \end{pmatrix}^T (1−5,2,0)T
P − 1 A P = Λ P^{-1}AP=\Lambda P−1AP=Λ
P = ( ξ 1 , ξ 2 , ξ 3 ) = ( − 1 1 + 5 1 − 5 − 1 2 2 1 0 0 ) P=\begin{pmatrix} \xi_1 ,& \xi_2, & \xi_3 \end{pmatrix}=\begin{pmatrix} -1 & 1+\sqrt5 & 1-\sqrt5 \\ -1 & 2 & 2 \\ 1 & 0 & 0 \end{pmatrix} P=(ξ1,ξ2,ξ3)=⎝⎛−1−111+5201−520⎠⎞
P − 1 = ( 0 0 1 5 10 5 − 5 20 5 + 5 20 − 5 10 5 + 5 20 5 − 5 20 ) P^{-1}=\begin{pmatrix} 0 & 0 & 1 \\ \frac{\sqrt5}{10} & \frac{5-\sqrt5}{20} & \frac{5+\sqrt5}{20} \\ -\frac{\sqrt5}{10} & \frac{5+\sqrt5}{20} & \frac{5-\sqrt5}{20} \end{pmatrix} P−1=⎝⎜⎛0105−1050205−5205+51205+5205−5⎠⎟⎞
Λ = ( λ 1 0 0 0 λ 2 0 0 0 λ 3 ) = ( 1 0 0 0 1 + 5 2 0 0 0 1 − 5 2 ) \Lambda=\begin{pmatrix} \lambda_1 & 0 & 0 \\ 0 & \lambda_2 & 0 \\ 0 & 0 & \lambda_3 \end{pmatrix}=\begin{pmatrix} 1 & 0 & 0 \\ 0 & \frac{1+\sqrt5}{2} & 0 \\ 0 & 0 & \frac{1-\sqrt5}{2} \end{pmatrix} Λ=⎝⎛λ1000λ2000λ3⎠⎞=⎝⎜⎛100021+500021−5⎠⎟⎞
A = P Λ P − 1 A=P\Lambda P^{-1} A=PΛP−1, A n = P Λ n P − 1 A^n=P\Lambda^n P^{-1} An=PΛnP−1
f n = ( A n ) 21 ∗ f 1 + ( A n ) 22 ∗ f 0 + ( A n ) 23 ∗ c f_n=(A^n)_{21}*f_1+(A^n)_{22}*f_0+(A^n)_{23}*c fn=(An)21∗f1+(An)22∗f0+(An)23∗c
这里的 ( A n ) i j (A^n)_{ij} (An)ij代表的是 A n A^n An的第 i i i行第 j j j列的元素。化简即可得到通项公式。
你可以在Symnolab Math Solver在线计算矩阵的逆、特征值、特征向量。