すたっくれす すから

うぃず

ふりー もなど

Sanshiro Yoshida (@halcat0x15a)

あぶすとらくと

Scalaコンパイラの末尾再帰除去は自分自身を呼び出すメソッドのみに限定されます

今回は、どんな再帰呼び出しでもスタックを消費しないようにする方法を紹介します

1. Introduction

StackOverflowError

みなさん経験ありますね?

val Zero = BigInt(0)
val One = BigInt(1)

lazy val factorial: BigInt => BigInt = {
  case Zero | One => One
  case n => n * factorial(n - 1)
}
scala> factorial(10000)
java.lang.StackOverflowError
        at scala.math.BigInt$.maxCached(BigInt.scala:22)
        at scala.math.BigInt$.apply(BigInt.scala:39)
        at scala.math.BigInt$.int2bigInt(BigInt.scala:102)
        at $anonfun$factorial$1.apply(<console>:17)
        at $anonfun$factorial$1.apply(<console>:15)
        at $anonfun$factorial$1.apply(<console>:17)
        at $anonfun$factorial$1.apply(<console>:15)
        at $anonfun$factorial$1.apply(<console>:17)
        .
        .
        .

2. Background: Tail-call elimination in Scala

末尾で自身を呼び出す関数

def foldl[A, B](as: List[A], b: B, f: (B, A) => B): B =
  as match {
    case Nil => b
    case x :: xs => foldl(xs, f(b, x), f)
  }

これは、varとwhileを使ったコードに機械的に変換出来る

コンパイルされたコードは以下と同等

def foldl[A, B](as: List[A], b: B, f: (B, A) => B): B = {
  var z = b
  var az = as
  while (true) {
    az match {
      case Nil => return z
      case x :: xs => {
        z = f(z, x)
        az = xs
      }
    }
  }
  z
}

末尾呼び出しならなんでも最適化されるのか?

最適化されない例

相互再帰

lazy val even: Int => Boolean = {
  case 0 => true
  case n => odd(n - 1)
}

lazy val odd: Int => Boolean = {
  case 0 => false
  case n => even(n - 1)
}
scala> even(100000)
java.lang.StackOverflowError
        at .even(<console>:13)
        at $anonfun$odd$1.apply$mcZI$sp(<console>:20)
        at $anonfun$even$1.apply$mcZI$sp(<console>:15)
        at $anonfun$odd$1.apply$mcZI$sp(<console>:20)
        at $anonfun$even$1.apply$mcZI$sp(<console>:15)
        at $anonfun$odd$1.apply$mcZI$sp(<console>:20)
        .
        .
        .

これらの問題を解決するデータ構造が存在します

3. Tampolines: Trading stack for heap

Trampoline

sealed trait Trampoline[+A] {
  final def runT: A =
    this match {
       case More(k) => k().runT
       case Done(v) => v
    }
}

case class Done[+A](a: A)
  extends Trampoline[A]

case class More[+A](k: () => Trampoline[A])
  extends Trampoline[A]

runTは再帰的に次のステップを呼び出し、結果を得る

Trampolineを用いた相互再帰

lazy val even: Int => Trampoline[Boolean] = {
  case 0 => Done(true)
  case n => More(() => odd(n - 1))
}

lazy val odd: Int => Trampoline[Boolean] = {
  case 0 => Done(false)
  case n => More(() => even(n - 1))
}
scala> even(10000)
res0: Trampoline[Boolean] = More(<function0>)

scala> .runT
res1: Boolean = true

4. Making every call a tail cal

最初に挙げた例を解決出来るか?

val Zero = BigInt(0)
val One = BigInt(1)

lazy val factorial: BigInt => Trampoline[BigInt] = {
  case Zero | One => Done(One)
  case n => More(() => Done(n * factorial(n - 1).runT))
}
scala> factorial(10000)
res0: Trampoline[BigInt] = More(<function0>)

