Branch data Line data Source code
1 : : #pragma once
2 : :
3 : : #include <algorithm>
4 : : #include <array>
5 : : #include <cfenv>
6 : : #include <cmath>
7 : : #include <functional>
8 : : #include <iostream>
9 : : #include <mutable/util/macro.hpp>
10 : : #include <stdexcept>
11 : : #include <tuple>
12 : : #include <type_traits>
13 : : #include <utility>
14 : : #include <vector>
15 : :
16 : :
17 : : namespace m {
18 : :
19 : : namespace gs {
20 : :
21 : : template<typename T, template<typename> typename Derived>
22 : : struct Space
23 : : {
24 : : using derived_type = Derived<T>;
25 : : using value_type = T;
26 : :
27 : : #define CDERIVED (*static_cast<const derived_type*>(this))
28 : 1 : value_type lo() const { return CDERIVED.lo(); }
29 : 1 : value_type hi() const { return CDERIVED.hi(); }
30 : 1 : double step() const { return CDERIVED.step(); }
31 : 1 : unsigned num_steps() const { return CDERIVED.num_steps(); }
32 : :
33 : : value_type at(unsigned n) const { return CDERIVED.at(n); }
34 : : value_type operator()(unsigned n) const {return at(n); }
35 : :
36 : : std::vector<value_type> sequence() const { return CDERIVED.sequence(); }
37 : : #undef CDERIVED
38 : :
39 : : M_LCOV_EXCL_START
40 : : friend std::ostream & operator<<(std::ostream &out, const Space &S) {
41 : : return out << static_cast<const derived_type&>(S);
42 : : }
43 : :
44 : : void dump(std::ostream &out) const { out << *this << std::endl; }
45 : : void dump() const { dump(std::cerr); }
46 : : M_LCOV_EXCL_STOP
47 : : };
48 : :
49 : : template<typename T>
50 : : struct LinearSpace : Space<T, LinearSpace>
51 : : {
52 : : static_assert(std::is_arithmetic_v<T>, "type T must be an arithmetic type");
53 : : using value_type = T;
54 : : using difference_type = typename std::conditional_t<std::is_integral_v<T>,
55 : : std::make_signed<T>,
56 : : std::common_type<T>>::type;
57 : :
58 : : private:
59 : : value_type lo_;
60 : : value_type hi_;
61 : : double step_;
62 : : unsigned num_steps_;
63 : : bool is_ascending_;
64 : :
65 : : public:
66 : 38 : LinearSpace(value_type lowest, value_type highest, unsigned num_steps, bool is_ascending = true)
67 : 38 : : lo_(lowest), hi_(highest), num_steps_(num_steps), is_ascending_(is_ascending)
68 : : {
69 [ + - + - : 38 : if (lo_ > hi_)
+ + + - +
- + - #
# ]
70 [ # # # # : 1 : throw std::invalid_argument("invalid range");
+ - # # #
# # # #
# ]
71 [ + - + - : 37 : if (num_steps_ == 0)
+ + + - +
- + - #
# ]
72 [ # # # # : 1 : throw std::invalid_argument("number of steps must not be zero");
+ - # # #
# # # #
# ]
73 : :
74 : 36 : const int save_round = std::fegetround();
75 : 36 : std::fesetround(FE_TOWARDZERO);
76 : 36 : step_ = (double(hi_) - double(lo_)) / num_steps_;
77 : 36 : std::fesetround(save_round);
78 : 36 : }
79 : :
80 : 2 : static LinearSpace Ascending(value_type lowest, value_type highest, unsigned num_steps) {
81 : 2 : return LinearSpace(lowest, highest, num_steps, true);
82 : : }
83 : :
84 : 2 : static LinearSpace Descending(value_type lowest, value_type highest, unsigned num_steps) {
85 : 2 : return LinearSpace(lowest, highest, num_steps, false);
86 : : }
87 : :
88 : 199 : value_type lo() const { return lo_; }
89 : 71 : value_type hi() const { return hi_; }
90 : 199 : double step() const { return step_; }
91 : 200 : unsigned num_steps() const { return num_steps_; }
92 : 0 : difference_type delta() const { return hi_ - lo_; }
93 : 242 : bool ascending() const { return is_ascending_; }
94 : 26 : bool descending() const { return not is_ascending_; }
95 : :
96 : 217 : value_type at(unsigned n) const {
97 [ + - + - : 217 : if (n > num_steps_)
+ + + - +
- + - #
# ]
98 [ # # # # : 1 : throw std::out_of_range("n must be between 0 and num_steps()");
+ - # # #
# # # #
# ]
99 : : if constexpr (std::is_integral_v<value_type>) {
100 : 172 : const typename std::make_unsigned_t<value_type> delta = std::round(n * step());
101 [ + - + + : 172 : if (ascending())
+ + + + #
# ]
102 : 128 : return lo() + delta;
103 : : else
104 : 44 : return hi() - delta;
105 : : } else {
106 [ + - + - ]: 44 : if (ascending())
107 : 44 : return std::clamp<value_type>(value_type(double(lo()) + n * step_), lo_, hi_);
108 : : else
109 : 0 : return std::clamp<value_type>(value_type(double(hi()) - n * step_), lo_, hi_);
110 : : }
111 : 216 : }
112 : 23 : value_type operator()(unsigned n) const { return at(n); }
113 : :
114 : 19 : std::vector<value_type> sequence() const {
115 : 19 : std::vector<value_type> vec;
116 [ + - + - : 19 : vec.reserve(num_steps());
+ - + - +
- + - + -
+ - + - +
- + - +
- ]
117 : :
118 [ + - + + : 150 : for (unsigned i = 0; i <= num_steps(); ++i)
+ - + + +
- + + + -
+ + + - +
+ + - +
+ ]
119 [ + - + - : 131 : vec.push_back(at(i));
+ - + - +
- + - + -
+ - + - +
- + - +
- ]
120 : :
121 : 19 : return vec;
122 [ + - + - : 19 : }
+ - + - +
- + - ]
123 : :
124 : : M_LCOV_EXCL_START
125 : : friend std::ostream & operator<<(std::ostream &out, const LinearSpace &S) {
126 : : return out << "linear space from " << S.lo() << " to " << S.hi() << " with " << S.num_steps() << " steps of "
127 : : << S.step();
128 : : }
129 : :
130 : : void dump(std::ostream &out) const { out << *this << std::endl; }
131 : : void dump() const { dump(std::cerr); }
132 : : M_LCOV_EXCL_STOP
133 : : };
134 : :
135 : : template<typename... Spaces>
136 : : struct GridSearch
137 : : {
138 : : using callback_type = std::function<void(typename Spaces::value_type...)>;
139 : : static constexpr std::size_t NUM_SPACES = sizeof...(Spaces);
140 : :
141 : : private:
142 : : std::tuple<Spaces...> spaces_;
143 : :
144 : : public:
145 : 1 : GridSearch(Spaces... spaces) : spaces_(std::forward<Spaces>(spaces)...) { }
146 : :
147 : 1 : constexpr std::size_t num_spaces() const { return NUM_SPACES; }
148 : :
149 : 1 : std::size_t num_points() const {
150 : 2 : return std::apply([](auto&... space) {
151 : 1 : return ((space.num_steps() + 1) * ... );
152 : 1 : }, spaces_);
153 : : }
154 : :
155 : : void search(callback_type fn) const;
156 [ # # # # : 0 : void operator()(callback_type fn) const { search(fn); }
# # ]
157 : :
158 : : M_LCOV_EXCL_START
159 : : friend std::ostream & operator<<(std::ostream &out, const GridSearch &GS) {
160 : : out << "grid search with";
161 : :
162 : : std::apply([&out](auto&... space) {
163 : : ((out << "\n " << space), ...); // use C++17 fold-expression
164 : : }, GS.spaces_);
165 : :
166 : : return out;
167 : : }
168 : :
169 : : void dump(std::ostream &out) const { out << *this << std::endl; }
170 : : void dump() const { dump(std::cerr); }
171 : : M_LCOV_EXCL_STOP
172 : :
173 : : private:
174 : : template<std::size_t... I>
175 : : std::tuple<typename Spaces::value_type...>
176 : 10 : make_args(std::array<unsigned, NUM_SPACES> &counters, std::index_sequence<I...>) const {
177 : 30 : return std::apply([&counters](auto&... space) {
178 : 10 : return std::make_tuple(space(counters[I])...);
179 : 10 : }, spaces_);
180 : : }
181 : : };
182 : :
183 : : template<typename... Spaces>
184 : 1 : void GridSearch<Spaces...>::search(callback_type fn) const
185 : : {
186 : : std::array<unsigned, NUM_SPACES> counters;
187 : 1 : std::fill(counters.begin(), counters.end(), 0U);
188 : 2 : const std::array<unsigned, NUM_SPACES> num_steps = std::apply([](auto&... space) {
189 : 1 : return std::array<unsigned, NUM_SPACES>{ space.num_steps()... };
190 : 1 : }, spaces_);
191 : :
192 : 10 : for (;;) {
193 : 10 : auto args = make_args(counters, std::index_sequence_for<Spaces...>{});
194 : 10 : std::apply(fn, args);
195 : :
196 : 10 : std::size_t idx = NUM_SPACES - 1;
197 : :
198 [ + + # # : 12 : while (counters[idx] == num_steps[idx]) {
# # ]
199 [ + + # # : 3 : if (idx == 0) goto finished;
# # ]
200 : 2 : counters[idx] = 0;
201 : 2 : --idx;
202 : : }
203 : 9 : ++counters[idx];
204 : : }
205 : : finished:;
206 : 1 : }
207 : :
208 : : }
209 : :
210 : : }
|