Actual source code: cupmcontext.hpp

  1: #pragma once

  3: #include <petsc/private/deviceimpl.h>
  4: #include <petsc/private/cupmsolverinterface.hpp>
  5: #include <petsc/private/logimpl.h>

  7: #include <petsc/private/cpp/array.hpp>

  9: #include "../segmentedmempool.hpp"
 10: #include "cupmallocator.hpp"
 11: #include "cupmstream.hpp"
 12: #include "cupmevent.hpp"

 14: namespace Petsc
 15: {

 17: namespace device
 18: {

 20: namespace cupm
 21: {

 23: namespace impl
 24: {

 26: template <DeviceType T>
 27: class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL DeviceContext : SolverInterface<T> {
 28: public:
 29:   PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);

 31: private:
 32:   template <typename H, std::size_t>
 33:   struct HandleTag {
 34:     using type = H;
 35:   };

 37:   using stream_tag = HandleTag<cupmStream_t, 0>;
 38:   using blas_tag   = HandleTag<cupmBlasHandle_t, 1>;
 39:   using solver_tag = HandleTag<cupmSolverHandle_t, 2>;

 41:   using stream_type = CUPMStream<T>;
 42:   using event_type  = CUPMEvent<T>;

 44: public:
 45:   // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
 46:   // header, but since we are using the power of templates it must be declared part of
 47:   // this class to have easy access the same typedefs. Technically one can make a
 48:   // templated struct outside the class but it's more code for the same result.
 49:   struct PetscDeviceContext_IMPLS {
 50:     stream_type stream{};
 51:     cupmEvent_t event{};
 52:     cupmEvent_t begin{}; // timer-only
 53:     cupmEvent_t end{};   // timer-only
 54: #if PetscDefined(USE_DEBUG)
 55:     PetscBool timerInUse{};
 56:     PetscBool EnergyMeterInUse{};
 57: #endif
 58:     cupmBlasHandle_t   blas{};
 59:     cupmSolverHandle_t solver{};
 60: #if PetscDefined(HAVE_CUDA)
 61:     nvmlDevice_t       nvmlHandle{};
 62:     unsigned long long energymeterbegin{};
 63:     unsigned long long energymeterend{};
 64: #endif

 66:     constexpr PetscDeviceContext_IMPLS() noexcept = default;

 68:     PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); }

 70:     PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; }

 72:     PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; }
 73:   };

 75: private:
 76:   static bool initialized_;

 78:   static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES>   blashandles_;
 79:   static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;

 81:   PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }

 83:   PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }

 85:   PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }

 87:   PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }

 89:   // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
 90:   // handles
 91:   static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }

 93:   static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
 94:   {
 95:     const auto dci    = impls_cast_(dctx);
 96:     auto      &handle = blashandles_[dctx->device->deviceId];

 98:     PetscFunctionBegin;
 99:     if (!handle) {
100:       PetscCall(PetscLogEventsPause());
101:       PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
102:       for (auto i = 0; i < 3; ++i) {
103:         const auto cberr = cupmBlasCreate(handle.ptr_to());
104:         if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
105:         if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
106:         if (i != 2) {
107:           PetscCall(PetscSleep(3));
108:           continue;
109:         }
110:         PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
111:       }
112:       PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
113:       PetscCall(PetscLogEventsResume());
114:     }
115:     PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
116:     dci->blas = handle;
117:     PetscFunctionReturn(PETSC_SUCCESS);
118:   }

120:   static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
121:   {
122:     const auto dci    = impls_cast_(dctx);
123:     auto      &handle = solverhandles_[dctx->device->deviceId];

125:     PetscFunctionBegin;
126:     if (!handle) {
127:       PetscCall(PetscLogEventsPause());
128:       PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
129:       for (auto i = 0; i < 3; ++i) {
130:         const auto cerr = cupmSolverCreate(&handle);
131:         if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
132:         if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
133:         if (i < 2) {
134:           PetscCall(PetscSleep(3));
135:           continue;
136:         }
137:         PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
138:       }
139:       PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
140:       PetscCall(PetscLogEventsResume());
141:     }
142:     PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
143:     dci->solver = handle;
144:     PetscFunctionReturn(PETSC_SUCCESS);
145:   }

