c++笔记--基于C++实现tensor合并

1--问题描述

        给定两个 NCHW 维度的 Blob,在 H 维度上进行拼接
        (N1, C1, H1, W1), (N2, C2, H2, W2) → (N3, C3, H1 + H2, W3)

2--实例代码

#include 
#include 
#include 
#include 
#include 
 
// c++实现tensor合并,在H维度
struct Tensor{
    Tensor(int n, int c, int h, int w){
        this->N = n;
        this->C = c;
        this->H = h;
        this->W = w;
        this->num = n*c*h*w;
        this->data = new float[this->num];
        // 随机初始化数据
        for(int i = 0; i < this->num; i++){
            data[i] = rand() % 10;
        }
    }
 
    float *data = nullptr;
    int N = 0;
    int C = 0;
    int H = 0;
    int W = 0;
    int num = 0;
 
    // 返回索引值为(ni, ci, hi, wi)的元素引用
    float &at(int ni, int ci, int hi, int wi) const {
        // idx = wi + hi*w + ci*h*w + ni*c*h*w // 类似于CUDA线程索引的计算
        return data[W * (H * (C * ni + ci) + hi) + wi]; // 计算对应地址的位置
    }
};
 
// 给定两个NCHW维度的Blob,在H维度上进行拼接
void Concat1(const Tensor *a, const Tensor *b, Tensor *c){
    for(int ni = 0; ni < a->N; ++ni){
        for(int ci = 0; ci < a->C; ++ci){
            for(int hi = 0; hi < a->H; ++hi){
                for(int wi = 0; wi < a->W; ++wi){
                    c->at(ni, ci, hi, wi) = a->at(ni, ci, hi, wi);
                }
            }
            for(int hi = 0; hi < b->H; ++hi){
                for(int wi = 0; wi < b->W; ++wi){
                    c->at(ni, ci, a->H+hi, wi) = b->at(ni, ci, hi, wi);
                }
            }  
        }
    }
}
 
void Concat2(const Tensor *a, const Tensor *b, Tensor *c){
    for(int ni = 0; ni < a->N; ++ni){
        for(int ci = 0; ci < a->C; ++ci){
            for(int hi = 0; hi < a->H; ++hi){
                int offseta = (hi * a->W + ci * a->H * a->W + ni * a->C * a->H * a->W)*sizeof(float);
                int offsetc = (hi * c->W + ci * c->H * c->W + ni * c->C * c->H * c->W)*sizeof(float);
                memcpy(c->data+offsetc, a->data+offseta, a->W*sizeof(float));
            }
            for(int hi = 0; hi < b->H; ++hi){
                int offsetc = ((hi + a->H) * c->W + ci * c->H * c->W + ni * c->C * c->H * c->W)*sizeof(float);
                int offsetb = (hi * b->W + ci*b->H*b->W + ni * b->C * b->H * b->W)*sizeof(float);
                memcpy(c->data+offsetc, c->data+offsetb, b->W*sizeof(float));
            }  
        }
    }
}

void Concat3(const Tensor *a, const Tensor *b, Tensor *c){
    for(int ni = 0; ni < a->N; ++ni){
        for(int ci = 0; ci < a->C; ++ci){    
            int offseta = (ci * a->H * a->W + ni * a->C * a->H * a->W)*sizeof(float);
            int offsetc1 = (ci * c->H * c->W + ni * c->C * c->H * c->W)*sizeof(float);
            memcpy(c->data+offsetc1, a->data+offseta, a->W*a->H*sizeof(float));

            int offsetc2 = offsetc1 + a->W*a->H*sizeof(float);
            int offsetb = (ci * b->H * b->W + ni * b->C * b->H * b->W)*sizeof(float);
            memcpy(c->data+offsetc2, c->data+offsetb, b->W*b->H*sizeof(float));
        }
    }
}
 
int main(int argc, char argv[]){
    srand(time(nullptr));
 
    int N1 = 1, C1 = 1, H1 = 2, W1 = 2;
    int N2 = 1, C2 = 1, H2 = 2, W2 = 2;
    Tensor *a = new Tensor(N1, C1, H1, W1);
    Tensor *b = new Tensor(N2, C2, H2, W2);
    Tensor *c = new Tensor(N1, C1, H1+H2, W1);
    // Tensor a(N1, C1, H1, W1);
    // Tensor b(N2, C2, H2, W2);
    // Tensor c(N1, C1, H1+H2, W1);
 
    Concat1(a, b, c);
    for(int n = 0; n < N1; n++){
        for(int channel = 0; channel < C1; channel++){
            for(int h = 0; h < H1+H2; h++){
                for(int w = 0; w < W1; w++){
                    std::cout << c->at(n, channel, h, w) << " ";
                }
                std::cout << std::endl;
            }
        }
    }

    std::cout << "-------------" << std::endl;
    Concat2(a, b, c);
    for(int n = 0; n < N1; n++){
        for(int channel = 0; channel < C1; channel++){
            for(int h = 0; h < H1+H2; h++){
                for(int w = 0; w < W1; w++){
                    std::cout << c->at(n, channel, h, w) << " ";
                }
                std::cout << std::endl;
            }
        }
    }

    std::cout << "-------------" << std::endl;
    Concat3(a, b, c);
    for(int n = 0; n < N1; n++){
        for(int channel = 0; channel < C1; channel++){
            for(int h = 0; h < H1+H2; h++){
                for(int w = 0; w < W1; w++){
                    std::cout << c->at(n, channel, h, w) << " ";
                }
                std::cout << std::endl;
            }
        }
    }
    return 0;
}

你可能感兴趣的:(C++复习笔记,c++)