ライブラリを使わずにScalaプログラム内部でclasspathなどを丸ごと引き継いで別のJVMを立ち上げるサンプル

そもそもがかなりレアケースなのですが、一応、稀にありえるシチュエーションとしては

  • 特別な(微妙な)コードをテストしたい
    • scala.sys.exit 自体を至る所で呼び出していて、本来それ自体を改善したいが、一旦それをそのままテストしたい
    • mainスレッドが終わった後に、デーモンではないスレッドが残らないか?みたいな、JVM丸ごと変えないと実質テストが難しいもの
    • objectの内部などで java.lang.System.getProperty 的なグローバルに依存した(初期化の)処理を実行してしまっていて、そこのコード変えないとうまくテストができない(が、コードをすぐに変えられない)
    • shutdown hook的な処理自体のテストを本物に近い形で行いたい
  • 普通にプログラムを書くと依存ライブラリが衝突して困ってるが、ClassLoader使うテクニックもやりたくないのでJVMのプロセスを丸ごとforkして処理を実行したほうが(効率は微妙だが)わかりやすい

などです。タイトルにも書いた細かいポイントとして

  • ライブラリ使わない
  • classpath引き継ぐ

あたりもあります。

また、Scalaユーザーなら、sbtのrunやtestのtask自体が、そもそも勝手にforkして実行してくれる機能備わってますよね?と思うかもしれません。

確かにその通りなのですが、JVMをforkして実行し、結果のexit codeが想定通りになっているか?ということそのものを、scalatestなどのライブラリ使って普通の書き方でテストしたかったら、sbt自体の仕組みを普通に使うだけでは、おそらく足りません。sbtのbuildファイルをゴニョゴニョして、独自task呼ぶ、とか何かが必要になる気がします。 (実はもっとすごい手軽なやり方あったら教えてください・・・)

ライブラリ使っていいなら、たとえばsbt内部に以下のようなものがあります(sbtのものは何度も使ったことがある)

適当にググったらJavaのライブラリでいくつも出てくると思います(以下雑に出てきた例。メンテされてなさそうだが・・・)

あとはsbt pluginもありますね。(元々akkaのテストのために作られたもの?)

ただ、sbtはScalaのversionが固定されるし、単純な最低限の例でいいなら、外部のライブラリ使わなくても、それなりに簡単に書けるので、それをメモ程度に貼っておく、というだけの記事です。 scala.sys.processなど使って書いてるから少し短いですが、pure Javaでも同じような感じで書けると思います。

大まかな方針としては

  • java home経由で java コマンド自体の場所を特定
  • それにいい感じに引数を構築して渡す
  • scala.sys.process.Process で実行

というだけですね。 あとの詳細は、インラインでコメント書いておきます。

build.sbt

scalaVersion := "3.4.1"

// こうしないと `javaClassPath` などが、うまく取得不可能なので。
// 2段階でforkすることになる
// runではなくtest内部でforkする場合も、おそらく同様?
run / fork := true

Main.scala

package example

import java.io.File
import java.lang.management.ManagementFactory
import scala.jdk.CollectionConverters.*
import scala.sys.process.Process
import scala.sys.process.ProcessLogger
import scala.util.Properties

object A1 {
  def main(args: Array[String]): Unit = {
    // 異なるprocessなんだぞ!を確認するデバック用
    println("A1: process id = " + ProcessHandle.current().pid())

    // これで `-Xmx` とか `-Dkey=value` などが取れる
    // 今回はそのまま全部渡しているが、必要に応じて、filterするとか、独自に構築する
    val javaVmArgs = ManagementFactory.getRuntimeMXBean.getInputArguments.asScala.toList

    // OSの差異を無視してたり、細かい部分色々雑だが、とりあえずこれで見つかるぞ
    val javaCommand = new File(Properties.javaHome, "bin/java").getCanonicalPath

    // ベタ書きでもいいが、この場合は直接参照可能なので
    // Scalaの場合 `object` の末尾には `$` がつくのでそれを消すためのdropRight
    val className = A2.getClass.getName.dropRight(1)
    val args: Seq[String] = Seq(
      Seq(javaCommand),
      javaVmArgs,
      Seq("-cp", Properties.javaClassPath, className) // classpathも丸ごと渡してるが、必要に応じて変更しましょう
    ).flatten

    // とりあえず全部標準出力に出すだけの雑なlogger渡してるだけ
    // たとえば、標準入力もつなげたい場合、std err分けたい場合などは、もっと工夫が必要
    val exitCode = Process(args).!(ProcessLogger(str => println(str)))
    assert(exitCode == 0) // これで、とりあえず正常終了でも異常終了でも、それ自体のテストが可能になる
  }
}

object A2 {
  def main(args: Array[String]): Unit = {
    // 異なるprocessなんだぞ!を確認するデバック用
    println("A2: process id = " + ProcessHandle.current().pid())

    // ここに、JVM分かれてないと難しいテスト対象の処理などが書かれている、という想定
  }
}

sbt "runMain example.A1" した場合の実行結果の例が以下

[info] running (fork) example.A1 
[info] A1: process id = 13512
[info] A2: process id = 13513