Actual source code: cupmcontext.hpp
1: #pragma once
3: #include <petsc/private/deviceimpl.h>
4: #include <petsc/private/cupmsolverinterface.hpp>
5: #include <petsc/private/logimpl.h>
7: #include <petsc/private/cpp/array.hpp>
9: #include "../segmentedmempool.hpp"
10: #include "cupmallocator.hpp"
11: #include "cupmstream.hpp"
12: #include "cupmevent.hpp"
14: namespace Petsc
15: {
17: namespace device
18: {
20: namespace cupm
21: {
23: namespace impl
24: {
26: template <DeviceType T>
27: class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL DeviceContext : SolverInterface<T> {
28: public:
29: PETSC_CUPMSOLVER_INHERIT_INTERFACE_TYPEDEFS_USING(T);
31: private:
32: template <typename H, std::size_t>
33: struct HandleTag {
34: using type = H;
35: };
37: using stream_tag = HandleTag<cupmStream_t, 0>;
38: using blas_tag = HandleTag<cupmBlasHandle_t, 1>;
39: using solver_tag = HandleTag<cupmSolverHandle_t, 2>;
41: using stream_type = CUPMStream<T>;
42: using event_type = CUPMEvent<T>;
44: public:
45: // This is the canonical PETSc "impls" struct that normally resides in a standalone impls
46: // header, but since we are using the power of templates it must be declared part of
47: // this class to have easy access the same typedefs. Technically one can make a
48: // templated struct outside the class but it's more code for the same result.
49: struct PetscDeviceContext_IMPLS {
50: stream_type stream{};
51: cupmEvent_t event{};
52: cupmEvent_t begin{}; // timer-only
53: cupmEvent_t end{}; // timer-only
54: #if PetscDefined(USE_DEBUG)
55: PetscBool timerInUse{};
56: PetscBool EnergyMeterInUse{};
57: #endif
58: cupmBlasHandle_t blas{};
59: cupmSolverHandle_t solver{};
60: #if PetscDefined(HAVE_CUDA)
61: nvmlDevice_t nvmlHandle{};
62: unsigned long long energymeterbegin{};
63: unsigned long long energymeterend{};
64: #endif
66: constexpr PetscDeviceContext_IMPLS() noexcept = default;
68: PETSC_NODISCARD const cupmStream_t &get(stream_tag) const noexcept { return this->stream.get_stream(); }
70: PETSC_NODISCARD const cupmBlasHandle_t &get(blas_tag) const noexcept { return this->blas; }
72: PETSC_NODISCARD const cupmSolverHandle_t &get(solver_tag) const noexcept { return this->solver; }
73: };
75: private:
76: static bool initialized_;
78: static std::array<cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> blashandles_;
79: static std::array<cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> solverhandles_;
81: PETSC_NODISCARD static constexpr PetscDeviceContext_IMPLS *impls_cast_(PetscDeviceContext ptr) noexcept { return static_cast<PetscDeviceContext_IMPLS *>(ptr->data); }
83: PETSC_NODISCARD static constexpr CUPMEvent<T> *event_cast_(PetscEvent event) noexcept { return static_cast<CUPMEvent<T> *>(event->data); }
85: PETSC_NODISCARD static PetscLogEvent CUPMBLAS_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUBLAS_HANDLE_CREATE : HIPBLAS_HANDLE_CREATE; }
87: PETSC_NODISCARD static PetscLogEvent CUPMSOLVER_HANDLE_CREATE() noexcept { return T == DeviceType::CUDA ? CUSOLVER_HANDLE_CREATE : HIPSOLVER_HANDLE_CREATE; }
89: // this exists purely to satisfy the compiler so the tag-based dispatch works for the other
90: // handles
91: static PetscErrorCode initialize_handle_(stream_tag, PetscDeviceContext) noexcept { return PETSC_SUCCESS; }
93: static PetscErrorCode initialize_handle_(blas_tag, PetscDeviceContext dctx) noexcept
94: {
95: const auto dci = impls_cast_(dctx);
96: auto &handle = blashandles_[dctx->device->deviceId];
98: PetscFunctionBegin;
99: if (!handle) {
100: PetscCall(PetscLogEventsPause());
101: PetscCall(PetscLogEventBegin(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
102: for (auto i = 0; i < 3; ++i) {
103: const auto cberr = cupmBlasCreate(handle.ptr_to());
104: if (PetscLikely(cberr == CUPMBLAS_STATUS_SUCCESS)) break;
105: if (PetscUnlikely(cberr != CUPMBLAS_STATUS_ALLOC_FAILED) && (cberr != CUPMBLAS_STATUS_NOT_INITIALIZED)) PetscCallCUPMBLAS(cberr);
106: if (i != 2) {
107: PetscCall(PetscSleep(3));
108: continue;
109: }
110: PetscCheck(cberr == CUPMBLAS_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmBlasName());
111: }
112: PetscCall(PetscLogEventEnd(CUPMBLAS_HANDLE_CREATE(), 0, 0, 0, 0));
113: PetscCall(PetscLogEventsResume());
114: }
115: PetscCallCUPMBLAS(cupmBlasSetStream(handle, dci->stream.get_stream()));
116: dci->blas = handle;
117: PetscFunctionReturn(PETSC_SUCCESS);
118: }
120: static PetscErrorCode initialize_handle_(solver_tag, PetscDeviceContext dctx) noexcept
121: {
122: const auto dci = impls_cast_(dctx);
123: auto &handle = solverhandles_[dctx->device->deviceId];
125: PetscFunctionBegin;
126: if (!handle) {
127: PetscCall(PetscLogEventsPause());
128: PetscCall(PetscLogEventBegin(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
129: for (auto i = 0; i < 3; ++i) {
130: const auto cerr = cupmSolverCreate(&handle);
131: if (PetscLikely(cerr == CUPMSOLVER_STATUS_SUCCESS)) break;
132: if ((cerr != CUPMSOLVER_STATUS_NOT_INITIALIZED) && (cerr != CUPMSOLVER_STATUS_ALLOC_FAILED)) PetscCallCUPMSOLVER(cerr);
133: if (i < 2) {
134: PetscCall(PetscSleep(3));
135: continue;
136: }
137: PetscCheck(cerr == CUPMSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_GPU_RESOURCE, "Unable to initialize %s", cupmSolverName());
138: }
139: PetscCall(PetscLogEventEnd(CUPMSOLVER_HANDLE_CREATE(), 0, 0, 0, 0));
140: PetscCall(PetscLogEventsResume());
141: }
142: PetscCallCUPMSOLVER(cupmSolverSetStream(handle, dci->stream.get_stream()));
143: dci->solver = handle;
144: PetscFunctionReturn(PETSC_SUCCESS);
145: }
147: static PetscErrorCode check_current_device_(PetscDeviceContext dctxl, PetscDeviceContext dctxr) noexcept
148: {
149: const auto devidl = dctxl->device->deviceId, devidr = dctxr->device->deviceId;
151: PetscFunctionBegin;
152: PetscCheck(devidl == devidr, PETSC_COMM_SELF, PETSC_ERR_GPU, "Device contexts must be on the same device; dctx A (id %" PetscInt64_FMT " device id %" PetscInt_FMT ") dctx B (id %" PetscInt64_FMT " device id %" PetscInt_FMT ")",
153: PetscObjectCast(dctxl)->id, devidl, PetscObjectCast(dctxr)->id, devidr);
154: PetscCall(PetscDeviceCheckDeviceCount_Internal(devidl));
155: PetscCall(PetscDeviceCheckDeviceCount_Internal(devidr));
156: PetscCallCUPM(cupmSetDevice(static_cast<int>(devidl)));
157: PetscFunctionReturn(PETSC_SUCCESS);
158: }
160: static PetscErrorCode check_current_device_(PetscDeviceContext dctx) noexcept { return check_current_device_(dctx, dctx); }
162: static PetscErrorCode finalize_() noexcept
163: {
164: PetscFunctionBegin;
165: for (auto &&handle : blashandles_) {
166: if (handle) {
167: PetscCallCUPMBLAS(cupmBlasDestroy(handle));
168: handle = nullptr;
169: }
170: }
171: for (auto &&handle : solverhandles_) {
172: if (handle) {
173: PetscCallCUPMSOLVER(cupmSolverDestroy(handle));
174: handle = nullptr;
175: }
176: }
177: initialized_ = false;
178: PetscFunctionReturn(PETSC_SUCCESS);
179: }
181: template <typename Allocator, typename PoolType = ::Petsc::memory::SegmentedMemoryPool<typename Allocator::value_type, stream_type, Allocator, 256 * sizeof(PetscScalar)>>
182: PETSC_NODISCARD static PoolType &default_pool_() noexcept
183: {
184: static PoolType pool;
185: return pool;
186: }
188: static PetscErrorCode check_memtype_(PetscMemType mtype, const char mess[]) noexcept
189: {
190: PetscFunctionBegin;
191: PetscCheck(PetscMemTypeHost(mtype) || (mtype == PETSC_MEMTYPE_DEVICE) || (mtype == PETSC_MEMTYPE_CUPM()), PETSC_COMM_SELF, PETSC_ERR_SUP, "%s device context can only handle %s (pinned) host or device memory", cupmName(), mess);
192: PetscFunctionReturn(PETSC_SUCCESS);
193: }
195: public:
196: // All of these functions MUST be static in order to be callable from C, otherwise they
197: // get the implicit 'this' pointer tacked on
198: static PetscErrorCode destroy(PetscDeviceContext) noexcept;
199: static PetscErrorCode changeStreamType(PetscDeviceContext, PetscStreamType) noexcept;
200: static PetscErrorCode setUp(PetscDeviceContext) noexcept;
201: static PetscErrorCode query(PetscDeviceContext, PetscBool *) noexcept;
202: static PetscErrorCode waitForContext(PetscDeviceContext, PetscDeviceContext) noexcept;
203: static PetscErrorCode synchronize(PetscDeviceContext) noexcept;
204: template <typename Handle_t>
205: static PetscErrorCode getHandle(PetscDeviceContext, void *) noexcept;
206: template <typename Handle_t>
207: static PetscErrorCode getHandlePtr(PetscDeviceContext, void **) noexcept;
208: static PetscErrorCode beginTimer(PetscDeviceContext) noexcept;
209: static PetscErrorCode endTimer(PetscDeviceContext, PetscLogDouble *) noexcept;
210: static PetscErrorCode getPower(PetscDeviceContext, PetscLogDouble *) noexcept;
211: static PetscErrorCode beginEnergyMeter(PetscDeviceContext) noexcept;
212: static PetscErrorCode endEnergyMeter(PetscDeviceContext, PetscLogDouble *) noexcept;
213: static PetscErrorCode memAlloc(PetscDeviceContext, PetscBool, PetscMemType, std::size_t, std::size_t, void **) noexcept;
214: static PetscErrorCode memFree(PetscDeviceContext, PetscMemType, void **) noexcept;
215: static PetscErrorCode memCopy(PetscDeviceContext, void *PETSC_RESTRICT, const void *PETSC_RESTRICT, std::size_t, PetscDeviceCopyMode) noexcept;
216: static PetscErrorCode memSet(PetscDeviceContext, PetscMemType, void *, PetscInt, std::size_t) noexcept;
217: static PetscErrorCode createEvent(PetscDeviceContext, PetscEvent) noexcept;
218: static PetscErrorCode recordEvent(PetscDeviceContext, PetscEvent) noexcept;
219: static PetscErrorCode waitForEvent(PetscDeviceContext, PetscEvent) noexcept;
221: // not a PetscDeviceContext method, this registers the class
222: static PetscErrorCode initialize(PetscDevice) noexcept;
224: // clang-format off
225: static constexpr _DeviceContextOps ops = {
226: PetscDesignatedInitializer(destroy, destroy),
227: PetscDesignatedInitializer(changestreamtype, changeStreamType),
228: PetscDesignatedInitializer(setup, setUp),
229: PetscDesignatedInitializer(query, query),
230: PetscDesignatedInitializer(waitforcontext, waitForContext),
231: PetscDesignatedInitializer(synchronize, synchronize),
232: PetscDesignatedInitializer(getblashandle, getHandle<blas_tag>),
233: PetscDesignatedInitializer(getsolverhandle, getHandle<solver_tag>),
234: PetscDesignatedInitializer(getstreamhandle, getHandlePtr<stream_tag>),
235: PetscDesignatedInitializer(begintimer, beginTimer),
236: PetscDesignatedInitializer(endtimer, endTimer),
237: #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS)
238: PetscDesignatedInitializer(getpower, getPower),
239: #else
240: PetscDesignatedInitializer(getpower, nullptr),
241: #endif
242: #if PetscDefined(HAVE_CUDA)
243: PetscDesignatedInitializer(beginenergymeter, beginEnergyMeter),
244: PetscDesignatedInitializer(endenergymeter, endEnergyMeter),
245: #else
246: PetscDesignatedInitializer(beginenergymeter, nullptr),
247: PetscDesignatedInitializer(endenergymeter, nullptr),
248: #endif
249: PetscDesignatedInitializer(memalloc, memAlloc),
250: PetscDesignatedInitializer(memfree, memFree),
251: PetscDesignatedInitializer(memcopy, memCopy),
252: PetscDesignatedInitializer(memset, memSet),
253: PetscDesignatedInitializer(createevent, createEvent),
254: PetscDesignatedInitializer(recordevent, recordEvent),
255: PetscDesignatedInitializer(waitforevent, waitForEvent)
256: };
257: // clang-format on
258: };
260: // not a PetscDeviceContext method, this initializes the CLASS
261: template <DeviceType T>
262: inline PetscErrorCode DeviceContext<T>::initialize(PetscDevice device) noexcept
263: {
264: PetscFunctionBegin;
265: if (PetscUnlikely(!initialized_)) {
266: uint64_t threshold = UINT64_MAX;
267: cupmMemPool_t mempool;
269: initialized_ = true;
270: PetscCallCUPM(cupmDeviceGetMemPool(&mempool, static_cast<int>(device->deviceId)));
271: PetscCallCUPM(cupmMemPoolSetAttribute(mempool, cupmMemPoolAttrReleaseThreshold, &threshold));
272: blashandles_.fill(nullptr);
273: solverhandles_.fill(nullptr);
274: PetscCall(PetscRegisterFinalize(finalize_));
275: }
276: PetscFunctionReturn(PETSC_SUCCESS);
277: }
279: template <DeviceType T>
280: inline PetscErrorCode DeviceContext<T>::destroy(PetscDeviceContext dctx) noexcept
281: {
282: PetscFunctionBegin;
283: if (const auto dci = impls_cast_(dctx)) {
284: PetscCall(dci->stream.destroy());
285: if (dci->event) PetscCall(cupm_fast_event_pool<T>().deallocate(&dci->event));
286: if (dci->begin) PetscCallCUPM(cupmEventDestroy(dci->begin));
287: if (dci->end) PetscCallCUPM(cupmEventDestroy(dci->end));
288: delete dci;
289: dctx->data = nullptr;
290: }
291: PetscFunctionReturn(PETSC_SUCCESS);
292: }
294: template <DeviceType T>
295: inline PetscErrorCode DeviceContext<T>::changeStreamType(PetscDeviceContext dctx, PETSC_UNUSED PetscStreamType stype) noexcept
296: {
297: const auto dci = impls_cast_(dctx);
299: PetscFunctionBegin;
300: PetscCall(dci->stream.destroy());
301: // set these to null so they aren't usable until setup is called again
302: dci->blas = nullptr;
303: dci->solver = nullptr;
304: PetscFunctionReturn(PETSC_SUCCESS);
305: }
307: template <DeviceType T>
308: inline PetscErrorCode DeviceContext<T>::setUp(PetscDeviceContext dctx) noexcept
309: {
310: const auto dci = impls_cast_(dctx);
311: auto &event = dci->event;
313: PetscFunctionBegin;
314: PetscCall(check_current_device_(dctx));
315: PetscCall(dci->stream.change_type(dctx->streamType));
316: if (!event) PetscCall(cupm_fast_event_pool<T>().allocate(&event));
317: #if PetscDefined(USE_DEBUG)
318: dci->timerInUse = PETSC_FALSE;
319: #endif
320: PetscFunctionReturn(PETSC_SUCCESS);
321: }
323: template <DeviceType T>
324: inline PetscErrorCode DeviceContext<T>::query(PetscDeviceContext dctx, PetscBool *idle) noexcept
325: {
326: PetscFunctionBegin;
327: PetscCall(check_current_device_(dctx));
328: switch (auto cerr = cupmStreamQuery(impls_cast_(dctx)->stream.get_stream())) {
329: case cupmSuccess:
330: *idle = PETSC_TRUE;
331: break;
332: case cupmErrorNotReady:
333: *idle = PETSC_FALSE;
334: // reset the error
335: cerr = cupmGetLastError();
336: static_cast<void>(cerr);
337: break;
338: default:
339: PetscCallCUPM(cerr);
340: PetscUnreachable();
341: }
342: PetscFunctionReturn(PETSC_SUCCESS);
343: }
345: template <DeviceType T>
346: inline PetscErrorCode DeviceContext<T>::waitForContext(PetscDeviceContext dctxa, PetscDeviceContext dctxb) noexcept
347: {
348: const auto dcib = impls_cast_(dctxb);
349: const auto event = dcib->event;
351: PetscFunctionBegin;
352: PetscCall(check_current_device_(dctxa, dctxb));
353: PetscCallCUPM(cupmEventRecord(event, dcib->stream.get_stream()));
354: PetscCallCUPM(cupmStreamWaitEvent(impls_cast_(dctxa)->stream.get_stream(), event, 0));
355: PetscFunctionReturn(PETSC_SUCCESS);
356: }
358: template <DeviceType T>
359: inline PetscErrorCode DeviceContext<T>::synchronize(PetscDeviceContext dctx) noexcept
360: {
361: auto idle = PETSC_TRUE;
363: PetscFunctionBegin;
364: PetscCall(query(dctx, &idle));
365: if (!idle) PetscCallCUPM(cupmStreamSynchronize(impls_cast_(dctx)->stream.get_stream()));
366: PetscFunctionReturn(PETSC_SUCCESS);
367: }
369: template <DeviceType T>
370: template <typename handle_t>
371: inline PetscErrorCode DeviceContext<T>::getHandle(PetscDeviceContext dctx, void *handle) noexcept
372: {
373: PetscFunctionBegin;
374: PetscCall(initialize_handle_(handle_t{}, dctx));
375: *static_cast<typename handle_t::type *>(handle) = impls_cast_(dctx)->get(handle_t{});
376: PetscFunctionReturn(PETSC_SUCCESS);
377: }
379: template <DeviceType T>
380: template <typename handle_t>
381: inline PetscErrorCode DeviceContext<T>::getHandlePtr(PetscDeviceContext dctx, void **handle) noexcept
382: {
383: using handle_type = typename handle_t::type;
385: PetscFunctionBegin;
386: PetscCall(initialize_handle_(handle_t{}, dctx));
387: *reinterpret_cast<handle_type **>(handle) = const_cast<handle_type *>(std::addressof(impls_cast_(dctx)->get(handle_t{})));
388: PetscFunctionReturn(PETSC_SUCCESS);
389: }
391: template <DeviceType T>
392: inline PetscErrorCode DeviceContext<T>::beginTimer(PetscDeviceContext dctx) noexcept
393: {
394: const auto dci = impls_cast_(dctx);
396: PetscFunctionBegin;
397: PetscCall(check_current_device_(dctx));
398: #if PetscDefined(USE_DEBUG)
399: PetscCheck(!dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeEnd()?");
400: dci->timerInUse = PETSC_TRUE;
401: #endif
402: if (!dci->begin) {
403: PetscAssert(!dci->end, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Don't have a 'begin' event, but somehow have an end event");
404: PetscCallCUPM(cupmEventCreate(&dci->begin));
405: PetscCallCUPM(cupmEventCreate(&dci->end));
406: }
407: PetscCallCUPM(cupmEventRecord(dci->begin, dci->stream.get_stream()));
408: PetscFunctionReturn(PETSC_SUCCESS);
409: }
411: template <DeviceType T>
412: inline PetscErrorCode DeviceContext<T>::endTimer(PetscDeviceContext dctx, PetscLogDouble *elapsed) noexcept
413: {
414: float gtime;
415: const auto dci = impls_cast_(dctx);
416: const auto end = dci->end;
418: PetscFunctionBegin;
419: PetscCall(check_current_device_(dctx));
420: #if PetscDefined(USE_DEBUG)
421: PetscCheck(dci->timerInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuTimeBegin()?");
422: dci->timerInUse = PETSC_FALSE;
423: #endif
424: PetscCallCUPM(cupmEventRecord(end, dci->stream.get_stream()));
425: PetscCallCUPM(cupmEventSynchronize(end));
426: PetscCallCUPM(cupmEventElapsedTime(>ime, dci->begin, end));
427: *elapsed = static_cast<util::remove_pointer_t<decltype(elapsed)>>(gtime);
428: PetscFunctionReturn(PETSC_SUCCESS);
429: }
431: #if PetscDefined(HAVE_CUDA_VERSION_12_2PLUS)
432: template <DeviceType T>
433: inline PetscErrorCode DeviceContext<T>::getPower(PetscDeviceContext dctx, PetscLogDouble *power) noexcept
434: {
435: const auto dci = impls_cast_(dctx);
436: nvmlFieldValue_t values[1];
438: PetscFunctionBegin;
439: PetscCall(check_current_device_(dctx));
440: PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream()));
441: values[0].fieldId = NVML_FI_DEV_POWER_INSTANT;
442: if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle));
443: PetscCallNVML(nvmlDeviceGetFieldValues(dci->nvmlHandle, 1, values));
444: *power = static_cast<util::remove_pointer_t<decltype(power)>>(values[0].value.uiVal);
445: PetscFunctionReturn(PETSC_SUCCESS);
446: }
447: #endif
449: #if PetscDefined(HAVE_CUDA)
450: template <DeviceType T>
451: inline PetscErrorCode DeviceContext<T>::beginEnergyMeter(PetscDeviceContext dctx) noexcept
452: {
453: const auto dci = impls_cast_(dctx);
455: PetscFunctionBegin;
456: PetscCall(check_current_device_(dctx));
457: #if PetscDefined(USE_DEBUG)
458: PetscCheck(!dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterEnd()?");
459: dci->EnergyMeterInUse = PETSC_TRUE;
460: #endif
461: if (!dci->nvmlHandle) PetscCallNVML(nvmlDeviceGetHandleByIndex(dctx->device->deviceId, &dci->nvmlHandle));
462: PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterbegin));
463: PetscFunctionReturn(PETSC_SUCCESS);
464: }
466: template <DeviceType T>
467: inline PetscErrorCode DeviceContext<T>::endEnergyMeter(PetscDeviceContext dctx, PetscLogDouble *energy) noexcept
468: {
469: const auto dci = impls_cast_(dctx);
471: PetscFunctionBegin;
472: PetscCall(check_current_device_(dctx));
473: #if PetscDefined(USE_DEBUG)
474: PetscCheck(dci->EnergyMeterInUse, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Forgot to call PetscLogGpuEnergyMeterBegin()?");
475: dci->EnergyMeterInUse = PETSC_FALSE;
476: #endif
477: PetscCallCUPM(cupmStreamSynchronize(dci->stream.get_stream()));
478: PetscCallNVML(nvmlDeviceGetTotalEnergyConsumption(dci->nvmlHandle, &dci->energymeterend));
479: *energy = static_cast<util::remove_pointer_t<decltype(energy)>>(dci->energymeterend - dci->energymeterbegin) / 1000; // convert to Joule
480: PetscFunctionReturn(PETSC_SUCCESS);
481: }
482: #endif
484: template <DeviceType T>
485: inline PetscErrorCode DeviceContext<T>::memAlloc(PetscDeviceContext dctx, PetscBool clear, PetscMemType mtype, std::size_t n, std::size_t alignment, void **dest) noexcept
486: {
487: const auto &stream = impls_cast_(dctx)->stream;
489: PetscFunctionBegin;
490: PetscCall(check_current_device_(dctx));
491: PetscCall(check_memtype_(mtype, "allocating"));
492: if (PetscMemTypeHost(mtype)) {
493: PetscCall(default_pool_<HostAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
494: } else {
495: PetscCall(default_pool_<DeviceAllocator<T>>().allocate(n, reinterpret_cast<char **>(dest), &stream, alignment));
496: }
497: if (clear) PetscCallCUPM(cupmMemsetAsync(*dest, 0, n, stream.get_stream()));
498: PetscFunctionReturn(PETSC_SUCCESS);
499: }
501: template <DeviceType T>
502: inline PetscErrorCode DeviceContext<T>::memFree(PetscDeviceContext dctx, PetscMemType mtype, void **ptr) noexcept
503: {
504: const auto &stream = impls_cast_(dctx)->stream;
506: PetscFunctionBegin;
507: PetscCall(check_current_device_(dctx));
508: PetscCall(check_memtype_(mtype, "freeing"));
509: if (!*ptr) PetscFunctionReturn(PETSC_SUCCESS);
510: if (PetscMemTypeHost(mtype)) {
511: PetscCall(default_pool_<HostAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
512: // if ptr exists still exists the pool didn't own it
513: if (*ptr) {
514: auto registered = PETSC_FALSE, managed = PETSC_FALSE;
516: PetscCall(PetscCUPMGetMemType(*ptr, nullptr, ®istered, &managed));
517: if (registered) {
518: PetscCallCUPM(cupmFreeHost(*ptr));
519: } else if (managed) {
520: PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
521: }
522: }
523: } else {
524: PetscCall(default_pool_<DeviceAllocator<T>>().deallocate(reinterpret_cast<char **>(ptr), &stream));
525: // if ptr still exists the pool didn't own it
526: if (*ptr) PetscCallCUPM(cupmFreeAsync(*ptr, stream.get_stream()));
527: }
528: PetscFunctionReturn(PETSC_SUCCESS);
529: }
531: template <DeviceType T>
532: inline PetscErrorCode DeviceContext<T>::memCopy(PetscDeviceContext dctx, void *PETSC_RESTRICT dest, const void *PETSC_RESTRICT src, std::size_t n, PetscDeviceCopyMode mode) noexcept
533: {
534: const auto stream = impls_cast_(dctx)->stream.get_stream();
536: PetscFunctionBegin;
537: // can't use PetscCUPMMemcpyAsync here since we don't know sizeof(*src)...
538: if (mode == PETSC_DEVICE_COPY_HTOH) {
539: const auto cerr = cupmStreamQuery(stream);
541: // yes this is faster
542: if (cerr == cupmSuccess) {
543: PetscCall(PetscMemcpy(dest, src, n));
544: PetscFunctionReturn(PETSC_SUCCESS);
545: } else if (cerr == cupmErrorNotReady) {
546: auto PETSC_UNUSED unused = cupmGetLastError();
548: static_cast<void>(unused);
549: } else {
550: PetscCallCUPM(cerr);
551: }
552: }
553: PetscCallCUPM(cupmMemcpyAsync(dest, src, n, PetscDeviceCopyModeToCUPMMemcpyKind(mode), stream));
554: PetscFunctionReturn(PETSC_SUCCESS);
555: }
557: template <DeviceType T>
558: inline PetscErrorCode DeviceContext<T>::memSet(PetscDeviceContext dctx, PetscMemType mtype, void *ptr, PetscInt v, std::size_t n) noexcept
559: {
560: PetscFunctionBegin;
561: PetscCall(check_current_device_(dctx));
562: PetscCall(check_memtype_(mtype, "zeroing"));
563: PetscCallCUPM(cupmMemsetAsync(ptr, static_cast<int>(v), n, impls_cast_(dctx)->stream.get_stream()));
564: PetscFunctionReturn(PETSC_SUCCESS);
565: }
567: template <DeviceType T>
568: inline PetscErrorCode DeviceContext<T>::createEvent(PetscDeviceContext, PetscEvent event) noexcept
569: {
570: PetscFunctionBegin;
571: PetscCallCXX(event->data = new event_type{});
572: event->destroy = [](PetscEvent event) {
573: PetscFunctionBegin;
574: delete event_cast_(event);
575: event->data = nullptr;
576: PetscFunctionReturn(PETSC_SUCCESS);
577: };
578: PetscFunctionReturn(PETSC_SUCCESS);
579: }
581: template <DeviceType T>
582: inline PetscErrorCode DeviceContext<T>::recordEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
583: {
584: PetscFunctionBegin;
585: PetscCall(impls_cast_(dctx)->stream.record_event(*event_cast_(event)));
586: PetscFunctionReturn(PETSC_SUCCESS);
587: }
589: template <DeviceType T>
590: inline PetscErrorCode DeviceContext<T>::waitForEvent(PetscDeviceContext dctx, PetscEvent event) noexcept
591: {
592: PetscFunctionBegin;
593: PetscCall(impls_cast_(dctx)->stream.wait_for_event(*event_cast_(event)));
594: PetscFunctionReturn(PETSC_SUCCESS);
595: }
597: // initialize the static member variables
598: template <DeviceType T>
599: bool DeviceContext<T>::initialized_ = false;
601: template <DeviceType T>
602: std::array<typename DeviceContext<T>::cupmBlasHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::blashandles_ = {};
604: template <DeviceType T>
605: std::array<typename DeviceContext<T>::cupmSolverHandle_t, PETSC_DEVICE_MAX_DEVICES> DeviceContext<T>::solverhandles_ = {};
607: template <DeviceType T>
608: constexpr _DeviceContextOps DeviceContext<T>::ops;
610: } // namespace impl
612: // shorten this one up a bit (and instantiate the templates)
613: using CUPMContextCuda = impl::DeviceContext<DeviceType::CUDA>;
614: using CUPMContextHip = impl::DeviceContext<DeviceType::HIP>;
616: // shorthand for what is an EXTREMELY long name
617: #define PetscDeviceContext_(IMPLS) ::Petsc::device::cupm::impl::DeviceContext<::Petsc::device::cupm::DeviceType::IMPLS>::PetscDeviceContext_IMPLS
619: } // namespace cupm
621: } // namespace device
623: } // namespace Petsc