MapReduce on OCaml

こないだ、おーくらから聞いたMapReduceのフリー実装は、Java実装でしかも遅いらしい。で、分散環境用ではなくてマルチCPU用MapReduceRuby実装して遊んでたらしい。マルチCPU用ってなんじゃって、要は通信の代わりにメモリ使って、プロセスの代わりにスレッド使ってってこと。って、メモリ使ったら簡単やん。
簡単なので、ボクも遊んでみる。Rubyじゃないよ。一応、研究室には2CPUマシンあるし。いずれ、マルチコア4CPUマシンとか買ってくれるだろう(とか言ってみる)。研究目的ですといって、9コアCPU搭載の「計算機」を買ってもらうのもいい。

let map_reduce
    (n_thread: int)
    (map: 'a -> ('b * 'c) list)
    (reduce: 'b * ('c list) -> 'd)
    (data: 'a list)
    : ('b * 'd) list =
  let lock = Mutex.create () in
  let d = ref data in

  let read () =
    Mutex.lock lock;
    let r = match !d with
      | [] -> None
      | x::xs -> d := xs; Some x in
    Mutex.unlock lock;
    r in

  let mapped = Hashtbl.create 1000 in
  let save h =
    Mutex.lock lock;
    Hashtbl.iter (fun key values ->
      if Hashtbl.mem mapped key then
        Hashtbl.replace mapped key ((Hashtbl.find mapped key) @ values)
      else
        Hashtbl.add mapped key values) h;
    Mutex.unlock lock in

  let threads = Array.init n_thread (fun _ ->
    Thread.create (fun () ->
      output_string stderr "new thread\n";
      flush stderr;

      let rec iter xs =
        match read () with
        | None -> xs
        | Some x ->
            iter (map x @ xs) in
      let ys = iter [] in
      
      let h = Hashtbl.create 1000 in
      List.iter (fun (key, value) ->
        if Hashtbl.mem h key then
          Hashtbl.replace h key (value::(Hashtbl.find h key))
        else
          Hashtbl.add h key [value]) ys;
      save h) ()) in

  Array.iter Thread.join threads;
  output_string stderr "map period finished\n";
  flush stderr;

  let keys = ref [] in
  Hashtbl.iter (fun k _ -> keys := (k::!keys)) mapped;
  let next () =
    Mutex.lock lock;
    let r =
      match !keys with
      | [] -> None
      | k::ks ->
          keys := ks;
          Some (k, Hashtbl.find mapped k) in
    Mutex.unlock lock;
    r in

  let result = ref [] in
  let collect v =
    Mutex.lock lock;
    result := v @ !result;
    Mutex.unlock lock in

  let threads = Array.init n_thread (fun _ ->
    Thread.create (fun () ->
      let rec iter xs =
        match next () with
        | None -> xs
        | Some (key, values) ->
            let v = reduce (key, values) in
            iter ((key, v)::xs) in
      let r = iter [] in
      collect r) ()) in

  Array.iter Thread.join threads;
  output_string stderr "reduce period finished\n";
  flush stderr;

  !result

lockする分がちょっともったいないなぁ。いろいろパフォーマンス的にもったいなさそうなところがちらほら。まぁ、いいや。mapとreduce処理が重い処理なら、複数CPUで高パフォーマンスがでると期待。OCamlでthread使うのはめんどいかなぁ、と思っていたが、思ってたよりもすっきり書けた。かもしれない。
使い方だが、入力は'aの集合(list)として、mapでは、'aからの集合に変換します。で、全入力で処理したら、keyごとにデータをまとめて、 listになります。最後に、reduce処理でを好きな出力データに変換して、になりますと。
例として、単語頻度でも測ってみましょうか。

  let lst =
    map_reduce
      2
      (fun w -> [w, 1])
      (fun (w, l) -> List.length l)
      ["a"; "b"; "a"; "c"; "b"] in
  List.iter (fun (w, n) ->
    Printf.printf "%s: %d\n" w n) lst

実行

$ ./a.out
new thread
new thread
map period finished
reduce period finished
a: 2
b: 2
c: 1

mapで<単語, 1>にすると、reduceフェーズで<単語, [1;1;1;1...]>を受け取るので、reduceでlistの長さを返します。すると、<単語, 頻度>のリストがかえって来るという寸法。データとして、単語集合じゃなくて、ドキュメント集合程度の粒度で扱った方が効率的ですね。パフォーマンスのチェックとかは、・・・めんどいなぁ。