Pytorch 读取t7文件

Pytorch 1.0以上可以使用:

import torchfile

th_path = r"./path/xx.t7"
data = torchfile.load(th_path)

print(data.shape)

若data的尺寸为0,则将torch版本降为0.4.1,并使用以下函数:

from torch.utils.serialization import load_lua

th_path = r"./path/xx.t7"
data = load_lua(th_path).numpy()

print(data.shape)

注意:
若是在Windows的系统中读取t7文件,一定要记得要用long_size=8

data = torchfile.load(th_path,long_size=8)
或
data = load_lua(th_path,long_size=8).numpy()

你可能感兴趣的:(Python,pytorch,人工智能,python,t7)