跑小数据集不出错(3w的数据)
今天换到100w的数据 立马来了问题.目测是显存炸了
纪念训练的第100次 依然没有成功
Traceback (most recent call last):
File "pretrain-0718.py", line 626, in
main()
File "pretrain-0718.py", line 258, in main
cluster_result)
File "pretrain-0718.py", line 302, in train
images, cluster_result=cluster_result, index=index)
File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 799, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/home/bujin1/hcsc/HCSC-master/hcsc/hcsc.py", line 235, in forward
proto_logits, proto_labels, proto_selected, temp_protos = self.get_protos(q, index, cluster_result)
File "/home/bujin1/hcsc/HCSC-master/hcsc/hcsc.py", line 357, in get_protos
neg_mask = self.sample_neg_protos(im2cluster, cluster2cluster, pos_proto_id, prot_logits, n, cluster_result) # [N, N_neg]
File "/home/bujin1/hcsc/HCSC-master/hcsc/hcsc.py", line 404, in sample_neg_protos
neg_sampler = torch.distributions.bernoulli.Bernoulli(sampling_prob.clamp(0.0001, 0.999))
File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/distributions/bernoulli.py", line 48, in __init__
super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)
File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/distributions/distribution.py", line 53, in __init__
raise ValueError("The parameter {} has invalid values".format(param))
ValueError: The parameter probs has invalid values