Branch data Line data Source code
1 : : #pragma once
2 : :
3 : : #include <mutable/util/fn.hpp>
4 : : #include <mutable/util/macro.hpp>
5 : : #include <algorithm>
6 : : #include <cassert>
7 : : #include <cmath>
8 : : #include <cstdint>
9 : : #include <cstdlib>
10 : : #include <functional>
11 : : #include <iomanip>
12 : : #include <iostream>
13 : : #include <memory>
14 : : #include <stdexcept>
15 : : #include <utility>
16 : :
17 : :
18 : : namespace m {
19 : :
20 : :
21 : : /*======================================================================================================================
22 : : * This class implements an open addressing hash map that uses reference counting for fast collision resolution.
23 : : *====================================================================================================================*/
24 : :
25 : : template<
26 : : typename Key,
27 : : typename Value,
28 : : typename Hash = std::hash<Key>,
29 : : typename KeyEqual = std::equal_to<Key>
30 : : >
31 : : struct RefCountingHashMap
32 : : {
33 : : using key_type = Key;
34 : : using mapped_type = Value;
35 : : using value_type = std::pair<const Key, Value>;
36 : : using hasher = Hash;
37 : : using key_equal = KeyEqual;
38 : : using pointer = value_type*;
39 : : using const_pointer = const value_type*;
40 : : using reference = value_type&;
41 : : using const_reference = const value_type&;
42 : : using size_type = std::size_t;
43 : : using difference_type = std::ptrdiff_t;
44 : :
45 : : private:
46 : 1409 : struct entry_type
47 : : {
48 : : /** Counts the length of the probe sequence. */
49 : 1409 : uint32_t probe_length = 0;
50 : : /** The value stored in this entry. */
51 : : value_type value;
52 : : };
53 : :
54 : : template<bool C>
55 : : struct the_iterator
56 : : {
57 : : friend struct RefCountingHashMap;
58 : :
59 : : static constexpr bool Is_Const = C;
60 : :
61 : : using map_type = std::conditional_t<Is_Const, const RefCountingHashMap, RefCountingHashMap>;
62 : : using bucket_type = std::conditional_t<Is_Const, const typename map_type::entry_type, typename map_type::entry_type>;
63 : : using pointer = std::conditional_t<Is_Const, const typename map_type::const_pointer, typename map_type::pointer>;
64 : : using reference = std::conditional_t<Is_Const, const typename map_type::const_reference, typename map_type::reference>;
65 : :
66 : : private:
67 : : map_type &map_;
68 : : bucket_type *bucket_ = nullptr;
69 : :
70 : : public:
71 : 192 : the_iterator(map_type &map, entry_type *bucket) : map_(map), bucket_(bucket) { }
72 : :
73 : : the_iterator & operator++() {
74 : : do {
75 : : ++bucket_;
76 : : } while (bucket_ != map_.table_ + map_.capacity_ and bucket_->probe_length == 0);
77 : : return *this;
78 : : }
79 : :
80 : : the_iterator operator++(int) {
81 : : the_iterator tmp = *this;
82 : : operator++();
83 : : return tmp;
84 : : }
85 : :
86 : : reference operator*() const { return bucket_->value; }
87 : 48 : pointer operator->() const { return &bucket_->value; }
88 : :
89 : 13 : bool operator==(the_iterator other) const { return this->bucket_ == other.bucket_; }
90 : 8 : bool operator!=(the_iterator other) const { return not operator==(other); }
91 : : };
92 : :
93 : : template<bool C>
94 : : struct the_bucket_iterator
95 : : {
96 : : friend struct RefCountingHashMap;
97 : :
98 : : static constexpr bool Is_Const = C;
99 : :
100 : : using map_type = std::conditional_t<Is_Const, const RefCountingHashMap, RefCountingHashMap>;
101 : : using bucket_type = std::conditional_t<Is_Const, const typename map_type::entry_type, typename map_type::entry_type>;
102 : : using pointer = std::conditional_t<Is_Const, const typename map_type::const_pointer, typename map_type::pointer>;
103 : : using reference = std::conditional_t<Is_Const, const typename map_type::const_reference, typename map_type::reference>;
104 : : using size_type = typename map_type::size_type;
105 : :
106 : : private:
107 : : map_type &map_;
108 : : size_type bucket_index_;
109 : 155 : size_type step_ = 0;
110 : : size_type max_step_;
111 : :
112 : : public:
113 : 155 : the_bucket_iterator(map_type &map, size_type bucket_index)
114 : 155 : : map_(map)
115 : 155 : , bucket_index_(bucket_index)
116 : : {
117 : 155 : max_step_ = bucket()->probe_length;
118 : 155 : }
119 : :
120 : 227 : bool has_next() const { return step_ < max_step_; }
121 : :
122 : 155 : the_bucket_iterator & operator++() {
123 : 155 : ++step_;
124 : 155 : bucket_index_ = map_.masked(bucket_index_ + step_);
125 : 155 : return *this;
126 : : }
127 : :
128 : : the_bucket_iterator & operator++(int) {
129 : : the_bucket_iterator tmp = *this;
130 : : operator++();
131 : : return tmp;
132 : : }
133 : :
134 : 72 : reference operator*() const { return map_.table_[bucket_index_].value; }
135 : 148 : pointer operator->() const { return &map_.table_[bucket_index_].value; }
136 : :
137 : 8 : size_type probe_length() const { return step_; }
138 : 16 : size_type probe_distance() const { return (step_ * (step_ + 1)) / 2; }
139 : 171 : size_type current_index() const { return bucket_index_; }
140 : 8 : size_type bucket_index() const { return map_.masked(current_index() - probe_distance()); }
141 : :
142 : : private:
143 : 155 : bucket_type * bucket() const { return map_.table_ + current_index(); }
144 : : };
145 : :
146 : : public:
147 : : using iterator = the_iterator<false>;
148 : : using const_iterator = the_iterator<true>;
149 : :
150 : : using bucket_iterator = the_bucket_iterator<false>;
151 : : using const_bucket_iterator = the_bucket_iterator<true>;
152 : :
153 : : private:
154 : : const hasher h_;
155 : : const key_equal eq_;
156 : :
157 : : /** A pointer to the beginning of the table. */
158 : 21 : entry_type *table_ = nullptr;
159 : : /** The total number of entries allocated in the table. */
160 : : size_type capacity_ = 0;
161 : : /** The number of occupied entries in the table. */
162 : 21 : size_type size_ = 0;
163 : : /** The maximum size before resizing. */
164 : : size_type watermark_high_;
165 : : /** The maximum load factor before resizing. */
166 : 21 : float max_load_factor_ = .85;
167 : :
168 : : public:
169 : 21 : RefCountingHashMap(size_type bucket_count,
170 : : const hasher &hash = hasher(),
171 : : const key_equal &equal = key_equal())
172 : 21 : : h_(hash)
173 : 21 : , eq_(equal)
174 : 21 : , capacity_(ceil_to_pow_2(bucket_count))
175 : : {
176 [ + + ]: 21 : if (bucket_count == 0)
177 [ + - ]: 1 : throw std::invalid_argument("bucket_count must not be zero");
178 : :
179 : : /* Allocate and initialize table. */
180 : 20 : table_ = allocate(capacity_);
181 : 20 : initialize();
182 : :
183 : : /* Compute high watermark. */
184 : 20 : watermark_high_ = capacity_ * max_load_factor_;
185 : 20 : }
186 : :
187 : 20 : ~RefCountingHashMap() {
188 [ + + ]: 1317 : for (auto p = table_, end = table_ + capacity_; p != end; ++p) {
189 [ + + ]: 1297 : if (p->probe_length != 0)
190 : 102 : p->~entry_type();
191 : 1297 : }
192 : 20 : free(table_);
193 : 20 : }
194 : :
195 : 31 : size_type capacity() const { return capacity_; }
196 : 53 : size_type size() const { return size_; }
197 : 551 : size_type mask() const { return capacity_ - size_type(1); }
198 : 551 : size_type masked(size_type index) const { return index & mask(); }
199 : 24 : float max_load_factor() const { return max_load_factor_; }
200 : 10 : void max_load_factor(float ml) {
201 : 10 : max_load_factor_ = std::clamp(ml, .0f, .99f);
202 : 10 : watermark_high_ = max_load_factor_ * capacity_;
203 : 10 : }
204 : 11 : size_type watermark_high() const { return watermark_high_; }
205 : :
206 : : iterator begin() { return iterator(*this, table_); }
207 : 18 : iterator end() { return iterator(*this, table_ + capacity()); }
208 : : const_iterator begin() const { return const_iterator(*this, table_); }
209 : : const_iterator end() const { return const_iterator(*this, table_ + capacity()); }
210 : : const_iterator cbegin() const { return begin(); }
211 : : const_iterator cend() const { return end(); }
212 : :
213 : 103 : iterator insert_with_duplicates(key_type key, mapped_type value) {
214 [ + + ]: 103 : if (size_ >= watermark_high_)
215 : 4 : resize(2 * capacity_);
216 : :
217 : 103 : const auto hash = h_(key);
218 : 103 : const size_type index = masked(hash);
219 : 103 : entry_type * const bucket = table_ + index;
220 : :
221 [ + + ]: 103 : if (bucket->probe_length == 0) [[likely]] { // bucket is free
222 : 30 : ++bucket->probe_length;
223 : 30 : new (&bucket->value) value_type(std::move(key), std::move(value));
224 : 30 : ++size_;
225 : 30 : return iterator(*this, bucket);
226 : : }
227 : :
228 : : /* Compute distance to end of probe sequence. */
229 : 73 : size_type distance = (bucket->probe_length * bucket->probe_length + bucket->probe_length) >> 1;
230 : 73 : M_insist(distance > 0, "the distance must not be 0, otherwise we would have run into the likely case above");
231 : :
232 : : /* Search next free slot in bucket's probe sequence. */
233 : 73 : entry_type *probe = table_ + masked(index + distance);
234 : 73 : M_insist(probe != bucket, "the probed slot must not be the original bucket as the distance is not 0 and always "
235 : : "less than capacity");
236 [ + + ]: 92 : while (probe->probe_length != 0) {
237 : 19 : ++bucket->probe_length;
238 : 19 : distance += bucket->probe_length;
239 : 19 : probe = table_ + masked(index + distance);
240 : : }
241 : :
242 : : /* Found free slot in bucket's probe sequence. Place element in slot and update probe length. */
243 : 73 : ++probe->probe_length; // set probe_length from 0 to 1
244 : 73 : ++bucket->probe_length;
245 : 73 : new (&probe->value) value_type(std::move(key), std::move(value));
246 : 73 : ++size_;
247 : 73 : return iterator(*this, probe);
248 : 103 : }
249 : :
250 : 63 : std::pair<iterator, bool> insert_without_duplicates(key_type key, mapped_type value) {
251 [ + + ]: 63 : if (size_ >= watermark_high_)
252 : 2 : resize(2 * capacity_);
253 : :
254 : 63 : const auto hash = h_(key);
255 : 63 : const size_type index = masked(hash);
256 : 63 : entry_type * const bucket = table_ + index;
257 : :
258 : 63 : entry_type *probe = bucket;
259 : 63 : size_type insertion_probe_length = 0;
260 : 63 : size_type insertion_probe_distance = 0;
261 [ + + ]: 92 : while (probe->probe_length != 0) {
262 [ + + ]: 35 : if (eq_(key, probe->value.first))
263 : 6 : return std::make_pair(iterator(*this, probe), false); // duplicate key
264 : 29 : ++insertion_probe_length;
265 : 29 : insertion_probe_distance += insertion_probe_length;
266 : 29 : probe = table_ + masked(index + insertion_probe_distance);
267 : : }
268 : :
269 : 57 : ++probe->probe_length; // set probe_length from 0 to 1
270 : 57 : bucket->probe_length = insertion_probe_length + 1;
271 : 57 : new (&probe->value) value_type(std::move(key), std::move(value));
272 : 57 : ++size_;
273 : 57 : return std::make_pair(iterator(*this, probe), true);
274 : 63 : }
275 : :
276 : 13 : iterator find(const key_type &key) {
277 : 13 : const auto hash = h_(key);
278 : 13 : const size_type index = masked(hash);
279 : 13 : entry_type * const bucket = table_ + index;
280 : :
281 : : #if 1
282 : : /* Search the probe sequence in natural order. */
283 : 13 : entry_type *probe = bucket;
284 : 13 : size_type lookup_probe_length = 0;
285 : 13 : size_type lookup_probe_distance = 0;
286 [ + + + + ]: 21 : while (probe->probe_length != 0 and lookup_probe_length < bucket->probe_length) {
287 [ + + ]: 16 : if (eq_(key, probe->value.first))
288 : 8 : return iterator(*this, probe);
289 : 8 : ++lookup_probe_length;
290 : 8 : lookup_probe_distance += lookup_probe_length;
291 : 8 : probe = table_ + masked(index + lookup_probe_distance);
292 : : }
293 : : #else
294 : : /* Search the probe sequence in inversed order, starting at the last element of this bucket. */
295 : : size_type lookup_probe_length = bucket->probe_length;
296 : : size_type lookup_probe_distance = (lookup_probe_length * (lookup_probe_length - 1)) >> 1;
297 : : while (lookup_probe_length != 0) {
298 : : entry_type *probe = table_ + masked(index + lookup_probe_distance);
299 : : if (eq_(key, probe->value.first))
300 : : return iterator(*this, probe);
301 : : --lookup_probe_length;
302 : : lookup_probe_distance -= lookup_probe_length;
303 : : }
304 : : #endif
305 : 5 : return end();
306 : 13 : }
307 : : const_iterator find(const key_type &key) const {
308 : : return const_iterator(*this, const_cast<RefCountingHashMap*>(this)->find(key).bucket_);
309 : : }
310 : :
311 : 80 : bucket_iterator bucket(const key_type &key) {
312 : 80 : const auto hash = h_(key);
313 : 80 : const size_type index = masked(hash);
314 : 80 : return bucket_iterator(*this, index);
315 : : }
316 : 75 : const_bucket_iterator bucket(const key_type &key) const {
317 : 75 : return const_bucket_iterator(*this, const_cast<RefCountingHashMap*>(this)->bucket(key).bucket_index_);
318 : : }
319 : :
320 : 4 : void for_all(const key_type &key, std::function<void(value_type&)> callback) {
321 [ + + ]: 16 : for (auto it = bucket(key); it.has_next(); ++it) {
322 [ + + ]: 12 : if (eq_(key, it->first))
323 : 8 : callback(*it);
324 : 12 : }
325 : 4 : }
326 : 75 : void for_all(const key_type &key, std::function<void(const value_type&)> callback) const {
327 [ + + ]: 211 : for (auto it = bucket(key); it.has_next(); ++it) {
328 [ + + ]: 136 : if (eq_(key, it->first))
329 : 64 : callback(*it);
330 : 136 : }
331 : 75 : }
332 : :
333 : 75 : size_type count(const key_type &key) const {
334 : 75 : size_type cnt = 0;
335 [ + - ]: 139 : for_all(key, [&cnt](auto) { ++cnt; });
336 : 75 : return cnt;
337 : 0 : }
338 : :
339 : : /** Rehash all elements. */
340 : : private:
341 : 12 : void rehash(std::size_t new_capacity) {
342 : 12 : M_insist((new_capacity & (new_capacity - 1)) == 0, "not a power of 2");
343 : 12 : M_insist(size_ <= watermark_high_, "there are more elements to rehash than the high watermark allows");
344 : :
345 : 12 : auto old_table = table_;
346 : 12 : table_ = allocate(new_capacity);
347 : 12 : auto old_capacity = capacity_;
348 : 12 : capacity_ = new_capacity;
349 : 12 : size_ = 0;
350 : 12 : initialize();
351 : :
352 [ + + ]: 124 : for (auto runner = old_table, end = old_table + old_capacity; runner != end; ++runner) {
353 [ + + ]: 112 : if (runner->probe_length) {
354 : 58 : auto &key_ref = const_cast<key_type&>(runner->value.first); // hack around the `const key_type`
355 [ # # # # ]: 58 : insert_with_duplicates(std::move(key_ref), std::move(runner->value.second));
356 : 58 : }
357 : 112 : }
358 : :
359 : 12 : free(old_table);
360 : 12 : }
361 : :
362 : : public:
363 : 1 : void rehash() { rehash(capacity_); }
364 : :
365 : 13 : void resize(std::size_t new_capacity) {
366 : 13 : new_capacity = std::max<decltype(new_capacity)>(new_capacity, std::ceil(size() / max_load_factor()));
367 : 13 : new_capacity = ceil_to_pow_2(new_capacity);
368 : :
369 [ + + ]: 13 : if (new_capacity != capacity_) {
370 : 11 : watermark_high_ = new_capacity * max_load_factor();
371 : 11 : M_insist(watermark_high() >= size());
372 : 11 : rehash(new_capacity);
373 : 11 : }
374 : 13 : }
375 : :
376 : : void shrink_to_fit() { resize(size()); }
377 : :
378 : : M_LCOV_EXCL_START
379 : : friend std::ostream & operator<<(std::ostream &out, const RefCountingHashMap &map) {
380 : : size_type log2 = log2_ceil(map.capacity());
381 : : size_type log10 = size_type(std::ceil(double(log2) / 3.322));
382 : : for (size_type i = 0; i != map.capacity_; ++i) {
383 : : auto &entry = map.table_[i];
384 : : out << '[' << std::setw(log10) << i << "]: probe length = " << std::setw(log10) << entry.probe_length;
385 : : if (entry.probe_length) {
386 : : out << ", value = (" << entry.value.first << ", " << entry.value.second << ')';
387 : : }
388 : : out << '\n';
389 : : }
390 : : return out;
391 : : }
392 : :
393 : : void dump(std::ostream &out) const { out << *this; out.flush(); }
394 : : void dump() const { dump(std::cerr); }
395 : : M_LCOV_EXCL_STOP
396 : :
397 : : private:
398 : 32 : static entry_type * allocate(size_type n, entry_type *hint = nullptr) {
399 : 32 : auto p = static_cast<entry_type*>(realloc(hint, n * sizeof(entry_type)));
400 [ + - ]: 32 : if (p == nullptr)
401 [ # # ]: 0 : throw std::runtime_error("allocation failed");
402 : 32 : return p;
403 : 0 : }
404 : :
405 : 32 : void initialize() {
406 [ + + ]: 1441 : for (auto runner = table_, end = table_ + capacity_; runner != end; ++runner)
407 : 1409 : new (runner) entry_type();
408 : 32 : }
409 : : };
410 : :
411 : : }
|