已知 f,g,degf=n,degg=m(m≤n) f , g , deg f = n , deg g = m ( m ≤ n )
求唯一的 q,r q , r ,使得 f=q×g+r f = q × g + r ,其中 degr<m deg r < m
例: f(x)=x4+x3+2x2+4x+2,g(x)=x2+x+3 f ( x ) = x 4 + x 3 + 2 x 2 + 4 x + 2 , g ( x ) = x 2 + x + 3
f(x)=(x2−1)g(x)+5x+5 f ( x ) = ( x 2 − 1 ) g ( x ) + 5 x + 5
其中 q(x)=x2+1,r(x)=5x+5 q ( x ) = x 2 + 1 , r ( x ) = 5 x + 5
f=q×g+r f = q × g + r
发现 r r 的存在使得很不好进行运算,考虑消去 r r 的影响
由于 degq×g=degf deg q × g = deg f ,将上式翻转,有
fR(x)=(qR×gR)(x)+xn−degrrR(x) f R ( x ) = ( q R × g R ) ( x ) + x n − deg r r R ( x )
R R 指将多项式翻转
degr≤degg−1=m−1⇒n−m−1≤n−degr deg r ≤ deg g − 1 = m − 1 ⇒ n − m − 1 ≤ n − deg r
xn−degrrR(x)≡0(modxn−m−1) x n − deg r r R ( x ) ≡ 0 ( mod x n − m − 1 )
fR(x)≡(qR×gR)(x)(modxn−m+1) f R ( x ) ≡ ( q R × g R ) ( x ) ( mod x n − m + 1 )
至此,完全消去 r r 的影响
此时对 gR g R 求逆即可解出 qR q R ,然后回代得到 r r
T(n)=O(nlogn) T ( n ) = O ( n log n )
代码如下:
#include
using namespace std;
const int N = 1000010 , mod = 998244353 , G = 3;
int A[N] , B[N] , revA[N] , revB[N];
int rev[N];
int a[N] , b[N] , c[N];
int n , m;
int read() {
int ans = 0 , flag = 1;
char ch = getchar();
while(ch > '9' || ch < '0') {if(ch=='-') flag = -1; ch = getchar();}
while(ch <= '9' && ch >= '0') {ans = ans * 10 + ch - '0'; ch = getchar();}
return ans * flag;
}
int qpow(int a , int b) {
int ans = 1;
while(b) {
if(b & 1) ans = 1ll * ans * a % mod;
a = 1ll * a * a % mod;
b >>= 1;
}
return ans;
}
void dft(int *now , int n , int f) {
for(int i = 0 ; i < n ; ++ i) {if(i < rev[i]) swap(now[i] , now[rev[i]]);}
for(int i = 1 ; i < n ; i <<= 1) {
int gn = qpow(G , (mod - 1) / (i<<1));
if(f != 1) gn = qpow(gn , mod - 2);
for(int j = 0 ; j < n ; j += (i << 1)) {
int x , y , g = 1;
for(int k = 0 ; k < i ; ++ k , g = 1ll * gn *g % mod) {
x = now[j + k]; y = 1ll * g * now[i + j + k] % mod;
now[j + k] = (x + y) % mod;
now[i + j + k] = ((x - y) % mod + mod) % mod;
}
}
}
if(f != 1) {
int ny = qpow(n , mod - 2);
for(int i = 0 ; i < n ; ++ i)
now[i] = 1ll * ny * now[i] % mod;
}
}
void work(int deg , int *a , int *b) {
if(deg == 1) {b[0] = qpow(a[0] , mod - 2); return;}
work((deg + 1) >> 1 , a , b);
int l = 0 , nn , n = deg * 2;
for(nn = 1 ; nn < n ; nn <<= 1) ++ l;
for(int i = 0 ; i < nn ; ++ i)
rev[i] = (rev[i>>1]>>1) | ((i & 1) << (l - 1));
for(int i = 0 ; i < nn ; ++ i) c[i] = i < deg ? a[i] : 0;
for(int i = deg ; i < nn ; ++ i) c[i] = 0;
dft(b , nn , 1); dft(c , nn , 1);
for(int i = 0 ; i < nn ; ++ i) b[i] = 1ll * ((2 - 1ll * c[i] * b[i] % mod ) %mod + mod ) % mod * b[i] % mod;
dft(b , nn , -1);
for(int i = deg ; i < nn ; ++ i) b[i] = 0;
}
int main() {
freopen("in" , "r" , stdin);
n = read(); m = read();
int l = 0 , nn;
for(nn = 1 ; nn < n * 2; nn <<= 1) ++ l;
for(int i = 0 ; i <= n ; ++ i) A[i] = revA[n - i] = read();
for(int i = 0 ; i <= m ; ++ i) B[i] = revB[m - i] = read();
work(n - m + 1 , revB , b);
memset(rev , 0 , sizeof(rev));
for(int i = 0 ; i < nn ; ++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
dft(revA , nn , 1); dft(b , nn , 1);
for(int i = 0 ; i < nn ; ++ i) c[i] = 1ll * revA[i] * b[i] %mod;
dft(c , nn , -1);
for(int i = n - m + 1 ; i < nn ; ++ i) c[i] = 0;
reverse(c , c + n - m + 1);
for(int i = 0 ; i < n - m + 1 ; ++ i) printf("%d ", c[i]);
puts("");
dft(B , nn , 1); dft(c , nn , 1);
for(int i = 0 ; i < nn ; ++ i) B[i] = 1ll * B[i] * c[i] % mod;
dft(B , nn , -1);
for(int i = 0 ; i <= n ; ++ i) A[i] = (A[i] - B[i] + mod) % mod;
for(int i = 0 ; i < m ; ++ i) printf("%d ",A[i]);
puts("");
return 0;
}