Skip to content

Commit f717e9d

Browse files
authored
Merge pull request #1111 from bishabosha/patch-1
ensure fastForInline does not make closures
2 parents 435e9af + a1533af commit f717e9d

File tree

3 files changed

+105
-136
lines changed

3 files changed

+105
-136
lines changed

core/src/main/scala-3/spire/syntax/FastForSyntax.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ trait FastForSyntax:
2727
case NumericRange[Long] => Long
2828

2929
inline def fastFor[A](inline init: A)(inline test: A => Boolean, inline next: A => A)(inline body: A => Unit): Unit =
30-
fastForInline(init, test, next, body)
30+
${ fastForImpl('init, 'test, 'next, 'body) }
3131

3232
inline def fastForRange[R <: RangeLike](inline r: R)(inline body: RangeElem[R] => Unit): Unit =
3333
${ fastForRangeMacroGen('r, 'body) }

core/src/main/scala-3/spire/syntax/macros/cforMacros.scala

Lines changed: 99 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -15,157 +15,125 @@
1515

1616
package spire.syntax.macros
1717

18-
import quoted._
19-
import collection.immutable.NumericRange
18+
import scala.quoted.*
19+
import scala.collection.immutable.NumericRange
20+
import scala.PartialFunction.cond
2021

2122
import spire.syntax.fastFor.{RangeElem, RangeLike}
2223

23-
inline def fastForInline[R](init: R, test: R => Boolean, next: R => R, body: R => Unit): Unit =
24-
var index = init
25-
while test(index) do
26-
body(index)
27-
index = next(index)
28-
29-
def fastForRangeMacroGen[R <: RangeLike: Type](r: Expr[R], body: Expr[RangeElem[R] => Unit])(using
30-
quotes: Quotes
24+
def fastForImpl[R: Type](init: Expr[R], test: Expr[R => Boolean], next: Expr[R => R], body: Expr[R => Unit])(using
25+
Quotes
3126
): Expr[Unit] =
32-
import quotes._
33-
import quotes.reflect._
34-
35-
type RangeL = NumericRange[Long]
27+
import quotes.reflect.*
3628

37-
(r, body) match
38-
case '{ $r: Range } -> '{ $body: (Int => Unit) } => fastForRangeMacro(r, body)
39-
case '{ $r: NumericRange[Long] } -> '{ $body: (Long => Unit) } => fastForRangeMacroLong(r, body)
40-
case '{ $r } -> _ => report.error(s"Ineligible Range type ", r); '{}
29+
def code(testRef: Expr[R => Boolean], nextRef: Expr[R => R], bodyRef: Expr[R => Unit]): Expr[Unit] = '{
30+
var index = $init
31+
while $testRef(index) do
32+
$bodyRef(index)
33+
index = $nextRef(index)
34+
}
4135

42-
end fastForRangeMacroGen
36+
letFunc("test", test)(t => letFunc("next", next)(n => letFunc("body", body)(b => code(t, n, b))))
37+
end fastForImpl
4338

44-
def fastForRangeMacroLong(r: Expr[NumericRange[Long]], body: Expr[Long => Unit])(using quotes: Quotes): Expr[Unit] =
45-
import quotes._
39+
def fastForRangeMacroGen[R <: RangeLike: Type](r: Expr[R], body: Expr[RangeElem[R] => Unit])(using
40+
quotes: Quotes
41+
): Expr[Unit] =
4642
import quotes.reflect.*
4743

48-
def strideUpUntil(fromExpr: Expr[Long], untilExpr: Expr[Long], stride: Expr[Long]): Expr[Unit] = '{
49-
var index = $fromExpr
50-
val limit = $untilExpr
51-
val body0 = $body
52-
while index < limit do
53-
${ Expr.betaReduce(body) }(index)
54-
index += $stride
55-
}
44+
r match
45+
case '{ $r: Range } => RangeForImpl.ofInt(r, body.asExprOf[Int => Unit])
46+
case '{ $r: NumericRange[Long] } => RangeForImpl.ofLong(r, body.asExprOf[Long => Unit])
47+
case '{ $r } => report.error(s"Ineligible Range type ", r); '{}
5648

