Scalaコンパイラの末尾再帰除去は自分自身を呼び出すメソッドのみに限定されます
今回は、どんな再帰呼び出しでもスタックを消費しないようにする方法を紹介します
val Zero = BigInt(0)
val One = BigInt(1)
lazy val factorial: BigInt => BigInt = {
case Zero | One => One
case n => n * factorial(n - 1)
}
scala> factorial(10000)
java.lang.StackOverflowError
at scala.math.BigInt$.maxCached(BigInt.scala:22)
at scala.math.BigInt$.apply(BigInt.scala:39)
at scala.math.BigInt$.int2bigInt(BigInt.scala:102)
at $anonfun$factorial$1.apply(<console>:17)
at $anonfun$factorial$1.apply(<console>:15)
at $anonfun$factorial$1.apply(<console>:17)
at $anonfun$factorial$1.apply(<console>:15)
at $anonfun$factorial$1.apply(<console>:17)
.
.
.
def foldl[A, B](as: List[A], b: B, f: (B, A) => B): B =
as match {
case Nil => b
case x :: xs => foldl(xs, f(b, x), f)
}
これは、varとwhileを使ったコードに機械的に変換出来る
def foldl[A, B](as: List[A], b: B, f: (B, A) => B): B = {
var z = b
var az = as
while (true) {
az match {
case Nil => return z
case x :: xs => {
z = f(z, x)
az = xs
}
}
}
z
}
末尾呼び出しならなんでも最適化されるのか?
lazy val even: Int => Boolean = {
case 0 => true
case n => odd(n - 1)
}
lazy val odd: Int => Boolean = {
case 0 => false
case n => even(n - 1)
}
scala> even(100000)
java.lang.StackOverflowError
at .even(<console>:13)
at $anonfun$odd$1.apply$mcZI$sp(<console>:20)
at $anonfun$even$1.apply$mcZI$sp(<console>:15)
at $anonfun$odd$1.apply$mcZI$sp(<console>:20)
at $anonfun$even$1.apply$mcZI$sp(<console>:15)
at $anonfun$odd$1.apply$mcZI$sp(<console>:20)
.
.
.
これらの問題を解決するデータ構造が存在します
sealed trait Trampoline[+A] {
final def runT: A =
this match {
case More(k) => k().runT
case Done(v) => v
}
}
case class Done[+A](a: A)
extends Trampoline[A]
case class More[+A](k: () => Trampoline[A])
extends Trampoline[A]
runTは再帰的に次のステップを呼び出し、結果を得る
lazy val even: Int => Trampoline[Boolean] = {
case 0 => Done(true)
case n => More(() => odd(n - 1))
}
lazy val odd: Int => Trampoline[Boolean] = {
case 0 => Done(false)
case n => More(() => even(n - 1))
}
scala> even(10000)
res0: Trampoline[Boolean] = More(<function0>)
scala> .runT
res1: Boolean = true
val Zero = BigInt(0)
val One = BigInt(1)
lazy val factorial: BigInt => Trampoline[BigInt] = {
case Zero | One => Done(One)
case n => More(() => Done(n * factorial(n - 1).runT))
}
scala> factorial(10000)
res0: Trampoline[BigInt] = More(<function0>)
scala> .runT
java.lang.StackOverflowError
at scala.math.BigInt.bigInteger(BigInt.scala:117)
at scala.math.BigInt.compare(BigInt.scala:182)
at scala.math.BigInt.equals(BigInt.scala:178)
at scala.math.BigInt.equals(BigInt.scala:126)
at scala.runtime.BoxesRunTime.equalsNumNum(Unknown Source)
at $anonfun$factorial$1.apply(Trampoline.scala:77)
at $anonfun$factorial$1.apply(Trampoline.scala:76)
at $anonfun$factorial$1$$anonfun$apply$3.apply(Trampoline.scala:78)
at $anonfun$factorial$1$$anonfun$apply$3.apply(Trampoline.scala:78)
at Trampoline$class.runT(Trampoline.scala:52)
at More.runT(Trampoline.scala:59)
at $anonfun$factorial$1$$anonfun$apply$3.apply(Trampoline.scala:78)
.
.
.
関数内でrunTを呼び出してしまっている
モナドにすることで解決を試みます
def flatMap[B](f: A => Trampoline[B]) =
More(() => f(runT))
しかし、flatMap内でrunTを呼び出してしまうと先ほどと同じ結果になってしまう。
ここではTrampolineの構成子を追加します
case class FlatMap[A, +B](sub: Trampoline[A], k: A => Trampoline[B])
extends Trampoline[B]
flatMap, mapは次のように定義できる
def flatMap[B](f: A => Trampoline[B]): Trampoline[B] =
this match {
case a FlatMap g =>
FlatMap(a, (x: Any) => g(x) flatMap f)
case x => FlatMap(x, f)
}
def map[B](f: A => B): Trampoline[B] =
flatMap(a => Done(f(a)))
構成子を追加したことでrunTに変更を加えなければならない
新しいrunTは次に示す、resumeメソッドによって定義される
final def resume: Either[() => Trampoline[A], A] =
this match {
case Done(a) => Right(a)
case More(k) => Left(k)
case a FlatMap f => a match {
case Done(a) => f(a).resume
case More(k) => Left(() => k() flatMap f)
case b FlatMap g => b.flatMap((x: Any) => g(x) flatMap f).resume
}
}
resumeメソッドはFlatMapを適用して結果か次のステップを返す
runTはresumeを利用して、以下の様に書くことが出来る
final def runT: A = resume match {
case Right(a) => a
case Left(k) => k().runT
}
resume, runTは末尾で自身を呼び出しているので、このメソッドはコンパイラによって最適化される
flatMap, mapが定義されたことによって最初の例は次のようになる
val Zero = BigInt(0)
val One = BigInt(1)
lazy val factorial: BigInt => Trampoline[BigInt] = {
case Zero | One => Done(One)
case n => More(() => factorial(n - 1)).map(n *)
}
もう一つ例を示す
lazy val fib: Int => Int = {
case n if n < 2 => n
case n => fib(n - 1) + fib(n - 2)
}
末尾で呼び出しているのは+めそっど
最適化はされない
Trampoline Monadとfor式を用いると自然な形で記述することが出来る
lazy val fib: Int => Trampoline[Int] = {
case n if n < 2 => Done(n)
case n => for {
x <- More(() => fib(n - 1))
y <- More(() => fib(n - 2))
} yield x + y
}
Trampoline
TrampolineはFunction0を利用しています
このFunction0の部分を抽象化すると次のような定義が可能です
sealed trait Free[S[+_], +A] {
private case class FlatMap[S[+_], A, +B](a: Free[S, A], f: A => Free[S, B]) extends Free[S, B]
}
case class Done[S[+_], +A](a: A) extends Free[S, A]
case class More[S[+_], +A](k: S[Free[S, A]]) extends Free[S, A]
Trampolineは以下のように定義出来る
type Trampoline[+A] = Free[Function0, A]
Function0を抽象化したことによって、resumeを変更しなければならない
実は、resumeではFunction0をFunctorとして利用することが出来た
trait Functor[F[_]] {
def map[A, B](m: F[A])(f: A => B): F[B]
}
implicit val f0Functor =
new Functor[Function0] {
def map[A, B](a: () => A)(f: A => B): () => B =
() => f(a())
}
resumeはFunctorを利用して次のように定義出来る
final def resume(implicit S: Functor[S]): Either[S[Free[S, A]], A] =
this match {
case Done(a) => Right(a)
case More(k) => Left(k)
case a FlatMap f => a match {
case Done(a) => f(a).resume
case More(k) => Left(S.map(k)(_ flatMap f))
case b FlatMap g => b.flatMap((x: Any) => g(x) flatMap f).resume
}
}
Freeで表現出来るデータ型はTrampolineだけではありません
Free[S, A]のSを枝、Aを葉と見做すことで木構造を表現出来ます
type Pair[+A] = (A, A)
type BinTree[+A] = Free[Pair, A]
この場合は枝はTuple2、葉はAで二分木を表現しています
Pairに対して2つの要素に関数を適用するようなFunctorを定義すれば、BinTreeは全ての葉を走査するようなMonadが定義されます
最後に、Freeを使ったプログラミングについて話します
ここでは例としてStateを構築します
まず最初に、枝となるデータ型を定義します
sealed trait StateF[S, +A]
case class Get[S, A](f: S => A)
extends State[S, A]
case class Put[S, A](s: S, a: A)
extends State[S, A]
ここで大切なことは関数のモデルをレコードで表現することで、実装は行ないません
次にFunctorを定義します
implicit def statefFun[S] =
new Functor[({ type F[A] = StateF[S, A] })#F] {
def map[A, B](m: StateF[S, A])(f: A => B): StateF[S, B] =
m match {
case Get(g) => Get((s: S) => f(g(s)))
case Put(s, a) => Put(s, f(a))
}
}
Functor則に気を付ければ自然とmapを定義することが可能です
StateFを使ったFreeStateの定義は以下のようになります
type FreeState[S, +A] =
Free[({ type F[B] = StateF[S, B] })#F, A]
FreeStateを返す関数として、次のようなものが定義出来ます
def pureState[S, A](a: A): FreeState[S, A] =
Done[({ type F[+B] = StateF[S, B] })#F, A](a)
def getState[S]: FreeState[S, S] =
More[({ type F[+B] = StateF[S, B] })#F, S](
Get(s => Done[({ type F[+B] = StateF[S, B] })#F, S](s)))
def setState[S](s: S): FreeState[S, Unit] =
More[({ type F[+B] = StateF[S, B] })#F, Unit](
Put(s, Done[({ type F[+B] = StateF[S, B] })#F, Unit](())))
そして、最初に定義した関数のモデルの実装は、以下のように定義されます
def evalS[S, A](s: S, t: FreeState[S, A]): A =
t.resume match {
case Left(Get(f)) => evalS(s, f(s))
case Left(Put(n, a)) => evalS(n, a)
case Right(a) => a
}
evalSは末尾で自身を呼び出しており、コンパイラによって最適化されます
このように、resumeを呼び出す関数で再帰的にモデルの評価を行なうことでス タックを消費しない関数を定義することが可能です
以上で終わります