Branch data Line data Source code
1 : : #pragma once 2 : : 3 : : #include <mutable/util/Pool.hpp> 4 : : #include <unordered_map> 5 : : #include <util/Spn.hpp> 6 : : #include <vector> 7 : : 8 : : 9 : : namespace m { 10 : : 11 : : /** A wrapper class for an Spn to be used in the context of databases. */ 12 : : struct SpnWrapper 13 : : { 14 : : using Filter = std::unordered_map<unsigned, std::pair<Spn::SpnOperator, float>>; 15 : : using AttrFilter = std::unordered_map<ThreadSafePooledString, std::pair<Spn::SpnOperator, float>>; 16 : : 17 : : private: 18 : : Spn spn_; 19 : : std::unordered_map<ThreadSafePooledString, unsigned> attribute_to_id_; ///< a map from attribute to spn internal id 20 : : 21 : 19 : SpnWrapper(Spn spn, std::unordered_map<ThreadSafePooledString, unsigned> attribute_to_id) 22 : 19 : : spn_(std::move(spn)) 23 : 19 : , attribute_to_id_(std::move(attribute_to_id)) 24 : 19 : { } 25 : : 26 : 13 : Filter translate_filter(const AttrFilter &attr_filter) const { 27 : 13 : Filter filter; 28 [ + + + - : 25 : for (auto &elem : attr_filter) { filter.emplace(translate_attribute(elem.first), elem.second); } + - ] 29 : 13 : return filter; 30 [ + - ]: 13 : }; 31 : : 32 : 13 : unsigned translate_attribute(const ThreadSafePooledString &attribute) const { 33 : 13 : unsigned spn_id = 0; 34 [ + - ]: 13 : if (auto it = attribute_to_id_.find(attribute); it != attribute_to_id_.end()) { 35 : 13 : spn_id = it->second; 36 : 13 : } else { std::cerr << "could not find attribute: " << attribute << std::endl; } 37 : 13 : return spn_id; 38 : : } 39 : : 40 : : public: 41 : : SpnWrapper(const SpnWrapper&) = delete; 42 : : SpnWrapper(SpnWrapper&&) = default; 43 : : 44 : : /** Get the reference to the attribute to spn internal id mapping. */ 45 : 0 : const std::unordered_map<ThreadSafePooledString, unsigned> & get_attribute_to_id() const { return attribute_to_id_; } 46 : : 47 : : /** Learn an SPN over the given table. 48 : : * 49 : : * @param name_of_database the database 50 : : * @param name_of_table the table in the database 51 : : * @param leaf_types the types of a leaf for a non-primary key attribute 52 : : * @return the learned SPN 53 : : */ 54 : : static SpnWrapper learn_spn_table(const ThreadSafePooledString &name_of_database, 55 : : const ThreadSafePooledString &name_of_table, 56 : : std::vector<Spn::LeafType> leaf_types = decltype(leaf_types)()); 57 : : 58 : : /** Learn SPNs over the tables in the given database. 59 : : * 60 : : * @param name_of_database the database 61 : : * @param leaf_types the type of a leaf for a non-primary key attribute in the respective table 62 : : * @return the learned SPNs 63 : : */ 64 : : static std::unordered_map<ThreadSafePooledString, SpnWrapper*> 65 : : learn_spn_database( 66 : : const ThreadSafePooledString &name_of_database, 67 : : std::unordered_map<ThreadSafePooledString, std::vector<Spn::LeafType>> leaf_types = decltype(leaf_types)() 68 : : ); 69 : : 70 : : 71 : : /** returns the number of rows in the SPN. */ 72 : 0 : std::size_t num_rows() const { return spn_.num_rows(); } 73 : : 74 : : /** Compute the likelihood of the given filter predicates given by a map from attribute to the 75 : : * respective operator and value. The predicates in the map are seen as conjunctions. */ 76 [ + - ]: 12 : float likelihood(const AttrFilter &attr_filter) const { return spn_.likelihood(translate_filter(attr_filter)); }; 77 : : /** Compute the likelihood of the given filter predicates given by a map from spn internal id to the 78 : : * respective operator and value. The predicates in the map are seen as conjunctions. */ 79 : 0 : float likelihood(const Filter &filter) const { return spn_.likelihood(filter); }; 80 : : 81 : : /** Compute the upper bound probability for continuous domains. */ 82 : : float upper_bound(const AttrFilter &attr_filter) const { return spn_.upper_bound(translate_filter(attr_filter)); }; 83 : : /** Compute the upper bound probability for continuous domains. */ 84 : : float upper_bound(const Filter &filter) const { return spn_.upper_bound(filter); }; 85 : : 86 : : /** Compute the lower bound probability for continuous domains. */ 87 : : float lower_bound(const AttrFilter &attr_filter) const { return spn_.lower_bound(translate_filter(attr_filter)); }; 88 : : /** Compute the lower bound probability for continuous domains. */ 89 : : float lower_bound(const Filter &filter) const { return spn_.lower_bound(filter); }; 90 : : 91 : : /** Compute the expectation of the given attribute. */ 92 : 1 : float expectation(const ThreadSafePooledString &attribute, const AttrFilter &attr_filter) const { 93 [ + - ]: 1 : return spn_.expectation(translate_attribute(attribute), translate_filter(attr_filter)); 94 : 0 : } 95 : : /** Compute the expectation of the given attribute. */ 96 : : float expectation(unsigned attribute_id, const Filter &filter) const { 97 : : return spn_.expectation(attribute_id, filter); 98 : : }; 99 : : 100 : : /** Update the SPN with the given row. */ 101 : : void update_row(Eigen::VectorXf &old_row, Eigen::VectorXf &updated_row) { spn_.update_row(old_row, updated_row); }; 102 : : 103 : : /** Insert the given row into the SPN. */ 104 : : void insert_row(Eigen::VectorXf &row) { spn_.insert_row(row); }; 105 : : 106 : : /** Delete the given row from the SPN. */ 107 : : void delete_row(Eigen::VectorXf &row) { spn_.delete_row(row); }; 108 : : 109 : : /** Estimate the number of distinct values of the given attribute. */ 110 : : std::size_t estimate_number_distinct_values(const ThreadSafePooledString &attribute) const { 111 : : return spn_.estimate_number_distinct_values(translate_attribute(attribute)); 112 : : } 113 : : /** Estimate the number of distinct values of the given attribute. */ 114 : 0 : std::size_t estimate_number_distinct_values(unsigned attribute_id) const { 115 : 0 : return spn_.estimate_number_distinct_values(attribute_id); 116 : : }; 117 : : 118 : 5 : unsigned height() const { return spn_.height(); } 119 : 5 : unsigned breadth() const { return spn_.breadth(); } 120 : 5 : unsigned degree() const { return spn_.degree(); } 121 : : std::size_t memory_usage() const { return spn_.memory_usage(); } 122 : : void dump() const { spn_.dump(); }; 123 : : void dump(std::ostream &out) const { spn_.dump(out); }; 124 : : }; 125 : : 126 : : }