Program Listing for File rk23.hpp
↰ Return to documentation for file (include/integrators/ode/rk23.hpp
)
#pragma once
#include <core/concepts.hpp>
#include <core/adaptive_integrator.hpp>
#include <core/state_creator.hpp>
#include <stdexcept>
namespace diffeq {
template<system_state S>
class RK23Integrator : public core::AdaptiveIntegrator<S> {
public:
using base_type = core::AdaptiveIntegrator<S>;
using state_type = typename base_type::state_type;
using time_type = typename base_type::time_type;
using value_type = typename base_type::value_type;
using system_function = typename base_type::system_function;
explicit RK23Integrator(system_function sys,
time_type rtol = static_cast<time_type>(1e-6),
time_type atol = static_cast<time_type>(1e-9))
: base_type(std::move(sys), rtol, atol) {}
void step(state_type& state, time_type dt) override {
adaptive_step(state, dt);
}
time_type adaptive_step(state_type& state, time_type dt) override {
const int max_attempts = 10;
time_type current_dt = dt;
for (int attempt = 0; attempt < max_attempts; ++attempt) {
state_type y_new = StateCreator<state_type>::create(state);
state_type error = StateCreator<state_type>::create(state);
rk23_step(state, y_new, error, current_dt);
// Calculate error norm
time_type err_norm = this->error_norm(error, y_new);
if (err_norm <= 1.0) {
// Accept step
state = y_new;
this->advance_time(current_dt);
// Suggest next step size
time_type next_dt = this->suggest_step_size(current_dt, err_norm, 3);
return std::max<time_type>(this->dt_min_, std::min<time_type>(this->dt_max_, next_dt));
} else {
// Reject step and reduce step size
current_dt *= std::max<time_type>(this->safety_factor_ * std::pow(err_norm, -1.0/3.0),
static_cast<time_type>(0.1));
current_dt = std::max<time_type>(current_dt, this->dt_min_);
}
}
throw std::runtime_error("RK23: Maximum number of step size reductions exceeded");
}
private:
void rk23_step(const state_type& y, state_type& y_new, state_type& error, time_type dt) {
// Bogacki-Shampine coefficients matching SciPy's RK23
// C = [0, 1/2, 3/4]
// A = [[0, 0, 0], [1/2, 0, 0], [0, 3/4, 0]]
// B = [2/9, 1/3, 4/9] (3rd order)
// E = [5/72, -1/12, -1/9, 1/8] (error estimate)
state_type k1 = StateCreator<state_type>::create(y);
state_type k2 = StateCreator<state_type>::create(y);
state_type k3 = StateCreator<state_type>::create(y);
state_type k4 = StateCreator<state_type>::create(y); // For error estimation
state_type temp = StateCreator<state_type>::create(y);
time_type t = this->current_time_;
// k1 = f(t, y)
this->sys_(t, y, k1);
// k2 = f(t + dt/2, y + dt*k1/2)
for (std::size_t i = 0; i < y.size(); ++i) {
auto y_it = y.begin();
auto k1_it = k1.begin();
auto temp_it = temp.begin();
temp_it[i] = y_it[i] + dt * k1_it[i] / static_cast<time_type>(2);
}
this->sys_(t + dt / static_cast<time_type>(2), temp, k2);
// k3 = f(t + 3*dt/4, y + 3*dt*k2/4)
for (std::size_t i = 0; i < y.size(); ++i) {
auto y_it = y.begin();
auto k2_it = k2.begin();
auto temp_it = temp.begin();
temp_it[i] = y_it[i] + static_cast<time_type>(3) * dt * k2_it[i] / static_cast<time_type>(4);
}
this->sys_(t + static_cast<time_type>(3) * dt / static_cast<time_type>(4), temp, k3);
// 3rd order solution: y_new = y + dt*(2*k1/9 + k2/3 + 4*k3/9)
for (std::size_t i = 0; i < y.size(); ++i) {
auto y_it = y.begin();
auto k1_it = k1.begin();
auto k2_it = k2.begin();
auto k3_it = k3.begin();
auto y_new_it = y_new.begin();
y_new_it[i] = y_it[i] + dt * (static_cast<time_type>(2) * k1_it[i] / static_cast<time_type>(9) +
k2_it[i] / static_cast<time_type>(3) +
static_cast<time_type>(4) * k3_it[i] / static_cast<time_type>(9));
}
// k4 = f(t + dt, y_new) - needed for error estimation
this->sys_(t + dt, y_new, k4);
// Error estimate using E = [5/72, -1/12, -1/9, 1/8]
for (std::size_t i = 0; i < y.size(); ++i) {
auto k1_it = k1.begin();
auto k2_it = k2.begin();
auto k3_it = k3.begin();
auto k4_it = k4.begin();
auto error_it = error.begin();
error_it[i] = dt * (static_cast<time_type>(5) * k1_it[i] / static_cast<time_type>(72) -
k2_it[i] / static_cast<time_type>(12) -
k3_it[i] / static_cast<time_type>(9) +
k4_it[i] / static_cast<time_type>(8));
}
}
};
} // namespace diffeq