DX12_Mesh Shader Instance

之前我们介绍过DX12_Mesh Shaders Render,但是基于MeshShader我们能做的还很多,比如实例化和剔除(视锥与遮挡),这也就直接解决了现在主流的GPU-Driven管线方法,是不是一举两得了(毕竟MS就是变种的CS嘛)。那么我们一步步来,先来说一下Mesh Shader实例化如何实现吧。

DX12_Mesh Shader Instance_第1张图片

本部分主要基于之前文章拓展实例化部分的代码,具体流程想回顾的直接看以前文章即可。

一、实例化数据

传统实例化你肯定知道,

  • 一种是将实例化数据放在与VertexBuffer绑定位同级的管线布局上并设置管线布局,之后调用DrawInstance即可;
  • 另一种就是放到常量缓冲区上,调用DrawInstance后在Shader中使用SV_InstanceID/gl_InstanceIndex进行绘制即可。

其实Mesh Shader的实例化就是和第二种方式一样,使用实例化的数据直接在MS中生成对应的Meshlet数据使用PS接上即可,当然了这种方式和传统API的实例化还是有区别的:

  • 效率比传统实例化快(MS->PS > VS->PS,更不说把臃肿的TS与GS加上了),原因是Mesh Shader的数据量更适合硬件并行计算,充分发挥了GPU算力
  • 更加灵活,可拓展完全GPU-Driven的算法实现

说了这么多还是上代码把,这样更直观:
这一步很简单,就是的在CPU端创建实例化的SRV,然后更新数据

void D3D12MeshletInstancing::RegenerateInstances()
{
    m_updateInstances = true;

    const float radius = m_model.GetBoundingSphere().Radius;
    const float padding = 0.0f;
    const float spacing = (1.0f + padding) * radius * 2.0f;

    const uint32_t width = m_instanceLevel * 2 + 1;
    const float extents = spacing * m_instanceLevel;

    m_instanceCount = width * width * width;

    const uint32_t instanceBufferSize = (uint32_t)GetAlignedSize(m_instanceCount * sizeof(Instance));

    // 实例化数量改变时重新创建默认堆数据
    if (!m_instanceBuffer || m_instanceBuffer->GetDesc().Width < instanceBufferSize)
    {
        WaitForGpu();

        const CD3DX12_HEAP_PROPERTIES instanceBufferDefaultHeapProps(D3D12_HEAP_TYPE_DEFAULT);
        const CD3DX12_RESOURCE_DESC instanceBufferDesc = CD3DX12_RESOURCE_DESC::Buffer(instanceBufferSize);

        // 创建Buffer(常变数据,所以放共享显存中,最后析构再UnMap)
        ThrowIfFailed(m_device->CreateCommittedResource(
            &instanceBufferDefaultHeapProps,
            D3D12_HEAP_FLAG_NONE,
            &instanceBufferDesc,
            D3D12_RESOURCE_STATE_GENERIC_READ,
            nullptr,
            IID_PPV_ARGS(&m_instanceBuffer)
        ));

        const CD3DX12_HEAP_PROPERTIES instanceBufferUploadHeapProps(D3D12_HEAP_TYPE_UPLOAD);

        // 创建上传堆
        ThrowIfFailed(m_device->CreateCommittedResource(
            &instanceBufferUploadHeapProps,
            D3D12_HEAP_FLAG_NONE,
            &instanceBufferDesc,
            D3D12_RESOURCE_STATE_GENERIC_READ,
            nullptr,
            IID_PPV_ARGS(&m_instanceUpload)
        ));

        m_instanceUpload->Map(0, nullptr, reinterpret_cast(&m_instanceData));
    }
    
    // CPU更新实例化数据
    for (uint32_t i = 0; i < m_instanceCount; ++i)
    {
        XMVECTOR index = XMVectorSet(float(i % width), float((i / width) % width), float(i / (width * width)), 0);
        XMVECTOR location = index * spacing - XMVectorReplicate(extents);

        XMMATRIX world = XMMatrixTranslationFromVector(location);

        auto& inst = m_instanceData[i];
        XMStoreFloat4x4(&inst.World, XMMatrixTranspose(world));
        XMStoreFloat4x4(&inst.WorldInvTranspose, XMMatrixTranspose(XMMatrixInverse(nullptr, XMMatrixTranspose(world))));
    }
}

