pytorch中的l2_normalize函数

pytorch_l2_normalize.py

import torch
import tensorflow as tf

########## PyTorch Version 1 ################
x = torch.randn(5, 6)
norm_th = x/torch.norm(x, p=2, dim=1, keepdim=True)
norm_th[torch.isnan(norm_th)] = 0 # to avoid nan

########## PyTorch Version 2 ################
norm_th = torch.nn.functional.normalize(x, p=2, dim=1)

########### Equivalent to ############
norm_tf = tf.nn.l2_normalize(x.numpy(), axis=1)

print(norm_th)
print(norm_tf)

reference
https://gist.github.com/EdisonLeeeee/290691c8b1895427024875c3fafece67

你可能感兴趣的:(python,pytorch,深度学习,tensorflow)