前言
疫情在家办公,新Team这边习惯用MMLab开发网络,正好趁这段时间理解一下商汤大佬们的框架。我之前其实网络开发的比较少,主要是学习用的,而且开发网络基本是靠手写或者copy,用这种架构开发我是十分赞成的,上手快,不容易出错,而且在这个网络训练网络的时代,config作为深度网络的上位机确实是王道。Anyway, 作为学习者,还是要知道网络是怎么通过config搭建好的,才能将自己的网络迁移进来,否则灵活性太差了。这期总结一下mmsegmentation的搭建网络的方法。
框架分析
关于config的分析就不多说了,继承形式的config,我们这里主要关心网络是如何形成的。
网络的搭建当然要从tools/train.py开始,整个main函数前面大部分都是在解析configs的配置并存到cfg对象,直到在196行开始终于开始用build_segmentor这个函数来建立模型。
追溯这个函数,可以找到mmseg/model文件夹下的builder.py,model文件中显然存放的是模型的结构文件,包括主干网、neck、检测头等,builder.py应该算作model的一个整体对外的接口。
在builder.py文件中,我们发现首先用Registry实例化MODELS,并且感觉像是继承了MMCV_MODELS,这个基类我们线猜测是MMLAB的模型库。然后将MODELS又传递给SEGMENTORS。
在build_segmentor中,SEGMENTORS使用build方法建立了模型,到目前为之,模型算子或者模块都没有显示出来,那核心就是这个注册表Registry类作了什么操作了。
我们找到Registry类,说明里面表明Registry是为了将字符串和类进行map,那懂了,Registry确实是注册表的意思。注册表是为了做什么的呢?注册表本质上是存储设置信息的一种数据库,说明Registry其实本质目的就是把config的信息传递到网络的类中。
我们看到的Registry的使用方法是如下这种形式,可以看到’model’是传入的name, MMCV_MODELS是传入的build_func,描述中可以看到,build_func如果没有给出,但是parent参数给出了,build_func会隐式继承parent传入的参数。
MODELS = Registry('models', parent=MMCV_MODELS)
到目前为之,我们还有两个问题没有搞清楚,第一个问题,这个注册表类是怎么生成模型的;第二个问题,父类注册表MMCV_MODELS里面又说了些啥。我们先解决第一个问题,看一下注册表类里面的build方法:
def __init__(self, name, build_func=None, parent=None, scope=None):self._name = nameself._module_dict = dict()self._children = dict()self._scope = self.infer_scope() if scope is None else scope# self.build_func will be set with the following priority:# 1. build_func# 2. parent.build_func# 3. build_from_cfgif build_func is None:if parent is not None:self.build_func = parent.build_funcelse:self.build_func = build_from_cfgelse:self.build_func = build_funcif parent is not None:assert isinstance(parent, Registry)parent._add_children(self)self.parent = parentelse:self.parent = None
def build(self, *args, **kwargs):return self.build_func(*args, **kwargs, registry=self)
其他代码先不关注,只看build,我们发现build方法实际就是调用了build_func,而build_func实际就是你传入的父类注册表,现在两个问题又回到了一个问题,父类或者基类的这个注册表描述了什么。去找一下他的定义,我发现了这个父类已经到了/python3.6/site-packages/mmcv/cnn/builder.py 这个路径下了,说明我们已经接近他的核心部分了。进取以后,惊呆了,竟然是个环,没错,你没看错,又是Registry,只不过这次直接传入了build_func,而且将build_func传入的builid_model_froom_cfg同时定义好了。
MODELS = Registry('model', build_func=build_model_from_cfg)
我们主要要看一下,这个build_func在干些啥,非常清晰,这个函数直接输出的就是nn.modules,就是我们要的pytorch模型
def build_model_from_cfg(cfg, registry, default_args=None):"""Build a PyTorch model from config dict(s). Different from``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.Args:cfg (dict, list[dict]): The config of modules, is is either a configdict or a list of config dicts. If cfg is a list, athe built modules will be wrapped with ``nn.Sequential``.registry (:obj:`Registry`): A registry the module belongs to.default_args (dict, optional): Default arguments to build the module.Defaults to None.Returns:nn.Module: A built nn module."""if isinstance(cfg, list):modules = [build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg]return Sequential(*modules)else:return build_from_cfg(cfg, registry, default_args)
回想前面的build方法(如下),传入的参数其实就是给build_model_from_cfg这个函数服务的,传入的主要是cfg,train_cfg和test_cfg,看起来应该是cfg参数是模型主参数,先做个大胆的推测,然后等待打脸(补充:回到train.py 你会发现,传入的是cfg.model,确实是模型主参数)~ 模型到底是砸建的呢?我们又可以看到,build_model_from_cfg函数里面出现了一个build_from_cfg,而且执行了一个for循环去遍历cfg,我们有理由相信这步就是为了形成模型的各个模块,查找一下build_from_cfg这个函数,这个又回到了Reigister那个类的文件中。
return SEGMENTORS.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
分析一下这个函数,忽略前面一堆if,首先将cfg传给了arg,果然习惯用arg的人也不少啊哈哈~然后将default_args传入到args中,然后从args中pop出‘type’这个key对应的value,例如‘EncodeDecoder’,再将这个value传给registry.get方法。
def build_from_cfg(cfg, registry, default_args=None):"""Build a module from config dict.Args:cfg (dict): Config dict. It should at least contain the key "type".registry (:obj:`Registry`): The registry to search the type from.default_args (dict, optional): Default initialization arguments.Returns:object: The constructed object."""if not isinstance(cfg, dict):raise TypeError(f'cfg must be a dict, but got {type(cfg)}')if 'type' not in cfg:if default_args is None or 'type' not in default_args:raise KeyError('`cfg` or `default_args` must contain the key "type", 'f'but got {cfg}\n{default_args}')if not isinstance(registry, Registry):raise TypeError('registry must be an mmcv.Registry object, 'f'but got {type(registry)}')if not (isinstance(default_args, dict) or default_args is None):raise TypeError('default_args must be a dict or None, 'f'but got {type(default_args)}')args = cfg.copy()if default_args is not None:for name, value in default_args.items():args.setdefault(name, value)obj_type = args.pop('type')if isinstance(obj_type, str):obj_cls = registry.get(obj_type)if obj_cls is None:raise KeyError(f'{obj_type} is not in the {registry.name} registry')elif inspect.isclass(obj_type):obj_cls = obj_typeelse:raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')try:return obj_cls(**args)except Exception as e:# Normal TypeError does not print class name.raise type(e)(f'{obj_cls.__name__}: {e}')
谈起get方法就稍微有点复杂了,get其实实现了这个pop出来的value是一个什么样的任务,其中又用到了类的嵌套,我对这个一直没有搞清楚。Anyway,我们这步实现了提取对应模型的class
def get(self, key):"""Get the registry record.Args:key (str): The class name in string format.Returns:class: The corresponding class."""scope, real_key = self.split_scope_key(key)if scope is None or scope == self._scope:# get from selfif real_key in self._module_dict:return self._module_dict[real_key]else:# get from self._childrenif scope in self._children:return self._children[scope].get(real_key)else:# goto rootparent = self.parentwhile parent.parent is not None:parent = parent.parentreturn parent.get(key)
最后在build_from_cfg中用try方法实例化的这个类,从而生成了模型。如果我们去看对应的类的话,我们还会发现每个类的上面还有对应的装饰器方法,该装饰器方法会在实例化模型的过程中,将模型记录在对应的注册表类中。