Branch data Line data Source code
1 : : #include "SpnWrapper.hpp"
2 : :
3 : : #include <mutable/mutable.hpp>
4 : : #include <mutable/util/Diagnostic.hpp>
5 : :
6 : :
7 : : using namespace m;
8 : : using namespace Eigen;
9 : :
10 : :
11 : 19 : SpnWrapper SpnWrapper::learn_spn_table(const ThreadSafePooledString &name_of_database,
12 : : const ThreadSafePooledString &name_of_table,
13 : : std::vector<Spn::LeafType> leaf_types)
14 : : {
15 : 19 : auto &C = Catalog::Get();
16 : 19 : auto &db = C.get_database(name_of_database);
17 : 19 : auto &table = db.get_table(name_of_table);
18 : :
19 : 19 : leaf_types.resize(table.num_attrs(), Spn::AUTO); // pad with AUTO
20 : :
21 : : /* use CartesianProductEstimator to query data since there currently are no SPNs on the data. */
22 [ + - + - : 19 : auto old_estimator = db.cardinality_estimator(C.create_cardinality_estimator(C.pool("CartesianProduct"), db.name));
+ - ]
23 : :
24 [ + - ]: 19 : std::size_t num_columns = table.num_attrs();
25 [ + - + - ]: 19 : std::size_t num_rows = table.store().num_rows();
26 : :
27 [ + - ]: 19 : Diagnostic diag(false, std::cout, std::cerr);
28 : :
29 [ + - ]: 19 : auto primary_key = table.primary_key();
30 : 19 : std::vector<std::size_t> primary_key_id;
31 [ + + ]: 38 : for (auto &elem : primary_key) {
32 [ + - ]: 19 : primary_key_id.push_back(elem.get().id);
33 : : }
34 : :
35 [ + - ]: 19 : MatrixXf data(num_rows, num_columns - primary_key_id.size());
36 [ + - + - ]: 19 : MatrixXi null_matrix = MatrixXi::Zero(data.rows(), data.cols());
37 : 19 : std::unordered_map<ThreadSafePooledString, unsigned> attribute_to_id;
38 : 1 :
39 : 19 : std::size_t primary_key_count = 0;
40 : :
41 [ + - + - : 19 : const std::string table_name = *table.name();
+ - ]
42 [ + - + - : 19 : auto stmt = statement_from_string(diag, "SELECT * FROM " + table_name + ";");
+ - ]
43 [ + - ]: 19 : std::unique_ptr<ast::SelectStmt> select_stmt(dynamic_cast<ast::SelectStmt*>(stmt.release()));
44 : :
45 : : /* fill the data matrix with the given table */
46 [ + + ]: 89 : for (std::size_t current_column = 0; current_column < num_columns; current_column++) {
47 [ + - ]: 70 : auto lower_bound = std::lower_bound(primary_key_id.begin(), primary_key_id.end(), current_column);
48 [ + + + + ]: 70 : if (lower_bound != primary_key_id.end() && *lower_bound == current_column) {
49 : 19 : primary_key_count++;
50 : 19 : continue;
51 : : }
52 : :
53 [ + - + - : 51 : auto attribute = table.schema()[current_column].id.name;
+ - ]
54 [ + - ]: 51 : attribute_to_id.emplace(attribute, current_column - primary_key_count);
55 : :
56 [ + - ]: 51 : auto &type = table.at(current_column).type;
57 : 52 : std::size_t current_row = 0;
58 : :
59 [ + - - + ]: 51 : if (type->is_float()) {
60 [ # # ]: 0 : if (leaf_types[current_column - primary_key_count] == Spn::AUTO) {
61 : 0 : leaf_types[current_column - primary_key_count] = Spn::CONTINUOUS;
62 : 0 : }
63 [ # # ]: 0 : auto callback_data = std::make_unique<CallbackOperator>([&](const Schema &S, const Tuple &T) {
64 [ # # ]: 0 : if (T.is_null(current_column)) {
65 : 0 : null_matrix(current_row, current_column - primary_key_count) = 1;
66 : 0 : data(current_row, current_column - primary_key_count) = 0;
67 : 0 : } else {
68 : 0 : data(current_row, current_column - primary_key_count) = T.get(current_column).as_f();
69 : : }
70 : 0 : current_row++;
71 : 0 : });
72 [ # # ]: 0 : execute_query(diag, *select_stmt, std::move(callback_data));
73 : 0 : }
74 : 1 :
75 [ + - - + ]: 51 : if (type->is_double()) {
76 [ # # ]: 0 : if (leaf_types[current_column - primary_key_count] == Spn::AUTO) {
77 : 0 : leaf_types[current_column - primary_key_count] = Spn::CONTINUOUS;
78 : 0 : }
79 [ # # ]: 0 : auto callback_data = std::make_unique<CallbackOperator>([&](const Schema &S, const Tuple &T) {
80 [ # # ]: 0 : if (T.is_null(current_column)) {
81 : 0 : null_matrix(current_row, current_column - primary_key_count) = 1;
82 : 0 : data(current_row, current_column - primary_key_count) = 0;
83 : 0 : } else {
84 : 0 : data(current_row, current_column - primary_key_count) = float(T.get(current_column).as_d());
85 : : }
86 : 0 : current_row++;
87 : 0 : });
88 [ # # ]: 0 : execute_query(diag, *select_stmt, std::move(callback_data));
89 : 0 : }
90 : :
91 [ + - + - ]: 51 : if (type->is_integral()) {
92 [ + + ]: 51 : if (leaf_types[current_column - primary_key_count] == Spn::AUTO) {
93 : 9 : leaf_types[current_column - primary_key_count] = Spn::DISCRETE;
94 : 9 : }
95 [ + - ]: 4277 : auto callback_data = std::make_unique<CallbackOperator>([&](const Schema &S, const Tuple &T) {
96 [ - + ]: 4226 : if (T.is_null(current_column)) {
97 : 0 : null_matrix(current_row, current_column - primary_key_count) = 1;
98 : 0 : data(current_row, current_column - primary_key_count) = 0;
99 : 0 : } else {
100 : 4226 : data(current_row, current_column - primary_key_count) = float(T.get(current_column).as_i());
101 : : }
102 : 4226 : current_row++;
103 : 4226 : });
104 [ - + ]: 51 : execute_query(diag, *select_stmt, std::move(callback_data));
105 : 51 : }
106 : :
107 [ + - - + ]: 51 : if (type->is_character_sequence()) {
108 [ # # ]: 0 : if (leaf_types[current_column - primary_key_count] == Spn::AUTO) {
109 : 0 : leaf_types[current_column - primary_key_count] = Spn::CONTINUOUS;
110 : 0 : }
111 [ # # ]: 0 : auto callback_data = std::make_unique<CallbackOperator>([&](const Schema &S, const Tuple &T) {
112 [ # # ]: 0 : if (T.is_null(current_column)) {
113 : 0 : null_matrix(current_row, current_column - primary_key_count) = 1;
114 : 0 : data(current_row, current_column - primary_key_count) = 0;
115 : 0 : } else {
116 : 0 : auto v_pointer = T.get(current_column).as_p();
117 : 0 : const char* value = static_cast<const char*>(v_pointer);
118 : 0 : data(current_row, current_column - primary_key_count) = float(std::hash<const char*>{}(value));
119 : : //data(current_row, current_column-primary_key_count) = 0;
120 : : }
121 : 0 : current_row++;
122 : 0 : });
123 [ # # ]: 0 : execute_query(diag, *select_stmt, std::move(callback_data));
124 : 0 : }
125 : 51 : }
126 : :
127 [ + - ]: 19 : db.cardinality_estimator(std::move(old_estimator));
128 : :
129 [ + - - + ]: 19 : return SpnWrapper(Spn::learn_spn(data, null_matrix, leaf_types), std::move(attribute_to_id));
130 : 19 : }
131 : :
132 : : std::unordered_map<ThreadSafePooledString, SpnWrapper*>
133 : 0 : SpnWrapper::learn_spn_database(const ThreadSafePooledString &name_of_database,
134 : : std::unordered_map<ThreadSafePooledString, std::vector<Spn::LeafType>> leaf_types)
135 : : {
136 : 0 : auto &C = Catalog::Get();
137 : 0 : auto &db = C.get_database(name_of_database);
138 : :
139 : 0 : std::unordered_map<ThreadSafePooledString, SpnWrapper*> spns;
140 : :
141 [ # # # # : 0 : for (auto table_it = db.begin_tables(); table_it != db.end_tables(); table_it++) {
# # ]
142 [ # # # # ]: 0 : spns.emplace(
143 : 0 : table_it->first,
144 [ # # ]: 0 : new SpnWrapper(
145 [ # # # # ]: 0 : learn_spn_table(name_of_database, table_it->first, std::move(leaf_types[table_it->first]))
146 : : )
147 : : );
148 : 0 : }
149 : :
150 : 0 : return spns;
151 [ # # ]: 0 : }
|