57-
def strideUpTo(fromExpr: Expr[Long], untilExpr: Expr[Long], stride: Expr[Long]): Expr[Unit] = '{
58-
var index = $fromExpr
59-
val end = $untilExpr
60-
while index <= end do
61-
${ Expr.betaReduce(body) }(index)
62-
index += $stride
63-
}
49+
end fastForRangeMacroGen
6450

65-
def strideDownTo(fromExpr: Expr[Long], untilExpr: Expr[Long], stride: Expr[Long]): Expr[Unit] = '{
66-
var index = $fromExpr
67-
val end = $untilExpr
68-
while index >= end do
69-
${ Expr.betaReduce(body) }(index)
70-
index -= $stride
71-
}
51+
private object RangeForImpl:
52+
type Code[T] = Expr[T => Unit] => Expr[Unit]
53+
type Test[T] = (Expr[T], Expr[T]) => Expr[Boolean]
54+
55+
def ofInt(r: Expr[Range], body: Expr[Int => Unit])(using Quotes): Expr[Unit] =
56+
val code: Code[Int] = r match
57+
case '{ ($i: Int) to $j } => loopCode(i, j, 1, (x, y) => '{ $x <= $y })
58+
case '{ ($i: Int) to $j by ${ Expr(k) } } if k > 0 => loopCode(i, j, k, (x, y) => '{ $x <= $y })
59+
case '{ ($i: Int) to $j by ${ Expr(k) } } if k < 0 => loopCode(i, j, k, (x, y) => '{ $x >= $y })
60+
case '{ ($i: Int) to $j by ${ Expr(k) } } if k == 0 => zeroStride(r)
61+
case '{ ($i: Int) until $j } => loopCode(i, j, 1, (x, y) => '{ $x < $y })
62+
case '{ ($i: Int) until $j by ${ Expr(k) } } if k > 0 => loopCode(i, j, k, (x, y) => '{ $x < $y })
63+
case '{ ($i: Int) until $j by ${ Expr(k) } } if k < 0 => loopCode(i, j, k, (x, y) => '{ $x > $y })
64+
case '{ ($i: Int) until $j by ${ Expr(k) } } if k == 0 => zeroStride(r)
65+
case _ => deOpt(r, '{ $r.foreach($body) })
66+
67+
letFunc("body", body)(code)
68+
end ofInt
69+
70+
def ofLong(r: Expr[NumericRange[Long]], body: Expr[Long => Unit])(using quotes: Quotes): Expr[Unit] =
71+
val code: Code[Long] = r match
72+
case '{ ($i: Long) to $j } => loopCode(i, j, 1L, (x, y) => '{ $x <= $y })
73+
case '{ ($i: Long) to $j by ${ Expr(k) } } if k > 0 => loopCode(i, j, k, (x, y) => '{ $x <= $y })
74+
case '{ ($i: Long) to $j by ${ Expr(k) } } if k < 0 => loopCode(i, j, k, (x, y) => '{ $x >= $y })
75+
case '{ ($i: Long) to $j by ${ Expr(k) } } if k == 0 => zeroStride(r)
76+
case '{ ($i: Long) until $j } => loopCode(i, j, 1L, (x, y) => '{ $x < $y })
77+
case '{ ($i: Long) until $j by ${ Expr(k) } } if k > 0 => loopCode(i, j, k, (x, y) => '{ $x < $y })
78+
case '{ ($i: Long) until $j by ${ Expr(k) } } if k < 0 => loopCode(i, j, k, (x, y) => '{ $x > $y })
79+
case '{ ($i: Long) until $j by ${ Expr(k) } } if k == 0 => zeroStride(r)
80+
case _ => deOpt(r, '{ $r.foreach($body) })
81+
82+
letFunc("body", body)(code)
83+
84+
end ofLong
85+
86+
def loopCode[T: Type: ToExpr: CanLoop](i: Expr[T], j: Expr[T], s: T, test: Test[T])(using Quotes): Code[T] =
87+
body =>
88+
'{
89+
var index = $i
90+
val limit = $j
91+
while ${ test('index, 'limit) } do
92+
$body(index)
93+
index = ${ 'index.stepBy(Expr(s)) }
94+
}
7295

73-
def strideDownUntil(fromExpr: Expr[Long], untilExpr: Expr[Long], stride: Expr[Long]): Expr[Unit] = '{
74-
var index = $fromExpr
75-
val limit = $untilExpr
76-
while index > limit do
77-
${ Expr.betaReduce(body) }(index)
78-
index -= $stride
79-
}
96+
def zeroStride[T, R](orig: Expr[R])(using Quotes): Code[T] = _ =>
97+
import quotes.reflect.*
98+
report.error("zero stride", orig)
99+
'{}
80100

