【CodeChef】Lucas Theorem

【题目链接】

  • 点击打开链接

【思路要点】

  • 考虑 subtask1 s u b t a s k 1 ,我们很容易可以得到一个动态规划的解法。
  • 注意到行与行之间转移的卷积本质,我们可以用FFT快速计算出DP数组的某一行,可以通过 subtask2 s u b t a s k 2
  • 原题中 N N 非常大,我们不可能求得DP数组的第 N N 行。
  • 考虑多项式 x(x+1)(x+2)(x+3)(x+p1) x ( x + 1 ) ( x + 2 ) ( x + 3 ) … ( x + p − 1 ) ,在模质数 p p 意义下,应当等于 xpx x p − x 。因为我们打表发现这两个多项式拥有 p p 个相同的根( 0,1,2,3,...,p1 0 , 1 , 2 , 3 , . . . , p − 1 )。
  • 因此令 A=x(x+1)(x+2)(x+3)(x+p1) A = x ( x + 1 ) ( x + 2 ) ( x + 3 ) … ( x + p − 1 ) ,令 a=Np,b=Nap a = ⌊ N p ⌋ , b = N − a p
  • x(x+1)(x+2)(x+3)(x+N)Aax(x+1)(x+2)(x+3)...(x+b)   (Mod p) x ( x + 1 ) ( x + 2 ) ( x + 3 ) … ( x + N ) ≡ A a ∗ x ( x + 1 ) ( x + 2 ) ( x + 3 ) . . . ( x + b )       ( M o d   p )
  • 乘号后面的部分可以用分治FFT计算得到,并且其次数不超过 p1 p − 1
  • 乘号前面的部分相邻两个系数非零的项次数差为 p1 p − 1
  • 特殊处理 b=p1 b = p − 1 的情况,对于其余部分,我们可以分别解决乘号前后的子问题,将它们的答案相乘得到答案。
  • 乘号前面的部分相当于在询问所有 (ai)(0ia) ( a i ) ( 0 ≤ i ≤ a ) 中,有多少不是 p p 的倍数。
  • 我们可以用Lucas定理解决这个问题:即 (ai)(a/pi/p)(a%pi%p)(Mod p) ( a i ) ≡ ( a / p i / p ) ∗ ( a % p i % p ) ( M o d   p )
  • 不难发现 (ai) ( a i ) 不是 p p 的倍数当且仅当 i i p p 进制表示每一位都不超过 a a p p 进制表示的对应位。
  • 进制转换即可。
  • 时间复杂度 O(|N|2+PLog2P) O ( | N | 2 + P L o g 2 P )

【代码】


#include

