为了看懂基于MMDetection/MMDetection3D的目标检测模型代码,有必要先了解一些重要但平时不常用的python基础知识。
参考:Python中的__init__和super() - 知乎
python定义类的语句如下:
class ClassName:
也可在类名后加括号,括号内写上另一个已定义类的名称表示新类继承旧类的属性和方法:
class DerivedClassName(BaseClassName):
这里DerivedClassName称为子类(派生类),BaseClassName称为父类(基类)。
假设我们有一个类Fruit,其定义如下:
class Fruit:
def __init__(self, name="Apple"):
self.name = name
以Fruit作为父类,定义下面的子类:
class Apple(Fruit):
pass
创建该子类的实例:
f = Apple()
使用print函数输出其属性:
print(f.name) # 输出Apple
可以看到,即使在定义Apple类时没有类似name="Apple"的语句,也能获取其name属性的值。这就是因为Apple类继承了其父类Fruit的属性和初始化方法。
被继承的方法可以重写。例如,新建一个Apple_Init类(仍以Fruit为父类),创建其实例并输出属性:
class Apple_Init(Fruit):
def __init__(self, color):
self.color = color
fi = Apple_Init('red')
print(fi.name) # 该语句会报错说Fruit_Init没有name属性
print(fi.color) # 输出red
可以看到,由于定义Apple_Init时定义了初始化方法,覆盖了继承自其父类的初始化方法,因此访问属性name失败。
若要同时继承其父类的初始化方法并添加新的属性,则可以使用super()函数,其语法为
super(DerivedClassName, self).BaseClassMethodName(*ArgsOfBaseClassMethod)
表示继承父类BaseClassName的BaseClassMethodName方法(参数为ArgsOfBaseClassMethod)。
仍以Fruit为父类创建另一Apple_Super类,在继承Fruit类初始化方法的基础上添加新的属性:
class Apple_Super(Fruit):
def __init__(self, name, color):
self.color = color
super(Apple_Super, self).__init__(name)
fs = Apple_Super('Apple','red')
print(fs.name) # 输出Apple
print(fs.color) # 输出red
可见Apple_Super类的实例同时拥有父类属性name和子类属性color。
参考:mmdetecion 中类注册的实现(@x.register_module())
假设现在有函数func1,以函数为参数:
def func1(fn):
fn()
print(1)
假设现在有另一函数func2,功能是在屏幕上打印“2”。我们使用func1修饰func2:
@func1
def func2():
print(2)
然后执行func2,可观察到依次输出2和1。实际上,经过@的修饰,func2与下面的语句等价:
def func1(fn):
fn()
print(1)
def func2():
print(2)
func1(func2)
类似的,若func以类为输入和返回值,修饰类my_class:
def func(cls):
print(0)
return cls
@func
class my_class():
def __init__(self):
...
直接运行上述程序(注意该程序并没有对类进行实例化),会发现屏幕输出0,说明执行了func中的语句。
在MMDetection中,自定义模型时用到的@x.register_module()即在运行时调用注册函数(维护一个模块列表,在搭建(build)时从配置文件的type字段取出类别名,然后将剩余字段传入该类中进行初始化)。
若定义函数时允许函数接收可变数量的参数,可以使用*args或**kwargs。二者区别在于:
(1)*args会将非键值对参数打包为元组。例如
def func(*args):
print(args)
func("name", "color") # 输出('name', 'color')
(2)**kwargs会将键值对参数打包为字典。例如
def func(**kwargs):
print(kwargs)
func(name='apple',color='red') # 输出{'name': 'apple', 'color': 'red'}
注意,在同时使用*args和**kwargs时,*args必须放在**kwargs前面。此外,args和kwargs的名称可以随意修改。
类似地,在传参时,也可以使用*和**,会自动将传入的列表/元组以及字典分开。例如:
def func(name,color):
print(name,color)
func_inputs = ['apple','red']
func(*func_inputs) # 输出:apple red
附:其他一些pytorch中可能会遇见的不常用操作
(1)None作为tensor索引,如
a = torch.zeros(2,3) print(a[:,None,None].shape) # 输出torch.Size([2,1,1,3])
即None作为索引时相对于在相应的维度进行一次unsqueeze操作。
(2)@运算符,即矩阵乘法。
a = torch.zeros(2,3) b = torch.zeros(3,2) c = a @ b # 即a和b的矩阵乘积