记录_nn.functional.one_hot

记录_nn.functional.one_hot(x, num_classes)

  • 作用:将x的元素转化为0/1标签值
  • Return.shape: (x.size, num_classes)
  • num_classes 为目标类别数,需要>=实际类别数(输入x的元素值的类别)
a = torch.randint(0, 24, (3, 4, 5))  
>>>
tensor([[[21,  2,  0,  5, 19],
         [10,  6, 10,  1,  5],
         [10,  7, 11,  0,  2],
         [17, 15, 16,  1,  6]],

        [[10,  2,  3,  8, 12],
         [10,  5, 18, 13,  6],
         [ 3, 14, 10, 16,  6],
         [15,  5,  4,  6,  3]],

        [[21, 14, 11,  2, 11],
         [ 3, 17,  0, 16,  5],
         [16, 15,  8,  3, 10],
         [21, 13, 15, 10, 14]]])
  • 寻找dim=2的最小值序列
b = torch.argmin(a, dim=2)  
>>>
	tensor([[2, 3, 3, 3],
					 [1, 1, 0, 4],
					 [3, 2, 3, 3]]) torch.Size([3, 4])
c = torch.nn.functional.one_hot(b, 5)   
  • 注意此处 5 >= b中索引值(0,1,2,3,4)的总数5
>>>
tensor([[[0, 0, 1, 0, 0], 	# 2
		 [0, 0, 0, 1, 0],	# 3
		 [0, 0, 0, 1, 0],	# 3
         [0, 0, 0, 1, 0]],	# 3
		
         [[0, 1, 0, 0, 0],	# 1
          [0, 1, 0, 0, 0],	# 1
          [1, 0, 0, 0, 0],	# 0
          [0, 0, 0, 0, 1]],	# 4
		
		 [[0, 0, 0, 1, 0],	
		  [0, 0, 1, 0, 0],		
		  [0, 0, 0, 1, 0],			
		  [0, 0, 0, 1, 0]]]) torch.Size([3, 4, 5])
  • 比如,num_classes = 6,但其实是多余了
c = torch.nn.functional.one_hot(b, 6) 
tensor([[[0, 0, 1, 0, 0, 0],
		 [0, 0, 0, 1, 0, 0],
		 [0, 0, 0, 1, 0, 0],
		 [0, 0, 0, 1, 0, 0]],
		
		[[0, 1, 0, 0, 0, 0],
		 [0, 1, 0, 0, 0, 0],
		 [1, 0, 0, 0, 0, 0],
		 [0, 0, 0, 0, 1, 0]],
		
		[[0, 0, 0, 1, 0, 0],
		 [0, 0, 1, 0, 0, 0],
	 	 [0, 0, 0, 1, 0, 0],
		 [0, 0, 0, 1, 0, 0]]]) torch.Size([3, 4, 6])

你可能感兴趣的:(记录_nn.functional.one_hot)