LCOV - code coverage report
Current view: top level - coupling/services - ParallelTimeIntegrationService.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 0 2 0.0 %
Date: 2025-06-25 11:26:37 Functions: 0 0 -

          Line data    Source code
       1             : // Copyright (C) 2023 Helmut Schmidt University
       2             : // This file is part of the Mamico project. For conditions of distribution
       3             : // and use, please see the copyright notice in Mamico's main folder
       4             : 
       5             : #pragma once
       6             : 
       7             : namespace coupling {
       8             : namespace services {
       9             : template <unsigned int dim> class ParallelTimeIntegrationService;
      10             : } // namespace services
      11             : } // namespace coupling
      12             : 
      13             : #include "coupling/configurations/MaMiCoConfiguration.h"
      14             : #include "coupling/interface/PintableMacroSolver.h"
      15             : #include "coupling/scenario/Scenario.h"
      16             : #include <functional>
      17             : 
      18             : // Convenience operators, to be able to write parareal iterations in a readable notation
      19             : using st_ptr = std::unique_ptr<coupling::interface::PintableMacroSolverState>;
      20             : inline st_ptr operator+(const st_ptr& lhs, const st_ptr& rhs){
      21             :     return *lhs + *rhs;
      22             : }
      23             : inline st_ptr operator-(const st_ptr& lhs, const st_ptr& rhs){
      24             :     return *lhs - *rhs;
      25             : }
      26             : 
      27             : /** 
      28             :  * Service to manage timeloop of a coupled simulation scenario. Supports sequential or parallel-in-time integration using a Parareal variant,
      29             :  * as described in "Blumers, A. L., Li, Z., & Karniadakis, G. E. (2019). Supervised parallel-in-time algorithm for long-time Lagrangian 
      30             :  * simulations of stochastic dynamics: Application to hydrodynamics. Journal of Computational Physics, 393, 214-228".
      31             :  * 
      32             :  *  @author Piet Jarmatz
      33             :  */
      34             : template <unsigned int dim> class coupling::services::ParallelTimeIntegrationService {
      35             : public:
      36             :     using State = coupling::interface::PintableMacroSolverState;
      37             :     using Solver = coupling::interface::PintableMacroSolver;
      38             : 
      39             :     ParallelTimeIntegrationService(coupling::configurations::MaMiCoConfiguration<dim> mamicoConfig, Scenario* scenario):
      40             :     _pint_domain(0), _rank(0), _ranks_per_domain(1), _world_rank(0),
      41             :     _cfg( mamicoConfig.getTimeIntegrationConfiguration() ),
      42             :     _scenario(scenario) {
      43             :         #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
      44             :         int world_size;
      45             :         MPI_Comm_size(MPI_COMM_WORLD, &world_size);
      46             :         MPI_Comm_rank(MPI_COMM_WORLD, &_world_rank);
      47             :         if(_cfg.isPintEnabled()){
      48             :             if(world_size % _cfg.getPintDomains() != 0){
      49             :                 if(_world_rank == 0){
      50             :                     std::cout << "ERROR coupling::services::ParallelTimeIntegrationService: " <<
      51             :                         "MPI ranks not divisible by number of PinT subdomains!" << std::endl;
      52             :                     std::cout << "When PinT is used, the number of required MPI ranks increases by a factor of number-subdomains." << std::endl;
      53             :                     std::cout << "Check your configuration." << std::endl;
      54             :                 }
      55             :                 exit(EXIT_FAILURE);
      56             :             }
      57             :             _ranks_per_domain = world_size / _cfg.getPintDomains();
      58             :             _pint_domain = _world_rank / _ranks_per_domain;
      59             :             // This initializes _local_pint_comm by splitting MPI_COMM_WORLD into getPintDomains() disjoint communicators
      60             :             MPI_Comm_split(MPI_COMM_WORLD, _pint_domain, _world_rank, &_local_pint_comm);
      61             :         } else {
      62             :             _local_pint_comm = MPI_COMM_WORLD;
      63             :         }
      64             :         MPI_Comm_rank(_local_pint_comm, &_rank);
      65             : 
      66             :         #ifdef PINT_DEBUG
      67             :         std::cout << "PINT_DEBUG: world_rank " << _world_rank << " is rank " << _rank << " in pint domain " << _pint_domain << std::endl;
      68             :         #endif
      69             : 
      70             :         #endif
      71             :     }
      72             : 
      73             :     void run(int num_cycles) {
      74             :         if(!_cfg.isPintEnabled())
      75             :             run_cycles(0, num_cycles);
      76             :         else {
      77             :             PintDomain domain = setup_domain(num_cycles);            
      78             :             setup_solvers(domain);
      79             :             init_parareal();
      80             :             run_parareal( _cfg.getPintIterations() );
      81             :         }
      82             :     }
      83             : 
      84             :     int getPintDomain() const { return _pint_domain; }
      85             :     int getRank() const { return _rank; }
      86           0 :     bool isPintEnabled() const { return _cfg.isPintEnabled(); }
      87           0 :     int getInteration() const { return _iteration; }
      88             : 
      89             :     #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
      90             :     MPI_Comm getPintComm() const { return _local_pint_comm; }
      91             :     #endif
      92             : 
      93             : private:
      94             :     void run_cycles(int start, int end){
      95             :         for (int cycle = start; cycle < end; cycle++)
      96             :             _scenario->runOneCouplingCycle(cycle);
      97             :     }
      98             : 
      99             :     bool isFirst() const { return _pint_domain == 0; }
     100             :     bool isLast() const { return _pint_domain == _cfg.getPintDomains()-1; }
     101             : 
     102             :     struct PintDomain {
     103             :         /** @brief number of time steps (coupling cycles) in this temporal domain */
     104             :         int size;
     105             :         /** @brief number of the first coupling cycle in this temporal domain (inclusive) */
     106             :         int minCycle;
     107             :         /** @brief number of the last coupling cycle of this temporal domain (exclusive) */
     108             :         int maxCycle;
     109             :     };
     110             : 
     111             :     PintDomain setup_domain(int num_cycles) const {
     112             :         PintDomain res;
     113             :         res.size = (int)( num_cycles / _cfg.getPintDomains() );
     114             :         res.minCycle = _pint_domain * res.size;
     115             :         res.maxCycle = res.minCycle + res.size;
     116             :         // In case num_cycles is not divisible by _cfg.getPintDomains(), the last domain gets the remainder
     117             :         if( isLast() ) 
     118             :             res.maxCycle = num_cycles;
     119             :         #ifdef PINT_DEBUG
     120             :         std::cout << "PINT_DEBUG: _pint_domain " << _pint_domain << " has minCycle " << res.minCycle 
     121             :             << " and maxCycle " << res.maxCycle << std::endl;
     122             :         #endif
     123             :         return res;
     124             :     }
     125             : 
     126             :     void setup_solvers(PintDomain domain){
     127             :         auto solver = dynamic_cast<Solver*>( _scenario->getSolver() );
     128             :         if(solver == nullptr){
     129             :             std::cout << "ERROR coupling::services::ParallelTimeIntegrationService: " <<
     130             :                 "macroscopic solver is not pintable (= not compatible with parallel in time coupling)" << std::endl;
     131             :             exit(EXIT_FAILURE);
     132             :         }
     133             :         #ifdef PINT_DEBUG
     134             :         if(_cfg.getViscMultiplier() != 1.0)
     135             :             std::cout << "PINT_DEBUG: Starting supervisor with viscosity modified by " << _cfg.getViscMultiplier() << std::endl;
     136             :         #endif
     137             :         _supervisor = solver->getSupervisor(domain.size, _cfg.getViscMultiplier() );
     138             :         _F = [this, solver, domain](const std::unique_ptr<State>& s){
     139             :             solver->setState(s, domain.minCycle);
     140             :             // TODO enable momentumTransfer on inner cells for MD equilibration steps here
     141             :             // TODO run MD equilibration here
     142             :             run_cycles(domain.minCycle, domain.maxCycle);
     143             :             // TODO double check that filter pipeline output ends up in solver.getState() here
     144             :             return solver->getState();
     145             :         };
     146             :         _G = [this, domain](const std::unique_ptr<State>& s){
     147             :             auto& G = *_supervisor;
     148             :             return G(s, domain.minCycle);
     149             :         };
     150             :         _u_0 = solver->getState();
     151             :     }
     152             : 
     153             :     void receive(std::unique_ptr<State>& state) const{
     154             :         if(!state) state = _u_0->clone();
     155             :         #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
     156             :         int source_rank = _world_rank - _ranks_per_domain;
     157             :         MPI_Recv(state->getData(), state->getSizeBytes(), MPI_BYTE, source_rank, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
     158             :         #endif
     159             :     }
     160             : 
     161             :     void send(std::unique_ptr<State>& state) const {
     162             :         #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
     163             :         int destination_rank = _world_rank + _ranks_per_domain;
     164             :         MPI_Send(state->getData(), state->getSizeBytes(), MPI_BYTE, destination_rank, 0, MPI_COMM_WORLD);
     165             :         #endif
     166             :     }
     167             : 
     168             :     void get_past_state(std::unique_ptr<State>& state) const {
     169             :         if( isFirst() ) 
     170             :             state = _u_0->clone();
     171             :         else
     172             :             receive(state);
     173             :     }
     174             : 
     175             :     void init_parareal(){
     176             :         get_past_state(_u_last_past);
     177             :         _u_last_future = _G(_u_last_past);
     178             :         if( !isLast() )
     179             :             send(_u_last_future);
     180             :     }
     181             : 
     182             :     void run_parareal(int iterations){
     183             :         while(_iteration < iterations){
     184             :             // Correction step
     185             :             auto delta = _F(_u_last_past) - _G(_u_last_past);
     186             : 
     187             :             _iteration++;
     188             : 
     189             :             // Prediction step
     190             :             get_past_state(_u_next_past);
     191             :             auto prediction = _G(_u_next_past);
     192             : 
     193             :             // Refinement step
     194             :             _u_next_future = prediction + delta;
     195             :             if( !isLast() )
     196             :                 send(_u_next_future);
     197             :             
     198             :             // move for next iteration
     199             :             _u_last_past   = std::move(_u_next_past);
     200             :             _u_last_future = std::move(_u_next_future);
     201             :         }
     202             : 
     203             :         #ifdef PINT_DEBUG
     204             :         std::cout << "PINT_DEBUG: Finished all PinT iterations. " << std::endl;
     205             :         #endif
     206             :     }
     207             : 
     208             :     int _pint_domain;       // the index of the time domain to which this process belongs to
     209             :     int _rank;              // rank of current process in _local_pint_comm
     210             :     int _ranks_per_domain;  // number of MPI ranks in each time domain
     211             :     int _world_rank;        // rank of current process in MPI_COMM_WORLD
     212             :     #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
     213             :     MPI_Comm _local_pint_comm; // the communicator of the local time domain of this rank
     214             :     #endif
     215             :     coupling::configurations::TimeIntegrationConfiguration _cfg;
     216             :     Scenario* _scenario;
     217             :     std::unique_ptr<Solver> _supervisor;  // The supervisor, i.e. the coarse predictor
     218             :     std::function<std::unique_ptr<State>(const std::unique_ptr<State>&)> _F;
     219             :     std::function<std::unique_ptr<State>(const std::unique_ptr<State>&)> _G;
     220             : 
     221             :     // These objects represent coupled simulation states. There are used by the supervised parallel in time algorithm for operations
     222             :     // "last" and "next" describe two consecutive parareal iterations
     223             :     // "past" and "future" describe two points in simulation time, one pint time domain apart
     224             :     std::unique_ptr<State> _u_0;       // initial state
     225             :     std::unique_ptr<State> _u_last_past;
     226             :     std::unique_ptr<State> _u_last_future;
     227             :     std::unique_ptr<State> _u_next_past;
     228             :     std::unique_ptr<State> _u_next_future;
     229             : 
     230             :     int _iteration{0};
     231             : };

Generated by: LCOV version 1.14