diff options
Diffstat (limited to 'src/int_utils.h')
-rw-r--r-- | src/int_utils.h | 290 |
1 files changed, 290 insertions, 0 deletions
diff --git a/src/int_utils.h b/src/int_utils.h new file mode 100644 index 0000000000..62b2c38a29 --- /dev/null +++ b/src/int_utils.h @@ -0,0 +1,290 @@ +/********************************************************************** + * Copyright (c) 2018 Pieter Wuille, Greg Maxwell, Gleb Naumenko * + * Distributed under the MIT software license, see the accompanying * + * file LICENSE or http://www.opensource.org/licenses/mit-license.php.* + **********************************************************************/ + +#ifndef _MINISKETCH_INT_UTILS_H_ +#define _MINISKETCH_INT_UTILS_H_ + +#include <stdlib.h> + +#include <limits> +#include <algorithm> +#include <type_traits> + +#ifdef _MSC_VER +# include <intrin.h> +#endif + +template<int bits> +static constexpr inline uint64_t Rot(uint64_t x) { return (x << bits) | (x >> (64 - bits)); } + +static inline void SipHashRound(uint64_t& v0, uint64_t& v1, uint64_t& v2, uint64_t& v3) { + v0 += v1; v1 = Rot<13>(v1); v1 ^= v0; + v0 = Rot<32>(v0); + v2 += v3; v3 = Rot<16>(v3); v3 ^= v2; + v0 += v3; v3 = Rot<21>(v3); v3 ^= v0; + v2 += v1; v1 = Rot<17>(v1); v1 ^= v2; + v2 = Rot<32>(v2); +} + +inline uint64_t SipHash(uint64_t k0, uint64_t k1, uint64_t data) { + uint64_t v0 = 0x736f6d6570736575ULL ^ k0; + uint64_t v1 = 0x646f72616e646f6dULL ^ k1; + uint64_t v2 = 0x6c7967656e657261ULL ^ k0; + uint64_t v3 = 0x7465646279746573ULL ^ k1 ^ data; + SipHashRound(v0, v1, v2, v3); + SipHashRound(v0, v1, v2, v3); + v0 ^= data; + v3 ^= 0x800000000000000ULL; + SipHashRound(v0, v1, v2, v3); + SipHashRound(v0, v1, v2, v3); + v0 ^= 0x800000000000000ULL; + v2 ^= 0xFF; + SipHashRound(v0, v1, v2, v3); + SipHashRound(v0, v1, v2, v3); + SipHashRound(v0, v1, v2, v3); + SipHashRound(v0, v1, v2, v3); + return v0 ^ v1 ^ v2 ^ v3; +} + +class BitWriter { + unsigned char state = 0; + int offset = 0; + unsigned char* out; + +public: + BitWriter(unsigned char* output) : out(output) {} + + template<int BITS, typename I> + inline void Write(I val) { + int bits = BITS; + if (bits + offset >= 8) { + state |= ((val & ((I(1) << (8 - offset)) - 1)) << offset); + *(out++) = state; + val >>= (8 - offset); + bits -= 8 - offset; + offset = 0; + state = 0; + } + while (bits >= 8) { + *(out++) = val & 255; + val >>= 8; + bits -= 8; + } + state |= ((val & ((I(1) << bits) - 1)) << offset); + offset += bits; + } + + inline void Flush() { + if (offset) { + *(out++) = state; + state = 0; + offset = 0; + } + } +}; + +class BitReader { + unsigned char state = 0; + int offset = 0; + const unsigned char* in; + +public: + BitReader(const unsigned char* input) : in(input) {} + + template<int BITS, typename I> + inline I Read() { + int bits = BITS; + if (offset >= bits) { + I ret = state & ((1 << bits) - 1); + state >>= bits; + offset -= bits; + return ret; + } + I val = state; + int out = offset; + while (out + 8 <= bits) { + val |= ((I(*(in++))) << out); + out += 8; + } + if (out < bits) { + unsigned char c = *(in++); + val |= (c & ((I(1) << (bits - out)) - 1)) << out; + state = c >> (bits - out); + offset = 8 - (bits - out); + } else { + state = 0; + offset = 0; + } + return val; + } +}; + +/** Return a value of type I with its `bits` lowest bits set (bits must be > 0). */ +template<int BITS, typename I> +constexpr inline I Mask() { return ((I((I(-1)) << (std::numeric_limits<I>::digits - BITS))) >> (std::numeric_limits<I>::digits - BITS)); } + +/** Compute the smallest power of two that is larger than val. */ +template<typename I> +static inline int CountBits(I val, int max) { +#ifdef HAVE_CLZ + (void)max; + if (val == 0) return 0; + if (std::numeric_limits<unsigned>::digits >= std::numeric_limits<I>::digits) { + return std::numeric_limits<unsigned>::digits - __builtin_clz(val); + } else if (std::numeric_limits<unsigned long>::digits >= std::numeric_limits<I>::digits) { + return std::numeric_limits<unsigned long>::digits - __builtin_clzl(val); + } else { + return std::numeric_limits<unsigned long long>::digits - __builtin_clzll(val); + } +#elif _MSC_VER + (void)max; + unsigned long index; + unsigned char ret; + if (std::numeric_limits<I>::digits <= 32) { + ret = _BitScanReverse(&index, val); + } else { + ret = _BitScanReverse64(&index, val); + } + if (!ret) return 0; + return index; +#else + while (max && (val >> (max - 1) == 0)) --max; + return max; +#endif +} + +template<typename I, int BITS> +class BitsInt { +private: + static_assert(std::is_integral<I>::value && std::is_unsigned<I>::value, "BitsInt requires an unsigned integer type"); + static_assert(BITS > 0 && BITS <= std::numeric_limits<I>::digits, "BitsInt requires 1 <= Bits <= representation type size"); + + static constexpr I MASK = Mask<BITS, I>(); + +public: + + typedef I Repr; + + static constexpr int SIZE = BITS; + + static void inline Swap(I& a, I& b) { + std::swap(a, b); + } + + static constexpr inline bool IsZero(I a) { return a == 0; } + static constexpr inline I Mask(I val) { return val & MASK; } + static constexpr inline I Shift(I val, int bits) { return ((val << bits) & MASK); } + static constexpr inline I UnsafeShift(I val, int bits) { return (val << bits); } + + template<int Offset, int Count> + static constexpr inline int MidBits(I val) { + static_assert(Count > 0, "BITSInt::MidBits needs Count > 0"); + static_assert(Count + Offset <= BITS, "BitsInt::MidBits overflow of Count+Offset"); + return (val >> Offset) & ((I(1) << Count) - 1); + } + + template<int Count> + static constexpr inline int TopBits(I val) { + static_assert(Count > 0, "BitsInt::TopBits needs Count > 0"); + static_assert(Count <= BITS, "BitsInt::TopBits needs Offset <= BITS"); + return val >> (BITS - Count); + } + + static inline constexpr I CondXorWith(I val, bool cond, I v) { + return val ^ (-I(cond) & v); + } + + template<I MOD> + static inline constexpr I CondXorWith(I val, bool cond) { + return val ^ (-I(cond) & MOD); + } + + static inline int Bits(I val, int max) { return CountBits<I>(val, max); } +}; + +/** Class which implements a stateless LFSR for generic moduli. */ +template<typename F, uint32_t MOD> +struct LFSR { + typedef typename F::Repr I; + /** Shift a value `a` up once, treating it as an `N`-bit LFSR, with pattern `MOD`. */ + static inline constexpr I Call(const I& a) { + return F::template CondXorWith<MOD>(F::Shift(a, 1), F::template TopBits<1>(a)); + } +}; + +/** Helper class for carryless multiplications. */ +template<typename I, int N, typename L, typename F, int K> struct GFMulHelper; +template<typename I, int N, typename L, typename F> struct GFMulHelper<I, N, L, F, 0> +{ + static inline constexpr I Run(const I& a, const I& b) { return I(0); } +}; +template<typename I, int N, typename L, typename F, int K> struct GFMulHelper +{ + static inline constexpr I Run(const I& a, const I& b) { return F::CondXorWith(GFMulHelper<I, N, L, F, K - 1>::Run(L::Call(a), b), F::template MidBits<N - K, 1>(b), a); } +}; + +/** Compute the carry-less multiplication of a and b, with N bits, using L as LFSR type. */ +template<typename I, int N, typename L, typename F> inline constexpr I GFMul(const I& a, const I& b) { return GFMulHelper<I, N, L, F, N>::Run(a, b); } + +/** Compute the inverse of x using an extgcd algorithm. */ +template<typename I, typename F, int BITS, uint32_t MOD> +inline I InvExtGCD(I x) +{ + if (F::IsZero(x)) return x; + I t(0), newt(1); + I r(MOD), newr = x; + int rlen = BITS + 1, newrlen = F::Bits(newr, BITS); + while (newr) { + int q = rlen - newrlen; + r ^= F::Shift(newr, q); + t ^= F::UnsafeShift(newt, q); + rlen = F::Bits(r, rlen - 1); + if (r < newr) { + F::Swap(t, newt); + F::Swap(r, newr); + std::swap(rlen, newrlen); + } + } + return t; +} + +/** Compute the inverse of x1 using an exponentiation ladder. + * + * The `MUL` argument is a multiplication function, `SQR` is a squaring function, and the `SQRi` arguments + * compute x**(2**i). + */ +template<typename I, typename F, int BITS, I (*MUL)(I, I), I (*SQR)(I), I (*SQR2)(I), I(*SQR4)(I), I(*SQR8)(I), I(*SQR16)(I)> +inline I InvLadder(I x1) +{ + static constexpr int INV_EXP = BITS - 1; + I x2 = (INV_EXP >= 2) ? MUL(SQR(x1), x1) : I(); + I x4 = (INV_EXP >= 4) ? MUL(SQR2(x2), x2) : I(); + I x8 = (INV_EXP >= 8) ? MUL(SQR4(x4), x4) : I(); + I x16 = (INV_EXP >= 16) ? MUL(SQR8(x8), x8) : I(); + I x32 = (INV_EXP >= 32) ? MUL(SQR16(x16), x16) : I(); + I r; + if (INV_EXP >= 32) { + r = x32; + } else if (INV_EXP >= 16) { + r = x16; + } else if (INV_EXP >= 8) { + r = x8; + } else if (INV_EXP >= 4) { + r = x4; + } else if (INV_EXP >= 2) { + r = x2; + } else { + r = x1; + } + if (INV_EXP >= 32 && (INV_EXP & 16)) r = MUL(SQR16(r), x16); + if (INV_EXP >= 16 && (INV_EXP & 8)) r = MUL(SQR8(r), x8); + if (INV_EXP >= 8 && (INV_EXP & 4)) r = MUL(SQR4(r), x4); + if (INV_EXP >= 4 && (INV_EXP & 2)) r = MUL(SQR2(r), x2); + if (INV_EXP >= 2 && (INV_EXP & 1)) r = MUL(SQR(r), x1); + return SQR(r); +} + +#endif |