Line data Source code
1 : // Copyright (C) 2023 Helmut Schmidt University 2 : // This file is part of the Mamico project. For conditions of distribution 3 : // and use, please see the copyright notice in Mamico's main folder 4 : 5 : #pragma once 6 : 7 : namespace coupling { 8 : namespace services { 9 : template <unsigned int dim> class ParallelTimeIntegrationService; 10 : } // namespace services 11 : } // namespace coupling 12 : 13 : #include "coupling/configurations/MaMiCoConfiguration.h" 14 : #include "coupling/interface/PintableMacroSolver.h" 15 : #include "coupling/scenario/Scenario.h" 16 : #include <functional> 17 : 18 : // Convenience operators, to be able to write parareal iterations in a readable notation 19 : using st_ptr = std::unique_ptr<coupling::interface::PintableMacroSolverState>; 20 : inline st_ptr operator+(const st_ptr& lhs, const st_ptr& rhs){ 21 : return *lhs + *rhs; 22 : } 23 : inline st_ptr operator-(const st_ptr& lhs, const st_ptr& rhs){ 24 : return *lhs - *rhs; 25 : } 26 : 27 : /** 28 : * Service to manage timeloop of a coupled simulation scenario. Supports sequential or parallel-in-time integration using a Parareal variant, 29 : * as described in "Blumers, A. L., Li, Z., & Karniadakis, G. E. (2019). Supervised parallel-in-time algorithm for long-time Lagrangian 30 : * simulations of stochastic dynamics: Application to hydrodynamics. Journal of Computational Physics, 393, 214-228". 31 : * 32 : * @author Piet Jarmatz 33 : */ 34 : template <unsigned int dim> class coupling::services::ParallelTimeIntegrationService { 35 : public: 36 : using State = coupling::interface::PintableMacroSolverState; 37 : using Solver = coupling::interface::PintableMacroSolver; 38 : 39 : ParallelTimeIntegrationService(coupling::configurations::MaMiCoConfiguration<dim> mamicoConfig, Scenario* scenario): 40 : _pint_domain(0), _rank(0), _ranks_per_domain(1), _world_rank(0), 41 : _cfg( mamicoConfig.getTimeIntegrationConfiguration() ), 42 : _scenario(scenario) { 43 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 44 : int world_size; 45 : MPI_Comm_size(MPI_COMM_WORLD, &world_size); 46 : MPI_Comm_rank(MPI_COMM_WORLD, &_world_rank); 47 : if(_cfg.isPintEnabled()){ 48 : if(world_size % _cfg.getPintDomains() != 0){ 49 : if(_world_rank == 0){ 50 : std::cout << "ERROR coupling::services::ParallelTimeIntegrationService: " << 51 : "MPI ranks not divisible by number of PinT subdomains!" << std::endl; 52 : std::cout << "When PinT is used, the number of required MPI ranks increases by a factor of number-subdomains." << std::endl; 53 : std::cout << "Check your configuration." << std::endl; 54 : } 55 : exit(EXIT_FAILURE); 56 : } 57 : _ranks_per_domain = world_size / _cfg.getPintDomains(); 58 : _pint_domain = _world_rank / _ranks_per_domain; 59 : // This initializes _local_pint_comm by splitting MPI_COMM_WORLD into getPintDomains() disjoint communicators 60 : MPI_Comm_split(MPI_COMM_WORLD, _pint_domain, _world_rank, &_local_pint_comm); 61 : } else { 62 : _local_pint_comm = MPI_COMM_WORLD; 63 : } 64 : MPI_Comm_rank(_local_pint_comm, &_rank); 65 : 66 : #ifdef PINT_DEBUG 67 : std::cout << "PINT_DEBUG: world_rank " << _world_rank << " is rank " << _rank << " in pint domain " << _pint_domain << std::endl; 68 : #endif 69 : 70 : #endif 71 : } 72 : 73 : void run(int num_cycles) { 74 : if(!_cfg.isPintEnabled()) 75 : run_cycles(0, num_cycles); 76 : else { 77 : PintDomain domain = setup_domain(num_cycles); 78 : setup_solvers(domain); 79 : init_parareal(); 80 : run_parareal( _cfg.getPintIterations() ); 81 : } 82 : } 83 : 84 : int getPintDomain() const { return _pint_domain; } 85 : int getRank() const { return _rank; } 86 0 : bool isPintEnabled() const { return _cfg.isPintEnabled(); } 87 0 : int getInteration() const { return _iteration; } 88 : 89 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 90 : MPI_Comm getPintComm() const { return _local_pint_comm; } 91 : #endif 92 : 93 : private: 94 : void run_cycles(int start, int end){ 95 : for (int cycle = start; cycle < end; cycle++) 96 : _scenario->runOneCouplingCycle(cycle); 97 : } 98 : 99 : bool isFirst() const { return _pint_domain == 0; } 100 : bool isLast() const { return _pint_domain == _cfg.getPintDomains()-1; } 101 : 102 : struct PintDomain { 103 : /** @brief number of time steps (coupling cycles) in this temporal domain */ 104 : int size; 105 : /** @brief number of the first coupling cycle in this temporal domain (inclusive) */ 106 : int minCycle; 107 : /** @brief number of the last coupling cycle of this temporal domain (exclusive) */ 108 : int maxCycle; 109 : }; 110 : 111 : PintDomain setup_domain(int num_cycles) const { 112 : PintDomain res; 113 : res.size = (int)( num_cycles / _cfg.getPintDomains() ); 114 : res.minCycle = _pint_domain * res.size; 115 : res.maxCycle = res.minCycle + res.size; 116 : // In case num_cycles is not divisible by _cfg.getPintDomains(), the last domain gets the remainder 117 : if( isLast() ) 118 : res.maxCycle = num_cycles; 119 : #ifdef PINT_DEBUG 120 : std::cout << "PINT_DEBUG: _pint_domain " << _pint_domain << " has minCycle " << res.minCycle 121 : << " and maxCycle " << res.maxCycle << std::endl; 122 : #endif 123 : return res; 124 : } 125 : 126 : void setup_solvers(PintDomain domain){ 127 : auto solver = dynamic_cast<Solver*>( _scenario->getSolver() ); 128 : if(solver == nullptr){ 129 : std::cout << "ERROR coupling::services::ParallelTimeIntegrationService: " << 130 : "macroscopic solver is not pintable (= not compatible with parallel in time coupling)" << std::endl; 131 : exit(EXIT_FAILURE); 132 : } 133 : #ifdef PINT_DEBUG 134 : if(_cfg.getViscMultiplier() != 1.0) 135 : std::cout << "PINT_DEBUG: Starting supervisor with viscosity modified by " << _cfg.getViscMultiplier() << std::endl; 136 : #endif 137 : _supervisor = solver->getSupervisor(domain.size, _cfg.getViscMultiplier() ); 138 : _F = [this, solver, domain](const std::unique_ptr<State>& s){ 139 : solver->setState(s, domain.minCycle); 140 : if(domain.minCycle > 0) 141 : _scenario->equilibrateMicro(); 142 : run_cycles(domain.minCycle, domain.maxCycle); 143 : return solver->getState(); 144 : }; 145 : _G = [this, domain](const std::unique_ptr<State>& s){ 146 : auto& G = *_supervisor; 147 : return G(s, domain.minCycle); 148 : }; 149 : _u_0 = solver->getState(); 150 : } 151 : 152 : void receive(std::unique_ptr<State>& state) const{ 153 : if(!state) state = _u_0->clone(); 154 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 155 : int source_rank = _world_rank - _ranks_per_domain; 156 : MPI_Recv(state->getData(), state->getSizeBytes(), MPI_BYTE, source_rank, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); 157 : #endif 158 : } 159 : 160 : void send(std::unique_ptr<State>& state) const { 161 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 162 : int destination_rank = _world_rank + _ranks_per_domain; 163 : MPI_Send(state->getData(), state->getSizeBytes(), MPI_BYTE, destination_rank, 0, MPI_COMM_WORLD); 164 : #endif 165 : } 166 : 167 : void get_past_state(std::unique_ptr<State>& state) const { 168 : if( isFirst() ) 169 : state = _u_0->clone(); 170 : else 171 : receive(state); 172 : } 173 : 174 : void init_parareal(){ 175 : get_past_state(_u_last_past); 176 : _u_last_future = _G(_u_last_past); 177 : if( !isLast() ) 178 : send(_u_last_future); 179 : } 180 : 181 : void run_parareal(int iterations){ 182 : while(_iteration < iterations){ 183 : // Correction step 184 : auto delta = _F(_u_last_past) - _G(_u_last_past); 185 : // Alternative variant, together with second _u_last_future line below. 186 : // Better scalability, should yield the same results. TODO test and verify. 187 : //auto delta = _F(_u_last_past) - _u_last_future; 188 : 189 : _iteration++; 190 : 191 : // Prediction step 192 : get_past_state(_u_next_past); 193 : auto prediction = _G(_u_next_past); 194 : 195 : // Refinement step 196 : _u_next_future = prediction + delta; 197 : if( !isLast() ) 198 : send(_u_next_future); 199 : 200 : // move for next iteration 201 : _u_last_past = std::move(_u_next_past); 202 : _u_last_future = std::move(_u_next_future); 203 : // Alternative variant 204 : //_u_last_future = std::move(prediction); 205 : } 206 : 207 : #ifdef PINT_DEBUG 208 : std::cout << "PINT_DEBUG: Finished all PinT iterations. " << std::endl; 209 : #endif 210 : } 211 : 212 : int _pint_domain; // the index of the time domain to which this process belongs to 213 : int _rank; // rank of current process in _local_pint_comm 214 : int _ranks_per_domain; // number of MPI ranks in each time domain 215 : int _world_rank; // rank of current process in MPI_COMM_WORLD 216 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 217 : MPI_Comm _local_pint_comm; // the communicator of the local time domain of this rank 218 : #endif 219 : coupling::configurations::TimeIntegrationConfiguration _cfg; 220 : Scenario* _scenario; 221 : std::unique_ptr<Solver> _supervisor; // The supervisor, i.e. the coarse predictor 222 : std::function<std::unique_ptr<State>(const std::unique_ptr<State>&)> _F; 223 : std::function<std::unique_ptr<State>(const std::unique_ptr<State>&)> _G; 224 : 225 : // These objects represent coupled simulation states. There are used by the supervised parallel in time algorithm for operations 226 : // "last" and "next" describe two consecutive parareal iterations 227 : // "past" and "future" describe two points in simulation time, one pint time domain apart 228 : std::unique_ptr<State> _u_0; // initial state 229 : std::unique_ptr<State> _u_last_past; 230 : std::unique_ptr<State> _u_last_future; 231 : std::unique_ptr<State> _u_next_past; 232 : std::unique_ptr<State> _u_next_future; 233 : 234 : int _iteration{0}; 235 : };