Chainerをバックエンドにブラックボックス変分ベイズを実装する日記[2]
内容
- 現時点で思う理想のAPIの雰囲気を書いておく
- 雑にクラス設計をする
考えているAPI
つくるものとして近いものはEdwardやPyro。
PyroのAPIは結構わかりやすいと感じたのでベースとして真似したい。
EdwardのAPIをもはや覚えていない。
PyroのAPIをざっくりいうと「生成モデル(model)」と「変分事後分布(guide)」の生成過程を関数として定義して推論用オブジェクトに渡す感じになっている。
modelやguideは手続き的に生成過程を書いていけばいいのでかなり簡単。
PyroのAPIがなんとなくわかるドキュメント: SVI Part I: An Introduction to Stochastic Variational Inference in Pyro — Pyro Tutorials 0.3.1 documentation
だが、真似したくない部分があってそれがPoutineによるエフェクトハンドリング。
Poutine自体がよくできていることは確かだが、初見何が起こっているのかわからない。
一見独立して定義して互いに内部の実装を知り得ないmodelとguideがグローバルな状態の管理によって紐づいている(正確でなければすみません)。
エフェクトハンドラを使わずに生成モデルと変分事後分布(と条件付け)を紐づけるために、変数名をサンプリング時ではなく生成モデルの戻り値として定義するイメージをしている。
def model(): x = Normal(0, 1) # Distributionの扱いは雑に書いています y = Normal(x, 0.5) return {"x": x, "y": y}
雑なクラス設計
上記のAPIになるように愚直にChainerのDistributionを使って実装しようとすると、推論時に条件付けや変分事後分布によるサンプルを上書きするような介入ができない問題がある。
PyroではPoutineがやってくれる部分なのでエフェクトハンドラをやめれば当然である。
modelを呼び出した時点では計算グラフだけが構築されていて実際のサンプリングは行われていないような状態にする必要がありそう。
PyroもそうだがChainerのDistributionは「分布」から「サンプル」が得られるAPIになっており、確率変数を作ろうとした際にサンプリングが行われてしまう。
そこで、「分布」と「サンプル」の間に「確率変数」クラスを設けてみてはどうかと考えた。
クラス図とは言えない雑な図だが、分布(Distribution)の__call__()からは確率変数(StochasticVariable)が得られ、確率変数(StochasticVariable)のsample()をコールすることでVariable型のサンプルとして初めて値を持つ形を考えた。
modelを呼び出したときに得られるのはStochasticVariableを使った計算グラフで、推論オブジェクトがサンプリングしたり条件付けしたりするイメージ。
うまくいくかはわからないがこのまま実装に入る。