scala> .runT
java.lang.StackOverflowError
        at scala.math.BigInt.bigInteger(BigInt.scala:117)
        at scala.math.BigInt.compare(BigInt.scala:182)
        at scala.math.BigInt.equals(BigInt.scala:178)
        at scala.math.BigInt.equals(BigInt.scala:126)
        at scala.runtime.BoxesRunTime.equalsNumNum(Unknown Source)
        at $anonfun$factorial$1.apply(Trampoline.scala:77)
        at $anonfun$factorial$1.apply(Trampoline.scala:76)
        at $anonfun$factorial$1$$anonfun$apply$3.apply(Trampoline.scala:78)
        at $anonfun$factorial$1$$anonfun$apply$3.apply(Trampoline.scala:78)
        at Trampoline$class.runT(Trampoline.scala:52)
        at More.runT(Trampoline.scala:59)
        at $anonfun$factorial$1$$anonfun$apply$3.apply(Trampoline.scala:78)
        .
        .
        .

関数内でrunTを呼び出してしまっている

Trampoline Monad

4.1 A Trampoline monad?

モナドにすることで解決を試みます

>>=

単純に実装すると

def flatMap[B](f: A => Trampoline[B]) =
  More(() => f(runT))

しかし、flatMap内でrunTを呼び出してしまうと先ほどと同じ結果になってしまう。

4.2 Building the monad right in

ここではTrampolineの構成子を追加します

case class FlatMap[A, +B](sub: Trampoline[A], k: A => Trampoline[B])
  extends Trampoline[B]

flatMap, mapは次のように定義できる

def flatMap[B](f: A => Trampoline[B]): Trampoline[B] =
  this match {
    case a FlatMap g =>
      FlatMap(a, (x: Any) => g(x) flatMap f)
    case x => FlatMap(x, f)
  }
def map[B](f: A => B): Trampoline[B] =
  flatMap(a => Done(f(a)))

構成子を追加したことでrunTに変更を加えなければならない

新しいrunTは次に示す、resumeメソッドによって定義される

final def resume: Either[() => Trampoline[A], A] =
  this match {
    case Done(a) => Right(a)
    case More(k) => Left(k)
    case a FlatMap f => a match {
      case Done(a) => f(a).resume
      case More(k) => Left(() => k() flatMap f)
      case b FlatMap g => b.flatMap((x: Any) => g(x) flatMap f).resume
    }
  }

resumeメソッドはFlatMapを適用して結果か次のステップを返す

runTはresumeを利用して、以下の様に書くことが出来る

final def runT: A = resume match {
  case Right(a) => a
  case Left(k) => k().runT
}

resume, runTは末尾で自身を呼び出しているので、このメソッドはコンパイラによって最適化される

4.4 Stackless Scala

flatMap, mapが定義されたことによって最初の例は次のようになる

val Zero = BigInt(0)
val One = BigInt(1)
lazy val factorial: BigInt => Trampoline[BigInt] = {
  case Zero | One => Done(One)
  case n => More(() => factorial(n - 1)).map(n *)
}

もう一つ例を示す

よくあるふぃぼなっち数の例

lazy val fib: Int => Int = {
  case n if n < 2 => n
  case n => fib(n - 1) + fib(n - 2)
}

末尾で呼び出しているのは+めそっど

最適化はされない

Trampoline Monadとfor式を用いると自然な形で記述することが出来る

lazy val fib: Int => Trampoline[Int] = {
  case n if n < 2 => Done(n)
  case n => for {
    x <- More(() => fib(n - 1))
    y <- More(() => fib(n - 2))
  } yield x + y
}

6. Free Monads: A Generalization of

Trampoline

TrampolineはFunction0を利用しています

このFunction0の部分を抽象化すると次のような定義が可能です

sealed trait Free[S[+_], +A] {
  private case class FlatMap[S[+_], A, +B](a: Free[S, A], f: A => Free[S, B]) extends Free[S, B]
}

