DiffEq - Modern C++ ODE Integration Library 1.0.0
High-performance C++ library for solving ODEs with async signal processing
Loading...
Searching...
No Matches
sri1.hpp
1#pragma once
2
3#include <sde/sde_base.hpp>
4#include <core/state_creator.hpp>
5#include <cmath>
6
7namespace diffeq {
8
18template<system_state StateType>
19class SRI1Integrator : public sde::AbstractSDEIntegrator<StateType> {
20public:
22 using state_type = typename base_type::state_type;
23 using time_type = typename base_type::time_type;
24 using value_type = typename base_type::value_type;
25
26 explicit SRI1Integrator(std::shared_ptr<typename base_type::sde_problem_type> problem,
27 std::shared_ptr<typename base_type::wiener_process_type> wiener = nullptr)
28 : base_type(problem, wiener) {}
29
30 void step(state_type& state, time_type dt) override {
31 // Create temporary states
32 state_type k1 = StateCreator<state_type>::create(state);
33 state_type k2 = StateCreator<state_type>::create(state);
34 state_type g1 = StateCreator<state_type>::create(state);
35 state_type g2 = StateCreator<state_type>::create(state);
36 state_type temp_state = StateCreator<state_type>::create(state);
37 state_type dW = StateCreator<state_type>::create(state);
38
39 // Generate Wiener increments
40 this->wiener_->generate_increment(dW, dt);
41
42 time_type t = this->current_time_;
43 value_type sqrt_dt = std::sqrt(static_cast<value_type>(dt));
44
45 // Stage 1: k1 = f(t, X), g1 = g(t, X)
46 this->problem_->drift(t, state, k1);
47 this->problem_->diffusion(t, state, g1);
48
49 // Intermediate state for stage 2
50 for (size_t i = 0; i < state.size(); ++i) {
51 auto state_it = state.begin();
52 auto k1_it = k1.begin();
53 auto g1_it = g1.begin();
54 auto temp_it = temp_state.begin();
55 auto dW_it = dW.begin();
56
57 temp_it[i] = state_it[i] + k1_it[i] * dt + g1_it[i] * sqrt_dt;
58 }
59
60 // Stage 2: k2 = f(t + dt, temp_state), g2 = g(t + dt, temp_state)
61 this->problem_->drift(t + dt, temp_state, k2);
62 this->problem_->diffusion(t + dt, temp_state, g2);
63
64 // Apply noise to diffusion terms
65 state_type g1_noise = StateCreator<state_type>::create(state);
66 state_type g2_noise = StateCreator<state_type>::create(state);
67
68 for (size_t i = 0; i < state.size(); ++i) {
69 auto g1_it = g1.begin();
70 auto g2_it = g2.begin();
71 auto g1_noise_it = g1_noise.begin();
72 auto g2_noise_it = g2_noise.begin();
73
74 g1_noise_it[i] = g1_it[i];
75 g2_noise_it[i] = g2_it[i];
76 }
77
78 this->problem_->apply_noise(t, state, g1_noise, dW);
79 this->problem_->apply_noise(t + dt, temp_state, g2_noise, dW);
80
81 // Final update: X_{n+1} = X_n + (k1 + k2)/2 * dt + (g1 + g2)/2 * dW
82 for (size_t i = 0; i < state.size(); ++i) {
83 auto state_it = state.begin();
84 auto k1_it = k1.begin();
85 auto k2_it = k2.begin();
86 auto g1_noise_it = g1_noise.begin();
87 auto g2_noise_it = g2_noise.begin();
88
89 state_it[i] += (k1_it[i] + k2_it[i]) * dt * static_cast<value_type>(0.5) +
90 (g1_noise_it[i] + g2_noise_it[i]) * static_cast<value_type>(0.5);
91 }
92
93 this->advance_time(dt);
94 }
95
96 std::string name() const override {
97 return "SRI1 (Stochastic Runge-Kutta)";
98 }
99};
100
101} // namespace diffeq
Stochastic Runge-Kutta method (SRI1)
Definition sri1.hpp:19
Abstract base class for SDE integrators.
Definition sde_base.hpp:147