81-
r match
82-
case '{ ($i: Long) until $j } => strideUpUntil(i, j, Expr(1L))
83-
case '{ ($i: Long) to $j } => strideUpTo(i, j, Expr(1L))
84-
case '{ ($i: Long) until $j by $step } =>
85-
step.asTerm match {
86-
case Literal(LongConstant(k)) if k > 0 => strideUpUntil(i, j, Expr(k))
87-
case Literal(LongConstant(k)) if k < 0 => strideDownUntil(i, j, Expr(-k))
88-
case Literal(LongConstant(k)) if k == 0 => report.error("zero stride", step); '{}
89-
case _ =>
90-
report.warning(s"defaulting to foreach, can not optimise non-constant step", step)
91-
'{ val b = $body; $r.foreach(b) }
92-
}
93-
case '{ ($i: Long) to $j by $step } =>
94-
step.asTerm match {
95-
case Literal(LongConstant(k)) if k > 0 => strideUpTo(i, j, Expr(k))
96-
case Literal(LongConstant(k)) if k < 0 => strideDownTo(i, j, Expr(-k))
97-
case Literal(LongConstant(k)) if k == 0 => report.error("zero stride", step); '{}
98-
case _ =>
99-
report.warning(s"defaulting to foreach, can not optimise non-constant step", step)
100-
'{ val b = $body; $r.foreach(b) }
101-
}
101+
def deOpt[T, R](orig: Expr[R], foreach: Expr[Unit])(using Quotes): Code[T] = _ =>
102+
import quotes.reflect.*
103+
report.warning(s"defaulting to foreach, can not optimise range expression", orig)
104+
foreach
102105

103-
case _ =>
104-
report.warning(s"defaulting to foreach, can not optimise range expression", r)
105-
'{ val b = $body; $r.foreach(b) }
106+
trait CanLoop[T]:
107+
extension (x: Expr[T]) def stepBy(y: Expr[T])(using Quotes): Expr[T]
106108

107-
end fastForRangeMacroLong
109+
object CanLoop:
110+
given CanLoop[Int] with
111+
extension (x: Expr[Int]) def stepBy(y: Expr[Int])(using Quotes): Expr[Int] = '{ $x + $y }
108112

109-
def fastForRangeMacro(r: Expr[Range], body: Expr[Int => Unit])(using quotes: Quotes): Expr[Unit] =
110-
import quotes._
111-
import quotes.reflect._
113+
given CanLoop[Long] with
114+
extension (x: Expr[Long]) def stepBy(y: Expr[Long])(using Quotes): Expr[Long] = '{ $x + $y }
112115

113-
def strideUpUntil(fromExpr: Expr[Int], untilExpr: Expr[Int], stride: Expr[Int]): Expr[Unit] = '{
114-
var index = $fromExpr
115-
val limit = $untilExpr
116-
while (index < limit) {
117-
${ Expr.betaReduce(body) }(index)
118-
index += $stride
119-
}
120-
}
116+
end RangeForImpl
121117

122-
def strideUpTo(fromExpr: Expr[Int], untilExpr: Expr[Int], stride: Expr[Int]): Expr[Unit] = '{
123-
var index = $fromExpr
124-
val end = $untilExpr
125-
while index <= end do
126-
${ Expr.betaReduce(body) }(index)
127-
index += $stride
128-
}
118+
/**
119+
* Equivalent to `'{ val name: A => B = $rhs; ${in('name)} }`, except when `rhs` is a function literal, then equivalent
120+
* to `in(rhs)`.
121+
*
122+
* This allows inlined function arguments to perform side-effects only once before their first evaluation, while still
123+
* avoiding the creation of closures for function literal arguments.
124+
*/
125+
private def letFunc[A, B, C](using Quotes)(name: String, rhs: Expr[A => B])(in: Expr[A => B] => Expr[C]): Expr[C] =
126+
import quotes.reflect.*
129127