因DX12使用命令队列录制,我们还必须保证实例化数据在使用之前已经被正确的拷贝完毕,因此在绘制之前,需要使用屏障来同步显存数据:

   // 仅实例化场景变更时更新
    if (m_updateInstances)
    {
        const auto toCopyBarrier = CD3DX12_RESOURCE_BARRIER::Transition(m_instanceBuffer.Get(), D3D12_RESOURCE_STATE_GENERIC_READ, D3D12_RESOURCE_STATE_COPY_DEST);
        m_commandList->ResourceBarrier(1, &toCopyBarrier);
        m_commandList->CopyResource(m_instanceBuffer.Get(), m_instanceUpload.Get());
        const auto toGenericBarrier = CD3DX12_RESOURCE_BARRIER::Transition(m_instanceBuffer.Get(), D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_GENERIC_READ);
        m_commandList->ResourceBarrier(1, &toGenericBarrier);

        m_updateInstances = false;
    }

二、Instance Mesh Shader实现

主要就是加了SRV(t4),大家可以自定对比之前MS与本部分实例化的MS,主要就是布局和main中的实现,具体流程见注释。

#define ROOT_SIG "CBV(b0), \
                  RootConstants(b1, num32bitconstants=2), \
                  RootConstants(b2, num32bitconstants=3), \
                  SRV(t0), \
                  SRV(t1), \
                  SRV(t2), \
                  SRV(t3), \
                  SRV(t4)"

struct Constants
{
    float4x4 World;
    float4x4 WorldView;
    float4x4 WorldViewProj;
    uint     DrawMeshlets;
};

struct Instance
{
    float4x4 World;
    float4x4 WorldInvTranspose;
};

struct DrawParams
{
    uint InstanceCount;
    uint InstanceOffset;
};

struct MeshInfo
{
    uint IndexBytes;
    uint MeshletCount;
    uint MeshletOffset;
};

struct Vertex
{
    float3 Position;
    float3 Normal;
};

struct VertexOut
{
    float4 PositionHS   : SV_Position;
    float3 PositionVS   : POSITION0;
    float3 Normal       : NORMAL0;
    uint   MeshletIndex : COLOR0;
};

//此处可拓展做剔除等操作
struct Meshlet
{
    uint VertCount;
    uint VertOffset;
    uint PrimCount;
    uint PrimOffset;
};

ConstantBuffer<Constants> Globals             : register(b0);
ConstantBuffer<DrawParams> DrawParams          : register(b1);
ConstantBuffer<MeshInfo>  MeshInfo            : register(b2);

StructuredBuffer<Vertex>  Vertices            : register(t0);
StructuredBuffer<Meshlet> Meshlets            : register(t1);
ByteAddressBuffer         UniqueVertexIndices : register(t2);
StructuredBuffer<uint>    PrimitiveIndices    : register(t3);
StructuredBuffer<Instance> Instances           : register(t4);


// Data Loaders
uint3 UnpackPrimitive(uint primitive)
{
    // 从32位的uint数据中解压三角形(10 bit)
    return uint3(primitive & 0x3FF, (primitive >> 10) & 0x3FF, (primitive >> 20) & 0x3FF);
}

//获取三角形索引
uint3 GetPrimitive(Meshlet m, uint index)
{
    return UnpackPrimitive(PrimitiveIndices[m.PrimOffset + index]);
}

//获取顶点数组的索引,以便后续获取顶点属性数据
uint GetVertexIndex(Meshlet m, uint localIndex)
{
    localIndex = m.VertOffset + localIndex;

    if (MeshInfo.IndexBytes == 4) // 32-bit Vertex Indices
    {
        return UniqueVertexIndices.Load(localIndex * 4);
    }
    else // 16-bit Vertex Indices
    {
        // Byte address must be 4-byte aligned.
        uint wordOffset = (localIndex & 0x1);
        uint byteOffset = (localIndex / 2) * 4;

        // Grab the pair of 16-bit indices, shift & mask off proper 16-bits.
        uint indexPair = UniqueVertexIndices.Load(byteOffset);
        uint index = (indexPair >> (wordOffset * 16)) & 0xffff;

        return index;
    }
}

//顶点属性输出数据(类似VS输出)
VertexOut GetVertexAttributes(uint meshletIndex, uint vertexIndex)
{
    Vertex v = Vertices[vertexIndex];

    VertexOut vout;
    vout.PositionVS = mul(float4(v.Position, 1), Globals.WorldView).xyz;
    vout.PositionHS = mul(float4(v.Position, 1), Globals.WorldViewProj);
    vout.Normal = mul(float4(v.Normal, 0), Globals.World).xyz;
    vout.MeshletIndex = meshletIndex;

    return vout;
}

