diffusers代码梳理总结

常用类

这里总结一些频繁用到的支持类。

from dataclasses import dataclass
from ..utils import BaseOutput
from collections import OrderedDict

class BaseOutput(OrderedDict):
	...

@dataclass
class Unet2DOutput(BaseOutput):
	"""
	The output of [`Unet2DModel`].
	Args:
		sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
			The hidden states output from the last layer of the model.
	"""
	sample: torch.FloatTensor

BaseOutput继承自OrderedDict,可以记住数据插入的顺序。BaseOutput这个类是所有模型输出的基类。models\unet_2d.py中就定义了Unet2DOutput做为该模型的输出类。且还用了dataclass装饰符,表明这个类只承载数据输出的作用。

from .modeling_utils import ModelMixin
from ..configuration_utils import ConfigMixin, register_to_config

class Unet2DModel(ModelMixin, ConfigMixin):
	"""
	A 2D UNet model that takes a noisy sample and a timestep and returns a sample shaped otuput.
	"""
	...

unet

Unet2DModel

主体由down_blocks, mid_blocks, up_blocks三块组成。输入除了sample,还有time_embedding和label_embedding。

down_blocks
mid_blocks
up_blocks
embeddings
forward

你可能感兴趣的:(深度学习,人工智能)