trifle

技術メモ

Monadic Memoization in Python

関数の再帰呼び出しにおいて値を保持しておくことで計算が高速化されるメモ化という手法が知られています.
例えば以下のようにして80番目のフィボナッチ数が高速に求まります.

d = [0 for i in range(100)]

def fib(n):
    if n <= 1:
        return n
    if d[n] != 0:
        return d[n]

    d[n] = fib(n-1) + fib(n-2)
    return d[n]

print(fib(80)) // 23416728348467685

実は上記の d のようなグローバル変数をあらかじめ用意せず, メモの部分を隠蔽しながら計算を実行することができます. State モナドの考え方を使うのです.
学科の課題でメモ化の隠蔽を OCaml でやったのが面白かったので, Python のような動的型付け言語でも何とかできないかなーと考えてやってみました.
もちろん, Python で型アノテーションを使ってモナドを定式化するというようなことは流石にしません. あくまで考え方を使うというだけです.


ではやっていきましょう.
そもそも State モナド

state :: (s -> (a, s)) -> State s a

こういう形状をしています. ある状態をもらって, 値と新しい状態のタプルを返す関数です.
今回はこの「状態」にメモの部分(値の配列)が相当します. なので, 雰囲気としては

type int m = [int] -> (int, [int])

なるモナド int m があると考えてみてください(もちろんこんな式を直接 Python のコードにはかけません).


まず, モナド則として必ず用意されるべき2つの関数をつくります.

1つ目は return: int -> int m で, 副作用を起こさずモナドを返します. これは

def return_(x):
    return lambda t: (x, t)

こんな風に無名関数を返してやればよいです(return予約語なので return_ で).


2つ目は副作用を繋げてモナドを返す bind: int m -> (int -> int m) -> int m です.
これは OCaml のような関数型言語だととてもシンプルに書けますが, 普通のプログラミング言語だと割とつらいです. というのも, int m というモナドを一つ目の引数に取り, int -> int m という, 引数をとってモナドを返す関数を二つ目の引数に取り, そのモナド自体も関数なのですから.
実際こんな形状をしています.

def bind(x):
    def bind_sub(f):
        def bind_sub2(t):
            tup = x(t)
            return f(tup[0])(tup[1])
        return bind_sub2
    return bind_sub

この関数のお気持ちはどうなっているのかというと, メモ t を一つ目のモナドが受け取って, 値と状態を返し, それを二つ目の関数 f にぶっこんで得た結果を返すので, 全体としては, 1つ目の引数の副作用と2つ目の引数の副作用を繋げた形になっているということです. (自分でも説明になっていない...)
この bind は, HaskellOCaml のような関数の中置記法が使える言語では [1つ目の引数] >>= [2つ目の引数] というように書け, 副作用を繋げている感じがあってよいのですが, Python ではそういうことはしづらいので残念です(流石に演算子オーバーロードをやるのはちょっと...)

実装で lambda 式をあえて使わなかったのは, 最初 lambda 式を使って試してみたところ式の評価がうまくいかず実装に失敗してしまったのと, lambda 式の中で局所変数が定義できず書くのに不便だったのが理由です.
JavaScript のアロー関数だったらもっと融通が利いたのかなーという気がします.


次に, 今回のメモ化の隠蔽の中で核心にあたる memo: (int -> int m) -> int -> int m をつくります.

def memo(f):
    def memo_sub(n):
        def memo_sub2(t):
            if t[n] == 0:
                tup = f(n)(t)
                t[n] = tup[0]
            return t[n], t
        return memo_sub2
    return memo_sub

これは, 引数としてメモ t が与えられた時に, メモにデータがなかったら関数 f を評価してメモを行います.


また, メモの初期値を与えてメモ化を走らせ, 最終的な結果を得る run_memo: int m -> int をつくります.

def run_memo(m):
    tup = m([0 for i in range(100)])
    return tup[0]

以上をもって, 構成パーツの準備が完了しました.


そして, フィボナッチ数を計算する関数 fib: int -> int m がやっと登場します.

def fib(n):
    if n <= 1:
        return return_(n)
    else:
        def fib_sub(x):
            def fib_sub2(y):
                return return_(x + y)
            return bind(memo(fib)(n-1))(fib_sub2)
        return bind(memo(fib)(n-2))(fib_sub)

形がエグいですね. 中置記法 >>= と無名関数を仮に使ってみれば, 下の5行は

(memo fib (n-2)) >>= (fun x ->
(memo fib (n-1)) >>= (fun y ->
   return (x + y)))

という感じになります.

ところで, これはフィボナッチ数を計算する関数ですが, フィボナッチ数を返す関数ではありません!
例えるなら, この fib東北新幹線東海道新幹線山陽新幹線を繋いだだけであって, 実際に青森に列車を用意していないので, 博多に着いた時の乗客の様子は分からないのです. だからこそ, 先ほどの run_memo が必要なのですね.


全体像は以下のようになります.

def return_(x):
    return lambda t: (x, t)

def bind(x):
    def bind_sub(f):
        def bind_sub2(t):
            tup = x(t)
            return f(tup[0])(tup[1])
        return bind_sub2
    return bind_sub

def memo(f):
    def memo_sub(n):
        def memo_sub2(t):
            if t[n] == 0:
                tup = f(n)(t)
                t[n] = tup[0]
            return t[n], t
        return memo_sub2
    return memo_sub

def run_memo(m):
    tup = m([0 for i in range(100)])
    return tup[0]

def fib(n):
    if n <= 1:
        return return_(n)
    else:
        def fib_sub(x):
            def fib_sub2(y):
                return return_(x + y)
            return bind(memo(fib)(n-1))(fib_sub2)
        return bind(memo(fib)(n-2))(fib_sub)

https://gist.github.com/7ma7X/ad8b56f6ca6781bcf742dfc8b070bc1d

そして

print(run_memo(fib(80))) // 23416728348467685

このように値を求めることができました.



メモ化の隠蔽を Python でやってみて思ったのは, やはり関数型言語はすごいということです.
State モナドは, 実態としては一つの関数です. なので, Python のような言語でこのモナドを使うときは, ネストした関数を意識して実装する必要があり難しいです. 一方 OCaml のような関数型言語では(学科の課題の解答になるのでここには書きませんが)非常に綺麗な形で書けます. 関数を型によって抽象化することによって, 関数を関数として意識しなくても実装ができるのがとても楽なのです.
というわけで, 関数型言語のありがたみを分かるためには, こんな風に, 関数型言語でない言語でモナドを作ってみるのがいいんじゃないかなーと思うのでありました.