【codeforces】528D. Fuzzy Search【FFT】

传送门:【codeforces】528D. Fuzzy Search

题目分析:

首先,我们先来了解一下 FFT 求字符串匹配的方法。

问题:求 B 串在 A 串中匹配的次数。

解法:
我们用 ai 表示串 A i 个位置的字符, bj 表示串 B j 个位置的字符。

A 串用多项式表示为:
A=ni=1xi=ni=1A(i)

B 串用多项式表示为:
B=mi=1yi=mi=1B(i)

A 串从第 l 位开始和 B 串匹配的贡献是:
Cost(l+m)=mi=1{ai=bi}A(l+(i1))B(m(i1))

这正是 FFT 可以做的事情(其中我们需要把其中一个串翻转,这样匹配过程正是一个卷积,为了方便,我们一般翻转模式串)。

我们考虑枚举字符集中每个字符,每个字符求一次 FFT ,可以得到这个字符对匹配的贡献。

然后我们看所有字符对某个位置的贡献之和是否为 B 串串长 m ,是的话则匹配次数加一。

检查完所有位置后, B 串在 A 串中匹配的次数以及匹配的位置我们也可以知道了。

回到本题。

考虑主串每一个位置放 A,C,G,T 四个字母中的其中一个时能否找到一个匹配的位置。假设位置 i 上放 A ,那么我们看 [ik+1,i+k1] 这个区间内是否有 A ,如果有则说明这个位置能和 A 匹配。

接下来我们枚举字符 A,C,G,T ,用 FFT 求一下每个字符对匹配的贡献,最后枚举位置累加一下所有字符对该位置的贡献,如果四个字符的匹配贡献之和等于模式串长度(说明匹配了),则答案加一。

my  code:

#include <stdio.h>
#include <string.h>
#include <set>
#include <map>
#include <math.h>
#include <vector>
#include <algorithm>
using namespace std ;

typedef long long LL ;

#define clr( a , x ) memset ( a , x , sizeof a )
#define cpy( a , x ) memcpy ( a , x , sizeof a )
#define clrs( a , x , size ) memset ( a , x , sizeof ( a[0] ) * ( size ) )
#define cpys( a , x , size ) memcpy ( a , x , sizeof ( a[0] ) * ( size ) )

const int MAXN = 600000 ;
const double pi = acos ( -1.0 ) ;

struct Complex {
    double r , i ;
    Complex () {}
    Complex ( double r , double i ) : r ( r ) , i ( i ) {}
    Complex operator + ( const Complex& p ) const {
        return Complex ( r + p.r , i + p.i ) ;
    }
    Complex operator - ( const Complex& p ) const {
        return Complex ( r - p.r , i - p.i ) ;
    }
    Complex operator * ( const Complex& p ) const {
        return Complex ( r * p.r - i * p.i , r * p.i + i * p.r ) ;
    }
} ;

Complex x1[MAXN] , x2[MAXN] ;
char s[MAXN] , p[MAXN] ;
int num[MAXN] ;
int pre[4] ;
int nxt[4] ;
int can[MAXN][4] ;
int cost[MAXN] ;
int n1 , n2 , k ;

void FFT ( Complex y[] , int n , int rev ) {
    for ( int i = 1 , j , k , t ; i < n ; ++ i ) {
        for ( j = 0 , t = i , k = n >> 1 ; k ; k >>= 1 , t >>= 1 ) {
            j = j << 1 | t & 1 ;
        }
        if ( i < j ) swap ( y[i] , y[j] ) ;
    }
    for ( int s = 2 , ds = 1 ; s <= n ; ds = s , s <<= 1 ) {
        Complex wn ( cos ( rev * 2 * pi / s ) , sin ( rev * 2 * pi / s ) ) ;
        for ( int k = 0 ; k < n ; k += s ) {
            Complex w ( 1 , 0 ) , t ;
            for ( int i = k ; i < k + ds ; ++ i ) {
                y[i + ds] = y[i] - ( t = w * y[i + ds] ) ;
                y[i] = y[i] + t ;
                w = w * wn ;
            }
        }
    }
    if ( rev < 0 ) {
        for ( int i = 0 ; i < n ; ++ i ) {
            y[i].r /= n ;
        }
    }
}

int get_idx ( char c ) {
    if ( c == 'A' ) return 0 ;
    if ( c == 'C' ) return 1 ;
    if ( c == 'G' ) return 2 ;
    if ( c == 'T' ) return 3 ;
}

void calc ( int n1 , int n2 , int n , int x ) {
    for ( int i = 0 ; i < n ; ++ i ) {
        x1[i] = Complex ( can[i][x] , 0 ) ;
        x2[i] = Complex ( i < n2 ? get_idx ( p[n2 - i - 1] ) == x : 0 , 0 ) ;
    }
    FFT ( x1 , n , 1 ) ;
    FFT ( x2 , n , 1 ) ;
    for ( int i = 0 ; i < n ; ++ i ) {
        x1[i] = x1[i] * x2[i] ;
    }
    FFT ( x1 , n , -1 ) ;
    for ( int i = 0 ; i < n ; ++ i ) {
        cost[i] += ( int ) ( x1[i].r + 0.5 ) ;
    }
}

void solve () {
    int n = 1 ;
    while ( n < n1 + n2 - 1 ) n <<= 1 ;
    clr ( pre , -1 ) ;
    clr ( nxt , -1 ) ;
    clr ( can , 0 ) ;
    clr ( cost , 0 ) ;
    for ( int i = 0 ; i < n1 ; ++ i ) {
        num[i] = get_idx ( s[i] ) ;
        int x = num[i] ;
        can[i][x] = 1 ;
        for ( int j = 0 ; j < 4 ; ++ j ) {
            if ( ~pre[j] && i - pre[j] <= k ) can[i][j] = 1 ;
        }
        pre[x] = i ;
    }
    for ( int i = n1 - 1 ; i >= 0 ; -- i ) {
        int x = num[i] ;
        for ( int j = 0 ; j < 4 ; ++ j ) {
            if ( ~nxt[j] && nxt[j] - i <= k ) can[i][j] = 1 ;
        }
        nxt[x] = i ;
    }
    for ( int i = 0 ; i < 4 ; ++ i ) {
        calc ( n1 , n2 , n , i ) ;
    }
    int ans = 0 ;
    for ( int i = n2 - 1 ; i < n1 ; ++ i ) {
        if ( cost[i] == n2 ) {
            ++ ans ;
        }
    }
    printf ( "%d\n" , ans ) ;
}

int main () {
    while ( ~scanf ( "%d%d%d%s%s" , &n1 , &n2 , &k , s , p ) ) solve () ;
    return 0 ;
}

你可能感兴趣的:(fft)