diff --git a/crypto/siphash.cpp b/crypto/siphash.cpp index 98033a9..4bd11ad 100644 --- a/crypto/siphash.cpp +++ b/crypto/siphash.cpp @@ -33,7 +33,7 @@ v1 ^= key->key[1]; \ v0 ^= key->key[0]; -#define POSTAMBLE \ +#define POSTAMBLE24 \ v3 ^= b; \ SIPROUND; \ SIPROUND; \ @@ -45,6 +45,17 @@ SIPROUND; \ return (v0 ^ v1) ^ (v2 ^ v3); +#define POSTAMBLE13 \ + v3 ^= b; \ + SIPROUND; \ + v0 ^= b; \ + v2 ^= 0xff; \ + SIPROUND; \ + SIPROUND; \ + SIPROUND; \ + return (v0 ^ v1) ^ (v2 ^ v3); + + uint64 siphash(const void *data, size_t len, const siphash_key_t *key) { const uint8 *end = (uint8*)data + len - (len % sizeof(uint64)); const uint8 left = len & (sizeof(uint64) - 1); @@ -66,7 +77,7 @@ uint64 siphash(const void *data, size_t len, const siphash_key_t *key) { case 2: b |= ReadLE16(data); break; case 1: b |= end[0]; } - POSTAMBLE + POSTAMBLE24 } /** @@ -81,7 +92,7 @@ uint64 siphash_1u64(const uint64 first, const siphash_key_t *key) SIPROUND; SIPROUND; v0 ^= first; - POSTAMBLE + POSTAMBLE24 } /** @@ -101,7 +112,7 @@ uint64 siphash_2u64(const uint64 first, const uint64 second, const siphash_key_t SIPROUND; SIPROUND; v0 ^= second; - POSTAMBLE + POSTAMBLE24 } /** @@ -127,7 +138,58 @@ uint64 siphash_3u64(const uint64 first, const uint64 second, const uint64 third, SIPROUND; SIPROUND; v0 ^= third; - POSTAMBLE + POSTAMBLE24 +} + +/** +* siphash13_3u64 - compute 64-bit siphash13 PRF value of 3 uint64 +* @first: first uint64 +* @second: second uint64 +* @third: third uint64 +* @key: the siphash key +*/ +uint64 siphash13_3u64(const uint64 first, const uint64 second, const uint64 third, + const siphash_key_t *key) { + PREAMBLE(24) + v3 ^= first; + SIPROUND; + v0 ^= first; + v3 ^= second; + SIPROUND; + v0 ^= second; + v3 ^= third; + SIPROUND; + v0 ^= third; + POSTAMBLE13 +} + +uint64 siphash13_2u64(const uint64 first, const uint64 second, const siphash_key_t *key) { + PREAMBLE(24) + v3 ^= first; + SIPROUND; + v0 ^= first; + v3 ^= second; + SIPROUND; + v0 ^= second; + POSTAMBLE13 +} + +uint64 siphash13_4u64(const uint64 first, const uint64 second, const uint64 third, const uint64 fourth, + const siphash_key_t *key) { + PREAMBLE(24) + v3 ^= first; + SIPROUND; + v0 ^= first; + v3 ^= second; + SIPROUND; + v0 ^= second; + v3 ^= third; + SIPROUND; + v0 ^= third; + v3 ^= fourth; + SIPROUND; + v0 ^= fourth; + POSTAMBLE13 } /** @@ -158,14 +220,14 @@ uint64 siphash_4u64(const uint64 first, const uint64 second, const uint64 third, SIPROUND; SIPROUND; v0 ^= forth; - POSTAMBLE + POSTAMBLE24 } uint64 siphash_1u32(const uint32 first, const siphash_key_t *key) { PREAMBLE(4) b |= first; - POSTAMBLE + POSTAMBLE24 } uint64 siphash_3u32(const uint32 first, const uint32 second, const uint32 third, @@ -178,7 +240,7 @@ uint64 siphash_3u32(const uint32 first, const uint32 second, const uint32 third, SIPROUND; v0 ^= combined; b |= third; - POSTAMBLE + POSTAMBLE24 } uint64 siphash_u64_u32(const uint64 combined, const uint32 third, const siphash_key_t *key) { @@ -188,6 +250,6 @@ uint64 siphash_u64_u32(const uint64 combined, const uint32 third, const siphash_ SIPROUND; v0 ^= combined; b |= third; - POSTAMBLE + POSTAMBLE24 } diff --git a/crypto/siphash.h b/crypto/siphash.h index 3b5dc74..d6775d0 100644 --- a/crypto/siphash.h +++ b/crypto/siphash.h @@ -50,4 +50,11 @@ uint64 siphash_u64_u32(const uint64 combined, const uint32 third, const siphash_ */ uint64 siphash(const void *data, size_t len, const siphash_key_t *key); +uint64 siphash13_2u64(const uint64 first, const uint64 second, const siphash_key_t *key); +uint64 siphash13_3u64(const uint64 first, const uint64 second, const uint64 third, + const siphash_key_t *key); + +uint64 siphash13_4u64(const uint64 first, const uint64 second, const uint64 third, + const uint64 fourth, const siphash_key_t *key); + #endif // TUNSAFE_CRYPTO_SIPHASH_H_ diff --git a/third_party/flat_hash_map/bytell_hash_map.hpp b/third_party/flat_hash_map/bytell_hash_map.hpp new file mode 100644 index 0000000..92f1f40 --- /dev/null +++ b/third_party/flat_hash_map/bytell_hash_map.hpp @@ -0,0 +1,1455 @@ +// Copyright Malte Skarupke 2017. +// Distributed under the Boost Software License, Version 1.0. +// (See http://www.boost.org/LICENSE_1_0.txt) + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +#define SKA_NOINLINE(...) __declspec(noinline) __VA_ARGS__ +#else +#define SKA_NOINLINE(...) __VA_ARGS__ __attribute__((noinline)) +#endif + +namespace ska { + +namespace detailv8 { + +template +struct functor_storage : Functor { + functor_storage() = default; + functor_storage(const Functor & functor) + : Functor(functor) { + } + template + Result operator()(Args &&... args) { + return static_cast(*this)(std::forward(args)...); + } + template + Result operator()(Args &&... args) const { + return static_cast(*this)(std::forward(args)...); + } +}; +template +struct functor_storage { + typedef Result(*function_ptr)(Args...); + function_ptr function; + functor_storage(function_ptr function) + : function(function) { + } + Result operator()(Args... args) const { + return function(std::forward(args)...); + } + operator function_ptr &() { + return function; + } + operator const function_ptr &() { + return function; + } +}; +template +struct KeyOrValueHasher : functor_storage { + typedef functor_storage hasher_storage; + KeyOrValueHasher() = default; + KeyOrValueHasher(const hasher & hash) + : hasher_storage(hash) { + } + size_t operator()(const key_type & key) { + return static_cast(*this)(key); + } + size_t operator()(const key_type & key) const { + return static_cast(*this)(key); + } + size_t operator()(const value_type & value) { + return static_cast(*this)(value.first); + } + size_t operator()(const value_type & value) const { + return static_cast(*this)(value.first); + } + template + size_t operator()(const std::pair & value) { + return static_cast(*this)(value.first); + } + template + size_t operator()(const std::pair & value) const { + return static_cast(*this)(value.first); + } +}; +template +struct KeyOrValueEquality : functor_storage { + typedef functor_storage equality_storage; + KeyOrValueEquality() = default; + KeyOrValueEquality(const key_equal & equality) + : equality_storage(equality) { + } + bool operator()(const key_type & lhs, const key_type & rhs) { + return static_cast(*this)(lhs, rhs); + } + bool operator()(const key_type & lhs, const value_type & rhs) { + return static_cast(*this)(lhs, rhs.first); + } + bool operator()(const value_type & lhs, const key_type & rhs) { + return static_cast(*this)(lhs.first, rhs); + } + bool operator()(const value_type & lhs, const value_type & rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const key_type & lhs, const std::pair & rhs) { + return static_cast(*this)(lhs, rhs.first); + } + template + bool operator()(const std::pair & lhs, const key_type & rhs) { + return static_cast(*this)(lhs.first, rhs); + } + template + bool operator()(const value_type & lhs, const std::pair & rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const std::pair & lhs, const value_type & rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } + template + bool operator()(const std::pair & lhs, const std::pair & rhs) { + return static_cast(*this)(lhs.first, rhs.first); + } +}; + +template +struct AssignIfTrue { + void operator()(T & lhs, const T & rhs) { + lhs = rhs; + } + void operator()(T & lhs, T && rhs) { + lhs = std::move(rhs); + } +}; +template +struct AssignIfTrue { + void operator()(T &, const T &) { + } + void operator()(T &, T &&) { + } +}; + +struct fibonacci_hash_policy; + +template using void_t = void; + +template +struct HashPolicySelector { + typedef fibonacci_hash_policy type; +}; +template +struct HashPolicySelector> { + typedef typename T::hash_policy type; +}; + +inline uint64_t next_power_of_two(uint64_t i) { + --i; + i |= i >> 1; + i |= i >> 2; + i |= i >> 4; + i |= i >> 8; + i |= i >> 16; + i |= i >> 32; + ++i; + return i; +} + +inline int8_t log2(uint64_t value) { + static constexpr int8_t table[64] = + { + 63, 0, 58, 1, 59, 47, 53, 2, + 60, 39, 48, 27, 54, 33, 42, 3, + 61, 51, 37, 40, 49, 18, 28, 20, + 55, 30, 34, 11, 43, 14, 22, 4, + 62, 57, 46, 52, 38, 26, 32, 41, + 50, 36, 17, 19, 29, 10, 13, 21, + 56, 45, 25, 31, 35, 16, 9, 12, + 44, 24, 15, 8, 23, 7, 6, 5 + }; + value |= value >> 1; + value |= value >> 2; + value |= value >> 4; + value |= value >> 8; + value |= value >> 16; + value |= value >> 32; + return table[((value - (value >> 1)) * 0x07EDD5E59A4E28C2) >> 58]; +} + +struct fibonacci_hash_policy { + static constexpr bool is_32bit = sizeof(size_t) == 4; + static constexpr int8_t max_shift_value = is_32bit ? 32 : 64; + + size_t index_for_hash(size_t hash, size_t num_slots_minus_one) const { + return hash & num_slots_minus_one; + if (is_32bit) { + return (2654435769 * hash) >> shift; + } else { + return (size_t)(11400714819323198485ull * hash) >> shift; + } + } + size_t keep_in_range(size_t index, size_t num_slots_minus_one) const { + return index & num_slots_minus_one; + } + int8_t next_size_over(size_t & size) const { + size = std::max(size_t(2), (size_t)detailv8::next_power_of_two(size)); + return max_shift_value - detailv8::log2(size); + } + void commit(int8_t shift) { + this->shift = shift; + } + void reset() { + shift = max_shift_value - 1; + } + +private: + int8_t shift = max_shift_value - 1; +}; + + +template +struct sherwood_v8_constants +{ + static constexpr bool is_32bit = sizeof(size_t) == 4; + static constexpr int8_t magic_for_empty = int8_t(0b11111111); + static constexpr int8_t magic_for_reserved = int8_t(0b11111110); + static constexpr int8_t bits_for_direct_hit = int8_t(0b10000000); + static constexpr int8_t magic_for_direct_hit = int8_t(0b00000000); + static constexpr int8_t magic_for_list_entry = int8_t(0b10000000); + + static constexpr int8_t bits_for_distance = int8_t(0b01111111); + inline static int distance_from_metadata(int8_t metadata) + { + return metadata & bits_for_distance; + } + + static constexpr int num_jump_distances = 126; + // jump distances chosen like this: + // 1. pick the first 16 integers to promote staying in the same block + // 2. add the next 66 triangular numbers to get even jumps when + // the hash table is a power of two + // 3. add 44 more triangular numbers at a much steeper growth rate + // to get a sequence that allows large jumps so that a table + // with 10000 sequential numbers doesn't endlessly re-allocate + static constexpr size_t jump_distances[num_jump_distances] + { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + + 21, 28, 36, 45, 55, 66, 78, 91, 105, 120, 136, 153, 171, 190, 210, 231, + 253, 276, 300, 325, 351, 378, 406, 435, 465, 496, 528, 561, 595, 630, + 666, 703, 741, 780, 820, 861, 903, 946, 990, 1035, 1081, 1128, 1176, + 1225, 1275, 1326, 1378, 1431, 1485, 1540, 1596, 1653, 1711, 1770, 1830, + 1891, 1953, 2016, 2080, 2145, 2211, 2278, 2346, 2415, 2485, 2556, + + 3741, 8385, 18915, 42486, 95703, 215496, 485605, 1091503, 2456436, + 5529475, 12437578, 27986421, 62972253, 141700195, 318819126, 717314626, + 1614000520, 3631437253, 8170829695, 18384318876, 41364501751, + 93070021080, 209407709220, 471167588430, 1060127437995, 2385287281530, + 5366895564381, 12075513791265, 27169907873235, 61132301007778, + 137547673121001, 309482258302503, 696335090510256, 1566753939653640, + 3525196427195653, 7931691866727775, 17846306747368716, + 40154190394120111, 90346928493040500, 203280588949935750, + 457381324898247375, 1029107980662394500, 2315492957028380766, + 5209859150892887590, + }; +}; +template +constexpr int8_t sherwood_v8_constants::magic_for_empty; +template +constexpr int8_t sherwood_v8_constants::magic_for_reserved; +template +constexpr int8_t sherwood_v8_constants::bits_for_direct_hit; +template +constexpr int8_t sherwood_v8_constants::magic_for_direct_hit; +template +constexpr int8_t sherwood_v8_constants::magic_for_list_entry; + +template +constexpr int8_t sherwood_v8_constants::bits_for_distance; + +template +constexpr int sherwood_v8_constants::num_jump_distances; +template +constexpr size_t sherwood_v8_constants::jump_distances[num_jump_distances]; + +template +struct sherwood_v8_block +{ + sherwood_v8_block() + { + } + ~sherwood_v8_block() + { + } + int8_t control_bytes[BlockSize]; + union + { + T data[BlockSize]; + }; + + static sherwood_v8_block * empty_block() + { + static std::array empty_bytes = [] + { + std::array result; + result.fill(sherwood_v8_constants<>::magic_for_empty); + return result; + }(); + return reinterpret_cast(&empty_bytes); + } + + int first_empty_index() const + { + for (int i = 0; i < BlockSize; ++i) + { + if (control_bytes[i] == sherwood_v8_constants<>::magic_for_empty) + return i; + } + return -1; + } + + void fill_control_bytes(int8_t value) + { + std::fill(std::begin(control_bytes), std::end(control_bytes), value); + } +}; + +template +class sherwood_v8_table : private ByteAlloc, private Hasher, private Equal +{ + using AllocatorTraits = std::allocator_traits; + using BlockType = sherwood_v8_block; + using BlockPointer = BlockType *; + using BytePointer = typename AllocatorTraits::pointer; + struct convertible_to_iterator; + using Constants = sherwood_v8_constants<>; + +public: + + using value_type = T; + using size_type = size_t; + using difference_type = std::ptrdiff_t; + using hasher = ArgumentHash; + using key_equal = ArgumentEqual; + using allocator_type = ByteAlloc; + using reference = value_type &; + using const_reference = const value_type &; + using pointer = value_type *; + using const_pointer = const value_type *; + + sherwood_v8_table() + { + } + explicit sherwood_v8_table(size_type bucket_count, const ArgumentHash & hash = ArgumentHash(), const ArgumentEqual & equal = ArgumentEqual(), const ArgumentAlloc & alloc = ArgumentAlloc()) + : ByteAlloc(alloc), Hasher(hash), Equal(equal) + { + if (bucket_count) + rehash(bucket_count); + } + sherwood_v8_table(size_type bucket_count, const ArgumentAlloc & alloc) + : sherwood_v8_table(bucket_count, ArgumentHash(), ArgumentEqual(), alloc) + { + } + sherwood_v8_table(size_type bucket_count, const ArgumentHash & hash, const ArgumentAlloc & alloc) + : sherwood_v8_table(bucket_count, hash, ArgumentEqual(), alloc) + { + } + explicit sherwood_v8_table(const ArgumentAlloc & alloc) + : ByteAlloc(alloc) + { + } + template + sherwood_v8_table(It first, It last, size_type bucket_count = 0, const ArgumentHash & hash = ArgumentHash(), const ArgumentEqual & equal = ArgumentEqual(), const ArgumentAlloc & alloc = ArgumentAlloc()) + : sherwood_v8_table(bucket_count, hash, equal, alloc) + { + insert(first, last); + } + template + sherwood_v8_table(It first, It last, size_type bucket_count, const ArgumentAlloc & alloc) + : sherwood_v8_table(first, last, bucket_count, ArgumentHash(), ArgumentEqual(), alloc) + { + } + template + sherwood_v8_table(It first, It last, size_type bucket_count, const ArgumentHash & hash, const ArgumentAlloc & alloc) + : sherwood_v8_table(first, last, bucket_count, hash, ArgumentEqual(), alloc) + { + } + sherwood_v8_table(std::initializer_list il, size_type bucket_count = 0, const ArgumentHash & hash = ArgumentHash(), const ArgumentEqual & equal = ArgumentEqual(), const ArgumentAlloc & alloc = ArgumentAlloc()) + : sherwood_v8_table(bucket_count, hash, equal, alloc) + { + if (bucket_count == 0) + rehash(il.size()); + insert(il.begin(), il.end()); + } + sherwood_v8_table(std::initializer_list il, size_type bucket_count, const ArgumentAlloc & alloc) + : sherwood_v8_table(il, bucket_count, ArgumentHash(), ArgumentEqual(), alloc) + { + } + sherwood_v8_table(std::initializer_list il, size_type bucket_count, const ArgumentHash & hash, const ArgumentAlloc & alloc) + : sherwood_v8_table(il, bucket_count, hash, ArgumentEqual(), alloc) + { + } + sherwood_v8_table(const sherwood_v8_table & other) + : sherwood_v8_table(other, AllocatorTraits::select_on_container_copy_construction(other.get_allocator())) + { + } + sherwood_v8_table(const sherwood_v8_table & other, const ArgumentAlloc & alloc) + : ByteAlloc(alloc), Hasher(other), Equal(other), _max_load_factor(other._max_load_factor) + { + rehash_for_other_container(other); + try + { + insert(other.begin(), other.end()); + } + catch(...) + { + clear(); + deallocate_data(entries, num_slots_minus_one); + throw; + } + } + sherwood_v8_table(sherwood_v8_table && other) noexcept + : ByteAlloc(std::move(other)), Hasher(std::move(other)), Equal(std::move(other)) + , _max_load_factor(other._max_load_factor) + { + swap_pointers(other); + } + sherwood_v8_table(sherwood_v8_table && other, const ArgumentAlloc & alloc) noexcept + : ByteAlloc(alloc), Hasher(std::move(other)), Equal(std::move(other)) + , _max_load_factor(other._max_load_factor) + { + swap_pointers(other); + } + sherwood_v8_table & operator=(const sherwood_v8_table & other) + { + if (this == std::addressof(other)) + return *this; + + clear(); + if (AllocatorTraits::propagate_on_container_copy_assignment::value) + { + if (static_cast(*this) != static_cast(other)) + { + reset_to_empty_state(); + } + AssignIfTrue()(*this, other); + } + _max_load_factor = other._max_load_factor; + static_cast(*this) = other; + static_cast(*this) = other; + rehash_for_other_container(other); + insert(other.begin(), other.end()); + return *this; + } + sherwood_v8_table & operator=(sherwood_v8_table && other) noexcept + { + if (this == std::addressof(other)) + return *this; + else if (AllocatorTraits::propagate_on_container_move_assignment::value) + { + clear(); + reset_to_empty_state(); + AssignIfTrue()(*this, std::move(other)); + swap_pointers(other); + } + else if (static_cast(*this) == static_cast(other)) + { + swap_pointers(other); + } + else + { + clear(); + _max_load_factor = other._max_load_factor; + rehash_for_other_container(other); + for (T & elem : other) + emplace(std::move(elem)); + other.clear(); + } + static_cast(*this) = std::move(other); + static_cast(*this) = std::move(other); + return *this; + } + ~sherwood_v8_table() + { + clear(); + deallocate_data(entries, num_slots_minus_one); + } + + const allocator_type & get_allocator() const + { + return static_cast(*this); + } + const ArgumentEqual & key_eq() const + { + return static_cast(*this); + } + const ArgumentHash & hash_function() const + { + return static_cast(*this); + } + + template + struct templated_iterator + { + private: + friend class sherwood_v8_table; + BlockPointer current = BlockPointer(); + size_t index = 0; + + public: + templated_iterator() + { + } + templated_iterator(BlockPointer entries, size_t index) + : current(entries) + , index(index) + { + } + + using iterator_category = std::forward_iterator_tag; + using value_type = ValueType; + using difference_type = ptrdiff_t; + using pointer = ValueType *; + using reference = ValueType &; + + friend bool operator==(const templated_iterator & lhs, const templated_iterator & rhs) + { + return lhs.index == rhs.index; + } + friend bool operator!=(const templated_iterator & lhs, const templated_iterator & rhs) + { + return !(lhs == rhs); + } + + templated_iterator & operator++() + { + do + { + if (index % BlockSize == 0) + --current; + if (index-- == 0) + break; + } + while(current->control_bytes[index % BlockSize] == Constants::magic_for_empty); + return *this; + } + templated_iterator operator++(int) + { + templated_iterator copy(*this); + ++*this; + return copy; + } + + ValueType & operator*() const + { + return current->data[index % BlockSize]; + } + ValueType * operator->() const + { + return current->data + index % BlockSize; + } + + operator templated_iterator() const + { + return { current, index }; + } + }; + using iterator = templated_iterator; + using const_iterator = templated_iterator; + + iterator begin() + { + size_t num_slots = num_slots_minus_one ? num_slots_minus_one + 1 : 0; + return ++iterator{ entries + num_slots / BlockSize, num_slots }; + } + const_iterator begin() const + { + size_t num_slots = num_slots_minus_one ? num_slots_minus_one + 1 : 0; + return ++iterator{ entries + num_slots / BlockSize, num_slots }; + } + const_iterator cbegin() const + { + return begin(); + } + iterator end() + { + return { entries - 1, std::numeric_limits::max() }; + } + const_iterator end() const + { + return { entries - 1, std::numeric_limits::max() }; + } + const_iterator cend() const + { + return end(); + } + + inline iterator find(const FindKey & key) + { + size_t index = hash_object(key); + size_t num_slots_minus_one = this->num_slots_minus_one; + BlockPointer entries = this->entries; + index = hash_policy.index_for_hash(index, num_slots_minus_one); + bool first = true; + for (;;) + { + size_t block_index = index / BlockSize; + size_t index_in_block = index % BlockSize; + BlockPointer block = entries + block_index; + int8_t metadata = block->control_bytes[index_in_block]; + if (first) + { + if ((metadata & Constants::bits_for_direct_hit) != Constants::magic_for_direct_hit) + return end(); + first = false; + } + if (compares_equal(key, block->data[index_in_block])) + return { block, index }; + int8_t to_next_index = metadata & Constants::bits_for_distance; + if (to_next_index == 0) + return end(); + index += Constants::jump_distances[to_next_index]; + index = hash_policy.keep_in_range(index, num_slots_minus_one); + } + } + inline const_iterator find(const FindKey & key) const + { + return const_cast(this)->find(key); + } + size_t count(const FindKey & key) const + { + return find(key) == end() ? 0 : 1; + } + std::pair equal_range(const FindKey & key) + { + iterator found = find(key); + if (found == end()) + return { found, found }; + else + return { found, std::next(found) }; + } + std::pair equal_range(const FindKey & key) const + { + const_iterator found = find(key); + if (found == end()) + return { found, found }; + else + return { found, std::next(found) }; + } + + + template + inline std::pair emplace(Key && key, Args &&... args) + { + size_t index = hash_object(key); + size_t num_slots_minus_one = this->num_slots_minus_one; + BlockPointer entries = this->entries; + index = hash_policy.index_for_hash(index, num_slots_minus_one); + bool first = true; + for (;;) + { + size_t block_index = index / BlockSize; + size_t index_in_block = index % BlockSize; + BlockPointer block = entries + block_index; + int8_t metadata = block->control_bytes[index_in_block]; + if (first) + { + if ((metadata & Constants::bits_for_direct_hit) != Constants::magic_for_direct_hit) + return emplace_direct_hit({ index, block }, std::forward(key), std::forward(args)...); + first = false; + } + if (compares_equal(key, block->data[index_in_block])) + return { { block, index }, false }; + int8_t to_next_index = metadata & Constants::bits_for_distance; + if (to_next_index == 0) + return emplace_new_key({ index, block }, std::forward(key), std::forward(args)...); + index += Constants::jump_distances[to_next_index]; + index = hash_policy.keep_in_range(index, num_slots_minus_one); + } + } + + std::pair insert(const value_type & value) + { + return emplace(value); + } + std::pair insert(value_type && value) + { + return emplace(std::move(value)); + } + template + iterator emplace_hint(const_iterator, Args &&... args) + { + return emplace(std::forward(args)...).first; + } + iterator insert(const_iterator, const value_type & value) + { + return emplace(value).first; + } + iterator insert(const_iterator, value_type && value) + { + return emplace(std::move(value)).first; + } + + template + void insert(It begin, It end) + { + for (; begin != end; ++begin) + { + emplace(*begin); + } + } + void insert(std::initializer_list il) + { + insert(il.begin(), il.end()); + } + + void rehash(size_t num_items) + { + num_items = std::max(num_items, static_cast(std::ceil(num_elements / static_cast(_max_load_factor)))); + if (num_items == 0) + { + reset_to_empty_state(); + return; + } + auto new_prime_index = hash_policy.next_size_over(num_items); + if (num_items == num_slots_minus_one + 1) + return; + size_t num_blocks = num_items / BlockSize; + if (num_items % BlockSize) + ++num_blocks; + size_t memory_requirement = calculate_memory_requirement(num_blocks); + unsigned char * new_memory = &*AllocatorTraits::allocate(*this, memory_requirement); + + BlockPointer new_buckets = reinterpret_cast(new_memory); + + BlockPointer special_end_item = new_buckets + num_blocks; + for (BlockPointer it = new_buckets; it <= special_end_item; ++it) + it->fill_control_bytes(Constants::magic_for_empty); + using std::swap; + swap(entries, new_buckets); + swap(num_slots_minus_one, num_items); + --num_slots_minus_one; + hash_policy.commit(new_prime_index); + num_elements = 0; + if (num_items) + ++num_items; + size_t old_num_blocks = num_items / BlockSize; + if (num_items % BlockSize) + ++old_num_blocks; + for (BlockPointer it = new_buckets, end = new_buckets + old_num_blocks; it != end; ++it) + { + for (int i = 0; i < BlockSize; ++i) + { + int8_t metadata = it->control_bytes[i]; + if (metadata != Constants::magic_for_empty && metadata != Constants::magic_for_reserved) + { + emplace(std::move(it->data[i])); + AllocatorTraits::destroy(*this, it->data + i); + } + } + } + deallocate_data(new_buckets, num_items - 1); + } + + void reserve(size_t num_elements) + { + size_t required_buckets = num_buckets_for_reserve(num_elements); + if (required_buckets > bucket_count()) + rehash(required_buckets); + } + + // the return value is a type that can be converted to an iterator + // the reason for doing this is that it's not free to find the + // iterator pointing at the next element. if you care about the + // next iterator, turn the return value into an iterator + convertible_to_iterator erase(const_iterator to_erase) + { + LinkedListIt current = { to_erase.index, to_erase.current }; + if (current.has_next()) + { + LinkedListIt previous = current; + LinkedListIt next = current.next(*this); + while (next.has_next()) + { + previous = next; + next = next.next(*this); + } + AllocatorTraits::destroy(*this, std::addressof(*current)); + AllocatorTraits::construct(*this, std::addressof(*current), std::move(*next)); + AllocatorTraits::destroy(*this, std::addressof(*next)); + next.set_metadata(Constants::magic_for_empty); + previous.clear_next(); + } + else + { + if (!current.is_direct_hit()) + find_parent_block(current).clear_next(); + AllocatorTraits::destroy(*this, std::addressof(*current)); + current.set_metadata(Constants::magic_for_empty); + } + --num_elements; + return { to_erase.current, to_erase.index }; + } + + iterator erase(const_iterator begin_it, const_iterator end_it) + { + if (begin_it == end_it) + return { begin_it.current, begin_it.index }; + if (std::next(begin_it) == end_it) + return erase(begin_it); + if (begin_it == begin() && end_it == end()) + { + clear(); + return { end_it.current, end_it.index }; + } + std::vector> depth_in_chain; + for (const_iterator it = begin_it; it != end_it; ++it) + { + LinkedListIt list_it(it.index, it.current); + if (list_it.is_direct_hit()) + depth_in_chain.emplace_back(0, list_it); + else + { + LinkedListIt root = find_direct_hit(list_it); + int distance = 1; + for (;;) + { + LinkedListIt next = root.next(*this); + if (next == list_it) + break; + ++distance; + root = next; + } + depth_in_chain.emplace_back(distance, list_it); + } + } + std::sort(depth_in_chain.begin(), depth_in_chain.end(), [](const auto & a, const auto & b) { return a.first < b.first; }); + for (auto it = depth_in_chain.rbegin(), end = depth_in_chain.rend(); it != end; ++it) + { + erase(it->second.it()); + } + + if (begin_it.current->control_bytes[begin_it.index % BlockSize] == Constants::magic_for_empty) + return ++iterator{ begin_it.current, begin_it.index }; + else + return { begin_it.current, begin_it.index }; + } + + size_t erase(const FindKey & key) + { + auto found = find(key); + if (found == end()) + return 0; + else + { + erase(found); + return 1; + } + } + + void clear() + { + if (!num_slots_minus_one) + return; + size_t num_slots = num_slots_minus_one + 1; + size_t num_blocks = num_slots / BlockSize; + if (num_slots % BlockSize) + ++num_blocks; + for (BlockPointer it = entries, end = it + num_blocks; it != end; ++it) + { + for (int i = 0; i < BlockSize; ++i) + { + if (it->control_bytes[i] != Constants::magic_for_empty) + { + AllocatorTraits::destroy(*this, std::addressof(it->data[i])); + it->control_bytes[i] = Constants::magic_for_empty; + } + } + } + num_elements = 0; + } + + void shrink_to_fit() + { + rehash_for_other_container(*this); + } + + void swap(sherwood_v8_table & other) + { + using std::swap; + swap_pointers(other); + swap(static_cast(*this), static_cast(other)); + swap(static_cast(*this), static_cast(other)); + if (AllocatorTraits::propagate_on_container_swap::value) + swap(static_cast(*this), static_cast(other)); + } + + size_t size() const + { + return num_elements; + } + size_t max_size() const + { + return (AllocatorTraits::max_size(*this)) / sizeof(T); + } + size_t bucket_count() const + { + return num_slots_minus_one ? num_slots_minus_one + 1 : 0; + } + size_type max_bucket_count() const + { + return (AllocatorTraits::max_size(*this)) / sizeof(T); + } + size_t bucket(const FindKey & key) const + { + return hash_policy.index_for_hash(hash_object(key), num_slots_minus_one); + } + float load_factor() const + { + return static_cast(num_elements) / (num_slots_minus_one + 1); + } + void max_load_factor(float value) + { + _max_load_factor = value; + } + float max_load_factor() const + { + return _max_load_factor; + } + + bool empty() const + { + return num_elements == 0; + } + +public: + BlockPointer entries = BlockType::empty_block(); + size_t num_slots_minus_one = 0; + typename HashPolicySelector::type hash_policy; + float _max_load_factor = 0.9375f; + size_t num_elements = 0; + + size_t num_buckets_for_reserve(size_t num_elements) const + { + return static_cast(std::ceil(num_elements / static_cast(_max_load_factor))); + } + void rehash_for_other_container(const sherwood_v8_table & other) + { + rehash(std::min(num_buckets_for_reserve(other.size()), other.bucket_count())); + } + bool is_full() const + { + if (!num_slots_minus_one) + return true; + else + return num_elements + 1 > (num_slots_minus_one + 1) * static_cast(_max_load_factor); + } + + void swap_pointers(sherwood_v8_table & other) + { + using std::swap; + swap(hash_policy, other.hash_policy); + swap(entries, other.entries); + swap(num_slots_minus_one, other.num_slots_minus_one); + swap(num_elements, other.num_elements); + swap(_max_load_factor, other._max_load_factor); + } + + struct LinkedListIt + { + size_t index = 0; + BlockPointer block = nullptr; + + LinkedListIt() + { + } + LinkedListIt(size_t index, BlockPointer block) + : index(index), block(block) + { + } + + iterator it() const + { + return { block, index }; + } + int index_in_block() const + { + return index % BlockSize; + } + bool is_direct_hit() const + { + return (metadata() & Constants::bits_for_direct_hit) == Constants::magic_for_direct_hit; + } + bool is_empty() const + { + return metadata() == Constants::magic_for_empty; + } + bool has_next() const + { + return jump_index() != 0; + } + int8_t jump_index() const + { + return Constants::distance_from_metadata(metadata()); + } + int8_t metadata() const + { + return block->control_bytes[index_in_block()]; + } + void set_metadata(int8_t metadata) + { + block->control_bytes[index_in_block()] = metadata; + } + + LinkedListIt next(sherwood_v8_table & table) const + { + int8_t distance = jump_index(); + size_t next_index = table.hash_policy.keep_in_range(index + Constants::jump_distances[distance], table.num_slots_minus_one); + return { next_index, table.entries + next_index / BlockSize }; + } + void set_next(int8_t jump_index) + { + int8_t & metadata = block->control_bytes[index_in_block()]; + metadata = (metadata & ~Constants::bits_for_distance) | jump_index; + } + void clear_next() + { + set_next(0); + } + + value_type & operator*() const + { + return block->data[index_in_block()]; + } + bool operator!() const + { + return !block; + } + explicit operator bool() const + { + return block != nullptr; + } + bool operator==(const LinkedListIt & other) const + { + return index == other.index; + } + bool operator!=(const LinkedListIt & other) const + { + return !(*this == other); + } + }; + + template + SKA_NOINLINE(std::pair) emplace_direct_hit(LinkedListIt block, Args &&... args) + { + using std::swap; + if (is_full()) + { + grow(); + return emplace(std::forward(args)...); + } + if (block.metadata() == Constants::magic_for_empty) + { + AllocatorTraits::construct(*this, std::addressof(*block), std::forward(args)...); + block.set_metadata(Constants::magic_for_direct_hit); + ++num_elements; + return { block.it(), true }; + } + else + { + LinkedListIt parent_block = find_parent_block(block); + std::pair free_block = find_free_index(parent_block); + if (!free_block.first) + { + grow(); + return emplace(std::forward(args)...); + } + value_type new_value(std::forward(args)...); + for (LinkedListIt it = block;;) + { + AllocatorTraits::construct(*this, std::addressof(*free_block.second), std::move(*it)); + AllocatorTraits::destroy(*this, std::addressof(*it)); + parent_block.set_next(free_block.first); + free_block.second.set_metadata(Constants::magic_for_list_entry); + if (!it.has_next()) + { + it.set_metadata(Constants::magic_for_empty); + break; + } + LinkedListIt next = it.next(*this); + it.set_metadata(Constants::magic_for_empty); + block.set_metadata(Constants::magic_for_reserved); + it = next; + parent_block = free_block.second; + free_block = find_free_index(free_block.second); + if (!free_block.first) + { + grow(); + return emplace(std::move(new_value)); + } + } + AllocatorTraits::construct(*this, std::addressof(*block), std::move(new_value)); + block.set_metadata(Constants::magic_for_direct_hit); + ++num_elements; + return { block.it(), true }; + } + } + + template + SKA_NOINLINE(std::pair) emplace_new_key(LinkedListIt parent, Args &&... args) + { + if (is_full()) + { + grow(); + return emplace(std::forward(args)...); + } + std::pair free_block = find_free_index(parent); + if (!free_block.first) + { + grow(); + return emplace(std::forward(args)...); + } + AllocatorTraits::construct(*this, std::addressof(*free_block.second), std::forward(args)...); + free_block.second.set_metadata(Constants::magic_for_list_entry); + parent.set_next(free_block.first); + ++num_elements; + return { free_block.second.it(), true }; + } + + LinkedListIt find_direct_hit(LinkedListIt child) const + { + size_t to_move_hash = hash_object(*child); + size_t to_move_index = hash_policy.index_for_hash(to_move_hash, num_slots_minus_one); + return { to_move_index, entries + to_move_index / BlockSize }; + } + LinkedListIt find_parent_block(LinkedListIt child) + { + LinkedListIt parent_block = find_direct_hit(child); + for (;;) + { + LinkedListIt next = parent_block.next(*this); + if (next == child) + return parent_block; + parent_block = next; + } + } + + std::pair find_free_index(LinkedListIt parent) const + { + for (int8_t jump_index = 1; jump_index < Constants::num_jump_distances; ++jump_index) + { + size_t index = hash_policy.keep_in_range(parent.index + Constants::jump_distances[jump_index], num_slots_minus_one); + BlockPointer block = entries + index / BlockSize; + if (block->control_bytes[index % BlockSize] == Constants::magic_for_empty) + return { jump_index, { index, block } }; + } + return { 0, {} }; + } + + void grow() + { + rehash(std::max(size_t(10), 2 * bucket_count())); + } + + size_t calculate_memory_requirement(size_t num_blocks) + { + size_t memory_required = sizeof(BlockType) * num_blocks; + memory_required += BlockSize; // for metadata of past-the-end pointer + return memory_required; + } + + void deallocate_data(BlockPointer begin, size_t num_slots_minus_one) + { + if (begin == BlockType::empty_block()) + return; + + ++num_slots_minus_one; + size_t num_blocks = num_slots_minus_one / BlockSize; + if (num_slots_minus_one % BlockSize) + ++num_blocks; + size_t memory = calculate_memory_requirement(num_blocks); + unsigned char * as_byte_pointer = reinterpret_cast(begin); + AllocatorTraits::deallocate(*this, typename AllocatorTraits::pointer(as_byte_pointer), memory); + } + + void reset_to_empty_state() + { + deallocate_data(entries, num_slots_minus_one); + entries = BlockType::empty_block(); + num_slots_minus_one = 0; + hash_policy.reset(); + } + + template + size_t hash_object(const U & key) + { + return static_cast(*this)(key); + } + template + size_t hash_object(const U & key) const + { + return static_cast(*this)(key); + } + template + bool compares_equal(const L & lhs, const R & rhs) + { + return static_cast(*this)(lhs, rhs); + } + + struct convertible_to_iterator + { + BlockPointer it; + size_t index; + + operator iterator() + { + if (it->control_bytes[index % BlockSize] == Constants::magic_for_empty) + return ++iterator{it, index}; + else + return { it, index }; + } + operator const_iterator() + { + if (it->control_bytes[index % BlockSize] == Constants::magic_for_empty) + return ++iterator{it, index}; + else + return { it, index }; + } + }; +}; +template +struct AlignmentOr8Bytes +{ + static constexpr size_t value = 8; +}; +template +struct AlignmentOr8Bytes= 1>::type> +{ + static constexpr size_t value = alignof(T); +}; +template +struct CalculateBytellBlockSize; +template +struct CalculateBytellBlockSize +{ + static constexpr size_t this_value = AlignmentOr8Bytes::value; + static constexpr size_t base_value = CalculateBytellBlockSize::value; + static constexpr size_t value = this_value > base_value ? this_value : base_value; +}; +template<> +struct CalculateBytellBlockSize<> +{ + static constexpr size_t value = 8; +}; +} + +template, typename E = std::equal_to, typename A = std::allocator > > +class bytell_hash_map + : public detailv8::sherwood_v8_table + < + std::pair, + K, + H, + detailv8::KeyOrValueHasher, H>, + E, + detailv8::KeyOrValueEquality, E>, + A, + typename std::allocator_traits::template rebind_alloc, + detailv8::CalculateBytellBlockSize::value + > +{ + using Table = detailv8::sherwood_v8_table + < + std::pair, + K, + H, + detailv8::KeyOrValueHasher, H>, + E, + detailv8::KeyOrValueEquality, E>, + A, + typename std::allocator_traits::template rebind_alloc, + detailv8::CalculateBytellBlockSize::value + >; +public: + + using key_type = K; + using mapped_type = V; + + using Table::Table; + SKA_NOINLINE() bytell_hash_map() + { + } + + inline V & operator[](const K & key) + { + return emplace(key, convertible_to_value()).first->second; + } + inline V & operator[](K && key) + { + return emplace(std::move(key), convertible_to_value()).first->second; + } + V & at(const K & key) + { + auto found = this->find(key); + if (found == this->end()) + throw std::out_of_range("Argument passed to at() was not in the map."); + return found->second; + } + const V & at(const K & key) const + { + auto found = this->find(key); + if (found == this->end()) + throw std::out_of_range("Argument passed to at() was not in the map."); + return found->second; + } + + using Table::emplace; + std::pair emplace() + { + return emplace(key_type(), convertible_to_value()); + } + template + std::pair insert_or_assign(const key_type & key, M && m) + { + auto emplace_result = emplace(key, std::forward(m)); + if (!emplace_result.second) + emplace_result.first->second = std::forward(m); + return emplace_result; + } + template + std::pair insert_or_assign(key_type && key, M && m) + { + auto emplace_result = emplace(std::move(key), std::forward(m)); + if (!emplace_result.second) + emplace_result.first->second = std::forward(m); + return emplace_result; + } + template + typename Table::iterator insert_or_assign(typename Table::const_iterator, const key_type & key, M && m) + { + return insert_or_assign(key, std::forward(m)).first; + } + template + typename Table::iterator insert_or_assign(typename Table::const_iterator, key_type && key, M && m) + { + return insert_or_assign(std::move(key), std::forward(m)).first; + } + + friend bool operator==(const bytell_hash_map & lhs, const bytell_hash_map & rhs) + { + if (lhs.size() != rhs.size()) + return false; + for (const typename Table::value_type & value : lhs) + { + auto found = rhs.find(value.first); + if (found == rhs.end()) + return false; + else if (value.second != found->second) + return false; + } + return true; + } + friend bool operator!=(const bytell_hash_map & lhs, const bytell_hash_map & rhs) + { + return !(lhs == rhs); + } + +private: + struct convertible_to_value + { + operator V() const + { + return V(); + } + }; +}; + +template, typename E = std::equal_to, typename A = std::allocator > +class bytell_hash_set + : public detailv8::sherwood_v8_table + < + T, + T, + H, + detailv8::functor_storage, + E, + detailv8::functor_storage, + A, + typename std::allocator_traits::template rebind_alloc, + detailv8::CalculateBytellBlockSize::value + > +{ + using Table = detailv8::sherwood_v8_table + < + T, + T, + H, + detailv8::functor_storage, + E, + detailv8::functor_storage, + A, + typename std::allocator_traits::template rebind_alloc, + detailv8::CalculateBytellBlockSize::value + >; +public: + + using key_type = T; + + using Table::Table; + bytell_hash_set() + { + } + + template + std::pair emplace(Args &&... args) + { + return Table::emplace(T(std::forward(args)...)); + } + std::pair emplace(const key_type & arg) + { + return Table::emplace(arg); + } + std::pair emplace(key_type & arg) + { + return Table::emplace(arg); + } + std::pair emplace(const key_type && arg) + { + return Table::emplace(std::move(arg)); + } + std::pair emplace(key_type && arg) + { + return Table::emplace(std::move(arg)); + } + + friend bool operator==(const bytell_hash_set & lhs, const bytell_hash_set & rhs) + { + if (lhs.size() != rhs.size()) + return false; + for (const T & value : lhs) + { + if (rhs.find(value) == rhs.end()) + return false; + } + return true; + } + friend bool operator!=(const bytell_hash_set & lhs, const bytell_hash_set & rhs) + { + return !(lhs == rhs); + } +}; + +} // end namespace ska \ No newline at end of file diff --git a/tunsafe_config.h b/tunsafe_config.h index d493b25..bcbbdc0 100644 --- a/tunsafe_config.h +++ b/tunsafe_config.h @@ -9,3 +9,6 @@ #define WITH_HEADER_OBFUSCATION 0 #define WITH_AVX512_OPTIMIZATIONS 0 #define WITH_BENCHMARK 0 + +// Use bytell hashmap instead. Only works in 64-bit builds +#define WITH_BYTELL_HASHMAP 0 diff --git a/wireguard.cpp b/wireguard.cpp index fb616e3..ae9a74b 100644 --- a/wireguard.cpp +++ b/wireguard.cpp @@ -766,9 +766,7 @@ void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packe data += 4, bytes_left -= 4; keypair = dev_.LookupKeypairByKeyId(key_id); } else { - // Lookup the packet source ip and port in the address mapping - uint64 addr_id = packet->addr.sin.sin_addr.s_addr | ((uint64)packet->addr.sin.sin_port << 32); - keypair = dev_.LookupKeypairInAddrEntryMap(addr_id, ((tag / WG_SHORT_HEADER_KEY_ID) & 3) - 1); + keypair = dev_.LookupKeypairInAddrEntryMap(packet->addr, ((tag & WG_SHORT_HEADER_KEY_ID_MASK) / WG_SHORT_HEADER_KEY_ID) - 1); } if (!keypair || !keypair->enabled_features[WG_FEATURE_ID_SHORT_HEADER]) @@ -854,10 +852,8 @@ void WireguardProcessor::HandleShortHeaderFormatPacket(uint32 tag, Packet *packe // Periodically broadcast out the short key if ((tag & WG_SHORT_HEADER_KEY_ID_MASK) == 0x00 && !keypair->did_attempt_remember_ip_port) { keypair->did_attempt_remember_ip_port = true; - if (keypair->enabled_features[WG_FEATURE_ID_SKIP_KEYID_IN]) { - uint64 addr_id = packet->addr.sin.sin_addr.s_addr | ((uint64)packet->addr.sin.sin_port << 32); - dev_.UpdateKeypairAddrEntry_Locked(addr_id, keypair); - } + if (keypair->enabled_features[WG_FEATURE_ID_SKIP_KEYID_IN]) + dev_.UpdateKeypairAddrEntry_Locked(packet->addr, keypair); } // Ack header may also signal that we can omit the key id in packets from now on. if (tag & WG_SHORT_HEADER_ACK) diff --git a/wireguard_proto.cpp b/wireguard_proto.cpp index a7a5567..033064b 100644 --- a/wireguard_proto.cpp +++ b/wireguard_proto.cpp @@ -263,20 +263,36 @@ void WgDevice::EraseKeypairAddrEntry_Locked(WgKeypair *kp) { } } -WgKeypair *WgDevice::LookupKeypairInAddrEntryMap(uint64 addr, uint32 slot) { +static WgAddrEntry::IpPort ConvertIpAddrToAddrX(const IpAddr &src) { + WgAddrEntry::IpPort r; + if (src.sin.sin_family == AF_INET) { + Write64(r.bytes, src.sin.sin_addr.s_addr); + Write64(r.bytes + 8, 0); + Write32(r.bytes + 16, src.sin.sin_port); + } else { + memcpy(r.bytes, &src.sin6.sin6_addr, 16); + Write32(r.bytes + 16, (AF_INET6 << 16) + src.sin6.sin6_port); + } + return r; +} + +WgKeypair *WgDevice::LookupKeypairInAddrEntryMap(const IpAddr &addr, uint32 slot) { + // Convert IpAddr to WgAddrEntry::IpPort suitable for use in hash. + WgAddrEntry::IpPort addr_x = ConvertIpAddrToAddrX(addr); WG_SCOPED_RWLOCK_SHARED(addr_entry_lookup_lock_); - auto it = addr_entry_lookup_.find(addr); + auto it = addr_entry_lookup_.find(addr_x); if (it == addr_entry_lookup_.end()) return NULL; WgAddrEntry *addr_entry = it->second; return addr_entry->keys[slot]; } -void WgDevice::UpdateKeypairAddrEntry_Locked(uint64 addr_id, WgKeypair *keypair) { +void WgDevice::UpdateKeypairAddrEntry_Locked(const IpAddr &addr, WgKeypair *keypair) { assert(keypair->peer->IsPeerLocked()); + WgAddrEntry::IpPort addr_x = ConvertIpAddrToAddrX(addr); { WG_SCOPED_RWLOCK_SHARED(addr_entry_lookup_lock_); - if (keypair->addr_entry != NULL && keypair->addr_entry->addr_entry_id == addr_id) { + if (keypair->addr_entry != NULL && keypair->addr_entry->addr_entry_id == addr_x) { keypair->broadcast_short_key = 1; return; } @@ -286,10 +302,10 @@ void WgDevice::UpdateKeypairAddrEntry_Locked(uint64 addr_id, WgKeypair *keypair) if (keypair->addr_entry != NULL) EraseKeypairAddrEntry_Locked(keypair); - WgAddrEntry **aep = &addr_entry_lookup_[addr_id], *ae; + WgAddrEntry **aep = &addr_entry_lookup_[addr_x], *ae; if ((ae = *aep) == NULL) { - *aep = ae = new WgAddrEntry(addr_id); + *aep = ae = new WgAddrEntry(addr_x); } else { // Ensure we don't insert new things in this addr entry too often. if (ae->time_of_last_insertion + 1000 * 60 > low_resolution_timestamp_) @@ -1452,3 +1468,20 @@ bool WgKeypairDecryptPayload(uint8 *dst, size_t src_len, return memcmp_crypto(mac, dst + src_len, keypair->auth_tag_length) == 0; } } + +// A random siphash key that can be used for hashing so it gets harder to induce hash collisions. +struct RandomSiphashKey { + RandomSiphashKey() { OsGetRandomBytes((uint8*)&key, sizeof(key)); } + siphash_key_t key; +}; +static RandomSiphashKey random_siphash_key; + +size_t WgAddrEntry::IpPortHasher::operator()(const WgAddrEntry::IpPort &a) const { + uint32 xx = Read32(a.bytes + 16); + return siphash13_2u64(Read64(a.bytes) + xx, Read64(a.bytes + 8) + xx, &random_siphash_key.key); +} + +size_t WgPublicKeyHasher::operator()(const WgPublicKey&a) const { + return siphash13_4u64(a.u64[0], a.u64[1], a.u64[2], a.u64[3], &random_siphash_key.key); +} + diff --git a/wireguard_proto.h b/wireguard_proto.h index 15d1432..367951e 100644 --- a/wireguard_proto.h +++ b/wireguard_proto.h @@ -6,12 +6,18 @@ #include "netapi.h" #include "ipzip2/ipzip2.h" #include "tunsafe_config.h" +#include "tunsafe_endian.h" #include "tunsafe_threading.h" #include "ip_to_peer_map.h" #include #include #include #include + +#if WITH_BYTELL_HASHMAP +#include "third_party/flat_hash_map/bytell_hash_map.hpp" +#endif // WITH_BYTELL_HASHMAP + // Threading macros that enable locks only in MT builds #if WITH_WG_THREADING #define WG_SCOPED_LOCK(name) ScopedLock scoped_lock(&name) @@ -41,6 +47,13 @@ #define WG_IF_LOCKS_ENABLED_ELSE(expr, def) (def) #endif // WITH_WG_THREADING +// bytell hash is faster but more untested +#if WITH_BYTELL_HASHMAP +#define WG_HASHTABLE_IMPL ska::bytell_hash_map +#else +#define WG_HASHTABLE_IMPL std::unordered_map +#endif + enum ProtocolTimeouts { COOKIE_SECRET_MAX_AGE_MS = 120000, COOKIE_SECRET_LATENCY_MS = 5000, @@ -235,13 +248,23 @@ private: }; struct WgAddrEntry { - // The id of the addr entry, so we can delete ourselves - uint64 addr_entry_id; + struct IpPort { + uint8 bytes[20]; - // Ensure there's at least 1 minute between we allow registering - // a new key in this table. This means that each key will have - // a life time of at least 3 minutes. - uint64 time_of_last_insertion; + friend bool operator==(const IpPort &a, const IpPort &b) { + uint64 rv = Read64(a.bytes) ^ Read64(b.bytes); + rv |= Read64(a.bytes + 8) ^ Read64(b.bytes + 8); + rv |= Read32(a.bytes + 16) ^ Read32(b.bytes + 16); + return (rv == 0); + } + }; + + struct IpPortHasher { + size_t operator()(const IpPort &a) const; + }; + + // The id of the addr entry, so we can delete ourselves + IpPort addr_entry_id; // This entry gets erased when there's no longer any key pointing at it. uint8 ref_count; @@ -249,13 +272,19 @@ struct WgAddrEntry { // Index of the next slot 0-2 where we'll insert the next key. uint8 next_slot; + // Ensure there's at least 1 minute between we allow registering + // a new key in this table. This means that each key will have + // a life time of at least 3 minutes. + uint64 time_of_last_insertion; + // The three keys. WgKeypair *keys[3]; - WgAddrEntry(uint64 addr_entry_id) : addr_entry_id(addr_entry_id), ref_count(0), next_slot(0) { + WgAddrEntry(const IpPort &addr_entry_id) + : addr_entry_id(addr_entry_id), ref_count(0), next_slot(0), time_of_last_insertion(0) { keys[0] = keys[1] = keys[2] = NULL; - time_of_last_insertion = 0x123456789123456; } + }; struct ScramblerSiphashKeys { @@ -271,10 +300,7 @@ union WgPublicKey { }; struct WgPublicKeyHasher { - size_t operator()(const WgPublicKey&a) const { - uint64 rv = a.u64[0] ^ a.u64[1] ^ a.u64[2] ^ a.u64[3]; - return (size_t)(rv ^ (rv >> 32)); - } + size_t operator()(const WgPublicKey&a) const; }; class WgDevice { @@ -314,14 +340,12 @@ public: bool CheckCookieMac2(Packet *packet); void CreateCookieMessage(MessageHandshakeCookie *dst, Packet *packet, uint32 remote_key_id); - void UpdateKeypairAddrEntry_Locked(uint64 addr_id, WgKeypair *keypair); void SecondLoop(uint64 now); IpToPeerMap &ip_to_peer_map() { return ip_to_peer_map_; } WgPeer *first_peer() { return peers_; } const uint8 *public_key() const { return s_pub_; } WgRateLimit *rate_limiter() { return &rate_limiter_; } - std::unordered_map &addr_entry_map() { return addr_entry_lookup_; } WgPacketCompressionVer01 *compression_header() { return &compression_header_; } bool is_private_key_initialized() { return is_private_key_initialized_; } @@ -333,7 +357,9 @@ public: private: std::pair *LookupPeerInKeyIdLookup(uint32 key_id); WgKeypair *LookupKeypairByKeyId(uint32 key_id); - WgKeypair *LookupKeypairInAddrEntryMap(uint64 addr, uint32 slot); + + void UpdateKeypairAddrEntry_Locked(const IpAddr &addr, WgKeypair *keypair); + WgKeypair *LookupKeypairInAddrEntryMap(const IpAddr &addr, uint32 slot); // Return the peer matching the |public_key| or NULL WgPeer *GetPeerFromPublicKey(const WgPublicKey &pubkey); // Create a cookie by inspecting the source address of the |packet| @@ -357,20 +383,26 @@ private: // For hooking Delegate *delegate_; + + // Keypair IDs are generated randomly by us so no point in wasting cycles on + // hashing the random value. + struct KeyIdHasher { + size_t operator()(uint32 v) const { return v; } + }; + // Lock that protects key_id_lookup_ WG_DECLARE_RWLOCK(key_id_lookup_lock_); // Mapping from key-id to either an active keypair (if keypair is non-NULL), // or to a handshake. - std::unordered_map > key_id_lookup_; + WG_HASHTABLE_IMPL, KeyIdHasher> key_id_lookup_; // Mapping from IPV4 IP/PORT to WgPeer*, so we can find the peer when a key id is // not explicitly included. - std::unordered_map addr_entry_lookup_; + WG_HASHTABLE_IMPL addr_entry_lookup_; WG_DECLARE_RWLOCK(addr_entry_lookup_lock_); // Mapping from peer id to peer. This may be accessed only from MT. - std::unordered_map peer_id_lookup_; - + WG_HASHTABLE_IMPL peer_id_lookup_; // Queue of things scheduled to run on the main thread. WG_DECLARE_LOCK(main_thread_scheduled_lock_); WgPeer *main_thread_scheduled_, **main_thread_scheduled_last_;