出现这个报错的一脸懵逼,检查后发现很简单,父类的初始化函数编写问题。
没有写父类的初始化函数
import torch
import torch.nn as nn
class Convolution(nn.Module):
def __init__(self,in_channels):
self.conv = nn.Conv2d(kernel_size=3,in_channels=in_channels,out_channels=64,stride=1,padding=1)
def forward(self,x):
x = self.conv(x)
return x
if __name__ == '__main__':
data = torch.ones([1,3,224,224])
model = Convolution(in_channels=3)
b = model(data)
print(b)
报错:
正确形式
import torch
import torch.nn as nn
class Convolution(nn.Module):
def __init__(self,in_channels):
super(Convolution,self).__init__()
self.conv = nn.Conv2d(kernel_size=3,in_channels=in_channels,out_channels=64,stride=1,padding=1)
def forward(self,x):
x = self.conv(x)
return x
if __name__ == '__main__':
data = torch.ones([1,3,224,224])
model = Convolution(in_channels=3)
b = model(data)
print(b)
父类初始化没有传入当前self对象
import torch
import torch.nn as nn
class Convolution(nn.Module):
def __init__(self,in_channels):
super(Convolution).__init__()
self.conv = nn.Conv2d(kernel_size=3,in_channels=in_channels,out_channels=64,stride=1,padding=1)
def forward(self,x):
x = self.conv(x)
return x
if __name__ == '__main__':
data = torch.ones([1,3,224,224])
model = Convolution(in_channels=3)
b = model(data)
print(b)
报错形式与上面一致
正确形式
import torch
import torch.nn as nn
class Convolution(nn.Module):
def __init__(self,in_channels):
super(Convolution,self).__init__()
self.conv = nn.Conv2d(kernel_size=3,in_channels=in_channels,out_channels=64,stride=1,padding=1)
def forward(self,x):
x = self.conv(x)
return x
if __name__ == '__main__':
data = torch.ones([1,3,224,224])
model = Convolution(in_channels=3)
b = model(data)
print(b)