bussorenre Laboratory

hoge piyo foo bar

Scala における末尾再帰

背景

S-99: Ninety-Nine Scala Problems 等の例題を説いていると、よく再帰による実装を見かけます。

確かに、 Scala関数型デザイン&プログラミング―Scalazコントリビューターによる関数型徹底ガイド | Paul Chiusano, Rúnar Bjarnason, 株式会社クイープ | 工学 | Kindleストア 等の書籍にも、可能な限り再帰的な考え方で実装しろと書かれています。

しかし、「再帰でばかり実装すると、スタック領域を食いつぶしてしまうんじゃないのか?」という不安があります。 基本的に、関数呼び出しの際は、現在実行している関数の情報(レジスタ情報や引数・戻り先のポインタ)を、メモリ上のスタック領域と呼ばれるところに押し込んでいくので、再帰はスタック階層が深くなり、かの有名なスタックオーバーフローエラーが出ることになります。

普通の再帰

階乗(n!)を実装します。 階乗とは、例えば n = 5 の時、 5! = 5 x 4 x 3 x 2 x 1 = 120 となります。

scala で実装すると以下のようになります。

def fact(n: Int): BigInt =
  n match {
    case 0 => 1
    case _ => n * fact(n - 1)
  }

さて、n の値が小さいうちは普通に計算してくれますが、10000 とかぶっこむと StackOverflowError が出ます。

java.lang.StackOverflowError
  at scala.math.BigInt$.apply(BigInt.scala:38)
  at scala.math.BigInt$.int2bigInt(BigInt.scala:96)
  at .fact(<console>:14)
  at .fact(<console>:14)
  at .fact(<console>:14)
  at .fact(<console>:14)
  at .fact(<console>:14)
  at .fact(<console>:14)
  at .fact(<console>:14)
# 以下略

延々とfact関数を呼び出しており、エラーログ的にも優しくありません。

末尾再帰

Scala では、関数の最後の処理として自分自身を呼び出す再帰関数(これを末尾再帰と言う)を検知すると、パラメーターを新しい値に更新した後、再帰呼び出しを関数の冒頭にジャンプするコードに書き換えるしくみがあるみたいです。 末尾再帰を検知すると、内部的にはwhile文に変換している。と捉えても大きな違いはなさそう。

実際にfact関数を末尾再帰にしてみる。

末尾再帰を妨げているのは、n * fact(n - 1) の部分で、この処理を別の関数として置き換え、その関数を再帰的に呼び出すように修正します。

def fact(n: Int): BigInt = {
  def innerFact(n: Int, f: BigInt): BigInt = 
    n match {
      case 0 => f
      case _ => innerFact(n - 1, n * f)
    }
  innerFact(n, 1)
}

fact の中に、内部関数としてinnerFact を定義しました。実装を見てもらえるとわかるように、関数の末尾は innerFact を呼び出すだけになっている。これにより、Scala の末尾再帰検出機構が働き、fact(10000) などもうまく実行してくれるようになります。

本当に書いた関数が末尾再帰になっているかを確認するには、 @tailrec アノテーションを利用します。

import scala.annotation.tailrec

@tailrec
def fact(n: Int): BigInt =
  n match {
    case 0 => 1
    case _ => n * fact(n - 1)
  }

<console>:19: error: could not optimize @tailrec annotated method fact: it contains a recursive call not in tail position
           case _ => n * fact(n - 1)

def fact(n: Int): BigInt = {
  @tailrec
  def innerFact(n: Int, f: BigInt): BigInt = 
    n match {
      case 0 => f
      case _ => innerFact(n - 1, n * f)
    }
  innerFact(n, 1)
}

// 何もエラーが発生しない

参考にしました