LCOV - code coverage report
Current view: top level - src/util - GridSearch.hpp (source / functions) Hit Total Coverage
Test: coverage.info Lines: 68 71 95.8 %
Date: 2025-03-25 01:19:55 Functions: 74 108 68.5 %
Branches: 85 200 42.5 %

           Branch data     Line data    Source code
       1                 :            : #pragma once
       2                 :            : 
       3                 :            : #include <algorithm>
       4                 :            : #include <array>
       5                 :            : #include <cfenv>
       6                 :            : #include <cmath>
       7                 :            : #include <functional>
       8                 :            : #include <iostream>
       9                 :            : #include <mutable/util/macro.hpp>
      10                 :            : #include <stdexcept>
      11                 :            : #include <tuple>
      12                 :            : #include <type_traits>
      13                 :            : #include <utility>
      14                 :            : #include <vector>
      15                 :            : 
      16                 :            : 
      17                 :            : namespace m {
      18                 :            : 
      19                 :            : namespace gs {
      20                 :            : 
      21                 :            : template<typename T, template<typename> typename Derived>
      22                 :            : struct Space
      23                 :            : {
      24                 :            :     using derived_type = Derived<T>;
      25                 :            :     using value_type = T;
      26                 :            : 
      27                 :            : #define CDERIVED (*static_cast<const derived_type*>(this))
      28                 :          1 :     value_type lo() const { return CDERIVED.lo(); }
      29                 :          1 :     value_type hi() const { return CDERIVED.hi(); }
      30                 :          1 :     double step() const { return CDERIVED.step(); }
      31                 :          1 :     unsigned num_steps() const { return CDERIVED.num_steps(); }
      32                 :            : 
      33                 :            :     value_type at(unsigned n) const { return CDERIVED.at(n); }
      34                 :            :     value_type operator()(unsigned n) const {return at(n); }
      35                 :            : 
      36                 :            :     std::vector<value_type> sequence() const { return CDERIVED.sequence(); }
      37                 :            : #undef CDERIVED
      38                 :            : 
      39                 :            : M_LCOV_EXCL_START
      40                 :            :     friend std::ostream & operator<<(std::ostream &out, const Space &S) {
      41                 :            :         return out << static_cast<const derived_type&>(S);
      42                 :            :     }
      43                 :            : 
      44                 :            :     void dump(std::ostream &out) const { out << *this << std::endl; }
      45                 :            :     void dump() const { dump(std::cerr); }
      46                 :            : M_LCOV_EXCL_STOP
      47                 :            : };
      48                 :            : 
      49                 :            : template<typename T>
      50                 :            : struct LinearSpace : Space<T, LinearSpace>
      51                 :            : {
      52                 :            :     static_assert(std::is_arithmetic_v<T>, "type T must be an arithmetic type");
      53                 :            :     using value_type = T;
      54                 :            :     using difference_type = typename std::conditional_t<std::is_integral_v<T>,
      55                 :            :                                                         std::make_signed<T>,
      56                 :            :                                                         std::common_type<T>>::type;
      57                 :            : 
      58                 :            :     private:
      59                 :            :     value_type lo_;
      60                 :            :     value_type hi_;
      61                 :            :     double step_;
      62                 :            :     unsigned num_steps_;
      63                 :            :     bool is_ascending_;
      64                 :            : 
      65                 :            :     public:
      66                 :         38 :     LinearSpace(value_type lowest, value_type highest, unsigned num_steps, bool is_ascending = true)
      67                 :         38 :         : lo_(lowest), hi_(highest), num_steps_(num_steps), is_ascending_(is_ascending)
      68                 :            :     {
      69   [ +  -  +  -  :         38 :         if (lo_ > hi_)
          +  +  +  -  +  
             -  +  -  #  
                      # ]
      70   [ #  #  #  #  :          1 :             throw std::invalid_argument("invalid range");
          +  -  #  #  #  
             #  #  #  #  
                      # ]
      71   [ +  -  +  -  :         37 :         if (num_steps_ == 0)
          +  +  +  -  +  
             -  +  -  #  
                      # ]
      72   [ #  #  #  #  :          1 :             throw std::invalid_argument("number of steps must not be zero");
          +  -  #  #  #  
             #  #  #  #  
                      # ]
      73                 :            : 
      74                 :         36 :         const int save_round = std::fegetround();
      75                 :         36 :         std::fesetround(FE_TOWARDZERO);
      76                 :         36 :         step_ = (double(hi_) - double(lo_)) / num_steps_;
      77                 :         36 :         std::fesetround(save_round);
      78                 :         36 :     }
      79                 :            : 
      80                 :          2 :     static LinearSpace Ascending(value_type lowest, value_type highest, unsigned num_steps) {
      81                 :          2 :         return LinearSpace(lowest, highest, num_steps, true);
      82                 :            :     }
      83                 :            : 
      84                 :          2 :     static LinearSpace Descending(value_type lowest, value_type highest, unsigned num_steps) {
      85                 :          2 :         return LinearSpace(lowest, highest, num_steps, false);
      86                 :            :     }
      87                 :            : 
      88                 :        199 :     value_type lo() const { return lo_; }
      89                 :         71 :     value_type hi() const { return hi_; }
      90                 :        199 :     double step() const { return step_; }
      91                 :        200 :     unsigned num_steps() const { return num_steps_; }
      92                 :          0 :     difference_type delta() const { return hi_ - lo_; }
      93                 :        242 :     bool ascending() const { return is_ascending_; }
      94                 :         26 :     bool descending() const { return not is_ascending_; }
      95                 :            : 
      96                 :        217 :     value_type at(unsigned n) const {
      97   [ +  -  +  -  :        217 :         if (n > num_steps_)
          +  +  +  -  +  
             -  +  -  #  
                      # ]
      98   [ #  #  #  #  :          1 :             throw std::out_of_range("n must be between 0 and num_steps()");
          +  -  #  #  #  
             #  #  #  #  
                      # ]
      99                 :            :         if constexpr (std::is_integral_v<value_type>) {
     100                 :        172 :             const typename std::make_unsigned_t<value_type> delta = std::round(n * step());
     101   [ +  -  +  +  :        172 :             if (ascending())
          +  +  +  +  #  
                      # ]
     102                 :        128 :                 return lo() + delta;
     103                 :            :             else
     104                 :         44 :                 return hi() - delta;
     105                 :            :         } else {
     106   [ +  -  +  - ]:         44 :             if (ascending())
     107                 :         44 :                 return std::clamp<value_type>(value_type(double(lo()) + n * step_), lo_, hi_);
     108                 :            :             else
     109                 :          0 :                 return std::clamp<value_type>(value_type(double(hi()) - n * step_), lo_, hi_);
     110                 :            :         }
     111                 :        216 :     }
     112                 :         23 :     value_type operator()(unsigned n) const { return at(n); }
     113                 :            : 
     114                 :         19 :     std::vector<value_type> sequence() const {
     115                 :         19 :         std::vector<value_type> vec;
     116   [ +  -  +  -  :         19 :         vec.reserve(num_steps());
          +  -  +  -  +  
          -  +  -  +  -  
          +  -  +  -  +  
             -  +  -  +  
                      - ]
     117                 :            : 
     118   [ +  -  +  +  :        150 :         for (unsigned i = 0; i <= num_steps(); ++i)
          +  -  +  +  +  
          -  +  +  +  -  
          +  +  +  -  +  
             +  +  -  +  
                      + ]
     119   [ +  -  +  -  :        131 :             vec.push_back(at(i));
          +  -  +  -  +  
          -  +  -  +  -  
          +  -  +  -  +  
             -  +  -  +  
                      - ]
     120                 :            : 
     121                 :         19 :         return vec;
     122   [ +  -  +  -  :         19 :     }
          +  -  +  -  +  
                -  +  - ]
     123                 :            : 
     124                 :            : M_LCOV_EXCL_START
     125                 :            :     friend std::ostream & operator<<(std::ostream &out, const LinearSpace &S) {
     126                 :            :         return out << "linear space from " << S.lo() << " to " << S.hi() << " with " << S.num_steps() << " steps of "
     127                 :            :                    << S.step();
     128                 :            :     }
     129                 :            : 
     130                 :            :     void dump(std::ostream &out) const { out << *this << std::endl; }
     131                 :            :     void dump() const { dump(std::cerr); }
     132                 :            : M_LCOV_EXCL_STOP
     133                 :            : };
     134                 :            : 
     135                 :            : template<typename... Spaces>
     136                 :            : struct GridSearch
     137                 :            : {
     138                 :            :     using callback_type = std::function<void(typename Spaces::value_type...)>;
     139                 :            :     static constexpr std::size_t NUM_SPACES = sizeof...(Spaces);
     140                 :            : 
     141                 :            :     private:
     142                 :            :     std::tuple<Spaces...> spaces_;
     143                 :            : 
     144                 :            :     public:
     145                 :          1 :     GridSearch(Spaces... spaces) : spaces_(std::forward<Spaces>(spaces)...) { }
     146                 :            : 
     147                 :          1 :     constexpr std::size_t num_spaces() const { return NUM_SPACES; }
     148                 :            : 
     149                 :          1 :     std::size_t num_points() const {
     150                 :          2 :         return std::apply([](auto&... space) {
     151                 :          1 :             return ((space.num_steps() + 1) * ... );
     152                 :          1 :         }, spaces_);
     153                 :            :     }
     154                 :            : 
     155                 :            :     void search(callback_type fn) const;
     156   [ #  #  #  #  :          0 :     void operator()(callback_type fn) const { search(fn); }
                   #  # ]
     157                 :            : 
     158                 :            : M_LCOV_EXCL_START
     159                 :            :     friend std::ostream & operator<<(std::ostream &out, const GridSearch &GS) {
     160                 :            :         out << "grid search with";
     161                 :            : 
     162                 :            :         std::apply([&out](auto&... space) {
     163                 :            :             ((out << "\n  " << space), ...); // use C++17 fold-expression
     164                 :            :         }, GS.spaces_);
     165                 :            : 
     166                 :            :         return out;
     167                 :            :     }
     168                 :            : 
     169                 :            :     void dump(std::ostream &out) const { out << *this << std::endl; }
     170                 :            :     void dump() const { dump(std::cerr); }
     171                 :            : M_LCOV_EXCL_STOP
     172                 :            : 
     173                 :            :     private:
     174                 :            :     template<std::size_t... I>
     175                 :            :     std::tuple<typename Spaces::value_type...>
     176                 :         10 :     make_args(std::array<unsigned, NUM_SPACES> &counters, std::index_sequence<I...>) const {
     177                 :         30 :         return std::apply([&counters](auto&... space) {
     178                 :         10 :             return std::make_tuple(space(counters[I])...);
     179                 :         10 :         }, spaces_);
     180                 :            :     }
     181                 :            : };
     182                 :            : 
     183                 :            : template<typename... Spaces>
     184                 :          1 : void GridSearch<Spaces...>::search(callback_type fn) const
     185                 :            : {
     186                 :            :     std::array<unsigned, NUM_SPACES> counters;
     187                 :          1 :     std::fill(counters.begin(), counters.end(), 0U);
     188                 :          2 :     const std::array<unsigned, NUM_SPACES> num_steps = std::apply([](auto&... space) {
     189                 :          1 :         return std::array<unsigned, NUM_SPACES>{ space.num_steps()... };
     190                 :          1 :     }, spaces_);
     191                 :            : 
     192                 :         10 :     for (;;) {
     193                 :         10 :         auto args = make_args(counters, std::index_sequence_for<Spaces...>{});
     194                 :         10 :         std::apply(fn, args);
     195                 :            : 
     196                 :         10 :         std::size_t idx = NUM_SPACES - 1;
     197                 :            : 
     198   [ +  +  #  #  :         12 :         while (counters[idx] == num_steps[idx]) {
                   #  # ]
     199   [ +  +  #  #  :          3 :             if (idx == 0) goto finished;
                   #  # ]
     200                 :          2 :             counters[idx] = 0;
     201                 :          2 :             --idx;
     202                 :            :         }
     203                 :          9 :         ++counters[idx];
     204                 :            :     }
     205                 :            : finished:;
     206                 :          1 : }
     207                 :            : 
     208                 :            : }
     209                 :            : 
     210                 :            : }

Generated by: LCOV version 1.16