Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions modules/build/src/main/scala/scala/build/ScopedSources.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ final case class ScopedSources(
): Either[BuildException, Sources] = either {
val combinedOptions = combinedBuildOptions(scope, baseOptions)

val codeWrapper = ScriptPreprocessor.getScriptWrapper(combinedOptions)
val codeWrapper = ScriptPreprocessor.getScriptWrapper(combinedOptions, logger)

val wrappedScripts = unwrappedScripts
.flatMap(_.valueFor(scope).toSeq)
.map(_.wrap(codeWrapper))

codeWrapper match {
case _: AppCodeWrapper.type if wrappedScripts.size > 1 =>
case _: AppCodeWrapper if wrappedScripts.size > 1 =>
wrappedScripts.find(_.originalPath.exists(_._1.toString == "main.sc"))
.foreach(_ => logger.diagnostic(WarningMessages.mainScriptNameClashesWithAppWrapper))
case _ => ()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package scala.build.internal

case object AppCodeWrapper extends CodeWrapper {
case class AppCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {
override def mainClassObject(className: Name) = className

def apply(
Expand All @@ -12,13 +12,19 @@ case object AppCodeWrapper extends CodeWrapper {
) = {
val wrapperObjectName = indexedWrapperName.backticked

val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
val invokeMain = mainObject match
case WrapperUtils.ScriptMainMethod.Exists(name) => s"\n$name.main(args)"
case otherwise =>
otherwise.warningMessage.foreach(log)
""
val packageDirective =
if (pkgName.isEmpty) "" else s"package ${AmmUtil.encodeScalaSourcePath(pkgName)}" + "\n"
val top = AmmUtil.normalizeNewlines(
s"""$packageDirective
|
|object $wrapperObjectName extends App {
|val scriptPath = \"\"\"$scriptPath\"\"\"
|val scriptPath = \"\"\"$scriptPath\"\"\"$invokeMain
|""".stripMargin
)
val bottom = AmmUtil.normalizeNewlines(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ package scala.build.internal
* running interconnected scripts using Scala CLI <br> <br> Incompatible with Scala 2 - it uses
* Scala 3 feature 'export'<br> Incompatible with native JS members - the wrapper is a class
*/
case object ClassCodeWrapper extends CodeWrapper {
case class ClassCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {

override def mainClassObject(className: Name): Name =
Name(className.raw ++ "_sc")
Expand All @@ -16,8 +16,16 @@ case object ClassCodeWrapper extends CodeWrapper {
extraCode: String,
scriptPath: String
) = {

val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
val mainInvocation = mainObject match
case WrapperUtils.ScriptMainMethod.Exists(name) => s"script.$name.main(args)"
case otherwise =>
otherwise.warningMessage.foreach(log)
s"val _ = script.hashCode()"

val name = mainClassObject(indexedWrapperName).backticked
val wrapperClassName = Name(indexedWrapperName.raw ++ "$_").backticked
val wrapperClassName = scala.build.internal.Name(indexedWrapperName.raw ++ "$_").backticked
val mainObjectCode =
AmmUtil.normalizeNewlines(s"""|object $name {
| private var args$$opt0 = Option.empty[Array[String]]
Expand All @@ -33,7 +41,7 @@ case object ClassCodeWrapper extends CodeWrapper {
|
| def main(args: Array[String]): Unit = {
| args$$set(args)
| val _ = script.hashCode() // hashCode to clear scalac warning about pure expression in statement position
| $mainInvocation // hashCode to clear scalac warning about pure expression in statement position
| }
|}
|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package scala.build.internal
* or/and not using JS native prefer [[ClassCodeWrapper]], since it prevents deadlocks when running
* threads from script
*/
case object ObjectCodeWrapper extends CodeWrapper {
case class ObjectCodeWrapper(scalaVersion: String, log: String => Unit) extends CodeWrapper {

override def mainClassObject(className: Name): Name =
Name(className.raw ++ "_sc")
Expand All @@ -15,12 +15,19 @@ case object ObjectCodeWrapper extends CodeWrapper {
extraCode: String,
scriptPath: String
) = {
val mainObject = WrapperUtils.mainObjectInScript(scalaVersion, code)
val name = mainClassObject(indexedWrapperName).backticked
val aliasedWrapperName = name + "$$alias"
val funHashCodeMethod =
val realScript =
if (name == "main_sc")
s"$aliasedWrapperName.alias.hashCode()" // https://github.com/VirtusLab/scala-cli/issues/314
else s"${indexedWrapperName.backticked}.hashCode()"
s"$aliasedWrapperName.alias" // https://github.com/VirtusLab/scala-cli/issues/314
else s"${indexedWrapperName.backticked}"

val funHashCodeMethod = mainObject match
case WrapperUtils.ScriptMainMethod.Exists(name) => s"$realScript.$name.main(args)"
case otherwise =>
otherwise.warningMessage.foreach(log)
s"val _ = $realScript.hashCode()"
// We need to call hashCode (or any other method so compiler does not report a warning)
val mainObjectCode =
AmmUtil.normalizeNewlines(s"""|object $name {
Expand All @@ -34,7 +41,7 @@ case object ObjectCodeWrapper extends CodeWrapper {
| }
| def main(args: Array[String]): Unit = {
| args$$set(args)
| val _ = $funHashCodeMethod // hashCode to clear scalac warning about pure expression in statement position
| $funHashCodeMethod // hashCode to clear scalac warning about pure expression in statement position
| }
|}
|""".stripMargin)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package scala.build.internal

import scala.build.internal.util.WarningMessages

object WrapperUtils {

enum ScriptMainMethod:
case Exists(name: String)
case Multiple(names: Seq[String])
case ToplevelStatsPresent
case NoMain

def warningMessage: Option[String] =
this match
case ScriptMainMethod.Multiple(names) =>
Some(WarningMessages.multipleMainObjectsInScript(names))
case ScriptMainMethod.ToplevelStatsPresent => Some(
WarningMessages.mixedToplvelAndObjectInScript
)
case _ => None

def mainObjectInScript(scalaVersion: String, code: String): ScriptMainMethod =
import scala.meta.*

val scriptDialect =
if scalaVersion.startsWith("3") then dialects.Scala3Future else dialects.Scala213Source3

given Dialect = scriptDialect.withAllowToplevelStatements(true).withAllowToplevelTerms(true)
val parsedCode = code.parse[Source] match
case Parsed.Success(Source(stats)) => stats
case _ => Nil

// Check if there is a main function defined inside an object
def checkSignature(defn: Defn.Def) =
defn.paramClauseGroups match
case List(Member.ParamClauseGroup(
Type.ParamClause(Nil),
List(Term.ParamClause(
List(Term.Param(
Nil,
_: Term.Name,
Some(Type.Apply.After_4_6_0(
Type.Name("Array"),
Type.ArgClause(List(Type.Name("String")))
)),
None
)),
None
))
)) => true
case _ => false

def noToplevelStatements = parsedCode.forall {
case _: Term => false
case _ => true
}

def hasMainSignature(templ: Template) = templ.body.stats.exists {
case defn: Defn.Def =>
defn.name.value == "main" && checkSignature(defn)
case _ => false
}
def extendsApp(templ: Template) = templ.inits match
case Init.After_4_6_0(Type.Name("App"), _, Nil) :: Nil => true
case _ => false
val potentialMains = parsedCode.collect {
case Defn.Object(_, objName, templ) if extendsApp(templ) || hasMainSignature(templ) =>
Seq(objName.value)
}.flatten

potentialMains match
case head :: Nil if noToplevelStatements =>
ScriptMainMethod.Exists(head)
case head :: Nil =>
ScriptMainMethod.ToplevelStatsPresent
case Nil => ScriptMainMethod.NoMain
case seq =>
ScriptMainMethod.Multiple(seq)

}
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ object WarningMessages {
val offlineModeBloopJvmNotFound =
"Offline mode is ON and a JVM for Bloop could not be fetched from the local cache, using scalac as fallback"

def multipleMainObjectsInScript(names: Seq[String]) =
s"Only single main is allowed within scripts and multiple main classes were found in the script: ${names.mkString(", ")}"

def mixedToplvelAndObjectInScript =
"Script contains objects with main methods and top-level statements, only the latter will be run."

def directivesInMultipleFilesWarning(
projectFilePath: String,
pathsToReport: Iterable[String] = Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ case object ScriptPreprocessor extends Preprocessor {
(codeWrapper: CodeWrapper) =>
if (containsMainAnnot) logger.diagnostic(
codeWrapper match {
case _: AppCodeWrapper.type =>
case _: AppCodeWrapper =>
WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ true)
case _ => WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ false)
}
Expand All @@ -157,24 +157,27 @@ case object ScriptPreprocessor extends Preprocessor {
* @return
* code wrapper compatible with provided BuildOptions
*/
def getScriptWrapper(buildOptions: BuildOptions): CodeWrapper = {
def getScriptWrapper(buildOptions: BuildOptions, logger: Logger): CodeWrapper = {
val effectiveScalaVersion =
buildOptions.scalaOptions.scalaVersion.flatMap(_.versionOpt)
.orElse(buildOptions.scalaOptions.defaultScalaVersion)
.getOrElse(Constants.defaultScalaVersion)
def logWarning(msg: String) = logger.diagnostic(msg)

def objectCodeWrapperForScalaVersion =
// AppObjectWrapper only introduces the 'main.sc' restriction when used in Scala 3, there's no gain in using it with Scala 3
if effectiveScalaVersion.startsWith("2") then AppCodeWrapper
else ObjectCodeWrapper
if effectiveScalaVersion.startsWith("2") then
AppCodeWrapper(effectiveScalaVersion, logWarning)
else ObjectCodeWrapper(effectiveScalaVersion, logWarning)

buildOptions.scriptOptions.forceObjectWrapper match {
case Some(true) => objectCodeWrapperForScalaVersion
case _ =>
buildOptions.scalaOptions.platform.map(_.value) match {
case Some(_: Platform.JS.type) => objectCodeWrapperForScalaVersion
case _ if effectiveScalaVersion.startsWith("2") => AppCodeWrapper
case _ => ClassCodeWrapper
case Some(_: Platform.JS.type) => objectCodeWrapperForScalaVersion
case _ if effectiveScalaVersion.startsWith("2") =>
AppCodeWrapper(effectiveScalaVersion, logWarning)
case _ => ClassCodeWrapper(effectiveScalaVersion, logWarning)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,133 @@ trait RunScriptTestDefinitions { _: RunTestDefinitions =>
}
}

test("main.sc has an object with a main method") {
val message = "Hello"
val inputs = TestInputs(
os.rel / "main.sc" ->
s"""|
|object Main {
| def main(args: Array[String]): Unit = println("$message")
|}
|""".stripMargin
)
inputs.fromRoot { root =>
val output = os.proc(TestUtil.cli, extraOptions, "main.sc").call(cwd =
root
).out.trim()
expect(output == message)
}
}
test("main.sc has an object that extends App") {
val message = "Hello"
val inputs = TestInputs(
os.rel / "main.sc" ->
s"""|
|object Main extends App{
| println("$message")
|}
|
|object Other {}
|""".stripMargin
)
inputs.fromRoot { root =>
val output = os.proc(TestUtil.cli, extraOptions, "main.sc").call(cwd =
root
).out.trim()
expect(output == message)
}
}

test("main.sc has an object with a main method and an object wrapper") {
val message = "Hello"
val inputs = TestInputs(
os.rel / "main.sc" ->
s"""|//> using objectWrapper
|object Main {
| def main(args: Array[String]): Unit = println("$message")
|}
|""".stripMargin
)
inputs.fromRoot { root =>
val output = os.proc(TestUtil.cli, extraOptions, "--power", "main.sc").call(cwd =
root
).out.trim()
expect(output == message)
}
}

test("main.sc has multiple main methods") {
val inputs = TestInputs(
os.rel / "main.sc" ->
s"""|//> using objectWrapper
|object Main {
| def main(args: Array[String]): Unit = println("1")
|}
|object AnotherMain {
| def main(args: Array[String]): Unit = println("2")
|}
|""".stripMargin
)
inputs.fromRoot { root =>
val result = os.proc(TestUtil.cli, extraOptions, "--power", "main.sc").call(
cwd = root,
stderr = os.Pipe
)
val output = result.out.trim()
val err = result.err.trim()
expect(output == "")
expect(err.contains(
"Only single main is allowed within scripts and multiple main classes were found in the script: Main, AnotherMain"
))
}
}

test("main.sc has both an object with a main method as well as top-level definitions") {
val message1 = "Hello"
val message2 = "Another hello"
val inputs = TestInputs(
os.rel / "main.sc" ->
s"""|object Main {
| def main(args: Array[String]): Unit = println("$message1")
|}
|println("$message2")
|""".stripMargin
)
inputs.fromRoot { root =>
val result = os.proc(TestUtil.cli, extraOptions, "main.sc").call(
cwd = root,
stderr = os.Pipe
)
val output = result.out.trim()
val err = result.err.trim()
expect(output == message2)
expect(err.contains(
"Script contains objects with main methods and top-level statements, only the latter will be run."
))
expect(output == message2)
}
}

test(
"main.sc has both an object with a main method and an object wrapper as well as top-level calls"
) {
val message1 = "Hello"
val message2 = "Another hello"
val inputs = TestInputs(
os.rel / "main.sc" ->
s"""|//> using objectWrapper
|object Main {
| def main(args: Array[String]): Unit = println("$message1")
|}
|println("$message2")
|""".stripMargin
)
inputs.fromRoot { root =>
val output = os.proc(TestUtil.cli, extraOptions, "--power", "main.sc")
.call(cwd = root).out.trim()
expect(output == message2)
}
}
if (actualScalaVersion.startsWith("3"))
test("use method from main.sc file") {
val message = "Hello"
Expand Down
Loading
Loading