40 _pint_domain(0), _rank(0), _ranks_per_domain(1), _world_rank(0),
43 #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
45 MPI_Comm_size(MPI_COMM_WORLD, &world_size);
46 MPI_Comm_rank(MPI_COMM_WORLD, &_world_rank);
47 if(_cfg.isPintEnabled()){
48 if(world_size % _cfg.getPintDomains() != 0){
50 std::cout <<
"ERROR coupling::services::ParallelTimeIntegrationService: " <<
51 "MPI ranks not divisible by number of PinT subdomains!" << std::endl;
52 std::cout <<
"When PinT is used, the number of required MPI ranks increases by a factor of number-subdomains." << std::endl;
53 std::cout <<
"Check your configuration." << std::endl;
57 _ranks_per_domain = world_size / _cfg.getPintDomains();
58 _pint_domain = _world_rank / _ranks_per_domain;
60 MPI_Comm_split(MPI_COMM_WORLD, _pint_domain, _world_rank, &_local_pint_comm);
62 _local_pint_comm = MPI_COMM_WORLD;
64 MPI_Comm_rank(_local_pint_comm, &_rank);
67 std::cout <<
"PINT_DEBUG: world_rank " << _world_rank <<
" is rank " << _rank <<
" in pint domain " << _pint_domain << std::endl;
73 void run(
int num_cycles) {
74 if(!_cfg.isPintEnabled())
75 run_cycles(0, num_cycles);
78 setup_solvers(domain);
80 run_parareal( _cfg.getPintIterations() );
84 int getPintDomain()
const {
return _pint_domain; }
85 int getRank()
const {
return _rank; }
86 bool isPintEnabled()
const {
return _cfg.isPintEnabled(); }
87 int getInteration()
const {
return _iteration; }
89 #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
90 MPI_Comm getPintComm()
const {
return _local_pint_comm; }
94 void run_cycles(
int start,
int end){
95 for (
int cycle = start; cycle < end; cycle++)
96 _scenario->runOneCouplingCycle(cycle);
99 bool isFirst()
const {
return _pint_domain == 0; }
100 bool isLast()
const {
return _pint_domain == _cfg.getPintDomains()-1; }
111 PintDomain setup_domain(
int num_cycles)
const {
113 res.
size = (int)( num_cycles / _cfg.getPintDomains() );
120 std::cout <<
"PINT_DEBUG: _pint_domain " << _pint_domain <<
" has minCycle " << res.
minCycle
121 <<
" and maxCycle " << res.
maxCycle << std::endl;
126 void setup_solvers(PintDomain domain){
127 auto solver =
dynamic_cast<Solver*
>( _scenario->getSolver() );
128 if(solver ==
nullptr){
129 std::cout <<
"ERROR coupling::services::ParallelTimeIntegrationService: " <<
130 "macroscopic solver is not pintable (= not compatible with parallel in time coupling)" << std::endl;
134 if(_cfg.getViscMultiplier() != 1.0)
135 std::cout <<
"PINT_DEBUG: Starting supervisor with viscosity modified by " << _cfg.getViscMultiplier() << std::endl;
137 _supervisor = solver->getSupervisor(domain.size, _cfg.getViscMultiplier() );
138 _F = [
this, solver, domain](
const std::unique_ptr<State>& s){
139 solver->setState(s, domain.minCycle);
142 run_cycles(domain.minCycle, domain.maxCycle);
144 return solver->getState();
146 _G = [
this, domain](
const std::unique_ptr<State>& s){
147 auto& G = *_supervisor;
148 return G(s, domain.minCycle);
150 _u_0 = solver->getState();
153 void receive(std::unique_ptr<State>& state)
const{
154 if(!state) state = _u_0->clone();
155 #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
156 int source_rank = _world_rank - _ranks_per_domain;
157 MPI_Recv(state->getData(), state->getSizeBytes(), MPI_BYTE, source_rank, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
161 void send(std::unique_ptr<State>& state)
const {
162 #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
163 int destination_rank = _world_rank + _ranks_per_domain;
164 MPI_Send(state->getData(), state->getSizeBytes(), MPI_BYTE, destination_rank, 0, MPI_COMM_WORLD);
168 void get_past_state(std::unique_ptr<State>& state)
const {
170 state = _u_0->clone();
175 void init_parareal(){
176 get_past_state(_u_last_past);
177 _u_last_future = _G(_u_last_past);
179 send(_u_last_future);
182 void run_parareal(
int iterations){
183 while(_iteration < iterations){
185 auto delta = _F(_u_last_past) - _G(_u_last_past);
190 get_past_state(_u_next_past);
191 auto prediction = _G(_u_next_past);
194 _u_next_future = prediction + delta;
196 send(_u_next_future);
199 _u_last_past = std::move(_u_next_past);
200 _u_last_future = std::move(_u_next_future);
204 std::cout <<
"PINT_DEBUG: Finished all PinT iterations. " << std::endl;
210 int _ranks_per_domain;
212 #if (COUPLING_MD_PARALLEL == COUPLING_MD_YES)
213 MPI_Comm _local_pint_comm;
215 coupling::configurations::TimeIntegrationConfiguration _cfg;
217 std::unique_ptr<Solver> _supervisor;
218 std::function<std::unique_ptr<State>(
const std::unique_ptr<State>&)> _F;
219 std::function<std::unique_ptr<State>(
const std::unique_ptr<State>&)> _G;
224 std::unique_ptr<State> _u_0;
225 std::unique_ptr<State> _u_last_past;
226 std::unique_ptr<State> _u_last_future;
227 std::unique_ptr<State> _u_next_past;
228 std::unique_ptr<State> _u_next_future;
int minCycle
number of the first coupling cycle in this temporal domain (inclusive)
Definition ParallelTimeIntegrationService.h:106
int maxCycle
number of the last coupling cycle of this temporal domain (exclusive)
Definition ParallelTimeIntegrationService.h:108
int size
number of time steps (coupling cycles) in this temporal domain
Definition ParallelTimeIntegrationService.h:104