E. Sasha and Array 矩阵快速幂 + 线段树

时间:2019-09-12 13:09:06   收藏:0   阅读:86

E. Sasha and Array

这个题目没有特别难,需要自己仔细想想,一开始我想了一个方法,不对,而且还很复杂,然后lj提示了我一下说矩阵乘,然后再仔细想想就知道怎么写了。

这个就是直接把矩阵放到线段树里面去了。

注意优化,降低复杂度。

技术图片
#include <cstring>
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <queue>
#define inf 0x3f3f3f3f
using namespace std;
typedef long long ll;
const int mod=1e9+7;
struct mat
{
    ll m[3][3];
}unite,zero;

mat operator*(mat a,mat b){
    mat ans;
    for(int i=1;i<=2;i++){
        for(int j=1;j<=2;j++){
            ll x = 0;
            for(int k=1;k<=2;k++) x+=(a.m[i][k]*b.m[k][j])%mod;
            ans.m[i][j]=x%mod;
        }
    }
    return ans;
}
mat operator+(mat a,mat b){
    mat ans;
    for(int i=1;i<=2;i++)
        for(int j=1;j<=2;j++)
            ans.m[i][j]=(a.m[i][j]+b.m[i][j])%mod;
    return ans;
}

mat mod_pow(mat a,ll n){
    mat ans=unite;
    while(n){
        if(n&1) ans=ans*a;
        a=a*a;
        n>>=1;
    }
    return ans;
}

const int maxn=1e5+10;
mat sum[maxn*4],a,lazy[maxn*4];
ll v[maxn];

void init(){
    a.m[1][1]=a.m[1][2]=a.m[2][1]=1,a.m[2][2]=0;
    for(int i=0;i<3;i++) unite.m[i][i]=1;
    for(int i=0;i<3;i++)
        for(int j=0;j<3;j++) zero.m[i][j]=0;
}

void push_up(int id){
    sum[id]=sum[id<<1]+sum[id<<1|1];
}

void build(int id,int l,int r){
    lazy[id]=unite;
    if(l==r){
        sum[id]=mod_pow(a,v[l]);
        return ;
    }
    int mid=(l+r)>>1;
    build(id<<1,l,mid);
    build(id<<1|1,mid+1,r);
    push_up(id);
}

bool same(mat a,mat b){
    for(int i=1;i<=2;i++){
        for(int j=1;j<=2;j++){
            if(a.m[i][j]!=b.m[i][j]) return false;
        }
    }
    return true;
}

void push_down(int id){
    if(same(lazy[id],unite)) return ;
    sum[id<<1]=sum[id<<1]*lazy[id];
    sum[id<<1|1]=sum[id<<1|1]*lazy[id];
    lazy[id<<1]=lazy[id<<1]*lazy[id];
    lazy[id<<1|1]=lazy[id<<1|1]*lazy[id];
    lazy[id]=unite;
}

void update(int id,int l,int r,int x,int y,mat val){
    if(x<=l&&y>=r){
        lazy[id]=lazy[id]*val;
        sum[id]=sum[id]*val;
        return ;
    }
    push_down(id);
    int mid=(l+r)>>1;
    if(x<=mid) update(id<<1,l,mid,x,y,val);
    if(y>mid) update(id<<1|1,mid+1,r,x,y,val);
    push_up(id);
}

mat query(int id,int l,int r,int x,int y){
    if(x<=l&&y>=r) return sum[id];
    mat ans=zero;
    int mid=(l+r)>>1;
    push_down(id);
    if(x<=mid) ans=ans+query(id<<1,l,mid,x,y);
    if(y>mid) ans=ans+query(id<<1|1,mid+1,r,x,y);
    return ans;
}

int main(){
    init();
    int n,m;
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++) scanf("%lld",&v[i]);
    build(1,1,n);
    while(m--){
        int opt;
        ll l,r,x;
        scanf("%d",&opt);
        if(opt==1){
            scanf("%lld%lld%lld",&l,&r,&x);
            mat val=mod_pow(a,x);
            update(1,1,n,l,r,val);
        }
        else {
            scanf("%lld%lld",&l,&r);
            mat ans=query(1,1,n,l,r);
            printf("%lld\n",ans.m[1][2]);
        }
    }
    return 0;
}
View Code

 

评论(0
© 2014 mamicode.com 版权所有 京ICP备13008772号-2  联系我们:gaon5@hotmail.com
迷上了代码!