37#if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
39 MPI_Comm_size(MPI_COMM_WORLD, &world_size);
40 MPI_Comm_rank(MPI_COMM_WORLD, &_world_rank);
41 if (_cfg.isPintEnabled()) {
42 if (world_size % _cfg.getPintDomains() != 0) {
43 if (_world_rank == 0) {
44 std::cout <<
"ERROR coupling::services::ParallelTimeIntegrationService: " <<
"MPI ranks not divisible by number of PinT subdomains!" << std::endl;
45 std::cout <<
"When PinT is used, the number of required MPI ranks increases by a factor of number-subdomains." << std::endl;
46 std::cout <<
"Check your configuration." << std::endl;
50 _ranks_per_domain = world_size / _cfg.getPintDomains();
51 _pint_domain = _world_rank / _ranks_per_domain;
53 MPI_Comm_split(MPI_COMM_WORLD, _pint_domain, _world_rank, &_local_pint_comm);
55 _local_pint_comm = MPI_COMM_WORLD;
57 MPI_Comm_rank(_local_pint_comm, &_rank);
60 std::cout <<
"PINT_DEBUG: world_rank " << _world_rank <<
" is rank " << _rank <<
" in pint domain " << _pint_domain << std::endl;
66 void run(
int num_cycles) {
67 if (!_cfg.isPintEnabled())
68 run_cycles(0, num_cycles);
71 setup_solvers(domain);
73 run_parareal(_cfg.getPintIterations());
77 int getPintDomain()
const {
return _pint_domain; }
78 int getRank()
const {
return _rank; }
79 bool isPintEnabled()
const {
return _cfg.isPintEnabled(); }
80 int getIteration()
const {
return _iteration; }
82#if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
83 MPI_Comm getPintComm()
const {
return _local_pint_comm; }
87 void run_cycles(
int start,
int end) {
88 for (
int cycle = start; cycle < end; cycle++)
89 _scenario->runOneCouplingCycle(cycle);
92 bool isFirst()
const {
return _pint_domain == 0; }
93 bool isLast()
const {
return _pint_domain == _cfg.getPintDomains() - 1; }
104 PintDomain setup_domain(
int num_cycles)
const {
106 res.
size = (int)(num_cycles / _cfg.getPintDomains());
114 std::cout <<
"PINT_DEBUG: _pint_domain " << _pint_domain <<
" has minCycle " << res.
minCycle <<
" and maxCycle " << res.
maxCycle << std::endl;
120 void setup_solvers(PintDomain domain) {
121 auto solver =
dynamic_cast<Solver*
>(_scenario->getSolver());
122 if (solver ==
nullptr) {
123 std::cout <<
"ERROR coupling::services::ParallelTimeIntegrationService: "
124 <<
"macroscopic solver is not pintable (= not compatible with parallel in time coupling)" << std::endl;
128 if (_cfg.getViscMultiplier() != 1.0)
129 if (_world_rank == 0) {
130 std::cout <<
"PINT_DEBUG: Starting supervisor with viscosity modified by " << _cfg.getViscMultiplier() << std::endl;
133 _supervisor = solver->
getSupervisor(domain.size, _cfg.getViscMultiplier());
134 _F = [
this, solver, domain](
const std::unique_ptr<State>& s) {
135 solver->setState(s, domain.minCycle);
136 if (domain.minCycle > 0)
137 _scenario->equilibrateMicro();
138 run_cycles(domain.minCycle, domain.maxCycle);
139 return solver->getState();
141 _G = [
this, domain](
const std::unique_ptr<State>& s) {
142 auto& G = *_supervisor;
143 return G(s, domain.minCycle);
145 _u_0 = solver->getState();
148 void receive(std::unique_ptr<State>& state)
const {
150 state = _u_0->clone();
151#if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
152 int source_rank = _world_rank - _ranks_per_domain;
153 MPI_Recv(state->getData(), state->getSizeBytes(), MPI_BYTE, source_rank, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
157 void send(std::unique_ptr<State>& state)
const {
158#if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
159 int destination_rank = _world_rank + _ranks_per_domain;
160 MPI_Send(state->getData(), state->getSizeBytes(), MPI_BYTE, destination_rank, 0, MPI_COMM_WORLD);
164 void get_past_state(std::unique_ptr<State>& state)
const {
166 state = _u_0->clone();
171 void init_parareal() {
172 get_past_state(_u_last_past);
173 _u_last_future = _G(_u_last_past);
175 send(_u_last_future);
178 void run_parareal(
int iterations) {
179 while (_iteration < iterations) {
181 auto delta = _F(_u_last_past) - _G(_u_last_past);
189 get_past_state(_u_next_past);
190 auto prediction = _G(_u_next_past);
193 _u_next_future = prediction + delta;
195 send(_u_next_future);
198 _u_last_past = std::move(_u_next_past);
199 _u_last_future = std::move(_u_next_future);
205 if (_world_rank == 0) {
206 std::cout <<
"PINT_DEBUG: Finished all PinT iterations. " << std::endl;
213 int _ranks_per_domain;
215#if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
216 MPI_Comm _local_pint_comm;
218 coupling::configurations::TimeIntegrationConfiguration _cfg;
220 std::unique_ptr<Solver> _supervisor;
221 std::function<std::unique_ptr<State>(
const std::unique_ptr<State>&)> _F;
222 std::function<std::unique_ptr<State>(
const std::unique_ptr<State>&)> _G;
227 std::unique_ptr<State> _u_0;
228 std::unique_ptr<State> _u_last_past;
229 std::unique_ptr<State> _u_last_future;
230 std::unique_ptr<State> _u_next_past;
231 std::unique_ptr<State> _u_next_future;
int minCycle
number of the first coupling cycle in this temporal domain (inclusive)
Definition ParallelTimeIntegrationService.h:99
int maxCycle
number of the last coupling cycle of this temporal domain (exclusive)
Definition ParallelTimeIntegrationService.h:101
int size
number of time steps (coupling cycles) in this temporal domain
Definition ParallelTimeIntegrationService.h:97