case class Done[S[+_], +A](a: A) extends Free[S, A]

case class More[S[+_], +A](k: S[Free[S, A]]) extends Free[S, A]

Trampolineは以下のように定義出来る

type Trampoline[+A] = Free[Function0, A]

Function0を抽象化したことによって、resumeを変更しなければならない

実は、resumeではFunction0をFunctorとして利用することが出来た

Functor

trait Functor[F[_]] {
  def map[A, B](m: F[A])(f: A => B): F[B]
}

Function0Functor

implicit val f0Functor =
  new Functor[Function0] {
    def map[A, B](a: () => A)(f: A => B): () => B =
      () => f(a())
  }

6.1 Functions defined on all free monads

resumeはFunctorを利用して次のように定義出来る

final def resume(implicit S: Functor[S]): Either[S[Free[S, A]], A] =
  this match {
    case Done(a) => Right(a)
    case More(k) => Left(k)
    case a FlatMap f => a match {
      case Done(a) => f(a).resume
      case More(k) => Left(S.map(k)(_ flatMap f))
      case b FlatMap g => b.flatMap((x: Any) => g(x) flatMap f).resume
    }
  }

6.2 Common data types as free monads

Freeで表現出来るデータ型はTrampolineだけではありません

Free[S, A]のSを枝、Aを葉と見做すことで木構造を表現出来ます

type Pair[+A] = (A, A)

type BinTree[+A] = Free[Pair, A]

この場合は枝はTuple2、葉はAで二分木を表現しています

Pairに対して2つの要素に関数を適用するようなFunctorを定義すれば、BinTreeは全ての葉を走査するようなMonadが定義されます

6.3 A free State monad

最後に、Freeを使ったプログラミングについて話します

ここでは例としてStateを構築します

まず最初に、枝となるデータ型を定義します

sealed trait StateF[S, +A]

case class Get[S, A](f: S => A)
  extends State[S, A]

case class Put[S, A](s: S, a: A)
  extends State[S, A]

ここで大切なことは関数のモデルをレコードで表現することで、実装は行ないません

次にFunctorを定義します

implicit def statefFun[S] =
  new Functor[({ type F[A] = StateF[S, A] })#F] {
    def map[A, B](m: StateF[S, A])(f: A => B): StateF[S, B] =
      m match {
        case Get(g) => Get((s: S) => f(g(s)))
        case Put(s, a) => Put(s, f(a))
      }
  }

Functor則に気を付ければ自然とmapを定義することが可能です

StateFを使ったFreeStateの定義は以下のようになります

type FreeState[S, +A] =
  Free[({ type F[B] = StateF[S, B] })#F, A]

FreeStateを返す関数として、次のようなものが定義出来ます

def pureState[S, A](a: A): FreeState[S, A] =
  Done[({ type F[+B] = StateF[S, B] })#F, A](a)

def getState[S]: FreeState[S, S] =
  More[({ type F[+B] = StateF[S, B] })#F, S](
    Get(s => Done[({ type F[+B] = StateF[S, B] })#F, S](s)))

def setState[S](s: S): FreeState[S, Unit] =
  More[({ type F[+B] = StateF[S, B] })#F, Unit](
    Put(s, Done[({ type F[+B] = StateF[S, B] })#F, Unit](())))

そして、最初に定義した関数のモデルの実装は、以下のように定義されます

def evalS[S, A](s: S, t: FreeState[S, A]): A =
  t.resume match {
    case Left(Get(f)) => evalS(s, f(s))
    case Left(Put(n, a)) => evalS(n, a)
    case Right(a) => a
  }

evalSは末尾で自身を呼び出しており、コンパイラによって最適化されます

このように、resumeを呼び出す関数で再帰的にモデルの評価を行なうことでス タックを消費しない関数を定義することが可能です

以上で終わります