Line data Source code
1 : // Copyright (C) 2023 Helmut Schmidt University 2 : // This file is part of the Mamico project. For conditions of distribution 3 : // and use, please see the copyright notice in Mamico's main folder 4 : 5 : #pragma once 6 : 7 : namespace coupling { 8 : namespace services { 9 : template <unsigned int dim> class ParallelTimeIntegrationService; 10 : } // namespace services 11 : } // namespace coupling 12 : 13 : #include "coupling/configurations/MaMiCoConfiguration.h" 14 : #include "coupling/interface/PintableMacroSolver.h" 15 : #include "coupling/scenario/Scenario.h" 16 : #include <functional> 17 : 18 : // Convenience operators, to be able to write parareal iterations in a readable notation 19 : using st_ptr = std::unique_ptr<coupling::interface::PintableMacroSolverState>; 20 : inline st_ptr operator+(const st_ptr& lhs, const st_ptr& rhs) { return *lhs + *rhs; } 21 : inline st_ptr operator-(const st_ptr& lhs, const st_ptr& rhs) { return *lhs - *rhs; } 22 : 23 : /** 24 : * Service to manage timeloop of a coupled simulation scenario. Supports sequential or parallel-in-time integration using a Parareal variant, 25 : * as described in "Blumers, A. L., Li, Z., & Karniadakis, G. E. (2019). Supervised parallel-in-time algorithm for long-time Lagrangian 26 : * simulations of stochastic dynamics: Application to hydrodynamics. Journal of Computational Physics, 393, 214-228". 27 : * 28 : * @author Piet Jarmatz 29 : */ 30 : template <unsigned int dim> class coupling::services::ParallelTimeIntegrationService { 31 : public: 32 : using State = coupling::interface::PintableMacroSolverState; 33 : using Solver = coupling::interface::PintableMacroSolver; 34 : 35 : ParallelTimeIntegrationService(coupling::configurations::MaMiCoConfiguration<dim> mamicoConfig, Scenario* scenario) 36 : : _pint_domain(0), _rank(0), _ranks_per_domain(1), _world_rank(0), _cfg(mamicoConfig.getTimeIntegrationConfiguration()), _scenario(scenario) { 37 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 38 : int world_size; 39 : MPI_Comm_size(MPI_COMM_WORLD, &world_size); 40 : MPI_Comm_rank(MPI_COMM_WORLD, &_world_rank); 41 : if (_cfg.isPintEnabled()) { 42 : if (world_size % _cfg.getPintDomains() != 0) { 43 : if (_world_rank == 0) { 44 : std::cout << "ERROR coupling::services::ParallelTimeIntegrationService: " << "MPI ranks not divisible by number of PinT subdomains!" << std::endl; 45 : std::cout << "When PinT is used, the number of required MPI ranks increases by a factor of number-subdomains." << std::endl; 46 : std::cout << "Check your configuration." << std::endl; 47 : } 48 : exit(EXIT_FAILURE); 49 : } 50 : _ranks_per_domain = world_size / _cfg.getPintDomains(); 51 : _pint_domain = _world_rank / _ranks_per_domain; 52 : // This initializes _local_pint_comm by splitting MPI_COMM_WORLD into getPintDomains() disjoint communicators 53 : MPI_Comm_split(MPI_COMM_WORLD, _pint_domain, _world_rank, &_local_pint_comm); 54 : } else { 55 : _local_pint_comm = MPI_COMM_WORLD; 56 : } 57 : MPI_Comm_rank(_local_pint_comm, &_rank); 58 : 59 : #ifdef PINT_DEBUG 60 : std::cout << "PINT_DEBUG: world_rank " << _world_rank << " is rank " << _rank << " in pint domain " << _pint_domain << std::endl; 61 : #endif 62 : 63 : #endif 64 : } 65 : 66 : void run(int num_cycles) { 67 : if (!_cfg.isPintEnabled()) 68 : run_cycles(0, num_cycles); 69 : else { 70 : PintDomain domain = setup_domain(num_cycles); 71 : setup_solvers(domain); 72 : init_parareal(); 73 : run_parareal(_cfg.getPintIterations()); 74 : } 75 : } 76 : 77 : int getPintDomain() const { return _pint_domain; } 78 : int getRank() const { return _rank; } 79 0 : bool isPintEnabled() const { return _cfg.isPintEnabled(); } 80 0 : int getIteration() const { return _iteration; } 81 : 82 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 83 : MPI_Comm getPintComm() const { return _local_pint_comm; } 84 : #endif 85 : 86 : private: 87 : void run_cycles(int start, int end) { 88 : for (int cycle = start; cycle < end; cycle++) 89 : _scenario->runOneCouplingCycle(cycle); 90 : } 91 : 92 : bool isFirst() const { return _pint_domain == 0; } 93 : bool isLast() const { return _pint_domain == _cfg.getPintDomains() - 1; } 94 : 95 : struct PintDomain { 96 : /** @brief number of time steps (coupling cycles) in this temporal domain */ 97 : int size; 98 : /** @brief number of the first coupling cycle in this temporal domain (inclusive) */ 99 : int minCycle; 100 : /** @brief number of the last coupling cycle of this temporal domain (exclusive) */ 101 : int maxCycle; 102 : }; 103 : 104 : PintDomain setup_domain(int num_cycles) const { 105 : PintDomain res; 106 : res.size = (int)(num_cycles / _cfg.getPintDomains()); 107 : res.minCycle = _pint_domain * res.size; 108 : res.maxCycle = res.minCycle + res.size; 109 : // In case num_cycles is not divisible by _cfg.getPintDomains(), the last domain gets the remainder 110 : if (isLast()) 111 : res.maxCycle = num_cycles; 112 : #ifdef PINT_DEBUG 113 : if (_rank == 0) { 114 : std::cout << "PINT_DEBUG: _pint_domain " << _pint_domain << " has minCycle " << res.minCycle << " and maxCycle " << res.maxCycle << std::endl; 115 : } 116 : #endif 117 : return res; 118 : } 119 : 120 : void setup_solvers(PintDomain domain) { 121 : auto solver = dynamic_cast<Solver*>(_scenario->getSolver()); 122 : if (solver == nullptr) { 123 : std::cout << "ERROR coupling::services::ParallelTimeIntegrationService: " 124 : << "macroscopic solver is not pintable (= not compatible with parallel in time coupling)" << std::endl; 125 : exit(EXIT_FAILURE); 126 : } 127 : #ifdef PINT_DEBUG 128 : if (_cfg.getViscMultiplier() != 1.0) 129 : if (_world_rank == 0) { 130 : std::cout << "PINT_DEBUG: Starting supervisor with viscosity modified by " << _cfg.getViscMultiplier() << std::endl; 131 : } 132 : #endif 133 : _supervisor = solver->getSupervisor(domain.size, _cfg.getViscMultiplier()); 134 : _F = [this, solver, domain](const std::unique_ptr<State>& s) { 135 : solver->setState(s, domain.minCycle); 136 : if (domain.minCycle > 0) 137 : _scenario->equilibrateMicro(); 138 : run_cycles(domain.minCycle, domain.maxCycle); 139 : return solver->getState(); 140 : }; 141 : _G = [this, domain](const std::unique_ptr<State>& s) { 142 : auto& G = *_supervisor; 143 : return G(s, domain.minCycle); 144 : }; 145 : _u_0 = solver->getState(); 146 : } 147 : 148 : void receive(std::unique_ptr<State>& state) const { 149 : if (!state) 150 : state = _u_0->clone(); 151 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 152 : int source_rank = _world_rank - _ranks_per_domain; 153 : MPI_Recv(state->getData(), state->getSizeBytes(), MPI_BYTE, source_rank, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); 154 : #endif 155 : } 156 : 157 : void send(std::unique_ptr<State>& state) const { 158 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 159 : int destination_rank = _world_rank + _ranks_per_domain; 160 : MPI_Send(state->getData(), state->getSizeBytes(), MPI_BYTE, destination_rank, 0, MPI_COMM_WORLD); 161 : #endif 162 : } 163 : 164 : void get_past_state(std::unique_ptr<State>& state) const { 165 : if (isFirst()) 166 : state = _u_0->clone(); 167 : else 168 : receive(state); 169 : } 170 : 171 : void init_parareal() { 172 : get_past_state(_u_last_past); 173 : _u_last_future = _G(_u_last_past); 174 : if (!isLast()) 175 : send(_u_last_future); 176 : } 177 : 178 : void run_parareal(int iterations) { 179 : while (_iteration < iterations) { 180 : // Correction step 181 : auto delta = _F(_u_last_past) - _G(_u_last_past); 182 : // Alternative variant, together with second _u_last_future line below. 183 : // Better scalability, should yield the same results. TODO test and verify. 184 : // auto delta = _F(_u_last_past) - _u_last_future; 185 : 186 : _iteration++; 187 : 188 : // Prediction step 189 : get_past_state(_u_next_past); 190 : auto prediction = _G(_u_next_past); 191 : 192 : // Refinement step 193 : _u_next_future = prediction + delta; 194 : if (!isLast()) 195 : send(_u_next_future); 196 : 197 : // move for next iteration 198 : _u_last_past = std::move(_u_next_past); 199 : _u_last_future = std::move(_u_next_future); 200 : // Alternative variant 201 : //_u_last_future = std::move(prediction); 202 : } 203 : 204 : #ifdef PINT_DEBUG 205 : if (_world_rank == 0) { 206 : std::cout << "PINT_DEBUG: Finished all PinT iterations. " << std::endl; 207 : } 208 : #endif 209 : } 210 : 211 : int _pint_domain; // the index of the time domain to which this process belongs to 212 : int _rank; // rank of current process in _local_pint_comm 213 : int _ranks_per_domain; // number of MPI ranks in each time domain 214 : int _world_rank; // rank of current process in MPI_COMM_WORLD 215 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 216 : MPI_Comm _local_pint_comm; // the communicator of the local time domain of this rank 217 : #endif 218 : coupling::configurations::TimeIntegrationConfiguration _cfg; 219 : Scenario* _scenario; 220 : std::unique_ptr<Solver> _supervisor; // The supervisor, i.e. the coarse predictor 221 : std::function<std::unique_ptr<State>(const std::unique_ptr<State>&)> _F; 222 : std::function<std::unique_ptr<State>(const std::unique_ptr<State>&)> _G; 223 : 224 : // These objects represent coupled simulation states. There are used by the supervised parallel in time algorithm for operations 225 : // "last" and "next" describe two consecutive parareal iterations 226 : // "past" and "future" describe two points in simulation time, one pint time domain apart 227 : std::unique_ptr<State> _u_0; // initial state 228 : std::unique_ptr<State> _u_last_past; 229 : std::unique_ptr<State> _u_last_future; 230 : std::unique_ptr<State> _u_next_past; 231 : std::unique_ptr<State> _u_next_future; 232 : 233 : int _iteration{0}; 234 : };