Registry的功能
mmlab中的registry类主要用来对model中的backbone、neck或者dataset、optimizer等进行一个构建,维护一个全局key-value对,当我们想更换某一部分的时候直接更改即可
例如我通过registry注册backbone类中有ResNet18、ResNet50、VGG19等等,假设我们需要实例化ResNet,只需要更改配置文件中的如下参数即可
1 2 3 4 5 6 7 8 9 backbone=dict ( type ='ResNet' , depth=50 , num_stages=4 , out_indices=(0 , 1 , 2 , 3 ), frozen_stages=1 , norm_cfg=dict (type ='BN' , requires_grad=True ), norm_eval=True , style='pytorch' ),
其中type就是我们要更换的backbone类型,随后的是backbone对应的参数
Registry的简单实现
上面提到registry类目的是维护一个全局的key-value对,所以我们需要一个全局变量,并不断完善它
下面的例子均来源于mmlab知乎
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 _module_dict = dict () def register_module (name ): def _register (cls ): _module_dict[name] = cls return cls return _register @register_module('one_class' ) class OneTest (object ): pass @register_module('two_class' ) class TwoTest (object ): pass if __name__ == '__main__' : one_test = _module_dict['one_class' ]() print (one_test) ''' <__main__.OneTest object at 0x7f4eee6bdc70> '''
对上面的例子进行解释,函数运行到@register_module('one_class')时,会对OneTest类进行注册,注册后全局字典_module_dict会多出一个参数,同样的,运行到TwoTest中时依然会对其进行注册,全部注册完后,_module_dict的内容如下所示,正好是两个key-value值
1 {'one_class' : <class '__main__.OneTest' >, 'two_class' : <class '__main__.TwoTest' >}
Registry类实现
registry的类实现就是mmcv中所使用的方法,方法非常简洁,registry类如下
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 class Registry : def __init__ (self, name ): self._name = name self._module_dict = dict () def _register_module (self, module_class, module_name=None , force=False ): if not inspect.isclass(module_class): raise TypeError('module must be a class, ' f'but got {type (module_class)} ' ) if module_name is None : module_name = module_class.__name__ if not force and module_name in self._module_dict: raise KeyError(f'{module_name} is already registered ' f'in {self.name} ' ) self._module_dict[module_name] = module_class def register_module (self, name=None , force=False , module=None ): if module is not None : self._register_module(module_class=module, module_name=name, force=force) return module def _register (cls ): self._register_module(module_class=cls, module_name=name, force=force) return cls return _register
registry类有两个属性,_name和_module_dict,第一个是registry的名字,比如backbone,第二个保存key-value属性,比如resnet和其对应类。其余两个函数就是用来注册字典属性的了,和之前的简单实现类似,还有一些特殊参数,下面举例看一下其使用
第一个例子
这个例子表示register装饰器可以自动命名,比如我们不想让他用默认的Converter1类名,想给他命名为abc,这样字典就是{“abc”: Converter1()}了,此外force为True表示,如果命名产生了重复,则会覆盖之前的类
1 2 3 4 5 @CONVERTERS.register_module(name="abc" , force=True ) class Converter1 (object ): def __init__ (self, a, b ): self.a = a self.b = b
第二个例子
下面这个例子表示不适用装饰器的方法,而是通过函数对其注册
1 CONVERTERS.register_module(module=Converter1())
Register的调用
我们维护了这么多register类,那么是如何调用的呢,mmlab给出了答案,只需要调用build_from_cfg函数即可,build_from_cfg(cfg, register, default_args=None)有三个参数
cfg: 与注册类相关的信息
register: 选择哪个类型的注册器
default_args: 是否要对cfg进行补充
下面我们看一个例子就很容易理解了
1 2 3 4 5 6 7 8 @CONVERTERS.register_module() class Converter1 (object ): def __init__ (self, a, b ): self.a = a self.b = b converter_cfg = dict (type ='Converter1' , a=1 , b=2 ) converter = build_from_cfg(converter_cfg, CONVERTERS)
如上所示,我们注册了Converter1这个key,注册的类别为CONVERTERS,build_from_cfg第一个参数为converter_cfg,表示我们想要Converter1这个key,他的参数为a和b,第二个参数为CONVERTERS表示我们要在CONVERTERS这个注册类中寻找Converter1这个key,找到之后就自动用ab参数初始化了,是不是很简单?
下面我们看一下build_from_cfg内部是如何实现的
代码如下,实现非常简单,首先如果default_args有值,则将其与cfg合并,随后取出type这个参数找到相应的key,如果有这个key则返回类,否则如果实例化了就直接返回,如果都没有就报错,最后找到这个类后,用其配置参数也就是我们上面说的ab来初始化他并返回~就本例而言,我们会直接返回一个用a=1,b=1初始化的Converter1类,随后使用即可
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 def build_from_cfg (cfg, registry, default_args=None ): args = cfg.copy() if default_args is not None : for name, value in default_args.items(): args.setdefault(name, value) obj_type = args.pop('type' ) if isinstance (obj_type, str ): obj_cls = registry.get(obj_type) if obj_cls is None : raise KeyError(f'{obj_type} is not in the {registry.name} registry' ) elif inspect.isclass(obj_type): obj_cls = obj_type else : raise TypeError(f'type must be a str or valid type, but got {type (obj_type)} ' ) return obj_cls(**args)