From: Julian Seward Date: Thu, 5 Nov 2009 08:43:38 +0000 (+0000) Subject: Correctly handle MPI_STATUS{ES}_IGNORE as valid values for X-Git-Tag: svn/VALGRIND_3_6_0~483 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f3bb59a960610e41aaf0ee3033bd204592080c61;p=thirdparty%2Fvalgrind.git Correctly handle MPI_STATUS{ES}_IGNORE as valid values for MPI_Status* arguments (as opposed to segfaulting :-) git-svn-id: svn://svn.valgrind.org/valgrind/trunk@10926 --- diff --git a/mpi/libmpiwrap.c b/mpi/libmpiwrap.c index c0f41b8b38..1519fd7cf4 100644 --- a/mpi/libmpiwrap.c +++ b/mpi/libmpiwrap.c @@ -57,6 +57,32 @@ without prior written permission. */ +/* Handling of MPI_STATUS{ES}_IGNORE for MPI_Status* arguments. + + The MPI-2 spec allows many functions which have MPI_Status* purely + as an out parameter, to accept the constants MPI_STATUS_IGNORE or + MPI_STATUSES_IGNORE there instead, if the caller does not care + about the status. See the MPI-2 spec sec 4.5.1 ("Passing + MPI_STATUS_IGNORE for Status"). (mpi2-report.pdf, 1615898 bytes, + md5=694a5efe2fd291eecf7e8c9875b5f43f). + + This library handles such cases by allocating a fake MPI_Status + object (on the stack) or an array thereof (on the heap), and + passing that onwards instead. From the outside the caller sees no + difference. Unfortunately the simpler approach of merely detecting + and handling these special cases at a lower level does not work, + because we need to use information returned in MPI_Status* + arguments to paint result buffers, even if the caller doesn't + supply a real MPI_Status object. + + Eg, MPI_Recv. We can't paint the result buffer without knowing how + many items arrived; but we can't find that out without passing a + real MPI_Status object to the (real) MPI_Recv call. Hence, if the + caller did not supply one, we have no option but to use a temporary + stack allocated one for the inner call. Ditto, more indirectly + (via maybe_complete) for nonblocking receives and the various + associated wait/test calls. */ + /*------------------------------------------------------------*/ /*--- includes ---*/ @@ -103,6 +129,17 @@ #endif +/* Define HAVE_MPI_STATUS_IGNORE iff we have to deal with + MPI_STATUS{ES}_IGNORE. */ +#if MPI_VERSION >= 2 \ + || (defined(MPI_STATUS_IGNORE) && defined(MPI_STATUSES_IGNORE)) +# undef HAVE_MPI_STATUS_IGNORE +# define HAVE_MPI_STATUS_IGNORE 1 +#else +# undef HAVE_MPI_STATUS_IGNORE +#endif + + /*------------------------------------------------------------*/ /*--- Decls ---*/ /*------------------------------------------------------------*/ @@ -401,6 +438,20 @@ Bool eq_MPI_Request ( MPI_Request r1, MPI_Request r2 ) return r1 == r2; } +/* Return True if status is MPI_STATUS_IGNORE or MPI_STATUSES_IGNORE. + On MPI-1.x platforms which don't have these symbols (and they would + only have them if they've been backported from 2.x) always return + False. */ +static __inline__ +Bool isMSI ( MPI_Status* status ) +{ +# if defined(HAVE_MPI_STATUS_IGNORE) + return status == MPI_STATUSES_IGNORE || status == MPI_STATUS_IGNORE; +# else + return False; +# endif +} + /* Get the 'extent' of a type. Note, as per the MPI spec this includes whatever padding would be required when using 'ty' in an array. */ @@ -1045,10 +1096,13 @@ int WRAPPER_FOR(PMPI_Recv)(void *buf, int count, MPI_Datatype datatype, int source, int tag, MPI_Comm comm, MPI_Status *status) { - OrigFn fn; - int err, recv_count = 0; + OrigFn fn; + int err, recv_count = 0; + MPI_Status fake_status; VALGRIND_GET_ORIG_FN(fn); before("Recv"); + if (isMSI(status)) + status = &fake_status; check_mem_is_addressable(buf, count, datatype); check_mem_is_addressable_untyped(status, sizeof(*status)); CALL_FN_W_7W(err, fn, buf,count,datatype,source,tag,comm,status); @@ -1386,10 +1440,13 @@ int WRAPPER_FOR(PMPI_Wait)( MPI_Request* request, MPI_Status* status ) { MPI_Request request_before; + MPI_Status fake_status; OrigFn fn; int err; VALGRIND_GET_ORIG_FN(fn); before("Wait"); + if (isMSI(status)) + status = &fake_status; check_mem_is_addressable_untyped(status, sizeof(MPI_Status)); check_mem_is_defined_untyped(request, sizeof(MPI_Request)); request_before = *request; @@ -1410,10 +1467,13 @@ int WRAPPER_FOR(PMPI_Waitany)( int count, MPI_Status* status ) { MPI_Request* requests_before = NULL; + MPI_Status fake_status; OrigFn fn; int err, i; VALGRIND_GET_ORIG_FN(fn); before("Waitany"); + if (isMSI(status)) + status = &fake_status; if (0) fprintf(stderr, "Waitany: %d\n", count); check_mem_is_addressable_untyped(index, sizeof(int)); check_mem_is_addressable_untyped(status, sizeof(MPI_Status)); @@ -1441,9 +1501,14 @@ int WRAPPER_FOR(PMPI_Waitall)( int count, MPI_Request* requests_before = NULL; OrigFn fn; int err, i; + Bool free_sta = False; VALGRIND_GET_ORIG_FN(fn); before("Waitall"); if (0) fprintf(stderr, "Waitall: %d\n", count); + if (isMSI(statuses)) { + free_sta = True; + statuses = malloc( (count < 0 ? 0 : count) * sizeof(MPI_Status) ); + } for (i = 0; i < count; i++) { check_mem_is_addressable_untyped(&statuses[i], sizeof(MPI_Status)); check_mem_is_defined_untyped(&requests[i], sizeof(MPI_Request)); @@ -1462,6 +1527,8 @@ int WRAPPER_FOR(PMPI_Waitall)( int count, } if (requests_before) free(requests_before); + if (free_sta) + free(statuses); after("Waitall", err); return err; } @@ -1472,10 +1539,13 @@ int WRAPPER_FOR(PMPI_Test)( MPI_Request* request, int* flag, MPI_Status* status ) { MPI_Request request_before; + MPI_Status fake_status; OrigFn fn; int err; VALGRIND_GET_ORIG_FN(fn); before("Test"); + if (isMSI(status)) + status = &fake_status; check_mem_is_addressable_untyped(status, sizeof(MPI_Status)); check_mem_is_addressable_untyped(flag, sizeof(int)); check_mem_is_defined_untyped(request, sizeof(MPI_Request)); @@ -1498,9 +1568,14 @@ int WRAPPER_FOR(PMPI_Testall)( int count, MPI_Request* requests, MPI_Request* requests_before = NULL; OrigFn fn; int err, i; + Bool free_sta = False; VALGRIND_GET_ORIG_FN(fn); before("Testall"); if (0) fprintf(stderr, "Testall: %d\n", count); + if (isMSI(statuses)) { + free_sta = True; + statuses = malloc( (count < 0 ? 0 : count) * sizeof(MPI_Status) ); + } check_mem_is_addressable_untyped(flag, sizeof(int)); for (i = 0; i < count; i++) { check_mem_is_addressable_untyped(&statuses[i], sizeof(MPI_Status)); @@ -1516,11 +1591,14 @@ int WRAPPER_FOR(PMPI_Testall)( int count, MPI_Request* requests, for (i = 0; i < count; i++) { maybe_complete(e_i_s, requests_before[i], requests[i], &statuses[i]); - make_mem_defined_if_addressable_untyped(&statuses[i], sizeof(MPI_Status)); + make_mem_defined_if_addressable_untyped(&statuses[i], + sizeof(MPI_Status)); } } if (requests_before) free(requests_before); + if (free_sta) + free(statuses); after("Testall", err); return err; } @@ -1533,10 +1611,13 @@ int WRAPPER_FOR(PMPI_Iprobe)(int source, int tag, MPI_Comm comm, int* flag, MPI_Status* status) { - OrigFn fn; - int err; + MPI_Status fake_status; + OrigFn fn; + int err; VALGRIND_GET_ORIG_FN(fn); before("Iprobe"); + if (isMSI(status)) + status = &fake_status; check_mem_is_addressable_untyped(flag, sizeof(*flag)); check_mem_is_addressable_untyped(status, sizeof(*status)); CALL_FN_W_5W(err, fn, source,tag,comm,flag,status); @@ -1555,10 +1636,13 @@ int WRAPPER_FOR(PMPI_Iprobe)(int source, int tag, int WRAPPER_FOR(PMPI_Probe)(int source, int tag, MPI_Comm comm, MPI_Status* status) { - OrigFn fn; - int err; + MPI_Status fake_status; + OrigFn fn; + int err; VALGRIND_GET_ORIG_FN(fn); before("Probe"); + if (isMSI(status)) + status = &fake_status; check_mem_is_addressable_untyped(status, sizeof(*status)); CALL_FN_W_WWWW(err, fn, source,tag,comm,status); make_mem_defined_if_addressable_if_success_untyped(err, status, sizeof(*status)); @@ -1606,12 +1690,16 @@ int WRAPPER_FOR(PMPI_Sendrecv)( int source, int recvtag, MPI_Comm comm, MPI_Status *status) { - OrigFn fn; - int err, recvcount_actual = 0; + MPI_Status fake_status; + OrigFn fn; + int err, recvcount_actual = 0; VALGRIND_GET_ORIG_FN(fn); before("Sendrecv"); + if (isMSI(status)) + status = &fake_status; check_mem_is_defined(sendbuf, sendcount, sendtype); check_mem_is_addressable(recvbuf, recvcount, recvtype); + check_mem_is_addressable_untyped(status, sizeof(*status)); CALL_FN_W_12W(err, fn, sendbuf,sendcount,sendtype,dest,sendtag, recvbuf,recvcount,recvtype,source,recvtag, comm,status);