DiffEq - Modern C++ ODE Integration Library 1.0.0
High-performance C++ library for solving ODEs with async signal processing
Loading...
Searching...
No Matches
interpolation_decorator.hpp
1#pragma once
2
3#include "integrator_decorator.hpp"
4#include <vector>
5#include <map>
6#include <algorithm>
7#include <functional>
8#include <memory>
9#include <stdexcept>
10#include <cmath>
11
12namespace diffeq::core::composable {
13
17enum class InterpolationMethod {
18 LINEAR, // Linear interpolation
19 CUBIC_SPLINE, // Cubic spline interpolation
20 HERMITE, // Hermite polynomial interpolation
21 AKIMA // Akima spline (smooth, avoids oscillation)
22};
23
28 InterpolationMethod method{InterpolationMethod::CUBIC_SPLINE};
29 size_t max_history_size{10000};
30 bool enable_adaptive_sampling{true};
31 double adaptive_tolerance{1e-6};
32 double min_step_size{1e-12};
33
34 // Memory management
35 bool enable_compression{false};
36 size_t compression_threshold{1000};
37 double compression_tolerance{1e-8};
38
39 // Extrapolation settings
40 bool allow_extrapolation{false};
41 double extrapolation_warning_threshold{0.1}; // Warn if extrapolating beyond 10% of range
42
47 void validate() const {
48 if (max_history_size < 2) {
49 throw std::invalid_argument("max_history_size must be at least 2 for interpolation");
50 }
51
52 if (adaptive_tolerance <= 0) {
53 throw std::invalid_argument("adaptive_tolerance must be positive");
54 }
55
56 if (min_step_size <= 0) {
57 throw std::invalid_argument("min_step_size must be positive");
58 }
59
60 if (compression_threshold > max_history_size) {
61 throw std::invalid_argument("compression_threshold cannot exceed max_history_size");
62 }
63 }
64};
65
70 size_t total_interpolations{0};
71 size_t history_compressions{0};
72 size_t extrapolation_warnings{0};
73 size_t out_of_bounds_queries{0};
74 double max_interpolation_error{0.0};
75 double average_interpolation_time_ns{0.0};
76
77 void update_interpolation_time(double time_ns) {
78 average_interpolation_time_ns = (average_interpolation_time_ns * total_interpolations + time_ns) / (total_interpolations + 1);
79 total_interpolations++;
80 }
81};
82
86template<typename T>
88private:
89 std::vector<T> times_;
90 std::vector<std::vector<typename T::value_type>> states_;
91 std::vector<std::vector<typename T::value_type>> derivatives_;
92 bool computed_{false};
93
94public:
95 void set_data(const std::vector<T>& times, const std::vector<std::vector<typename T::value_type>>& states) {
96 times_ = times;
97 states_ = states;
98 computed_ = false;
99 compute_derivatives();
100 }
101
102 std::vector<typename T::value_type> interpolate(T t) {
103 if (!computed_) {
104 throw std::runtime_error("Spline not computed");
105 }
106
107 if (times_.empty()) {
108 throw std::runtime_error("No data for interpolation");
109 }
110
111 // Find the interval containing t
112 auto it = std::lower_bound(times_.begin(), times_.end(), t);
113
114 if (it == times_.begin()) {
115 return states_[0]; // Extrapolate to first point
116 }
117
118 if (it == times_.end()) {
119 return states_.back(); // Extrapolate to last point
120 }
121
122 size_t idx = std::distance(times_.begin(), it) - 1;
123 T h = times_[idx + 1] - times_[idx];
124 T a = (times_[idx + 1] - t) / h;
125 T b = (t - times_[idx]) / h;
126
127 std::vector<typename T::value_type> result(states_[idx].size());
128
129 for (size_t i = 0; i < result.size(); ++i) {
130 result[i] = a * states_[idx][i] + b * states_[idx + 1][i] +
131 ((a * a * a - a) * derivatives_[idx][i] + (b * b * b - b) * derivatives_[idx + 1][i]) * (h * h) / 6.0;
132 }
133
134 return result;
135 }
136
137private:
138 void compute_derivatives() {
139 size_t n = times_.size();
140 if (n < 2) return;
141
142 derivatives_.resize(n);
143 for (size_t i = 0; i < n; ++i) {
144 derivatives_[i].resize(states_[i].size());
145 }
146
147 if (n == 2) {
148 // Linear case
149 for (size_t j = 0; j < states_[0].size(); ++j) {
150 derivatives_[0][j] = derivatives_[1][j] = 0.0;
151 }
152 computed_ = true;
153 return;
154 }
155
156 // Tridiagonal system solution for cubic spline
157 std::vector<typename T::value_type> a(n), b(n), c(n);
158
159 for (size_t j = 0; j < states_[0].size(); ++j) {
160 std::vector<typename T::value_type> d(n);
161
162 // Set up tridiagonal system
163 for (size_t i = 1; i < n - 1; ++i) {
164 T h1 = times_[i] - times_[i - 1];
165 T h2 = times_[i + 1] - times_[i];
166
167 a[i] = h1;
168 b[i] = 2.0 * (h1 + h2);
169 c[i] = h2;
170 d[i] = 6.0 * ((states_[i + 1][j] - states_[i][j]) / h2 - (states_[i][j] - states_[i - 1][j]) / h1);
171 }
172
173 // Natural boundary conditions
174 b[0] = b[n - 1] = 1.0;
175 c[0] = a[n - 1] = 0.0;
176 d[0] = d[n - 1] = 0.0;
177
178 // Solve tridiagonal system
179 solve_tridiagonal(a, b, c, d);
180
181 for (size_t i = 0; i < n; ++i) {
182 derivatives_[i][j] = d[i];
183 }
184 }
185
186 computed_ = true;
187 }
188
189 void solve_tridiagonal(std::vector<typename T::value_type>& a,
190 std::vector<typename T::value_type>& b,
191 std::vector<typename T::value_type>& c,
192 std::vector<typename T::value_type>& d) {
193 size_t n = b.size();
194
195 // Forward elimination
196 for (size_t i = 1; i < n; ++i) {
197 typename T::value_type m = a[i] / b[i - 1];
198 b[i] = b[i] - m * c[i - 1];
199 d[i] = d[i] - m * d[i - 1];
200 }
201
202 // Back substitution
203 d[n - 1] = d[n - 1] / b[n - 1];
204 for (int i = n - 2; i >= 0; --i) {
205 d[i] = (d[i] - c[i] * d[i + 1]) / b[i];
206 }
207 }
208};
209
225template<system_state S>
227private:
228 InterpolationConfig config_;
229 std::map<typename IntegratorDecorator<S>::time_type, S> state_history_;
230 std::unique_ptr<CubicSplineInterpolator<typename IntegratorDecorator<S>::time_type>> spline_interpolator_;
231 InterpolationStats stats_;
232 mutable std::mutex history_mutex_;
233 typename IntegratorDecorator<S>::time_type last_query_time_{};
234 bool history_compressed_{false};
235
236public:
243 explicit InterpolationDecorator(std::unique_ptr<AbstractIntegrator<S>> integrator,
245 : IntegratorDecorator<S>(std::move(integrator))
246 , config_(std::move(config))
247 , spline_interpolator_(std::make_unique<CubicSplineInterpolator<typename IntegratorDecorator<S>::time_type>>()) {
248
249 config_.validate();
250 }
251
255 void step(typename IntegratorDecorator<S>::state_type& state, typename IntegratorDecorator<S>::time_type dt) override {
256 this->wrapped_integrator_->step(state, dt);
257 record_state(state, this->current_time());
258 }
259
263 void integrate(typename IntegratorDecorator<S>::state_type& state, typename IntegratorDecorator<S>::time_type dt,
264 typename IntegratorDecorator<S>::time_type end_time) override {
265 // Record initial state
266 record_state(state, this->current_time());
267
268 // Integrate with history recording
269 this->wrapped_integrator_->integrate(state, dt, end_time);
270
271 // Record final state
272 record_state(state, this->current_time());
273
274 // Compress history if needed
275 if (config_.enable_compression && state_history_.size() > config_.compression_threshold) {
276 compress_history();
277 }
278 }
279
286 S interpolate_at(typename IntegratorDecorator<S>::time_type t) {
287 std::lock_guard<std::mutex> lock(history_mutex_);
288
289 auto start_time = std::chrono::high_resolution_clock::now();
290
291 if (state_history_.empty()) {
292 throw std::runtime_error("No history available for interpolation");
293 }
294
295 // Check bounds
296 auto bounds = get_time_bounds();
297 if (t < bounds.first || t > bounds.second) {
298 if (!config_.allow_extrapolation) {
299 stats_.out_of_bounds_queries++;
300 throw std::runtime_error("Time " + std::to_string(t) + " is outside interpolation bounds [" +
301 std::to_string(bounds.first) + ", " + std::to_string(bounds.second) + "]");
302 }
303
304 // Check extrapolation warning threshold
305 typename IntegratorDecorator<S>::time_type range = bounds.second - bounds.first;
306 if (std::abs(t - bounds.first) > config_.extrapolation_warning_threshold * range ||
307 std::abs(t - bounds.second) > config_.extrapolation_warning_threshold * range) {
308 stats_.extrapolation_warnings++;
309 }
310 }
311
312 S result = perform_interpolation(t);
313
314 auto end_time = std::chrono::high_resolution_clock::now();
315 double duration_ns = std::chrono::duration_cast<std::chrono::nanoseconds>(end_time - start_time).count();
316 stats_.update_interpolation_time(duration_ns);
317
318 last_query_time_ = t;
319 return result;
320 }
321
327 std::vector<S> interpolate_at_multiple(const std::vector<typename IntegratorDecorator<S>::time_type>& time_points) {
328 std::vector<S> results;
329 results.reserve(time_points.size());
330
331 for (typename IntegratorDecorator<S>::time_type t : time_points) {
332 results.push_back(interpolate_at(t));
333 }
334
335 return results;
336 }
337
345 std::pair<std::vector<typename IntegratorDecorator<S>::time_type>, std::vector<S>> get_dense_output(
346 typename IntegratorDecorator<S>::time_type start_time,
347 typename IntegratorDecorator<S>::time_type end_time,
348 size_t num_points) {
349 if (num_points < 2) {
350 throw std::invalid_argument("num_points must be at least 2");
351 }
352
353 std::vector<typename IntegratorDecorator<S>::time_type> times;
354 std::vector<S> states;
355
356 typename IntegratorDecorator<S>::time_type dt = (end_time - start_time) / (num_points - 1);
357
358 for (size_t i = 0; i < num_points; ++i) {
359 typename IntegratorDecorator<S>::time_type t = start_time + i * dt;
360 times.push_back(t);
361 states.push_back(interpolate_at(t));
362 }
363
364 return {std::move(times), std::move(states)};
365 }
366
371 return stats_;
372 }
373
378 stats_ = InterpolationStats{};
379 }
380
385 std::pair<typename IntegratorDecorator<S>::time_type, typename IntegratorDecorator<S>::time_type> get_time_bounds() const {
386 std::lock_guard<std::mutex> lock(history_mutex_);
387 if (state_history_.empty()) {
388 return {typename IntegratorDecorator<S>::time_type{}, typename IntegratorDecorator<S>::time_type{}};
389 }
390 return {state_history_.begin()->first, state_history_.rbegin()->first};
391 }
392
397 std::lock_guard<std::mutex> lock(history_mutex_);
398 state_history_.clear();
399 history_compressed_ = false;
400 }
401
405 size_t get_history_size() const {
406 std::lock_guard<std::mutex> lock(history_mutex_);
407 return state_history_.size();
408 }
409
413 InterpolationConfig& config() { return config_; }
414 const InterpolationConfig& config() const { return config_; }
415
416private:
420 void record_state(const S& state, typename IntegratorDecorator<S>::time_type time) {
421 std::lock_guard<std::mutex> lock(history_mutex_);
422
423 // Check if we need to make room
424 if (state_history_.size() >= config_.max_history_size) {
425 // Remove oldest entry
426 state_history_.erase(state_history_.begin());
427 }
428
429 state_history_[time] = state;
430 }
431
435 S perform_interpolation(typename IntegratorDecorator<S>::time_type t) {
436 switch (config_.method) {
437 case InterpolationMethod::LINEAR:
438 return linear_interpolation(t);
439 case InterpolationMethod::CUBIC_SPLINE:
440 return cubic_spline_interpolation(t);
441 case InterpolationMethod::HERMITE:
442 return hermite_interpolation(t);
443 case InterpolationMethod::AKIMA:
444 return akima_interpolation(t);
445 default:
446 throw std::runtime_error("Unknown interpolation method");
447 }
448 }
449
453 S linear_interpolation(typename IntegratorDecorator<S>::time_type t) {
454 auto it = state_history_.lower_bound(t);
455
456 if (it == state_history_.begin()) {
457 return it->second;
458 }
459
460 if (it == state_history_.end()) {
461 return state_history_.rbegin()->second;
462 }
463
464 auto prev_it = std::prev(it);
465
466 typename IntegratorDecorator<S>::time_type t1 = prev_it->first;
467 typename IntegratorDecorator<S>::time_type t2 = it->first;
468 const S& s1 = prev_it->second;
469 const S& s2 = it->second;
470
471 typename IntegratorDecorator<S>::time_type alpha = (t - t1) / (t2 - t1);
472
473 S result = s1;
474 for (size_t i = 0; i < result.size(); ++i) {
475 result[i] = (1 - alpha) * s1[i] + alpha * s2[i];
476 }
477
478 return result;
479 }
480
484 S cubic_spline_interpolation(typename IntegratorDecorator<S>::time_type t) {
485 // Prepare data for spline interpolator
486 std::vector<typename IntegratorDecorator<S>::time_type> times;
487 std::vector<std::vector<typename S::value_type>> states;
488
489 times.reserve(state_history_.size());
490 states.reserve(state_history_.size());
491
492 for (const auto& [time, state] : state_history_) {
493 times.push_back(time);
494 std::vector<typename S::value_type> state_vec(state.begin(), state.end());
495 states.push_back(std::move(state_vec));
496 }
497
498 spline_interpolator_->set_data(times, states);
499 auto result_vec = spline_interpolator_->interpolate(t);
500
501 // Convert back to state type
502 S result;
503 if constexpr (std::is_same_v<S, std::vector<typename S::value_type>>) {
504 result = result_vec;
505 } else {
506 std::copy(result_vec.begin(), result_vec.end(), result.begin());
507 }
508
509 return result;
510 }
511
515 S hermite_interpolation(typename IntegratorDecorator<S>::time_type t) {
516 // For now, fall back to cubic spline
517 return cubic_spline_interpolation(t);
518 }
519
523 S akima_interpolation(typename IntegratorDecorator<S>::time_type t) {
524 // For now, fall back to cubic spline
525 return cubic_spline_interpolation(t);
526 }
527
531 void compress_history() {
532 if (state_history_.size() <= config_.compression_threshold) {
533 return;
534 }
535
536 // Simple compression: remove every other point if error is small
537 auto it = state_history_.begin();
538 while (it != state_history_.end() && state_history_.size() > config_.compression_threshold / 2) {
539 auto next_it = std::next(it);
540 if (next_it != state_history_.end()) {
541 auto next_next_it = std::next(next_it);
542 if (next_next_it != state_history_.end()) {
543 // Check if middle point can be removed
544 if (is_point_redundant(*it, *next_it, *next_next_it)) {
545 it = state_history_.erase(next_it);
546 continue;
547 }
548 }
549 }
550 ++it;
551 }
552
553 stats_.history_compressions++;
554 history_compressed_ = true;
555 }
556
560 bool is_point_redundant(const std::pair<typename IntegratorDecorator<S>::time_type, S>& p1,
561 const std::pair<typename IntegratorDecorator<S>::time_type, S>& p2,
562 const std::pair<typename IntegratorDecorator<S>::time_type, S>& p3) {
563 // Simple linear interpolation error check
564 typename IntegratorDecorator<S>::time_type alpha = (p2.first - p1.first) / (p3.first - p1.first);
565
566 for (size_t i = 0; i < p2.second.size(); ++i) {
567 double interpolated = (1 - alpha) * p1.second[i] + alpha * p3.second[i];
568 double error = std::abs(interpolated - p2.second[i]);
569 if (error > config_.compression_tolerance) {
570 return false;
571 }
572 }
573
574 return true;
575 }
576};
577
578} // namespace diffeq::core::composable
Base decorator interface for integrator enhancements.
Interpolation decorator - adds dense output capabilities to any integrator.
void reset_statistics()
Reset interpolation statistics.
InterpolationConfig & config()
Access and modify interpolation configuration.
void step(typename IntegratorDecorator< S >::state_type &state, typename IntegratorDecorator< S >::time_type dt) override
Override step to record state history.
std::pair< std::vector< typename IntegratorDecorator< S >::time_type >, std::vector< S > > get_dense_output(typename IntegratorDecorator< S >::time_type start_time, typename IntegratorDecorator< S >::time_type end_time, size_t num_points)
Get dense output over time interval.
InterpolationDecorator(std::unique_ptr< AbstractIntegrator< S > > integrator, InterpolationConfig config={})
Construct interpolation decorator.
const InterpolationStats & get_statistics() const
Get current interpolation statistics.
S interpolate_at(typename IntegratorDecorator< S >::time_type t)
Get interpolated state at arbitrary time.
std::vector< S > interpolate_at_multiple(const std::vector< typename IntegratorDecorator< S >::time_type > &time_points)
Get interpolated states at multiple time points.
size_t get_history_size() const
Get number of stored history points.
std::pair< typename IntegratorDecorator< S >::time_type, typename IntegratorDecorator< S >::time_type > get_time_bounds() const
Get time bounds of available history.
void integrate(typename IntegratorDecorator< S >::state_type &state, typename IntegratorDecorator< S >::time_type dt, typename IntegratorDecorator< S >::time_type end_time) override
Override integrate to maintain history during integration.
Configuration for interpolation and dense output.
void validate() const
Validate configuration parameters.
Statistics for interpolation operations.