MaMiCo 1.2
Loading...
Searching...
No Matches
ParallelTimeIntegrationService.h
1// Copyright (C) 2023 Helmut Schmidt University
2// This file is part of the Mamico project. For conditions of distribution
3// and use, please see the copyright notice in Mamico's main folder
4
5#pragma once
6
7namespace coupling {
8namespace services {
9template <unsigned int dim> class ParallelTimeIntegrationService;
10} // namespace services
11} // namespace coupling
12
13#include "coupling/configurations/MaMiCoConfiguration.h"
14#include "coupling/interface/PintableMacroSolver.h"
15#include "coupling/scenario/Scenario.h"
16#include <functional>
17
18// Convenience operators, to be able to write parareal iterations in a readable notation
19using st_ptr = std::unique_ptr<coupling::interface::PintableMacroSolverState>;
20inline st_ptr operator+(const st_ptr& lhs, const st_ptr& rhs) { return *lhs + *rhs; }
21inline st_ptr operator-(const st_ptr& lhs, const st_ptr& rhs) { return *lhs - *rhs; }
22
30template <unsigned int dim> class coupling::services::ParallelTimeIntegrationService {
31public:
34
35 ParallelTimeIntegrationService(coupling::configurations::MaMiCoConfiguration<dim> mamicoConfig, Scenario* scenario)
36 : _pint_domain(0), _rank(0), _ranks_per_domain(1), _world_rank(0), _cfg(mamicoConfig.getTimeIntegrationConfiguration()), _scenario(scenario) {
37#if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
38 int world_size;
39 MPI_Comm_size(MPI_COMM_WORLD, &world_size);
40 MPI_Comm_rank(MPI_COMM_WORLD, &_world_rank);
41 if (_cfg.isPintEnabled()) {
42 if (world_size % _cfg.getPintDomains() != 0) {
43 if (_world_rank == 0) {
44 std::cout << "ERROR coupling::services::ParallelTimeIntegrationService: " << "MPI ranks not divisible by number of PinT subdomains!" << std::endl;
45 std::cout << "When PinT is used, the number of required MPI ranks increases by a factor of number-subdomains." << std::endl;
46 std::cout << "Check your configuration." << std::endl;
47 }
48 exit(EXIT_FAILURE);
49 }
50 _ranks_per_domain = world_size / _cfg.getPintDomains();
51 _pint_domain = _world_rank / _ranks_per_domain;
52 // This initializes _local_pint_comm by splitting MPI_COMM_WORLD into getPintDomains() disjoint communicators
53 MPI_Comm_split(MPI_COMM_WORLD, _pint_domain, _world_rank, &_local_pint_comm);
54 } else {
55 _local_pint_comm = MPI_COMM_WORLD;
56 }
57 MPI_Comm_rank(_local_pint_comm, &_rank);
58
59#ifdef PINT_DEBUG
60 std::cout << "PINT_DEBUG: world_rank " << _world_rank << " is rank " << _rank << " in pint domain " << _pint_domain << std::endl;
61#endif
62
63#endif
64 }
65
66 void run(int num_cycles) {
67 if (!_cfg.isPintEnabled())
68 run_cycles(0, num_cycles);
69 else {
70 PintDomain domain = setup_domain(num_cycles);
71 setup_solvers(domain);
72 init_parareal();
73 run_parareal(_cfg.getPintIterations());
74 }
75 }
76
77 int getPintDomain() const { return _pint_domain; }
78 int getRank() const { return _rank; }
79 bool isPintEnabled() const { return _cfg.isPintEnabled(); }
80 int getIteration() const { return _iteration; }
81
82#if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
83 MPI_Comm getPintComm() const { return _local_pint_comm; }
84#endif
85
86private:
87 void run_cycles(int start, int end) {
88 for (int cycle = start; cycle < end; cycle++)
89 _scenario->runOneCouplingCycle(cycle);
90 }
91
92 bool isFirst() const { return _pint_domain == 0; }
93 bool isLast() const { return _pint_domain == _cfg.getPintDomains() - 1; }
94
95 struct PintDomain {
97 int size;
102 };
103
104 PintDomain setup_domain(int num_cycles) const {
105 PintDomain res;
106 res.size = (int)(num_cycles / _cfg.getPintDomains());
107 res.minCycle = _pint_domain * res.size;
108 res.maxCycle = res.minCycle + res.size;
109 // In case num_cycles is not divisible by _cfg.getPintDomains(), the last domain gets the remainder
110 if (isLast())
111 res.maxCycle = num_cycles;
112#ifdef PINT_DEBUG
113 if (_rank == 0) {
114 std::cout << "PINT_DEBUG: _pint_domain " << _pint_domain << " has minCycle " << res.minCycle << " and maxCycle " << res.maxCycle << std::endl;
115 }
116#endif
117 return res;
118 }
119
120 void setup_solvers(PintDomain domain) {
121 auto solver = dynamic_cast<Solver*>(_scenario->getSolver());
122 if (solver == nullptr) {
123 std::cout << "ERROR coupling::services::ParallelTimeIntegrationService: "
124 << "macroscopic solver is not pintable (= not compatible with parallel in time coupling)" << std::endl;
125 exit(EXIT_FAILURE);
126 }
127#ifdef PINT_DEBUG
128 if (_cfg.getViscMultiplier() != 1.0)
129 if (_world_rank == 0) {
130 std::cout << "PINT_DEBUG: Starting supervisor with viscosity modified by " << _cfg.getViscMultiplier() << std::endl;
131 }
132#endif
133 _supervisor = solver->getSupervisor(domain.size, _cfg.getViscMultiplier());
134 _F = [this, solver, domain](const std::unique_ptr<State>& s) {
135 solver->setState(s, domain.minCycle);
136 if (domain.minCycle > 0)
137 _scenario->equilibrateMicro();
138 run_cycles(domain.minCycle, domain.maxCycle);
139 return solver->getState();
140 };
141 _G = [this, domain](const std::unique_ptr<State>& s) {
142 auto& G = *_supervisor;
143 return G(s, domain.minCycle);
144 };
145 _u_0 = solver->getState();
146 }
147
148 void receive(std::unique_ptr<State>& state) const {
149 if (!state)
150 state = _u_0->clone();
151#if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
152 int source_rank = _world_rank - _ranks_per_domain;
153 MPI_Recv(state->getData(), state->getSizeBytes(), MPI_BYTE, source_rank, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
154#endif
155 }
156
157 void send(std::unique_ptr<State>& state) const {
158#if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
159 int destination_rank = _world_rank + _ranks_per_domain;
160 MPI_Send(state->getData(), state->getSizeBytes(), MPI_BYTE, destination_rank, 0, MPI_COMM_WORLD);
161#endif
162 }
163
164 void get_past_state(std::unique_ptr<State>& state) const {
165 if (isFirst())
166 state = _u_0->clone();
167 else
168 receive(state);
169 }
170
171 void init_parareal() {
172 get_past_state(_u_last_past);
173 _u_last_future = _G(_u_last_past);
174 if (!isLast())
175 send(_u_last_future);
176 }
177
178 void run_parareal(int iterations) {
179 while (_iteration < iterations) {
180 // Correction step
181 auto delta = _F(_u_last_past) - _G(_u_last_past);
182 // Alternative variant, together with second _u_last_future line below.
183 // Better scalability, should yield the same results. TODO test and verify.
184 // auto delta = _F(_u_last_past) - _u_last_future;
185
186 _iteration++;
187
188 // Prediction step
189 get_past_state(_u_next_past);
190 auto prediction = _G(_u_next_past);
191
192 // Refinement step
193 _u_next_future = prediction + delta;
194 if (!isLast())
195 send(_u_next_future);
196
197 // move for next iteration
198 _u_last_past = std::move(_u_next_past);
199 _u_last_future = std::move(_u_next_future);
200 // Alternative variant
201 //_u_last_future = std::move(prediction);
202 }
203
204#ifdef PINT_DEBUG
205 if (_world_rank == 0) {
206 std::cout << "PINT_DEBUG: Finished all PinT iterations. " << std::endl;
207 }
208#endif
209 }
210
211 int _pint_domain; // the index of the time domain to which this process belongs to
212 int _rank; // rank of current process in _local_pint_comm
213 int _ranks_per_domain; // number of MPI ranks in each time domain
214 int _world_rank; // rank of current process in MPI_COMM_WORLD
215#if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
216 MPI_Comm _local_pint_comm; // the communicator of the local time domain of this rank
217#endif
218 coupling::configurations::TimeIntegrationConfiguration _cfg;
219 Scenario* _scenario;
220 std::unique_ptr<Solver> _supervisor; // The supervisor, i.e. the coarse predictor
221 std::function<std::unique_ptr<State>(const std::unique_ptr<State>&)> _F;
222 std::function<std::unique_ptr<State>(const std::unique_ptr<State>&)> _G;
223
224 // These objects represent coupled simulation states. There are used by the supervised parallel in time algorithm for operations
225 // "last" and "next" describe two consecutive parareal iterations
226 // "past" and "future" describe two points in simulation time, one pint time domain apart
227 std::unique_ptr<State> _u_0; // initial state
228 std::unique_ptr<State> _u_last_past;
229 std::unique_ptr<State> _u_last_future;
230 std::unique_ptr<State> _u_next_past;
231 std::unique_ptr<State> _u_next_future;
232
233 int _iteration{0};
234};
Definition Scenario.h:19
parses all sub-tags for MaMiCo configuration.
Definition MaMiCoConfiguration.h:31
const coupling::configurations::TimeIntegrationConfiguration & getTimeIntegrationConfiguration() const
Definition MaMiCoConfiguration.h:138
Definition PintableMacroSolver.h:67
Definition PintableMacroSolver.h:30
virtual std::unique_ptr< PintableMacroSolver > getSupervisor(int num_cycles, double visc_multiplier=1) const =0
Definition ParallelTimeIntegrationService.h:30
everything necessary for coupling operations, is defined in here
Definition AdditiveMomentumInsertion.h:15
Definition ParallelTimeIntegrationService.h:95
int minCycle
number of the first coupling cycle in this temporal domain (inclusive)
Definition ParallelTimeIntegrationService.h:99
int maxCycle
number of the last coupling cycle of this temporal domain (exclusive)
Definition ParallelTimeIntegrationService.h:101
int size
number of time steps (coupling cycles) in this temporal domain
Definition ParallelTimeIntegrationService.h:97