たれぱんのびぼーろく

わたしの備忘録、生物学とプログラミングが多いかも

PyTorchのnn.Moduleを読み解く

レイヤーをattributeとして設定する必要がある理由

__setattr__でフック掛けて処理をしているから
フック内ではattribute valueの種類に基づいて内部登録がなされる.
module.parameters()ではparamsのみではなくmodulesへも再帰的にアクセスしてparamsを拾ってきてくれる。__setattr__フックによるmodule登録がこれを可能にしてくれている.
moduleは他にも.to(device)を提供しており、配下のParametersをGPUへ転送してくれる.
このparams取得も同様に子モジュール再帰で実現されている

torch.nn.modules.module — PyTorch master documentation

    def __setattr__(self, name, value):
        def remove_from(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]
        params = self.__dict__.get('_parameters')
        # 以下、つらつらとsetされたvalueをチェック
        # “Parameter”
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "cannot assign parameters before Module.__init__() call")
            remove_from(self.__dict__, self._buffers, self._modules)
            self.register_parameter(name, value)
        # 既存paramの更新
        elif params is not None and name in params:
            if value is not None:
                raise TypeError("cannot assign '{}' as parameter '{}' "
                                "(torch.nn.Parameter or None expected)"
                                .format(torch.typename(value), name))
            self.register_parameter(name, value)
        # Paramじゃない新規attribute
        else:
            modules = self.__dict__.get('_modules')
            ## Module in Moduleの場合
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers)
                modules[name] = value
            ## 既存moduleの更新
            elif modules is not None and name in modules:
                if value is not None:
                    raise TypeError("cannot assign '{}' as child module '{}' "
                                    "(torch.nn.Module or None expected)"
                                    .format(torch.typename(value), name))
                modules[name] = value
            ## Buffer扱い
            else:
                buffers = self.__dict__.get('_buffers')
                if buffers is not None and name in buffers:
                    if value is not None and not isinstance(value, torch.Tensor):
                        raise TypeError("cannot assign '{}' as buffer '{}' "
                                        "(torch.Tensor or None expected)"
                                        .format(torch.typename(value), name))
                    buffers[name] = value
                else:
                    object.__setattr__(self, name, value)