DiffEq - Modern C++ ODE Integration Library 1.0.0
High-performance C++ library for solving ODEs with async signal processing
Loading...
Searching...
No Matches
sri.hpp
1#pragma once
2
3#include <sde/sde_base.hpp>
4#include <core/state_creator.hpp>
5#include <cmath>
6#include <vector>
7#include <algorithm>
8
9namespace diffeq {
10
14template<typename T>
15struct SRITableau {
16 // Drift coefficients
17 std::vector<std::vector<T>> A0, A1;
18 std::vector<T> c0;
19 std::vector<T> alpha;
20
21 // Diffusion coefficients
22 std::vector<std::vector<T>> B0, B1;
23 std::vector<T> c1;
24 std::vector<T> beta1, beta2, beta3, beta4;
25
26 int stages;
27 T order;
28};
29
43template<system_state StateType>
44class SRIIntegrator : public sde::AbstractSDEIntegrator<StateType> {
45public:
47 using state_type = typename base_type::state_type;
48 using time_type = typename base_type::time_type;
49 using value_type = typename base_type::value_type;
51
52 explicit SRIIntegrator(std::shared_ptr<typename base_type::sde_problem_type> problem,
53 std::shared_ptr<typename base_type::wiener_process_type> wiener = nullptr,
54 tableau_type tableau = SRIIntegrator::create_sriw1_tableau())
55 : base_type(problem, wiener)
56 , tableau_(std::move(tableau)) {}
57
58 void step(state_type& state, time_type dt) override {
59 const int stages = tableau_.stages;
60
61 // Create temporary states
62 std::vector<state_type> H0(stages), H1(stages);
63 for (int i = 0; i < stages; ++i) {
66 }
67
68 state_type dW = StateCreator<state_type>::create(state);
69 state_type dZ = StateCreator<state_type>::create(state);
70 state_type ftmp = StateCreator<state_type>::create(state);
71 state_type gtmp = StateCreator<state_type>::create(state);
72
73 // Generate Wiener increments
74 this->wiener_->generate_increment(dW, dt);
75 this->wiener_->generate_increment(dZ, dt);
76
77 // Compute multiple stochastic integrals
78 value_type sqrt3 = std::sqrt(static_cast<value_type>(3));
79 value_type sqrt_dt = std::sqrt(static_cast<value_type>(dt));
80
81 // chi1 = (1/2) * ((dW)^2 - dt) / sqrt(dt) for I_(1,1)/sqrt(h)
82 // chi2 = (1/2) * (dW + dZ/sqrt(3)) for I_(1,0)/h
83 // chi3 = (1/6) * ((dW)^3 - 3*dW*dt) / dt for I_(1,1,1)/h
84 state_type chi1 = StateCreator<state_type>::create(state);
85 state_type chi2 = StateCreator<state_type>::create(state);
86 state_type chi3 = StateCreator<state_type>::create(state);
87
88 for (size_t j = 0; j < state.size(); ++j) {
89 auto dW_it = dW.begin();
90 auto dZ_it = dZ.begin();
91 auto chi1_it = chi1.begin();
92 auto chi2_it = chi2.begin();
93 auto chi3_it = chi3.begin();
94
95 value_type dW_val = dW_it[j];
96 value_type dW_squared = dW_val * dW_val;
97
98 chi1_it[j] = static_cast<value_type>(0.5) * (dW_squared - dt) / sqrt_dt;
99 chi2_it[j] = static_cast<value_type>(0.5) * (dW_val + dZ_it[j] / sqrt3);
100 chi3_it[j] = static_cast<value_type>(1.0/6.0) * (dW_val * dW_squared - 3 * dW_val * dt) / dt;
101 }
102
103 // Initialize H0[0] = H1[0] = current state
104 for (size_t j = 0; j < state.size(); ++j) {
105 auto state_it = state.begin();
106 auto H0_0_it = H0[0].begin();
107 auto H1_0_it = H1[0].begin();
108 H0_0_it[j] = state_it[j];
109 H1_0_it[j] = state_it[j];
110 }
111
112 // Compute stages
113 for (int i = 1; i < stages; ++i) {
114 state_type A0temp = StateCreator<state_type>::create(state);
115 state_type A1temp = StateCreator<state_type>::create(state);
116 state_type B0temp = StateCreator<state_type>::create(state);
117 state_type B1temp = StateCreator<state_type>::create(state);
118
119 for (int j = 0; j < i; ++j) {
120 this->problem_->drift(this->current_time_ + tableau_.c0[j] * dt, H0[j], ftmp);
121 this->problem_->diffusion(this->current_time_ + tableau_.c1[j] * dt, H1[j], gtmp);
122
123 for (size_t k = 0; k < state.size(); ++k) {
124 auto A0temp_it = A0temp.begin();
125 auto A1temp_it = A1temp.begin();
126 auto B0temp_it = B0temp.begin();
127 auto B1temp_it = B1temp.begin();
128 auto ftmp_it = ftmp.begin();
129 auto gtmp_it = gtmp.begin();
130 auto chi1_it = chi1.begin();
131 auto chi2_it = chi2.begin();
132
133 A0temp_it[k] += tableau_.A0[j][i] * ftmp_it[k];
134 A1temp_it[k] += tableau_.A1[j][i] * ftmp_it[k];
135 B0temp_it[k] += tableau_.B0[j][i] * gtmp_it[k];
136 B1temp_it[k] += tableau_.B1[j][i] * gtmp_it[k] * chi1_it[k];
137 }
138 }
139
140 // Update H0[i] and H1[i]
141 for (size_t k = 0; k < state.size(); ++k) {
142 auto state_it = state.begin();
143 auto H0_i_it = H0[i].begin();
144 auto H1_i_it = H1[i].begin();
145 auto A0temp_it = A0temp.begin();
146 auto A1temp_it = A1temp.begin();
147 auto B0temp_it = B0temp.begin();
148 auto B1temp_it = B1temp.begin();
149 auto chi2_it = chi2.begin();
150 auto dW_it = dW.begin();
151
152 H0_i_it[k] = state_it[k] + dt * A0temp_it[k] + B0temp_it[k] * dW_it[k];
153 H1_i_it[k] = state_it[k] + dt * A1temp_it[k] + B0temp_it[k] * sqrt_dt + B1temp_it[k] + chi2_it[k] * B0temp_it[k];
154 }
155 }
156
157 // Compute final update
158 state_type drift_sum = StateCreator<state_type>::create(state);
159 state_type E1 = StateCreator<state_type>::create(state);
160 state_type E2 = StateCreator<state_type>::create(state);
161 state_type E3 = StateCreator<state_type>::create(state);
162
163 std::fill(drift_sum.begin(), drift_sum.end(), value_type(0));
164 std::fill(E1.begin(), E1.end(), value_type(0));
165 std::fill(E2.begin(), E2.end(), value_type(0));
166 std::fill(E3.begin(), E3.end(), value_type(0));
167
168 for (int i = 0; i < stages; ++i) {
169 this->problem_->drift(this->current_time_ + tableau_.c0[i] * dt, H0[i], ftmp);
170 this->problem_->diffusion(this->current_time_ + tableau_.c1[i] * dt, H1[i], gtmp);
171
172 for (size_t k = 0; k < state.size(); ++k) {
173 auto drift_sum_it = drift_sum.begin();
174 auto E1_it = E1.begin();
175 auto E2_it = E2.begin();
176 auto E3_it = E3.begin();
177 auto ftmp_it = ftmp.begin();
178 auto gtmp_it = gtmp.begin();
179 auto dW_it = dW.begin();
180 auto chi1_it = chi1.begin();
181 auto chi2_it = chi2.begin();
182 auto chi3_it = chi3.begin();
183
184 drift_sum_it[k] += tableau_.alpha[i] * ftmp_it[k];
185 E1_it[k] += tableau_.beta1[i] * gtmp_it[k] * dW_it[k];
186 E2_it[k] += tableau_.beta2[i] * gtmp_it[k] * chi1_it[k];
187 E2_it[k] += tableau_.beta3[i] * gtmp_it[k] * chi2_it[k];
188 E3_it[k] += tableau_.beta4[i] * gtmp_it[k] * chi3_it[k];
189 }
190 }
191
192 // Final state update
193 for (size_t k = 0; k < state.size(); ++k) {
194 auto state_it = state.begin();
195 auto drift_sum_it = drift_sum.begin();
196 auto E1_it = E1.begin();
197 auto E2_it = E2.begin();
198 auto E3_it = E3.begin();
199
200 state_it[k] += dt * drift_sum_it[k] + E1_it[k] + E2_it[k] + E3_it[k];
201 }
202
203 this->advance_time(dt);
204 }
205
206 std::string name() const override {
207 return "SRI (Strong Order 1.5 for General Itô SDEs)";
208 }
209
210 void set_tableau(const tableau_type& tableau) {
211 tableau_ = tableau;
212 }
213
214private:
215 tableau_type tableau_;
216
217 // Default SRIW1 tableau
218 static tableau_type create_sriw1_tableau() {
219 tableau_type tableau;
220 tableau.stages = 2;
221 tableau.order = static_cast<value_type>(1.5);
222
223 // Basic SRIW1 coefficients (simplified)
224 tableau.A0 = {{0, 0}, {1, 0}};
225 tableau.A1 = {{0, 0}, {1, 0}};
226 tableau.c0 = {0, 1};
227 tableau.alpha = {static_cast<value_type>(0.5), static_cast<value_type>(0.5)};
228
229 tableau.B0 = {{0, 0}, {1, 0}};
230 tableau.B1 = {{0, 0}, {1, 0}};
231 tableau.c1 = {0, 1};
232 tableau.beta1 = {static_cast<value_type>(0.5), static_cast<value_type>(0.5)};
233 tableau.beta2 = {0, 1};
234 tableau.beta3 = {0, static_cast<value_type>(0.5)};
235 tableau.beta4 = {0, static_cast<value_type>(1.0/6.0)};
236
237 return tableau;
238 }
239};
240
241} // namespace diffeq
SRI (Stochastic Runge-Kutta for general Itô SDEs) integrator.
Definition sri.hpp:44
Abstract base class for SDE integrators.
Definition sde_base.hpp:147
Tableau coefficients for SRI methods.
Definition sri.hpp:15