Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
/build/
/install/
mlton-*.tgz
.vscode/
224 changes: 167 additions & 57 deletions basis-library/schedulers/spork/ForkJoin.sml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ struct
datatype TokenPolicy = datatype Scheduler.TokenPolicy

val spork = Scheduler.SporkJoin.spork
val primSporkChoose = Scheduler.primSporkChoose

fun par (f: unit -> 'a, g: unit -> 'b): 'a * 'b =
spork {
Expand Down Expand Up @@ -54,26 +55,14 @@ struct
end


signature LOOP_INDEX =
sig
type idx
type t = idx

val fromInt: int -> idx
val toInt: idx -> int

val increment: idx -> idx
val midpoint: idx * idx -> idx
val equal: idx * idx -> bool
end


functor ManagedLoops (LoopIndex: LOOP_INDEX) :>
sig
val pareduce: (int * int) -> 'a -> (int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a
val pareduceBreakExn: (int * int) -> 'a -> (('a -> exn) * int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a
val reducem: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
val parform: (int * int) -> (int -> unit) -> unit
val seqLoop: (int * int) -> (int -> unit) -> unit
val seqReduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
end =
struct
type idx = LoopIndex.t
Expand Down Expand Up @@ -160,52 +149,26 @@ struct

fun __inline_always__ parform (lo: int, hi: int) (f: int -> unit) : unit =
reducem (fn _ => ()) () (lo, hi) f
end


functor LoopIndexFromWord(WordImpl: WORD) :> LOOP_INDEX =
struct
type idx = WordImpl.word
type t = idx
fun __inline_always__ seqLoop (lo: int, hi: int) (f: int -> unit) : unit =
let
fun loop (i: idx, j: idx) : unit =
if LoopIndex.equal (i, j) then ()
else (__inline_always__ f (LoopIndex.toInt i); loop (LoopIndex.increment i, j))
in
loop (LoopIndex.fromInt (Int.min (lo, hi)), LoopIndex.fromInt hi)
end

fun __inline_always__ toInt (w: idx) = __inline_always__ WordImpl.toIntX w
fun __inline_always__ fromInt i = __inline_always__ WordImpl.fromInt i

fun __inline_always__ midpoint (i: idx, j: idx) =
fun __inline_always__ seqReduce (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a =
let
(* This way is broken! *)
(* val mid = WordImpl.~>> (WordImpl.+ (i, j), 0w1) *)

val range_size = WordImpl.+ (j, WordImpl.~ i)
val mid = WordImpl.+ (i, WordImpl.div (range_size, WordImpl.fromInt 2))
fun loop (acc: 'a) (i: idx, j: idx) : 'a =
if LoopIndex.equal (i, j) then acc
else loop (__inline_always__ combine (acc, __inline_always__ f (LoopIndex.toInt i))) (LoopIndex.increment i, j)
in
(* If using a different midpoint calculation, consider uncommenting
* the following for debugging/testing.
*)

(* if toInt i <= toInt mid andalso toInt mid <= toInt j then
()
else
( print
( "ERROR: schedulers/spork/ForkJoin.sml: bug! midpoint failure: "
^ Int.toString (toInt i)
^ " "
^ Int.toString (toInt mid)
^ " "
^ Int.toString (toInt j)
^ "\n"
)

; OS.Process.exit OS.Process.failure
); *)

mid
loop zero (LoopIndex.fromInt (Int.min (lo, hi)), LoopIndex.fromInt hi)
end

fun __inline_always__ increment (i: idx) =
WordImpl.+ (i, fromInt 1)

fun __inline_always__ equal (i: idx, j: idx) = (i = j)
end


Expand All @@ -221,11 +184,17 @@ sig
val pareduceBreakExn: (int * int) -> 'a -> (('a -> exn) * int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a

val reducem: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
val reduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
val reducemDefault: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
val parform: (int * int) -> (int -> unit) -> unit
val parformDefault: (int * int) -> (int -> unit) -> unit

val parfor: int -> (int * int) -> (int -> unit) -> unit
val alloc: int -> 'a array

val seqLoop: (int * int) -> (int -> unit) -> unit
val seqReduce: ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a

val idleTimeSoFar: unit -> Time.time
val workTimeSoFar: unit -> Time.time
val maxForkDepthSoFar: unit -> int
Expand Down Expand Up @@ -254,6 +223,11 @@ struct
val equal = op=
end)

structure Unrolled8 = UnrolledLoops(Word8)
structure Unrolled16 = UnrolledLoops(Word16)
structure Unrolled32 = UnrolledLoops(Word32)
structure Unrolled64 = UnrolledLoops(Word64)

