/* * Copyright (c) 2021, Leon Albrecht * Copyright (c) 2023, Dan Klishch * * SPDX-License-Identifier: BSD-2-Clause */ #pragma once #include #include #include #include #include #include #include #include #include #include namespace AK { namespace Detail { // As noted near the declaration of StaticStorage, bit_size is more like a hint for a storage size. // The effective bit size is `sizeof(StaticStorage<...>) * 8`. It is a programmer's responsibility // to ensure that the hinted bit_size is always greater than the actual integer size. // That said, do not use unaligned (bit_size % 64 != 0) `UFixedBigInt`s if you do not know what you // are doing. template> class UFixedBigInt; // ===== Concepts ===== template constexpr inline size_t assumed_bit_size = 0; template<> constexpr inline size_t assumed_bit_size = bit_width; template constexpr inline size_t assumed_bit_size> = bit_size; template constexpr inline size_t assumed_bit_size = bit_width; template concept ConvertibleToUFixedInt = (assumed_bit_size != 0); template concept UFixedInt = (ConvertibleToUFixedInt && !IsSame); template concept NotBuiltInUFixedInt = (UFixedInt && !BuiltInUFixedInt); // ===== UFixedBigInt itself ===== template constexpr auto& get_storage_of(UFixedBigInt& value) { return value.m_data; } template constexpr auto& get_storage_of(UFixedBigInt const& value) { return value.m_data; } template constexpr void mul_internal(Operand1 const& operand1, Operand2 const& operand2, Result& result) { StorageOperations<>::baseline_mul(operand1, operand2, result, g_null_allocator); } template constexpr void div_mod_internal( // Include AK/UFixedBigIntDivision.h to use UFixedBigInt division StaticStorage const& dividend, StaticStorage const& divisor, StaticStorage& quotient, StaticStorage& remainder); template class UFixedBigInt { constexpr static size_t static_size = Storage::static_size; constexpr static size_t part_size = static_size / 2; using UFixedBigIntPart = Conditional>; using Ops = StorageOperations<>; public: constexpr UFixedBigInt() = default; explicit constexpr UFixedBigInt(IntegerWrapper value) { Ops::copy(value.m_data, m_data); } consteval UFixedBigInt(int value) { Ops::copy(IntegerWrapper(value).m_data, m_data); } template requires(sizeof(T) > sizeof(Storage)) explicit constexpr UFixedBigInt(T const& value) { Ops::copy(get_storage_of(value), m_data); } template requires(sizeof(T) <= sizeof(Storage)) constexpr UFixedBigInt(T const& value) { Ops::copy(get_storage_of(value), m_data); } constexpr UFixedBigInt(UFixedBigIntPart const& low, UFixedBigIntPart const& high) requires(static_size % 2 == 0) { decltype(auto) low_storage = get_storage_of(low); decltype(auto) high_storage = get_storage_of(high); for (size_t i = 0; i < part_size; ++i) m_data[i] = low_storage[i]; for (size_t i = 0; i < part_size; ++i) m_data[i + part_size] = high_storage[i]; } template requires((assumed_bit_size * n) <= bit_size) constexpr UFixedBigInt(T const (&value)[n]) { size_t offset = 0; for (size_t i = 0; i < n; ++i) { if (offset % native_word_size == 0) { // Aligned initialization (i. e. u256 from two u128) decltype(auto) storage = get_storage_of(value[i]); for (size_t i = 0; i < storage.size(); ++i) m_data[i + offset / native_word_size] = storage[i]; } else if (offset % native_word_size == 32 && IsSame) { // u32 vector initialization on 64-bit platforms m_data[offset / native_word_size] |= static_cast(value[i]) << 32; } else { VERIFY_NOT_REACHED(); } offset += assumed_bit_size; } for (size_t i = (offset + native_word_size - 1) / native_word_size; i < m_data.size(); ++i) m_data[i] = 0; } // Casts & parts extraction template constexpr explicit operator T() const { T result; Ops::copy(m_data, result.m_data); return result; } template requires(sizeof(T) <= sizeof(NativeWord)) constexpr explicit operator T() const { return m_data[0]; } template requires(sizeof(T) == sizeof(NativeDoubleWord)) constexpr explicit operator T() const { return (static_cast(m_data[1]) << native_word_size) + m_data[0]; } constexpr UFixedBigIntPart low() const requires(static_size % 2 == 0) { if constexpr (part_size == 1) { return m_data[0]; } else if constexpr (IsSame) { return m_data[0] + (static_cast(m_data[1]) << native_word_size); } else { UFixedBigInt result; Ops::copy(m_data, result.m_data); return result; } } constexpr UFixedBigIntPart high() const requires(static_size % 2 == 0) { if constexpr (part_size == 1) { return m_data[part_size]; } else if constexpr (IsSame) { return m_data[part_size] + (static_cast(m_data[part_size + 1]) << native_word_size); } else { UFixedBigInt result; Ops::copy(m_data, result.m_data, part_size); return result; } } Bytes bytes() { return Bytes(reinterpret_cast(this), sizeof(Storage)); } ReadonlyBytes bytes() const { return ReadonlyBytes(reinterpret_cast(this), sizeof(Storage)); } constexpr UnsignedStorageSpan span() { return { m_data.data(), static_size }; } constexpr UnsignedStorageReadonlySpan span() const { return { m_data.data(), static_size }; } // Binary utils constexpr size_t popcnt() const { size_t result = 0; for (size_t i = 0; i < m_data.size(); ++i) result += popcount(m_data[i]); return result; } constexpr size_t ctz() const { size_t result = 0; for (size_t i = 0; i < m_data.size(); ++i) { if (m_data[i]) { result += count_trailing_zeroes(m_data[i]); break; } else { result += native_word_size; } } return result; } constexpr size_t clz() const { size_t result = 0; for (size_t i = m_data.size(); i--;) { if (m_data[i]) { result += count_leading_zeroes(m_data[i]); break; } else { result += native_word_size; } } return result + bit_size - native_word_size * static_size; } // Comparisons constexpr bool operator!() const { bool result = true; for (size_t i = 0; i < m_data.size(); ++i) result &= !m_data[i]; return result; } constexpr explicit operator bool() const { bool result = false; for (size_t i = 0; i < m_data.size(); ++i) result |= m_data[i]; return result; } constexpr bool operator==(UFixedInt auto const& other) const { return Ops::compare(m_data, get_storage_of(other), true) == 0; } constexpr bool operator==(IntegerWrapper other) const { return Ops::compare(m_data, get_storage_of(other), true) == 0; } constexpr int operator<=>(UFixedInt auto const& other) const { return Ops::compare(m_data, get_storage_of(other), false); } constexpr int operator<=>(IntegerWrapper other) const { return Ops::compare(m_data, get_storage_of(other), false); } #define DEFINE_STANDARD_BINARY_OPERATOR(op, function) \ constexpr auto operator op(UFixedInt auto const& other) const \ { \ auto func = [](auto&& a, auto&& b, auto&& c) { function(a, b, c); }; \ return do_standard_binary_operation(other, func); \ } \ \ constexpr auto operator op(IntegerWrapper other) const \ { \ auto func = [](auto&& a, auto&& b, auto&& c) { function(a, b, c); }; \ return do_standard_binary_operation(other, func); \ } #define DEFINE_STANDARD_COMPOUND_ASSIGNMENT(op, function) \ constexpr auto& operator op(UFixedInt auto const& other) \ { \ auto func = [](auto&& a, auto&& b, auto&& c) { function(a, b, c); }; \ do_standard_compound_assignment(other, func); \ return *this; \ } \ \ constexpr auto& operator op(IntegerWrapper other) \ { \ auto func = [](auto&& a, auto&& b, auto&& c) { function(a, b, c); }; \ do_standard_compound_assignment(other, func); \ return *this; \ } // Binary operators DEFINE_STANDARD_BINARY_OPERATOR(^, Ops::compute_bitwise) DEFINE_STANDARD_BINARY_OPERATOR(&, Ops::compute_bitwise) DEFINE_STANDARD_BINARY_OPERATOR(|, Ops::compute_bitwise) DEFINE_STANDARD_COMPOUND_ASSIGNMENT(^=, Ops::compute_inplace_bitwise) DEFINE_STANDARD_COMPOUND_ASSIGNMENT(&=, Ops::compute_inplace_bitwise) DEFINE_STANDARD_COMPOUND_ASSIGNMENT(|=, Ops::compute_inplace_bitwise) constexpr auto operator~() const { UFixedBigInt result; Ops::compute_bitwise(m_data, m_data, result.m_data); return result; } constexpr auto operator<<(size_t shift) const { UFixedBigInt result; Ops::shift_left(m_data, shift, result.m_data); return result; } constexpr auto& operator<<=(size_t shift) { Ops::shift_left(m_data, shift, m_data); return *this; } constexpr auto operator>>(size_t shift) const { UFixedBigInt result; Ops::shift_right(m_data, shift, result.m_data); return result; } constexpr auto& operator>>=(size_t shift) { Ops::shift_right(m_data, shift, m_data); return *this; } // Arithmetic template constexpr auto addc(T const& other, bool& carry) const { UFixedBigInt)> result; carry = Ops::add(m_data, get_storage_of(other), result.m_data, carry); return result; } template constexpr auto subc(T const& other, bool& borrow) const { UFixedBigInt)> result; borrow = Ops::add(m_data, get_storage_of(other), result.m_data, borrow); return result; } DEFINE_STANDARD_BINARY_OPERATOR(+, Ops::add) DEFINE_STANDARD_BINARY_OPERATOR(-, Ops::add) DEFINE_STANDARD_COMPOUND_ASSIGNMENT(+=, Ops::add) DEFINE_STANDARD_COMPOUND_ASSIGNMENT(-=, Ops::add) constexpr auto& operator++() { Ops::increment(m_data); return *this; } constexpr auto& operator--() { Ops::increment(m_data); return *this; } constexpr auto operator++(int) { UFixedBigInt result = *this; Ops::increment(m_data); return result; } constexpr auto operator--(int) { UFixedBigInt result = *this; Ops::increment(m_data); return result; } DEFINE_STANDARD_BINARY_OPERATOR(*, mul_internal) constexpr auto& operator*=(UFixedInt auto const& other) { return *this = *this * other; } constexpr auto& operator*=(IntegerWrapper const& other) { return *this = *this * other; } template constexpr auto wide_multiply(T const& other) const { UFixedBigInt> result; mul_internal(m_data, get_storage_of(other), result.m_data); return result; } template constexpr UFixedBigInt div_mod(T const& divisor, T& remainder) const { UFixedBigInt quotient; UFixedBigInt> resulting_remainder; div_mod_internal, true>(m_data, get_storage_of(divisor), get_storage_of(quotient), get_storage_of(resulting_remainder)); remainder = resulting_remainder; return quotient; } template constexpr auto operator/(T const& other) const { UFixedBigInt quotient; StaticStorage> remainder; // unused div_mod_internal, false>(m_data, get_storage_of(other), get_storage_of(quotient), remainder); return quotient; } template constexpr auto operator%(T const& other) const { StaticStorage quotient; // unused UFixedBigInt> remainder; div_mod_internal, true>(m_data, get_storage_of(other), quotient, get_storage_of(remainder)); return remainder; } constexpr auto operator/(IntegerWrapper const& other) const { return *this / static_cast>(other); } constexpr auto operator%(IntegerWrapper const& other) const { return *this % static_cast>(other); } template constexpr auto& operator/=(T const& other) { return *this = *this / other; } constexpr auto& operator/=(IntegerWrapper const& other) { return *this = *this / other; } template constexpr auto& operator%=(U const& other) { return *this = *this % other; } constexpr auto& operator%=(IntegerWrapper const& other) { return *this = *this % other; } // Note: If there ever be need for non side-channel proof sqrt/pow/pow_mod of UFixedBigInt, you // can restore them from Git history. #undef DEFINE_STANDARD_BINARY_OPERATOR #undef DEFINE_STANDARD_COMPOUND_ASSIGNMENT // These functions are intended to be used in LibCrypto for equality checks without branching. constexpr bool is_zero_constant_time() const { NativeWord fold = 0; for (size_t i = 0; i < m_data.size(); ++i) taint_for_optimizer(fold |= m_data[i]); return !fold; } constexpr bool is_equal_to_constant_time(UFixedBigInt other) const { NativeWord fold = 0; for (size_t i = 0; i < m_data.size(); ++i) taint_for_optimizer(fold |= m_data[i] ^ other.m_data[i]); return !fold; } private: template constexpr auto do_standard_binary_operation(T const& other, Function function) const { UFixedBigInt)> result; function(m_data, get_storage_of(other), result.m_data); return result; } template constexpr void do_standard_compound_assignment(T const& other, Function function) { static_assert(bit_size >= assumed_bit_size, "Requested operation requires integer size to be expanded."); function(m_data, get_storage_of(other), m_data); } template friend class UFixedBigInt; friend constexpr auto& get_storage_of(UFixedBigInt&); friend constexpr auto& get_storage_of(UFixedBigInt const&); Storage m_data; }; // FIXME: There is a bug in LLVM (https://github.com/llvm/llvm-project/issues/59783) which doesn't // allow to use the following comparisons. bool operator==(BuiltInUFixedInt auto const& a, NotBuiltInUFixedInt auto const& b) { return b.operator==(a); } int operator<=>(BuiltInUFixedInt auto const& a, NotBuiltInUFixedInt auto const& b) { return -b.operator<=>(a); } bool operator==(IntegerWrapper const& a, NotBuiltInUFixedInt auto const& b) { return b.operator==(a); } int operator<=>(IntegerWrapper const& a, NotBuiltInUFixedInt auto const& b) { return -b.operator<=>(a); } } using Detail::UFixedBigInt; template constexpr inline bool IsUnsigned> = true; template constexpr inline bool IsSigned> = false; template struct NumericLimits> { using T = UFixedBigInt; static constexpr T min() { return T {}; } static constexpr T max() { return --T {}; } static constexpr bool is_signed() { return false; } }; template class LittleEndian> { template constexpr static auto byte_swap_if_not_little_endian(UFixedBigInt value) { if constexpr (HostIsLittleEndian) { return value; } else { auto words = value.span(); auto front_it = words.begin(); auto ending_half_words = words.slice(ceil_div(words.size(), static_cast(2))); for (size_t i = 0; i < ending_half_words.size(); ++i, ++front_it) *front_it = convert_between_host_and_little_endian(exchange(ending_half_words[ending_half_words.size() - i - 1], convert_between_host_and_little_endian(*front_it))); if (words.size() % 2) words[words.size() / 2] = convert_between_host_and_little_endian(*front_it); return value; } } public: constexpr LittleEndian() = default; constexpr LittleEndian(UFixedBigInt value) : m_value(byte_swap_if_not_little_endian(value)) { } constexpr operator UFixedBigInt() const { return byte_swap_if_not_little_endian(m_value); } private: UFixedBigInt m_value { 0 }; }; template class BigEndian> { template constexpr static auto byte_swap_if_not_big_endian(UFixedBigInt value) { if constexpr (!HostIsLittleEndian) { return value; } else { auto words = value.span(); auto front_it = words.begin(); auto ending_half_words = words.slice(ceil_div(words.size(), static_cast(2))); for (size_t i = 0; i < ending_half_words.size(); ++i, ++front_it) *front_it = convert_between_host_and_big_endian(exchange(ending_half_words[ending_half_words.size() - i - 1], convert_between_host_and_big_endian(*front_it))); if (words.size() % 2) words[words.size() / 2] = convert_between_host_and_big_endian(*front_it); return value; } } public: constexpr BigEndian() = default; constexpr BigEndian(UFixedBigInt value) : m_value(byte_swap_if_not_big_endian(value)) { } constexpr operator UFixedBigInt() const { return byte_swap_if_not_big_endian(m_value); } private: UFixedBigInt m_value { 0 }; }; template struct Traits> : public DefaultTraits> { static constexpr bool is_trivially_serializable() { return true; } static constexpr bool is_trivial() { return true; } }; // ===== Formatting ===== // FIXME: This does not work for size != 2 ** x template struct Formatter : StandardFormatter { Formatter() = default; explicit Formatter(StandardFormatter formatter) : StandardFormatter(formatter) { } ErrorOr format(FormatBuilder& builder, T const& value) { using U = decltype(value.low()); if (m_precision.has_value()) VERIFY_NOT_REACHED(); if (m_mode == Mode::Pointer) { // these are way to big for a pointer VERIFY_NOT_REACHED(); } if (m_mode == Mode::Default) m_mode = Mode::Hexadecimal; if (!value.high()) { Formatter formatter { *this }; return formatter.format(builder, value.low()); } u8 base = 0; if (m_mode == Mode::Binary) { base = 2; } else if (m_mode == Mode::BinaryUppercase) { base = 2; } else if (m_mode == Mode::Octal) { TODO(); } else if (m_mode == Mode::Decimal) { TODO(); } else if (m_mode == Mode::Hexadecimal) { base = 16; } else if (m_mode == Mode::HexadecimalUppercase) { base = 16; } else { VERIFY_NOT_REACHED(); } ssize_t width = m_width.value_or(0); ssize_t lower_length = ceil_div(Detail::assumed_bit_size, (ssize_t)base); Formatter formatter { *this }; formatter.m_width = max(width - lower_length, (ssize_t)0); TRY(formatter.format(builder, value.high())); TRY(builder.put_literal("'"sv)); formatter.m_zero_pad = true; formatter.m_alternative_form = false; formatter.m_width = lower_length; TRY(formatter.format(builder, value.low())); return {}; } }; } // these sizes should suffice for most usecases using u128 = AK::UFixedBigInt<128>; using u256 = AK::UFixedBigInt<256>; using u384 = AK::UFixedBigInt<384>; using u512 = AK::UFixedBigInt<512>; using u768 = AK::UFixedBigInt<768>; using u1024 = AK::UFixedBigInt<1024>; using u1536 = AK::UFixedBigInt<1536>; using u2048 = AK::UFixedBigInt<2048>; using u4096 = AK::UFixedBigInt<4096>;