//MS函数主入口
[RootSignature(ROOT_SIG)]
[NumThreads(128, 1, 1)]
[OutputTopology("triangle")]
void main(
    uint gtid : SV_GroupThreadID,
    uint gid : SV_GroupID,
    out indices uint3 tris[126],
    out vertices VertexOut verts[64]
)
{
    //--------------------------------------------------------------------
    uint meshletIndex = gid / DrawParams.InstanceCount;
    Meshlet m = Meshlets[meshletIndex];

    // 实例数确定:一般情况下每个线程组只有一个实例
    uint startInstance = gid % DrawParams.InstanceCount;
    uint instanceCount = 1;

    // 最后一个Meshlet单独处理- 由一个线程组提交的多个实例
    if (meshletIndex == MeshInfo.MeshletCount - 1)
    {
        const uint instancesPerGroup = min(MAX_VERTS / m.VertCount, MAX_PRIMS / m.PrimCount);

        // 确定这个组中有多少个实例
        uint unpackedGroupCount = (MeshInfo.MeshletCount - 1) * DrawParams.InstanceCount;
        uint packedIndex = gid - unpackedGroupCount;

        startInstance = packedIndex * instancesPerGroup;
        instanceCount = min(DrawParams.InstanceCount - startInstance, instancesPerGroup);
    }

    // 计算我们的需要输出的顶点与索引数
    uint vertCount = m.VertCount * instanceCount;
    uint primCount = m.PrimCount * instanceCount;

    SetMeshOutputCounts(vertCount, primCount);

    //--------------------------------------------------------------------
    // 数据导出

    if (gtid < vertCount)
    {
        uint readIndex = gtid % m.VertCount;  // Wrap our reads for packed instancing.
        uint instanceId = gtid / m.VertCount; // Instance index into this threadgroup's instances (only non-zero for packed threadgroups.)

        uint vertexIndex = GetVertexIndex(m, readIndex);
        uint instanceIndex = startInstance + instanceId;

        verts[gtid] = GetVertexAttributes(meshletIndex, vertexIndex, instanceIndex);
    }

    if (gtid < primCount)
    {
        uint readIndex = gtid % m.PrimCount;  // Wrap our reads for packed instancing.
        uint instanceId = gtid / m.PrimCount; // Instance index within this threadgroup (only non-zero in last meshlet threadgroups.)

        // Must offset the vertex indices to this thread's instanced verts
        tris[gtid] = GetPrimitive(m, readIndex) + (m.VertCount * instanceId);
    }
}

PS就不再赘述了

struct Constants
{
    float4x4 World;
    float4x4 WorldView;
    float4x4 WorldViewProj;
    uint     DrawMeshlets;
};

struct VertexOut
{
    float4 PositionHS   : SV_Position;
    float3 PositionVS   : POSITION0;
    float3 Normal       : NORMAL0;
    uint   MeshletIndex : COLOR0;
};

ConstantBuffer<Constants> Globals : register(b0);

float4 main(VertexOut input) : SV_TARGET
{
    float ambientIntensity = 0.1;
    float3 lightColor = float3(1, 1, 1);
    float3 lightDir = -normalize(float3(1, -1, 1));

    float3 diffuseColor;
    float shininess;
    if (Globals.DrawMeshlets)
    {
        uint meshletIndex = input.MeshletIndex;
        diffuseColor = float3(
            float(meshletIndex & 1),
            float(meshletIndex & 3) / 4,
            float(meshletIndex & 7) / 8);
        shininess = 16.0;
    }
    else
    {
        diffuseColor = 0.8;
        shininess = 64.0;
    }

    float3 normal = normalize(input.Normal);

    // Do some fancy Blinn-Phong shading!
    float cosAngle = saturate(dot(normal, lightDir));
    float3 viewDir = -normalize(input.PositionVS);
    float3 halfAngle = normalize(lightDir + viewDir);

    float blinnTerm = saturate(dot(normal, halfAngle));
    blinnTerm = cosAngle != 0.0 ? blinnTerm : 0.0;
    blinnTerm = pow(blinnTerm, shininess);

    float3 finalColor = (cosAngle + blinnTerm + ambientIntensity) * diffuseColor;

    return float4(finalColor, 1);
}


当然了这是全绘制的效果,后续我们继续跟一下MeshShader的遮挡剔除与LOD来优化效率。

你可能感兴趣的:(DirectX12,Mesh,Shader)