MaMiCo 1.2
Loading...
Searching...
No Matches
ParallelTimeIntegrationService.h
1// Copyright (C) 2023 Helmut Schmidt University
2// This file is part of the Mamico project. For conditions of distribution
3// and use, please see the copyright notice in Mamico's main folder
4
5#pragma once
6
7namespace coupling {
8namespace services {
9template <unsigned int dim> class ParallelTimeIntegrationService;
10} // namespace services
11} // namespace coupling
12
13#include "coupling/configurations/MaMiCoConfiguration.h"
14#include "coupling/interface/PintableMacroSolver.h"
15#include "coupling/scenario/Scenario.h"
16#include <functional>
17
18// Convenience operators, to be able to write parareal iterations in a readable notation
19using st_ptr = std::unique_ptr<coupling::interface::PintableMacroSolverState>;
20inline st_ptr operator+(const st_ptr& lhs, const st_ptr& rhs){
21 return *lhs + *rhs;
22}
23inline st_ptr operator-(const st_ptr& lhs, const st_ptr& rhs){
24 return *lhs - *rhs;
25}
26
34template <unsigned int dim> class coupling::services::ParallelTimeIntegrationService {
35public:
38
39 ParallelTimeIntegrationService(coupling::configurations::MaMiCoConfiguration<dim> mamicoConfig, Scenario* scenario):
40 _pint_domain(0), _rank(0), _ranks_per_domain(1), _world_rank(0),
41 _cfg( mamicoConfig.getTimeIntegrationConfiguration() ),
42 _scenario(scenario) {
43 #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
44 int world_size;
45 MPI_Comm_size(MPI_COMM_WORLD, &world_size);
46 MPI_Comm_rank(MPI_COMM_WORLD, &_world_rank);
47 if(_cfg.isPintEnabled()){
48 if(world_size % _cfg.getPintDomains() != 0){
49 if(_world_rank == 0){
50 std::cout << "ERROR coupling::services::ParallelTimeIntegrationService: " <<
51 "MPI ranks not divisible by number of PinT subdomains!" << std::endl;
52 std::cout << "When PinT is used, the number of required MPI ranks increases by a factor of number-subdomains." << std::endl;
53 std::cout << "Check your configuration." << std::endl;
54 }
55 exit(EXIT_FAILURE);
56 }
57 _ranks_per_domain = world_size / _cfg.getPintDomains();
58 _pint_domain = _world_rank / _ranks_per_domain;
59 // This initializes _local_pint_comm by splitting MPI_COMM_WORLD into getPintDomains() disjoint communicators
60 MPI_Comm_split(MPI_COMM_WORLD, _pint_domain, _world_rank, &_local_pint_comm);
61 } else {
62 _local_pint_comm = MPI_COMM_WORLD;
63 }
64 MPI_Comm_rank(_local_pint_comm, &_rank);
65
66 #ifdef PINT_DEBUG
67 std::cout << "PINT_DEBUG: world_rank " << _world_rank << " is rank " << _rank << " in pint domain " << _pint_domain << std::endl;
68 #endif
69
70 #endif
71 }
72
73 void run(int num_cycles) {
74 if(!_cfg.isPintEnabled())
75 run_cycles(0, num_cycles);
76 else {
77 PintDomain domain = setup_domain(num_cycles);
78 setup_solvers(domain);
79 init_parareal();
80 run_parareal( _cfg.getPintIterations() );
81 }
82 }
83
84 int getPintDomain() const { return _pint_domain; }
85 int getRank() const { return _rank; }
86 bool isPintEnabled() const { return _cfg.isPintEnabled(); }
87 int getInteration() const { return _iteration; }
88
89 #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
90 MPI_Comm getPintComm() const { return _local_pint_comm; }
91 #endif
92
93private:
94 void run_cycles(int start, int end){
95 for (int cycle = start; cycle < end; cycle++)
96 _scenario->runOneCouplingCycle(cycle);
97 }
98
99 bool isFirst() const { return _pint_domain == 0; }
100 bool isLast() const { return _pint_domain == _cfg.getPintDomains()-1; }
101
102 struct PintDomain {
104 int size;
109 };
110
111 PintDomain setup_domain(int num_cycles) const {
112 PintDomain res;
113 res.size = (int)( num_cycles / _cfg.getPintDomains() );
114 res.minCycle = _pint_domain * res.size;
115 res.maxCycle = res.minCycle + res.size;
116 // In case num_cycles is not divisible by _cfg.getPintDomains(), the last domain gets the remainder
117 if( isLast() )
118 res.maxCycle = num_cycles;
119 #ifdef PINT_DEBUG
120 std::cout << "PINT_DEBUG: _pint_domain " << _pint_domain << " has minCycle " << res.minCycle
121 << " and maxCycle " << res.maxCycle << std::endl;
122 #endif
123 return res;
124 }
125
126 void setup_solvers(PintDomain domain){
127 auto solver = dynamic_cast<Solver*>( _scenario->getSolver() );
128 if(solver == nullptr){
129 std::cout << "ERROR coupling::services::ParallelTimeIntegrationService: " <<
130 "macroscopic solver is not pintable (= not compatible with parallel in time coupling)" << std::endl;
131 exit(EXIT_FAILURE);
132 }
133 #ifdef PINT_DEBUG
134 if(_cfg.getViscMultiplier() != 1.0)
135 std::cout << "PINT_DEBUG: Starting supervisor with viscosity modified by " << _cfg.getViscMultiplier() << std::endl;
136 #endif
137 _supervisor = solver->getSupervisor(domain.size, _cfg.getViscMultiplier() );
138 _F = [this, solver, domain](const std::unique_ptr<State>& s){
139 solver->setState(s, domain.minCycle);
140 // TODO enable momentumTransfer on inner cells for MD equilibration steps here
141 // TODO run MD equilibration here
142 run_cycles(domain.minCycle, domain.maxCycle);
143 // TODO double check that filter pipeline output ends up in solver.getState() here
144 return solver->getState();
145 };
146 _G = [this, domain](const std::unique_ptr<State>& s){
147 auto& G = *_supervisor;
148 return G(s, domain.minCycle);
149 };
150 _u_0 = solver->getState();
151 }
152
153 void receive(std::unique_ptr<State>& state) const{
154 if(!state) state = _u_0->clone();
155 #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
156 int source_rank = _world_rank - _ranks_per_domain;
157 MPI_Recv(state->getData(), state->getSizeBytes(), MPI_BYTE, source_rank, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
158 #endif
159 }
160
161 void send(std::unique_ptr<State>& state) const {
162 #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
163 int destination_rank = _world_rank + _ranks_per_domain;
164 MPI_Send(state->getData(), state->getSizeBytes(), MPI_BYTE, destination_rank, 0, MPI_COMM_WORLD);
165 #endif
166 }
167
168 void get_past_state(std::unique_ptr<State>& state) const {
169 if( isFirst() )
170 state = _u_0->clone();
171 else
172 receive(state);
173 }
174
175 void init_parareal(){
176 get_past_state(_u_last_past);
177 _u_last_future = _G(_u_last_past);
178 if( !isLast() )
179 send(_u_last_future);
180 }
181
182 void run_parareal(int iterations){
183 while(_iteration < iterations){
184 // Correction step
185 auto delta = _F(_u_last_past) - _G(_u_last_past);
186
187 _iteration++;
188
189 // Prediction step
190 get_past_state(_u_next_past);
191 auto prediction = _G(_u_next_past);
192
193 // Refinement step
194 _u_next_future = prediction + delta;
195 if( !isLast() )
196 send(_u_next_future);
197
198 // move for next iteration
199 _u_last_past = std::move(_u_next_past);
200 _u_last_future = std::move(_u_next_future);
201 }
202
203 #ifdef PINT_DEBUG
204 std::cout << "PINT_DEBUG: Finished all PinT iterations. " << std::endl;
205 #endif
206 }
207
208 int _pint_domain; // the index of the time domain to which this process belongs to
209 int _rank; // rank of current process in _local_pint_comm
210 int _ranks_per_domain; // number of MPI ranks in each time domain
211 int _world_rank; // rank of current process in MPI_COMM_WORLD
212 #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
213 MPI_Comm _local_pint_comm; // the communicator of the local time domain of this rank
214 #endif
215 coupling::configurations::TimeIntegrationConfiguration _cfg;
216 Scenario* _scenario;
217 std::unique_ptr<Solver> _supervisor; // The supervisor, i.e. the coarse predictor
218 std::function<std::unique_ptr<State>(const std::unique_ptr<State>&)> _F;
219 std::function<std::unique_ptr<State>(const std::unique_ptr<State>&)> _G;
220
221 // These objects represent coupled simulation states. There are used by the supervised parallel in time algorithm for operations
222 // "last" and "next" describe two consecutive parareal iterations
223 // "past" and "future" describe two points in simulation time, one pint time domain apart
224 std::unique_ptr<State> _u_0; // initial state
225 std::unique_ptr<State> _u_last_past;
226 std::unique_ptr<State> _u_last_future;
227 std::unique_ptr<State> _u_next_past;
228 std::unique_ptr<State> _u_next_future;
229
230 int _iteration{0};
231};
Definition Scenario.h:15
parses all sub-tags for MaMiCo configuration.
Definition MaMiCoConfiguration.h:31
const coupling::configurations::TimeIntegrationConfiguration & getTimeIntegrationConfiguration() const
Definition MaMiCoConfiguration.h:138
Definition PintableMacroSolver.h:67
Definition PintableMacroSolver.h:30
Definition ParallelTimeIntegrationService.h:34
everything necessary for coupling operations, is defined in here
Definition AdditiveMomentumInsertion.h:15
Definition ParallelTimeIntegrationService.h:102
int minCycle
number of the first coupling cycle in this temporal domain (inclusive)
Definition ParallelTimeIntegrationService.h:106
int maxCycle
number of the last coupling cycle of this temporal domain (exclusive)
Definition ParallelTimeIntegrationService.h:108
int size
number of time steps (coupling cycles) in this temporal domain
Definition ParallelTimeIntegrationService.h:104