147:   static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
148:   {
149:     const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;

151:     PetscFunctionBegin;
152:     PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
153:                PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
154:     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
155:     PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
156:     PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
157:     PetscFunctionReturn(PETSC_SUCCESS);
158:   }

160:   static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }

162:   static PetscErrorCode finalize_() noexcept
163:   {
164:     PetscFunctionBegin;
165:     for (auto &&handle : blashandles_) {
166:       if (handle) {
167:         PetscCallCUPMBLAS(cupmBlasDestroy(handle));
168:         handle = nullptr;
169:       }
170:     }
171:     for (auto &&handle : solverhandles_) {
172:       if (handle) {
173:         PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
174:         handle = nullptr;
175:       }
176:     }
177:     initialized_ = false;
178:     PetscFunctionReturn(PETSC_SUCCESS);
179:   }

181:   template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
182:   PETSC_NODISCARD static PoolType &default_pool_() noexcept
183:   {
184:     static PoolType pool;
185:     return pool;
186:   }

188:   static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
189:   {
190:     PetscFunctionBegin;
191:     PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
192:     PetscFunctionReturn(PETSC_SUCCESS);
193:   }

195: public:
196:   // All of these functions MUST be static in order to be callable from C, otherwise they
197:   // get the implicit 'this' pointer tacked on
198:   static PetscErrorCode destroy(PetscDeviceContext) noexcept;
199:   static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
200:   static PetscErrorCode setUp(PetscDeviceContext) noexcept;
201:   static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
202:   static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
203:   static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
204:   template <typename Handle_t>
205:   static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
206:   template <typename Handle_t>
207:   static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept;
208:   static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
209:   static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
210:   static PetscErrorCode getPower(PetscDeviceContext, PetscLogDouble *) noexcept;
211:   static PetscErrorCode beginEnergyMeter(PetscDeviceContext) noexcept;
212:   static PetscErrorCode endEnergyMeter(PetscDeviceContext, PetscLogDouble *) noexcept;
213:   static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
214:   static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
215:   static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
216:   static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
217:   static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
218:   static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
219:   static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;

221:   // not a PetscDeviceContext method, this registers the class
222:   static PetscErrorCode initialize(PetscDevice) noexcept;

224:   // clang-format off
225:   static constexpr _DeviceContextOps ops = {
226:     PetscDesignatedInitializer(destroy, destroy),
227:     PetscDesignatedInitializer(changestreamtype, changeStreamType),
228:     PetscDesignatedInitializer(setup, setUp),
229:     PetscDesignatedInitializer(query, query),
230:     PetscDesignatedInitializer(waitforcontext, waitForContext),
231:     PetscDesignatedInitializer(synchronize, synchronize),
232:     PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
233:     PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
234:     PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>),
235:     PetscDesignatedInitializer(begintimer, beginTimer),
236:     PetscDesignatedInitializer(endtimer, endTimer),
237: #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS)
238:     PetscDesignatedInitializer(getpower, getPower),
239: #else
240:     PetscDesignatedInitializer(getpower, nullptr),
241: #endif
242: #if PetscDefined(HAVE_CUDA)
243:     PetscDesignatedInitializer(beginenergymeter, beginEnergyMeter),
244:     PetscDesignatedInitializer(endenergymeter, endEnergyMeter),
245: #else
246:     PetscDesignatedInitializer(beginenergymeter, nullptr),
247:     PetscDesignatedInitializer(endenergymeter, nullptr),
248: #endif
249:     PetscDesignatedInitializer(memalloc, memAlloc),
250:     PetscDesignatedInitializer(memfree, memFree),
251:     PetscDesignatedInitializer(memcopy, memCopy),
252:     PetscDesignatedInitializer(memset, memSet),
253:     PetscDesignatedInitializer(createevent, createEvent),
254:     PetscDesignatedInitializer(recordevent, recordEvent),
255:     PetscDesignatedInitializer(waitforevent, waitForEvent)
256:   };
257:   // clang-format on
258: };

