読者です 読者をやめる 読者になる 読者になる

技術memo

関数型ゴースト

再帰関数のメモ化(Memoization for Recursive Functions)とY Combinator

Y Combinatorのことを調べていて、気がついたらこんなことに。

目次

  • 関数のメモ化(Memoization)のこと
  • Y Combinatorのこと(不動点コンビネータ)
  • 再帰関数のメモ化(Memoization for Recursive Functions)
  • 参考記事

関数のメモ化(Memoization)のこと

メモ化とは

メモ化 - Wikipedia

  • メモ化(英: Memoization)とは、プログラムの高速化のための最適化技法の一種であり、サブルーチン呼び出しの結果を後で再利用するために保持し、そのサブルーチン(関数)の呼び出し毎の再計算を防ぐ手法である。

  • メモ化された関数は、以前の呼び出しの際の結果をそのときの引数と共に記憶しておき、後で同じ引数で呼び出されたとき、計算せずにその格納されている結果を返す。メモ化可能な関数は参照透過性を備えたものに限られる。すなわち、メモ化されたことで副作用が生じない場合に限られる。

なるほどですね。

やってみる

ここで事前に用意しておきました、F# の「関数をメモ化する関数」を見てみましょう。

// memoize function
// val memoize : ('a -> 'b) -> ('a -> 'b) when 'a : equality
let memoize f =
    let cache = new System.Collections.Generic.Dictionary<_,_>() in
    (fun x-> cache.TryGetValue(x) |> function
        | true,r -> r
        | _ -> let r = f x in cache.Add(x,r); r)

これくらいの長さだと、つい一行で書いてしまいたくなりますが、blogなので適度に改行します。 この関数は先ほどのWikiPediaからの説明の通り、「以前の呼び出しの際の結果をそのときの引数と共に記憶しておき、後で同じ引数で呼び出されたとき、計算せずにその格納されている結果を返す」ようになっています。 クロージャの概念に馴染みが無いと、「このcacheって奴はただのローカル変数だから消えてしまうのではないか?」と思うのですが、ここでは「memoize関数に関数fを渡して適用したときに返される関数オブジェクトの中に保持されている」という程度に解釈しておきましょう。詳しくはクロージャ - Wikipediaなど。

実際に使ってみます。

let plus_one x =
    printfn "called!";
    x + 1

// memoize plus_one
let plus_one_memo = memoize plus_one

// test
do printfn "*** test plus_one ***"
do plus_one 2 |> printfn "result:%d"
do plus_one 2 |> printfn "result:%d"
do printfn "*** test plus_one_memo ***"
do plus_one_memo 2 |> printfn "result:%d"
do plus_one_memo 2 |> printfn "result:%d"

(* results....
*** test plus_one ***
called!
result:3
called!
result:3
*** test plus_one_memo ***
called!
result:3
result:3
*)

何をするにもまずは単純な例からということで、plus_one関数はただ1増やした数を計算するだけの関数です。printfnで画面に出力しているのは、確認用です。「メモ化可能な関数は参照透過性を備えたものに限られる」と矛盾しますが、ここでは無視します。

結果を見ると、plus_one関数では呼ぶ度にcalled!と出力されていますが、plus_one_memo関数では初めの1回だけ出力されています。これは、1回目に「1」を渡されて計算を行い、2回目で「2」が渡されたときは「後で同じ引数で呼び出されたとき、計算せずにその格納されている結果を返す」ようにした結果です。

Y Combinatorのこと(不動点コンビネータ)

何それこわい

まずは解説から

不動点コンビネータ - Wikipedia

  • 与えられた関数の不動点(のひとつ)を求める高階関数である。

  • ここで関数fの不動点とは、f(x) = xを満たすようなxのことをいう。

ちょっとうろ覚えなのですが、二分法 - Wikipediaを思い出しました。関数の適用結果をまた同じ関数に適用していって、最終的に値が(ほとんど)変わらなくなったらそれを計算結果とする。たぶん似たような話だと思います。

