Skip to content

Commit 222d928

Browse files
Pass Automated Testing SuitePass Automated Testing Suite
authored andcommitted
Implement PCG64 without using the uin128 type.
1 parent 29f6478 commit 222d928

File tree

2 files changed

+65
-31
lines changed

2 files changed

+65
-31
lines changed

lib/pcg.ml

Lines changed: 64 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,35 @@
55
SPDX-License-Identifier: BSD-3-Clause *)
66
open Stdint
77

8+
module U128 = struct
9+
type t = { high : uint64; low : uint64 }
10+
11+
let of_u64 high low = {high; low}
12+
13+
let one = Uint64.{high = zero; low = one}
14+
15+
let zero = Uint64.{high = zero; low = zero}
16+
17+
let ( + ) a b =
18+
match Uint64.{high = a.high + b.high; low = a.low + b.low} with
19+
| x when x.low < b.low -> {x with high = Uint64.(x.high + one)}
20+
| x -> x
21+
22+
let max32 = Uint32.(max_int |> to_uint64)
23+
let mult64 x y =
24+
let open Uint64 in
25+
let x0 = logand max32 x and y0 = logand max32 y
26+
and x1 = shift_right x 32 and y1 = shift_right y 32 in
27+
let t = shift_right (x0 * y0) 32 + x1 * y0 in
28+
{high = shift_right (logand max32 t + x0 * y1) 32 + (shift_right t 32) + x1 * y1; low = x * y}
29+
30+
let ( * ) a b = match mult64 a.low b.low with
31+
| {high;low} -> {high = Uint64.(high + a.high * b.low + a.low * b.high); low}
32+
33+
let ( ** ) a b = match mult64 a.low b with
34+
| x -> {x with high = Uint64.(x.high + a.high * b)}
35+
end
36+
837

938
module PCG64 : sig
1039
(** PCG-64 is a 128-bit implementation of O'Neill's permutation congruential
@@ -20,7 +49,7 @@ module PCG64 : sig
2049

2150
include Common.BITGEN
2251

23-
val advance : int128 -> t -> t
52+
val advance : uint64 * uint64 -> t -> t
2453
(** [advance delta] Advances the underlying RNG as if [delta] draws have been made.
2554
The returned state is that of the generator [delta] steps forward. *)
2655

@@ -29,55 +58,60 @@ module PCG64 : sig
2958
(0, bound) as well as the state of the generator advanced one step forward. *)
3059
end = struct
3160
type t = {s : setseq; ustore : uint32 option}
32-
and setseq = {state : uint128; increment : uint128}
61+
and setseq = {state : U128.t; increment : U128.t}
62+
63+
64+
let sixtythree = Uint32.of_int32 63l
65+
let multiplier = U128.of_u64 (Uint64.of_int64 2549297995355413924L)
66+
(Uint64.of_int64 4865540595714422341L)
3367

34-
let multiplier = Uint128.of_string "0x2360ed051fc65da44385df649fccf645"
35-
let sixtythree = Uint32.of_int 63
3668

3769
(* Uses the XSL-RR output function *)
38-
let output state =
39-
let v = Uint128.(shift_right state 64 |> logxor state |> to_uint64)
40-
and r = Uint128.(shift_right state 122 |> to_int) in
41-
let nr = Uint32.(of_int r |> neg |> logand sixtythree |> to_int) in
42-
Uint64.(logor (shift_left v nr) (shift_right v r))
43-
70+
let output U128.{high; low} =
71+
let v = Uint64.(logxor high low) in
72+
let r = Uint64.(shift_right high 58 |> to_int) in
73+
let nr = Uint32.(of_int r |> neg |> logand sixtythree |> to_int) in
74+
Uint64.(logor (shift_left v nr) (shift_right v r))
75+
4476

4577
let next {state; increment} =
46-
let state' = Uint128.(state * multiplier + increment) in
78+
let state' = U128.(state * multiplier + increment) in
4779
output state', {state = state'; increment}
4880

4981

5082
let next_uint64 t = match next t.s with
5183
| u, s -> u, {t with s}
52-
84+
5385

5486
let next_uint32 t =
5587
match Common.next_uint32 ~next:next t.s t.ustore with
56-
| u, s, ustore -> u, {s; ustore}
88+
| u, s, ustore -> u, {s; ustore}
5789

5890

59-
let next_double t = Common.next_double ~nextu64:next_uint64 t
91+
let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t
6092

6193

62-
let advance delta {s = {state; increment}; _} =
63-
let open Uint128 in
64-
let rec lcg d am ap cm cp = (* advance state using LCG method *)
65-
match d = zero, logand d one = one with
66-
| true, _ -> am * state + ap
67-
| false, true -> lcg (shift_right d 1) (am * cm) (ap * cm + cp) (cm * cm) (cp * (cm + one))
68-
| false, false -> lcg (shift_right d 1) am ap (cm * cm) (cp * (cm + one))
69-
in {s = {state = lcg (Uint128.of_int128 delta) one zero multiplier increment; increment}; ustore = None}
94+
let next_double t = Common.next_double ~nextu64:next_uint64 t
7095

7196

7297
let set_seed seed =
73-
let open Uint128 in
74-
let s = logor (shift_left (of_uint64 seed.(0)) 64) (of_uint64 seed.(1))
75-
and i = logor (shift_left (of_uint64 seed.(2)) 64) (of_uint64 seed.(3)) in
76-
let increment = logor (shift_left i 1) one in
77-
{state = (increment + s) * multiplier + increment; increment}
78-
79-
80-
let next_bounded_uint64 bound t = Common.next_bounded_uint64 bound ~nextu64:next_uint64 t
98+
let s2 = Uint64.(logor (shift_left seed.(2) 1) (shift_right seed.(3) 63)) in
99+
let s3 = Uint64.(logor (shift_left seed.(3) 1) one) in
100+
let increment = U128.of_u64 s2 s3 in
101+
let state = U128.(zero * multiplier + increment) in
102+
{state = U128.((of_u64 seed.(0) seed.(1) + state) * multiplier + increment); increment}
103+
104+
105+
let advance (d1, d0) {s = {state; increment}; _} =
106+
let open U128 in
107+
let half x = U128.{low = Uint64.(logor (shift_right x.low 1) (shift_left x.high 63));
108+
high = Uint64.(shift_right x.high 1)} in
109+
let rec lcg d am ap cm cp =
110+
match Uint64.(d.high <= zero && d.low <= zero, logand d.low one = one) with
111+
| true, _ -> am * state + ap
112+
| false, true -> lcg (half d) (am * cm) (ap * cm + cp) (cm * cm) (cp * (cm + one))
113+
| false, false -> lcg (half d) am ap (cm * cm) (cp * (cm + one))
114+
in {s = {state = lcg (of_u64 d1 d0) one zero multiplier increment; increment}; ustore = None}
81115

82116

83117
let initialize seed =

test/test_pcg.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ let test_advance _ =
77
let t = SeedSequence.initialize [Uint128.of_int 12345] |> PCG64.initialize in
88
let advance n = Seq.(iterate (fun s -> PCG64.next_uint64 s |> snd) t |> drop n |> uncons |> Option.get |> fst) in
99
assert_equal
10-
(PCG64.advance (Int128.of_int 100) t |> PCG64.next_uint64 |> fst |> Uint64.to_string)
10+
(PCG64.advance Uint64.(of_int 0, of_int 100) t |> PCG64.next_uint64 |> fst |> Uint64.to_string)
1111
(advance 100 |> PCG64.next_uint64 |> fst |> Uint64.to_string)
1212
~printer:(fun x -> x)
1313

0 commit comments

Comments
 (0)