DiffEq - Modern C++ ODE Integration Library 1.0.0
High-performance C++ library for solving ODEs with async signal processing
Loading...
Searching...
No Matches
milstein.hpp
1#pragma once
2
3#include <sde/sde_base.hpp>
4#include <core/state_creator.hpp>
5#include <cmath>
6
7namespace diffeq {
8
21template<system_state StateType>
23public:
25 using state_type = typename base_type::state_type;
26 using time_type = typename base_type::time_type;
27 using value_type = typename base_type::value_type;
28
29 // Function signature for diffusion derivative
30 using diffusion_derivative_function = std::function<void(time_type, const state_type&, state_type&)>;
31
32 explicit MilsteinIntegrator(std::shared_ptr<typename base_type::sde_problem_type> problem,
33 diffusion_derivative_function diffusion_derivative,
34 std::shared_ptr<typename base_type::wiener_process_type> wiener = nullptr)
35 : base_type(problem, wiener)
36 , diffusion_derivative_(std::move(diffusion_derivative)) {}
37
38 void step(state_type& state, time_type dt) override {
39 // Create temporary states
40 state_type drift_term = StateCreator<state_type>::create(state);
41 state_type diffusion_term = StateCreator<state_type>::create(state);
42 state_type diffusion_deriv_term = StateCreator<state_type>::create(state);
43 state_type dW = StateCreator<state_type>::create(state);
44
45 // Generate Wiener increments
46 this->wiener_->generate_increment(dW, dt);
47
48 // Compute drift: f(t, X)
49 this->problem_->drift(this->current_time_, state, drift_term);
50
51 // Compute diffusion: g(t, X)
52 this->problem_->diffusion(this->current_time_, state, diffusion_term);
53
54 // Compute diffusion derivative: g'(t, X)
55 diffusion_derivative_(this->current_time_, state, diffusion_deriv_term);
56
57 // Apply noise to diffusion term
58 this->problem_->apply_noise(this->current_time_, state, diffusion_term, dW);
59
60 // Update state: X_{n+1} = X_n + f*dt + g*dW + 0.5*g*g'*(dW^2 - dt)
61 for (size_t i = 0; i < state.size(); ++i) {
62 auto state_it = state.begin();
63 auto drift_it = drift_term.begin();
64 auto diffusion_it = diffusion_term.begin();
65 auto diffusion_deriv_it = diffusion_deriv_term.begin();
66 auto dW_it = dW.begin();
67
68 value_type dW_squared = dW_it[i] * dW_it[i];
69 value_type correction = static_cast<value_type>(0.5) * diffusion_it[i] * diffusion_deriv_it[i] * (dW_squared - dt);
70
71 state_it[i] += drift_it[i] * dt + diffusion_it[i] + correction;
72 }
73
74 this->advance_time(dt);
75 }
76
77 std::string name() const override {
78 return "Milstein";
79 }
80
81private:
82 diffusion_derivative_function diffusion_derivative_;
83};
84
85} // namespace diffeq
Milstein method for SDEs.
Definition milstein.hpp:22
Abstract base class for SDE integrators.
Definition sde_base.hpp:147