260: // not a PetscDeviceContext method, this initializes the CLASS
261: template <DeviceType T>
262: inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
263: {
264:   PetscFunctionBegin;
265:   if (PetscUnlikely(!initialized_)) {
266:     uint64_t      threshold = UINT64_MAX;
267:     cupmMemPool_t mempool;

269:     initialized_ = true;
270:     PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
271:     PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
272:     blashandles_.fill(nullptr);
273:     solverhandles_.fill(nullptr);
274:     PetscCall(PetscRegisterFinalize(finalize_));
275:   }
276:   PetscFunctionReturn(PETSC_SUCCESS);
277: }

279: template <DeviceType T>
280: inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
281: {
282:   PetscFunctionBegin;
283:   if (const auto dci = impls_cast_(dctx)) {
284:     PetscCall(dci->stream.destroy());
285:     if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
286:     if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
287:     if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
288:     delete dci;
289:     dctx->data = nullptr;
290:   }
291:   PetscFunctionReturn(PETSC_SUCCESS);
292: }

294: template <DeviceType T>
295: inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
296: {
297:   const auto dci = impls_cast_(dctx);

299:   PetscFunctionBegin;
300:   PetscCall(dci->stream.destroy());
301:   // set these to null so they aren't usable until setup is called again
302:   dci->blas   = nullptr;
303:   dci->solver = nullptr;
304:   PetscFunctionReturn(PETSC_SUCCESS);
305: }

307: template <DeviceType T>
308: inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
309: {
310:   const auto dci   = impls_cast_(dctx);
311:   auto      &event = dci->event;

313:   PetscFunctionBegin;
314:   PetscCall(check_current_device_(dctx));
315:   PetscCall(dci->stream.change_type(dctx->streamType));
316:   if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
317: #if PetscDefined(USE_DEBUG)
318:   dci->timerInUse = PETSC_FALSE;
319: #endif
320:   PetscFunctionReturn(PETSC_SUCCESS);
321: }

323: template <DeviceType T>
324: inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
325: {
326:   PetscFunctionBegin;
327:   PetscCall(check_current_device_(dctx));
328:   switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
329:   case cupmSuccess:
330:     *idle = PETSC_TRUE;
331:     break;
332:   case cupmErrorNotReady:
333:     *idle = PETSC_FALSE;
334:     // reset the error
335:     cerr = cupmGetLastError();
336:     static_cast<void>(cerr);
337:     break;
338:   default:
339:     PetscCallCUPM(cerr);
340:     PetscUnreachable();
341:   }
342:   PetscFunctionReturn(PETSC_SUCCESS);
343: }

345: template <DeviceType T>
346: inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
347: {
348:   const auto dcib  = impls_cast_(dctxb);
349:   const auto event = dcib->event;

351:   PetscFunctionBegin;
352:   PetscCall(check_current_device_(dctxa, dctxb));
353:   PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
354:   PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
355:   PetscFunctionReturn(PETSC_SUCCESS);
356: }

358: template <DeviceType T>
359: inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
360: {
361:   auto idle = PETSC_TRUE;

363:   PetscFunctionBegin;
364:   PetscCall(query(dctx, &idle));
365:   if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
366:   PetscFunctionReturn(PETSC_SUCCESS);
367: }

369: template <DeviceType T>
370: template <typename handle_t>
371: inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
372: {
373:   PetscFunctionBegin;
374:   PetscCall(initialize_handle_(handle_t{}, dctx));
375:   *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
376:   PetscFunctionReturn(PETSC_SUCCESS);
377: }

379: template <DeviceType T>
380: template <typename handle_t>
381: inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept
382: {
383:   using handle_type = typename handle_t::type;

385:   PetscFunctionBegin;
386:   PetscCall(initialize_handle_(handle_t{}, dctx));
387:   *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{})));
388:   PetscFunctionReturn(PETSC_SUCCESS);
389: }

