欧几里德算法又称辗转相除法,用于计算两个整数a,b的最大公约数。
基本算法:设a=qb+r
,其中a,b,q,r
都是整数,则gcd(a,b)=gcd(b,r)
,即gcd(a,b)=gcd(b,a%b)
。
算法的实现:
最简单的方法就是应用递归算法,代码如下:
int gcd(int a,int b)
{
if(b==0)
return a;
return
gcd(b,a%b);
}
代码可优化如下:
int gcd(int a,int b)
{
return b ? gcd(b,a%b) : a;
}
当然你也可以用迭代形式:
int Gcd(int a, int b)
{
while(b != 0)
{
int r = b;
b = a % b;
a = r;
}
return a;
}
扩展的欧几里得算法用于计算满足形如a*x+b*y=c
的方程的解
首先,我们需要先判断这个方程有没有解
对于形如a*x+b*y=c
的方程,有解的条件是c是a和b最大公约数的倍数
即c = n*gcd(a, b)(n > 0)
关于然后计算gcd(a, b),上面有说明,这里不再赘述
在判定有解之后,我们需要计算出一组满足条件的(x, y),由于a, b, c都是gcd(a, b)的整数倍,我们可以将它们都缩小gcd(a, b)倍,即A=a/gcd(a,b), B=b/gcd(a,b), C=c/gcd(a,b);
化简为A*x+B*y=C,而且gcd(A, B) == 1
此时,我们可以先求A*x+B*y=1的解(x’, y’),然后将其扩大C倍,最后要求的解就是(x, y)=(C*x’, C*y’)
接下来就是研究如何解A*x+B*y=1
假设A>B>0,我们设
A*x[1]+B*y[1]=gcd(A, B); B*x[2]+(A mod B)*y[2]=gcd(B, A mod B)
已知gcd(A, B)==gcd(B, A mod B)
A * x[1] + B * y[1] = B * x[2] + (A mod B) * y[2] => A * x[1] + B * y[1] = B * x[2] + (A - kB) * y[2] // A = kB + r => A * x[1] + B * y[1] = A * y[2] + B * x[2] - kB * y[2] => A * x[1] + B * y[1] = A * y[2] + B * (x[2] - ky[2]) => x[1] = y[2], y[1] = (x[2] - ky[2])
利用这个性质,我们可以递归的去求解(x,y)。
其终止条件为gcd(A, B)=B,此时对应的(x,y)=(0,1)
对应代码如下
pair<long long, long long > extend_gcd(long long a, long long b){
pair<long long, long long> tmp;
if (a%b == 0){
return pair<long long , long long>(0, 1);
}
tmp = extend_gcd(b, a%b);
long long tmp_x = tmp.first, tmp_y = tmp.second;
x = tmp_y;
y = tmp_x-(a/b)*tmp_y;
return pair<long long , long long>(x, y);
}
这样我们就能求出一组满足条件的解(x, y)
但如果要保证x是最小非负数,我们又该需要怎么做呢
如果那样的话,我们需要将(A,B,x’,y’)扩充为一个解系。
由于A B是互质的,所以可以将A*x’+B*y’=1扩展为:
Ax'+By'+(u+(-u))AB=1 => (x' + uB)*A + (y' - uA)*B = 1 => X = x' + uB, Y = y' - uA
可以求得最小的X为(x’+uB) mod B,(x’+uB>0)
同时我们还需要将X扩大C倍,因此最后解为:
x = (x'*C') mod B'
若x<0,则不断累加B,直到x>0为止。
总体代码如下
long long solve(long long s1, long long s2, long long v1, long long v2, long long m){
long long a = v1-v2;
long long b = m;
long long c = s2-s1;
if (a < 0){
a += m;
}
long long d = gcd(a, b);
if (c%d){
return -1;
}
a /= d;
b /= d;
c /= d;
pair<long long, long long > tmp = extend_gcd(a, b);
long long x = tmp.first;
x = (x*c)%b;
while(x < 0) x += b;
return x;
}
#include <iostream>
#include <cstdio>
using namespace std;
long long x, y;
long long gcd(long long a, long long b){
if (!b) return a;
return gcd(b, a%b);
}
pair<long long, long long > extend_gcd(long long a, long long b){
pair<long long, long long> tmp;
if (a%b == 0){
return pair<long long , long long>(0, 1);
}
tmp = extend_gcd(b, a%b);
long long tmp_x = tmp.first, tmp_y = tmp.second;
x = tmp_y;
y = tmp_x-(a/b)*tmp_y;
return pair<long long , long long>(x, y);
}
long long solve(long long s1, long long s2, long long v1, long long v2, long long m){
long long a = v1-v2;
long long b = m;
long long c = s2-s1;
if (a < 0){
a += m;
}
long long d = gcd(a, b);
if (c%d){
return -1;
}
a /= d;
b /= d;
c /= d;
pair<long long, long long > tmp = extend_gcd(a, b);
long long x = tmp.first;
x = (x*c)%b;
while(x < 0) x += b;
return x;
}
int main(){
long long s1, s2, v1, v2, m, k;
//s1+v1*t=s2+v2*t-k*m
cin >> s1 >> s2 >> v1 >> v2 >> m;
cout << solve(s1, s2, v1, v2, m);
return 0;
}