matrix

题意

定义一个矩阵价值为它的不相同的行的个数
给出n*m大小的矩阵,求它的所有子矩阵的价值

题解

这个问题相当于对于每个(p,S)(p为左端点所在列,S为一个字符串(em…这里跳了一步,我们可以把数字序列看成字符串))在多少个(x,y)中满足 ∃ z ∈ [ x , y ] \exist z ∈[x,y] z[x,y],从z行p列开始的字符串和S相同
对于p=1,我们可以这样,将这N行看做是N个字符串,然后插入一个trie树中,这个trie树ch指针用map保存(因为字符集大小太大),并且用一个set存这个节点有多少个行有相同的信息,每次插入的时候,这一行的贡献就是他和他的前驱之间的距离乘上他和他的后继之间的距离
感性理解一下,这样就可以覆盖所有的对数了
并且存下这个节点的答案贡献
对于p=2,我们将trie树的第一层节点合并,那么像0 0 1 ;1 0 0这样的第二层就会变成相同的,于是我们就把这样相同的节点合并。我们使用启发式合并,每次节点数都可以减少一半,就可以在log内完成了,每次合并的时候也是暴力地把两个set小的合进大的,同时计算新的贡献,再减去原来的贡献,我们就得到了现在答案相比第一次答案的变化量,对于p=3也是同理的
p=1时 O ( n m l o g 2 n ) O(nmlog^2n) O(nmlog2n),之后每列都是 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)
综上总时间复杂度 O ( n m l o g 2 n ) O(nmlog^2n) O(nmlog2n)

#include
#include
#include
#include
#include
using namespace std;
typedef long long ll;
typedef set<int>::iterator sit;
typedef map<int,int>::iterator mit;
const int N=5e5+5;
int n,m;
int tcnt=1,rt=1;
ll now;
struct node{
    ll val;
    map<int,int>ch;
    set<int>*s;
    node(){s=new set<int>();}
    void Insert(int x){
        int l=x+1,r=n-x;
        sit it=s->insert(x).first,it1;
        if(it!=s->begin()){
            it1=it;
            it1--;
            l=x-*it1;
        }
        it1=it;
        it1++;
        if(it1!=s->end()){
            r=*it1-x;
        }
        //printf("%d\n",l*r);
        val+=1ll*l*r;
        now+=1ll*l*r;
    }
 }trie[N];
int Merge(int x,int y){
    if(!x) return y;
    if(!y) return x;
    if(trie[x].ch.size()<trie[y].ch.size())
        swap(x,y);
    for(mit it=trie[y].ch.begin();it!=trie[y].ch.end();it++){
        trie[x].ch[it->first]=Merge(trie[x].ch[it->first],it->second);
    }
    if(trie[x].s->size()<trie[y].s->size()){
        swap(trie[x].s,trie[y].s);
        swap(trie[x].val,trie[y].val);
    }
    for(sit it=trie[y].s->begin();it!=trie[y].s->end();it++){
        trie[x].Insert(*it);
    }
    now-=trie[y].val;
    return x;
}
void Read(int &x){
    x=0;
    char c=getchar();
    while(c<'0'||c>'9')
        c=getchar();
    while(c>='0'&&c<='9')
        x=x*10+c-'0',c=getchar();
}
int main()
{
    Read(n),Read(m);
    //scanf("%d%d",&n,&m);
    for(int i=0;i<n;i++){
        int p=rt;
        for(int j=0;j<m;j++){
            int x;
            Read(x);
            //scanf("%d",&x);
            if(trie[p].ch.find(x)==trie[p].ch.end())
                trie[p].ch[x]=++tcnt;
            p=trie[p].ch[x];
            trie[p].Insert(i);
        }
    }
    ll ans=0;
    ans+=now;
    for(int i=1;i<m;i++){
        int nrt=0;
        for(mit it=trie[rt].ch.begin();it!=trie[rt].ch.end();it++){
            nrt=Merge(nrt,it->second);
        }
        rt=nrt;
        now-=trie[rt].val;
        ans+=now;
    }
    printf("%lld\n",ans);
}

你可能感兴趣的:(数据结构,字符串)