391: template <DeviceType T>
392: inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
393: {
394:   const auto dci = impls_cast_(dctx);

396:   PetscFunctionBegin;
397:   PetscCall(check_current_device_(dctx));
398: #if PetscDefined(USE_DEBUG)
399:   PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
400:   dci->timerInUse = PETSC_TRUE;
401: #endif
402:   if (!dci->begin) {
403:     PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
404:     PetscCallCUPM(cupmEventCreate(&dci->begin));
405:     PetscCallCUPM(cupmEventCreate(&dci->end));
406:   }
407:   PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
408:   PetscFunctionReturn(PETSC_SUCCESS);
409: }

411: template <DeviceType T>
412: inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
413: {
414:   float      gtime;
415:   const auto dci = impls_cast_(dctx);
416:   const auto end = dci->end;

418:   PetscFunctionBegin;
419:   PetscCall(check_current_device_(dctx));
420: #if PetscDefined(USE_DEBUG)
421:   PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
422:   dci->timerInUse = PETSC_FALSE;
423: #endif
424:   PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
425:   PetscCallCUPM(cupmEventSynchronize(end));
426:   PetscCallCUPM(cupmEventElapsedTime(&gtime, dci->begin, end));
427:   *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
428:   PetscFunctionReturn(PETSC_SUCCESS);
429: }

431: #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS)
432: template <DeviceType T>
433: inline PetscErrorCode DeviceContext<T>::getPower(PetscDeviceContext dctx, PetscLogDouble *power) noexcept
434: {
435:   const auto       dci = impls_cast_(dctx);
436:   nvmlFieldValue_t values[1];

438:   PetscFunctionBegin;
439:   PetscCall(check_current_device_(dctx));
440:   PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream()));
441:   values[0].fieldId = NVML_FI_DEV_POWER_INSTANT;
442:   if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle));
443:   PetscCallNVML(nvmlDeviceGetFieldValues(dci->nvmlHandle, 1, values));
444:   *power = static_cast<util::remove_pointer_t<decltype(power)>>(values[0].value.uiVal);
445:   PetscFunctionReturn(PETSC_SUCCESS);
446: }
447: #endif

449: #if PetscDefined(HAVE_CUDA)
450: template <DeviceType T>
451: inline PetscErrorCode DeviceContext<T>::beginEnergyMeter(PetscDeviceContext dctx) noexcept
452: {
453:   const auto dci = impls_cast_(dctx);

455:   PetscFunctionBegin;
456:   PetscCall(check_current_device_(dctx));
457:   #if PetscDefined(USE_DEBUG)
458:   PetscCheck(!dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterEnd()?");
459:   dci->EnergyMeterInUse = PETSC_TRUE;
460:   #endif
461:   if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle));
462:   PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterbegin));
463:   PetscFunctionReturn(PETSC_SUCCESS);
464: }

466: template <DeviceType T>
467: inline PetscErrorCode DeviceContext<T>::endEnergyMeter(PetscDeviceContext dctx, PetscLogDouble *energy) noexcept
468: {
469:   const auto dci = impls_cast_(dctx);

471:   PetscFunctionBegin;
472:   PetscCall(check_current_device_(dctx));
473:   #if PetscDefined(USE_DEBUG)
474:   PetscCheck(dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterBegin()?");
475:   dci->EnergyMeterInUse = PETSC_FALSE;
476:   #endif
477:   PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream()));
478:   PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterend));
479:   *energy = static_cast<util::remove_pointer_t<decltype(energy)>>(dci->energymeterend - dci->energymeterbegin) / 1000; // convert to Joule
480:   PetscFunctionReturn(PETSC_SUCCESS);
481: }
482: #endif

