You are given two huge binary integer numbers aa and bb of lengths nn and mm respectively. You will repeat the following process: if b>0 , then add to the answer the value a & band divide b by 2 rounding down (i.e. remove the last digit of bb ), and repeat the process again, otherwise stop the process.
The value a & b means bitwise AND of a and b . Your task is to calculate the answer modulo 998244353 .
Note that you should add the value a & b to the answer in decimal notation, not in binary. So your task is to calculate the answer in decimal notation. For example, if a=10102 (1010) and b=10002 (810) , then the value a & bwill be equal to 8 , not to 1000.
The first line of the input contains two integers n and m (1≤n,m≤200000 ) — the length of aa and the length of bb correspondingly.
The second line of the input contains one huge integer aa . It is guaranteed that this number consists of exactly nn zeroes and ones and the first digit is always 1 .
The third line of the input contains one huge integer b . It is guaranteed that this number consists of exactly mm zeroes and ones and the first digit is always 1 .
Print the answer to this problem in decimal notation modulo 998244353 .
Input
4 4
1010
1101
Output
12
Input
4 5
1001
10101
Output
11
The algorithm for the first example:
So the answer is 8+2+2+0=128+2+2+0=12 .
The algorithm for the second example:
So the answer is 1+8+1+0+1=111+8+1+0+1=11 .
#include
using namespace std;
#define pii pair
#define pb push_back
#define mem(a,b) memset(a,b,sizeof(a))
#define per(i,a,b) for(int i=a;i<=b;i++)
#define rep(i,a,b) for(int i=a;i>=b;i--)
#define all(x) x.begin(),x.end()
#define PER(i,x) for(auto i=x.begin();i!=x.end();i++)
#define PI acos(-1.0)
#define inf 0x3f3f3f3f
typedef long long ll;
const double eps=1.0e-5;
const int maxn=200000+10;
const long long mod=998244353;
int a[maxn],b[maxn],sum[maxn],n,m;
char s[maxn];
int main()
{
scanf("%d%d",&n,&m);
scanf("%s",s+1);
per(i,1,n) a[i]=(int)(s[i]-'0');
scanf("%s",s+1);
per(i,1,m) b[i]=(int)(s[i]-'0');
sum[0]=0;
per(i,1,m) sum[i]=sum[i-1]+b[i];
int cur=1;
ll ans=0;
rep(i,n,1){
if(a[i]==1) ans=(ans+((ll)sum[max(0,m-n+i)]*cur%mod))%mod;
cur=(2*cur)%mod;
}
printf("%lld\n",ans);
}