DiffEq - Modern C++ ODE Integration Library 1.0.0
High-performance C++ library for solving ODEs with async signal processing
Loading...
Searching...
No Matches
rk23.hpp
1#pragma once
2#include <core/concepts.hpp>
3#include <core/adaptive_integrator.hpp>
4#include <core/state_creator.hpp>
5#include <stdexcept>
6
7namespace diffeq {
8
20template<system_state S>
22public:
24 using state_type = typename base_type::state_type;
25 using time_type = typename base_type::time_type;
26 using value_type = typename base_type::value_type;
27 using system_function = typename base_type::system_function;
28
29 explicit RK23Integrator(system_function sys,
30 time_type rtol = static_cast<time_type>(1e-6),
31 time_type atol = static_cast<time_type>(1e-9))
32 : base_type(std::move(sys), rtol, atol) {}
33
34 void step(state_type& state, time_type dt) override {
35 adaptive_step(state, dt);
36 }
37
38 time_type adaptive_step(state_type& state, time_type dt) override {
39 const int max_attempts = 10;
40 time_type current_dt = dt;
41
42 for (int attempt = 0; attempt < max_attempts; ++attempt) {
43 state_type y_new = StateCreator<state_type>::create(state);
44 state_type error = StateCreator<state_type>::create(state);
45
46 rk23_step(state, y_new, error, current_dt);
47
48 // Calculate error norm
49 time_type err_norm = this->error_norm(error, y_new);
50
51 if (err_norm <= 1.0) {
52 // Accept step
53 state = y_new;
54 this->advance_time(current_dt);
55
56 // Suggest next step size
57 time_type next_dt = this->suggest_step_size(current_dt, err_norm, 3);
58 return std::max<time_type>(this->dt_min_, std::min<time_type>(this->dt_max_, next_dt));
59 } else {
60 // Reject step and reduce step size
61 current_dt *= std::max<time_type>(this->safety_factor_ * std::pow(err_norm, -1.0/3.0),
62 static_cast<time_type>(0.1));
63 current_dt = std::max<time_type>(current_dt, this->dt_min_);
64 }
65 }
66
67 throw std::runtime_error("RK23: Maximum number of step size reductions exceeded");
68 }
69
70private:
71 void rk23_step(const state_type& y, state_type& y_new, state_type& error, time_type dt) {
72 // Bogacki-Shampine coefficients matching SciPy's RK23
73 // C = [0, 1/2, 3/4]
74 // A = [[0, 0, 0], [1/2, 0, 0], [0, 3/4, 0]]
75 // B = [2/9, 1/3, 4/9] (3rd order)
76 // E = [5/72, -1/12, -1/9, 1/8] (error estimate)
77
78 state_type k1 = StateCreator<state_type>::create(y);
79 state_type k2 = StateCreator<state_type>::create(y);
80 state_type k3 = StateCreator<state_type>::create(y);
81 state_type k4 = StateCreator<state_type>::create(y); // For error estimation
82 state_type temp = StateCreator<state_type>::create(y);
83
84 time_type t = this->current_time_;
85
86 // k1 = f(t, y)
87 this->sys_(t, y, k1);
88
89 // k2 = f(t + dt/2, y + dt*k1/2)
90 for (std::size_t i = 0; i < y.size(); ++i) {
91 auto y_it = y.begin();
92 auto k1_it = k1.begin();
93 auto temp_it = temp.begin();
94 temp_it[i] = y_it[i] + dt * k1_it[i] / static_cast<time_type>(2);
95 }
96 this->sys_(t + dt / static_cast<time_type>(2), temp, k2);
97
98 // k3 = f(t + 3*dt/4, y + 3*dt*k2/4)
99 for (std::size_t i = 0; i < y.size(); ++i) {
100 auto y_it = y.begin();
101 auto k2_it = k2.begin();
102 auto temp_it = temp.begin();
103 temp_it[i] = y_it[i] + static_cast<time_type>(3) * dt * k2_it[i] / static_cast<time_type>(4);
104 }
105 this->sys_(t + static_cast<time_type>(3) * dt / static_cast<time_type>(4), temp, k3);
106
107 // 3rd order solution: y_new = y + dt*(2*k1/9 + k2/3 + 4*k3/9)
108 for (std::size_t i = 0; i < y.size(); ++i) {
109 auto y_it = y.begin();
110 auto k1_it = k1.begin();
111 auto k2_it = k2.begin();
112 auto k3_it = k3.begin();
113 auto y_new_it = y_new.begin();
114
115 y_new_it[i] = y_it[i] + dt * (static_cast<time_type>(2) * k1_it[i] / static_cast<time_type>(9) +
116 k2_it[i] / static_cast<time_type>(3) +
117 static_cast<time_type>(4) * k3_it[i] / static_cast<time_type>(9));
118 }
119
120 // k4 = f(t + dt, y_new) - needed for error estimation
121 this->sys_(t + dt, y_new, k4);
122
123 // Error estimate using E = [5/72, -1/12, -1/9, 1/8]
124 for (std::size_t i = 0; i < y.size(); ++i) {
125 auto k1_it = k1.begin();
126 auto k2_it = k2.begin();
127 auto k3_it = k3.begin();
128 auto k4_it = k4.begin();
129 auto error_it = error.begin();
130
131 error_it[i] = dt * (static_cast<time_type>(5) * k1_it[i] / static_cast<time_type>(72) -
132 k2_it[i] / static_cast<time_type>(12) -
133 k3_it[i] / static_cast<time_type>(9) +
134 k4_it[i] / static_cast<time_type>(8));
135 }
136 }
137};
138
139} // namespace diffeq
RK23 (Bogacki-Shampine) adaptive integrator.
Definition rk23.hpp:21