484: template <DeviceType T>
485: inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
486: {
487:   const auto &stream = impls_cast_(dctx)->stream;

489:   PetscFunctionBegin;
490:   PetscCall(check_current_device_(dctx));
491:   PetscCall(check_memtype_(mtype, "allocating"));
492:   if (PetscMemTypeHost(mtype)) {
493:     PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
494:   } else {
495:     PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
496:   }
497:   if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
498:   PetscFunctionReturn(PETSC_SUCCESS);
499: }

501: template <DeviceType T>
502: inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
503: {
504:   const auto &stream = impls_cast_(dctx)->stream;

506:   PetscFunctionBegin;
507:   PetscCall(check_current_device_(dctx));
508:   PetscCall(check_memtype_(mtype, "freeing"));
509:   if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
510:   if (PetscMemTypeHost(mtype)) {
511:     PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
512:     // if ptr exists still exists the pool didn't own it
513:     if (*ptr) {
514:       auto registered = PETSC_FALSE, managed = PETSC_FALSE;

516:       PetscCall(PetscCUPMGetMemType(*ptr, nullptr, &registered, &managed));
517:       if (registered) {
518:         PetscCallCUPM(cupmFreeHost(*ptr));
519:       } else if (managed) {
520:         PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
521:       }
522:     }
523:   } else {
524:     PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
525:     // if ptr still exists the pool didn't own it
526:     if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
527:   }
528:   PetscFunctionReturn(PETSC_SUCCESS);
529: }

531: template <DeviceType T>
532: inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
533: {
534:   const auto stream = impls_cast_(dctx)->stream.get_stream();

536:   PetscFunctionBegin;
537:   // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
538:   if (mode == PETSC_DEVICE_COPY_HTOH) {
539:     const auto cerr = cupmStreamQuery(stream);

541:     // yes this is faster
542:     if (cerr == cupmSuccess) {
543:       PetscCall(PetscMemcpy(dest, src, n));
544:       PetscFunctionReturn(PETSC_SUCCESS);
545:     } else if (cerr == cupmErrorNotReady) {
546:       auto PETSC_UNUSED unused = cupmGetLastError();

548:       static_cast<void>(unused);
549:     } else {
550:       PetscCallCUPM(cerr);
551:     }
552:   }
553:   PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
554:   PetscFunctionReturn(PETSC_SUCCESS);
555: }

557: template <DeviceType T>
558: inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
559: {
560:   PetscFunctionBegin;
561:   PetscCall(check_current_device_(dctx));
562:   PetscCall(check_memtype_(mtype, "zeroing"));
563:   PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
564:   PetscFunctionReturn(PETSC_SUCCESS);
565: }

567: template <DeviceType T>
568: inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
569: {
570:   PetscFunctionBegin;
571:   PetscCallCXX(event->data = new event_type{});
572:   event->destroy = [](PetscEvent event) {
573:     PetscFunctionBegin;
574:     delete event_cast_(event);
575:     event->data = nullptr;
576:     PetscFunctionReturn(PETSC_SUCCESS);
577:   };
578:   PetscFunctionReturn(PETSC_SUCCESS);
579: }

581: template <DeviceType T>
582: inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
583: {
584:   PetscFunctionBegin;
585:   PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
586:   PetscFunctionReturn(PETSC_SUCCESS);
587: }

589: template <DeviceType T>
590: inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
591: {
592:   PetscFunctionBegin;
593:   PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
594:   PetscFunctionReturn(PETSC_SUCCESS);
595: }

597: // initialize the static member variables
598: template <DeviceType T>
599: bool DeviceContext<T>::initialized_ = false;

601: template <DeviceType T>
602: std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};

604: template <DeviceType T>
605: std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};

607: template <DeviceType T>
608: constexpr _DeviceContextOps DeviceContext<T>::ops;

610: } // namespace impl

612: // shorten this one up a bit (and instantiate the templates)
613: using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
614: using CUPMContextHip  = impl::DeviceContext<DeviceType::HIP>;

616: // shorthand for what is an EXTREMELY long name
617: #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS

619: } // namespace cupm

621: } // namespace device

623: } // namespace Petsc