LCOV - code coverage report
Current view: top level - src/catalog - SpnWrapper.hpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 21 27 77.8 %
Date: 2025-03-25 01:19:55 Functions: 9 13 69.2 %
Branches: 8 14 57.1 %

           Branch data     Line data    Source code
       1                 :            : #pragma once
       2                 :            : 
       3                 :            : #include <mutable/util/Pool.hpp>
       4                 :            : #include <unordered_map>
       5                 :            : #include <util/Spn.hpp>
       6                 :            : #include <vector>
       7                 :            : 
       8                 :            : 
       9                 :            : namespace m {
      10                 :            : 
      11                 :            : /** A wrapper class for an Spn to be used in the context of databases. */
      12                 :            : struct SpnWrapper
      13                 :            : {
      14                 :            :     using Filter = std::unordered_map<unsigned, std::pair<Spn::SpnOperator, float>>;
      15                 :            :     using AttrFilter = std::unordered_map<ThreadSafePooledString, std::pair<Spn::SpnOperator, float>>;
      16                 :            : 
      17                 :            :     private:
      18                 :            :     Spn spn_;
      19                 :            :     std::unordered_map<ThreadSafePooledString, unsigned> attribute_to_id_; ///< a map from attribute to spn internal id
      20                 :            : 
      21                 :         19 :     SpnWrapper(Spn spn, std::unordered_map<ThreadSafePooledString, unsigned> attribute_to_id)
      22                 :         19 :         : spn_(std::move(spn))
      23                 :         19 :         , attribute_to_id_(std::move(attribute_to_id))
      24                 :         19 :     { }
      25                 :            : 
      26                 :         13 :     Filter translate_filter(const AttrFilter &attr_filter) const {
      27                 :         13 :         Filter filter;
      28   [ +  +  +  -  :         25 :         for (auto &elem : attr_filter) { filter.emplace(translate_attribute(elem.first), elem.second); }
                   +  - ]
      29                 :         13 :         return filter;
      30         [ +  - ]:         13 :     };
      31                 :            : 
      32                 :         13 :     unsigned translate_attribute(const ThreadSafePooledString &attribute) const {
      33                 :         13 :         unsigned spn_id = 0;
      34         [ +  - ]:         13 :         if (auto it = attribute_to_id_.find(attribute); it != attribute_to_id_.end()) {
      35                 :         13 :             spn_id = it->second;
      36                 :         13 :         } else { std::cerr << "could not find attribute: " << attribute << std::endl; }
      37                 :         13 :         return spn_id;
      38                 :            :     }
      39                 :            : 
      40                 :            :     public:
      41                 :            :     SpnWrapper(const SpnWrapper&) = delete;
      42                 :            :     SpnWrapper(SpnWrapper&&) = default;
      43                 :            : 
      44                 :            :     /** Get the reference to the attribute to spn internal id mapping. */
      45                 :          0 :     const std::unordered_map<ThreadSafePooledString, unsigned> & get_attribute_to_id() const { return attribute_to_id_; }
      46                 :            : 
      47                 :            :     /** Learn an SPN over the given table.
      48                 :            :      *
      49                 :            :      * @param name_of_database  the database
      50                 :            :      * @param name_of_table     the table in the database
      51                 :            :      * @param leaf_types        the types of a leaf for a non-primary key attribute
      52                 :            :      * @return                  the learned SPN
      53                 :            :      */
      54                 :            :     static SpnWrapper learn_spn_table(const ThreadSafePooledString &name_of_database,
      55                 :            :                                       const ThreadSafePooledString &name_of_table,
      56                 :            :                                       std::vector<Spn::LeafType> leaf_types = decltype(leaf_types)());
      57                 :            : 
      58                 :            :     /** Learn SPNs over the tables in the given database.
      59                 :            :      *
      60                 :            :      * @param name_of_database  the database
      61                 :            :      * @param leaf_types        the type of a leaf for a non-primary key attribute in the respective table
      62                 :            :      * @return                  the learned SPNs
      63                 :            :      */
      64                 :            :     static std::unordered_map<ThreadSafePooledString, SpnWrapper*>
      65                 :            :     learn_spn_database(
      66                 :            :         const ThreadSafePooledString &name_of_database,
      67                 :            :         std::unordered_map<ThreadSafePooledString, std::vector<Spn::LeafType>> leaf_types = decltype(leaf_types)()
      68                 :            :     );
      69                 :            : 
      70                 :            : 
      71                 :            :     /** returns the number of rows in the SPN. */
      72                 :          0 :     std::size_t num_rows() const { return spn_.num_rows(); }
      73                 :            : 
      74                 :            :     /** Compute the likelihood of the given filter predicates given by a map from attribute to the
      75                 :            :      * respective operator and value. The predicates in the map are seen as conjunctions. */
      76         [ +  - ]:         12 :     float likelihood(const AttrFilter &attr_filter) const { return spn_.likelihood(translate_filter(attr_filter)); };
      77                 :            :     /** Compute the likelihood of the given filter predicates given by a map from spn internal id to the
      78                 :            :      * respective operator and value. The predicates in the map are seen as conjunctions. */
      79                 :          0 :     float likelihood(const Filter &filter) const { return spn_.likelihood(filter); };
      80                 :            : 
      81                 :            :     /** Compute the upper bound probability for continuous domains. */
      82                 :            :     float upper_bound(const AttrFilter &attr_filter) const { return spn_.upper_bound(translate_filter(attr_filter)); };
      83                 :            :     /** Compute the upper bound probability for continuous domains. */
      84                 :            :     float upper_bound(const Filter &filter) const { return spn_.upper_bound(filter); };
      85                 :            : 
      86                 :            :     /** Compute the lower bound probability for continuous domains. */
      87                 :            :     float lower_bound(const AttrFilter &attr_filter) const { return spn_.lower_bound(translate_filter(attr_filter)); };
      88                 :            :     /** Compute the lower bound probability for continuous domains. */
      89                 :            :     float lower_bound(const Filter &filter) const { return spn_.lower_bound(filter); };
      90                 :            : 
      91                 :            :     /** Compute the expectation of the given attribute. */
      92                 :          1 :     float expectation(const ThreadSafePooledString &attribute, const AttrFilter &attr_filter) const {
      93         [ +  - ]:          1 :         return spn_.expectation(translate_attribute(attribute), translate_filter(attr_filter));
      94                 :          0 :     }
      95                 :            :     /** Compute the expectation of the given attribute. */
      96                 :            :     float expectation(unsigned attribute_id, const Filter &filter) const {
      97                 :            :         return spn_.expectation(attribute_id, filter);
      98                 :            :     };
      99                 :            : 
     100                 :            :     /** Update the SPN with the given row. */
     101                 :            :     void update_row(Eigen::VectorXf &old_row, Eigen::VectorXf &updated_row) { spn_.update_row(old_row, updated_row); };
     102                 :            : 
     103                 :            :     /** Insert the given row into the SPN. */
     104                 :            :     void insert_row(Eigen::VectorXf &row) { spn_.insert_row(row); };
     105                 :            : 
     106                 :            :     /** Delete the given row from the SPN. */
     107                 :            :     void delete_row(Eigen::VectorXf &row) { spn_.delete_row(row); };
     108                 :            : 
     109                 :            :     /** Estimate the number of distinct values of the given attribute. */
     110                 :            :     std::size_t estimate_number_distinct_values(const ThreadSafePooledString &attribute) const {
     111                 :            :         return spn_.estimate_number_distinct_values(translate_attribute(attribute));
     112                 :            :     }
     113                 :            :     /** Estimate the number of distinct values of the given attribute. */
     114                 :          0 :     std::size_t estimate_number_distinct_values(unsigned attribute_id) const {
     115                 :          0 :         return spn_.estimate_number_distinct_values(attribute_id);
     116                 :            :     };
     117                 :            : 
     118                 :          5 :     unsigned height() const { return spn_.height(); }
     119                 :          5 :     unsigned breadth() const { return spn_.breadth(); }
     120                 :          5 :     unsigned degree() const { return spn_.degree(); }
     121                 :            :     std::size_t memory_usage() const { return spn_.memory_usage(); }
     122                 :            :     void dump() const { spn_.dump(); };
     123                 :            :     void dump(std::ostream &out) const { spn_.dump(out); };
     124                 :            : };
     125                 :            : 
     126                 :            : }

Generated by: LCOV version 1.16