Scalaにおける末尾再帰最適化とパフォーマンス

知ってる人は知っている?話で、全く新しい話ではないんですが、以下のtweetの簡単な実例含んだ解説

例えば、Listのreverseをするとして、以下のようにすごく雑にわざとREPLで計測してみると、foldLeftで書くよりも、実際にほとんど処理が同じだとしても、末尾再帰の方が明らかに数十倍も速そうな結果が出ます

https://github.com/scala/scala/blob/v2.13.12/src/library/scala/collection/LinearSeq.scala#L179-L187

Welcome to Scala 3.3.1 (11.0.20, Java OpenJDK 64-Bit Server VM).
Type in expressions for evaluation. Or try :help.
                                                                                                                                             
scala> def time[A](a: => A): Unit = { val start = System.currentTimeMillis; a ; println(System.currentTimeMillis - start) }
def time[A](a: => A): Unit
                                                                                                                                             
scala> lazy val list = (1 to 1000_000).toList
lazy val list: List[Int]
                                                                                                                                             
scala> time { list.foldLeft(List.empty[Int])((a, b) => b :: a) }
391
                                                                                                                                             
scala> time { list.foldLeft(List.empty[Int])((a, b) => b :: a) }
377
                                                                                                                                             
scala> time { list.foldLeft(List.empty[Int])((a, b) => b :: a) }
439
                                                                                                                                             
scala> time { list.foldLeft(List.empty[Int])((a, b) => b :: a) }
272
                                                                                                                                             
scala> time { list.foldLeft(List.empty[Int])((a, b) => b :: a) }
123
                                                                                                                                             
scala> @annotation.tailrec
     | def loop[A](src: List[A], acc: List[A]): List[A] =
     |   src match {
     |     case x :: xs =>
     |       loop(src = xs, acc = x :: acc)
     |     case _ =>
     |       acc
     |   }
     | 
def loop[A](src: List[A], acc: List[A]): List[A]
                                                                                                                                             
scala> time { loop(list, Nil) }
25
                                                                                                                                             
scala> time { loop(list, Nil) }
6
                                                                                                                                             
scala> time { loop(list, Nil) }
22
                                                                                                                                             
scala> time { loop(list, Nil) }
4
                                                                                                                                             
scala> time { loop(list, Nil) }
4

しかし「JVMの実行時最適化が賢ければ実質変わらない可能性」の部分ですが、JMHというツールを使ってもう少し真面目に計測したら、以下のように、 少なくともこの実験コードだとまだ多少の差は出ますが、REPLのように数十倍の差は出ません、せいぜい1.7倍程度です。

( ops/s は、単位時間あたり何回実行できたか?なので、数が多いほどパフォーマンスが良い、ということ)

https://github.com/sbt/sbt-jmh

「もう少し真面目」というのは、これでもまだ雑な部分が多くあり、JVMでJMH使って本当に正しく計測するには色々と奥が深いので、気になる人はもっと色々変えて試してみましょう。

1.7倍も、foldLeftと末尾再帰による違いというより、余計なboxingとunboxingの可能性がある気がします

plugin.sbt

addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.6")

build.sbt

scalaVersion := "3.3.1"

enablePlugins(JmhPlugin)

Main.scala

package example

import org.openjdk.jmh.annotations.Benchmark
import scala.annotation.tailrec

object Main {
  val list: List[Int] = (1 to 1000_000).toList
}

class Main {
  
  @Benchmark
  def foldLeft(): List[Int] = {
    Main.list.foldLeft(List.empty[Int])((xs, x) => x :: xs)
  }

  @Benchmark
  def tailrecLoop(): List[Int] = {
    @tailrec
    def loop[A](src: List[A], acc: List[A]): List[A] =
      src match {
        case x :: xs =>
          loop(src = xs, acc = x :: acc)
        case _ =>
          acc
      }

    loop(src = Main.list, acc = Nil)
  }
}

Jmh/run -i 5 -wi 5 -f1 -t1 での実行結果

