LCOV - code coverage report
Current view: top level - coupling/services - ParallelTimeIntegrationService.h (source / functions) Hit Total Coverage
Test: coverage.info Lines: 0 2 0.0 %
Date: 2026-02-16 14:39:39 Functions: 0 0 -

          Line data    Source code
       1             : // Copyright (C) 2023 Helmut Schmidt University
       2             : // This file is part of the Mamico project. For conditions of distribution
       3             : // and use, please see the copyright notice in Mamico's main folder
       4             : 
       5             : #pragma once
       6             : 
       7             : namespace coupling {
       8             : namespace services {
       9             : template <unsigned int dim> class ParallelTimeIntegrationService;
      10             : } // namespace services
      11             : } // namespace coupling
      12             : 
      13             : #include "coupling/configurations/MaMiCoConfiguration.h"
      14             : #include "coupling/interface/PintableMacroSolver.h"
      15             : #include "coupling/scenario/Scenario.h"
      16             : #include <functional>
      17             : 
      18             : // Convenience operators, to be able to write parareal iterations in a readable notation
      19             : using st_ptr = std::unique_ptr<coupling::interface::PintableMacroSolverState>;
      20             : inline st_ptr operator+(const st_ptr& lhs, const st_ptr& rhs) { return *lhs + *rhs; }
      21             : inline st_ptr operator-(const st_ptr& lhs, const st_ptr& rhs) { return *lhs - *rhs; }
      22             : 
      23             : /**
      24             :  * Service to manage timeloop of a coupled simulation scenario. Supports sequential or parallel-in-time integration using a Parareal variant,
      25             :  * as described in "Blumers, A. L., Li, Z., & Karniadakis, G. E. (2019). Supervised parallel-in-time algorithm for long-time Lagrangian
      26             :  * simulations of stochastic dynamics: Application to hydrodynamics. Journal of Computational Physics, 393, 214-228".
      27             :  *
      28             :  *  @author Piet Jarmatz
      29             :  */
      30             : template <unsigned int dim> class coupling::services::ParallelTimeIntegrationService {
      31             : public:
      32             :   using State = coupling::interface::PintableMacroSolverState;
      33             :   using Solver = coupling::interface::PintableMacroSolver;
      34             : 
      35             :   ParallelTimeIntegrationService(coupling::configurations::MaMiCoConfiguration<dim> mamicoConfig, Scenario* scenario)
      36             :       : _pint_domain(0), _rank(0), _ranks_per_domain(1), _world_rank(0), _cfg(mamicoConfig.getTimeIntegrationConfiguration()), _scenario(scenario) {
      37             : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
      38             :     int world_size;
      39             :     MPI_Comm_size(MPI_COMM_WORLD, &world_size);
      40             :     MPI_Comm_rank(MPI_COMM_WORLD, &_world_rank);
      41             :     if (_cfg.isPintEnabled()) {
      42             :       if (world_size % _cfg.getPintDomains() != 0) {
      43             :         if (_world_rank == 0) {
      44             :           std::cout << "ERROR coupling::services::ParallelTimeIntegrationService: " << "MPI ranks not divisible by number of PinT subdomains!" << std::endl;
      45             :           std::cout << "When PinT is used, the number of required MPI ranks increases by a factor of number-subdomains." << std::endl;
      46             :           std::cout << "Check your configuration." << std::endl;
      47             :         }
      48             :         exit(EXIT_FAILURE);
      49             :       }
      50             :       _ranks_per_domain = world_size / _cfg.getPintDomains();
      51             :       _pint_domain = _world_rank / _ranks_per_domain;
      52             :       // This initializes _local_pint_comm by splitting MPI_COMM_WORLD into getPintDomains() disjoint communicators
      53             :       MPI_Comm_split(MPI_COMM_WORLD, _pint_domain, _world_rank, &_local_pint_comm);
      54             :     } else {
      55             :       _local_pint_comm = MPI_COMM_WORLD;
      56             :     }
      57             :     MPI_Comm_rank(_local_pint_comm, &_rank);
      58             : 
      59             : #ifdef PINT_DEBUG
      60             :     std::cout << "PINT_DEBUG: world_rank " << _world_rank << " is rank " << _rank << " in pint domain " << _pint_domain << std::endl;
      61             : #endif
      62             : 
      63             : #endif
      64             :   }
      65             : 
      66             :   void run(int num_cycles) {
      67             :     if (!_cfg.isPintEnabled())
      68             :       run_cycles(0, num_cycles);
      69             :     else {
      70             :       PintDomain domain = setup_domain(num_cycles);
      71             :       setup_solvers(domain);
      72             :       init_parareal();
      73             :       run_parareal(_cfg.getPintIterations());
      74             :     }
      75             :   }
      76             : 
      77             :   int getPintDomain() const { return _pint_domain; }
      78             :   int getRank() const { return _rank; }
      79           0 :   bool isPintEnabled() const { return _cfg.isPintEnabled(); }
      80           0 :   int getIteration() const { return _iteration; }
      81             : 
      82             : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
      83             :   MPI_Comm getPintComm() const { return _local_pint_comm; }
      84             : #endif
      85             : 
      86             : private:
      87             :   void run_cycles(int start, int end) {
      88             :     for (int cycle = start; cycle < end; cycle++)
      89             :       _scenario->runOneCouplingCycle(cycle);
      90             :   }
      91             : 
      92             :   bool isFirst() const { return _pint_domain == 0; }
      93             :   bool isLast() const { return _pint_domain == _cfg.getPintDomains() - 1; }
      94             : 
      95             :   struct PintDomain {
      96             :     /** @brief number of time steps (coupling cycles) in this temporal domain */
      97             :     int size;
      98             :     /** @brief number of the first coupling cycle in this temporal domain (inclusive) */
      99             :     int minCycle;
     100             :     /** @brief number of the last coupling cycle of this temporal domain (exclusive) */
     101             :     int maxCycle;
     102             :   };
     103             : 
     104             :   PintDomain setup_domain(int num_cycles) const {
     105             :     PintDomain res;
     106             :     res.size = (int)(num_cycles / _cfg.getPintDomains());
     107             :     res.minCycle = _pint_domain * res.size;
     108             :     res.maxCycle = res.minCycle + res.size;
     109             :     // In case num_cycles is not divisible by _cfg.getPintDomains(), the last domain gets the remainder
     110             :     if (isLast())
     111             :       res.maxCycle = num_cycles;
     112             : #ifdef PINT_DEBUG
     113             :     if (_rank == 0) {
     114             :       std::cout << "PINT_DEBUG: _pint_domain " << _pint_domain << " has minCycle " << res.minCycle << " and maxCycle " << res.maxCycle << std::endl;
     115             :     }
     116             : #endif
     117             :     return res;
     118             :   }
     119             : 
     120             :   void setup_solvers(PintDomain domain) {
     121             :     auto solver = dynamic_cast<Solver*>(_scenario->getSolver());
     122             :     if (solver == nullptr) {
     123             :       std::cout << "ERROR coupling::services::ParallelTimeIntegrationService: "
     124             :                 << "macroscopic solver is not pintable (= not compatible with parallel in time coupling)" << std::endl;
     125             :       exit(EXIT_FAILURE);
     126             :     }
     127             : #ifdef PINT_DEBUG
     128             :     if (_cfg.getViscMultiplier() != 1.0)
     129             :       if (_world_rank == 0) {
     130             :         std::cout << "PINT_DEBUG: Starting supervisor with viscosity modified by " << _cfg.getViscMultiplier() << std::endl;
     131             :       }
     132             : #endif
     133             :     _supervisor = solver->getSupervisor(domain.size, _cfg.getViscMultiplier());
     134             :     _F = [this, solver, domain](const std::unique_ptr<State>& s) {
     135             :       solver->setState(s, domain.minCycle);
     136             :       if (domain.minCycle > 0)
     137             :         _scenario->equilibrateMicro();
     138             :       run_cycles(domain.minCycle, domain.maxCycle);
     139             :       return solver->getState();
     140             :     };
     141             :     _G = [this, domain](const std::unique_ptr<State>& s) {
     142             :       auto& G = *_supervisor;
     143             :       return G(s, domain.minCycle);
     144             :     };
     145             :     _u_0 = solver->getState();
     146             :   }
     147             : 
     148             :   void receive(std::unique_ptr<State>& state) const {
     149             :     if (!state)
     150             :       state = _u_0->clone();
     151             : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
     152             :     int source_rank = _world_rank - _ranks_per_domain;
     153             :     MPI_Recv(state->getData(), state->getSizeBytes(), MPI_BYTE, source_rank, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
     154             : #endif
     155             :   }
     156             : 
     157             :   void send(std::unique_ptr<State>& state) const {
     158             : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
     159             :     int destination_rank = _world_rank + _ranks_per_domain;
     160             :     MPI_Send(state->getData(), state->getSizeBytes(), MPI_BYTE, destination_rank, 0, MPI_COMM_WORLD);
     161             : #endif
     162             :   }
     163             : 
     164             :   void get_past_state(std::unique_ptr<State>& state) const {
     165             :     if (isFirst())
     166             :       state = _u_0->clone();
     167             :     else
     168             :       receive(state);
     169             :   }
     170             : 
     171             :   void init_parareal() {
     172             :     get_past_state(_u_last_past);
     173             :     _u_last_future = _G(_u_last_past);
     174             :     if (!isLast())
     175             :       send(_u_last_future);
     176             :   }
     177             : 
     178             :   void run_parareal(int iterations) {
     179             :     while (_iteration < iterations) {
     180             :       // Correction step
     181             :       auto delta = _F(_u_last_past) - _G(_u_last_past);
     182             :       // Alternative variant, together with second _u_last_future line below.
     183             :       // Better scalability, should yield the same results. TODO test and verify.
     184             :       // auto delta = _F(_u_last_past) - _u_last_future;
     185             : 
     186             :       _iteration++;
     187             : 
     188             :       // Prediction step
     189             :       get_past_state(_u_next_past);
     190             :       auto prediction = _G(_u_next_past);
     191             : 
     192             :       // Refinement step
     193             :       _u_next_future = prediction + delta;
     194             :       if (!isLast())
     195             :         send(_u_next_future);
     196             : 
     197             :       // move for next iteration
     198             :       _u_last_past = std::move(_u_next_past);
     199             :       _u_last_future = std::move(_u_next_future);
     200             :       // Alternative variant
     201             :       //_u_last_future = std::move(prediction);
     202             :     }
     203             : 
     204             : #ifdef PINT_DEBUG
     205             :     if (_world_rank == 0) {
     206             :       std::cout << "PINT_DEBUG: Finished all PinT iterations. " << std::endl;
     207             :     }
     208             : #endif
     209             :   }
     210             : 
     211             :   int _pint_domain;      // the index of the time domain to which this process belongs to
     212             :   int _rank;             // rank of current process in _local_pint_comm
     213             :   int _ranks_per_domain; // number of MPI ranks in each time domain
     214             :   int _world_rank;       // rank of current process in MPI_COMM_WORLD
     215             : #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
     216             :   MPI_Comm _local_pint_comm; // the communicator of the local time domain of this rank
     217             : #endif
     218             :   coupling::configurations::TimeIntegrationConfiguration _cfg;
     219             :   Scenario* _scenario;
     220             :   std::unique_ptr<Solver> _supervisor; // The supervisor, i.e. the coarse predictor
     221             :   std::function<std::unique_ptr<State>(const std::unique_ptr<State>&)> _F;
     222             :   std::function<std::unique_ptr<State>(const std::unique_ptr<State>&)> _G;
     223             : 
     224             :   // These objects represent coupled simulation states. There are used by the supervised parallel in time algorithm for operations
     225             :   // "last" and "next" describe two consecutive parareal iterations
     226             :   // "past" and "future" describe two points in simulation time, one pint time domain apart
     227             :   std::unique_ptr<State> _u_0; // initial state
     228             :   std::unique_ptr<State> _u_last_past;
     229             :   std::unique_ptr<State> _u_last_future;
     230             :   std::unique_ptr<State> _u_next_past;
     231             :   std::unique_ptr<State> _u_next_future;
     232             : 
     233             :   int _iteration{0};
     234             : };

Generated by: LCOV version 1.14