15#include <Comm_Group_MPI.h>
16#include <petsc_for_kernel.h>
17#include <communications.h>
20#include <Perf_counters.h>
30MPI_Status * Comm_Group_MPI::mpi_status_ = 0;
31MPI_Request * Comm_Group_MPI::mpi_requests_ = 0;
32int Comm_Group_MPI::mpi_nrequests_ = -1;
33int Comm_Group_MPI::mpi_maxrequests_ = -1;
34int Comm_Group_MPI::current_msg_size_;
35MPI_Comm Comm_Group_MPI::trio_u_world_ = MPI_COMM_WORLD;
37bool Comm_Group_MPI::must_mpi_initialize_ =
true;
46void mpi_print_error(
int error_code)
48 Cerr <<
"mpi_error in Comm_Group_MPI : error_code = " << error_code << finl;
49 Process::Journal() <<
"mpi_error in Comm_Group_MPI : error_code = " << error_code << finl;
51 char message[MPI_MAX_ERROR_STRING];
52 MPI_Error_string(error_code, message, & length);
55 Cerr << message << finl;
62 MPI_Abort(MPI_COMM_WORLD,-1);
71inline void mpi_error(
int error_code)
73 if (error_code != MPI_SUCCESS)
74 mpi_print_error(error_code);
100 : mpi_group_(MPI_GROUP_NULL),
101 mpi_comm_(MPI_COMM_NULL),
112 if ((mpi_comm_!=MPI_COMM_NULL) && (mpi_comm_ == trio_u_world_))
114 delete [] mpi_status_;
117 for (
int r=0; r<mpi_maxrequests_; r++)
119 if(mpi_requests_[r]!=MPI_REQUEST_NULL)
121 MPI_Request_free(&(mpi_requests_[r]));
125 delete [] mpi_requests_;
131 if (mpi_comm_!=MPI_COMM_NULL)
134 mpi_error(MPI_Comm_free(&mpi_comm_));
135 assert(mpi_comm_==MPI_COMM_NULL);
137 if (mpi_group_!=MPI_GROUP_NULL)
139 mpi_error(MPI_Group_free( &mpi_group_));
140 assert(mpi_group_==MPI_GROUP_NULL);
152 MPI_Abort(trio_u_world_,-1);
157template <
typename _TYPE_,
int TYP_IDX>
160 static_assert(TYP_IDX >= 1 && TYP_IDX <= 4,
"Invalid type index!");
161 MPI_Datatype mpi_typ = TYP_IDX==1 ? MPI_INT : (TYP_IDX==2 ? MPI_LONG : (TYP_IDX==3 ? MPI_DOUBLE : MPI_FLOAT));
168 statistics().begin_count(STD_COUNTERS::mpi_sumdouble);
169 mpi_error(MPI_Allreduce(x, resu, n, mpi_typ, MPI_SUM, mpi_comm_));
170 if (clock_on && statistics().is_running(STD_COUNTERS::mpi_sumdouble))
171 s = statistics().get_time_since_last_open(STD_COUNTERS::mpi_sumdouble);
172 statistics().end_count(STD_COUNTERS::mpi_sumdouble);
175 statistics().begin_count(STD_COUNTERS::mpi_mindouble);
176 mpi_error(MPI_Allreduce(x, resu, n, mpi_typ, MPI_MIN, mpi_comm_));
177 if (clock_on && statistics().is_running(STD_COUNTERS::mpi_mindouble))
178 s = statistics().get_time_since_last_open(STD_COUNTERS::mpi_mindouble);
179 statistics().end_count(STD_COUNTERS::mpi_mindouble);
182 statistics().begin_count(STD_COUNTERS::mpi_maxdouble);
183 mpi_error(MPI_Allreduce(x, resu, n, mpi_typ, MPI_MAX, mpi_comm_));
184 if (clock_on && statistics().is_running(STD_COUNTERS::mpi_maxdouble))
185 s = statistics().get_time_since_last_open(STD_COUNTERS::mpi_maxdouble);
186 statistics().end_count(STD_COUNTERS::mpi_maxdouble);
189 internal_collective(x, resu, n, &op, -1 , 0 );
195 std::string mpi_reduce =
"mp_sum";
198 printf(
"%s %7.3f ms [MPI] %s\n", clock.c_str(), 0.001 * s, mpi_reduce.c_str());
207 mp_collective_op_template<double, 3 >(x, resu, n, op);
214 mp_collective_op_template<float, 4 >(x, resu, n, op);
221 mp_collective_op_template<int, 1 >(x, resu, n, op);
229 mp_collective_op_template<trustIdType, 2 >(x, resu, n, op);
239 internal_collective(x, resu, n, op, n , 0 );
248 internal_collective(x, resu, n, op, n , 0 );
257 internal_collective(x, resu, n, op, n , 0 );
267 internal_collective(x, resu, n, op, n , 0 );
284 static const int max_tag = 32;
285 statistics().begin_count(STD_COUNTERS::mpi_barrier);
286 assert(tag >= 0 && tag < max_tag);
293 int min_tag, amax_tag;
294 mpi_error(MPI_Allreduce(& tag_complet, & min_tag, 1, MPI_ENTIER, MPI_MIN, mpi_comm_));
295 mpi_error(MPI_Allreduce(& tag_complet, & amax_tag, 1, MPI_ENTIER, MPI_MAX, mpi_comm_));
296 if (min_tag != tag_complet || amax_tag != tag_complet)
298 Cerr <<
"Error in Comm_Group_MPI::barrier(int tag)\n";
299 Cerr <<
" the tag is not identical on all the processes.\n";
300 Cerr <<
" (Loss of communications synchronisation)." << finl;
301 Process::Journal() <<
"Comm_Group_MPI::barrier\n Error : tag = " << tag << finl;
309 mpi_error(MPI_Barrier(mpi_comm_));
311 statistics().end_count(STD_COUNTERS::mpi_barrier);
328 const ArrOfInt& send_size,
329 const char *
const *
const send_buffers,
330 const ArrOfInt& recv_list,
331 const ArrOfInt& recv_size,
332 char *
const *
const recv_buffers,
336 statistics().begin_count(STD_COUNTERS::mpi_sendrecv);
337 assert(mpi_nrequests_ < 0);
345 MPI_Datatype datatype = MPI_CHAR;
346 assert(
sizeof(
int) ==
sizeof(
int));
353 divisor =
sizeof(int);
354 datatype = MPI_ENTIER;
357 divisor =
sizeof(double);
358 datatype = MPI_DOUBLE;
361 divisor =
sizeof(float);
362 datatype = MPI_FLOAT;
371 for (i = 0; i < n; i++)
373 int source = recv_list[i];
374 int sz = recv_size[i];
376 assert(source >= 0 && source <
nproc());
377 assert(mpi_nrequests_ < mpi_maxrequests_);
378 assert(sz % divisor == 0);
379 assert(mpi_requests_[mpi_nrequests_]==MPI_REQUEST_NULL);
380 mpi_error(MPI_Irecv(recv_buffers[i], sz / divisor,
382 source, tag, mpi_comm_,
383 & mpi_requests_[mpi_nrequests_]));
388 for (i = 0; i < n; i++)
390 int dest = send_list[i];
391 int sz = send_size[i];
393 assert(dest >= 0 && dest <
nproc());
394 assert(mpi_nrequests_ < mpi_maxrequests_);
395 assert(sz % divisor == 0);
396 mpi_error(MPI_Isend((
char*) send_buffers[i], sz / divisor,
398 dest, tag, mpi_comm_,
399 & mpi_requests_[mpi_nrequests_]));
402 current_msg_size_ = msg_size;
412 assert(mpi_nrequests_ >= 0);
413 mpi_error(MPI_Waitall(mpi_nrequests_, mpi_requests_, mpi_status_));
417 double ms = 0.001 * statistics().get_time_since_last_open(STD_COUNTERS::mpi_sendrecv) ;
418 printf(
"%s %7.3f ms [MPI] Comm_Group_MPI::exchange\n", clock.c_str(), ms);
421 statistics().end_count(STD_COUNTERS::mpi_sendrecv,mpi_nrequests_,current_msg_size_);
448 statistics().begin_count(STD_COUNTERS::mpi_send);
449 assert(mpi_nrequests_ < 0);
451 assert(dest >= 0 && dest <
nproc());
455 mpi_error(MPI_Ssend ((
void*)buffer, size, MPI_CHAR, dest, tag, mpi_comm_));
457 mpi_error(MPI_Send ((
void*)buffer, size, MPI_CHAR, dest, tag, mpi_comm_));
458 statistics().end_count(STD_COUNTERS::mpi_send,1,size);
468 statistics().begin_count(STD_COUNTERS::mpi_recv);
469 assert(mpi_nrequests_ < 0);
472 assert(source >= 0 && source <
nproc());
473 mpi_error(MPI_Recv (buffer, size, MPI_CHAR, source, tag, mpi_comm_, & status));
474 statistics().end_count(STD_COUNTERS::mpi_recv,1,size);
481 statistics().begin_count(STD_COUNTERS::mpi_bcast);
482 assert(mpi_nrequests_ < 0);
483 mpi_error(MPI_Bcast (buffer, size, MPI_CHAR, pe_source, mpi_comm_));
484 statistics().end_count(STD_COUNTERS::mpi_bcast,1,size);
491 statistics().begin_count(STD_COUNTERS::mpi_alltoall);
492 assert(src_buffer != dest_buffer);
493 void * ptr = (
void *) src_buffer;
494 mpi_error(MPI_Alltoall(ptr, data_size, MPI_CHAR, dest_buffer, data_size, MPI_CHAR, mpi_comm_));
495 statistics().end_count(STD_COUNTERS::mpi_alltoall,1,data_size);
502 statistics().begin_count(STD_COUNTERS::mpi_gather);
503 void * ptr = (
void *) src_buffer;
504 mpi_error(MPI_Gather(ptr, data_size, MPI_CHAR, dest_buffer, data_size, MPI_CHAR, root, mpi_comm_));
505 statistics().end_count(STD_COUNTERS::mpi_gather,1,data_size);
512 statistics().begin_count(STD_COUNTERS::mpi_allgather);
513 void * ptr = (
void *) src_buffer;
514 mpi_error(MPI_Allgather(ptr, data_size, MPI_CHAR, dest_buffer, data_size, MPI_CHAR, mpi_comm_));
515 statistics().end_count(STD_COUNTERS::mpi_allgather,1,data_size);
522 statistics().begin_count(STD_COUNTERS::mpi_allgather);
523 void * ptr = (
void *) src_buffer;
524 mpi_error(MPI_Allgatherv(ptr, send_size, MPI_CHAR, dest_buffer, recv_size, displs, MPI_CHAR, mpi_comm_));
525 statistics().end_count(STD_COUNTERS::mpi_allgather,1,send_size);
544 if (mpi_status_ != 0)
546 Cerr <<
"Error : the construction of the global Comm_Group_MPI has already been done." << finl;
550 if (must_mpi_initialize_)
552 if (trio_u_world_ != MPI_COMM_WORLD)
554 Cerr <<
"Error in Comm_Group_MPI::init_group_trio(...) : you cannot ask to initialize MPI\n"
555 <<
" with something else than MPI_COMM_WORLD !" << finl;
561 int errcode = MPI_Init(&argc, &argv);
563 if (errcode != MPI_SUCCESS)
565 Cerr <<
"Error in Comm_Group_MPI::init_group_trio()\n"
566 <<
" MPI_Init() failed (forget to run with mpirun ?)" << finl;
574 mpi_error(MPI_Comm_size (trio_u_world_, & nbproc));
575 mpi_error(MPI_Comm_rank (trio_u_world_, & arank));
579 mpi_comm_ = trio_u_world_;
580 MPI_Comm_group(mpi_comm_, &mpi_group_);
585 mpi_maxrequests_ = nbproc * 2;
586 mpi_status_ =
new MPI_Status[mpi_maxrequests_];
587 mpi_requests_ =
new MPI_Request[mpi_maxrequests_];
589 for (
int r=0; r<mpi_maxrequests_; r++)
591 mpi_requests_[r]=MPI_REQUEST_NULL;
595 if (trio_u_world_ == MPI_COMM_WORLD)
597 Cerr <<
"Initialized MPI with MPI_COMM_WORLD (using all processors)" << finl;
601 Cerr <<
"Initialized MPI with communicator!=MPI_COMM_WORLD: using " << (int)nbproc <<
" processors" << finl;
608void Comm_Group_MPI::free()
610 if (mpi_maxrequests_!=-1)
611 mpi_error(MPI_Group_free(& mpi_group_));
618void Comm_Group_MPI::free_all()
620 if (mpi_maxrequests_!=-1)
622 if (mpi_group_!=MPI_GROUP_NULL)
623 mpi_error(MPI_Group_free(& mpi_group_));
624 if (mpi_comm_!=MPI_COMM_NULL)
625 mpi_error(MPI_Comm_free(&mpi_comm_));
631void Comm_Group_MPI::all_to_allv(
const void *src_buffer,
int *send_data_size,
int *send_data_offset,
632 void *dest_buffer,
int *recv_data_size,
int *recv_data_offset)
const
634 statistics().
begin_count(STD_COUNTERS::mpi_alltoall);
635 assert(src_buffer != dest_buffer);
636 void * ptr = (
void *) src_buffer;
638 const int n =
nproc();
642 std::vector<int> send_data_size_int(n);
643 std::vector<int> send_data_offset_int(n);
644 std::vector<int> recv_data_size_int(n);
645 std::vector<int> recv_data_offset_int(n);
647 auto cast_func = [](
int i) ->
int {
return static_cast<int>(i); };
648 std::transform(send_data_size, send_data_size + n, send_data_size_int.begin(), cast_func);
649 std::transform(send_data_offset, send_data_offset + n, send_data_offset_int.begin(), cast_func);
650 std::transform(recv_data_size, recv_data_size + n, recv_data_size_int.begin(), cast_func);
651 std::transform(recv_data_offset, recv_data_offset + n, recv_data_offset_int.begin(), cast_func);
653 mpi_error(MPI_Alltoallv(ptr, send_data_size_int.data(), send_data_offset_int.data(), MPI_CHAR,
654 dest_buffer, recv_data_size_int.data(), recv_data_offset_int.data(), MPI_CHAR, mpi_comm_));
655 size = send_data_offset_int[n-1] + send_data_size_int[n-1] + recv_data_size_int[n-1] + recv_data_offset_int[n-1];
657 mpi_error(MPI_Alltoallv(ptr, send_data_size, send_data_offset, MPI_CHAR,
658 dest_buffer, recv_data_size, recv_data_offset, MPI_CHAR, mpi_comm_));
659 size = send_data_offset[n-1] + send_data_size[n-1] + recv_data_size[n-1] + recv_data_offset[n-1];
662 statistics().
end_count(STD_COUNTERS::mpi_alltoall,1,size);
670void Comm_Group_MPI::set_trio_u_world(MPI_Comm world)
672 if (mpi_status_ != 0)
674 Cerr <<
"Error : the construction of the global Comm_Group_MPI has already been done\n"
675 <<
" set_trio_u_world call is forbidden" << finl;
679 PETSC_COMM_WORLD= world;
681 trio_u_world_ = world;
684MPI_Comm Comm_Group_MPI::get_trio_u_world()
686 return trio_u_world_;
691void Comm_Group_MPI::set_must_mpi_initialize(
bool flag)
693 if (mpi_status_ != 0)
695 Cerr <<
"Error : the construction of the global Comm_Group_MPI has already been done\n"
696 <<
" set_must_mpi_initialize() call is forbidden." << finl;
699 must_mpi_initialize_ = flag;
702void Comm_Group_MPI::ptop_send_recv(
const void * send_buf,
int send_buf_size,
int send_proc,
703 void * recv_buf,
int recv_buf_size,
int recv_proc)
const
705 statistics().
begin_count(STD_COUNTERS::mpi_sendrecv);
706 assert(mpi_nrequests_ < 0);
707 int dest = send_proc;
711 if (send_proc < 0 && recv_proc < 0)
715 else if (send_proc < 0 && recv_proc >= 0)
717 mpi_error(MPI_Recv (recv_buf, recv_buf_size, MPI_CHAR, src, tag, mpi_comm_, &status));
719 else if (recv_proc < 0 && send_proc >= 0)
721 mpi_error(MPI_Send ((
void*)send_buf, send_buf_size, MPI_CHAR, dest, tag, mpi_comm_));
725 assert(dest >= 0 && dest <
nproc());
726 assert(src >= 0 && src <
nproc());
729 mpi_error(MPI_Sendrecv((
void*)send_buf, send_buf_size, MPI_CHAR, dest, tag,
730 recv_buf, recv_buf_size, MPI_CHAR, src, tag, mpi_comm_,
733 statistics().
end_count(STD_COUNTERS::mpi_sendrecv, 1, send_buf_size + recv_buf_size);
756 const MPI_Group& current_mpi_group = cg.mpi_group_;
757 const MPI_Comm& current_mpi_comm = cg.mpi_comm_;
759 const int nbproc = this->
nproc();
760 int *ranks =
new int[nbproc];
761 for (
int i = 0; i < nbproc; i++)
762 ranks[i] = pe_list[i];
763 assert(mpi_group_==MPI_GROUP_NULL);
764 mpi_error(MPI_Group_incl(current_mpi_group, nbproc, ranks, & mpi_group_));
769 mpi_error(MPI_Comm_create(current_mpi_comm, mpi_group_, & mpi_comm_));
777void Comm_Group_MPI::init_comm_on_numa_node()
782 assert(mpi_group_==MPI_GROUP_NULL);
788 const MPI_Comm& current_mpi_comm = cg.mpi_comm_;
789 int current_rank = cg.rank();
790 mpi_error(MPI_Comm_split_type(current_mpi_comm, MPI_COMM_TYPE_SHARED, current_rank, MPI_INFO_NULL, &mpi_comm_));
791 mpi_error(MPI_Comm_group(mpi_comm_, &mpi_group_));
795 mpi_error(MPI_Comm_size(mpi_comm_, &nbproc));
796 mpi_error(MPI_Comm_rank(mpi_comm_, &loc_rank));
803 int master = loc_rank==0 ? 0 : MPI_UNDEFINED;
805 mpi_error(MPI_Comm_split(current_mpi_comm, master, current_rank, &tmp));
806 if(tmp != MPI_COMM_NULL)
808 mpi_error(MPI_Comm_rank(tmp, &
node_id_));
809 mpi_error(MPI_Comm_size(tmp, &
nb_nodes_));
812 mpi_error(MPI_Bcast(&
node_id_, 1, MPI_INT, 0, mpi_comm_));
813 mpi_error(MPI_Bcast(&
nb_nodes_, 1, MPI_INT, 0, mpi_comm_));
815 if (tmp!= MPI_COMM_NULL)
816 mpi_error(MPI_Comm_free(&tmp));
823void Comm_Group_MPI::init_comm_on_node_master()
828 assert(mpi_group_==MPI_GROUP_NULL);
834 const MPI_Comm& current_mpi_comm = cg.mpi_comm_;
835 const MPI_Group& current_mpi_group = cg.mpi_group_;
837 mpi_error(MPI_Group_incl(current_mpi_group, 1, &master, & mpi_group_));
838 mpi_error(MPI_Comm_create(current_mpi_comm, mpi_group_, & mpi_comm_));
841 int loc_rank = cg.rank() == 0 ? 0 : -1;
845void Comm_Group_MPI::internal_collective(
const int *x,
int *resu,
int nx,
const Collective_Op *op,
int nop,
int level)
const
848 for (
int i = 0; i < nx; i++)
850 int j = (nop < 0) ? 0 : i;
851 trustIdType xx = x[i], resu2 = -1;
855 resu2 = mppartial_sum_impl(x[i]);
856 assert(resu2 < std::numeric_limits<int>::max());
857 resu[i] =
static_cast<int>(resu2);
862void Comm_Group_MPI::internal_collective(
const trustIdType *x, trustIdType *resu,
int nx,
const Collective_Op *op,
int nop,
int level)
const
865 for (
int i = 0; i < nx; i++)
867 int j = (nop < 0) ? 0 : i;
871 resu[i] = mppartial_sum_impl(x[i]);
877void Comm_Group_MPI::internal_collective(
const double *x,
double *resu,
int nx,
const Collective_Op *op,
int nop,
int level)
const
880 for (
int i = 0; i < nx; i++)
882 int j = (nop < 0) ? 0 : i;
887 Cerr <<
"Error in Comm_Group_MPI: COLL_PARTIAL_SUM not coded for double" << finl;
893void Comm_Group_MPI::internal_collective(
const float *x,
float *resu,
int nx,
const Collective_Op *op,
int nop,
int level)
const
896 for (
int i = 0; i < nx; i++)
898 int j = (nop < 0) ? 0 : i;
903 Cerr <<
"Error in Comm_Group_MPI: COLL_PARTIAL_SUM not coded for float" << finl;
916trustIdType Comm_Group_MPI::mppartial_sum_impl(trustIdType x)
const
918 statistics().
begin_count(STD_COUNTERS::mpi_partialsum);
919 trustIdType somme = 0;
929 mpi_error(MPI_Recv(& somme, 1, MPI_INT, rang-1, tag, mpi_comm_, &status));
931 mpi_error(MPI_Recv(& somme, 1, MPI_LONG, rang-1, tag, mpi_comm_, &status));
937 trustIdType s = somme + x;
939 mpi_error(MPI_Send(& s, 1, MPI_INT, rang+1, tag, mpi_comm_));
941 mpi_error(MPI_Send(& s, 1, MPI_LONG, rang+1, tag, mpi_comm_));
944 statistics().
end_count(STD_COUNTERS::mpi_partialsum);
: Classe Comm_Group_MPI, derivee de la classe abstraite Comm_Group.
void recv(int pe, void *buffer, int size, int tag) const override
Reception blocante d'un message.
void all_to_all(const void *src_buffer, void *dest_buffer, int data_size) const override
void all_gather(const void *src_buffer, void *dest_buffer, int data_size) const override
void abort() const override
appel a MPI_Abort et rend la main
void send(int pe, const void *buffer, int size, int tag) const override
Envoi blocant.
void broadcast(void *buffer, int size, int pe_source) const override
void all_gatherv(const void *src_buffer, void *dest_buffer, int send_size, const int *recv_size, const int *displs) const override
~Comm_Group_MPI() override
void send_recv_finish() const override
Attend que l'ensemble des communications lancees par send_recv_start soient finie.
void send_recv_start(const ArrOfInt &send_list, const ArrOfInt &send_size, const char *const *const send_buffers, const ArrOfInt &recv_list, const ArrOfInt &recv_size, char *const *const recv_buffers, TypeHint typehint=CHAR) const override
Demarre l'envoi et la reception des buffers.
Comm_Group_MPI()
Constructeur par defaut.
void gather(const void *src_buffer, void *dest_buffer, int data_size, int root) const override
void mp_collective_op(const double *x, double *resu, int n, Collective_Op op) const override
: Cette classe decrit un groupe de processeurs sur lesquels
static int check_enabled()
int nproc() const
Renvoie le nombre de processeurs dans le groupe *this.
int rank() const
Renvoie le rang du processeur local dans le groupe *this.
void init_group_node(int nproc, int loc_rank, int glob_rank)
Initialize all the information relative to world sizes and ranks for node communicator.
virtual void init_group(const ArrOfInt &pe_list)
Cette fonction doit etre appelee simultanement par tous les PEs du groupe current_group avec les meme...
void init_group_trio(int nproc, int rank)
Initialise le groupe_TRUST().
int get_new_tag() const
Cette fonction renvoie un nouveau tag de communication pour le groupe.
Class defining operators and methods for all reading operation in an input flow (file,...
virtual Entree & readOn(Entree &)
Lecture d'un Objet_U sur un flot d'entree Methode a surcharger.
virtual Sortie & printOn(Sortie &) const
Ecriture de l'objet sur un flot de sortie Methode a surcharger.
static const Comm_Group & get_node_group()
Renvoie une reference au groupe sur les noeuds.
static const Comm_Group & current_group()
renvoie une reference au groupe de processeurs actif courant
void begin_count(const STD_COUNTERS &std_cnt, int counter_lvl=-100000)
void end_count(const std::string &custom_count_name, int count_increment=1, long int quantity_increment=0)
End the count of a counter and update the counter values.
static bool is_parallel()
static Sortie & Journal(int message_level=0)
Renvoie un objet statique de type Sortie qui sert de journal d'evenements.
static void barrier()
Synchronise tous les processeurs du groupe courant (attend que tous les processeurs soient arrives a ...
static int me()
renvoie mon rang dans le groupe de communication courant.
static void exit(int exit_code=-1)
Routine de sortie de TRUST dans une region Kokkos.
static int je_suis_maitre()
renvoie 1 si on est sur le processeur maitre du groupe courant (c'est a dire me() == 0),...
Classe de base des flux de sortie.
_SIZE_ size_array() const