reconic 天空 模型

目录

推理代码:

EnvLight 代码:


推理代码:

        sky_model = self.models["Sky"]
        outputs["rgb_sky"] = sky_model(image_info)
        outputs["rgb_sky_blend"] = outputs["rgb_sky"] * (1.0 - outputs["opacity"])

EnvLight 代码:

import torch

# 定义环境光类(EnvLight),继承自 torch.nn.Module
class EnvLight(torch.nn.Module):
    def __init__(self, class_name: str, resolution: int = 1024, device: torch.device = torch.device("cuda"), **kwargs):
        # 初始化函数,接收类名、分辨率、设备(默认 GPU)以及其他关键字参数
        super().__init__()
        
        # 设置类的前缀,方便后续参数管理
        self.class_prefix = class_name + "#"
        
        # 设置设备(默认为 GPU)
        self.device = device
        
        # 定义 OpenGL 转换矩阵,将世界坐标系转换为 OpenGL 坐标系
        # 该矩阵的作用是转换方向向量
        self.to_opengl = torch.tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]], dtype=torch.float32, device="cuda")
        
        # 定义基础光照参数:初始化为一个 6 x resolution x resolution 的全 0.5 张量,
        # 每个光照样本有 3 个值(RGB)。该参数是可训练的(requires_grad=True)
        self.base = torch.nn.Parameter(
            0.5 * torch.ones(6, resolution, resolution, 3, requires_grad=True),
        )

    def forward(self, image_info: ImageInfo):
        # 前向传播函数,接受一个 ImageInfo 类型的输入(包含射线信息)

        # 获取传入图像信息中的方向向量(viewdirs),表示视角方向
        directions = image_info.rays.viewdirs

        # 将方向向量从世界坐标系转换到 OpenGL 坐标系
        directions = (directions.reshape(-1, 3) @ self.to_opengl.T).reshape(*directions.shape)
        
        # 重新调整方向向量的内存布局为连续的,以便后续操作
        directions = directions.contiguous()
        
        # 获取方向向量的前缀尺寸,用于后续的形状调整
        prefix = directions.shape[:-1]
        
        # 如果前缀尺寸不是三维(即 [B, H, W]),则将方向向量重塑为 [1, 1, -1, 3]
        # 目的是将其转换为适合批量处理的形状
        if len(prefix) != 3:  # reshape to [B, H, W, -1]
            directions = directions.reshape(1, 1, -1, directions.shape[-1])

        # 使用 dr.texture 函数计算光照(dr 是某个光照计算库)
        # `self.base[None, ...]` 代表基础光照纹理,`directions` 是输入的方向向量
        # `filter_mode="linear"` 表示纹理的过滤模式,`boundary_mode="cube"` 表示纹理的边界模式
        light = dr.texture(self.base[None, ...], directions, filter_mode="linear", boundary_mode="cube")
        
        # 将输出的光照结果 reshaped 为适合的形状
        light = light.view(*prefix, -1)

        return light

    def get_param_groups(self):
        # 获取模型参数分组,返回一个字典
        # 这里我们将所有参数归为一个组,键为 "class_name + all"
        return {
            self.class_prefix + "all": self.parameters(),
        }

你可能感兴趣的:(python基础,3d渲染,python,pytorch,深度学习)