HCSC 2022cvpr 训练报错

跑小数据集不出错(3w的数据)

今天换到100w的数据 立马来了问题.目测是显存炸了

纪念训练的第100次 依然没有成功

Traceback (most recent call last):
  File "pretrain-0718.py", line 626, in
    main()
  File "pretrain-0718.py", line 258, in main
    cluster_result)
  File "pretrain-0718.py", line 302, in train
    images, cluster_result=cluster_result, index=index)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/distributed.py", line 799, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/bujin1/hcsc/HCSC-master/hcsc/hcsc.py", line 235, in forward
    proto_logits, proto_labels, proto_selected, temp_protos = self.get_protos(q, index, cluster_result)
  File "/home/bujin1/hcsc/HCSC-master/hcsc/hcsc.py", line 357, in get_protos
    neg_mask = self.sample_neg_protos(im2cluster, cluster2cluster, pos_proto_id, prot_logits, n, cluster_result) # [N, N_neg]
  File "/home/bujin1/hcsc/HCSC-master/hcsc/hcsc.py", line 404, in sample_neg_protos
    neg_sampler = torch.distributions.bernoulli.Bernoulli(sampling_prob.clamp(0.0001, 0.999))
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/distributions/bernoulli.py", line 48, in __init__
    super(Bernoulli, self).__init__(batch_shape, validate_args=validate_args)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/distributions/distribution.py", line 53, in __init__
    raise ValueError("The parameter {} has invalid values".format(param))
ValueError: The parameter probs has invalid values

你可能感兴趣的:(python)