n < = 1 0 18 n<=10^{18} n<=1018,保证 g c d ( a , b ) = 1 gcd(a,b)=1 gcd(a,b)=1
考虑计算大于等于 t t t的第一个满足 f ( x , k ) = s f(x,k)=s f(x,k)=s的数
我们可以从高位到低位贪心去填,每一位从零开始依次尝试,看剩余位用剩余数字去填能否大于等于 t t t
我们设这个过程叫 n e x t ( t , k , s ) next(t,k,s) next(t,k,s)
因为答案不会超过 1 e 18 1e18 1e18,经过 x j b xjb xjb计算后得出 s s s 最多只有 400 400 400
单次 n e x t next next操作的复杂度可以做到 O ( l o g n ) O(logn) O(logn)
令 g [ s ] g[s] g[s] 表示大于等于 n n n的满足 f ( x , a ) = f ( x , b ) = s f(x,a)=f(x,b)=s f(x,a)=f(x,b)=s 的整数 x x x
初始时 对于所有 s s s, g [ s ] = n g[s]=n g[s]=n
然后每次找最小的 g [ s ] = t g[s]=t g[s]=t,令 w = m a x ( n e x t ( t , a , s ) , n e x t ( t , b , s ) ) w=max(next(t,a,s),next(t,b,s)) w=max(next(t,a,s),next(t,b,s))
如果 t = w t=w t=w那 t t t就是一个合法答案,显然这种做法下这是最优答案,直接输出
否则将 g [ s ] g[s] g[s]设为 w w w,重复上述过程
为什么这样做是可以的?
我们考虑对于一个区间 [ l , r ] [l,r] [l,r]满足 f ( x , a ) = s f(x,a)=s f(x,a)=s的数和 f ( x , b ) = s f(x,b)=s f(x,b)=s数集,将他们排序之后会形如 a , b , a , a , a , b , b , a , a a,b,a,a,a,b,b,a,a a,b,a,a,a,b,b,a,a等,发现求 n e x t next next的过程相当于求有多少 a , b a,b a,b交替出现,如上例就变成了 a , b , a , b , a a,b,a,b,a a,b,a,b,a,只进行了 5 5 5次 n e x t next next操作
显然交替次数的级别不会高于 a , b a,b a,b任意一种数的出现的次数,由于进位的存在使得这两种数出现趋向随机,而根据生日悖论,出现矛盾的位置一般不超过出现次数的平方,所以寻找次数的上界显然为 O ( n ) O(\sqrt n) O(n)
于是总的复杂度上界约为 O ( s n l o g n ) O(s\sqrt nlogn) O(snlogn)
看起来过不去
但是对于不同的 s s s它的寻找次数会有较大差异,对于有些 s s s寻找次数会非常少,加上求 n e x t next next的常数较小,最终实现起来的速度是很快的
#include
using namespace std;
typedef long long ll;
typedef pair<ll,int> pi;
#define mp make_pair
#define fi first
#define se second
priority_queue<pi,vector<pi>,greater<pi> >q;
pi x;
ll xx,n,ans,tmp;
int cnt,sum,a[70],b[70],ss,i,A,B;
ll ne(ll x,int k,int s){
memset(b,0,sizeof(b));
xx=x;cnt=sum=0;
while (xx) sum+=(a[++cnt]=xx%k),xx/=k;
if (sum==s) return x;cnt++;
a[cnt]=0;
for (int i=cnt;i;i--){
s-=a[i],b[i]=a[i],ss=s;
int j=i-1;
for (;j;j--){
if (ss>a[j] && a[j]<k-1) break;
if (ss<a[j]){
j=0;
break;
}
ss-=a[j];
}
if (!j){
s--,b[i]++;
break;
}
}
for (int i=1;s;i++){
if (i>cnt) return 1e18;
if (s>=k-1-b[i]) s-=k-1-b[i],b[i]=k-1;
else b[i]+=s,s=0;
}
ans=0;
for (int i=cnt;i;i--) ans=ans*k+b[i];
return ans;
}
int main(){
scanf("%lld%d%d",&n,&A,&B);
if (!n) return puts("0"),0;
for (i=1;i<=400;i++) q.push(mp(max(ne(n,A,i),ne(n,B,i)),i));
while (1){
x=q.top(),q.pop();
tmp=max(ne(x.fi,A,x.se),ne(x.fi,B,x.se));
if (x.fi==tmp) return printf("%lld",tmp),0;
q.push(mp(tmp,x.se));
}
}