想了解具体原理与API的同学可以看一下上一篇文章:
DX12_Amplification/Mesh Shaders and API。本文主要从应用程序和GPU两部分来看一下Mesh Shader如何实现。
本文参考文章:coming-to-directx-12-mesh-shaders-and-amplification-shaders-reinventing-the-geometry-pipeline
网格着色器(Mesh Shader)通过调度一组线程组来开始其工作,每个线程组可以处理较大网格的子集。每个线程组都可以访问组共享内存(如计算着色器),但输出不必与组中的特定线程关联的顶点和基元 。只要线程组处理与线程组中的基元关联的所有顶点,就可以以最有效的方式分配资源。
仅构建网格着色器(Mesh Shader)非常简单,具体操作如下:
[ numthreads ( X, Y, Z ) ]
[ outputtopology ( T ) ]
当然,网格着色器可以将许多系统值作为输入,包括 SV_DispatchThreadID 、SV_GroupThreadID 、SV_ViewID等(间后续MS使用实例),但必须输出一个用于表示顶点的数组和一个用于图元索引数组。 这些是您将在计算结束时写入的数组。
如果网格着色器使用放大着色器(AS),则它还必须具有payload的输入。
SetMeshOutputCounts ( uint numVertices, uint numPrimatives )
也就是通过调用 SetMeshOutputCounts 来告诉PS输出的数据大小。在写入输出数组之前,必须在网格着色器中仅调用此函数一次。如果调用此函数,网格着色器将不会输出任何数据,具体规则见上文链接中的文章有具体介绍。
除了这些规则之外,你还可以做很多GPU-Driven的事情,比如剔除与实例化等操作,后续文章中具体介绍,本文主要介绍Mesh Shader的使用方法。
在创建之前,首先要确定下硬件是否支持:
D3D12_FEATURE_DATA_SHADER_MODEL shaderModel = { D3D_SHADER_MODEL_6_5 };
if (FAILED(m_device->CheckFeatureSupport(D3D12_FEATURE_SHADER_MODEL, &shaderModel, sizeof(shaderModel)))
|| (shaderModel.HighestShaderModel < D3D_SHADER_MODEL_6_5))
{
OutputDebugStringA("ERROR: Shader Model 6.5 is not supported\n");
throw std::exception("Shader Model 6.5 is not supported");
}
D3D12_FEATURE_DATA_D3D12_OPTIONS7 features = {};
if (FAILED(m_device->CheckFeatureSupport(D3D12_FEATURE_D3D12_OPTIONS7, &features, sizeof(features)))
|| (features.MeshShaderTier == D3D12_MESH_SHADER_TIER_NOT_SUPPORTED))
{
OutputDebugStringA("ERROR: Mesh Shaders aren't supported!\n");
throw std::exception("Mesh Shaders aren't supported!");
}
确定机器支持MS的情况下,进行创建管线对象,具体流程如下:
// 创建管线渲染对象
{
struct
{
byte* data;
uint32_t size;
} meshShader, pixelShader;
ReadDataFromFile(GetAssetFullPath(c_meshShaderFilename).c_str(), &meshShader.data, &meshShader.size);
ReadDataFromFile(GetAssetFullPath(c_pixelShaderFilename).c_str(), &pixelShader.data, &pixelShader.size);
// 从预编译的网格着色器创建根签名。
ThrowIfFailed(m_device->CreateRootSignature(0, meshShader.data, meshShader.size, IID_PPV_ARGS(&m_rootSignature)));
D3DX12_MESH_SHADER_PIPELINE_STATE_DESC psoDesc = {};
psoDesc.pRootSignature = m_rootSignature.Get();
psoDesc.MS = { meshShader.data, meshShader.size };
psoDesc.PS = { pixelShader.data, pixelShader.size };
psoDesc.NumRenderTargets = 1;
psoDesc.RTVFormats[0] = m_renderTargets[0]->GetDesc().Format;
psoDesc.DSVFormat = m_depthStencil->GetDesc().Format;
psoDesc.RasterizerState = CD3DX12_RASTERIZER_DESC(D3D12_DEFAULT); // CW front; cull back
psoDesc.BlendState = CD3DX12_BLEND_DESC(D3D12_DEFAULT); // Opaque
psoDesc.DepthStencilState = CD3DX12_DEPTH_STENCIL_DESC(D3D12_DEFAULT); // Less-equal depth test w/ writes; no stencil
psoDesc.SampleMask = UINT_MAX;
psoDesc.SampleDesc = DefaultSampleDesc();
auto psoStream = CD3DX12_PIPELINE_MESH_STATE_STREAM(psoDesc);
D3D12_PIPELINE_STATE_STREAM_DESC streamDesc;
streamDesc.pPipelineStateSubobjectStream = &psoStream;
streamDesc.SizeInBytes = sizeof(psoStream);
ThrowIfFailed(m_device->CreatePipelineState(&streamDesc, IID_PPV_ARGS(&m_pipelineState)));
}
可以看出来,创建几乎和传统管线创建流程一致,不在赘述。
首先来看一下几个用到的结构体(为了更好的记录Meshlet中的数据):
struct FileHeader
{
uint32_t Prolog;
uint32_t Version;
uint32_t MeshCount;
uint32_t AccessorCount;
uint32_t BufferViewCount;
uint32_t BufferSize;
};
struct MeshHeader
{
uint32_t Indices;
uint32_t IndexSubsets;
uint32_t Attributes[Attribute::Count];
uint32_t Meshlets;
uint32_t MeshletSubsets;
uint32_t UniqueVertexIndices;
uint32_t PrimitiveIndices;
uint32_t CullData;
};
struct BufferView
{
uint32_t Offset;
uint32_t Size;
};
struct Accessor
{
uint32_t BufferView;
uint32_t Offset;
uint32_t Size;
uint32_t Stride;
uint32_t Count;
};
之后从文件中读取相应的数据信息:
//读取原始数据
std::ifstream stream(filename, std::ios::binary);
if (!stream.is_open())
{
return E_INVALIDARG;
}
std::vector meshes;
std::vector bufferViews;
std::vector accessors;
FileHeader header;
stream.read(reinterpret_cast(&header), sizeof(header));
if (header.Prolog != c_prolog)
{
return E_FAIL;
}
if (header.Version != CURRENT_FILE_VERSION)
{
return E_FAIL;
}
// 取数据
meshes.resize(header.MeshCount);
stream.read(reinterpret_cast(meshes.data()), meshes.size() * sizeof(meshes[0]));
accessors.resize(header.AccessorCount);
stream.read(reinterpret_cast(accessors.data()), accessors.size() * sizeof(accessors[0]));
bufferViews.resize(header.BufferViewCount);
stream.read(reinterpret_cast(bufferViews.data()), bufferViews.size() * sizeof(bufferViews[0]));
m_buffer.resize(header.BufferSize);
stream.read(reinterpret_cast(m_buffer.data()), header.BufferSize);
char eofbyte;
stream.read(&eofbyte, 1);
assert(stream.eof());
stream.close();
之后根据Mesh中预处理的Meshlet数据分配至各部分(当然了有自动划分Meshlet的工具):
// 从二进制数据和原数据中填充网格数据
m_meshes.resize(meshes.size());
for (uint32_t i = 0; i < static_cast(meshes.size()); ++i)
{
auto& meshView = meshes[i];
auto& mesh = m_meshes[i];
// Index data
{
Accessor& accessor = accessors[meshView.Indices];
BufferView& bufferView = bufferViews[accessor.BufferView];
mesh.IndexSize = accessor.Size;
mesh.IndexCount = accessor.Count;
mesh.Indices = MakeSpan(m_buffer.data() + bufferView.Offset, bufferView.Size);
}
// Index Subset data
{
Accessor& accessor = accessors[meshView.IndexSubsets];
BufferView& bufferView = bufferViews[accessor.BufferView];
mesh.IndexSubsets = MakeSpan(reinterpret_cast(m_buffer.data() + bufferView.Offset), accessor.Count);
}
// Vertex data & layout metadata
// 确定与顶点属性关联的唯一缓冲区视图的数量并复制顶点缓冲区。
std::vector vbMap;
mesh.LayoutDesc.pInputElementDescs = mesh.LayoutElems;
mesh.LayoutDesc.NumElements = 0;
for (uint32_t j = 0; j < Attribute::Count; ++j)
{
if (meshView.Attributes[j] == -1)
continue;
Accessor& accessor = accessors[meshView.Attributes[j]];
auto it = std::find(vbMap.begin(), vbMap.end(), accessor.BufferView);
if (it != vbMap.end())
{
continue;
}
// 新的缓冲区视图,添加到列表并复制顶点数据
vbMap.push_back(accessor.BufferView);
BufferView& bufferView = bufferViews[accessor.BufferView];
Span verts = MakeSpan(m_buffer.data() + bufferView.Offset, bufferView.Size);
mesh.VertexStrides.push_back(accessor.Stride);
mesh.Vertices.push_back(verts);
mesh.VertexCount = static_cast(verts.size()) / accessor.Stride;
}
// 从访问器填充顶点缓冲区原数据。
for (uint32_t j = 0; j < Attribute::Count; ++j)
{
if (meshView.Attributes[j] == -1)
continue;
Accessor& accessor = accessors[meshView.Attributes[j]];
// 确定哪个顶点缓冲区索引使用此属性的数据
auto it = std::find(vbMap.begin(), vbMap.end(), accessor.BufferView);
D3D12_INPUT_ELEMENT_DESC desc = c_elementDescs[j];
desc.InputSlot = static_cast(std::distance(vbMap.begin(), it));
mesh.LayoutElems[mesh.LayoutDesc.NumElements++] = desc;
}
// Meshlet data
{
Accessor& accessor = accessors[meshView.Meshlets];
BufferView& bufferView = bufferViews[accessor.BufferView];
mesh.Meshlets = MakeSpan(reinterpret_cast(m_buffer.data() + bufferView.Offset), accessor.Count);
}
// Meshlet Subset data
{
Accessor& accessor = accessors[meshView.MeshletSubsets];
BufferView& bufferView = bufferViews[accessor.BufferView];
mesh.MeshletSubsets = MakeSpan(reinterpret_cast(m_buffer.data() + bufferView.Offset), accessor.Count);
}
// Unique Vertex Index data
{
Accessor& accessor = accessors[meshView.UniqueVertexIndices];
BufferView& bufferView = bufferViews[accessor.BufferView];
mesh.UniqueVertexIndices = MakeSpan(m_buffer.data() + bufferView.Offset, bufferView.Size);
}
// Primitive Index data
{
Accessor& accessor = accessors[meshView.PrimitiveIndices];
BufferView& bufferView = bufferViews[accessor.BufferView];
mesh.PrimitiveIndices = MakeSpan(reinterpret_cast(m_buffer.data() + bufferView.Offset), accessor.Count);
}
// Cull data
{
Accessor& accessor = accessors[meshView.CullData];
BufferView& bufferView = bufferViews[accessor.BufferView];
mesh.CullingData = MakeSpan(reinterpret_cast(m_buffer.data() + bufferView.Offset), accessor.Count);
}
}
最后是为所有meshlet创建包围盒(为了后续可以剔除用,本示例用不到此数据):
// 为每个网格建立边界球
for (uint32_t i = 0; i < static_cast(m_meshes.size()); ++i)
{
auto& m = m_meshes[i];
uint32_t vbIndexPos = 0;
// 查找位置属性的顶点缓冲区的索引
for (uint32_t j = 1; j < m.LayoutDesc.NumElements; ++j)
{
auto& desc = m.LayoutElems[j];
if (strcmp(desc.SemanticName, "POSITION") == 0)
{
vbIndexPos = j;
break;
}
}
// 查找position属性及其顶点缓冲区的字节偏移量
uint32_t positionOffset = 0;
for (uint32_t j = 0; j < m.LayoutDesc.NumElements; ++j)
{
auto& desc = m.LayoutElems[j];
if (strcmp(desc.SemanticName, "POSITION") == 0)
{
break;
}
if (desc.InputSlot == vbIndexPos)
{
positionOffset += GetFormatSize(m.LayoutElems[j].Format);
}
}
XMFLOAT3* v0 = reinterpret_cast(m.Vertices[vbIndexPos].data() + positionOffset);
uint32_t stride = m.VertexStrides[vbIndexPos];
BoundingSphere::CreateFromPoints(m.BoundingSphere, m.VertexCount, v0, stride);
if (i == 0)
{
m_boundingSphere = m.BoundingSphere;
}
else
{
BoundingSphere::CreateMerged(m_boundingSphere, m_boundingSphere, m.BoundingSphere);
}
}
数据准备全了,接下来就是把数据更新到GPU了,其中涉及到数据的对应。
HRESULT Model::UploadGpuResources(ID3D12Device* device, ID3D12CommandQueue* cmdQueue, ID3D12CommandAllocator* cmdAlloc, ID3D12GraphicsCommandList* cmdList)
{
for (uint32_t i = 0; i < m_meshes.size(); ++i)
{
auto& m = m_meshes[i];
// 创建适当大小的committed D3D资源
auto indexDesc = CD3DX12_RESOURCE_DESC::Buffer(m.Indices.size());
auto meshletDesc = CD3DX12_RESOURCE_DESC::Buffer(m.Meshlets.size() * sizeof(m.Meshlets[0]));
auto cullDataDesc = CD3DX12_RESOURCE_DESC::Buffer(m.CullingData.size() * sizeof(m.CullingData[0]));
auto vertexIndexDesc = CD3DX12_RESOURCE_DESC::Buffer(DivRoundUp(m.UniqueVertexIndices.size(), 4) * 4);
auto primitiveDesc = CD3DX12_RESOURCE_DESC::Buffer(m.PrimitiveIndices.size() * sizeof(m.PrimitiveIndices[0]));
auto meshInfoDesc = CD3DX12_RESOURCE_DESC::Buffer(sizeof(MeshInfo));
auto defaultHeap = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_DEFAULT);
ThrowIfFailed(device->CreateCommittedResource(&defaultHeap, D3D12_HEAP_FLAG_NONE, &indexDesc, D3D12_RESOURCE_STATE_COPY_DEST, nullptr, IID_PPV_ARGS(&m.IndexResource)));
ThrowIfFailed(device->CreateCommittedResource(&defaultHeap, D3D12_HEAP_FLAG_NONE, &meshletDesc, D3D12_RESOURCE_STATE_COPY_DEST, nullptr, IID_PPV_ARGS(&m.MeshletResource)));
ThrowIfFailed(device->CreateCommittedResource(&defaultHeap, D3D12_HEAP_FLAG_NONE, &cullDataDesc, D3D12_RESOURCE_STATE_COPY_DEST, nullptr, IID_PPV_ARGS(&m.CullDataResource)));
ThrowIfFailed(device->CreateCommittedResource(&defaultHeap, D3D12_HEAP_FLAG_NONE, &vertexIndexDesc, D3D12_RESOURCE_STATE_COPY_DEST, nullptr, IID_PPV_ARGS(&m.UniqueVertexIndexResource)));
ThrowIfFailed(device->CreateCommittedResource(&defaultHeap, D3D12_HEAP_FLAG_NONE, &primitiveDesc, D3D12_RESOURCE_STATE_COPY_DEST, nullptr, IID_PPV_ARGS(&m.PrimitiveIndexResource)));
ThrowIfFailed(device->CreateCommittedResource(&defaultHeap, D3D12_HEAP_FLAG_NONE, &meshInfoDesc, D3D12_RESOURCE_STATE_COPY_DEST, nullptr, IID_PPV_ARGS(&m.MeshInfoResource)));
m.IBView.BufferLocation = m.IndexResource->GetGPUVirtualAddress();
m.IBView.Format = m.IndexSize == 4 ? DXGI_FORMAT_R32_UINT : DXGI_FORMAT_R16_UINT;
m.IBView.SizeInBytes = m.IndexCount * m.IndexSize;
m.VertexResources.resize(m.Vertices.size());
m.VBViews.resize(m.Vertices.size());
for (uint32_t j = 0; j < m.Vertices.size(); ++j)
{
auto vertexDesc = CD3DX12_RESOURCE_DESC::Buffer(m.Vertices[j].size());
device->CreateCommittedResource(&defaultHeap, D3D12_HEAP_FLAG_NONE, &vertexDesc, D3D12_RESOURCE_STATE_COPY_DEST, nullptr, IID_PPV_ARGS(&m.VertexResources[j]));
m.VBViews[j].BufferLocation = m.VertexResources[j]->GetGPUVirtualAddress();
m.VBViews[j].SizeInBytes = static_cast(m.Vertices[j].size());
m.VBViews[j].StrideInBytes = m.VertexStrides[j];
}
// 创建上传资源
std::vector> vertexUploads;
ComPtr indexUpload;
ComPtr meshletUpload;
ComPtr cullDataUpload;
ComPtr uniqueVertexIndexUpload;
ComPtr primitiveIndexUpload;
ComPtr meshInfoUpload;
auto uploadHeap = CD3DX12_HEAP_PROPERTIES(D3D12_HEAP_TYPE_UPLOAD);
ThrowIfFailed(device->CreateCommittedResource(&uploadHeap, D3D12_HEAP_FLAG_NONE, &indexDesc, D3D12_RESOURCE_STATE_GENERIC_READ, nullptr, IID_PPV_ARGS(&indexUpload)));
ThrowIfFailed(device->CreateCommittedResource(&uploadHeap, D3D12_HEAP_FLAG_NONE, &meshletDesc, D3D12_RESOURCE_STATE_GENERIC_READ, nullptr, IID_PPV_ARGS(&meshletUpload)));
ThrowIfFailed(device->CreateCommittedResource(&uploadHeap, D3D12_HEAP_FLAG_NONE, &cullDataDesc, D3D12_RESOURCE_STATE_GENERIC_READ, nullptr, IID_PPV_ARGS(&cullDataUpload)));
ThrowIfFailed(device->CreateCommittedResource(&uploadHeap, D3D12_HEAP_FLAG_NONE, &vertexIndexDesc, D3D12_RESOURCE_STATE_GENERIC_READ, nullptr, IID_PPV_ARGS(&uniqueVertexIndexUpload)));
ThrowIfFailed(device->CreateCommittedResource(&uploadHeap, D3D12_HEAP_FLAG_NONE, &primitiveDesc, D3D12_RESOURCE_STATE_GENERIC_READ, nullptr, IID_PPV_ARGS(&primitiveIndexUpload)));
ThrowIfFailed(device->CreateCommittedResource(&uploadHeap, D3D12_HEAP_FLAG_NONE, &meshInfoDesc, D3D12_RESOURCE_STATE_GENERIC_READ, nullptr, IID_PPV_ARGS(&meshInfoUpload)));
// 映射并复制内存至上传堆
vertexUploads.resize(m.Vertices.size());
for (uint32_t j = 0; j < m.Vertices.size(); ++j)
{
auto vertexDesc = CD3DX12_RESOURCE_DESC::Buffer(m.Vertices[j].size());
ThrowIfFailed(device->CreateCommittedResource(&uploadHeap, D3D12_HEAP_FLAG_NONE, &vertexDesc, D3D12_RESOURCE_STATE_GENERIC_READ, nullptr, IID_PPV_ARGS(&vertexUploads[j])));
uint8_t* memory = nullptr;
vertexUploads[j]->Map(0, nullptr, reinterpret_cast(&memory));
std::memcpy(memory, m.Vertices[j].data(), m.Vertices[j].size());
vertexUploads[j]->Unmap(0, nullptr);
}
{
uint8_t* memory = nullptr;
indexUpload->Map(0, nullptr, reinterpret_cast(&memory));
std::memcpy(memory, m.Indices.data(), m.Indices.size());
indexUpload->Unmap(0, nullptr);
}
{
uint8_t* memory = nullptr;
meshletUpload->Map(0, nullptr, reinterpret_cast(&memory));
std::memcpy(memory, m.Meshlets.data(), m.Meshlets.size() * sizeof(m.Meshlets[0]));
meshletUpload->Unmap(0, nullptr);
}
{
uint8_t* memory = nullptr;
cullDataUpload->Map(0, nullptr, reinterpret_cast(&memory));
std::memcpy(memory, m.CullingData.data(), m.CullingData.size() * sizeof(m.CullingData[0]));
cullDataUpload->Unmap(0, nullptr);
}
{
uint8_t* memory = nullptr;
uniqueVertexIndexUpload->Map(0, nullptr, reinterpret_cast(&memory));
std::memcpy(memory, m.UniqueVertexIndices.data(), m.UniqueVertexIndices.size());
uniqueVertexIndexUpload->Unmap(0, nullptr);
}
{
uint8_t* memory = nullptr;
primitiveIndexUpload->Map(0, nullptr, reinterpret_cast(&memory));
std::memcpy(memory, m.PrimitiveIndices.data(), m.PrimitiveIndices.size() * sizeof(m.PrimitiveIndices[0]));
primitiveIndexUpload->Unmap(0, nullptr);
}
{
MeshInfo info = {};
info.IndexSize = m.IndexSize;
info.MeshletCount = static_cast(m.Meshlets.size());
info.LastMeshletVertCount = m.Meshlets.back().VertCount;
info.LastMeshletPrimCount = m.Meshlets.back().PrimCount;
uint8_t* memory = nullptr;
meshInfoUpload->Map(0, nullptr, reinterpret_cast(&memory));
std::memcpy(memory, &info, sizeof(MeshInfo));
meshInfoUpload->Unmap(0, nullptr);
}
// 填充命令列表
cmdList->Reset(cmdAlloc, nullptr);
for (uint32_t j = 0; j < m.Vertices.size(); ++j)
{
cmdList->CopyResource(m.VertexResources[j].Get(), vertexUploads[j].Get());
const auto barrier = CD3DX12_RESOURCE_BARRIER::Transition(m.VertexResources[j].Get(), D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_PIXEL_SHADER_RESOURCE);
cmdList->ResourceBarrier(1, &barrier);
}
D3D12_RESOURCE_BARRIER postCopyBarriers[6];
cmdList->CopyResource(m.IndexResource.Get(), indexUpload.Get());
postCopyBarriers[0] = CD3DX12_RESOURCE_BARRIER::Transition(m.IndexResource.Get(), D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_PIXEL_SHADER_RESOURCE);
cmdList->CopyResource(m.MeshletResource.Get(), meshletUpload.Get());
postCopyBarriers[1] = CD3DX12_RESOURCE_BARRIER::Transition(m.MeshletResource.Get(), D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_PIXEL_SHADER_RESOURCE);
cmdList->CopyResource(m.CullDataResource.Get(), cullDataUpload.Get());
postCopyBarriers[2] = CD3DX12_RESOURCE_BARRIER::Transition(m.CullDataResource.Get(), D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_PIXEL_SHADER_RESOURCE);
cmdList->CopyResource(m.UniqueVertexIndexResource.Get(), uniqueVertexIndexUpload.Get());
postCopyBarriers[3] = CD3DX12_RESOURCE_BARRIER::Transition(m.UniqueVertexIndexResource.Get(), D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_PIXEL_SHADER_RESOURCE);
cmdList->CopyResource(m.PrimitiveIndexResource.Get(), primitiveIndexUpload.Get());
postCopyBarriers[4] = CD3DX12_RESOURCE_BARRIER::Transition(m.PrimitiveIndexResource.Get(), D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_PIXEL_SHADER_RESOURCE);
cmdList->CopyResource(m.MeshInfoResource.Get(), meshInfoUpload.Get());
postCopyBarriers[5] = CD3DX12_RESOURCE_BARRIER::Transition(m.MeshInfoResource.Get(), D3D12_RESOURCE_STATE_COPY_DEST, D3D12_RESOURCE_STATE_VERTEX_AND_CONSTANT_BUFFER);
cmdList->ResourceBarrier(ARRAYSIZE(postCopyBarriers), postCopyBarriers);
ThrowIfFailed(cmdList->Close());
ID3D12CommandList* ppCommandLists[] = { cmdList };
cmdQueue->ExecuteCommandLists(1, ppCommandLists);
// 创建同步栅栏
ComPtr fence;
ThrowIfFailed(device->CreateFence(0, D3D12_FENCE_FLAG_NONE, IID_PPV_ARGS(&fence)));
cmdQueue->Signal(fence.Get(), 1);
// Wait for GPU
if (fence->GetCompletedValue() != 1)
{
HANDLE event = CreateEvent(nullptr, FALSE, FALSE, nullptr);
fence->SetEventOnCompletion(1, event);
WaitForSingleObjectEx(event, INFINITE, false);
CloseHandle(event);
}
}
return S_OK;
}
之后根据meshlet中数量更新数据到GPU并分配绘制指令。
for (auto& mesh : m_model)
{
m_commandList->SetGraphicsRoot32BitConstant(1, mesh.IndexSize, 0);
m_commandList->SetGraphicsRootShaderResourceView(2, mesh.VertexResources[0]->GetGPUVirtualAddress());
m_commandList->SetGraphicsRootShaderResourceView(3, mesh.MeshletResource->GetGPUVirtualAddress());
m_commandList->SetGraphicsRootShaderResourceView(4, mesh.UniqueVertexIndexResource->GetGPUVirtualAddress());
m_commandList->SetGraphicsRootShaderResourceView(5, mesh.PrimitiveIndexResource->GetGPUVirtualAddress());
for (auto& subset : mesh.MeshletSubsets)
{
m_commandList->SetGraphicsRoot32BitConstant(1, subset.Offset, 1);
m_commandList->DispatchMesh(subset.Count, 1, 1);
}
}
主要来看一下MS,具体含义见注释:
#define ROOT_SIG "CBV(b0), \
RootConstants(b1, num32bitconstants=2), \
SRV(t0), \
SRV(t1), \
SRV(t2), \
SRV(t3)"
struct Constants
{
float4x4 World;
float4x4 WorldView;
float4x4 WorldViewProj;
uint DrawMeshlets;
};
struct MeshInfo
{
uint IndexBytes;
uint MeshletOffset;
};
struct Vertex
{
float3 Position;
float3 Normal;
};
struct VertexOut
{
float4 PositionHS : SV_Position;
float3 PositionVS : POSITION0;
float3 Normal : NORMAL0;
uint MeshletIndex : COLOR0;
};
//此处可拓展做剔除等操作
struct Meshlet
{
uint VertCount;
uint VertOffset;
uint PrimCount;
uint PrimOffset;
};
ConstantBuffer Globals : register(b0);
ConstantBuffer MeshInfo : register(b1);
StructuredBuffer Vertices : register(t0);
StructuredBuffer Meshlets : register(t1);
ByteAddressBuffer UniqueVertexIndices : register(t2);
StructuredBuffer PrimitiveIndices : register(t3);
// Data Loaders
uint3 UnpackPrimitive(uint primitive)
{
// 从32位的uint数据中解压三角形(10 bit)
return uint3(primitive & 0x3FF, (primitive >> 10) & 0x3FF, (primitive >> 20) & 0x3FF);
}
//获取三角形索引
uint3 GetPrimitive(Meshlet m, uint index)
{
return UnpackPrimitive(PrimitiveIndices[m.PrimOffset + index]);
}
//获取顶点数组的索引,以便后续获取顶点属性数据
uint GetVertexIndex(Meshlet m, uint localIndex)
{
localIndex = m.VertOffset + localIndex;
if (MeshInfo.IndexBytes == 4) // 32-bit Vertex Indices
{
return UniqueVertexIndices.Load(localIndex * 4);
}
else // 16-bit Vertex Indices
{
// Byte address must be 4-byte aligned.
uint wordOffset = (localIndex & 0x1);
uint byteOffset = (localIndex / 2) * 4;
// Grab the pair of 16-bit indices, shift & mask off proper 16-bits.
uint indexPair = UniqueVertexIndices.Load(byteOffset);
uint index = (indexPair >> (wordOffset * 16)) & 0xffff;
return index;
}
}
//顶点属性输出数据(类似VS输出)
VertexOut GetVertexAttributes(uint meshletIndex, uint vertexIndex)
{
Vertex v = Vertices[vertexIndex];
VertexOut vout;
vout.PositionVS = mul(float4(v.Position, 1), Globals.WorldView).xyz;
vout.PositionHS = mul(float4(v.Position, 1), Globals.WorldViewProj);
vout.Normal = mul(float4(v.Normal, 0), Globals.World).xyz;
vout.MeshletIndex = meshletIndex;
return vout;
}
//MS函数主入口
[RootSignature(ROOT_SIG)]
[NumThreads(128, 1, 1)]
[OutputTopology("triangle")]
void main(
uint gtid : SV_GroupThreadID,
uint gid : SV_GroupID,
out indices uint3 tris[126],
out vertices VertexOut verts[64]
)
{
Meshlet m = Meshlets[MeshInfo.MeshletOffset + gid];
SetMeshOutputCounts(m.VertCount, m.PrimCount);
if (gtid < m.PrimCount)
{
tris[gtid] = GetPrimitive(m, gtid);
}
if (gtid < m.VertCount)
{
uint vertexIndex = GetVertexIndex(m, gtid);
verts[gtid] = GetVertexAttributes(gid, vertexIndex);
}
}
PS:
struct Constants
{
float4x4 World;
float4x4 WorldView;
float4x4 WorldViewProj;
uint DrawMeshlets;
};
struct VertexOut
{
float4 PositionHS : SV_Position;
float3 PositionVS : POSITION0;
float3 Normal : NORMAL0;
uint MeshletIndex : COLOR0;
};
ConstantBuffer Globals : register(b0);
float4 main(VertexOut input) : SV_TARGET
{
float ambientIntensity = 0.1;
float3 lightColor = float3(1, 1, 1);
float3 lightDir = -normalize(float3(1, -1, 1));
float3 diffuseColor;
float shininess;
if (Globals.DrawMeshlets)
{
uint meshletIndex = input.MeshletIndex;
diffuseColor = float3(
float(meshletIndex & 1),
float(meshletIndex & 3) / 4,
float(meshletIndex & 7) / 8);
shininess = 16.0;
}
else
{
diffuseColor = 0.8;
shininess = 64.0;
}
float3 normal = normalize(input.Normal);
// Do some fancy Blinn-Phong shading!
float cosAngle = saturate(dot(normal, lightDir));
float3 viewDir = -normalize(input.PositionVS);
float3 halfAngle = normalize(lightDir + viewDir);
float blinnTerm = saturate(dot(normal, halfAngle));
blinnTerm = cosAngle != 0.0 ? blinnTerm : 0.0;
blinnTerm = pow(blinnTerm, shininess);
float3 finalColor = (cosAngle + blinnTerm + ambientIntensity) * diffuseColor;
return float4(finalColor, 1);
}