Branch data Line data Source code
1 : : #include "util/Kmeans.hpp" 2 : : 3 : : #include <cfloat> 4 : : #include <cmath> 5 : : #include <limits> 6 : : #include <mutable/util/macro.hpp> 7 : : #include <random> 8 : : 9 : : 10 : : using namespace m; 11 : : using namespace Eigen; 12 : : 13 : : 14 : : static constexpr unsigned KMEANS_MAX_ITERATIONS = 100; 15 : : 16 : : 17 : 258 : MatrixRXf m::kmeans_plus_plus(const MatrixXf &data, unsigned k) 18 : : { 19 : 258 : std::mt19937 g(0); 20 : : /** Chosen initial centroids for *k*-means. */ 21 : 258 : MatrixRXf centroids(k, data.cols()); 22 : : /** Stores the squared distance of each data point to its nearest centroid. Initialized to maximum representable 23 : : * value, as no centroid exists yet. Used as weighted probability for choosing the next centroid. */ 24 [ + - ]: 258 : VectorXf weights(data.rows()); 25 [ + - ]: 258 : weights.setConstant(std::numeric_limits<float>::max()); 26 : : 27 [ + + ]: 1073 : for (unsigned i = 0; i != k; ++i) { 28 : : /*----- Pick next centroid. -----*/ 29 [ + - + - : 815 : std::discrete_distribution<unsigned> dist(weights.data(), weights.data() + weights.size()); + - ] 30 [ + - + - : 815 : centroids.row(i) = data.row(dist(g)); + - + - ] 31 : : /*----- Update distances. -----*/ 32 [ + - + - : 815 : weights = weights.cwiseMin((data.rowwise() - centroids.row(i)).rowwise().squaredNorm()); // cell-wise minimum + - + - + - + - + - ] 33 : 815 : } 34 : : 35 : 258 : return centroids; 36 [ + - ]: 258 : } 37 : : 38 : 260 : std::pair<std::vector<unsigned>, MatrixRXf> m::kmeans_with_centroids(const MatrixXf &data, unsigned k) 39 : : { 40 : 259 : M_insist(k >= 1, "kmeans requires at least one cluster"); 41 [ + + + - : 259 : if (data.size() == 0) return std::make_pair(std::vector<unsigned>(), MatrixXf(0, data.cols())); + - + - ] 42 : : 43 : : /* Compute initial centroids via k-means++. */ 44 : 258 : MatrixRXf centroids = kmeans_plus_plus(data, k); 45 : : 46 [ + - ]: 258 : std::vector<unsigned> labels(data.rows(), 0); // the labels assigned to the data points 47 [ + - ]: 258 : std::vector<unsigned> label_counters(k, 0); // the frequency of each label, used for iterative mean 48 : 258 : bool change = true; // whether the assignment of labels changed 49 : : 50 : 258 : unsigned i = KMEANS_MAX_ITERATIONS; 51 [ + + + + ]: 1558 : while (change and i--) { 52 : 1300 : change = false; 53 : : 54 : : /*----- Assignment step: Compute nearest centroid for all data points. ---------------------------------------*/ 55 [ + + ]: 77724 : for (unsigned row_id = 0; row_id != data.rows(); ++row_id) { 56 : : unsigned label; 57 [ + - + - : 76425 : auto deltas = (centroids.rowwise() - data.row(row_id)).rowwise().squaredNorm(); + - + - + - ] 58 [ + - ]: 76424 : deltas.minCoeff(&label); 59 [ + + ]: 76424 : change = change or labels[row_id] != label; // label has changed 60 : 76424 : labels[row_id] = label; 61 : 76424 : } 62 : : 63 : : /*----- Update step: Compute new centroids as the mean of data points in the cluster. ------------------------*/ 64 [ + - ]: 1300 : label_counters.assign(k, 0); // reset frequencies 65 [ + - + - ]: 1300 : centroids = MatrixRXf::Zero(centroids.rows(), centroids.cols()); // reset centroids 66 [ + + ]: 77724 : for (unsigned row_id = 0; row_id != data.rows(); ++row_id) { 67 : 76424 : const auto l = labels[row_id]; 68 [ + - + - : 76424 : centroids.row(l) += (data.row(row_id) - centroids.row(l)) / ++label_counters[l]; // iterative mean + - + - + - + - ] 69 : 76424 : } 70 : : } 71 : : 72 [ + - ]: 258 : return std::make_pair(std::move(labels), std::move(centroids)); 73 : 259 : }