Branch data Line data Source code
1 : : #include <mutable/util/LinearModel.hpp> 2 : : 3 : : #include "util/eigen_meta.hpp" 4 : : #include <mutable/util/Diagnostic.hpp> 5 : : #include <mutable/util/Diagnostic.hpp> 6 : : 7 : : 8 : : using namespace m; 9 : : 10 : : 11 : : //====================================================================================================================== 12 : : // LinearModel Methods 13 : : //====================================================================================================================== 14 : : 15 : 4 : LinearModel::LinearModel(const Eigen::MatrixXd &X, const Eigen::VectorXd &y) 16 : 4 : : coefficients_(regression_linear_closed_form(X, y)), num_features_(coefficients_.rows()) {} 17 : : 18 : 0 : LinearModel::LinearModel(const Eigen::MatrixXd &X, const Eigen::VectorXd &y, 19 : : const std::function<Eigen::MatrixXd(Eigen::MatrixXd)> &transform_function) 20 : : { 21 : : /* apply transformations */ 22 [ # # # # ]: 0 : auto X_trans = transform_function(X); 23 : : 24 : 0 : num_features_ = X.cols(); 25 [ # # ]: 0 : transformation_ = transform_function; 26 [ # # ]: 0 : coefficients_ = regression_linear_closed_form(X_trans, y); 27 : 0 : } 28 : : 29 : 6 : double LinearModel::predict_target(const Eigen::RowVectorXd& feature_vector) const 30 : : { 31 [ - + ]: 6 : M_insist(feature_vector.rows() == 1 and num_features_ - 1 == feature_vector.cols()); 32 : : /* The linear regression algorithm requires a column of only ones to properly 33 : : * train the y-intercept of the model. This column is added here. */ 34 : 6 : Eigen::RowVectorXd concat_vector(1, feature_vector.cols() + 1); 35 [ + - + - ]: 6 : concat_vector << 1, feature_vector; 36 [ + - ]: 6 : M_insist(num_features_ == concat_vector.cols()); 37 : : 38 [ - + ]: 7 : if (bool(transformation_)) { 39 : : /* apply transformations */ 40 [ # # # # : 0 : concat_vector = transformation_(concat_vector); # # ] 41 : 0 : } 42 : : 43 [ + - + - ]: 6 : return concat_vector * coefficients_; 44 : 6 : } 45 : : 46 : : M_LCOV_EXCL_START 47 : : std::ostream & m::operator<<(std::ostream &out, const LinearModel &linear_model) 48 : : { 49 : : out << "LinearModel "; 50 : : if (bool(linear_model.transformation_)) { 51 : : out << "with transformations "; 52 : : } 53 : : out << "\n[\n" << linear_model.coefficients_ << "\n]\n"; 54 : : return out; 55 : : } 56 : : 57 : : void LinearModel::dump(std::ostream &out) const 58 : : { 59 : : out << *this; 60 : : out.flush(); 61 : : } 62 : : 63 : : void LinearModel::dump() const { dump(std::cerr); } 64 : : M_LCOV_EXCL_STOP