背景:我们需要把模型上传集群运行,所以预训练的模型需要放在文件夹之内进行加载,把环境及配置拷入env之后,不能用文件夹之外的库。预训练的resnet101需要直接放入目录下加载。
目录
一、预训练模型的加载
1.1 模型加载
1.2 加载流程
1.3 模型位置
1.4 缺点
1.5 找到预训练模型位置
二、加载指定位置模型
2.1 例子程序
2.2 把网络模型放入目录下
2.3 我们的程序
三、验证(可不看)
四、集群预训练模型的解决
4.1 相应报错
4.2 加载模型位置
4.3 服务器拷贝及运行
直接通过pytorch的models加载模型。
class HGAT_FC(nn.Module):
def __init__(self, backbone, groups, nclasses, nclasses_per_group, group_channels, class_channels):
super(HGAT_FC, self).__init__()
self.groups = groups
self.nclasses = nclasses
self.nclasses_per_group = nclasses_per_group
self.group_channels = group_channels
self.class_channels = class_channels
if backbone == 'resnet101':
model = models.resnet101(pretrained=True)
elif backbone == 'resnet50':
model = models.resnet50(pretrained=False)
else:
raise Exception()
其中需要导入的库为 torchvision.models
import torch
import torchvision.models as models
from torch import nn
import mymodels.utils as utils
import torch
from torch import nn
import torch.nn.functional as F
import torch
import torchvision.models as models
。。。
if backbone == 'resnet101':
model = models.resnet101(pretrained=True)
elif backbone == 'resnet50':
model = models.resnet50(pretrained=False)
else:
raise Exception()
cd ~是返回home目录。这个表明torch再home目录下安装着。
[[email protected] ~]$ cd ~/.torch/models
[[email protected] models]$ pwd
/home/xingxiangrui/.torch/models
[[email protected] models]$ ls
resnet101-5d3b4d8f.pth
如果没有下载过,torchvision会自动联网下载模型。
但是没有网络的情况下或者没有权限的情况下,模型不会下载,因此不能运行,会报错。
requests.exceptions.ConnectionError: ('Connection aborted.', TimeoutError(10060, '由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败。', None, 10060, None))
因此需要用下面的方法,直接从目录之中加载模型。
每个环境下,模型位置不一定,如果模型已经下载,需要找到模型存储的位置
如果预训练,则相应语句为:
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
对load_url函数进行ctrl+b
找到相应的位置:即如果模型本地有,则从本地加载,如果没有,则从url下载。
def load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True):
r"""Loads the Torch serialized object at the given URL.
If the object is already present in `model_dir`, it's deserialized and
returned. The filename part of the URL should follow the naming convention
``filename-.ext`` where ```` is the first eight or more
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file.
The default value of `model_dir` is ``$TORCH_HOME/checkpoints`` where
environment variable ``$TORCH_HOME`` defaults to ``$XDG_CACHE_HOME/torch``.
``$XDG_CACHE_HOME`` follows the X Design Group specification of the Linux
filesytem layout, with a default value ``~/.cache`` if not set.
Args:
url (string): URL of the object to download
model_dir (string, optional): directory in which to save the object
map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load)
progress (bool, optional): whether or not to display a progress bar to stderr
Example:
>>> state_dict = torch.hub.load_state_dict_from_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth')
"""
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
warnings.warn('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
if model_dir is None:
torch_home = _get_torch_home()
model_dir = os.path.join(torch_home, 'checkpoints')
try:
os.makedirs(model_dir)
except OSError as e:
if e.errno == errno.EEXIST:
# Directory already exists, ignore.
pass
else:
# Unexpected OSError, re-raise.
raise
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = HASH_REGEX.search(filename).group(1)
_download_url_to_file(url, cached_file, hash_prefix, progress=progress)
return torch.load(cached_file, map_location=map_location)
设置断点,用调试器找到模型位置:
这样就不用担心联网的问题,并且可以指定好相应的模型。
https://blog.csdn.net/u014264373/article/details/85332181
直接从pth文件之中进行加载。
例如
import torch
import torchvision.models as models
# pretrained=True就可以使用预训练的模型
net = models.squeezenet1_1(pretrained=False)
pthfile = r'E:\anaconda\app\envs\luo\Lib\site-packages\torchvision\models\squeezenet1_1.pth'
net.load_state_dict(torch.load(pthfile))
print(net)
程序定义直接从目录下面读取文件。
直接从目录下加载
文件放在运行的目录下(语法很可能不对,只是参考):
def gcn_resnet101(num_classes, t, pretrained=True, adj_file=None, in_channel=300):
# fixme
model = models.resnet101(pretrained=False)
if pretrained:
print('load pretrained model...')
model.load_state_dict(torch.load('./resnet101-5d3b4d8f.pth'))
return GCNResnet(model, num_classes, t=t, adj_file=adj_file, in_channel=in_channel)
cp ~/.torch/models/resnet101-5d3b4d8f.pth chun-ML_GCN/
注意,要与程序运行的位置和 load_state_dict的路径一致
if backbone == 'resnet101':
model = models.resnet101(pretrained=False)
print('load pretrained model...')
model.load_state_dict(torch.load('./resnet101-5d3b4d8f.pth'))
elif backbone == 'resnet50':
model = models.resnet50(pretrained=False)
print('load pretrained model...')
model.load_state_dict(torch.load('./resnet50-5d3b4d8f.pth'))
即直接加载运行目录下的resnet101-5d3b4d8f.pth 这个模型。
这部分是我们对自己程序的验证,其他可以不看。因为每个人模型不一样。
直接按上面的方法进行更改。
general_train.py之中,改为exp_3,hgat_fc.py之中按照上面进行修改。
直接在目录下,env/bin/python general_train.py如果不报错,即可。
集群预训练模型的解决
看出报错在于集群依然想要加载预训练模型。
Downloading: "http://xxxxxxxr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth" to /home/xxx/.torch/models/se_resnet152-d17c99b7.pth
Traceback (most recent call last):
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connection.py", line 159, in _new_conn
(self._dns_host, self.port), self.timeout, **extra_kw)
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/util/connection.py", line 80, in create_connection
raise err
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/util/connection.py", line 70, in create_connection
sock.connect(sa)
OSError: [Errno 101] Network is unreachable
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connectionpool.py", line 600, in urlopen
chunked=chunked)
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connectionpool.py", line 354, in _make_request
conn.request(method, url, **httplib_request_kw)
File "/home/sxxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 1107, in request
self._send_request(method, url, body, headers)
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/ccccccccc/client.py", line 1152, in _send_request
self.endheaders(body)
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 1103, in endheaders
self._send_output(message_body)
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 934, in _send_output
self.send(msg)
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/http/client.py", line 877, in send
self.connect()
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connection.py", line 181, in connect
conn = self._new_conn()
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connection.py", line 168, in _new_conn
self, "Failed to establish a new connection: %s" % e)
urllib3.exceptions.NewConnectionError: : Failed to establish a new connection: [Errno 101] Network is unreachable
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/xx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/adapters.py", line 449, in send
timeout=timeout
File "/home/xxxxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/connectionpool.py", line 638, in urlopen
_stacktrace=sys.exc_info()[2])
File "/home/xxxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/urllib3/util/retry.py", line 398, in increment
raise MaxRetryError(_pool, url, error or ResponseError(cause))
urllib3.exceptions.MaxRetryError: HTTPConnectionPool(host='data.lip6.fr', port=80): Max retries exceeded with url: /cadene/pretrainedmodels/se_resnet152-d17c99b7.pth (Caused by NewConnectionError(': Failed to establish a new connection: [Errno 101] Network is unreachable',))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "train_se_clsgat.py", line 128, in
main()
File "train_se_clsgat.py", line 107, in main
model = util.get_model(args)
File "/home/xxx/job/tmp/job-25509/util.py", line 266, in get_model
class_channels=args.CLASS_CHANNELS)
File "/home/xxxx/job/tmp/job-25509/models/se_clsgat.py", line 379, in __init__
model=senet_origin.se_resnet152()
File "/home/xxx/job/tmp/job-25509/models/senet_origin.py", line 423, in se_resnet152
initialize_pretrained_model(model, num_classes, settings)
File "/home/xxx/job/tmp/job-25509/models/senet_origin.py", line 377, in initialize_pretrained_model
model.load_state_dict(model_zoo.load_url(settings['url']))
File "/home/slurm/job/tmp/job-25509/torch/lib/python3.5/site-packages/torch/utils/model_zoo.py", line 65, in load_url
_download_url_to_file(url, cached_file, hash_prefix, progress=progress)
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/torch/utils/model_zoo.py", line 71, in _download_url_to_file
u = urlopen(url, stream=True)
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/api.py", line 75, in get
return request('get', url, params=params, **kwargs)
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/api.py", line 60, in request
return session.request(method=method, url=url, **kwargs)
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/sessions.py", line 533, in request
resp = self.send(prep, **send_kwargs)
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/sessions.py", line 646, in send
r = adapter.send(request, **kwargs)
File "/home/xxx/job/tmp/job-25509/torch/lib/python3.5/site-packages/requests/adapters.py", line 516, in send
raise ConnectionError(e, request=request)
requests.exceptions.ConnectionError: HTTPConnectionPool(host='data.lip6.fr', port=80): Max retries exceeded with url: /cadene/pretrainedmodels/se_resnet152-d17c99b7.pth (Caused by NewConnectionError(': Failed to establish a new connection: [Errno 101] Network is unreachable',))
需要将预训练模型放在目录之下免得集群重复加载。
程序没有运行到加载模型一步。
==== GLOBAL INFO ====
IPLIST: xx.xx.xx.xx
IP0: xx.xx.xx.xx
====================
==== NODE INFO ====
NODE_RNAK: 0
IP0: xx.xx.xx.xx
NODE_IP: xx.xx.xx
===================
{'ADJ_FILE': 'data/data/coco/coco_adj.pkl',
'ALPHA': 0.8,
'BACKBONE': 'resnet150',
'BATCH_SIZE': 16,
'CLASS_CHANNELS': 256,
'CPROB': array([[1.00000000e+00, 8.26410144e-01, 7.04392284e-01, ...,
4.03311258e-01, 4.45312500e-01, 5.40000000e-01],
[4.18382255e-02, 1.00000000e+00, 1.02719033e-01, ...,
1.12582781e-02, 0.00000000e+00, 5.71428571e-03],
[1.34192234e-01, 3.86532575e-01, 1.00000000e+00, ...,
3.84105960e-02, 7.81250000e-03, 8.57142857e-03],
...,
[1.34812060e-02, 7.43331876e-03, 6.73948408e-03, ...,
1.00000000e+00, 2.34375000e-02, 8.57142857e-03],
[1.26178775e-03, 0.00000000e+00, 1.16198001e-04, ...,
1.98675497e-03, 1.00000000e+00, 2.57142857e-02],
[8.36764511e-03, 1.74901618e-03, 6.97188008e-04, ...,
3.97350993e-03, 1.40625000e-01, 1.00000000e+00]]),
'DATA': 'data/data/coco',
'DATA_TYPE': 'coco',
'DEEPMAR_LOSS': ,
'DEVICE_IDS': [0, 1, 2, 3, 4, 5, 6, 7],
'EPOCH': 100,
'EPOCH_STEP': 30,
'EVALUATE': False,
'EXP_NAME': 'se_clsgat',
'GROUPS': 12,
'GROUP_CHANNELS': 512,
'IMAGE_SIZE': 448,
'INP_NAME': 'data/data/coco/coco_glove_word2vec.pkl',
'IS_SLURM': False,
'LOSS_TYPE': 'DeepMarLoss',
'LR': 0.01,
'LRP': 0.01,
'LR_SCHEDULER': None,
'LR_SCHEDULER_PARAMS': None,
'MAX_EPOCH': 100,
'MODEL': 'se_clsgat',
'MOMENTUM': 0.9,
'NCLASSES': 80,
'NCLASSES_PER_GROUP': [1, 8, 5, 10, 5, 10, 7, 10, 6, 6, 5, 7],
'PRINT_FREQ': 10,
'RESUME': 'checkpoints/coco/se_clsgat/checkpoint.pth.tar',
'SAVE_MODEL_PATH': 'checkpoints/coco/se_clsgat',
'START_EPOCH': 0,
'WEIGHT_DECAY': 1e-05,
'WEIGHT_FILE': 'data/coco/coco_rate.pkl',
'WORKERS': 4}
Compose(
Resize(size=(512, 512), interpolation=PIL.Image.BILINEAR)
MultiScaleCrop
RandomHorizontalFlip(p=0.5)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
Compose(
Warp (size=448, interpolation=2)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
[dataset] Done!
[annotation] Done!
[json] Done!
[dataset] Done!
[annotation] Done!
[json] Done!
-------------------------------------------------------
Primary job terminated normally, but 1 process returned
a non-zero exit code.. Per user-direction, the job has been aborted.
-------------------------------------------------------
--------------------------------------------------------------------------
mpirun detected that one or more processes exited with non-zero status, thus causing
the job to be terminated. The first process to do so was:
Process name: [[58771,1],0]
Exit code: 1
--------------------------------------------------------------------------
通过上面1.5中的方法设置断点找到模型位置:拷贝过去
本地可以采用这种方法:
(torch041) $ cd /Users/baidu/.cache/torch/checkpoints/
(torch041) baidudeMacBook-Pro:checkpoints baidu$ ls
resnet101-5d3b4d8f.pth se_resnet152-d17c99b7.pth
(torch041) baidudeMacBook-Pro:checkpoints baidu$ cp se_resnet152-d17c99b7.pth /Users/baidu/Desktop/code/ML_GAT-master/
运行没有报错。
服务器已经知道相应的torch的缓存的地址:
cd ~/.torch/models/
ls
resnet101-5d3b4d8f.pth resnet50-19c8e357.pth se_resnet152-d17c99b7.pth
直接更换更改好的
senet_origin
def initialize_pretrained_model(model, num_classes, settings):
assert num_classes == settings['num_classes'], \
'num_classes should be {}, but is {}'.format(
settings['num_classes'], num_classes)
# model.load_state_dict(model_zoo.load_url(settings['url']))
print('loading pretrained model from local...')
model.load_state_dict(torch.load('./se_resnet152-d17c99b7.pth'))
model.input_space = settings['input_space']
model.input_size = settings['input_size']
model.input_range = settings['input_range']
model.mean = settings['mean']
model.std = settings['std']