Scala 3のinlineによるcompile時間増大に対処する方法

概要や結論を先に書いておくと、適切にプロファイルを取って、適切にボトルネック箇所を見つけて対処しましょう。 という話になるのですが、大まかな概要というか、結果的に使った方法を書くと

  1. sbtのTask毎の時間を記録する機能を使う
  2. Scala 3 compilerのログを多めに出してphase毎の時間を出す
  3. Scala 3のcompiler pluginを書いて特に時間がかかっているファイルを見つけ出す
  4. 普通のinlineやsummonFromだけで書いていたものを、あえてQuotes使ったlow level macroに書き換えて、compile time debugする
  5. 遅い原因となったinstanceを半手動定義する

といった感じです。タイトルに、

とありますが、一部はそれ以外のパターンでも使える、Scalaやsbtにおける、compile時間、build時間の最適化手法です。

こういった結論というか方法を見つけるために、かなり色々試行錯誤したのですが、やはり正攻法(?)で、しっかり計測することが大事ですね。

結果的にinlineが一番ボトルネックだったので、そこにたどり着くまでの話をしていきますが、別にそこがボトルネックでなければ、他の手法を適用しただけでも十分なことも多々あるでしょう。 また、以前書いたCIでのキャッシュの効率化など、触れない話題も色々あります。

xuwei-k.hatenablog.com

一応、versionは、このblog書いている時点で最新の

  • sbt 1.6.2
  • Scala 3.1.3-RC1-bin-20220328-ed9267e-NIGHTLY

としておきます。 (Scala 3がNIGHTLYであるのは後述)

1: sbtのTask毎の時間を記録する機能

github.com

sbt起動時に引数渡すだけで、いい感じのグラフ見れるためのjsonファイル吐き出してくれます。chromeに読み込ませる必要がありますが。

細かい部分をカスタマイズしたくなった(時間短いやつは除くなど)、ので、以下のように改変して、CIで毎回結果を保存してます。

https://gist.github.com/xuwei-k/5b8f1f924ff5c10eb30dcf9ad2fe5c03

これを眺めたり、また、以前Scala 3の移行の発表をした際にも、かるく話を出しましたが、

https://xuwei-k.hatenablog.com/entry/2022/03/05/100217

sbtで出したTask毎(sbtのsub project毎のcompile時間)と、projectの構成、依存関係を睨めっこして、どこがボトルネックか?をまず探りましょう。

https://github.com/dwijnand/sbt-project-graph

全自動でボトルネック箇所出すのも、もっと頑張れば不可能ではないかもしれませんが、数十や百個ほどsbtのsub projectがあっても、 個人的には数十秒ほど睨めっこすれば、ボトルネック箇所はわかるので、ひとまずそれで十分だとは思います。

場合によっては、sbtのsub project間の依存関係を見直すだけで、他の最適化をやらなくても、結構な効果がある場合は多々あると思います。

2: Scala 3 compilerのログを多めに出してphase毎の時間を出す

Scala 2でも一応同様の方法は原理上可能だと思いますが、とりあえずScala 3前提で話を進めます。

それぞれのphaseの詳細に詳しくないし、そもそもcompilerの内部実装なので、phaseはversionによって消えたり増えたりしますが*1、あくまで上記に書いた現在のScala 3の話です。

Scala 2でも3でも scala -Xshow-phases とすると、phaseの一覧が表示されます

(出力は省略) (この前のwartremoverの発表でも多少話しましたね)

https://xuwei-k.github.io/slides/wartremover-3/#1

一般的に(?) typer が色々やるので大抵遅い印象がありますが、Scala 3のinlineでコンパイル時間が爆発すると、typerではなくinliningというphaseが爆発することがわかりました。

--verbose というのをscalacOptionsに追加すると、phase毎の実行時間含めた、色々なlogが出るようになります。

https://github.com/lampepfl/dotty/blob/e8356687e294d8a6274c1decf271fbcc43275e5a/compiler/src/dotty/tools/dotc/config/ScalaSettings.scala#L102

(実行時間だけならもっと適切なオプションがあるかもしれないが、調べてない)

