diff --git a/CMakeLists.txt b/CMakeLists.txt index b564d962..5426c9cd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -60,7 +60,6 @@ add_definitions(-DABSL_MIN_LOG_LEVEL=1) if (NOT TARGET absl::base) find_package(absl REQUIRED) endif() -find_package(OpenSSL REQUIRED) # pthreads isn't used directly, but this is still required for std::thread. find_package(Threads REQUIRED) @@ -215,6 +214,7 @@ add_library(s2 src/s2/util/bits/bit-interleave.cc src/s2/util/coding/coder.cc src/s2/util/coding/varint.cc + src/s2/util/math/exactfloat/bignum.cc src/s2/util/math/exactfloat/exactfloat.cc src/s2/util/math/mathutil.cc src/s2/util/units/length-units.cc) @@ -230,7 +230,6 @@ endif() target_link_libraries( s2 - ${OPENSSL_LIBRARIES} absl::absl_vlog_is_on absl::base absl::btree @@ -458,7 +457,8 @@ if(S2_ENABLE_INSTALL) src/s2/util/math/matrix3x3.h src/s2/util/math/vector.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/math") - install(FILES src/s2/util/math/exactfloat/exactfloat.h + install(FILES src/s2/util/math/exactfloat/bignum.h + src/s2/util/math/exactfloat/exactfloat.h DESTINATION "${CMAKE_INSTALL_INCLUDEDIR}/s2/util/math/exactfloat") install(FILES src/s2/util/units/length-units.h src/s2/util/units/physical-units.h @@ -488,6 +488,8 @@ if(S2_ENABLE_INSTALL) endif() # S2_ENABLE_INSTALL if (BUILD_TESTS) + find_package(OpenSSL REQUIRED) + if (NOT GOOGLETEST_ROOT) message(FATAL_ERROR "BUILD_TESTS requires GOOGLETEST_ROOT") endif() @@ -619,6 +621,7 @@ if (BUILD_TESTS) src/s2/s2wrapped_shape_test.cc src/s2/sequence_lexicon_test.cc src/s2/value_lexicon_test.cc + src/s2/util/math/exactfloat/bignum_test.cc src/s2/util/math/exactfloat/exactfloat_test.cc src/s2/util/math/exactfloat/exactfloat_underflow_test.cc) @@ -629,6 +632,7 @@ if (BUILD_TESTS) add_executable(${test} ${test_cc}) target_link_libraries( ${test} + ${OPENSSL_CRYPTO_LIBRARIES} s2testing s2 absl::base absl::btree diff --git a/README.md b/README.md index 9a6b22de..2fce96e7 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,6 @@ This issue may require revision of boringssl or exactfloat. * [Abseil](https://github.com/abseil/abseil-cpp) >= LTS [`20250814`](https://github.com/abseil/abseil-cpp/releases/tag/20250814.1) (standard library extensions) -* [OpenSSL](https://github.com/openssl/openssl) (for its bignum library) * [googletest testing framework >= 1.10](https://github.com/google/googletest) (to build tests and example programs, optional) @@ -139,6 +138,13 @@ Disable building of shared libraries with `-DBUILD_SHARED_LIBS=OFF`. Enable the python interface with `-DWITH_PYTHON=ON`. +# For Testing + +If BUILD_TESTS is 'on' (the default), then OpenSSL must be available to build +some tests: + +* [OpenSSL](https://github.com/openssl/openssl) (for its bignum library) + If OpenSSL is installed in a non-standard location set `OPENSSL_ROOT_DIR` before running configure, for example on macOS: ``` @@ -218,7 +224,7 @@ python -m build The resulting wheel will be in the `dist` directory. -> If OpenSSL is in a non-standard location make sure to set `OPENSSL_ROOT_DIR`; +> If OpenSSL is in a non-standard location make sure to set `OPENSSL_ROOT_DIR`; > see above for more information. ## Other S2 implementations diff --git a/src/s2/util/math/exactfloat/BUILD b/src/s2/util/math/exactfloat/BUILD index 4f5499a9..1056df56 100644 --- a/src/s2/util/math/exactfloat/BUILD +++ b/src/s2/util/math/exactfloat/BUILD @@ -1,14 +1,37 @@ package(default_visibility = ["//visibility:public"]) +cc_library( + name = "bignum", + srcs = ["bignum.cc"], + hdrs = ["bignum.h"], + visibility = ["//visibility:private"], + deps = [ + "@abseil-cpp//absl/algorithm:container", + "@abseil-cpp//absl/container:inlined_vector", + "@abseil-cpp//absl/log:absl_check", + ], +) + cc_library( name = "exactfloat", srcs = ["exactfloat.cc"], hdrs = ["exactfloat.h"], deps = [ + ":bignum", "//s2/base:logging", - "@abseil-cpp//absl/log:log", - "@abseil-cpp//absl/log:absl_check", + ], +) + +cc_test( + name = "bignum_test", + srcs = ["bignum_test.cc"], + deps = [ + ":bignum", + "//:s2_testing_headers", + "@abseil-cpp//absl/log:log_streamer", + "@abseil-cpp//absl/random:bit_gen_ref", "@boringssl//:crypto", + "@googletest//:gtest_main", ], ) diff --git a/src/s2/util/math/exactfloat/bignum.cc b/src/s2/util/math/exactfloat/bignum.cc new file mode 100644 index 00000000..0d51ac9c --- /dev/null +++ b/src/s2/util/math/exactfloat/bignum.cc @@ -0,0 +1,884 @@ +// Copyright 2025 Google LLC +// Author: smcallis@google.com (Sean McAllister) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "s2/util/math/exactfloat/bignum.h" + +#ifdef __x86_64__ +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/numeric/bits.h" +#include "absl/numeric/int128.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace exactfloat_internal { + +// Number of bigits in smaller of the two operands before we fall back to simple +// multiplication in the Karatsuba recursion. Determined empirically. +static constexpr int kSimpleMulThreshold = 24; + +// Computes dst[i] = a[i]*b + c +// +// Returns the final carry, if any. +inline Bigit MulWithCarry(absl::Span dst, absl::Span a, + Bigit b, Bigit carry); + +// Compares magnitude magnitude of two bigit vectors, returning -1, 0, or +1. +// +// Magnitudes are compared lexicographically from the most significant bigit +// to the least significant. +int CmpAbs(absl::Span a, absl::Span b) { + if (a.size() != b.size()) { + return a.size() < b.size() ? -1 : +1; + } + + for (int i = a.size() - 1; i >= 0; --i) { + if (a[i] != b[i]) { + return a[i] < b[i] ? -1 : +1; + } + } + + return 0; +} + +int Bignum::Compare(const Bignum& b) const { + if (is_negative() != b.is_negative()) { + return is_negative() ? -1 : +1; + } + + // Signs are equal, are they both zero? + if (is_zero() && b.is_zero()) { + return 0; + } + + // Signs are equal and non-zero, compare magnitude. + const int compare = CmpAbs(bigits_, b.bigits_); + return is_negative() ? -compare : compare; +} + +std::optional Bignum::FromString(absl::string_view s) { + // A chunk is up to 19 decimal digits, which can always fit into a Bigit. + constexpr size_t kMaxChunkDigits = std::numeric_limits::digits10; + + // NOTE: We use a simple multiply-and-add (aka Horner's) method here for the + // sake of simplicity. This isn't the fastest algorithm, being quadratic in + // the number of chunks the input has. If we use divide and conquer approach + // or an FFT based multiply we could probably make this ~O(n^1.5) or + // semi-linear. + + // Precomputed powers of 10. + static constexpr auto kPow10 = []() { + std::array out = {1}; + for (size_t i = 1; i < out.size(); ++i) { + out[i] = 10 * out[i - 1]; + } + return out; + }(); + + Bignum out; + if (s.empty()) { + return out; + } + + out.bigits_.reserve((s.size() + kMaxChunkDigits - 1) / kMaxChunkDigits); + + bool negative = false; + + // Consume optional +/- at the front. + auto begin = s.cbegin(); + if ((*begin == '+' || *begin == '-')) { + negative = (s[0] == '-'); + ++begin; + } + + const auto end = s.cend(); + while (begin < end) { + int64_t chunk_len = std::min(static_cast(std::distance(begin, end)), + kMaxChunkDigits); + + Bigit chunk = 0; + auto result = std::from_chars(begin, begin + chunk_len, chunk); + if (result.ec != std::errc() || (result.ptr - begin) != chunk_len) { + return std::nullopt; + } + begin += chunk_len; + + // Shift left by chunk_len digits and add the chunk to it. + auto outspan = absl::MakeSpan(out.bigits_); + Bigit carry = MulWithCarry(outspan, outspan, kPow10[chunk_len], chunk); + if (carry) { + out.bigits_.emplace_back(carry); + } + } + + out.negative_ = negative; + out.Normalize(); + return out; +} + +int bit_width(const Bignum& a) { + ABSL_DCHECK(a.bigits_.empty() || a.bigits_.back() != 0); + if (a.is_zero()) { + return 0; + } + + // Bit width is the bits in the least significant bigits + bit width of + // the most significant word. + const int msw_width = absl::bit_width(a.bigits_.back()); + const int lsw_width = (a.bigits_.size() - 1) * Bignum::kBigitBits; + return msw_width + lsw_width; +} + +int countr_zero(const Bignum& a) { + int nzero = 0; + for (Bigit bigit : a.bigits_) { + if (bigit == 0) { + nzero += Bignum::kBigitBits; + } else { + nzero += absl::countr_zero(bigit); + break; + } + } + return nzero; +} + +bool Bignum::is_bit_set(int nbit) const { + ABSL_DCHECK_GE(nbit, 0); + const size_t digit = nbit / kBigitBits; + const size_t shift = nbit % kBigitBits; + + if (digit >= bigits_.size()) { + return false; + } + + return ((bigits_[digit] >> shift) & 0x1) != 0; +} + +Bignum Bignum::operator-() const { + Bignum result = *this; + result.negate(); + return result; +} + +Bignum& Bignum::operator<<=(int nbit) { + ABSL_DCHECK_GE(nbit, 0); + if (is_zero() || nbit == 0) { + return *this; + } + + const int nbigit = nbit / kBigitBits; + const int nrem = nbit % kBigitBits; + + // First, handle the whole-bigit shift by inserting zeros. + bigits_.insert(bigits_.begin(), nbigit, 0); + + // Then, handle the within-bigit shift, if any. + if (nrem != 0) { + Bigit carry = 0; + for (size_t i = 0; i < bigits_.size(); ++i) { + const Bigit old_val = bigits_[i]; + bigits_[i] = (old_val << nrem) | carry; + carry = old_val >> (kBigitBits - nrem); + } + + if (carry) { + bigits_.push_back(carry); + } + } + + return *this; +} + +Bignum& Bignum::operator>>=(int nbit) { + ABSL_DCHECK_GE(nbit, 0); + if (is_zero() || nbit == 0) { + return *this; + } + + // Shifting by more than the bit width results in zero. + if (nbit >= bit_width(*this)) { + return set_zero(); + } + + const int nbigit = nbit / kBigitBits; + const int nrem = nbit % kBigitBits; + + // First, handle the whole-bigit shift by removing bigits. + bigits_.erase(bigits_.begin(), bigits_.begin() + nbigit); + + // Then, handle the within-bigit shift, if any. + if (nrem != 0) { + Bigit carry = 0; + for (auto i = static_cast(bigits_.size()) - 1; i >= 0; --i) { + const Bigit old_val = bigits_[i]; + bigits_[i] = (old_val >> nrem) | carry; + carry = old_val << (kBigitBits - nrem); + } + } + + // Result might be smaller or zero, so normalize. + Normalize(); + return *this; +} + +// Raise this value to the given power, which must be non-negative. +Bignum Bignum::Pow(int32_t pow) const { + ABSL_DCHECK_GE(pow, 0); + + // Anything to the zero-th power is 1 (including zero). + if (pow == 0) { + return Bignum(1); + } + + if (is_zero()) { + return Bignum(0); + } + + // Core algorithm: Exponentiation by squaring. + Bignum result(1); + Bignum base = *this; // A mutable copy of the base. + uint32_t upow = static_cast(pow); + + while (upow > 0) { + if (upow & 1) { // If current exponent bit is 1, multiply into result. + result *= base; + } + base *= base; + upow >>= 1; + } + + return result; +} + +// Computes a + b + carry and updates the carry. +inline Bigit AddBigit(Bigit a, Bigit b, Bigit* absl_nonnull carry) { +// Compilers such as GCC and Clang are known to be terrible at generating good +// code for long carry chains on Intel. Using the _addcarry_u64 intrinsic (which +// maps to the add/adc instructions produces a tight series of add/adc/adc/adc +// instructions, whereas writing the loop manually often generates many add/adc +// pairs with spurious bit twiddling. +// +// Using the intrinsic here improves benchmarks by ~30% when summing larger +// Bignums together. +// +// See this SO discussion for more information: +// https://stackoverflow.com/questions/33690791 +// +// Godbolt link for comparison: +// https://godbolt.org/z/cGnnfMbMn (no intrinsics) +// https://godbolt.org/z/jnM1Y3Tjs (intrinsics) +#ifdef __x86_64__ + static_assert(sizeof(Bigit) == sizeof(unsigned long long)); + Bigit out; + *carry = + _addcarry_u64(*carry, a, b, reinterpret_cast(&out)); + return out; +#else + auto sum = absl::uint128(a) + b + *carry; + *carry = absl::Uint128High64(sum); + return static_cast(sum); +#endif +} + +// Computes a - b - borrow and updates the borrow. +// +// NOTE: Borrow must be one or zero. +inline Bigit SubBigit(Bigit a, Bigit b, Bigit* absl_nonnull borrow) { + ABSL_DCHECK_LE(*borrow, Bigit(1)); + // See notes in AddBigit on why using an intrinsic is the right choice here. +#ifdef __x86_64__ + Bigit out; + *borrow = _subborrow_u64(*borrow, a, b, + reinterpret_cast(&out)); + return out; +#else + Bigit diff = a - b - *borrow; + *borrow = (a < b) || (*borrow && (a == b)); + return diff; +#endif +} + +// Computes a * b + carry and updates the carry. +inline Bigit MulBigit(Bigit a, Bigit b, Bigit* absl_nonnull carry) { + auto sum = absl::uint128(a) * b + *carry; + *carry = absl::Uint128High64(sum); + return static_cast(sum); +} + +// Computes sum += a * b + carry and updates the carry. +// +// NOTE: Will not overflow even if a, b, and c are their maximum values. +inline void MulAddBigit(Bigit* absl_nonnull sum, Bigit a, Bigit b, + Bigit* absl_nonnull carry) { + // Similar to the comment in AddBigit, and just for completeness, it's worth + // noting that the "best" way to implement this is with the Intel MULX, ADCQ, + // and ADOQ instructions (i.e. the _mulx_u64, _addcarry_u64, and + // _addcarryx_u64 intrinsics), but GCC and Clang do not support _addcarryx_u64 + // properly (and have no plans to do so). The issue is that gcc doesn't + // support reasoning about separate dependency chains for the carry and + // overflow flags, because all the flags are considered to be one + // register. (The ADOX instructions were added specifically for this use case, + // i.e. high-precision integer multiplies. They propagate carries using the + // overflow flag rather than the carry flag, which lets you do two + // extended-precision add operations in parallel without having them stomp on + // each other's carry flags. ) + auto term = absl::uint128(a) * b + *carry + *sum; + *carry = absl::Uint128High64(term); + *sum = static_cast(term); +} + +// Computes a += b in place. Returns the final carry (if any). +// +// A operand must be at least as large as B. When adding two same-sized values, +// the result may overflow and be larger than either of them, in which case we +// will return the final carry value. +// +// This allows a work flow like this: +// Bigit carry = AddInPlace(a, b); +// if (carry) { +// a.bigits_.emplace_back(carry); +// } +// +// Rather than having to expand A to B.bigits_.size() + 1, and popping off the +// top bigit if it's unused (which is the most common case). +ABSL_ATTRIBUTE_NOINLINE inline Bigit AddInPlace(absl::Span a, + absl::Span b) { + ABSL_DCHECK_GE(a.size(), b.size()); + + Bigit carry = 0; + + // Dispatch four at a time to help loop unrolling. + size_t i = 0; + while (i + 4 <= b.size()) { + for (int j = 0; j < 4; ++j, ++i) { + a[i] = AddBigit(a[i], b[i], &carry); + } + } + + // Finish remainder. + for (; i < b.size(); ++i) { + a[i] = AddBigit(a[i], b[i], &carry); + } + + // Propagate carry through the rest of a. + for (; carry && i < a.size(); ++i) { + a[i] = AddBigit(a[i], 0, &carry); + } + + return carry; +} + +// Computes dst = a + b out of place. Returns the number of bigits actually +// written into dst. +// +// NOTE: dst must be sized to be larger than max(a.size(), b.size()) + 1 (i.e. +// it must be able to hold the carry bigit, if any. +// +// This allows for using a pre-allocated buffer to store the result of an +// addition followed by trimming down to size: +// +// auto out = arena.Alloc(std::max(a.size(), b.size()) + 1); +// out = out.first(AddInto(out, a, b)); +// +// Which is used in the Karatsuba multiplication, where we don't have the option +// to expand the allocate space on demand. +inline size_t Add(absl::Span dst, absl::Span a, + absl::Span b) { + const size_t max_size = std::max(a.size(), b.size()); + const size_t min_size = std::min(a.size(), b.size()); + ABSL_DCHECK_GE(dst.size(), max_size + 1); + + // Add common parts. + Bigit carry = 0; + + // Dispatch four at a time to help loop unrolling. + size_t i = 0; + while (i + 4 < min_size) { + for (int j = 0; j < 4; ++j, ++i) { + dst[i] = AddBigit(a[i], b[i], &carry); + } + } + + // Finish remainder of the parts common to A and B. + for (; i < min_size; ++i) { + dst[i] = AddBigit(a[i], b[i], &carry); + } + + // Copy remaining digits from the longer operand and propagate carry. + auto longer = (a.size() > b.size()) ? a : b; + + // Dispatch four at a time for the remaining part. + const size_t size = longer.size(); + while (i + 4 < size) { + for (int j = 0; j < 4; ++j, ++i) { + dst[i] = AddBigit(longer[i], 0, &carry); + } + } + + // Propagate carry through the longer operand. + for (; i < size; ++i) { + dst[i] = AddBigit(longer[i], 0, &carry); + } + + if (carry) { + dst[i++] = carry; + return max_size + 1; + } + + return max_size; +} + +// Computes a -= b. +// +// REQUIRES: |a| >= |b|. +inline void SubInPlace(absl::Span a, absl::Span b) { + ABSL_DCHECK_GE(a.size(), b.size()); + ABSL_DCHECK_GE(CmpAbs(a, b), 0); + + Bigit borrow = 0; + + // Dispatch four at a time to help loop unrolling. + size_t size = b.size(); + size_t i = 0; + while (i + 4 <= size) { + for (int j = 0; j < 4; ++j, ++i) { + a[i] = SubBigit(a[i], b[i], &borrow); + } + } + + // Finish remainder of subtraction. + for (; i < size; ++i) { + a[i] = SubBigit(a[i], b[i], &borrow); + } + + // Propagate the borrow through a. + for (; borrow && i < a.size(); ++i) { + borrow = (a[i] == 0); + a[i]--; + } +} + +// Computes a = b - a. +// +// NOTE: Requires |b| >= |a|. +// +// Since we write the result to a, but b is larger, a must be expanded with +// enough leading zeros to fit the result. +inline void SubReverseInPlace(absl::Span a, absl::Span b) { + ABSL_DCHECK_GE(a.size(), b.size()); + ABSL_DCHECK_GE(CmpAbs(b, a), 0); + + Bigit borrow = 0; + + // Dispatch four at a time to help loop unrolling. + size_t size = a.size(); + size_t i = 0; + while (i + 4 <= size) { + for (int j = 0; j < 4; ++j, ++i) { + a[i] = SubBigit(b[i], a[i], &borrow); + } + } + + // Finish remainder. + for (; i < size; ++i) { + a[i] = SubBigit(b[i], a[i], &borrow); + } +} + +inline Bigit MulWithCarry(absl::Span dst, absl::Span a, + Bigit b, Bigit carry) { + ABSL_DCHECK_GE(dst.size(), a.size()); + + // Dispatch four at a time to help loop unrolling. + size_t i = 0; + while (i + 4 <= a.size()) { + for (int j = 0; j < 4; ++j, ++i) { + dst[i] = MulBigit(a[i], b, &carry); + } + } + + for (; i < a.size(); ++i) { + dst[i] = MulBigit(a[i], b, &carry); + } + + return carry; +} + +// Computes sum[i] += a[i]*b in place. +// +// Returns the final carry, if any. +inline Bigit MulAddInPlace(absl::Span sum, absl::Span a, + Bigit b) { + // Dispatch four at a time to help loop unrolling. + Bigit carry = 0; + size_t i = 0; + while (i + 4 <= a.size()) { + for (int j = 0; j < 4; ++j, ++i) { + MulAddBigit(&sum[i], a[i], b, &carry); + } + } + + // Finish remainder. + for (; i < a.size(); ++i) { + MulAddBigit(&sum[i], a[i], b, &carry); + } + + return carry; +} + +// Implements the standard grade school long multiplication algorithm. The +// output is computed by multiplying A by each digit of B and summing the +// results as we go. This is a quadratic algorithm and only serves as the base +// case for the recursive Karatsuba algorithm below. +// +// NOTE: out must be at least as large as the sums of the sizes of A and B. +inline void MulQuadratic(absl::Span dst, absl::Span a, + absl::Span b) { + ABSL_DCHECK_GE(dst.size(), a.size() + b.size()); + + // Make sure A is the longer of the two arguments. + if (a.size() < b.size()) { + using std::swap; + swap(a, b); + } + + if (b.empty()) { + absl::c_fill(dst, 0); + return; + } + + // Each call to MulAdd and MulAddInPlace only updates a.size() elements of out + // so we manually set the carries as we go. We grab a span to the upper half + // of out starting at a.size() to facilitate this. + auto upper = dst.subspan(a.size()); + upper[0] = MulWithCarry(dst, a, b[0], 0); + + const size_t size = b.size(); + size_t i = 1; + for (; i < size; ++i) { + upper[i] = MulAddInPlace(dst.subspan(i), a, b[i]); + } + + // Finish zeroing out the upper half. + for (; i < upper.size(); ++i) { + upper[i] = 0; + } +} + +// Split a span into two contiguous spans of length at most a and b. +// +// If span.size() <= a, the second span is empty, otherwise the second span +// has length at most b. If span.size() > a + b, then the two spans only cover +// part of the input span. +template +inline std::pair, absl::Span> Split(absl::Span span, + size_t a, size_t b) { + if (a < span.size()) { + return {span.subspan(0, a), span.subspan(a, b)}; + } + return {span.subspan(0, a), {}}; +}; + +// A simple bump allocator to allow us to very efficiently allocate temporary +// space when recursing in the Karatsuba multiply. The arena is pre-sized and +// returns spans of memory via Alloc() which are then returned to the arena via +// Release. +// +// NOTE: We use std::unique_ptr here instead of std::vector because we don't +// want to initialize the memory unnecessarily and using std::vector without +// resizing (and thus initializing) the container leads to false positives with +// ASAN. +class Arena { + public: + // TODO: Use make_unique_for_overwrite when on C++20. + explicit Arena(size_t size) : size_(size), data_(new Bigit[size]) {} + + // Allocates a span of length n from the arena. + absl::Span Alloc(size_t n) { + ABSL_DCHECK_LE(used_ + n, size_); + size_t start = used_; + used_ += n; + return absl::Span(data_.get() + start, n); + } + + size_t Available() const { return size_ - used_; } + + size_t Used() const { return used_; } + + // Resets the arena to the given position which must be < Used(). + void Reset(size_t to) { + ABSL_DCHECK_LE(to, used_); + used_ = to; + } + + private: + size_t size_ = 0; + size_t used_ = 0; + std::unique_ptr data_; +}; + +// Returns the total arena size needed to multiply two number of a_size and +// b_size bigits using the recursive Karatsuba implementation. +inline size_t ArenaSize(size_t a_size, size_t b_size) { + // Each step of Karatsuba splits at: + // N = (std::max(a.size() + b.size() + 1) / 2 + // + // We have to hold a total of 4*(N + 1) bigits as temporaries at each step. + // + // Simulate the recursion (log(n) steps) and compute the arena size. + int peak = 0; + while (std::min(a_size, b_size) > kSimpleMulThreshold) { + int half = (std::max(a_size, b_size) + 1) / 2; + int next = half + 1; + peak += 4 * next; + a_size = next; + b_size = next; + }; + return peak; +} + +// Recursive step in the Karatsuba multiplication. dst must be large enough to +// hold the product of a and b (i.e. it must be at least as large as a.size() + +// b.size()). The product is computed and stored in-place in dst. +// +// Additionally an arena must be provided for temporary storage for intermediate +// products. The arena must have at least ArenaSize(a.size(), b.size()) space +// available. +inline void KaratsubaMulRecursive(absl::Span dst, + absl::Span a, + absl::Span b, + Arena* absl_nonnull arena) { + ABSL_DCHECK_GE(dst.size(), a.size() + b.size()); + ABSL_DCHECK_GE(arena->Available(), ArenaSize(a.size(), b.size())); + if (a.empty() || b.empty()) { + absl::c_fill(dst, 0); + return; + } + + int arena_start = arena->Used(); + + // Karatsuba lets us represent two numbers of M bigits each, A and B, as: + // + // A = a1*10^(M/2) + a0 + // B = b1*10^(M/2) + b0 + // + // Which we can multiply out: + // AB = (a1*10^(M/2) + a0)*(b1*10^(M/2) + b0); + // = a1*b1*10^M + (a1*b0 + a0*b1)*10^(M/2) + a0*b0 + // = z2 * 10^M + z1*10^(M/2) + z0 + // + // Where: + // z0 = a0*b0 + // z1 = a1*b0 + a0*b1 + // z2 = a1*b1 + // + // We can replace the multiplications in z1 by computing: + // + // z3 = (a0 + a1)*(b0 + b1) + // + // And noting z1 = z3 - z2 - z0 + // + // This lets us compute a 2M digit multiply with three M digit multiplies, + // with those individual multiplies able to be recursively divided. + + // Fall back to long multiplication when we're small enough. + if (std::min(a.size(), b.size()) <= kSimpleMulThreshold) { + MulQuadratic(dst, a, b); + return; + } + + const size_t half = (std::max(a.size(), b.size()) + 1) / 2; + + // Split the inputs into contiguous subspans. + auto [a0, a1] = Split(a, half, half); + auto [b0, b1] = Split(b, half, half); + + // Make space to hold results in the output and multiply sub-terms. + // z0 = a0 * b0 + // z2 = a1 * b1 + auto [z0, z2] = Split(dst, a0.size() + b0.size(), a1.size() + b1.size()); + KaratsubaMulRecursive(z0, a0, b0, arena); + KaratsubaMulRecursive(z2, a1, b1, arena); + + // Compute (a0 + a1) and (b0 + b1) + // + // If the upper terms are zero we can just re-use the terms we have, otherwise + // we compute the sum and pop off the MSB bigit if no carry occurred. + absl::Span asum = a0; + if (!a1.empty()) { + absl::Span tmp = arena->Alloc(half + 1); + asum = tmp.first(Add(tmp, a0, a1)); + } + + absl::Span bsum = b0; + if (!b1.empty()) { + absl::Span tmp = arena->Alloc(half + 1); + bsum = tmp.first(Add(tmp, b0, b1)); + } + + // Compute z1 = asum*bsum - z0 - z2 = (a0 + a1)*(b0 + b1) - z0 - z2 + auto z1 = arena->Alloc(asum.size() + bsum.size()); + + // Compute asum * bsum into the beginning of z1 + KaratsubaMulRecursive(z1, asum, bsum, arena); + + // NOTE: (a0 + a1) * (b0 + b1) >= a0*b0 + a1*b1 so this never underflows. + SubInPlace(z1, z0); + if (!a1.empty() && !b1.empty()) { + SubInPlace(z1, z2); + } + + // We need to add z1*10^half, which we can do by simply adding z1 at a shifted + // position in the output. + auto dst_z1 = dst.subspan(half); + + // Although the value of z1 is guaranteed to fit in the available space of + // dst, it may have one or more high-order zero bigits because it was sized + // conservatively to hold the intermediate result (asum * bsum). We trim these + // leading zeros if necessary to ensure that the Add() operation below does + // not attempt to write zero bigits past the end of dst. + AddInPlace(dst_z1, z1.first(std::min(z1.size(), dst_z1.size()))); + + // Release temporary memory we used. + arena->Reset(arena_start); +} + +// Multiplies two unsigned bigit vectors together using Karatsuba's algorithm. +// +// This algorithm recursively subdivides the inputs until one or both is below +// some threshold, and then falls back to standard long multiplication. +void KaratsubaMul(absl::Span dst, absl::Span a, + absl::Span b) { + ABSL_DCHECK_GE(dst.size(), a.size() + b.size()); + if (a.empty() || b.empty()) { + absl::c_fill(dst, 0); + return; + } + + Arena arena(ArenaSize(a.size(), b.size())); + KaratsubaMulRecursive(dst, a, b, &arena); +} + +Bignum& Bignum::operator+=(const Bignum& b) { + if (b.is_zero()) { + return *this; + } + + if (is_zero()) { + *this = b; + return *this; + } + + if (is_negative() == b.is_negative()) { + // Same sign: + // +|a| + +|b| == +(|a| + |b|) + // -|a| + -|b| == -(|a| + |b|) + // + // So we can just sum magnitudes, final sign is the same as A. + bigits_.resize(std::max(bigits_.size(), b.bigits_.size()), 0); + Bigit carry = AddInPlace(absl::MakeSpan(bigits_), b.bigits_); + if (carry) { + bigits_.emplace_back(carry); + } + } else { + // We know the signs are different, so there's two options: + // -|a| + +|b| = ?(|b| - |a|) + // +|a| + -|b| = ?(|a| - |b|) + // + // With the final sign being dependent on how |a| and |b| relate. + if (CmpAbs(bigits_, b.bigits_) >= 0) { + // |a| >= |b| + // -|a| + +|b| --> -(|a| - |b|) + // +|a| + -|b| --> +(|a| - |b|) + // + // So we can subtract magnitudes, final sign is the same as A. + SubInPlace(absl::MakeSpan(bigits_), b.bigits_); + } else { + // |a| < |b| + // -|a| + +|b| --> +(|b| - |a|) + // +|a| + -|b| --> -(|b| - |a|) + // + // So we can compute |b| - |a| and the final sign is the same as B. + bigits_.resize(b.bigits_.size()); + SubReverseInPlace(absl::MakeSpan(bigits_), b.bigits_); + negative_ = b.is_negative(); + } + } + + Normalize(); + return *this; +} + +Bignum& Bignum::operator-=(const Bignum& b) { + if (this == &b) { + set_zero(); + return *this; + } + + // Compute -(-a + b) == a - b + negate(); + *this += b; + negate(); + return *this; +} + +Bignum& Bignum::operator*=(const Bignum& b) { + if (is_zero() || b.is_zero()) { + return set_zero(); + } + + // Result is only negative if signs are different. + const bool negative = (is_negative() != b.is_negative()); + + // Fast path for single-bigit multiplication. + if (bigits_.size() == 1 && b.bigits_.size() == 1) { + absl::uint128 prod = absl::uint128(bigits_[0]) * b.bigits_[0]; + const uint64_t lo = absl::Uint128Low64(prod); + const uint64_t hi = absl::Uint128High64(prod); + if (hi == 0) { + bigits_ = {lo}; + } else { + bigits_ = {lo, hi}; + } + set_negative(negative); + return *this; + } + + // Use Karatsuba multiplication. + // If the inputs are small enough this will just do long multiplication. + BigitVector result; + result.resize(bigits_.size() + b.bigits_.size()); + KaratsubaMul(absl::MakeSpan(result), bigits_, b.bigits_); + bigits_ = std::move(result); + + negative_ = negative; + Normalize(); + return *this; +} + +} // namespace exactfloat_internal diff --git a/src/s2/util/math/exactfloat/bignum.h b/src/s2/util/math/exactfloat/bignum.h new file mode 100644 index 00000000..c85bb8ea --- /dev/null +++ b/src/s2/util/math/exactfloat/bignum.h @@ -0,0 +1,383 @@ +// Copyright 2025 Google LLC +// Author: smcallis@google.com (Sean McAllister) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef S2_UTIL_MATH_EXACTFLOAT_BIGNUM_H_ +#define S2_UTIL_MATH_EXACTFLOAT_BIGNUM_H_ + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/container/inlined_vector.h" +#include "absl/log/absl_check.h" +#include "absl/numeric/bits.h" +#include "absl/numeric/int128.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_format.h" + +namespace exactfloat_internal { + +// A digit of a bignum. A contraction of "big digit" (rhymes with the latter). +using Bigit = uint64_t; + +// A class to support arithmetic on large, arbitrary precision integers. +// +// Large integers are represented as an array of uint64_t values. +class Bignum { + public: + // The most common use of ExactFloat involves evaluating a 3x3 determinant to + // determine whether 3 points are oriented clockwise or counter-clockwise. + // + // The typical number of mantissa bits in the result is probably about 170, so + // we allocate 4 bigits (256 bits) inline. + using BigitVector = absl::InlinedVector; + + static constexpr int kBigitBits = std::numeric_limits::digits; + + Bignum() = default; + + // Constructs a bignum from an integral value (signed or unsigned). + template ::is_integer>> + explicit Bignum(T value); + + //-------------------------------------- + // String formatting and parsing. + //-------------------------------------- + + // Constructs a bignum from an ASCII string containing decimal digits. + // + // The input string must only have an optional leading +/- and decimal digits. + // Any other characters will yield std::nullopt. + static std::optional FromString(absl::string_view s); + + // Formats the bignum as a decimal integer into an abseil sink. + template + friend void AbslStringify(Sink& sink, const Bignum& b); + + friend std::ostream& operator<<(std::ostream& os, const Bignum& b) { + return os << absl::StrFormat("%v", b); + } + + friend std::ostream& operator<<(std::ostream& os, + const std::optional& b) { + if (!b) { + return os << "[nullopt]"; + } + return os << *b; + } + + //-------------------------------------- + // Casting and coercion. + //-------------------------------------- + + // Returns true if bignum can be stored in T without truncation or overflow. + template >> + bool FitsIn() const; + + // Cast to an integral type T unconditionally. Use FitsIn() or + // ConvertTo to perform conversion with bounds checking. + template >> + T Cast() const; + + // Casts the value to the given type if it fits, otherwise std::nullopt. + template >> + std::optional ConvertTo() const { + if (!FitsIn()) { + return std::nullopt; + } + return Cast(); + } + + //-------------------------------------- + // General accessors. + //-------------------------------------- + + // Returns the number of bits required for the magnitude of the value. + // + // Named to match std::bit_width. + friend int bit_width(const Bignum& a); + + // Returns the number of consecutive 0 bits in the value, starting from the + // least significant bit. + // + // Named to match std::countr_zero. + friend int countr_zero(const Bignum& a); + + // Returns true if the n-th bit of the number's magnitude is 1. + bool is_bit_set(int nbit) const; + + // Clears this bignum and sets it to zero. + Bignum& set_zero() { + negative_ = false; + bigits_.clear(); + return *this; + } + + // Sets the negative flag on the value. If the value is zero, has no effect. + Bignum& set_negative(bool negative = true) { + negative_ = !bigits_.empty() && negative; + return *this; + } + + // Returns true if the number is zero. + bool is_zero() const { return bigits_.empty(); } + + // Returns true if the number is less than zero. + bool is_negative() const { return negative_; } + + // Returns true if the number is odd (least significant bit is 1). + bool is_odd() const { return is_bit_set(0); } + + // Returns true if the number is even (least significant bit is 0). + bool is_even() const { return !is_odd(); } + + //-------------------------------------- + // Comparisons. + //-------------------------------------- + + bool operator==(const Bignum& b) const { + return negative_ == b.negative_ && bigits_ == b.bigits_; + } + + bool operator!=(const Bignum& b) const { return !(*this == b); } + bool operator<(const Bignum& b) const { return Compare(b) < 0; } + bool operator<=(const Bignum& b) const { return Compare(b) <= 0; } + bool operator>(const Bignum& b) const { return Compare(b) > 0; } + bool operator>=(const Bignum& b) const { return Compare(b) >= 0; } + + //-------------------------------------- + // Arithmetic operators. + //-------------------------------------- + + Bignum operator+() const { return *this; } + Bignum operator-() const; + + // Raise this value to the given power, which must be non-negative. + Bignum Pow(int32_t pow) const; + + Bignum& operator+=(const Bignum& b); + Bignum& operator-=(const Bignum& b); + Bignum& operator*=(const Bignum& b); + Bignum& operator<<=(int nbit); + Bignum& operator>>=(int nbit); + + friend Bignum operator+(Bignum a, const Bignum& b) { return a += b; } + friend Bignum operator-(Bignum a, const Bignum& b) { return a -= b; } + friend Bignum operator*(Bignum a, const Bignum& b) { return a *= b; } + friend Bignum operator<<(Bignum a, int nbit) { return a <<= nbit; } + friend Bignum operator>>(Bignum a, int nbit) { return a >>= nbit; } + + // Negates this bignum in place. + void negate() { + negative_ = !negative_; + if (bigits_.empty()) { + negative_ = false; + } + } + + // Compares to another bignum, returning -1, 0, +1. + int Compare(const Bignum& b) const; + + private: + // Constructs a Bignum from bigits and an optional sign bit. + explicit Bignum(BigitVector bigits, bool negative = false) + : bigits_(std::move(bigits)), negative_(negative) { + Normalize(); + } + + // Drop leading zero bigits, and ensure sign is positive if result is zero. + void Normalize() { + while (!bigits_.empty() && bigits_.back() == 0) { + bigits_.pop_back(); + } + + if (bigits_.empty()) { + negative_ = false; + } + } + + // We store bignums in sign-magnitude form. bigits_ contains the individual + // 64-bit digits of the bignum, stored in little-endian order. If bigits_ is + // non-empty, then the last element must be non-zero and when it is empty + // (representing a zero value), negative_ must be false. + BigitVector bigits_; + bool negative_ = false; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Implementation Details +//////////////////////////////////////////////////////////////////////////////// + +// Constructs a bignum from an integral value (signed or unsigned). +template +Bignum::Bignum(T value) { + using UT = std::make_unsigned_t; + + if (value == 0) { + return; + } + + negative_ = false; + if constexpr (std::is_signed_v) { + // Put into constexpr if to avoid warnings when T is unsigned. + negative_ = (value < 0); + } + + // Get magnitude of value, handle minimum value of T cleanly. + UT mag = static_cast(value); + if constexpr (std::is_signed_v) { + if (value < 0) { + mag = UT(0) - mag; + } + } + + // Pack the magnitude into bigits. + if constexpr (std::numeric_limits::digits <= kBigitBits) { + bigits_.push_back(static_cast(mag)); + } else { + while (mag) { + bigits_.push_back(static_cast(mag)); + mag >>= kBigitBits; + } + } +} + +// Formats the bignum as a decimal integer into an abseil sink. +template +void AbslStringify(Sink& sink, const Bignum& b) { + if (b.is_zero()) { + sink.Append("0"); + return; + } + + // Sign + if (b.is_negative()) { + sink.Append("-"); + } + + // Work on a copy of the magnitude. + Bignum copy = b; + copy.negative_ = false; + + // 10**19 is the largest power of 10 that fits in 64-bits. So we can + // repeatedly divide and modulo the bignum to get uint64_t values we can + // format as 19 decimal digits. + static_assert(sizeof(Bigit) * 8 == 64); + static constexpr uint64_t kChunkDivisor = 10'000'000'000'000'000'000u; + Bignum::BigitVector chunks; + + while (!copy.is_zero()) { + absl::uint128 rem = 0; + for (int i = static_cast(copy.bigits_.size()) - 1; i >= 0; --i) { + absl::uint128 acc = (rem << 64) + copy.bigits_[i]; + Bigit quot = static_cast(acc / kChunkDivisor); + rem = acc - absl::uint128(quot) * kChunkDivisor; + copy.bigits_[i] = quot; + } + + copy.Normalize(); + chunks.push_back(static_cast(rem)); + } + ABSL_DCHECK(!chunks.empty()); + + // Emit most significant chunk without zero padding. + absl::Format(&sink, "%d", chunks.back()); + + // Emit remaining chunks as fixed-width 19-digit zero-padded blocks. + for (int i = static_cast(chunks.size()) - 2; i >= 0; --i) { + absl::Format(&sink, "%019d", chunks[i]); + } +} + +template +inline bool Bignum::FitsIn() const { + using UT = std::make_unsigned_t; + + if (is_zero()) { + return true; + } + + // Maximum number of bits that could fit in the output type. + constexpr int kTBitWidth = std::numeric_limits::digits; + constexpr int kMaxBigits = (kTBitWidth + (kBigitBits - 1)) / kBigitBits; + + // Fast reject if the bignum couldn't conceivably fit. + if (bigits_.size() > kMaxBigits) { + return false; + } + + // Unsigned type T can hold the value iff the value is non-negative and + // the bitwidth is <= the maximum bit width of the type. + if constexpr (!std::is_signed_v) { + if (is_negative()) { + return false; + } + return bit_width(*this) <= kTBitWidth; + } + + // T is signed and our bignum isn't zero. + ABSL_DCHECK(std::is_signed_v && !is_zero()); + + if (is_negative()) { + // Magnitude must fit in negative value. If the value is negative and + // the same bit width as the output type, the only valid value is + // -2^(k-1). + if (bit_width(*this) == kTBitWidth) { + return countr_zero(*this) == kTBitWidth - 1; + } + return bit_width(*this) < kTBitWidth; + } else /* positive */ { + return bit_width(*this) <= (kTBitWidth - 1); + } +} + +template +T Bignum::Cast() const { + using UT = std::make_unsigned_t; + + constexpr int kTBitWidth = std::numeric_limits::digits; + + if (bigits_.empty()) { + return 0; + } + + // T fits in a Bigit, so just cast to truncate. + UT residue = 0; + if (kTBitWidth <= kBigitBits) { + residue = static_cast(bigits_[0]); + } else { + ABSL_DCHECK_EQ(kTBitWidth % kBigitBits, 0); + for (int i = 0; i < kTBitWidth / kBigitBits; ++i) { + UT chunk = static_cast(bigits_[i]); + residue |= chunk << (i * kBigitBits); + } + } + + // Compute two's complement of the residue if value is negative. + if (is_negative()) { + residue = UT(0) - residue; + } + + return static_cast(residue); +} + +} // namespace exactfloat_internal + +#endif // S2_UTIL_MATH_EXACTFLOAT_BIGNUM_H_ diff --git a/src/s2/util/math/exactfloat/bignum_test.cc b/src/s2/util/math/exactfloat/bignum_test.cc new file mode 100644 index 00000000..038edd4b --- /dev/null +++ b/src/s2/util/math/exactfloat/bignum_test.cc @@ -0,0 +1,1231 @@ +// Copyright 2025 Google LLC +// Author: smcallis@google.com (Sean McAllister) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "s2/util/math/exactfloat/bignum.h" + +#include +#include +#include +#include +#include +#include +#include + +// TODO: remove once benchmarks are available +#if 0 +#include "benchmark/benchmark.h" +#endif + +#include "absl/log/log_streamer.h" +#include "absl/random/bit_gen_ref.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "gtest/gtest.h" +#include "openssl/bn.h" +#include "openssl/crypto.h" +#include "s2/s2testing.h" + +namespace exactfloat_internal { + +using ::testing::TestWithParam; + +constexpr uint64_t kU8max = std::numeric_limits::max(); +constexpr uint64_t kU16max = std::numeric_limits::max(); +constexpr uint64_t kU32max = std::numeric_limits::max(); +constexpr uint64_t kU64max = std::numeric_limits::max(); + +constexpr int64_t kI8max = std::numeric_limits::max(); +constexpr int64_t kI16max = std::numeric_limits::max(); +constexpr int64_t kI32max = std::numeric_limits::max(); +constexpr int64_t kI64max = std::numeric_limits::max(); + +constexpr int64_t kI8min = std::numeric_limits::min(); +constexpr int64_t kI16min = std::numeric_limits::min(); +constexpr int64_t kI32min = std::numeric_limits::min(); +constexpr int64_t kI64min = std::numeric_limits::min(); + +// To reduce duplication. +inline auto Bn(absl::string_view str) { return Bignum::FromString(str); }; + +TEST(BignumTest, ZeroInputsNormalizeToZero) { + EXPECT_EQ(Bn(""), Bignum(0)); + EXPECT_EQ(Bn("+"), Bignum(0)); + EXPECT_EQ(Bn("-"), Bignum(0)); + EXPECT_EQ(Bn("0"), Bignum(0)); + EXPECT_EQ(Bn("-0"), Bignum(0)); + EXPECT_EQ(Bn("000000"), Bignum(0)); + EXPECT_EQ(Bn("-000000"), Bignum(0)); + EXPECT_EQ(Bn("-00000000000000000000"), Bignum(0)); + EXPECT_EQ(Bn("+00000000000000000000"), Bignum(0)); +} + +TEST(BignumTest, BasicSmallNumbers) { + EXPECT_EQ(Bn("42"), Bignum(42)); + EXPECT_EQ(Bn("-17"), Bignum(-17)); + EXPECT_EQ(Bn("+123"), Bignum(123)); +} + +TEST(BignumTest, LeadingZeros) { + EXPECT_EQ(Bn("0000042"), Bignum(42)); + EXPECT_EQ(Bn("-00017"), Bignum(-17)); + EXPECT_EQ(Bn("+00007"), Bignum(7)); + + // Larger bignums. + // 10^19 and 10^19 - 1, near chunk boundaries. + EXPECT_EQ(Bn("100000000000000000"), Bn("000100000000000000000")); + EXPECT_EQ(Bn("99999999999999999"), Bn("000099999999999999999")); + + // 2^64 - 1 cross-check against integral constructor. + EXPECT_EQ(Bn("18446744073709551615"), Bignum(18446744073709551615ull)); + + // 2^64 cross-check with leading zeros variant via string constructor. + EXPECT_EQ(Bn("18446744073709551616"), Bn("00018446744073709551616")); + EXPECT_EQ(Bn("-18446744073709551616"), Bn("-00018446744073709551616")); + + // Multi-digit bignums. + // 2^80 + EXPECT_EQ(Bn("1208925819614629174706176"), + Bn("0001208925819614629174706176")); + EXPECT_EQ(Bn("-1208925819614629174706176"), + Bn("-0001208925819614629174706176")); + + // 2^128 + EXPECT_EQ(Bn("340282366920938463463374607431768211456"), + Bn("000340282366920938463463374607431768211456")); + EXPECT_EQ(Bn("-340282366920938463463374607431768211456"), + Bn("-000340282366920938463463374607431768211456")); +} + +TEST(BignumTest, MultipleSignsBeforeDigitsCausesFailure) { + EXPECT_EQ(Bn("++123"), std::nullopt); + EXPECT_EQ(Bn("--42"), std::nullopt); + EXPECT_EQ(Bn("+-9"), std::nullopt); + EXPECT_EQ(Bn("-+9"), std::nullopt); +} + +TEST(BignumTest, SignInWrongPlaceCausesFailure) { + EXPECT_EQ(Bn("123-"), std::nullopt); + EXPECT_EQ(Bn("456+"), std::nullopt); + EXPECT_EQ(Bn("-789+"), std::nullopt); + EXPECT_EQ(Bn("+314-"), std::nullopt); +} + +TEST(BignumTest, ZeroAlwaysFitsIn) { + const Bignum zero(0); + EXPECT_TRUE(zero.FitsIn()); + EXPECT_TRUE(zero.FitsIn()); + EXPECT_TRUE(zero.FitsIn()); + EXPECT_TRUE(zero.FitsIn()); + EXPECT_TRUE(zero.FitsIn()); + EXPECT_TRUE(zero.FitsIn()); + EXPECT_TRUE(zero.FitsIn()); + EXPECT_TRUE(zero.FitsIn()); +} + +TEST(BignumTest, ZeroAlwaysCastsToZero) { + Bignum zero; + EXPECT_EQ(zero.Cast(), 0); + EXPECT_EQ(zero.Cast(), 0); + EXPECT_EQ(zero.Cast(), 0); + EXPECT_EQ(zero.Cast(), 0); + EXPECT_EQ(zero.Cast(), 0); + EXPECT_EQ(zero.Cast(), 0); + EXPECT_EQ(zero.Cast(), 0); + EXPECT_EQ(zero.Cast(), 0); +} + +TEST(BignumTest, NegativeOnlyFitsInSigned) { + const Bignum small_neg(-1); + EXPECT_FALSE(small_neg.FitsIn()); + EXPECT_FALSE(small_neg.FitsIn()); + EXPECT_FALSE(small_neg.FitsIn()); + EXPECT_FALSE(small_neg.FitsIn()); + + EXPECT_TRUE(small_neg.FitsIn()); + EXPECT_TRUE(small_neg.FitsIn()); + EXPECT_TRUE(small_neg.FitsIn()); + EXPECT_TRUE(small_neg.FitsIn()); +} + +TEST(BignumTest, FitsInUnsignedBoundsChecks) { + const Bignum bn_u8max(kU8max); + const Bignum bn_u8over(kU8max + 1); + EXPECT_TRUE(bn_u8max.FitsIn()); + EXPECT_TRUE(bn_u8max.FitsIn()); + EXPECT_TRUE(bn_u8max.FitsIn()); + EXPECT_FALSE(bn_u8over.FitsIn()); + EXPECT_TRUE(bn_u8over.FitsIn()); + EXPECT_TRUE(bn_u8over.FitsIn()); + + const Bignum bn_u16max(kU16max); + const Bignum bn_u16over(kU16max + 1); + EXPECT_FALSE(bn_u16max.FitsIn()); + EXPECT_TRUE(bn_u16max.FitsIn()); + EXPECT_TRUE(bn_u16max.FitsIn()); + EXPECT_FALSE(bn_u16over.FitsIn()); + EXPECT_FALSE(bn_u16over.FitsIn()); + EXPECT_TRUE(bn_u16over.FitsIn()); + + const Bignum bn_u32max(kU32max); + const Bignum bn_u32over(kU32max + 1); + EXPECT_FALSE(bn_u32max.FitsIn()); + EXPECT_FALSE(bn_u32max.FitsIn()); + EXPECT_TRUE(bn_u32max.FitsIn()); + EXPECT_FALSE(bn_u32over.FitsIn()); + EXPECT_FALSE(bn_u32over.FitsIn()); + EXPECT_FALSE(bn_u32over.FitsIn()); + + const Bignum bn_u64max(kU64max); + EXPECT_TRUE(bn_u64max.FitsIn()); + + // 2^64, need to use string constructor. + Bignum bn0 = *Bn("18446744073709551616"); + EXPECT_FALSE(bn0.FitsIn()); +} + +TEST(BignumTest, FitsInSignedBoundsChecks) { + const Bignum bn_i8max(kI8max); + const Bignum bn_i8over(kI8max + 1); + EXPECT_TRUE(bn_i8max.FitsIn()); + EXPECT_TRUE(bn_i8max.FitsIn()); + EXPECT_TRUE(bn_i8max.FitsIn()); + EXPECT_FALSE(bn_i8over.FitsIn()); + EXPECT_TRUE(bn_i8over.FitsIn()); + EXPECT_TRUE(bn_i8over.FitsIn()); + + const Bignum bn_i16max(kI16max); + const Bignum bn_i16over(kI16max + 1); + EXPECT_FALSE(bn_i16max.FitsIn()); + EXPECT_TRUE(bn_i16max.FitsIn()); + EXPECT_TRUE(bn_i16max.FitsIn()); + EXPECT_FALSE(bn_i16over.FitsIn()); + EXPECT_FALSE(bn_i16over.FitsIn()); + EXPECT_TRUE(bn_i16over.FitsIn()); + + const Bignum bn_i32max(kI32max); + const Bignum bn_i32over(kI32max + 1); + EXPECT_FALSE(bn_i32max.FitsIn()); + EXPECT_FALSE(bn_i32max.FitsIn()); + EXPECT_TRUE(bn_i32max.FitsIn()); + EXPECT_FALSE(bn_i32over.FitsIn()); + EXPECT_FALSE(bn_i32over.FitsIn()); + EXPECT_FALSE(bn_i32over.FitsIn()); + + Bignum bn_i64max(kI64max); + EXPECT_TRUE(bn_i64max.FitsIn()); + + // 2^63, need to use string constructor. + Bignum bn0 = *Bn("9223372036854775808"); + EXPECT_FALSE(bn0.FitsIn()); + + const Bignum bn_i8min(kI8min); + const Bignum bn_i8under(kI8min - 1); + EXPECT_TRUE(bn_i8min.FitsIn()); + EXPECT_TRUE(bn_i8min.FitsIn()); + EXPECT_TRUE(bn_i8min.FitsIn()); + EXPECT_FALSE(bn_i8under.FitsIn()); + EXPECT_TRUE(bn_i8under.FitsIn()); + EXPECT_TRUE(bn_i8under.FitsIn()); + + const Bignum bn_i16min(kI16min); + const Bignum bn_i16under(kI16min - 1); + EXPECT_FALSE(bn_i16min.FitsIn()); + EXPECT_TRUE(bn_i16min.FitsIn()); + EXPECT_TRUE(bn_i16min.FitsIn()); + EXPECT_FALSE(bn_i16under.FitsIn()); + EXPECT_FALSE(bn_i16under.FitsIn()); + EXPECT_TRUE(bn_i16under.FitsIn()); + + const Bignum bn_i32min(kI32min); + const Bignum bn_i32under(kI32min - 1); + EXPECT_FALSE(bn_i32min.FitsIn()); + EXPECT_FALSE(bn_i32min.FitsIn()); + EXPECT_TRUE(bn_i32min.FitsIn()); + EXPECT_FALSE(bn_i32under.FitsIn()); + EXPECT_FALSE(bn_i32under.FitsIn()); + EXPECT_FALSE(bn_i32under.FitsIn()); + + Bignum bn_i64min(kI64min); + EXPECT_TRUE(bn_i64min.FitsIn()); +} + +TEST(BignumTest, FitsInBasicSanityChecks) { + Bignum pos42(42); + EXPECT_TRUE(pos42.FitsIn()); + EXPECT_TRUE(pos42.FitsIn()); + EXPECT_TRUE(pos42.FitsIn()); + EXPECT_TRUE(pos42.FitsIn()); + EXPECT_TRUE(pos42.FitsIn()); + EXPECT_TRUE(pos42.FitsIn()); + EXPECT_TRUE(pos42.FitsIn()); + EXPECT_TRUE(pos42.FitsIn()); + + Bignum neg42(-42); + EXPECT_TRUE(neg42.FitsIn()); + EXPECT_FALSE(neg42.FitsIn()); + EXPECT_TRUE(neg42.FitsIn()); + EXPECT_FALSE(neg42.FitsIn()); + EXPECT_TRUE(neg42.FitsIn()); + EXPECT_FALSE(neg42.FitsIn()); + EXPECT_TRUE(neg42.FitsIn()); + EXPECT_FALSE(neg42.FitsIn()); +} + +TEST(BignumTest, UnsignedCasting) { + EXPECT_EQ(Bignum(300).Cast(), static_cast(300)); + + // Negative to unsigned -> maximum value. + Bignum bn1(-1); + EXPECT_EQ(bn1.Cast(), std::numeric_limits::max()); + EXPECT_EQ(bn1.Cast(), std::numeric_limits::max()); + EXPECT_EQ(bn1.Cast(), std::numeric_limits::max()); + EXPECT_EQ(bn1.Cast(), std::numeric_limits::max()); + + // Big values via decimal: 2^64, 2^128-1, 2^128 + Bignum bn2 = *Bn("18446744073709551616"); + EXPECT_EQ(bn2.Cast(), 0); + + // 2^128 - 1 -> lower 64 bits = 2^64 - 1 + Bignum bn3 = *Bn("340282366920938463463374607431768211455"); + EXPECT_EQ(bn3.Cast(), std::numeric_limits::max()); + + // 2^128 -> lower 64 bits = 0 + Bignum bn4 = *Bn("340282366920938463463374607431768211456"); + EXPECT_EQ(bn4.Cast(), 0); +} + +TEST(BignumTest, SignedCasting) { + // In-range positives stay the same. + Bignum bn0(127); + EXPECT_EQ(bn0.Cast(), 127); + + // Positive overflow wraps into negative range. + Bignum bn1(128); + EXPECT_EQ(bn1.Cast(), -128); + + // In-range negatives stay the same. + Bignum bn2(-128); + EXPECT_EQ(bn2.Cast(), -128); + + // Negative overflow wraps into positive range. + Bignum bn3(-129); + EXPECT_EQ(bn3.Cast(), 127); + + // +2^63 over to -2^63 in signed int64. + Bignum bn4 = *Bn("9223372036854775808"); + EXPECT_EQ(bn4.Cast(), std::numeric_limits::min()); + + // +2^64 - 1 casts to -1. + Bignum bn5 = *Bn("18446744073709551615"); + EXPECT_EQ(bn5.Cast(), -1); + + // -(2^63) - 1 casts to 2^63 - 3 + Bignum bn6 = *Bn("-9223372036854775809"); + EXPECT_EQ(bn6.Cast(), std::numeric_limits::max()); +} + +TEST(BignumTest, CastingLargeResidues) { + // 2^80 + 0x1234 -> low 64 bits should be 0x1234. + Bignum bn0 = *Bn("1208925819614629174710836"); + EXPECT_EQ(bn0.Cast(), 0x1234); + EXPECT_EQ(bn0.Cast(), 0x1234); + + // -(2^80 + 1) -> low 64 bits = 0xFFFFFFFFFFFFFFFF; signed = -1 + Bignum bn1 = *Bn("-1208925819614629174706177"); + EXPECT_EQ(bn1.Cast(), std::numeric_limits::max()); + EXPECT_EQ(bn1.Cast(), -1); +} + +TEST(BignumTest, UnaryOperators) { + EXPECT_EQ(+Bignum(0), Bignum(0)); + EXPECT_EQ(-Bignum(0), Bignum(0)); + EXPECT_EQ(+Bignum(+42), Bignum(+42)); + EXPECT_EQ(+Bignum(-42), Bignum(-42)); + EXPECT_EQ(-Bignum(+42), Bignum(-42)); + EXPECT_EQ(-Bignum(-17), Bignum(+17)); +} + +TEST(BignumTest, Addition) { + // Basic combinations of signs + EXPECT_EQ(Bignum(+5) + Bignum(+3), Bignum(+8)); + EXPECT_EQ(Bignum(-5) + Bignum(-3), Bignum(-8)); + EXPECT_EQ(Bignum(+5) + Bignum(-3), Bignum(+2)); + EXPECT_EQ(Bignum(-5) + Bignum(+3), Bignum(-2)); + + // Identity and additive inverse + EXPECT_EQ(Bignum(42) + Bignum(0), Bignum(42)); + EXPECT_EQ(Bignum(0) + Bignum(42), Bignum(42)); + EXPECT_EQ(Bignum(5) + Bignum(-5), Bignum(0)); + + // Carry propagation + const auto bn_u64max = Bignum(kU64max); + EXPECT_EQ(bn_u64max + Bignum(1), *Bn("18446744073709551616")); + EXPECT_EQ(bn_u64max + bn_u64max, *Bn("36893488147419103230")); + + // Aliasing (x += x) + Bignum a = bn_u64max; + a += a; + EXPECT_EQ(a, *Bn("36893488147419103230")); +} + +TEST(BignumTest, Subtraction) { + // Basic combinations of signs + EXPECT_EQ(Bignum(+5) - Bignum(+3), Bignum(+2)); + EXPECT_EQ(Bignum(+3) - Bignum(+5), Bignum(-2)); + EXPECT_EQ(Bignum(-5) - Bignum(-3), Bignum(-2)); + EXPECT_EQ(Bignum(+5) - Bignum(-3), Bignum(+8)); + EXPECT_EQ(Bignum(-5) - Bignum(+3), Bignum(-8)); + + // Identity and subtracting to zero + EXPECT_EQ(Bignum(42) - Bignum(0), Bignum(42)); + EXPECT_EQ(Bignum(0) - Bignum(42), Bignum(-42)); + EXPECT_EQ(Bignum(42) - Bignum(42), Bignum(0)); + + // Borrow propagation + const auto bn_u64max = Bignum(kU64max); + const auto two_pow_64 = *Bn("18446744073709551616"); + EXPECT_EQ(two_pow_64 - Bignum(1), bn_u64max); + + // Aliasing (x -= x) + Bignum a(100); + a -= a; + EXPECT_EQ(a, Bignum(0)); +} + +TEST(BignumTest, MixedOperations) { + Bignum a(10), b(20), c(-5); + EXPECT_EQ((a + b) - a, b); + EXPECT_EQ((b - a) + a, b); + EXPECT_EQ(a + c, Bignum(5)); + EXPECT_EQ(c - a, Bignum(-15)); +} + +TEST(BignumTest, LargeNumberArithmetic) { + const auto two_pow_128_minus_1 = + *Bn("340282366920938463463374607431768211455"); + const auto two_pow_128 = *Bn("340282366920938463463374607431768211456"); + const auto two_pow_64 = *Bn("18446744073709551616"); + + // Test multi-bigit carry propagation: (2^128 - 1) + 1 = 2^128 + EXPECT_EQ(two_pow_128_minus_1 + Bignum(1), two_pow_128); + + // Test multi-bigit borrow propagation: 2^128 - 1 = (2^128 - 1) + EXPECT_EQ(two_pow_128 - Bignum(1), two_pow_128_minus_1); + + // Subtraction resulting in a sign change with large numbers. + const auto neg_two_pow_128_minus_1 = + *Bn("-340282366920938463463374607431768211455"); + EXPECT_EQ(Bignum(1) - two_pow_128, neg_two_pow_128_minus_1); + + // Addition of large numbers with different signs (triggers subtraction). + EXPECT_EQ(two_pow_128 + Bignum(-1), two_pow_128_minus_1); + + // Subtraction of large numbers with different signs (triggers addition). + EXPECT_EQ(two_pow_128_minus_1 - Bignum(-1), two_pow_128); + + // Add two different large positive numbers: 2^128 + 2^64 + const auto sum_128_64 = *Bn("340282366920938463481821351505477763072"); + EXPECT_EQ(two_pow_128 + two_pow_64, sum_128_64); + + // Subtract two different large positive numbers: 2^128 - 2^64 + const auto diff_128_64 = *Bn("340282366920938463444927863358058659840"); + EXPECT_EQ(two_pow_128 - two_pow_64, diff_128_64); +} + +TEST(BignumTest, LeftShift) { + EXPECT_EQ((Bignum(1) << 0), Bignum(1)); + EXPECT_EQ((Bignum(1) << 1), Bignum(2)); + EXPECT_EQ((Bignum(1) << 63), *Bn("9223372036854775808")); + EXPECT_EQ((Bignum(1) << 64), *Bn("18446744073709551616")); + EXPECT_EQ((Bignum(-1) << 64), *Bn("-18446744073709551616")); + + const auto bn_u64max = Bignum(kU64max); + const auto two_pow_128_minus_two_pow_64 = + *Bn("340282366920938463444927863358058659840"); + EXPECT_EQ((bn_u64max << 64), two_pow_128_minus_two_pow_64); + + Bignum a(5); + a <<= 2; + EXPECT_EQ(a, Bignum(20)); + + // Shifting zero or by zero amount. + EXPECT_EQ((Bignum(0) << 100), Bignum(0)); + EXPECT_EQ((Bignum(123) << 0), Bignum(123)); +} + +TEST(BignumTest, RightShift) { + EXPECT_EQ((Bignum(8) >> 0), Bignum(8)); + EXPECT_EQ((Bignum(8) >> 3), Bignum(1)); + EXPECT_EQ((Bignum(8) >> 4), Bignum(0)); + EXPECT_EQ((Bignum(7) >> 2), Bignum(1)); + EXPECT_EQ((Bignum(-7) >> 2), Bignum(-1)); + + const auto two_pow_64 = *Bn("18446744073709551616"); + EXPECT_EQ((two_pow_64 >> 1), *Bn("9223372036854775808")); + EXPECT_EQ((two_pow_64 >> 64), Bignum(1)); + EXPECT_EQ((two_pow_64 >> 65), Bignum(0)); + + const auto u64max = Bignum(std::numeric_limits::max()); + const auto two_pow_128_minus_1 = + *Bn("340282366920938463463374607431768211455"); + EXPECT_EQ((two_pow_128_minus_1 >> 64), u64max); + + const auto two_pow_65_minus_1 = *Bn("36893488147419103231"); + EXPECT_EQ((two_pow_65_minus_1 >> 63), Bignum(3)); + + Bignum b(20); + b >>= 2; + EXPECT_EQ(b, Bignum(5)); + + // Shifting zero or by zero amount. + EXPECT_EQ((Bignum(0) >> 100), Bignum(0)); + EXPECT_EQ((Bignum(123) >> 0), Bignum(123)); + + // Shifting to zero. + EXPECT_EQ((Bignum(100) >> 100), Bignum(0)); +} + +TEST(BignumTest, Multiplication) { + // Zero + EXPECT_EQ(Bignum(123) * Bignum(0), Bignum(0)); + EXPECT_EQ(Bignum(0) * Bignum(456), Bignum(0)); + + // Identity + EXPECT_EQ(Bignum(123) * Bignum(1), Bignum(123)); + EXPECT_EQ(Bignum(1) * Bignum(456), Bignum(456)); + EXPECT_EQ(Bignum(-123) * Bignum(1), Bignum(-123)); + + // Signs + EXPECT_EQ(Bignum(10) * Bignum(20), Bignum(200)); + EXPECT_EQ(Bignum(-10) * Bignum(20), Bignum(-200)); + EXPECT_EQ(Bignum(10) * Bignum(-20), Bignum(-200)); + EXPECT_EQ(Bignum(-10) * Bignum(-20), Bignum(200)); + + // Simple carry + const auto bn_u32max = Bignum(kU32max); + EXPECT_EQ(bn_u32max * Bignum(2), *Bn("8589934590")); + + // 1x1 bigit fast path + const auto bn_u64max = Bignum(kU64max); + EXPECT_EQ(Bignum(2) * bn_u64max, *Bn("36893488147419103230")); + + // 1xN bigit multiplication + const auto two_pow_128_minus_1 = + *Bn("340282366920938463463374607431768211455"); + const auto res_1xN = *Bn("680564733841876926926749214863536422910"); + EXPECT_EQ(two_pow_128_minus_1 * Bignum(2), res_1xN); + + // Check that aliasing doesn't cause problems. + Bignum a(100); + a *= a; + EXPECT_EQ(a, Bignum(10000)); + + Bignum b = *Bn("10000000000000000000"); // > 64 bits + b *= b; + EXPECT_EQ(b, *Bn("100000000000000000000000000000000000000")); + + // Karatsuba threshold test + // (2^128 - 1) * (2^128 - 1) = 2^256 - 2*2^128 + 1 + // This should trigger karatsuba if the threshold is low enough. + EXPECT_EQ(two_pow_128_minus_1 * two_pow_128_minus_1, + *Bn("115792089237316195423570985008687907852589419931798687112530" + "834793049593217025")); + + // Karatsuba with uneven operands + EXPECT_EQ(two_pow_128_minus_1 * bn_u64max, + *Bn("6277101735386680763495507056286727952620534092958556749825")); +} + +TEST(BignumTest, CountrZero) { + EXPECT_EQ(countr_zero(Bignum(0)), 0); + EXPECT_EQ(countr_zero(Bignum(1)), 0); + EXPECT_EQ(countr_zero(Bignum(7)), 0); + EXPECT_EQ(countr_zero(Bignum(-7)), 0); + + EXPECT_EQ(countr_zero(Bignum(2)), 1); + EXPECT_EQ(countr_zero(Bignum(8)), 3); + EXPECT_EQ(countr_zero(Bignum(10)), 1); // 0b1010 + EXPECT_EQ(countr_zero(Bignum(12)), 2); // 0b1100 + + auto two_pow_64 = Bignum(1) << 64; + EXPECT_EQ(countr_zero(two_pow_64), 64); + + auto large_shifted = Bignum(6) << 100; // 0b110 << 100 + EXPECT_EQ(countr_zero(large_shifted), 101); + + auto neg_large_shifted = Bignum(-5) << 200; + EXPECT_EQ(countr_zero(neg_large_shifted), 200); +} + +TEST(BignumTest, is_bit_set) { + EXPECT_FALSE(Bignum(0).is_bit_set(0)); + EXPECT_FALSE(Bignum(0).is_bit_set(100)); + + // 5 = 0b101 + Bignum five(5); + EXPECT_TRUE(five.is_bit_set(0)); + EXPECT_FALSE(five.is_bit_set(1)); + EXPECT_TRUE(five.is_bit_set(2)); + EXPECT_FALSE(five.is_bit_set(3)); + + // Negative numbers should test the magnitude. + Bignum neg_five(-5); + EXPECT_TRUE(neg_five.is_bit_set(0)); + EXPECT_FALSE(neg_five.is_bit_set(1)); + EXPECT_TRUE(neg_five.is_bit_set(2)); + + // Test edges of and across bigits. + Bignum high_is_bit_set_63 = Bignum(1) << 63; + EXPECT_FALSE(high_is_bit_set_63.is_bit_set(62)); + EXPECT_TRUE(high_is_bit_set_63.is_bit_set(63)); + EXPECT_FALSE(high_is_bit_set_63.is_bit_set(64)); + + Bignum cross_bigit = (Bignum(1) << 100) + Bignum(1); + EXPECT_TRUE(cross_bigit.is_bit_set(0)); + EXPECT_TRUE(cross_bigit.is_bit_set(100)); + EXPECT_FALSE(cross_bigit.is_bit_set(50)); + EXPECT_FALSE(cross_bigit.is_bit_set(1000)); +} + +TEST(BignumTest, Pow) { + // Edge cases + EXPECT_EQ(Bignum(0).Pow(0), Bignum(1)); + EXPECT_EQ(Bignum(123).Pow(0), Bignum(1)); + EXPECT_EQ(Bignum(0).Pow(123), Bignum(0)); + EXPECT_EQ(Bignum(1).Pow(12345), Bignum(1)); + + // Negative base + EXPECT_EQ(Bignum(-1).Pow(2), Bignum(1)); + EXPECT_EQ(Bignum(-1).Pow(3), Bignum(-1)); + EXPECT_EQ(Bignum(-2).Pow(2), Bignum(4)); + EXPECT_EQ(Bignum(-2).Pow(3), Bignum(-8)); + + // Basic powers + EXPECT_EQ(Bignum(2).Pow(10), Bignum(1024)); + EXPECT_EQ(Bignum(3).Pow(5), Bignum(243)); + EXPECT_EQ(Bignum(10).Pow(18), *Bn("1000000000000000000")); + + // Large exponent + Bignum two_pow_100 = Bignum(1) << 100; + EXPECT_EQ(Bignum(2).Pow(100), two_pow_100); + + // Large base + Bignum ten_pow_19 = *Bn("10000000000000000000"); + Bignum ten_pow_38 = *Bn("100000000000000000000000000000000000000"); + EXPECT_EQ(ten_pow_19.Pow(2), ten_pow_38); +} + +TEST(BignumTest, SetZero) { + Bignum a(123); + a.set_zero(); + EXPECT_TRUE(a.is_zero()); + + Bignum b(-456); + b.set_zero(); + EXPECT_EQ(b, Bignum(0)); +} + +TEST(BignumTest, SetNegativeSetPositive) { + Bignum a(42); + a.set_negative(); + EXPECT_TRUE(a.is_negative()); + EXPECT_EQ(a, Bignum(-42)); + + a.set_negative(false); + EXPECT_FALSE(a.is_negative()); + EXPECT_EQ(a, Bignum(42)); + + // set_negative() has no effect on zero. + Bignum b(0); + b.set_negative(); + EXPECT_FALSE(b.is_negative()); + EXPECT_EQ(b, Bignum(0)); +} + +TEST(BignumTest, Comparisons) { + EXPECT_EQ(*Bn("123"), Bignum(123)); + EXPECT_EQ(*Bn("-123"), Bignum(-123)); + EXPECT_NE(*Bn("123"), Bignum(-123)); + EXPECT_EQ(Bignum(0), Bignum(0)); + EXPECT_NE(Bignum(0), Bignum(1)); + + // Positive vs Positive + EXPECT_LT(Bignum(100), Bignum(200)); + EXPECT_GT(Bignum(200), Bignum(100)); + EXPECT_LE(Bignum(100), Bignum(200)); + EXPECT_GE(Bignum(200), Bignum(100)); + + // Negative vs Negative + EXPECT_LT(Bignum(-200), Bignum(-100)); + EXPECT_GT(Bignum(-100), Bignum(-200)); + EXPECT_GE(Bignum(-100), Bignum(-200)); + EXPECT_LE(Bignum(-200), Bignum(-100)); + + // Positive vs Negative + EXPECT_LT(Bignum(-10), Bignum(10)); + EXPECT_GT(Bignum(10), Bignum(-10)); + + // Zero + EXPECT_LT(Bignum(-1), Bignum(0)); + EXPECT_LT(Bignum(0), Bignum(1)); + EXPECT_GT(Bignum(0), Bignum(-1)); + EXPECT_GT(Bignum(1), Bignum(0)); + + // Multi-bigit + const auto two_pow_64 = *Bn("18446744073709551616"); + EXPECT_LT(Bignum(0), two_pow_64); + EXPECT_GT(two_pow_64, Bignum(0)); + EXPECT_LT(Bignum(-1), two_pow_64); + EXPECT_GT(two_pow_64, Bignum(-1)); + + EXPECT_LE(Bignum(100), Bignum(200)); + EXPECT_LE(Bignum(100), Bignum(100)); + EXPECT_LE(Bignum(-200), Bignum(-100)); + EXPECT_LE(Bignum(-100), Bignum(-100)); + EXPECT_GT(Bignum(200), Bignum(100)); + + EXPECT_GE(Bignum(200), Bignum(100)); + EXPECT_GE(Bignum(100), Bignum(100)); + EXPECT_GE(Bignum(-100), Bignum(-200)); + EXPECT_GE(Bignum(-100), Bignum(-100)); + EXPECT_LT(Bignum(100), Bignum(200)); + + EXPECT_LE(Bignum(0), Bignum(0)); + EXPECT_GE(Bignum(0), Bignum(0)); +} + +// RAII wrapper for OpenSSL BIGNUM +class OpenSSLBignum { + public: + OpenSSLBignum() : bn_(BN_new()) {} + + // Construct from a decimal number in a string. + // + // We take decimal as a string so that it's explicitly zero-terminated. + explicit OpenSSLBignum(const std::string& decimal) : bn_(BN_new()) { + BN_dec2bn(&bn_, decimal.data()); + } + + explicit OpenSSLBignum(uint64_t value) : bn_(BN_new()) { + BN_set_word(bn_, value); + } + + ~OpenSSLBignum() { BN_free(bn_); } + + OpenSSLBignum(OpenSSLBignum&& other) noexcept : bn_(other.bn_) { + other.bn_ = nullptr; + } + + OpenSSLBignum& operator=(OpenSSLBignum&& other) noexcept { + if (this != &other) { + BN_free(bn_); + bn_ = other.bn_; + other.bn_ = nullptr; + } + return *this; + } + + OpenSSLBignum(const OpenSSLBignum& other) : bn_(BN_dup(other.bn_)) {} + + OpenSSLBignum& operator=(const OpenSSLBignum& other) { + if (this != &other) { + BN_copy(bn_, other.bn_); + } + return *this; + } + + BIGNUM* get() const { return bn_; } + + private: + BIGNUM* bn_; +}; + +// Power of two for fast modulo. +const int kRandomBignumCount = 128; + +static std::vector GenerateRandomNumberStrings( + absl::BitGenRef bitgen, int bits) { + std::vector numbers; + numbers.reserve(kRandomBignumCount); + + for (int i = 0; i < kRandomBignumCount; ++i) { + std::string num; + + // Generate approximately `bits` worth of decimal digits + int decimal_digits = (bits * 3) / 10; // log10(2^bits) ≈ bits * 0.301 + + // First digit can't be zero + absl::StrAppend(&num, absl::StrFormat("%d", absl::Uniform(bitgen, 1, 9))); + for (int j = 1; j < decimal_digits; ++j) { + num += absl::Uniform(bitgen, '0', '9'); + } + + numbers.push_back(num); + } + + return numbers; +} + +// Basic correctness test to ensure OpenSSL integration is working +TEST(BignumTest, OpenSSLIntegration) { + OpenSSLBignum a(123); + OpenSSLBignum b(456); + OpenSSLBignum result; + + BN_add(result.get(), a.get(), b.get()); + + char* str = BN_bn2dec(result.get()); + EXPECT_STREQ(str, "579"); + OPENSSL_free(str); +} + +TEST(BignumTest, ResultsMatch) { + // Test that and OpenSSL produce the same results + const Bignum w_a(12345); + const Bignum w_b(67890); + Bignum w_result = w_a + w_b; + + const OpenSSLBignum ssl_a(12345); + const OpenSSLBignum ssl_b(67890); + OpenSSLBignum ssl_result; + BN_add(ssl_result.get(), ssl_a.get(), ssl_b.get()); + + char* ssl_str = BN_bn2dec(ssl_result.get()); + std::string w_str = absl::StrFormat("%v", w_result); + + EXPECT_EQ(w_str, std::string(ssl_str)); + OPENSSL_free(ssl_str); +} + +// Different number sizes for benchmarking. +enum class NumberSizeClass : uint32_t { + kSmall = 64, + kMedium = 256, + kLarge = 1024, + kHuge = 4096, + kMega = 18000 +}; + +std::vector RandomNumberStrings(absl::BitGenRef bitgen, + NumberSizeClass size_class) { + return GenerateRandomNumberStrings(bitgen, static_cast(size_class)); +} + +class VsOpenSSLTest + : public TestWithParam> { + protected: + // Returns a vector of pairs of numbers (as decimal strings) based on the + // parameters that the test suite was created with. + // + // E.g. If the test suite was instantiate with (kSmall, kHuge) as the number + // classes, then this will return a small value on the left and a huge value + // on the right. (kHuge, kSmall) would return the opposite. + std::vector> Numbers( + absl::BitGenRef bitgen) { + auto numbers0 = RandomNumberStrings(bitgen, GetParam().first); + auto numbers1 = RandomNumberStrings(bitgen, GetParam().second); + ABSL_CHECK_EQ(numbers0.size(), numbers1.size()); + + std::vector> numbers; + numbers.reserve(numbers0.size()); + + for (size_t i = 0; i < numbers0.size(); ++i) { + numbers.emplace_back(numbers0[i], numbers1[i]); + } + return numbers; + } +}; + +TEST_P(VsOpenSSLTest, MultiplyCorrect) { + absl::BitGen bitgen(S2Testing::MakeTaggedSeedSeq( + "MULTIPLY_CORRECT", absl::LogInfoStreamer(__FILE__, __LINE__).stream())); + + // Test that multiplication produces the same results as OpenSSL. + BN_CTX* ctx = BN_CTX_new(); + for (const auto& [a, b] : Numbers(bitgen)) { + const Bignum bn_a = *Bignum::FromString(a); + const Bignum bn_b = *Bignum::FromString(b); + const Bignum bn_result = bn_a * bn_b; + + const OpenSSLBignum ssl_a(a); + const OpenSSLBignum ssl_b(b); + OpenSSLBignum ssl_result; + BN_mul(ssl_result.get(), ssl_a.get(), ssl_b.get(), ctx); + + // Compare string representations + char* ssl_str = BN_bn2dec(ssl_result.get()); + std::string bn_str = absl::StrFormat("%v", bn_result); + + EXPECT_EQ(bn_str, std::string(ssl_str)) + << "Mismatch for multiplication" + << "\nBignum result: " << bn_str.substr(0, 100) << "..." + << "\nOpenSSL result: " << std::string(ssl_str).substr(0, 100) << "..."; + OPENSSL_free(ssl_str); + } + BN_CTX_free(ctx); +} + +TEST_P(VsOpenSSLTest, AdditionCorrect) { + absl::BitGen bitgen(S2Testing::MakeTaggedSeedSeq( + "ADDITION_CORRECT", absl::LogInfoStreamer(__FILE__, __LINE__).stream())); + + // Test that addition produces correct results by comparing to OpenSSL. + for (const auto& [a, b] : Numbers(bitgen)) { + const Bignum bn_a = *Bignum::FromString(a); + const Bignum bn_b = *Bignum::FromString(b); + + const Bignum bn_result = bn_a + bn_b; + + const OpenSSLBignum ssl_a(a); + const OpenSSLBignum ssl_b(b); + OpenSSLBignum ssl_result; + BN_add(ssl_result.get(), ssl_a.get(), ssl_b.get()); + + // Compare string representations + char* ssl_str = BN_bn2dec(ssl_result.get()); + std::string bn_str = absl::StrFormat("%v", bn_result); + + EXPECT_EQ(bn_str, std::string(ssl_str)) + << "Mismatch for addition" + << "\nBignum result: " << bn_str.substr(0, 100) << "..." + << "\nOpenSSL result: " << std::string(ssl_str).substr(0, 100) << "..."; + OPENSSL_free(ssl_str); + } +} + +TEST_P(VsOpenSSLTest, SubtractionCorrect) { + absl::BitGen bitgen(S2Testing::MakeTaggedSeedSeq( + "SUBTRACTION_CORRECT", + absl::LogInfoStreamer(__FILE__, __LINE__).stream())); + + // Test that subtraction produces correct results by comparing to + // OpenSSL. + for (const auto& [a, b] : Numbers(bitgen)) { + const Bignum bn_a = *Bignum::FromString(a); + const Bignum bn_b = *Bignum::FromString(b); + + const Bignum bn_result = bn_a - bn_b; + + const OpenSSLBignum ssl_a(a); + const OpenSSLBignum ssl_b(b); + OpenSSLBignum ssl_result; + BN_sub(ssl_result.get(), ssl_a.get(), ssl_b.get()); + + // Compare string representations + char* ssl_str = BN_bn2dec(ssl_result.get()); + std::string bn_str = absl::StrFormat("%v", bn_result); + + EXPECT_EQ(bn_str, std::string(ssl_str)) + << "Mismatch for addition" + << "\nBignum result: " << bn_str.substr(0, 100) << "..." + << "\nOpenSSL result: " << std::string(ssl_str).substr(0, 100) << "..."; + OPENSSL_free(ssl_str); + } +} + +// clang-format off +INSTANTIATE_TEST_SUITE_P( + VsOpenSSL, VsOpenSSLTest, ::testing::Values( + std::make_pair(NumberSizeClass::kSmall, NumberSizeClass::kSmall), + std::make_pair(NumberSizeClass::kSmall, NumberSizeClass::kHuge), + std::make_pair(NumberSizeClass::kHuge, NumberSizeClass::kSmall), + std::make_pair(NumberSizeClass::kMedium, NumberSizeClass::kMedium), + std::make_pair(NumberSizeClass::kLarge, NumberSizeClass::kLarge), + std::make_pair(NumberSizeClass::kHuge, NumberSizeClass::kHuge), + std::make_pair(NumberSizeClass::kMega, NumberSizeClass::kMega))); +// clang-format on + +// TODO: Enable once benchmark is integrated. +#if 0 + +template +void BignumBinaryOpBenchmark(benchmark::State& state, + const std::vector& number_strings, + BinaryOp op) { + std::vector numbers; + for (const auto& str : number_strings) { + numbers.push_back(*Bignum::FromString(str)); + } + + Bignum result; + size_t idx = 0; + for (auto _ : state) { + const Bignum& a = numbers[(idx + 0) % kRandomBignumCount]; + const Bignum& b = numbers[(idx + 1) % kRandomBignumCount]; + result = op(a, b); + benchmark::DoNotOptimize(result); + ++idx; + } +} + +void BignumPowBenchmark(benchmark::State& state, + const std::vector& number_strings, + int exponent) { + std::vector numbers; + for (const auto& str : number_strings) { + numbers.push_back(*Bignum::FromString(str)); + } + + Bignum result; + size_t idx = 0; + for (auto _ : state) { + const Bignum& base = numbers[(idx + 0) % kRandomBignumCount]; + result = base.Pow(exponent); + benchmark::DoNotOptimize(result); + ++idx; + } +} + +template +void OpenSSLBinaryOpBenchmark(benchmark::State& state, + const std::vector& number_strings, + BinaryOp op) { + std::vector numbers; + for (const auto& str : number_strings) { + numbers.emplace_back(str); + } + + size_t idx = 0; + + for (auto _ : state) { + OpenSSLBignum result; + const OpenSSLBignum& a = numbers[(idx + 0) % kRandomBignumCount]; + const OpenSSLBignum& b = numbers[(idx + 1) % kRandomBignumCount]; + op(result.get(), a.get(), b.get()); + benchmark::DoNotOptimize(result.get()); + ++idx; + } +} + +template +void OpenSSLMulOpBenchmark(benchmark::State& state, + const std::vector& number_strings, + MulOp op) { + std::vector numbers; + for (const auto& str : number_strings) { + numbers.emplace_back(str); + } + + BN_CTX* ctx = BN_CTX_new(); + size_t idx = 0; + + for (auto _ : state) { + OpenSSLBignum result; + const OpenSSLBignum& a = numbers[(idx + 0) % kRandomBignumCount]; + const OpenSSLBignum& b = numbers[(idx + 1) % kRandomBignumCount]; + op(result.get(), a.get(), b.get(), ctx); + benchmark::DoNotOptimize(result.get()); + ++idx; + } + + BN_CTX_free(ctx); +} + +void OpenSSLPowBenchmark(benchmark::State& state, + const std::vector& number_strings, + int exponent) { + std::vector numbers; + for (const auto& str : number_strings) { + numbers.emplace_back(str); + } + + const OpenSSLBignum exp(exponent); + BN_CTX* ctx = BN_CTX_new(); + size_t idx = 0; + + for (auto _ : state) { + OpenSSLBignum result; + const OpenSSLBignum& base = numbers[(idx + 0) % kRandomBignumCount]; + BN_exp(result.get(), base.get(), exp.get(), ctx); + benchmark::DoNotOptimize(result.get()); + ++idx; + } + + BN_CTX_free(ctx); +} + +std::vector SmallNumbers(absl::BitGenRef bitgen) { + return RandomNumberStrings(bitgen, NumberSizeClass::kSmall); +} + +std::vector MediumNumbers(absl::BitGenRef bitgen) { + return RandomNumberStrings(bitgen, NumberSizeClass::kMedium); +} + +std::vector LargeNumbers(absl::BitGenRef bitgen) { + return RandomNumberStrings(bitgen, NumberSizeClass::kLarge); +} + +std::vector HugeNumbers(absl::BitGenRef bitgen) { + return RandomNumberStrings(bitgen, NumberSizeClass::kHuge); +} + +std::vector MegaNumbers(absl::BitGenRef bitgen) { + return RandomNumberStrings(bitgen, NumberSizeClass::kMega); +} + +void BM_Bignum_AddSmall(benchmark::State& state) { + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, SmallNumbers(bitgen), std::plus{}); +} +BENCHMARK(BM_Bignum_AddSmall); + +void BM_Bignum_AddMedium(benchmark::State& state) { + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, MediumNumbers(bitgen), std::plus{}); +} +BENCHMARK(BM_Bignum_AddMedium); + +void BM_Bignum_AddLarge(benchmark::State& state) { + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, LargeNumbers(bitgen), std::plus{}); +} +BENCHMARK(BM_Bignum_AddLarge); + +void BM_Bignum_AddHuge(benchmark::State& state) { + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, HugeNumbers(bitgen), std::plus{}); +} +BENCHMARK(BM_Bignum_AddHuge); + +void BM_Bignum_AddMega(benchmark::State& state) { + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, MegaNumbers(bitgen), std::plus{}); +} +BENCHMARK(BM_Bignum_AddMega); + +void BM_OpenSSL_AddSmall(benchmark::State& state) { + std::mt19937_64 bitgen; + OpenSSLBinaryOpBenchmark(state, SmallNumbers(bitgen), BN_add); +} +BENCHMARK(BM_OpenSSL_AddSmall); + +void BM_OpenSSL_AddMedium(benchmark::State& state) { + std::mt19937_64 bitgen; + OpenSSLBinaryOpBenchmark(state, MediumNumbers(bitgen), BN_add); +} +BENCHMARK(BM_OpenSSL_AddMedium); + +void BM_OpenSSL_AddLarge(benchmark::State& state) { + std::mt19937_64 bitgen; + OpenSSLBinaryOpBenchmark(state, LargeNumbers(bitgen), BN_add); +} +BENCHMARK(BM_OpenSSL_AddLarge); + +void BM_OpenSSL_AddHuge(benchmark::State& state) { + std::mt19937_64 bitgen; + OpenSSLBinaryOpBenchmark(state, HugeNumbers(bitgen), BN_add); +} +BENCHMARK(BM_OpenSSL_AddHuge); + +void BM_OpenSSL_AddMega(benchmark::State& state) { + std::mt19937_64 bitgen; + OpenSSLBinaryOpBenchmark(state, MegaNumbers(bitgen), BN_add); +} +BENCHMARK(BM_OpenSSL_AddMega); + +void BM_Bignum_MulSmall(benchmark::State& state) { + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, SmallNumbers(bitgen), + std::multiplies{}); +} +BENCHMARK(BM_Bignum_MulSmall); + +void BM_Bignum_MulMedium(benchmark::State& state) { + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, MediumNumbers(bitgen), + std::multiplies{}); +} +BENCHMARK(BM_Bignum_MulMedium); + +void BM_Bignum_MulLarge(benchmark::State& state) { + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, LargeNumbers(bitgen), + std::multiplies{}); +} +BENCHMARK(BM_Bignum_MulLarge); + +void BM_Bignum_MulHuge(benchmark::State& state) { + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, HugeNumbers(bitgen), + std::multiplies{}); +} +BENCHMARK(BM_Bignum_MulHuge); + +void BM_Bignum_MulMega(benchmark::State& state) { + std::mt19937_64 bitgen; + BignumBinaryOpBenchmark(state, MegaNumbers(bitgen), + std::multiplies{}); +} +BENCHMARK(BM_Bignum_MulMega); + +void BM_OpenSSL_MulSmall(benchmark::State& state) { + std::mt19937_64 bitgen; + OpenSSLMulOpBenchmark(state, SmallNumbers(bitgen), BN_mul); +} +BENCHMARK(BM_OpenSSL_MulSmall); + +void BM_OpenSSL_MulMedium(benchmark::State& state) { + std::mt19937_64 bitgen; + OpenSSLMulOpBenchmark(state, MediumNumbers(bitgen), BN_mul); +} +BENCHMARK(BM_OpenSSL_MulMedium); + +void BM_OpenSSL_MulLarge(benchmark::State& state) { + std::mt19937_64 bitgen; + OpenSSLMulOpBenchmark(state, LargeNumbers(bitgen), BN_mul); +} +BENCHMARK(BM_OpenSSL_MulLarge); + +void BM_OpenSSL_MulHuge(benchmark::State& state) { + std::mt19937_64 bitgen; + OpenSSLMulOpBenchmark(state, HugeNumbers(bitgen), BN_mul); +} +BENCHMARK(BM_OpenSSL_MulHuge); + +void BM_OpenSSL_MulMega(benchmark::State& state) { + std::mt19937_64 bitgen; + OpenSSLMulOpBenchmark(state, MegaNumbers(bitgen), BN_mul); +} +BENCHMARK(BM_OpenSSL_MulMega); + +void BM_Bignum_PowSmall(benchmark::State& state) { + std::mt19937_64 bitgen; + BignumPowBenchmark(state, SmallNumbers(bitgen), 20); +} +BENCHMARK(BM_Bignum_PowSmall); + +void BM_Bignum_PowMedium(benchmark::State& state) { + std::mt19937_64 bitgen; + BignumPowBenchmark(state, MediumNumbers(bitgen), 10); +} +BENCHMARK(BM_Bignum_PowMedium); + +void BM_OpenSSL_PowSmall(benchmark::State& state) { + std::mt19937_64 bitgen; + OpenSSLPowBenchmark(state, SmallNumbers(bitgen), 20); +} +BENCHMARK(BM_OpenSSL_PowSmall); + +void BM_OpenSSL_PowMedium(benchmark::State& state) { + std::mt19937_64 bitgen; + OpenSSLPowBenchmark(state, MediumNumbers(bitgen), 10); +} +BENCHMARK(BM_OpenSSL_PowMedium); +#endif + +} // namespace exactfloat_internal diff --git a/src/s2/util/math/exactfloat/exactfloat.cc b/src/s2/util/math/exactfloat/exactfloat.cc index 3272834d..36cb781d 100644 --- a/src/s2/util/math/exactfloat/exactfloat.cc +++ b/src/s2/util/math/exactfloat/exactfloat.cc @@ -25,23 +25,14 @@ #include #include -#include -#include // for OPENSSL_free - #include "absl/log/absl_check.h" #include "absl/log/absl_log.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" -#ifndef OPENSSL_IS_BORINGSSL -#include "absl/container/fixed_array.h" // IWYU pragma: keep -#include "absl/numeric/bits.h" // IWYU pragma: keep -#endif - namespace exactfloat { using std::max; -using std::min; // To simplify the overflow/underflow logic, we limit the exponent and // precision range so that (2 * bn_exp_) does not overflow an "int". We take @@ -51,83 +42,6 @@ static_assert(ExactFloat::kMaxExp <= INT_MAX / 2 && ExactFloat::kMinExp - ExactFloat::kMaxPrec >= INT_MIN / 2, "exactfloat exponent might overflow"); -// We define a few simple extensions to the OpenSSL's BIGNUM interface. -// In some cases these depend on BIGNUM internal fields, so they might -// require tweaking if the BIGNUM implementation changes significantly. -// These are just thin wrappers for BoringSSL. - -#ifdef OPENSSL_IS_BORINGSSL - -inline static void BN_ext_set_uint64(BIGNUM* bn, uint64_t v) { - ABSL_CHECK(BN_set_u64(bn, v)); -} - -// Return the absolute value of a BIGNUM as a 64-bit unsigned integer. -// Requires that BIGNUM fits into 64 bits. -inline static uint64_t BN_ext_get_uint64(const BIGNUM* bn) { - uint64_t u64; - if (!BN_get_u64(bn, &u64)) { - ABSL_DLOG(FATAL) << "BN has " << BN_num_bits(bn) << " bits"; - return 0; - } - return u64; -} - -static int BN_ext_count_low_zero_bits(const BIGNUM* bn) { - return BN_count_low_zero_bits(bn); -} - -#else // !defined(OPENSSL_IS_BORINGSSL) - -// Set a BIGNUM to the given unsigned 64-bit value. -inline static void BN_ext_set_uint64(BIGNUM* bn, uint64_t v) { -#if BN_BITS2 == 64 - ABSL_CHECK(BN_set_word(bn, v)); -#else - static_assert(BN_BITS2 == 32, "at least 32 bit openssl build needed"); - ABSL_CHECK(BN_set_word(bn, static_cast(v >> 32))); - ABSL_CHECK(BN_lshift(bn, bn, 32)); - ABSL_CHECK(BN_add_word(bn, static_cast(v))); -#endif -} - -// Return the absolute value of a BIGNUM as a 64-bit unsigned integer. -// Requires that BIGNUM fits into 64 bits. -inline static uint64_t BN_ext_get_uint64(const BIGNUM* bn) { - ABSL_DCHECK_LE(BN_num_bytes(bn), sizeof(uint64_t)); -#if BN_BITS2 == 64 - return BN_get_word(bn); -#else - static_assert(BN_BITS2 == 32, "at least 32 bit openssl build needed"); - if (bn->top == 0) return 0; - if (bn->top == 1) return BN_get_word(bn); - ABSL_DCHECK_EQ(bn->top, 2); - return (static_cast(bn->d[1]) << 32) + bn->d[0]; -#endif -} - -static int BN_ext_count_low_zero_bits(const BIGNUM* bn) { - // In OpenSSL >= 1.1, BIGNUM is an opaque type, so d and top - // cannot be accessed. The bytes must be copied out at a ~25% - // performance penalty. - absl::FixedArray bytes(BN_num_bytes(bn)); - // "le" indicates little endian. - ABSL_CHECK_EQ(BN_bn2lebinpad(bn, bytes.data(), bytes.size()), bytes.size()); - - int count = 0; - for (unsigned char c : bytes) { - if (c == 0) { - count += 8; - } else { - count += absl::countr_zero(c); - break; - } - } - return count; -} - -#endif // !defined(OPENSSL_IS_BORINGSSL) - ExactFloat::ExactFloat(double v) { sign_ = std::signbit(v) ? -1 : 1; if (std::isnan(v)) { @@ -144,26 +58,26 @@ ExactFloat::ExactFloat(double v) { int exp; double f = frexp(fabs(v), &exp); uint64_t m = static_cast(ldexp(f, kDoubleMantissaBits)); - BN_ext_set_uint64(bn_.get(), m); + bn_ = Bignum(m); bn_exp_ = exp - kDoubleMantissaBits; Canonicalize(); } } +// Calculates abs(v) without UB. SafeAbs(INT_MIN) == INT_MIN. +// Generates the same code as std::abs(). +// https://godbolt.org/z/eT6KW1zGb +int SafeAbs(int v) { + return v < 0 ? -static_cast(v) : v; +} + ExactFloat::ExactFloat(int v) { sign_ = (v >= 0) ? 1 : -1; - // Note that this works even for INT_MIN because the parameter type for - // BN_set_word() is unsigned. - ABSL_CHECK(BN_set_word(bn_.get(), abs(v))); bn_exp_ = 0; + bn_ = Bignum(static_cast(SafeAbs(v))); Canonicalize(); } -ExactFloat::ExactFloat(const ExactFloat& b) - : sign_(b.sign_), bn_exp_(b.bn_exp_) { - BN_copy(bn_.get(), b.bn_.get()); -} - ExactFloat ExactFloat::SignedZero(int sign) { ExactFloat r; r.set_zero(sign); @@ -182,29 +96,29 @@ ExactFloat ExactFloat::NaN() { return r; } -int ExactFloat::prec() const { return BN_num_bits(bn_.get()); } +int ExactFloat::prec() const { return bit_width(bn_); } int ExactFloat::exp() const { ABSL_DCHECK(isnormal(*this)); - return bn_exp_ + BN_num_bits(bn_.get()); + return bn_exp_ + bit_width(bn_); } void ExactFloat::set_zero(int sign) { sign_ = sign; bn_exp_ = kExpZero; - if (!BN_is_zero(bn_.get())) BN_zero(bn_.get()); + bn_.set_zero(); } void ExactFloat::set_inf(int sign) { sign_ = sign; bn_exp_ = kExpInfinity; - if (!BN_is_zero(bn_.get())) BN_zero(bn_.get()); + bn_.set_zero(); } void ExactFloat::set_nan() { sign_ = 1; bn_exp_ = kExpNaN; - if (!BN_is_zero(bn_.get())) BN_zero(bn_.get()); + bn_.set_zero(); } int fpclassify(ExactFloat const& x) { @@ -228,7 +142,7 @@ ExactFloat::operator double() const { } double ExactFloat::ToDoubleHelper() const { - ABSL_DCHECK_LE(BN_num_bits(bn_.get()), kDoubleMantissaBits); + ABSL_DCHECK_LE(bit_width(bn_), kDoubleMantissaBits); if (!isnormal(*this)) { if (is_zero()) return copysign(0, sign_); if (isinf(*this)) { @@ -236,10 +150,12 @@ double ExactFloat::ToDoubleHelper() const { } return std::copysign(std::numeric_limits::quiet_NaN(), sign_); } - uint64_t d_mantissa = BN_ext_get_uint64(bn_.get()); + + const double d_mantissa = static_cast(bn_.Cast()); + // We rely on ldexp() to handle overflow and underflow. (It will return a // signed zero or infinity if the result is too small or too large.) - return sign_ * ldexp(static_cast(d_mantissa), bn_exp_); + return sign_ * ldexp(d_mantissa, bn_exp_); } ExactFloat ExactFloat::RoundToMaxPrec(int max_prec, RoundingMode mode) const { @@ -287,10 +203,10 @@ ExactFloat ExactFloat::RoundToPowerOf2(int bit_exp, RoundingMode mode) const { // Never increment. } else if (mode == kRoundTiesAwayFromZero) { // Increment if the highest discarded bit is 1. - if (BN_is_bit_set(bn_.get(), shift - 1)) increment = true; + if (bn_.is_bit_set(shift - 1)) increment = true; } else if (mode == kRoundAwayFromZero) { // Increment unless all discarded bits are zero. - if (BN_ext_count_low_zero_bits(bn_.get()) < shift) increment = true; + if (countr_zero(bn_) < shift) increment = true; } else { ABSL_DCHECK_EQ(mode, kRoundTiesToEven); // Let "w/xyz" denote a mantissa where "w" is the lowest kept bit and @@ -299,16 +215,15 @@ ExactFloat ExactFloat::RoundToPowerOf2(int bit_exp, RoundingMode mode) const { // 0/10* -> Don't increment (fraction = 1/2, kept part even) // 1/10* -> Increment (fraction = 1/2, kept part odd) // ./1.*1.* -> Increment (fraction > 1/2) - if (BN_is_bit_set(bn_.get(), shift - 1) && - ((BN_is_bit_set(bn_.get(), shift) || - BN_ext_count_low_zero_bits(bn_.get()) < shift - 1))) { + if (bn_.is_bit_set(shift - 1) && + ((bn_.is_bit_set(shift) || countr_zero(bn_) < shift - 1))) { increment = true; } } r.bn_exp_ = bn_exp_ + shift; - ABSL_CHECK(BN_rshift(r.bn_.get(), bn_.get(), shift)); + r.bn_ = bn_ >> shift; if (increment) { - ABSL_CHECK(BN_add_word(r.bn_.get(), 1)); + r.bn_ += Bignum(1); } r.sign_ = sign_; r.Canonicalize(); @@ -411,41 +326,36 @@ static void IncrementDecimalDigits(std::string* digits) { int ExactFloat::GetDecimalDigits(int max_digits, std::string* digits) const { ABSL_DCHECK(isnormal(*this)); // Convert the value to the form (bn * (10 ** bn_exp10)) where "bn" is a - // positive integer (BIGNUM). - BigNum bn; + // positive integer. + Bignum bn; int bn_exp10; if (bn_exp_ >= 0) { // The easy case: bn = bn_ * (2 ** bn_exp_)), bn_exp10 = 0. - ABSL_CHECK(BN_lshift(bn.get(), bn_.get(), bn_exp_)); + bn = bn_ << bn_exp_; bn_exp10 = 0; } else { // Set bn = bn_ * (5 ** -bn_exp_) and bn_exp10 = bn_exp_. This is // equivalent to the original value of (bn_ * (2 ** bn_exp_)). - BigNum power; - ABSL_CHECK(BN_set_word(power.get(), -bn_exp_)); - ABSL_CHECK(BN_set_word(bn.get(), 5)); - BN_CTX* ctx = BN_CTX_new(); - ABSL_CHECK(BN_exp(bn.get(), bn.get(), power.get(), ctx)); - ABSL_CHECK(BN_mul(bn.get(), bn.get(), bn_.get(), ctx)); - BN_CTX_free(ctx); + bn = bn_ * Bignum(5).Pow(-bn_exp_); bn_exp10 = bn_exp_; } - // Now convert "bn" to a decimal string. - char* all_digits = BN_bn2dec(bn.get()); - ABSL_DCHECK(all_digits != nullptr); + // Now convert "bn" to a decimal string using our Bignum's string conversion. + std::string all_digits = absl::StrFormat("%v", bn); + ABSL_DCHECK(!all_digits.empty()); // Check whether we have too many digits and round if necessary. - int num_digits = strlen(all_digits); + int num_digits = all_digits.length(); if (num_digits <= max_digits) { *digits = all_digits; } else { - digits->assign(all_digits, max_digits); + digits->assign(all_digits, 0, max_digits); // Standard "printf" formatting rounds ties to an even number. This means // that we round up (away from zero) if highest discarded digit is '5' or // more, unless all other discarded digits are zero in which case we round // up only if the lowest kept digit is odd. if (all_digits[max_digits] >= '5' && ((all_digits[max_digits - 1] & 1) == 1 || - strpbrk(all_digits + max_digits + 1, "123456789") != nullptr)) { + all_digits.find_first_not_of('0', max_digits + 1) != + std::string::npos)) { // This can increase the number of digits by 1, but in that case at // least one trailing zero will be stripped off below. IncrementDecimalDigits(digits); @@ -453,14 +363,13 @@ int ExactFloat::GetDecimalDigits(int max_digits, std::string* digits) const { // Adjust the base-10 exponent to reflect the digits we have removed. bn_exp10 += num_digits - max_digits; } - OPENSSL_free(all_digits); // Now strip any trailing zeros. ABSL_DCHECK_NE((*digits)[0], '0'); std::string::size_type pos = digits->find_last_not_of('0') + 1; bn_exp10 += digits->size() - pos; digits->erase(pos); - ABSL_DCHECK_LE(digits->size(), max_digits); + ABSL_DCHECK_LE(static_cast(digits->size()), max_digits); // Finally, we adjust the base-10 exponent so that the mantissa is a // fraction in the range [0.1, 1) rather than an integer. @@ -471,15 +380,6 @@ std::string ExactFloat::ToUniqueString() const { return absl::StrFormat("%s<%d>", ToString(), prec()); } -ExactFloat& ExactFloat::operator=(const ExactFloat& b) { - if (this != &b) { - sign_ = b.sign_; - bn_exp_ = b.bn_exp_; - BN_copy(bn_.get(), b.bn_.get()); - } - return *this; -} - ExactFloat ExactFloat::operator-() const { return CopyWithSign(-sign_); } ExactFloat operator+(const ExactFloat& a, const ExactFloat& b) { @@ -518,29 +418,26 @@ ExactFloat ExactFloat::SignedSum(int a_sign, const ExactFloat* a, int b_sign, swap(a_sign, b_sign); swap(a, b); } + // Shift "a" if necessary so that both values have the same bn_exp_. ExactFloat r; - if (a->bn_exp_ > b->bn_exp_) { - ABSL_CHECK(BN_lshift(r.bn_.get(), a->bn_.get(), a->bn_exp_ - b->bn_exp_)); - a = &r; // The only field of "a" used below is bn_. - } + r.bn_ = a->bn_; + r.bn_ <<= (a->bn_exp_ - b->bn_exp_); + r.bn_exp_ = b->bn_exp_; if (a_sign == b_sign) { - ABSL_CHECK(BN_add(r.bn_.get(), a->bn_.get(), b->bn_.get())); + r.bn_ += b->bn_; r.sign_ = a_sign; } else { - // Note that the BIGNUM documentation is out of date -- all methods now - // allow the result to be the same as any input argument, so it is okay if - // (a == &r) due to the shift above. - ABSL_CHECK(BN_sub(r.bn_.get(), a->bn_.get(), b->bn_.get())); - if (BN_is_zero(r.bn_.get())) { + r.bn_ -= b->bn_; + if (r.bn_.is_zero()) { r.sign_ = +1; - } else if (BN_is_negative(r.bn_.get())) { - // The magnitude of "b" was larger. + } else if (r.bn_.is_negative()) { + // |b| was greater than |a|. r.sign_ = b_sign; - BN_set_negative(r.bn_.get(), false); + r.bn_.set_negative(false); } else { - // They were equal, or the magnitude of "a" was larger. + // |a| was greater than |b| r.sign_ = a_sign; } } @@ -554,16 +451,16 @@ void ExactFloat::Canonicalize() { // Underflow/overflow occurs if exp() is not in [kMinExp, kMaxExp]. // We also convert a zero mantissa to signed zero. int my_exp = exp(); - if (my_exp < kMinExp || BN_is_zero(bn_.get())) { + if (my_exp < kMinExp || bn_.is_zero()) { set_zero(sign_); } else if (my_exp > kMaxExp) { set_inf(sign_); - } else if (!BN_is_odd(bn_.get())) { + } else if (bn_.is_even()) { // Remove any low-order zero bits from the mantissa. - ABSL_DCHECK(!BN_is_zero(bn_.get())); - int shift = BN_ext_count_low_zero_bits(bn_.get()); + ABSL_DCHECK(!bn_.is_zero()); + int shift = countr_zero(bn_); if (shift > 0) { - ABSL_CHECK(BN_rshift(bn_.get(), bn_.get(), shift)); + bn_ >>= shift; bn_exp_ += shift; } } @@ -595,9 +492,7 @@ ExactFloat operator*(const ExactFloat& a, const ExactFloat& b) { ExactFloat r; r.sign_ = result_sign; r.bn_exp_ = a.bn_exp_ + b.bn_exp_; - BN_CTX* ctx = BN_CTX_new(); - ABSL_CHECK(BN_mul(r.bn_.get(), a.bn_.get(), b.bn_.get(), ctx)); - BN_CTX_free(ctx); + r.bn_ = a.bn_ * b.bn_; r.Canonicalize(); return r; } @@ -615,14 +510,14 @@ bool operator==(const ExactFloat& a, const ExactFloat& b) { // Otherwise, the signs and mantissas must match. Note that non-normal // values such as infinity have a mantissa of zero. - return a.sign_ == b.sign_ && BN_ucmp(a.bn_.get(), b.bn_.get()) == 0; + return a.sign_ == b.sign_ && a.bn_ == b.bn_; } int ExactFloat::ScaleAndCompare(const ExactFloat& b) const { ABSL_DCHECK(isnormal(*this) && isnormal(b) && bn_exp_ >= b.bn_exp_); ExactFloat tmp = *this; - ABSL_CHECK(BN_lshift(tmp.bn_.get(), tmp.bn_.get(), bn_exp_ - b.bn_exp_)); - return BN_ucmp(tmp.bn_.get(), b.bn_.get()); + tmp.bn_ <<= (bn_exp_ - b.bn_exp_); + return tmp.bn_.Compare(b.bn_); } bool ExactFloat::UnsignedLess(const ExactFloat& b) const { @@ -714,7 +609,7 @@ T ExactFloat::ToInteger(RoundingMode mode) const { if (!isinf(r)) { // If the unsigned value has more than 63 bits it is always clamped. if (r.exp() < 64) { - int64_t value = BN_ext_get_uint64(r.bn_.get()) << r.bn_exp_; + int64_t value = r.bn_.Cast() << r.bn_exp_; if (r.sign_ < 0) value = -value; return std::clamp(value, kMinValue, kMaxValue); } diff --git a/src/s2/util/math/exactfloat/exactfloat.h b/src/s2/util/math/exactfloat/exactfloat.h index 3e271e5d..37d864d4 100644 --- a/src/s2/util/math/exactfloat/exactfloat.h +++ b/src/s2/util/math/exactfloat/exactfloat.h @@ -15,7 +15,7 @@ // Author: ericv@google.com (Eric Veach) // -// ExactFloat is a multiple-precision floating point type that uses the OpenSSL +// ExactFloat is a multiple-precision floating point type that uses a custom // Bignum library for numerical calculations. It has the same interface as the // built-in "float" and "double" types, but only supports the subset of // operators and intrinsics where it is possible to compute the result exactly. @@ -25,10 +25,8 @@ // algorithms, especially for disambiguating cases where ordinary // double-precision arithmetic yields an uncertain result. // -// ExactFloat is a subset of the now-retired MPFloat class, which used the GNU -// MPFR library for numerical calculations. The main reason for the switch to -// ExactFloat is that OpenSSL has a BSD-style license whereas MPFR has a much -// more restrictive LGPL license. +// ExactFloat was originally based on OpenSSL's Bignum library, but has been +// updated to use a custom implementation to remove external dependencies. // // ExactFloat has the following features: // @@ -95,13 +93,14 @@ #ifndef S2_UTIL_MATH_EXACTFLOAT_EXACTFLOAT_H_ #define S2_UTIL_MATH_EXACTFLOAT_EXACTFLOAT_H_ +#include #include #include #include #include #include -#include +#include "s2/util/math/exactfloat/bignum.h" namespace exactfloat { @@ -168,13 +167,6 @@ class ExactFloat { // example, "0.125" is allowed but "0.1" is not). explicit ExactFloat(const char* s) { Unimplemented(); } - // Copy constructor. - ExactFloat(const ExactFloat& b); - - // The destructor is not virtual for efficiency reasons. Therefore no - // subclass should declare additional fields that require destruction. - inline ~ExactFloat() = default; - ///////////////////////////////////////////////////////////////////// // Constants // @@ -314,9 +306,6 @@ class ExactFloat { ///////////////////////////////////////////////////////////////////////////// // Operators - // Assignment operator. - ExactFloat& operator=(const ExactFloat& b); - // Unary plus. ExactFloat operator+() const { return *this; } @@ -487,39 +476,7 @@ class ExactFloat { friend ExactFloat logb(const ExactFloat& a); protected: - // OpenSSL >= 1.1 does not have BN_init, and does not support stack- - // allocated BIGNUMS. We use BN_init when possible, but BN_new otherwise. - // If the performance penalty is too high, an object pool can be added - // in the future. -#if defined(OPENSSL_IS_BORINGSSL) - // BoringSSL supports stack allocated BIGNUMs and BN_init. - class BigNum { - public: - BigNum() { BN_init(&bn_); } - // Prevent accidental, expensive, copying. - BigNum(const BigNum&) = delete; - BigNum& operator=(const BigNum&) = delete; - ~BigNum() { BN_free(&bn_); } - BIGNUM* get() { return &bn_; } - const BIGNUM* get() const { return &bn_; } - - private: - BIGNUM bn_; - }; -#else - class BigNum { - public: - BigNum() : bn_(BN_new()) {} - BigNum(const BigNum&) = delete; - BigNum& operator=(const BigNum&) = delete; - ~BigNum() { BN_free(bn_); } - BIGNUM* get() { return bn_; } - const BIGNUM* get() const { return bn_; } - - private: - BIGNUM* bn_; - }; -#endif + using Bignum = exactfloat_internal::Bignum; // Non-normal numbers are represented using special exponent values and a // mantissa of zero. Do not change these values; methods such as @@ -531,12 +488,16 @@ class ExactFloat { // Normal numbers are represented as (sign_ * bn_ * (2 ** bn_exp_)), where: // - sign_ is either +1 or -1 - // - bn_ is a BIGNUM with a positive value + // - bn_ is a Bignum with a positive value // - bn_exp_ is the base-2 exponent applied to bn_. - // Default value is zero. + // + // bn_ stores the magnitude for the mantissa of the floating point value. We + // store a sign bit here separately from bn_ so that functions like exp() are + // easier to reason about (as the sign would flip depending on whether the + // exponent were odd or even). int32_t sign_ = 1; int32_t bn_exp_ = kExpZero; - BigNum bn_; + Bignum bn_; // A standard IEEE "double" has a 53-bit mantissa consisting of a 52-bit // fraction plus an implicit leading "1" bit. diff --git a/src/s2/util/math/exactfloat/exactfloat_test.cc b/src/s2/util/math/exactfloat/exactfloat_test.cc index 282d6391..cb082104 100644 --- a/src/s2/util/math/exactfloat/exactfloat_test.cc +++ b/src/s2/util/math/exactfloat/exactfloat_test.cc @@ -457,6 +457,12 @@ TEST_F(ExactFloatTest, Constructors) { ExpectSameWithPrec(-125, 7, e); } +TEST_F(ExactFloatTest, IntMinConstruction) { + // Ensure that construction with INT_MIN works properly. + ExactFloat f = INT_MIN; + ExpectSame(INT_MIN, f); +} + TEST_F(ExactFloatTest, Constants) { EXPECT_TRUE(ExactFloat::SignedZero(+1).is_zero()); EXPECT_EQ(false, signbit(ExactFloat::SignedZero(+1)));