题意:给出一个正定 n × n n \times n n×n矩阵,和一个 n n n维的向量 b b b,现在叫你找到一个 x 1 , x 2 , x 3 , . . . , x n x_1,x_2,x_3,...,x_n x1,x2,x3,...,xn满足如下条件:
然后叫你求 ( ∑ i = 1 n b i x i ) 2 (\sum_{i=1}^{n}b_ix_i)^2 (∑i=1nbixi)2在模998244353的值。
其他限制条件就不说了。
做法:
首先你要知道对于 n n n阶正定的实矩阵等价于 n n n阶对称实矩阵。
然后上面的第二个条件应该知道是一个正定二次型,即 x T A x ≤ 1 x^TAx\leq1 xTAx≤1
我们这里设 x x x是一个列向量, b b b是一个列向量。
其实我也不怎么会这样,大一学的都忘了。
然后前置知识,不等式约束条件下的最值问题(KKT),矩阵微积分,这里就不放连接了,可以自己去维基百科或者知乎。
然后高斯消元就行了。
#include "bits/stdc++.h"
using namespace std;
inline int read() {
int x = 0;
bool f = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') f = 0;
for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
if (f) return x;
return 0 - x;
}
#define ll long long
#define SZ(x) ((int)((x).size()))
#define all(x) (x).begin(),(x).end()
const int maxn = 2e2 + 10;
const ll mod = 998244353;
const double eps = 1e-8;
#define lowbit(x) (x&-x)
ll ksm(ll a, ll n) {
ll ans = 1;
while (n) {
if (n & 1) ans = ans * a % mod;
n >>= 1;
a = a * a % mod;
}
return ans;
}
ll b[maxn], a[maxn][maxn], inv_a[maxn][maxn];
int n;
void solve() {
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) inv_a[i][j] = (i == j);
}
for (int i = 1; i <= n; i++) {
int t;
for (t = i; t <= n; t++) if (a[t][i] != 0) break;
for (int j = 1; j <= n; j++) {
swap(a[i][j], a[t][j]);
swap(inv_a[i][j], inv_a[t][j]);
}
if (a[i][i] == 0) {
return;
}
ll inv = ksm(a[i][i], mod - 2);
for (int j = 1; j <= n; j++) {
a[i][j] = inv * a[i][j] % mod;
inv_a[i][j] = inv * inv_a[i][j] % mod;
}
for (int k = 1; k <= n; k++) {
if (k == i) continue;
ll tmp = a[k][i];
for (int j = 1; j <= n; j++) {
a[k][j] = (a[k][j] - a[i][j] * tmp % mod + mod) % mod;
inv_a[k][j] = (inv_a[k][j] - inv_a[i][j] * tmp % mod + mod) % mod;
}
}
}
ll tmp[maxn] = {0};
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
tmp[i] = (tmp[i] + b[j] * inv_a[j][i] % mod) % mod;
}
}
ll ans = 0;
for (int i = 1; i <= n; i++) ans = (ans + tmp[i] * b[i] % mod) % mod;
printf("%lld\n", ans);
}
int main() {
while (~scanf("%d", &n)) {
for (int i = 1; i <= n; i++)
for (int j = 1; j <= n; j++) {
scanf("%lld", &a[i][j]);
a[i][j] = (a[i][j] + mod) % mod;
}
for (int i = 1; i <= n; i++) {
scanf("%lld", &b[i]);
b[i] = (b[i] + mod) % mod;
}
solve();
}
return 0;
}