それで何

私も理解があやふやなので、参考記事を持ち出すことにします。

フドーテンが倒せない: いげ太のブログ

  • f(x) = x なる x が不動点となるのであった。f(x) の具体的な式がわかれば不動点がわかるのは前述したとおりだが、わからない場合でもどうにかならないか、と考えてみる。そこで、この問題を抽象的に捉え、f(x) = x の条件を満たす x を算出する関数 g(f) を仮定してみるのだ。そうすると、f(x) = x なる xg(f) であるので、xg(f) を代入してもこの条件は成立するのであり、よって、f(g(f)) = g(f) が導かれる。

  • なんと、この f(g(f)) = g(f)不動点演算子と呼ばれる、再帰を表す関数なのである。

  • プログラム上で不動点の理論を使うというのは、なにを意味するのか。答えを先に言うと、それは再帰である。不動点を使うことは再帰を使うことに等しい。

何がなんだか……頭が爆発しそうです。

ともかくやってみる

まずは普通のフィボナッチ数から

// compute the n-th fibonacci number
let rec fib_normal x = if x <= 2 then 1 else fib_normal (x-1) + fib_normal (x-2)
let rec fib_normal' x = printfn "called!(%d)" x; if x <= 2 then 1 else fib_normal' (x-1) + fib_normal' (x-2)// debug version
// test
do printfn "*** test fib_normal' ***"
do fib_normal' 6 |> printfn "result:%d"
do printfn "*** test fib_normal ***"
do [1..10] |> List.map fib_normal |> List.iter (printfn "%d")

(* results...
*** test fib_normal' ***
called!(6)
called!(5)
called!(4)
called!(3)
called!(2)
called!(1)
called!(2)
called!(3)
called!(2)
called!(1)
called!(4)
called!(3)
called!(2)
called!(1)
called!(2)
result:8
*** test fib_normal ***
1
1
2
3
5
8
13
21
34
55
*)

デバッグ出力を付けたパターンでひとつと、付けないパターンで1~10まで全てを出力してみました。 あまりやると長くなるので省略しますが、下の方は馴染み深いフィボナッチ数列のようになっています。

そしてY Combinator版。

// Y Combinator
// val fix : (('a -> 'b) -> 'a -> 'b) -> 'a -> 'b
let rec fix f x = f (fix f) x

// create (computing the n-th fibonacci number) function
let fib_maker' f x = printfn "called!(%d)" x; if x <= 2 then 1 else f (x-1) + f (x-2)

// computing the n-th fibonacci number
let fib' = fix fib_maker'

// test
do printfn "*** test fib' ***"
do fib' 6 |> printfn "result:%d"

(* results...
*** test fib' ***
called!(6)
called!(5)
called!(4)
called!(3)
called!(2)
called!(1)
called!(2)
called!(3)
called!(2)
called!(1)
called!(4)
called!(3)
called!(2)
called!(1)
called!(2)
result:8
*)

なるほど、何やらよくわからなりませんが、同じ結果にはなっているようです。 そして、数学はさほどわからなくても関数はわかる私たちプログラマーとしては、fib_maker'関数の特徴として「再帰呼び出しの代わりに、引数で渡された関数を呼び出している」ことくらいは理解できます。なるほど、再帰呼び出しの部分だけを関数の外側に持ってきた、と考えるとわかりがよさそうです。

再帰関数のメモ化(Memoization for Recursive Functions)

やっと本題です。今回やりたかったのは「fib_normal関数を、先ほどの汎用メモ化関数memoizeでメモ化できないか」ということなのです。

何も考えずにやってみる

// memoized fibonacci function (memoization do not works)
let rec fib_m1 x =
    let f = memoize (fun x-> printfn "called!(%d)" x; fib_m1 x) in
    if x <= 2 then 1 else f (x-1) + f(x-2)

