DiffEq - Modern C++ ODE Integration Library 1.0.0
High-performance C++ library for solving ODEs with async signal processing
Loading...
Searching...
No Matches
sra.hpp
1#pragma once
2
3#include <sde/sde_base.hpp>
4#include <core/state_creator.hpp>
5#include <cmath>
6#include <vector>
7
8namespace diffeq {
9
13template<typename T>
14struct SRATableau {
15 // Drift coefficients
16 std::vector<std::vector<T>> A0;
17 std::vector<T> c0;
18 std::vector<T> alpha;
19
20 // Diffusion coefficients
21 std::vector<std::vector<T>> B0;
22 std::vector<T> c1;
23 std::vector<T> beta1;
24 std::vector<T> beta2;
25
26 int stages;
27 T order;
28};
29
43template<system_state StateType>
44class SRAIntegrator : public sde::AbstractSDEIntegrator<StateType> {
45public:
47 using state_type = typename base_type::state_type;
48 using time_type = typename base_type::time_type;
49 using value_type = typename base_type::value_type;
51
52 explicit SRAIntegrator(std::shared_ptr<typename base_type::sde_problem_type> problem,
53 std::shared_ptr<typename base_type::wiener_process_type> wiener = nullptr,
54 tableau_type tableau = SRAIntegrator::create_sra1_tableau())
55 : base_type(problem, wiener)
56 , tableau_(std::move(tableau)) {}
57
58 void step(state_type& state, time_type dt) override {
59 const int stages = tableau_.stages;
60
61 // Create temporary states
62 std::vector<state_type> H0(stages);
63 for (int i = 0; i < stages; ++i) {
65 }
66
67 state_type dW = StateCreator<state_type>::create(state);
68 state_type dZ = StateCreator<state_type>::create(state); // For chi2 computation
69 state_type ftmp = StateCreator<state_type>::create(state);
70 state_type gtmp = StateCreator<state_type>::create(state);
71 state_type atemp = StateCreator<state_type>::create(state);
72 state_type btemp = StateCreator<state_type>::create(state);
73 state_type E2 = StateCreator<state_type>::create(state);
74
75 // Generate Wiener increments
76 this->wiener_->generate_increment(dW, dt);
77 this->wiener_->generate_increment(dZ, dt); // Independent for chi2
78
79 // Compute chi2 = (1/2)*(dW + dZ/sqrt(3)) for I_(1,0)/h
80 value_type sqrt3 = std::sqrt(static_cast<value_type>(3));
81 state_type chi2 = StateCreator<state_type>::create(state);
82 for (size_t j = 0; j < chi2.size(); ++j) {
83 auto chi2_it = chi2.begin();
84 auto dW_it = dW.begin();
85 auto dZ_it = dZ.begin();
86 chi2_it[j] = static_cast<value_type>(0.5) * (dW_it[j] + dZ_it[j] / sqrt3);
87 }
88
89 // Initialize H0[0] = current state
90 for (size_t j = 0; j < state.size(); ++j) {
91 auto state_it = state.begin();
92 auto H0_0_it = H0[0].begin();
93 H0_0_it[j] = state_it[j];
94 }
95
96 // Compute stages
97 for (int i = 1; i < stages; ++i) {
98 // Compute A0temp and B0temp
99 state_type A0temp = StateCreator<state_type>::create(state);
100 state_type B0temp = StateCreator<state_type>::create(state);
101
102 for (int j = 0; j < i; ++j) {
103 this->problem_->drift(this->current_time_ + tableau_.c0[j] * dt, H0[j], ftmp);
104 this->problem_->diffusion(this->current_time_ + tableau_.c1[j] * dt, H0[j], gtmp);
105
106 for (size_t k = 0; k < state.size(); ++k) {
107 auto A0temp_it = A0temp.begin();
108 auto B0temp_it = B0temp.begin();
109 auto ftmp_it = ftmp.begin();
110 auto gtmp_it = gtmp.begin();
111 auto chi2_it = chi2.begin();
112
113 A0temp_it[k] += tableau_.A0[j][i] * ftmp_it[k];
114 B0temp_it[k] += tableau_.B0[j][i] * gtmp_it[k] * chi2_it[k];
115 }
116 }
117
118 // H0[i] = state + dt*A0temp + B0temp
119 for (size_t k = 0; k < state.size(); ++k) {
120 auto state_it = state.begin();
121 auto H0_i_it = H0[i].begin();
122 auto A0temp_it = A0temp.begin();
123 auto B0temp_it = B0temp.begin();
124
125 H0_i_it[k] = state_it[k] + dt * A0temp_it[k] + B0temp_it[k];
126 }
127 }
128
129 // Compute final update terms
130 std::fill(atemp.begin(), atemp.end(), value_type(0));
131 std::fill(btemp.begin(), btemp.end(), value_type(0));
132 std::fill(E2.begin(), E2.end(), value_type(0));
133
134 for (int i = 0; i < stages; ++i) {
135 this->problem_->drift(this->current_time_ + tableau_.c0[i] * dt, H0[i], ftmp);
136 this->problem_->diffusion(this->current_time_ + tableau_.c1[i] * dt, H0[i], gtmp);
137
138 for (size_t k = 0; k < state.size(); ++k) {
139 auto atemp_it = atemp.begin();
140 auto btemp_it = btemp.begin();
141 auto E2_it = E2.begin();
142 auto ftmp_it = ftmp.begin();
143 auto gtmp_it = gtmp.begin();
144 auto dW_it = dW.begin();
145 auto chi2_it = chi2.begin();
146
147 atemp_it[k] += tableau_.alpha[i] * ftmp_it[k];
148 btemp_it[k] += tableau_.beta1[i] * gtmp_it[k] * dW_it[k];
149 E2_it[k] += tableau_.beta2[i] * gtmp_it[k] * chi2_it[k];
150 }
151 }
152
153 // Final state update: u = uprev + dt*atemp + btemp + E2
154 for (size_t k = 0; k < state.size(); ++k) {
155 auto state_it = state.begin();
156 auto atemp_it = atemp.begin();
157 auto btemp_it = btemp.begin();
158 auto E2_it = E2.begin();
159
160 state_it[k] += dt * atemp_it[k] + btemp_it[k] + E2_it[k];
161 }
162
163 this->advance_time(dt);
164 }
165
166 std::string name() const override {
167 return "SRA (Strong Order 1.5 for Additive Noise)";
168 }
169
170 void set_tableau(const tableau_type& tableau) {
171 tableau_ = tableau;
172 }
173
174private:
175 tableau_type tableau_;
176
177 // Default SRA1 tableau
178 static tableau_type create_sra1_tableau() {
179 tableau_type tableau;
180 tableau.stages = 2;
181 tableau.order = static_cast<value_type>(1.5);
182
183 // Drift coefficients
184 tableau.A0 = {{0, 0}, {1, 0}};
185 tableau.c0 = {0, 1};
186 tableau.alpha = {static_cast<value_type>(0.5), static_cast<value_type>(0.5)};
187
188 // Diffusion coefficients
189 tableau.B0 = {{0, 0}, {1, 0}};
190 tableau.c1 = {0, 1};
191 tableau.beta1 = {static_cast<value_type>(0.5), static_cast<value_type>(0.5)};
192 tableau.beta2 = {0, 1};
193
194 return tableau;
195 }
196};
197
198} // namespace diffeq
SRA (Stochastic Runge-Kutta for additive noise SDEs) integrator.
Definition sra.hpp:44
Abstract base class for SDE integrators.
Definition sde_base.hpp:147
Tableau coefficients for SRA methods.
Definition sra.hpp:14