DiffEq - Modern C++ ODE Integration Library 1.0.0
High-performance C++ library for solving ODEs with async signal processing
Loading...
Searching...
No Matches
dop853.hpp
1#pragma once
2#include <core/adaptive_integrator.hpp>
3#include <core/state_creator.hpp>
4#include <integrators/ode/dop853_coefficients.hpp>
5#include <cmath>
6#include <stdexcept>
7
8namespace diffeq {
9
10template<system_state S>
11class DOP853Integrator;
12
13template<system_state S>
15public:
16 using value_type = typename DOP853Integrator<S>::value_type;
17
18 // Dense output for DOP853: ported from Fortran CONTD8
19 // CON: continuous output coefficients, size 8*nd
20 // ICOMP: index mapping, size nd
21 // nd: number of dense output components
22 // xold: left endpoint of interval, h: step size
23 // x: interpolation point (xold <= x <= xold+h)
24 // Returns: interpolated value for component ii at x
25 static value_type contd8(
26 int ii, value_type x, const value_type* con, const int* icomp, int nd,
27 value_type xold, value_type h) {
28 int i = -1;
29 for (int j = 0; j < nd; ++j) {
30 if (icomp[j] == ii) { i = j; break; }
31 }
32 if (i == -1) {
33 throw std::runtime_error("No dense output available for component " + std::to_string(ii));
34 }
35 value_type s = (x - xold) / h;
36 value_type s1 = 1.0 - s;
37 value_type conpar = con[i + nd*4] + s * (con[i + nd*5] + s1 * (con[i + nd*6] + s * con[i + nd*7]));
38 value_type result = con[i] + s * (con[i + nd] + s1 * (con[i + nd*2] + s * (con[i + nd*3] + s1 * conpar)));
39 return result;
40 }
41};
42
43
50template<system_state S>
52
53
54public:
56 using state_type = typename base_type::state_type;
57 using time_type = typename base_type::time_type;
58 using value_type = typename base_type::value_type;
59 using system_function = typename base_type::system_function;
60
61 // Fortran default parameters (do not change unless you know what you are doing)
62 static constexpr time_type fortran_safety = static_cast<time_type>(0.9); // SAFE
63 static constexpr time_type fortran_fac1 = static_cast<time_type>(0.333); // FAC1 (min step size factor)
64 static constexpr time_type fortran_fac2 = static_cast<time_type>(6.0); // FAC2 (max step size factor)
65 static constexpr time_type fortran_beta = static_cast<time_type>(0.0); // BETA (step size stabilization)
66 static constexpr time_type fortran_dt_max = static_cast<time_type>(1e100); // HMAX (max step size)
67 static constexpr time_type fortran_dt_min = static_cast<time_type>(1e-16); // practical min (not in Fortran, but practical)
68 static constexpr int fortran_nmax = 100000; // NMAX (max steps)
69 static constexpr int fortran_nstiff = 1000; // NSTIFF (stiffness test interval)
70
71 // Internal state (Fortran default values)
72 time_type safety_factor_ = fortran_safety;
73 time_type fac1_ = fortran_fac1;
74 time_type fac2_ = fortran_fac2;
75 time_type beta_ = fortran_beta;
76 time_type dt_max_ = fortran_dt_max;
77 time_type dt_min_ = fortran_dt_min;
78 int nmax_ = fortran_nmax;
79 int nstiff_ = fortran_nstiff;
80 time_type facold_ = static_cast<time_type>(1e-4); // Fortran FACOLD
81 // Stiffness detection state
82 int iastiff_ = 0;
83 int nonsti_ = 0;
84 time_type hlamb_ = 0;
85 // For statistics (optional)
86 int nstep_ = 0;
87 int naccpt_ = 0;
88 int nrejct_ = 0;
89 int nfcn_ = 0;
90
91private:
92 void check_nan_inf(const std::string& context, const state_type& state, const state_type& y_new,
93 const state_type& error, time_type dt, time_type err, time_type err2, time_type deno) {
94 // Check for NaN/Inf in all vectors and scalars
95 for (std::size_t i = 0; i < state.size(); ++i) {
96 if (std::isnan(state[i]) || std::isinf(state[i])) {
97 throw std::runtime_error("DOP853: NaN/Inf detected in " + context + " state[" + std::to_string(i) + "]=" + std::to_string(state[i]));
98 }
99 }
100 for (std::size_t i = 0; i < y_new.size(); ++i) {
101 if (std::isnan(y_new[i]) || std::isinf(y_new[i])) {
102 throw std::runtime_error("DOP853: NaN/Inf detected in " + context + " y_new[" + std::to_string(i) + "]=" + std::to_string(y_new[i]));
103 }
104 }
105 for (std::size_t i = 0; i < error.size(); ++i) {
106 if (std::isnan(error[i]) || std::isinf(error[i])) {
107 throw std::runtime_error("DOP853: NaN/Inf detected in " + context + " error[" + std::to_string(i) + "]=" + std::to_string(error[i]));
108 }
109 }
110 if (std::isnan(dt) || std::isinf(dt)) {
111 throw std::runtime_error("DOP853: NaN/Inf detected in " + context + " dt=" + std::to_string(dt));
112 }
113 if (std::isnan(err) || std::isinf(err)) {
114 throw std::runtime_error("DOP853: NaN/Inf detected in " + context + " err=" + std::to_string(err));
115 }
116 if (std::isnan(err2) || std::isinf(err2)) {
117 throw std::runtime_error("DOP853: NaN/Inf detected in " + context + " err2=" + std::to_string(err2));
118 }
119 if (std::isnan(deno) || std::isinf(deno)) {
120 throw std::runtime_error("DOP853: NaN/Inf detected in " + context + " deno=" + std::to_string(deno));
121 }
122 }
123
124 // Compute a good initial step size (HINIT from Fortran)
125 time_type compute_initial_step(const state_type& y, time_type t, const system_function& sys, time_type t_end) const {
126 // Compute f0 = f(t, y)
127 state_type f0 = StateCreator<state_type>::create(y);
128 sys(t, y, f0);
129
130 // Compute a norm for y and f0
131 time_type dnf = 0.0, dny = 0.0;
132 for (std::size_t i = 0; i < y.size(); ++i) {
133 time_type sk = this->atol_ + this->rtol_ * std::abs(y[i]);
134 dnf += (f0[i] / sk) * (f0[i] / sk);
135 dny += (y[i] / sk) * (y[i] / sk);
136 }
137 time_type h = 1e-6;
138 if (dnf > 1e-10 && dny > 1e-10) {
139 h = std::sqrt(dny / dnf) * 0.01;
140 }
141 h = std::min<time_type>(h, std::abs(t_end - t));
142 h = std::copysign(h, t_end - t);
143
144 // Perform an explicit Euler step
145 state_type y1 = StateCreator<state_type>::create(y);
146 for (std::size_t i = 0; i < y.size(); ++i)
147 y1[i] = y[i] + h * f0[i];
148 state_type f1 = StateCreator<state_type>::create(y);
149 sys(t + h, y1, f1);
150
151 // Estimate the second derivative
152 time_type der2 = 0.0;
153 for (std::size_t i = 0; i < y.size(); ++i) {
154 time_type sk = this->atol_ + this->rtol_ * std::abs(y[i]);
155 der2 += ((f1[i] - f0[i]) / sk) * ((f1[i] - f0[i]) / sk);
156 }
157 der2 = std::sqrt(der2) / h;
158
159 // Step size is computed such that h^order * max(norm(f0), norm(der2)) = 0.01
160 time_type der12 = std::max<time_type>(std::abs(der2), std::sqrt(dnf));
161 time_type h1 = h;
162 if (der12 > 1e-15) {
163 h1 = std::pow(0.01 / der12, 1.0 / 8.0);
164 } else {
165 h1 = std::max<time_type>(1e-6, std::abs(h) * 1e-3);
166 }
167 // Avoid std::min(a, b, c) which is not standard C++
168 time_type hmax = 100 * std::abs(h);
169 time_type htmp = (h1 < hmax) ? h1 : hmax;
170 htmp = (htmp < std::abs(t_end - t)) ? htmp : std::abs(t_end - t);
171 h = std::copysign(htmp, t_end - t);
172 return h;
173 }
174
175public:
176 explicit DOP853Integrator(system_function sys,
177 time_type rtol = static_cast<time_type>(1e-8),
178 time_type atol = static_cast<time_type>(1e-10))
179 : base_type(std::move(sys), rtol, atol) {}
180
181 void step(state_type& state, time_type dt) override {
182 adaptive_step(state, dt);
183 }
184
185 // To match Fortran DOP853, we need to know the integration target time for HINIT
186 // This version assumes you set target_time_ before calling adaptive_step
187 time_type target_time_ = 0; // User must set this before integration
188
189 time_type adaptive_step(state_type& state, time_type dt) override {
190 // Fortran DOP853: if dt <= 0, estimate initial step size using HINIT (compute_initial_step)
191 time_type t = this->current_time_;
192 time_type t_end = target_time_;
193 if (t_end == t) t_end = t + 1.0; // fallback if not set
194 time_type current_dt = dt;
195 if (current_dt <= 0) {
196 // Use the system function and current state to estimate initial step
197 current_dt = compute_initial_step(state, t, this->sys_, t_end);
198 // Clamp to allowed min/max
199 current_dt = std::max<time_type>(dt_min_, std::min<time_type>(dt_max_, current_dt));
200 }
201 int attempt = 0;
202 for (; attempt < nmax_; ++attempt) {
203
204 state_type y_new = StateCreator<state_type>::create(state);
205 state_type error = StateCreator<state_type>::create(state);
206 dop853_step(state, y_new, error, current_dt);
207 nfcn_ += 12; // 12 stages per step
208
209 // Check for NaN/Inf in step computation
210 check_nan_inf("step_computation", state, y_new, error, current_dt, 0.0, 0.0, 0.0);
211
212 // Fortran error norm (ERR, ERR2, DENO, etc.)
213 time_type err = 0.0, err2 = 0.0;
214 for (std::size_t i = 0; i < state.size(); ++i) {
215 time_type sk = this->atol_ + this->rtol_ * std::max<time_type>(std::abs(state[i]), std::abs(y_new[i]));
216 // Fortran: ERRI=K4(I)-BHH1*K1(I)-BHH2*K9(I)-BHH3*K3(I) (here, error[i] is 8th-5th order diff)
217 // We use error[i] as the embedded error estimate, so for full Fortran, you may need to store all k's
218 err2 += (error[i] / sk) * (error[i] / sk); // proxy for Fortran's ERR2
219 err += (error[i] / sk) * (error[i] / sk); // Fortran's ERR
220 }
221 time_type deno = err + 0.01 * err2;
222 if (deno <= 0.0 || std::isnan(deno) || std::isinf(deno)) {
223 deno = 1.0;
224 }
225 err = std::abs(current_dt) * err * std::sqrt(1.0 / (state.size() * deno));
226 if (std::isnan(err) || std::isinf(err)) {
227 err = 1.0;
228 }
229
230 // Check for NaN/Inf in error norm calculation
231 check_nan_inf("error_norm", state, y_new, error, current_dt, err, err2, deno);
232
233 // Fortran: FAC11 = ERR**EXPO1, FAC = FAC11 / FACOLD**BETA
234 time_type expo1 = 1.0 / 8.0 - beta_ * 0.2;
235 time_type fac11 = std::pow(std::max<time_type>(err, static_cast<time_type>(1e-16)), expo1);
236 time_type fac = fac11 / std::pow(facold_, beta_);
237 // Clamp fac between fac1_ (min, <1) and fac2_ (max, >1)
238 fac = std::min<time_type>(fac2_, std::max<time_type>(fac1_, fac / safety_factor_));
239 if (std::isnan(fac) || std::isinf(fac)) {
240 fac = 1.0;
241 }
242 time_type next_dt = current_dt / fac;
243 if (next_dt <= 0.0 || std::isnan(next_dt) || std::isinf(next_dt)) {
244 next_dt = dt_min_;
245 }
246
247 if (err <= 1.0) {
248 facold_ = std::max<time_type>(err, static_cast<time_type>(1e-4));
249 naccpt_++;
250 nstep_++;
251 state = y_new;
252 this->advance_time(current_dt);
253
254 // stiffness detection (Fortran HLAMB)
255 if (nstiff_ > 0 && (naccpt_ % nstiff_ == 0 || iastiff_ > 0)) {
256 // Compute HLAMB = |h| * sqrt(stnum / stden)
257 time_type stnum = 0, stden = 0;
258 for (std::size_t i = 0; i < state.size(); ++i) {
259 stnum += (error[i]) * (error[i]);
260 stden += (y_new[i] - state[i]) * (y_new[i] - state[i]);
261 }
262 if (stden > 0) hlamb_ = std::abs(current_dt) * std::sqrt(stnum / stden);
263 if (hlamb_ > 6.1) {
264 nonsti_ = 0;
265 iastiff_++;
266 if (iastiff_ == 15) {
267 throw std::runtime_error("DOP853: Problem seems to become stiff");
268 }
269 } else {
270 nonsti_++;
271 if (nonsti_ == 6) iastiff_ = 0;
272 }
273 }
274 // Clamp next step size
275 next_dt = std::max<time_type>(dt_min_, std::min<time_type>(dt_max_, next_dt));
276 return next_dt;
277 } else {
278 // Step rejected
279 nrejct_++;
280 nstep_++;
281 next_dt = current_dt / std::min<time_type>(fac1_, fac11 / safety_factor_);
282 current_dt = std::max<time_type>(dt_min_, next_dt);
283 }
284 }
285 throw std::runtime_error("DOP853: Maximum number of step size reductions or steps exceeded");
286 }
287
288private:
289 // Helper functions to access coefficients
290 static constexpr time_type get_c(int i) { return diffeq::integrators::ode::dop853::C<time_type>[i]; }
291 static constexpr time_type get_a(int i, int j) { return diffeq::integrators::ode::dop853::A<time_type>[i][j]; }
292 static constexpr time_type get_b(int i) { return diffeq::integrators::ode::dop853::B<time_type>[i]; }
293 static constexpr time_type get_e5(int i) { return diffeq::integrators::ode::dop853::E5<time_type>[i]; }
294
295 void dop853_step(const state_type& y, state_type& y_new, state_type& error, time_type dt) {
296 // Allocate all needed k vectors
297 std::vector<state_type> k(12, StateCreator<state_type>::create(y));
298 state_type temp = StateCreator<state_type>::create(y);
299 time_type t = this->current_time_;
300
301 // k1 = f(t, y)
302 this->sys_(t, y, k[0]);
303
304 // k2 = f(t + c2*dt, y + dt*a21*k1)
305 for (std::size_t i = 0; i < y.size(); ++i)
306 temp[i] = y[i] + dt * get_a(1, 0) * k[0][i];
307 this->sys_(t + get_c(1) * dt, temp, k[1]);
308
309 // k3 = f(t + c3*dt, y + dt*(a31*k1 + a32*k2))
310 for (std::size_t i = 0; i < y.size(); ++i)
311 temp[i] = y[i] + dt * (get_a(2, 0) * k[0][i] + get_a(2, 1) * k[1][i]);
312 this->sys_(t + get_c(2) * dt, temp, k[2]);
313
314 // k4 = f(t + c4*dt, y + dt*(a41*k1 + a43*k3))
315 for (std::size_t i = 0; i < y.size(); ++i)
316 temp[i] = y[i] + dt * (get_a(3, 0) * k[0][i] + get_a(3, 2) * k[2][i]);
317 this->sys_(t + get_c(3) * dt, temp, k[3]);
318
319 // k5 = f(t + c5*dt, y + dt*(a51*k1 + a53*k3 + a54*k4))
320 for (std::size_t i = 0; i < y.size(); ++i)
321 temp[i] = y[i] + dt * (get_a(4, 0) * k[0][i] + get_a(4, 2) * k[2][i] + get_a(4, 3) * k[3][i]);
322 this->sys_(t + get_c(4) * dt, temp, k[4]);
323
324 // k6 = f(t + c6*dt, y + dt*(a61*k1 + a64*k4 + a65*k5))
325 for (std::size_t i = 0; i < y.size(); ++i)
326 temp[i] = y[i] + dt * (get_a(5, 0) * k[0][i] + get_a(5, 3) * k[3][i] + get_a(5, 4) * k[4][i]);
327 this->sys_(t + get_c(5) * dt, temp, k[5]);
328
329 // k7 = f(t + c7*dt, y + dt*(a71*k1 + a74*k4 + a75*k5 + a76*k6))
330 for (std::size_t i = 0; i < y.size(); ++i)
331 temp[i] = y[i] + dt * (get_a(6, 0) * k[0][i] + get_a(6, 3) * k[3][i] + get_a(6, 4) * k[4][i] + get_a(6, 5) * k[5][i]);
332 this->sys_(t + get_c(6) * dt, temp, k[6]);
333
334 // k8 = f(t + c8*dt, y + dt*(a81*k1 + a84*k4 + a85*k5 + a86*k6 + a87*k7))
335 for (std::size_t i = 0; i < y.size(); ++i)
336 temp[i] = y[i] + dt * (get_a(7, 0) * k[0][i] + get_a(7, 3) * k[3][i] + get_a(7, 4) * k[4][i] + get_a(7, 5) * k[5][i] + get_a(7, 6) * k[6][i]);
337 this->sys_(t + get_c(7) * dt, temp, k[7]);
338
339 // k9 = f(t + c9*dt, y + dt*(a91*k1 + a94*k4 + a95*k5 + a96*k6 + a97*k7 + a98*k8))
340 for (std::size_t i = 0; i < y.size(); ++i)
341 temp[i] = y[i] + dt * (get_a(8, 0) * k[0][i] + get_a(8, 3) * k[3][i] + get_a(8, 4) * k[4][i] + get_a(8, 5) * k[5][i] + get_a(8, 6) * k[6][i] + get_a(8, 7) * k[7][i]);
342 this->sys_(t + get_c(8) * dt, temp, k[8]);
343
344 // k10 = f(t + c10*dt, y + dt*(a101*k1 + a104*k4 + a105*k5 + a106*k6 + a107*k7 + a108*k8 + a109*k9))
345 for (std::size_t i = 0; i < y.size(); ++i)
346 temp[i] = y[i] + dt * (get_a(9, 0) * k[0][i] + get_a(9, 3) * k[3][i] + get_a(9, 4) * k[4][i] + get_a(9, 5) * k[5][i] + get_a(9, 6) * k[6][i] + get_a(9, 7) * k[7][i] + get_a(9, 8) * k[8][i]);
347 this->sys_(t + get_c(9) * dt, temp, k[9]);
348
349 // k11 = f(t + c11*dt, y + dt*(a111*k1 + a114*k4 + a115*k5 + a116*k6 + a117*k7 + a118*k8 + a119*k9 + a1110*k10))
350 for (std::size_t i = 0; i < y.size(); ++i)
351 temp[i] = y[i] + dt * (get_a(10, 0) * k[0][i] + get_a(10, 3) * k[3][i] + get_a(10, 4) * k[4][i] + get_a(10, 5) * k[5][i] + get_a(10, 6) * k[6][i] + get_a(10, 7) * k[7][i] + get_a(10, 8) * k[8][i] + get_a(10, 9) * k[9][i]);
352 this->sys_(t + get_c(10) * dt, temp, k[10]);
353
354 // k12 = f(t + dt, y + dt*(a121*k1 + a124*k4 + a125*k5 + a126*k6 + a127*k7 + a128*k8 + a129*k9 + a1210*k10 + a1211*k11))
355 for (std::size_t i = 0; i < y.size(); ++i)
356 temp[i] = y[i] + dt * (get_a(11, 0) * k[0][i] + get_a(11, 3) * k[3][i] + get_a(11, 4) * k[4][i] + get_a(11, 5) * k[5][i] + get_a(11, 6) * k[6][i] + get_a(11, 7) * k[7][i] + get_a(11, 8) * k[8][i] + get_a(11, 9) * k[9][i] + get_a(11, 10) * k[10][i]);
357 this->sys_(t + dt, temp, k[11]);
358
359 // 8th order solution (y_new)
360 for (std::size_t i = 0; i < y.size(); ++i) {
361 y_new[i] = y[i] + dt * (get_b(0) * k[0][i] + get_b(5) * k[5][i] + get_b(6) * k[6][i] + get_b(7) * k[7][i] + get_b(8) * k[8][i] + get_b(9) * k[9][i] + get_b(10) * k[10][i] + get_b(11) * k[11][i]);
362 }
363
364 // 5th order error estimate (embedded)
365 for (std::size_t i = 0; i < y.size(); ++i) {
366 error[i] = dt * (get_e5(0) * k[0][i] + get_e5(5) * k[5][i] + get_e5(6) * k[6][i] + get_e5(7) * k[7][i] + get_e5(8) * k[8][i] + get_e5(9) * k[9][i] + get_e5(10) * k[10][i] + get_e5(11) * k[11][i]);
367 }
368 }
369};
370
371} // namespace diffeq
DOP853 (Dormand-Prince 8(5,3)) adaptive integrator.
Definition dop853.hpp:51