mmlab中Registry类

Registry的功能

mmlab中的registry类主要用来对model中的backbone、neck或者dataset、optimizer等进行一个构建,维护一个全局key-value对,当我们想更换某一部分的时候直接更改即可

例如我通过registry注册backbone类中有ResNet18、ResNet50、VGG19等等,假设我们需要实例化ResNet,只需要更改配置文件中的如下参数即可

1
2
3
4
5
6
7
8
9
backbone=dict(
type='ResNet', # 待实例化的类名
depth=50, # 后面的都是对于的类初始化参数
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),

其中type就是我们要更换的backbone类型,随后的是backbone对应的参数

Registry的简单实现

上面提到registry类目的是维护一个全局的key-value对,所以我们需要一个全局变量,并不断完善它

下面的例子均来源于mmlab知乎

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
_module_dict = dict()

# 定义装饰器函数
def register_module(name):
def _register(cls):
_module_dict[name] = cls
return cls
return _register

# 装饰器用法
@register_module('one_class')
class OneTest(object):
pass

@register_module('two_class')
class TwoTest(object):
pass

if __name__ == '__main__':
# 通过注册类名实现自动实例化功能
one_test = _module_dict['one_class']()
print(one_test)

'''
<__main__.OneTest object at 0x7f4eee6bdc70>
'''

对上面的例子进行解释,函数运行到@register_module('one_class')时,会对OneTest类进行注册,注册后全局字典_module_dict会多出一个参数,同样的,运行到TwoTest中时依然会对其进行注册,全部注册完后,_module_dict的内容如下所示,正好是两个key-value值

1
{'one_class': <class '__main__.OneTest'>, 'two_class': <class '__main__.TwoTest'>}

Registry类实现

registry的类实现就是mmcv中所使用的方法,方法非常简洁,registry类如下

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Registry:
def __init__(self, name):
self._name = name # 类名, 例如backbone, neck等
self._module_dict = dict() # 全局key-value对

def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class): # 判断是否为类
raise TypeError('module must be a class, ' f'but got {type(module_class)}')

if module_name is None: # 如果没有名, 默认用注册的类名
module_name = module_class.__name__
if not force and module_name in self._module_dict: # 是否覆盖重名类
raise KeyError(f'{module_name} is already registered ' f'in {self.name}')
self._module_dict[module_name] = module_class # 加入key-value对

# 装饰器函数
def register_module(self, name=None, force=False, module=None):
if module is not None: # 如果给定module,直接增加到字典中即可
self._register_module(module_class=module, module_name=name, force=force)
return module

def _register(cls): # 装饰器用法
self._register_module(module_class=cls, module_name=name, force=force)
return cls
return _register

registry类有两个属性,_name_module_dict,第一个是registry的名字,比如backbone,第二个保存key-value属性,比如resnet和其对应类。其余两个函数就是用来注册字典属性的了,和之前的简单实现类似,还有一些特殊参数,下面举例看一下其使用

第一个例子

这个例子表示register装饰器可以自动命名,比如我们不想让他用默认的Converter1类名,想给他命名为abc,这样字典就是{“abc”: Converter1()}了,此外force为True表示,如果命名产生了重复,则会覆盖之前的类

1
2
3
4
5
@CONVERTERS.register_module(name="abc", force=True)
class Converter1(object):
def __init__(self, a, b):
self.a = a
self.b = b

第二个例子

下面这个例子表示不适用装饰器的方法,而是通过函数对其注册

1
CONVERTERS.register_module(module=Converter1())

Register的调用

我们维护了这么多register类,那么是如何调用的呢,mmlab给出了答案,只需要调用build_from_cfg函数即可,build_from_cfg(cfg, register, default_args=None)有三个参数

  • cfg: 与注册类相关的信息
  • register: 选择哪个类型的注册器
  • default_args: 是否要对cfg进行补充

下面我们看一个例子就很容易理解了

1
2
3
4
5
6
7
8
@CONVERTERS.register_module()
class Converter1(object):
def __init__(self, a, b):
self.a = a
self.b = b

converter_cfg = dict(type='Converter1', a=1, b=2)
converter = build_from_cfg(converter_cfg, CONVERTERS)

如上所示,我们注册了Converter1这个key,注册的类别为CONVERTERS,build_from_cfg第一个参数为converter_cfg,表示我们想要Converter1这个key,他的参数为a和b,第二个参数为CONVERTERS表示我们要在CONVERTERS这个注册类中寻找Converter1这个key,找到之后就自动用ab参数初始化了,是不是很简单?

下面我们看一下build_from_cfg内部是如何实现的

代码如下,实现非常简单,首先如果default_args有值,则将其与cfg合并,随后取出type这个参数找到相应的key,如果有这个key则返回类,否则如果实例化了就直接返回,如果都没有就报错,最后找到这个类后,用其配置参数也就是我们上面说的ab来初始化他并返回~就本例而言,我们会直接返回一个用a=1,b=1初始化的Converter1类,随后使用即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def build_from_cfg(cfg, registry, default_args=None):
args = cfg.copy()

if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)

obj_type = args.pop('type') # 注册 str 类名
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type) # 从dict中根据key获取value
if obj_cls is None:
raise KeyError(f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type): # 如果已经实例化了,那就直接返回
obj_cls = obj_type
else:
raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')

return obj_cls(**args) # 根据给定cfg参数初始化对应类