Line data Source code
1 : // Copyright (C) 2023 Helmut Schmidt University 2 : // This file is part of the Mamico project. For conditions of distribution 3 : // and use, please see the copyright notice in Mamico's main folder 4 : 5 : #pragma once 6 : 7 : namespace coupling { 8 : namespace services { 9 : template <unsigned int dim> class ParallelTimeIntegrationService; 10 : } // namespace services 11 : } // namespace coupling 12 : 13 : #include "coupling/configurations/MaMiCoConfiguration.h" 14 : #include "coupling/interface/PintableMacroSolver.h" 15 : #include "coupling/scenario/Scenario.h" 16 : #include <functional> 17 : 18 : // Convenience operators, to be able to write parareal iterations in a readable notation 19 : using st_ptr = std::unique_ptr<coupling::interface::PintableMacroSolverState>; 20 : inline st_ptr operator+(const st_ptr& lhs, const st_ptr& rhs){ 21 : return *lhs + *rhs; 22 : } 23 : inline st_ptr operator-(const st_ptr& lhs, const st_ptr& rhs){ 24 : return *lhs - *rhs; 25 : } 26 : 27 : /** 28 : * Service to manage timeloop of a coupled simulation scenario. Supports sequential or parallel-in-time integration using a Parareal variant, 29 : * as described in "Blumers, A. L., Li, Z., & Karniadakis, G. E. (2019). Supervised parallel-in-time algorithm for long-time Lagrangian 30 : * simulations of stochastic dynamics: Application to hydrodynamics. Journal of Computational Physics, 393, 214-228". 31 : * 32 : * @author Piet Jarmatz 33 : */ 34 : template <unsigned int dim> class coupling::services::ParallelTimeIntegrationService { 35 : public: 36 : using State = coupling::interface::PintableMacroSolverState; 37 : using Solver = coupling::interface::PintableMacroSolver; 38 : 39 : ParallelTimeIntegrationService(coupling::configurations::MaMiCoConfiguration<dim> mamicoConfig, Scenario* scenario): 40 : _pint_domain(0), _rank(0), _ranks_per_domain(1), _world_rank(0), 41 : _cfg( mamicoConfig.getTimeIntegrationConfiguration() ), 42 : _scenario(scenario) { 43 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 44 : int world_size; 45 : MPI_Comm_size(MPI_COMM_WORLD, &world_size); 46 : MPI_Comm_rank(MPI_COMM_WORLD, &_world_rank); 47 : if(_cfg.isPintEnabled()){ 48 : if(world_size % _cfg.getPintDomains() != 0){ 49 : if(_world_rank == 0){ 50 : std::cout << "ERROR coupling::services::ParallelTimeIntegrationService: " << 51 : "MPI ranks not divisible by number of PinT subdomains!" << std::endl; 52 : std::cout << "When PinT is used, the number of required MPI ranks increases by a factor of number-subdomains." << std::endl; 53 : std::cout << "Check your configuration." << std::endl; 54 : } 55 : exit(EXIT_FAILURE); 56 : } 57 : _ranks_per_domain = world_size / _cfg.getPintDomains(); 58 : _pint_domain = _world_rank / _ranks_per_domain; 59 : // This initializes _local_pint_comm by splitting MPI_COMM_WORLD into getPintDomains() disjoint communicators 60 : MPI_Comm_split(MPI_COMM_WORLD, _pint_domain, _world_rank, &_local_pint_comm); 61 : } else { 62 : _local_pint_comm = MPI_COMM_WORLD; 63 : } 64 : MPI_Comm_rank(_local_pint_comm, &_rank); 65 : 66 : #ifdef PINT_DEBUG 67 : std::cout << "PINT_DEBUG: world_rank " << _world_rank << " is rank " << _rank << " in pint domain " << _pint_domain << std::endl; 68 : #endif 69 : 70 : #endif 71 : } 72 : 73 : void run(int num_cycles) { 74 : if(!_cfg.isPintEnabled()) 75 : run_cycles(0, num_cycles); 76 : else { 77 : PintDomain domain = setup_domain(num_cycles); 78 : setup_solvers(domain); 79 : init_parareal(); 80 : run_parareal( _cfg.getPintIterations() ); 81 : } 82 : } 83 : 84 : int getPintDomain() const { return _pint_domain; } 85 : int getRank() const { return _rank; } 86 0 : bool isPintEnabled() const { return _cfg.isPintEnabled(); } 87 0 : int getInteration() const { return _iteration; } 88 : 89 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 90 : MPI_Comm getPintComm() const { return _local_pint_comm; } 91 : #endif 92 : 93 : private: 94 : void run_cycles(int start, int end){ 95 : for (int cycle = start; cycle < end; cycle++) 96 : _scenario->runOneCouplingCycle(cycle); 97 : } 98 : 99 : bool isFirst() const { return _pint_domain == 0; } 100 : bool isLast() const { return _pint_domain == _cfg.getPintDomains()-1; } 101 : 102 : struct PintDomain { 103 : /** @brief number of time steps (coupling cycles) in this temporal domain */ 104 : int size; 105 : /** @brief number of the first coupling cycle in this temporal domain (inclusive) */ 106 : int minCycle; 107 : /** @brief number of the last coupling cycle of this temporal domain (exclusive) */ 108 : int maxCycle; 109 : }; 110 : 111 : PintDomain setup_domain(int num_cycles) const { 112 : PintDomain res; 113 : res.size = (int)( num_cycles / _cfg.getPintDomains() ); 114 : res.minCycle = _pint_domain * res.size; 115 : res.maxCycle = res.minCycle + res.size; 116 : // In case num_cycles is not divisible by _cfg.getPintDomains(), the last domain gets the remainder 117 : if( isLast() ) 118 : res.maxCycle = num_cycles; 119 : #ifdef PINT_DEBUG 120 : std::cout << "PINT_DEBUG: _pint_domain " << _pint_domain << " has minCycle " << res.minCycle 121 : << " and maxCycle " << res.maxCycle << std::endl; 122 : #endif 123 : return res; 124 : } 125 : 126 : void setup_solvers(PintDomain domain){ 127 : auto solver = dynamic_cast<Solver*>( _scenario->getSolver() ); 128 : if(solver == nullptr){ 129 : std::cout << "ERROR coupling::services::ParallelTimeIntegrationService: " << 130 : "macroscopic solver is not pintable (= not compatible with parallel in time coupling)" << std::endl; 131 : exit(EXIT_FAILURE); 132 : } 133 : #ifdef PINT_DEBUG 134 : if(_cfg.getViscMultiplier() != 1.0) 135 : std::cout << "PINT_DEBUG: Starting supervisor with viscosity modified by " << _cfg.getViscMultiplier() << std::endl; 136 : #endif 137 : _supervisor = solver->getSupervisor(domain.size, _cfg.getViscMultiplier() ); 138 : _F = [this, solver, domain](const std::unique_ptr<State>& s){ 139 : solver->setState(s, domain.minCycle); 140 : // TODO enable momentumTransfer on inner cells for MD equilibration steps here 141 : // TODO run MD equilibration here 142 : run_cycles(domain.minCycle, domain.maxCycle); 143 : // TODO double check that filter pipeline output ends up in solver.getState() here 144 : return solver->getState(); 145 : }; 146 : _G = [this, domain](const std::unique_ptr<State>& s){ 147 : auto& G = *_supervisor; 148 : return G(s, domain.minCycle); 149 : }; 150 : _u_0 = solver->getState(); 151 : } 152 : 153 : void receive(std::unique_ptr<State>& state) const{ 154 : if(!state) state = _u_0->clone(); 155 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 156 : int source_rank = _world_rank - _ranks_per_domain; 157 : MPI_Recv(state->getData(), state->getSizeBytes(), MPI_BYTE, source_rank, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE); 158 : #endif 159 : } 160 : 161 : void send(std::unique_ptr<State>& state) const { 162 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 163 : int destination_rank = _world_rank + _ranks_per_domain; 164 : MPI_Send(state->getData(), state->getSizeBytes(), MPI_BYTE, destination_rank, 0, MPI_COMM_WORLD); 165 : #endif 166 : } 167 : 168 : void get_past_state(std::unique_ptr<State>& state) const { 169 : if( isFirst() ) 170 : state = _u_0->clone(); 171 : else 172 : receive(state); 173 : } 174 : 175 : void init_parareal(){ 176 : get_past_state(_u_last_past); 177 : _u_last_future = _G(_u_last_past); 178 : if( !isLast() ) 179 : send(_u_last_future); 180 : } 181 : 182 : void run_parareal(int iterations){ 183 : while(_iteration < iterations){ 184 : // Correction step 185 : auto delta = _F(_u_last_past) - _G(_u_last_past); 186 : 187 : _iteration++; 188 : 189 : // Prediction step 190 : get_past_state(_u_next_past); 191 : auto prediction = _G(_u_next_past); 192 : 193 : // Refinement step 194 : _u_next_future = prediction + delta; 195 : if( !isLast() ) 196 : send(_u_next_future); 197 : 198 : // move for next iteration 199 : _u_last_past = std::move(_u_next_past); 200 : _u_last_future = std::move(_u_next_future); 201 : } 202 : 203 : #ifdef PINT_DEBUG 204 : std::cout << "PINT_DEBUG: Finished all PinT iterations. " << std::endl; 205 : #endif 206 : } 207 : 208 : int _pint_domain; // the index of the time domain to which this process belongs to 209 : int _rank; // rank of current process in _local_pint_comm 210 : int _ranks_per_domain; // number of MPI ranks in each time domain 211 : int _world_rank; // rank of current process in MPI_COMM_WORLD 212 : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES) 213 : MPI_Comm _local_pint_comm; // the communicator of the local time domain of this rank 214 : #endif 215 : coupling::configurations::TimeIntegrationConfiguration _cfg; 216 : Scenario* _scenario; 217 : std::unique_ptr<Solver> _supervisor; // The supervisor, i.e. the coarse predictor 218 : std::function<std::unique_ptr<State>(const std::unique_ptr<State>&)> _F; 219 : std::function<std::unique_ptr<State>(const std::unique_ptr<State>&)> _G; 220 : 221 : // These objects represent coupled simulation states. There are used by the supervised parallel in time algorithm for operations 222 : // "last" and "next" describe two consecutive parareal iterations 223 : // "past" and "future" describe two points in simulation time, one pint time domain apart 224 : std::unique_ptr<State> _u_0; // initial state 225 : std::unique_ptr<State> _u_last_past; 226 : std::unique_ptr<State> _u_last_future; 227 : std::unique_ptr<State> _u_next_past; 228 : std::unique_ptr<State> _u_next_future; 229 : 230 : int _iteration{0}; 231 : };