Program Listing for File sri1.hpp

Return to documentation for file (include/integrators/sde/sri1.hpp)

#pragma once

#include <sde/sde_base.hpp>
#include <core/state_creator.hpp>
#include <cmath>

namespace diffeq {

template<system_state StateType>
class SRI1Integrator : public sde::AbstractSDEIntegrator<StateType> {
public:
    using base_type = sde::AbstractSDEIntegrator<StateType>;
    using state_type = typename base_type::state_type;
    using time_type = typename base_type::time_type;
    using value_type = typename base_type::value_type;

    explicit SRI1Integrator(std::shared_ptr<typename base_type::sde_problem_type> problem,
                          std::shared_ptr<typename base_type::wiener_process_type> wiener = nullptr)
        : base_type(problem, wiener) {}

    void step(state_type& state, time_type dt) override {
        // Create temporary states
        state_type k1 = StateCreator<state_type>::create(state);
        state_type k2 = StateCreator<state_type>::create(state);
        state_type g1 = StateCreator<state_type>::create(state);
        state_type g2 = StateCreator<state_type>::create(state);
        state_type temp_state = StateCreator<state_type>::create(state);
        state_type dW = StateCreator<state_type>::create(state);

        // Generate Wiener increments
        this->wiener_->generate_increment(dW, dt);

        time_type t = this->current_time_;
        value_type sqrt_dt = std::sqrt(static_cast<value_type>(dt));

        // Stage 1: k1 = f(t, X), g1 = g(t, X)
        this->problem_->drift(t, state, k1);
        this->problem_->diffusion(t, state, g1);

        // Intermediate state for stage 2
        for (size_t i = 0; i < state.size(); ++i) {
            auto state_it = state.begin();
            auto k1_it = k1.begin();
            auto g1_it = g1.begin();
            auto temp_it = temp_state.begin();
            auto dW_it = dW.begin();

            temp_it[i] = state_it[i] + k1_it[i] * dt + g1_it[i] * sqrt_dt;
        }

        // Stage 2: k2 = f(t + dt, temp_state), g2 = g(t + dt, temp_state)
        this->problem_->drift(t + dt, temp_state, k2);
        this->problem_->diffusion(t + dt, temp_state, g2);

        // Apply noise to diffusion terms
        state_type g1_noise = StateCreator<state_type>::create(state);
        state_type g2_noise = StateCreator<state_type>::create(state);

        for (size_t i = 0; i < state.size(); ++i) {
            auto g1_it = g1.begin();
            auto g2_it = g2.begin();
            auto g1_noise_it = g1_noise.begin();
            auto g2_noise_it = g2_noise.begin();

            g1_noise_it[i] = g1_it[i];
            g2_noise_it[i] = g2_it[i];
        }

        this->problem_->apply_noise(t, state, g1_noise, dW);
        this->problem_->apply_noise(t + dt, temp_state, g2_noise, dW);

        // Final update: X_{n+1} = X_n + (k1 + k2)/2 * dt + (g1 + g2)/2 * dW
        for (size_t i = 0; i < state.size(); ++i) {
            auto state_it = state.begin();
            auto k1_it = k1.begin();
            auto k2_it = k2.begin();
            auto g1_noise_it = g1_noise.begin();
            auto g2_noise_it = g2_noise.begin();

            state_it[i] += (k1_it[i] + k2_it[i]) * dt * static_cast<value_type>(0.5) +
                          (g1_noise_it[i] + g2_noise_it[i]) * static_cast<value_type>(0.5);
        }

        this->advance_time(dt);
    }

    std::string name() const override {
        return "SRI1 (Stochastic Runge-Kutta)";
    }
};

} // namespace diffeq