DiffEq - Modern C++ ODE Integration Library 1.0.0
High-performance C++ library for solving ODEs with async signal processing
Loading...
Searching...
No Matches
rk4.hpp
1#pragma once
2#include <core/concepts.hpp>
3#include <core/abstract_integrator.hpp>
4#include <core/state_creator.hpp>
5
6namespace diffeq {
7
18template<system_state S>
20public:
22 using state_type = typename base_type::state_type;
23 using time_type = typename base_type::time_type;
24 using value_type = typename base_type::value_type;
25 using system_function = typename base_type::system_function;
26
27 explicit RK4Integrator(system_function sys)
28 : base_type(std::move(sys)) {}
29
30 void step(state_type& state, time_type dt) override {
31 // Create temporary states for RK4 calculations
32 state_type k1 = StateCreator<state_type>::create(state);
33 state_type k2 = StateCreator<state_type>::create(state);
34 state_type k3 = StateCreator<state_type>::create(state);
35 state_type k4 = StateCreator<state_type>::create(state);
36 state_type temp_state = StateCreator<state_type>::create(state);
37
38 // k1 = f(t, y)
39 this->sys_(this->current_time_, state, k1);
40
41 // k2 = f(t + dt/2, y + dt*k1/2)
42 for (std::size_t i = 0; i < state.size(); ++i) {
43 auto state_it = state.begin();
44 auto k1_it = k1.begin();
45 auto temp_it = temp_state.begin();
46
47 temp_it[i] = state_it[i] + dt * k1_it[i] / static_cast<time_type>(2);
48 }
49 this->sys_(this->current_time_ + dt / static_cast<time_type>(2), temp_state, k2);
50
51 // k3 = f(t + dt/2, y + dt*k2/2)
52 for (std::size_t i = 0; i < state.size(); ++i) {
53 auto state_it = state.begin();
54 auto k2_it = k2.begin();
55 auto temp_it = temp_state.begin();
56
57 temp_it[i] = state_it[i] + dt * k2_it[i] / static_cast<time_type>(2);
58 }
59 this->sys_(this->current_time_ + dt / static_cast<time_type>(2), temp_state, k3);
60
61 // k4 = f(t + dt, y + dt*k3)
62 for (std::size_t i = 0; i < state.size(); ++i) {
63 auto state_it = state.begin();
64 auto k3_it = k3.begin();
65 auto temp_it = temp_state.begin();
66
67 temp_it[i] = state_it[i] + dt * k3_it[i];
68 }
69 this->sys_(this->current_time_ + dt, temp_state, k4);
70
71 // y_new = y + dt/6 * (k1 + 2*k2 + 2*k3 + k4)
72 for (std::size_t i = 0; i < state.size(); ++i) {
73 auto state_it = state.begin();
74 auto k1_it = k1.begin();
75 auto k2_it = k2.begin();
76 auto k3_it = k3.begin();
77 auto k4_it = k4.begin();
78
79 state_it[i] = state_it[i] + dt * (k1_it[i] + static_cast<time_type>(2) * k2_it[i] +
80 static_cast<time_type>(2) * k3_it[i] + k4_it[i]) / static_cast<time_type>(6);
81 }
82
83 this->advance_time(dt);
84 }
85};
86
87} // namespace diffeq
Classical 4th-order Runge-Kutta integrator.
Definition rk4.hpp:19