torch标记维度最大

 

import torch
import numpy as np

a = torch.rand(2,2,2)

print(a)

values, indices = a.max(2,keepdim=True)
values, indices = torch.max(a, 2,keepdim=True)
c=a-values

b=torch.argmax(a,dim=2).type(torch.uint8)
a[c==0]=2
print(a)

 

简单方法:

a[a==torch.max(a, 2,keepdim=True)[0]]

你可能感兴趣的:(torch)