@@ -97,78 +97,128 @@ where
9797 let m = n << 1 ;
9898 assert ! ( x < m) ;
9999
100- // We need q and r s.t. RR/2 = qm + r, and `0 <= r < m`
101- // As R/4 < m < R/2,
102- // we have R <= q < 2R
103- // so let q = R + f
104- // RR/2 = (R + f)m + r
105- // R(R/2 - m) = fm + r
106-
107- // v = R/2 - m < R/4 < m
108- let v = ( _1 << ( U :: BITS - 1 ) ) - m;
109- let ( f, r) = v. widen_hi ( ) . checked_narrowing_div_rem ( m) . unwrap ( ) ;
110-
111- // xq < qm <= RR/2
112- // 2xq < RR
113- // 2xq = 2xR + 2xf;
114- let _2x: U = x << 1 ;
100+ // We need to compute the parameters
101+ // `q = (RR/2) / m`
102+ // `r = (RR/2) % m`
103+
104+ // Since `m` is in `(R/4, R/2)`, the quotient `q` is in `[R, 2R)`, and
105+ // it would overflow in `U` if computed directly. Instead, we compute
106+ // `f = q - R`, which is in `[0, R)`. To do so, we simply subtract `Rm`
107+ // from the dividend, which doesn't change the remainder:
108+ // `f = R(R/2 - m) / m`
109+ // `r = R(R/2 - m) % m`
110+ let dividend = ( ( _1 << ( U :: BITS - 1 ) ) - m) . widen_hi ( ) ;
111+ let ( f, r) = dividend. checked_narrowing_div_rem ( m) . unwrap ( ) ;
112+
113+ // As `x < m`, `xq < qm <= RR/2`
114+ // Thus `2xq = 2xR + 2xf` does not overflow in `U::D`.
115+ let _2x = x + x;
115116 let _2xq = _2x. widen_hi ( ) + _2x. widen_mul ( f) ;
116117 Self { m, r, _2xq }
117118 }
118119
119- /// Extract the current remainder in the range `[0, 2n)`
120+ /// Extract the current remainder `x` in the range `[0, 2n)`
120121 fn partial_remainder ( & self ) -> U {
121- // RR/2 = qm + r, 0 <= r < m
122- // 2xq = uR + v, 0 <= v < R
123- // muR = 2mxq - mv
124- // = xRR - 2xr - mv
125- // mu + (2xr + mv)/R == xR
126-
127- // 0 <= 2xq < RR
128- // R <= q < 2R
129- // 0 <= x < R/2
130- // R/4 < m < R/2
131- // 0 <= r < m
132- // 0 <= mv < mR
133- // 0 <= 2xr < rR < mR
134-
135- // 0 <= (2xr + mv)/R < 2m
136- // Add `mu` to each term to obtain:
137- // mu <= xR < mu + 2m
138-
139- // Since `0 <= 2m < R`, `xR` is the only multiple of `R` between
140- // `mu` and `m(u+2)`, so the high half of `m(u+2)` must equal `x`.
141- let _1 = U :: ONE ;
142- self . m . widen_mul ( self . _2xq . hi ( ) + ( _1 + _1) ) . hi ( )
122+ // `RR/2 = qm + r`, where `0 <= r < m`
123+ // `2xq = uR + v`, where `0 <= v < R`
124+
125+ // The goal is to extract the current value of `x` from the value `2xq`
126+ // that we actually have. A bit simplified, we could multiply it by `m`
127+ // to obtain `2xqm == 2x(RR/2 - r) == xRR - 2xr`, where `2xr < RR`.
128+ // We could just round that up to the next multiple of `RR` to get `x`,
129+ // but we can avoid having to multiply the full double-wide `2xq` by
130+ // making a couple of adjustments:
131+
132+ // First, let's only use the high half `u` for the product, and
133+ // include an additional error term due to the truncation:
134+ // `mu = xR - (2xr + mv)/R`
135+
136+ // Next, show bounds for the error term
137+ // `0 <= mv < mR` follows from `0 <= v < R`
138+ // `0 <= 2xr < mR` follows from `0 <= x < m < R/2` and `0 <= r < m`
139+ // Adding those together, we have:
140+ // `0 <= (mv + 2xr)/R < 2m`
141+ // Which also implies:
142+ // `0 < 2m - (mv + 2xr)/R <= 2m < R`
143+
144+ // For that reason, we can use `u + 2` as the factor to obtain
145+ // `m(u + 2) = xR + (2m - (mv + 2xr)/R)`
146+ // By the previous inequality, the second term fits neatly in the lower
147+ // half, so we get exactly `x` as the high half.
148+ let u = self . _2xq . hi ( ) ;
149+ let _2 = U :: ONE + U :: ONE ;
150+ self . m . widen_mul ( u + _2) . hi ( )
151+
152+ // Additionally, we should ensure that `u + 2` cannot overflow:
153+ // Since `x < m` and `2qm <= RR`,
154+ // `2xq <= 2q(m-1) <= RR - 2q`
155+ // As we also have `q > R`,
156+ // `2xq < RR - 2R`
157+ // which is sufficient.
143158 }
144159
145160 /// Replace the remainder `x` with `(x << k) - un`,
146161 /// for a suitable quotient `u`, which is returned.
162+ ///
163+ /// Requires that `k < U::BITS`.
147164 fn shift_reduce ( & mut self , k : u32 ) -> U {
148165 assert ! ( k < U :: BITS ) ;
149- // 2xq << k = aRR/2 + b;
166+
167+ // First, split the shifted value:
168+ // `2xq << k = aRR/2 + b`, where `0 <= b < RR/2`
150169 let a = self . _2xq . hi ( ) >> ( U :: BITS - 1 - k) ;
151170 let ( low, high) = ( self . _2xq << k) . lo_hi ( ) ;
152171 let b = U :: D :: from_lo_hi ( low, high & ( U :: MAX >> 1 ) ) ;
153172
173+ // Then, subtract `2anq = aqm`:
174+ // ```
154175 // (2xq << k) - aqm
155176 // = aRR/2 + b - aqm
156177 // = a(RR/2 - qm) + b
157178 // = ar + b
179+ // ```
158180 self . _2xq = a. widen_mul ( self . r ) + b;
159181 a
182+
183+ // Since `a` is at most the high half of `2xq`, we have
184+ // `a + 2 < R` (shown above, in `partial_remainder`)
185+ // Using that together with `b < RR/2` and `r < m < R/2`,
186+ // we get `(a + 2)r + b < RR`, so
187+ // `ar + b < RR - 2r = 2mq`
188+ // which shows that the new remainder still satisfies `x < m`.
160189 }
161190
191+ // NB: `word_reduce()` is just the special case `shift_reduce(U::BITS - 1)`
192+ // that optimizes especially well. The correspondence is that `a == u` and
193+ // `b == (v >> 1).widen_hi()`
194+ //
162195 /// Replace the remainder `x` with `x(R/2) - un`,
163196 /// for a suitable quotient `u`, which is returned.
164197 fn word_reduce ( & mut self ) -> U {
165- // 2xq = uR + v
166- let ( v, u) = self . _2xq . lo_hi ( ) ;
167- // xqR - uqm
198+ // To do so, we replace `2xq = uR + v` with
199+ // ```
200+ // 2 * (x(R/2) - un) * q
201+ // = xqR - 2unq
202+ // = xqR - uqm
168203 // = uRR/2 + vR/2 - uRR/2 + ur
169204 // = ur + (v/2)R
205+ // ```
206+ let ( v, u) = self . _2xq . lo_hi ( ) ;
170207 self . _2xq = u. widen_mul ( self . r ) + U :: widen_hi ( v >> 1 ) ;
171208 u
209+
210+ // Additional notes:
211+ // 1. As `v` is the low bits of `2xq`, it is even and can be halved.
212+ // 2. The new remainder is `(xr + mv/2) / R` (see below)
213+ // and since `v < R`, `r < m`, `x < m < R/2`,
214+ // that is also strictly less than `m`.
215+ // ```
216+ // (x(R/2) - un)R
217+ // = xRR/2 - (m/2)uR
218+ // = x(qm + r) - (m/2)(2xq - v)
219+ // = xqm + xr - xqm + mv/2
220+ // = xr + mv/2
221+ // ```
172222 }
173223}
174224
0 commit comments