一种方法为使用init和forward函数。my_model中传入torch.nn.module模块。下面为class的模版:
class my_model(torch.nn.module):
def __init__(self,a,b,c,...):
super (my_model,self)__init__():
self.my_model_a = a
self.b = b
self.c = c
...
def forward(self, input_1, imput_2):
input_x1 = imput_1
input_x2 = input_2
output = input_x1 + input_x2
return output
class my_model(torch.nn.module):
def __init__(self,a,b,c,...):
super (my_model,self)__init__():
self.my_model_a = a
self.b = b
self.c = c
...
def forward(self, input_1, imput_2):
input_x1 = imput_1
input_x2 = input_2
output = input_x1 + input_x2
return output
在使用上述定义时:
1.首先导入参数:
my_model_1 = my_model(a, b, c, …)
2.然后使用forward调用函数
y_out = my_model_1(input_1, input_2)
另一种实例化方式:
分为两个步骤:
1.init 初始化
2.call 调用函数(内部调用forward函数)
模板如下:
class my_model(torch.nn.module):
def __init__(self):
pass
...
def __call__(self):
pass
1.def function(a, b, c, d, e)
function(1, 2, 3, x = 4, y = 5)
2.当数量不确定时
function(1, 2, 3,…, x = 4, y = 5, …)
def function(*args, **kwargs)
print(args)
print(kwargs)
print(‘Hello’ + str(’’))
function(1, 2, 3, x = 4, y = 5)
使用上述定义时,传参按照call函数:
class A():
def __init__(self, origion_apple):
super().__init__()
print('我原来有苹果个数是:',origion_apple)
self.origion_apple= origion_apple
def __call__(self, added_apples):
res = self.forward(added_apples)
return res
def forward(self, input_apples):
print('forward 函数被调用了')
return input_apples + self.origion_apple
print('对象初始化')
a = A(5)
all_num_apples = a(2)
print("我现在苹果个数是:", all_num_apples )