同余最短路常用于解决这样一类问题:
有 n n n个正整数 a 1 , a 2 , a 3 , ⋯ , a n a_1, a_2,a_3 , \cdots, a_n a1,a2,a3,⋯,an,设:
x 1 a 1 + x 2 a 2 + x 3 a 3 + ⋯ + x n a n = k ( x 1 , x 2 , x 3 , ⋯ , x n ∈ N ) x_1a_1+x_2a_2+x_3a_3+\cdots+x_na_n=k\;\;\;\;\;\;\;\;(x_1,x_2,x_3,\cdots,x_n\in N) x1a1+x2a2+x3a3+⋯+xnan=k(x1,x2,x3,⋯,xn∈N)
即 使用无限个 a 1 , a 2 , a 3 , ⋯ , a n a_1,a_2,a_3,\cdots,a_n a1,a2,a3,⋯,an进行拼凑,然后对 k k k的可能值进行各种询问(例如询问 k k k在 [ l , r ] [l,r] [l,r]区间内的可能值个数,询问 k k k最大的无法拼凑的值,询问某个定值 k k k能否被拼凑出……)
取其中一个正整数作为剩余系,例如取 a 1 a_1 a1,设 f ( i ) = min { k ∣ k m o d a 1 = i } f(i)=\min\{k | k \bmod a_1 = i\} f(i)=min{k∣kmoda1=i} ,其中 i = 0 , 1 , 2 , ⋯ , a 1 − 1 i =0,1,2,\cdots,a_1-1 i=0,1,2,⋯,a1−1, f ( i ) f(i) f(i)为即 可拼凑出的 使得 k m o d a 1 = i k \bmod a_1 = i kmoda1=i 的 最小的 k k k。
首先根据定义易知: f ( ( i + a j ) m o d a 1 ) = min { f ( i ) + a j } f((i+a_j)\bmod a_1 )=\min\{f(i)+a_j\} f((i+aj)moda1)=min{f(i)+aj} 且 f ( 0 ) = 0 f(0)=0 f(0)=0,可以发现这和最短路的求法相同,
故可建图 G = { V , E } G=\{V,E\} G={V,E},其中 V = { 0 , 1 , 2 , ⋯ , a 1 − 1 } V=\{0,1,2,\cdots,a_1-1\} V={0,1,2,⋯,a1−1}, E = { < u , v , w > ∣ ( u + a j ) m o d a 1 = v ∧ w = a j } E=\{|(u+a_j) \bmod a_1 = v \land w=a_j\} E={<u,v,w>∣(u+aj)moda1=v∧w=aj}
然后从点 0 0 0开始,求一遍最短路,到点 i i i的最短路距离即为 f ( i ) 。 f(i)。 f(i)。
求得所有的 f ( i ) f(i) f(i)后,
若 n n n个正整数有重复值,则最好要去重,且应当取最小的 a i a_i ai作为剩余系,这样图更小,求最短路的时间复杂度也更优。
结点数 ∣ V ∣ = min i = 1 n { a i } |V| = \min \limits_{i=1}^n \{a_i\} ∣V∣=i=1minn{ai},边数 ∣ E ∣ = n × ∣ V ∣ |E|=n\times |V| ∣E∣=n×∣V∣,利用堆优化Dijkstra算法求最短路时间复杂度为 O ( ∣ E ∣ + ∣ V ∣ log ∣ V ∣ ) O(|E|+|V|\log|V|) O(∣E∣+∣V∣log∣V∣)。
洛谷P2371 墨墨的等式
#include
using namespace std;
typedef long long LL;
typedef pair<LL, int> PLI;
const int INF = 0x3f3f3f3f;
const int LINF = 0x3f3f3f3f3f3f3f3f;
const int maxn = 5e5 + 10;
int n, a[maxn];
LL l, r, dist[maxn];
struct edge
{
int v;
int w;
};
vector<edge> g[maxn];
bool vis[maxn];
void Dijkstra(int s)
{
memset(dist, 0x3f, sizeof(dist));
priority_queue<PLI, vector<PLI>, greater<PLI> > q;
dist[s] = 0;
q.push(make_pair(dist[s], s));
while(!q.empty())
{
int u = q.top().second;
q.pop();
if(vis[u])
continue;
else
vis[u] = true;
for(auto e : g[u])
{
int v = e.v, w = e.w;
if(!vis[v] && dist[u] + w < dist[v])
{
dist[v] = dist[u] + w;
q.push(make_pair(dist[v], v));
}
}
}
}
int main()
{
scanf("%d %lld %lld", &n, &l, &r);
for(int i = 1; i <= n; i++)
scanf("%d", &a[i]);
sort(a + 1, a + n + 1);
n = unique(a + 1, a + n + 1) - (a + 1);
for(int i = 0; i < a[1]; i++)
for(int j = 2; j <= n; j++)
g[i].push_back(edge{(i + a[j]) % a[1], a[j]});
Dijkstra(0);
LL ans = 0;
for(int i = 0; i < a[1]; i++)
{
if(r >= dist[i])
ans += (r- dist[i]) / a[1] + 1;
if(l - 1 >= dist[i])
ans -= ((l - 1) - dist[i]) / a[1] + 1;
}
printf("%lld\n", ans);
return 0;
}