图片标签 转为 one-hot 张量

图片大小 (530, 730),其中每个像素点的值代表其属于的类别,总共 38 个类,num_class = 38,现在要把每个像素点的值转化成 one-hot 形式。

import numpy as np
import torch

label = np.load('label.npy').astype('int')  # [0,37]

print(label.shape)  # (530, 730)

num_class = int(np.max(label) - np.min(label) + 1)
print(num_class)  # 38

# print(label[0])
print(label[200])
# see some class value
print(label[200][0])
print(label[200][48])
print(label[200][240])

# (530, 730) -> # (1, 530, 730)
label = label[np.newaxis, :, :]  # add new dim in any dim
print(label.shape)

# (1, 530, 730) -> [1, 530, 730, 1]
label = torch.LongTensor(label).unsqueeze(3)
print(label.shape)

# [1, 530, 730, 1] -> [1, 530, 730, 38]
label = torch.zeros(label.shape[0], label.shape[1], label.shape[2], num_class).scatter_(3, label, 1).long()
print(label.shape)

# see the class value -> one-hot tensor
print(label[0][200][0])
print(label[0][200][48])
print(label[0][200][240])

# [1, 530, 730, 38] -> [1, 38, 530, 730]
label = label.transpose(1, 3).transpose(2, 3)

print(label.shape)

# print(label[])
(530, 730)
38
[ 1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
  1  1  1  1  1  1  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5
  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  1  1  1  1  1
  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
  1  1  1  1  1  1  1  1  1  1  5  5  5  5  5  5  5  5  5  5  5  5  5  5
  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5
  5  5  5  5  1  1  1  1  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7
  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  7  1  1  1  1
  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1  1
  1  1  0  0  0  0  0  0 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35
 35  0  0  0  0  0  0  0  0  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5
  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
  0  0  0  0  0  0  0  0  0  0]
1
5
7
(1, 530, 730)
torch.Size([1, 530, 730, 1])
torch.Size([1, 530, 730, 38])
tensor([0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
torch.Size([1, 38, 530, 730])

你可能感兴趣的:(图片标签 转为 one-hot 张量)