structure Pareduce =
Int_ChooseFromInt (struct
type 'a t = (int * int) -> 'a -> (int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a
Expand Down Expand Up @@ -294,8 +268,144 @@ struct
val fIntInf = LoopsInt.parform
end)

val pareduce = Pareduce.f
val pareduceBreakExn = PareduceBreakExn.f
val reducem = Reducem.f
val parform = Parform.f
structure UnrolledPareduce =
Int_ChooseFromInt (struct
type 'a t = (int * int) -> 'a -> (int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a
val fInt8 = Unrolled8.pareduce
val fInt16 = Unrolled16.pareduce
val fInt32 = Unrolled32.pareduce
val fInt64 = Unrolled64.pareduce
val fIntInf = Unrolled64.pareduce
end)

structure UnrolledPareduceBreakExn =
Int_ChooseFromInt (struct
type 'a t = (int * int) -> 'a -> (('a -> exn) * int * 'a -> 'a) -> ('a * 'a -> 'a) -> 'a
val fInt8 = Unrolled8.pareduceBreakExn
val fInt16 = Unrolled16.pareduceBreakExn
val fInt32 = Unrolled32.pareduceBreakExn
val fInt64 = Unrolled64.pareduceBreakExn
val fIntInf = Unrolled64.pareduceBreakExn
end)

structure UnrolledReducem =
Int_ChooseFromInt (struct
type 'a t = ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
val fInt8 = Unrolled8.reducem
val fInt16 = Unrolled16.reducem
val fInt32 = Unrolled32.reducem
val fInt64 = Unrolled64.reducem
val fIntInf = Unrolled64.reducem
end)

structure UnrolledParform =
Int_ChooseFromInt (struct
type 'a t = (int * int) -> (int -> unit) -> unit
val fInt8 = Unrolled8.parform
val fInt16 = Unrolled16.parform
val fInt32 = Unrolled32.parform
val fInt64 = Unrolled64.parform
val fIntInf = Unrolled64.parform
end)

structure SeqLoop =
Int_ChooseFromInt (struct
type 'a t = (int * int) -> (int -> unit) -> unit
val fInt8 = Loops8.seqLoop
val fInt16 = Loops16.seqLoop
val fInt32 = Loops32.seqLoop
val fInt64 = Loops64.seqLoop
val fIntInf = LoopsInt.seqLoop
end)

structure SeqReduce =
Int_ChooseFromInt (struct
type 'a t = ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
val fInt8 = Loops8.seqReduce
val fInt16 = Loops16.seqReduce
val fInt32 = Loops32.seqReduce
val fInt64 = Loops64.seqReduce
val fIntInf = LoopsInt.seqReduce
end)

structure UnrolledSeqLoop =
Int_ChooseFromInt (struct
type 'a t = (int * int) -> (int -> unit) -> unit
val fInt8 = Unrolled8.seqLoop
val fInt16 = Unrolled16.seqLoop
val fInt32 = Unrolled32.seqLoop
val fInt64 = Unrolled64.seqLoop
val fIntInf = Unrolled64.seqLoop
end)

structure UnrolledSeqReduce =
Int_ChooseFromInt (struct
type 'a t = ('a * 'a -> 'a) -> 'a -> (int * int) -> (int -> 'a) -> 'a
val fInt8 = Unrolled8.seqReduce
val fInt16 = Unrolled16.seqReduce
val fInt32 = Unrolled32.seqReduce
val fInt64 = Unrolled64.seqReduce
val fIntInf = Unrolled64.seqReduce
end)

local

fun __inline_always__ unifiedReducem (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a =
let
fun __inline_always__ regularImpl () = __inline_always__ Reducem.f combine zero (lo, hi) f
fun __inline_always__ unrolledImpl () = __inline_always__ UnrolledReducem.f combine zero (lo, hi) f
in
primSporkChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl)
end

fun unifiedParform (lo: int, hi: int) (f: int -> unit) : unit =
let
fun __inline_always__ regularImpl () = __inline_always__ Parform.f (lo, hi) f

fun __inline_always__ unrolledImpl () = __inline_always__ UnrolledParform.f (lo, hi) f
in
primSporkChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl)
end

fun __inline_always__ unifiedPareduce (lo: int, hi: int) (zero: 'a) (step: int * 'a -> 'a) (combine: 'a * 'a -> 'a) : 'a =
let
fun __inline_always__ regularImpl () =
__inline_always__ Pareduce.f (lo, hi) zero step combine

fun __inline_always__ unrolledImpl () =
__inline_always__ UnrolledPareduce.f (lo, hi) zero step combine

fun __inline_always__ loopBody i = __inline_always__ step (i, zero)
in
primSporkChoose (__inline_always__ loopBody, __inline_always__ unrolledImpl, __inline_always__ regularImpl)
end

fun __inline_always__ unifiedSeqLoop (lo: int, hi: int) (f: int -> unit) : unit =
let
fun __inline_always__ regularImpl () = __inline_always__ SeqLoop.f (lo, hi) f
fun __inline_always__ unrolledImpl () = __inline_always__ UnrolledSeqLoop.f (lo, hi) f
in
Scheduler.primLoopChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl)
end

