LCOV - code coverage report
Current view: top level - src/util - Kmeans.cpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 39 39 100.0 %
Date: 2025-03-25 01:19:55 Functions: 7 8 87.5 %
Branches: 52 90 57.8 %

           Branch data     Line data    Source code
       1                 :            : #include "util/Kmeans.hpp"
       2                 :            : 
       3                 :            : #include <cfloat>
       4                 :            : #include <cmath>
       5                 :            : #include <limits>
       6                 :            : #include <mutable/util/macro.hpp>
       7                 :            : #include <random>
       8                 :            : 
       9                 :            : 
      10                 :            : using namespace m;
      11                 :            : using namespace Eigen;
      12                 :            : 
      13                 :            : 
      14                 :            : static constexpr unsigned KMEANS_MAX_ITERATIONS = 100;
      15                 :            : 
      16                 :            : 
      17                 :        258 : MatrixRXf m::kmeans_plus_plus(const MatrixXf &data, unsigned k)
      18                 :            : {
      19                 :        258 :     std::mt19937 g(0);
      20                 :            :     /** Chosen initial centroids for *k*-means. */
      21                 :        258 :     MatrixRXf centroids(k, data.cols());
      22                 :            :     /** Stores the squared distance of each data point to its nearest centroid.  Initialized to maximum representable
      23                 :            :      * value, as no centroid exists yet.  Used as weighted probability for choosing the next centroid. */
      24         [ +  - ]:        258 :     VectorXf weights(data.rows());
      25         [ +  - ]:        258 :     weights.setConstant(std::numeric_limits<float>::max());
      26                 :            : 
      27         [ +  + ]:       1073 :     for (unsigned i = 0; i != k; ++i) {
      28                 :            :         /*----- Pick next centroid. -----*/
      29   [ +  -  +  -  :        815 :         std::discrete_distribution<unsigned> dist(weights.data(), weights.data() + weights.size());
                   +  - ]
      30   [ +  -  +  -  :        815 :         centroids.row(i) = data.row(dist(g));
             +  -  +  - ]
      31                 :            :         /*----- Update distances. -----*/
      32   [ +  -  +  -  :        815 :         weights = weights.cwiseMin((data.rowwise() - centroids.row(i)).rowwise().squaredNorm()); // cell-wise minimum
          +  -  +  -  +  
             -  +  -  +  
                      - ]
      33                 :        815 :     }
      34                 :            : 
      35                 :        258 :     return centroids;
      36         [ +  - ]:        258 : }
      37                 :            : 
      38                 :        260 : std::pair<std::vector<unsigned>, MatrixRXf> m::kmeans_with_centroids(const MatrixXf &data, unsigned k)
      39                 :            : {
      40                 :        259 :     M_insist(k >= 1, "kmeans requires at least one cluster");
      41   [ +  +  +  -  :        259 :     if (data.size() == 0) return std::make_pair(std::vector<unsigned>(), MatrixXf(0, data.cols()));
             +  -  +  - ]
      42                 :            : 
      43                 :            :     /* Compute initial centroids via k-means++. */
      44                 :        258 :     MatrixRXf centroids = kmeans_plus_plus(data, k);
      45                 :            : 
      46         [ +  - ]:        258 :     std::vector<unsigned> labels(data.rows(), 0); // the labels assigned to the data points
      47         [ +  - ]:        258 :     std::vector<unsigned> label_counters(k, 0); // the frequency of each label, used for iterative mean
      48                 :        258 :     bool change = true; // whether the assignment of labels changed
      49                 :            : 
      50                 :        258 :     unsigned i = KMEANS_MAX_ITERATIONS;
      51   [ +  +  +  + ]:       1558 :     while (change and i--) {
      52                 :       1300 :         change = false;
      53                 :            : 
      54                 :            :         /*----- Assignment step: Compute nearest centroid for all data points. ---------------------------------------*/
      55         [ +  + ]:      77724 :         for (unsigned row_id = 0; row_id != data.rows(); ++row_id) {
      56                 :            :             unsigned label;
      57   [ +  -  +  -  :      76425 :             auto deltas = (centroids.rowwise() - data.row(row_id)).rowwise().squaredNorm();
          +  -  +  -  +  
                      - ]
      58         [ +  - ]:      76424 :             deltas.minCoeff(&label);
      59         [ +  + ]:      76424 :             change = change or labels[row_id] != label; // label has changed
      60                 :      76424 :             labels[row_id] = label;
      61                 :      76424 :         }
      62                 :            : 
      63                 :            :         /*----- Update step: Compute new centroids as the mean of data points in the cluster. ------------------------*/
      64         [ +  - ]:       1300 :         label_counters.assign(k, 0); // reset frequencies
      65   [ +  -  +  - ]:       1300 :         centroids = MatrixRXf::Zero(centroids.rows(), centroids.cols()); // reset centroids
      66         [ +  + ]:      77724 :         for (unsigned row_id = 0; row_id != data.rows(); ++row_id) {
      67                 :      76424 :             const auto l = labels[row_id];
      68   [ +  -  +  -  :      76424 :             centroids.row(l) += (data.row(row_id) - centroids.row(l)) / ++label_counters[l]; // iterative mean
          +  -  +  -  +  
                -  +  - ]
      69                 :      76424 :         }
      70                 :            :     }
      71                 :            : 
      72         [ +  - ]:        258 :     return std::make_pair(std::move(labels), std::move(centroids));
      73                 :        259 : }

Generated by: LCOV version 1.16