(そもそもScala 3内部にある程度profiler的なclassが存在していそうだったが、中途半端な実装になっていたり、使えるのか謎だったので、結局使っていない。誰か試して使えたら、使い方の記事書いて・・・ https://github.com/lampepfl/dotty/blob/e8356687e294d8a6274c1decf271fbcc43275e5a/compiler/src/dotty/tools/dotc/profile/Profiler.scala#L218-L223 )

実例でいうと、とあるボトルネックだと思われるsbtのsub project部分だけ計測したところ、以下のようになってました。inliningがtyperの20倍以上かかっていました。

[info] [parser  in 203ms]
[info] [typer  in 13146ms]
[info] [inlinedPositions  in 1ms]
[info] [sbt-deps  in 851ms]
[info] [posttyper  in 228ms]
[info] [sbt-api  in 4932ms]
[info] [pickler  in 303ms]
[info] [inlining  in 276505ms]
[info] [postInlining  in 113ms]
[info] [staging  in 1ms]
[info] [pickleQuotes  in 227ms]
[info] [MegaPhase{firstTransform, checkReentrant, elimPackagePrefixes, cookComments, checkStatic, checkLoopingImplicits, betaReduce, inlineVals, expandSAMs}  in 287ms]
[info] [MegaPhase{elimRepeated, protectedAccessors, extmethods, uncacheGivenAliases, byNameClosures, hoistSuperArgs, specializeApplyMethods, refchecks, tryCatchPatterns, patternMatcher}  in 2018ms]
[info] [MegaPhase{elimOpaque, explicitOuter, explicitSelf, elimByName, stringInterpolatorOpt}  in 315ms]
[info] [MegaPhase{pruneErasedDefs, uninitializedDefs, inlinePatterns, vcInlineMethods, seqLiterals, intercepted, getters, specializeFunctions, liftTry, collectNullableFields, elimOuterSelect, resolveSuper, functionXXLForwarders, paramForwarding, genericTuples, letOverApply, arrayConstructors}  in 799ms]
[info] [erasure  in 1713ms]
[info] [MegaPhase{elimErasedValueType, pureStats, vcElideAllocations, arrayApply, elimPolyFunction, tailrec, completeJavaEnums, mixin, lazyVals, memoize, nonLocalReturns, capturedVars}  in 615ms]
[info] [constructors  in 335ms]
[info] [MegaPhase{lambdaLift, elimStaticThis, countOuterAccesses}  in 562ms]
[info] [MegaPhase{dropOuterAccessors, checkNoSuperThis, flatten, transformWildcards, moveStatic, expandPrivate, restoreScopes, selectStatic, Collect entry points, collectSuperCalls, repeatableAnnotations}  in 243ms]
[info] [genBCode  in 5227ms]

3: Scala 3のcompiler pluginを書いて特に時間がかかっているファイルを見つけ出す

わざわざcompiler pluginを書かずに見つけられるならば、その方が望ましいですが、これで見つけることが出来てしまって、他の方法知らないので、ひとまずこの方法を紹介します。

https://docs.scala-lang.org/scala3/reference/changed-features/compiler-plugins.html

Scala 2にも3にもcompiler pluginという仕組みがあり、任意のphaseの前や後に、処理を追加して挟むことが出来ます。 さらにScala 3では、処理の順番というかphaseそのものを完全に入れ替えられるより強力なresearch pluginsというものが追加されているそうです。 (Scala 2でもすごく無理やり頑張れば、実質似たようなことが出来たらしい?がよく知らない)

https://github.com/lampepfl/dotty/blob/e8356687e294d8a6274c1decf271fbcc43275e5a/compiler/src/dotty/tools/dotc/plugins/Plugin.scala#L53-L65

この前発表したwartremoverは、research pluginではなく、(今の実装は)typerの後にphaseを一つ挟むだけの、普通のcompiler pluginです。

ここで注意するのが、research pluginをbuildする側は関係ないかもしれませんが、research pluginを使う側が(?)NIGHTLYである必要がある、ということです。 3.1.1や3.1.2-RC3のようなversionでは動きません(警告すら出ずに無視されます) https://github.com/lampepfl/dotty/blob/ed9267e8e1af6472db2b9164d92668de5b61376f/compiler/src/dotty/tools/dotc/plugins/Plugins.scala#L129-L133

gist.github.com

作成方法としては、以下のようなことをしました。

  • build.sbtでcompilerの依存追加
  • src/main/resourcesにplugin.propertiesというファイルを置いて、pluginのclass名を記述
  • ResearchPluginを継承したclassを作り、name、initなどの必要なメソッドをoverrideする
  • 詳細はResearchPluginのscaladocに書いてあるが、plugin向けのコンパイルオプションと、既存のphase一覧が渡ってくるので、phaseを独自のものに置き換えるなどして、新しいphase一覧を返す
  • 今回はinliningだけ置き換えたいので、既存のInliningをそのまま継承した、独自のphaseを作る
  • 独自のphaseでは、今回は、runOn(全てのfileのinlining phaseを行うタイミングで呼ばれる)と、run(実質1ファイル毎に呼ばれる?)、をoverrideして、時間を記録、記録した時間を(時間がかかった順に)出力するだけのコードを書く(Inliningの実装詳細見るとわかるが、別にどちらかだけoverrideでもいける気はする)

https://github.com/lampepfl/dotty/blob/e8356687e294d8a6274c1decf271fbcc43275e5a/compiler/src/dotty/tools/dotc/transform/Inlining.scala#L26-L34

これによって、特にinliningでどのファイルが遅いのか?が判明したので、次にいきます。 (結果としては、明らかに特定のinlineやってるファイルだけ6分〜7分かかっていました!!!)

4: 普通のinlineやsummonFromだけで書いていたものを、あえてQuotes使ったlow level macroに書き換えて、compile time debugする

これまた辛いというか、マニアックな話で、本来もう少しカッコ良い方法でいけたらよかったのですが、結果的にこれでうまくいってしまったので、この方法の詳細を紹介します。

これ自体は、そもそも独自の再帰的なtype classのinstance導出のinlineを書いていないと使えない、というか、発生しない問題です。 ライブラリ側がやってしまっていたら、より面倒ですが、原理的に頑張ればいけるとは思います。

独自の再帰的なtype classのinstance導出のinline とは、具体的には例えば、circeのautoとsemiautoをイメージすると良いと思います。

(Scala 3ではsemiautoもautoと同様の挙動になってしまっている問題がありますが!!! https://github.com/circe/circe/pull/1923 )

さて、再帰的なtype classのinstance導出のinline、は、大抵以下のようなコードになると思います。

import scala.compiletime.erasedValue
import scala.compiletime.summonFrom
import scala.deriving.Mirror

trait MyTypeClass[A] {
  // 実装省略
}

object MyTypeClass {
  implicit inline def derive[A]: MyTypeClass[A] = summonFrom {
    case x: MyTypeClass[A] =>
      // すでにinstanceが存在すればそれを返すだけ
      x
    case _ =>
      // 存在しなければ作成
      create[A]
  }

  // 直接呼びたいこともあるので、あえてderiveとは分けて定義
  inline def create[A]: MyTypeClass[A] = summonFrom {
    case mirror: Mirror.ProductOf[A] =>
      val typeclasses = deriveRec[mirror.MirroredElemTypes]
      deriveProduct[A](typeclasses, mirror)
    case mirror: Mirror.SumOf[A] =>
      val typeclasses = deriveRec[mirror.MirroredElemTypes]
      deriveSum[A](typeclasses, mirror)
  }

  /** 展開されたコードが肥大化しないようにわざとこれはinlineにしていない */
  final def deriveProduct[A](typeclasses: Seq[MyTypeClass[?]], mirror: Mirror.ProductOf[A]): MyTypeClass[A] =
    ??? // 実装省略
    
  /** 展開されたコードが肥大化しないようにわざとこれはinlineにしていない */
  final def deriveSum[A](typeclasses: Seq[MyTypeClass[?]], mirror: Mirror.SumOf[A]): MyTypeClass[A] =
    ??? // 実装省略

  inline def deriveRec[T <: Tuple]: List[MyTypeClass[?]] =
    inline erasedValue[T] match {
      case _: EmptyTuple =>
        Nil
      case _: (t *: ts) =>
        derive[t] :: deriveRec[ts]
    }
}

summonFrom、erasedValue、Mirrorなどを使った実装は、綺麗に簡素に書ける一方で、compile timeに副作用を簡単に発生させることが不可能?なので、 このinline自体がcompile時間爆発の原因だった場合に、debugがやりずらいです(実は他の方法があったら教えてください)

さて、これらをlow levelな scala.quoted.Quotes 使ったmacroで書き換えると、大体、以下のようになるはずです。 *2

package example 

import scala.deriving.Mirror
import scala.quoted.Quotes
import scala.quoted.Expr
import scala.quoted.Type
import scala.collection.concurrent.TrieMap
import scala.concurrent.ExecutionContext
import scala.concurrent.Future
import java.util.concurrent.atomic.AtomicInteger

object MyTypeClass {

  private val derivedTypes = TrieMap.empty[String, AtomicInteger]
  private val sum = new AtomicInteger

  implicit inline def derive[A]: MyTypeClass[A] = ${ deriveImpl[A] }

  inline def create[A]: MyTypeClass[A] = ${ createImpl[A] }

  final def deriveImpl[A](using t: Type[A], q: Quotes): Expr[MyTypeClass[A]] = {
    import q.reflect.*

    Implicits.search(TypeRepr.of[MyTypeClass[A]]) match {
      case a: ImplicitSearchSuccess =>
        a.tree.asExprOf[MyTypeClass[A]]
      case _ =>
        createImpl[A]
    }
  }

  final def createImpl[A](using t: Type[A], q: Quotes): Expr[MyTypeClass[A]] = {
    import q.reflect.*

    scala.concurrent.Future {
      val typeName = TypeRepr.of[A].show
      derivedTypes.getOrElseUpdate(typeName, new AtomicInteger).getAndIncrement

      if (sum.getAndIncrement % 1000 == 0) { // この数は適当に調節
        val values = derivedTypes.toList
          .map { (k, v) => k -> v.get }
          .sortBy(-_._2)
          .take(20)
          .map { (name, count) =>
            s"$name $count"
          }
          .mkString(", ")

        println(values)
      }
    }(ExecutionContext.global)

    @annotation.tailrec
    def loop(tpe: Type[?], acc: List[Expr[MyTypeClass[?]]]): Expr[Seq[MyTypeClass[?]]] =
      tpe match {
        case '[y *: ys] =>
          loop(Type.of[ys], deriveImpl[y] :: acc)
        case '[EmptyTuple] =>
          Expr.ofSeq[MyTypeClass[?]](acc.reverse)
      }

    Implicits.search(TypeRepr.of[Mirror.ProductOf[A]]) match {
      case s: ImplicitSearchSuccess =>
        s.tree.tpe.asType match {
          case '[ { type MirroredElemTypes = x }] =>
            val typeclasses = loop(Type.of[x], Nil)
            s.tree.asExpr match {
              case '{ ($z: Mirror.ProductOf[A]) } =>
                '{ deriveProduct[A](${ typeclasses }, $z.fromProduct) }
            }
          case '[x] =>
            report.errorAndAbort("not found MirroredElemTypes", s.tree.pos)
        }
      case _ =>
        Implicits.search(TypeRepr.of[Mirror.SumOf[A]]) match {
          case s: ImplicitSearchSuccess =>
            s.tree.tpe.asType match {
              case '[ { type MirroredElemTypes = x }] =>
                val typeclasses = loop(Type.of[x], Nil)
                s.tree.asExpr match {
                  case '{ ($z: Mirror.SumOf[A]) } =>
                    '{ deriveSum[A](${ typeclasses }, $z) }
                }
              case '[x] =>
                report.errorAndAbort("not found MirroredElemTypes", s.tree.pos)
            }
        }
    }
  }

  /** 展開されたコードが肥大化しないようにわざとこれはinlineにしていない */
  final def deriveProduct[A](typeclasses: Seq[MyTypeClass[?]], mirror: Mirror.ProductOf[A]): MyTypeClass[A] =
    ??? // 実装省略
    
  /** 展開されたコードが肥大化しないようにわざとこれはinlineにしていない */
  final def deriveSum[A](typeclasses: Seq[MyTypeClass[?]], mirror: Mirror.SumOf[A]): MyTypeClass[A] =
    ??? // 実装省略
}

deriveProductとderiveSumを残して、他を書き換えました。

同時にprofileの仕組みというか、生成されたinstance多い順に表示する仕組みも入れました。

このような仕組みを使って実際に計測したところ、多いものではインスタンスが数万個以上も生成されていることがわかりました。 数万個というのも、時間がかかり過ぎて途中でcompileを止めたので最終的には10万を超えていた可能性すらあります。

shapeless 2でほぼ同様のことをやると、ここまで顕著に遅くなったりしないのですが、なぜScala 3だと明らかに遅くなるのか?は調べきれていません。

5: 遅い原因となったinstanceを半手動定義する

ここまでくれば、あとは多い順に、上記だとcreateメソッドを呼び出して、半手動でインスタンス定義をしておけば、それが再利用されることにより、compile時にTreeの大きさやcompile時間が爆発せずにすみます。 (deriveとは別にcreateがあるのは、上記のような定義でderiveの方を呼ぶと自己再帰になる可能性があるため)

6〜7分かかっていたのが、半手動のinstance定義を追加していったら、1〜2分に縮みました (一旦そこで満足したが、後でもっとinstance定義を追加するかも?)

  implicit val someTypeInstance: MyTypeClass[SomeValue] =
    create[SomeValue]
  implicit val anotherTypeInstance: MyTypeClass[AnotherValue] =
    create[AnotherValue]

*1: https://twitter.com/xuwei_k/status/1446464694802284546

*2:元のコードでもそうだが、そもそも今回はMirroredElemTypesだけ使うパターンで、他のものは使わないtype classだが・・・