DiffEq - Modern C++ ODE Integration Library 1.0.0
High-performance C++ library for solving ODEs with async signal processing
Loading...
Searching...
No Matches
adaptive_integrator.hpp
1#pragma once
2#include <functional>
3#include <concepts>
4#include <iterator>
5#include <type_traits>
6#include <vector>
7#include <array>
8#include <algorithm>
9#include <cmath>
10#include <stdexcept>
11#include <core/concepts.hpp>
12#include <core/abstract_integrator.hpp>
13#include <core/state_creator.hpp>
14
15namespace diffeq::core {
16
17// Abstract adaptive integrator with error control
18template<system_state S>
20public:
22 using state_type = typename base_type::state_type;
23 using time_type = typename base_type::time_type;
24 using value_type = typename base_type::value_type;
25 using system_function = typename base_type::system_function;
26
27 explicit AdaptiveIntegrator(system_function sys,
28 time_type rtol = static_cast<time_type>(1e-6),
29 time_type atol = static_cast<time_type>(1e-9))
30 : base_type(std::move(sys)), rtol_(rtol), atol_(atol),
31 dt_min_(static_cast<time_type>(1e-12)), dt_max_(static_cast<time_type>(1e2)),
32 safety_factor_(static_cast<time_type>(0.9)) {}
33
34 // Override integrate to use adaptive stepping
35 void integrate(state_type& state, time_type dt, time_type end_time) override {
36 time_type current_dt = dt;
37
38 while (this->current_time_ < end_time) {
39 if (this->current_time_ + current_dt > end_time) {
40 current_dt = end_time - this->current_time_;
41 }
42
43 current_dt = adaptive_step(state, current_dt);
44
45 if (current_dt < dt_min_) {
46 throw std::runtime_error("Step size became too small in adaptive integration");
47 }
48 }
49 }
50
51 // Pure virtual adaptive step - derived classes implement this
52 virtual time_type adaptive_step(state_type& state, time_type dt) = 0;
53
54 // Setters for tolerances
55 void set_tolerances(time_type rtol, time_type atol) {
56 rtol_ = rtol;
57 atol_ = atol;
58 }
59
60 void set_step_limits(time_type dt_min, time_type dt_max) {
61 dt_min_ = dt_min;
62 dt_max_ = dt_max;
63 }
64
65protected:
66 time_type rtol_, atol_; // Relative and absolute tolerances
67 time_type dt_min_, dt_max_; // Step size limits
68 time_type safety_factor_; // Safety factor for step size adjustment
69
70 // Calculate error tolerance for each component
71 time_type calculate_tolerance(value_type y_val) const {
72 return atol_ + rtol_ * std::abs(y_val);
73 }
74
75 // Calculate error norm using SciPy-style L2 norm
76 time_type error_norm(const state_type& error, const state_type& y) const {
77 time_type norm_squared = static_cast<time_type>(0);
78 std::size_t n = 0;
79
80 for (std::size_t i = 0; i < y.size(); ++i) {
81 auto y_it = y.begin();
82 auto err_it = error.begin();
83 time_type scale = atol_ + std::abs(y_it[i]) * rtol_;
84 time_type scaled_error = err_it[i] / scale;
85 norm_squared += scaled_error * scaled_error;
86 ++n;
87 }
88
89 if (n == 0) return static_cast<time_type>(0);
90 return std::sqrt(norm_squared / n);
91 }
92
93 // SciPy-style error norm calculation using max of current and new state
94 time_type error_norm_scipy_style(const state_type& error, const state_type& y_old, const state_type& y_new) const {
95 time_type norm_squared = static_cast<time_type>(0);
96 std::size_t n = 0;
97
98 for (std::size_t i = 0; i < y_old.size(); ++i) {
99 auto y_old_it = y_old.begin();
100 auto y_new_it = y_new.begin();
101 auto err_it = error.begin();
102
103 // SciPy uses: scale = atol + max(abs(y), abs(y_new)) * rtol
104 time_type scale = atol_ + std::max(std::abs(y_old_it[i]), std::abs(y_new_it[i])) * rtol_;
105 time_type scaled_error = err_it[i] / scale;
106 norm_squared += scaled_error * scaled_error;
107 ++n;
108 }
109
110 if (n == 0) return static_cast<time_type>(0);
111 return std::sqrt(norm_squared / n);
112 }
113
114 // Suggest new step size based on error
115 time_type suggest_step_size(time_type current_dt, time_type error_norm, int order) const {
116 if (error_norm == 0) {
117 return std::min(current_dt * static_cast<time_type>(2), dt_max_);
118 }
119
120 time_type factor = safety_factor_ * std::pow(static_cast<time_type>(1) / error_norm,
121 static_cast<time_type>(1) / order);
122 factor = std::max(static_cast<time_type>(0.1), std::min(factor, static_cast<time_type>(5.0)));
123
124 return std::max(dt_min_, std::min(dt_max_, current_dt * factor));
125 }
126};
127
128} // namespace diffeq::core