たれぱんのびぼーろく

わたしの備忘録、生物学とプログラミングが多いかも

Pythonのデフォルト引数とクラスインスタンス - 再利用が引き起こすバグ

Pythonは言語仕様としてデフォルト引数を再利用する。これは mutable オブジェクトとの組み合わせで容易にバグを引き起こす。

言語仕様

関数定義時にデフォルト引数が評価されて確定し(=評価は一度きり)、これが関数呼び出し時に渡される。

The default values are evaluated at the point of function definition in the defining scope ... The default value is evaluated only once.

意図しないオブジェクト共有

デフォルト引数に mutable オブジェクト(例: list、クラスインスタンス)を使うと、言語仕様によるオブジェクト共有が起こる。
これは(一般的な)デフォルト引数のセマンティクスと一致しないため、意図しない共有として機能してしまう可能性が高い。

こんな感じ:

from dataclasses import dataclass, field

@dataclass
class Child:
    c1: str
    c2: int
    c3: int = 1

@dataclass
class Parent:
    a1: int
    child: Child = Child(
        c1 = "from parent",
        c2 = 0)

p1, p2 = Parent(1), Parent(2)
p1.child.c3 = 100

assert p1.child.c3 != p2.child.c3, f"Different instance should have different number, but `{p1.child.c3}` == `{p2.child.c3}`"
print("passed")
# AssertionError: Different instance should have different number, but `100` == `100`

解決策

クラスの場合、Python標準の field(default_factory=lambda: Cls(args)) で解消できる。
default_factory はデフォルト引数が必要な際に都度呼出しされるため、その場でインスタンスが生成され共有されない。
これにより意図しない共有を防げる。

こんな感じ:

from dataclasses import dataclass, field

@dataclass
class Child:
    c1: str
    c2: int
    c3: int = 1

@dataclass
class Parent:
    a1: int
    child: Child = field(default_factory=lambda: Child(
        c1 = "from parent",
        c2 = 0))

p1, p2 = Parent(1), Parent(2)
p1.child.c3 = 100

assert p1.child.c3 != p2.child.c3, f"Different instance should have different number, but `{p1.child.c3}` == `{p2.child.c3}`"
print("passed")
# passed

糖衣構文

上記の解決策は問題を解決できるが、記述が煩雑になる。

私は以下のような糖衣構文を使っている:

from copy import deepcopy

def default(instance):
    return field(default_factory=lambda: deepcopy(instance))

これを用いることで次のように簡略化される:

from dataclasses import dataclass, field

@dataclass
class Child:
    c1: str
    c2: int
    c3: int = 1

@dataclass
class Parent:
    a1: int
    child: Child = default(Child(
        c1 = "from parent",
        c2 = 0))

p1, p2 = Parent(1), Parent(2)
p1.child.c3 = 100

assert p1.child.c3 != p2.child.c3, f"Different instance should have different number, but `{p1.child.c3}` == `{p2.child.c3}`"
print("passed")
# passed