题目:http://acm.sgu.ru/problem.php?contest=0&problem=261
题意:给定n,a,p 求出x^n ≡ a(mod p)在模p意义下的所有解,其中p是素数
说明:
代码:
/* ID: [email protected] PROG: LANG: C++ */ #include<map> #include<set> #include<queue> #include<stack> #include<cmath> #include<cstdio> #include<vector> #include<string> #include<fstream> #include<cstring> #include<ctype.h> #include<iostream> #include<algorithm> #define INF (1<<30) #define PI acos(-1.0) #define mem(a, b) memset(a, b, sizeof(a)) #define rep(i, n) for (int i = 0; i < n; i++) #define debug puts("===============") using namespace std; typedef long long ll; //快速幂 ll pow_mod(ll a, ll n, ll m) { ll res = 1; while(n) { if (n & 1) res = res * a % m; n >>= 1; a = a * a % m; } return res; } //求原根 vector<ll> a; bool g_test(ll g, ll p) { for (ll i = 0; i < a.size(); i++) if (pow_mod(g, (p - 1) / a[i], p) == 1) return 0; return 1; } ll primitive_root(ll p) { a.clear(); ll tmp = p - 1; for (ll i = 2; i <= tmp / i; i++) if (tmp % i == 0) { //这里还可以用筛素数优化 a.push_back(i); while(tmp % i == 0) tmp /= i; } if (tmp != 1) a.push_back(tmp); ll g = 1; while(true) { if (g_test(g, p)) return g; g++; } } // 求离散对数 #define N 111111 struct node { ll x, id; bool operator < (const node & T) const { if (x == T.x) return id < T.id; return x < T.x; } }E[N]; ll discrete_log(ll x, ll n, ll m) { int s = sqrt(m + 0.5); for (; (ll) s * s <= m; ) s++; ll cur = 1; node tmp; for (int i = 0; i < s; i++) { tmp.id = i, tmp.x = cur; E[i] = tmp; cur = cur * x % m; } sort(E, E + s); ll mul = pow_mod(cur, m - 2, m) % m; // mul = 1 / (x^s) cur = 1; for (int i = 0; i < s; i++) { ll more = (ll) n * cur % m; tmp.id = -1, tmp.x = more; int pos = lower_bound(E, E + s, tmp) - E; if (E[pos].x == more) return i * s + E[pos].id; cur = cur * mul % m; } return -1; } //扩展欧几里得 ll extend_gcd(ll a, ll b, ll &x, ll &y) { if (b == 0) { x = 1, y = 0; return a; } else { ll r = extend_gcd(b, a % b, y, x); y -= x * (a / b); return r; } } //N次剩余 //给定n,a,p 求出x^n ≡ a(mod p)在模p意义下的所有解,其中p是素数 vector<ll> residue(ll p, ll n, ll a) { vector<ll> ret; if (a == 0) { ret.push_back(0); return ret; } ll g = primitive_root(p); ll m = discrete_log(g, a, p); if (m == -1) return ret; ll A = n, B = p - 1, C = m, x, y; ll d = extend_gcd(A, B, x, y); if (C % d != 0) return ret; x = x * (C / d) % B; ll delta = B / d; for (int i = 0; i < d; i++) { x = ((x + delta) % B + B) % B; ret.push_back(pow_mod(g, x, p)); } sort(ret.begin(), ret.end()); ret.erase(unique(ret.begin(), ret.end()), ret.end()); return ret; } int main () { ll n, a, p; scanf("%lld%lld%lld", &p, &n, &a); vector<ll> ret = residue(p, n, a); printf("%d\n", ret.size()); for (int i = 0; i < ret.size(); i++) printf("%d ", ret[i]); return 0; }