DiffEq - Modern C++ ODE Integration Library 1.0.0
High-performance C++ library for solving ODEs with async signal processing
Loading...
Searching...
No Matches
sde_base.hpp
1#pragma once
2
3#include <core/concepts.hpp>
4#include <functional>
5#include <memory>
6#include <random>
7#include <chrono>
8
9namespace diffeq::sde {
10
14enum class NoiseType {
15 SCALAR_NOISE, // Single noise source
16 DIAGONAL_NOISE, // Independent noise for each component
17 GENERAL_NOISE // Correlated noise (full noise matrix)
18};
19
31template<system_state StateType>
33public:
34 using state_type = StateType;
35 using time_type = typename StateType::value_type;
36 using value_type = typename StateType::value_type;
37
38 // Function signatures
39 using drift_function = std::function<void(time_type, const state_type&, state_type&)>;
40 using diffusion_function = std::function<void(time_type, const state_type&, state_type&)>;
41 using noise_function = std::function<void(time_type, const state_type&, StateType&, const StateType&)>;
42
43 SDEProblem(drift_function drift, diffusion_function diffusion,
44 NoiseType noise_type = NoiseType::DIAGONAL_NOISE)
45 : drift_(std::move(drift))
46 , diffusion_(std::move(diffusion))
47 , noise_type_(noise_type) {}
48
49 void drift(time_type t, const state_type& x, state_type& fx) const {
50 drift_(t, x, fx);
51 }
52
53 void diffusion(time_type t, const state_type& x, state_type& gx) const {
54 diffusion_(t, x, gx);
55 }
56
57 NoiseType get_noise_type() const { return noise_type_; }
58
59 void set_noise_function(noise_function noise) {
60 noise_ = std::move(noise);
61 }
62
63 bool has_custom_noise() const { return noise_ != nullptr; }
64
65 void apply_noise(time_type t, const state_type& x, state_type& noise_term, const state_type& dW) const {
66 if (noise_) {
67 noise_(t, x, noise_term, dW);
68 } else {
69 // Default noise application based on noise type
70 apply_default_noise(noise_term, dW);
71 }
72 }
73
74private:
75 drift_function drift_;
76 diffusion_function diffusion_;
77 noise_function noise_;
78 NoiseType noise_type_;
79
80 void apply_default_noise(state_type& noise_term, const state_type& dW) const {
81 switch (noise_type_) {
82 case NoiseType::SCALAR_NOISE:
83 // All components use the same noise
84 for (size_t i = 0; i < noise_term.size(); ++i) {
85 noise_term[i] *= dW[0];
86 }
87 break;
88
89 case NoiseType::DIAGONAL_NOISE:
90 // Each component has independent noise
91 for (size_t i = 0; i < noise_term.size() && i < dW.size(); ++i) {
92 noise_term[i] *= dW[i];
93 }
94 break;
95
96 case NoiseType::GENERAL_NOISE:
97 // Custom noise - should be handled by noise function
98 // Default to diagonal for safety
99 for (size_t i = 0; i < noise_term.size() && i < dW.size(); ++i) {
100 noise_term[i] *= dW[i];
101 }
102 break;
103 }
104 }
105};
106
110template<system_state StateType>
112public:
113 using state_type = StateType;
114 using time_type = typename StateType::value_type;
115 using value_type = typename StateType::value_type;
116
117 explicit WienerProcess(size_t dimension, uint32_t seed = 0)
118 : dimension_(dimension)
119 , generator_(seed == 0 ? std::chrono::steady_clock::now().time_since_epoch().count() : seed)
120 , normal_dist_(0.0, 1.0) {}
121
122 void generate_increment(state_type& dW, time_type dt) {
123 value_type sqrt_dt = std::sqrt(static_cast<value_type>(dt));
124
125 for (size_t i = 0; i < dimension_ && i < dW.size(); ++i) {
126 auto dW_it = dW.begin();
127 dW_it[i] = static_cast<value_type>(normal_dist_(generator_)) * sqrt_dt;
128 }
129 }
130
131 void set_seed(uint32_t seed) {
132 generator_.seed(seed);
133 }
134
135 size_t dimension() const { return dimension_; }
136
137private:
138 size_t dimension_;
139 std::mt19937 generator_;
140 std::normal_distribution<double> normal_dist_;
141};
142
146template<system_state StateType>
148public:
149 using state_type = StateType;
150 using time_type = typename StateType::value_type;
151 using value_type = typename StateType::value_type;
154
155 explicit AbstractSDEIntegrator(std::shared_ptr<sde_problem_type> problem,
156 std::shared_ptr<wiener_process_type> wiener = nullptr)
157 : problem_(problem)
158 , wiener_(wiener ? wiener : std::make_shared<wiener_process_type>(get_default_dimension(), 0))
159 , current_time_(0) {}
160
161 virtual ~AbstractSDEIntegrator() = default;
162
163 // Pure virtual methods to be implemented by derived classes
164 virtual void step(state_type& state, time_type dt) = 0;
165 virtual std::string name() const = 0;
166
167 // Integration interface
168 void integrate(state_type& state, time_type dt, time_type end_time) {
169 while (current_time_ < end_time) {
170 time_type step_size = std::min<time_type>(dt, end_time - current_time_);
171 step(state, step_size);
172 }
173 }
174
175 // Accessors
176 time_type current_time() const { return current_time_; }
177 void set_time(time_type t) { current_time_ = t; }
178
179 std::shared_ptr<sde_problem_type> get_problem() const { return problem_; }
180 std::shared_ptr<wiener_process_type> get_wiener_process() const { return wiener_; }
181
182 void set_wiener_process(std::shared_ptr<wiener_process_type> wiener) {
183 wiener_ = wiener;
184 }
185
186protected:
187 void advance_time(time_type dt) { current_time_ += dt; }
188
189 virtual size_t get_default_dimension() {
190 // Default to assuming state dimension equals noise dimension
191 return 10; // Will be overridden by actual state size in practice
192 }
193
194 std::shared_ptr<sde_problem_type> problem_;
195 std::shared_ptr<wiener_process_type> wiener_;
196 time_type current_time_;
197};
198
202namespace factory {
203
204template<system_state StateType>
205auto make_sde_problem(
206 typename SDEProblem<StateType>::drift_function drift,
207 typename SDEProblem<StateType>::diffusion_function diffusion,
208 NoiseType noise_type = NoiseType::DIAGONAL_NOISE) {
209 return std::make_shared<SDEProblem<StateType>>(std::move(drift), std::move(diffusion), noise_type);
210}
211
212template<system_state StateType>
213auto make_wiener_process(size_t dimension, uint32_t seed = 0) {
214 return std::make_shared<WienerProcess<StateType>>(dimension, seed);
215}
216
217} // namespace factory
218
219} // namespace diffeq::sde
Abstract base class for SDE integrators.
Definition sde_base.hpp:147
SDE problem definition.
Definition sde_base.hpp:32
Wiener process (Brownian motion) generator.
Definition sde_base.hpp:111