using namespace std;
const int MAXN = 262144;
const int MAXLOG = 30;
const int P = 1e9 + 7;
template <typename T> void chkmax(T &x, T y) {x = max(x, y); }
template <typename T> void chkmin(T &x, T y) {x = min(x, y); } 
template <typename T> void read(T &x) {
  x = 0; int f = 1;
  char c = getchar();
  for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
  for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
  x *= f;
}
template <typename T> void write(T x) {
  if (x < 0) x = -x, putchar('-');
  if (x > 9) write(x / 10);
  putchar(x % 10 + '0');
}
template <typename T> void writeln(T x) {
  write(x);
  puts("");
}
namespace FFT {
  const int MAXN = 262144;
  const long double pi = acosl(-1);
  struct point {long double x, y; };
  point operator + (point a, point b) {return (point) {a.x + b.x, a.y + b.y}; }
  point operator - (point a, point b) {return (point) {a.x - b.x, a.y - b.y}; }
  point operator * (point a, point b) {return (point) {a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x}; }
  point operator / (point a, long double x) {return (point) {a.x / x, a.y / x}; }
  int N, Log, home[MAXN];
  point tmp[MAXN];
  void FFTinit() {
      for (int i = 0; i < N; i++) {
          int tmp = i, ans = 0;
          for (int j = 1; j <= Log; j++) {
              ans <<= 1;
              ans += tmp & 1;
              tmp >>= 1;
          }
          home[i] = ans;
      }
  }
  void FFT(point *a, int mode) {
      for (int i = 0; i < N; i++)
          if (home[i] < i) swap(a[i], a[home[i]]);
      for (int len = 2; len <= N; len <<= 1) {
          point delta = (point) {cosl(2 * pi / len * mode), sinl(2 * pi / len * mode)};
          for (int i = 0; i < N; i += len) {
              point now = (point) {1, 0};
              for (int j = i, k = i + len / 2; k < i + len; j++, k++) {
                  point tmp = a[j];
                  point tnp = a[k] * now;
                  a[j] = tmp + tnp;
                  a[k] = tmp - tnp;
                  now = now * delta;
              }
          }
      }
      if (mode == -1) {
          for (int i = 0; i < N; i++)
              a[i] = a[i] / (4 * N);
      }
  }
  void times(int *a, int *b, int *c, int limit, int p) {
      N = 1, Log = 0;
      while (N < 2 * limit) {
          N <<= 1;
          Log++;
      }
      for (int i = 0; i < limit; i++)
          tmp[i] = (point) {(long double) (a[i] + b[i]), (long double) (a[i] - b[i])};
      for (int i = limit; i < N; i++)
          tmp[i] = (point) {0, 0};
      FFTinit();
      FFT(tmp, 1);
      for (int i = 0; i < N; i++)
          tmp[i] = tmp[i] * tmp[i];
      FFT(tmp, -1);
      for (int i = 0; i < N; i++)
          c[i] = (long long) (tmp[i].x + 0.5) % p;
  }
}
char s[MAXN];
int a[MAXN], f[MAXLOG][MAXN];
int n, len, p, bits[MAXN];
int modulo() {
  int r = 0;
  for (int i = len; i >= 1; i--)
      r = (r * 10 + a[i]) % p;
  return r;
}
void divide() {
  int r = 0;
  for (int i = len; i >= 1; i--) {
      r = r * 10 + a[i];
      a[i] = r / p;
      r %= p;
  }
  while (len && a[len] == 0) len--;
}
void work(int l, int r, int depth) {
  int len = r - l + 1;
  for (int i = 0; i <= 2 * len; i++)
      f[depth][i] = 0;
  if (l == r) {
      f[depth][0] = 1;
      if (l != 0) f[depth][1] = l;
      return;
  }
  int mid = (l + r) / 2;
  work(l, mid, depth);
  work(mid + 1, r, depth + 1);
  int lim = max(mid - l + 1, r - mid) + 1;
  FFT :: times(f[depth], f[depth + 1], f[depth], lim, p);
}
int main() {
  int T; read(T);
  while (T--) {
      scanf("\n%s", s + 1); read(p);
      len = strlen(s + 1);
      reverse(s + 1, s + len + 1);
      for (int i = 1; i <= len; i++)
          a[i] = s[i] - '0';
      n = 0;
      while (len != 0) {
          bits[++n] = modulo();
          divide();
      }
      bits[n + 1] = 0;
      if (bits[1] == p - 1) {
          bits[1] = 0;
          for (int i = 2; true; i++) {
              chkmax(n, i);
              if (++bits[i] == p) bits[i] = 0;
              else break;
          }
      }
      int ans = 1;
      for (int i = 2; i <= n; i++)
          ans = ans * (bits[i] + 1ll) % P;
      work(0, bits[1], 0);
      int tans = 0;
      for (int i = 0; i <= bits[1]; i++)
          if (f[0][i]) tans++;
      writeln(1ll * tans * ans % P);
  }
  return 0;
}

你可能感兴趣的:(【OJ】CodeChef,【类型】做题记录,【算法】FFT与NTT,【算法】找规律,【算法】高精度,【算法】Lucas定理,【算法】数学)