[info] running (fork) org.openjdk.jmh.Main -i 5 -wi 5 -f1 -t1
[info] # JMH version: 1.37
[info] # VM version: JDK 11.0.20, OpenJDK 64-Bit Server VM, 11.0.20+8-LTS
[info] # VM invoker: /Library/Java/JavaVirtualMachines/zulu-11.jdk/Contents/Home/bin/java
[info] # VM options: <none>
[info] # Blackhole mode: full + dont-inline hint (auto-detected, use -Djmh.blackhole.autoDetect=false to disable)
[info] # Warmup: 5 iterations, 10 s each
[info] # Measurement: 5 iterations, 10 s each
[info] # Timeout: 10 min per iteration
[info] # Threads: 1 thread, will synchronize iterations
[info] # Benchmark mode: Throughput, ops/time
[info] # Benchmark: example.Main.foldLeft
[info] # Run progress: 0.00% complete, ETA 00:03:20
[info] # Fork: 1 of 1
[info] # Warmup Iteration   1: 103.134 ops/s
[info] # Warmup Iteration   2: 211.510 ops/s
[info] # Warmup Iteration   3: 182.879 ops/s
[info] # Warmup Iteration   4: 206.715 ops/s
[info] # Warmup Iteration   5: 203.103 ops/s
[info] Iteration   1: 187.781 ops/s
[info] Iteration   2: 190.669 ops/s
[info] Iteration   3: 183.167 ops/s
[info] Iteration   4: 186.122 ops/s
[info] Iteration   5: 189.595 ops/s
[info] Result "example.Main.foldLeft":
[info]   187.467 ±(99.9%) 11.420 ops/s [Average]
[info]   (min, avg, max) = (183.167, 187.467, 190.669), stdev = 2.966
[info]   CI (99.9%): [176.047, 198.887] (assumes normal distribution)
[info] # JMH version: 1.37
[info] # VM version: JDK 11.0.20, OpenJDK 64-Bit Server VM, 11.0.20+8-LTS
[info] # VM invoker: /Library/Java/JavaVirtualMachines/zulu-11.jdk/Contents/Home/bin/java
[info] # VM options: <none>
[info] # Blackhole mode: full + dont-inline hint (auto-detected, use -Djmh.blackhole.autoDetect=false to disable)
[info] # Warmup: 5 iterations, 10 s each
[info] # Measurement: 5 iterations, 10 s each
[info] # Timeout: 10 min per iteration
[info] # Threads: 1 thread, will synchronize iterations
[info] # Benchmark mode: Throughput, ops/time
[info] # Benchmark: example.Main.tailrecLoop
[info] # Run progress: 50.00% complete, ETA 00:01:41
[info] # Fork: 1 of 1
[info] # Warmup Iteration   1: 300.062 ops/s
[info] # Warmup Iteration   2: 326.647 ops/s
[info] # Warmup Iteration   3: 324.159 ops/s
[info] # Warmup Iteration   4: 324.243 ops/s
[info] # Warmup Iteration   5: 323.696 ops/s
[info] Iteration   1: 324.353 ops/s
[info] Iteration   2: 324.258 ops/s
[info] Iteration   3: 322.606 ops/s
[info] Iteration   4: 324.562 ops/s
[info] Iteration   5: 323.758 ops/s
[info] Result "example.Main.tailrecLoop":
[info]   323.907 ±(99.9%) 3.023 ops/s [Average]
[info]   (min, avg, max) = (322.606, 323.907, 324.562), stdev = 0.785
[info]   CI (99.9%): [320.884, 326.931] (assumes normal distribution)
[info] # Run complete. Total time: 00:03:21
[info] REMEMBER: The numbers below are just data. To gain reusable insights, you need to follow up on
[info] why the numbers are the way they are. Use profilers (see -prof, -lprof), design factorial
[info] experiments, perform baseline and negative tests that provide experimental control, make sure
[info] the benchmarking environment is safe on JVM/OS/HW level, ask for reviews from the domain experts.
[info] Do not assume the numbers tell you what you want them to tell.
[info] Benchmark          Mode  Cnt    Score    Error  Units
[info] Main.foldLeft     thrpt    5  187.467 ± 11.420  ops/s
[info] Main.tailrecLoop  thrpt    5  323.907 ±  3.023  ops/s