在Android端使用OpenGL的compute shader加速计算

compute shader的介绍和使用参看博客使用compute shader进行通用计算及示例

在Android端使用compute shader需要OpenGL ES3.1,即Android5.1以上的平台。可能是oples的原因,在Android上使用compute shader有几个注意要点:

  • 生成texture时不能使用glTexIamge2D, 需使用glTexStorage2D,然后使用glTexSubImage2D将数据赋予texture
  • 在写shader时,输入输出image2D需要显式地用限定符readonly或writeonly限定其读写权限,不然编译shader程序会失败
  • 注意生成texture时的level值需要与数据格式对应
layout(binding = 0, rgba32f) readonly uniform  image2D input_image;
layout(binding = 1, rgba32f) writeonly uniform  image2D output_image;

在Android上使用opengl最方便地做法就是就是使用GLSurfaceView生成EGL环境,具体用法不清楚的话可以参看网上教程,有很多,这里不再详述。这个例子仍然是生成模拟数据,然后通过compute shader对数据做一些加法后再读回。

先在Activity中设置EGL环境

public class ComputeActivity extends Activity {

    private GLSurfaceView glsv;

    @Override
    protected void onCreate(@Nullable Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_compute);
        glsv = findViewById(R.id.glsv);
        glsv.setEGLContextClientVersion(3);
        glsv.setRenderer(new ComputeRender(this));
        glsv.setRenderMode(GLSurfaceView.RENDERMODE_WHEN_DIRTY);
    }
}

在ComputeRender中生成模拟数据

    private FloatBuffer createInputBuffer() {
        FloatBuffer floatBuffer = FloatBuffer.allocate(mSize);
        for (int i = 0; i < mSize; i++) {
            floatBuffer.put(i);
        }
        floatBuffer.position(0);
        return floatBuffer;
    }

生成FrameBuffer和Texture

    public void createEnvi() {
        GLES31.glGenFramebuffers(1, fFrame, 0);
        GLES31.glBindFramebuffer(GLES31.GL_FRAMEBUFFER, fFrame[0]);
        GLES31.glGenTextures(3, fTexture, 0);
        for (int i = 0; i < 3; i++) {
            GLES31.glBindTexture(GLES31.GL_TEXTURE_2D, fTexture[i]);
            GLES31.glTexStorage2D(GLES31.GL_TEXTURE_2D, 1, GLES31.GL_RGBA32F, mWidth, mHeight);
            glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR);
            glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR);
            glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
            glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
            GLES31.glBindTexture(GL_TEXTURE_2D, 0);
        }
        GLES31.glFramebufferTexture2D(GLES31.GL_FRAMEBUFFER, GLES31.GL_COLOR_ATTACHMENT0,
                GLES31.GL_TEXTURE_2D, fTexture[0], 0);
        GLES31.glFramebufferTexture2D(GLES31.GL_FRAMEBUFFER, GLES31.GL_COLOR_ATTACHMENT1,
                GLES31.GL_TEXTURE_2D, fTexture[1], 0);
        GLES31.glFramebufferTexture2D(GLES31.GL_FRAMEBUFFER, GLES31.GL_COLOR_ATTACHMENT2,
                GLES31.GL_TEXTURE_2D, fTexture[2], 0);
    }

绑定数据和Texture

    private void transferToTexture(Buffer data, int texID) {
        GLES31.glBindTexture(GLES31.GL_TEXTURE_2D, texID);
        GLES31.glTexSubImage2D(GLES31.GL_TEXTURE_2D, 0, 0, 0, mWidth, mHeight, GLES31.GL_RGBA, GLES31.GL_FLOAT, data);
    }

创建并链接shader程序

#version 310 es

layout (local_size_x = 32, local_size_y = 32, local_size_z = 1) in;

uniform float v[1000];
layout(binding = 0, rgba32f) readonly uniform  image2D input_image;
layout(binding = 1, rgba32f) writeonly uniform  image2D output_image;

shared vec4 scanline[32][32];

void main(void)
{
    ivec2 pos = ivec2(gl_GlobalInvocationID.xy);
    scanline[pos.x][pos.y] = imageLoad(input_image, pos);
    barrier();
    vec4 data = scanline[pos.x][pos.y];
    data.r = data.r + v[999] ;
    data.g = data.g;
    data.b = data.b;
    data.a = data.a;
    imageStore(output_image, pos.xy, data);
}
    private void initGLSL() {
        mComputeProg = GLES31.glCreateProgram();
        String source = ShaderUtils.loadFromAssetsFile("compute.cs", mContext.getResources());
        ShaderUtils.vglAttachShaderSource(mComputeProg, GLES31.GL_COMPUTE_SHADER, source);
        GLES31.glLinkProgram(mComputeProg);
    }

执行计算

    private void performCompute(int inputTeture, int outputTexture) {
        GLES31.glUseProgram(mComputeProg);
        GLES31.glUniform1fv(GLES31.glGetUniformLocation(mComputeProg, "v"), mValueSize, mValueBuffer);

        GLES31.glBindImageTexture(0, inputTeture, 0, false, 0, GLES31.GL_READ_ONLY, GLES31.GL_RGBA32F);
        GLES31.glBindImageTexture(1, outputTexture, 0, false, 0, GLES31.GL_WRITE_ONLY, GLES31.GL_RGBA32F);

        GLES31.glDispatchCompute(1, 1, 1);
        GLES31.glMemoryBarrier(GLES31.GL_SHADER_IMAGE_ACCESS_BARRIER_BIT);
    }

读回数据

    @Override
    public void onDrawFrame(GL10 gl) {
        createEnvi();
        transferToTexture(mInputBuffer, fTexture[0]);
        FloatBuffer a0 = FloatBuffer.allocate(mSize);
        FloatBuffer a1 = FloatBuffer.allocate(mSize);
        FloatBuffer a2 = FloatBuffer.allocate(mSize);

        long begin = System.currentTimeMillis();

        performCompute(fTexture[0], fTexture[1]);
        performCompute(fTexture[1], fTexture[2]);

        Log.w(TAG, "total compute spent:" + (System.currentTimeMillis() - begin));
        GLES31.glReadBuffer(GLES31.GL_COLOR_ATTACHMENT0);
        GLES31.glReadPixels(0, 0, mWidth, mHeight, GLES31.GL_RGBA, GLES31.GL_FLOAT, a0);
        GLES31.glReadBuffer(GLES31.GL_COLOR_ATTACHMENT1);
        GLES31.glReadPixels(0, 0, mWidth, mHeight, GLES31.GL_RGBA, GLES31.GL_FLOAT, a1);
        GLES31.glReadBuffer(GLES31.GL_COLOR_ATTACHMENT2);
        GLES31.glReadPixels(0, 0, mWidth, mHeight, GLES31.GL_RGBA, GLES31.GL_FLOAT, a2);
        float[] o1 = a0.array();
        float[] o2 = a1.array();
        float[] o3 = a2.array();
    }

最后可以观察o1,o2,o3三个数据数据是否正确。经测试通过compute shader计算,运行200次计算着色器计算,也仅耗时5~7ms。因此用来做移动端深度学习加速完全可行。

全部代码

你可能感兴趣的:(移动端深度学习)