// test
do printfn "*** test fib_m2 ***"
do fib_m1 6 |> printfn "result:%d"

(* results...
called!(5)
called!(4)
called!(3)
called!(2)
called!(1)
called!(2)
called!(3)
called!(2)
called!(1)
called!(4)
called!(3)
called!(2)
called!(1)
called!(2)
result:8
*)

あれれ、上手くいかない。

Y Combinatorバージョンは……

// memoized Y Combinator (memoization do not works)
let rec fix_m1 f x =
    let f' = memoize (fix_m1 f) in
    f (f') x

// test
let fib_m2 = fix_m1 fib_maker'
do printfn "*** test fib_m2 ***"
do fib_m2 6 |> printfn "result:%d"

(* results...
*** test fib_m2 ***
called!(6)
called!(5)
called!(4)
called!(3)
called!(2)
called!(1)
called!(2)
called!(3)
called!(2)
called!(1)
called!(4)
called!(3)
called!(2)
called!(1)
called!(2)
result:8
*)

こちらもダメ。何ということでしょう。

そして相互再帰

関数を百辺見直してみて気づくところによれば、どうやらmemoizeで関数で生成したキャッシュオブジェクトが、再帰呼び出しの為に引き回せていないように見えます。 そこで改良版です。

// memoized fibonacci function (it works)
let fib_m_ok =
    let rec f x =
        if x <= 2 then 1 else f_memo (x-1) + f_memo(x-2)
    and f_memo = memoize (fun x-> printfn "called!(%d)" x; f x) in
    f_memo
// test
do printfn "*** test fib_m_ok ***"
do fib_m_ok 6 |> printfn "result:%d"

(* results...
*** test fib_m_ok ***
called!(5)
called!(4)
called!(3)
called!(2)
called!(1)
result:8

*)

上手くいきました。ここまでくれば、「どんな再帰関数にも使えるメモ化関数」がY Combinatorを使って作れそうですね?

// memoized Y Combinator (it works)
// val fix_memo : (('a -> 'b) -> 'a -> 'b) -> 'a -> 'b when 'a : equality
let fix_memo f =
    let rec fix x = f m_fix x
    and m_fix = memoize fix in
    m_fix

// test
let fib_memo' = fix_memo fib_maker'
do printfn "*** test fib_memo ***"
do fib_memo' 6 |> printfn "result:%d"

(* results...
*** test fib_memo' ***
called!(6)
called!(5)
called!(4)
called!(3)
called!(2)
called!(1)
result:8
*)

やりました。great!

ところで、再帰関数といえば末尾再帰じゃないですか? そんなこともご存知なあなたにオススメなのがこちらメモ〜化したりスマス。 Memoization and Tail Recursive Function - Bug Catharsisです。私は末尾再帰にはまだまだ不慣れで挫折してしまいました。

最後に使いまわしの利きそうな関数だけ置いておきます。

// memoize function
// notice: if you want an async version, use System.Collections.Concurrent.ConcurrentDictionary.
// val memoize : ('a -> 'b) -> ('a -> 'b) when 'a : equality
let memoize f =
    let cache = new System.Collections.Generic.Dictionary<_,_>() in
    (fun x-> cache.TryGetValue(x) |> function
        | true,r -> r
        | _,_ -> let r = f x in cache.Add(x,r); r;)

// Y Combinator
// val fix : (('a -> 'b) -> 'a -> 'b) -> 'a -> 'b
let rec fix f x = f (fix f) x

// memoized Y Combinator
// val fix_memo : (('a -> 'b) -> 'a -> 'b) -> 'a -> 'b when 'a : equality
let fix_memo f =
    let rec fix x = f m_fix x
    and m_fix = memoize fix in
    m_fix

参考記事

追記

  • [2014/6/17] fix_memo, fib_m_ok関数を調整して、内部の再帰呼び出しだけでなく外部からの呼び出しに対してもきちんとメモ化されるように修正しました。