LCOV - code coverage report
Current view: top level - src/util - Spn.hpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 79 100 79.0 %
Date: 2025-03-25 01:19:55 Functions: 39 48 81.2 %
Branches: 12 16 75.0 %

           Branch data     Line data    Source code
       1                 :            : #pragma once
       2                 :            : 
       3                 :            : #include <Eigen/Core>
       4                 :            : #include <iostream>
       5                 :            : #include <map>
       6                 :            : #include <memory>
       7                 :            : #include "mutable/util/ADT.hpp"
       8                 :            : #include <set>
       9                 :            : #include <type_traits>
      10                 :            : #include <unordered_map>
      11                 :            : #include <vector>
      12                 :            : 
      13                 :            : 
      14                 :            : namespace m {
      15                 :            : 
      16                 :            : /**
      17                 :            :  * Tree structure for Sum Product Networks
      18                 :            :  */
      19                 :            : struct Spn
      20                 :            : {
      21                 :            :     public:
      22                 :            :     /** The different types of leaves for an attribute. `AUTO` if `learn_spn` should choose the leaf type by itself. */
      23                 :            :     enum LeafType {
      24                 :            :         AUTO,
      25                 :            :         DISCRETE,
      26                 :            :         CONTINUOUS
      27                 :            :     };
      28                 :            :     enum SpnOperator {
      29                 :            :         EQUAL,
      30                 :            :         LESS,
      31                 :            :         LESS_EQUAL,
      32                 :            :         GREATER,
      33                 :            :         GREATER_EQUAL,
      34                 :            :         IS_NULL,
      35                 :            :         EXPECTATION
      36                 :            :     };
      37                 :            :     enum EvalType {
      38                 :            :         APPROXIMATE,
      39                 :            :         UPPER_BOUND,
      40                 :            :         LOWER_BOUND
      41                 :            :     };
      42                 :            :     enum UpdateType {
      43                 :            :         INSERT,
      44                 :            :         DELETE
      45                 :            :     };
      46                 :            : 
      47                 :            :     using Filter = std::unordered_map<unsigned, std::pair<SpnOperator, float>>;
      48                 :            : 
      49                 :            :     private:
      50                 :            : 
      51                 :            :     struct LearningData
      52                 :            :     {
      53                 :            :         const Eigen::MatrixXf &data;
      54                 :            :         const Eigen::MatrixXf &normalized;
      55                 :            :         const Eigen::MatrixXi &null_matrix;
      56                 :            :         SmallBitset variables;
      57                 :            :         std::vector<LeafType> leaf_types;
      58                 :            : 
      59                 :        645 :         LearningData(
      60                 :            :             const Eigen::MatrixXf &data,
      61                 :            :             const Eigen::MatrixXf &normalized,
      62                 :            :             const Eigen::MatrixXi &null_matrix,
      63                 :            :             SmallBitset variables,
      64                 :            :             std::vector<LeafType> leaf_types
      65                 :            :         )
      66                 :        645 :             : data(data)
      67                 :        645 :             , normalized(normalized)
      68                 :        645 :             , null_matrix(null_matrix)
      69                 :        645 :             , variables(variables)
      70                 :        645 :             , leaf_types(std::move(leaf_types))
      71                 :        645 :         { }
      72                 :            :     };
      73                 :            : 
      74                 :            :     struct Node
      75                 :            :     {
      76                 :            :         std::size_t num_rows;
      77                 :            : 
      78                 :        646 :         explicit Node(std::size_t num_rows) : num_rows(num_rows) { }
      79                 :            : 
      80                 :        646 :         virtual ~Node() = default;
      81                 :            :         void dump() const;
      82                 :            :         void dump(std::ostream &out) const;
      83                 :            : 
      84                 :            :         /** Evaluate the SPN bottom up with a filter condition.
      85                 :            :          *
      86                 :            :          * @param filter    the filter condition
      87                 :            :          * @param eval_type for continuous leaves to test bin accuracy
      88                 :            :          * @return          a pair <conditional expectation, likelihood> (cond. expectation undefined if only likelihood
      89                 :            :          *                  is evaluated)
      90                 :            :          */
      91                 :            :         virtual std::pair<float, float> evaluate(const Filter &filter, unsigned leaf_id, EvalType eval_type) const = 0;
      92                 :            : 
      93                 :            :         virtual void update(Eigen::VectorXf &row, SmallBitset variables, UpdateType update_type) = 0;
      94                 :            : 
      95                 :            :         virtual std::size_t estimate_number_distinct_values(unsigned id) const = 0;
      96                 :            : 
      97                 :            :         virtual unsigned height() const = 0;
      98                 :            :         virtual unsigned breadth() const = 0;
      99                 :            :         virtual unsigned degree() const = 0;
     100                 :            :         virtual std::size_t memory_usage() const = 0;
     101                 :            : 
     102                 :            :         virtual void print(std::ostream &out, std::size_t num_tabs) const = 0;
     103                 :            :     };
     104                 :            : 
     105                 :            :     struct Sum : Node
     106                 :            :     {
     107                 :            :         struct ChildWithWeight
     108                 :            :         {
     109                 :            :             std::unique_ptr<Node> child;
     110                 :            :             float weight; ///< weight of a child of a sum node
     111                 :            :             Eigen::VectorXf centroid; ///< centroid of this child according to kmeans cluster
     112                 :            : 
     113                 :        254 :             ChildWithWeight(std::unique_ptr<Node> child, float weight, Eigen::VectorXf centroid)
     114                 :        254 :                 : child(std::move(child))
     115                 :        254 :                 , weight(weight)
     116                 :        254 :                 , centroid(std::move(centroid))
     117                 :        254 :             { }
     118                 :            :         };
     119                 :            : 
     120                 :            :         std::vector<std::unique_ptr<ChildWithWeight>> children;
     121                 :            : 
     122                 :         99 :         Sum(std::vector<std::unique_ptr<ChildWithWeight>> children, std::size_t num_rows)
     123                 :         99 :             : Node(num_rows)
     124                 :         99 :             , children(std::move(children))
     125                 :        198 :         { }
     126                 :            : 
     127                 :            :         std::pair<float, float> evaluate(const Filter &filter, unsigned leaf_id, EvalType eval_type) const override;
     128                 :            : 
     129                 :            :         void update(Eigen::VectorXf &row, SmallBitset variables, UpdateType update_type) override;
     130                 :            : 
     131                 :            :         std::size_t estimate_number_distinct_values(unsigned id) const override;
     132                 :            : 
     133                 :          1 :         unsigned height() const override {
     134                 :          1 :             unsigned max_height = 0;
     135         [ +  + ]:          3 :             for (auto &child : children) { max_height = std::max(max_height, child->child->height()); }
     136                 :          1 :             return 1 + max_height;
     137                 :            :         }
     138                 :          1 :         unsigned breadth() const override {
     139                 :          1 :             unsigned breadth = 0;
     140         [ +  + ]:          3 :             for (auto &child : children) { breadth += child->child->breadth(); }
     141                 :          1 :             return breadth;
     142                 :            :         }
     143                 :          1 :         unsigned degree() const override {
     144                 :          1 :             unsigned max_degree = children.size();
     145         [ +  + ]:          3 :             for (auto &child : children) { max_degree = std::max(max_degree, child->child->degree()); }
     146                 :          1 :             return max_degree;
     147                 :            :         }
     148                 :          0 :         std::size_t memory_usage() const override {
     149                 :          0 :             std::size_t memory_used = 0;
     150                 :          0 :             memory_used += sizeof *this + children.size() * sizeof(decltype(children)::value_type);
     151         [ #  # ]:          0 :             for (auto &child : children) {
     152                 :          0 :                 memory_used += child->child->memory_usage();
     153                 :            :             }
     154                 :          0 :             return memory_used;
     155                 :            :         }
     156                 :            : 
     157                 :            :         void print(std::ostream &out, std::size_t num_tabs) const override;
     158                 :            :     };
     159                 :            : 
     160                 :            :     struct Product : Node
     161                 :            :     {
     162                 :            :         struct ChildWithVariables
     163                 :            :         {
     164                 :            :             std::unique_ptr<Node> child; ///< a child of the Product node
     165                 :            :             SmallBitset variables; ///< the set of variables(attributes), that are in this child
     166                 :            : 
     167                 :        373 :             ChildWithVariables(std::unique_ptr<Node> child, SmallBitset variables)
     168                 :        373 :                 : child(std::move(child))
     169                 :        373 :                 , variables(variables)
     170                 :        373 :             { }
     171                 :            :         };
     172                 :            : 
     173                 :            :         std::vector<std::unique_ptr<ChildWithVariables>> children;
     174                 :            : 
     175                 :        186 :         Product(std::vector<std::unique_ptr<ChildWithVariables>> children, std::size_t num_rows)
     176                 :        186 :             : Node(num_rows)
     177                 :        186 :             , children(std::move(children))
     178                 :        372 :         { }
     179                 :            : 
     180                 :            :         std::pair<float, float> evaluate(const Filter &filter, unsigned leaf_id, EvalType eval_type) const override;
     181                 :            : 
     182                 :            :         void update(Eigen::VectorXf &row, SmallBitset variables, UpdateType update_type) override;
     183                 :            : 
     184                 :            :         std::size_t estimate_number_distinct_values(unsigned id) const override;
     185                 :            : 
     186                 :          4 :         unsigned height() const override {
     187                 :          4 :             unsigned max_height = 0;
     188         [ +  + ]:         13 :             for (auto &child : children) { max_height = std::max(max_height, child->child->height()); }
     189                 :          4 :             return 1 + max_height;
     190                 :            :         }
     191                 :          4 :         unsigned breadth() const override {
     192                 :          4 :             unsigned breadth = 0;
     193         [ +  + ]:         13 :             for (auto &child : children) { breadth += child->child->breadth(); }
     194                 :          4 :             return breadth;
     195                 :            :         }
     196                 :          4 :         unsigned degree() const override {
     197                 :          4 :             unsigned max_degree = children.size();
     198         [ +  + ]:         13 :             for (auto &child : children) { max_degree = std::max(max_degree, child->child->degree()); }
     199                 :          4 :             return max_degree;
     200                 :            :         }
     201                 :          0 :         std::size_t memory_usage() const override {
     202                 :          0 :             std::size_t memory_used = 0;
     203                 :          0 :             memory_used += sizeof *this + children.size() * sizeof(decltype(children)::value_type);
     204         [ #  # ]:          0 :             for (auto &child : children) {
     205                 :          0 :                 memory_used += child->variables.size() * sizeof child->variables;
     206                 :          0 :                 memory_used += child->child->memory_usage();
     207                 :            :             }
     208                 :          0 :             return memory_used;
     209                 :            :         }
     210                 :            : 
     211                 :            :         void print(std::ostream &out, std::size_t num_tabs) const override;
     212                 :            :     };
     213                 :            : 
     214                 :            :     struct DiscreteLeaf : Node
     215                 :            :     {
     216                 :            :         struct Bin
     217                 :            :         {
     218                 :            :             float value; ///< the value of this bin
     219                 :            :             ///> the cumulative probability of this and all predecessor bins; in the range [0;1]
     220                 :            :             float cumulative_probability;
     221                 :            : 
     222                 :       1425 :             Bin(float value, float cumulative_probability)
     223                 :       1425 :                 : value(value)
     224                 :       1425 :                 , cumulative_probability(cumulative_probability)
     225                 :       1425 :             { }
     226                 :            : 
     227                 :            :             bool operator<(Bin &other) const { return this->value < other.value; }
     228                 :       3066 :             bool operator<(float other) const { return this->value < other; }
     229                 :            :         };
     230                 :            : 
     231                 :            :         std::vector<Bin> bins; ///< bins of this leaf
     232                 :            :         float null_probability; ///< the probability of null values in this leaf
     233                 :            : 
     234                 :        186 :         DiscreteLeaf(std::vector<Bin> bins, float null_probability, std::size_t num_rows)
     235                 :        186 :             : Node(num_rows)
     236                 :        186 :             , bins(std::move(bins))
     237                 :        186 :             , null_probability(null_probability)
     238                 :        372 :         { }
     239                 :            : 
     240                 :            :         std::pair<float, float> evaluate(const Filter &bin_value, unsigned leaf_id, EvalType eval_type) const override;
     241                 :            : 
     242                 :            :         void update(Eigen::VectorXf &row, SmallBitset variables, UpdateType update_type) override;
     243                 :            : 
     244                 :            :         std::size_t estimate_number_distinct_values(unsigned id) const override;
     245                 :            : 
     246                 :         11 :         unsigned height() const override { return 0; }
     247                 :         11 :         unsigned breadth() const override { return 1; }
     248                 :         11 :         unsigned degree() const override { return 0; }
     249                 :          0 :         std::size_t memory_usage() const override {
     250                 :          0 :             return sizeof *this + bins.size() * sizeof(decltype(bins)::value_type);
     251                 :            :         }
     252                 :            :         void print(std::ostream &out, std::size_t num_tabs) const override;
     253                 :            :     };
     254                 :            : 
     255                 :            :     struct ContinuousLeaf : Node
     256                 :            :     {
     257                 :            :         struct Bin
     258                 :            :         {
     259                 :            :             float upper_bound; ///< the upper bound of this bin
     260                 :            :             ///> the cumulative probability of this and all predecessor bins; in the range [0;1]
     261                 :            :             float cumulative_probability;
     262                 :            : 
     263                 :        784 :             Bin(float upper_bound, float cumulative_probability)
     264                 :        784 :                 : upper_bound(upper_bound)
     265                 :        784 :                 , cumulative_probability(cumulative_probability)
     266                 :        784 :             { }
     267                 :            : 
     268                 :            :             bool operator<(Bin &other) const { return this->upper_bound < other.upper_bound; }
     269                 :       3192 :             bool operator<(float other) const { return this->upper_bound < other; }
     270                 :            :         };
     271                 :            : 
     272                 :            :         std::vector<Bin> bins; ///< bins of this leaf
     273                 :            :         float lower_bound; ///< the lower bound of the first bin
     274                 :            :         float lower_bound_probability; ///< probability of the lower_bound
     275                 :            :         float null_probability; ///< the probability of null values in this leaf
     276                 :            : 
     277                 :        175 :         ContinuousLeaf(
     278                 :            :             std::vector<Bin> bins,
     279                 :            :             float lower_bound,
     280                 :            :             float lower_bound_probability,
     281                 :            :             float null_probability,
     282                 :            :             std::size_t num_rows
     283                 :            :         )
     284                 :        175 :             : Node(num_rows)
     285                 :        175 :             , bins(std::move(bins))
     286                 :        175 :             , lower_bound(lower_bound)
     287                 :        175 :             , lower_bound_probability(lower_bound_probability)
     288                 :        175 :             , null_probability(null_probability)
     289                 :        350 :         { }
     290                 :            : 
     291                 :            :         std::pair<float, float> evaluate(const Filter &filter, unsigned leaf_id, EvalType eval_type) const override;
     292                 :            : 
     293                 :            :         void update(Eigen::VectorXf &row, SmallBitset variables, UpdateType update_type) override;
     294                 :            : 
     295                 :            :         std::size_t estimate_number_distinct_values(unsigned id) const override;
     296                 :            : 
     297                 :          0 :         unsigned height() const override { return 0; }
     298                 :          0 :         unsigned breadth() const override { return 1; }
     299                 :          0 :         unsigned degree() const override { return 0; }
     300                 :          0 :         std::size_t memory_usage() const override {
     301                 :          0 :             return sizeof *this + bins.size() * sizeof(decltype(bins)::value_type);
     302                 :            :         }
     303                 :            :         void print(std::ostream &out, std::size_t num_tabs) const override;
     304                 :            :     };
     305                 :            : 
     306                 :            :     std::size_t num_rows_;
     307                 :            :     std::unique_ptr<Node> root_;
     308                 :            : 
     309                 :         19 :     Spn(std::size_t num_rows, std::unique_ptr<Node> root) : num_rows_(num_rows), root_(std::move(root)) { }
     310                 :            : 
     311                 :            :     public:
     312                 :            : 
     313                 :            :     /** returns the number of rows in the SPN. */
     314                 :          0 :     std::size_t num_rows() const { return num_rows_; }
     315                 :            : 
     316                 :            :     /*==================================================================================================================
     317                 :            :      * Learning
     318                 :            :      *================================================================================================================*/
     319                 :            : 
     320                 :            :     private:
     321                 :            : 
     322                 :            :     /** Create a product node by splitting all columns */
     323                 :            :     static std::unique_ptr<Spn::Product> create_product_min_slice(LearningData &learning_data);
     324                 :            : 
     325                 :            :     /** Create a product node with the given candidates (vertical clustering) */
     326                 :            :     static std::unique_ptr<Product> create_product_rdc(
     327                 :            :         LearningData &learning_data,
     328                 :            :         std::vector<SmallBitset> &column_candidates,
     329                 :            :         std::vector<SmallBitset> &variable_candidates
     330                 :            :     );
     331                 :            : 
     332                 :            :     /** Create a sum node by clustering the rows */
     333                 :            :     static std::unique_ptr<Spn::Sum> create_sum(LearningData &learning_data);
     334                 :            : 
     335                 :            :     /** Recursively learns the nodes of an SPN. */
     336                 :            :     static std::unique_ptr<Node> learn_node(LearningData &learning_data);
     337                 :            : 
     338                 :            :     public:
     339                 :            : 
     340                 :            :     /** Learn an SPN over the given data.
     341                 :            :      *
     342                 :            :      * @param data              the data
     343                 :            :      * @param null_matrix       the NULL values of the data as a matrix
     344                 :            :      * @param attribute_to_id   a map from the attributes (random variables) to internal id
     345                 :            :      * @param leaf_types        the types of a leaf for a non-primary key attribute
     346                 :            :      * @return                  the learned SPN
     347                 :            :      */
     348                 :            :     static Spn learn_spn(Eigen::MatrixXf &data, Eigen::MatrixXi &null_matrix, std::vector<LeafType> &leaf_types);
     349                 :            : 
     350                 :            :     /*==================================================================================================================
     351                 :            :      * Inference
     352                 :            :      *================================================================================================================*/
     353                 :            : 
     354                 :            :     private:
     355                 :            : 
     356                 :            :     /** Update the SPN from the top down and adjust weights of sum nodes and the distributions on leaves.
     357                 :            :      *
     358                 :            :      * @param row the row to update in the SPN
     359                 :            :      * @param update_type the type of update (insert or delete)
     360                 :            :      */
     361                 :            :     void update(Eigen::VectorXf &row, UpdateType update_type);
     362                 :            : 
     363                 :            :     public:
     364                 :            : 
     365                 :            :     /** Compute the likelihood of the given filter predicates given by a map from attribute to the
     366                 :            :      * respective operator and value. The predicates in the map are seen as conjunctions. */
     367                 :            :     float likelihood(const Filter &filter) const;
     368                 :            : 
     369                 :            :     /** Compute the upper bound probability for continuous domains. */
     370                 :            :     float upper_bound(const Filter &filter) const;
     371                 :            : 
     372                 :            :     /** Compute the lower bound probability for continuous domains. */
     373                 :            :     float lower_bound(const Filter &filter) const;
     374                 :            : 
     375                 :            :     /** Compute the expectation of the given attribute. */
     376                 :            :     float expectation(unsigned attribute_id, const Filter &filter) const;
     377                 :            : 
     378                 :            :     /** Update the SPN with the given row. */
     379                 :            :     void update_row(Eigen::VectorXf &old_row, Eigen::VectorXf &updated_row);
     380                 :            : 
     381                 :            :     /** Insert the given row into the SPN. */
     382                 :            :     void insert_row(Eigen::VectorXf &row);
     383                 :            : 
     384                 :            :     /** Delete the given row from the SPN. */
     385                 :            :     void delete_row(Eigen::VectorXf &row);
     386                 :            : 
     387                 :            :     /** Estimate the number of distinct values of the given attribute. */
     388                 :            :     std::size_t estimate_number_distinct_values(unsigned attribute_id) const;
     389                 :            : 
     390                 :          5 :     unsigned height() const { return root_->height(); }
     391                 :          5 :     unsigned breadth() const { return root_->breadth(); }
     392                 :          5 :     unsigned degree() const { return root_->degree(); }
     393                 :            :     std::size_t memory_usage() const {
     394                 :            :         std::size_t memory_used = 0;
     395                 :            :         memory_used += sizeof root_;
     396                 :            :         memory_used += root_->memory_usage();
     397                 :            :         return memory_used;
     398                 :            :     }
     399                 :            : 
     400                 :            :     void dump() const;
     401                 :            :     void dump(std::ostream &out) const;
     402                 :            : };
     403                 :            : 
     404                 :            : }

Generated by: LCOV version 1.16