WMSE

class Weighted_mse_mae(nn.Module):
	def __init__(self, mse_weight=1.0, mae_weight=1.0, NORMAL_LOSS_GLOBAL_SCALE=0.00005):
		super(Weighted_mse_mae, self).__init__()
		self.NORMAL_LOSS_GLOBAL_SCALE = NORMAL_LOSS_GLOBAL_SCALE
		self.mse_weight = mse_weight
		self.mae_weight = mae_weight

	def forward(self, input, target): #, mask
		balancing_weights = (1, 1, 5, 10, 30, 32)
		weights = torch.ones_like(input) * balancing_weights[0]
		#weights = torch.nn.Parameter(weights, requires_grad=True)

		thresholds = [ dBZ_to_pixel(ele) for ele in np.array([10, 20, 30, 40, 50]) ]
		for i, threshold in enumerate(thresholds):
			weights = weights + (balancing_weights[i + 1] - balancing_weights[i]) * (target >= threshold).float()

		#input: S*B*H*W
		mse = torch.sum(weights * ((input-target)**2), (2, 3))
		mae = torch.sum(weights * (torch.abs((input-target))), (2, 3))
		'''
		mse = weights * torch.nn.MSELoss(reduce=True, size_average=True)(input, target)
		mse = torch.sum(mse, (2,3,4))
		mae = weights * torch.nn.L1Loss(reduce=True, size_average=True)(input, target)
		mae = torch.sum(mae, (2,3,4))
		'''
		loss_value = self.NORMAL_LOSS_GLOBAL_SCALE * (self.mse_weight*torch.mean(mse) + self.mae_weight*torch.mean(mae))
		return loss_value

你可能感兴趣的:(实现某些功能的的代码code,深度学习)