@@ -12,11 +12,6 @@ import Annotations.Annotation
1212
1313object MainProxies {
1414
15- /** Generate proxy classes for @main functions and @myMain functions where myMain <:< MainAnnotation */
16- def proxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
17- mainAnnotationProxies(stats) ++ mainProxies(stats)
18- }
19-
2015 /** Generate proxy classes for @main functions.
2116 * A function like
2217 *
@@ -35,7 +30,7 @@ object MainProxies {
3530 * catch case err: ParseError => showError(err)
3631 * }
3732 */
38- private def mainProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
33+ def proxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
3934 import tpd .*
4035 def mainMethods (stats : List [Tree ]): List [Symbol ] = stats.flatMap {
4136 case stat : DefDef if stat.symbol.hasAnnotation(defn.MainAnnot ) =>
@@ -127,323 +122,4 @@ object MainProxies {
127122 result
128123 }
129124
130- private type DefaultValueSymbols = Map [Int , Symbol ]
131- private type ParameterAnnotationss = Seq [Seq [Annotation ]]
132-
133- /**
134- * Generate proxy classes for main functions.
135- * A function like
136- *
137- * /* *
138- * * Lorem ipsum dolor sit amet
139- * * consectetur adipiscing elit.
140- * *
141- * * @param x my param x
142- * * @param ys all my params y
143- * */
144- * @myMain(80) def f(
145- * @myMain.Alias("myX") x: S,
146- * y: S,
147- * ys: T*
148- * ) = ...
149- *
150- * would be translated to something like
151- *
152- * final class f {
153- * static def main(args: Array[String]): Unit = {
154- * val annotation = new myMain(80)
155- * val info = new Info(
156- * name = "f",
157- * documentation = "Lorem ipsum dolor sit amet consectetur adipiscing elit.",
158- * parameters = Seq(
159- * new scala.annotation.MainAnnotation.Parameter("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX"))),
160- * new scala.annotation.MainAnnotation.Parameter("y", "S", true, false, "", Seq()),
161- * new scala.annotation.MainAnnotation.Parameter("ys", "T", false, true, "all my params y", Seq())
162- * )
163- * ),
164- * val command = annotation.command(info, args)
165- * if command.isDefined then
166- * val cmd = command.get
167- * val args0: () => S = annotation.argGetter[S](info.parameters(0), cmd(0), None)
168- * val args1: () => S = annotation.argGetter[S](info.parameters(1), mainArgs(1), Some(() => sum$default$1()))
169- * val args2: () => Seq[T] = annotation.varargGetter[T](info.parameters(2), cmd.drop(2))
170- * annotation.run(() => f(args0(), args1(), args2()*))
171- * }
172- * }
173- */
174- private def mainAnnotationProxies (stats : List [tpd.Tree ])(using Context ): List [untpd.Tree ] = {
175- import tpd .*
176-
177- /**
178- * Computes the symbols of the default values of the function. Since they cannot be inferred anymore at this
179- * point of the compilation, they must be explicitly passed by [[mainProxy ]].
180- */
181- def defaultValueSymbols (scope : Tree , funSymbol : Symbol ): DefaultValueSymbols =
182- scope match {
183- case TypeDef (_, template : Template ) =>
184- template.body.flatMap((_ : Tree ) match {
185- case dd : DefDef if dd.name.is(DefaultGetterName ) && dd.name.firstPart == funSymbol.name =>
186- val DefaultGetterName .NumberedInfo (index) = dd.name.info: @ unchecked
187- List (index -> dd.symbol)
188- case _ => Nil
189- }).toMap
190- case _ => Map .empty
191- }
192-
193- /** Computes the list of main methods present in the code. */
194- def mainMethods (scope : Tree , stats : List [Tree ]): List [(Symbol , ParameterAnnotationss , DefaultValueSymbols , Option [Comment ])] = stats.flatMap {
195- case stat : DefDef =>
196- val sym = stat.symbol
197- sym.annotations.filter(_.matches(defn.MainAnnotationClass )) match {
198- case Nil =>
199- Nil
200- case _ :: Nil =>
201- val paramAnnotations = stat.paramss.flatMap(_.map(
202- valdef => valdef.symbol.annotations.filter(_.matches(defn.MainAnnotationParameterAnnotation ))
203- ))
204- (sym, paramAnnotations.toVector, defaultValueSymbols(scope, sym), stat.rawComment) :: Nil
205- case mainAnnot :: others =>
206- report.error(em " method cannot have multiple main annotations " , mainAnnot.tree)
207- Nil
208- }
209- case stat @ TypeDef (_, impl : Template ) if stat.symbol.is(Module ) =>
210- mainMethods(stat, impl.body)
211- case _ =>
212- Nil
213- }
214-
215- // Assuming that the top-level object was already generated, all main methods will have a scope
216- mainMethods(EmptyTree , stats).flatMap(mainAnnotationProxy)
217- }
218-
219- private def mainAnnotationProxy (mainFun : Symbol , paramAnnotations : ParameterAnnotationss , defaultValueSymbols : DefaultValueSymbols , docComment : Option [Comment ])(using Context ): Option [TypeDef ] = {
220- val mainAnnot = mainFun.getAnnotation(defn.MainAnnotationClass ).get
221- def pos = mainFun.sourcePos
222-
223- val documentation = new Documentation (docComment)
224-
225- /** () => value */
226- def unitToValue (value : Tree ): Tree =
227- val defDef = DefDef (nme.ANON_FUN , List (Nil ), TypeTree (), value)
228- Block (defDef, Closure (Nil , Ident (nme.ANON_FUN ), EmptyTree ))
229-
230- /** Generate a list of trees containing the ParamInfo instantiations.
231- *
232- * A ParamInfo has the following shape
233- * ```
234- * new scala.annotation.MainAnnotation.Parameter("x", "S", false, false, "my param x", Seq(new scala.main.Alias("myX")))
235- * ```
236- */
237- def parameterInfos (mt : MethodType ): List [Tree ] =
238- extension (tree : Tree ) def withProperty (sym : Symbol , args : List [Tree ]) =
239- Apply (Select (tree, sym.name), args)
240-
241- for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
242- val param = paramName.toString
243- val paramType0 = if formal.isRepeatedParam then formal.argTypes.head.dealias else formal.dealias
244- val paramType = paramType0.dealias
245- val paramTypeOwner = paramType.typeSymbol.owner
246- val paramTypeStr =
247- if paramTypeOwner == defn.EmptyPackageClass then paramType.show
248- else paramTypeOwner.showFullName + " ." + paramType.show
249- val hasDefault = defaultValueSymbols.contains(idx)
250- val isRepeated = formal.isRepeatedParam
251- val paramDoc = documentation.argDocs.getOrElse(param, " " )
252- val paramAnnots =
253- val annotationTrees = paramAnnotations(idx).map(instantiateAnnotation).toList
254- Apply (ref(defn.SeqModule .termRef), annotationTrees)
255-
256- val constructorArgs = List (param, paramTypeStr, hasDefault, isRepeated, paramDoc)
257- .map(value => Literal (Constant (value)))
258-
259- New (TypeTree (defn.MainAnnotationParameter .typeRef), List (constructorArgs :+ paramAnnots))
260-
261- end parameterInfos
262-
263- /**
264- * Creates a list of references and definitions of arguments.
265- * The goal is to create the
266- * `val args0: () => S = annotation.argGetter[S](0, cmd(0), None)`
267- * part of the code.
268- */
269- def argValDefs (mt : MethodType ): List [ValDef ] =
270- for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
271- val argName = nme.args ++ idx.toString
272- val isRepeated = formal.isRepeatedParam
273- val formalType = if isRepeated then formal.argTypes.head else formal
274- val getterName = if isRepeated then nme.varargGetter else nme.argGetter
275- val defaultValueGetterOpt = defaultValueSymbols.get(idx) match
276- case None => ref(defn.NoneModule .termRef)
277- case Some (dvSym) =>
278- val value = unitToValue(ref(dvSym.termRef))
279- Apply (ref(defn.SomeClass .companionModule.termRef), value)
280- val argGetter0 = TypeApply (Select (Ident (nme.annotation), getterName), TypeTree (formalType) :: Nil )
281- val index = Literal (Constant (idx))
282- val paramInfo = Apply (Select (Ident (nme.info), nme.parameters), index)
283- val argGetter =
284- if isRepeated then Apply (argGetter0, List (paramInfo, Apply (Select (Ident (nme.cmd), nme.drop), List (index))))
285- else Apply (argGetter0, List (paramInfo, Apply (Ident (nme.cmd), List (index)), defaultValueGetterOpt))
286- ValDef (argName, TypeTree (), argGetter)
287- end argValDefs
288-
289-
290- /** Create a list of argument references that will be passed as argument to the main method.
291- * `args0`, ...`argn*`
292- */
293- def argRefs (mt : MethodType ): List [Tree ] =
294- for ((formal, paramName), idx) <- mt.paramInfos.zip(mt.paramNames).zipWithIndex yield
295- val argRef = Apply (Ident (nme.args ++ idx.toString), Nil )
296- if formal.isRepeatedParam then repeated(argRef) else argRef
297- end argRefs
298-
299-
300- /** Turns an annotation (e.g. `@main(40)`) into an instance of the class (e.g. `new scala.main(40)`). */
301- def instantiateAnnotation (annot : Annotation ): Tree =
302- val argss = {
303- def recurse (t : tpd.Tree , acc : List [List [Tree ]]): List [List [Tree ]] = t match {
304- case Apply (t, args : List [tpd.Tree ]) => recurse(t, extractArgs(args) :: acc)
305- case _ => acc
306- }
307-
308- def extractArgs (args : List [tpd.Tree ]): List [Tree ] =
309- args.flatMap {
310- case Typed (SeqLiteral (varargs, _), _) => varargs.map(arg => TypedSplice (arg))
311- case arg : Select if arg.name.is(DefaultGetterName ) => Nil // Ignore default values, they will be added later by the compiler
312- case arg => List (TypedSplice (arg))
313- }
314-
315- recurse(annot.tree, Nil )
316- }
317-
318- New (TypeTree (annot.symbol.typeRef), argss)
319- end instantiateAnnotation
320-
321- def generateMainClass (mainCall : Tree , args : List [Tree ], parameterInfos : List [Tree ]): TypeDef =
322- val cmdInfo =
323- val nameTree = Literal (Constant (mainFun.showName))
324- val docTree = Literal (Constant (documentation.mainDoc))
325- val paramInfos = Apply (ref(defn.SeqModule .termRef), parameterInfos)
326- New (TypeTree (defn.MainAnnotationInfo .typeRef), List (List (nameTree, docTree, paramInfos)))
327-
328- val annotVal = ValDef (
329- nme.annotation,
330- TypeTree (),
331- instantiateAnnotation(mainAnnot)
332- )
333- val infoVal = ValDef (
334- nme.info,
335- TypeTree (),
336- cmdInfo
337- )
338- val command = ValDef (
339- nme.command,
340- TypeTree (),
341- Apply (
342- Select (Ident (nme.annotation), nme.command),
343- List (Ident (nme.info), Ident (nme.args))
344- )
345- )
346- val argsVal = ValDef (
347- nme.cmd,
348- TypeTree (),
349- Select (Ident (nme.command), nme.get)
350- )
351- val run = Apply (Select (Ident (nme.annotation), nme.run), mainCall)
352- val body0 = If (
353- Select (Ident (nme.command), nme.isDefined),
354- Block (argsVal :: args, run),
355- EmptyTree
356- )
357- val body = Block (List (annotVal, infoVal, command), body0) // TODO add `if (cmd.nonEmpty)`
358-
359- val mainArg = ValDef (nme.args, TypeTree (defn.ArrayType .appliedTo(defn.StringType )), EmptyTree )
360- .withFlags(Param )
361- /** Replace typed `Ident`s that have been typed with a TypeSplice with the reference to the symbol.
362- * The annotations will be retype-checked in another scope that may not have the same imports.
363- */
364- def insertTypeSplices = new TreeMap {
365- override def transform (tree : Tree )(using Context ): Tree = tree match
366- case tree : tpd.Ident @ unchecked => TypedSplice (tree)
367- case tree => super .transform(tree)
368- }
369- val annots = mainFun.annotations
370- .filterNot(_.matches(defn.MainAnnotationClass ))
371- .map(annot => insertTypeSplices.transform(annot.tree))
372- val mainMeth = DefDef (nme.main, (mainArg :: Nil ) :: Nil , TypeTree (defn.UnitType ), body)
373- .withFlags(JavaStatic )
374- .withAnnotations(annots)
375- val mainTempl = Template (emptyConstructor, Nil , Nil , EmptyValDef , mainMeth :: Nil )
376- val mainCls = TypeDef (mainFun.name.toTypeName, mainTempl)
377- .withFlags(Final | Invisible )
378- mainCls.withSpan(mainAnnot.tree.span.toSynthetic)
379- end generateMainClass
380-
381- if (! mainFun.owner.isStaticOwner)
382- report.error(em " main method is not statically accessible " , pos)
383- None
384- else mainFun.info match {
385- case _ : ExprType =>
386- Some (generateMainClass(unitToValue(ref(mainFun.termRef)), Nil , Nil ))
387- case mt : MethodType =>
388- if (mt.isImplicitMethod)
389- report.error(em " main method cannot have implicit parameters " , pos)
390- None
391- else mt.resType match
392- case restpe : MethodType =>
393- report.error(em " main method cannot be curried " , pos)
394- None
395- case _ =>
396- Some (generateMainClass(unitToValue(Apply (ref(mainFun.termRef), argRefs(mt))), argValDefs(mt), parameterInfos(mt)))
397- case _ : PolyType =>
398- report.error(em " main method cannot have type parameters " , pos)
399- None
400- case _ =>
401- report.error(em " main can only annotate a method " , pos)
402- None
403- }
404- }
405-
406- /** A class responsible for extracting the docstrings of a method. */
407- private class Documentation (docComment : Option [Comment ]):
408- import util .CommentParsing .*
409-
410- /** The main part of the documentation. */
411- lazy val mainDoc : String = _mainDoc
412- /** The parameters identified by @param. Maps from parameter name to its documentation. */
413- lazy val argDocs : Map [String , String ] = _argDocs
414-
415- private var _mainDoc : String = " "
416- private var _argDocs : Map [String , String ] = Map ()
417-
418- docComment match {
419- case Some (comment) => if comment.isDocComment then parseDocComment(comment.raw) else _mainDoc = comment.raw
420- case None =>
421- }
422-
423- private def cleanComment (raw : String ): String =
424- var lines : Seq [String ] = raw.trim.nn.split('\n ' ).nn.toSeq
425- lines = lines.map(l => l.substring(skipLineLead(l, - 1 ), l.length).nn.trim.nn)
426- var s = lines.foldLeft(" " ) {
427- case (" " , s2) => s2
428- case (s1, " " ) if s1.last == '\n ' => s1 // Multiple newlines are kept as single newlines
429- case (s1, " " ) => s1 + '\n '
430- case (s1, s2) if s1.last == '\n ' => s1 + s2
431- case (s1, s2) => s1 + ' ' + s2
432- }
433- s.replaceAll(raw " \[\[ " , " " ).nn.replaceAll(raw " \]\] " , " " ).nn.trim.nn
434-
435- private def parseDocComment (raw : String ): Unit =
436- // Positions of the sections (@) in the docstring
437- val tidx : List [(Int , Int )] = tagIndex(raw)
438-
439- // Parse main comment
440- var mainComment : String = raw.substring(skipLineLead(raw, 0 ), startTag(raw, tidx)).nn
441- _mainDoc = cleanComment(mainComment)
442-
443- // Parse arguments comments
444- val argsCommentsSpans : Map [String , (Int , Int )] = paramDocs(raw, " @param" , tidx)
445- val argsCommentsTextSpans = argsCommentsSpans.view.mapValues(extractSectionText(raw, _))
446- val argsCommentsTexts = argsCommentsTextSpans.mapValues({ case (beg, end) => raw.substring(beg, end).nn })
447- _argDocs = argsCommentsTexts.mapValues(cleanComment(_)).toMap
448- end Documentation
449125}
0 commit comments