DiffEq - Modern C++ ODE Integration Library 1.0.0
High-performance C++ library for solving ODEs with async signal processing
Loading...
Searching...
No Matches
implicit_euler_maruyama.hpp
1#pragma once
2
3#include <sde/sde_base.hpp>
4#include <core/state_creator.hpp>
5#include <algorithm>
6#include <cmath>
7
8namespace diffeq {
9
18template<system_state StateType>
20public:
22 using state_type = typename base_type::state_type;
23 using time_type = typename base_type::time_type;
24 using value_type = typename base_type::value_type;
25
27 std::shared_ptr<typename base_type::sde_problem_type> problem,
28 std::shared_ptr<typename base_type::wiener_process_type> wiener = nullptr,
29 int max_iterations = 10,
30 value_type tolerance = 1e-8)
31 : base_type(problem, wiener)
32 , max_iterations_(max_iterations)
33 , tolerance_(tolerance) {}
34
35 void step(state_type& state, time_type dt) override {
36 // Create temporary states
37 state_type diffusion_term = StateCreator<state_type>::create(state);
38 state_type dW = StateCreator<state_type>::create(state);
39 state_type x_new = StateCreator<state_type>::create(state);
40 state_type x_old = StateCreator<state_type>::create(state);
41 state_type drift_term = StateCreator<state_type>::create(state);
42
43 // Generate Wiener increments
44 this->wiener_->generate_increment(dW, dt);
45
46 // Compute explicit diffusion term: g(t_n, X_n) * dW_n
47 this->problem_->diffusion(this->current_time_, state, diffusion_term);
48 this->problem_->apply_noise(this->current_time_, state, diffusion_term, dW);
49
50 // Initial guess: x_new = x_old (explicit Euler)
51 for (size_t i = 0; i < state.size(); ++i) {
52 auto state_it = state.begin();
53 auto x_new_it = x_new.begin();
54 auto diffusion_it = diffusion_term.begin();
55
56 x_new_it[i] = state_it[i] + diffusion_it[i];
57 }
58
59 // Fixed-point iteration to solve: x_new = x_old + f(t+dt, x_new)*dt + diffusion_term
60 for (int iter = 0; iter < max_iterations_; ++iter) {
61 // Save old iterate
62 for (size_t i = 0; i < state.size(); ++i) {
63 auto x_new_it = x_new.begin();
64 auto x_old_it = x_old.begin();
65 x_old_it[i] = x_new_it[i];
66 }
67
68 // Compute drift at new time and new state
69 this->problem_->drift(this->current_time_ + dt, x_old, drift_term);
70
71 // Update: x_new = x_n + f(t+dt, x_old)*dt + diffusion_term
72 value_type max_change = 0;
73 for (size_t i = 0; i < state.size(); ++i) {
74 auto state_it = state.begin();
75 auto x_new_it = x_new.begin();
76 auto x_old_it = x_old.begin();
77 auto drift_it = drift_term.begin();
78 auto diffusion_it = diffusion_term.begin();
79
80 value_type new_val = state_it[i] + drift_it[i] * dt + diffusion_it[i];
81 value_type change = std::abs(new_val - x_old_it[i]);
82 max_change = std::max<value_type>(max_change, change);
83 x_new_it[i] = new_val;
84 }
85
86 // Check convergence
87 if (max_change < tolerance_) {
88 break;
89 }
90 }
91
92 // Update state
93 for (size_t i = 0; i < state.size(); ++i) {
94 auto state_it = state.begin();
95 auto x_new_it = x_new.begin();
96 state_it[i] = x_new_it[i];
97 }
98
99 this->advance_time(dt);
100 }
101
102 std::string name() const override {
103 return "Implicit Euler-Maruyama";
104 }
105
106 void set_iteration_parameters(int max_iterations, value_type tolerance) {
107 max_iterations_ = max_iterations;
108 tolerance_ = tolerance;
109 }
110
111private:
112 int max_iterations_;
113 value_type tolerance_;
114};
115
116} // namespace diffeq
Abstract base class for SDE integrators.
Definition sde_base.hpp:147