LCOV - code coverage report
Current view: top level - src/util/container - RefCountingHashMap.hpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 173 176 98.3 %
Date: 2025-03-25 01:19:55 Functions: 46 70 65.7 %
Branches: 43 52 82.7 %

           Branch data     Line data    Source code
       1                 :            : #pragma once
       2                 :            : 
       3                 :            : #include <mutable/util/fn.hpp>
       4                 :            : #include <mutable/util/macro.hpp>
       5                 :            : #include <algorithm>
       6                 :            : #include <cassert>
       7                 :            : #include <cmath>
       8                 :            : #include <cstdint>
       9                 :            : #include <cstdlib>
      10                 :            : #include <functional>
      11                 :            : #include <iomanip>
      12                 :            : #include <iostream>
      13                 :            : #include <memory>
      14                 :            : #include <stdexcept>
      15                 :            : #include <utility>
      16                 :            : 
      17                 :            : 
      18                 :            : namespace m {
      19                 :            : 
      20                 :            : 
      21                 :            : /*======================================================================================================================
      22                 :            :  * This class implements an open addressing hash map that uses reference counting for fast collision resolution.
      23                 :            :  *====================================================================================================================*/
      24                 :            : 
      25                 :            : template<
      26                 :            :     typename Key,
      27                 :            :     typename Value,
      28                 :            :     typename Hash = std::hash<Key>,
      29                 :            :     typename KeyEqual = std::equal_to<Key>
      30                 :            : >
      31                 :            : struct RefCountingHashMap
      32                 :            : {
      33                 :            :     using key_type          = Key;
      34                 :            :     using mapped_type       = Value;
      35                 :            :     using value_type        = std::pair<const Key, Value>;
      36                 :            :     using hasher            = Hash;
      37                 :            :     using key_equal         = KeyEqual;
      38                 :            :     using pointer           = value_type*;
      39                 :            :     using const_pointer     = const value_type*;
      40                 :            :     using reference         = value_type&;
      41                 :            :     using const_reference   = const value_type&;
      42                 :            :     using size_type         = std::size_t;
      43                 :            :     using difference_type   = std::ptrdiff_t;
      44                 :            : 
      45                 :            :     private:
      46                 :       1409 :     struct entry_type
      47                 :            :     {
      48                 :            :         /** Counts the length of the probe sequence. */
      49                 :       1409 :         uint32_t probe_length = 0;
      50                 :            :         /** The value stored in this entry. */
      51                 :            :         value_type value;
      52                 :            :     };
      53                 :            : 
      54                 :            :     template<bool C>
      55                 :            :     struct the_iterator
      56                 :            :     {
      57                 :            :         friend struct RefCountingHashMap;
      58                 :            : 
      59                 :            :         static constexpr bool Is_Const = C;
      60                 :            : 
      61                 :            :         using map_type = std::conditional_t<Is_Const, const RefCountingHashMap, RefCountingHashMap>;
      62                 :            :         using bucket_type = std::conditional_t<Is_Const, const typename map_type::entry_type, typename map_type::entry_type>;
      63                 :            :         using pointer = std::conditional_t<Is_Const, const typename map_type::const_pointer, typename map_type::pointer>;
      64                 :            :         using reference = std::conditional_t<Is_Const, const typename map_type::const_reference, typename map_type::reference>;
      65                 :            : 
      66                 :            :         private:
      67                 :            :         map_type &map_;
      68                 :            :         bucket_type *bucket_ = nullptr;
      69                 :            : 
      70                 :            :         public:
      71                 :        192 :         the_iterator(map_type &map, entry_type *bucket) : map_(map), bucket_(bucket) { }
      72                 :            : 
      73                 :            :         the_iterator & operator++() {
      74                 :            :             do {
      75                 :            :                 ++bucket_;
      76                 :            :             } while (bucket_ != map_.table_ + map_.capacity_ and bucket_->probe_length == 0);
      77                 :            :             return *this;
      78                 :            :         }
      79                 :            : 
      80                 :            :         the_iterator operator++(int) {
      81                 :            :             the_iterator tmp = *this;
      82                 :            :             operator++();
      83                 :            :             return tmp;
      84                 :            :         }
      85                 :            : 
      86                 :            :         reference operator*() const { return bucket_->value; }
      87                 :         48 :         pointer operator->() const { return &bucket_->value; }
      88                 :            : 
      89                 :         13 :         bool operator==(the_iterator other) const { return this->bucket_ == other.bucket_; }
      90                 :          8 :         bool operator!=(the_iterator other) const { return not operator==(other); }
      91                 :            :     };
      92                 :            : 
      93                 :            :     template<bool C>
      94                 :            :     struct the_bucket_iterator
      95                 :            :     {
      96                 :            :         friend struct RefCountingHashMap;
      97                 :            : 
      98                 :            :         static constexpr bool Is_Const = C;
      99                 :            : 
     100                 :            :         using map_type = std::conditional_t<Is_Const, const RefCountingHashMap, RefCountingHashMap>;
     101                 :            :         using bucket_type = std::conditional_t<Is_Const, const typename map_type::entry_type, typename map_type::entry_type>;
     102                 :            :         using pointer = std::conditional_t<Is_Const, const typename map_type::const_pointer, typename map_type::pointer>;
     103                 :            :         using reference = std::conditional_t<Is_Const, const typename map_type::const_reference, typename map_type::reference>;
     104                 :            :         using size_type = typename map_type::size_type;
     105                 :            : 
     106                 :            :         private:
     107                 :            :         map_type &map_;
     108                 :            :         size_type bucket_index_;
     109                 :        155 :         size_type step_ = 0;
     110                 :            :         size_type max_step_;
     111                 :            : 
     112                 :            :         public:
     113                 :        155 :         the_bucket_iterator(map_type &map, size_type bucket_index)
     114                 :        155 :             : map_(map)
     115                 :        155 :             , bucket_index_(bucket_index)
     116                 :            :         {
     117                 :        155 :             max_step_ = bucket()->probe_length;
     118                 :        155 :         }
     119                 :            : 
     120                 :        227 :         bool has_next() const { return step_ < max_step_; }
     121                 :            : 
     122                 :        155 :         the_bucket_iterator & operator++() {
     123                 :        155 :             ++step_;
     124                 :        155 :             bucket_index_ = map_.masked(bucket_index_ + step_);
     125                 :        155 :             return *this;
     126                 :            :         }
     127                 :            : 
     128                 :            :         the_bucket_iterator & operator++(int) {
     129                 :            :             the_bucket_iterator tmp = *this;
     130                 :            :             operator++();
     131                 :            :             return tmp;
     132                 :            :         }
     133                 :            : 
     134                 :         72 :         reference operator*() const { return map_.table_[bucket_index_].value; }
     135                 :        148 :         pointer operator->() const { return &map_.table_[bucket_index_].value; }
     136                 :            : 
     137                 :          8 :         size_type probe_length() const { return step_; }
     138                 :         16 :         size_type probe_distance() const { return (step_ * (step_ + 1)) / 2; }
     139                 :        171 :         size_type current_index() const { return bucket_index_; }
     140                 :          8 :         size_type bucket_index() const { return map_.masked(current_index() - probe_distance()); }
     141                 :            : 
     142                 :            :         private:
     143                 :        155 :         bucket_type * bucket() const { return map_.table_ + current_index(); }
     144                 :            :     };
     145                 :            : 
     146                 :            :     public:
     147                 :            :     using iterator = the_iterator<false>;
     148                 :            :     using const_iterator = the_iterator<true>;
     149                 :            : 
     150                 :            :     using bucket_iterator = the_bucket_iterator<false>;
     151                 :            :     using const_bucket_iterator = the_bucket_iterator<true>;
     152                 :            : 
     153                 :            :     private:
     154                 :            :     const hasher h_;
     155                 :            :     const key_equal eq_;
     156                 :            : 
     157                 :            :     /** A pointer to the beginning of the table. */
     158                 :         21 :     entry_type *table_ = nullptr;
     159                 :            :     /** The total number of entries allocated in the table. */
     160                 :            :     size_type capacity_ = 0;
     161                 :            :     /** The number of occupied entries in the table. */
     162                 :         21 :     size_type size_ = 0;
     163                 :            :     /** The maximum size before resizing. */
     164                 :            :     size_type watermark_high_;
     165                 :            :     /** The maximum load factor before resizing. */
     166                 :         21 :     float max_load_factor_ = .85;
     167                 :            : 
     168                 :            :     public:
     169                 :         21 :     RefCountingHashMap(size_type bucket_count,
     170                 :            :                        const hasher &hash = hasher(),
     171                 :            :                        const key_equal &equal = key_equal())
     172                 :         21 :         : h_(hash)
     173                 :         21 :         , eq_(equal)
     174                 :         21 :         , capacity_(ceil_to_pow_2(bucket_count))
     175                 :            :     {
     176         [ +  + ]:         21 :         if (bucket_count == 0)
     177         [ +  - ]:          1 :             throw std::invalid_argument("bucket_count must not be zero");
     178                 :            : 
     179                 :            :         /* Allocate and initialize table. */
     180                 :         20 :         table_ = allocate(capacity_);
     181                 :         20 :         initialize();
     182                 :            : 
     183                 :            :         /* Compute high watermark. */
     184                 :         20 :         watermark_high_ = capacity_ * max_load_factor_;
     185                 :         20 :     }
     186                 :            : 
     187                 :         20 :     ~RefCountingHashMap() {
     188         [ +  + ]:       1317 :         for (auto p = table_, end = table_ + capacity_; p != end; ++p) {
     189         [ +  + ]:       1297 :             if (p->probe_length != 0)
     190                 :        102 :                 p->~entry_type();
     191                 :       1297 :         }
     192                 :         20 :         free(table_);
     193                 :         20 :     }
     194                 :            : 
     195                 :         31 :     size_type capacity() const { return capacity_; }
     196                 :         53 :     size_type size() const { return size_; }
     197                 :        551 :     size_type mask() const { return capacity_ - size_type(1); }
     198                 :        551 :     size_type masked(size_type index) const { return index & mask(); }
     199                 :         24 :     float max_load_factor() const { return max_load_factor_; }
     200                 :         10 :     void max_load_factor(float ml) {
     201                 :         10 :         max_load_factor_ = std::clamp(ml, .0f, .99f);
     202                 :         10 :         watermark_high_ = max_load_factor_ * capacity_;
     203                 :         10 :     }
     204                 :         11 :     size_type watermark_high() const { return watermark_high_; }
     205                 :            : 
     206                 :            :     iterator begin() { return iterator(*this, table_); }
     207                 :         18 :     iterator end()   { return iterator(*this, table_ + capacity()); }
     208                 :            :     const_iterator begin() const { return const_iterator(*this, table_); }
     209                 :            :     const_iterator end()   const { return const_iterator(*this, table_ + capacity()); }
     210                 :            :     const_iterator cbegin() const { return begin(); }
     211                 :            :     const_iterator cend()   const { return end(); }
     212                 :            : 
     213                 :        103 :     iterator insert_with_duplicates(key_type key, mapped_type value) {
     214         [ +  + ]:        103 :         if (size_ >= watermark_high_)
     215                 :          4 :             resize(2 * capacity_);
     216                 :            : 
     217                 :        103 :         const auto hash = h_(key);
     218                 :        103 :         const size_type index = masked(hash);
     219                 :        103 :         entry_type * const bucket = table_ + index;
     220                 :            : 
     221         [ +  + ]:        103 :         if (bucket->probe_length == 0) [[likely]] { // bucket is free
     222                 :         30 :             ++bucket->probe_length;
     223                 :         30 :             new (&bucket->value) value_type(std::move(key), std::move(value));
     224                 :         30 :             ++size_;
     225                 :         30 :             return iterator(*this, bucket);
     226                 :            :         }
     227                 :            : 
     228                 :            :         /* Compute distance to end of probe sequence. */
     229                 :         73 :         size_type distance = (bucket->probe_length * bucket->probe_length + bucket->probe_length) >> 1;
     230                 :         73 :         M_insist(distance > 0, "the distance must not be 0, otherwise we would have run into the likely case above");
     231                 :            : 
     232                 :            :         /* Search next free slot in bucket's probe sequence. */
     233                 :         73 :         entry_type *probe = table_ + masked(index + distance);
     234                 :         73 :         M_insist(probe != bucket, "the probed slot must not be the original bucket as the distance is not 0 and always "
     235                 :            :                                  "less than capacity");
     236         [ +  + ]:         92 :         while (probe->probe_length != 0) {
     237                 :         19 :             ++bucket->probe_length;
     238                 :         19 :             distance += bucket->probe_length;
     239                 :         19 :             probe = table_ + masked(index + distance);
     240                 :            :         }
     241                 :            : 
     242                 :            :         /* Found free slot in bucket's probe sequence.  Place element in slot and update probe length. */
     243                 :         73 :         ++probe->probe_length; // set probe_length from 0 to 1
     244                 :         73 :         ++bucket->probe_length;
     245                 :         73 :         new (&probe->value) value_type(std::move(key), std::move(value));
     246                 :         73 :         ++size_;
     247                 :         73 :         return iterator(*this, probe);
     248                 :        103 :     }
     249                 :            : 
     250                 :         63 :     std::pair<iterator, bool> insert_without_duplicates(key_type key, mapped_type value) {
     251         [ +  + ]:         63 :         if (size_ >= watermark_high_)
     252                 :          2 :             resize(2 * capacity_);
     253                 :            : 
     254                 :         63 :         const auto hash = h_(key);
     255                 :         63 :         const size_type index = masked(hash);
     256                 :         63 :         entry_type * const bucket = table_ + index;
     257                 :            : 
     258                 :         63 :         entry_type *probe = bucket;
     259                 :         63 :         size_type insertion_probe_length = 0;
     260                 :         63 :         size_type insertion_probe_distance = 0;
     261         [ +  + ]:         92 :         while (probe->probe_length != 0) {
     262         [ +  + ]:         35 :             if (eq_(key, probe->value.first))
     263                 :          6 :                 return std::make_pair(iterator(*this, probe), false); // duplicate key
     264                 :         29 :             ++insertion_probe_length;
     265                 :         29 :             insertion_probe_distance += insertion_probe_length;
     266                 :         29 :             probe = table_ + masked(index + insertion_probe_distance);
     267                 :            :         }
     268                 :            : 
     269                 :         57 :         ++probe->probe_length; // set probe_length from 0 to 1
     270                 :         57 :         bucket->probe_length = insertion_probe_length + 1;
     271                 :         57 :         new (&probe->value) value_type(std::move(key), std::move(value));
     272                 :         57 :         ++size_;
     273                 :         57 :         return std::make_pair(iterator(*this, probe), true);
     274                 :         63 :     }
     275                 :            : 
     276                 :         13 :     iterator find(const key_type &key) {
     277                 :         13 :         const auto hash = h_(key);
     278                 :         13 :         const size_type index = masked(hash);
     279                 :         13 :         entry_type * const bucket = table_ + index;
     280                 :            : 
     281                 :            : #if 1
     282                 :            :         /* Search the probe sequence in natural order. */
     283                 :         13 :         entry_type *probe = bucket;
     284                 :         13 :         size_type lookup_probe_length = 0;
     285                 :         13 :         size_type lookup_probe_distance = 0;
     286   [ +  +  +  + ]:         21 :         while (probe->probe_length != 0 and lookup_probe_length < bucket->probe_length) {
     287         [ +  + ]:         16 :             if (eq_(key, probe->value.first))
     288                 :          8 :                 return iterator(*this, probe);
     289                 :          8 :             ++lookup_probe_length;
     290                 :          8 :             lookup_probe_distance += lookup_probe_length;
     291                 :          8 :             probe = table_ + masked(index + lookup_probe_distance);
     292                 :            :         }
     293                 :            : #else
     294                 :            :         /* Search the probe sequence in inversed order, starting at the last element of this bucket. */
     295                 :            :         size_type lookup_probe_length = bucket->probe_length;
     296                 :            :         size_type lookup_probe_distance = (lookup_probe_length * (lookup_probe_length - 1)) >> 1;
     297                 :            :         while (lookup_probe_length != 0) {
     298                 :            :             entry_type *probe = table_ + masked(index + lookup_probe_distance);
     299                 :            :             if (eq_(key, probe->value.first))
     300                 :            :                 return iterator(*this, probe);
     301                 :            :             --lookup_probe_length;
     302                 :            :             lookup_probe_distance -= lookup_probe_length;
     303                 :            :         }
     304                 :            : #endif
     305                 :          5 :         return end();
     306                 :         13 :     }
     307                 :            :     const_iterator find(const key_type &key) const {
     308                 :            :         return const_iterator(*this, const_cast<RefCountingHashMap*>(this)->find(key).bucket_);
     309                 :            :     }
     310                 :            : 
     311                 :         80 :     bucket_iterator bucket(const key_type &key) {
     312                 :         80 :         const auto hash = h_(key);
     313                 :         80 :         const size_type index = masked(hash);
     314                 :         80 :         return bucket_iterator(*this, index);
     315                 :            :     }
     316                 :         75 :     const_bucket_iterator bucket(const key_type &key) const {
     317                 :         75 :         return const_bucket_iterator(*this, const_cast<RefCountingHashMap*>(this)->bucket(key).bucket_index_);
     318                 :            :     }
     319                 :            : 
     320                 :          4 :     void for_all(const key_type &key, std::function<void(value_type&)> callback) {
     321         [ +  + ]:         16 :         for (auto it = bucket(key); it.has_next(); ++it) {
     322         [ +  + ]:         12 :             if (eq_(key, it->first))
     323                 :          8 :                 callback(*it);
     324                 :         12 :         }
     325                 :          4 :     }
     326                 :         75 :     void for_all(const key_type &key, std::function<void(const value_type&)> callback) const {
     327         [ +  + ]:        211 :         for (auto it = bucket(key); it.has_next(); ++it) {
     328         [ +  + ]:        136 :             if (eq_(key, it->first))
     329                 :         64 :                 callback(*it);
     330                 :        136 :         }
     331                 :         75 :     }
     332                 :            : 
     333                 :         75 :     size_type count(const key_type &key) const {
     334                 :         75 :         size_type cnt = 0;
     335         [ +  - ]:        139 :         for_all(key, [&cnt](auto) { ++cnt; });
     336                 :         75 :         return cnt;
     337                 :          0 :     }
     338                 :            : 
     339                 :            :     /** Rehash all elements. */
     340                 :            :     private:
     341                 :         12 :     void rehash(std::size_t new_capacity) {
     342                 :         12 :         M_insist((new_capacity & (new_capacity - 1)) == 0, "not a power of 2");
     343                 :         12 :         M_insist(size_ <= watermark_high_, "there are more elements to rehash than the high watermark allows");
     344                 :            : 
     345                 :         12 :         auto old_table = table_;
     346                 :         12 :         table_ = allocate(new_capacity);
     347                 :         12 :         auto old_capacity = capacity_;
     348                 :         12 :         capacity_ = new_capacity;
     349                 :         12 :         size_ = 0;
     350                 :         12 :         initialize();
     351                 :            : 
     352         [ +  + ]:        124 :         for (auto runner = old_table, end = old_table + old_capacity; runner != end; ++runner) {
     353         [ +  + ]:        112 :             if (runner->probe_length) {
     354                 :         58 :                 auto &key_ref = const_cast<key_type&>(runner->value.first); // hack around the `const key_type`
     355   [ #  #  #  # ]:         58 :                 insert_with_duplicates(std::move(key_ref), std::move(runner->value.second));
     356                 :         58 :             }
     357                 :        112 :         }
     358                 :            : 
     359                 :         12 :         free(old_table);
     360                 :         12 :     }
     361                 :            : 
     362                 :            :     public:
     363                 :          1 :     void rehash() { rehash(capacity_); }
     364                 :            : 
     365                 :         13 :     void resize(std::size_t new_capacity) {
     366                 :         13 :         new_capacity = std::max<decltype(new_capacity)>(new_capacity, std::ceil(size() / max_load_factor()));
     367                 :         13 :         new_capacity = ceil_to_pow_2(new_capacity);
     368                 :            : 
     369         [ +  + ]:         13 :         if (new_capacity != capacity_) {
     370                 :         11 :             watermark_high_ = new_capacity * max_load_factor();
     371                 :         11 :             M_insist(watermark_high() >= size());
     372                 :         11 :             rehash(new_capacity);
     373                 :         11 :         }
     374                 :         13 :     }
     375                 :            : 
     376                 :            :     void shrink_to_fit() { resize(size()); }
     377                 :            : 
     378                 :            : M_LCOV_EXCL_START
     379                 :            :     friend std::ostream & operator<<(std::ostream &out, const RefCountingHashMap &map) {
     380                 :            :         size_type log2 = log2_ceil(map.capacity());
     381                 :            :         size_type log10 = size_type(std::ceil(double(log2) / 3.322));
     382                 :            :         for (size_type i = 0; i != map.capacity_; ++i) {
     383                 :            :             auto &entry = map.table_[i];
     384                 :            :             out << '[' << std::setw(log10) << i << "]: probe length = " << std::setw(log10) << entry.probe_length;
     385                 :            :             if (entry.probe_length) {
     386                 :            :                 out << ", value = (" << entry.value.first << ", " << entry.value.second << ')';
     387                 :            :             }
     388                 :            :             out << '\n';
     389                 :            :         }
     390                 :            :         return out;
     391                 :            :     }
     392                 :            : 
     393                 :            :     void dump(std::ostream &out) const { out << *this; out.flush(); }
     394                 :            :     void dump() const { dump(std::cerr); }
     395                 :            : M_LCOV_EXCL_STOP
     396                 :            : 
     397                 :            :     private:
     398                 :         32 :     static entry_type * allocate(size_type n, entry_type *hint = nullptr) {
     399                 :         32 :         auto p = static_cast<entry_type*>(realloc(hint, n * sizeof(entry_type)));
     400         [ +  - ]:         32 :         if (p == nullptr)
     401         [ #  # ]:          0 :             throw std::runtime_error("allocation failed");
     402                 :         32 :         return p;
     403                 :          0 :     }
     404                 :            : 
     405                 :         32 :     void initialize() {
     406         [ +  + ]:       1441 :         for (auto runner = table_, end = table_ + capacity_; runner != end; ++runner)
     407                 :       1409 :             new (runner) entry_type();
     408                 :         32 :     }
     409                 :            : };
     410                 :            : 
     411                 :            : }

Generated by: LCOV version 1.16