DiffEq - Modern C++ ODE Integration Library 1.0.0
High-performance C++ library for solving ODEs with async signal processing
Loading...
Searching...
No Matches
coroutine_integration_demo.cpp
1#define NOMINMAX
2#include <diffeq.hpp>
3#include <coroutine>
4#include <iostream>
5#include <vector>
6#include <chrono>
7#include <thread>
8#include <queue>
9#include <mutex>
10#include <optional>
11#include <memory>
12
13#ifdef _WIN32
14#include <windows.h>
15#endif
16
30// ============================================================================
31// 协程基础设施
32// ============================================================================
33
37template<typename State>
39 struct promise_type {
40 State current_state;
41 double current_time{0.0};
42 std::exception_ptr exception;
43
44 IntegrationTask get_return_object() {
45 return IntegrationTask{
46 std::coroutine_handle<promise_type>::from_promise(*this)
47 };
48 }
49
50 std::suspend_always initial_suspend() { return {}; }
51 std::suspend_always final_suspend() noexcept { return {}; }
52
53 void unhandled_exception() {
54 exception = std::current_exception();
55 }
56
57 void return_void() {}
58
59 // 允许协程 co_yield 状态和时间
60 std::suspend_always yield_value(std::pair<const State&, double> value) {
61 current_state = value.first;
62 current_time = value.second;
63 return {};
64 }
65 };
66
67 using handle_type = std::coroutine_handle<promise_type>;
68 handle_type coro;
69
70 explicit IntegrationTask(handle_type h) : coro(h) {}
71
73 if (coro) {
74 coro.destroy();
75 }
76 }
77
78 // 移动构造和赋值
79 IntegrationTask(IntegrationTask&& other) noexcept
80 : coro(std::exchange(other.coro, {})) {}
81
82 IntegrationTask& operator=(IntegrationTask&& other) noexcept {
83 if (this != &other) {
84 if (coro) coro.destroy();
85 coro = std::exchange(other.coro, {});
86 }
87 return *this;
88 }
89
90 // 禁用拷贝
91 IntegrationTask(const IntegrationTask&) = delete;
92 IntegrationTask& operator=(const IntegrationTask&) = delete;
93
94 // 恢复协程执行
95 bool resume() {
96 if (!coro || coro.done()) return false;
97 coro.resume();
98 return !coro.done();
99 }
100
101 // 检查是否完成
102 bool done() const {
103 return !coro || coro.done();
104 }
105
106 // 获取当前状态
107 std::pair<State, double> get_current() const {
108 if (coro) {
109 return {coro.promise().current_state, coro.promise().current_time};
110 }
111 throw std::runtime_error("No current state available");
112 }
113
114 // 检查异常
115 void check_exception() {
116 if (coro && coro.promise().exception) {
117 std::rethrow_exception(coro.promise().exception);
118 }
119 }
120
121 // 使 IntegrationTask 可等待(awaitable)
122 bool await_ready() const noexcept {
123 return done();
124 }
125
126 void await_suspend(std::coroutine_handle<> h) {
127 // 在另一个线程中运行任务直到完成,然后恢复等待的协程
128 std::thread([this, h]() {
129 while (!done()) {
130 resume();
131 std::this_thread::sleep_for(std::chrono::milliseconds{1});
132 }
133 h.resume();
134 }).detach();
135 }
136
137 std::pair<State, double> await_resume() {
138 check_exception();
139 if (coro) {
140 return {coro.promise().current_state, coro.promise().current_time};
141 }
142 throw std::runtime_error("Coroutine not available");
143 }
144};
145
150 std::chrono::milliseconds delay;
151
152 bool await_ready() const noexcept { return delay.count() <= 0; }
153
154 void await_suspend(std::coroutine_handle<> h) const {
155 std::thread([h, this]() {
156 std::this_thread::sleep_for(delay);
157 h.resume();
158 }).detach();
159 }
160
161 void await_resume() const noexcept {}
162};
163
164// ============================================================================
165// 协程化的积分器包装
166// ============================================================================
167
171template<typename State>
173private:
174 std::unique_ptr<diffeq::core::AbstractIntegrator<State>> integrator_;
175
176public:
177 explicit CoroutineIntegrator(
178 std::unique_ptr<diffeq::core::AbstractIntegrator<State>> integrator)
179 : integrator_(std::move(integrator)) {}
180
189 State initial_state,
190 typename diffeq::core::AbstractIntegrator<State>::time_type dt,
191 typename diffeq::core::AbstractIntegrator<State>::time_type end_time,
192 size_t yield_interval = 10) {
193
194 State state = std::move(initial_state);
195 double current_time = 0.0;
196 integrator_->set_time(current_time);
197 size_t step_count = 0;
198
199 while (current_time < end_time) {
200 // 执行一步积分
201 auto step_dt = std::min(dt, end_time - current_time);
202 integrator_->step(state, step_dt);
203 current_time += step_dt; // 手动更新时间
204 integrator_->set_time(current_time); // 同步积分器时间
205 step_count++;
206
207 // 每隔一定步数,暂停并返回当前状态
208 if (step_count % yield_interval == 0) {
209 co_yield std::make_pair(std::cref(state), current_time);
210 }
211 }
212
213 // 返回最终状态
214 co_yield std::make_pair(std::cref(state), current_time);
215 }
216
220 template<typename ProgressCallback>
222 State initial_state,
223 typename diffeq::core::AbstractIntegrator<State>::time_type dt,
224 typename diffeq::core::AbstractIntegrator<State>::time_type end_time,
225 ProgressCallback&& callback) {
226
227 State state = std::move(initial_state);
228 double current_time = 0.0;
229 integrator_->set_time(current_time);
230
231 while (current_time < end_time) {
232 // 执行积分步
233 auto step_dt = std::min(dt, end_time - current_time);
234 integrator_->step(state, step_dt);
235 current_time += step_dt; // 手动更新时间
236 integrator_->set_time(current_time); // 同步积分器时间
237
238 // 调用进度回调
239 double progress = current_time / end_time;
240 bool should_continue = callback(state, current_time, progress);
241
242 if (!should_continue) {
243 break; // 用户请求停止
244 }
245
246 // 让出控制权
247 co_yield std::make_pair(std::cref(state), current_time);
248 }
249 }
250};
251
252// ============================================================================
253// 协程任务调度器
254// ============================================================================
255
260private:
261 struct Task {
262 std::function<bool()> resume_func;
263 std::string name;
264 std::chrono::steady_clock::time_point last_run;
265 std::chrono::milliseconds interval;
266 };
267
268 std::vector<Task> tasks_;
269 std::mutex mutex_;
270
271public:
275 template<typename State>
277 const std::string& name,
278 std::chrono::milliseconds interval = std::chrono::milliseconds{0}) {
279 std::lock_guard<std::mutex> lock(mutex_);
280
281 // 捕获任务的共享指针,确保生命周期
282 auto task_ptr = std::make_shared<IntegrationTask<State>>(std::move(task));
283
284 tasks_.push_back({
285 [task_ptr]() { return task_ptr->resume(); },
286 name,
287 std::chrono::steady_clock::now(),
288 interval
289 });
290 }
291
296 void run(std::chrono::milliseconds duration) {
297 auto end_time = std::chrono::steady_clock::now() + duration;
298
299 while (std::chrono::steady_clock::now() < end_time) {
300 std::vector<Task> active_tasks;
301
302 {
303 std::lock_guard<std::mutex> lock(mutex_);
304 // 移除已完成的任务,保留活跃任务
305 for (auto& task : tasks_) {
306 auto now = std::chrono::steady_clock::now();
307 if (now - task.last_run >= task.interval) {
308 if (task.resume_func()) {
309 task.last_run = now;
310 active_tasks.push_back(task);
311 } else {
312 std::cout << "任务 '" << task.name << "' 已完成" << std::endl;
313 }
314 } else {
315 active_tasks.push_back(task);
316 }
317 }
318 tasks_ = std::move(active_tasks);
319 }
320
321 // 短暂休眠,避免忙等待
322 std::this_thread::sleep_for(std::chrono::milliseconds{1});
323 }
324 }
325
330 std::lock_guard<std::mutex> lock(mutex_);
331 return tasks_.size();
332 }
333};
334
335// ============================================================================
336// 示例:多尺度积分
337// ============================================================================
338
344IntegrationTask<std::vector<double>> multiscale_integration_coro(
345 double epsilon = 0.01) {
346
347 // 快-慢耦合系统
348 auto system = [epsilon](double t, const std::vector<double>& x,
349 std::vector<double>& dx) {
350 // 慢变量
351 dx[0] = -x[0] + x[1];
352 // 快变量
353 dx[1] = -(1.0/epsilon) * (x[1] - x[0]*x[0]);
354 };
355
356 // 创建积分器
357 auto integrator = std::make_unique<diffeq::RK45Integrator<std::vector<double>>>(system);
358 CoroutineIntegrator<std::vector<double>> coro_integrator(std::move(integrator));
359
360 // 初始条件
361 std::vector<double> state = {1.0, 0.0};
362
363 std::cout << "\n=== 多尺度系统积分 (ε = " << epsilon << ") ===" << std::endl;
364
365 // 使用自适应步长,协程每10步返回一次
366 auto task = coro_integrator.integrate_coro(state, 0.001, 5.0, 10);
367
368 size_t yield_count = 0;
369 while (!task.done()) {
370 task.resume();
371
372 if (!task.done()) {
373 auto [current_state, current_time] = task.get_current();
374 yield_count++;
375
376 // 每50次yield打印一次状态
377 if (yield_count % 50 == 0) {
378 std::cout << "t = " << current_time
379 << ", 慢变量 = " << current_state[0]
380 << ", 快变量 = " << current_state[1] << std::endl;
381 }
382
383 // 模拟其他计算
384 co_await TimedSuspend{std::chrono::milliseconds{1}};
385 }
386 }
387
388 std::cout << "多尺度积分完成,共 yield " << yield_count << " 次" << std::endl;
389}
390
391// ============================================================================
392// 示例:参数扫描与动态调度
393// ============================================================================
394
398template<typename State>
399IntegrationTask<State> parameter_scan_coro(
400 double param,
401 std::function<void(double, const State&, double)> result_handler) {
402
403 // 参数化的 Van der Pol 振荡器
404 auto system = [param](double t, const std::vector<double>& x,
405 std::vector<double>& dx) {
406 dx[0] = x[1];
407 dx[1] = param * (1 - x[0]*x[0]) * x[1] - x[0];
408 };
409
410 auto integrator = std::make_unique<diffeq::RK4Integrator<std::vector<double>>>(system);
411 CoroutineIntegrator<std::vector<double>> coro_integrator(std::move(integrator));
412
413 State state = {2.0, 0.0};
414
415 // 带进度监控的积分
416 auto task = coro_integrator.integrate_with_progress(
417 state, 0.01, 20.0,
418 [param](const auto& s, double t, double progress) {
419 // 仅在关键时刻输出
420 if (static_cast<int>(progress * 100) % 25 == 0 &&
421 static_cast<int>(progress * 100) % 25 < 1) {
422 std::cout << "参数 " << param << " 的积分进度: "
423 << static_cast<int>(progress * 100) << "%" << std::endl;
424 }
425 return true; // 继续积分
426 }
427 );
428
429 // 执行积分
430 while (!task.done()) {
431 task.resume();
432
433 if (!task.done()) {
434 auto [current_state, current_time] = task.get_current();
435
436 // 让其他协程有机会运行
437 co_await std::suspend_always{};
438 }
439 }
440
441 // 获取最终结果
442 auto [final_state, final_time] = task.get_current();
443 result_handler(param, final_state, final_time);
444}
445
446// ============================================================================
447// 主程序
448// ============================================================================
449
450int main() {
451#ifdef _WIN32
452 SetConsoleOutputCP(CP_UTF8);
453#endif
454
455 std::cout << "=== C++20 协程与 diffeq 库集成示例 ===" << std::endl;
456 std::cout << "展示协程在细粒度 CPU 运行控制上的优势\n" << std::endl;
457
458 // 1. 基本协程积分示例
459 std::cout << "1. 基本协程积分示例" << std::endl;
460 {
461 // Lorenz 系统
462 auto lorenz = [](double t, const std::vector<double>& x,
463 std::vector<double>& dx) {
464 const double sigma = 10.0, rho = 28.0, beta = 8.0/3.0;
465 dx[0] = sigma * (x[1] - x[0]);
466 dx[1] = x[0] * (rho - x[2]) - x[1];
467 dx[2] = x[0] * x[1] - beta * x[2];
468 };
469
470 auto integrator = std::make_unique<diffeq::RK45Integrator<std::vector<double>>>(lorenz);
471 CoroutineIntegrator<std::vector<double>> coro_integrator(std::move(integrator));
472
473 std::vector<double> initial_state = {1.0, 1.0, 1.0};
474
475 // 创建协程任务
476 auto task = coro_integrator.integrate_coro(initial_state, 0.01, 2.0, 20);
477
478 std::cout << "开始 Lorenz 系统积分..." << std::endl;
479 size_t step_count = 0;
480
481 while (!task.done()) {
482 task.resume();
483 step_count++;
484
485 if (!task.done() && step_count % 5 == 0) {
486 try {
487 auto [state, time] = task.get_current();
488 std::cout << " t = " << time
489 << ", ||x|| = " << std::sqrt(state[0]*state[0] +
490 state[1]*state[1] +
491 state[2]*state[2])
492 << std::endl;
493 } catch (const std::exception& e) {
494 std::cout << " 获取 Lorenz 状态失败: " << e.what() << std::endl;
495 }
496 }
497 }
498
499 std::cout << "Lorenz 积分完成,共暂停/恢复 " << step_count << " 次\n" << std::endl;
500 }
501
502 // 2. 参数扫描示例(简化版)
503 std::cout << "2. 参数扫描示例" << std::endl;
504 {
505 std::vector<double> parameters = {0.5, 1.0, 2.0};
506 std::vector<std::pair<double, std::vector<double>>> results;
507
508 for (double param : parameters) {
509 std::cout << "开始参数 μ = " << param << " 的积分..." << std::endl;
510
511 auto system = [param](double t, const std::vector<double>& x,
512 std::vector<double>& dx) {
513 dx[0] = x[1];
514 dx[1] = param * (1 - x[0]*x[0]) * x[1] - x[0];
515 };
516
517 auto integrator = std::make_unique<diffeq::RK4Integrator<std::vector<double>>>(system);
518 CoroutineIntegrator<std::vector<double>> coro_integrator(std::move(integrator));
519
520 std::vector<double> state = {2.0, 0.0};
521 auto task = coro_integrator.integrate_coro(state, 0.05, 10.0, 50);
522
523 // 逐步执行并监控进度
524 size_t step_count = 0;
525 while (!task.done()) {
526 task.resume();
527 step_count++;
528
529 if (!task.done() && step_count % 20 == 0) {
530 try {
531 auto [current_state, current_time] = task.get_current();
532 std::cout << " μ = " << param
533 << ", t = " << current_time
534 << ", x = [" << current_state[0] << ", " << current_state[1] << "]"
535 << std::endl;
536 } catch (const std::exception& e) {
537 std::cout << " 获取状态失败: " << e.what() << std::endl;
538 }
539 }
540 }
541
542 try {
543 auto [final_state, final_time] = task.get_current();
544 results.emplace_back(param, final_state);
545 std::cout << "参数 μ = " << param << " 积分完成" << std::endl;
546 } catch (const std::exception& e) {
547 std::cout << "参数 μ = " << param << " 积分失败: " << e.what() << std::endl;
548 }
549 }
550
551 std::cout << "\n参数扫描结果:" << std::endl;
552 for (const auto& [param, state] : results) {
553 std::cout << " μ = " << param
554 << ", 最终状态 = [" << state[0] << ", " << state[1] << "]"
555 << std::endl;
556 }
557 }
558
559 // 3. 带进度监控的协程积分
560 std::cout << "\n3. 带进度监控的协程积分" << std::endl;
561 {
562 // 阻尼振荡器
563 auto damped_oscillator = [](double t, const std::vector<double>& x,
564 std::vector<double>& dx) {
565 double omega = 2.0, gamma = 0.1;
566 dx[0] = x[1];
567 dx[1] = -omega*omega*x[0] - 2*gamma*x[1];
568 };
569
570 auto integrator = std::make_unique<diffeq::RK45Integrator<std::vector<double>>>(damped_oscillator);
571 CoroutineIntegrator<std::vector<double>> coro_integrator(std::move(integrator));
572
573 std::vector<double> state = {1.0, 0.0};
574
575 auto task = coro_integrator.integrate_with_progress(
576 state, 0.01, 10.0,
577 [](const auto& s, double t, double progress) {
578 // 每25%进度报告一次
579 int progress_percent = static_cast<int>(progress * 100);
580 if (progress_percent % 25 == 0 && progress_percent > 0) {
581 std::cout << " 进度: " << progress_percent
582 << "%, t = " << t
583 << ", 能量 = " << 0.5 * (s[0]*s[0] + s[1]*s[1])
584 << std::endl;
585 }
586 return true; // 继续积分
587 }
588 );
589
590 std::cout << "开始阻尼振荡器积分..." << std::endl;
591 while (!task.done()) {
592 task.resume();
593 }
594
595 auto [final_state, final_time] = task.get_current();
596 std::cout << "阻尼振荡器积分完成,最终能量 = "
597 << 0.5 * (final_state[0]*final_state[0] + final_state[1]*final_state[1])
598 << std::endl;
599 }
600
601 // 4. 协程的细粒度控制示例
602 std::cout << "\n4. 协程的细粒度控制示例" << std::endl;
603 {
604 // 二体问题(简化的轨道力学)
605 auto orbital_system = [](double t, const std::vector<double>& x,
606 std::vector<double>& dx) {
607 double mu = 1.0; // 引力参数
608 double r = std::sqrt(x[0]*x[0] + x[1]*x[1]);
609 double r3 = r*r*r;
610
611 dx[0] = x[2]; // vx
612 dx[1] = x[3]; // vy
613 dx[2] = -mu * x[0] / r3; // ax
614 dx[3] = -mu * x[1] / r3; // ay
615 };
616
617 auto integrator = std::make_unique<diffeq::RK45Integrator<std::vector<double>>>(orbital_system);
618 CoroutineIntegrator<std::vector<double>> coro_integrator(std::move(integrator));
619
620 // 椭圆轨道初始条件
621 std::vector<double> state = {1.0, 0.0, 0.0, 0.8};
622
623 auto task = coro_integrator.integrate_coro(state, 0.01, 6.28, 25); // 一个轨道周期
624
625 std::cout << "开始轨道积分(每25步暂停一次)..." << std::endl;
626 size_t resume_count = 0;
627
628 while (!task.done()) {
629 task.resume();
630 resume_count++;
631
632 if (!task.done()) {
633 auto [current_state, current_time] = task.get_current();
634 double r = std::sqrt(current_state[0]*current_state[0] + current_state[1]*current_state[1]);
635 double v = std::sqrt(current_state[2]*current_state[2] + current_state[3]*current_state[3]);
636 double energy = 0.5 * v*v - 1.0/r; // 比能量
637
638 std::cout << " 第 " << resume_count << " 次恢复: t = " << current_time
639 << ", r = " << r
640 << ", 能量 = " << energy << std::endl;
641
642 // 演示协程的暂停特性
643 std::this_thread::sleep_for(std::chrono::milliseconds{10});
644 }
645 }
646
647 auto [final_state, final_time] = task.get_current();
648 std::cout << "轨道积分完成,共恢复 " << resume_count << " 次" << std::endl;
649 std::cout << "最终位置: [" << final_state[0] << ", " << final_state[1] << "]" << std::endl;
650 }
651
652 std::cout << "\n=== 协程集成演示完成 ===" << std::endl;
653 std::cout << "关键优势:" << std::endl;
654 std::cout << "- 细粒度的执行控制" << std::endl;
655 std::cout << "- 协作式多任务处理" << std::endl;
656 std::cout << "- 零开销的状态保存和恢复" << std::endl;
657 std::cout << "- 与标准库的无缝集成" << std::endl;
658
659 return 0;
660}
将积分器包装为协程,支持细粒度控制
IntegrationTask< State > integrate_coro(State initial_state, typename diffeq::core::AbstractIntegrator< State >::time_type dt, typename diffeq::core::AbstractIntegrator< State >::time_type end_time, size_t yield_interval=10)
协程化的积分,每步都可以暂停
IntegrationTask< State > integrate_with_progress(State initial_state, typename diffeq::core::AbstractIntegrator< State >::time_type dt, typename diffeq::core::AbstractIntegrator< State >::time_type end_time, ProgressCallback &&callback)
带进度回调的协程积分
简单的协程任务调度器
void run(std::chrono::milliseconds duration)
运行调度器
void add_task(IntegrationTask< State > &&task, const std::string &name, std::chrono::milliseconds interval=std::chrono::milliseconds{0})
添加一个协程任务
size_t active_task_count()
获取活跃任务数
Modern C++ ODE Integration Library with Real-time Signal Processing.
C++20 协程与 diffeq 库集成示例
可等待的延迟对象,用于协程中的定时暂停