fun __inline_always__ unifiedSeqReduce (combine: 'a * 'a -> 'a) (zero: 'a) (lo: int, hi: int) (f: int -> 'a) : 'a =
let
fun __inline_always__ regularImpl () = __inline_always__ SeqReduce.f combine zero (lo, hi) f
fun __inline_always__ unrolledImpl () = __inline_always__ UnrolledSeqReduce.f combine zero (lo, hi) f
in
Scheduler.primLoopChoose (__inline_always__ f, __inline_always__ unrolledImpl, __inline_always__ regularImpl)
end
in
val reducem = __inline_always__ unifiedReducem
val reduce = __inline_always__ unifiedReducem
val reducemDefault = __inline_always__ Reducem.f
val parform = __inline_always__ unifiedParform
val parformDefault = __inline_always__ Parform.f
val pareduce = __inline_always__ unifiedPareduce
val parfor = __inline_always__ ForkJoin0.parfor
val seqLoop = __inline_always__ unifiedSeqLoop
val seqReduce = __inline_always__ unifiedSeqReduce
end

val pareduceBreakExn = __inline_always__ PareduceBreakExn.f
end
58 changes: 58 additions & 0 deletions basis-library/schedulers/spork/LoopIndex.sml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
signature LOOP_INDEX =
sig
type idx
type t = idx

val fromInt: int -> idx
val toInt: idx -> int

val increment: idx -> idx
val midpoint: idx * idx -> idx
val equal: idx * idx -> bool
end


functor LoopIndexFromWord(WordImpl: WORD) :> LOOP_INDEX =
struct
type idx = WordImpl.word
type t = idx

fun __inline_always__ toInt (w: idx) = __inline_always__ WordImpl.toIntX w
fun __inline_always__ fromInt i = __inline_always__ WordImpl.fromInt i

fun __inline_always__ midpoint (i: idx, j: idx) =
let
(* This way is broken! *)
(* val mid = WordImpl.~>> (WordImpl.+ (i, j), 0w1) *)

val range_size = WordImpl.+ (j, WordImpl.~ i)
val mid = WordImpl.+ (i, WordImpl.div (range_size, WordImpl.fromInt 2))
in
(* If using a different midpoint calculation, consider uncommenting
* the following for debugging/testing.
*)

(* if toInt i <= toInt mid andalso toInt mid <= toInt j then
()
else
( print
( "ERROR: schedulers/spork/ForkJoin.sml: bug! midpoint failure: "
^ Int.toString (toInt i)
^ " "
^ Int.toString (toInt mid)
^ " "
^ Int.toString (toInt j)
^ "\n"
)

; OS.Process.exit OS.Process.failure
); *)

mid
end

fun __inline_always__ increment (i: idx) =
WordImpl.+ (i, fromInt 1)

fun __inline_always__ equal (i: idx, j: idx) = (i = j)
end
21 changes: 20 additions & 1 deletion basis-library/schedulers/spork/Scheduler.sml
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,30 @@ struct
* (exn -> 'c) (* exn seq *)
* (exn * 'd -> 'c) (* exn sync *)
-> 'c;
val primSporkChoose' =
_prim "spork_choose"
: ('u -> 'a) (* loop body *)
* (unit -> 'a) (* unrolled implementation *)
* (unit -> 'a) (* regular implementation *)
-> 'a;
val primLoopChoose' =
_prim "loop_choose"
: ('u -> 'a) (* loop body *)
* (unit -> 'a) (* unrolled implementation *)
* (unit -> 'a) (* regular implementation *)
-> 'a;

fun __inline_always__ primSporkFair (body, spwn, seq, sync, exnseq, exnsync) =
__inline_always__ primSporkFair' (body, (), spwn, (), seq, sync, exnseq, exnsync)
fun __inline_always__ primSporkKeep (body, spwn, seq, sync, exnseq, exnsync) =
__inline_always__ primSporkKeep' (body, (), spwn, (), seq, sync, exnseq, exnsync)
fun __inline_always__ primSporkGive (body, spwn, seq, sync, exnseq, exnsync) =
__inline_always__ primSporkGive' (body, (), spwn, (), seq, sync, exnseq, exnsync)

fun __inline_always__ primSporkChoose (loopBody, unrolled, regular) =
__inline_always__ primSporkChoose' (loopBody, unrolled, regular)
fun __inline_always__ primLoopChoose (loopBody, unrolled, regular) =
__inline_always__ primLoopChoose' (loopBody, unrolled, regular)

val primForkThreadAndSetData = _prim "spork_forkThreadAndSetData": Thread.t * 'a -> Thread.p;
val primForkThreadAndSetData_youngest = _prim "spork_forkThreadAndSetData_youngest": Thread.t * 'a -> Thread.p;

Expand Down Expand Up @@ -1024,6 +1041,7 @@ struct
fun __inline_always__ tryPromoteNow yo =
( Thread.atomicBegin ()
; if
(* ! Second heartbeat check *)
Heartbeat.enoughToSpawn () andalso
#maybeSpawn (sched_package ()) yo (Thread.current ())
then
Expand Down Expand Up @@ -1052,6 +1070,7 @@ struct
val (inject, project) = Universal.embed ()

fun __inline_always__ body' (): 'a =
(* ! First Hearbeat Check *)
((if not (Heartbeat.enoughToSpawn ()) then () else tryPromoteNow {youngestOptimization = true});
__inline_always__ body ())

Expand Down
Loading