130-
def strideDownTo(fromExpr: Expr[Int], untilExpr: Expr[Int], stride: Expr[Int]): Expr[Unit] = '{
131-
var index = $fromExpr
132-
val end = $untilExpr
133-
while index >= end do
134-
${ Expr.betaReduce(body) }(index)
135-
index -= $stride
136-
}
128+
extension (t: Term) def unsafeAsExpr[A] = t.asExpr.asInstanceOf[Expr[A]] // cast without `quoted.Type[A]`
137129

138-
def strideDownUntil(fromExpr: Expr[Int], untilExpr: Expr[Int], stride: Expr[Int]): Expr[Unit] = '{
139-
var index = $fromExpr
140-
val limit = $untilExpr
141-
while index > limit do
142-
${ Expr.betaReduce(body) }(index)
143-
index -= $stride
130+
def isFunctionLiteral[A, B](f: Expr[A => B]): Boolean = cond(f.asTerm.underlyingArgument) { case Lambda(_, _) =>
131+
true
144132
}
145133

146-
r match
147-
case '{ ($i: Int) until $j } => strideUpUntil(i, j, Expr(1))
148-
case '{ ($i: Int) to $j } => strideUpTo(i, j, Expr(1))
149-
case '{ ($i: Int) until $j by $step } =>
150-
step.asTerm match {
151-
case Literal(IntConstant(k)) if k > 0 => strideUpUntil(i, j, Expr(k))
152-
case Literal(IntConstant(k)) if k < 0 => strideDownUntil(i, j, Expr(-k))
153-
case Literal(IntConstant(k)) if k == 0 => report.error("zero stride", step); '{}
154-
case _ =>
155-
report.warning(s"defaulting to foreach, can not optimise non-constant step", step)
156-
'{ val b = $body; $r.foreach(b) }
157-
}
158-
case '{ ($i: Int) to $j by $step } =>
159-
step.asTerm match {
160-
case Literal(IntConstant(k)) if k > 0 => strideUpTo(i, j, Expr(k))
161-
case Literal(IntConstant(k)) if k < 0 => strideDownTo(i, j, Expr(-k))
162-
case Literal(IntConstant(k)) if k == 0 => report.error("zero stride", step); '{}
163-
case _ =>
164-
report.warning(s"defaulting to foreach, can not optimise non-constant step", step)
165-
'{ val b = $body; $r.foreach(b) }
166-
}
167-
case _ =>
168-
report.warning(s"defaulting to foreach, can not optimise range expression", r)
169-
'{ val b = $body; $r.foreach(b) }
134+
def let[A, B](name: String, rhs: Expr[A])(in: Expr[A] => Expr[B])(using Quotes): Expr[B] =
135+
// Equivalent to `'{ val name = $rhs; ${in('name)} }`
136+
ValDef.let(Symbol.spliceOwner, name, rhs.asTerm)(ref => in(ref.unsafeAsExpr[A]).asTerm).unsafeAsExpr[B]
170137

171-
end fastForRangeMacro
138+
if isFunctionLiteral(rhs) then in(Expr.betaReduce(rhs))
139+
else let(name, rhs)(in)

tests/shared/src/test/scala/spire/syntax/FastForSuite.scala renamed to tests/shared/src/test/scala-3/scala/spire/syntax/FastForSuite.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,15 +99,16 @@ class FastForSuite extends munit.FunSuite {
9999
assertEquals(b.toList, List(0, 1, 2))
100100
}
101101

102-
// This test distinguishes fastFor from cfor
103-
test("doesn't capture value in closure") {
102+
test("capture value in closure") { // same behavior as cfor
104103
val b1 = collection.mutable.ArrayBuffer.empty[() => Int]
105104
fastFor(0)(_ < 3, _ + 1) { x =>
106105
b1 += (() => x)
107106
}
108107
val b2 = collection.mutable.ArrayBuffer[() => Int]()
109-
(0 until 3).foreach { x =>
110-
b2 += (() => x)
108+
var i = 0
109+
while (i < 3) {
110+
b2 += (() => i)
111+
i += 1
111112
}
112113
assertEquals(b1.map(_.apply()).toList, b2.map(_.apply()).toList)
113114
}

0 commit comments

Comments
 (0)