pytorch 常用函数参数详解

1、torch.cat(inputs, dim=0) -> Tensor 

参考链接:

[Pytorch] 详解 torch.cat()

Pytorch学习笔记(一):torch.cat()模块的详解

函数作用:cat 是 concatnate 的意思:拼接,联系在一起。在给定维度上对输入的 Tensor 序列进行拼接操作。torch.cat 可以看作是 torch.split 和 torch.chunk 的反操作

参数:

inputs(sequence of Tensors):可以是任意相同类型的 Tensor 的 python 序列

dim(int, optional):defaults=0

dim=0: 按列进行拼接 

dim=1: 按行进行拼接

dim=-1: 如果行和列数都相同则按行进行拼接,否则按照行数或列数相等的维度进行拼接

假设 a 和 b 都是 Tensor,且 a 的维度为 [2, 3],b 的维度为 [2, 4],则

torch.cat((a, b), dim=1) 的维度为 [2, 7]


2、torch.nn.CrossEntropyLoss()

函数作用:CrossEntropy 是交叉熵的意思,故而 CrossEntropyLoss 的作用是计算交叉熵。CrossEntropyLoss 函数是将 torch.nn.Softmax 和 torch.nn.NLLLoss 两个函数组合在一起使用,故而传入的预测值不需要先进行 torch.nnSoftmax 操作。

参数:

input(N, C):N 是 batch_size,C 则是类别数,即在定义模型输出时,输出节点个数要定义为 [N, C]。其中特别注意的是 target 的数据类型需要是浮点数,即 float32

target(N):N 是 batch_size,故 target 需要是 1D 张量。其中特别注意的是 target 的数据类型需要是 long,即 int64

例子:

loss = nn.CrossEntropyLoss()

input = torch.randn(3, 5, requires_grad=True, dtype=torch.float32)

target = torch.empty(3, dtype=torch.long).random_(5)

output = loss(input, target)

output

输出为:

tensor(1.6916, grad_fn=)

你可能感兴趣的:(pytorch 常用函数参数详解)