Quadratic Form(约束条件下的最值)牛客多校第一场D题

Quadratic Form

题意:给出一个正定 n × n n \times n n×n矩阵,和一个 n n n维的向量 b b b,现在叫你找到一个 x 1 , x 2 , x 3 , . . . , x n x_1,x_2,x_3,...,x_n x1,x2,x3,...,xn满足如下条件:

  • x 1 , x 2 , . . . , x n ∈ R x_1,x_2,...,x_n \in R x1,x2,...,xnR
  • ∑ i = 1 n ∑ j = 1 n A i , j x i x j ≤ 1 \sum_{i=1}^n\sum_{j=1}^n A_{i,j}x_ix_j \leq 1 i=1nj=1nAi,jxixj1
  • ∑ i = 1 n b i x i   i s   m a x i m u m \sum_{i=1}^{n}b_ix_i \ is \ maximum i=1nbixi is maximum

然后叫你求 ( ∑ i = 1 n b i x i ) 2 (\sum_{i=1}^{n}b_ix_i)^2 (i=1nbixi)2在模998244353的值。
其他限制条件就不说了。
做法:
首先你要知道对于 n n n阶正定的实矩阵等价于 n n n阶对称实矩阵。
然后上面的第二个条件应该知道是一个正定二次型,即 x T A x ≤ 1 x^TAx\leq1 xTAx1
我们这里设 x x x是一个列向量, b b b是一个列向量。
其实我也不怎么会这样,大一学的都忘了。
然后前置知识,不等式约束条件下的最值问题(KKT),矩阵微积分,这里就不放连接了,可以自己去维基百科或者知乎。

  • 构造拉格朗日函数 L ( x , λ ) = b T x + λ ( x T A x − 1 ) L(x,\lambda)=b^Tx +\lambda(x^TAx - 1) L(x,λ)=bTx+λ(xTAx1)然后对 x x x求导:
    如二次型求导不清楚的可以百度。(我百度了)
    ∂ L ( x , λ ) ∂ x = b + 2 λ A x \frac{\partial{L(x,\lambda)}}{\partial x}=b+2\lambda Ax xL(x,λ)=b+2λAx
  • 根据不等式约束条件下的最值情况下, b + 2 λ A x = 0 , λ ≥ 0 , λ ( x T A x − 1 ) = 0 b+2\lambda Ax=0,\lambda \geq0,\lambda (x^TAx- 1) = 0 b+2λAx=0,λ0,λ(xTAx1)=0
  • 根据上面的条件: x = − A − 1 b 2 λ x=\frac{-A^{-1}b}{2\lambda} x=2λA1b
  • 接下来我们找 λ \lambda λ的关系:把 x x x带入约束条件下: λ ( − A − 1 b 2 λ T A − A − 1 b 2 λ − 1 ) = 0 , = > λ = − b T ( A − 1 ) T b / 2 \lambda({\frac{-A^{-1}b}{2\lambda}}^TA\frac{-A^{-1}b}{2\lambda} - 1)=0,=>\lambda =-\sqrt{b^T(A^{-1})^Tb}/2 λ(2λA1bTA2λA1b1)=0,=>λ=bT(A1)Tb /2并且 A − 1 = ( A − 1 ) T A^{-1}=(A^{-1})^T A1=(A1)T
  • 最后,我们计算: b T x = b T A − 1 b b T A − 1 b = b T A − 1 b b^Tx=\frac{b^TA^{-1}b}{\sqrt{b^TA^{-1}b}}=\sqrt{b^TA^{-1}b} bTx=bTA1b bTA1b=bTA1b 因此答案就是 b T A − 1 b b^TA^{-1}b bTA1b

然后高斯消元就行了。

#include "bits/stdc++.h"

using namespace std;
inline int read() {
    int x = 0;
    bool f = 1;
    char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == '-') f = 0;
    for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
    if (f) return x;
    return 0 - x;
}
#define ll long long
#define SZ(x) ((int)((x).size()))
#define all(x) (x).begin(),(x).end()

const int maxn = 2e2 + 10;
const ll mod = 998244353;
const double eps = 1e-8;
#define lowbit(x) (x&-x)
ll ksm(ll a, ll n) {
    ll ans = 1;
    while (n) {
        if (n & 1) ans = ans * a % mod;
        n >>= 1;
        a = a * a % mod;
    }
    return ans;
}

ll b[maxn], a[maxn][maxn], inv_a[maxn][maxn];
int n;
void solve() {
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= n; j++) inv_a[i][j] = (i == j);
    }
    for (int i = 1; i <= n; i++) {
        int t;
        for (t = i; t <= n; t++) if (a[t][i] != 0) break;

        for (int j = 1; j <= n; j++) {
            swap(a[i][j], a[t][j]);
            swap(inv_a[i][j], inv_a[t][j]);
        }
        if (a[i][i] == 0) {
            return;
        }
        ll inv = ksm(a[i][i], mod - 2);
        for (int j = 1; j <= n; j++) {
            a[i][j] = inv * a[i][j] % mod;
            inv_a[i][j] = inv * inv_a[i][j] % mod;
        }
        for (int k = 1; k <= n; k++) {
            if (k == i) continue;
            ll tmp = a[k][i];
            for (int j = 1; j <= n; j++) {
                a[k][j] = (a[k][j] - a[i][j] * tmp % mod + mod) % mod;
                inv_a[k][j] = (inv_a[k][j] - inv_a[i][j] * tmp % mod + mod) % mod;
            }
        }
    }
    ll tmp[maxn] = {0};
    for (int i = 1; i <= n; i++) {
        for (int j = 1; j <= n; j++) {
            tmp[i] = (tmp[i] + b[j] * inv_a[j][i] % mod) % mod;
        }
    }
    ll ans = 0;
    for (int i = 1; i <= n; i++) ans = (ans + tmp[i] * b[i] % mod) % mod;
    printf("%lld\n", ans);
}

int main() {
    while (~scanf("%d", &n)) {
        for (int i = 1; i <= n; i++)
            for (int j = 1; j <= n; j++) {
                scanf("%lld", &a[i][j]);
                a[i][j] = (a[i][j] + mod) % mod;
            }

        for (int i = 1; i <= n; i++) {
            scanf("%lld", &b[i]);
            b[i] = (b[i] + mod) % mod;
        }

        solve();
    }
    return 0;
}

你可能感兴趣的:(数学数论)