]> git.ipfire.org Git - thirdparty/gcc.git/commitdiff
libgomp/plugin/plugin-nvptx.c: Fix device used for stream creation
authorTobias Burnus <tburnus@baylibre.com>
Mon, 24 Mar 2025 15:08:20 +0000 (16:08 +0100)
committerTobias Burnus <tburnus@baylibre.com>
Mon, 24 Mar 2025 15:08:20 +0000 (16:08 +0100)
libgomp/ChangeLog:

* plugin/plugin-nvptx.c (GOMP_OFFLOAD_interop): Set context for
stream creation to use the specified device.

libgomp/plugin/plugin-nvptx.c

index 822c6a410e287fc574d90a3eefa684b169e9c111..a5cf859db197f3c94d51d6dc22b7e8d74277a013 100644 (file)
@@ -2483,12 +2483,26 @@ GOMP_OFFLOAD_interop (struct interop_obj_t *obj, int ord,
          break;
       }
 
-  obj->device_data = ptx_devices[ord];
+  struct ptx_device *ptx_dev = obj->device_data = ptx_devices[ord];
 
   if (targetsync)
     {
       CUstream stream = NULL;
-      CUDA_CALL_ASSERT (cuStreamCreate, &stream, CU_STREAM_DEFAULT);
+      CUdevice cur_ctx_dev;
+      CUresult res = CUDA_CALL_NOCHECK (cuCtxGetDevice, &cur_ctx_dev);
+      if (res != CUDA_SUCCESS && res != CUDA_ERROR_INVALID_CONTEXT)
+       GOMP_PLUGIN_fatal ("cuCtxGetDevice error: %s", cuda_error (res));
+      if (res != CUDA_ERROR_INVALID_CONTEXT && ptx_dev->dev == cur_ctx_dev)
+       CUDA_CALL_ASSERT (cuStreamCreate, &stream, CU_STREAM_DEFAULT);
+      else
+       {
+         CUcontext old_ctx;
+         assert (ptx_dev->ctx);
+         CUDA_CALL_ASSERT (cuCtxPushCurrent, ptx_dev->ctx);
+         CUDA_CALL_ASSERT (cuStreamCreate, &stream, CU_STREAM_DEFAULT);
+         if (res != CUDA_ERROR_INVALID_CONTEXT)
+           CUDA_CALL_ASSERT (cuCtxPopCurrent, &old_ctx);
+       }
       obj->stream = stream;
     }
 }