|
15 | 15 |
|
16 | 16 | package spire.syntax.macros
|
17 | 17 |
|
18 |
| -import quoted._ |
19 |
| -import collection.immutable.NumericRange |
| 18 | +import scala.quoted.* |
| 19 | +import scala.collection.immutable.NumericRange |
| 20 | +import scala.PartialFunction.cond |
20 | 21 |
|
21 | 22 | import spire.syntax.fastFor.{RangeElem, RangeLike}
|
22 | 23 |
|
23 |
| -inline def fastForInline[R](init: R, test: R => Boolean, next: R => R, body: R => Unit): Unit = |
24 |
| - var index = init |
25 |
| - while test(index) do |
26 |
| - body(index) |
27 |
| - index = next(index) |
28 |
| - |
29 |
| -def fastForRangeMacroGen[R <: RangeLike: Type](r: Expr[R], body: Expr[RangeElem[R] => Unit])(using |
30 |
| - quotes: Quotes |
| 24 | +def fastForImpl[R: Type](init: Expr[R], test: Expr[R => Boolean], next: Expr[R => R], body: Expr[R => Unit])(using |
| 25 | + Quotes |
31 | 26 | ): Expr[Unit] =
|
32 |
| - import quotes._ |
33 |
| - import quotes.reflect._ |
34 |
| - |
35 |
| - type RangeL = NumericRange[Long] |
| 27 | + import quotes.reflect.* |
36 | 28 |
|
37 |
| - (r, body) match |
38 |
| - case '{ $r: Range } -> '{ $body: (Int => Unit) } => fastForRangeMacro(r, body) |
39 |
| - case '{ $r: NumericRange[Long] } -> '{ $body: (Long => Unit) } => fastForRangeMacroLong(r, body) |
40 |
| - case '{ $r } -> _ => report.error(s"Ineligible Range type ", r); '{} |
| 29 | + def code(testRef: Expr[R => Boolean], nextRef: Expr[R => R], bodyRef: Expr[R => Unit]): Expr[Unit] = '{ |
| 30 | + var index = $init |
| 31 | + while $testRef(index) do |
| 32 | + $bodyRef(index) |
| 33 | + index = $nextRef(index) |
| 34 | + } |
41 | 35 |
|
42 |
| -end fastForRangeMacroGen |
| 36 | + letFunc("test", test)(t => letFunc("next", next)(n => letFunc("body", body)(b => code(t, n, b)))) |
| 37 | +end fastForImpl |
43 | 38 |
|
44 |
| -def fastForRangeMacroLong(r: Expr[NumericRange[Long]], body: Expr[Long => Unit])(using quotes: Quotes): Expr[Unit] = |
45 |
| - import quotes._ |
| 39 | +def fastForRangeMacroGen[R <: RangeLike: Type](r: Expr[R], body: Expr[RangeElem[R] => Unit])(using |
| 40 | + quotes: Quotes |
| 41 | +): Expr[Unit] = |
46 | 42 | import quotes.reflect.*
|
47 | 43 |
|
48 |
| - def strideUpUntil(fromExpr: Expr[Long], untilExpr: Expr[Long], stride: Expr[Long]): Expr[Unit] = '{ |
49 |
| - var index = $fromExpr |
50 |
| - val limit = $untilExpr |
51 |
| - val body0 = $body |
52 |
| - while index < limit do |
53 |
| - ${ Expr.betaReduce(body) }(index) |
54 |
| - index += $stride |
55 |
| - } |
| 44 | + r match |
| 45 | + case '{ $r: Range } => RangeForImpl.ofInt(r, body.asExprOf[Int => Unit]) |
| 46 | + case '{ $r: NumericRange[Long] } => RangeForImpl.ofLong(r, body.asExprOf[Long => Unit]) |
| 47 | + case '{ $r } => report.error(s"Ineligible Range type ", r); '{} |
56 | 48 |
|
57 |
| - def strideUpTo(fromExpr: Expr[Long], untilExpr: Expr[Long], stride: Expr[Long]): Expr[Unit] = '{ |
58 |
| - var index = $fromExpr |
59 |
| - val end = $untilExpr |
60 |
| - while index <= end do |
61 |
| - ${ Expr.betaReduce(body) }(index) |
62 |
| - index += $stride |
63 |
| - } |
| 49 | +end fastForRangeMacroGen |
64 | 50 |
|
65 |
| - def strideDownTo(fromExpr: Expr[Long], untilExpr: Expr[Long], stride: Expr[Long]): Expr[Unit] = '{ |
66 |
| - var index = $fromExpr |
67 |
| - val end = $untilExpr |
68 |
| - while index >= end do |
69 |
| - ${ Expr.betaReduce(body) }(index) |
70 |
| - index -= $stride |
71 |
| - } |
| 51 | +private object RangeForImpl: |
| 52 | + type Code[T] = Expr[T => Unit] => Expr[Unit] |
| 53 | + type Test[T] = (Expr[T], Expr[T]) => Expr[Boolean] |
| 54 | + |
| 55 | + def ofInt(r: Expr[Range], body: Expr[Int => Unit])(using Quotes): Expr[Unit] = |
| 56 | + val code: Code[Int] = r match |
| 57 | + case '{ ($i: Int) to $j } => loopCode(i, j, 1, (x, y) => '{ $x <= $y }) |
| 58 | + case '{ ($i: Int) to $j by ${ Expr(k) } } if k > 0 => loopCode(i, j, k, (x, y) => '{ $x <= $y }) |
| 59 | + case '{ ($i: Int) to $j by ${ Expr(k) } } if k < 0 => loopCode(i, j, k, (x, y) => '{ $x >= $y }) |
| 60 | + case '{ ($i: Int) to $j by ${ Expr(k) } } if k == 0 => zeroStride(r) |
| 61 | + case '{ ($i: Int) until $j } => loopCode(i, j, 1, (x, y) => '{ $x < $y }) |
| 62 | + case '{ ($i: Int) until $j by ${ Expr(k) } } if k > 0 => loopCode(i, j, k, (x, y) => '{ $x < $y }) |
| 63 | + case '{ ($i: Int) until $j by ${ Expr(k) } } if k < 0 => loopCode(i, j, k, (x, y) => '{ $x > $y }) |
| 64 | + case '{ ($i: Int) until $j by ${ Expr(k) } } if k == 0 => zeroStride(r) |
| 65 | + case _ => deOpt(r, '{ $r.foreach($body) }) |
| 66 | + |
| 67 | + letFunc("body", body)(code) |
| 68 | + end ofInt |
| 69 | + |
| 70 | + def ofLong(r: Expr[NumericRange[Long]], body: Expr[Long => Unit])(using quotes: Quotes): Expr[Unit] = |
| 71 | + val code: Code[Long] = r match |
| 72 | + case '{ ($i: Long) to $j } => loopCode(i, j, 1L, (x, y) => '{ $x <= $y }) |
| 73 | + case '{ ($i: Long) to $j by ${ Expr(k) } } if k > 0 => loopCode(i, j, k, (x, y) => '{ $x <= $y }) |
| 74 | + case '{ ($i: Long) to $j by ${ Expr(k) } } if k < 0 => loopCode(i, j, k, (x, y) => '{ $x >= $y }) |
| 75 | + case '{ ($i: Long) to $j by ${ Expr(k) } } if k == 0 => zeroStride(r) |
| 76 | + case '{ ($i: Long) until $j } => loopCode(i, j, 1L, (x, y) => '{ $x < $y }) |
| 77 | + case '{ ($i: Long) until $j by ${ Expr(k) } } if k > 0 => loopCode(i, j, k, (x, y) => '{ $x < $y }) |
| 78 | + case '{ ($i: Long) until $j by ${ Expr(k) } } if k < 0 => loopCode(i, j, k, (x, y) => '{ $x > $y }) |
| 79 | + case '{ ($i: Long) until $j by ${ Expr(k) } } if k == 0 => zeroStride(r) |
| 80 | + case _ => deOpt(r, '{ $r.foreach($body) }) |
| 81 | + |
| 82 | + letFunc("body", body)(code) |
| 83 | + |
| 84 | + end ofLong |
| 85 | + |
| 86 | + def loopCode[T: Type: ToExpr: CanLoop](i: Expr[T], j: Expr[T], s: T, test: Test[T])(using Quotes): Code[T] = |
| 87 | + body => |
| 88 | + '{ |
| 89 | + var index = $i |
| 90 | + val limit = $j |
| 91 | + while ${ test('index, 'limit) } do |
| 92 | + $body(index) |
| 93 | + index = ${ 'index.stepBy(Expr(s)) } |
| 94 | + } |
72 | 95 |
|
73 |
| - def strideDownUntil(fromExpr: Expr[Long], untilExpr: Expr[Long], stride: Expr[Long]): Expr[Unit] = '{ |
74 |
| - var index = $fromExpr |
75 |
| - val limit = $untilExpr |
76 |
| - while index > limit do |
77 |
| - ${ Expr.betaReduce(body) }(index) |
78 |
| - index -= $stride |
79 |
| - } |
| 96 | + def zeroStride[T, R](orig: Expr[R])(using Quotes): Code[T] = _ => |
| 97 | + import quotes.reflect.* |
| 98 | + report.error("zero stride", orig) |
| 99 | + '{} |
80 | 100 |
|
81 |
| - r match |
82 |
| - case '{ ($i: Long) until $j } => strideUpUntil(i, j, Expr(1L)) |
83 |
| - case '{ ($i: Long) to $j } => strideUpTo(i, j, Expr(1L)) |
84 |
| - case '{ ($i: Long) until $j by $step } => |
85 |
| - step.asTerm match { |
86 |
| - case Literal(LongConstant(k)) if k > 0 => strideUpUntil(i, j, Expr(k)) |
87 |
| - case Literal(LongConstant(k)) if k < 0 => strideDownUntil(i, j, Expr(-k)) |
88 |
| - case Literal(LongConstant(k)) if k == 0 => report.error("zero stride", step); '{} |
89 |
| - case _ => |
90 |
| - report.warning(s"defaulting to foreach, can not optimise non-constant step", step) |
91 |
| - '{ val b = $body; $r.foreach(b) } |
92 |
| - } |
93 |
| - case '{ ($i: Long) to $j by $step } => |
94 |
| - step.asTerm match { |
95 |
| - case Literal(LongConstant(k)) if k > 0 => strideUpTo(i, j, Expr(k)) |
96 |
| - case Literal(LongConstant(k)) if k < 0 => strideDownTo(i, j, Expr(-k)) |
97 |
| - case Literal(LongConstant(k)) if k == 0 => report.error("zero stride", step); '{} |
98 |
| - case _ => |
99 |
| - report.warning(s"defaulting to foreach, can not optimise non-constant step", step) |
100 |
| - '{ val b = $body; $r.foreach(b) } |
101 |
| - } |
| 101 | + def deOpt[T, R](orig: Expr[R], foreach: Expr[Unit])(using Quotes): Code[T] = _ => |
| 102 | + import quotes.reflect.* |
| 103 | + report.warning(s"defaulting to foreach, can not optimise range expression", orig) |
| 104 | + foreach |
102 | 105 |
|
103 |
| - case _ => |
104 |
| - report.warning(s"defaulting to foreach, can not optimise range expression", r) |
105 |
| - '{ val b = $body; $r.foreach(b) } |
| 106 | + trait CanLoop[T]: |
| 107 | + extension (x: Expr[T]) def stepBy(y: Expr[T])(using Quotes): Expr[T] |
106 | 108 |
|
107 |
| -end fastForRangeMacroLong |
| 109 | + object CanLoop: |
| 110 | + given CanLoop[Int] with |
| 111 | + extension (x: Expr[Int]) def stepBy(y: Expr[Int])(using Quotes): Expr[Int] = '{ $x + $y } |
108 | 112 |
|
109 |
| -def fastForRangeMacro(r: Expr[Range], body: Expr[Int => Unit])(using quotes: Quotes): Expr[Unit] = |
110 |
| - import quotes._ |
111 |
| - import quotes.reflect._ |
| 113 | + given CanLoop[Long] with |
| 114 | + extension (x: Expr[Long]) def stepBy(y: Expr[Long])(using Quotes): Expr[Long] = '{ $x + $y } |
112 | 115 |
|
113 |
| - def strideUpUntil(fromExpr: Expr[Int], untilExpr: Expr[Int], stride: Expr[Int]): Expr[Unit] = '{ |
114 |
| - var index = $fromExpr |
115 |
| - val limit = $untilExpr |
116 |
| - while (index < limit) { |
117 |
| - ${ Expr.betaReduce(body) }(index) |
118 |
| - index += $stride |
119 |
| - } |
120 |
| - } |
| 116 | +end RangeForImpl |
121 | 117 |
|
122 |
| - def strideUpTo(fromExpr: Expr[Int], untilExpr: Expr[Int], stride: Expr[Int]): Expr[Unit] = '{ |
123 |
| - var index = $fromExpr |
124 |
| - val end = $untilExpr |
125 |
| - while index <= end do |
126 |
| - ${ Expr.betaReduce(body) }(index) |
127 |
| - index += $stride |
128 |
| - } |
| 118 | +/** |
| 119 | + * Equivalent to `'{ val name: A => B = $rhs; ${in('name)} }`, except when `rhs` is a function literal, then equivalent |
| 120 | + * to `in(rhs)`. |
| 121 | + * |
| 122 | + * This allows inlined function arguments to perform side-effects only once before their first evaluation, while still |
| 123 | + * avoiding the creation of closures for function literal arguments. |
| 124 | + */ |
| 125 | +private def letFunc[A, B, C](using Quotes)(name: String, rhs: Expr[A => B])(in: Expr[A => B] => Expr[C]): Expr[C] = |
| 126 | + import quotes.reflect.* |
129 | 127 |
|
130 |
| - def strideDownTo(fromExpr: Expr[Int], untilExpr: Expr[Int], stride: Expr[Int]): Expr[Unit] = '{ |
131 |
| - var index = $fromExpr |
132 |
| - val end = $untilExpr |
133 |
| - while index >= end do |
134 |
| - ${ Expr.betaReduce(body) }(index) |
135 |
| - index -= $stride |
136 |
| - } |
| 128 | + extension (t: Term) def unsafeAsExpr[A] = t.asExpr.asInstanceOf[Expr[A]] // cast without `quoted.Type[A]` |
137 | 129 |
|
138 |
| - def strideDownUntil(fromExpr: Expr[Int], untilExpr: Expr[Int], stride: Expr[Int]): Expr[Unit] = '{ |
139 |
| - var index = $fromExpr |
140 |
| - val limit = $untilExpr |
141 |
| - while index > limit do |
142 |
| - ${ Expr.betaReduce(body) }(index) |
143 |
| - index -= $stride |
| 130 | + def isFunctionLiteral[A, B](f: Expr[A => B]): Boolean = cond(f.asTerm.underlyingArgument) { case Lambda(_, _) => |
| 131 | + true |
144 | 132 | }
|
145 | 133 |
|
146 |
| - r match |
147 |
| - case '{ ($i: Int) until $j } => strideUpUntil(i, j, Expr(1)) |
148 |
| - case '{ ($i: Int) to $j } => strideUpTo(i, j, Expr(1)) |
149 |
| - case '{ ($i: Int) until $j by $step } => |
150 |
| - step.asTerm match { |
151 |
| - case Literal(IntConstant(k)) if k > 0 => strideUpUntil(i, j, Expr(k)) |
152 |
| - case Literal(IntConstant(k)) if k < 0 => strideDownUntil(i, j, Expr(-k)) |
153 |
| - case Literal(IntConstant(k)) if k == 0 => report.error("zero stride", step); '{} |
154 |
| - case _ => |
155 |
| - report.warning(s"defaulting to foreach, can not optimise non-constant step", step) |
156 |
| - '{ val b = $body; $r.foreach(b) } |
157 |
| - } |
158 |
| - case '{ ($i: Int) to $j by $step } => |
159 |
| - step.asTerm match { |
160 |
| - case Literal(IntConstant(k)) if k > 0 => strideUpTo(i, j, Expr(k)) |
161 |
| - case Literal(IntConstant(k)) if k < 0 => strideDownTo(i, j, Expr(-k)) |
162 |
| - case Literal(IntConstant(k)) if k == 0 => report.error("zero stride", step); '{} |
163 |
| - case _ => |
164 |
| - report.warning(s"defaulting to foreach, can not optimise non-constant step", step) |
165 |
| - '{ val b = $body; $r.foreach(b) } |
166 |
| - } |
167 |
| - case _ => |
168 |
| - report.warning(s"defaulting to foreach, can not optimise range expression", r) |
169 |
| - '{ val b = $body; $r.foreach(b) } |
| 134 | + def let[A, B](name: String, rhs: Expr[A])(in: Expr[A] => Expr[B])(using Quotes): Expr[B] = |
| 135 | + // Equivalent to `'{ val name = $rhs; ${in('name)} }` |
| 136 | + ValDef.let(Symbol.spliceOwner, name, rhs.asTerm)(ref => in(ref.unsafeAsExpr[A]).asTerm).unsafeAsExpr[B] |
170 | 137 |
|
171 |
| -end fastForRangeMacro |
| 138 | + if isFunctionLiteral(rhs) then in(Expr.betaReduce(rhs)) |
| 139 | + else let(name, rhs)(in) |
0 commit comments