From 877fb4a279cddba590e6901a5368b3680f99cb6e Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Wed, 24 Sep 2025 10:39:06 -0600 Subject: [PATCH 01/31] Add bignum.h and port exactfloat to use it. This removes the dependency on OpenSSL for its bignum library. The main benefit of this is that targetting things like WASM are much simpler since we don't need to compile OpenSSL to WASM. I didn't try as hard for performance optimization as OpenSSL (which also has a hard requirement around constant-time operation for security purposes), but benchmarks indicate addition and subtraction is generally the same and multiplication is within a factor of two. For the size of bignums we expect in S2 there shouldn't be a noticeable difference in performance. --- CMakeLists.txt | 10 +- README.md | 8 - src/s2/util/math/exactfloat/BUILD | 8 +- src/s2/util/math/exactfloat/bignum.h | 948 ++++++++++++++++ src/s2/util/math/exactfloat/bignum_test.cc | 1137 ++++++++++++++++++++ src/s2/util/math/exactfloat/exactfloat.cc | 208 +--- src/s2/util/math/exactfloat/exactfloat.h | 53 +- 7 files changed, 2160 insertions(+), 212 deletions(-) create mode 100644 src/s2/util/math/exactfloat/bignum.h create mode 100644 src/s2/util/math/exactfloat/bignum_test.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 37e4d8ca..889698ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,7 +60,6 @@ add_definitions(-DABSL_MIN_LOG_LEVEL=1) if (NOT TARGET absl::base) find_package(absl REQUIRED) endif() -find_package(OpenSSL REQUIRED) # pthreads isn't used directly, but this is still required for std::thread. find_package(Threads REQUIRED) @@ -99,10 +98,6 @@ else() add_compile_options(-fsized-deallocation) endif() -# If OpenSSL is installed in a non-standard location, configure with -# something like: -# OPENSSL_ROOT_DIR=/usr/local/opt/openssl cmake .. -include_directories(${OPENSSL_INCLUDE_DIR}) if (WITH_PYTHON) include_directories(${Python3_INCLUDE_DIRS}) @@ -230,7 +225,6 @@ endif() target_link_libraries( s2 - ${OPENSSL_LIBRARIES} absl::absl_vlog_is_on absl::base absl::btree @@ -616,6 +610,7 @@ if (BUILD_TESTS) src/s2/s2shapeutil_shape_edge_id_test.cc src/s2/s2shapeutil_visit_crossing_edge_pairs_test.cc src/s2/s2text_format_test.cc + src/s2/util/math/exactfloat/bignum_test.cc src/s2/s2validation_query_test.cc src/s2/s2wedge_relations_test.cc src/s2/s2winding_operation_test.cc @@ -645,7 +640,8 @@ if (BUILD_TESTS) absl::status absl::strings absl::synchronization - gmock_main) + gmock_main + crypto) add_test(${test} ${test}) endforeach() endif() diff --git a/README.md b/README.md index 9a6b22de..d39c4a54 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,6 @@ This issue may require revision of boringssl or exactfloat. * [Abseil](https://github.com/abseil/abseil-cpp) >= LTS [`20250814`](https://github.com/abseil/abseil-cpp/releases/tag/20250814.1) (standard library extensions) -* [OpenSSL](https://github.com/openssl/openssl) (for its bignum library) * [googletest testing framework >= 1.10](https://github.com/google/googletest) (to build tests and example programs, optional) @@ -139,11 +138,6 @@ Disable building of shared libraries with `-DBUILD_SHARED_LIBS=OFF`. Enable the python interface with `-DWITH_PYTHON=ON`. -If OpenSSL is installed in a non-standard location set `OPENSSL_ROOT_DIR` -before running configure, for example on macOS: -``` -OPENSSL_ROOT_DIR=/opt/homebrew/Cellar/openssl@3/3.1.0 cmake -DCMAKE_PREFIX_PATH=/opt/homebrew -DCMAKE_CXX_STANDARD=17 -``` ## Installing @@ -218,8 +212,6 @@ python -m build The resulting wheel will be in the `dist` directory. -> If OpenSSL is in a non-standard location make sure to set `OPENSSL_ROOT_DIR`; -> see above for more information. ## Other S2 implementations diff --git a/src/s2/util/math/exactfloat/BUILD b/src/s2/util/math/exactfloat/BUILD index 6621eaec..bfab4c2d 100644 --- a/src/s2/util/math/exactfloat/BUILD +++ b/src/s2/util/math/exactfloat/BUILD @@ -3,12 +3,16 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "exactfloat", srcs = ["exactfloat.cc"], - hdrs = ["exactfloat.h"], + hdrs = ["exactfloat.h", "bignum.h"], deps = [ "//s2/base:port", "//s2/base:logging", "@abseil-cpp//absl/log:log", "@abseil-cpp//absl/log:absl_check", - "@boringssl//:crypto", + "@abseil-cpp//absl/strings:str_format", + "@abseil-cpp//absl/container:inlined_vector", + "@abseil-cpp//absl/numeric:bits", + "@abseil-cpp//absl/numeric:int128", + "@abseil-cpp//absl/strings:ascii", ], ) diff --git a/src/s2/util/math/exactfloat/bignum.h b/src/s2/util/math/exactfloat/bignum.h new file mode 100644 index 00000000..e53e0d46 --- /dev/null +++ b/src/s2/util/math/exactfloat/bignum.h @@ -0,0 +1,948 @@ +// Copyright 2025 Google LLC +// Author: smcallis@google.com (Sean McAllister) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "absl/container/inlined_vector.h" +#include "absl/log/absl_check.h" +#include "absl/numeric/bits.h" +#include "absl/numeric/int128.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_format.h" + +namespace internal { + +// Most of the STL cannot be overloaded per the spec, so we need to roll our own +// wrappers that will also work with absl::int128. + +template +constexpr bool IsInt = std::numeric_limits::is_integer; + +template +constexpr bool IsSigned() { + if constexpr (std::is_same_v) { + return false; + } + + if constexpr (std::is_same_v) { + return true; + } + + return std::is_signed_v; +} + +template +constexpr auto InferUnsigned() { + if constexpr (std::is_same_v || + std::is_same_v) { + return absl::uint128{}; + } else { + return std::make_unsigned_t{}; + } +} + +template +using MakeUnsigned = decltype(InferUnsigned()); + +} // namespace internal + +class Bignum { + private: + using Bigit = uint64_t; + + static constexpr int kKaratsubaThreshold = 32; + static constexpr int kBigitBits = std::numeric_limits::digits; + + public: + Bignum() = default; + + // Constructs a bignum from an integral value (signed or unsigned). + template >> + explicit Bignum(T value) { + using UT = internal::MakeUnsigned; + + if (value == 0) { + return; + } + + sign_ = +1; + if constexpr (internal::IsSigned()) { + sign_ = (value < 0) ? -1 : +1; + } + + // Get magnitude of value, handle minimum value of T cleanly. + UT mag = static_cast(value); + if constexpr (internal::IsSigned()) { + if (value < 0) { + mag = UT(0) - mag; + } + } + + // Pack the magnitude into bigits. + if constexpr (std::numeric_limits::digits <= kBigitBits) { + bigits_.push_back(static_cast(mag)); + } else { + while (mag) { + bigits_.push_back(static_cast(mag)); + mag >>= kBigitBits; + } + } + } + + // Constructs a bignum from an ASCII string containing decimal digits. + // + // The input string must only have an optional leading +/- and decimal digits. + // Any other characters will yield std::nullopt. + static std::optional FromString(absl::string_view s) { + // We can fit ~10^19 into a uint64_t. + constexpr int kMaxChunkDigits = 19; + + // NOTE: We use a simple multiply-and-add (aka Horner's) method here for the + // sake of simplicity. This isn't the fastest algorithm, being quadratic in + // the number of chunks the input has. If we use divide and conquer approach + // or an FFT based multiply we could probably make this ~O(n^1.5) or + // semi-linear. + + // Precomputed powers of 10. + static constexpr uint64_t kPow10[20] = {1ull, + 10ull, + 100ull, + 1000ull, + 10000ull, + 100000ull, + 1000000ull, + 10000000ull, + 100000000ull, + 1000000000ull, + 10000000000ull, + 100000000000ull, + 1000000000000ull, + 10000000000000ull, + 100000000000000ull, + 1000000000000000ull, + 10000000000000000ull, + 100000000000000000ull, + 1000000000000000000ull, + 10000000000000000000ull}; + + Bignum out; + if (s.empty()) { + return out; + } + + // Reserve space for bigits. + out.bigits_.reserve((s.size() + kMaxChunkDigits - 1) / kMaxChunkDigits); + + int sign = +1; + uint64_t chunk = 0; + int clen = 0; + + // Finish processing the current chunk. + auto FlushChunk = [&]() { + if (clen) { + out.MulAddSmall(kPow10[clen], chunk); + chunk = 0; + clen = 0; + } + }; + + // Consume optional +/- at the front. + int start = 0; + if ((s[0] == '+' || s[0] == '-')) { + sign = (s[0] == '-') ? -1 : +1; + ++start; + } + + bool seen_digit = false; + for (char c : s.substr(start)) { + if (!absl::ascii_isdigit(c)) { + return std::nullopt; + } + + // Accumulate digit into the local 64-bit chunk. Skip leading zeros. + uint64_t digit = static_cast(c - '0'); + if (!seen_digit && digit == 0) { + continue; + } + seen_digit = true; + + chunk = 10 * chunk + digit; + ++clen; + + if (clen == kMaxChunkDigits) { + FlushChunk(); + } + } + FlushChunk(); + + out.NormalizeSign(sign); + return out; + } + + // Formats the bignum as a decimal integer into an abseil sink. + template + friend void AbslStringify(Sink& sink, const Bignum& b) { + if (b.zero()) { + sink.Append("0"); + return; + } + + // Sign + if (b.negative()) { + sink.Append("-"); + } + + // Work on a copy of the magnitude. + Bignum copy = b; + copy.sign_ = 1; + + // Repeatedly divide and modulo by 10^19 to get decimal chunks. + static constexpr uint64_t kBase = 10000000000000000000ull; + absl::InlinedVector chunks; + + while (!copy.zero()) { + absl::uint128 rem = 0; + for (int i = static_cast(copy.bigits_.size()) - 1; i >= 0; --i) { + absl::uint128 acc = (rem << 64) + copy.bigits_[i]; + uint64_t quot = static_cast(acc / kBase); + rem = acc - absl::uint128(quot) * kBase; + copy.bigits_[i] = quot; + } + + copy.Normalize(); + chunks.push_back(static_cast(rem)); + } + ABSL_DCHECK(!chunks.empty()); + + // Emit most significant chunk without zero padding. + absl::Format(&sink, "%d", chunks.back()); + + // Emit remaining chunks as fixed-width 19-digit zero-padded blocks. + for (int i = static_cast(chunks.size()) - 2; i >= 0; --i) { + absl::Format(&sink, "%019d", chunks[i]); + } + } + + friend std::ostream& operator<<(std::ostream& os, const Bignum& b) { + return os << absl::StrFormat("%v", b); + } + + friend std::ostream& operator<<(std::ostream& os, + const std::optional& b) { + if (!b) { + return os << "[nullopt]"; + } + return os << *b; + } + + // Returns true if bignum can be stored in T without truncation or overflow. + template >> + bool Compatible() const { + using UT = internal::MakeUnsigned; + + if (sign_ == 0) { + return true; + } + + // Maximum number of bits that could fit in the output type. + constexpr int kTBitWidth = std::numeric_limits::digits; + constexpr int kMaxBigits = (kTBitWidth + (kBigitBits - 1)) / kBigitBits; + + // Fast reject if the bignum couldn't conceivably fit. + if (bigits_.size() > kMaxBigits) { + return false; + } + + // Unsigned type T can hold the value iff the value is non-negative and the + // bitwidth is <= the maximum bit width of the type. + if constexpr (!internal::IsSigned()) { + if (negative()) { + return false; + } + return BitWidth() <= kTBitWidth; + } + + // T is signed and our bignum isn't zero. + ABSL_DCHECK(internal::IsSigned() && !zero()); + + if (positive()) { + return BitWidth() <= (kTBitWidth - 1); + } else /* negative() */ { + // Magnitude must fit in negative value. If the value is negative and the + // same bit width as the output type, the only valid value is -2^(k-1). + if (BitWidth() == kTBitWidth) { + return IsPow2(kTBitWidth - 1); + } + return BitWidth() < kTBitWidth; + } + } + + // Cast to an integral type T unconditionally. Use Compatible() or + // Convert to perform conversion with bounds checking. + template >> + T Cast() const { + using UT = internal::MakeUnsigned; + + constexpr int kTBitWidth = std::numeric_limits::digits; + + if (empty()) { + return 0; + } + + // Grab the bottom bits into an unsigned value. + UT residue = 0; + for (size_t i = 0; i < bigits_.size(); ++i) { + const int shift = i * kBigitBits; + if (shift >= kTBitWidth) { + break; + } + + const int room = kTBitWidth - shift; + UT chunk = static_cast(bigits_[i]); + if (room < kBigitBits && room < std::numeric_limits::digits) { + chunk &= (UT(1) << room) - UT(1); + } + residue |= (chunk << shift); + } + + // Compute two's complement of the residue if value is negative. + if (negative()) { + residue = UT(0) - residue; + } + + return static_cast(residue); + } + + // Casts the value to the given type if it fits, otherwise std::nullopt. + template >> + std::optional Convert() const { + if (!Compatible()) { + return std::nullopt; + } + return Cast(); + } + + // // Creates a Bignum by parsing a decimal representation from a string. + // explicit Bignum(absl::string_view dec) { ParseDecimal_(dec); } + + // Returns the number of bits required to represent the bignum. + int BitWidth() const { + ABSL_DCHECK(Normalized()); + if (empty()) { + return 0; + } + + // Bit width is the bits in the least significant bigits + bit width of the + // most significant word. + const int msw_width = (kBigitBits - absl::countl_zero(bigits_.back())); + const int lsw_width = (bigits_.size() - 1) * kBigitBits; + return msw_width + lsw_width; + } + + // Returns the number of consecutive 0 bits in the value, starting from the + // least significant bit. + int CountrZero() const { + if (zero()) { + return 0; + } + + int nzero = 0; + for (Bigit bigit : bigits_) { + if (bigit == 0) { + nzero += kBigitBits; + } else { + nzero += absl::countr_zero(bigit); + break; + } + } + return nzero; + } + + // Returns true if the n-th bit of the number's magnitude is set. + bool Bit(int nbit) const { + ABSL_DCHECK_GE(nbit, 0); + if (zero()) { + return false; + } + + const int digit = nbit / kBigitBits; + const int shift = nbit % kBigitBits; + + if (digit >= size()) { + return false; + } + + return ((bigits_[digit] >> shift) & 0x1) != 0; + } + + // Clears this bignum and sets it to zero. + Bignum& SetZero() { + sign_ = 0; + bigits_.clear(); + return *this; + } + + // Unconditionally makes the sign of this bignum negative. + Bignum& SetNegative() { + sign_ = -1; + return *this; + } + + // Unconditionally makes the sign of this bignum positive. + Bignum& SetPositive() { + sign_ = +1; + return *this; + } + + // Unconditionally set the sign of this bignum to match the sign of the + // argument. If the argument is zero, set the bignum to zero. + Bignum& SetSign(int sign) { + if (sign == 0) { + return SetZero(); + } + + if (sign < 0) { + return SetNegative(); + } + return SetPositive(); + } + + // Returns true if the number is zero. + bool zero() const { // + return sign_ == 0; + } + + // Returns true if the number is greater than zero. + bool positive() const { // + return sign_ > 0; + } + + // Returns true if the number is less than zero. + bool negative() const { // + return sign_ < 0; + } + + // Returns true if the number is odd (least significant bit is 1). + bool odd() const { return Bit(0); } + + // Returns true if the number is even (least significant bit is 0). + bool even() const { return !odd(); } + + bool operator==(const Bignum& b) const { + return sign_ == b.sign_ && bigits_ == b.bigits_; + } + + bool operator!=(const Bignum& b) const { return !(*this == b); } + + bool operator<(const Bignum& b) const { return Compare(b) < 0; } + + bool operator<=(const Bignum& b) const { return Compare(b) <= 0; } + + bool operator>(const Bignum& b) const { return Compare(b) > 0; } + + bool operator>=(const Bignum& b) const { return Compare(b) >= 0; } + + Bignum operator+() const { return *this; } + + Bignum operator-() const { + Bignum result = *this; + result.sign_ = -result.sign_; + return result; + } + + Bignum& operator+=(const Bignum& b) { + if (b.zero()) { + return *this; + } + + if (zero()) { + *this = b; + return *this; + } + + if (sign_ == b.sign_) { + // Same sign: + // (+a) + (+b) == +(a + b) + // (-a) + (-b) == -(a + b) + AddAbs(b); + } else { + if (CmpAbs(b) >= 0) { + // |a| >= |b|, so a - b is same sign as a. + SubAbsGe(b); + NormalizeSign(sign_); + } else { + // |a| < |b|, so a - b is same sign as b. + SubAbsLt(b); + NormalizeSign(b.sign_); + } + } + + return *this; + } + + Bignum& operator-=(const Bignum& b) { + if (this == &b) { + bigits_.clear(); + sign_ = 0; + return *this; + } + + if (b.zero()) { + return *this; + } + + if (zero()) { + return *this = -b; + } + + if (sign_ != b.sign_) { + AddAbs(b); + } else { + if (CmpAbs(b) >= 0) { + SubAbsGe(b); + NormalizeSign(sign_); + } else { + SubAbsLt(b); + NormalizeSign(-sign_); + } + } + + return *this; + } + + // Left-shift the bignum by nbit. + Bignum& operator<<=(int nbit) { + ABSL_DCHECK_GE(nbit, 0); + if (zero() || nbit == 0) { + return *this; + } + + const int nbigit = nbit / kBigitBits; + const int nrem = nbit % kBigitBits; + + // First, handle the whole-bigit shift by inserting zeros. + bigits_.insert(bigits_.begin(), nbigit, 0); + + // Then, handle the within-bigit shift, if any. + if (nrem != 0) { + Bigit carry = 0; + for (size_t i = 0; i < bigits_.size(); ++i) { + const Bigit old_val = bigits_[i]; + bigits_[i] = (old_val << nrem) | carry; + carry = old_val >> (kBigitBits - nrem); + } + + if (carry) { + bigits_.push_back(carry); + } + } + + return *this; + } + + // Right-shift the bignum by nbit. + Bignum& operator>>=(int nbit) { + ABSL_DCHECK_GE(nbit, 0); + if (zero() || nbit == 0) { + return *this; + } + + // Shifting by more than the bit width results in zero. + if (nbit >= BitWidth()) { + bigits_.clear(); + sign_ = 0; + return *this; + } + + const int nbigit = nbit / kBigitBits; + const int nrem = nbit % kBigitBits; + + // First, handle the whole-bigit shift by removing bigits. + bigits_.erase(bigits_.begin(), bigits_.begin() + nbigit); + + // Then, handle the within-bigit shift, if any. + if (nrem != 0) { + Bigit carry = 0; + for (int i = static_cast(bigits_.size()) - 1; i >= 0; --i) { + const Bigit old_val = bigits_[i]; + bigits_[i] = (old_val >> nrem) | carry; + carry = old_val << (kBigitBits - nrem); + } + } + + // Result might be smaller or zero, so normalize. + NormalizeSign(sign_); + return *this; + } + + Bignum& operator*=(const Bignum& b) { + if (zero() || b.zero()) { + bigits_.clear(); + sign_ = 0; + return *this; + } + + const int new_sign = sign_ * b.sign_; + bigits_ = MulAbs(bigits_, b.bigits_); + NormalizeSign(new_sign); + return *this; + } + + // Raise this value to the given power, which must be non-negative. + Bignum Pow(int32_t pow) const { + ABSL_DCHECK_GE(pow, 0); + + // Anything to the zero-th power is 1 (including zero). + if (pow == 0) { + return Bignum(1); + } + + if (zero()) { + return Bignum(0); + } + + if (*this == Bignum(1)) { + return Bignum(1); + } + + if (*this == Bignum(-1)) { + return (pow % 2 != 0) ? Bignum(-1) : Bignum(1); + } + + // Core algorithm: Exponentiation by squaring. + Bignum result(1); + Bignum base = *this; // A mutable copy of the base. + uint32_t upow = static_cast(pow); + + while (upow > 0) { + if (upow & 1) { // If current exponent bit is 1, multiply into result. + result *= base; + } + base *= base; + upow >>= 1; + } + + return result; + } + + friend Bignum operator*(Bignum a, const Bignum& b) { return a *= b; } + + friend Bignum operator+(Bignum a, const Bignum& b) { return a += b; } + + friend Bignum operator-(Bignum a, const Bignum& b) { return a -= b; } + + friend Bignum operator<<(Bignum a, int nbit) { return a <<= nbit; } + + friend Bignum operator>>(Bignum a, int nbit) { return a >>= nbit; } + + private: + // Construct a Bignum from bigits and an optional sign bit. + explicit Bignum(absl::Span bigits, int sign = +1) { + if (bigits.empty()) { + return; + } + + bigits_.assign(bigits.begin(), bigits.end()); + NormalizeSign(sign); + } + + // Returns the number of bigits in this bignum. + int size() const { // + return bigits_.size(); + } + + // Returns true if this value has no digits. + bool empty() const { // + return bigits_.empty(); + } + + // Compare to another bignum, returns -1, 0, +1. + int Compare(const Bignum& b) const { + if (sign_ != b.sign_) { + return sign_ < b.sign_ ? -1 : 1; + } + + // Signs are equal, are they both zero? + if (sign_ == 0) { + return 0; + } + + // Signs are equal and non-zero, compare magnitude. + return positive() ? CmpAbs(b) : -CmpAbs(b); + } + + // Compute value = value * mul + add where mul, add ≤ 10^19. We can accumulate + // using a 128 bit integer in a single pass over the bigits for small terms. + void MulAddSmall(uint64_t mul, uint64_t add) { + absl::uint128 carry = add; + for (size_t i = 0, n = bigits_.size(); i < n; ++i) { + absl::uint128 prod = absl::uint128(bigits_[i]) * mul + carry; + bigits_[i] = absl::Uint128Low64(prod); + carry = absl::Uint128High64(prod); + } + + if (carry != 0) { + bigits_.push_back(absl::Uint128Low64(carry)); + } + } + + // Multiplies two bigit operands. Uses either simple quadratic multiplication + // (SimpleMulAbs) or divide-and-conquer multiplication (KaratsubaMulAbs) based + // on the size of the operands. + static absl::InlinedVector MulAbs( + absl::Span a, absl::Span b) { + // Fast path for single-bigit multiplication. + if (a.size() == 1 && b.size() == 1) { + absl::uint128 prod = absl::uint128(a[0]) * b[0]; + const uint64_t lo = absl::Uint128Low64(prod); + const uint64_t hi = absl::Uint128High64(prod); + if (hi == 0) { + return {lo}; + } + return {lo, hi}; + } + + if (a.size() < kKaratsubaThreshold || b.size() < kKaratsubaThreshold) { + return SimpleMulAbs(a, b); + } + return KaratsubaMulAbs(a, b); + } + + // Performs simple quadratic long multiplication between two sets of bigits. + static absl::InlinedVector SimpleMulAbs( + absl::Span a, absl::Span b) { + if (a.empty() || b.empty()) { + return {}; + } + + absl::InlinedVector result(a.size() + b.size(), 0); + for (size_t i = 0; i < a.size(); ++i) { + if (a[i] == 0) { + continue; + } + + absl::uint128 carry = 0; + for (size_t j = 0; j < b.size(); ++j) { + absl::uint128 prod = absl::uint128(a[i]) * b[j] + result[i + j] + carry; + result[i + j] = absl::Uint128Low64(prod); + carry = absl::Uint128High64(prod); + } + + // Propagate final carry. This can ripple through multiple bigits. + for (size_t k = i + b.size(); carry != 0 && k < result.size(); ++k) { + absl::uint128 sum = absl::uint128(result[k]) + carry; + result[k] = absl::Uint128Low64(sum); + carry = absl::Uint128High64(sum); + } + ABSL_DCHECK_EQ(carry, 0); // Result vector must be large enough. + } + + // Normalize vector by removing leading zeros. + while (!result.empty() && result.back() == 0) { + result.pop_back(); + } + return result; + } + + // Repeatedly divides a multiplication in half and recurses, stitching + // results back together to get the final result. + static absl::InlinedVector KaratsubaMulAbs( + absl::Span a, absl::Span b) { + // Base case is handled by MulAbs dispatcher, so we only handle recursion. + const size_t n = std::max(a.size(), b.size()); + const size_t m = (n + 1) / 2; + + const Bignum a0(a.subspan(0, std::min(m, a.size()))); + const Bignum a1(a.size() > m ? a.subspan(m) : absl::Span()); + const Bignum b0(b.subspan(0, std::min(m, b.size()))); + const Bignum b1(b.size() > m ? b.subspan(m) : absl::Span()); + + Bignum z2; + z2.bigits_ = MulAbs(a1.bigits_, b1.bigits_); + z2.NormalizeSign(+1); + + Bignum z0; + z0.bigits_ = MulAbs(a0.bigits_, b0.bigits_); + z0.NormalizeSign(+1); + + Bignum z1; + z1.bigits_ = MulAbs((a0 + a1).bigits_, (b0 + b1).bigits_); + z1.NormalizeSign(+1); + + z1 -= z2; + z1 -= z0; + + // Recombine: result = (z2 << 2*m) + (z1 << m) + z0 + z2 <<= (2 * m * kBigitBits); + z1 <<= (m * kBigitBits); + + Bignum result = z2 + z1 + z0; + return result.bigits_; + } + + // Drop leading zero bigits. + void Normalize() { + while (!empty() && bigits_.back() == 0) { + bigits_.pop_back(); + } + + if (empty()) { + sign_ = 0; + } + } + + // Drop leading zero bigits and canonicalize sign. + void NormalizeSign(int sign) { + Normalize(); + sign_ = empty() ? 0 : sign; + } + + // Returns true if the bignum is in normal form (no extra leading zeros). + bool Normalized() const { // + return bigits_.empty() || bigits_.back() != 0; + } + + // Returns true if the bignum is the given power of two. + bool IsPow2(int pow2) const { + const int bigits = pow2 / kBigitBits; + if (bigits_.size() != bigits + 1) { + return false; + } + + // Verify lower words are zero. + for (int i = 0; i < bigits; ++i) { + if (bigits_[i] != 0) { + return false; + } + } + + // Check final word is power of two. + pow2 -= bigits * kBigitBits; + ABSL_DCHECK_LT(pow2, kBigitBits); + return bigits_.back() == (Bigit(1) << pow2); + } + + // Compares magnitude with another bignum, returning -1, 0, or +1. + int CmpAbs(const Bignum& b) const; + + // Adds another bignum to this bignum in place. + void AddAbs(const Bignum& b); + + // In-place subtraction: *this = |*this| - |b|, assuming |*this| >= |b|. + void SubAbsGe(const Bignum& b); + + // In-place subtraction: *this = |b| - |*this|, assuming |*this| < |b|. + void SubAbsLt(const Bignum& b); + + absl::InlinedVector bigits_; + char sign_ = 0; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Implementation Details +//////////////////////////////////////////////////////////////////////////////// + +inline int Bignum::CmpAbs(const Bignum& b) const { + if (size() != b.size()) { + return size() < b.size() ? -1 : +1; + } + + for (int i = size() - 1; i >= 0; --i) { + if (bigits_[i] != b.bigits_[i]) { + return bigits_[i] < b.bigits_[i] ? -1 : +1; + } + } + + return 0; +} + +inline void Bignum::AddAbs(const Bignum& b) { + // Grow if needed. + const bool a_longer = size() > b.size(); + const size_t min_size = std::min(size(), b.size()); + const size_t max_size = std::max(size(), b.size()); + bigits_.resize(max_size, 0); + + // Add common parts. + absl::uint128 sum; + absl::uint128 carry = 0; + for (size_t i = 0; i < min_size; ++i) { + sum = absl::uint128(bigits_[i]) + b.bigits_[i] + carry; + bigits_[i] = absl::Uint128Low64(sum); + carry = absl::Uint128High64(sum); + } + + // Propagate carry through the longer operand. + const auto* longer = a_longer ? this : &b; + for (size_t i = min_size; i < max_size; ++i) { + sum = absl::uint128(longer->bigits_[i]) + carry; + bigits_[i] = absl::Uint128Low64(sum); + carry = absl::Uint128High64(sum); + } + + if (carry) { + bigits_.push_back(absl::Uint128Low64(carry)); + } +} + +inline void Bignum::SubAbsGe(const Bignum& b) { + ABSL_DCHECK_GE(CmpAbs(b), 0); + uint64_t borrow = 0; + + size_t i = 0; + for (; i < b.size(); ++i) { + const uint64_t d1 = bigits_[i]; + const uint64_t d2 = b.bigits_[i]; + const uint64_t diff = d1 - d2 - borrow; + borrow = (d1 < d2) || (borrow && d1 == d2); + bigits_[i] = diff; + } + + for (; borrow && i < bigits_.size(); ++i) { + borrow = (bigits_[i] == 0); + bigits_[i]--; + } + ABSL_DCHECK(!borrow); + Normalize(); +} + +inline void Bignum::SubAbsLt(const Bignum& b) { + ABSL_DCHECK_LT(CmpAbs(b), 0); + uint64_t borrow = 0; + const size_t n_this = bigits_.size(); + const size_t n_b = b.size(); + bigits_.resize(n_b); + + size_t i = 0; + for (; i < n_this; ++i) { + const uint64_t d1 = b.bigits_[i]; + const uint64_t d2 = bigits_[i]; + const uint64_t diff = d1 - d2 - borrow; + borrow = (d1 < d2) || (borrow && d1 == d2); + bigits_[i] = diff; + } + + for (; i < n_b; ++i) { + const uint64_t d1 = b.bigits_[i]; + const uint64_t diff = d1 - borrow; + borrow = (borrow && d1 == 0); + bigits_[i] = diff; + } + ABSL_DCHECK(!borrow); + Normalize(); +} diff --git a/src/s2/util/math/exactfloat/bignum_test.cc b/src/s2/util/math/exactfloat/bignum_test.cc new file mode 100644 index 00000000..f6838d89 --- /dev/null +++ b/src/s2/util/math/exactfloat/bignum_test.cc @@ -0,0 +1,1137 @@ +// Copyright 2025 Google LLC +// Author: smcallis@google.com (Sean McAllister) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "s2/util/math/exactfloat/bignum.h" + +#include +#include +#include +#include + +#if 0 +#include "absl/base/no_destructor.h" +#include "absl/strings/string_view.h" +#include "benchmark/benchmark.h" +#include "openssl/bn.h" +#include "openssl/crypto.h" +#endif + +#include "gtest/gtest.h" + +const uint64_t u8max = std::numeric_limits::max(); +const uint64_t u16max = std::numeric_limits::max(); +const uint64_t u32max = std::numeric_limits::max(); +const uint64_t u64max = std::numeric_limits::max(); + +const int64_t i8max = std::numeric_limits::max(); +const int64_t i16max = std::numeric_limits::max(); +const int64_t i32max = std::numeric_limits::max(); +const int64_t i64max = std::numeric_limits::max(); + +const int64_t i8min = std::numeric_limits::min(); +const int64_t i16min = std::numeric_limits::min(); +const int64_t i32min = std::numeric_limits::min(); +const int64_t i64min = std::numeric_limits::min(); + +// To reduce duplication. +inline auto Bn(absl::string_view str) { return Bignum::FromString(str); }; + +TEST(BignumTest, ZeroInputsNormalizeToZero) { + EXPECT_EQ(Bn(""), Bignum(0)); + EXPECT_EQ(Bn("+"), Bignum(0)); + EXPECT_EQ(Bn("-"), Bignum(0)); + EXPECT_EQ(Bn("0"), Bignum(0)); + EXPECT_EQ(Bn("-0"), Bignum(0)); + EXPECT_EQ(Bn("000000"), Bignum(0)); + EXPECT_EQ(Bn("-000000"), Bignum(0)); + EXPECT_EQ(Bn("-00000000000000000000"), Bignum(0)); + EXPECT_EQ(Bn("+00000000000000000000"), Bignum(0)); +} + +TEST(BignumTest, BasicSmallNumbers) { + EXPECT_EQ(Bn("42"), Bignum(42)); + EXPECT_EQ(Bn("-17"), Bignum(-17)); + EXPECT_EQ(Bn("+123"), Bignum(123)); +} + +TEST(BignumTest, LeadingZeros) { + EXPECT_EQ(Bn("0000042"), Bignum(42)); + EXPECT_EQ(Bn("-00017"), Bignum(-17)); + EXPECT_EQ(Bn("+00007"), Bignum(7)); + + // Larger bignums. + // 10^19 and 10^19 - 1, near chunk boundaries. + EXPECT_EQ(Bn("100000000000000000"), Bn("000100000000000000000")); + EXPECT_EQ(Bn("99999999999999999"), Bn("000099999999999999999")); + + // 2^64 - 1 cross-check against integral constructor. + EXPECT_EQ(Bn("18446744073709551615"), Bignum(18446744073709551615ull)); + + // 2^64 cross-check with leading zeros variant via string constructor. + EXPECT_EQ(Bn("18446744073709551616"), Bn("00018446744073709551616")); + EXPECT_EQ(Bn("-18446744073709551616"), Bn("-00018446744073709551616")); + + // Multi-digit bignums. + // 2^80 + EXPECT_EQ(Bn("1208925819614629174706176"), + Bn("0001208925819614629174706176")); + EXPECT_EQ(Bn("-1208925819614629174706176"), + Bn("-0001208925819614629174706176")); + + // 2^128 + EXPECT_EQ(Bn("340282366920938463463374607431768211456"), + Bn("000340282366920938463463374607431768211456")); + EXPECT_EQ(Bn("-340282366920938463463374607431768211456"), + Bn("-000340282366920938463463374607431768211456")); +} + +TEST(BignumTest, MultipleSignsBeforeDigitsCausesFailure) { + EXPECT_EQ(Bn("++123"), std::nullopt); + EXPECT_EQ(Bn("--42"), std::nullopt); + EXPECT_EQ(Bn("+-9"), std::nullopt); + EXPECT_EQ(Bn("-+9"), std::nullopt); +} + +TEST(BignumTest, SignInWrongPlaceCausesFailure) { + EXPECT_EQ(Bn("123-"), std::nullopt); + EXPECT_EQ(Bn("456+"), std::nullopt); + EXPECT_EQ(Bn("-789+"), std::nullopt); + EXPECT_EQ(Bn("+314-"), std::nullopt); +} + +TEST(BignumTest, ZeroAlwaysCompatible) { + const Bignum zero(0); + EXPECT_TRUE(zero.Compatible()); + EXPECT_TRUE(zero.Compatible()); + EXPECT_TRUE(zero.Compatible()); + EXPECT_TRUE(zero.Compatible()); + EXPECT_TRUE(zero.Compatible()); + EXPECT_TRUE(zero.Compatible()); + EXPECT_TRUE(zero.Compatible()); + EXPECT_TRUE(zero.Compatible()); + EXPECT_TRUE(zero.Compatible()); + EXPECT_TRUE(zero.Compatible()); +} + +TEST(BignumTest, ZeroAlwaysCastsToZero) { + Bignum zero; + EXPECT_EQ(zero.Cast(), 0); + EXPECT_EQ(zero.Cast(), 0); + EXPECT_EQ(zero.Cast(), 0); + EXPECT_EQ(zero.Cast(), 0); + EXPECT_EQ(zero.Cast(), 0); + EXPECT_EQ(zero.Cast(), 0); + EXPECT_EQ(zero.Cast(), 0); + EXPECT_EQ(zero.Cast(), 0); +} + +TEST(BignumTest, NegativeOnlyCompatibleSigned) { + const Bignum small_neg(-1); + EXPECT_FALSE(small_neg.Compatible()); + EXPECT_FALSE(small_neg.Compatible()); + EXPECT_FALSE(small_neg.Compatible()); + EXPECT_FALSE(small_neg.Compatible()); + EXPECT_FALSE(small_neg.Compatible()); + + EXPECT_TRUE(small_neg.Compatible()); + EXPECT_TRUE(small_neg.Compatible()); + EXPECT_TRUE(small_neg.Compatible()); + EXPECT_TRUE(small_neg.Compatible()); + EXPECT_TRUE(small_neg.Compatible()); +} + +TEST(BignumTest, CompatibleUnsignedBoundsChecks) { + const Bignum bn_u8max(u8max); + const Bignum bn_u8over(u8max + 1); + EXPECT_TRUE(bn_u8max.Compatible()); + EXPECT_TRUE(bn_u8max.Compatible()); + EXPECT_TRUE(bn_u8max.Compatible()); + EXPECT_FALSE(bn_u8over.Compatible()); + EXPECT_TRUE(bn_u8over.Compatible()); + EXPECT_TRUE(bn_u8over.Compatible()); + + const Bignum bn_u16max(u16max); + const Bignum bn_u16over(u16max + 1); + EXPECT_FALSE(bn_u16max.Compatible()); + EXPECT_TRUE(bn_u16max.Compatible()); + EXPECT_TRUE(bn_u16max.Compatible()); + EXPECT_FALSE(bn_u16over.Compatible()); + EXPECT_FALSE(bn_u16over.Compatible()); + EXPECT_TRUE(bn_u16over.Compatible()); + + const Bignum bn_u32max(u32max); + const Bignum bn_u32over(u32max + 1); + EXPECT_FALSE(bn_u32max.Compatible()); + EXPECT_FALSE(bn_u32max.Compatible()); + EXPECT_TRUE(bn_u32max.Compatible()); + EXPECT_FALSE(bn_u32over.Compatible()); + EXPECT_FALSE(bn_u32over.Compatible()); + EXPECT_FALSE(bn_u32over.Compatible()); + + const Bignum bn_u64max(u64max); + EXPECT_TRUE(bn_u64max.Compatible()); + + // 2^64, need to use string constructor. + Bignum bn0 = *Bn("18446744073709551616"); + EXPECT_FALSE(bn0.Compatible()); + EXPECT_TRUE(bn0.Compatible()); + + // (2^128 - 1) fits in absl::uint128. + Bignum bn1 = *Bn("340282366920938463463374607431768211455"); + EXPECT_TRUE(bn1.Compatible()); + + // 2^128 does not fit in absl::uint128. + Bignum bn2 = *Bn("340282366920938463463374607431768211456"); + EXPECT_FALSE(bn2.Compatible()); +} + +TEST(BignumTest, CompatibleSignedBoundsChecks) { + const Bignum bn_i8max(i8max); + const Bignum bn_i8over(i8max + 1); + EXPECT_TRUE(bn_i8max.Compatible()); + EXPECT_TRUE(bn_i8max.Compatible()); + EXPECT_TRUE(bn_i8max.Compatible()); + EXPECT_FALSE(bn_i8over.Compatible()); + EXPECT_TRUE(bn_i8over.Compatible()); + EXPECT_TRUE(bn_i8over.Compatible()); + + const Bignum bn_i16max(i16max); + const Bignum bn_i16over(i16max + 1); + EXPECT_FALSE(bn_i16max.Compatible()); + EXPECT_TRUE(bn_i16max.Compatible()); + EXPECT_TRUE(bn_i16max.Compatible()); + EXPECT_FALSE(bn_i16over.Compatible()); + EXPECT_FALSE(bn_i16over.Compatible()); + EXPECT_TRUE(bn_i16over.Compatible()); + + const Bignum bn_i32max(i32max); + const Bignum bn_i32over(i32max + 1); + EXPECT_FALSE(bn_i32max.Compatible()); + EXPECT_FALSE(bn_i32max.Compatible()); + EXPECT_TRUE(bn_i32max.Compatible()); + EXPECT_FALSE(bn_i32over.Compatible()); + EXPECT_FALSE(bn_i32over.Compatible()); + EXPECT_FALSE(bn_i32over.Compatible()); + + Bignum bn_i64max(i64max); + EXPECT_TRUE(bn_i64max.Compatible()); + + // 2^63, need to use string constructor. + Bignum bn0 = *Bn("9223372036854775808"); + EXPECT_FALSE(bn0.Compatible()); + + const Bignum bn_i8min(i8min); + const Bignum bn_i8under(i8min - 1); + EXPECT_TRUE(bn_i8min.Compatible()); + EXPECT_TRUE(bn_i8min.Compatible()); + EXPECT_TRUE(bn_i8min.Compatible()); + EXPECT_FALSE(bn_i8under.Compatible()); + EXPECT_TRUE(bn_i8under.Compatible()); + EXPECT_TRUE(bn_i8under.Compatible()); + + const Bignum bn_i16min(i16min); + const Bignum bn_i16under(i16min - 1); + EXPECT_FALSE(bn_i16min.Compatible()); + EXPECT_TRUE(bn_i16min.Compatible()); + EXPECT_TRUE(bn_i16min.Compatible()); + EXPECT_FALSE(bn_i16under.Compatible()); + EXPECT_FALSE(bn_i16under.Compatible()); + EXPECT_TRUE(bn_i16under.Compatible()); + + const Bignum bn_i32min(i32min); + const Bignum bn_i32under(i32min - 1); + EXPECT_FALSE(bn_i32min.Compatible()); + EXPECT_FALSE(bn_i32min.Compatible()); + EXPECT_TRUE(bn_i32min.Compatible()); + EXPECT_FALSE(bn_i32under.Compatible()); + EXPECT_FALSE(bn_i32under.Compatible()); + EXPECT_FALSE(bn_i32under.Compatible()); + + Bignum bn_i64min(i64min); + EXPECT_TRUE(bn_i64min.Compatible()); + + // -(2^63) - 1 doesn't fit in int64_t. + Bignum b0 = *Bn("-9223372036854775809"); + EXPECT_FALSE(b0.Compatible()); + + // Exact min and max of signed 128. + Bignum bn_s128min = *Bn("-170141183460469231731687303715884105728"); + Bignum bn_s128max = *Bn("170141183460469231731687303715884105727"); + EXPECT_TRUE(bn_s128min.Compatible()); + EXPECT_TRUE(bn_s128max.Compatible()); + + // +2^127 does not fit in signed 128, but does in unsigned 128. + Bignum bn1 = *Bn("170141183460469231731687303715884105728"); + EXPECT_FALSE(bn1.Compatible()); + EXPECT_TRUE(bn1.Compatible()); + + // Below min: -(2^127) - 1 should not fit. + Bignum bn2 = *Bn("-170141183460469231731687303715884105729"); + EXPECT_FALSE(bn2.Compatible()); +} + +TEST(BignumTest, CompatibleBasicSanityChecks) { + Bignum pos42(42); + EXPECT_TRUE(pos42.Compatible()); + EXPECT_TRUE(pos42.Compatible()); + EXPECT_TRUE(pos42.Compatible()); + EXPECT_TRUE(pos42.Compatible()); + EXPECT_TRUE(pos42.Compatible()); + EXPECT_TRUE(pos42.Compatible()); + EXPECT_TRUE(pos42.Compatible()); + EXPECT_TRUE(pos42.Compatible()); + EXPECT_TRUE(pos42.Compatible()); + EXPECT_TRUE(pos42.Compatible()); + + Bignum neg42(-42); + EXPECT_TRUE(neg42.Compatible()); + EXPECT_FALSE(neg42.Compatible()); + EXPECT_TRUE(neg42.Compatible()); + EXPECT_FALSE(neg42.Compatible()); + EXPECT_TRUE(neg42.Compatible()); + EXPECT_FALSE(neg42.Compatible()); + EXPECT_TRUE(neg42.Compatible()); + EXPECT_FALSE(neg42.Compatible()); + EXPECT_TRUE(neg42.Compatible()); + EXPECT_FALSE(neg42.Compatible()); +} + +TEST(BignumTest, UnsignedCasting) { + EXPECT_EQ(Bignum(300).Cast(), static_cast(300)); + + // Negative to unsigned -> maximum value. + Bignum bn1(-1); + EXPECT_EQ(bn1.Cast(), std::numeric_limits::max()); + EXPECT_EQ(bn1.Cast(), std::numeric_limits::max()); + EXPECT_EQ(bn1.Cast(), std::numeric_limits::max()); + EXPECT_EQ(bn1.Cast(), std::numeric_limits::max()); + + // Big values via decimal: 2^64, 2^128-1, 2^128 + Bignum bn2 = *Bn("18446744073709551616"); + EXPECT_EQ(bn2.Cast(), 0); + + // 2^128 - 1 -> lower 64 bits = 2^64 - 1 + Bignum bn3 = *Bn("340282366920938463463374607431768211455"); + EXPECT_EQ(bn3.Cast(), std::numeric_limits::max()); + + // 2^128 -> lower 64 bits = 0 + Bignum bn4 = *Bn("340282366920938463463374607431768211456"); + EXPECT_EQ(bn4.Cast(), 0); +} + +TEST(BignumTest, SignedCasting) { + // In-range positives stay the same. + Bignum bn0(127); + EXPECT_EQ(bn0.Cast(), 127); + + // Positive overflow wraps into negative range. + Bignum bn1(128); + EXPECT_EQ(bn1.Cast(), -128); + + // In-range negatives stay the same. + Bignum bn2(-128); + EXPECT_EQ(bn2.Cast(), -128); + + // Negative overflow wraps into positive range. + Bignum bn3(-129); + EXPECT_EQ(bn3.Cast(), 127); + + // +2^63 over to -2^63 in signed int64. + Bignum bn4 = *Bn("9223372036854775808"); + EXPECT_EQ(bn4.Cast(), std::numeric_limits::min()); + + // +2^64 - 1 casts to -1. + Bignum bn5 = *Bn("18446744073709551615"); + EXPECT_EQ(bn5.Cast(), -1); + + // -(2^63) - 1 casts to 2^63 - 3 + Bignum bn6 = *Bn("-9223372036854775809"); + EXPECT_EQ(bn6.Cast(), std::numeric_limits::max()); +} + +TEST(BignumTest, CastingLargeResidues) { + // 2^80 + 0x1234 -> low 64 bits should be 0x1234. + Bignum bn0 = *Bn("1208925819614629174710836"); + EXPECT_EQ(bn0.Cast(), 0x1234); + EXPECT_EQ(bn0.Cast(), 0x1234); + + // -(2^80 + 1) -> low 64 bits = 0xFFFFFFFFFFFFFFFF; signed = -1 + Bignum bn1 = *Bn("-1208925819614629174706177"); + EXPECT_EQ(bn1.Cast(), std::numeric_limits::max()); + EXPECT_EQ(bn1.Cast(), -1); +} + +TEST(BignumTest, AbslUint128Casting) { + Bignum neg1(-1); + EXPECT_EQ(neg1.Cast(), ~absl::uint128(0)); + + // 2^128 -> low 128 bits == 0 + Bignum bn1 = *Bn("340282366920938463463374607431768211456"); + EXPECT_EQ(bn1.Cast(), absl::uint128(0)); + + // 2^200 + 5 -> low 128 bits == 5 + Bignum bn2 = + *Bn("1606938044258990275541962092341162602522202993782792835301381"); + EXPECT_EQ(bn2.Cast(), absl::uint128(5)); +} + +TEST(BignumTest, AbslInt128Casting) { + const absl::int128 two127 = absl::int128(1) << 127; + + // +2^127 -> wraps to -2^127 + Bignum bn0 = *Bn("170141183460469231731687303715884105728"); + EXPECT_EQ(bn0.Cast(), 0 - two127); + + // -(2^127) - 1 -> wraps to +2^127 - 1 + Bignum bn1 = *Bn("-170141183460469231731687303715884105729"); + EXPECT_EQ(bn1.Cast(), two127 - 1); + + // 2^200 + 5 -> low 128 bits == 5 + Bignum bn2 = + *Bn("1606938044258990275541962092341162602522202993782792835301381"); + EXPECT_EQ(bn2.Cast(), absl::int128(5)); + + // -(2^200 + 5) -> low 128 bits == -5 + Bignum bn3 = + *Bn("-1606938044258990275541962092341162602522202993782792835301381"); + EXPECT_EQ(bn3.Cast(), absl::int128(-5)); +} + +TEST(BignumTest, UnaryOperators) { + EXPECT_EQ(+Bignum(0), Bignum(0)); + EXPECT_EQ(-Bignum(0), Bignum(0)); + EXPECT_EQ(+Bignum(+42), Bignum(+42)); + EXPECT_EQ(+Bignum(-42), Bignum(-42)); + EXPECT_EQ(-Bignum(+42), Bignum(-42)); + EXPECT_EQ(-Bignum(-17), Bignum(+17)); +} + +TEST(BignumTest, Addition) { + // Basic combinations of signs + EXPECT_EQ(Bignum(+5) + Bignum(+3), Bignum(+8)); + EXPECT_EQ(Bignum(-5) + Bignum(-3), Bignum(-8)); + EXPECT_EQ(Bignum(+5) + Bignum(-3), Bignum(+2)); + EXPECT_EQ(Bignum(-5) + Bignum(+3), Bignum(-2)); + + // Identity and additive inverse + EXPECT_EQ(Bignum(42) + Bignum(0), Bignum(42)); + EXPECT_EQ(Bignum(0) + Bignum(42), Bignum(42)); + EXPECT_EQ(Bignum(5) + Bignum(-5), Bignum(0)); + + // Carry propagation + const auto bn_u64max = Bignum(u64max); + EXPECT_EQ(bn_u64max + Bignum(1), *Bn("18446744073709551616")); + EXPECT_EQ(bn_u64max + bn_u64max, *Bn("36893488147419103230")); + + // Aliasing (x += x) + Bignum a = bn_u64max; + a += a; + EXPECT_EQ(a, *Bn("36893488147419103230")); +} + +TEST(BignumTest, Subtraction) { + // Basic combinations of signs + EXPECT_EQ(Bignum(+5) - Bignum(+3), Bignum(+2)); + EXPECT_EQ(Bignum(+3) - Bignum(+5), Bignum(-2)); + EXPECT_EQ(Bignum(-5) - Bignum(-3), Bignum(-2)); + EXPECT_EQ(Bignum(+5) - Bignum(-3), Bignum(+8)); + EXPECT_EQ(Bignum(-5) - Bignum(+3), Bignum(-8)); + + // Identity and subtracting to zero + EXPECT_EQ(Bignum(42) - Bignum(0), Bignum(42)); + EXPECT_EQ(Bignum(0) - Bignum(42), Bignum(-42)); + EXPECT_EQ(Bignum(42) - Bignum(42), Bignum(0)); + + // Borrow propagation + const auto bn_u64max = Bignum(u64max); + const auto two_pow_64 = *Bn("18446744073709551616"); + EXPECT_EQ(two_pow_64 - Bignum(1), bn_u64max); + + // Aliasing (x -= x) + Bignum a(100); + a -= a; + EXPECT_EQ(a, Bignum(0)); +} + +TEST(BignumTest, MixedOperations) { + Bignum a(10), b(20), c(-5); + EXPECT_EQ((a + b) - a, b); + EXPECT_EQ((b - a) + a, b); + EXPECT_EQ(a + c, Bignum(5)); + EXPECT_EQ(c - a, Bignum(-15)); +} + +TEST(BignumTest, LargeNumberArithmetic) { + const auto two_pow_128_minus_1 = + *Bn("340282366920938463463374607431768211455"); + const auto two_pow_128 = *Bn("340282366920938463463374607431768211456"); + const auto two_pow_64 = *Bn("18446744073709551616"); + + // Test multi-bigit carry propagation: (2^128 - 1) + 1 = 2^128 + EXPECT_EQ(two_pow_128_minus_1 + Bignum(1), two_pow_128); + + // Test multi-bigit borrow propagation: 2^128 - 1 = (2^128 - 1) + EXPECT_EQ(two_pow_128 - Bignum(1), two_pow_128_minus_1); + + // Subtraction resulting in a sign change with large numbers. + const auto neg_two_pow_128_minus_1 = + *Bn("-340282366920938463463374607431768211455"); + EXPECT_EQ(Bignum(1) - two_pow_128, neg_two_pow_128_minus_1); + + // Addition of large numbers with different signs (triggers subtraction). + EXPECT_EQ(two_pow_128 + Bignum(-1), two_pow_128_minus_1); + + // Subtraction of large numbers with different signs (triggers addition). + EXPECT_EQ(two_pow_128_minus_1 - Bignum(-1), two_pow_128); + + // Add two different large positive numbers: 2^128 + 2^64 + const auto sum_128_64 = *Bn("340282366920938463481821351505477763072"); + EXPECT_EQ(two_pow_128 + two_pow_64, sum_128_64); + + // Subtract two different large positive numbers: 2^128 - 2^64 + const auto diff_128_64 = *Bn("340282366920938463444927863358058659840"); + EXPECT_EQ(two_pow_128 - two_pow_64, diff_128_64); +} + +TEST(BignumTest, LeftShift) { + EXPECT_EQ((Bignum(1) << 0), Bignum(1)); + EXPECT_EQ((Bignum(1) << 1), Bignum(2)); + EXPECT_EQ((Bignum(1) << 63), *Bn("9223372036854775808")); + EXPECT_EQ((Bignum(1) << 64), *Bn("18446744073709551616")); + EXPECT_EQ((Bignum(-1) << 64), *Bn("-18446744073709551616")); + + const auto bn_u64max = Bignum(u64max); + const auto two_pow_128_minus_two_pow_64 = + *Bn("340282366920938463444927863358058659840"); + EXPECT_EQ((bn_u64max << 64), two_pow_128_minus_two_pow_64); + + Bignum a(5); + a <<= 2; + EXPECT_EQ(a, Bignum(20)); + + // Shifting zero or by zero amount. + EXPECT_EQ((Bignum(0) << 100), Bignum(0)); + EXPECT_EQ((Bignum(123) << 0), Bignum(123)); +} + +TEST(BignumTest, RightShift) { + EXPECT_EQ((Bignum(8) >> 0), Bignum(8)); + EXPECT_EQ((Bignum(8) >> 3), Bignum(1)); + EXPECT_EQ((Bignum(8) >> 4), Bignum(0)); + EXPECT_EQ((Bignum(7) >> 2), Bignum(1)); + EXPECT_EQ((Bignum(-7) >> 2), Bignum(-1)); + + const auto two_pow_64 = *Bn("18446744073709551616"); + EXPECT_EQ((two_pow_64 >> 1), *Bn("9223372036854775808")); + EXPECT_EQ((two_pow_64 >> 64), Bignum(1)); + EXPECT_EQ((two_pow_64 >> 65), Bignum(0)); + + const auto u64max = Bignum(std::numeric_limits::max()); + const auto two_pow_128_minus_1 = + *Bn("340282366920938463463374607431768211455"); + EXPECT_EQ((two_pow_128_minus_1 >> 64), u64max); + + const auto two_pow_65_minus_1 = *Bn("36893488147419103231"); + EXPECT_EQ((two_pow_65_minus_1 >> 63), Bignum(3)); + + Bignum b(20); + b >>= 2; + EXPECT_EQ(b, Bignum(5)); + + // Shifting zero or by zero amount. + EXPECT_EQ((Bignum(0) >> 100), Bignum(0)); + EXPECT_EQ((Bignum(123) >> 0), Bignum(123)); + + // Shifting to zero. + EXPECT_EQ((Bignum(100) >> 100), Bignum(0)); +} + +TEST(BignumTest, Multiplication) { + // Zero + EXPECT_EQ(Bignum(123) * Bignum(0), Bignum(0)); + EXPECT_EQ(Bignum(0) * Bignum(456), Bignum(0)); + + // Identity + EXPECT_EQ(Bignum(123) * Bignum(1), Bignum(123)); + EXPECT_EQ(Bignum(1) * Bignum(456), Bignum(456)); + EXPECT_EQ(Bignum(-123) * Bignum(1), Bignum(-123)); + + // Signs + EXPECT_EQ(Bignum(10) * Bignum(20), Bignum(200)); + EXPECT_EQ(Bignum(-10) * Bignum(20), Bignum(-200)); + EXPECT_EQ(Bignum(10) * Bignum(-20), Bignum(-200)); + EXPECT_EQ(Bignum(-10) * Bignum(-20), Bignum(200)); + + // Simple carry + const auto bn_u32max = Bignum(u32max); + EXPECT_EQ(bn_u32max * Bignum(2), *Bn("8589934590")); + + // 1x1 bigit fast path + const auto bn_u64max = Bignum(u64max); + EXPECT_EQ(Bignum(2) * bn_u64max, *Bn("36893488147419103230")); + + // 1xN bigit multiplication + const auto two_pow_128_minus_1 = + *Bn("340282366920938463463374607431768211455"); + const auto res_1xN = *Bn("680564733841876926926749214863536422910"); + EXPECT_EQ(two_pow_128_minus_1 * Bignum(2), res_1xN); + + // Check that aliasing doesn't cause problems. + Bignum a(100); + a *= a; + EXPECT_EQ(a, Bignum(10000)); + + Bignum b = *Bn("10000000000000000000"); // > 64 bits + b *= b; + EXPECT_EQ(b, *Bn("100000000000000000000000000000000000000")); + + // Karatsuba threshold test + // (2^128 - 1) * (2^128 - 1) = 2^256 - 2*2^128 + 1 + // This should trigger karatsuba if the threshold is low enough. + EXPECT_EQ(two_pow_128_minus_1 * two_pow_128_minus_1, + *Bn("115792089237316195423570985008687907852589419931798687112530" + "834793049593217025")); + + // Karatsuba with uneven operands + EXPECT_EQ(two_pow_128_minus_1 * bn_u64max, + *Bn("6277101735386680763495507056286727952620534092958556749825")); +} + +TEST(BignumTest, CountrZero) { + EXPECT_EQ(Bignum(0).CountrZero(), 0); + EXPECT_EQ(Bignum(1).CountrZero(), 0); + EXPECT_EQ(Bignum(7).CountrZero(), 0); + EXPECT_EQ(Bignum(-7).CountrZero(), 0); + + EXPECT_EQ(Bignum(2).CountrZero(), 1); + EXPECT_EQ(Bignum(8).CountrZero(), 3); + EXPECT_EQ(Bignum(10).CountrZero(), 1); // 0b1010 + EXPECT_EQ(Bignum(12).CountrZero(), 2); // 0b1100 + + auto two_pow_64 = Bignum(1) << 64; + EXPECT_EQ(two_pow_64.CountrZero(), 64); + + auto large_shifted = Bignum(6) << 100; // 0b110 << 100 + EXPECT_EQ(large_shifted.CountrZero(), 101); + + auto neg_large_shifted = Bignum(-5) << 200; + EXPECT_EQ(neg_large_shifted.CountrZero(), 200); +} + +TEST(BignumTest, Bit) { + EXPECT_FALSE(Bignum(0).Bit(0)); + EXPECT_FALSE(Bignum(0).Bit(100)); + + // 5 = 0b101 + Bignum five(5); + EXPECT_TRUE(five.Bit(0)); + EXPECT_FALSE(five.Bit(1)); + EXPECT_TRUE(five.Bit(2)); + EXPECT_FALSE(five.Bit(3)); + + // Negative numbers should test the magnitude. + Bignum neg_five(-5); + EXPECT_TRUE(neg_five.Bit(0)); + EXPECT_FALSE(neg_five.Bit(1)); + EXPECT_TRUE(neg_five.Bit(2)); + + // Test edges of and across bigits. + Bignum high_bit_63 = Bignum(1) << 63; + EXPECT_FALSE(high_bit_63.Bit(62)); + EXPECT_TRUE(high_bit_63.Bit(63)); + EXPECT_FALSE(high_bit_63.Bit(64)); + + Bignum cross_bigit = (Bignum(1) << 100) + Bignum(1); + EXPECT_TRUE(cross_bigit.Bit(0)); + EXPECT_TRUE(cross_bigit.Bit(100)); + EXPECT_FALSE(cross_bigit.Bit(50)); + EXPECT_FALSE(cross_bigit.Bit(1000)); +} + +TEST(BignumTest, Pow) { + // Edge cases + EXPECT_EQ(Bignum(0).Pow(0), Bignum(1)); + EXPECT_EQ(Bignum(123).Pow(0), Bignum(1)); + EXPECT_EQ(Bignum(0).Pow(123), Bignum(0)); + EXPECT_EQ(Bignum(1).Pow(12345), Bignum(1)); + + // Negative base + EXPECT_EQ(Bignum(-1).Pow(2), Bignum(1)); + EXPECT_EQ(Bignum(-1).Pow(3), Bignum(-1)); + EXPECT_EQ(Bignum(-2).Pow(2), Bignum(4)); + EXPECT_EQ(Bignum(-2).Pow(3), Bignum(-8)); + + // Basic powers + EXPECT_EQ(Bignum(2).Pow(10), Bignum(1024)); + EXPECT_EQ(Bignum(3).Pow(5), Bignum(243)); + EXPECT_EQ(Bignum(10).Pow(18), *Bn("1000000000000000000")); + + // Large exponent + Bignum two_pow_100 = Bignum(1) << 100; + EXPECT_EQ(Bignum(2).Pow(100), two_pow_100); + + // Large base + Bignum ten_pow_19 = *Bn("10000000000000000000"); + Bignum ten_pow_38 = *Bn("100000000000000000000000000000000000000"); + EXPECT_EQ(ten_pow_19.Pow(2), ten_pow_38); +} + +TEST(BignumTest, SetZero) { + Bignum a(123); + a.SetZero(); + EXPECT_TRUE(a.zero()); + + Bignum b(-456); + b.SetZero(); + EXPECT_EQ(b, Bignum(0)); +} + +TEST(BignumTest, SetNegativeSetPositive) { + Bignum a(42); + a.SetNegative(); + EXPECT_TRUE(a.negative()); + EXPECT_EQ(a, Bignum(-42)); + + a.SetPositive(); + EXPECT_TRUE(a.positive()); + EXPECT_EQ(a, Bignum(42)); +} + +TEST(BignumTest, SetSign) { + Bignum a(99); + a.SetSign(-10); // any negative + EXPECT_EQ(a, Bignum(-99)); + + a.SetSign(5); // any positive + EXPECT_EQ(a, Bignum(99)); + + a.SetSign(0); + EXPECT_TRUE(a.zero()); +} + +TEST(BignumTest, Comparisons) { + EXPECT_EQ(*Bn("123"), Bignum(123)); + EXPECT_EQ(*Bn("-123"), Bignum(-123)); + EXPECT_NE(*Bn("123"), Bignum(-123)); + EXPECT_EQ(Bignum(0), Bignum(0)); + EXPECT_NE(Bignum(0), Bignum(1)); + + // Positive vs Positive + EXPECT_LT(Bignum(100), Bignum(200)); + EXPECT_GT(Bignum(200), Bignum(100)); + EXPECT_LE(Bignum(100), Bignum(200)); + EXPECT_GE(Bignum(200), Bignum(100)); + + // Negative vs Negative + EXPECT_LT(Bignum(-200), Bignum(-100)); + EXPECT_GT(Bignum(-100), Bignum(-200)); + EXPECT_GE(Bignum(-100), Bignum(-200)); + EXPECT_LE(Bignum(-200), Bignum(-100)); + + // Positive vs Negative + EXPECT_LT(Bignum(-10), Bignum(10)); + EXPECT_GT(Bignum(10), Bignum(-10)); + + // Zero + EXPECT_LT(Bignum(-1), Bignum(0)); + EXPECT_LT(Bignum(0), Bignum(1)); + EXPECT_GT(Bignum(0), Bignum(-1)); + EXPECT_GT(Bignum(1), Bignum(0)); + + // Multi-bigit + const auto two_pow_64 = *Bn("18446744073709551616"); + EXPECT_LT(Bignum(0), two_pow_64); + EXPECT_GT(two_pow_64, Bignum(0)); + EXPECT_LT(Bignum(-1), two_pow_64); + EXPECT_GT(two_pow_64, Bignum(-1)); + + EXPECT_LE(Bignum(100), Bignum(200)); + EXPECT_LE(Bignum(100), Bignum(100)); + EXPECT_LE(Bignum(-200), Bignum(-100)); + EXPECT_LE(Bignum(-100), Bignum(-100)); + EXPECT_GT(Bignum(200), Bignum(100)); + + EXPECT_GE(Bignum(200), Bignum(100)); + EXPECT_GE(Bignum(100), Bignum(100)); + EXPECT_GE(Bignum(-100), Bignum(-200)); + EXPECT_GE(Bignum(-100), Bignum(-100)); + EXPECT_LT(Bignum(100), Bignum(200)); + + EXPECT_LE(Bignum(0), Bignum(0)); + EXPECT_GE(Bignum(0), Bignum(0)); +} + +// TODO: Enable once benchmark is integrated. +#if 0 + +// RAII wrapper for OpenSSL BIGNUM +class OpenSSLBignum { + public: + OpenSSLBignum() : bn_(BN_new()) {} + + // Construct from a decimal number in a string. + explicit OpenSSLBignum(const absl::string_view& decimal) : bn_(BN_new()) { + BN_dec2bn(&bn_, decimal.data()); + } + + explicit OpenSSLBignum(uint64_t value) : bn_(BN_new()) { + BN_set_word(bn_, value); + } + + ~OpenSSLBignum() { BN_free(bn_); } + + OpenSSLBignum(OpenSSLBignum&& other) noexcept : bn_(other.bn_) { + other.bn_ = nullptr; + } + + OpenSSLBignum& operator=(OpenSSLBignum&& other) noexcept { + if (this != &other) { + BN_free(bn_); + bn_ = other.bn_; + other.bn_ = nullptr; + } + return *this; + } + + OpenSSLBignum(const OpenSSLBignum& other) : bn_(BN_dup(other.bn_)) {} + + OpenSSLBignum& operator=(const OpenSSLBignum& other) { + if (this != &other) { + BN_copy(bn_, other.bn_); + } + return *this; + } + + BIGNUM* get() const { return bn_; } + + private: + BIGNUM* bn_; +}; + +// Power of two for fast modulo. +const int kRandomBignumCount = 128; + +static std::vector GenerateRandomNumbers(int bits) { + std::vector numbers; + std::mt19937_64 rng(42); // Fixed seed for reproducibility + + for (int i = 0; i < kRandomBignumCount; ++i) { + std::string num; + + // Generate approximately `bits` worth of decimal digits + int decimal_digits = (bits * 3) / 10; // log10(2^bits) ≈ bits * 0.301 + + // First digit can't be zero + std::uniform_int_distribution first_digit(1, 9); + num += std::to_string(first_digit(rng)); + + std::uniform_int_distribution digit(0, 9); + for (int j = 1; j < decimal_digits; ++j) { + num += std::to_string(digit(rng)); + } + + numbers.push_back(num); + } + + return numbers; +} + +// Basic correctness test to ensure OpenSSL integration is working +TEST(BignumTestBenchmarkTest, OpenSSLIntegration) { + OpenSSLBignum a(123); + OpenSSLBignum b(456); + OpenSSLBignum result; + + BN_add(result.get(), a.get(), b.get()); + + char* str = BN_bn2dec(result.get()); + EXPECT_STREQ(str, "579"); + OPENSSL_free(str); +} + +TEST(BignumTestBenchmarkTest, ResultsMatch) { + // Test that and OpenSSL produce the same results + const Bignum w_a(12345); + const Bignum w_b(67890); + Bignum w_result = w_a + w_b; + + const OpenSSLBignum ssl_a(12345); + const OpenSSLBignum ssl_b(67890); + OpenSSLBignum ssl_result; + BN_add(ssl_result.get(), ssl_a.get(), ssl_b.get()); + + char* ssl_str = BN_bn2dec(ssl_result.get()); + std::string w_str = absl::StrFormat("%v", w_result); + + EXPECT_EQ(w_str, std::string(ssl_str)); + OPENSSL_free(ssl_str); +} + +const std::vector& SmallNumbers() { + static absl::NoDestructor> numbers( // + GenerateRandomNumbers(64)); + return *numbers; +} + +const std::vector& MediumNumbers() { + static absl::NoDestructor> numbers( // + GenerateRandomNumbers(256)); + return *numbers; +} + +const std::vector& LargeNumbers() { + static absl::NoDestructor> numbers( // + GenerateRandomNumbers(1024)); + return *numbers; +} + +const std::vector& HugeNumbers() { + static absl::NoDestructor> numbers( // + GenerateRandomNumbers(4096)); + return *numbers; +} + +const std::vector& MegaNumbers() { + static absl::NoDestructor> numbers( // + GenerateRandomNumbers(18000)); + return *numbers; +} + +template +void BignumBinaryOpBenchmark(benchmark::State& state, + const std::vector& number_strings, + BinaryOp op) { + std::vector numbers; + for (const auto& str : number_strings) { + numbers.push_back(*Bignum::FromString(str)); + } + + Bignum result; + size_t idx = 0; + for (auto _ : state) { + const Bignum& a = numbers[(idx + 0) % kRandomBignumCount]; + const Bignum& b = numbers[(idx + 1) % kRandomBignumCount]; + result = op(a, b); + benchmark::DoNotOptimize(result); + ++idx; + } +} + +void BignumPowBenchmark(benchmark::State& state, + const std::vector& number_strings, + int exponent) { + std::vector numbers; + for (const auto& str : number_strings) { + numbers.push_back(*Bignum::FromString(str)); + } + + Bignum result; + size_t idx = 0; + for (auto _ : state) { + const Bignum& base = numbers[(idx + 0) % kRandomBignumCount]; + result = base.Pow(exponent); + benchmark::DoNotOptimize(result); + ++idx; + } +} + +template +void OpenSSLBinaryOpBenchmark(benchmark::State& state, + const std::vector& number_strings, + BinaryOp op) { + std::vector numbers; + for (const auto& str : number_strings) { + numbers.emplace_back(str); + } + + size_t idx = 0; + + for (auto _ : state) { + OpenSSLBignum result; + const OpenSSLBignum& a = numbers[(idx + 0) % kRandomBignumCount]; + const OpenSSLBignum& b = numbers[(idx + 1) % kRandomBignumCount]; + op(result.get(), a.get(), b.get()); + benchmark::DoNotOptimize(result.get()); + ++idx; + } +} + +template +void OpenSSLMulOpBenchmark(benchmark::State& state, + const std::vector& number_strings, + MulOp op) { + std::vector numbers; + for (const auto& str : number_strings) { + numbers.emplace_back(str); + } + + BN_CTX* ctx = BN_CTX_new(); + size_t idx = 0; + + for (auto _ : state) { + OpenSSLBignum result; + const OpenSSLBignum& a = numbers[(idx + 0) % kRandomBignumCount]; + const OpenSSLBignum& b = numbers[(idx + 1) % kRandomBignumCount]; + op(result.get(), a.get(), b.get(), ctx); + benchmark::DoNotOptimize(result.get()); + ++idx; + } + + BN_CTX_free(ctx); +} + +void OpenSSLPowBenchmark(benchmark::State& state, + const std::vector& number_strings, + int exponent) { + std::vector numbers; + for (const auto& str : number_strings) { + numbers.emplace_back(str); + } + + const OpenSSLBignum exp(exponent); + BN_CTX* ctx = BN_CTX_new(); + size_t idx = 0; + + for (auto _ : state) { + OpenSSLBignum result; + const OpenSSLBignum& base = numbers[(idx + 0) % kRandomBignumCount]; + BN_exp(result.get(), base.get(), exp.get(), ctx); + benchmark::DoNotOptimize(result.get()); + ++idx; + } + + BN_CTX_free(ctx); +} + +void BM_Bignum_AddSmall(benchmark::State& state) { + BignumBinaryOpBenchmark(state, SmallNumbers(), std::plus{}); +} +BENCHMARK(BM_Bignum_AddSmall); + +void BM_Bignum_AddMedium(benchmark::State& state) { + BignumBinaryOpBenchmark(state, MediumNumbers(), std::plus{}); +} +BENCHMARK(BM_Bignum_AddMedium); + +void BM_Bignum_AddLarge(benchmark::State& state) { + BignumBinaryOpBenchmark(state, LargeNumbers(), std::plus{}); +} +BENCHMARK(BM_Bignum_AddLarge); + +void BM_Bignum_AddHuge(benchmark::State& state) { + BignumBinaryOpBenchmark(state, HugeNumbers(), std::plus{}); +} +BENCHMARK(BM_Bignum_AddHuge); + +void BM_Bignum_AddMega(benchmark::State& state) { + BignumBinaryOpBenchmark(state, MegaNumbers(), std::plus{}); +} +BENCHMARK(BM_Bignum_AddMega); + +void BM_OpenSSL_AddSmall(benchmark::State& state) { + OpenSSLBinaryOpBenchmark(state, SmallNumbers(), BN_add); +} +BENCHMARK(BM_OpenSSL_AddSmall); + +void BM_OpenSSL_AddMedium(benchmark::State& state) { + OpenSSLBinaryOpBenchmark(state, MediumNumbers(), BN_add); +} +BENCHMARK(BM_OpenSSL_AddMedium); + +void BM_OpenSSL_AddLarge(benchmark::State& state) { + OpenSSLBinaryOpBenchmark(state, LargeNumbers(), BN_add); +} +BENCHMARK(BM_OpenSSL_AddLarge); + +void BM_OpenSSL_AddHuge(benchmark::State& state) { + OpenSSLBinaryOpBenchmark(state, HugeNumbers(), BN_add); +} +BENCHMARK(BM_OpenSSL_AddHuge); + +void BM_OpenSSL_AddMega(benchmark::State& state) { + OpenSSLBinaryOpBenchmark(state, MegaNumbers(), BN_add); +} +BENCHMARK(BM_OpenSSL_AddMega); + +void BM_Bignum_MulSmall(benchmark::State& state) { + BignumBinaryOpBenchmark(state, SmallNumbers(), std::multiplies{}); +} +BENCHMARK(BM_Bignum_MulSmall); + +void BM_Bignum_MulMedium(benchmark::State& state) { + BignumBinaryOpBenchmark(state, MediumNumbers(), std::multiplies{}); +} +BENCHMARK(BM_Bignum_MulMedium); + +void BM_Bignum_MulLarge(benchmark::State& state) { + BignumBinaryOpBenchmark(state, LargeNumbers(), std::multiplies{}); +} +BENCHMARK(BM_Bignum_MulLarge); + +void BM_Bignum_MulHuge(benchmark::State& state) { + BignumBinaryOpBenchmark(state, HugeNumbers(), std::multiplies{}); +} +BENCHMARK(BM_Bignum_MulHuge); + +void BM_Bignum_MulMega(benchmark::State& state) { + BignumBinaryOpBenchmark(state, MegaNumbers(), std::multiplies{}); +} +BENCHMARK(BM_Bignum_MulMega); + +void BM_OpenSSL_MulSmall(benchmark::State& state) { + OpenSSLMulOpBenchmark(state, SmallNumbers(), BN_mul); +} +BENCHMARK(BM_OpenSSL_MulSmall); + +void BM_OpenSSL_MulMedium(benchmark::State& state) { + OpenSSLMulOpBenchmark(state, MediumNumbers(), BN_mul); +} +BENCHMARK(BM_OpenSSL_MulMedium); + +void BM_OpenSSL_MulLarge(benchmark::State& state) { + OpenSSLMulOpBenchmark(state, LargeNumbers(), BN_mul); +} +BENCHMARK(BM_OpenSSL_MulLarge); + +void BM_OpenSSL_MulHuge(benchmark::State& state) { + OpenSSLMulOpBenchmark(state, HugeNumbers(), BN_mul); +} +BENCHMARK(BM_OpenSSL_MulHuge); + +void BM_OpenSSL_MulMega(benchmark::State& state) { + OpenSSLMulOpBenchmark(state, MegaNumbers(), BN_mul); +} +BENCHMARK(BM_OpenSSL_MulMega); + +void BM_Bignum_PowSmall(benchmark::State& state) { + BignumPowBenchmark(state, SmallNumbers(), 20); +} +BENCHMARK(BM_Bignum_PowSmall); + +void BM_Bignum_PowMedium(benchmark::State& state) { + BignumPowBenchmark(state, MediumNumbers(), 10); +} +BENCHMARK(BM_Bignum_PowMedium); + +void BM_OpenSSL_PowSmall(benchmark::State& state) { + OpenSSLPowBenchmark(state, SmallNumbers(), 20); +} +BENCHMARK(BM_OpenSSL_PowSmall); + +void BM_OpenSSL_PowMedium(benchmark::State& state) { + OpenSSLPowBenchmark(state, MediumNumbers(), 10); +} +BENCHMARK(BM_OpenSSL_PowMedium); +#endif diff --git a/src/s2/util/math/exactfloat/exactfloat.cc b/src/s2/util/math/exactfloat/exactfloat.cc index 71af9208..9010b29d 100644 --- a/src/s2/util/math/exactfloat/exactfloat.cc +++ b/src/s2/util/math/exactfloat/exactfloat.cc @@ -24,14 +24,10 @@ #include #include -#include -#include // for OPENSSL_free - #include "absl/base/macros.h" #include "absl/container/fixed_array.h" #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" -// Used by !defined(OPENSSL_IS_BORINGSSL). #include "absl/numeric/bits.h" // IWYU pragma: keep #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -47,83 +43,6 @@ static_assert(ExactFloat::kMaxExp <= INT_MAX / 2 && ExactFloat::kMinExp - ExactFloat::kMaxPrec >= INT_MIN / 2, "exactfloat exponent might overflow"); -// We define a few simple extensions to the OpenSSL's BIGNUM interface. -// In some cases these depend on BIGNUM internal fields, so they might -// require tweaking if the BIGNUM implementation changes significantly. -// These are just thin wrappers for BoringSSL. - -#ifdef OPENSSL_IS_BORINGSSL - -inline static void BN_ext_set_uint64(BIGNUM* bn, uint64_t v) { - ABSL_CHECK(BN_set_u64(bn, v)); -} - -// Return the absolute value of a BIGNUM as a 64-bit unsigned integer. -// Requires that BIGNUM fits into 64 bits. -inline static uint64_t BN_ext_get_uint64(const BIGNUM* bn) { - uint64_t u64; - if (!BN_get_u64(bn, &u64)) { - ABSL_DCHECK(false) << "BN has " << BN_num_bits(bn) << " bits"; - return 0; - } - return u64; -} - -static int BN_ext_count_low_zero_bits(const BIGNUM* bn) { - return BN_count_low_zero_bits(bn); -} - -#else // !defined(OPENSSL_IS_BORINGSSL) - -// Set a BIGNUM to the given unsigned 64-bit value. -inline static void BN_ext_set_uint64(BIGNUM* bn, uint64_t v) { -#if BN_BITS2 == 64 - ABSL_CHECK(BN_set_word(bn, v)); -#else - static_assert(BN_BITS2 == 32, "at least 32 bit openssl build needed"); - ABSL_CHECK(BN_set_word(bn, static_cast(v >> 32))); - ABSL_CHECK(BN_lshift(bn, bn, 32)); - ABSL_CHECK(BN_add_word(bn, static_cast(v))); -#endif -} - -// Return the absolute value of a BIGNUM as a 64-bit unsigned integer. -// Requires that BIGNUM fits into 64 bits. -inline static uint64_t BN_ext_get_uint64(const BIGNUM* bn) { - ABSL_DCHECK_LE(BN_num_bytes(bn), sizeof(uint64_t)); -#if BN_BITS2 == 64 - return BN_get_word(bn); -#else - static_assert(BN_BITS2 == 32, "at least 32 bit openssl build needed"); - if (bn->top == 0) return 0; - if (bn->top == 1) return BN_get_word(bn); - ABSL_DCHECK_EQ(bn->top, 2); - return (static_cast(bn->d[1]) << 32) + bn->d[0]; -#endif -} - -static int BN_ext_count_low_zero_bits(const BIGNUM* bn) { - // In OpenSSL >= 1.1, BIGNUM is an opaque type, so d and top - // cannot be accessed. The bytes must be copied out at a ~25% - // performance penalty. - absl::FixedArray bytes(BN_num_bytes(bn)); - // "le" indicates little endian. - ABSL_CHECK_EQ(BN_bn2lebinpad(bn, bytes.data(), bytes.size()), bytes.size()); - - int count = 0; - for (unsigned char c : bytes) { - if (c == 0) { - count += 8; - } else { - count += absl::countr_zero(c); - break; - } - } - return count; -} - -#endif // !defined(OPENSSL_IS_BORINGSSL) - ExactFloat::ExactFloat(double v) { sign_ = std::signbit(v) ? -1 : 1; if (std::isnan(v)) { @@ -140,7 +59,7 @@ ExactFloat::ExactFloat(double v) { int exp; double f = frexp(fabs(v), &exp); uint64_t m = static_cast(ldexp(f, kDoubleMantissaBits)); - BN_ext_set_uint64(bn_.get(), m); + bn_ = Bignum(m); bn_exp_ = exp - kDoubleMantissaBits; Canonicalize(); } @@ -148,17 +67,14 @@ ExactFloat::ExactFloat(double v) { ExactFloat::ExactFloat(int v) { sign_ = (v >= 0) ? 1 : -1; - // Note that this works even for INT_MIN because the parameter type for - // BN_set_word() is unsigned. - ABSL_CHECK(BN_set_word(bn_.get(), abs(v))); + // Note that this works even for INT_MIN. + bn_ = Bignum(abs(v)); bn_exp_ = 0; Canonicalize(); } ExactFloat::ExactFloat(const ExactFloat& b) - : sign_(b.sign_), bn_exp_(b.bn_exp_) { - BN_copy(bn_.get(), b.bn_.get()); -} + : sign_(b.sign_), bn_exp_(b.bn_exp_), bn_(b.bn_) {} ExactFloat ExactFloat::SignedZero(int sign) { ExactFloat r; @@ -178,29 +94,29 @@ ExactFloat ExactFloat::NaN() { return r; } -int ExactFloat::prec() const { return BN_num_bits(bn_.get()); } +int ExactFloat::prec() const { return bn_.BitWidth(); } int ExactFloat::exp() const { ABSL_DCHECK(is_normal()); - return bn_exp_ + BN_num_bits(bn_.get()); + return bn_exp_ + bn_.BitWidth(); } void ExactFloat::set_zero(int sign) { sign_ = sign; bn_exp_ = kExpZero; - if (!BN_is_zero(bn_.get())) BN_zero(bn_.get()); + bn_.SetZero(); } void ExactFloat::set_inf(int sign) { sign_ = sign; bn_exp_ = kExpInfinity; - if (!BN_is_zero(bn_.get())) BN_zero(bn_.get()); + bn_.SetZero(); } void ExactFloat::set_nan() { sign_ = 1; bn_exp_ = kExpNaN; - if (!BN_is_zero(bn_.get())) BN_zero(bn_.get()); + bn_.SetZero(); } double ExactFloat::ToDouble() const { @@ -214,7 +130,7 @@ double ExactFloat::ToDouble() const { } double ExactFloat::ToDoubleHelper() const { - ABSL_DCHECK_LE(BN_num_bits(bn_.get()), kDoubleMantissaBits); + ABSL_DCHECK_LE(bn_.BitWidth(), kDoubleMantissaBits); if (!is_normal()) { if (is_zero()) return copysign(0, sign_); if (is_inf()) { @@ -222,7 +138,9 @@ double ExactFloat::ToDoubleHelper() const { } return std::copysign(std::numeric_limits::quiet_NaN(), sign_); } - uint64_t d_mantissa = BN_ext_get_uint64(bn_.get()); + auto opt_mantissa = bn_.Convert(); + ABSL_DCHECK(opt_mantissa.has_value()); + uint64_t d_mantissa = opt_mantissa.value(); // We rely on ldexp() to handle overflow and underflow. (It will return a // signed zero or infinity if the result is too small or too large.) return sign_ * ldexp(static_cast(d_mantissa), bn_exp_); @@ -273,10 +191,10 @@ ExactFloat ExactFloat::RoundToPowerOf2(int bit_exp, RoundingMode mode) const { // Never increment. } else if (mode == kRoundTiesAwayFromZero) { // Increment if the highest discarded bit is 1. - if (BN_is_bit_set(bn_.get(), shift - 1)) increment = true; + if (bn_.Bit(shift - 1)) increment = true; } else if (mode == kRoundAwayFromZero) { // Increment unless all discarded bits are zero. - if (BN_ext_count_low_zero_bits(bn_.get()) < shift) increment = true; + if (bn_.CountrZero() < shift) increment = true; } else { ABSL_DCHECK_EQ(mode, kRoundTiesToEven); // Let "w/xyz" denote a mantissa where "w" is the lowest kept bit and @@ -285,16 +203,15 @@ ExactFloat ExactFloat::RoundToPowerOf2(int bit_exp, RoundingMode mode) const { // 0/10* -> Don't increment (fraction = 1/2, kept part even) // 1/10* -> Increment (fraction = 1/2, kept part odd) // ./1.*1.* -> Increment (fraction > 1/2) - if (BN_is_bit_set(bn_.get(), shift - 1) && - ((BN_is_bit_set(bn_.get(), shift) || - BN_ext_count_low_zero_bits(bn_.get()) < shift - 1))) { + if (bn_.Bit(shift - 1) && + ((bn_.Bit(shift) || bn_.CountrZero() < shift - 1))) { increment = true; } } r.bn_exp_ = bn_exp_ + shift; - ABSL_CHECK(BN_rshift(r.bn_.get(), bn_.get(), shift)); + r.bn_ = bn_ >> shift; if (increment) { - ABSL_CHECK(BN_add_word(r.bn_.get(), 1)); + r.bn_ += Bignum(1); } r.sign_ = sign_; r.Canonicalize(); @@ -397,43 +314,37 @@ static void IncrementDecimalDigits(std::string* digits) { int ExactFloat::GetDecimalDigits(int max_digits, std::string* digits) const { ABSL_DCHECK(is_normal()); // Convert the value to the form (bn * (10 ** bn_exp10)) where "bn" is a - // positive integer (BIGNUM). - BIGNUM* bn = BN_new(); + // positive integer. + Bignum bn; int bn_exp10; if (bn_exp_ >= 0) { // The easy case: bn = bn_ * (2 ** bn_exp_)), bn_exp10 = 0. - ABSL_CHECK(BN_lshift(bn, bn_.get(), bn_exp_)); + bn = bn_ << bn_exp_; bn_exp10 = 0; } else { // Set bn = bn_ * (5 ** -bn_exp_) and bn_exp10 = bn_exp_. This is // equivalent to the original value of (bn_ * (2 ** bn_exp_)). - BIGNUM* power = BN_new(); - ABSL_CHECK(BN_set_word(power, -bn_exp_)); - ABSL_CHECK(BN_set_word(bn, 5)); - BN_CTX* ctx = BN_CTX_new(); - ABSL_CHECK(BN_exp(bn, bn, power, ctx)); - ABSL_CHECK(BN_mul(bn, bn, bn_.get(), ctx)); - BN_CTX_free(ctx); - BN_free(power); + int power = -bn_exp_; + bn = Bignum(5).Pow(power) * bn_; bn_exp10 = bn_exp_; } - // Now convert "bn" to a decimal string. - char* all_digits = BN_bn2dec(bn); - ABSL_DCHECK(all_digits != nullptr); - BN_free(bn); + // Now convert "bn" to a decimal string using our Bignum's string conversion. + std::string all_digits = absl::StrFormat("%v", bn); + ABSL_DCHECK(!all_digits.empty()); // Check whether we have too many digits and round if necessary. - int num_digits = strlen(all_digits); + int num_digits = all_digits.length(); if (num_digits <= max_digits) { *digits = all_digits; } else { - digits->assign(all_digits, max_digits); + digits->assign(all_digits, 0, max_digits); // Standard "printf" formatting rounds ties to an even number. This means // that we round up (away from zero) if highest discarded digit is '5' or // more, unless all other discarded digits are zero in which case we round // up only if the lowest kept digit is odd. if (all_digits[max_digits] >= '5' && ((all_digits[max_digits - 1] & 1) == 1 || - strpbrk(all_digits + max_digits + 1, "123456789") != nullptr)) { + all_digits.substr(max_digits + 1).find_first_of("123456789") != + std::string::npos)) { // This can increase the number of digits by 1, but in that case at // least one trailing zero will be stripped off below. IncrementDecimalDigits(digits); @@ -441,7 +352,6 @@ int ExactFloat::GetDecimalDigits(int max_digits, std::string* digits) const { // Adjust the base-10 exponent to reflect the digits we have removed. bn_exp10 += num_digits - max_digits; } - OPENSSL_free(all_digits); // Now strip any trailing zeros. ABSL_DCHECK_NE((*digits)[0], '0'); @@ -463,7 +373,7 @@ ExactFloat& ExactFloat::operator=(const ExactFloat& b) { if (this != &b) { sign_ = b.sign_; bn_exp_ = b.bn_exp_; - BN_copy(bn_.get(), b.bn_.get()); + bn_ = b.bn_; } return *this; } @@ -508,28 +418,26 @@ ExactFloat ExactFloat::SignedSum(int a_sign, const ExactFloat* a, int b_sign, } // Shift "a" if necessary so that both values have the same bn_exp_. ExactFloat r; + Bignum a_bn; if (a->bn_exp_ > b->bn_exp_) { - ABSL_CHECK(BN_lshift(r.bn_.get(), a->bn_.get(), a->bn_exp_ - b->bn_exp_)); - a = &r; // The only field of "a" used below is bn_. + a_bn = a->bn_ << (a->bn_exp_ - b->bn_exp_); + } else { + a_bn = a->bn_; } r.bn_exp_ = b->bn_exp_; if (a_sign == b_sign) { - ABSL_CHECK(BN_add(r.bn_.get(), a->bn_.get(), b->bn_.get())); + r.bn_ = a_bn + b->bn_; r.sign_ = a_sign; } else { - // Note that the BIGNUM documentation is out of date -- all methods now - // allow the result to be the same as any input argument, so it is okay if - // (a == &r) due to the shift above. - ABSL_CHECK(BN_sub(r.bn_.get(), a->bn_.get(), b->bn_.get())); - if (BN_is_zero(r.bn_.get())) { - r.sign_ = +1; - } else if (BN_is_negative(r.bn_.get())) { - // The magnitude of "b" was larger. - r.sign_ = b_sign; - BN_set_negative(r.bn_.get(), false); - } else { - // They were equal, or the magnitude of "a" was larger. + if (a_bn >= b->bn_) { + r.bn_ = a_bn - b->bn_; r.sign_ = a_sign; + } else { + r.bn_ = b->bn_ - a_bn; + r.sign_ = b_sign; + } + if (r.bn_.zero()) { + r.sign_ = +1; } } r.Canonicalize(); @@ -542,16 +450,16 @@ void ExactFloat::Canonicalize() { // Underflow/overflow occurs if exp() is not in [kMinExp, kMaxExp]. // We also convert a zero mantissa to signed zero. int my_exp = exp(); - if (my_exp < kMinExp || BN_is_zero(bn_.get())) { + if (my_exp < kMinExp || bn_.zero()) { set_zero(sign_); } else if (my_exp > kMaxExp) { set_inf(sign_); - } else if (!BN_is_odd(bn_.get())) { + } else if (bn_.even()) { // Remove any low-order zero bits from the mantissa. - ABSL_DCHECK(!BN_is_zero(bn_.get())); - int shift = BN_ext_count_low_zero_bits(bn_.get()); + ABSL_DCHECK(!bn_.zero()); + int shift = bn_.CountrZero(); if (shift > 0) { - ABSL_CHECK(BN_rshift(bn_.get(), bn_.get(), shift)); + bn_ >>= shift; bn_exp_ += shift; } } @@ -583,9 +491,7 @@ ExactFloat operator*(const ExactFloat& a, const ExactFloat& b) { ExactFloat r; r.sign_ = result_sign; r.bn_exp_ = a.bn_exp_ + b.bn_exp_; - BN_CTX* ctx = BN_CTX_new(); - ABSL_CHECK(BN_mul(r.bn_.get(), a.bn_.get(), b.bn_.get(), ctx)); - BN_CTX_free(ctx); + r.bn_ = a.bn_ * b.bn_; r.Canonicalize(); return r; } @@ -603,14 +509,16 @@ bool operator==(const ExactFloat& a, const ExactFloat& b) { // Otherwise, the signs and mantissas must match. Note that non-normal // values such as infinity have a mantissa of zero. - return a.sign_ == b.sign_ && BN_ucmp(a.bn_.get(), b.bn_.get()) == 0; + return a.sign_ == b.sign_ && a.bn_ == b.bn_; } int ExactFloat::ScaleAndCompare(const ExactFloat& b) const { ABSL_DCHECK(is_normal() && b.is_normal() && bn_exp_ >= b.bn_exp_); ExactFloat tmp = *this; - ABSL_CHECK(BN_lshift(tmp.bn_.get(), tmp.bn_.get(), bn_exp_ - b.bn_exp_)); - return BN_ucmp(tmp.bn_.get(), b.bn_.get()); + tmp.bn_ <<= (bn_exp_ - b.bn_exp_); + if (tmp.bn_ < b.bn_) return -1; + if (tmp.bn_ > b.bn_) return 1; + return 0; } bool ExactFloat::UnsignedLess(const ExactFloat& b) const { @@ -702,7 +610,9 @@ T ExactFloat::ToInteger(RoundingMode mode) const { if (!r.is_inf()) { // If the unsigned value has more than 63 bits it is always clamped. if (r.exp() < 64) { - int64_t value = BN_ext_get_uint64(r.bn_.get()) << r.bn_exp_; + auto opt_value = r.bn_.Convert(); + ABSL_DCHECK(opt_value.has_value()); + int64_t value = static_cast(opt_value.value()) << r.bn_exp_; if (r.sign_ < 0) value = -value; return max(kMinValue, min(kMaxValue, value)); } diff --git a/src/s2/util/math/exactfloat/exactfloat.h b/src/s2/util/math/exactfloat/exactfloat.h index 62503087..e0867976 100644 --- a/src/s2/util/math/exactfloat/exactfloat.h +++ b/src/s2/util/math/exactfloat/exactfloat.h @@ -15,7 +15,7 @@ // Author: ericv@google.com (Eric Veach) // -// ExactFloat is a multiple-precision floating point type that uses the OpenSSL +// ExactFloat is a multiple-precision floating point type that uses a custom // Bignum library for numerical calculations. It has the same interface as the // built-in "float" and "double" types, but only supports the subset of // operators and intrinsics where it is possible to compute the result exactly. @@ -25,10 +25,8 @@ // algorithms, especially for disambiguating cases where ordinary // double-precision arithmetic yields an uncertain result. // -// ExactFloat is a subset of the now-retired MPFloat class, which used the GNU -// MPFR library for numerical calculations. The main reason for the switch to -// ExactFloat is that OpenSSL has a BSD-style license whereas MPFR has a much -// more restrictive LGPL license. +// ExactFloat was originally based on OpenSSL's Bignum library, but has been +// updated to use a custom implementation to remove external dependencies. // // ExactFloat has the following features: // @@ -109,18 +107,16 @@ #ifndef S2_UTIL_MATH_EXACTFLOAT_EXACTFLOAT_H_ #define S2_UTIL_MATH_EXACTFLOAT_EXACTFLOAT_H_ +#include #include #include - -#include #include #include #include #include #include -#include - +#include "s2/util/math/exactfloat/bignum.h" class ExactFloat { public: @@ -500,40 +496,6 @@ class ExactFloat { friend ExactFloat logb(const ExactFloat& a); protected: - // OpenSSL >= 1.1 does not have BN_init, and does not support stack- - // allocated BIGNUMS. We use BN_init when possible, but BN_new otherwise. - // If the performance penalty is too high, an object pool can be added - // in the future. -#if defined(OPENSSL_IS_BORINGSSL) - // BoringSSL supports stack allocated BIGNUMs and BN_init. - class BigNum { - public: - BigNum() { BN_init(&bn_); } - // Prevent accidental, expensive, copying. - BigNum(const BigNum&) = delete; - BigNum& operator=(const BigNum&) = delete; - ~BigNum() { BN_free(&bn_); } - BIGNUM* get() { return &bn_; } - const BIGNUM* get() const { return &bn_; } - - private: - BIGNUM bn_; - }; -#else - class BigNum { - public: - BigNum() : bn_(BN_new()) {} - BigNum(const BigNum&) = delete; - BigNum& operator=(const BigNum&) = delete; - ~BigNum() { BN_free(bn_); } - BIGNUM* get() { return bn_; } - const BIGNUM* get() const { return bn_; } - - private: - BIGNUM* bn_; - }; -#endif - // Non-normal numbers are represented using special exponent values and a // mantissa of zero. Do not change these values; methods such as // is_normal() make assumptions about their ordering. Non-normal numbers @@ -544,12 +506,11 @@ class ExactFloat { // Normal numbers are represented as (sign_ * bn_ * (2 ** bn_exp_)), where: // - sign_ is either +1 or -1 - // - bn_ is a BIGNUM with a positive value + // - bn_ is a Bignum with a positive value // - bn_exp_ is the base-2 exponent applied to bn_. - // Default value is zero. int32_t sign_ = 1; int32_t bn_exp_ = kExpZero; - BigNum bn_; + Bignum bn_; // A standard IEEE "double" has a 53-bit mantissa consisting of a 52-bit // fraction plus an implicit leading "1" bit. From 383f43b8d921743c946c941dab19c014a5f20387 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Wed, 24 Sep 2025 17:00:35 -0600 Subject: [PATCH 02/31] Rework Karatsuba and respond to PR comments. Karatsuba avoids allocating memory as it recurses by pre-allocating an arena. Generally reworked the arithmetic code to unroll loops a bit more. With these changes performance is closer to OpenSSL: --------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------- BM_Bignum_AddSmall 9.50 ns 9.50 ns 1463556659 BM_Bignum_AddMedium 27.5 ns 27.5 ns 495811317 BM_Bignum_AddLarge 33.8 ns 33.8 ns 438882689 BM_Bignum_AddHuge 52.6 ns 52.6 ns 260576695 BM_Bignum_AddMega 194 ns 194 ns 71530603 BM_OpenSSL_AddSmall 28.1 ns 28.1 ns 486002407 BM_OpenSSL_AddMedium 29.2 ns 29.2 ns 463641940 BM_OpenSSL_AddLarge 34.4 ns 34.4 ns 404689860 BM_OpenSSL_AddHuge 54.7 ns 54.7 ns 253974791 BM_OpenSSL_AddMega 163 ns 163 ns 87070697 BM_Bignum_MulSmall 8.12 ns 8.12 ns 1724260557 BM_Bignum_MulMedium 48.4 ns 48.4 ns 288699498 BM_Bignum_MulLarge 185 ns 185 ns 75525625 BM_Bignum_MulHuge 1788 ns 1788 ns 7811143 BM_Bignum_MulMega 27517 ns 27514 ns 509658 BM_OpenSSL_MulSmall 33.0 ns 33.0 ns 415226181 BM_OpenSSL_MulMedium 39.6 ns 39.6 ns 352798959 BM_OpenSSL_MulLarge 148 ns 148 ns 94531613 BM_OpenSSL_MulHuge 1427 ns 1427 ns 9805970 BM_OpenSSL_MulMega 29192 ns 29188 ns 479649 BM_Bignum_PowSmall 387 ns 387 ns 36209348 BM_Bignum_PowMedium 1020 ns 1020 ns 13670435 BM_OpenSSL_PowSmall 249 ns 248 ns 56293679 BM_OpenSSL_PowMedium 417 ns 417 ns 33756592 Moved Bignum into exactfloat_internal.(h|cc) --- CMakeLists.txt | 16 +- README.md | 13 + src/s2/util/math/exactfloat/BUILD | 4 +- src/s2/util/math/exactfloat/bignum.h | 948 ------------------ src/s2/util/math/exactfloat/exactfloat.cc | 24 +- src/s2/util/math/exactfloat/exactfloat.h | 4 +- .../math/exactfloat/exactfloat_internal.cc | 753 ++++++++++++++ .../math/exactfloat/exactfloat_internal.h | 489 +++++++++ ...um_test.cc => exactfloat_internal_test.cc} | 336 +++---- 9 files changed, 1429 insertions(+), 1158 deletions(-) delete mode 100644 src/s2/util/math/exactfloat/bignum.h create mode 100644 src/s2/util/math/exactfloat/exactfloat_internal.cc create mode 100644 src/s2/util/math/exactfloat/exactfloat_internal.h rename src/s2/util/math/exactfloat/{bignum_test.cc => exactfloat_internal_test.cc} (77%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 889698ba..f333766e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -211,6 +211,7 @@ add_library(s2 src/s2/util/coding/coder.cc src/s2/util/coding/varint.cc src/s2/util/math/exactfloat/exactfloat.cc + src/s2/util/math/exactfloat/exactfloat_internal.cc src/s2/util/math/mathutil.cc src/s2/util/units/length-units.cc) @@ -223,6 +224,7 @@ if (GOOGLETEST_ROOT) src/s2/thread_testing.cc) endif() + target_link_libraries( s2 absl::absl_vlog_is_on @@ -437,6 +439,8 @@ if(S2_ENABLE_INSTALL) DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/testing") install(FILES src/s2/util/bitmap/bitmap.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/bitmap") + install(FILES src/s2/util/bits/bits.h + DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/bits") install(FILES src/s2/util/coding/coder.h src/s2/util/coding/varint.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/coding") @@ -447,7 +451,6 @@ if(S2_ENABLE_INSTALL) src/s2/util/gtl/dense_hash_set.h src/s2/util/gtl/densehashtable.h src/s2/util/gtl/hashtable_common.h - src/s2/util/gtl/requires.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/gtl") install(FILES src/s2/util/hash/mix.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/hash") @@ -456,6 +459,7 @@ if(S2_ENABLE_INSTALL) src/s2/util/math/vector.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/math") install(FILES src/s2/util/math/exactfloat/exactfloat.h + src/s2/util/math/exactfloat/exactfloat_internal.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/math/exactfloat") install(FILES src/s2/util/units/length-units.h src/s2/util/units/physical-units.h @@ -485,6 +489,8 @@ if(S2_ENABLE_INSTALL) endif() # S2_ENABLE_INSTALL if (BUILD_TESTS) + add_library(test_main "src/s2/test_main.cc") + if (NOT GOOGLETEST_ROOT) message(FATAL_ERROR "BUILD_TESTS requires GOOGLETEST_ROOT") endif() @@ -610,14 +616,13 @@ if (BUILD_TESTS) src/s2/s2shapeutil_shape_edge_id_test.cc src/s2/s2shapeutil_visit_crossing_edge_pairs_test.cc src/s2/s2text_format_test.cc - src/s2/util/math/exactfloat/bignum_test.cc + src/s2/util/math/exactfloat/exactfloat_internal_test.cc src/s2/s2validation_query_test.cc src/s2/s2wedge_relations_test.cc src/s2/s2winding_operation_test.cc src/s2/s2wrapped_shape_test.cc src/s2/sequence_lexicon_test.cc - src/s2/value_lexicon_test.cc - src/s2/util/math/exactfloat/exactfloat_test.cc) + src/s2/value_lexicon_test.cc) enable_testing() @@ -627,6 +632,9 @@ if (BUILD_TESTS) target_link_libraries( ${test} s2testing s2 + test_main + benchmark + profiler absl::base absl::btree absl::check diff --git a/README.md b/README.md index d39c4a54..e900b684 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,19 @@ Disable building of shared libraries with `-DBUILD_SHARED_LIBS=OFF`. Enable the python interface with `-DWITH_PYTHON=ON`. +# For Testing + +If BUILD_TESTS is 'on' (the default), and benchmarks are enabled, then OpenSSL +must be available to build some tests: + +* [OpenSSL](https://github.com/openssl/openssl) (for its bignum library) + + +If OpenSSL is installed in a non-standard location set `OPENSSL_ROOT_DIR` +before running configure, for example on macOS: +``` +OPENSSL_ROOT_DIR=/opt/homebrew/Cellar/openssl@3/3.1.0 cmake -DCMAKE_PREFIX_PATH=/opt/homebrew -DCMAKE_CXX_STANDARD=17 +``` ## Installing diff --git a/src/s2/util/math/exactfloat/BUILD b/src/s2/util/math/exactfloat/BUILD index bfab4c2d..8642a1d4 100644 --- a/src/s2/util/math/exactfloat/BUILD +++ b/src/s2/util/math/exactfloat/BUILD @@ -2,8 +2,8 @@ package(default_visibility = ["//visibility:public"]) cc_library( name = "exactfloat", - srcs = ["exactfloat.cc"], - hdrs = ["exactfloat.h", "bignum.h"], + srcs = ["exactfloat.cc", "exactfloat_internal.cc"], + hdrs = ["exactfloat.h", "exactfloat_internal.h"], deps = [ "//s2/base:port", "//s2/base:logging", diff --git a/src/s2/util/math/exactfloat/bignum.h b/src/s2/util/math/exactfloat/bignum.h deleted file mode 100644 index e53e0d46..00000000 --- a/src/s2/util/math/exactfloat/bignum.h +++ /dev/null @@ -1,948 +0,0 @@ -// Copyright 2025 Google LLC -// Author: smcallis@google.com (Sean McAllister) -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include - -#include "absl/container/inlined_vector.h" -#include "absl/log/absl_check.h" -#include "absl/numeric/bits.h" -#include "absl/numeric/int128.h" -#include "absl/strings/ascii.h" -#include "absl/strings/str_format.h" - -namespace internal { - -// Most of the STL cannot be overloaded per the spec, so we need to roll our own -// wrappers that will also work with absl::int128. - -template -constexpr bool IsInt = std::numeric_limits::is_integer; - -template -constexpr bool IsSigned() { - if constexpr (std::is_same_v) { - return false; - } - - if constexpr (std::is_same_v) { - return true; - } - - return std::is_signed_v; -} - -template -constexpr auto InferUnsigned() { - if constexpr (std::is_same_v || - std::is_same_v) { - return absl::uint128{}; - } else { - return std::make_unsigned_t{}; - } -} - -template -using MakeUnsigned = decltype(InferUnsigned()); - -} // namespace internal - -class Bignum { - private: - using Bigit = uint64_t; - - static constexpr int kKaratsubaThreshold = 32; - static constexpr int kBigitBits = std::numeric_limits::digits; - - public: - Bignum() = default; - - // Constructs a bignum from an integral value (signed or unsigned). - template >> - explicit Bignum(T value) { - using UT = internal::MakeUnsigned; - - if (value == 0) { - return; - } - - sign_ = +1; - if constexpr (internal::IsSigned()) { - sign_ = (value < 0) ? -1 : +1; - } - - // Get magnitude of value, handle minimum value of T cleanly. - UT mag = static_cast(value); - if constexpr (internal::IsSigned()) { - if (value < 0) { - mag = UT(0) - mag; - } - } - - // Pack the magnitude into bigits. - if constexpr (std::numeric_limits::digits <= kBigitBits) { - bigits_.push_back(static_cast(mag)); - } else { - while (mag) { - bigits_.push_back(static_cast(mag)); - mag >>= kBigitBits; - } - } - } - - // Constructs a bignum from an ASCII string containing decimal digits. - // - // The input string must only have an optional leading +/- and decimal digits. - // Any other characters will yield std::nullopt. - static std::optional FromString(absl::string_view s) { - // We can fit ~10^19 into a uint64_t. - constexpr int kMaxChunkDigits = 19; - - // NOTE: We use a simple multiply-and-add (aka Horner's) method here for the - // sake of simplicity. This isn't the fastest algorithm, being quadratic in - // the number of chunks the input has. If we use divide and conquer approach - // or an FFT based multiply we could probably make this ~O(n^1.5) or - // semi-linear. - - // Precomputed powers of 10. - static constexpr uint64_t kPow10[20] = {1ull, - 10ull, - 100ull, - 1000ull, - 10000ull, - 100000ull, - 1000000ull, - 10000000ull, - 100000000ull, - 1000000000ull, - 10000000000ull, - 100000000000ull, - 1000000000000ull, - 10000000000000ull, - 100000000000000ull, - 1000000000000000ull, - 10000000000000000ull, - 100000000000000000ull, - 1000000000000000000ull, - 10000000000000000000ull}; - - Bignum out; - if (s.empty()) { - return out; - } - - // Reserve space for bigits. - out.bigits_.reserve((s.size() + kMaxChunkDigits - 1) / kMaxChunkDigits); - - int sign = +1; - uint64_t chunk = 0; - int clen = 0; - - // Finish processing the current chunk. - auto FlushChunk = [&]() { - if (clen) { - out.MulAddSmall(kPow10[clen], chunk); - chunk = 0; - clen = 0; - } - }; - - // Consume optional +/- at the front. - int start = 0; - if ((s[0] == '+' || s[0] == '-')) { - sign = (s[0] == '-') ? -1 : +1; - ++start; - } - - bool seen_digit = false; - for (char c : s.substr(start)) { - if (!absl::ascii_isdigit(c)) { - return std::nullopt; - } - - // Accumulate digit into the local 64-bit chunk. Skip leading zeros. - uint64_t digit = static_cast(c - '0'); - if (!seen_digit && digit == 0) { - continue; - } - seen_digit = true; - - chunk = 10 * chunk + digit; - ++clen; - - if (clen == kMaxChunkDigits) { - FlushChunk(); - } - } - FlushChunk(); - - out.NormalizeSign(sign); - return out; - } - - // Formats the bignum as a decimal integer into an abseil sink. - template - friend void AbslStringify(Sink& sink, const Bignum& b) { - if (b.zero()) { - sink.Append("0"); - return; - } - - // Sign - if (b.negative()) { - sink.Append("-"); - } - - // Work on a copy of the magnitude. - Bignum copy = b; - copy.sign_ = 1; - - // Repeatedly divide and modulo by 10^19 to get decimal chunks. - static constexpr uint64_t kBase = 10000000000000000000ull; - absl::InlinedVector chunks; - - while (!copy.zero()) { - absl::uint128 rem = 0; - for (int i = static_cast(copy.bigits_.size()) - 1; i >= 0; --i) { - absl::uint128 acc = (rem << 64) + copy.bigits_[i]; - uint64_t quot = static_cast(acc / kBase); - rem = acc - absl::uint128(quot) * kBase; - copy.bigits_[i] = quot; - } - - copy.Normalize(); - chunks.push_back(static_cast(rem)); - } - ABSL_DCHECK(!chunks.empty()); - - // Emit most significant chunk without zero padding. - absl::Format(&sink, "%d", chunks.back()); - - // Emit remaining chunks as fixed-width 19-digit zero-padded blocks. - for (int i = static_cast(chunks.size()) - 2; i >= 0; --i) { - absl::Format(&sink, "%019d", chunks[i]); - } - } - - friend std::ostream& operator<<(std::ostream& os, const Bignum& b) { - return os << absl::StrFormat("%v", b); - } - - friend std::ostream& operator<<(std::ostream& os, - const std::optional& b) { - if (!b) { - return os << "[nullopt]"; - } - return os << *b; - } - - // Returns true if bignum can be stored in T without truncation or overflow. - template >> - bool Compatible() const { - using UT = internal::MakeUnsigned; - - if (sign_ == 0) { - return true; - } - - // Maximum number of bits that could fit in the output type. - constexpr int kTBitWidth = std::numeric_limits::digits; - constexpr int kMaxBigits = (kTBitWidth + (kBigitBits - 1)) / kBigitBits; - - // Fast reject if the bignum couldn't conceivably fit. - if (bigits_.size() > kMaxBigits) { - return false; - } - - // Unsigned type T can hold the value iff the value is non-negative and the - // bitwidth is <= the maximum bit width of the type. - if constexpr (!internal::IsSigned()) { - if (negative()) { - return false; - } - return BitWidth() <= kTBitWidth; - } - - // T is signed and our bignum isn't zero. - ABSL_DCHECK(internal::IsSigned() && !zero()); - - if (positive()) { - return BitWidth() <= (kTBitWidth - 1); - } else /* negative() */ { - // Magnitude must fit in negative value. If the value is negative and the - // same bit width as the output type, the only valid value is -2^(k-1). - if (BitWidth() == kTBitWidth) { - return IsPow2(kTBitWidth - 1); - } - return BitWidth() < kTBitWidth; - } - } - - // Cast to an integral type T unconditionally. Use Compatible() or - // Convert to perform conversion with bounds checking. - template >> - T Cast() const { - using UT = internal::MakeUnsigned; - - constexpr int kTBitWidth = std::numeric_limits::digits; - - if (empty()) { - return 0; - } - - // Grab the bottom bits into an unsigned value. - UT residue = 0; - for (size_t i = 0; i < bigits_.size(); ++i) { - const int shift = i * kBigitBits; - if (shift >= kTBitWidth) { - break; - } - - const int room = kTBitWidth - shift; - UT chunk = static_cast(bigits_[i]); - if (room < kBigitBits && room < std::numeric_limits::digits) { - chunk &= (UT(1) << room) - UT(1); - } - residue |= (chunk << shift); - } - - // Compute two's complement of the residue if value is negative. - if (negative()) { - residue = UT(0) - residue; - } - - return static_cast(residue); - } - - // Casts the value to the given type if it fits, otherwise std::nullopt. - template >> - std::optional Convert() const { - if (!Compatible()) { - return std::nullopt; - } - return Cast(); - } - - // // Creates a Bignum by parsing a decimal representation from a string. - // explicit Bignum(absl::string_view dec) { ParseDecimal_(dec); } - - // Returns the number of bits required to represent the bignum. - int BitWidth() const { - ABSL_DCHECK(Normalized()); - if (empty()) { - return 0; - } - - // Bit width is the bits in the least significant bigits + bit width of the - // most significant word. - const int msw_width = (kBigitBits - absl::countl_zero(bigits_.back())); - const int lsw_width = (bigits_.size() - 1) * kBigitBits; - return msw_width + lsw_width; - } - - // Returns the number of consecutive 0 bits in the value, starting from the - // least significant bit. - int CountrZero() const { - if (zero()) { - return 0; - } - - int nzero = 0; - for (Bigit bigit : bigits_) { - if (bigit == 0) { - nzero += kBigitBits; - } else { - nzero += absl::countr_zero(bigit); - break; - } - } - return nzero; - } - - // Returns true if the n-th bit of the number's magnitude is set. - bool Bit(int nbit) const { - ABSL_DCHECK_GE(nbit, 0); - if (zero()) { - return false; - } - - const int digit = nbit / kBigitBits; - const int shift = nbit % kBigitBits; - - if (digit >= size()) { - return false; - } - - return ((bigits_[digit] >> shift) & 0x1) != 0; - } - - // Clears this bignum and sets it to zero. - Bignum& SetZero() { - sign_ = 0; - bigits_.clear(); - return *this; - } - - // Unconditionally makes the sign of this bignum negative. - Bignum& SetNegative() { - sign_ = -1; - return *this; - } - - // Unconditionally makes the sign of this bignum positive. - Bignum& SetPositive() { - sign_ = +1; - return *this; - } - - // Unconditionally set the sign of this bignum to match the sign of the - // argument. If the argument is zero, set the bignum to zero. - Bignum& SetSign(int sign) { - if (sign == 0) { - return SetZero(); - } - - if (sign < 0) { - return SetNegative(); - } - return SetPositive(); - } - - // Returns true if the number is zero. - bool zero() const { // - return sign_ == 0; - } - - // Returns true if the number is greater than zero. - bool positive() const { // - return sign_ > 0; - } - - // Returns true if the number is less than zero. - bool negative() const { // - return sign_ < 0; - } - - // Returns true if the number is odd (least significant bit is 1). - bool odd() const { return Bit(0); } - - // Returns true if the number is even (least significant bit is 0). - bool even() const { return !odd(); } - - bool operator==(const Bignum& b) const { - return sign_ == b.sign_ && bigits_ == b.bigits_; - } - - bool operator!=(const Bignum& b) const { return !(*this == b); } - - bool operator<(const Bignum& b) const { return Compare(b) < 0; } - - bool operator<=(const Bignum& b) const { return Compare(b) <= 0; } - - bool operator>(const Bignum& b) const { return Compare(b) > 0; } - - bool operator>=(const Bignum& b) const { return Compare(b) >= 0; } - - Bignum operator+() const { return *this; } - - Bignum operator-() const { - Bignum result = *this; - result.sign_ = -result.sign_; - return result; - } - - Bignum& operator+=(const Bignum& b) { - if (b.zero()) { - return *this; - } - - if (zero()) { - *this = b; - return *this; - } - - if (sign_ == b.sign_) { - // Same sign: - // (+a) + (+b) == +(a + b) - // (-a) + (-b) == -(a + b) - AddAbs(b); - } else { - if (CmpAbs(b) >= 0) { - // |a| >= |b|, so a - b is same sign as a. - SubAbsGe(b); - NormalizeSign(sign_); - } else { - // |a| < |b|, so a - b is same sign as b. - SubAbsLt(b); - NormalizeSign(b.sign_); - } - } - - return *this; - } - - Bignum& operator-=(const Bignum& b) { - if (this == &b) { - bigits_.clear(); - sign_ = 0; - return *this; - } - - if (b.zero()) { - return *this; - } - - if (zero()) { - return *this = -b; - } - - if (sign_ != b.sign_) { - AddAbs(b); - } else { - if (CmpAbs(b) >= 0) { - SubAbsGe(b); - NormalizeSign(sign_); - } else { - SubAbsLt(b); - NormalizeSign(-sign_); - } - } - - return *this; - } - - // Left-shift the bignum by nbit. - Bignum& operator<<=(int nbit) { - ABSL_DCHECK_GE(nbit, 0); - if (zero() || nbit == 0) { - return *this; - } - - const int nbigit = nbit / kBigitBits; - const int nrem = nbit % kBigitBits; - - // First, handle the whole-bigit shift by inserting zeros. - bigits_.insert(bigits_.begin(), nbigit, 0); - - // Then, handle the within-bigit shift, if any. - if (nrem != 0) { - Bigit carry = 0; - for (size_t i = 0; i < bigits_.size(); ++i) { - const Bigit old_val = bigits_[i]; - bigits_[i] = (old_val << nrem) | carry; - carry = old_val >> (kBigitBits - nrem); - } - - if (carry) { - bigits_.push_back(carry); - } - } - - return *this; - } - - // Right-shift the bignum by nbit. - Bignum& operator>>=(int nbit) { - ABSL_DCHECK_GE(nbit, 0); - if (zero() || nbit == 0) { - return *this; - } - - // Shifting by more than the bit width results in zero. - if (nbit >= BitWidth()) { - bigits_.clear(); - sign_ = 0; - return *this; - } - - const int nbigit = nbit / kBigitBits; - const int nrem = nbit % kBigitBits; - - // First, handle the whole-bigit shift by removing bigits. - bigits_.erase(bigits_.begin(), bigits_.begin() + nbigit); - - // Then, handle the within-bigit shift, if any. - if (nrem != 0) { - Bigit carry = 0; - for (int i = static_cast(bigits_.size()) - 1; i >= 0; --i) { - const Bigit old_val = bigits_[i]; - bigits_[i] = (old_val >> nrem) | carry; - carry = old_val << (kBigitBits - nrem); - } - } - - // Result might be smaller or zero, so normalize. - NormalizeSign(sign_); - return *this; - } - - Bignum& operator*=(const Bignum& b) { - if (zero() || b.zero()) { - bigits_.clear(); - sign_ = 0; - return *this; - } - - const int new_sign = sign_ * b.sign_; - bigits_ = MulAbs(bigits_, b.bigits_); - NormalizeSign(new_sign); - return *this; - } - - // Raise this value to the given power, which must be non-negative. - Bignum Pow(int32_t pow) const { - ABSL_DCHECK_GE(pow, 0); - - // Anything to the zero-th power is 1 (including zero). - if (pow == 0) { - return Bignum(1); - } - - if (zero()) { - return Bignum(0); - } - - if (*this == Bignum(1)) { - return Bignum(1); - } - - if (*this == Bignum(-1)) { - return (pow % 2 != 0) ? Bignum(-1) : Bignum(1); - } - - // Core algorithm: Exponentiation by squaring. - Bignum result(1); - Bignum base = *this; // A mutable copy of the base. - uint32_t upow = static_cast(pow); - - while (upow > 0) { - if (upow & 1) { // If current exponent bit is 1, multiply into result. - result *= base; - } - base *= base; - upow >>= 1; - } - - return result; - } - - friend Bignum operator*(Bignum a, const Bignum& b) { return a *= b; } - - friend Bignum operator+(Bignum a, const Bignum& b) { return a += b; } - - friend Bignum operator-(Bignum a, const Bignum& b) { return a -= b; } - - friend Bignum operator<<(Bignum a, int nbit) { return a <<= nbit; } - - friend Bignum operator>>(Bignum a, int nbit) { return a >>= nbit; } - - private: - // Construct a Bignum from bigits and an optional sign bit. - explicit Bignum(absl::Span bigits, int sign = +1) { - if (bigits.empty()) { - return; - } - - bigits_.assign(bigits.begin(), bigits.end()); - NormalizeSign(sign); - } - - // Returns the number of bigits in this bignum. - int size() const { // - return bigits_.size(); - } - - // Returns true if this value has no digits. - bool empty() const { // - return bigits_.empty(); - } - - // Compare to another bignum, returns -1, 0, +1. - int Compare(const Bignum& b) const { - if (sign_ != b.sign_) { - return sign_ < b.sign_ ? -1 : 1; - } - - // Signs are equal, are they both zero? - if (sign_ == 0) { - return 0; - } - - // Signs are equal and non-zero, compare magnitude. - return positive() ? CmpAbs(b) : -CmpAbs(b); - } - - // Compute value = value * mul + add where mul, add ≤ 10^19. We can accumulate - // using a 128 bit integer in a single pass over the bigits for small terms. - void MulAddSmall(uint64_t mul, uint64_t add) { - absl::uint128 carry = add; - for (size_t i = 0, n = bigits_.size(); i < n; ++i) { - absl::uint128 prod = absl::uint128(bigits_[i]) * mul + carry; - bigits_[i] = absl::Uint128Low64(prod); - carry = absl::Uint128High64(prod); - } - - if (carry != 0) { - bigits_.push_back(absl::Uint128Low64(carry)); - } - } - - // Multiplies two bigit operands. Uses either simple quadratic multiplication - // (SimpleMulAbs) or divide-and-conquer multiplication (KaratsubaMulAbs) based - // on the size of the operands. - static absl::InlinedVector MulAbs( - absl::Span a, absl::Span b) { - // Fast path for single-bigit multiplication. - if (a.size() == 1 && b.size() == 1) { - absl::uint128 prod = absl::uint128(a[0]) * b[0]; - const uint64_t lo = absl::Uint128Low64(prod); - const uint64_t hi = absl::Uint128High64(prod); - if (hi == 0) { - return {lo}; - } - return {lo, hi}; - } - - if (a.size() < kKaratsubaThreshold || b.size() < kKaratsubaThreshold) { - return SimpleMulAbs(a, b); - } - return KaratsubaMulAbs(a, b); - } - - // Performs simple quadratic long multiplication between two sets of bigits. - static absl::InlinedVector SimpleMulAbs( - absl::Span a, absl::Span b) { - if (a.empty() || b.empty()) { - return {}; - } - - absl::InlinedVector result(a.size() + b.size(), 0); - for (size_t i = 0; i < a.size(); ++i) { - if (a[i] == 0) { - continue; - } - - absl::uint128 carry = 0; - for (size_t j = 0; j < b.size(); ++j) { - absl::uint128 prod = absl::uint128(a[i]) * b[j] + result[i + j] + carry; - result[i + j] = absl::Uint128Low64(prod); - carry = absl::Uint128High64(prod); - } - - // Propagate final carry. This can ripple through multiple bigits. - for (size_t k = i + b.size(); carry != 0 && k < result.size(); ++k) { - absl::uint128 sum = absl::uint128(result[k]) + carry; - result[k] = absl::Uint128Low64(sum); - carry = absl::Uint128High64(sum); - } - ABSL_DCHECK_EQ(carry, 0); // Result vector must be large enough. - } - - // Normalize vector by removing leading zeros. - while (!result.empty() && result.back() == 0) { - result.pop_back(); - } - return result; - } - - // Repeatedly divides a multiplication in half and recurses, stitching - // results back together to get the final result. - static absl::InlinedVector KaratsubaMulAbs( - absl::Span a, absl::Span b) { - // Base case is handled by MulAbs dispatcher, so we only handle recursion. - const size_t n = std::max(a.size(), b.size()); - const size_t m = (n + 1) / 2; - - const Bignum a0(a.subspan(0, std::min(m, a.size()))); - const Bignum a1(a.size() > m ? a.subspan(m) : absl::Span()); - const Bignum b0(b.subspan(0, std::min(m, b.size()))); - const Bignum b1(b.size() > m ? b.subspan(m) : absl::Span()); - - Bignum z2; - z2.bigits_ = MulAbs(a1.bigits_, b1.bigits_); - z2.NormalizeSign(+1); - - Bignum z0; - z0.bigits_ = MulAbs(a0.bigits_, b0.bigits_); - z0.NormalizeSign(+1); - - Bignum z1; - z1.bigits_ = MulAbs((a0 + a1).bigits_, (b0 + b1).bigits_); - z1.NormalizeSign(+1); - - z1 -= z2; - z1 -= z0; - - // Recombine: result = (z2 << 2*m) + (z1 << m) + z0 - z2 <<= (2 * m * kBigitBits); - z1 <<= (m * kBigitBits); - - Bignum result = z2 + z1 + z0; - return result.bigits_; - } - - // Drop leading zero bigits. - void Normalize() { - while (!empty() && bigits_.back() == 0) { - bigits_.pop_back(); - } - - if (empty()) { - sign_ = 0; - } - } - - // Drop leading zero bigits and canonicalize sign. - void NormalizeSign(int sign) { - Normalize(); - sign_ = empty() ? 0 : sign; - } - - // Returns true if the bignum is in normal form (no extra leading zeros). - bool Normalized() const { // - return bigits_.empty() || bigits_.back() != 0; - } - - // Returns true if the bignum is the given power of two. - bool IsPow2(int pow2) const { - const int bigits = pow2 / kBigitBits; - if (bigits_.size() != bigits + 1) { - return false; - } - - // Verify lower words are zero. - for (int i = 0; i < bigits; ++i) { - if (bigits_[i] != 0) { - return false; - } - } - - // Check final word is power of two. - pow2 -= bigits * kBigitBits; - ABSL_DCHECK_LT(pow2, kBigitBits); - return bigits_.back() == (Bigit(1) << pow2); - } - - // Compares magnitude with another bignum, returning -1, 0, or +1. - int CmpAbs(const Bignum& b) const; - - // Adds another bignum to this bignum in place. - void AddAbs(const Bignum& b); - - // In-place subtraction: *this = |*this| - |b|, assuming |*this| >= |b|. - void SubAbsGe(const Bignum& b); - - // In-place subtraction: *this = |b| - |*this|, assuming |*this| < |b|. - void SubAbsLt(const Bignum& b); - - absl::InlinedVector bigits_; - char sign_ = 0; -}; - -//////////////////////////////////////////////////////////////////////////////// -// Implementation Details -//////////////////////////////////////////////////////////////////////////////// - -inline int Bignum::CmpAbs(const Bignum& b) const { - if (size() != b.size()) { - return size() < b.size() ? -1 : +1; - } - - for (int i = size() - 1; i >= 0; --i) { - if (bigits_[i] != b.bigits_[i]) { - return bigits_[i] < b.bigits_[i] ? -1 : +1; - } - } - - return 0; -} - -inline void Bignum::AddAbs(const Bignum& b) { - // Grow if needed. - const bool a_longer = size() > b.size(); - const size_t min_size = std::min(size(), b.size()); - const size_t max_size = std::max(size(), b.size()); - bigits_.resize(max_size, 0); - - // Add common parts. - absl::uint128 sum; - absl::uint128 carry = 0; - for (size_t i = 0; i < min_size; ++i) { - sum = absl::uint128(bigits_[i]) + b.bigits_[i] + carry; - bigits_[i] = absl::Uint128Low64(sum); - carry = absl::Uint128High64(sum); - } - - // Propagate carry through the longer operand. - const auto* longer = a_longer ? this : &b; - for (size_t i = min_size; i < max_size; ++i) { - sum = absl::uint128(longer->bigits_[i]) + carry; - bigits_[i] = absl::Uint128Low64(sum); - carry = absl::Uint128High64(sum); - } - - if (carry) { - bigits_.push_back(absl::Uint128Low64(carry)); - } -} - -inline void Bignum::SubAbsGe(const Bignum& b) { - ABSL_DCHECK_GE(CmpAbs(b), 0); - uint64_t borrow = 0; - - size_t i = 0; - for (; i < b.size(); ++i) { - const uint64_t d1 = bigits_[i]; - const uint64_t d2 = b.bigits_[i]; - const uint64_t diff = d1 - d2 - borrow; - borrow = (d1 < d2) || (borrow && d1 == d2); - bigits_[i] = diff; - } - - for (; borrow && i < bigits_.size(); ++i) { - borrow = (bigits_[i] == 0); - bigits_[i]--; - } - ABSL_DCHECK(!borrow); - Normalize(); -} - -inline void Bignum::SubAbsLt(const Bignum& b) { - ABSL_DCHECK_LT(CmpAbs(b), 0); - uint64_t borrow = 0; - const size_t n_this = bigits_.size(); - const size_t n_b = b.size(); - bigits_.resize(n_b); - - size_t i = 0; - for (; i < n_this; ++i) { - const uint64_t d1 = b.bigits_[i]; - const uint64_t d2 = bigits_[i]; - const uint64_t diff = d1 - d2 - borrow; - borrow = (d1 < d2) || (borrow && d1 == d2); - bigits_[i] = diff; - } - - for (; i < n_b; ++i) { - const uint64_t d1 = b.bigits_[i]; - const uint64_t diff = d1 - borrow; - borrow = (borrow && d1 == 0); - bigits_[i] = diff; - } - ABSL_DCHECK(!borrow); - Normalize(); -} diff --git a/src/s2/util/math/exactfloat/exactfloat.cc b/src/s2/util/math/exactfloat/exactfloat.cc index 9010b29d..634b46ed 100644 --- a/src/s2/util/math/exactfloat/exactfloat.cc +++ b/src/s2/util/math/exactfloat/exactfloat.cc @@ -94,11 +94,11 @@ ExactFloat ExactFloat::NaN() { return r; } -int ExactFloat::prec() const { return bn_.BitWidth(); } +int ExactFloat::prec() const { return bit_width(bn_); } int ExactFloat::exp() const { ABSL_DCHECK(is_normal()); - return bn_exp_ + bn_.BitWidth(); + return bn_exp_ + bit_width(bn_); } void ExactFloat::set_zero(int sign) { @@ -130,7 +130,7 @@ double ExactFloat::ToDouble() const { } double ExactFloat::ToDoubleHelper() const { - ABSL_DCHECK_LE(bn_.BitWidth(), kDoubleMantissaBits); + ABSL_DCHECK_LE(bit_width(bn_), kDoubleMantissaBits); if (!is_normal()) { if (is_zero()) return copysign(0, sign_); if (is_inf()) { @@ -138,7 +138,7 @@ double ExactFloat::ToDoubleHelper() const { } return std::copysign(std::numeric_limits::quiet_NaN(), sign_); } - auto opt_mantissa = bn_.Convert(); + auto opt_mantissa = bn_.ConvertTo(); ABSL_DCHECK(opt_mantissa.has_value()); uint64_t d_mantissa = opt_mantissa.value(); // We rely on ldexp() to handle overflow and underflow. (It will return a @@ -194,7 +194,7 @@ ExactFloat ExactFloat::RoundToPowerOf2(int bit_exp, RoundingMode mode) const { if (bn_.Bit(shift - 1)) increment = true; } else if (mode == kRoundAwayFromZero) { // Increment unless all discarded bits are zero. - if (bn_.CountrZero() < shift) increment = true; + if (countr_zero(bn_) < shift) increment = true; } else { ABSL_DCHECK_EQ(mode, kRoundTiesToEven); // Let "w/xyz" denote a mantissa where "w" is the lowest kept bit and @@ -204,7 +204,7 @@ ExactFloat ExactFloat::RoundToPowerOf2(int bit_exp, RoundingMode mode) const { // 1/10* -> Increment (fraction = 1/2, kept part odd) // ./1.*1.* -> Increment (fraction > 1/2) if (bn_.Bit(shift - 1) && - ((bn_.Bit(shift) || bn_.CountrZero() < shift - 1))) { + ((bn_.Bit(shift) || countr_zero(bn_) < shift - 1))) { increment = true; } } @@ -436,7 +436,7 @@ ExactFloat ExactFloat::SignedSum(int a_sign, const ExactFloat* a, int b_sign, r.bn_ = b->bn_ - a_bn; r.sign_ = b_sign; } - if (r.bn_.zero()) { + if (r.bn_.is_zero()) { r.sign_ = +1; } } @@ -450,14 +450,14 @@ void ExactFloat::Canonicalize() { // Underflow/overflow occurs if exp() is not in [kMinExp, kMaxExp]. // We also convert a zero mantissa to signed zero. int my_exp = exp(); - if (my_exp < kMinExp || bn_.zero()) { + if (my_exp < kMinExp || bn_.is_zero()) { set_zero(sign_); } else if (my_exp > kMaxExp) { set_inf(sign_); - } else if (bn_.even()) { + } else if (bn_.is_even()) { // Remove any low-order zero bits from the mantissa. - ABSL_DCHECK(!bn_.zero()); - int shift = bn_.CountrZero(); + ABSL_DCHECK(!bn_.is_zero()); + int shift = countr_zero(bn_); if (shift > 0) { bn_ >>= shift; bn_exp_ += shift; @@ -610,7 +610,7 @@ T ExactFloat::ToInteger(RoundingMode mode) const { if (!r.is_inf()) { // If the unsigned value has more than 63 bits it is always clamped. if (r.exp() < 64) { - auto opt_value = r.bn_.Convert(); + auto opt_value = r.bn_.ConvertTo(); ABSL_DCHECK(opt_value.has_value()); int64_t value = static_cast(opt_value.value()) << r.bn_exp_; if (r.sign_ < 0) value = -value; diff --git a/src/s2/util/math/exactfloat/exactfloat.h b/src/s2/util/math/exactfloat/exactfloat.h index e0867976..f77939a0 100644 --- a/src/s2/util/math/exactfloat/exactfloat.h +++ b/src/s2/util/math/exactfloat/exactfloat.h @@ -116,7 +116,7 @@ #include #include -#include "s2/util/math/exactfloat/bignum.h" +#include "s2/util/math/exactfloat/exactfloat_internal.h" class ExactFloat { public: @@ -508,6 +508,8 @@ class ExactFloat { // - sign_ is either +1 or -1 // - bn_ is a Bignum with a positive value // - bn_exp_ is the base-2 exponent applied to bn_. + // + // Bignum supports negative values so that subtraction can be supported. int32_t sign_ = 1; int32_t bn_exp_ = kExpZero; Bignum bn_; diff --git a/src/s2/util/math/exactfloat/exactfloat_internal.cc b/src/s2/util/math/exactfloat/exactfloat_internal.cc new file mode 100644 index 00000000..eae68422 --- /dev/null +++ b/src/s2/util/math/exactfloat/exactfloat_internal.cc @@ -0,0 +1,753 @@ +#include "s2/util/math/exactfloat/exactfloat_internal.h" + +// Threshold for fallback to simple multiplication, determined empirically. +static constexpr int kKaratsubaThreshold = 64; + +// Avoid the dependent name clutter. +using Bigit = typename Bignum::Bigit; + +static Bigit MulAdd( // + absl::Span out, absl::Span a, Bigit b, Bigit c); + +std::optional Bignum::FromString(absl::string_view s) { + // A chunk is up to 19 decimal digits, which can always fit into a Bigit. + constexpr int kMaxChunkDigits = std::numeric_limits::digits10; + + // NOTE: We use a simple multiply-and-add (aka Horner's) method here for the + // sake of simplicity. This isn't the fastest algorithm, being quadratic in + // the number of chunks the input has. If we use divide and conquer approach + // or an FFT based multiply we could probably make this ~O(n^1.5) or + // semi-linear. + + // Precomputed powers of 10. + static const auto kPow10 = []() { + std::array out; + + Bigit value = 1; + for (int i = 0; i < out.size(); ++i) { + out[i] = value; + value = value * 10; + } + return out; + }(); + + Bignum out; + if (s.empty()) { + return out; + } + + // Reserve space for bigits. + out.bigits_.reserve((s.size() + kMaxChunkDigits - 1) / kMaxChunkDigits); + + int sign = +1; + Bigit chunk = 0; + int clen = 0; + + // Finish processing the current chunk. + auto FlushChunk = [&]() { + if (clen) { + auto outspan = absl::MakeSpan(out.bigits_); + if (Bigit carry = MulAdd(outspan, outspan, kPow10[clen], chunk)) { + out.bigits_.emplace_back(carry); + } + chunk = 0; + clen = 0; + } + }; + + // Consume optional +/- at the front. + int start = 0; + if ((s[0] == '+' || s[0] == '-')) { + sign = (s[0] == '-') ? -1 : +1; + ++start; + } + + bool seen_digit = false; + for (char c : s.substr(start)) { + if (!absl::ascii_isdigit(c)) { + return std::nullopt; + } + + // Accumulate digit into the local 64-bit chunk. Skip leading + // zeros. + uint64_t digit = static_cast(c - '0'); + if (!seen_digit && digit == 0) { + continue; + } + seen_digit = true; + + chunk = 10 * chunk + digit; + ++clen; + + if (clen == kMaxChunkDigits) { + FlushChunk(); + } + } + FlushChunk(); + + out.NormalizeSign(sign); + return out; +} + +int bit_width(const Bignum& a) { + ABSL_DCHECK(a.Normalized()); + if (a.empty()) { + return 0; + } + + // Bit width is the bits in the least significant bigits + bit width of + // the most significant word. + const int msw_width = + (Bigit::kBits - absl::countl_zero(a.bigits_.back().value_)); + const int lsw_width = (a.bigits_.size() - 1) * Bigit::kBits; + return msw_width + lsw_width; +} + +int countr_zero(const Bignum& a) { + if (a.is_zero()) { + return 0; + } + + int nzero = 0; + for (Bigit bigit : a.bigits_) { + if (bigit == 0) { + nzero += Bigit::kBits; + } else { + nzero += absl::countr_zero(static_cast(bigit)); + break; + } + } + return nzero; +} + +bool Bignum::Bit(int nbit) const { + ABSL_DCHECK_GE(nbit, 0); + if (is_zero()) { + return false; + } + + const int digit = nbit / Bigit::kBits; + const int shift = nbit % Bigit::kBits; + + if (digit >= size()) { + return false; + } + + return ((bigits_[digit] >> shift) & 0x1) != 0; +} + +Bignum Bignum::operator-() const { + Bignum result = *this; + result.sign_ = -result.sign_; + return result; +} + +Bignum& Bignum::operator<<=(int nbit) { + ABSL_DCHECK_GE(nbit, 0); + if (is_zero() || nbit == 0) { + return *this; + } + + const int nbigit = nbit / Bigit::kBits; + const int nrem = nbit % Bigit::kBits; + + // First, handle the whole-bigit shift by inserting zeros. + bigits_.insert(bigits_.begin(), nbigit, 0); + + // Then, handle the within-bigit shift, if any. + if (nrem != 0) { + Bigit carry = 0; + for (size_t i = 0; i < bigits_.size(); ++i) { + const Bigit old_val = bigits_[i]; + bigits_[i] = (old_val << nrem) | carry; + carry = old_val >> (Bigit::kBits - nrem); + } + + if (carry) { + bigits_.push_back(carry); + } + } + + return *this; +} + +Bignum& Bignum::operator>>=(int nbit) { + ABSL_DCHECK_GE(nbit, 0); + if (is_zero() || nbit == 0) { + return *this; + } + + // Shifting by more than the bit width results in zero. + if (nbit >= bit_width(*this)) { + return SetZero(); + } + + const int nbigit = nbit / Bigit::kBits; + const int nrem = nbit % Bigit::kBits; + + // First, handle the whole-bigit shift by removing bigits. + bigits_.erase(bigits_.begin(), bigits_.begin() + nbigit); + + // Then, handle the within-bigit shift, if any. + if (nrem != 0) { + Bigit carry = 0; + for (int i = static_cast(bigits_.size()) - 1; i >= 0; --i) { + const Bigit old_val = bigits_[i]; + bigits_[i] = (old_val >> nrem) | carry; + carry = old_val << (Bigit::kBits - nrem); + } + } + + // Result might be smaller or zero, so normalize. + NormalizeSign(sign_); + return *this; +} + +// Raise this value to the given power, which must be non-negative. +Bignum Bignum::Pow(int32_t pow) const { + ABSL_DCHECK_GE(pow, 0); + + // Anything to the zero-th power is 1 (including zero). + if (pow == 0) { + return Bignum(1); + } + + if (is_zero()) { + return Bignum(0); + } + + if (*this == Bignum(1)) { + return Bignum(1); + } + + if (*this == Bignum(-1)) { + return (pow % 2 != 0) ? Bignum(-1) : Bignum(1); + } + + // Core algorithm: Exponentiation by squaring. + Bignum result(1); + Bignum base = *this; // A mutable copy of the base. + uint32_t upow = static_cast(pow); + + while (upow > 0) { + if (upow & 1) { // If current exponent bit is 1, multiply into result. + result *= base; + } + base *= base; + upow >>= 1; + } + + return result; +} + +// Computes a + b + c and updates the carry. +static Bigit AddCarry(Bigit a, Bigit b, Bigit& c) { + auto sum = absl::uint128(a) + b + c; + c = absl::Uint128High64(sum); + return static_cast(sum); +} + +// Computes a - b - c and updates the borrow. +static Bigit SubBorrow(Bigit a, Bigit b, Bigit& borrow) { + Bigit diff = a - b - borrow; + borrow = (a < b) || (borrow && (a == b)); + return diff; +} + +// Computes a * b + c and updates the carry. +static Bigit MulCarry(Bigit a, Bigit b, Bigit& c) { + auto sum = absl::uint128(a) * b + c; + c = absl::Uint128High64(sum); + return static_cast(sum); +} + +// Computes out += a * b + c and updates the carry. +static void MulAddCarry(Bigit& out, Bigit a, Bigit b, Bigit& c) { + auto sum = absl::uint128(a) * b + c + out; + c = absl::Uint128High64(sum); + out = static_cast(sum); +} + +// Computes a += b in place. Returns the final carry (if any). +// NOTE: the a operand must be pre-expanded to fit b. +static Bigit AddInPlace(absl::Span a, absl::Span b) { + ABSL_DCHECK_GE(a.size(), b.size()); + + Bigit* pa = a.data(); + const Bigit* pb = b.data(); + + int left = b.size(); + Bigit carry = 0; + + // Dispatch four at a time to help loop unrolling. + while (left >= 4) { + for (int i = 0; i < 4; ++i) { + *pa = AddCarry(*pa, *pb++, carry); + ++pa; + --left; + } + } + + // Finish remainder. + while (left--) { + *pa = AddCarry(*pa, *pb++, carry); + ++pa; + } + + // Propagate carry through the rest of a. + int remaining = a.size() - b.size(); + while (carry && remaining--) { + *pa = AddCarry(*pa, 0, carry); + ++pa; + } + + return carry; +} + +static ssize_t AddInto( // + absl::Span dst, absl::Span a, + absl::Span b) { + const size_t max_size = std::max(a.size(), b.size()); + const size_t min_size = std::min(a.size(), b.size()); + ABSL_DCHECK_GE(dst.size(), max_size + 1); + + Bigit* pdst = dst.data(); + const Bigit* pa = a.data(); + const Bigit* pb = b.data(); + + // Add common parts. + Bigit carry = 0; + + // Dispatch four at a time to help loop unrolling. + int size = min_size; + int i = 0; + while (size >= i + 4) { + for (int j = 0; j < 4; ++j) { + pdst[i] = AddCarry(pa[i], pb[i], carry); + ++i; + } + } + + // Finish remainder of common parts. + for (; i < size; ++i) { + pdst[i] = AddCarry(pa[i], pb[i], carry); + } + + // Copy remaining digits from the longer operand and propagate carry. + auto longer = (a.size() > b.size()) ? a : b; + const Bigit* plonger = (a.size() > b.size()) ? pa : pb; + + // Dispatch four at a time for the remaining part. + size = longer.size(); + while (size >= i + 4) { + for (int j = 0; j < 4; ++j) { + pdst[i] = AddCarry(plonger[i], 0, carry); + ++i; + } + } + + // Finish remainder. + for (; i < size; ++i) { + pdst[i] = AddCarry(plonger[i], 0, carry); + } + + if (carry) { + pdst[i++] = carry; + return max_size + 1; + } + return max_size; +} + +// Computes a -= b. Returns the final borrow (if any). +// +// REQUIRES: |a| < |b|. +// NOTE: A must be pre-expanded to match the size of b. +static Bigit SubLtIp( // + absl::Span a, absl::Span b, ssize_t na) { + ABSL_DCHECK_EQ(a.size(), b.size()); + + Bigit* pa = a.data(); + const Bigit* pb = b.data(); + Bigit borrow = 0; + + // Dispatch four at a time to help loop unrolling. + int size = na; + int i = 0; + while (size >= i + 4) { + for (int j = 0; j < 4; ++j) { + pa[i] = SubBorrow(pb[i], pa[i], borrow); + ++i; + } + } + + // Finish remainder. + for (; i < na; ++i) { + pa[i] = SubBorrow(pb[i], pa[i], borrow); + } + + // Propagate borrow through the rest of b. + for (; borrow && i < b.size(); ++i) { + pa[i] = SubBorrow(pb[i], 0, borrow); + } + return borrow; +} + +// Computes a -= b. Returns the final borrow (if any). +// +// REQUIRES: |a| >= |b|. +static Bigit SubGeIp(absl::Span a, absl::Span b) { + ABSL_DCHECK_GE(a.size(), b.size()); + + Bigit borrow = 0; + + Bigit* pa = a.data(); + const Bigit* pb = b.data(); + + // Dispatch four at a time to help loop unrolling. + int size = b.size(); + int done = 0; + while (size >= done + 4) { + for (int i = 0; i < 4; ++i) { + pa[done] = SubBorrow(pa[done], pb[done], borrow); + ++done; + } + } + + // Finish remainder of subtraction. + for (; done < size; ++done) { + pa[done] = SubBorrow(pa[done], pb[done], borrow); + } + + // Propagate the borrow through a. + for (; borrow && done < a.size(); ++done) { + borrow = (a[done] == 0); + a[done]--; + } + return borrow; +} + +// Computes out[i] = a[i]*b + c +// +// Returns the final carry, if any. +Bigit MulAdd( // + absl::Span out, absl::Span a, Bigit b, Bigit c = 0) { + ABSL_DCHECK_GE(out.size(), a.size()); + + Bigit* pout = out.data(); + const Bigit* pa = a.data(); + + int left = a.size(); + + // Dispatch four at a time to help loop unrolling. + while (left >= 4) { + for (int i = 0; i < 4; ++i) { + *pout++ = MulCarry(*pa++, b, c); + --left; + } + } + + while (left--) { + *pout++ = MulCarry(*pa++, b, c); + } + return c; +} + +// Computes out[i] += a[i]*b in place. +// +// Returns the final carry, if any. +static Bigit MulAddIp( // + absl::Span out, absl::Span a, Bigit b) { + Bigit* pout = out.data(); + const Bigit* pa = a.data(); + + int left = a.size(); + + // Dispatch four at a time to help loop unrolling. + Bigit carry = 0; + while (left >= 4) { + for (int i = 0; i < 4; ++i) { + MulAddCarry(*pout++, *pa++, b, carry); + --left; + } + } + + // Finish remainder. + while (left--) { + MulAddCarry(*pout++, *pa++, b, carry); + } + + return carry; +} + +static void MulQuadratic( // + absl::Span out, // + absl::Span a, absl::Span b) { + ABSL_DCHECK_EQ(out.size(), a.size() + b.size()); + + // Make sure A is the longer of the two arguments. + if (a.size() < b.size()) { + using std::swap; + swap(a, b); + } + + if (b.empty()) { + absl::c_fill(out, 0); + return; + } + + auto upper = out.subspan(a.size()); + upper[0] = MulAdd(out, a, b[0]); + + const int size = b.size(); + int i = 1; + while (size >= i + 4) { + for (int j = 0; j < 4; ++j) { + upper[i] = MulAddIp(out.subspan(i), a, b[i]); + ++i; + } + } + + // Finish remainder (if any). + for (; i < size; ++i) { + upper[i] = MulAddIp(out.subspan(i), a, b[i]); + } + + // Finish zeroing out upper half. + for (; i < upper.size(); ++i) { + upper[i] = 0; + } +} + +// Split a span into two contiguous pieces of length a and b, respectively. +template +static std::pair, absl::Span> Split( // + absl::Span span, int a, int b) { + return {span.subspan(0, a), span.subspan(a, b)}; +}; + +// A simple bump allocator to avoid allocating memory during recursion. +class Arena { + public: + explicit Arena(ssize_t size) { data_.reserve(size); } + + // Allocates a span of length n from the arena. + absl::Span Alloc(ssize_t n) { + ABSL_DCHECK_LE(used_ + n, data_.capacity()); + size_t start = used_; + used_ += n; + return absl::Span(data_.data() + start, n); + } + + void Release(ssize_t n) { + ABSL_DCHECK_LE(n, used_); + used_ -= n; + } + + private: + ssize_t used_ = 0; + absl::InlinedVector data_; +}; + +static void KaratsubaMulRec( // + absl::Span dst, // + absl::Span a, absl::Span b, Arena& arena) { + ABSL_DCHECK_EQ(dst.size(), a.size() + b.size()); + if (a.empty() || b.empty()) { + return; + } + + // Karatsuba lets us represent two numbers, A and B thusly: + // A = a1*10^M + a0 + // B = b1*10^M + b0 + // + // Which we can multiply out: + // AB = (a1*10^M + a0)*(b1*10^M + b0); + // = a1*b1*10^(2M) + (a1*b0 + a0*b1)*10^M + a0*b0 + // = z2 * 10^2M + z1*10^M + z0 + // + // Where: + // z0 = a0*b0 + // z1 = a1*b0 + a0*b1 + // z2 = a1*b1 + // + // We can replace the multiplications in z1 by computing: + // + // z3 = (a0 + a1)*(b0 + b1) + // + // And noting z1 = z3 - z2 - z0 + // + // This lets us compute a 2M digit multiply with three M digit multiplies, + // with those individual multiplies able to be recursively divided. + + // Fall back to long multiplication when we're small enough. + if (dst.size() < kKaratsubaThreshold) { + MulQuadratic(dst, a, b); + return; + } + + const int half = (std::min(a.size(), b.size()) + 1) / 2; + + // Split the inputs into contiguous subspans. + auto [a0, a1] = Split(a, half, half); + auto [b0, b1] = Split(b, half, half); + + // Make space to hold results in the output and multiply sub-terms. + // z0 = a0 * b0 + // z2 = a1 * b1 + auto [z0, z2] = Split(dst, 2 * half, 2 * half); + + KaratsubaMulRec(z0, a0, b0, arena); + KaratsubaMulRec(z2, a1, b1, arena); + + // Compute (a0 + a1) and (b0 + b1) using space from the arena. + // + // The sums may or may not carry. We pop the extra bigit off if they + // don't. + auto sa = arena.Alloc(half + 1); + auto sb = arena.Alloc(half + 1); + sa = sa.first(AddInto(sa, a0, a1)); + sb = sb.first(AddInto(sb, b0, b1)); + + // Compute z1 = sa*sb - z0 - z2 = (a0 + a1)*(b0 + b1) - z0 - z2 + auto z1 = arena.Alloc(sa.size() + sb.size()); + + // Compute sa * sb into the beginning of z1 + KaratsubaMulRec(z1, sa, sb, arena); + + // NOTE: (a0 + a1) * (b0 + b1) >= a0*b0 + a1*b1 so this never underflows. + SubGeIp(z1, z0); + SubGeIp(z1, z2); + + // We need to add z1*10^half which we can do by adding it offset. + AddInPlace(dst.subspan(half), z1); + + // Release temporary memory we used. + arena.Release(z1.size() + sb.size() + sa.size()); +} + +Bignum::BigitVector Bignum::KaratsubaMul( // + absl::Span a, absl::Span b) { + if (a.empty() || b.empty()) { + return {}; + } + + // Each step of Karatsuba splits at: + // N = std::ceil(std::min(a.size(), b.size())/2) + // + // We have to hold a total of 4*(N + 1) bigits as temporaries at each step. + // + // Simulate the recursion (log(n) steps) and compute the arena size. + int size = a.size() + b.size(); + int peak = 0; + do { + int half = (size + 1) / 2; + int next = half + 1; + peak += 4 * next; + size = next; + } while (size > kKaratsubaThreshold); + + Arena arena(peak); + BigitVector out(a.size() + b.size(), 0); + KaratsubaMulRec(absl::MakeSpan(out), a, b, arena); + return out; +} + +Bignum& Bignum::operator+=(const Bignum& b) { + if (b.is_zero()) { + return *this; + } + + if (is_zero()) { + *this = b; + return *this; + } + + if (sign_ == b.sign_) { + // Same sign: + // (+a) + (+b) == +(a + b) + // (-a) + (-b) == -(a + b) + bigits_.resize(std::max(size(), b.size()), 0); + Bigit carry = AddInPlace(absl::MakeSpan(bigits_), b.bigits_); + if (carry) { + bigits_.emplace_back(carry); + } + Normalize(); + } else { + if (CmpAbs(b) >= 0) { + // |a| >= |b|, so a - b is same sign as a. + SubGeIp(absl::MakeSpan(bigits_), b.bigits_); + NormalizeSign(sign_); + } else { + // |a| < |b|, so a - b is same sign as b. + const int prev_size = size(); + bigits_.resize(b.size()); + SubLtIp(absl::MakeSpan(bigits_), b.bigits_, prev_size); + NormalizeSign(b.sign_); + } + } + + return *this; +} + +Bignum& Bignum::operator-=(const Bignum& b) { + if (this == &b) { + return SetZero(); + } + + if (b.is_zero()) { + return *this; + } + + if (is_zero()) { + return *this = -b; + } + + if (sign_ != b.sign_) { + bigits_.resize(std::max(size(), b.size()), 0); + uint64_t carry = AddInPlace(absl::MakeSpan(bigits_), b.bigits_); + if (carry) { + bigits_.emplace_back(carry); + } + Normalize(); + } else { + if (CmpAbs(b) >= 0) { + SubGeIp(absl::MakeSpan(bigits_), b.bigits_); + NormalizeSign(sign_); + } else { + const int prev_size = size(); + bigits_.resize(b.size()); + SubLtIp(absl::MakeSpan(bigits_), b.bigits_, prev_size); + NormalizeSign(-sign_); + } + } + + return *this; +} + +Bignum& Bignum::operator*=(const Bignum& b) { + if (is_zero() || b.is_zero()) { + return SetZero(); + } + + const int new_sign = sign_ * b.sign_; + + // Fast path for single-bigit multiplication. + if (size() == 1 && b.size() == 1) { + absl::uint128 prod = absl::uint128(bigits_[0]) * b.bigits_[0]; + const uint64_t lo = absl::Uint128Low64(prod); + const uint64_t hi = absl::Uint128High64(prod); + if (hi == 0) { + bigits_ = {lo}; + } else { + bigits_ = {lo, hi}; + } + sign_ = new_sign; + return *this; + } + + // Use Karatsuba multiplication. + // If the inputs are small enough this will just do long multiplication. + bigits_ = KaratsubaMul(bigits_, b.bigits_); + NormalizeSign(new_sign); + return *this; +} diff --git a/src/s2/util/math/exactfloat/exactfloat_internal.h b/src/s2/util/math/exactfloat/exactfloat_internal.h new file mode 100644 index 00000000..596f6f4e --- /dev/null +++ b/src/s2/util/math/exactfloat/exactfloat_internal.h @@ -0,0 +1,489 @@ +// Copyright 2025 Google LLC +// Author: smcallis@google.com (Sean McAllister) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/absl_check.h" +#include "absl/numeric/bits.h" +#include "absl/numeric/int128.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_format.h" + +// A class to support arithmetic on large, arbitrary precision integers. +// +// Large integers are represented as an array of uint64_t values. +class Bignum { + public: + // Wrap uint64_t in a struct so we can avoid default initialization. + // + // Avoiding default initialization overhead saves us 50% on some benchmarks. + struct Bigit { + static constexpr int kBits = std::numeric_limits::digits; + + Bigit() {} + constexpr Bigit(uint64_t value) : value_(value) {} + explicit Bigit(absl::uint128 value) : value_(absl::Uint128Low64(value)) {} + + constexpr operator uint64_t() const { return value_; } + constexpr Bigit& operator=(uint64_t value) { + value_ = value; + return *this; + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE constexpr Bigit& operator--(int) { + value_--; + return *this; + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE constexpr friend Bigit operator*( // + int a, Bigit b) { + return a * b.value_; + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE friend absl::uint128 operator*( // + absl::uint128 a, Bigit b) { + return a * b.value_; + } + + ABSL_ATTRIBUTE_ALWAYS_INLINE friend absl::uint128 operator+( // + absl::uint128 a, Bigit b) { + return a + b.value_; + } + + uint64_t value_; + }; + + using BigitVector = absl::InlinedVector; + + Bignum() = default; + + // Constructs a bignum from an integral value (signed or unsigned). + template ::is_integer>> + explicit Bignum(T value); + + //-------------------------------------- + // String formatting and parsing. + //-------------------------------------- + + // Constructs a bignum from an ASCII string containing decimal digits. + // + // The input string must only have an optional leading +/- and decimal digits. + // Any other characters will yield std::nullopt. + static std::optional FromString(absl::string_view s); + + // Formats the bignum as a decimal integer into an abseil sink. + template + friend void AbslStringify(Sink& sink, const Bignum& b); + + friend std::ostream& operator<<(std::ostream& os, const Bignum& b) { + return os << absl::StrFormat("%v", b); + } + + friend std::ostream& operator<<( // + std::ostream& os, const std::optional& b) { + if (!b) { + return os << "[nullopt]"; + } + return os << *b; + } + + //-------------------------------------- + // Casting and coercion. + //-------------------------------------- + + // Returns true if bignum can be stored in T without truncation or overflow. + template >> + bool FitsIn() const; + + // Cast to an integral type T unconditionally. Use FitsIn() or + // ConvertTo to perform conversion with bounds checking. + template >> + T Cast() const; + + // Casts the value to the given type if it fits, otherwise std::nullopt. + template >> + std::optional ConvertTo() const { + if (!FitsIn()) { + return std::nullopt; + } + return Cast(); + } + + //-------------------------------------- + // General accessors. + //-------------------------------------- + + // Returns the number of bits required to represent the bignum. + friend int bit_width(const Bignum& a); + + // Returns the number of consecutive 0 bits in the value, starting from the + // least significant bit. + friend int countr_zero(const Bignum& a); + + // Returns true if the n-th bit of the number's magnitude is set. + bool Bit(int nbit) const; + + // Clears this bignum and sets it to zero. + Bignum& SetZero() { + sign_ = 0; + bigits_.clear(); + return *this; + } + + // Unconditionally makes the sign of this bignum negative. + Bignum& SetNegative() { + sign_ = -1; + return *this; + } + + // Unconditionally makes the sign of this bignum positive. + Bignum& SetPositive() { + sign_ = +1; + return *this; + } + + // Unconditionally set the sign of this bignum to match the sign of the + // argument. If the argument is zero, set the bignum to zero. + Bignum& SetSign(int sign) { + if (sign == 0) { + return SetZero(); + } + + if (sign < 0) { + return SetNegative(); + } + return SetPositive(); + } + + // Returns true if the number is zero. + bool is_zero() const { // + return sign_ == 0; + } + + // Returns true if the number is greater than zero. + bool positive() const { // + return sign_ > 0; + } + + // Returns true if the number is less than zero. + bool negative() const { // + return sign_ < 0; + } + + // Returns true if the number is odd (least significant bit is 1). + bool is_odd() const { return Bit(0); } + + // Returns true if the number is even (least significant bit is 0). + bool is_even() const { return !is_odd(); } + + //-------------------------------------- + // Comparisons. + //-------------------------------------- + + bool operator==(const Bignum& b) const { + return sign_ == b.sign_ && bigits_ == b.bigits_; + } + + bool operator!=(const Bignum& b) const { return !(*this == b); } + bool operator<(const Bignum& b) const { return Compare(b) < 0; } + bool operator<=(const Bignum& b) const { return Compare(b) <= 0; } + bool operator>(const Bignum& b) const { return Compare(b) > 0; } + bool operator>=(const Bignum& b) const { return Compare(b) >= 0; } + + //-------------------------------------- + // Arithmetic operators. + //-------------------------------------- + + Bignum operator+() const { return *this; } + Bignum operator-() const; + + // Raise this value to the given power, which must be non-negative. + Bignum Pow(int32_t pow) const; + + Bignum& operator+=(const Bignum& b); + Bignum& operator-=(const Bignum& b); + Bignum& operator*=(const Bignum& b); + Bignum& operator<<=(int nbit); + Bignum& operator>>=(int nbit); + + friend Bignum operator+(Bignum a, const Bignum& b) { return a += b; } + friend Bignum operator-(Bignum a, const Bignum& b) { return a -= b; } + friend Bignum operator*(Bignum a, const Bignum& b) { return a *= b; } + friend Bignum operator<<(Bignum a, int nbit) { return a <<= nbit; } + friend Bignum operator>>(Bignum a, int nbit) { return a >>= nbit; } + + private: + // Constructs a Bignum from bigits and an optional sign bit. + explicit Bignum(BigitVector bigits, int sign = +1) + : bigits_(std::move(bigits)) { + NormalizeSign(sign); + } + + // Returns the number of bigits in this bignum. + size_t size() const { // + return bigits_.size(); + } + + // Returns true if this value has no digits. + bool empty() const { // + return bigits_.empty(); + } + + // Compare to another bignum, returns -1, 0, +1. + int Compare(const Bignum& b) const { + if (sign_ != b.sign_) { + return sign_ < b.sign_ ? -1 : 1; + } + + // Signs are equal, are they both zero? + if (sign_ == 0) { + return 0; + } + + // Signs are equal and non-zero, compare magnitude. + return positive() ? CmpAbs(b) : -CmpAbs(b); + } + + // Multiplies two unsigned bigit vectors together using Karatsuba's algorithm. + static BigitVector KaratsubaMul( // + absl::Span a, absl::Span b); + + // Drop leading zero bigits. + void Normalize() { + while (!empty() && bigits_.back() == 0) { + bigits_.pop_back(); + } + + if (empty()) { + sign_ = 0; + } + } + + // Drop leading zero bigits and canonicalize sign. + void NormalizeSign(int sign) { + Normalize(); + sign_ = empty() ? 0 : sign; + } + + // Returns true if the bignum is in normal form (no extra leading zeros). + bool Normalized() const { // + return bigits_.empty() || bigits_.back() != 0; + } + + // Returns true if the bignum is the given power of two. + bool IsPow2(int pow2) const { + const int bigits = pow2 / Bigit::kBits; + if (bigits_.size() != bigits + 1) { + return false; + } + + // Verify lower words are zero. + for (int i = 0; i < bigits; ++i) { + if (bigits_[i] != 0) { + return false; + } + } + + // Check final word is power of two. + pow2 -= bigits * Bigit::kBits; + ABSL_DCHECK_LT(pow2, Bigit::kBits); + return bigits_.back() == (Bigit(1) << pow2); + } + + // Compares magnitude with another bignum, returning -1, 0, or +1. + int CmpAbs(const Bignum& b) const; + + BigitVector bigits_; + int sign_ = 0; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Implementation Details +//////////////////////////////////////////////////////////////////////////////// + +// Constructs a bignum from an integral value (signed or unsigned). +template +Bignum::Bignum(T value) { + using UT = std::make_unsigned_t; + + if (value == 0) { + return; + } + + sign_ = +1; + if constexpr (std::is_signed_v) { + sign_ = (value < 0) ? -1 : +1; + } + + // Get magnitude of value, handle minimum value of T cleanly. + UT mag = static_cast(value); + if constexpr (std::is_signed_v) { + if (value < 0) { + mag = UT(0) - mag; + } + } + + // Pack the magnitude into bigits. + if constexpr (std::numeric_limits::digits <= Bigit::kBits) { + bigits_.push_back(static_cast(mag)); + } else { + while (mag) { + bigits_.push_back(static_cast(mag)); + mag >>= Bigit::kBits; + } + } +} + +// Formats the bignum as a decimal integer into an abseil sink. +template +void AbslStringify(Sink& sink, const Bignum& b) { + if (b.is_zero()) { + sink.Append("0"); + return; + } + + // Sign + if (b.negative()) { + sink.Append("-"); + } + + // Work on a copy of the magnitude. + Bignum copy = b; + copy.sign_ = 1; + + // Repeatedly divide and modulo by 10^19 to get decimal chunks. + static constexpr uint64_t kBase = 10000000000000000000ull; + Bignum::BigitVector chunks; + + while (!copy.is_zero()) { + absl::uint128 rem = 0; + for (int i = static_cast(copy.bigits_.size()) - 1; i >= 0; --i) { + absl::uint128 acc = (rem << 64) + copy.bigits_[i]; + Bignum::Bigit quot = static_cast(acc / kBase); + rem = acc - absl::uint128(quot) * kBase; + copy.bigits_[i] = quot; + } + + copy.Normalize(); + chunks.push_back(static_cast(rem)); + } + ABSL_DCHECK(!chunks.empty()); + + // Emit most significant chunk without zero padding. + absl::Format(&sink, "%d", chunks.back()); + + // Emit remaining chunks as fixed-width 19-digit zero-padded blocks. + for (int i = static_cast(chunks.size()) - 2; i >= 0; --i) { + absl::Format(&sink, "%019d", chunks[i]); + } +} + +template +inline bool Bignum::FitsIn() const { + using UT = std::make_unsigned_t; + + if (sign_ == 0) { + return true; + } + + // Maximum number of bits that could fit in the output type. + constexpr int kTBitWidth = std::numeric_limits::digits; + constexpr int kMaxBigits = (kTBitWidth + (Bigit::kBits - 1)) / Bigit::kBits; + + // Fast reject if the bignum couldn't conceivably fit. + if (bigits_.size() > kMaxBigits) { + return false; + } + + // Unsigned type T can hold the value iff the value is non-negative and + // the bitwidth is <= the maximum bit width of the type. + if constexpr (!std::is_signed_v) { + if (negative()) { + return false; + } + return bit_width(*this) <= kTBitWidth; + } + + // T is signed and our bignum isn't zero. + ABSL_DCHECK(std::is_signed_v && !is_zero()); + + if (positive()) { + return bit_width(*this) <= (kTBitWidth - 1); + } else /* negative() */ { + // Magnitude must fit in negative value. If the value is negative and + // the same bit width as the output type, the only valid value is + // -2^(k-1). + if (bit_width(*this) == kTBitWidth) { + return IsPow2(kTBitWidth - 1); + } + return bit_width(*this) < kTBitWidth; + } +} + +template +T Bignum::Cast() const { + using UT = std::make_unsigned_t; + + constexpr int kTBitWidth = std::numeric_limits::digits; + + if (empty()) { + return 0; + } + + // Grab the bottom bits into an unsigned value. + UT residue = 0; + for (size_t i = 0; i < bigits_.size(); ++i) { + const int shift = i * Bigit::kBits; + if (shift >= kTBitWidth) { + break; + } + + const int room = kTBitWidth - shift; + UT chunk = static_cast(bigits_[i]); + if (room < Bigit::kBits && room < std::numeric_limits::digits) { + chunk &= (UT(1) << room) - UT(1); + } + residue |= (chunk << shift); + } + + // Compute two's complement of the residue if value is negative. + if (negative()) { + residue = UT(0) - residue; + } + + return static_cast(residue); +} + +inline int Bignum::CmpAbs(const Bignum& b) const { + if (size() != b.size()) { + return size() < b.size() ? -1 : +1; + } + + for (int i = size() - 1; i >= 0; --i) { + if (bigits_[i] != b.bigits_[i]) { + return bigits_[i] < b.bigits_[i] ? -1 : +1; + } + } + + return 0; +} diff --git a/src/s2/util/math/exactfloat/bignum_test.cc b/src/s2/util/math/exactfloat/exactfloat_internal_test.cc similarity index 77% rename from src/s2/util/math/exactfloat/bignum_test.cc rename to src/s2/util/math/exactfloat/exactfloat_internal_test.cc index f6838d89..7480b542 100644 --- a/src/s2/util/math/exactfloat/bignum_test.cc +++ b/src/s2/util/math/exactfloat/exactfloat_internal_test.cc @@ -13,22 +13,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "s2/util/math/exactfloat/bignum.h" +#include "s2/util/math/exactfloat/exactfloat_internal.h" #include #include #include #include -#if 0 #include "absl/base/no_destructor.h" #include "absl/strings/string_view.h" #include "benchmark/benchmark.h" +#include "gtest/gtest.h" #include "openssl/bn.h" #include "openssl/crypto.h" -#endif - -#include "gtest/gtest.h" const uint64_t u8max = std::numeric_limits::max(); const uint64_t u16max = std::numeric_limits::max(); @@ -111,18 +108,16 @@ TEST(BignumTest, SignInWrongPlaceCausesFailure) { EXPECT_EQ(Bn("+314-"), std::nullopt); } -TEST(BignumTest, ZeroAlwaysCompatible) { +TEST(BignumTest, ZeroAlwaysFitsIn) { const Bignum zero(0); - EXPECT_TRUE(zero.Compatible()); - EXPECT_TRUE(zero.Compatible()); - EXPECT_TRUE(zero.Compatible()); - EXPECT_TRUE(zero.Compatible()); - EXPECT_TRUE(zero.Compatible()); - EXPECT_TRUE(zero.Compatible()); - EXPECT_TRUE(zero.Compatible()); - EXPECT_TRUE(zero.Compatible()); - EXPECT_TRUE(zero.Compatible()); - EXPECT_TRUE(zero.Compatible()); + EXPECT_TRUE(zero.FitsIn()); + EXPECT_TRUE(zero.FitsIn()); + EXPECT_TRUE(zero.FitsIn()); + EXPECT_TRUE(zero.FitsIn()); + EXPECT_TRUE(zero.FitsIn()); + EXPECT_TRUE(zero.FitsIn()); + EXPECT_TRUE(zero.FitsIn()); + EXPECT_TRUE(zero.FitsIn()); } TEST(BignumTest, ZeroAlwaysCastsToZero) { @@ -137,175 +132,141 @@ TEST(BignumTest, ZeroAlwaysCastsToZero) { EXPECT_EQ(zero.Cast(), 0); } -TEST(BignumTest, NegativeOnlyCompatibleSigned) { +TEST(BignumTest, NegativeOnlyFitsInSigned) { const Bignum small_neg(-1); - EXPECT_FALSE(small_neg.Compatible()); - EXPECT_FALSE(small_neg.Compatible()); - EXPECT_FALSE(small_neg.Compatible()); - EXPECT_FALSE(small_neg.Compatible()); - EXPECT_FALSE(small_neg.Compatible()); + EXPECT_FALSE(small_neg.FitsIn()); + EXPECT_FALSE(small_neg.FitsIn()); + EXPECT_FALSE(small_neg.FitsIn()); + EXPECT_FALSE(small_neg.FitsIn()); - EXPECT_TRUE(small_neg.Compatible()); - EXPECT_TRUE(small_neg.Compatible()); - EXPECT_TRUE(small_neg.Compatible()); - EXPECT_TRUE(small_neg.Compatible()); - EXPECT_TRUE(small_neg.Compatible()); + EXPECT_TRUE(small_neg.FitsIn()); + EXPECT_TRUE(small_neg.FitsIn()); + EXPECT_TRUE(small_neg.FitsIn()); + EXPECT_TRUE(small_neg.FitsIn()); } -TEST(BignumTest, CompatibleUnsignedBoundsChecks) { +TEST(BignumTest, FitsInUnsignedBoundsChecks) { const Bignum bn_u8max(u8max); const Bignum bn_u8over(u8max + 1); - EXPECT_TRUE(bn_u8max.Compatible()); - EXPECT_TRUE(bn_u8max.Compatible()); - EXPECT_TRUE(bn_u8max.Compatible()); - EXPECT_FALSE(bn_u8over.Compatible()); - EXPECT_TRUE(bn_u8over.Compatible()); - EXPECT_TRUE(bn_u8over.Compatible()); + EXPECT_TRUE(bn_u8max.FitsIn()); + EXPECT_TRUE(bn_u8max.FitsIn()); + EXPECT_TRUE(bn_u8max.FitsIn()); + EXPECT_FALSE(bn_u8over.FitsIn()); + EXPECT_TRUE(bn_u8over.FitsIn()); + EXPECT_TRUE(bn_u8over.FitsIn()); const Bignum bn_u16max(u16max); const Bignum bn_u16over(u16max + 1); - EXPECT_FALSE(bn_u16max.Compatible()); - EXPECT_TRUE(bn_u16max.Compatible()); - EXPECT_TRUE(bn_u16max.Compatible()); - EXPECT_FALSE(bn_u16over.Compatible()); - EXPECT_FALSE(bn_u16over.Compatible()); - EXPECT_TRUE(bn_u16over.Compatible()); + EXPECT_FALSE(bn_u16max.FitsIn()); + EXPECT_TRUE(bn_u16max.FitsIn()); + EXPECT_TRUE(bn_u16max.FitsIn()); + EXPECT_FALSE(bn_u16over.FitsIn()); + EXPECT_FALSE(bn_u16over.FitsIn()); + EXPECT_TRUE(bn_u16over.FitsIn()); const Bignum bn_u32max(u32max); const Bignum bn_u32over(u32max + 1); - EXPECT_FALSE(bn_u32max.Compatible()); - EXPECT_FALSE(bn_u32max.Compatible()); - EXPECT_TRUE(bn_u32max.Compatible()); - EXPECT_FALSE(bn_u32over.Compatible()); - EXPECT_FALSE(bn_u32over.Compatible()); - EXPECT_FALSE(bn_u32over.Compatible()); + EXPECT_FALSE(bn_u32max.FitsIn()); + EXPECT_FALSE(bn_u32max.FitsIn()); + EXPECT_TRUE(bn_u32max.FitsIn()); + EXPECT_FALSE(bn_u32over.FitsIn()); + EXPECT_FALSE(bn_u32over.FitsIn()); + EXPECT_FALSE(bn_u32over.FitsIn()); const Bignum bn_u64max(u64max); - EXPECT_TRUE(bn_u64max.Compatible()); + EXPECT_TRUE(bn_u64max.FitsIn()); // 2^64, need to use string constructor. Bignum bn0 = *Bn("18446744073709551616"); - EXPECT_FALSE(bn0.Compatible()); - EXPECT_TRUE(bn0.Compatible()); - - // (2^128 - 1) fits in absl::uint128. - Bignum bn1 = *Bn("340282366920938463463374607431768211455"); - EXPECT_TRUE(bn1.Compatible()); - - // 2^128 does not fit in absl::uint128. - Bignum bn2 = *Bn("340282366920938463463374607431768211456"); - EXPECT_FALSE(bn2.Compatible()); + EXPECT_FALSE(bn0.FitsIn()); } -TEST(BignumTest, CompatibleSignedBoundsChecks) { +TEST(BignumTest, FitsInSignedBoundsChecks) { const Bignum bn_i8max(i8max); const Bignum bn_i8over(i8max + 1); - EXPECT_TRUE(bn_i8max.Compatible()); - EXPECT_TRUE(bn_i8max.Compatible()); - EXPECT_TRUE(bn_i8max.Compatible()); - EXPECT_FALSE(bn_i8over.Compatible()); - EXPECT_TRUE(bn_i8over.Compatible()); - EXPECT_TRUE(bn_i8over.Compatible()); + EXPECT_TRUE(bn_i8max.FitsIn()); + EXPECT_TRUE(bn_i8max.FitsIn()); + EXPECT_TRUE(bn_i8max.FitsIn()); + EXPECT_FALSE(bn_i8over.FitsIn()); + EXPECT_TRUE(bn_i8over.FitsIn()); + EXPECT_TRUE(bn_i8over.FitsIn()); const Bignum bn_i16max(i16max); const Bignum bn_i16over(i16max + 1); - EXPECT_FALSE(bn_i16max.Compatible()); - EXPECT_TRUE(bn_i16max.Compatible()); - EXPECT_TRUE(bn_i16max.Compatible()); - EXPECT_FALSE(bn_i16over.Compatible()); - EXPECT_FALSE(bn_i16over.Compatible()); - EXPECT_TRUE(bn_i16over.Compatible()); + EXPECT_FALSE(bn_i16max.FitsIn()); + EXPECT_TRUE(bn_i16max.FitsIn()); + EXPECT_TRUE(bn_i16max.FitsIn()); + EXPECT_FALSE(bn_i16over.FitsIn()); + EXPECT_FALSE(bn_i16over.FitsIn()); + EXPECT_TRUE(bn_i16over.FitsIn()); const Bignum bn_i32max(i32max); const Bignum bn_i32over(i32max + 1); - EXPECT_FALSE(bn_i32max.Compatible()); - EXPECT_FALSE(bn_i32max.Compatible()); - EXPECT_TRUE(bn_i32max.Compatible()); - EXPECT_FALSE(bn_i32over.Compatible()); - EXPECT_FALSE(bn_i32over.Compatible()); - EXPECT_FALSE(bn_i32over.Compatible()); + EXPECT_FALSE(bn_i32max.FitsIn()); + EXPECT_FALSE(bn_i32max.FitsIn()); + EXPECT_TRUE(bn_i32max.FitsIn()); + EXPECT_FALSE(bn_i32over.FitsIn()); + EXPECT_FALSE(bn_i32over.FitsIn()); + EXPECT_FALSE(bn_i32over.FitsIn()); Bignum bn_i64max(i64max); - EXPECT_TRUE(bn_i64max.Compatible()); + EXPECT_TRUE(bn_i64max.FitsIn()); // 2^63, need to use string constructor. Bignum bn0 = *Bn("9223372036854775808"); - EXPECT_FALSE(bn0.Compatible()); + EXPECT_FALSE(bn0.FitsIn()); const Bignum bn_i8min(i8min); const Bignum bn_i8under(i8min - 1); - EXPECT_TRUE(bn_i8min.Compatible()); - EXPECT_TRUE(bn_i8min.Compatible()); - EXPECT_TRUE(bn_i8min.Compatible()); - EXPECT_FALSE(bn_i8under.Compatible()); - EXPECT_TRUE(bn_i8under.Compatible()); - EXPECT_TRUE(bn_i8under.Compatible()); + EXPECT_TRUE(bn_i8min.FitsIn()); + EXPECT_TRUE(bn_i8min.FitsIn()); + EXPECT_TRUE(bn_i8min.FitsIn()); + EXPECT_FALSE(bn_i8under.FitsIn()); + EXPECT_TRUE(bn_i8under.FitsIn()); + EXPECT_TRUE(bn_i8under.FitsIn()); const Bignum bn_i16min(i16min); const Bignum bn_i16under(i16min - 1); - EXPECT_FALSE(bn_i16min.Compatible()); - EXPECT_TRUE(bn_i16min.Compatible()); - EXPECT_TRUE(bn_i16min.Compatible()); - EXPECT_FALSE(bn_i16under.Compatible()); - EXPECT_FALSE(bn_i16under.Compatible()); - EXPECT_TRUE(bn_i16under.Compatible()); + EXPECT_FALSE(bn_i16min.FitsIn()); + EXPECT_TRUE(bn_i16min.FitsIn()); + EXPECT_TRUE(bn_i16min.FitsIn()); + EXPECT_FALSE(bn_i16under.FitsIn()); + EXPECT_FALSE(bn_i16under.FitsIn()); + EXPECT_TRUE(bn_i16under.FitsIn()); const Bignum bn_i32min(i32min); const Bignum bn_i32under(i32min - 1); - EXPECT_FALSE(bn_i32min.Compatible()); - EXPECT_FALSE(bn_i32min.Compatible()); - EXPECT_TRUE(bn_i32min.Compatible()); - EXPECT_FALSE(bn_i32under.Compatible()); - EXPECT_FALSE(bn_i32under.Compatible()); - EXPECT_FALSE(bn_i32under.Compatible()); + EXPECT_FALSE(bn_i32min.FitsIn()); + EXPECT_FALSE(bn_i32min.FitsIn()); + EXPECT_TRUE(bn_i32min.FitsIn()); + EXPECT_FALSE(bn_i32under.FitsIn()); + EXPECT_FALSE(bn_i32under.FitsIn()); + EXPECT_FALSE(bn_i32under.FitsIn()); Bignum bn_i64min(i64min); - EXPECT_TRUE(bn_i64min.Compatible()); - - // -(2^63) - 1 doesn't fit in int64_t. - Bignum b0 = *Bn("-9223372036854775809"); - EXPECT_FALSE(b0.Compatible()); - - // Exact min and max of signed 128. - Bignum bn_s128min = *Bn("-170141183460469231731687303715884105728"); - Bignum bn_s128max = *Bn("170141183460469231731687303715884105727"); - EXPECT_TRUE(bn_s128min.Compatible()); - EXPECT_TRUE(bn_s128max.Compatible()); - - // +2^127 does not fit in signed 128, but does in unsigned 128. - Bignum bn1 = *Bn("170141183460469231731687303715884105728"); - EXPECT_FALSE(bn1.Compatible()); - EXPECT_TRUE(bn1.Compatible()); - - // Below min: -(2^127) - 1 should not fit. - Bignum bn2 = *Bn("-170141183460469231731687303715884105729"); - EXPECT_FALSE(bn2.Compatible()); + EXPECT_TRUE(bn_i64min.FitsIn()); } -TEST(BignumTest, CompatibleBasicSanityChecks) { +TEST(BignumTest, FitsInBasicSanityChecks) { Bignum pos42(42); - EXPECT_TRUE(pos42.Compatible()); - EXPECT_TRUE(pos42.Compatible()); - EXPECT_TRUE(pos42.Compatible()); - EXPECT_TRUE(pos42.Compatible()); - EXPECT_TRUE(pos42.Compatible()); - EXPECT_TRUE(pos42.Compatible()); - EXPECT_TRUE(pos42.Compatible()); - EXPECT_TRUE(pos42.Compatible()); - EXPECT_TRUE(pos42.Compatible()); - EXPECT_TRUE(pos42.Compatible()); + EXPECT_TRUE(pos42.FitsIn()); + EXPECT_TRUE(pos42.FitsIn()); + EXPECT_TRUE(pos42.FitsIn()); + EXPECT_TRUE(pos42.FitsIn()); + EXPECT_TRUE(pos42.FitsIn()); + EXPECT_TRUE(pos42.FitsIn()); + EXPECT_TRUE(pos42.FitsIn()); + EXPECT_TRUE(pos42.FitsIn()); Bignum neg42(-42); - EXPECT_TRUE(neg42.Compatible()); - EXPECT_FALSE(neg42.Compatible()); - EXPECT_TRUE(neg42.Compatible()); - EXPECT_FALSE(neg42.Compatible()); - EXPECT_TRUE(neg42.Compatible()); - EXPECT_FALSE(neg42.Compatible()); - EXPECT_TRUE(neg42.Compatible()); - EXPECT_FALSE(neg42.Compatible()); - EXPECT_TRUE(neg42.Compatible()); - EXPECT_FALSE(neg42.Compatible()); + EXPECT_TRUE(neg42.FitsIn()); + EXPECT_FALSE(neg42.FitsIn()); + EXPECT_TRUE(neg42.FitsIn()); + EXPECT_FALSE(neg42.FitsIn()); + EXPECT_TRUE(neg42.FitsIn()); + EXPECT_FALSE(neg42.FitsIn()); + EXPECT_TRUE(neg42.FitsIn()); + EXPECT_FALSE(neg42.FitsIn()); } TEST(BignumTest, UnsignedCasting) { @@ -373,42 +334,6 @@ TEST(BignumTest, CastingLargeResidues) { EXPECT_EQ(bn1.Cast(), -1); } -TEST(BignumTest, AbslUint128Casting) { - Bignum neg1(-1); - EXPECT_EQ(neg1.Cast(), ~absl::uint128(0)); - - // 2^128 -> low 128 bits == 0 - Bignum bn1 = *Bn("340282366920938463463374607431768211456"); - EXPECT_EQ(bn1.Cast(), absl::uint128(0)); - - // 2^200 + 5 -> low 128 bits == 5 - Bignum bn2 = - *Bn("1606938044258990275541962092341162602522202993782792835301381"); - EXPECT_EQ(bn2.Cast(), absl::uint128(5)); -} - -TEST(BignumTest, AbslInt128Casting) { - const absl::int128 two127 = absl::int128(1) << 127; - - // +2^127 -> wraps to -2^127 - Bignum bn0 = *Bn("170141183460469231731687303715884105728"); - EXPECT_EQ(bn0.Cast(), 0 - two127); - - // -(2^127) - 1 -> wraps to +2^127 - 1 - Bignum bn1 = *Bn("-170141183460469231731687303715884105729"); - EXPECT_EQ(bn1.Cast(), two127 - 1); - - // 2^200 + 5 -> low 128 bits == 5 - Bignum bn2 = - *Bn("1606938044258990275541962092341162602522202993782792835301381"); - EXPECT_EQ(bn2.Cast(), absl::int128(5)); - - // -(2^200 + 5) -> low 128 bits == -5 - Bignum bn3 = - *Bn("-1606938044258990275541962092341162602522202993782792835301381"); - EXPECT_EQ(bn3.Cast(), absl::int128(-5)); -} - TEST(BignumTest, UnaryOperators) { EXPECT_EQ(+Bignum(0), Bignum(0)); EXPECT_EQ(-Bignum(0), Bignum(0)); @@ -610,24 +535,24 @@ TEST(BignumTest, Multiplication) { } TEST(BignumTest, CountrZero) { - EXPECT_EQ(Bignum(0).CountrZero(), 0); - EXPECT_EQ(Bignum(1).CountrZero(), 0); - EXPECT_EQ(Bignum(7).CountrZero(), 0); - EXPECT_EQ(Bignum(-7).CountrZero(), 0); + EXPECT_EQ(countr_zero(Bignum(0)), 0); + EXPECT_EQ(countr_zero(Bignum(1)), 0); + EXPECT_EQ(countr_zero(Bignum(7)), 0); + EXPECT_EQ(countr_zero(Bignum(-7)), 0); - EXPECT_EQ(Bignum(2).CountrZero(), 1); - EXPECT_EQ(Bignum(8).CountrZero(), 3); - EXPECT_EQ(Bignum(10).CountrZero(), 1); // 0b1010 - EXPECT_EQ(Bignum(12).CountrZero(), 2); // 0b1100 + EXPECT_EQ(countr_zero(Bignum(2)), 1); + EXPECT_EQ(countr_zero(Bignum(8)), 3); + EXPECT_EQ(countr_zero(Bignum(10)), 1); // 0b1010 + EXPECT_EQ(countr_zero(Bignum(12)), 2); // 0b1100 auto two_pow_64 = Bignum(1) << 64; - EXPECT_EQ(two_pow_64.CountrZero(), 64); + EXPECT_EQ(countr_zero(two_pow_64), 64); auto large_shifted = Bignum(6) << 100; // 0b110 << 100 - EXPECT_EQ(large_shifted.CountrZero(), 101); + EXPECT_EQ(countr_zero(large_shifted), 101); auto neg_large_shifted = Bignum(-5) << 200; - EXPECT_EQ(neg_large_shifted.CountrZero(), 200); + EXPECT_EQ(countr_zero(neg_large_shifted), 200); } TEST(BignumTest, Bit) { @@ -691,7 +616,7 @@ TEST(BignumTest, Pow) { TEST(BignumTest, SetZero) { Bignum a(123); a.SetZero(); - EXPECT_TRUE(a.zero()); + EXPECT_TRUE(a.is_zero()); Bignum b(-456); b.SetZero(); @@ -718,7 +643,7 @@ TEST(BignumTest, SetSign) { EXPECT_EQ(a, Bignum(99)); a.SetSign(0); - EXPECT_TRUE(a.zero()); + EXPECT_TRUE(a.is_zero()); } TEST(BignumTest, Comparisons) { @@ -773,9 +698,6 @@ TEST(BignumTest, Comparisons) { EXPECT_GE(Bignum(0), Bignum(0)); } -// TODO: Enable once benchmark is integrated. -#if 0 - // RAII wrapper for OpenSSL BIGNUM class OpenSSLBignum { public: @@ -849,7 +771,7 @@ static std::vector GenerateRandomNumbers(int bits) { } // Basic correctness test to ensure OpenSSL integration is working -TEST(BignumTestBenchmarkTest, OpenSSLIntegration) { +TEST(BignumTest, OpenSSLIntegration) { OpenSSLBignum a(123); OpenSSLBignum b(456); OpenSSLBignum result; @@ -861,7 +783,7 @@ TEST(BignumTestBenchmarkTest, OpenSSLIntegration) { OPENSSL_free(str); } -TEST(BignumTestBenchmarkTest, ResultsMatch) { +TEST(BignumTest, ResultsMatch) { // Test that and OpenSSL produce the same results const Bignum w_a(12345); const Bignum w_b(67890); @@ -909,6 +831,38 @@ const std::vector& MegaNumbers() { return *numbers; } +TEST(BignumTest, MultiplyCorrectVsOpenSSL) { + // Test that multiplication produces correct results by comparing to OpenSSL. + BN_CTX* ctx = BN_CTX_new(); + for (const auto& numbers : {SmallNumbers(), MediumNumbers(), LargeNumbers(), + HugeNumbers(), MediumNumbers()}) { + for (const auto& number : numbers) { + // Test same number multiplication (most likely to trigger edge cases) + const Bignum bn_a = *Bignum::FromString(number); + const Bignum bn_result = bn_a * bn_a; + + const OpenSSLBignum ssl_a(number); + OpenSSLBignum ssl_result; + BN_mul(ssl_result.get(), ssl_a.get(), ssl_a.get(), ctx); + + // Compare string representations + char* ssl_str = BN_bn2dec(ssl_result.get()); + std::string bn_str = absl::StrFormat("%v", bn_result); + + EXPECT_EQ(bn_str, std::string(ssl_str)) + << "Mismatch for multiplication" + << "\nBignum result: " << bn_str.substr(0, 100) << "..." + << "\nOpenSSL result: " << std::string(ssl_str).substr(0, 100) + << "..."; + OPENSSL_free(ssl_str); + } + } + BN_CTX_free(ctx); +} + +// TODO: Enable once benchmark is integrated. +#if 0 + template void BignumBinaryOpBenchmark(benchmark::State& state, const std::vector& number_strings, From 71314f08e481053595fdd4c3e94edac8d74bd1ec Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Sun, 28 Sep 2025 10:50:48 -0600 Subject: [PATCH 03/31] Address round 2 CR comments. Move exactfloat_internal back to bignum.h and use namespace instead. Loosen equality checks in ABSL_DCHECK calls in bignum. --- CMakeLists.txt | 17 +-- src/s2/util/math/exactfloat/BUILD | 20 ++- .../{exactfloat_internal.cc => bignum.cc} | 10 +- .../{exactfloat_internal.h => bignum.h} | 17 ++- ...tfloat_internal_test.cc => bignum_test.cc} | 133 +++++++++++++----- src/s2/util/math/exactfloat/exactfloat.cc | 10 +- src/s2/util/math/exactfloat/exactfloat.h | 11 +- 7 files changed, 148 insertions(+), 70 deletions(-) rename src/s2/util/math/exactfloat/{exactfloat_internal.cc => bignum.cc} (98%) rename src/s2/util/math/exactfloat/{exactfloat_internal.h => bignum.h} (95%) rename src/s2/util/math/exactfloat/{exactfloat_internal_test.cc => bignum_test.cc} (90%) diff --git a/CMakeLists.txt b/CMakeLists.txt index f333766e..928fc356 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -210,8 +210,8 @@ add_library(s2 src/s2/util/bits/bit-interleave.cc src/s2/util/coding/coder.cc src/s2/util/coding/varint.cc + src/s2/util/math/exactfloat/bignum.cc src/s2/util/math/exactfloat/exactfloat.cc - src/s2/util/math/exactfloat/exactfloat_internal.cc src/s2/util/math/mathutil.cc src/s2/util/units/length-units.cc) @@ -439,8 +439,6 @@ if(S2_ENABLE_INSTALL) DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/testing") install(FILES src/s2/util/bitmap/bitmap.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/bitmap") - install(FILES src/s2/util/bits/bits.h - DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/bits") install(FILES src/s2/util/coding/coder.h src/s2/util/coding/varint.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/coding") @@ -451,6 +449,7 @@ if(S2_ENABLE_INSTALL) src/s2/util/gtl/dense_hash_set.h src/s2/util/gtl/densehashtable.h src/s2/util/gtl/hashtable_common.h + src/s2/util/gtl/requires.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/gtl") install(FILES src/s2/util/hash/mix.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/hash") @@ -458,8 +457,8 @@ if(S2_ENABLE_INSTALL) src/s2/util/math/matrix3x3.h src/s2/util/math/vector.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/math") - install(FILES src/s2/util/math/exactfloat/exactfloat.h - src/s2/util/math/exactfloat/exactfloat_internal.h + install(FILES src/s2/util/math/exactfloat/bignum.h + src/s2/util/math/exactfloat/exactfloat.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/math/exactfloat") install(FILES src/s2/util/units/length-units.h src/s2/util/units/physical-units.h @@ -489,8 +488,6 @@ if(S2_ENABLE_INSTALL) endif() # S2_ENABLE_INSTALL if (BUILD_TESTS) - add_library(test_main "src/s2/test_main.cc") - if (NOT GOOGLETEST_ROOT) message(FATAL_ERROR "BUILD_TESTS requires GOOGLETEST_ROOT") endif() @@ -616,12 +613,13 @@ if (BUILD_TESTS) src/s2/s2shapeutil_shape_edge_id_test.cc src/s2/s2shapeutil_visit_crossing_edge_pairs_test.cc src/s2/s2text_format_test.cc - src/s2/util/math/exactfloat/exactfloat_internal_test.cc src/s2/s2validation_query_test.cc src/s2/s2wedge_relations_test.cc src/s2/s2winding_operation_test.cc src/s2/s2wrapped_shape_test.cc src/s2/sequence_lexicon_test.cc + src/s2/util/math/exactfloat/bignum_test.cc + src/s2/util/math/exactfloat/exactfloat_test.cc src/s2/value_lexicon_test.cc) enable_testing() @@ -632,9 +630,6 @@ if (BUILD_TESTS) target_link_libraries( ${test} s2testing s2 - test_main - benchmark - profiler absl::base absl::btree absl::check diff --git a/src/s2/util/math/exactfloat/BUILD b/src/s2/util/math/exactfloat/BUILD index 8642a1d4..82ae1312 100644 --- a/src/s2/util/math/exactfloat/BUILD +++ b/src/s2/util/math/exactfloat/BUILD @@ -1,18 +1,28 @@ package(default_visibility = ["//visibility:public"]) +cc_library( + name = "bignum" + srcs = ["bignum.cc"] + hdrs = ["bignum.h"] +) + cc_library( name = "exactfloat", - srcs = ["exactfloat.cc", "exactfloat_internal.cc"], - hdrs = ["exactfloat.h", "exactfloat_internal.h"], + srcs = ["exactfloat.cc"], + hdrs = ["exactfloat.h"], deps = [ + ":bignum", + "//s2/base:types", "//s2/base:port", "//s2/base:logging", - "@abseil-cpp//absl/log:log", - "@abseil-cpp//absl/log:absl_check", - "@abseil-cpp//absl/strings:str_format", "@abseil-cpp//absl/container:inlined_vector", + "@abseil-cpp//absl/log:absl_check", + "@abseil-cpp//absl/log:log", "@abseil-cpp//absl/numeric:bits", "@abseil-cpp//absl/numeric:int128", + "@abseil-cpp//absl/random:random", "@abseil-cpp//absl/strings:ascii", + "@abseil-cpp//absl/strings:str_cat", + "@abseil-cpp//absl/strings:str_format", ], ) diff --git a/src/s2/util/math/exactfloat/exactfloat_internal.cc b/src/s2/util/math/exactfloat/bignum.cc similarity index 98% rename from src/s2/util/math/exactfloat/exactfloat_internal.cc rename to src/s2/util/math/exactfloat/bignum.cc index eae68422..37106c72 100644 --- a/src/s2/util/math/exactfloat/exactfloat_internal.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -1,4 +1,6 @@ -#include "s2/util/math/exactfloat/exactfloat_internal.h" +#include "s2/util/math/exactfloat/bignum.h" + +namespace exactfloat_internal { // Threshold for fallback to simple multiplication, determined empirically. static constexpr int kKaratsubaThreshold = 64; @@ -482,7 +484,7 @@ static Bigit MulAddIp( // static void MulQuadratic( // absl::Span out, // absl::Span a, absl::Span b) { - ABSL_DCHECK_EQ(out.size(), a.size() + b.size()); + ABSL_DCHECK_GE(out.size(), a.size() + b.size()); // Make sure A is the longer of the two arguments. if (a.size() < b.size()) { @@ -551,7 +553,7 @@ class Arena { static void KaratsubaMulRec( // absl::Span dst, // absl::Span a, absl::Span b, Arena& arena) { - ABSL_DCHECK_EQ(dst.size(), a.size() + b.size()); + ABSL_DCHECK_GE(dst.size(), a.size() + b.size()); if (a.empty() || b.empty()) { return; } @@ -751,3 +753,5 @@ Bignum& Bignum::operator*=(const Bignum& b) { NormalizeSign(new_sign); return *this; } + +} // namespace exactfloat_internal diff --git a/src/s2/util/math/exactfloat/exactfloat_internal.h b/src/s2/util/math/exactfloat/bignum.h similarity index 95% rename from src/s2/util/math/exactfloat/exactfloat_internal.h rename to src/s2/util/math/exactfloat/bignum.h index 596f6f4e..7173cc98 100644 --- a/src/s2/util/math/exactfloat/exactfloat_internal.h +++ b/src/s2/util/math/exactfloat/bignum.h @@ -28,14 +28,16 @@ #include "absl/strings/ascii.h" #include "absl/strings/str_format.h" +namespace exactfloat_internal { + // A class to support arithmetic on large, arbitrary precision integers. // // Large integers are represented as an array of uint64_t values. class Bignum { public: - // Wrap uint64_t in a struct so we can avoid default initialization. + // Wrap uint64_t in a struct so we can make value-initialization a noop. // - // Avoiding default initialization overhead saves us 50% on some benchmarks. + // Avoiding value-initialization overhead saves us 50% on some benchmarks. struct Bigit { static constexpr int kBits = std::numeric_limits::digits; @@ -133,7 +135,7 @@ class Bignum { // General accessors. //-------------------------------------- - // Returns the number of bits required to represent the bignum. + // Returns the number of bits required for the magnitude of the value. friend int bit_width(const Bignum& a); // Returns the number of consecutive 0 bits in the value, starting from the @@ -290,7 +292,7 @@ class Bignum { return bigits_.empty() || bigits_.back() != 0; } - // Returns true if the bignum is the given power of two. + // Returns true if the bignum magnitude is the given power of two. bool IsPow2(int pow2) const { const int bigits = pow2 / Bigit::kBits; if (bigits_.size() != bigits + 1) { @@ -311,6 +313,9 @@ class Bignum { } // Compares magnitude with another bignum, returning -1, 0, or +1. + // + // Magnitudes are compared lexicographically from the most significant bigit + // (bigits_.back()) to the least significant (bigits_[0]). int CmpAbs(const Bignum& b) const; BigitVector bigits_; @@ -372,7 +377,7 @@ void AbslStringify(Sink& sink, const Bignum& b) { copy.sign_ = 1; // Repeatedly divide and modulo by 10^19 to get decimal chunks. - static constexpr uint64_t kBase = 10000000000000000000ull; + static constexpr uint64_t kBase = 10'000'000'000'000'000'000u; Bignum::BigitVector chunks; while (!copy.is_zero()) { @@ -487,3 +492,5 @@ inline int Bignum::CmpAbs(const Bignum& b) const { return 0; } + +} // namespace exactfloat_internal diff --git a/src/s2/util/math/exactfloat/exactfloat_internal_test.cc b/src/s2/util/math/exactfloat/bignum_test.cc similarity index 90% rename from src/s2/util/math/exactfloat/exactfloat_internal_test.cc rename to src/s2/util/math/exactfloat/bignum_test.cc index 7480b542..7cff4535 100644 --- a/src/s2/util/math/exactfloat/exactfloat_internal_test.cc +++ b/src/s2/util/math/exactfloat/bignum_test.cc @@ -13,20 +13,29 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "s2/util/math/exactfloat/exactfloat_internal.h" +#include "s2/util/math/exactfloat/bignum.h" #include -#include #include #include +// TODO: remove once benchmarks are available +#if 0 +#include "benchmark/benchmark.h" +#endif + #include "absl/base/no_destructor.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" -#include "benchmark/benchmark.h" #include "gtest/gtest.h" #include "openssl/bn.h" #include "openssl/crypto.h" +namespace exactfloat_internal { + +using ::testing::TestWithParam; + const uint64_t u8max = std::numeric_limits::max(); const uint64_t u16max = std::numeric_limits::max(); const uint64_t u32max = std::numeric_limits::max(); @@ -747,8 +756,8 @@ const int kRandomBignumCount = 128; static std::vector GenerateRandomNumbers(int bits) { std::vector numbers; - std::mt19937_64 rng(42); // Fixed seed for reproducibility + absl::BitGen bitgen; for (int i = 0; i < kRandomBignumCount; ++i) { std::string num; @@ -756,12 +765,9 @@ static std::vector GenerateRandomNumbers(int bits) { int decimal_digits = (bits * 3) / 10; // log10(2^bits) ≈ bits * 0.301 // First digit can't be zero - std::uniform_int_distribution first_digit(1, 9); - num += std::to_string(first_digit(rng)); - - std::uniform_int_distribution digit(0, 9); + absl::StrAppend(&num, absl::StrFormat("%d", absl::Uniform(bitgen, 1, 9))); for (int j = 1; j < decimal_digits; ++j) { - num += std::to_string(digit(rng)); + num += std::to_string(absl::Uniform(bitgen, 0, 9)); } numbers.push_back(num); @@ -831,35 +837,96 @@ const std::vector& MegaNumbers() { return *numbers; } -TEST(BignumTest, MultiplyCorrectVsOpenSSL) { +class VsOpenSSLTest : public TestWithParam> {}; + +TEST_P(VsOpenSSLTest, MultiplyCorrect) { // Test that multiplication produces correct results by comparing to OpenSSL. BN_CTX* ctx = BN_CTX_new(); - for (const auto& numbers : {SmallNumbers(), MediumNumbers(), LargeNumbers(), - HugeNumbers(), MediumNumbers()}) { - for (const auto& number : numbers) { - // Test same number multiplication (most likely to trigger edge cases) - const Bignum bn_a = *Bignum::FromString(number); - const Bignum bn_result = bn_a * bn_a; - - const OpenSSLBignum ssl_a(number); - OpenSSLBignum ssl_result; - BN_mul(ssl_result.get(), ssl_a.get(), ssl_a.get(), ctx); - - // Compare string representations - char* ssl_str = BN_bn2dec(ssl_result.get()); - std::string bn_str = absl::StrFormat("%v", bn_result); - - EXPECT_EQ(bn_str, std::string(ssl_str)) - << "Mismatch for multiplication" - << "\nBignum result: " << bn_str.substr(0, 100) << "..." - << "\nOpenSSL result: " << std::string(ssl_str).substr(0, 100) - << "..."; - OPENSSL_free(ssl_str); - } + for (const auto& number : GetParam()) { + // Test same number multiplication (most likely to trigger edge cases) + const Bignum bn_a = *Bignum::FromString(number); + const Bignum bn_result = bn_a * bn_a; + + const OpenSSLBignum ssl_a(number); + OpenSSLBignum ssl_result; + BN_mul(ssl_result.get(), ssl_a.get(), ssl_a.get(), ctx); + + // Compare string representations + char* ssl_str = BN_bn2dec(ssl_result.get()); + std::string bn_str = absl::StrFormat("%v", bn_result); + + EXPECT_EQ(bn_str, std::string(ssl_str)) + << "Mismatch for multiplication" + << "\nBignum result: " << bn_str.substr(0, 100) << "..." + << "\nOpenSSL result: " << std::string(ssl_str).substr(0, 100) << "..."; + OPENSSL_free(ssl_str); } BN_CTX_free(ctx); } +TEST_P(VsOpenSSLTest, AdditionCorrect) { + // Test that addition produces correct results by comparing to OpenSSL. + const std::vector numbers = GetParam(); + for (int i = 0; i < numbers.size(); ++i) { + const auto& num_a = numbers[i]; + const auto& num_b = numbers[(i + 1) % numbers.size()]; + + const Bignum bn_a = *Bignum::FromString(num_a); + const Bignum bn_b = *Bignum::FromString(num_b); + + const Bignum bn_result = bn_a + bn_b; + + const OpenSSLBignum ssl_a(num_a); + const OpenSSLBignum ssl_b(num_b); + OpenSSLBignum ssl_result; + BN_add(ssl_result.get(), ssl_a.get(), ssl_b.get()); + + // Compare string representations + char* ssl_str = BN_bn2dec(ssl_result.get()); + std::string bn_str = absl::StrFormat("%v", bn_result); + + EXPECT_EQ(bn_str, std::string(ssl_str)) + << "Mismatch for addition" + << "\nBignum result: " << bn_str.substr(0, 100) << "..." + << "\nOpenSSL result: " << std::string(ssl_str).substr(0, 100) << "..."; + OPENSSL_free(ssl_str); + } +} + +TEST_P(VsOpenSSLTest, SubtractionCorrect) { + // Test that subtraction produces correct results by comparing to OpenSSL. + const std::vector numbers = GetParam(); + for (int i = 0; i < numbers.size(); ++i) { + const auto& num_a = numbers[i]; + const auto& num_b = numbers[(i + 1) % numbers.size()]; + + const Bignum bn_a = *Bignum::FromString(num_a); + const Bignum bn_b = *Bignum::FromString(num_b); + + const Bignum bn_result = bn_a - bn_b; + + const OpenSSLBignum ssl_a(num_a); + const OpenSSLBignum ssl_b(num_b); + OpenSSLBignum ssl_result; + BN_sub(ssl_result.get(), ssl_a.get(), ssl_b.get()); + + // Compare string representations + char* ssl_str = BN_bn2dec(ssl_result.get()); + std::string bn_str = absl::StrFormat("%v", bn_result); + + EXPECT_EQ(bn_str, std::string(ssl_str)) + << "Mismatch for addition" + << "\nBignum result: " << bn_str.substr(0, 100) << "..." + << "\nOpenSSL result: " << std::string(ssl_str).substr(0, 100) << "..."; + OPENSSL_free(ssl_str); + } +} + +INSTANTIATE_TEST_SUITE_P(VsOpenSSL, VsOpenSSLTest, + ::testing::Values(SmallNumbers(), MediumNumbers(), + LargeNumbers(), HugeNumbers(), + MediumNumbers())); + // TODO: Enable once benchmark is integrated. #if 0 @@ -1089,3 +1156,5 @@ void BM_OpenSSL_PowMedium(benchmark::State& state) { } BENCHMARK(BM_OpenSSL_PowMedium); #endif + +} // namespace exactfloat_internal diff --git a/src/s2/util/math/exactfloat/exactfloat.cc b/src/s2/util/math/exactfloat/exactfloat.cc index 634b46ed..eca17f8d 100644 --- a/src/s2/util/math/exactfloat/exactfloat.cc +++ b/src/s2/util/math/exactfloat/exactfloat.cc @@ -67,15 +67,13 @@ ExactFloat::ExactFloat(double v) { ExactFloat::ExactFloat(int v) { sign_ = (v >= 0) ? 1 : -1; - // Note that this works even for INT_MIN. + + // Note that this works even for INT_MIN, as |INT_MIN| < |INT_MAX|. bn_ = Bignum(abs(v)); bn_exp_ = 0; Canonicalize(); } -ExactFloat::ExactFloat(const ExactFloat& b) - : sign_(b.sign_), bn_exp_(b.bn_exp_), bn_(b.bn_) {} - ExactFloat ExactFloat::SignedZero(int sign) { ExactFloat r; r.set_zero(sign); @@ -140,10 +138,10 @@ double ExactFloat::ToDoubleHelper() const { } auto opt_mantissa = bn_.ConvertTo(); ABSL_DCHECK(opt_mantissa.has_value()); - uint64_t d_mantissa = opt_mantissa.value(); + // We rely on ldexp() to handle overflow and underflow. (It will return a // signed zero or infinity if the result is too small or too large.) - return sign_ * ldexp(static_cast(d_mantissa), bn_exp_); + return sign_ * ldexp(static_cast(*opt_mantissa), bn_exp_); } ExactFloat ExactFloat::RoundToMaxPrec(int max_prec, RoundingMode mode) const { diff --git a/src/s2/util/math/exactfloat/exactfloat.h b/src/s2/util/math/exactfloat/exactfloat.h index f77939a0..82ddce02 100644 --- a/src/s2/util/math/exactfloat/exactfloat.h +++ b/src/s2/util/math/exactfloat/exactfloat.h @@ -116,7 +116,7 @@ #include #include -#include "s2/util/math/exactfloat/exactfloat_internal.h" +#include "s2/util/math/exactfloat/bignum.h" class ExactFloat { public: @@ -181,13 +181,6 @@ class ExactFloat { // example, "0.125" is allowed but "0.1" is not). explicit ExactFloat(const char* s) { Unimplemented(); } - // Copy constructor. - ExactFloat(const ExactFloat& b); - - // The destructor is not virtual for efficiency reasons. Therefore no - // subclass should declare additional fields that require destruction. - inline ~ExactFloat() = default; - ///////////////////////////////////////////////////////////////////// // Constants // @@ -496,6 +489,8 @@ class ExactFloat { friend ExactFloat logb(const ExactFloat& a); protected: + using Bignum = exactfloat_internal::Bignum; + // Non-normal numbers are represented using special exponent values and a // mantissa of zero. Do not change these values; methods such as // is_normal() make assumptions about their ordering. Non-normal numbers From f374cb3d8c053418071db875b07a0f7c6e520965 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Mon, 29 Sep 2025 08:58:40 -0600 Subject: [PATCH 04/31] Address PR Round 3 comments. - Remove Bigit wrapper around uint64_t as it seemed to be moot. - Make CmpAbs a free function and use it to DCHECK preconditions. - Make powers-of-10 precomputation constexpr. - Ensure clean compile with -Wsign-compare. - General readability and comment updates. --- README.md | 5 +- src/s2/util/math/exactfloat/BUILD | 7 +- src/s2/util/math/exactfloat/bignum.cc | 112 +++++++++++---------- src/s2/util/math/exactfloat/bignum.h | 92 +++++------------ src/s2/util/math/exactfloat/bignum_test.cc | 4 +- 5 files changed, 92 insertions(+), 128 deletions(-) diff --git a/README.md b/README.md index e900b684..38b1be44 100644 --- a/README.md +++ b/README.md @@ -140,12 +140,11 @@ Enable the python interface with `-DWITH_PYTHON=ON`. # For Testing -If BUILD_TESTS is 'on' (the default), and benchmarks are enabled, then OpenSSL -must be available to build some tests: +If BUILD_TESTS is 'on' (the default), then OpenSSL must be available to build +some tests: * [OpenSSL](https://github.com/openssl/openssl) (for its bignum library) - If OpenSSL is installed in a non-standard location set `OPENSSL_ROOT_DIR` before running configure, for example on macOS: ``` diff --git a/src/s2/util/math/exactfloat/BUILD b/src/s2/util/math/exactfloat/BUILD index 82ae1312..291d22c1 100644 --- a/src/s2/util/math/exactfloat/BUILD +++ b/src/s2/util/math/exactfloat/BUILD @@ -1,9 +1,10 @@ package(default_visibility = ["//visibility:public"]) cc_library( - name = "bignum" - srcs = ["bignum.cc"] - hdrs = ["bignum.h"] + name = "bignum", + srcs = ["bignum.cc"], + hdrs = ["bignum.h"], + visibility = "//visibility:private" ) cc_library( diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index 37106c72..37ec248f 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -5,12 +5,23 @@ namespace exactfloat_internal { // Threshold for fallback to simple multiplication, determined empirically. static constexpr int kKaratsubaThreshold = 64; -// Avoid the dependent name clutter. -using Bigit = typename Bignum::Bigit; - static Bigit MulAdd( // absl::Span out, absl::Span a, Bigit b, Bigit c); +int CmpAbs(absl::Span a, absl::Span b) { + if (a.size() != b.size()) { + return a.size() < b.size() ? -1 : +1; + } + + for (int i = a.size() - 1; i >= 0; --i) { + if (a[i] != b[i]) { + return a[i] < b[i] ? -1 : +1; + } + } + + return 0; +} + std::optional Bignum::FromString(absl::string_view s) { // A chunk is up to 19 decimal digits, which can always fit into a Bigit. constexpr int kMaxChunkDigits = std::numeric_limits::digits10; @@ -22,13 +33,10 @@ std::optional Bignum::FromString(absl::string_view s) { // semi-linear. // Precomputed powers of 10. - static const auto kPow10 = []() { - std::array out; - - Bigit value = 1; - for (int i = 0; i < out.size(); ++i) { - out[i] = value; - value = value * 10; + static constexpr auto kPow10 = []() { + std::array out = {1}; + for (size_t i = 1; i < out.size(); ++i) { + out[i] = 10 * out[i - 1]; } return out; }(); @@ -70,8 +78,7 @@ std::optional Bignum::FromString(absl::string_view s) { return std::nullopt; } - // Accumulate digit into the local 64-bit chunk. Skip leading - // zeros. + // Accumulate digit into the local 64-bit chunk. Skip leading zeros. uint64_t digit = static_cast(c - '0'); if (!seen_digit && digit == 0) { continue; @@ -100,8 +107,8 @@ int bit_width(const Bignum& a) { // Bit width is the bits in the least significant bigits + bit width of // the most significant word. const int msw_width = - (Bigit::kBits - absl::countl_zero(a.bigits_.back().value_)); - const int lsw_width = (a.bigits_.size() - 1) * Bigit::kBits; + (Bignum::kBigitBits - absl::countl_zero(a.bigits_.back())); + const int lsw_width = (a.bigits_.size() - 1) * Bignum::kBigitBits; return msw_width + lsw_width; } @@ -113,7 +120,7 @@ int countr_zero(const Bignum& a) { int nzero = 0; for (Bigit bigit : a.bigits_) { if (bigit == 0) { - nzero += Bigit::kBits; + nzero += Bignum::kBigitBits; } else { nzero += absl::countr_zero(static_cast(bigit)); break; @@ -128,8 +135,8 @@ bool Bignum::Bit(int nbit) const { return false; } - const int digit = nbit / Bigit::kBits; - const int shift = nbit % Bigit::kBits; + const size_t digit = nbit / kBigitBits; + const size_t shift = nbit % kBigitBits; if (digit >= size()) { return false; @@ -150,8 +157,8 @@ Bignum& Bignum::operator<<=(int nbit) { return *this; } - const int nbigit = nbit / Bigit::kBits; - const int nrem = nbit % Bigit::kBits; + const int nbigit = nbit / kBigitBits; + const int nrem = nbit % kBigitBits; // First, handle the whole-bigit shift by inserting zeros. bigits_.insert(bigits_.begin(), nbigit, 0); @@ -162,7 +169,7 @@ Bignum& Bignum::operator<<=(int nbit) { for (size_t i = 0; i < bigits_.size(); ++i) { const Bigit old_val = bigits_[i]; bigits_[i] = (old_val << nrem) | carry; - carry = old_val >> (Bigit::kBits - nrem); + carry = old_val >> (kBigitBits - nrem); } if (carry) { @@ -184,8 +191,8 @@ Bignum& Bignum::operator>>=(int nbit) { return SetZero(); } - const int nbigit = nbit / Bigit::kBits; - const int nrem = nbit % Bigit::kBits; + const int nbigit = nbit / kBigitBits; + const int nrem = nbit % kBigitBits; // First, handle the whole-bigit shift by removing bigits. bigits_.erase(bigits_.begin(), bigits_.begin() + nbigit); @@ -196,7 +203,7 @@ Bignum& Bignum::operator>>=(int nbit) { for (int i = static_cast(bigits_.size()) - 1; i >= 0; --i) { const Bigit old_val = bigits_[i]; bigits_[i] = (old_val >> nrem) | carry; - carry = old_val << (Bigit::kBits - nrem); + carry = old_val << (kBigitBits - nrem); } } @@ -306,7 +313,7 @@ static Bigit AddInPlace(absl::Span a, absl::Span b) { return carry; } -static ssize_t AddInto( // +static size_t AddInto( // absl::Span dst, absl::Span a, absl::Span b) { const size_t max_size = std::max(a.size(), b.size()); @@ -321,8 +328,8 @@ static ssize_t AddInto( // Bigit carry = 0; // Dispatch four at a time to help loop unrolling. - int size = min_size; - int i = 0; + size_t size = min_size; + size_t i = 0; while (size >= i + 4) { for (int j = 0; j < 4; ++j) { pdst[i] = AddCarry(pa[i], pb[i], carry); @@ -362,19 +369,22 @@ static ssize_t AddInto( // // Computes a -= b. Returns the final borrow (if any). // +// A must be expanded to match the size of B and the total number of digits +// actually set in A must be passed in via a_digits. +// // REQUIRES: |a| < |b|. -// NOTE: A must be pre-expanded to match the size of b. -static Bigit SubLtIp( // - absl::Span a, absl::Span b, ssize_t na) { +static Bigit SubLtInPlace( // + absl::Span a, absl::Span b, size_t a_digits) { ABSL_DCHECK_EQ(a.size(), b.size()); + ABSL_DCHECK_LT(CmpAbs(a, b), 0); Bigit* pa = a.data(); const Bigit* pb = b.data(); Bigit borrow = 0; // Dispatch four at a time to help loop unrolling. - int size = na; - int i = 0; + size_t size = a_digits; + size_t i = 0; while (size >= i + 4) { for (int j = 0; j < 4; ++j) { pa[i] = SubBorrow(pb[i], pa[i], borrow); @@ -383,7 +393,7 @@ static Bigit SubLtIp( // } // Finish remainder. - for (; i < na; ++i) { + for (; i < a_digits; ++i) { pa[i] = SubBorrow(pb[i], pa[i], borrow); } @@ -397,8 +407,9 @@ static Bigit SubLtIp( // // Computes a -= b. Returns the final borrow (if any). // // REQUIRES: |a| >= |b|. -static Bigit SubGeIp(absl::Span a, absl::Span b) { +static Bigit SubGeInPlace(absl::Span a, absl::Span b) { ABSL_DCHECK_GE(a.size(), b.size()); + ABSL_DCHECK_GE(CmpAbs(a, b), 0); Bigit borrow = 0; @@ -406,8 +417,8 @@ static Bigit SubGeIp(absl::Span a, absl::Span b) { const Bigit* pb = b.data(); // Dispatch four at a time to help loop unrolling. - int size = b.size(); - int done = 0; + size_t size = b.size(); + size_t done = 0; while (size >= done + 4) { for (int i = 0; i < 4; ++i) { pa[done] = SubBorrow(pa[done], pb[done], borrow); @@ -457,7 +468,7 @@ Bigit MulAdd( // // Computes out[i] += a[i]*b in place. // // Returns the final carry, if any. -static Bigit MulAddIp( // +static Bigit MulAddInPlace( // absl::Span out, absl::Span a, Bigit b) { Bigit* pout = out.data(); const Bigit* pa = a.data(); @@ -500,18 +511,18 @@ static void MulQuadratic( // auto upper = out.subspan(a.size()); upper[0] = MulAdd(out, a, b[0]); - const int size = b.size(); - int i = 1; + const size_t size = b.size(); + size_t i = 1; while (size >= i + 4) { for (int j = 0; j < 4; ++j) { - upper[i] = MulAddIp(out.subspan(i), a, b[i]); + upper[i] = MulAddInPlace(out.subspan(i), a, b[i]); ++i; } } // Finish remainder (if any). for (; i < size; ++i) { - upper[i] = MulAddIp(out.subspan(i), a, b[i]); + upper[i] = MulAddInPlace(out.subspan(i), a, b[i]); } // Finish zeroing out upper half. @@ -547,7 +558,7 @@ class Arena { private: ssize_t used_ = 0; - absl::InlinedVector data_; + std::vector data_; }; static void KaratsubaMulRec( // @@ -558,7 +569,8 @@ static void KaratsubaMulRec( // return; } - // Karatsuba lets us represent two numbers, A and B thusly: + // Karatsuba lets us represent two numbers of 2M bigits each, A and B, as: + // // A = a1*10^M + a0 // B = b1*10^M + b0 // @@ -617,8 +629,8 @@ static void KaratsubaMulRec( // KaratsubaMulRec(z1, sa, sb, arena); // NOTE: (a0 + a1) * (b0 + b1) >= a0*b0 + a1*b1 so this never underflows. - SubGeIp(z1, z0); - SubGeIp(z1, z2); + SubGeInPlace(z1, z0); + SubGeInPlace(z1, z2); // We need to add z1*10^half which we can do by adding it offset. AddInPlace(dst.subspan(half), z1); @@ -677,13 +689,13 @@ Bignum& Bignum::operator+=(const Bignum& b) { } else { if (CmpAbs(b) >= 0) { // |a| >= |b|, so a - b is same sign as a. - SubGeIp(absl::MakeSpan(bigits_), b.bigits_); + SubGeInPlace(absl::MakeSpan(bigits_), b.bigits_); NormalizeSign(sign_); } else { // |a| < |b|, so a - b is same sign as b. const int prev_size = size(); bigits_.resize(b.size()); - SubLtIp(absl::MakeSpan(bigits_), b.bigits_, prev_size); + SubLtInPlace(absl::MakeSpan(bigits_), b.bigits_, prev_size); NormalizeSign(b.sign_); } } @@ -692,10 +704,6 @@ Bignum& Bignum::operator+=(const Bignum& b) { } Bignum& Bignum::operator-=(const Bignum& b) { - if (this == &b) { - return SetZero(); - } - if (b.is_zero()) { return *this; } @@ -713,12 +721,12 @@ Bignum& Bignum::operator-=(const Bignum& b) { Normalize(); } else { if (CmpAbs(b) >= 0) { - SubGeIp(absl::MakeSpan(bigits_), b.bigits_); + SubGeInPlace(absl::MakeSpan(bigits_), b.bigits_); NormalizeSign(sign_); } else { const int prev_size = size(); bigits_.resize(b.size()); - SubLtIp(absl::MakeSpan(bigits_), b.bigits_, prev_size); + SubLtInPlace(absl::MakeSpan(bigits_), b.bigits_, prev_size); NormalizeSign(-sign_); } } diff --git a/src/s2/util/math/exactfloat/bignum.h b/src/s2/util/math/exactfloat/bignum.h index 7173cc98..e02ece5e 100644 --- a/src/s2/util/math/exactfloat/bignum.h +++ b/src/s2/util/math/exactfloat/bignum.h @@ -30,52 +30,23 @@ namespace exactfloat_internal { +using Bigit = uint64_t; + +// Compares magnitude magnitude of two bigit vectors, returning -1, 0, or +1. +// +// Magnitudes are compared lexicographically from the most significant bigit +// to the least significant. +int CmpAbs(absl::Span a, absl::Span b); + // A class to support arithmetic on large, arbitrary precision integers. // // Large integers are represented as an array of uint64_t values. class Bignum { public: - // Wrap uint64_t in a struct so we can make value-initialization a noop. - // - // Avoiding value-initialization overhead saves us 50% on some benchmarks. - struct Bigit { - static constexpr int kBits = std::numeric_limits::digits; - - Bigit() {} - constexpr Bigit(uint64_t value) : value_(value) {} - explicit Bigit(absl::uint128 value) : value_(absl::Uint128Low64(value)) {} - - constexpr operator uint64_t() const { return value_; } - constexpr Bigit& operator=(uint64_t value) { - value_ = value; - return *this; - } - - ABSL_ATTRIBUTE_ALWAYS_INLINE constexpr Bigit& operator--(int) { - value_--; - return *this; - } - - ABSL_ATTRIBUTE_ALWAYS_INLINE constexpr friend Bigit operator*( // - int a, Bigit b) { - return a * b.value_; - } - - ABSL_ATTRIBUTE_ALWAYS_INLINE friend absl::uint128 operator*( // - absl::uint128 a, Bigit b) { - return a * b.value_; - } - - ABSL_ATTRIBUTE_ALWAYS_INLINE friend absl::uint128 operator+( // - absl::uint128 a, Bigit b) { - return a + b.value_; - } - - uint64_t value_; - }; - using BigitVector = absl::InlinedVector; + static constexpr int kBigitBits = std::numeric_limits::digits; + Bignum() = default; // Constructs a bignum from an integral value (signed or unsigned). @@ -294,29 +265,28 @@ class Bignum { // Returns true if the bignum magnitude is the given power of two. bool IsPow2(int pow2) const { - const int bigits = pow2 / Bigit::kBits; + const size_t bigits = pow2 / kBigitBits; if (bigits_.size() != bigits + 1) { return false; } // Verify lower words are zero. - for (int i = 0; i < bigits; ++i) { + for (size_t i = 0; i < bigits; ++i) { if (bigits_[i] != 0) { return false; } } // Check final word is power of two. - pow2 -= bigits * Bigit::kBits; - ABSL_DCHECK_LT(pow2, Bigit::kBits); + pow2 -= bigits * kBigitBits; + ABSL_DCHECK_LT(pow2, kBigitBits); return bigits_.back() == (Bigit(1) << pow2); } // Compares magnitude with another bignum, returning -1, 0, or +1. - // - // Magnitudes are compared lexicographically from the most significant bigit - // (bigits_.back()) to the least significant (bigits_[0]). - int CmpAbs(const Bignum& b) const; + int CmpAbs(const Bignum& b) const { + return ::exactfloat_internal::CmpAbs(bigits_, b.bigits_); + } BigitVector bigits_; int sign_ = 0; @@ -349,12 +319,12 @@ Bignum::Bignum(T value) { } // Pack the magnitude into bigits. - if constexpr (std::numeric_limits::digits <= Bigit::kBits) { + if constexpr (std::numeric_limits::digits <= kBigitBits) { bigits_.push_back(static_cast(mag)); } else { while (mag) { bigits_.push_back(static_cast(mag)); - mag >>= Bigit::kBits; + mag >>= kBigitBits; } } } @@ -384,13 +354,13 @@ void AbslStringify(Sink& sink, const Bignum& b) { absl::uint128 rem = 0; for (int i = static_cast(copy.bigits_.size()) - 1; i >= 0; --i) { absl::uint128 acc = (rem << 64) + copy.bigits_[i]; - Bignum::Bigit quot = static_cast(acc / kBase); + Bigit quot = static_cast(acc / kBase); rem = acc - absl::uint128(quot) * kBase; copy.bigits_[i] = quot; } copy.Normalize(); - chunks.push_back(static_cast(rem)); + chunks.push_back(static_cast(rem)); } ABSL_DCHECK(!chunks.empty()); @@ -413,7 +383,7 @@ inline bool Bignum::FitsIn() const { // Maximum number of bits that could fit in the output type. constexpr int kTBitWidth = std::numeric_limits::digits; - constexpr int kMaxBigits = (kTBitWidth + (Bigit::kBits - 1)) / Bigit::kBits; + constexpr int kMaxBigits = (kTBitWidth + (kBigitBits - 1)) / kBigitBits; // Fast reject if the bignum couldn't conceivably fit. if (bigits_.size() > kMaxBigits) { @@ -458,14 +428,14 @@ T Bignum::Cast() const { // Grab the bottom bits into an unsigned value. UT residue = 0; for (size_t i = 0; i < bigits_.size(); ++i) { - const int shift = i * Bigit::kBits; + const int shift = i * kBigitBits; if (shift >= kTBitWidth) { break; } const int room = kTBitWidth - shift; UT chunk = static_cast(bigits_[i]); - if (room < Bigit::kBits && room < std::numeric_limits::digits) { + if (room < kBigitBits && room < std::numeric_limits::digits) { chunk &= (UT(1) << room) - UT(1); } residue |= (chunk << shift); @@ -479,18 +449,4 @@ T Bignum::Cast() const { return static_cast(residue); } -inline int Bignum::CmpAbs(const Bignum& b) const { - if (size() != b.size()) { - return size() < b.size() ? -1 : +1; - } - - for (int i = size() - 1; i >= 0; --i) { - if (bigits_[i] != b.bigits_[i]) { - return bigits_[i] < b.bigits_[i] ? -1 : +1; - } - } - - return 0; -} - } // namespace exactfloat_internal diff --git a/src/s2/util/math/exactfloat/bignum_test.cc b/src/s2/util/math/exactfloat/bignum_test.cc index 7cff4535..979dec06 100644 --- a/src/s2/util/math/exactfloat/bignum_test.cc +++ b/src/s2/util/math/exactfloat/bignum_test.cc @@ -867,7 +867,7 @@ TEST_P(VsOpenSSLTest, MultiplyCorrect) { TEST_P(VsOpenSSLTest, AdditionCorrect) { // Test that addition produces correct results by comparing to OpenSSL. const std::vector numbers = GetParam(); - for (int i = 0; i < numbers.size(); ++i) { + for (size_t i = 0; i < numbers.size(); ++i) { const auto& num_a = numbers[i]; const auto& num_b = numbers[(i + 1) % numbers.size()]; @@ -896,7 +896,7 @@ TEST_P(VsOpenSSLTest, AdditionCorrect) { TEST_P(VsOpenSSLTest, SubtractionCorrect) { // Test that subtraction produces correct results by comparing to OpenSSL. const std::vector numbers = GetParam(); - for (int i = 0; i < numbers.size(); ++i) { + for (size_t i = 0; i < numbers.size(); ++i) { const auto& num_a = numbers[i]; const auto& num_b = numbers[(i + 1) % numbers.size()]; From 54244eb7158bc5f3287b1719963b7843937a37f2 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Mon, 29 Sep 2025 11:44:59 -0600 Subject: [PATCH 05/31] Fix Karatsuba for highly asymmetric operand sizes. Previously an operand A much smaller than operand B wouldn't properly subdivide all the way which generated incorrect results. I wrote a fuzztest for both multiplication and addition and they found no further issues after several million iterations. --- src/s2/util/math/exactfloat/bignum.cc | 72 +++++++++++++++------- src/s2/util/math/exactfloat/bignum_test.cc | 27 +++++++- 2 files changed, 77 insertions(+), 22 deletions(-) diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index 37ec248f..8731ebd7 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -535,7 +535,10 @@ static void MulQuadratic( // template static std::pair, absl::Span> Split( // absl::Span span, int a, int b) { - return {span.subspan(0, a), span.subspan(a, b)}; + if (a < span.size()) { + return {span.subspan(0, a), span.subspan(a, b)}; + } + return {span.subspan(0, a), {}}; }; // A simple bump allocator to avoid allocating memory during recursion. @@ -551,6 +554,8 @@ class Arena { return absl::Span(data_.data() + start, n); } + size_t Used() const { return used_; } + void Release(ssize_t n) { ABSL_DCHECK_LE(n, used_); used_ -= n; @@ -566,18 +571,21 @@ static void KaratsubaMulRec( // absl::Span a, absl::Span b, Arena& arena) { ABSL_DCHECK_GE(dst.size(), a.size() + b.size()); if (a.empty() || b.empty()) { + absl::c_fill(dst, 0); return; } - // Karatsuba lets us represent two numbers of 2M bigits each, A and B, as: + int arena_start = arena.Used(); + + // Karatsuba lets us represent two numbers of M bigits each, A and B, as: // - // A = a1*10^M + a0 - // B = b1*10^M + b0 + // A = a1*10^(M/2) + a0 + // B = b1*10^(M/2) + b0 // // Which we can multiply out: - // AB = (a1*10^M + a0)*(b1*10^M + b0); - // = a1*b1*10^(2M) + (a1*b0 + a0*b1)*10^M + a0*b0 - // = z2 * 10^2M + z1*10^M + z0 + // AB = (a1*10^(M/2) + a0)*(b1*10^(M/2) + b0); + // = a1*b1*10^M + (a1*b0 + a0*b1)*10^(M/2) + a0*b0 + // = z2 * 10^M + z1*10^(M/2) + z0 // // Where: // z0 = a0*b0 @@ -594,33 +602,42 @@ static void KaratsubaMulRec( // // with those individual multiplies able to be recursively divided. // Fall back to long multiplication when we're small enough. - if (dst.size() < kKaratsubaThreshold) { + if (dst.size() <= kKaratsubaThreshold) { MulQuadratic(dst, a, b); return; } - const int half = (std::min(a.size(), b.size()) + 1) / 2; + const int half = (std::max(a.size(), b.size()) + 1) / 2; // Split the inputs into contiguous subspans. auto [a0, a1] = Split(a, half, half); auto [b0, b1] = Split(b, half, half); + // We can skip adding the z2 term if a1 or b1 is zero. + const bool z2_zero = (a1.empty() || b1.empty()); + // Make space to hold results in the output and multiply sub-terms. // z0 = a0 * b0 // z2 = a1 * b1 - auto [z0, z2] = Split(dst, 2 * half, 2 * half); - + auto [z0, z2] = Split(dst, a0.size() + b0.size(), a1.size() + b1.size()); KaratsubaMulRec(z0, a0, b0, arena); KaratsubaMulRec(z2, a1, b1, arena); - // Compute (a0 + a1) and (b0 + b1) using space from the arena. + // Compute (a0 + a1) and (b0 + b1) // - // The sums may or may not carry. We pop the extra bigit off if they - // don't. - auto sa = arena.Alloc(half + 1); - auto sb = arena.Alloc(half + 1); - sa = sa.first(AddInto(sa, a0, a1)); - sb = sb.first(AddInto(sb, b0, b1)); + // If the upper terms are zero we can just re-use the terms we have, otherwise + // we compute the sum and pop off the MSB bigit if no carry occurred. + absl::Span sa = a0; + absl::Span sb = b0; + if (!a1.empty()) { + absl::Span tmp = arena.Alloc(half + 1); + sa = tmp.first(AddInto(tmp, a0, a1)); + } + + if (!b1.empty()) { + absl::Span tmp = arena.Alloc(half + 1); + sb = tmp.first(AddInto(tmp, b0, b1)); + } // Compute z1 = sa*sb - z0 - z2 = (a0 + a1)*(b0 + b1) - z0 - z2 auto z1 = arena.Alloc(sa.size() + sb.size()); @@ -630,13 +647,26 @@ static void KaratsubaMulRec( // // NOTE: (a0 + a1) * (b0 + b1) >= a0*b0 + a1*b1 so this never underflows. SubGeInPlace(z1, z0); - SubGeInPlace(z1, z2); + if (!z2_zero) { + SubGeInPlace(z1, z2); + } + + // Z1 may overflow because of a carry in (a0 + b0) or (a1 + b1) but + // subtracting z0 and z2 will always bring it back in range, trim any leading + // zeros to shorten the value if needed. + int i = 0; + for (i = z1.size() - 1; i > 0; --i) { + if (z1[i]) { + break; + } + } + z1 = z1.first(i + 1); // We need to add z1*10^half which we can do by adding it offset. AddInPlace(dst.subspan(half), z1); // Release temporary memory we used. - arena.Release(z1.size() + sb.size() + sa.size()); + arena.Release(arena.Used() - arena_start); } Bignum::BigitVector Bignum::KaratsubaMul( // @@ -646,7 +676,7 @@ Bignum::BigitVector Bignum::KaratsubaMul( // } // Each step of Karatsuba splits at: - // N = std::ceil(std::min(a.size(), b.size())/2) + // N = (std::max(a.size() + b.size() + 1) / 2 // // We have to hold a total of 4*(N + 1) bigits as temporaries at each step. // diff --git a/src/s2/util/math/exactfloat/bignum_test.cc b/src/s2/util/math/exactfloat/bignum_test.cc index 979dec06..769997d5 100644 --- a/src/s2/util/math/exactfloat/bignum_test.cc +++ b/src/s2/util/math/exactfloat/bignum_test.cc @@ -839,7 +839,7 @@ const std::vector& MegaNumbers() { class VsOpenSSLTest : public TestWithParam> {}; -TEST_P(VsOpenSSLTest, MultiplyCorrect) { +TEST_P(VsOpenSSLTest, SquaringCorrect) { // Test that multiplication produces correct results by comparing to OpenSSL. BN_CTX* ctx = BN_CTX_new(); for (const auto& number : GetParam()) { @@ -864,6 +864,31 @@ TEST_P(VsOpenSSLTest, MultiplyCorrect) { BN_CTX_free(ctx); } +TEST_P(VsOpenSSLTest, MultiplyCorrect) { + // Multiply by a small constant to test widely different operand sizes. + BN_CTX* ctx = BN_CTX_new(); + for (const auto& number : GetParam()) { + // Test same number multiplication (most likely to trigger edge cases) + const Bignum bn_a = *Bignum::FromString(number); + const Bignum bn_result = Bignum(2) * bn_a; + + const OpenSSLBignum ssl_a(number); + OpenSSLBignum ssl_result; + BN_mul(ssl_result.get(), OpenSSLBignum("2").get(), ssl_a.get(), ctx); + + // Compare string representations + char* ssl_str = BN_bn2dec(ssl_result.get()); + std::string bn_str = absl::StrFormat("%v", bn_result); + + EXPECT_EQ(bn_str, std::string(ssl_str)) + << "Mismatch for multiplication" + << "\nBignum result: " << bn_str.substr(0, 100) << "..." + << "\nOpenSSL result: " << std::string(ssl_str).substr(0, 100) << "..."; + OPENSSL_free(ssl_str); + } + BN_CTX_free(ctx); +} + TEST_P(VsOpenSSLTest, AdditionCorrect) { // Test that addition produces correct results by comparing to OpenSSL. const std::vector numbers = GetParam(); From 640f6b3ccfe4a630f36cfdd968a29bb67c540d56 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Mon, 29 Sep 2025 14:25:12 -0600 Subject: [PATCH 06/31] Remove assignment operator. Implicit assignment should be fine for us. --- src/s2/util/math/exactfloat/exactfloat.cc | 9 --------- src/s2/util/math/exactfloat/exactfloat.h | 3 --- 2 files changed, 12 deletions(-) diff --git a/src/s2/util/math/exactfloat/exactfloat.cc b/src/s2/util/math/exactfloat/exactfloat.cc index eca17f8d..fc4cf9a2 100644 --- a/src/s2/util/math/exactfloat/exactfloat.cc +++ b/src/s2/util/math/exactfloat/exactfloat.cc @@ -367,15 +367,6 @@ std::string ExactFloat::ToUniqueString() const { return absl::StrFormat("%s<%d>", ToString(), prec()); } -ExactFloat& ExactFloat::operator=(const ExactFloat& b) { - if (this != &b) { - sign_ = b.sign_; - bn_exp_ = b.bn_exp_; - bn_ = b.bn_; - } - return *this; -} - ExactFloat ExactFloat::operator-() const { return CopyWithSign(-sign_); } ExactFloat operator+(const ExactFloat& a, const ExactFloat& b) { diff --git a/src/s2/util/math/exactfloat/exactfloat.h b/src/s2/util/math/exactfloat/exactfloat.h index 82ddce02..0dd44002 100644 --- a/src/s2/util/math/exactfloat/exactfloat.h +++ b/src/s2/util/math/exactfloat/exactfloat.h @@ -321,9 +321,6 @@ class ExactFloat { ///////////////////////////////////////////////////////////////////////////// // Operators - // Assignment operator. - ExactFloat& operator=(const ExactFloat& b); - // Unary plus. ExactFloat operator+() const { return *this; } From b2f72d3b1d339e4f6745e519572f599c8809b229 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Tue, 30 Sep 2025 08:12:37 -0600 Subject: [PATCH 07/31] PR round 4 fixes. - Remove usage of ssize_t - Add include guards to bignum.h - Explictly plumb through bitgen in benchmarks/tests. - Variable renaming. --- src/s2/util/math/exactfloat/BUILD | 2 +- src/s2/util/math/exactfloat/bignum.cc | 13 +- src/s2/util/math/exactfloat/bignum.h | 5 + src/s2/util/math/exactfloat/bignum_test.cc | 191 ++++++++++++--------- 4 files changed, 127 insertions(+), 84 deletions(-) diff --git a/src/s2/util/math/exactfloat/BUILD b/src/s2/util/math/exactfloat/BUILD index 291d22c1..f02ce32b 100644 --- a/src/s2/util/math/exactfloat/BUILD +++ b/src/s2/util/math/exactfloat/BUILD @@ -18,7 +18,7 @@ cc_library( "//s2/base:logging", "@abseil-cpp//absl/container:inlined_vector", "@abseil-cpp//absl/log:absl_check", - "@abseil-cpp//absl/log:log", + "@abseil-cpp//absl/log:absl_log", "@abseil-cpp//absl/numeric:bits", "@abseil-cpp//absl/numeric:int128", "@abseil-cpp//absl/random:random", diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index 8731ebd7..3cbbb068 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -541,13 +541,16 @@ static std::pair, absl::Span> Split( // return {span.subspan(0, a), {}}; }; -// A simple bump allocator to avoid allocating memory during recursion. +// A simple bump allocator to allow us to very efficiently allocate temporary +// space when recursing in the Karatsuba multiply. The arena is pre-sized and +// returns spans of memory via Alloc() which are then returned to the arena via +// Release. class Arena { public: - explicit Arena(ssize_t size) { data_.reserve(size); } + explicit Arena(size_t size) { data_.reserve(size); } // Allocates a span of length n from the arena. - absl::Span Alloc(ssize_t n) { + absl::Span Alloc(size_t n) { ABSL_DCHECK_LE(used_ + n, data_.capacity()); size_t start = used_; used_ += n; @@ -556,13 +559,13 @@ class Arena { size_t Used() const { return used_; } - void Release(ssize_t n) { + void Release(size_t n) { ABSL_DCHECK_LE(n, used_); used_ -= n; } private: - ssize_t used_ = 0; + size_t used_ = 0; std::vector data_; }; diff --git a/src/s2/util/math/exactfloat/bignum.h b/src/s2/util/math/exactfloat/bignum.h index e02ece5e..80c1ec7c 100644 --- a/src/s2/util/math/exactfloat/bignum.h +++ b/src/s2/util/math/exactfloat/bignum.h @@ -13,6 +13,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#ifndef S2_UTIL_MATH_EXACTFLOAT_BIGNUM_H_ +#define S2_UTIL_MATH_EXACTFLOAT_BIGNUM_H_ + #include #include #include @@ -450,3 +453,5 @@ T Bignum::Cast() const { } } // namespace exactfloat_internal + +#endif // S2_UTIL_MATH_EXACTFLOAT_BIGNUM_H_ diff --git a/src/s2/util/math/exactfloat/bignum_test.cc b/src/s2/util/math/exactfloat/bignum_test.cc index 769997d5..b6ef4ef8 100644 --- a/src/s2/util/math/exactfloat/bignum_test.cc +++ b/src/s2/util/math/exactfloat/bignum_test.cc @@ -16,6 +16,7 @@ #include "s2/util/math/exactfloat/bignum.h" #include +#include #include #include @@ -25,6 +26,7 @@ #endif #include "absl/base/no_destructor.h" +#include "absl/random/bit_gen_ref.h" #include "absl/random/random.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" @@ -36,20 +38,20 @@ namespace exactfloat_internal { using ::testing::TestWithParam; -const uint64_t u8max = std::numeric_limits::max(); -const uint64_t u16max = std::numeric_limits::max(); -const uint64_t u32max = std::numeric_limits::max(); -const uint64_t u64max = std::numeric_limits::max(); +constexpr uint64_t kU8max = std::numeric_limits::max(); +constexpr uint64_t kU16max = std::numeric_limits::max(); +constexpr uint64_t kU32max = std::numeric_limits::max(); +constexpr uint64_t kU64max = std::numeric_limits::max(); -const int64_t i8max = std::numeric_limits::max(); -const int64_t i16max = std::numeric_limits::max(); -const int64_t i32max = std::numeric_limits::max(); -const int64_t i64max = std::numeric_limits::max(); +constexpr int64_t kI8max = std::numeric_limits::max(); +constexpr int64_t kI16max = std::numeric_limits::max(); +constexpr int64_t kI32max = std::numeric_limits::max(); +constexpr int64_t kI64max = std::numeric_limits::max(); -const int64_t i8min = std::numeric_limits::min(); -const int64_t i16min = std::numeric_limits::min(); -const int64_t i32min = std::numeric_limits::min(); -const int64_t i64min = std::numeric_limits::min(); +constexpr int64_t kI8min = std::numeric_limits::min(); +constexpr int64_t kI16min = std::numeric_limits::min(); +constexpr int64_t kI32min = std::numeric_limits::min(); +constexpr int64_t kI64min = std::numeric_limits::min(); // To reduce duplication. inline auto Bn(absl::string_view str) { return Bignum::FromString(str); }; @@ -155,8 +157,8 @@ TEST(BignumTest, NegativeOnlyFitsInSigned) { } TEST(BignumTest, FitsInUnsignedBoundsChecks) { - const Bignum bn_u8max(u8max); - const Bignum bn_u8over(u8max + 1); + const Bignum bn_u8max(kU8max); + const Bignum bn_u8over(kU8max + 1); EXPECT_TRUE(bn_u8max.FitsIn()); EXPECT_TRUE(bn_u8max.FitsIn()); EXPECT_TRUE(bn_u8max.FitsIn()); @@ -164,8 +166,8 @@ TEST(BignumTest, FitsInUnsignedBoundsChecks) { EXPECT_TRUE(bn_u8over.FitsIn()); EXPECT_TRUE(bn_u8over.FitsIn()); - const Bignum bn_u16max(u16max); - const Bignum bn_u16over(u16max + 1); + const Bignum bn_u16max(kU16max); + const Bignum bn_u16over(kU16max + 1); EXPECT_FALSE(bn_u16max.FitsIn()); EXPECT_TRUE(bn_u16max.FitsIn()); EXPECT_TRUE(bn_u16max.FitsIn()); @@ -173,8 +175,8 @@ TEST(BignumTest, FitsInUnsignedBoundsChecks) { EXPECT_FALSE(bn_u16over.FitsIn()); EXPECT_TRUE(bn_u16over.FitsIn()); - const Bignum bn_u32max(u32max); - const Bignum bn_u32over(u32max + 1); + const Bignum bn_u32max(kU32max); + const Bignum bn_u32over(kU32max + 1); EXPECT_FALSE(bn_u32max.FitsIn()); EXPECT_FALSE(bn_u32max.FitsIn()); EXPECT_TRUE(bn_u32max.FitsIn()); @@ -182,7 +184,7 @@ TEST(BignumTest, FitsInUnsignedBoundsChecks) { EXPECT_FALSE(bn_u32over.FitsIn()); EXPECT_FALSE(bn_u32over.FitsIn()); - const Bignum bn_u64max(u64max); + const Bignum bn_u64max(kU64max); EXPECT_TRUE(bn_u64max.FitsIn()); // 2^64, need to use string constructor. @@ -191,8 +193,8 @@ TEST(BignumTest, FitsInUnsignedBoundsChecks) { } TEST(BignumTest, FitsInSignedBoundsChecks) { - const Bignum bn_i8max(i8max); - const Bignum bn_i8over(i8max + 1); + const Bignum bn_i8max(kI8max); + const Bignum bn_i8over(kI8max + 1); EXPECT_TRUE(bn_i8max.FitsIn()); EXPECT_TRUE(bn_i8max.FitsIn()); EXPECT_TRUE(bn_i8max.FitsIn()); @@ -200,8 +202,8 @@ TEST(BignumTest, FitsInSignedBoundsChecks) { EXPECT_TRUE(bn_i8over.FitsIn()); EXPECT_TRUE(bn_i8over.FitsIn()); - const Bignum bn_i16max(i16max); - const Bignum bn_i16over(i16max + 1); + const Bignum bn_i16max(kI16max); + const Bignum bn_i16over(kI16max + 1); EXPECT_FALSE(bn_i16max.FitsIn()); EXPECT_TRUE(bn_i16max.FitsIn()); EXPECT_TRUE(bn_i16max.FitsIn()); @@ -209,8 +211,8 @@ TEST(BignumTest, FitsInSignedBoundsChecks) { EXPECT_FALSE(bn_i16over.FitsIn()); EXPECT_TRUE(bn_i16over.FitsIn()); - const Bignum bn_i32max(i32max); - const Bignum bn_i32over(i32max + 1); + const Bignum bn_i32max(kI32max); + const Bignum bn_i32over(kI32max + 1); EXPECT_FALSE(bn_i32max.FitsIn()); EXPECT_FALSE(bn_i32max.FitsIn()); EXPECT_TRUE(bn_i32max.FitsIn()); @@ -218,15 +220,15 @@ TEST(BignumTest, FitsInSignedBoundsChecks) { EXPECT_FALSE(bn_i32over.FitsIn()); EXPECT_FALSE(bn_i32over.FitsIn()); - Bignum bn_i64max(i64max); + Bignum bn_i64max(kI64max); EXPECT_TRUE(bn_i64max.FitsIn()); // 2^63, need to use string constructor. Bignum bn0 = *Bn("9223372036854775808"); EXPECT_FALSE(bn0.FitsIn()); - const Bignum bn_i8min(i8min); - const Bignum bn_i8under(i8min - 1); + const Bignum bn_i8min(kI8min); + const Bignum bn_i8under(kI8min - 1); EXPECT_TRUE(bn_i8min.FitsIn()); EXPECT_TRUE(bn_i8min.FitsIn()); EXPECT_TRUE(bn_i8min.FitsIn()); @@ -234,8 +236,8 @@ TEST(BignumTest, FitsInSignedBoundsChecks) { EXPECT_TRUE(bn_i8under.FitsIn()); EXPECT_TRUE(bn_i8under.FitsIn()); - const Bignum bn_i16min(i16min); - const Bignum bn_i16under(i16min - 1); + const Bignum bn_i16min(kI16min); + const Bignum bn_i16under(kI16min - 1); EXPECT_FALSE(bn_i16min.FitsIn()); EXPECT_TRUE(bn_i16min.FitsIn()); EXPECT_TRUE(bn_i16min.FitsIn()); @@ -243,8 +245,8 @@ TEST(BignumTest, FitsInSignedBoundsChecks) { EXPECT_FALSE(bn_i16under.FitsIn()); EXPECT_TRUE(bn_i16under.FitsIn()); - const Bignum bn_i32min(i32min); - const Bignum bn_i32under(i32min - 1); + const Bignum bn_i32min(kI32min); + const Bignum bn_i32under(kI32min - 1); EXPECT_FALSE(bn_i32min.FitsIn()); EXPECT_FALSE(bn_i32min.FitsIn()); EXPECT_TRUE(bn_i32min.FitsIn()); @@ -252,7 +254,7 @@ TEST(BignumTest, FitsInSignedBoundsChecks) { EXPECT_FALSE(bn_i32under.FitsIn()); EXPECT_FALSE(bn_i32under.FitsIn()); - Bignum bn_i64min(i64min); + Bignum bn_i64min(kI64min); EXPECT_TRUE(bn_i64min.FitsIn()); } @@ -365,7 +367,7 @@ TEST(BignumTest, Addition) { EXPECT_EQ(Bignum(5) + Bignum(-5), Bignum(0)); // Carry propagation - const auto bn_u64max = Bignum(u64max); + const auto bn_u64max = Bignum(kU64max); EXPECT_EQ(bn_u64max + Bignum(1), *Bn("18446744073709551616")); EXPECT_EQ(bn_u64max + bn_u64max, *Bn("36893488147419103230")); @@ -389,7 +391,7 @@ TEST(BignumTest, Subtraction) { EXPECT_EQ(Bignum(42) - Bignum(42), Bignum(0)); // Borrow propagation - const auto bn_u64max = Bignum(u64max); + const auto bn_u64max = Bignum(kU64max); const auto two_pow_64 = *Bn("18446744073709551616"); EXPECT_EQ(two_pow_64 - Bignum(1), bn_u64max); @@ -446,7 +448,7 @@ TEST(BignumTest, LeftShift) { EXPECT_EQ((Bignum(1) << 64), *Bn("18446744073709551616")); EXPECT_EQ((Bignum(-1) << 64), *Bn("-18446744073709551616")); - const auto bn_u64max = Bignum(u64max); + const auto bn_u64max = Bignum(kU64max); const auto two_pow_128_minus_two_pow_64 = *Bn("340282366920938463444927863358058659840"); EXPECT_EQ((bn_u64max << 64), two_pow_128_minus_two_pow_64); @@ -509,11 +511,11 @@ TEST(BignumTest, Multiplication) { EXPECT_EQ(Bignum(-10) * Bignum(-20), Bignum(200)); // Simple carry - const auto bn_u32max = Bignum(u32max); + const auto bn_u32max = Bignum(kU32max); EXPECT_EQ(bn_u32max * Bignum(2), *Bn("8589934590")); // 1x1 bigit fast path - const auto bn_u64max = Bignum(u64max); + const auto bn_u64max = Bignum(kU64max); EXPECT_EQ(Bignum(2) * bn_u64max, *Bn("36893488147419103230")); // 1xN bigit multiplication @@ -754,10 +756,11 @@ class OpenSSLBignum { // Power of two for fast modulo. const int kRandomBignumCount = 128; -static std::vector GenerateRandomNumbers(int bits) { +static std::vector GenerateRandomNumbers(absl::BitGenRef bitgen, + int bits) { std::vector numbers; + numbers.reserve(kRandomBignumCount); - absl::BitGen bitgen; for (int i = 0; i < kRandomBignumCount; ++i) { std::string num; @@ -767,7 +770,7 @@ static std::vector GenerateRandomNumbers(int bits) { // First digit can't be zero absl::StrAppend(&num, absl::StrFormat("%d", absl::Uniform(bitgen, 1, 9))); for (int j = 1; j < decimal_digits; ++j) { - num += std::to_string(absl::Uniform(bitgen, 0, 9)); + num += absl::Uniform(bitgen, '0', '9'); } numbers.push_back(num); @@ -807,33 +810,33 @@ TEST(BignumTest, ResultsMatch) { OPENSSL_free(ssl_str); } -const std::vector& SmallNumbers() { +const std::vector& SmallNumbers(absl::BitGenRef bitgen) { static absl::NoDestructor> numbers( // - GenerateRandomNumbers(64)); + GenerateRandomNumbers(bitgen, 64)); return *numbers; } -const std::vector& MediumNumbers() { +const std::vector& MediumNumbers(absl::BitGenRef bitgen) { static absl::NoDestructor> numbers( // - GenerateRandomNumbers(256)); + GenerateRandomNumbers(bitgen, 256)); return *numbers; } -const std::vector& LargeNumbers() { +const std::vector& LargeNumbers(absl::BitGenRef bitgen) { static absl::NoDestructor> numbers( // - GenerateRandomNumbers(1024)); + GenerateRandomNumbers(bitgen, 1024)); return *numbers; } -const std::vector& HugeNumbers() { +const std::vector& HugeNumbers(absl::BitGenRef bitgen) { static absl::NoDestructor> numbers( // - GenerateRandomNumbers(4096)); + GenerateRandomNumbers(bitgen, 4096)); return *numbers; } -const std::vector& MegaNumbers() { +const std::vector& MegaNumbers(absl::BitGenRef bitgen) { static absl::NoDestructor> numbers( // - GenerateRandomNumbers(18000)); + GenerateRandomNumbers(bitgen, 18000)); return *numbers; } @@ -947,10 +950,13 @@ TEST_P(VsOpenSSLTest, SubtractionCorrect) { } } +absl::BitGen bitgen; INSTANTIATE_TEST_SUITE_P(VsOpenSSL, VsOpenSSLTest, - ::testing::Values(SmallNumbers(), MediumNumbers(), - LargeNumbers(), HugeNumbers(), - MediumNumbers())); + ::testing::Values(SmallNumbers(bitgen), + MediumNumbers(bitgen), + LargeNumbers(bitgen), + HugeNumbers(bitgen), + MediumNumbers(bitgen))); // TODO: Enable once benchmark is integrated. #if 0 @@ -1062,122 +1068,151 @@ void OpenSSLPowBenchmark(benchmark::State& state, } void BM_Bignum_AddSmall(benchmark::State& state) { - BignumBinaryOpBenchmark(state, SmallNumbers(), std::plus{}); + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, SmallNumbers(bitgen), std::plus{}); } BENCHMARK(BM_Bignum_AddSmall); void BM_Bignum_AddMedium(benchmark::State& state) { - BignumBinaryOpBenchmark(state, MediumNumbers(), std::plus{}); + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, MediumNumbers(bitgen), std::plus{}); } BENCHMARK(BM_Bignum_AddMedium); void BM_Bignum_AddLarge(benchmark::State& state) { - BignumBinaryOpBenchmark(state, LargeNumbers(), std::plus{}); + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, LargeNumbers(bitgen), std::plus{}); } BENCHMARK(BM_Bignum_AddLarge); void BM_Bignum_AddHuge(benchmark::State& state) { - BignumBinaryOpBenchmark(state, HugeNumbers(), std::plus{}); + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, HugeNumbers(bitgen), std::plus{}); } BENCHMARK(BM_Bignum_AddHuge); void BM_Bignum_AddMega(benchmark::State& state) { - BignumBinaryOpBenchmark(state, MegaNumbers(), std::plus{}); + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, MegaNumbers(bitgen), std::plus{}); } BENCHMARK(BM_Bignum_AddMega); void BM_OpenSSL_AddSmall(benchmark::State& state) { - OpenSSLBinaryOpBenchmark(state, SmallNumbers(), BN_add); + std::mt19937_64 bitgen; + OpenSSLBinaryOpBenchmark(state, SmallNumbers(bitgen), BN_add); } BENCHMARK(BM_OpenSSL_AddSmall); void BM_OpenSSL_AddMedium(benchmark::State& state) { - OpenSSLBinaryOpBenchmark(state, MediumNumbers(), BN_add); + std::mt19937_64 bitgen; + OpenSSLBinaryOpBenchmark(state, MediumNumbers(bitgen), BN_add); } BENCHMARK(BM_OpenSSL_AddMedium); void BM_OpenSSL_AddLarge(benchmark::State& state) { - OpenSSLBinaryOpBenchmark(state, LargeNumbers(), BN_add); + std::mt19937_64 bitgen; + OpenSSLBinaryOpBenchmark(state, LargeNumbers(bitgen), BN_add); } BENCHMARK(BM_OpenSSL_AddLarge); void BM_OpenSSL_AddHuge(benchmark::State& state) { - OpenSSLBinaryOpBenchmark(state, HugeNumbers(), BN_add); + std::mt19937_64 bitgen; + OpenSSLBinaryOpBenchmark(state, HugeNumbers(bitgen), BN_add); } BENCHMARK(BM_OpenSSL_AddHuge); void BM_OpenSSL_AddMega(benchmark::State& state) { - OpenSSLBinaryOpBenchmark(state, MegaNumbers(), BN_add); + std::mt19937_64 bitgen; + OpenSSLBinaryOpBenchmark(state, MegaNumbers(bitgen), BN_add); } BENCHMARK(BM_OpenSSL_AddMega); void BM_Bignum_MulSmall(benchmark::State& state) { - BignumBinaryOpBenchmark(state, SmallNumbers(), std::multiplies{}); + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, SmallNumbers(bitgen), + std::multiplies{}); } BENCHMARK(BM_Bignum_MulSmall); void BM_Bignum_MulMedium(benchmark::State& state) { - BignumBinaryOpBenchmark(state, MediumNumbers(), std::multiplies{}); + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, MediumNumbers(bitgen), + std::multiplies{}); } BENCHMARK(BM_Bignum_MulMedium); void BM_Bignum_MulLarge(benchmark::State& state) { - BignumBinaryOpBenchmark(state, LargeNumbers(), std::multiplies{}); + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, LargeNumbers(bitgen), + std::multiplies{}); } BENCHMARK(BM_Bignum_MulLarge); void BM_Bignum_MulHuge(benchmark::State& state) { - BignumBinaryOpBenchmark(state, HugeNumbers(), std::multiplies{}); + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, HugeNumbers(bitgen), + std::multiplies{}); } BENCHMARK(BM_Bignum_MulHuge); void BM_Bignum_MulMega(benchmark::State& state) { - BignumBinaryOpBenchmark(state, MegaNumbers(), std::multiplies{}); + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, MegaNumbers(bitgen), + std::multiplies{}); } BENCHMARK(BM_Bignum_MulMega); void BM_OpenSSL_MulSmall(benchmark::State& state) { - OpenSSLMulOpBenchmark(state, SmallNumbers(), BN_mul); + std::mt19937_64 bitgen; + OpenSSLMulOpBenchmark(state, SmallNumbers(bitgen), BN_mul); } BENCHMARK(BM_OpenSSL_MulSmall); void BM_OpenSSL_MulMedium(benchmark::State& state) { - OpenSSLMulOpBenchmark(state, MediumNumbers(), BN_mul); + std::mt19937_64 bitgen; + OpenSSLMulOpBenchmark(state, MediumNumbers(bitgen), BN_mul); } BENCHMARK(BM_OpenSSL_MulMedium); void BM_OpenSSL_MulLarge(benchmark::State& state) { - OpenSSLMulOpBenchmark(state, LargeNumbers(), BN_mul); + std::mt19937_64 bitgen; + OpenSSLMulOpBenchmark(state, LargeNumbers(bitgen), BN_mul); } BENCHMARK(BM_OpenSSL_MulLarge); void BM_OpenSSL_MulHuge(benchmark::State& state) { - OpenSSLMulOpBenchmark(state, HugeNumbers(), BN_mul); + std::mt19937_64 bitgen; + OpenSSLMulOpBenchmark(state, HugeNumbers(bitgen), BN_mul); } BENCHMARK(BM_OpenSSL_MulHuge); void BM_OpenSSL_MulMega(benchmark::State& state) { - OpenSSLMulOpBenchmark(state, MegaNumbers(), BN_mul); + std::mt19937_64 bitgen; + OpenSSLMulOpBenchmark(state, MegaNumbers(bitgen), BN_mul); } BENCHMARK(BM_OpenSSL_MulMega); void BM_Bignum_PowSmall(benchmark::State& state) { - BignumPowBenchmark(state, SmallNumbers(), 20); + std::mt19937_64 bitgen; + BignumPowBenchmark(state, SmallNumbers(bitgen), 20); } BENCHMARK(BM_Bignum_PowSmall); void BM_Bignum_PowMedium(benchmark::State& state) { - BignumPowBenchmark(state, MediumNumbers(), 10); + std::mt19937_64 bitgen; + BignumPowBenchmark(state, MediumNumbers(bitgen), 10); } BENCHMARK(BM_Bignum_PowMedium); void BM_OpenSSL_PowSmall(benchmark::State& state) { - OpenSSLPowBenchmark(state, SmallNumbers(), 20); + std::mt19937_64 bitgen; + OpenSSLPowBenchmark(state, SmallNumbers(bitgen), 20); } BENCHMARK(BM_OpenSSL_PowSmall); void BM_OpenSSL_PowMedium(benchmark::State& state) { - OpenSSLPowBenchmark(state, MediumNumbers(), 10); + std::mt19937_64 bitgen; + OpenSSLPowBenchmark(state, MediumNumbers(bitgen), 10); } BENCHMARK(BM_OpenSSL_PowMedium); #endif From 095e140a59db6e87f77d11ef21c73141267558ed Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Fri, 3 Oct 2025 14:11:05 -0600 Subject: [PATCH 08/31] PR (Mega) round 5 changes. --- src/s2/util/math/exactfloat/bignum.cc | 390 ++++++++++-------- src/s2/util/math/exactfloat/bignum.h | 220 ++++------ src/s2/util/math/exactfloat/bignum_test.cc | 171 ++++---- src/s2/util/math/exactfloat/exactfloat.cc | 64 ++- src/s2/util/math/exactfloat/exactfloat.h | 5 +- .../util/math/exactfloat/exactfloat_test.cc | 4 + 6 files changed, 420 insertions(+), 434 deletions(-) diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index 3cbbb068..13c881ba 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -1,13 +1,54 @@ +// Copyright 2025 Google LLC +// Author: smcallis@google.com (Sean McAllister) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "s2/util/math/exactfloat/bignum.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/numeric/bits.h" +#include "absl/numeric/int128.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + namespace exactfloat_internal { -// Threshold for fallback to simple multiplication, determined empirically. -static constexpr int kKaratsubaThreshold = 64; +// Number of bigits in the result of a multiplication before we fall back to +// simple multiplication in the Karatsuba recursion. Determined empirically. +static constexpr int kSimpleMulThreshold = 64; -static Bigit MulAdd( // - absl::Span out, absl::Span a, Bigit b, Bigit c); +// Computes out[i] = a[i]*b + c +// +// Returns the final carry, if any. +inline Bigit MulAdd(absl::Span out, absl::Span a, Bigit b, + Bigit c); +// Compares magnitude magnitude of two bigit vectors, returning -1, 0, or +1. +// +// Magnitudes are compared lexicographically from the most significant bigit +// to the least significant. int CmpAbs(absl::Span a, absl::Span b) { if (a.size() != b.size()) { return a.size() < b.size() ? -1 : +1; @@ -22,9 +63,24 @@ int CmpAbs(absl::Span a, absl::Span b) { return 0; } +int Bignum::Compare(const Bignum& b) const { + if (is_negative() != b.is_negative()) { + return is_negative() ? -1 : +1; + } + + // Signs are equal, are they both zero? + if (is_zero() && b.is_zero()) { + return 0; + } + + // Signs are equal and non-zero, compare magnitude. + const int compare = CmpAbs(bigits_, b.bigits_); + return is_negative() ? -compare : compare; +} + std::optional Bignum::FromString(absl::string_view s) { // A chunk is up to 19 decimal digits, which can always fit into a Bigit. - constexpr int kMaxChunkDigits = std::numeric_limits::digits10; + constexpr ssize_t kMaxChunkDigits = std::numeric_limits::digits10; // NOTE: We use a simple multiply-and-add (aka Horner's) method here for the // sake of simplicity. This isn't the fastest algorithm, being quadratic in @@ -46,61 +102,44 @@ std::optional Bignum::FromString(absl::string_view s) { return out; } - // Reserve space for bigits. out.bigits_.reserve((s.size() + kMaxChunkDigits - 1) / kMaxChunkDigits); - int sign = +1; - Bigit chunk = 0; - int clen = 0; - - // Finish processing the current chunk. - auto FlushChunk = [&]() { - if (clen) { - auto outspan = absl::MakeSpan(out.bigits_); - if (Bigit carry = MulAdd(outspan, outspan, kPow10[clen], chunk)) { - out.bigits_.emplace_back(carry); - } - chunk = 0; - clen = 0; - } - }; + bool negative = false; // Consume optional +/- at the front. - int start = 0; - if ((s[0] == '+' || s[0] == '-')) { - sign = (s[0] == '-') ? -1 : +1; - ++start; + auto begin = s.cbegin(); + if ((*begin == '+' || *begin == '-')) { + negative = (s[0] == '-'); + ++begin; } - bool seen_digit = false; - for (char c : s.substr(start)) { - if (!absl::ascii_isdigit(c)) { - return std::nullopt; - } + const auto end = s.cend(); + while (begin < end) { + size_t chunk_len = std::min(std::distance(begin, end), kMaxChunkDigits); - // Accumulate digit into the local 64-bit chunk. Skip leading zeros. - uint64_t digit = static_cast(c - '0'); - if (!seen_digit && digit == 0) { - continue; + Bigit chunk; + auto result = std::from_chars(begin, begin + chunk_len, chunk); + if (result.ec != std::errc() || (result.ptr - begin) != chunk_len) { + return std::nullopt; } - seen_digit = true; + begin += chunk_len; - chunk = 10 * chunk + digit; - ++clen; - - if (clen == kMaxChunkDigits) { - FlushChunk(); + // Shift out up by chunk_len digits and add the chunk to it. + auto outspan = absl::MakeSpan(out.bigits_); + Bigit carry = MulAdd(outspan, outspan, kPow10[chunk_len], chunk); + if (carry) { + out.bigits_.emplace_back(carry); } } - FlushChunk(); - out.NormalizeSign(sign); + out.negative_ = negative; + out.Normalize(); return out; } int bit_width(const Bignum& a) { - ABSL_DCHECK(a.Normalized()); - if (a.empty()) { + ABSL_DCHECK(a.is_normalized()); + if (a.bigits_.empty()) { return 0; } @@ -122,14 +161,14 @@ int countr_zero(const Bignum& a) { if (bigit == 0) { nzero += Bignum::kBigitBits; } else { - nzero += absl::countr_zero(static_cast(bigit)); + nzero += absl::countr_zero(bigit); break; } } return nzero; } -bool Bignum::Bit(int nbit) const { +bool Bignum::is_bit_set(int nbit) const { ABSL_DCHECK_GE(nbit, 0); if (is_zero()) { return false; @@ -138,7 +177,7 @@ bool Bignum::Bit(int nbit) const { const size_t digit = nbit / kBigitBits; const size_t shift = nbit % kBigitBits; - if (digit >= size()) { + if (digit >= bigits_.size()) { return false; } @@ -147,7 +186,7 @@ bool Bignum::Bit(int nbit) const { Bignum Bignum::operator-() const { Bignum result = *this; - result.sign_ = -result.sign_; + result.set_negative(!result.negative_); return result; } @@ -188,7 +227,7 @@ Bignum& Bignum::operator>>=(int nbit) { // Shifting by more than the bit width results in zero. if (nbit >= bit_width(*this)) { - return SetZero(); + return set_zero(); } const int nbigit = nbit / kBigitBits; @@ -208,7 +247,7 @@ Bignum& Bignum::operator>>=(int nbit) { } // Result might be smaller or zero, so normalize. - NormalizeSign(sign_); + Normalize(); return *this; } @@ -225,14 +264,6 @@ Bignum Bignum::Pow(int32_t pow) const { return Bignum(0); } - if (*this == Bignum(1)) { - return Bignum(1); - } - - if (*this == Bignum(-1)) { - return (pow % 2 != 0) ? Bignum(-1) : Bignum(1); - } - // Core algorithm: Exponentiation by squaring. Bignum result(1); Bignum base = *this; // A mutable copy of the base. @@ -249,37 +280,55 @@ Bignum Bignum::Pow(int32_t pow) const { return result; } -// Computes a + b + c and updates the carry. -static Bigit AddCarry(Bigit a, Bigit b, Bigit& c) { - auto sum = absl::uint128(a) + b + c; - c = absl::Uint128High64(sum); +// Computes a + b + carry and updates the carry. +inline Bigit AddCarry(Bigit a, Bigit b, Bigit* absl_nonnull carry) { + auto sum = absl::uint128(a) + b + *carry; + *carry = absl::Uint128High64(sum); return static_cast(sum); } -// Computes a - b - c and updates the borrow. -static Bigit SubBorrow(Bigit a, Bigit b, Bigit& borrow) { - Bigit diff = a - b - borrow; - borrow = (a < b) || (borrow && (a == b)); +// Computes a - b - borrow and updates the borrow. +// +// NOTE: Borrow must be one or zero. +inline Bigit SubBorrow(Bigit a, Bigit b, Bigit* absl_nonnull borrow) { + ABSL_DCHECK_LE(borrow, 1); + Bigit diff = a - b - *borrow; + *borrow = (a < b) || (*borrow && (a == b)); return diff; } -// Computes a * b + c and updates the carry. -static Bigit MulCarry(Bigit a, Bigit b, Bigit& c) { - auto sum = absl::uint128(a) * b + c; - c = absl::Uint128High64(sum); +// Computes a * b + carry and updates the carry. +inline Bigit MulCarry(Bigit a, Bigit b, Bigit* absl_nonnull carry) { + auto sum = absl::uint128(a) * b + *carry; + *carry = absl::Uint128High64(sum); return static_cast(sum); } // Computes out += a * b + c and updates the carry. -static void MulAddCarry(Bigit& out, Bigit a, Bigit b, Bigit& c) { - auto sum = absl::uint128(a) * b + c + out; - c = absl::Uint128High64(sum); +// +// NOTE: Will not overflow even if a, b, and c are their maximum values. +inline void MulAddCarry(Bigit& out, Bigit a, Bigit b, + Bigit* absl_nonnull carry) { + auto sum = absl::uint128(a) * b + *carry + out; + *carry = absl::Uint128High64(sum); out = static_cast(sum); } // Computes a += b in place. Returns the final carry (if any). -// NOTE: the a operand must be pre-expanded to fit b. -static Bigit AddInPlace(absl::Span a, absl::Span b) { +// +// A operand must be at least as large as B. When adding two same-sized values, +// the result may overflow and be larger than either of them, in which case we +// will return the final carry value. +// +// This allows a work flow like this: +// Bigit carry = AddInPlace(a, b); +// if (carry) { +// a.bigits_.emplace_back(carry); +// } +// +// Rather than having to expand A to B.bigits_.size() + 1, and popping off the +// top bigit if it's unused (which is the most common case). +inline Bigit AddInPlace(absl::Span a, absl::Span b) { ABSL_DCHECK_GE(a.size(), b.size()); Bigit* pa = a.data(); @@ -290,32 +339,42 @@ static Bigit AddInPlace(absl::Span a, absl::Span b) { // Dispatch four at a time to help loop unrolling. while (left >= 4) { - for (int i = 0; i < 4; ++i) { - *pa = AddCarry(*pa, *pb++, carry); - ++pa; + for (int i = 0; i < 4; ++i, ++pa, ++pb) { + *pa = AddCarry(*pa, *pb, &carry); --left; } } // Finish remainder. - while (left--) { - *pa = AddCarry(*pa, *pb++, carry); - ++pa; + for (; left > 0; --left, pa++, pb++) { + *pa = AddCarry(*pa, *pb, &carry); } // Propagate carry through the rest of a. int remaining = a.size() - b.size(); - while (carry && remaining--) { - *pa = AddCarry(*pa, 0, carry); - ++pa; + for (; carry && remaining > 0; --remaining, pa++) { + *pa = AddCarry(*pa, 0, &carry); } return carry; } -static size_t AddInto( // - absl::Span dst, absl::Span a, - absl::Span b) { +// Computes dst = a + b out of place. Returns the number of bigits actually +// written into dst. +// +// NOTE: dst must be sized to be larger than max(a.size(), b.size()) + 1 (i.e. +// it must be able to hold the carry bigit, if any. +// +// This allows for using a pre-allocated buffer to store the result of an +// addition followed by trimming down to size: +// +// auto out = arena.Alloc(std::max(a.size(), b.size()) + 1); +// out = out.first(AddInto(out, a, b)); +// +// Which is used in the Karatsuba multiplication, where we don't have the option +// to expand the allocate space on demand. +inline size_t AddOutOfPlace(absl::Span dst, absl::Span a, + absl::Span b) { const size_t max_size = std::max(a.size(), b.size()); const size_t min_size = std::min(a.size(), b.size()); ABSL_DCHECK_GE(dst.size(), max_size + 1); @@ -328,18 +387,17 @@ static size_t AddInto( // Bigit carry = 0; // Dispatch four at a time to help loop unrolling. - size_t size = min_size; size_t i = 0; - while (size >= i + 4) { + while (i + 4 < min_size) { for (int j = 0; j < 4; ++j) { - pdst[i] = AddCarry(pa[i], pb[i], carry); + pdst[i] = AddCarry(pa[i], pb[i], &carry); ++i; } } // Finish remainder of common parts. - for (; i < size; ++i) { - pdst[i] = AddCarry(pa[i], pb[i], carry); + for (; i < min_size; ++i) { + pdst[i] = AddCarry(pa[i], pb[i], &carry); } // Copy remaining digits from the longer operand and propagate carry. @@ -347,17 +405,17 @@ static size_t AddInto( // const Bigit* plonger = (a.size() > b.size()) ? pa : pb; // Dispatch four at a time for the remaining part. - size = longer.size(); - while (size >= i + 4) { + const int size = longer.size(); + while (i + 4 < size) { for (int j = 0; j < 4; ++j) { - pdst[i] = AddCarry(plonger[i], 0, carry); + pdst[i] = AddCarry(plonger[i], 0, &carry); ++i; } } // Finish remainder. for (; i < size; ++i) { - pdst[i] = AddCarry(plonger[i], 0, carry); + pdst[i] = AddCarry(plonger[i], 0, &carry); } if (carry) { @@ -373,8 +431,8 @@ static size_t AddInto( // // actually set in A must be passed in via a_digits. // // REQUIRES: |a| < |b|. -static Bigit SubLtInPlace( // - absl::Span a, absl::Span b, size_t a_digits) { +inline Bigit SubLtInPlace(absl::Span a, absl::Span b, + size_t a_digits) { ABSL_DCHECK_EQ(a.size(), b.size()); ABSL_DCHECK_LT(CmpAbs(a, b), 0); @@ -385,21 +443,21 @@ static Bigit SubLtInPlace( // // Dispatch four at a time to help loop unrolling. size_t size = a_digits; size_t i = 0; - while (size >= i + 4) { + while (i + 4 < size) { for (int j = 0; j < 4; ++j) { - pa[i] = SubBorrow(pb[i], pa[i], borrow); + pa[i] = SubBorrow(pb[i], pa[i], &borrow); ++i; } } // Finish remainder. for (; i < a_digits; ++i) { - pa[i] = SubBorrow(pb[i], pa[i], borrow); + pa[i] = SubBorrow(pb[i], pa[i], &borrow); } // Propagate borrow through the rest of b. for (; borrow && i < b.size(); ++i) { - pa[i] = SubBorrow(pb[i], 0, borrow); + pa[i] = SubBorrow(pb[i], 0, &borrow); } return borrow; } @@ -407,7 +465,7 @@ static Bigit SubLtInPlace( // // Computes a -= b. Returns the final borrow (if any). // // REQUIRES: |a| >= |b|. -static Bigit SubGeInPlace(absl::Span a, absl::Span b) { +inline Bigit SubGeInPlace(absl::Span a, absl::Span b) { ABSL_DCHECK_GE(a.size(), b.size()); ABSL_DCHECK_GE(CmpAbs(a, b), 0); @@ -419,16 +477,16 @@ static Bigit SubGeInPlace(absl::Span a, absl::Span b) { // Dispatch four at a time to help loop unrolling. size_t size = b.size(); size_t done = 0; - while (size >= done + 4) { + while (done + 4 < size) { for (int i = 0; i < 4; ++i) { - pa[done] = SubBorrow(pa[done], pb[done], borrow); + pa[done] = SubBorrow(pa[done], pb[done], &borrow); ++done; } } // Finish remainder of subtraction. for (; done < size; ++done) { - pa[done] = SubBorrow(pa[done], pb[done], borrow); + pa[done] = SubBorrow(pa[done], pb[done], &borrow); } // Propagate the borrow through a. @@ -439,11 +497,8 @@ static Bigit SubGeInPlace(absl::Span a, absl::Span b) { return borrow; } -// Computes out[i] = a[i]*b + c -// -// Returns the final carry, if any. -Bigit MulAdd( // - absl::Span out, absl::Span a, Bigit b, Bigit c = 0) { +Bigit MulAdd(absl::Span out, absl::Span a, Bigit b, + Bigit c) { ABSL_DCHECK_GE(out.size(), a.size()); Bigit* pout = out.data(); @@ -453,14 +508,14 @@ Bigit MulAdd( // // Dispatch four at a time to help loop unrolling. while (left >= 4) { - for (int i = 0; i < 4; ++i) { - *pout++ = MulCarry(*pa++, b, c); + for (int i = 0; i < 4; ++i, pa++, pout++) { + *pout = MulCarry(*pa, b, &c); --left; } } - while (left--) { - *pout++ = MulCarry(*pa++, b, c); + for (; left > 0; --left, pa++, pout++) { + *pout = MulCarry(*pa, b, &c); } return c; } @@ -468,8 +523,8 @@ Bigit MulAdd( // // Computes out[i] += a[i]*b in place. // // Returns the final carry, if any. -static Bigit MulAddInPlace( // - absl::Span out, absl::Span a, Bigit b) { +inline Bigit MulAddInPlace(absl::Span out, absl::Span a, + Bigit b) { Bigit* pout = out.data(); const Bigit* pa = a.data(); @@ -479,22 +534,21 @@ static Bigit MulAddInPlace( // Bigit carry = 0; while (left >= 4) { for (int i = 0; i < 4; ++i) { - MulAddCarry(*pout++, *pa++, b, carry); + MulAddCarry(*pout++, *pa++, b, &carry); --left; } } // Finish remainder. while (left--) { - MulAddCarry(*pout++, *pa++, b, carry); + MulAddCarry(*pout++, *pa++, b, &carry); } return carry; } -static void MulQuadratic( // - absl::Span out, // - absl::Span a, absl::Span b) { +inline void MulQuadratic(absl::Span out, absl::Span a, + absl::Span b) { ABSL_DCHECK_GE(out.size(), a.size() + b.size()); // Make sure A is the longer of the two arguments. @@ -509,7 +563,7 @@ static void MulQuadratic( // } auto upper = out.subspan(a.size()); - upper[0] = MulAdd(out, a, b[0]); + upper[0] = MulAdd(out, a, b[0], 0); const size_t size = b.size(); size_t i = 1; @@ -533,8 +587,8 @@ static void MulQuadratic( // // Split a span into two contiguous pieces of length a and b, respectively. template -static std::pair, absl::Span> Split( // - absl::Span span, int a, int b) { +inline std::pair, absl::Span> Split(absl::Span span, + size_t a, size_t b) { if (a < span.size()) { return {span.subspan(0, a), span.subspan(a, b)}; } @@ -569,9 +623,8 @@ class Arena { std::vector data_; }; -static void KaratsubaMulRec( // - absl::Span dst, // - absl::Span a, absl::Span b, Arena& arena) { +inline void KaratsubaMulRec(absl::Span dst, absl::Span a, + absl::Span b, Arena& arena) { ABSL_DCHECK_GE(dst.size(), a.size() + b.size()); if (a.empty() || b.empty()) { absl::c_fill(dst, 0); @@ -605,7 +658,7 @@ static void KaratsubaMulRec( // // with those individual multiplies able to be recursively divided. // Fall back to long multiplication when we're small enough. - if (dst.size() <= kKaratsubaThreshold) { + if (dst.size() <= kSimpleMulThreshold) { MulQuadratic(dst, a, b); return; } @@ -634,12 +687,12 @@ static void KaratsubaMulRec( // absl::Span sb = b0; if (!a1.empty()) { absl::Span tmp = arena.Alloc(half + 1); - sa = tmp.first(AddInto(tmp, a0, a1)); + sa = tmp.first(AddOutOfPlace(tmp, a0, a1)); } if (!b1.empty()) { absl::Span tmp = arena.Alloc(half + 1); - sb = tmp.first(AddInto(tmp, b0, b1)); + sb = tmp.first(AddOutOfPlace(tmp, b0, b1)); } // Compute z1 = sa*sb - z0 - z2 = (a0 + a1)*(b0 + b1) - z0 - z2 @@ -672,8 +725,8 @@ static void KaratsubaMulRec( // arena.Release(arena.Used() - arena_start); } -Bignum::BigitVector Bignum::KaratsubaMul( // - absl::Span a, absl::Span b) { +Bignum::BigitVector Bignum::KaratsubaMul(absl::Span a, + absl::Span b) { if (a.empty() || b.empty()) { return {}; } @@ -691,7 +744,7 @@ Bignum::BigitVector Bignum::KaratsubaMul( // int next = half + 1; peak += 4 * next; size = next; - } while (size > kKaratsubaThreshold); + } while (size > kSimpleMulThreshold); Arena arena(peak); BigitVector out(a.size() + b.size(), 0); @@ -709,27 +762,31 @@ Bignum& Bignum::operator+=(const Bignum& b) { return *this; } - if (sign_ == b.sign_) { + if (is_negative() == b.is_negative()) { // Same sign: - // (+a) + (+b) == +(a + b) - // (-a) + (-b) == -(a + b) - bigits_.resize(std::max(size(), b.size()), 0); + // +|a| + +|b| == +|a + b| + // -|a| + -|b| == -|a + b| + // + // So we can just sum magnitudes. + bigits_.resize(std::max(bigits_.size(), b.bigits_.size()), 0); Bigit carry = AddInPlace(absl::MakeSpan(bigits_), b.bigits_); if (carry) { bigits_.emplace_back(carry); } Normalize(); } else { - if (CmpAbs(b) >= 0) { - // |a| >= |b|, so a - b is same sign as a. + if (CmpAbs(bigits_, b.bigits_) >= 0) { + // |a| >= |b|, so a - b is the same sign as a. SubGeInPlace(absl::MakeSpan(bigits_), b.bigits_); - NormalizeSign(sign_); + Normalize(); } else { - // |a| < |b|, so a - b is same sign as b. - const int prev_size = size(); - bigits_.resize(b.size()); + // |a| < |b|, so a - b is the same sign as b. + const int prev_size = bigits_.size(); + bigits_.resize(b.bigits_.size()); SubLtInPlace(absl::MakeSpan(bigits_), b.bigits_, prev_size); - NormalizeSign(b.sign_); + + negative_ = b.negative_; + Normalize(); } } @@ -737,45 +794,28 @@ Bignum& Bignum::operator+=(const Bignum& b) { } Bignum& Bignum::operator-=(const Bignum& b) { - if (b.is_zero()) { + if (this == &b) { + set_zero(); return *this; } - if (is_zero()) { - return *this = -b; - } - - if (sign_ != b.sign_) { - bigits_.resize(std::max(size(), b.size()), 0); - uint64_t carry = AddInPlace(absl::MakeSpan(bigits_), b.bigits_); - if (carry) { - bigits_.emplace_back(carry); - } - Normalize(); - } else { - if (CmpAbs(b) >= 0) { - SubGeInPlace(absl::MakeSpan(bigits_), b.bigits_); - NormalizeSign(sign_); - } else { - const int prev_size = size(); - bigits_.resize(b.size()); - SubLtInPlace(absl::MakeSpan(bigits_), b.bigits_, prev_size); - NormalizeSign(-sign_); - } - } - + // Compute -(-a + b) == a - b + negate(); + *this += b; + negate(); return *this; } Bignum& Bignum::operator*=(const Bignum& b) { if (is_zero() || b.is_zero()) { - return SetZero(); + return set_zero(); } - const int new_sign = sign_ * b.sign_; + // Result is only negative if signs are different. + const bool negative = (is_negative() != b.is_negative()); // Fast path for single-bigit multiplication. - if (size() == 1 && b.size() == 1) { + if (bigits_.size() == 1 && b.bigits_.size() == 1) { absl::uint128 prod = absl::uint128(bigits_[0]) * b.bigits_[0]; const uint64_t lo = absl::Uint128Low64(prod); const uint64_t hi = absl::Uint128High64(prod); @@ -784,14 +824,16 @@ Bignum& Bignum::operator*=(const Bignum& b) { } else { bigits_ = {lo, hi}; } - sign_ = new_sign; + set_negative(negative); return *this; } // Use Karatsuba multiplication. // If the inputs are small enough this will just do long multiplication. bigits_ = KaratsubaMul(bigits_, b.bigits_); - NormalizeSign(new_sign); + + negative_ = negative; + Normalize(); return *this; } diff --git a/src/s2/util/math/exactfloat/bignum.h b/src/s2/util/math/exactfloat/bignum.h index 80c1ec7c..89baae38 100644 --- a/src/s2/util/math/exactfloat/bignum.h +++ b/src/s2/util/math/exactfloat/bignum.h @@ -33,20 +33,20 @@ namespace exactfloat_internal { +// A digit of a bignum. A contraction of "big digit" (rhymes with the latter). using Bigit = uint64_t; -// Compares magnitude magnitude of two bigit vectors, returning -1, 0, or +1. -// -// Magnitudes are compared lexicographically from the most significant bigit -// to the least significant. -int CmpAbs(absl::Span a, absl::Span b); - // A class to support arithmetic on large, arbitrary precision integers. // // Large integers are represented as an array of uint64_t values. class Bignum { public: - using BigitVector = absl::InlinedVector; + // The most common use of ExactFloat involves evaluating a 3x3 determinant to + // determine whether 3 points are oriented clockwise or counter-clockwise. + // + // The typical number of mantissa bits in the result is probably about 170, so + // we allocate 4 bigits (256 bits) inline. + using BigitVector = absl::InlinedVector; static constexpr int kBigitBits = std::numeric_limits::digits; @@ -75,8 +75,8 @@ class Bignum { return os << absl::StrFormat("%v", b); } - friend std::ostream& operator<<( // - std::ostream& os, const std::optional& b) { + friend std::ostream& operator<<(std::ostream& os, + const std::optional& b) { if (!b) { return os << "[nullopt]"; } @@ -110,64 +110,40 @@ class Bignum { //-------------------------------------- // Returns the number of bits required for the magnitude of the value. + // + // Named to match std::bit_width. friend int bit_width(const Bignum& a); // Returns the number of consecutive 0 bits in the value, starting from the // least significant bit. + // + // Named to match std::countr_zero. friend int countr_zero(const Bignum& a); - // Returns true if the n-th bit of the number's magnitude is set. - bool Bit(int nbit) const; + // Returns true if the n-th bit of the number's magnitude is 1. + bool is_bit_set(int nbit) const; // Clears this bignum and sets it to zero. - Bignum& SetZero() { - sign_ = 0; + Bignum& set_zero() { + negative_ = false; bigits_.clear(); return *this; } - // Unconditionally makes the sign of this bignum negative. - Bignum& SetNegative() { - sign_ = -1; + // Sets the negative flag on the value. If the value is zero, has no effect. + Bignum& set_negative(bool negative = true) { + negative_ = !bigits_.empty() && negative; return *this; } - // Unconditionally makes the sign of this bignum positive. - Bignum& SetPositive() { - sign_ = +1; - return *this; - } - - // Unconditionally set the sign of this bignum to match the sign of the - // argument. If the argument is zero, set the bignum to zero. - Bignum& SetSign(int sign) { - if (sign == 0) { - return SetZero(); - } - - if (sign < 0) { - return SetNegative(); - } - return SetPositive(); - } - // Returns true if the number is zero. - bool is_zero() const { // - return sign_ == 0; - } - - // Returns true if the number is greater than zero. - bool positive() const { // - return sign_ > 0; - } + bool is_zero() const { return bigits_.empty(); } // Returns true if the number is less than zero. - bool negative() const { // - return sign_ < 0; - } + bool is_negative() const { return negative_; } // Returns true if the number is odd (least significant bit is 1). - bool is_odd() const { return Bit(0); } + bool is_odd() const { return is_bit_set(0); } // Returns true if the number is even (least significant bit is 0). bool is_even() const { return !is_odd(); } @@ -177,7 +153,7 @@ class Bignum { //-------------------------------------- bool operator==(const Bignum& b) const { - return sign_ == b.sign_ && bigits_ == b.bigits_; + return negative_ == b.negative_ && bigits_ == b.bigits_; } bool operator!=(const Bignum& b) const { return !(*this == b); } @@ -208,91 +184,43 @@ class Bignum { friend Bignum operator<<(Bignum a, int nbit) { return a <<= nbit; } friend Bignum operator>>(Bignum a, int nbit) { return a >>= nbit; } - private: - // Constructs a Bignum from bigits and an optional sign bit. - explicit Bignum(BigitVector bigits, int sign = +1) - : bigits_(std::move(bigits)) { - NormalizeSign(sign); - } - - // Returns the number of bigits in this bignum. - size_t size() const { // - return bigits_.size(); - } - - // Returns true if this value has no digits. - bool empty() const { // - return bigits_.empty(); + // Negates this bignum in place. + void negate() { + negative_ = !negative_; + Normalize(); } - // Compare to another bignum, returns -1, 0, +1. - int Compare(const Bignum& b) const { - if (sign_ != b.sign_) { - return sign_ < b.sign_ ? -1 : 1; - } - - // Signs are equal, are they both zero? - if (sign_ == 0) { - return 0; - } + // Compares to another bignum, returning -1, 0, +1. + int Compare(const Bignum& b) const; - // Signs are equal and non-zero, compare magnitude. - return positive() ? CmpAbs(b) : -CmpAbs(b); + private: + // Constructs a Bignum from bigits and an optional sign bit. + explicit Bignum(BigitVector bigits, bool negative = false) + : bigits_(std::move(bigits)), negative_(negative) { + Normalize(); } // Multiplies two unsigned bigit vectors together using Karatsuba's algorithm. - static BigitVector KaratsubaMul( // - absl::Span a, absl::Span b); + static BigitVector KaratsubaMul(absl::Span a, + absl::Span b); - // Drop leading zero bigits. + // Drop leading zero bigits, and ensure sign is positive if result is zero. void Normalize() { - while (!empty() && bigits_.back() == 0) { + while (!bigits_.empty() && bigits_.back() == 0) { bigits_.pop_back(); } - - if (empty()) { - sign_ = 0; - } - } - - // Drop leading zero bigits and canonicalize sign. - void NormalizeSign(int sign) { - Normalize(); - sign_ = empty() ? 0 : sign; + negative_ = !bigits_.empty() && negative_; } // Returns true if the bignum is in normal form (no extra leading zeros). - bool Normalized() const { // - return bigits_.empty() || bigits_.back() != 0; - } - - // Returns true if the bignum magnitude is the given power of two. - bool IsPow2(int pow2) const { - const size_t bigits = pow2 / kBigitBits; - if (bigits_.size() != bigits + 1) { - return false; - } - - // Verify lower words are zero. - for (size_t i = 0; i < bigits; ++i) { - if (bigits_[i] != 0) { - return false; - } - } - - // Check final word is power of two. - pow2 -= bigits * kBigitBits; - ABSL_DCHECK_LT(pow2, kBigitBits); - return bigits_.back() == (Bigit(1) << pow2); - } - - // Compares magnitude with another bignum, returning -1, 0, or +1. - int CmpAbs(const Bignum& b) const { - return ::exactfloat_internal::CmpAbs(bigits_, b.bigits_); - } + bool is_normalized() const { return bigits_.empty() || bigits_.back() != 0; } + // We store bignums in sign-magnitude form. bigits_ contains the individual + // 64-bit digits of the bignum. If bigits_ is non-empty, then the last element + // must be non-zero and when it is empty (representing a zero value), + // negative_ must be false. BigitVector bigits_; - int sign_ = 0; + bool negative_ = false; }; //////////////////////////////////////////////////////////////////////////////// @@ -308,9 +236,10 @@ Bignum::Bignum(T value) { return; } - sign_ = +1; + negative_ = false; if constexpr (std::is_signed_v) { - sign_ = (value < 0) ? -1 : +1; + // Put into constexpr if to avoid warnings when T is unsigned. + negative_ = (value < 0); } // Get magnitude of value, handle minimum value of T cleanly. @@ -341,24 +270,27 @@ void AbslStringify(Sink& sink, const Bignum& b) { } // Sign - if (b.negative()) { + if (b.is_negative()) { sink.Append("-"); } // Work on a copy of the magnitude. Bignum copy = b; - copy.sign_ = 1; + copy.negative_ = false; - // Repeatedly divide and modulo by 10^19 to get decimal chunks. - static constexpr uint64_t kBase = 10'000'000'000'000'000'000u; + // 10**19 is the largest power of 10 that fits in 64-bits. So we can + // repeatedly divide and modulo the bignum to get uint64_t values we can + // format as 19 decimal digits. + static_assert(sizeof(Bigit) * 8 == 64); + static constexpr uint64_t kChunkDivisor = 10'000'000'000'000'000'000u; Bignum::BigitVector chunks; while (!copy.is_zero()) { absl::uint128 rem = 0; for (int i = static_cast(copy.bigits_.size()) - 1; i >= 0; --i) { absl::uint128 acc = (rem << 64) + copy.bigits_[i]; - Bigit quot = static_cast(acc / kBase); - rem = acc - absl::uint128(quot) * kBase; + Bigit quot = static_cast(acc / kChunkDivisor); + rem = acc - absl::uint128(quot) * kChunkDivisor; copy.bigits_[i] = quot; } @@ -380,7 +312,7 @@ template inline bool Bignum::FitsIn() const { using UT = std::make_unsigned_t; - if (sign_ == 0) { + if (is_zero()) { return true; } @@ -396,7 +328,7 @@ inline bool Bignum::FitsIn() const { // Unsigned type T can hold the value iff the value is non-negative and // the bitwidth is <= the maximum bit width of the type. if constexpr (!std::is_signed_v) { - if (negative()) { + if (is_negative()) { return false; } return bit_width(*this) <= kTBitWidth; @@ -405,16 +337,16 @@ inline bool Bignum::FitsIn() const { // T is signed and our bignum isn't zero. ABSL_DCHECK(std::is_signed_v && !is_zero()); - if (positive()) { - return bit_width(*this) <= (kTBitWidth - 1); - } else /* negative() */ { + if (is_negative()) { // Magnitude must fit in negative value. If the value is negative and // the same bit width as the output type, the only valid value is // -2^(k-1). if (bit_width(*this) == kTBitWidth) { - return IsPow2(kTBitWidth - 1); + return countr_zero(*this) == kTBitWidth - 1; } return bit_width(*this) < kTBitWidth; + } else /* positive */ { + return bit_width(*this) <= (kTBitWidth - 1); } } @@ -424,28 +356,24 @@ T Bignum::Cast() const { constexpr int kTBitWidth = std::numeric_limits::digits; - if (empty()) { + if (bigits_.empty()) { return 0; } - // Grab the bottom bits into an unsigned value. + // T fits in a Bigit, so just cast to truncate. UT residue = 0; - for (size_t i = 0; i < bigits_.size(); ++i) { - const int shift = i * kBigitBits; - if (shift >= kTBitWidth) { - break; - } - - const int room = kTBitWidth - shift; - UT chunk = static_cast(bigits_[i]); - if (room < kBigitBits && room < std::numeric_limits::digits) { - chunk &= (UT(1) << room) - UT(1); + if (kTBitWidth <= kBigitBits) { + residue = static_cast(bigits_[0]); + } else { + ABSL_DCHECK_EQ(kTBitWidth % kBigitBits, 0); + for (int i = 0; i < kTBitWidth / kBigitBits; ++i) { + UT chunk = static_cast(bigits_[i]); + residue |= chunk << (i * kBigitBits); } - residue |= (chunk << shift); } // Compute two's complement of the residue if value is negative. - if (negative()) { + if (is_negative()) { residue = UT(0) - residue; } diff --git a/src/s2/util/math/exactfloat/bignum_test.cc b/src/s2/util/math/exactfloat/bignum_test.cc index b6ef4ef8..135433c2 100644 --- a/src/s2/util/math/exactfloat/bignum_test.cc +++ b/src/s2/util/math/exactfloat/bignum_test.cc @@ -15,7 +15,11 @@ #include "s2/util/math/exactfloat/bignum.h" -#include +#include +#include +#include +#include +#include #include #include #include @@ -25,10 +29,10 @@ #include "benchmark/benchmark.h" #endif -#include "absl/base/no_destructor.h" #include "absl/random/bit_gen_ref.h" #include "absl/random/random.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "gtest/gtest.h" #include "openssl/bn.h" @@ -566,34 +570,34 @@ TEST(BignumTest, CountrZero) { EXPECT_EQ(countr_zero(neg_large_shifted), 200); } -TEST(BignumTest, Bit) { - EXPECT_FALSE(Bignum(0).Bit(0)); - EXPECT_FALSE(Bignum(0).Bit(100)); +TEST(BignumTest, is_bit_set) { + EXPECT_FALSE(Bignum(0).is_bit_set(0)); + EXPECT_FALSE(Bignum(0).is_bit_set(100)); // 5 = 0b101 Bignum five(5); - EXPECT_TRUE(five.Bit(0)); - EXPECT_FALSE(five.Bit(1)); - EXPECT_TRUE(five.Bit(2)); - EXPECT_FALSE(five.Bit(3)); + EXPECT_TRUE(five.is_bit_set(0)); + EXPECT_FALSE(five.is_bit_set(1)); + EXPECT_TRUE(five.is_bit_set(2)); + EXPECT_FALSE(five.is_bit_set(3)); // Negative numbers should test the magnitude. Bignum neg_five(-5); - EXPECT_TRUE(neg_five.Bit(0)); - EXPECT_FALSE(neg_five.Bit(1)); - EXPECT_TRUE(neg_five.Bit(2)); + EXPECT_TRUE(neg_five.is_bit_set(0)); + EXPECT_FALSE(neg_five.is_bit_set(1)); + EXPECT_TRUE(neg_five.is_bit_set(2)); // Test edges of and across bigits. - Bignum high_bit_63 = Bignum(1) << 63; - EXPECT_FALSE(high_bit_63.Bit(62)); - EXPECT_TRUE(high_bit_63.Bit(63)); - EXPECT_FALSE(high_bit_63.Bit(64)); + Bignum high_is_bit_set_63 = Bignum(1) << 63; + EXPECT_FALSE(high_is_bit_set_63.is_bit_set(62)); + EXPECT_TRUE(high_is_bit_set_63.is_bit_set(63)); + EXPECT_FALSE(high_is_bit_set_63.is_bit_set(64)); Bignum cross_bigit = (Bignum(1) << 100) + Bignum(1); - EXPECT_TRUE(cross_bigit.Bit(0)); - EXPECT_TRUE(cross_bigit.Bit(100)); - EXPECT_FALSE(cross_bigit.Bit(50)); - EXPECT_FALSE(cross_bigit.Bit(1000)); + EXPECT_TRUE(cross_bigit.is_bit_set(0)); + EXPECT_TRUE(cross_bigit.is_bit_set(100)); + EXPECT_FALSE(cross_bigit.is_bit_set(50)); + EXPECT_FALSE(cross_bigit.is_bit_set(1000)); } TEST(BignumTest, Pow) { @@ -626,35 +630,29 @@ TEST(BignumTest, Pow) { TEST(BignumTest, SetZero) { Bignum a(123); - a.SetZero(); + a.set_zero(); EXPECT_TRUE(a.is_zero()); Bignum b(-456); - b.SetZero(); + b.set_zero(); EXPECT_EQ(b, Bignum(0)); } TEST(BignumTest, SetNegativeSetPositive) { Bignum a(42); - a.SetNegative(); - EXPECT_TRUE(a.negative()); + a.set_negative(); + EXPECT_TRUE(a.is_negative()); EXPECT_EQ(a, Bignum(-42)); - a.SetPositive(); - EXPECT_TRUE(a.positive()); + a.set_negative(false); + EXPECT_FALSE(a.is_negative()); EXPECT_EQ(a, Bignum(42)); -} - -TEST(BignumTest, SetSign) { - Bignum a(99); - a.SetSign(-10); // any negative - EXPECT_EQ(a, Bignum(-99)); - a.SetSign(5); // any positive - EXPECT_EQ(a, Bignum(99)); - - a.SetSign(0); - EXPECT_TRUE(a.is_zero()); + // set_negative() has no effect on zero. + Bignum b(0); + b.set_negative(); + EXPECT_FALSE(b.is_negative()); + EXPECT_EQ(b, Bignum(0)); } TEST(BignumTest, Comparisons) { @@ -715,7 +713,9 @@ class OpenSSLBignum { OpenSSLBignum() : bn_(BN_new()) {} // Construct from a decimal number in a string. - explicit OpenSSLBignum(const absl::string_view& decimal) : bn_(BN_new()) { + // + // We take decimal as a string so that it's explicitly zero-terminated. + explicit OpenSSLBignum(const std::string& decimal) : bn_(BN_new()) { BN_dec2bn(&bn_, decimal.data()); } @@ -756,8 +756,8 @@ class OpenSSLBignum { // Power of two for fast modulo. const int kRandomBignumCount = 128; -static std::vector GenerateRandomNumbers(absl::BitGenRef bitgen, - int bits) { +static std::vector GenerateRandomNumberStrings( + absl::BitGenRef bitgen, int bits) { std::vector numbers; numbers.reserve(kRandomBignumCount); @@ -810,42 +810,35 @@ TEST(BignumTest, ResultsMatch) { OPENSSL_free(ssl_str); } -const std::vector& SmallNumbers(absl::BitGenRef bitgen) { - static absl::NoDestructor> numbers( // - GenerateRandomNumbers(bitgen, 64)); - return *numbers; -} - -const std::vector& MediumNumbers(absl::BitGenRef bitgen) { - static absl::NoDestructor> numbers( // - GenerateRandomNumbers(bitgen, 256)); - return *numbers; -} - -const std::vector& LargeNumbers(absl::BitGenRef bitgen) { - static absl::NoDestructor> numbers( // - GenerateRandomNumbers(bitgen, 1024)); - return *numbers; -} +// Different number sizes for benchmarking. +enum class NumberSizeClass : uint32_t { + kSmall = 64, + kMedium = 256, + kLarge = 1024, + kHuge = 4096, + kMega = 18000 +}; -const std::vector& HugeNumbers(absl::BitGenRef bitgen) { - static absl::NoDestructor> numbers( // - GenerateRandomNumbers(bitgen, 4096)); - return *numbers; +std::vector RandomNumberStrings(absl::BitGenRef bitgen, + NumberSizeClass size_class) { + return GenerateRandomNumberStrings(bitgen, static_cast(size_class)); } -const std::vector& MegaNumbers(absl::BitGenRef bitgen) { - static absl::NoDestructor> numbers( // - GenerateRandomNumbers(bitgen, 18000)); - return *numbers; -} +class VsOpenSSLTest : public TestWithParam { + protected: + std::vector Numbers() { + return RandomNumberStrings(bitgen_, GetParam()); + } -class VsOpenSSLTest : public TestWithParam> {}; + private: + absl::BitGen bitgen_; +}; TEST_P(VsOpenSSLTest, SquaringCorrect) { - // Test that multiplication produces correct results by comparing to OpenSSL. + // Test that multiplication produces correct results by comparing to + // OpenSSL. BN_CTX* ctx = BN_CTX_new(); - for (const auto& number : GetParam()) { + for (const auto& number : Numbers()) { // Test same number multiplication (most likely to trigger edge cases) const Bignum bn_a = *Bignum::FromString(number); const Bignum bn_result = bn_a * bn_a; @@ -870,7 +863,7 @@ TEST_P(VsOpenSSLTest, SquaringCorrect) { TEST_P(VsOpenSSLTest, MultiplyCorrect) { // Multiply by a small constant to test widely different operand sizes. BN_CTX* ctx = BN_CTX_new(); - for (const auto& number : GetParam()) { + for (const auto& number : Numbers()) { // Test same number multiplication (most likely to trigger edge cases) const Bignum bn_a = *Bignum::FromString(number); const Bignum bn_result = Bignum(2) * bn_a; @@ -894,7 +887,7 @@ TEST_P(VsOpenSSLTest, MultiplyCorrect) { TEST_P(VsOpenSSLTest, AdditionCorrect) { // Test that addition produces correct results by comparing to OpenSSL. - const std::vector numbers = GetParam(); + const std::vector numbers = Numbers(); for (size_t i = 0; i < numbers.size(); ++i) { const auto& num_a = numbers[i]; const auto& num_b = numbers[(i + 1) % numbers.size()]; @@ -922,8 +915,9 @@ TEST_P(VsOpenSSLTest, AdditionCorrect) { } TEST_P(VsOpenSSLTest, SubtractionCorrect) { - // Test that subtraction produces correct results by comparing to OpenSSL. - const std::vector numbers = GetParam(); + // Test that subtraction produces correct results by comparing to + // OpenSSL. + const std::vector numbers = Numbers(); for (size_t i = 0; i < numbers.size(); ++i) { const auto& num_a = numbers[i]; const auto& num_b = numbers[(i + 1) % numbers.size()]; @@ -950,13 +944,12 @@ TEST_P(VsOpenSSLTest, SubtractionCorrect) { } } -absl::BitGen bitgen; INSTANTIATE_TEST_SUITE_P(VsOpenSSL, VsOpenSSLTest, - ::testing::Values(SmallNumbers(bitgen), - MediumNumbers(bitgen), - LargeNumbers(bitgen), - HugeNumbers(bitgen), - MediumNumbers(bitgen))); + ::testing::Values(NumberSizeClass::kSmall, + NumberSizeClass::kMedium, + NumberSizeClass::kLarge, + NumberSizeClass::kHuge, + NumberSizeClass::kMega)); // TODO: Enable once benchmark is integrated. #if 0 @@ -1067,6 +1060,26 @@ void OpenSSLPowBenchmark(benchmark::State& state, BN_CTX_free(ctx); } +std::vector SmallNumbers(absl::BitGenRef bitgen) { + return RandomNumberStrings(bitgen, NumberSizeClass::kSmall); +} + +std::vector MediumNumbers(absl::BitGenRef bitgen) { + return RandomNumberStrings(bitgen, NumberSizeClass::kMedium); +} + +std::vector LargeNumbers(absl::BitGenRef bitgen) { + return RandomNumberStrings(bitgen, NumberSizeClass::kLarge); +} + +std::vector HugeNumbers(absl::BitGenRef bitgen) { + return RandomNumberStrings(bitgen, NumberSizeClass::kHuge); +} + +std::vector MegaNumbers(absl::BitGenRef bitgen) { + return RandomNumberStrings(bitgen, NumberSizeClass::kMega); +} + void BM_Bignum_AddSmall(benchmark::State& state) { std::mt19937_64 bitgen; BignumBinaryOpBenchmark(state, SmallNumbers(bitgen), std::plus{}); diff --git a/src/s2/util/math/exactfloat/exactfloat.cc b/src/s2/util/math/exactfloat/exactfloat.cc index fc4cf9a2..89521ef4 100644 --- a/src/s2/util/math/exactfloat/exactfloat.cc +++ b/src/s2/util/math/exactfloat/exactfloat.cc @@ -24,11 +24,8 @@ #include #include -#include "absl/base/macros.h" -#include "absl/container/fixed_array.h" #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" -#include "absl/numeric/bits.h" // IWYU pragma: keep #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" @@ -67,10 +64,13 @@ ExactFloat::ExactFloat(double v) { ExactFloat::ExactFloat(int v) { sign_ = (v >= 0) ? 1 : -1; - - // Note that this works even for INT_MIN, as |INT_MIN| < |INT_MAX|. - bn_ = Bignum(abs(v)); bn_exp_ = 0; + + if (v == std::numeric_limits::min()) { + bn_ = Bignum(unsigned(0) - static_cast(v)); + } else { + bn_ = Bignum(abs(v)); + } Canonicalize(); } @@ -102,19 +102,19 @@ int ExactFloat::exp() const { void ExactFloat::set_zero(int sign) { sign_ = sign; bn_exp_ = kExpZero; - bn_.SetZero(); + bn_.set_zero(); } void ExactFloat::set_inf(int sign) { sign_ = sign; bn_exp_ = kExpInfinity; - bn_.SetZero(); + bn_.set_zero(); } void ExactFloat::set_nan() { sign_ = 1; bn_exp_ = kExpNaN; - bn_.SetZero(); + bn_.set_zero(); } double ExactFloat::ToDouble() const { @@ -136,12 +136,12 @@ double ExactFloat::ToDoubleHelper() const { } return std::copysign(std::numeric_limits::quiet_NaN(), sign_); } - auto opt_mantissa = bn_.ConvertTo(); - ABSL_DCHECK(opt_mantissa.has_value()); + + const double d_mantissa = static_cast(bn_.Cast()); // We rely on ldexp() to handle overflow and underflow. (It will return a // signed zero or infinity if the result is too small or too large.) - return sign_ * ldexp(static_cast(*opt_mantissa), bn_exp_); + return sign_ * ldexp(d_mantissa, bn_exp_); } ExactFloat ExactFloat::RoundToMaxPrec(int max_prec, RoundingMode mode) const { @@ -189,7 +189,7 @@ ExactFloat ExactFloat::RoundToPowerOf2(int bit_exp, RoundingMode mode) const { // Never increment. } else if (mode == kRoundTiesAwayFromZero) { // Increment if the highest discarded bit is 1. - if (bn_.Bit(shift - 1)) increment = true; + if (bn_.is_bit_set(shift - 1)) increment = true; } else if (mode == kRoundAwayFromZero) { // Increment unless all discarded bits are zero. if (countr_zero(bn_) < shift) increment = true; @@ -201,8 +201,8 @@ ExactFloat ExactFloat::RoundToPowerOf2(int bit_exp, RoundingMode mode) const { // 0/10* -> Don't increment (fraction = 1/2, kept part even) // 1/10* -> Increment (fraction = 1/2, kept part odd) // ./1.*1.* -> Increment (fraction > 1/2) - if (bn_.Bit(shift - 1) && - ((bn_.Bit(shift) || countr_zero(bn_) < shift - 1))) { + if (bn_.is_bit_set(shift - 1) && + ((bn_.is_bit_set(shift) || countr_zero(bn_) < shift - 1))) { increment = true; } } @@ -322,8 +322,7 @@ int ExactFloat::GetDecimalDigits(int max_digits, std::string* digits) const { } else { // Set bn = bn_ * (5 ** -bn_exp_) and bn_exp10 = bn_exp_. This is // equivalent to the original value of (bn_ * (2 ** bn_exp_)). - int power = -bn_exp_; - bn = Bignum(5).Pow(power) * bn_; + bn = bn_ * Bignum(5).Pow(-bn_exp_); bn_exp10 = bn_exp_; } // Now convert "bn" to a decimal string using our Bignum's string conversion. @@ -341,7 +340,7 @@ int ExactFloat::GetDecimalDigits(int max_digits, std::string* digits) const { // up only if the lowest kept digit is odd. if (all_digits[max_digits] >= '5' && ((all_digits[max_digits - 1] & 1) == 1 || - all_digits.substr(max_digits + 1).find_first_of("123456789") != + all_digits.substr(max_digits + 1).find_first_not_of("0") != std::string::npos)) { // This can increase the number of digits by 1, but in that case at // least one trailing zero will be stripped off below. @@ -405,24 +404,25 @@ ExactFloat ExactFloat::SignedSum(int a_sign, const ExactFloat* a, int b_sign, swap(a_sign, b_sign); swap(a, b); } + // Shift "a" if necessary so that both values have the same bn_exp_. ExactFloat r; - Bignum a_bn; - if (a->bn_exp_ > b->bn_exp_) { - a_bn = a->bn_ << (a->bn_exp_ - b->bn_exp_); - } else { - a_bn = a->bn_; - } + r.bn_ = a->bn_; + r.bn_ <<= (a->bn_exp_ - b->bn_exp_); + r.bn_exp_ = b->bn_exp_; if (a_sign == b_sign) { - r.bn_ = a_bn + b->bn_; + r.bn_ += b->bn_; r.sign_ = a_sign; } else { - if (a_bn >= b->bn_) { - r.bn_ = a_bn - b->bn_; + if (r.bn_ >= b->bn_) { + // |a| >= |b|, compute |a| - |b|, result has same sign as a. + r.bn_ -= b->bn_; r.sign_ = a_sign; } else { - r.bn_ = b->bn_ - a_bn; + // |a| < |b|, compute -|a| + |b| == |b| - |a|, result has same sign as b. + r.bn_.negate(); + r.bn_ += b->bn_; r.sign_ = b_sign; } if (r.bn_.is_zero()) { @@ -505,9 +505,7 @@ int ExactFloat::ScaleAndCompare(const ExactFloat& b) const { ABSL_DCHECK(is_normal() && b.is_normal() && bn_exp_ >= b.bn_exp_); ExactFloat tmp = *this; tmp.bn_ <<= (bn_exp_ - b.bn_exp_); - if (tmp.bn_ < b.bn_) return -1; - if (tmp.bn_ > b.bn_) return 1; - return 0; + return tmp.bn_.Compare(b.bn_); } bool ExactFloat::UnsignedLess(const ExactFloat& b) const { @@ -599,9 +597,7 @@ T ExactFloat::ToInteger(RoundingMode mode) const { if (!r.is_inf()) { // If the unsigned value has more than 63 bits it is always clamped. if (r.exp() < 64) { - auto opt_value = r.bn_.ConvertTo(); - ABSL_DCHECK(opt_value.has_value()); - int64_t value = static_cast(opt_value.value()) << r.bn_exp_; + int64_t value = r.bn_.Cast() << r.bn_exp_; if (r.sign_ < 0) value = -value; return max(kMinValue, min(kMaxValue, value)); } diff --git a/src/s2/util/math/exactfloat/exactfloat.h b/src/s2/util/math/exactfloat/exactfloat.h index 0dd44002..3980e15b 100644 --- a/src/s2/util/math/exactfloat/exactfloat.h +++ b/src/s2/util/math/exactfloat/exactfloat.h @@ -501,7 +501,10 @@ class ExactFloat { // - bn_ is a Bignum with a positive value // - bn_exp_ is the base-2 exponent applied to bn_. // - // Bignum supports negative values so that subtraction can be supported. + // bn_ stores the magnitude for the mantissa of the floating point value. We + // store a sign bit here separately from bn_ so that functions like exp() are + // easier to reason about (as the sign would flip depending on whether the + // exponent were odd or even). int32_t sign_ = 1; int32_t bn_exp_ = kExpZero; Bignum bn_; diff --git a/src/s2/util/math/exactfloat/exactfloat_test.cc b/src/s2/util/math/exactfloat/exactfloat_test.cc index ee8002d3..bd95d684 100644 --- a/src/s2/util/math/exactfloat/exactfloat_test.cc +++ b/src/s2/util/math/exactfloat/exactfloat_test.cc @@ -446,6 +446,10 @@ TEST_F(ExactFloatTest, Constructors) { // Copy constructor. ExactFloat e = c; ExpectSameWithPrec(-125, 7, e); + + // Ensure that construction with INT_MIN works properly. + ExactFloat f = INT_MIN; + ExpectSame(INT_MIN, f); } TEST_F(ExactFloatTest, Constants) { From ddc7b6977638cb9fe9cc9c7373baf273c0c89d78 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Fri, 3 Oct 2025 15:56:51 -0600 Subject: [PATCH 09/31] PR (Mega) Round 5 more changes. --- src/s2/util/math/exactfloat/bignum.cc | 321 +++++++++++++------------- src/s2/util/math/exactfloat/bignum.h | 4 - 2 files changed, 159 insertions(+), 166 deletions(-) diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index 13c881ba..5e22b26a 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -304,7 +304,7 @@ inline Bigit MulCarry(Bigit a, Bigit b, Bigit* absl_nonnull carry) { return static_cast(sum); } -// Computes out += a * b + c and updates the carry. +// Computes out += a * b + carry and updates the carry. // // NOTE: Will not overflow even if a, b, and c are their maximum values. inline void MulAddCarry(Bigit& out, Bigit a, Bigit b, @@ -331,29 +331,25 @@ inline void MulAddCarry(Bigit& out, Bigit a, Bigit b, inline Bigit AddInPlace(absl::Span a, absl::Span b) { ABSL_DCHECK_GE(a.size(), b.size()); - Bigit* pa = a.data(); - const Bigit* pb = b.data(); - - int left = b.size(); Bigit carry = 0; // Dispatch four at a time to help loop unrolling. - while (left >= 4) { - for (int i = 0; i < 4; ++i, ++pa, ++pb) { - *pa = AddCarry(*pa, *pb, &carry); - --left; + int i = 0; + while (i + 4 <= b.size()) { + for (int j = 0; j < 4; ++j, ++i) { + a[i] = AddCarry(a[i], b[i], &carry); } } // Finish remainder. - for (; left > 0; --left, pa++, pb++) { - *pa = AddCarry(*pa, *pb, &carry); + for (; i < b.size(); ++i) { + a[i] = AddCarry(a[i], b[i], &carry); } // Propagate carry through the rest of a. - int remaining = a.size() - b.size(); - for (; carry && remaining > 0; --remaining, pa++) { - *pa = AddCarry(*pa, 0, &carry); + int left = a.size() - b.size(); + for (; carry && left > 0; --left, ++i) { + a[i] = AddCarry(a[i], 0, &carry); } return carry; @@ -379,145 +375,130 @@ inline size_t AddOutOfPlace(absl::Span dst, absl::Span a, const size_t min_size = std::min(a.size(), b.size()); ABSL_DCHECK_GE(dst.size(), max_size + 1); - Bigit* pdst = dst.data(); - const Bigit* pa = a.data(); - const Bigit* pb = b.data(); - // Add common parts. Bigit carry = 0; // Dispatch four at a time to help loop unrolling. size_t i = 0; while (i + 4 < min_size) { - for (int j = 0; j < 4; ++j) { - pdst[i] = AddCarry(pa[i], pb[i], &carry); - ++i; + for (int j = 0; j < 4; ++j, ++i) { + dst[i] = AddCarry(a[i], b[i], &carry); } } - // Finish remainder of common parts. + // Finish remainder of the parts common to A and B. for (; i < min_size; ++i) { - pdst[i] = AddCarry(pa[i], pb[i], &carry); + dst[i] = AddCarry(a[i], b[i], &carry); } // Copy remaining digits from the longer operand and propagate carry. auto longer = (a.size() > b.size()) ? a : b; - const Bigit* plonger = (a.size() > b.size()) ? pa : pb; // Dispatch four at a time for the remaining part. const int size = longer.size(); while (i + 4 < size) { - for (int j = 0; j < 4; ++j) { - pdst[i] = AddCarry(plonger[i], 0, &carry); - ++i; + for (int j = 0; j < 4; ++j, ++i) { + dst[i] = AddCarry(longer[i], 0, &carry); } } - // Finish remainder. + // Propagate carry through the longer operand. for (; i < size; ++i) { - pdst[i] = AddCarry(plonger[i], 0, &carry); + dst[i] = AddCarry(longer[i], 0, &carry); } if (carry) { - pdst[i++] = carry; + dst[i++] = carry; return max_size + 1; } + return max_size; } -// Computes a -= b. Returns the final borrow (if any). +// Computes a -= b. // -// A must be expanded to match the size of B and the total number of digits -// actually set in A must be passed in via a_digits. -// -// REQUIRES: |a| < |b|. -inline Bigit SubLtInPlace(absl::Span a, absl::Span b, - size_t a_digits) { - ABSL_DCHECK_EQ(a.size(), b.size()); - ABSL_DCHECK_LT(CmpAbs(a, b), 0); +// REQUIRES: |a| >= |b|. +inline void SubInPlace(absl::Span a, absl::Span b) { + ABSL_DCHECK_GE(a.size(), b.size()); + ABSL_DCHECK_GE(CmpAbs(a, b), 0); - Bigit* pa = a.data(); - const Bigit* pb = b.data(); Bigit borrow = 0; // Dispatch four at a time to help loop unrolling. - size_t size = a_digits; - size_t i = 0; - while (i + 4 < size) { - for (int j = 0; j < 4; ++j) { - pa[i] = SubBorrow(pb[i], pa[i], &borrow); - ++i; + size_t size = b.size(); + int i = 0; + while (i + 4 <= size) { + for (int j = 0; j < 4; ++j, ++i) { + a[i] = SubBorrow(a[i], b[i], &borrow); } } - // Finish remainder. - for (; i < a_digits; ++i) { - pa[i] = SubBorrow(pb[i], pa[i], &borrow); + // Finish remainder of subtraction. + for (; i < size; ++i) { + a[i] = SubBorrow(a[i], b[i], &borrow); } - // Propagate borrow through the rest of b. - for (; borrow && i < b.size(); ++i) { - pa[i] = SubBorrow(pb[i], 0, &borrow); + // Propagate the borrow through a. + for (; borrow && i < a.size(); ++i) { + borrow = (a[i] == 0); + a[i]--; } - return borrow; } -// Computes a -= b. Returns the final borrow (if any). +// Computes dst = a - b. // -// REQUIRES: |a| >= |b|. -inline Bigit SubGeInPlace(absl::Span a, absl::Span b) { - ABSL_DCHECK_GE(a.size(), b.size()); - ABSL_DCHECK_GE(CmpAbs(a, b), 0); +// Requires |a| >= |b| and dst is thus the same size as a. +// A must be expanded to match the size of B and the total number of digits +// actually set in A must be passed in via a_digits. +inline Bigit SubOutOfPlace(absl::Span dst, absl::Span a, + absl::Span b, size_t digits) { + ABSL_DCHECK_EQ(dst.size(), a.size()); + ABSL_DCHECK_LT(CmpAbs(a, b), 0); Bigit borrow = 0; - Bigit* pa = a.data(); - const Bigit* pb = b.data(); - // Dispatch four at a time to help loop unrolling. - size_t size = b.size(); - size_t done = 0; - while (done + 4 < size) { - for (int i = 0; i < 4; ++i) { - pa[done] = SubBorrow(pa[done], pb[done], &borrow); - ++done; + size_t size = digits; + size_t i = 0; + while (i + 4 < size) { + for (int j = 0; j < 4; ++j, ++i) { + dst[i] = SubBorrow(a[i], b[i], &borrow); } } - // Finish remainder of subtraction. - for (; done < size; ++done) { - pa[done] = SubBorrow(pa[done], pb[done], &borrow); + // Finish remainder. + for (; i < digits; ++i) { + dst[i] = SubBorrow(a[i], b[i], &borrow); } - // Propagate the borrow through a. - for (; borrow && done < a.size(); ++done) { - borrow = (a[done] == 0); - a[done]--; + // Propagate borrow through the rest of a. + for (; borrow && i < a.size(); ++i) { + dst[i] = SubBorrow(a[i], 0, &borrow); } + return borrow; } Bigit MulAdd(absl::Span out, absl::Span a, Bigit b, - Bigit c) { + Bigit carry) { ABSL_DCHECK_GE(out.size(), a.size()); - Bigit* pout = out.data(); - const Bigit* pa = a.data(); - int left = a.size(); // Dispatch four at a time to help loop unrolling. - while (left >= 4) { - for (int i = 0; i < 4; ++i, pa++, pout++) { - *pout = MulCarry(*pa, b, &c); + int i = 0; + while (i + 4 <= a.size()) { + for (int j = 0; j < 4; ++j, ++i) { + out[i] = MulCarry(a[i], b, &carry); --left; } } - for (; left > 0; --left, pa++, pout++) { - *pout = MulCarry(*pa, b, &c); + for (; i < a.size(); ++i) { + out[i] = MulCarry(a[i], b, &carry); } - return c; + + return carry; } // Computes out[i] += a[i]*b in place. @@ -525,28 +506,31 @@ Bigit MulAdd(absl::Span out, absl::Span a, Bigit b, // Returns the final carry, if any. inline Bigit MulAddInPlace(absl::Span out, absl::Span a, Bigit b) { - Bigit* pout = out.data(); - const Bigit* pa = a.data(); - int left = a.size(); // Dispatch four at a time to help loop unrolling. Bigit carry = 0; - while (left >= 4) { - for (int i = 0; i < 4; ++i) { - MulAddCarry(*pout++, *pa++, b, &carry); - --left; + int i = 0; + while (i + 4 <= a.size()) { + for (int j = 0; j < 4; ++j, ++i) { + MulAddCarry(out[i], a[i], b, &carry); } } // Finish remainder. - while (left--) { - MulAddCarry(*pout++, *pa++, b, &carry); + for (; i < a.size(); ++i) { + MulAddCarry(out[i], a[i], b, &carry); } return carry; } +// Implements the standard grade school long multiplication algorithm. The +// output is computed by multiplying A by each digit of B and summing the +// results as we go. This is a quadratic algorithm and only serves as the base +// case for the recursive Karatsuba algorithm below. +// +// NOTE: out must be at least as large as the sums of the sizes of A and B. inline void MulQuadratic(absl::Span out, absl::Span a, absl::Span b) { ABSL_DCHECK_GE(out.size(), a.size() + b.size()); @@ -562,30 +546,28 @@ inline void MulQuadratic(absl::Span out, absl::Span a, return; } + // Each call to MulAdd and MulAddInPlace only updates a.size() elements of out + // so we manually set the carries as we go. We grab a span to the upper half + // of out starting at a.size() to facilitate this. auto upper = out.subspan(a.size()); upper[0] = MulAdd(out, a, b[0], 0); const size_t size = b.size(); size_t i = 1; - while (size >= i + 4) { - for (int j = 0; j < 4; ++j) { - upper[i] = MulAddInPlace(out.subspan(i), a, b[i]); - ++i; - } - } - - // Finish remainder (if any). for (; i < size; ++i) { upper[i] = MulAddInPlace(out.subspan(i), a, b[i]); } - // Finish zeroing out upper half. + // Finish zeroing out the upper half. for (; i < upper.size(); ++i) { upper[i] = 0; } } -// Split a span into two contiguous pieces of length a and b, respectively. +// Split a span into at most two contiguous spans of length a and b. +// +// If a + b < span.size() then the two spans only cover part of the input. +// If span.size() <= a, then the second span is empty. template inline std::pair, absl::Span> Split(absl::Span span, size_t a, size_t b) { @@ -613,9 +595,10 @@ class Arena { size_t Used() const { return used_; } - void Release(size_t n) { - ABSL_DCHECK_LE(n, used_); - used_ -= n; + // Resets the arena to the given position which must be < Used(). + void Reset(size_t to) { + ABSL_DCHECK_LE(to, used_); + used_ = to; } private: @@ -623,15 +606,17 @@ class Arena { std::vector data_; }; -inline void KaratsubaMulRec(absl::Span dst, absl::Span a, - absl::Span b, Arena& arena) { +inline void KaratsubaMulRecursive(absl::Span dst, + absl::Span a, + absl::Span b, + Arena* absl_nonnull arena) { ABSL_DCHECK_GE(dst.size(), a.size() + b.size()); if (a.empty() || b.empty()) { absl::c_fill(dst, 0); return; } - int arena_start = arena.Used(); + int arena_start = arena->Used(); // Karatsuba lets us represent two numbers of M bigits each, A and B, as: // @@ -658,7 +643,7 @@ inline void KaratsubaMulRec(absl::Span dst, absl::Span a, // with those individual multiplies able to be recursively divided. // Fall back to long multiplication when we're small enough. - if (dst.size() <= kSimpleMulThreshold) { + if (std::min(a.size(), b.size()) <= kSimpleMulThreshold) { MulQuadratic(dst, a, b); return; } @@ -669,66 +654,65 @@ inline void KaratsubaMulRec(absl::Span dst, absl::Span a, auto [a0, a1] = Split(a, half, half); auto [b0, b1] = Split(b, half, half); - // We can skip adding the z2 term if a1 or b1 is zero. - const bool z2_zero = (a1.empty() || b1.empty()); - // Make space to hold results in the output and multiply sub-terms. // z0 = a0 * b0 // z2 = a1 * b1 auto [z0, z2] = Split(dst, a0.size() + b0.size(), a1.size() + b1.size()); - KaratsubaMulRec(z0, a0, b0, arena); - KaratsubaMulRec(z2, a1, b1, arena); + KaratsubaMulRecursive(z0, a0, b0, arena); + KaratsubaMulRecursive(z2, a1, b1, arena); // Compute (a0 + a1) and (b0 + b1) // // If the upper terms are zero we can just re-use the terms we have, otherwise // we compute the sum and pop off the MSB bigit if no carry occurred. - absl::Span sa = a0; - absl::Span sb = b0; + absl::Span asum = a0; if (!a1.empty()) { - absl::Span tmp = arena.Alloc(half + 1); - sa = tmp.first(AddOutOfPlace(tmp, a0, a1)); + absl::Span tmp = arena->Alloc(half + 1); + asum = tmp.first(AddOutOfPlace(tmp, a0, a1)); } + absl::Span bsum = b0; if (!b1.empty()) { - absl::Span tmp = arena.Alloc(half + 1); - sb = tmp.first(AddOutOfPlace(tmp, b0, b1)); + absl::Span tmp = arena->Alloc(half + 1); + bsum = tmp.first(AddOutOfPlace(tmp, b0, b1)); } - // Compute z1 = sa*sb - z0 - z2 = (a0 + a1)*(b0 + b1) - z0 - z2 - auto z1 = arena.Alloc(sa.size() + sb.size()); + // Compute z1 = asum*bsum - z0 - z2 = (a0 + a1)*(b0 + b1) - z0 - z2 + auto z1 = arena->Alloc(asum.size() + bsum.size()); - // Compute sa * sb into the beginning of z1 - KaratsubaMulRec(z1, sa, sb, arena); + // Compute asum * bsum into the beginning of z1 + KaratsubaMulRecursive(z1, asum, bsum, arena); // NOTE: (a0 + a1) * (b0 + b1) >= a0*b0 + a1*b1 so this never underflows. - SubGeInPlace(z1, z0); - if (!z2_zero) { - SubGeInPlace(z1, z2); + SubInPlace(z1, z0); + if (!a1.empty() && !b1.empty()) { + SubInPlace(z1, z2); } // Z1 may overflow because of a carry in (a0 + b0) or (a1 + b1) but // subtracting z0 and z2 will always bring it back in range, trim any leading // zeros to shorten the value if needed. - int i = 0; - for (i = z1.size() - 1; i > 0; --i) { - if (z1[i]) { - break; - } + while (z1.back() == 0) { + z1 = z1.first(z1.size() - 1); } - z1 = z1.first(i + 1); - // We need to add z1*10^half which we can do by adding it offset. + // We need to add z1*10^half which we can do by adding it at an offset. AddInPlace(dst.subspan(half), z1); // Release temporary memory we used. - arena.Release(arena.Used() - arena_start); + arena->Reset(arena_start); } -Bignum::BigitVector Bignum::KaratsubaMul(absl::Span a, - absl::Span b) { +// Multiplies two unsigned bigit vectors together using Karatsuba's algorithm. +// +// This algorithm recursively subdivides the inputs until one or both is below +// some threshold, and then falls back to standard long multiplication. +void KaratsubaMul(absl::Span out, absl::Span a, + absl::Span b) { + ABSL_DCHECK_GE(out.size(), a.size() + b.size()); if (a.empty() || b.empty()) { - return {}; + absl::c_fill(out, 0); + return; } // Each step of Karatsuba splits at: @@ -737,19 +721,19 @@ Bignum::BigitVector Bignum::KaratsubaMul(absl::Span a, // We have to hold a total of 4*(N + 1) bigits as temporaries at each step. // // Simulate the recursion (log(n) steps) and compute the arena size. - int size = a.size() + b.size(); + int a_size = a.size(); + int b_size = b.size(); int peak = 0; - do { - int half = (size + 1) / 2; + while (std::min(a_size, b_size) > kSimpleMulThreshold) { + int half = (std::max(a_size, b_size) + 1) / 2; int next = half + 1; peak += 4 * next; - size = next; - } while (size > kSimpleMulThreshold); + a_size = next; + b_size = next; + }; Arena arena(peak); - BigitVector out(a.size() + b.size(), 0); - KaratsubaMulRec(absl::MakeSpan(out), a, b, arena); - return out; + KaratsubaMulRecursive(out, a, b, &arena); } Bignum& Bignum::operator+=(const Bignum& b) { @@ -764,32 +748,42 @@ Bignum& Bignum::operator+=(const Bignum& b) { if (is_negative() == b.is_negative()) { // Same sign: - // +|a| + +|b| == +|a + b| - // -|a| + -|b| == -|a + b| + // +|a| + +|b| == +(|a| + |b|) + // -|a| + -|b| == -(|a| + |b|) // - // So we can just sum magnitudes. + // So we can just sum magnitudes, final sign is the same as A. bigits_.resize(std::max(bigits_.size(), b.bigits_.size()), 0); Bigit carry = AddInPlace(absl::MakeSpan(bigits_), b.bigits_); if (carry) { bigits_.emplace_back(carry); } - Normalize(); } else { + // We know the signs are different, so there's two options: + // -|a| + +|b| = ?(|b| - |a|) + // +|a| + -|b| = ?(|a| - |b|) + // + // With the final sign being dependent on how |a| and |b| relate. if (CmpAbs(bigits_, b.bigits_) >= 0) { - // |a| >= |b|, so a - b is the same sign as a. - SubGeInPlace(absl::MakeSpan(bigits_), b.bigits_); - Normalize(); + // |a| >= |b| + // -|a| + +|b| --> -(|a| - |b|) + // +|a| + -|b| --> +(|a| - |b|) + // + // So we can subtract magnitudes, final sign is the same as A. + SubInPlace(absl::MakeSpan(bigits_), b.bigits_); } else { - // |a| < |b|, so a - b is the same sign as b. - const int prev_size = bigits_.size(); + // |a| < |b| + // -|a| + +|b| --> +(|b| - |a|) + // +|a| + -|b| --> -(|b| - |a|) + // + // So we can compute |b| - |a| and the final sign is the same as B. + size_t prev_size = bigits_.size(); bigits_.resize(b.bigits_.size()); - SubLtInPlace(absl::MakeSpan(bigits_), b.bigits_, prev_size); - - negative_ = b.negative_; - Normalize(); + SubOutOfPlace(absl::MakeSpan(bigits_), b.bigits_, bigits_, prev_size); + negative_ = b.is_negative(); } } + Normalize(); return *this; } @@ -830,7 +824,10 @@ Bignum& Bignum::operator*=(const Bignum& b) { // Use Karatsuba multiplication. // If the inputs are small enough this will just do long multiplication. - bigits_ = KaratsubaMul(bigits_, b.bigits_); + BigitVector result; + result.resize(bigits_.size() + b.bigits_.size()); + KaratsubaMul(absl::MakeSpan(result), bigits_, b.bigits_); + bigits_ = std::move(result); negative_ = negative; Normalize(); diff --git a/src/s2/util/math/exactfloat/bignum.h b/src/s2/util/math/exactfloat/bignum.h index 89baae38..adac4490 100644 --- a/src/s2/util/math/exactfloat/bignum.h +++ b/src/s2/util/math/exactfloat/bignum.h @@ -200,10 +200,6 @@ class Bignum { Normalize(); } - // Multiplies two unsigned bigit vectors together using Karatsuba's algorithm. - static BigitVector KaratsubaMul(absl::Span a, - absl::Span b); - // Drop leading zero bigits, and ensure sign is positive if result is zero. void Normalize() { while (!bigits_.empty() && bigits_.back() == 0) { From 29be9ffd5ceb0635634bdb0653fae0fa2366f313 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Sat, 4 Oct 2025 09:15:08 -0600 Subject: [PATCH 10/31] PR Round 7 changes. - Ensure -Wsign-compare safety - Minor constexpr fixes - DCHECK fixes. --- src/s2/util/math/exactfloat/BUILD | 2 +- src/s2/util/math/exactfloat/bignum.cc | 24 +++++++++++------------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/s2/util/math/exactfloat/BUILD b/src/s2/util/math/exactfloat/BUILD index f02ce32b..2f7b5de9 100644 --- a/src/s2/util/math/exactfloat/BUILD +++ b/src/s2/util/math/exactfloat/BUILD @@ -4,7 +4,7 @@ cc_library( name = "bignum", srcs = ["bignum.cc"], hdrs = ["bignum.h"], - visibility = "//visibility:private" + visibility = ["//visibility:private"] ) cc_library( diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index 5e22b26a..5d4c8bbb 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -90,7 +90,7 @@ std::optional Bignum::FromString(absl::string_view s) { // Precomputed powers of 10. static constexpr auto kPow10 = []() { - std::array out = {1}; + std::array out = {1}; for (size_t i = 1; i < out.size(); ++i) { out[i] = 10 * out[i - 1]; } @@ -115,7 +115,7 @@ std::optional Bignum::FromString(absl::string_view s) { const auto end = s.cend(); while (begin < end) { - size_t chunk_len = std::min(std::distance(begin, end), kMaxChunkDigits); + ssize_t chunk_len = std::min(std::distance(begin, end), kMaxChunkDigits); Bigit chunk; auto result = std::from_chars(begin, begin + chunk_len, chunk); @@ -291,7 +291,7 @@ inline Bigit AddCarry(Bigit a, Bigit b, Bigit* absl_nonnull carry) { // // NOTE: Borrow must be one or zero. inline Bigit SubBorrow(Bigit a, Bigit b, Bigit* absl_nonnull borrow) { - ABSL_DCHECK_LE(borrow, 1); + ABSL_DCHECK_LE(*borrow, 1); Bigit diff = a - b - *borrow; *borrow = (a < b) || (*borrow && (a == b)); return diff; @@ -334,7 +334,7 @@ inline Bigit AddInPlace(absl::Span a, absl::Span b) { Bigit carry = 0; // Dispatch four at a time to help loop unrolling. - int i = 0; + size_t i = 0; while (i + 4 <= b.size()) { for (int j = 0; j < 4; ++j, ++i) { a[i] = AddCarry(a[i], b[i], &carry); @@ -347,8 +347,8 @@ inline Bigit AddInPlace(absl::Span a, absl::Span b) { } // Propagate carry through the rest of a. - int left = a.size() - b.size(); - for (; carry && left > 0; --left, ++i) { + size_t left = a.size() - b.size(); + for (; carry && i < a.size(); ++i) { a[i] = AddCarry(a[i], 0, &carry); } @@ -395,7 +395,7 @@ inline size_t AddOutOfPlace(absl::Span dst, absl::Span a, auto longer = (a.size() > b.size()) ? a : b; // Dispatch four at a time for the remaining part. - const int size = longer.size(); + const size_t size = longer.size(); while (i + 4 < size) { for (int j = 0; j < 4; ++j, ++i) { dst[i] = AddCarry(longer[i], 0, &carry); @@ -426,7 +426,7 @@ inline void SubInPlace(absl::Span a, absl::Span b) { // Dispatch four at a time to help loop unrolling. size_t size = b.size(); - int i = 0; + size_t i = 0; while (i + 4 <= size) { for (int j = 0; j < 4; ++j, ++i) { a[i] = SubBorrow(a[i], b[i], &borrow); @@ -453,7 +453,7 @@ inline void SubInPlace(absl::Span a, absl::Span b) { inline Bigit SubOutOfPlace(absl::Span dst, absl::Span a, absl::Span b, size_t digits) { ABSL_DCHECK_EQ(dst.size(), a.size()); - ABSL_DCHECK_LT(CmpAbs(a, b), 0); + ABSL_DCHECK_GE(CmpAbs(a, b), 1); Bigit borrow = 0; @@ -486,7 +486,7 @@ Bigit MulAdd(absl::Span out, absl::Span a, Bigit b, int left = a.size(); // Dispatch four at a time to help loop unrolling. - int i = 0; + size_t i = 0; while (i + 4 <= a.size()) { for (int j = 0; j < 4; ++j, ++i) { out[i] = MulCarry(a[i], b, &carry); @@ -506,11 +506,9 @@ Bigit MulAdd(absl::Span out, absl::Span a, Bigit b, // Returns the final carry, if any. inline Bigit MulAddInPlace(absl::Span out, absl::Span a, Bigit b) { - int left = a.size(); - // Dispatch four at a time to help loop unrolling. Bigit carry = 0; - int i = 0; + size_t i = 0; while (i + 4 <= a.size()) { for (int j = 0; j < 4; ++j, ++i) { MulAddCarry(out[i], a[i], b, &carry); From 308698e80d1940e0f96b9dce3abf8b324291ce35 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Sat, 4 Oct 2025 12:15:09 -0600 Subject: [PATCH 11/31] Cleanup util/math BUILD file and ensure clean compilation and run. - Fix any unused variables and missing dependencies. - Add ASAN config to .bazelrc and ensure clean run. - Fix -Wcompare-unsigned warning in exactfloat.cc Now builds and runs cleanly even with ASAN on (only ASAN notification is internal to Abseil). --- src/.bazelrc | 2 +- src/s2/util/math/exactfloat/BUILD | 30 ++++++++++++++--------- src/s2/util/math/exactfloat/bignum.cc | 5 ++-- src/s2/util/math/exactfloat/exactfloat.cc | 2 +- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/.bazelrc b/src/.bazelrc index a543738a..c7f7c7b7 100644 --- a/src/.bazelrc +++ b/src/.bazelrc @@ -1,3 +1,3 @@ # Enable Bzlmod for every Bazel command common --enable_bzlmod -common --cxxopt=-std=c++20 \ No newline at end of file +common --cxxopt=-std=c++20 diff --git a/src/s2/util/math/exactfloat/BUILD b/src/s2/util/math/exactfloat/BUILD index 2f7b5de9..b9fd86b9 100644 --- a/src/s2/util/math/exactfloat/BUILD +++ b/src/s2/util/math/exactfloat/BUILD @@ -4,7 +4,12 @@ cc_library( name = "bignum", srcs = ["bignum.cc"], hdrs = ["bignum.h"], - visibility = ["//visibility:private"] + visibility = ["//visibility:private"], + deps = [ + "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/container:inlined_vector", + "@abseil-cpp//absl/log:absl_check", + ], ) cc_library( @@ -13,17 +18,18 @@ cc_library( hdrs = ["exactfloat.h"], deps = [ ":bignum", - "//s2/base:types", - "//s2/base:port", "//s2/base:logging", - "@abseil-cpp//absl/container:inlined_vector", - "@abseil-cpp//absl/log:absl_check", - "@abseil-cpp//absl/log:absl_log", - "@abseil-cpp//absl/numeric:bits", - "@abseil-cpp//absl/numeric:int128", - "@abseil-cpp//absl/random:random", - "@abseil-cpp//absl/strings:ascii", - "@abseil-cpp//absl/strings:str_cat", - "@abseil-cpp//absl/strings:str_format", + "//s2/base:port", + ], +) + +cc_test( + name = "bignum_test", + srcs = ["bignum_test.cc"], + deps = [ + ":bignum", + "@abseil-cpp//absl/random:bit_gen_ref", + "@boringssl//:crypto", + "@googletest//:gtest_main", ], ) diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index 5d4c8bbb..d706b971 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -117,7 +117,7 @@ std::optional Bignum::FromString(absl::string_view s) { while (begin < end) { ssize_t chunk_len = std::min(std::distance(begin, end), kMaxChunkDigits); - Bigit chunk; + Bigit chunk = 0; auto result = std::from_chars(begin, begin + chunk_len, chunk); if (result.ec != std::errc() || (result.ptr - begin) != chunk_len) { return std::nullopt; @@ -291,7 +291,7 @@ inline Bigit AddCarry(Bigit a, Bigit b, Bigit* absl_nonnull carry) { // // NOTE: Borrow must be one or zero. inline Bigit SubBorrow(Bigit a, Bigit b, Bigit* absl_nonnull borrow) { - ABSL_DCHECK_LE(*borrow, 1); + ABSL_DCHECK_LE(*borrow, 1u); Bigit diff = a - b - *borrow; *borrow = (a < b) || (*borrow && (a == b)); return diff; @@ -347,7 +347,6 @@ inline Bigit AddInPlace(absl::Span a, absl::Span b) { } // Propagate carry through the rest of a. - size_t left = a.size() - b.size(); for (; carry && i < a.size(); ++i) { a[i] = AddCarry(a[i], 0, &carry); } diff --git a/src/s2/util/math/exactfloat/exactfloat.cc b/src/s2/util/math/exactfloat/exactfloat.cc index 89521ef4..bbc5e7d6 100644 --- a/src/s2/util/math/exactfloat/exactfloat.cc +++ b/src/s2/util/math/exactfloat/exactfloat.cc @@ -355,7 +355,7 @@ int ExactFloat::GetDecimalDigits(int max_digits, std::string* digits) const { std::string::size_type pos = digits->find_last_not_of('0') + 1; bn_exp10 += digits->size() - pos; digits->erase(pos); - ABSL_DCHECK_LE(digits->size(), max_digits); + ABSL_DCHECK_LE(static_cast(digits->size()), max_digits); // Finally, we adjust the base-10 exponent so that the mantissa is a // fraction in the range [0.1, 1) rather than an integer. From 7e90c5c6d5d0fc0f28df0e5d3b4ee5e06a2e165e Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Sat, 4 Oct 2025 12:21:46 -0600 Subject: [PATCH 12/31] Remove final usage of ssize_t. It's a POSIX standard, not standard C++ so avoid it entirely. --- src/s2/util/math/exactfloat/bignum.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index d706b971..a6bac321 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -80,7 +80,7 @@ int Bignum::Compare(const Bignum& b) const { std::optional Bignum::FromString(absl::string_view s) { // A chunk is up to 19 decimal digits, which can always fit into a Bigit. - constexpr ssize_t kMaxChunkDigits = std::numeric_limits::digits10; + constexpr int kMaxChunkDigits = std::numeric_limits::digits10; // NOTE: We use a simple multiply-and-add (aka Horner's) method here for the // sake of simplicity. This isn't the fastest algorithm, being quadratic in @@ -115,7 +115,8 @@ std::optional Bignum::FromString(absl::string_view s) { const auto end = s.cend(); while (begin < end) { - ssize_t chunk_len = std::min(std::distance(begin, end), kMaxChunkDigits); + int chunk_len = + std::min(static_cast(std::distance(begin, end)), kMaxChunkDigits); Bigit chunk = 0; auto result = std::from_chars(begin, begin + chunk_len, chunk); From 81f846aa41119f4b902e3ab866685d0c8d5cdb64 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Sat, 4 Oct 2025 12:42:24 -0600 Subject: [PATCH 13/31] Add exactfloat_test to BUILD file. --- src/s2/util/math/exactfloat/BUILD | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/s2/util/math/exactfloat/BUILD b/src/s2/util/math/exactfloat/BUILD index b9fd86b9..c73a4f36 100644 --- a/src/s2/util/math/exactfloat/BUILD +++ b/src/s2/util/math/exactfloat/BUILD @@ -33,3 +33,14 @@ cc_test( "@googletest//:gtest_main", ], ) + +cc_test( + name = "exactfloat_test", + srcs = ["exactfloat_test.cc"], + deps = [ + ":exactfloat", + "//s2/util/math:vector", + "@abseil-cpp//absl/random:random", + "@googletest//:gtest_main", + ], +) From c675cca8d443e1d4d1942d36fb9c4ee54f3c4317 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Sat, 4 Oct 2025 12:53:30 -0600 Subject: [PATCH 14/31] Use int64_t instead of int for size casts out of an abundance of caution. --- src/s2/util/math/exactfloat/bignum.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index a6bac321..817ac470 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -80,7 +80,7 @@ int Bignum::Compare(const Bignum& b) const { std::optional Bignum::FromString(absl::string_view s) { // A chunk is up to 19 decimal digits, which can always fit into a Bigit. - constexpr int kMaxChunkDigits = std::numeric_limits::digits10; + constexpr int64_t kMaxChunkDigits = std::numeric_limits::digits10; // NOTE: We use a simple multiply-and-add (aka Horner's) method here for the // sake of simplicity. This isn't the fastest algorithm, being quadratic in @@ -115,8 +115,8 @@ std::optional Bignum::FromString(absl::string_view s) { const auto end = s.cend(); while (begin < end) { - int chunk_len = - std::min(static_cast(std::distance(begin, end)), kMaxChunkDigits); + int64_t chunk_len = std::min( + static_cast(std::distance(begin, end)), kMaxChunkDigits); Bigit chunk = 0; auto result = std::from_chars(begin, begin + chunk_len, chunk); @@ -240,7 +240,7 @@ Bignum& Bignum::operator>>=(int nbit) { // Then, handle the within-bigit shift, if any. if (nrem != 0) { Bigit carry = 0; - for (int i = static_cast(bigits_.size()) - 1; i >= 0; --i) { + for (auto i = static_cast(bigits_.size()) - 1; i >= 0; --i) { const Bigit old_val = bigits_[i]; bigits_[i] = (old_val >> nrem) | carry; carry = old_val << (kBigitBits - nrem); @@ -292,7 +292,7 @@ inline Bigit AddCarry(Bigit a, Bigit b, Bigit* absl_nonnull carry) { // // NOTE: Borrow must be one or zero. inline Bigit SubBorrow(Bigit a, Bigit b, Bigit* absl_nonnull borrow) { - ABSL_DCHECK_LE(*borrow, 1u); + ABSL_DCHECK_LE(*borrow, Bigit(1)); Bigit diff = a - b - *borrow; *borrow = (a < b) || (*borrow && (a == b)); return diff; From 593414d70bc1cb7732f758b74420f3f798ff9e48 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Sat, 4 Oct 2025 12:54:51 -0600 Subject: [PATCH 15/31] Split INT_MIN construction test case into a separate named case. --- .../util/math/exactfloat/exactfloat_test.cc | 65 ++++++++++--------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/src/s2/util/math/exactfloat/exactfloat_test.cc b/src/s2/util/math/exactfloat/exactfloat_test.cc index bd95d684..7b9636c4 100644 --- a/src/s2/util/math/exactfloat/exactfloat_test.cc +++ b/src/s2/util/math/exactfloat/exactfloat_test.cc @@ -17,13 +17,14 @@ #include "s2/util/math/exactfloat/exactfloat.h" +#include + #include #include #include #include #include -#include #include "absl/base/casts.h" #include "absl/base/macros.h" #include "absl/log/absl_log.h" @@ -119,14 +120,14 @@ double scalbln(double a, long exp) { // Here we fix the rounding functions to match MPFloat, which clamps out of // range values and returns the maximum possible value for NaN. -#define FIX_INT_ROUNDING(T, fname) \ - T fname(double a) { \ +#define FIX_INT_ROUNDING(T, fname) \ + T fname(double a) { \ if (std::isnan(a)) return std::numeric_limits::max(); \ - if (a <= std::numeric_limits::min()) \ - return std::numeric_limits::min(); \ - if (a >= std::numeric_limits::max()) \ - return std::numeric_limits::max(); \ - return ::fname(a); \ + if (a <= std::numeric_limits::min()) \ + return std::numeric_limits::min(); \ + if (a >= std::numeric_limits::max()) \ + return std::numeric_limits::max(); \ + return ::fname(a); \ } FIX_INT_ROUNDING(long, lrint) @@ -252,7 +253,7 @@ class ExactFloatTest : public ::testing::Test { // Expect "actual" to have the given value when converted to a "double". // Two values are considered equivalent if they have the same bit pattern or // they are both NaN. (So for example, +0 and -0 are not equivalent.) - void ExpectSame(double expected, const ExactFloat &xf_actual) { + void ExpectSame(double expected, const ExactFloat& xf_actual) { double actual = xf_actual.ToDouble(); if (std::isnan(expected)) { EXPECT_TRUE(std::isnan(actual)); @@ -267,13 +268,13 @@ class ExactFloatTest : public ::testing::Test { // Like ExpectSame() but also check that "actual" has the expected precision. void ExpectSameWithPrec(double expected_value, int expected_prec, - const ExactFloat &xf_actual) { + const ExactFloat& xf_actual) { ExpectSame(expected_value, xf_actual); EXPECT_EQ(expected_prec, xf_actual.prec()); } // Log an error when a math intrinsic does not return the expected result. - static void AddMathcallFailure(const testing::Message &call_msg, + static void AddMathcallFailure(const testing::Message& call_msg, double expected, double actual) { ADD_FAILURE() << call_msg << "\nExpected (glibc): " << ExactFloat(expected) << "\nActual (ExactFloat): " << ExactFloat(actual) @@ -284,8 +285,8 @@ class ExactFloatTest : public ::testing::Test { // Given two versions "f" and "mp_f" of the unary function called "fname", // check that their results agree to within the given number of ulps for a // range of test arguments. - void TestMathcall1(const char *fname, double f(double), - ExactFloat mp_f(const ExactFloat &), uint64_t ulps) { + void TestMathcall1(const char* fname, double f(double), + ExactFloat mp_f(const ExactFloat&), uint64_t ulps) { for (int i = 0; i < kSpecialDoubleValues.size(); ++i) { double a = kSpecialDoubleValues[i]; double expected = f(a); @@ -301,8 +302,8 @@ class ExactFloatTest : public ::testing::Test { // Given two versions "f" and "mp_f" of the binary function called "fname", // check that their results agree to within the given number of ulps for a // range of test arguments. - void TestMathcall2(const char *fname, double f(double, double), - ExactFloat mp_f(const ExactFloat &, const ExactFloat &), + void TestMathcall2(const char* fname, double f(double, double), + ExactFloat mp_f(const ExactFloat&, const ExactFloat&), uint64_t ulps) { for (int i = 0; i < kSpecialDoubleValues.size(); ++i) { double a = kSpecialDoubleValues[i]; @@ -323,7 +324,7 @@ class ExactFloatTest : public ::testing::Test { // function "mp_f", check that they return the same result on a range of // test arguments. template - void TestMethod0(const char *fname, ResultType f(double), + void TestMethod0(const char* fname, ResultType f(double), ResultType (ExactFloat::*mp_f)() const) { for (int i = 0; i < kSpecialDoubleValues.size(); ++i) { double a = kSpecialDoubleValues[i]; @@ -341,8 +342,8 @@ class ExactFloatTest : public ::testing::Test { // returns the integer type ResultType, check that they return the same // value for a range of test arguments. template - void TestIntMathcall1(const char *fname, ResultType f(double), - ResultType mp_f(const ExactFloat &)) { + void TestIntMathcall1(const char* fname, ResultType f(double), + ResultType mp_f(const ExactFloat&)) { for (int i = 0; i < kSpecialDoubleValues.size(); ++i) { double a = kSpecialDoubleValues[i]; ResultType expected = f(a); @@ -360,8 +361,8 @@ class ExactFloatTest : public ::testing::Test { // integer argument), check that they return the same result for a range of // test arguments. "ExpType" is the type of the integer argument. template - void TestLdexpCall(const char *fname, double f(double, ExpType), - ExactFloat mp_f(const ExactFloat &, ExpType)) { + void TestLdexpCall(const char* fname, double f(double, ExpType), + ExactFloat mp_f(const ExactFloat&, ExpType)) { static const ExpType kUnsignedExpValues[] = { // Doesn't test with numeric_limits::min() because it's // undefined @@ -446,7 +447,9 @@ TEST_F(ExactFloatTest, Constructors) { // Copy constructor. ExactFloat e = c; ExpectSameWithPrec(-125, 7, e); +} +TEST_F(ExactFloatTest, IntMinConstruction) { // Ensure that construction with INT_MIN works properly. ExactFloat f = INT_MIN; ExpectSame(INT_MIN, f); @@ -624,7 +627,7 @@ TEST_F(ExactFloatTest, RoundToMaxPrec) { // corresponding C++ operator. #define TEST_MATHOP1(op_name, op) \ double op_name(double a) { return op(a); } \ - ExactFloat mp_##op_name(const ExactFloat &a) { return op(a); } \ + ExactFloat mp_##op_name(const ExactFloat& a) { return op(a); } \ TEST_F(ExactFloatTest, op_name) { \ TestMathcall1(#op_name, op_name, mp_##op_name, 0); \ } @@ -636,7 +639,7 @@ TEST_MATHOP1(minus, -) // corresponding C++ operator. #define TEST_MATHOP2(op_name, op) \ double op_name(double a, double b) { return (a)op(b); } \ - ExactFloat mp_##op_name(const ExactFloat &a, const ExactFloat &b) { \ + ExactFloat mp_##op_name(const ExactFloat& a, const ExactFloat& b) { \ return (a)op(b); \ } \ TEST_F(ExactFloatTest, op_name) { \ @@ -660,7 +663,7 @@ TEST_MATHOP2(not_greater, <=); (a) op(b); \ return a; \ } \ - ExactFloat mp_##op_name(const ExactFloat &a, const ExactFloat &b) { \ + ExactFloat mp_##op_name(const ExactFloat& a, const ExactFloat& b) { \ ExactFloat x = a; \ x op(b); \ return x; \ @@ -679,7 +682,7 @@ TEST_ASSIGNOP(times_equals, *=); #define TEST_MATHCALL1(func, ulps) \ /* We define a wrapper function around ExactFloat version of "func" */ \ /* so that we can take its address (gcc can't find it otherwise). */ \ - ExactFloat mp_##func(const ExactFloat &a) { return func(a); } \ + ExactFloat mp_##func(const ExactFloat& a) { return func(a); } \ TEST_F(ExactFloatTest, func) { TestMathcall1(#func, func, mp_##func, ulps); } // Test all the unary math instrinsics (in the same order as the .h file). @@ -696,7 +699,7 @@ TEST_MATHCALL1(round, 0) // Check that the ExactFloat and glibc versions of "func" always return the // same value to within the given number of ulps. #define TEST_MATHCALL2(func, ulps) \ - ExactFloat mp_##func(const ExactFloat &a, const ExactFloat &b) { \ + ExactFloat mp_##func(const ExactFloat& a, const ExactFloat& b) { \ return func(a, b); \ } \ TEST_F(ExactFloatTest, func) { TestMathcall2(#func, func, mp_##func, ulps); } @@ -712,7 +715,7 @@ TEST_MATHCALL2(copysign, 0) // integer value. #define TEST_INTEGER_MATHCALL1(ResultType, func) \ - ResultType mp_##func(const ExactFloat &a) { return func(a); } \ + ResultType mp_##func(const ExactFloat& a) { return func(a); } \ TEST_F(ExactFloatTest, func) { TestIntMathcall1(#func, func, mp_##func); } TEST_INTEGER_MATHCALL1(long, lrint); @@ -732,11 +735,11 @@ int frexp_exp(double a) { (void)frexp(a, &exp_part); return exp_part; } -ExactFloat mp_frexp_frac(const ExactFloat &a) { +ExactFloat mp_frexp_frac(const ExactFloat& a) { int exp_part; return frexp(a, &exp_part); } -int mp_frexp_exp(const ExactFloat &a) { +int mp_frexp_exp(const ExactFloat& a) { int exp_part; (void)frexp(a, &exp_part); return exp_part; @@ -752,7 +755,7 @@ TEST_F(ExactFloatTest, frexp) { // ldexp(), scalbn(), scalbln() #define TEST_LDEXP_CALL(ExpType, func) \ - ExactFloat mp_##func(const ExactFloat &a, ExpType exp) { \ + ExactFloat mp_##func(const ExactFloat& a, ExpType exp) { \ return func(a, exp); \ } \ TEST_F(ExactFloatTest, func) { TestLdexpCall(#func, func, mp_##func); } @@ -786,7 +789,7 @@ TEST_METHOD0_VS_FUNCTION(int, sgn, ref_sgn) // Test a zero-argument ExactFloat member function against a corresponding // one-argument std:: function. #define TEST_METHOD0_VS_STD_FN(ResultType, method, fn) \ - ResultType ref_##fn(double a) { return std::fn(a); } \ + ResultType ref_##fn(double a) { return std::fn(a); } \ TEST_METHOD0_VS_FUNCTION(ResultType, method, ref_##fn) TEST_METHOD0_VS_STD_FN(bool, is_inf, isinf) @@ -865,7 +868,7 @@ TEST_F(ExactFloatTest, Vector3_Part3) { EXPECT_EQ(1, y.z()); y.x(3); EXPECT_EQ(MyVec3(3, 4, 1), y); - ExactFloat *y_data = y.Data(); + ExactFloat* y_data = y.Data(); y_data[2] = 12; EXPECT_EQ(MyVec3(3, 4, 12), y); EXPECT_EQ(y.Data()[1], ExactFloat(4)); From f0756554833f929f0dac87ff3e71770d9a5bbc79 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Sat, 4 Oct 2025 15:25:48 -0600 Subject: [PATCH 16/31] Fix typo in bignum.cc. This caused us to subtract less than we should have in some cases. Expand bignum_test.cc test framework to test numbers from different size classes on the left and right. Verify that the tests would have caught the bug now. --- src/s2/util/math/exactfloat/bignum.cc | 2 +- src/s2/util/math/exactfloat/bignum_test.cc | 103 +++++++++------------ 2 files changed, 44 insertions(+), 61 deletions(-) diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index 817ac470..f08fc80e 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -774,7 +774,7 @@ Bignum& Bignum::operator+=(const Bignum& b) { // +|a| + -|b| --> -(|b| - |a|) // // So we can compute |b| - |a| and the final sign is the same as B. - size_t prev_size = bigits_.size(); + size_t prev_size = b.bigits_.size(); bigits_.resize(b.bigits_.size()); SubOutOfPlace(absl::MakeSpan(bigits_), b.bigits_, bigits_, prev_size); negative_ = b.is_negative(); diff --git a/src/s2/util/math/exactfloat/bignum_test.cc b/src/s2/util/math/exactfloat/bignum_test.cc index 135433c2..c5326e7d 100644 --- a/src/s2/util/math/exactfloat/bignum_test.cc +++ b/src/s2/util/math/exactfloat/bignum_test.cc @@ -824,53 +824,39 @@ std::vector RandomNumberStrings(absl::BitGenRef bitgen, return GenerateRandomNumberStrings(bitgen, static_cast(size_class)); } -class VsOpenSSLTest : public TestWithParam { +class VsOpenSSLTest + : public TestWithParam> { protected: - std::vector Numbers() { - return RandomNumberStrings(bitgen_, GetParam()); + std::vector> Numbers() { + auto numbers0 = RandomNumberStrings(bitgen_, GetParam().first); + auto numbers1 = RandomNumberStrings(bitgen_, GetParam().second); + ABSL_CHECK_EQ(numbers0.size(), numbers1.size()); + + std::vector> numbers; + numbers.reserve(numbers0.size()); + + for (size_t i = 0; i < numbers0.size(); ++i) { + numbers.emplace_back(numbers0[i], numbers1[i]); + } + return numbers; } private: absl::BitGen bitgen_; }; -TEST_P(VsOpenSSLTest, SquaringCorrect) { - // Test that multiplication produces correct results by comparing to - // OpenSSL. - BN_CTX* ctx = BN_CTX_new(); - for (const auto& number : Numbers()) { - // Test same number multiplication (most likely to trigger edge cases) - const Bignum bn_a = *Bignum::FromString(number); - const Bignum bn_result = bn_a * bn_a; - - const OpenSSLBignum ssl_a(number); - OpenSSLBignum ssl_result; - BN_mul(ssl_result.get(), ssl_a.get(), ssl_a.get(), ctx); - - // Compare string representations - char* ssl_str = BN_bn2dec(ssl_result.get()); - std::string bn_str = absl::StrFormat("%v", bn_result); - - EXPECT_EQ(bn_str, std::string(ssl_str)) - << "Mismatch for multiplication" - << "\nBignum result: " << bn_str.substr(0, 100) << "..." - << "\nOpenSSL result: " << std::string(ssl_str).substr(0, 100) << "..."; - OPENSSL_free(ssl_str); - } - BN_CTX_free(ctx); -} - TEST_P(VsOpenSSLTest, MultiplyCorrect) { - // Multiply by a small constant to test widely different operand sizes. + // Test that multiplication produces the same results as OpenSSL. BN_CTX* ctx = BN_CTX_new(); - for (const auto& number : Numbers()) { - // Test same number multiplication (most likely to trigger edge cases) - const Bignum bn_a = *Bignum::FromString(number); - const Bignum bn_result = Bignum(2) * bn_a; + for (const auto& [a, b] : Numbers()) { + const Bignum bn_a = *Bignum::FromString(a); + const Bignum bn_b = *Bignum::FromString(b); + const Bignum bn_result = bn_a * bn_b; - const OpenSSLBignum ssl_a(number); + const OpenSSLBignum ssl_a(a); + const OpenSSLBignum ssl_b(b); OpenSSLBignum ssl_result; - BN_mul(ssl_result.get(), OpenSSLBignum("2").get(), ssl_a.get(), ctx); + BN_mul(ssl_result.get(), ssl_a.get(), ssl_b.get(), ctx); // Compare string representations char* ssl_str = BN_bn2dec(ssl_result.get()); @@ -887,18 +873,14 @@ TEST_P(VsOpenSSLTest, MultiplyCorrect) { TEST_P(VsOpenSSLTest, AdditionCorrect) { // Test that addition produces correct results by comparing to OpenSSL. - const std::vector numbers = Numbers(); - for (size_t i = 0; i < numbers.size(); ++i) { - const auto& num_a = numbers[i]; - const auto& num_b = numbers[(i + 1) % numbers.size()]; - - const Bignum bn_a = *Bignum::FromString(num_a); - const Bignum bn_b = *Bignum::FromString(num_b); + for (const auto& [a, b] : Numbers()) { + const Bignum bn_a = *Bignum::FromString(a); + const Bignum bn_b = *Bignum::FromString(b); const Bignum bn_result = bn_a + bn_b; - const OpenSSLBignum ssl_a(num_a); - const OpenSSLBignum ssl_b(num_b); + const OpenSSLBignum ssl_a(a); + const OpenSSLBignum ssl_b(b); OpenSSLBignum ssl_result; BN_add(ssl_result.get(), ssl_a.get(), ssl_b.get()); @@ -917,18 +899,14 @@ TEST_P(VsOpenSSLTest, AdditionCorrect) { TEST_P(VsOpenSSLTest, SubtractionCorrect) { // Test that subtraction produces correct results by comparing to // OpenSSL. - const std::vector numbers = Numbers(); - for (size_t i = 0; i < numbers.size(); ++i) { - const auto& num_a = numbers[i]; - const auto& num_b = numbers[(i + 1) % numbers.size()]; - - const Bignum bn_a = *Bignum::FromString(num_a); - const Bignum bn_b = *Bignum::FromString(num_b); + for (const auto& [a, b] : Numbers()) { + const Bignum bn_a = *Bignum::FromString(a); + const Bignum bn_b = *Bignum::FromString(b); const Bignum bn_result = bn_a - bn_b; - const OpenSSLBignum ssl_a(num_a); - const OpenSSLBignum ssl_b(num_b); + const OpenSSLBignum ssl_a(a); + const OpenSSLBignum ssl_b(b); OpenSSLBignum ssl_result; BN_sub(ssl_result.get(), ssl_a.get(), ssl_b.get()); @@ -944,12 +922,17 @@ TEST_P(VsOpenSSLTest, SubtractionCorrect) { } } -INSTANTIATE_TEST_SUITE_P(VsOpenSSL, VsOpenSSLTest, - ::testing::Values(NumberSizeClass::kSmall, - NumberSizeClass::kMedium, - NumberSizeClass::kLarge, - NumberSizeClass::kHuge, - NumberSizeClass::kMega)); +// clang-format off +INSTANTIATE_TEST_SUITE_P( + VsOpenSSL, VsOpenSSLTest, ::testing::Values( + std::make_pair(NumberSizeClass::kSmall, NumberSizeClass::kSmall), + std::make_pair(NumberSizeClass::kSmall, NumberSizeClass::kHuge), + std::make_pair(NumberSizeClass::kHuge, NumberSizeClass::kSmall), + std::make_pair(NumberSizeClass::kMedium, NumberSizeClass::kMedium), + std::make_pair(NumberSizeClass::kLarge, NumberSizeClass::kLarge), + std::make_pair(NumberSizeClass::kHuge, NumberSizeClass::kHuge), + std::make_pair(NumberSizeClass::kMega, NumberSizeClass::kMega))); +// clang-format on // TODO: Enable once benchmark is integrated. #if 0 From 5853f7064dea1c42118d90fdf3f8cbbdb42d55d5 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Sat, 4 Oct 2025 15:38:19 -0600 Subject: [PATCH 17/31] Add comment for the Numbers() function. --- src/s2/util/math/exactfloat/bignum_test.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/s2/util/math/exactfloat/bignum_test.cc b/src/s2/util/math/exactfloat/bignum_test.cc index c5326e7d..9460896e 100644 --- a/src/s2/util/math/exactfloat/bignum_test.cc +++ b/src/s2/util/math/exactfloat/bignum_test.cc @@ -827,6 +827,12 @@ std::vector RandomNumberStrings(absl::BitGenRef bitgen, class VsOpenSSLTest : public TestWithParam> { protected: + // Returns a vector of pairs of numbers (as decimal strings) based on the + // parameters that the test suite was created with. + // + // E.g. If the test suite was instantiate with (kSmall, kHuge) as the number + // classes, then this will return a small value on the left and a huge value + // on the right. (kHuge, kSmall) would return the opposite. std::vector> Numbers() { auto numbers0 = RandomNumberStrings(bitgen_, GetParam().first); auto numbers1 = RandomNumberStrings(bitgen_, GetParam().second); From a470ace4777d8ee89a25b2fa18a3892f03679496 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Sat, 4 Oct 2025 16:37:15 -0600 Subject: [PATCH 18/31] Switch to std::unique_ptr in Arena. This prevents ASAN false positives due to leaving the vector unsized while still accessing it through its data() member. We want to intentionally leave the memory uninitialized. --- src/s2/util/math/exactfloat/bignum.cc | 16 +++-- .../util/math/exactfloat/exactfloat_test.cc | 63 +++++++++---------- 2 files changed, 43 insertions(+), 36 deletions(-) diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index f08fc80e..3e99b7bf 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -579,16 +579,23 @@ inline std::pair, absl::Span> Split(absl::Span span, // space when recursing in the Karatsuba multiply. The arena is pre-sized and // returns spans of memory via Alloc() which are then returned to the arena via // Release. +// +// NOTE: We use std::unique_ptr here instead of std::vector because we don't +// want to initialize the memory unnecessarily and using std::vector without +// resizing (and thus initializing) the container leads to false positives with +// ASAN. class Arena { public: - explicit Arena(size_t size) { data_.reserve(size); } + // TODO: Use make_unique_for_overwrite when on C++20. + explicit Arena(size_t size) + : size_(size), data_(std::unique_ptr(new Bigit[size])){}; // Allocates a span of length n from the arena. absl::Span Alloc(size_t n) { - ABSL_DCHECK_LE(used_ + n, data_.capacity()); + ABSL_DCHECK_LE(used_ + n, size_); size_t start = used_; used_ += n; - return absl::Span(data_.data() + start, n); + return absl::Span(data_.get() + start, n); } size_t Used() const { return used_; } @@ -600,8 +607,9 @@ class Arena { } private: + size_t size_ = 0; size_t used_ = 0; - std::vector data_; + std::unique_ptr data_; }; inline void KaratsubaMulRecursive(absl::Span dst, diff --git a/src/s2/util/math/exactfloat/exactfloat_test.cc b/src/s2/util/math/exactfloat/exactfloat_test.cc index 7b9636c4..0a035728 100644 --- a/src/s2/util/math/exactfloat/exactfloat_test.cc +++ b/src/s2/util/math/exactfloat/exactfloat_test.cc @@ -17,14 +17,13 @@ #include "s2/util/math/exactfloat/exactfloat.h" -#include - #include #include #include #include #include +#include #include "absl/base/casts.h" #include "absl/base/macros.h" #include "absl/log/absl_log.h" @@ -120,14 +119,14 @@ double scalbln(double a, long exp) { // Here we fix the rounding functions to match MPFloat, which clamps out of // range values and returns the maximum possible value for NaN. -#define FIX_INT_ROUNDING(T, fname) \ - T fname(double a) { \ +#define FIX_INT_ROUNDING(T, fname) \ + T fname(double a) { \ if (std::isnan(a)) return std::numeric_limits::max(); \ - if (a <= std::numeric_limits::min()) \ - return std::numeric_limits::min(); \ - if (a >= std::numeric_limits::max()) \ - return std::numeric_limits::max(); \ - return ::fname(a); \ + if (a <= std::numeric_limits::min()) \ + return std::numeric_limits::min(); \ + if (a >= std::numeric_limits::max()) \ + return std::numeric_limits::max(); \ + return ::fname(a); \ } FIX_INT_ROUNDING(long, lrint) @@ -253,7 +252,7 @@ class ExactFloatTest : public ::testing::Test { // Expect "actual" to have the given value when converted to a "double". // Two values are considered equivalent if they have the same bit pattern or // they are both NaN. (So for example, +0 and -0 are not equivalent.) - void ExpectSame(double expected, const ExactFloat& xf_actual) { + void ExpectSame(double expected, const ExactFloat &xf_actual) { double actual = xf_actual.ToDouble(); if (std::isnan(expected)) { EXPECT_TRUE(std::isnan(actual)); @@ -268,13 +267,13 @@ class ExactFloatTest : public ::testing::Test { // Like ExpectSame() but also check that "actual" has the expected precision. void ExpectSameWithPrec(double expected_value, int expected_prec, - const ExactFloat& xf_actual) { + const ExactFloat &xf_actual) { ExpectSame(expected_value, xf_actual); EXPECT_EQ(expected_prec, xf_actual.prec()); } // Log an error when a math intrinsic does not return the expected result. - static void AddMathcallFailure(const testing::Message& call_msg, + static void AddMathcallFailure(const testing::Message &call_msg, double expected, double actual) { ADD_FAILURE() << call_msg << "\nExpected (glibc): " << ExactFloat(expected) << "\nActual (ExactFloat): " << ExactFloat(actual) @@ -285,8 +284,8 @@ class ExactFloatTest : public ::testing::Test { // Given two versions "f" and "mp_f" of the unary function called "fname", // check that their results agree to within the given number of ulps for a // range of test arguments. - void TestMathcall1(const char* fname, double f(double), - ExactFloat mp_f(const ExactFloat&), uint64_t ulps) { + void TestMathcall1(const char *fname, double f(double), + ExactFloat mp_f(const ExactFloat &), uint64_t ulps) { for (int i = 0; i < kSpecialDoubleValues.size(); ++i) { double a = kSpecialDoubleValues[i]; double expected = f(a); @@ -302,8 +301,8 @@ class ExactFloatTest : public ::testing::Test { // Given two versions "f" and "mp_f" of the binary function called "fname", // check that their results agree to within the given number of ulps for a // range of test arguments. - void TestMathcall2(const char* fname, double f(double, double), - ExactFloat mp_f(const ExactFloat&, const ExactFloat&), + void TestMathcall2(const char *fname, double f(double, double), + ExactFloat mp_f(const ExactFloat &, const ExactFloat &), uint64_t ulps) { for (int i = 0; i < kSpecialDoubleValues.size(); ++i) { double a = kSpecialDoubleValues[i]; @@ -324,7 +323,7 @@ class ExactFloatTest : public ::testing::Test { // function "mp_f", check that they return the same result on a range of // test arguments. template - void TestMethod0(const char* fname, ResultType f(double), + void TestMethod0(const char *fname, ResultType f(double), ResultType (ExactFloat::*mp_f)() const) { for (int i = 0; i < kSpecialDoubleValues.size(); ++i) { double a = kSpecialDoubleValues[i]; @@ -342,8 +341,8 @@ class ExactFloatTest : public ::testing::Test { // returns the integer type ResultType, check that they return the same // value for a range of test arguments. template - void TestIntMathcall1(const char* fname, ResultType f(double), - ResultType mp_f(const ExactFloat&)) { + void TestIntMathcall1(const char *fname, ResultType f(double), + ResultType mp_f(const ExactFloat &)) { for (int i = 0; i < kSpecialDoubleValues.size(); ++i) { double a = kSpecialDoubleValues[i]; ResultType expected = f(a); @@ -361,8 +360,8 @@ class ExactFloatTest : public ::testing::Test { // integer argument), check that they return the same result for a range of // test arguments. "ExpType" is the type of the integer argument. template - void TestLdexpCall(const char* fname, double f(double, ExpType), - ExactFloat mp_f(const ExactFloat&, ExpType)) { + void TestLdexpCall(const char *fname, double f(double, ExpType), + ExactFloat mp_f(const ExactFloat &, ExpType)) { static const ExpType kUnsignedExpValues[] = { // Doesn't test with numeric_limits::min() because it's // undefined @@ -627,7 +626,7 @@ TEST_F(ExactFloatTest, RoundToMaxPrec) { // corresponding C++ operator. #define TEST_MATHOP1(op_name, op) \ double op_name(double a) { return op(a); } \ - ExactFloat mp_##op_name(const ExactFloat& a) { return op(a); } \ + ExactFloat mp_##op_name(const ExactFloat &a) { return op(a); } \ TEST_F(ExactFloatTest, op_name) { \ TestMathcall1(#op_name, op_name, mp_##op_name, 0); \ } @@ -639,7 +638,7 @@ TEST_MATHOP1(minus, -) // corresponding C++ operator. #define TEST_MATHOP2(op_name, op) \ double op_name(double a, double b) { return (a)op(b); } \ - ExactFloat mp_##op_name(const ExactFloat& a, const ExactFloat& b) { \ + ExactFloat mp_##op_name(const ExactFloat &a, const ExactFloat &b) { \ return (a)op(b); \ } \ TEST_F(ExactFloatTest, op_name) { \ @@ -663,7 +662,7 @@ TEST_MATHOP2(not_greater, <=); (a) op(b); \ return a; \ } \ - ExactFloat mp_##op_name(const ExactFloat& a, const ExactFloat& b) { \ + ExactFloat mp_##op_name(const ExactFloat &a, const ExactFloat &b) { \ ExactFloat x = a; \ x op(b); \ return x; \ @@ -682,7 +681,7 @@ TEST_ASSIGNOP(times_equals, *=); #define TEST_MATHCALL1(func, ulps) \ /* We define a wrapper function around ExactFloat version of "func" */ \ /* so that we can take its address (gcc can't find it otherwise). */ \ - ExactFloat mp_##func(const ExactFloat& a) { return func(a); } \ + ExactFloat mp_##func(const ExactFloat &a) { return func(a); } \ TEST_F(ExactFloatTest, func) { TestMathcall1(#func, func, mp_##func, ulps); } // Test all the unary math instrinsics (in the same order as the .h file). @@ -699,7 +698,7 @@ TEST_MATHCALL1(round, 0) // Check that the ExactFloat and glibc versions of "func" always return the // same value to within the given number of ulps. #define TEST_MATHCALL2(func, ulps) \ - ExactFloat mp_##func(const ExactFloat& a, const ExactFloat& b) { \ + ExactFloat mp_##func(const ExactFloat &a, const ExactFloat &b) { \ return func(a, b); \ } \ TEST_F(ExactFloatTest, func) { TestMathcall2(#func, func, mp_##func, ulps); } @@ -715,7 +714,7 @@ TEST_MATHCALL2(copysign, 0) // integer value. #define TEST_INTEGER_MATHCALL1(ResultType, func) \ - ResultType mp_##func(const ExactFloat& a) { return func(a); } \ + ResultType mp_##func(const ExactFloat &a) { return func(a); } \ TEST_F(ExactFloatTest, func) { TestIntMathcall1(#func, func, mp_##func); } TEST_INTEGER_MATHCALL1(long, lrint); @@ -735,11 +734,11 @@ int frexp_exp(double a) { (void)frexp(a, &exp_part); return exp_part; } -ExactFloat mp_frexp_frac(const ExactFloat& a) { +ExactFloat mp_frexp_frac(const ExactFloat &a) { int exp_part; return frexp(a, &exp_part); } -int mp_frexp_exp(const ExactFloat& a) { +int mp_frexp_exp(const ExactFloat &a) { int exp_part; (void)frexp(a, &exp_part); return exp_part; @@ -755,7 +754,7 @@ TEST_F(ExactFloatTest, frexp) { // ldexp(), scalbn(), scalbln() #define TEST_LDEXP_CALL(ExpType, func) \ - ExactFloat mp_##func(const ExactFloat& a, ExpType exp) { \ + ExactFloat mp_##func(const ExactFloat &a, ExpType exp) { \ return func(a, exp); \ } \ TEST_F(ExactFloatTest, func) { TestLdexpCall(#func, func, mp_##func); } @@ -789,7 +788,7 @@ TEST_METHOD0_VS_FUNCTION(int, sgn, ref_sgn) // Test a zero-argument ExactFloat member function against a corresponding // one-argument std:: function. #define TEST_METHOD0_VS_STD_FN(ResultType, method, fn) \ - ResultType ref_##fn(double a) { return std::fn(a); } \ + ResultType ref_##fn(double a) { return std::fn(a); } \ TEST_METHOD0_VS_FUNCTION(ResultType, method, ref_##fn) TEST_METHOD0_VS_STD_FN(bool, is_inf, isinf) @@ -868,7 +867,7 @@ TEST_F(ExactFloatTest, Vector3_Part3) { EXPECT_EQ(1, y.z()); y.x(3); EXPECT_EQ(MyVec3(3, 4, 1), y); - ExactFloat* y_data = y.Data(); + ExactFloat *y_data = y.Data(); y_data[2] = 12; EXPECT_EQ(MyVec3(3, 4, 12), y); EXPECT_EQ(y.Data()[1], ExactFloat(4)); From f9697b1626777c82e75ca1758f5ead1224db6af9 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Sun, 5 Oct 2025 12:41:38 -0600 Subject: [PATCH 19/31] Create tagged seed sequences for bignum tests and minor fixes. --- src/s2/util/math/exactfloat/BUILD | 2 ++ src/s2/util/math/exactfloat/bignum.cc | 4 +-- src/s2/util/math/exactfloat/bignum_test.cc | 31 ++++++++++++++-------- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/s2/util/math/exactfloat/BUILD b/src/s2/util/math/exactfloat/BUILD index c73a4f36..9dc5bc01 100644 --- a/src/s2/util/math/exactfloat/BUILD +++ b/src/s2/util/math/exactfloat/BUILD @@ -28,6 +28,8 @@ cc_test( srcs = ["bignum_test.cc"], deps = [ ":bignum", + "//:s2_testing_headers", + "@abseil-cpp//absl/log:log_streamer", "@abseil-cpp//absl/random:bit_gen_ref", "@boringssl//:crypto", "@googletest//:gtest_main", diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index 3e99b7bf..535fef7c 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -23,7 +23,6 @@ #include #include #include -#include #include "absl/algorithm/container.h" #include "absl/base/nullability.h" @@ -587,8 +586,7 @@ inline std::pair, absl::Span> Split(absl::Span span, class Arena { public: // TODO: Use make_unique_for_overwrite when on C++20. - explicit Arena(size_t size) - : size_(size), data_(std::unique_ptr(new Bigit[size])){}; + explicit Arena(size_t size) : size_(size), data_(new Bigit[size]){}; // Allocates a span of length n from the arena. absl::Span Alloc(size_t n) { diff --git a/src/s2/util/math/exactfloat/bignum_test.cc b/src/s2/util/math/exactfloat/bignum_test.cc index 9460896e..038edd4b 100644 --- a/src/s2/util/math/exactfloat/bignum_test.cc +++ b/src/s2/util/math/exactfloat/bignum_test.cc @@ -17,9 +17,8 @@ #include #include -#include +#include #include -#include #include #include #include @@ -29,6 +28,7 @@ #include "benchmark/benchmark.h" #endif +#include "absl/log/log_streamer.h" #include "absl/random/bit_gen_ref.h" #include "absl/random/random.h" #include "absl/strings/str_cat.h" @@ -37,6 +37,7 @@ #include "gtest/gtest.h" #include "openssl/bn.h" #include "openssl/crypto.h" +#include "s2/s2testing.h" namespace exactfloat_internal { @@ -833,9 +834,10 @@ class VsOpenSSLTest // E.g. If the test suite was instantiate with (kSmall, kHuge) as the number // classes, then this will return a small value on the left and a huge value // on the right. (kHuge, kSmall) would return the opposite. - std::vector> Numbers() { - auto numbers0 = RandomNumberStrings(bitgen_, GetParam().first); - auto numbers1 = RandomNumberStrings(bitgen_, GetParam().second); + std::vector> Numbers( + absl::BitGenRef bitgen) { + auto numbers0 = RandomNumberStrings(bitgen, GetParam().first); + auto numbers1 = RandomNumberStrings(bitgen, GetParam().second); ABSL_CHECK_EQ(numbers0.size(), numbers1.size()); std::vector> numbers; @@ -846,15 +848,15 @@ class VsOpenSSLTest } return numbers; } - - private: - absl::BitGen bitgen_; }; TEST_P(VsOpenSSLTest, MultiplyCorrect) { + absl::BitGen bitgen(S2Testing::MakeTaggedSeedSeq( + "MULTIPLY_CORRECT", absl::LogInfoStreamer(__FILE__, __LINE__).stream())); + // Test that multiplication produces the same results as OpenSSL. BN_CTX* ctx = BN_CTX_new(); - for (const auto& [a, b] : Numbers()) { + for (const auto& [a, b] : Numbers(bitgen)) { const Bignum bn_a = *Bignum::FromString(a); const Bignum bn_b = *Bignum::FromString(b); const Bignum bn_result = bn_a * bn_b; @@ -878,8 +880,11 @@ TEST_P(VsOpenSSLTest, MultiplyCorrect) { } TEST_P(VsOpenSSLTest, AdditionCorrect) { + absl::BitGen bitgen(S2Testing::MakeTaggedSeedSeq( + "ADDITION_CORRECT", absl::LogInfoStreamer(__FILE__, __LINE__).stream())); + // Test that addition produces correct results by comparing to OpenSSL. - for (const auto& [a, b] : Numbers()) { + for (const auto& [a, b] : Numbers(bitgen)) { const Bignum bn_a = *Bignum::FromString(a); const Bignum bn_b = *Bignum::FromString(b); @@ -903,9 +908,13 @@ TEST_P(VsOpenSSLTest, AdditionCorrect) { } TEST_P(VsOpenSSLTest, SubtractionCorrect) { + absl::BitGen bitgen(S2Testing::MakeTaggedSeedSeq( + "SUBTRACTION_CORRECT", + absl::LogInfoStreamer(__FILE__, __LINE__).stream())); + // Test that subtraction produces correct results by comparing to // OpenSSL. - for (const auto& [a, b] : Numbers()) { + for (const auto& [a, b] : Numbers(bitgen)) { const Bignum bn_a = *Bignum::FromString(a); const Bignum bn_b = *Bignum::FromString(b); From 746be61ffe155b410dfccd55d8dc3087daeb8c74 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Wed, 8 Oct 2025 08:40:20 -0600 Subject: [PATCH 20/31] CR Round 7 fixes. --- src/s2/util/math/exactfloat/bignum.cc | 146 ++++++++++------------ src/s2/util/math/exactfloat/bignum.h | 12 +- src/s2/util/math/exactfloat/exactfloat.cc | 20 ++- 3 files changed, 83 insertions(+), 95 deletions(-) diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index 535fef7c..60746be2 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -34,15 +34,15 @@ namespace exactfloat_internal { -// Number of bigits in the result of a multiplication before we fall back to -// simple multiplication in the Karatsuba recursion. Determined empirically. -static constexpr int kSimpleMulThreshold = 64; +// Number of bigits in smaller of the two operands before we fall back to simple +// multiplication in the Karatsuba recursion. Determined empirically. +static constexpr int kSimpleMulThreshold = 24; -// Computes out[i] = a[i]*b + c +// Computes dst[i] = a[i]*b + c // // Returns the final carry, if any. -inline Bigit MulAdd(absl::Span out, absl::Span a, Bigit b, - Bigit c); +inline Bigit MulAddWithCarry(absl::Span dst, absl::Span a, + Bigit b, Bigit carry); // Compares magnitude magnitude of two bigit vectors, returning -1, 0, or +1. // @@ -79,7 +79,7 @@ int Bignum::Compare(const Bignum& b) const { std::optional Bignum::FromString(absl::string_view s) { // A chunk is up to 19 decimal digits, which can always fit into a Bigit. - constexpr int64_t kMaxChunkDigits = std::numeric_limits::digits10; + constexpr size_t kMaxChunkDigits = std::numeric_limits::digits10; // NOTE: We use a simple multiply-and-add (aka Horner's) method here for the // sake of simplicity. This isn't the fastest algorithm, being quadratic in @@ -114,8 +114,8 @@ std::optional Bignum::FromString(absl::string_view s) { const auto end = s.cend(); while (begin < end) { - int64_t chunk_len = std::min( - static_cast(std::distance(begin, end)), kMaxChunkDigits); + int64_t chunk_len = std::min(static_cast(std::distance(begin, end)), + kMaxChunkDigits); Bigit chunk = 0; auto result = std::from_chars(begin, begin + chunk_len, chunk); @@ -124,9 +124,9 @@ std::optional Bignum::FromString(absl::string_view s) { } begin += chunk_len; - // Shift out up by chunk_len digits and add the chunk to it. + // Shift left by chunk_len digits and add the chunk to it. auto outspan = absl::MakeSpan(out.bigits_); - Bigit carry = MulAdd(outspan, outspan, kPow10[chunk_len], chunk); + Bigit carry = MulAddWithCarry(outspan, outspan, kPow10[chunk_len], chunk); if (carry) { out.bigits_.emplace_back(carry); } @@ -138,8 +138,8 @@ std::optional Bignum::FromString(absl::string_view s) { } int bit_width(const Bignum& a) { - ABSL_DCHECK(a.is_normalized()); - if (a.bigits_.empty()) { + ABSL_DCHECK(a.bigits_.empty() || a.bigits_.back() != 0); + if (a.is_zero()) { return 0; } @@ -152,10 +152,6 @@ int bit_width(const Bignum& a) { } int countr_zero(const Bignum& a) { - if (a.is_zero()) { - return 0; - } - int nzero = 0; for (Bigit bigit : a.bigits_) { if (bigit == 0) { @@ -170,10 +166,6 @@ int countr_zero(const Bignum& a) { bool Bignum::is_bit_set(int nbit) const { ABSL_DCHECK_GE(nbit, 0); - if (is_zero()) { - return false; - } - const size_t digit = nbit / kBigitBits; const size_t shift = nbit % kBigitBits; @@ -186,7 +178,7 @@ bool Bignum::is_bit_set(int nbit) const { Bignum Bignum::operator-() const { Bignum result = *this; - result.set_negative(!result.negative_); + result.negate(); return result; } @@ -281,7 +273,7 @@ Bignum Bignum::Pow(int32_t pow) const { } // Computes a + b + carry and updates the carry. -inline Bigit AddCarry(Bigit a, Bigit b, Bigit* absl_nonnull carry) { +inline Bigit AddWithCarry(Bigit a, Bigit b, Bigit* absl_nonnull carry) { auto sum = absl::uint128(a) + b + *carry; *carry = absl::Uint128High64(sum); return static_cast(sum); @@ -290,7 +282,7 @@ inline Bigit AddCarry(Bigit a, Bigit b, Bigit* absl_nonnull carry) { // Computes a - b - borrow and updates the borrow. // // NOTE: Borrow must be one or zero. -inline Bigit SubBorrow(Bigit a, Bigit b, Bigit* absl_nonnull borrow) { +inline Bigit SubWithBorrow(Bigit a, Bigit b, Bigit* absl_nonnull borrow) { ABSL_DCHECK_LE(*borrow, Bigit(1)); Bigit diff = a - b - *borrow; *borrow = (a < b) || (*borrow && (a == b)); @@ -298,20 +290,20 @@ inline Bigit SubBorrow(Bigit a, Bigit b, Bigit* absl_nonnull borrow) { } // Computes a * b + carry and updates the carry. -inline Bigit MulCarry(Bigit a, Bigit b, Bigit* absl_nonnull carry) { +inline Bigit MulWithCarry(Bigit a, Bigit b, Bigit* absl_nonnull carry) { auto sum = absl::uint128(a) * b + *carry; *carry = absl::Uint128High64(sum); return static_cast(sum); } -// Computes out += a * b + carry and updates the carry. +// Computes sum += a * b + carry and updates the carry. // // NOTE: Will not overflow even if a, b, and c are their maximum values. -inline void MulAddCarry(Bigit& out, Bigit a, Bigit b, - Bigit* absl_nonnull carry) { - auto sum = absl::uint128(a) * b + *carry + out; - *carry = absl::Uint128High64(sum); - out = static_cast(sum); +inline void MulAddWithCarry(Bigit* absl_nonnull sum, Bigit a, Bigit b, + Bigit* absl_nonnull carry) { + auto term = absl::uint128(a) * b + *carry + *sum; + *carry = absl::Uint128High64(term); + *sum = static_cast(term); } // Computes a += b in place. Returns the final carry (if any). @@ -337,18 +329,18 @@ inline Bigit AddInPlace(absl::Span a, absl::Span b) { size_t i = 0; while (i + 4 <= b.size()) { for (int j = 0; j < 4; ++j, ++i) { - a[i] = AddCarry(a[i], b[i], &carry); + a[i] = AddWithCarry(a[i], b[i], &carry); } } // Finish remainder. for (; i < b.size(); ++i) { - a[i] = AddCarry(a[i], b[i], &carry); + a[i] = AddWithCarry(a[i], b[i], &carry); } // Propagate carry through the rest of a. for (; carry && i < a.size(); ++i) { - a[i] = AddCarry(a[i], 0, &carry); + a[i] = AddWithCarry(a[i], 0, &carry); } return carry; @@ -368,8 +360,8 @@ inline Bigit AddInPlace(absl::Span a, absl::Span b) { // // Which is used in the Karatsuba multiplication, where we don't have the option // to expand the allocate space on demand. -inline size_t AddOutOfPlace(absl::Span dst, absl::Span a, - absl::Span b) { +inline size_t Add(absl::Span dst, absl::Span a, + absl::Span b) { const size_t max_size = std::max(a.size(), b.size()); const size_t min_size = std::min(a.size(), b.size()); ABSL_DCHECK_GE(dst.size(), max_size + 1); @@ -381,13 +373,13 @@ inline size_t AddOutOfPlace(absl::Span dst, absl::Span a, size_t i = 0; while (i + 4 < min_size) { for (int j = 0; j < 4; ++j, ++i) { - dst[i] = AddCarry(a[i], b[i], &carry); + dst[i] = AddWithCarry(a[i], b[i], &carry); } } // Finish remainder of the parts common to A and B. for (; i < min_size; ++i) { - dst[i] = AddCarry(a[i], b[i], &carry); + dst[i] = AddWithCarry(a[i], b[i], &carry); } // Copy remaining digits from the longer operand and propagate carry. @@ -397,13 +389,13 @@ inline size_t AddOutOfPlace(absl::Span dst, absl::Span a, const size_t size = longer.size(); while (i + 4 < size) { for (int j = 0; j < 4; ++j, ++i) { - dst[i] = AddCarry(longer[i], 0, &carry); + dst[i] = AddWithCarry(longer[i], 0, &carry); } } // Propagate carry through the longer operand. for (; i < size; ++i) { - dst[i] = AddCarry(longer[i], 0, &carry); + dst[i] = AddWithCarry(longer[i], 0, &carry); } if (carry) { @@ -428,13 +420,13 @@ inline void SubInPlace(absl::Span a, absl::Span b) { size_t i = 0; while (i + 4 <= size) { for (int j = 0; j < 4; ++j, ++i) { - a[i] = SubBorrow(a[i], b[i], &borrow); + a[i] = SubWithBorrow(a[i], b[i], &borrow); } } // Finish remainder of subtraction. for (; i < size; ++i) { - a[i] = SubBorrow(a[i], b[i], &borrow); + a[i] = SubWithBorrow(a[i], b[i], &borrow); } // Propagate the borrow through a. @@ -449,8 +441,8 @@ inline void SubInPlace(absl::Span a, absl::Span b) { // Requires |a| >= |b| and dst is thus the same size as a. // A must be expanded to match the size of B and the total number of digits // actually set in A must be passed in via a_digits. -inline Bigit SubOutOfPlace(absl::Span dst, absl::Span a, - absl::Span b, size_t digits) { +inline void Sub(absl::Span dst, absl::Span a, + absl::Span b, size_t digits) { ABSL_DCHECK_EQ(dst.size(), a.size()); ABSL_DCHECK_GE(CmpAbs(a, b), 1); @@ -461,62 +453,57 @@ inline Bigit SubOutOfPlace(absl::Span dst, absl::Span a, size_t i = 0; while (i + 4 < size) { for (int j = 0; j < 4; ++j, ++i) { - dst[i] = SubBorrow(a[i], b[i], &borrow); + dst[i] = SubWithBorrow(a[i], b[i], &borrow); } } // Finish remainder. for (; i < digits; ++i) { - dst[i] = SubBorrow(a[i], b[i], &borrow); + dst[i] = SubWithBorrow(a[i], b[i], &borrow); } // Propagate borrow through the rest of a. for (; borrow && i < a.size(); ++i) { - dst[i] = SubBorrow(a[i], 0, &borrow); + dst[i] = SubWithBorrow(a[i], 0, &borrow); } - - return borrow; } -Bigit MulAdd(absl::Span out, absl::Span a, Bigit b, - Bigit carry) { - ABSL_DCHECK_GE(out.size(), a.size()); - - int left = a.size(); +inline Bigit MulAddWithCarry(absl::Span dst, absl::Span a, + Bigit b, Bigit carry) { + ABSL_DCHECK_GE(dst.size(), a.size()); // Dispatch four at a time to help loop unrolling. size_t i = 0; while (i + 4 <= a.size()) { for (int j = 0; j < 4; ++j, ++i) { - out[i] = MulCarry(a[i], b, &carry); - --left; + dst[i] = MulWithCarry(a[i], b, &carry); } } for (; i < a.size(); ++i) { - out[i] = MulCarry(a[i], b, &carry); + dst[i] = MulWithCarry(a[i], b, &carry); } return carry; } -// Computes out[i] += a[i]*b in place. +// Computes sum[i] += a[i]*b in place. // // Returns the final carry, if any. -inline Bigit MulAddInPlace(absl::Span out, absl::Span a, +inline Bigit MulAddInPlace(absl::Span sum, absl::Span a, Bigit b) { // Dispatch four at a time to help loop unrolling. Bigit carry = 0; size_t i = 0; while (i + 4 <= a.size()) { for (int j = 0; j < 4; ++j, ++i) { - MulAddCarry(out[i], a[i], b, &carry); + MulAddWithCarry(&sum[i], a[i], b, &carry); } } // Finish remainder. for (; i < a.size(); ++i) { - MulAddCarry(out[i], a[i], b, &carry); + MulAddWithCarry(&sum[i], a[i], b, &carry); } return carry; @@ -528,9 +515,9 @@ inline Bigit MulAddInPlace(absl::Span out, absl::Span a, // case for the recursive Karatsuba algorithm below. // // NOTE: out must be at least as large as the sums of the sizes of A and B. -inline void MulQuadratic(absl::Span out, absl::Span a, +inline void MulQuadratic(absl::Span dst, absl::Span a, absl::Span b) { - ABSL_DCHECK_GE(out.size(), a.size() + b.size()); + ABSL_DCHECK_GE(dst.size(), a.size() + b.size()); // Make sure A is the longer of the two arguments. if (a.size() < b.size()) { @@ -539,20 +526,20 @@ inline void MulQuadratic(absl::Span out, absl::Span a, } if (b.empty()) { - absl::c_fill(out, 0); + absl::c_fill(dst, 0); return; } // Each call to MulAdd and MulAddInPlace only updates a.size() elements of out // so we manually set the carries as we go. We grab a span to the upper half // of out starting at a.size() to facilitate this. - auto upper = out.subspan(a.size()); - upper[0] = MulAdd(out, a, b[0], 0); + auto upper = dst.subspan(a.size()); + upper[0] = MulAddWithCarry(dst, a, b[0], 0); const size_t size = b.size(); size_t i = 1; for (; i < size; ++i) { - upper[i] = MulAddInPlace(out.subspan(i), a, b[i]); + upper[i] = MulAddInPlace(dst.subspan(i), a, b[i]); } // Finish zeroing out the upper half. @@ -561,10 +548,11 @@ inline void MulQuadratic(absl::Span out, absl::Span a, } } -// Split a span into at most two contiguous spans of length a and b. +// Split a span into two contiguous spans of length at most a and b. // -// If a + b < span.size() then the two spans only cover part of the input. -// If span.size() <= a, then the second span is empty. +// If span.size() <= a, the second span is empty, otherwise the second span +// has length at most b. If span.size() > a + b, then the two spans only cover +// part of the input span. template inline std::pair, absl::Span> Split(absl::Span span, size_t a, size_t b) { @@ -586,7 +574,7 @@ inline std::pair, absl::Span> Split(absl::Span span, class Arena { public: // TODO: Use make_unique_for_overwrite when on C++20. - explicit Arena(size_t size) : size_(size), data_(new Bigit[size]){}; + explicit Arena(size_t size) : size_(size), data_(new Bigit[size]) {} // Allocates a span of length n from the arena. absl::Span Alloc(size_t n) { @@ -672,13 +660,13 @@ inline void KaratsubaMulRecursive(absl::Span dst, absl::Span asum = a0; if (!a1.empty()) { absl::Span tmp = arena->Alloc(half + 1); - asum = tmp.first(AddOutOfPlace(tmp, a0, a1)); + asum = tmp.first(Add(tmp, a0, a1)); } absl::Span bsum = b0; if (!b1.empty()) { absl::Span tmp = arena->Alloc(half + 1); - bsum = tmp.first(AddOutOfPlace(tmp, b0, b1)); + bsum = tmp.first(Add(tmp, b0, b1)); } // Compute z1 = asum*bsum - z0 - z2 = (a0 + a1)*(b0 + b1) - z0 - z2 @@ -696,7 +684,7 @@ inline void KaratsubaMulRecursive(absl::Span dst, // Z1 may overflow because of a carry in (a0 + b0) or (a1 + b1) but // subtracting z0 and z2 will always bring it back in range, trim any leading // zeros to shorten the value if needed. - while (z1.back() == 0) { + while (!z1.empty() && z1.back() == 0) { z1 = z1.first(z1.size() - 1); } @@ -711,11 +699,11 @@ inline void KaratsubaMulRecursive(absl::Span dst, // // This algorithm recursively subdivides the inputs until one or both is below // some threshold, and then falls back to standard long multiplication. -void KaratsubaMul(absl::Span out, absl::Span a, +void KaratsubaMul(absl::Span dst, absl::Span a, absl::Span b) { - ABSL_DCHECK_GE(out.size(), a.size() + b.size()); + ABSL_DCHECK_GE(dst.size(), a.size() + b.size()); if (a.empty() || b.empty()) { - absl::c_fill(out, 0); + absl::c_fill(dst, 0); return; } @@ -737,7 +725,7 @@ void KaratsubaMul(absl::Span out, absl::Span a, }; Arena arena(peak); - KaratsubaMulRecursive(out, a, b, &arena); + KaratsubaMulRecursive(dst, a, b, &arena); } Bignum& Bignum::operator+=(const Bignum& b) { @@ -782,7 +770,7 @@ Bignum& Bignum::operator+=(const Bignum& b) { // So we can compute |b| - |a| and the final sign is the same as B. size_t prev_size = b.bigits_.size(); bigits_.resize(b.bigits_.size()); - SubOutOfPlace(absl::MakeSpan(bigits_), b.bigits_, bigits_, prev_size); + Sub(absl::MakeSpan(bigits_), b.bigits_, bigits_, prev_size); negative_ = b.is_negative(); } } diff --git a/src/s2/util/math/exactfloat/bignum.h b/src/s2/util/math/exactfloat/bignum.h index adac4490..5a08165c 100644 --- a/src/s2/util/math/exactfloat/bignum.h +++ b/src/s2/util/math/exactfloat/bignum.h @@ -187,7 +187,9 @@ class Bignum { // Negates this bignum in place. void negate() { negative_ = !negative_; - Normalize(); + if (bigits_.empty()) { + negative_ = false; + } } // Compares to another bignum, returning -1, 0, +1. @@ -205,11 +207,11 @@ class Bignum { while (!bigits_.empty() && bigits_.back() == 0) { bigits_.pop_back(); } - negative_ = !bigits_.empty() && negative_; - } - // Returns true if the bignum is in normal form (no extra leading zeros). - bool is_normalized() const { return bigits_.empty() || bigits_.back() != 0; } + if (bigits_.empty()) { + negative_ = false; + } + } // We store bignums in sign-magnitude form. bigits_ contains the individual // 64-bit digits of the bignum. If bigits_ is non-empty, then the last element diff --git a/src/s2/util/math/exactfloat/exactfloat.cc b/src/s2/util/math/exactfloat/exactfloat.cc index bbc5e7d6..320420d9 100644 --- a/src/s2/util/math/exactfloat/exactfloat.cc +++ b/src/s2/util/math/exactfloat/exactfloat.cc @@ -340,7 +340,7 @@ int ExactFloat::GetDecimalDigits(int max_digits, std::string* digits) const { // up only if the lowest kept digit is odd. if (all_digits[max_digits] >= '5' && ((all_digits[max_digits - 1] & 1) == 1 || - all_digits.substr(max_digits + 1).find_first_not_of("0") != + all_digits.find_first_not_of('0', max_digits + 1) != std::string::npos)) { // This can increase the number of digits by 1, but in that case at // least one trailing zero will be stripped off below. @@ -415,18 +415,16 @@ ExactFloat ExactFloat::SignedSum(int a_sign, const ExactFloat* a, int b_sign, r.bn_ += b->bn_; r.sign_ = a_sign; } else { - if (r.bn_ >= b->bn_) { - // |a| >= |b|, compute |a| - |b|, result has same sign as a. - r.bn_ -= b->bn_; - r.sign_ = a_sign; - } else { - // |a| < |b|, compute -|a| + |b| == |b| - |a|, result has same sign as b. - r.bn_.negate(); - r.bn_ += b->bn_; - r.sign_ = b_sign; - } + r.bn_ -= b->bn_; if (r.bn_.is_zero()) { r.sign_ = +1; + } else if (r.bn_.is_negative()) { + // |b| was greater than |a|. + r.sign_ = b_sign; + r.bn_.set_negative(false); + } else { + // |a| was greater than |b| + r.sign_ = a_sign; } } r.Canonicalize(); From d1b692011007576dfb025ece4846efa629127da2 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Fri, 10 Oct 2025 10:35:17 -0600 Subject: [PATCH 21/31] CR round 8 changes. - Use _addcarry_u64/_subcarry_u64 intrinsics when available. These generate better carry chains in assembly and speed up some benchmarks by 30%. - Rename single bigit operations to Bigit to distinguish from the span ops. - General clarifying comments and cleanup. --- src/s2/util/math/exactfloat/bignum.cc | 151 +++++++++++++++----------- src/s2/util/math/exactfloat/bignum.h | 6 +- 2 files changed, 92 insertions(+), 65 deletions(-) diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index 60746be2..74cee2d4 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -15,6 +15,10 @@ #include "s2/util/math/exactfloat/bignum.h" +#ifdef __x86_64__ +#include +#endif + #include #include #include @@ -273,24 +277,38 @@ Bignum Bignum::Pow(int32_t pow) const { } // Computes a + b + carry and updates the carry. -inline Bigit AddWithCarry(Bigit a, Bigit b, Bigit* absl_nonnull carry) { +inline Bigit AddBigit(Bigit a, Bigit b, Bigit* absl_nonnull carry) { +#ifdef __x86_64__ + Bigit out; + *carry = + _addcarry_u64(*carry, a, b, reinterpret_cast(&out)); + return out; +#else auto sum = absl::uint128(a) + b + *carry; *carry = absl::Uint128High64(sum); return static_cast(sum); +#endif } // Computes a - b - borrow and updates the borrow. // // NOTE: Borrow must be one or zero. -inline Bigit SubWithBorrow(Bigit a, Bigit b, Bigit* absl_nonnull borrow) { +inline Bigit SubBigit(Bigit a, Bigit b, Bigit* absl_nonnull borrow) { ABSL_DCHECK_LE(*borrow, Bigit(1)); +#ifdef __x86_64__ + Bigit out; + *borrow = _subborrow_u64(*borrow, a, b, + reinterpret_cast(&out)); + return out; +#else Bigit diff = a - b - *borrow; *borrow = (a < b) || (*borrow && (a == b)); return diff; +#endif } // Computes a * b + carry and updates the carry. -inline Bigit MulWithCarry(Bigit a, Bigit b, Bigit* absl_nonnull carry) { +inline Bigit MulBigit(Bigit a, Bigit b, Bigit* absl_nonnull carry) { auto sum = absl::uint128(a) * b + *carry; *carry = absl::Uint128High64(sum); return static_cast(sum); @@ -299,8 +317,8 @@ inline Bigit MulWithCarry(Bigit a, Bigit b, Bigit* absl_nonnull carry) { // Computes sum += a * b + carry and updates the carry. // // NOTE: Will not overflow even if a, b, and c are their maximum values. -inline void MulAddWithCarry(Bigit* absl_nonnull sum, Bigit a, Bigit b, - Bigit* absl_nonnull carry) { +inline void MulAddBigit(Bigit* absl_nonnull sum, Bigit a, Bigit b, + Bigit* absl_nonnull carry) { auto term = absl::uint128(a) * b + *carry + *sum; *carry = absl::Uint128High64(term); *sum = static_cast(term); @@ -329,18 +347,18 @@ inline Bigit AddInPlace(absl::Span a, absl::Span b) { size_t i = 0; while (i + 4 <= b.size()) { for (int j = 0; j < 4; ++j, ++i) { - a[i] = AddWithCarry(a[i], b[i], &carry); + a[i] = AddBigit(a[i], b[i], &carry); } } // Finish remainder. for (; i < b.size(); ++i) { - a[i] = AddWithCarry(a[i], b[i], &carry); + a[i] = AddBigit(a[i], b[i], &carry); } // Propagate carry through the rest of a. for (; carry && i < a.size(); ++i) { - a[i] = AddWithCarry(a[i], 0, &carry); + a[i] = AddBigit(a[i], 0, &carry); } return carry; @@ -373,13 +391,13 @@ inline size_t Add(absl::Span dst, absl::Span a, size_t i = 0; while (i + 4 < min_size) { for (int j = 0; j < 4; ++j, ++i) { - dst[i] = AddWithCarry(a[i], b[i], &carry); + dst[i] = AddBigit(a[i], b[i], &carry); } } // Finish remainder of the parts common to A and B. for (; i < min_size; ++i) { - dst[i] = AddWithCarry(a[i], b[i], &carry); + dst[i] = AddBigit(a[i], b[i], &carry); } // Copy remaining digits from the longer operand and propagate carry. @@ -389,13 +407,13 @@ inline size_t Add(absl::Span dst, absl::Span a, const size_t size = longer.size(); while (i + 4 < size) { for (int j = 0; j < 4; ++j, ++i) { - dst[i] = AddWithCarry(longer[i], 0, &carry); + dst[i] = AddBigit(longer[i], 0, &carry); } } // Propagate carry through the longer operand. for (; i < size; ++i) { - dst[i] = AddWithCarry(longer[i], 0, &carry); + dst[i] = AddBigit(longer[i], 0, &carry); } if (carry) { @@ -420,13 +438,13 @@ inline void SubInPlace(absl::Span a, absl::Span b) { size_t i = 0; while (i + 4 <= size) { for (int j = 0; j < 4; ++j, ++i) { - a[i] = SubWithBorrow(a[i], b[i], &borrow); + a[i] = SubBigit(a[i], b[i], &borrow); } } // Finish remainder of subtraction. for (; i < size; ++i) { - a[i] = SubWithBorrow(a[i], b[i], &borrow); + a[i] = SubBigit(a[i], b[i], &borrow); } // Propagate the borrow through a. @@ -436,35 +454,32 @@ inline void SubInPlace(absl::Span a, absl::Span b) { } } -// Computes dst = a - b. +// Computes a = b - a. // -// Requires |a| >= |b| and dst is thus the same size as a. -// A must be expanded to match the size of B and the total number of digits -// actually set in A must be passed in via a_digits. -inline void Sub(absl::Span dst, absl::Span a, - absl::Span b, size_t digits) { - ABSL_DCHECK_EQ(dst.size(), a.size()); - ABSL_DCHECK_GE(CmpAbs(a, b), 1); +// NOTE: Requires |b| >= |a|. +inline void SubReverseInPlace(absl::Span a, absl::Span b) { + ABSL_DCHECK_GE(b.size(), a.size()); + ABSL_DCHECK_GE(CmpAbs(b, a), 0); Bigit borrow = 0; // Dispatch four at a time to help loop unrolling. - size_t size = digits; + size_t size = a.size(); size_t i = 0; - while (i + 4 < size) { + while (i + 4 <= size) { for (int j = 0; j < 4; ++j, ++i) { - dst[i] = SubWithBorrow(a[i], b[i], &borrow); + a[i] = SubBigit(b[i], a[i], &borrow); } } // Finish remainder. - for (; i < digits; ++i) { - dst[i] = SubWithBorrow(a[i], b[i], &borrow); + for (; i < size; ++i) { + a[i] = SubBigit(b[i], a[i], &borrow); } - // Propagate borrow through the rest of a. + // Propagate borrow through the rest of dst. for (; borrow && i < a.size(); ++i) { - dst[i] = SubWithBorrow(a[i], 0, &borrow); + a[i] = SubBigit(a[i], 0, &borrow); } } @@ -476,12 +491,12 @@ inline Bigit MulAddWithCarry(absl::Span dst, absl::Span a, size_t i = 0; while (i + 4 <= a.size()) { for (int j = 0; j < 4; ++j, ++i) { - dst[i] = MulWithCarry(a[i], b, &carry); + dst[i] = MulBigit(a[i], b, &carry); } } for (; i < a.size(); ++i) { - dst[i] = MulWithCarry(a[i], b, &carry); + dst[i] = MulBigit(a[i], b, &carry); } return carry; @@ -497,13 +512,13 @@ inline Bigit MulAddInPlace(absl::Span sum, absl::Span a, size_t i = 0; while (i + 4 <= a.size()) { for (int j = 0; j < 4; ++j, ++i) { - MulAddWithCarry(&sum[i], a[i], b, &carry); + MulAddBigit(&sum[i], a[i], b, &carry); } } // Finish remainder. for (; i < a.size(); ++i) { - MulAddWithCarry(&sum[i], a[i], b, &carry); + MulAddBigit(&sum[i], a[i], b, &carry); } return carry; @@ -584,6 +599,8 @@ class Arena { return absl::Span(data_.get() + start, n); } + size_t Available() const { return size_ - used_; } + size_t Used() const { return used_; } // Resets the arena to the given position which must be < Used(). @@ -598,11 +615,39 @@ class Arena { std::unique_ptr data_; }; +// Returns the total arena size needed to multiply two number of a_size and +// b_size bigits using the recursive Karatsuba implementation. +inline size_t ArenaSize(size_t a_size, size_t b_size) { + // Each step of Karatsuba splits at: + // N = (std::max(a.size() + b.size() + 1) / 2 + // + // We have to hold a total of 4*(N + 1) bigits as temporaries at each step. + // + // Simulate the recursion (log(n) steps) and compute the arena size. + int peak = 0; + while (std::min(a_size, b_size) > kSimpleMulThreshold) { + int half = (std::max(a_size, b_size) + 1) / 2; + int next = half + 1; + peak += 4 * next; + a_size = next; + b_size = next; + }; + return peak; +} + +// Recursive step in the Karatsuba multiplication. dst must be large enough to +// hold the product of a and b (i.e. it must be at least as large as a.size() + +// b.size()). The product is computed and stored in-place in dst. +// +// Additionally an arena must be provided for temporary storage for intermediate +// products. The arena must have at least ArenaSize(a.size(), b.size()) space +// available. inline void KaratsubaMulRecursive(absl::Span dst, absl::Span a, absl::Span b, Arena* absl_nonnull arena) { ABSL_DCHECK_GE(dst.size(), a.size() + b.size()); + ABSL_DCHECK_GE(arena->Available(), ArenaSize(a.size(), b.size())); if (a.empty() || b.empty()) { absl::c_fill(dst, 0); return; @@ -640,7 +685,7 @@ inline void KaratsubaMulRecursive(absl::Span dst, return; } - const int half = (std::max(a.size(), b.size()) + 1) / 2; + const size_t half = (std::max(a.size(), b.size()) + 1) / 2; // Split the inputs into contiguous subspans. auto [a0, a1] = Split(a, half, half); @@ -681,15 +726,15 @@ inline void KaratsubaMulRecursive(absl::Span dst, SubInPlace(z1, z2); } - // Z1 may overflow because of a carry in (a0 + b0) or (a1 + b1) but - // subtracting z0 and z2 will always bring it back in range, trim any leading - // zeros to shorten the value if needed. - while (!z1.empty() && z1.back() == 0) { - z1 = z1.first(z1.size() - 1); - } + // We need to add z1*10^half, which we can do by simply adding z1 at a shifted + // position in the output. + auto dst_z1 = dst.subspan(half); - // We need to add z1*10^half which we can do by adding it at an offset. - AddInPlace(dst.subspan(half), z1); + // Z1 may overflow to half + 1 bigits because of a carry. This is a problem + // when we only have space for half bigits in the final sum. Fortunate, + // subtracting z0 and z2 will always bring it back into range, and we simply + // have to trim the leading zeros, if any. + AddInPlace(dst_z1, z1.first(std::min(z1.size(), dst_z1.size()))); // Release temporary memory we used. arena->Reset(arena_start); @@ -707,24 +752,7 @@ void KaratsubaMul(absl::Span dst, absl::Span a, return; } - // Each step of Karatsuba splits at: - // N = (std::max(a.size() + b.size() + 1) / 2 - // - // We have to hold a total of 4*(N + 1) bigits as temporaries at each step. - // - // Simulate the recursion (log(n) steps) and compute the arena size. - int a_size = a.size(); - int b_size = b.size(); - int peak = 0; - while (std::min(a_size, b_size) > kSimpleMulThreshold) { - int half = (std::max(a_size, b_size) + 1) / 2; - int next = half + 1; - peak += 4 * next; - a_size = next; - b_size = next; - }; - - Arena arena(peak); + Arena arena(ArenaSize(a.size(), b.size())); KaratsubaMulRecursive(dst, a, b, &arena); } @@ -768,9 +796,8 @@ Bignum& Bignum::operator+=(const Bignum& b) { // +|a| + -|b| --> -(|b| - |a|) // // So we can compute |b| - |a| and the final sign is the same as B. - size_t prev_size = b.bigits_.size(); bigits_.resize(b.bigits_.size()); - Sub(absl::MakeSpan(bigits_), b.bigits_, bigits_, prev_size); + SubReverseInPlace(absl::MakeSpan(bigits_), b.bigits_); negative_ = b.is_negative(); } } diff --git a/src/s2/util/math/exactfloat/bignum.h b/src/s2/util/math/exactfloat/bignum.h index 5a08165c..c85bb8ea 100644 --- a/src/s2/util/math/exactfloat/bignum.h +++ b/src/s2/util/math/exactfloat/bignum.h @@ -214,9 +214,9 @@ class Bignum { } // We store bignums in sign-magnitude form. bigits_ contains the individual - // 64-bit digits of the bignum. If bigits_ is non-empty, then the last element - // must be non-zero and when it is empty (representing a zero value), - // negative_ must be false. + // 64-bit digits of the bignum, stored in little-endian order. If bigits_ is + // non-empty, then the last element must be non-zero and when it is empty + // (representing a zero value), negative_ must be false. BigitVector bigits_; bool negative_ = false; }; From eb1d5018f118495d35f793936a195df367c1719c Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Sat, 11 Oct 2025 08:37:17 -0600 Subject: [PATCH 22/31] Minor renaming and comment fixes. --- src/s2/util/math/exactfloat/bignum.cc | 28 +++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index 74cee2d4..73ca2930 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -45,8 +45,8 @@ static constexpr int kSimpleMulThreshold = 24; // Computes dst[i] = a[i]*b + c // // Returns the final carry, if any. -inline Bigit MulAddWithCarry(absl::Span dst, absl::Span a, - Bigit b, Bigit carry); +inline Bigit MulWithCarry(absl::Span dst, absl::Span a, + Bigit b, Bigit carry); // Compares magnitude magnitude of two bigit vectors, returning -1, 0, or +1. // @@ -130,7 +130,7 @@ std::optional Bignum::FromString(absl::string_view s) { // Shift left by chunk_len digits and add the chunk to it. auto outspan = absl::MakeSpan(out.bigits_); - Bigit carry = MulAddWithCarry(outspan, outspan, kPow10[chunk_len], chunk); + Bigit carry = MulWithCarry(outspan, outspan, kPow10[chunk_len], chunk); if (carry) { out.bigits_.emplace_back(carry); } @@ -457,8 +457,11 @@ inline void SubInPlace(absl::Span a, absl::Span b) { // Computes a = b - a. // // NOTE: Requires |b| >= |a|. +// +// Since we write the result to a, but b is larger, a must be expanded with +// enough leading zeros to fit the result. inline void SubReverseInPlace(absl::Span a, absl::Span b) { - ABSL_DCHECK_GE(b.size(), a.size()); + ABSL_DCHECK_GE(a.size(), b.size()); ABSL_DCHECK_GE(CmpAbs(b, a), 0); Bigit borrow = 0; @@ -477,14 +480,14 @@ inline void SubReverseInPlace(absl::Span a, absl::Span b) { a[i] = SubBigit(b[i], a[i], &borrow); } - // Propagate borrow through the rest of dst. + // Propagate borrow through the rest of a. for (; borrow && i < a.size(); ++i) { a[i] = SubBigit(a[i], 0, &borrow); } } -inline Bigit MulAddWithCarry(absl::Span dst, absl::Span a, - Bigit b, Bigit carry) { +inline Bigit MulWithCarry(absl::Span dst, absl::Span a, + Bigit b, Bigit carry) { ABSL_DCHECK_GE(dst.size(), a.size()); // Dispatch four at a time to help loop unrolling. @@ -549,7 +552,7 @@ inline void MulQuadratic(absl::Span dst, absl::Span a, // so we manually set the carries as we go. We grab a span to the upper half // of out starting at a.size() to facilitate this. auto upper = dst.subspan(a.size()); - upper[0] = MulAddWithCarry(dst, a, b[0], 0); + upper[0] = MulWithCarry(dst, a, b[0], 0); const size_t size = b.size(); size_t i = 1; @@ -730,10 +733,11 @@ inline void KaratsubaMulRecursive(absl::Span dst, // position in the output. auto dst_z1 = dst.subspan(half); - // Z1 may overflow to half + 1 bigits because of a carry. This is a problem - // when we only have space for half bigits in the final sum. Fortunate, - // subtracting z0 and z2 will always bring it back into range, and we simply - // have to trim the leading zeros, if any. + // Although the value of z1 is guaranteed to fit in the available space of + // dst, it may have one or more high-order zero bigits because it was sized + // conservatively to hold the intermediate result (asum * bsum). We trim these + // leading zeros if necessary to ensure that the Add() operation below does + // not attempt to write zero bigits past the end of dst. AddInPlace(dst_z1, z1.first(std::min(z1.size(), dst_z1.size()))); // Release temporary memory we used. From dde62ae76534c58c774ef650858075575799408a Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Sat, 11 Oct 2025 15:30:24 -0600 Subject: [PATCH 23/31] Remove redundant borrow propagation loop. --- src/s2/util/math/exactfloat/bignum.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index 73ca2930..abaad92b 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -479,11 +479,6 @@ inline void SubReverseInPlace(absl::Span a, absl::Span b) { for (; i < size; ++i) { a[i] = SubBigit(b[i], a[i], &borrow); } - - // Propagate borrow through the rest of a. - for (; borrow && i < a.size(); ++i) { - a[i] = SubBigit(a[i], 0, &borrow); - } } inline Bigit MulWithCarry(absl::Span dst, absl::Span a, From f434e10a64a421526f42efa773fb69711fd53521 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Wed, 22 Oct 2025 09:47:34 -0600 Subject: [PATCH 24/31] Revert .bazelrc changes. --- src/.bazelrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/.bazelrc b/src/.bazelrc index c7f7c7b7..a543738a 100644 --- a/src/.bazelrc +++ b/src/.bazelrc @@ -1,3 +1,3 @@ # Enable Bzlmod for every Bazel command common --enable_bzlmod -common --cxxopt=-std=c++20 +common --cxxopt=-std=c++20 \ No newline at end of file From 26b180bb37353a05bce8929d7bc0dba6c0076c0f Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Fri, 24 Oct 2025 12:29:10 -0600 Subject: [PATCH 25/31] Minor formatting cleanup. --- src/s2/util/math/exactfloat/exactfloat.cc | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/src/s2/util/math/exactfloat/exactfloat.cc b/src/s2/util/math/exactfloat/exactfloat.cc index cd616b3e..bba16288 100644 --- a/src/s2/util/math/exactfloat/exactfloat.cc +++ b/src/s2/util/math/exactfloat/exactfloat.cc @@ -25,17 +25,14 @@ #include #include -#include "absl/container/fixed_array.h" // IWYU pragma: keep #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" -#include "absl/numeric/bits.h" // IWYU pragma: keep #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" namespace exactfloat { using std::max; -using std::min; // To simplify the overflow/underflow logic, we limit the exponent and // precision range so that (2 * bn_exp_) does not overflow an "int". We take @@ -124,15 +121,11 @@ void ExactFloat::set_nan() { int fpclassify(ExactFloat const& x) { switch (x.bn_exp_) { - case ExactFloat::kExpNaN: - return FP_NAN; - case ExactFloat::kExpInfinity: - return FP_INFINITE; - case ExactFloat::kExpZero: - return FP_ZERO; + case ExactFloat::kExpNaN: return FP_NAN; + case ExactFloat::kExpInfinity: return FP_INFINITE; + case ExactFloat::kExpZero: return FP_ZERO; // There are no subnormal `ExactFloat`s. - default: - return FP_NORMAL; + default: return FP_NORMAL; } } From 6d5b227aa5437586e8abddba6ef57084ee0998b7 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Fri, 24 Oct 2025 13:47:34 -0600 Subject: [PATCH 26/31] Revert to using find_package for OpenSSL. --- CMakeLists.txt | 10 ++++++++-- README.md | 2 ++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 43ca0527..8d2e5cb3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,6 +60,7 @@ add_definitions(-DABSL_MIN_LOG_LEVEL=1) if (NOT TARGET absl::base) find_package(absl REQUIRED) endif() +find_package(OpenSSL REQUIRED) # pthreads isn't used directly, but this is still required for std::thread. find_package(Threads REQUIRED) @@ -67,6 +68,11 @@ if (NOT TARGET absl::vlog_is_on) message(FATAL_ERROR "Could not find absl vlog module. Are you using an older version?") endif() +# If OpenSSL is installed in a non-standard location, configure with +# something like: +# OPENSSL_ROOT_DIR=/usr/local/opt/openssl cmake .. +include_directories(${OPENSSL_INCLUDE_DIR}) + if (WITH_PYTHON) # Should be easy to make it work with swig3, but some args to %pythonprepend # seem to be different and were changed. @@ -627,6 +633,7 @@ if (BUILD_TESTS) add_executable(${test} ${test_cc}) target_link_libraries( ${test} + ${OPENSSL_CRYPTO_LIBRARIES} s2testing s2 absl::base absl::btree @@ -641,8 +648,7 @@ if (BUILD_TESTS) absl::status absl::strings absl::synchronization - gmock_main - crypto) + gmock_main) add_test(${test} ${test}) endforeach() endif() diff --git a/README.md b/README.md index 38b1be44..2fce96e7 100644 --- a/README.md +++ b/README.md @@ -224,6 +224,8 @@ python -m build The resulting wheel will be in the `dist` directory. +> If OpenSSL is in a non-standard location make sure to set `OPENSSL_ROOT_DIR`; +> see above for more information. ## Other S2 implementations From 05605eaa0e766c10d4281c1dabba1d38fb572430 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Fri, 24 Oct 2025 14:14:05 -0600 Subject: [PATCH 27/31] Remove 0 - term when static_cast-ing INT_MIN to take absolute value. The cast itself adds 2^32 which is equivalent to taking the absolute value. --- src/s2/util/math/exactfloat/exactfloat.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/s2/util/math/exactfloat/exactfloat.cc b/src/s2/util/math/exactfloat/exactfloat.cc index bba16288..91695d46 100644 --- a/src/s2/util/math/exactfloat/exactfloat.cc +++ b/src/s2/util/math/exactfloat/exactfloat.cc @@ -69,7 +69,7 @@ ExactFloat::ExactFloat(int v) { bn_exp_ = 0; if (v == std::numeric_limits::min()) { - bn_ = Bignum(unsigned(0) - static_cast(v)); + bn_ = Bignum(static_cast(v)); } else { bn_ = Bignum(abs(v)); } From 4fc88066dd29e857e7393000488bea2c9110254f Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Fri, 24 Oct 2025 14:20:47 -0600 Subject: [PATCH 28/31] Make find_package(OpenSSL) test-only and minor fixups. --- CMakeLists.txt | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d2e5cb3..e86bb2fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,7 +60,6 @@ add_definitions(-DABSL_MIN_LOG_LEVEL=1) if (NOT TARGET absl::base) find_package(absl REQUIRED) endif() -find_package(OpenSSL REQUIRED) # pthreads isn't used directly, but this is still required for std::thread. find_package(Threads REQUIRED) @@ -68,11 +67,6 @@ if (NOT TARGET absl::vlog_is_on) message(FATAL_ERROR "Could not find absl vlog module. Are you using an older version?") endif() -# If OpenSSL is installed in a non-standard location, configure with -# something like: -# OPENSSL_ROOT_DIR=/usr/local/opt/openssl cmake .. -include_directories(${OPENSSL_INCLUDE_DIR}) - if (WITH_PYTHON) # Should be easy to make it work with swig3, but some args to %pythonprepend # seem to be different and were changed. @@ -104,6 +98,10 @@ else() add_compile_options(-fsized-deallocation) endif() +# If OpenSSL is installed in a non-standard location, configure with +# something like: +# OPENSSL_ROOT_DIR=/usr/local/opt/openssl cmake .. +include_directories(${OPENSSL_INCLUDE_DIR}) if (WITH_PYTHON) include_directories(${Python3_INCLUDE_DIRS}) @@ -491,6 +489,8 @@ if(S2_ENABLE_INSTALL) endif() # S2_ENABLE_INSTALL if (BUILD_TESTS) + find_package(OpenSSL REQUIRED) + if (NOT GOOGLETEST_ROOT) message(FATAL_ERROR "BUILD_TESTS requires GOOGLETEST_ROOT") endif() From 1bbe409638819a2ba4b0655c01fc8bdcea539cab Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Sat, 25 Oct 2025 12:47:14 -0600 Subject: [PATCH 29/31] Minor spacing fixes and more comments on AddBigit and MulAddBigit. --- CMakeLists.txt | 1 - src/s2/util/math/exactfloat/bignum.cc | 31 ++++++++++++++++++++++++++- 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e86bb2fc..5426c9cd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -228,7 +228,6 @@ if (GOOGLETEST_ROOT) src/s2/thread_testing.cc) endif() - target_link_libraries( s2 absl::absl_vlog_is_on diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index abaad92b..442d8df1 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -278,7 +278,22 @@ Bignum Bignum::Pow(int32_t pow) const { // Computes a + b + carry and updates the carry. inline Bigit AddBigit(Bigit a, Bigit b, Bigit* absl_nonnull carry) { +// Compilers such as GCC and Clang are known to be terrible at generating good +// code for long carry chains on Intel. Using the _addcarry_u64 intrinsic (which +// maps to the add/adc instructions produces a tight series of add/adc/adc/adc +// instructions, whereas writing the loop manually often generates many add/adc +// pairs with spurious bit twiddling. +// +// Using the intrinsic here improves benchmarks by ~30% when summing larger +// Bignums together. +// +// See this SO discussion for more information: +// https://stackoverflow.com/questions/33690791 +// +// Godbolt link for comparison: +// https://godbolt.org/z/cGnnfMbMn #ifdef __x86_64__ + static_assert(sizeof(Bigit) == sizeof(unsigned long long)); Bigit out; *carry = _addcarry_u64(*carry, a, b, reinterpret_cast(&out)); @@ -295,6 +310,7 @@ inline Bigit AddBigit(Bigit a, Bigit b, Bigit* absl_nonnull carry) { // NOTE: Borrow must be one or zero. inline Bigit SubBigit(Bigit a, Bigit b, Bigit* absl_nonnull borrow) { ABSL_DCHECK_LE(*borrow, Bigit(1)); + // See notes in AddBigit on why using an intrinsic is the right choice here. #ifdef __x86_64__ Bigit out; *borrow = _subborrow_u64(*borrow, a, b, @@ -319,6 +335,18 @@ inline Bigit MulBigit(Bigit a, Bigit b, Bigit* absl_nonnull carry) { // NOTE: Will not overflow even if a, b, and c are their maximum values. inline void MulAddBigit(Bigit* absl_nonnull sum, Bigit a, Bigit b, Bigit* absl_nonnull carry) { + // Similar to the comment in AddBigit, and just for completeness, it's worth + // noting that the "best" way to implement this is with the Intel MULX, ADCQ, + // and ADOQ instructions (i.e. the _mulx_u64, _addcarry_u64, and + // _addcarryx_u64 intrinsics), but GCC and Clang do not support _addcarryx_u64 + // properly (and have no plans to do so). The issue is that gcc doesn't + // support reasoning about separate dependency chains for the carry and + // overflow flags, because all the flags are considered to be one + // register. (The ADOX instructions were added specifically for this use case, + // i.e. high-precision integer multiplies. They propagate carries using the + // overflow flag rather than the carry flag, which lets you do two + // extended-precision add operations in parallel without having them stomp on + // each other's carry flags. ) auto term = absl::uint128(a) * b + *carry + *sum; *carry = absl::Uint128High64(term); *sum = static_cast(term); @@ -338,7 +366,8 @@ inline void MulAddBigit(Bigit* absl_nonnull sum, Bigit a, Bigit b, // // Rather than having to expand A to B.bigits_.size() + 1, and popping off the // top bigit if it's unused (which is the most common case). -inline Bigit AddInPlace(absl::Span a, absl::Span b) { +ABSL_ATTRIBUTE_NOINLINE inline Bigit AddInPlace(absl::Span a, + absl::Span b) { ABSL_DCHECK_GE(a.size(), b.size()); Bigit carry = 0; From 4a4380b751797241ed41e2d43fd2674b4a24eba3 Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Thu, 30 Oct 2025 08:46:18 -0600 Subject: [PATCH 30/31] Minor edits. --- src/s2/util/math/exactfloat/bignum.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc index 442d8df1..0d51ac9c 100644 --- a/src/s2/util/math/exactfloat/bignum.cc +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -149,8 +149,7 @@ int bit_width(const Bignum& a) { // Bit width is the bits in the least significant bigits + bit width of // the most significant word. - const int msw_width = - (Bignum::kBigitBits - absl::countl_zero(a.bigits_.back())); + const int msw_width = absl::bit_width(a.bigits_.back()); const int lsw_width = (a.bigits_.size() - 1) * Bignum::kBigitBits; return msw_width + lsw_width; } @@ -291,7 +290,8 @@ inline Bigit AddBigit(Bigit a, Bigit b, Bigit* absl_nonnull carry) { // https://stackoverflow.com/questions/33690791 // // Godbolt link for comparison: -// https://godbolt.org/z/cGnnfMbMn +// https://godbolt.org/z/cGnnfMbMn (no intrinsics) +// https://godbolt.org/z/jnM1Y3Tjs (intrinsics) #ifdef __x86_64__ static_assert(sizeof(Bigit) == sizeof(unsigned long long)); Bigit out; From fc43b3c80f5c8a51ef1e0c9d2f82627fd93c3f7c Mon Sep 17 00:00:00 2001 From: Sean McAllister Date: Fri, 31 Oct 2025 09:17:00 -0600 Subject: [PATCH 31/31] Add SafeAbs to avoid UB when casting INT_MIN. --- src/s2/util/math/exactfloat/exactfloat.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/s2/util/math/exactfloat/exactfloat.cc b/src/s2/util/math/exactfloat/exactfloat.cc index 91695d46..36cb781d 100644 --- a/src/s2/util/math/exactfloat/exactfloat.cc +++ b/src/s2/util/math/exactfloat/exactfloat.cc @@ -64,15 +64,17 @@ ExactFloat::ExactFloat(double v) { } } +// Calculates abs(v) without UB. SafeAbs(INT_MIN) == INT_MIN. +// Generates the same code as std::abs(). +// https://godbolt.org/z/eT6KW1zGb +int SafeAbs(int v) { + return v < 0 ? -static_cast(v) : v; +} + ExactFloat::ExactFloat(int v) { sign_ = (v >= 0) ? 1 : -1; bn_exp_ = 0; - - if (v == std::numeric_limits::min()) { - bn_ = Bignum(static_cast(v)); - } else { - bn_ = Bignum(abs(v)); - } + bn_ = Bignum(static_cast(SafeAbs(v))); Canonicalize(); }