Actual source code: vecseqcupm_impl.hpp

  1: #pragma once

  3: #include "vecseqcupm.hpp"

  5: #include <petsc/private/randomimpl.h>

  7: #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"
  8: #include "../src/sys/objects/device/impls/cupm/kernels.hpp"

 10: #if PetscDefined(USE_COMPLEX)
 11:   #include <thrust/transform_reduce.h>
 12: #endif
 13: #include <thrust/transform.h>
 14: #include <thrust/reduce.h>
 15: #include <thrust/functional.h>
 16: #include <thrust/tuple.h>
 17: #include <thrust/device_ptr.h>
 18: #include <thrust/iterator/zip_iterator.h>
 19: #include <thrust/iterator/counting_iterator.h>
 20: #include <thrust/iterator/constant_iterator.h>
 21: #include <thrust/inner_product.h>

 23: namespace Petsc
 24: {

 26: namespace vec
 27: {

 29: namespace cupm
 30: {

 32: namespace impl
 33: {

 35: // ==========================================================================================
 36: // VecSeq_CUPM - Private API
 37: // ==========================================================================================

 39: template <device::cupm::DeviceType T>
 40: inline Vec_Seq *VecSeq_CUPM<T>::VecIMPLCast_(Vec v) noexcept
 41: {
 42:   return static_cast<Vec_Seq *>(v->data);
 43: }

 45: template <device::cupm::DeviceType T>
 46: inline constexpr VecType VecSeq_CUPM<T>::VECIMPLCUPM_() noexcept
 47: {
 48:   return VECSEQCUPM();
 49: }

 51: template <device::cupm::DeviceType T>
 52: inline constexpr VecType VecSeq_CUPM<T>::VECIMPL_() noexcept
 53: {
 54:   return VECSEQ;
 55: }

 57: template <device::cupm::DeviceType T>
 58: inline PetscErrorCode VecSeq_CUPM<T>::ClearAsyncFunctions(Vec v) noexcept
 59: {
 60:   PetscFunctionBegin;
 61:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Abs), nullptr));
 62:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBY), nullptr));
 63:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBYPCZ), nullptr));
 64:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPY), nullptr));
 65:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AYPX), nullptr));
 66:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Conjugate), nullptr));
 67:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Copy), nullptr));
 68:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Exp), nullptr));
 69:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Log), nullptr));
 70:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(MAXPY), nullptr));
 71:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseDivide), nullptr));
 72:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMax), nullptr));
 73:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMaxAbs), nullptr));
 74:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMin), nullptr));
 75:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMult), nullptr));
 76:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseSign), nullptr));
 77:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Reciprocal), nullptr));
 78:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Scale), nullptr));
 79:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Set), nullptr));
 80:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Shift), nullptr));
 81:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(SqrtAbs), nullptr));
 82:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Swap), nullptr));
 83:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(WAXPY), nullptr));
 84:   PetscFunctionReturn(PETSC_SUCCESS);
 85: }

 87: template <device::cupm::DeviceType T>
 88: inline PetscErrorCode VecSeq_CUPM<T>::InitializeAsyncFunctions(Vec v) noexcept
 89: {
 90:   PetscFunctionBegin;
 91:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Abs), VecSeq_CUPM<T>::AbsAsync));
 92:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBY), VecSeq_CUPM<T>::AXPBYAsync));
 93:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPBYPCZ), VecSeq_CUPM<T>::AXPBYPCZAsync));
 94:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AXPY), VecSeq_CUPM<T>::AXPYAsync));
 95:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(AYPX), VecSeq_CUPM<T>::AYPXAsync));
 96:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Conjugate), VecSeq_CUPM<T>::ConjugateAsync));
 97:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Copy), VecSeq_CUPM<T>::CopyAsync));
 98:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Exp), VecSeq_CUPM<T>::ExpAsync));
 99:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Log), VecSeq_CUPM<T>::LogAsync));
100:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(MAXPY), VecSeq_CUPM<T>::MAXPYAsync));
101:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseDivide), VecSeq_CUPM<T>::PointwiseDivideAsync));
102:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMax), VecSeq_CUPM<T>::PointwiseMaxAsync));
103:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMaxAbs), VecSeq_CUPM<T>::PointwiseMaxAbsAsync));
104:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMin), VecSeq_CUPM<T>::PointwiseMinAsync));
105:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseMult), VecSeq_CUPM<T>::PointwiseMultAsync));
106:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(PointwiseSign), VecSeq_CUPM<T>::PointwiseSignAsync));
107:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Reciprocal), VecSeq_CUPM<T>::ReciprocalAsync));
108:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Scale), VecSeq_CUPM<T>::ScaleAsync));
109:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Set), VecSeq_CUPM<T>::SetAsync));
110:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Shift), VecSeq_CUPM<T>::ShiftAsync));
111:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(SqrtAbs), VecSeq_CUPM<T>::SqrtAbsAsync));
112:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(Swap), VecSeq_CUPM<T>::SwapAsync));
113:   PetscCall(PetscObjectComposeFunction(PetscObjectCast(v), VecAsyncFnName(WAXPY), VecSeq_CUPM<T>::WAXPYAsync));
114:   PetscFunctionReturn(PETSC_SUCCESS);
115: }

117: template <device::cupm::DeviceType T>
118: inline PetscErrorCode VecSeq_CUPM<T>::VecDestroy_IMPL_(Vec v) noexcept
119: {
120:   PetscFunctionBegin;
121:   PetscCall(ClearAsyncFunctions(v));
122:   PetscCall(VecDestroy_Seq(v));
123:   PetscFunctionReturn(PETSC_SUCCESS);
124: }

126: template <device::cupm::DeviceType T>
127: inline PetscErrorCode VecSeq_CUPM<T>::VecResetArray_IMPL_(Vec v) noexcept
128: {
129:   return VecResetArray_Seq(v);
130: }

132: template <device::cupm::DeviceType T>
133: inline PetscErrorCode VecSeq_CUPM<T>::VecPlaceArray_IMPL_(Vec v, const PetscScalar *a) noexcept
134: {
135:   return VecPlaceArray_Seq(v, a);
136: }

138: template <device::cupm::DeviceType T>
139: inline PetscErrorCode VecSeq_CUPM<T>::VecCreate_IMPL_Private_(Vec v, PetscBool *alloc_missing, PetscInt, PetscScalar *host_array) noexcept
140: {
141:   PetscMPIInt size;

143:   PetscFunctionBegin;
144:   if (alloc_missing) *alloc_missing = PETSC_FALSE;
145:   PetscCallMPI(MPI_Comm_size(PetscObjectComm(PetscObjectCast(v)), &size));
146:   PetscCheck(size <= 1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Must create VecSeq on communicator of size 1, have size %d", size);
147:   PetscCall(VecCreate_Seq_Private(v, host_array));
148:   PetscCall(InitializeAsyncFunctions(v));
149:   PetscFunctionReturn(PETSC_SUCCESS);
150: }

152: // for functions with an early return based one vec size we still need to artificially bump the
153: // object state. This is to prevent the following:
154: //
155: // 0. Suppose you have a Vec {
156: //   rank 0: [0],
157: //   rank 1: []
158: // }
159: // 1. both ranks have Vec with PetscObjectState = 0, stashed norm of 0
160: // 2. Vec enters e.g. VecSet(10)
161: // 3. rank 1 has local size 0 and bails immediately
162: // 4. rank 0 has local size 1 and enters function, eventually calls DeviceArrayWrite()
163: // 5. DeviceArrayWrite() calls PetscObjectStateIncrease(), now state = 1
164: // 6. Vec enters VecNorm(), and calls VecNormAvailable()
165: // 7. rank 1 has object state = 0, equal to stash and returns early with norm = 0
166: // 8. rank 0 has object state = 1, not equal to stash, continues to impl function
167: // 9. rank 0 deadlocks on MPI_Allreduce() because rank 1 bailed early
168: template <device::cupm::DeviceType T>
169: inline PetscErrorCode VecSeq_CUPM<T>::MaybeIncrementEmptyLocalVec(Vec v) noexcept
170: {
171:   PetscFunctionBegin;
172:   if (PetscUnlikely((v->map->n == 0) && (v->map->N != 0))) PetscCall(PetscObjectStateIncrease(PetscObjectCast(v)));
173:   PetscFunctionReturn(PETSC_SUCCESS);
174: }

176: template <device::cupm::DeviceType T>
177: inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM_(Vec v, PetscDeviceContext dctx, PetscScalar *host_array, PetscScalar *device_array) noexcept
178: {
179:   PetscFunctionBegin;
180:   PetscCall(base_type::VecCreate_IMPL_Private(v, nullptr, 0, host_array));
181:   PetscCall(Initialize_CUPMBase(v, PETSC_FALSE, host_array, device_array, dctx));
182:   PetscFunctionReturn(PETSC_SUCCESS);
183: }

185: template <device::cupm::DeviceType T>
186: template <typename BinaryFuncT>
187: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinary_(BinaryFuncT &&binary, Vec xin, Vec yin, Vec zout, PetscDeviceContext dctx) noexcept
188: {
189:   PetscFunctionBegin;
190:   if (const auto n = zout->map->n) {
191:     cupmStream_t stream;

193:     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
194:     PetscCall(GetHandlesFrom_(dctx, &stream));
195:     // clang-format off
196:     PetscCallThrust(
197:       const auto dxptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, xin).data());

199:       THRUST_CALL(
200:         thrust::transform,
201:         stream,
202:         dxptr, dxptr + n,
203:         thrust::device_pointer_cast(DeviceArrayRead(dctx, yin).data()),
204:         thrust::device_pointer_cast(DeviceArrayWrite(dctx, zout).data()),
205:         std::forward<BinaryFuncT>(binary)
206:       )
207:     );
208:     // clang-format on
209:     PetscCall(PetscLogGpuFlops(n));
210:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
211:   } else {
212:     PetscCall(MaybeIncrementEmptyLocalVec(zout));
213:   }
214:   PetscFunctionReturn(PETSC_SUCCESS);
215: }

217: template <device::cupm::DeviceType T>
218: template <typename BinaryFuncT>
219: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseBinaryDispatch_(PetscErrorCode (*VecSeqFunction)(Vec, Vec, Vec), BinaryFuncT &&binary, Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
220: {
221:   PetscFunctionBegin;
222:   if (xin->boundtocpu || yin->boundtocpu) PetscCall((*VecSeqFunction)(wout, xin, yin));
223:   else PetscCall(PointwiseBinary_(std::forward<BinaryFuncT>(binary), xin, yin, wout, dctx)); // note order of arguments! xin and yin are read, wout is written!
224:   PetscFunctionReturn(PETSC_SUCCESS);
225: }

227: template <device::cupm::DeviceType T>
228: template <typename UnaryFuncT>
229: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseUnary_(UnaryFuncT &&unary, Vec xinout, Vec yout, PetscDeviceContext dctx) noexcept
230: {
231:   const auto inplace = !yout || (xinout == yout);

233:   PetscFunctionBegin;
234:   if (const auto n = xinout->map->n) {
235:     cupmStream_t stream;
236:     const auto   apply = [&](PetscScalar *xinout, PetscScalar *yout = nullptr) {
237:       PetscFunctionBegin;
238:       // clang-format off
239:       PetscCallThrust(
240:         const auto xptr = thrust::device_pointer_cast(xinout);

242:         THRUST_CALL(
243:           thrust::transform,
244:           stream,
245:           xptr, xptr + n,
246:           (yout && (yout != xinout)) ? thrust::device_pointer_cast(yout) : xptr,
247:           std::forward<UnaryFuncT>(unary)
248:         )
249:       );
250:       // clang-format on
251:       PetscFunctionReturn(PETSC_SUCCESS);
252:     };

254:     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
255:     PetscCall(GetHandlesFrom_(dctx, &stream));
256:     if (inplace) {
257:       PetscCall(apply(DeviceArrayReadWrite(dctx, xinout).data()));
258:     } else {
259:       PetscCall(apply(DeviceArrayRead(dctx, xinout).data(), DeviceArrayWrite(dctx, yout).data()));
260:     }
261:     PetscCall(PetscLogGpuFlops(n));
262:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
263:   } else {
264:     if (inplace) {
265:       PetscCall(MaybeIncrementEmptyLocalVec(xinout));
266:     } else {
267:       PetscCall(MaybeIncrementEmptyLocalVec(yout));
268:     }
269:   }
270:   PetscFunctionReturn(PETSC_SUCCESS);
271: }

273: // ==========================================================================================
274: // VecSeq_CUPM - Public API - Constructors
275: // ==========================================================================================

277: // VecCreateSeqCUPM()
278: template <device::cupm::DeviceType T>
279: inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPM(MPI_Comm comm, PetscInt bs, PetscInt n, Vec *v, PetscBool call_set_type) noexcept
280: {
281:   PetscFunctionBegin;
282:   PetscCall(Create_CUPMBase(comm, bs, n, n, v, call_set_type));
283:   PetscFunctionReturn(PETSC_SUCCESS);
284: }

286: // VecCreateSeqCUPMWithArrays()
287: template <device::cupm::DeviceType T>
288: inline PetscErrorCode VecSeq_CUPM<T>::CreateSeqCUPMWithBothArrays(MPI_Comm comm, PetscInt bs, PetscInt n, const PetscScalar host_array[], const PetscScalar device_array[], Vec *v) noexcept
289: {
290:   PetscDeviceContext dctx;

292:   PetscFunctionBegin;
293:   PetscCall(GetHandles_(&dctx));
294:   // do NOT call VecSetType(), otherwise ops->create() -> create() ->
295:   // CreateSeqCUPM_() is called!
296:   PetscCall(CreateSeqCUPM(comm, bs, n, v, PETSC_FALSE));
297:   PetscCall(CreateSeqCUPM_(*v, dctx, PetscRemoveConstCast(host_array), PetscRemoveConstCast(device_array)));
298:   PetscFunctionReturn(PETSC_SUCCESS);
299: }

301: // v->ops->duplicate
302: template <device::cupm::DeviceType T>
303: inline PetscErrorCode VecSeq_CUPM<T>::Duplicate(Vec v, Vec *y) noexcept
304: {
305:   PetscDeviceContext dctx;

307:   PetscFunctionBegin;
308:   PetscCall(GetHandles_(&dctx));
309:   PetscCall(Duplicate_CUPMBase(v, y, dctx));
310:   PetscFunctionReturn(PETSC_SUCCESS);
311: }

313: // ==========================================================================================
314: // VecSeq_CUPM - Public API - Utility
315: // ==========================================================================================

317: // v->ops->bindtocpu
318: template <device::cupm::DeviceType T>
319: inline PetscErrorCode VecSeq_CUPM<T>::BindToCPU(Vec v, PetscBool usehost) noexcept
320: {
321:   PetscDeviceContext dctx;

323:   PetscFunctionBegin;
324:   PetscCall(GetHandles_(&dctx));
325:   PetscCall(BindToCPU_CUPMBase(v, usehost, dctx));

327:   // REVIEW ME: this absolutely should be some sort of bulk mempcy rather than this mess
328:   VecSetOp_CUPM(dot, VecDot_Seq, Dot);
329:   VecSetOp_CUPM(norm, VecNorm_Seq, Norm);
330:   VecSetOp_CUPM(tdot, VecTDot_Seq, TDot);
331:   VecSetOp_CUPM(mdot, VecMDot_Seq, MDot);
332:   VecSetOp_CUPM(resetarray, VecResetArray_Seq, base_type::template ResetArray<PETSC_MEMTYPE_HOST>);
333:   VecSetOp_CUPM(placearray, VecPlaceArray_Seq, base_type::template PlaceArray<PETSC_MEMTYPE_HOST>);
334:   v->ops->mtdot = v->ops->mtdot_local = VecMTDot_Seq;
335:   VecSetOp_CUPM(max, VecMax_Seq, Max);
336:   VecSetOp_CUPM(min, VecMin_Seq, Min);
337:   VecSetOp_CUPM(setpreallocationcoo, VecSetPreallocationCOO_Seq, SetPreallocationCOO);
338:   VecSetOp_CUPM(setvaluescoo, VecSetValuesCOO_Seq, SetValuesCOO);
339:   PetscFunctionReturn(PETSC_SUCCESS);
340: }

342: // ==========================================================================================
343: // VecSeq_CUPM - Public API - Mutators
344: // ==========================================================================================

346: // v->ops->getlocalvector or v->ops->getlocalvectorread
347: template <device::cupm::DeviceType T>
348: template <PetscMemoryAccessMode access>
349: inline PetscErrorCode VecSeq_CUPM<T>::GetLocalVector(Vec v, Vec w) noexcept
350: {
351:   PetscBool wisseqcupm;

353:   PetscFunctionBegin;
354:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
355:   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
356:   if (wisseqcupm) {
357:     if (const auto wseq = VecIMPLCast(w)) {
358:       if (auto &alloced = wseq->array_allocated) {
359:         const auto useit = UseCUPMHostAlloc(util::exchange(w->pinned_memory, PETSC_FALSE));

361:         PetscCall(PetscFree(alloced));
362:       }
363:       wseq->array         = nullptr;
364:       wseq->unplacedarray = nullptr;
365:     }
366:     if (const auto wcu = VecCUPMCast(w)) {
367:       if (auto &device_array = wcu->array_d) {
368:         cupmStream_t stream;

370:         PetscCall(GetHandles_(&stream));
371:         PetscCallCUPM(cupmFreeAsync(device_array, stream));
372:       }
373:       PetscCall(PetscFree(w->spptr /* wcu */));
374:     }
375:   }
376:   if (v->petscnative && wisseqcupm) {
377:     PetscCall(PetscFree(w->data));
378:     w->data          = v->data;
379:     w->offloadmask   = v->offloadmask;
380:     w->pinned_memory = v->pinned_memory;
381:     w->spptr         = v->spptr;
382:     PetscCall(PetscObjectStateIncrease(PetscObjectCast(w)));
383:   } else {
384:     const auto array = &VecIMPLCast(w)->array;

386:     if (access == PETSC_MEMORY_ACCESS_READ) {
387:       PetscCall(VecGetArrayRead(v, const_cast<const PetscScalar **>(array)));
388:     } else {
389:       PetscCall(VecGetArray(v, array));
390:     }
391:     w->offloadmask = PETSC_OFFLOAD_CPU;
392:     if (wisseqcupm) {
393:       PetscDeviceContext dctx;

395:       PetscCall(GetHandles_(&dctx));
396:       PetscCall(DeviceAllocateCheck_(dctx, w));
397:     }
398:   }
399:   PetscFunctionReturn(PETSC_SUCCESS);
400: }

402: // v->ops->restorelocalvector or v->ops->restorelocalvectorread
403: template <device::cupm::DeviceType T>
404: template <PetscMemoryAccessMode access>
405: inline PetscErrorCode VecSeq_CUPM<T>::RestoreLocalVector(Vec v, Vec w) noexcept
406: {
407:   PetscBool wisseqcupm;

409:   PetscFunctionBegin;
410:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
411:   PetscCall(PetscObjectTypeCompare(PetscObjectCast(w), VECSEQCUPM(), &wisseqcupm));
412:   if (v->petscnative && wisseqcupm) {
413:     // the assignments to nullptr are __critical__, as w may persist after this call returns
414:     // and shouldn't share data with v!
415:     v->pinned_memory = w->pinned_memory;
416:     v->offloadmask   = util::exchange(w->offloadmask, PETSC_OFFLOAD_UNALLOCATED);
417:     v->data          = util::exchange(w->data, nullptr);
418:     v->spptr         = util::exchange(w->spptr, nullptr);
419:   } else {
420:     const auto array = &VecIMPLCast(w)->array;

422:     if (access == PETSC_MEMORY_ACCESS_READ) {
423:       PetscCall(VecRestoreArrayRead(v, const_cast<const PetscScalar **>(array)));
424:     } else {
425:       PetscCall(VecRestoreArray(v, array));
426:     }
427:     if (w->spptr && wisseqcupm) {
428:       cupmStream_t stream;

430:       PetscCall(GetHandles_(&stream));
431:       PetscCallCUPM(cupmFreeAsync(VecCUPMCast(w)->array_d, stream));
432:       PetscCall(PetscFree(w->spptr));
433:     }
434:   }
435:   PetscFunctionReturn(PETSC_SUCCESS);
436: }

438: // ==========================================================================================
439: // VecSeq_CUPM - Public API - Compute Methods
440: // ==========================================================================================

442: // VecAYPXAsync_Private
443: template <device::cupm::DeviceType T>
444: inline PetscErrorCode VecSeq_CUPM<T>::AYPXAsync(Vec yin, PetscScalar alpha, Vec xin, PetscDeviceContext dctx) noexcept
445: {
446:   const auto n = static_cast<cupmBlasInt_t>(yin->map->n);
447:   PetscBool  xiscupm;

449:   PetscFunctionBegin;
450:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
451:   if (!xiscupm) {
452:     PetscCall(VecAYPX_Seq(yin, alpha, xin));
453:     PetscFunctionReturn(PETSC_SUCCESS);
454:   }
455:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
456:   if (alpha == PetscScalar(0.0)) {
457:     cupmStream_t stream;

459:     PetscCall(GetHandlesFrom_(dctx, &stream));
460:     PetscCall(PetscLogGpuTimeBegin());
461:     PetscCall(PetscCUPMMemcpyAsync(DeviceArrayWrite(dctx, yin).data(), DeviceArrayRead(dctx, xin).data(), n, cupmMemcpyDeviceToDevice, stream));
462:     PetscCall(PetscLogGpuTimeEnd());
463:   } else if (n) {
464:     const auto       alphaIsOne = alpha == PetscScalar(1.0);
465:     const auto       calpha     = cupmScalarPtrCast(&alpha);
466:     cupmBlasHandle_t cupmBlasHandle;

468:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
469:     {
470:       const auto yptr = DeviceArrayReadWrite(dctx, yin);
471:       const auto xptr = DeviceArrayRead(dctx, xin);

473:       PetscCall(PetscLogGpuTimeBegin());
474:       if (alphaIsOne) {
475:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, calpha, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
476:       } else {
477:         const auto one = cupmScalarCast(1.0);

479:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, calpha, yptr.cupmdata(), 1));
480:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, &one, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
481:       }
482:       PetscCall(PetscLogGpuTimeEnd());
483:     }
484:     PetscCall(PetscLogGpuFlops((alphaIsOne ? 1 : 2) * n));
485:   }
486:   if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
487:   PetscFunctionReturn(PETSC_SUCCESS);
488: }

490: // v->ops->aypx
491: template <device::cupm::DeviceType T>
492: inline PetscErrorCode VecSeq_CUPM<T>::AYPX(Vec yin, PetscScalar alpha, Vec xin) noexcept
493: {
494:   PetscFunctionBegin;
495:   PetscCall(AYPXAsync(yin, alpha, xin, nullptr));
496:   PetscFunctionReturn(PETSC_SUCCESS);
497: }

499: // VecAXPYAsync_Private
500: template <device::cupm::DeviceType T>
501: inline PetscErrorCode VecSeq_CUPM<T>::AXPYAsync(Vec yin, PetscScalar alpha, Vec xin, PetscDeviceContext dctx) noexcept
502: {
503:   PetscBool xiscupm;

505:   PetscFunctionBegin;
506:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
507:   if (xiscupm) {
508:     const auto       n = static_cast<cupmBlasInt_t>(yin->map->n);
509:     cupmBlasHandle_t cupmBlasHandle;

511:     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
512:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
513:     PetscCall(PetscLogGpuTimeBegin());
514:     PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
515:     PetscCall(PetscLogGpuTimeEnd());
516:     PetscCall(PetscLogGpuFlops(2 * n));
517:     if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
518:   } else {
519:     PetscCall(VecAXPY_Seq(yin, alpha, xin));
520:   }
521:   PetscFunctionReturn(PETSC_SUCCESS);
522: }

524: // v->ops->axpy
525: template <device::cupm::DeviceType T>
526: inline PetscErrorCode VecSeq_CUPM<T>::AXPY(Vec yin, PetscScalar alpha, Vec xin) noexcept
527: {
528:   PetscFunctionBegin;
529:   PetscCall(AXPYAsync(yin, alpha, xin, nullptr));
530:   PetscFunctionReturn(PETSC_SUCCESS);
531: }

533: namespace detail
534: {

536: struct divides {
537:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return rhs == PetscScalar{0.0} ? rhs : lhs / rhs; }
538: };

540: } // namespace detail

542: // VecPointwiseDivideAsync_Private
543: template <device::cupm::DeviceType T>
544: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivideAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
545: {
546:   PetscFunctionBegin;
547:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseDivide_Seq, detail::divides{}, wout, xin, yin, dctx));
548:   PetscFunctionReturn(PETSC_SUCCESS);
549: }

551: // v->ops->pointwisedivide
552: template <device::cupm::DeviceType T>
553: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseDivide(Vec wout, Vec xin, Vec yin) noexcept
554: {
555:   PetscFunctionBegin;
556:   PetscCall(PointwiseDivideAsync(wout, xin, yin, nullptr));
557:   PetscFunctionReturn(PETSC_SUCCESS);
558: }

560: // VecPointwiseMultAsync_Private
561: template <device::cupm::DeviceType T>
562: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMultAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
563: {
564:   PetscFunctionBegin;
565:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMult_Seq, thrust::multiplies<PetscScalar>{}, wout, xin, yin, dctx));
566:   PetscFunctionReturn(PETSC_SUCCESS);
567: }

569: // v->ops->pointwisemult
570: template <device::cupm::DeviceType T>
571: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMult(Vec wout, Vec xin, Vec yin) noexcept
572: {
573:   PetscFunctionBegin;
574:   PetscCall(PointwiseMultAsync(wout, xin, yin, nullptr));
575:   PetscFunctionReturn(PETSC_SUCCESS);
576: }

578: namespace detail
579: {

581: struct MaximumRealPart {
582:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::maximum<PetscReal>{}(PetscRealPart(lhs), PetscRealPart(rhs)); }
583: };

585: } // namespace detail

587: // VecPointwiseMaxAsync_Private
588: template <device::cupm::DeviceType T>
589: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
590: {
591:   PetscFunctionBegin;
592:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMax_Seq, detail::MaximumRealPart{}, wout, xin, yin, dctx));
593:   PetscFunctionReturn(PETSC_SUCCESS);
594: }

596: // v->ops->pointwisemax
597: template <device::cupm::DeviceType T>
598: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMax(Vec wout, Vec xin, Vec yin) noexcept
599: {
600:   PetscFunctionBegin;
601:   PetscCall(PointwiseMaxAsync(wout, xin, yin, nullptr));
602:   PetscFunctionReturn(PETSC_SUCCESS);
603: }

605: namespace detail
606: {

608: struct MaximumAbsoluteValue {
609:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::maximum<PetscReal>{}(PetscAbsScalar(lhs), PetscAbsScalar(rhs)); }
610: };

612: } // namespace detail

614: // VecPointwiseMaxAbsAsync_Private
615: template <device::cupm::DeviceType T>
616: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAbsAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
617: {
618:   PetscFunctionBegin;
619:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMaxAbs_Seq, detail::MaximumAbsoluteValue{}, wout, xin, yin, dctx));
620:   PetscFunctionReturn(PETSC_SUCCESS);
621: }

623: // v->ops->pointwisemaxabs
624: template <device::cupm::DeviceType T>
625: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMaxAbs(Vec wout, Vec xin, Vec yin) noexcept
626: {
627:   PetscFunctionBegin;
628:   PetscCall(PointwiseMaxAbsAsync(wout, xin, yin, nullptr));
629:   PetscFunctionReturn(PETSC_SUCCESS);
630: }

632: namespace detail
633: {

635: struct MinimumRealPart {
636:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &lhs, const PetscScalar &rhs) const noexcept { return thrust::minimum<PetscReal>{}(PetscRealPart(lhs), PetscRealPart(rhs)); }
637: };

639: } // namespace detail

641: // VecPointwiseMinAsync_Private
642: template <device::cupm::DeviceType T>
643: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMinAsync(Vec wout, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
644: {
645:   PetscFunctionBegin;
646:   PetscCall(PointwiseBinaryDispatch_(VecPointwiseMin_Seq, detail::MinimumRealPart{}, wout, xin, yin, dctx));
647:   PetscFunctionReturn(PETSC_SUCCESS);
648: }

650: // v->ops->pointwisemin
651: template <device::cupm::DeviceType T>
652: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseMin(Vec wout, Vec xin, Vec yin) noexcept
653: {
654:   PetscFunctionBegin;
655:   PetscCall(PointwiseMinAsync(wout, xin, yin, nullptr));
656:   PetscFunctionReturn(PETSC_SUCCESS);
657: }

659: namespace detail
660: {

662: struct Reciprocal {
663:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept
664:   {
665:     // yes all of this verbosity is needed because sometimes PetscScalar is a thrust::complex
666:     // and then it matters whether we do s ? true : false vs s == 0, as well as whether we wrap
667:     // everything in PetscScalar...
668:     return s == PetscScalar{0.0} ? s : PetscScalar{1.0} / s;
669:   }
670: };

672: } // namespace detail

674: // VecReciprocalAsync_Private
675: template <device::cupm::DeviceType T>
676: inline PetscErrorCode VecSeq_CUPM<T>::ReciprocalAsync(Vec xin, PetscDeviceContext dctx) noexcept
677: {
678:   PetscFunctionBegin;
679:   PetscCall(PointwiseUnary_(detail::Reciprocal{}, xin, nullptr, dctx));
680:   PetscFunctionReturn(PETSC_SUCCESS);
681: }

683: // v->ops->reciprocal
684: template <device::cupm::DeviceType T>
685: inline PetscErrorCode VecSeq_CUPM<T>::Reciprocal(Vec xin) noexcept
686: {
687:   PetscFunctionBegin;
688:   PetscCall(ReciprocalAsync(xin, nullptr));
689:   PetscFunctionReturn(PETSC_SUCCESS);
690: }

692: namespace detail
693: {

695: struct AbsoluteValue {
696:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscAbsScalar(s); }
697: };

699: } // namespace detail

701: // VecAbsAsync_Private
702: template <device::cupm::DeviceType T>
703: inline PetscErrorCode VecSeq_CUPM<T>::AbsAsync(Vec xin, PetscDeviceContext dctx) noexcept
704: {
705:   PetscFunctionBegin;
706:   PetscCall(PointwiseUnary_(detail::AbsoluteValue{}, xin, nullptr, dctx));
707:   PetscFunctionReturn(PETSC_SUCCESS);
708: }

710: // v->ops->abs
711: template <device::cupm::DeviceType T>
712: inline PetscErrorCode VecSeq_CUPM<T>::Abs(Vec xin) noexcept
713: {
714:   PetscFunctionBegin;
715:   PetscCall(AbsAsync(xin, nullptr));
716:   PetscFunctionReturn(PETSC_SUCCESS);
717: }

719: namespace detail
720: {

722: struct SignZeroToSignedUnit {
723:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return VecSignZeroToSignedUnit_Private(PetscRealPart(s)); }
724: };

726: struct SignZeroToZero {
727:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return VecSignZeroToZero_Private(PetscRealPart(s)); }
728: };

730: struct SignZeroToSignedZero {
731:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return VecSignZeroToSignedZero_Private(PetscRealPart(s)); }
732: };

734: } // namespace detail

736: // VecPointwiseSignAsync_Private
737: template <device::cupm::DeviceType T>
738: inline PetscErrorCode VecSeq_CUPM<T>::PointwiseSignAsync(Vec yout, Vec xin, VecSignMode sign_type, PetscDeviceContext dctx) noexcept
739: {
740:   PetscFunctionBegin;
741:   switch (sign_type) {
742:   case VEC_SIGN_ZERO_TO_ZERO:
743:     PetscCall(PointwiseUnary_(detail::SignZeroToZero{}, xin, yout, dctx));
744:     break;
745:   case VEC_SIGN_ZERO_TO_SIGNED_ZERO:
746:     PetscCall(PointwiseUnary_(detail::SignZeroToSignedZero{}, xin, yout, dctx));
747:     break;
748:   case VEC_SIGN_ZERO_TO_SIGNED_UNIT:
749:     PetscCall(PointwiseUnary_(detail::SignZeroToSignedUnit{}, xin, yout, dctx));
750:     break;
751:   }
752:   PetscFunctionReturn(PETSC_SUCCESS);
753: }

755: namespace detail
756: {

758: struct SquareRootAbsoluteValue {
759:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscSqrtReal(PetscAbsScalar(s)); }
760: };

762: } // namespace detail

764: // VecSqrtAbsAsync_Private
765: template <device::cupm::DeviceType T>
766: inline PetscErrorCode VecSeq_CUPM<T>::SqrtAbsAsync(Vec xin, PetscDeviceContext dctx) noexcept
767: {
768:   PetscFunctionBegin;
769:   PetscCall(PointwiseUnary_(detail::SquareRootAbsoluteValue{}, xin, nullptr, dctx));
770:   PetscFunctionReturn(PETSC_SUCCESS);
771: }

773: // v->ops->sqrt
774: template <device::cupm::DeviceType T>
775: inline PetscErrorCode VecSeq_CUPM<T>::SqrtAbs(Vec xin) noexcept
776: {
777:   PetscFunctionBegin;
778:   PetscCall(SqrtAbsAsync(xin, nullptr));
779:   PetscFunctionReturn(PETSC_SUCCESS);
780: }

782: namespace detail
783: {

785: struct Exponent {
786:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscExpScalar(s); }
787: };

789: } // namespace detail

791: // VecExpAsync_Private
792: template <device::cupm::DeviceType T>
793: inline PetscErrorCode VecSeq_CUPM<T>::ExpAsync(Vec xin, PetscDeviceContext dctx) noexcept
794: {
795:   PetscFunctionBegin;
796:   PetscCall(PointwiseUnary_(detail::Exponent{}, xin, nullptr, dctx));
797:   PetscFunctionReturn(PETSC_SUCCESS);
798: }

800: // v->ops->exp
801: template <device::cupm::DeviceType T>
802: inline PetscErrorCode VecSeq_CUPM<T>::Exp(Vec xin) noexcept
803: {
804:   PetscFunctionBegin;
805:   PetscCall(ExpAsync(xin, nullptr));
806:   PetscFunctionReturn(PETSC_SUCCESS);
807: }

809: namespace detail
810: {

812: struct Logarithm {
813:   PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &s) const noexcept { return PetscLogScalar(s); }
814: };

816: } // namespace detail

818: // VecLogAsync_Private
819: template <device::cupm::DeviceType T>
820: inline PetscErrorCode VecSeq_CUPM<T>::LogAsync(Vec xin, PetscDeviceContext dctx) noexcept
821: {
822:   PetscFunctionBegin;
823:   PetscCall(PointwiseUnary_(detail::Logarithm{}, xin, nullptr, dctx));
824:   PetscFunctionReturn(PETSC_SUCCESS);
825: }

827: // v->ops->log
828: template <device::cupm::DeviceType T>
829: inline PetscErrorCode VecSeq_CUPM<T>::Log(Vec xin) noexcept
830: {
831:   PetscFunctionBegin;
832:   PetscCall(LogAsync(xin, nullptr));
833:   PetscFunctionReturn(PETSC_SUCCESS);
834: }

836: // v->ops->waxpy
837: template <device::cupm::DeviceType T>
838: inline PetscErrorCode VecSeq_CUPM<T>::WAXPYAsync(Vec win, PetscScalar alpha, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
839: {
840:   PetscBool xiscupm, yiscupm;

842:   PetscFunctionBegin;
843:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
844:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
845:   if (!xiscupm || !yiscupm) {
846:     PetscCall(VecWAXPY_Seq(win, alpha, xin, yin));
847:     PetscFunctionReturn(PETSC_SUCCESS);
848:   }
849:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
850:   if (alpha == PetscScalar(0.0)) {
851:     PetscCall(CopyAsync(yin, win, dctx));
852:   } else if (const auto n = static_cast<cupmBlasInt_t>(win->map->n)) {
853:     cupmBlasHandle_t cupmBlasHandle;
854:     cupmStream_t     stream;
855:     PetscBool        xiscupm, yiscupm;

857:     PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
858:     PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
859:     if (!xiscupm || !yiscupm) {
860:       PetscCall(VecWAXPY_Seq(win, alpha, xin, yin));
861:       PetscFunctionReturn(PETSC_SUCCESS);
862:     }
863:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle, NULL, &stream));
864:     {
865:       const auto wptr = DeviceArrayWrite(dctx, win);

867:       PetscCall(PetscLogGpuTimeBegin());
868:       PetscCall(PetscCUPMMemcpyAsync(wptr.data(), DeviceArrayRead(dctx, yin).data(), n, cupmMemcpyDeviceToDevice, stream, true));
869:       PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayRead(dctx, xin), 1, wptr.cupmdata(), 1));
870:       PetscCall(PetscLogGpuTimeEnd());
871:     }
872:     PetscCall(PetscLogGpuFlops(2 * n));
873:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
874:   }
875:   PetscFunctionReturn(PETSC_SUCCESS);
876: }

878: // v->ops->waxpy
879: template <device::cupm::DeviceType T>
880: inline PetscErrorCode VecSeq_CUPM<T>::WAXPY(Vec win, PetscScalar alpha, Vec xin, Vec yin) noexcept
881: {
882:   PetscFunctionBegin;
883:   PetscCall(WAXPYAsync(win, alpha, xin, yin, nullptr));
884:   PetscFunctionReturn(PETSC_SUCCESS);
885: }

887: namespace kernels
888: {

890: template <typename... Args>
891: PETSC_KERNEL_DECL static void MAXPY_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT xptr, const PetscScalar *PETSC_RESTRICT aptr, Args... yptr)
892: {
893:   constexpr int      N        = sizeof...(Args);
894:   const auto         tx       = threadIdx.x;
895:   const PetscScalar *yptr_p[] = {yptr...};

897:   PETSC_SHAREDMEM_DECL PetscScalar aptr_shmem[N];

899:   // load a to shared memory
900:   if (tx < N) aptr_shmem[tx] = aptr[tx];
901:   __syncthreads();

903:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
904:   // these may look the same but give different results!
905: #if 0
906:     PetscScalar sum = 0.0;

908:   #pragma unroll
909:     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j]*yptr_p[j][i];
910:     xptr[i] += sum;
911: #else
912:     auto sum = xptr[i];

914:   #pragma unroll
915:     for (auto j = 0; j < N; ++j) sum += aptr_shmem[j] * yptr_p[j][i];
916:     xptr[i] = sum;
917: #endif
918:   });
919:   return;
920: }

922: } // namespace kernels

924: namespace detail
925: {

927: // a helper-struct to gobble the size_t input, it is used with template parameter pack
928: // expansion such that
929: // typename repeat_type...
930: // expands to
931: // MyType, MyType, MyType, ... [repeated sizeof...(IdxParamPack) times]
932: template <typename T, std::size_t>
933: struct repeat_type {
934:   using type = T;
935: };

937: } // namespace detail

939: template <device::cupm::DeviceType T>
940: template <std::size_t... Idx>
941: inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, util::index_sequence<Idx...>) noexcept
942: {
943:   PetscFunctionBegin;
944:   // clang-format off
945:   PetscCall(
946:     PetscCUPMLaunchKernel1D(
947:       size, 0, stream,
948:       kernels::MAXPY_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
949:       size, xptr, aptr, DeviceArrayRead(dctx, yin[Idx]).data()...
950:     )
951:   );
952:   // clang-format on
953:   PetscFunctionReturn(PETSC_SUCCESS);
954: }

956: template <device::cupm::DeviceType T>
957: template <int N>
958: inline PetscErrorCode VecSeq_CUPM<T>::MAXPY_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, PetscScalar *xptr, const PetscScalar *aptr, const Vec *yin, PetscInt size, PetscInt &yidx) noexcept
959: {
960:   PetscFunctionBegin;
961:   PetscCall(MAXPY_kernel_dispatch_(dctx, stream, xptr, aptr + yidx, yin + yidx, size, util::make_index_sequence<N>{}));
962:   yidx += N;
963:   PetscFunctionReturn(PETSC_SUCCESS);
964: }

966: // VecMAXPYAsync_Private
967: template <device::cupm::DeviceType T>
968: inline PetscErrorCode VecSeq_CUPM<T>::MAXPYAsync(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin, PetscDeviceContext dctx) noexcept
969: {
970:   const auto   n = xin->map->n;
971:   cupmStream_t stream;
972:   PetscBool    yiscupm = PETSC_TRUE;

974:   PetscFunctionBegin;
975:   for (PetscInt i = 0; i < nv && yiscupm; i++) PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin[i]), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
976:   if (!yiscupm) {
977:     PetscCall(VecMAXPY_Seq(xin, nv, alpha, yin));
978:     PetscFunctionReturn(PETSC_SUCCESS);
979:   }
980:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
981:   PetscCall(GetHandlesFrom_(dctx, &stream));
982:   {
983:     const auto   xptr    = DeviceArrayReadWrite(dctx, xin);
984:     PetscScalar *d_alpha = nullptr;
985:     PetscInt     yidx    = 0;

987:     // placement of early-return is deliberate, we would like to capture the
988:     // DeviceArrayReadWrite() call (which calls PetscObjectStateIncreate()) before we bail
989:     if (!n || !nv) PetscFunctionReturn(PETSC_SUCCESS);
990:     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_alpha));
991:     PetscCall(PetscCUPMMemcpyAsync(d_alpha, alpha, nv, cupmMemcpyHostToDevice, stream));
992:     PetscCall(PetscLogGpuTimeBegin());
993:     do {
994:       switch (nv - yidx) {
995:       case 7:
996:         PetscCall(MAXPY_kernel_dispatch_<7>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
997:         break;
998:       case 6:
999:         PetscCall(MAXPY_kernel_dispatch_<6>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
1000:         break;
1001:       case 5:
1002:         PetscCall(MAXPY_kernel_dispatch_<5>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
1003:         break;
1004:       case 4:
1005:         PetscCall(MAXPY_kernel_dispatch_<4>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
1006:         break;
1007:       case 3:
1008:         PetscCall(MAXPY_kernel_dispatch_<3>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
1009:         break;
1010:       case 2:
1011:         PetscCall(MAXPY_kernel_dispatch_<2>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
1012:         break;
1013:       case 1:
1014:         PetscCall(MAXPY_kernel_dispatch_<1>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
1015:         break;
1016:       default: // 8 or more
1017:         PetscCall(MAXPY_kernel_dispatch_<8>(dctx, stream, xptr.data(), d_alpha, yin, n, yidx));
1018:         break;
1019:       }
1020:     } while (yidx < nv);
1021:     PetscCall(PetscLogGpuTimeEnd());
1022:     PetscCall(PetscDeviceFree(dctx, d_alpha));
1023:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1024:   }
1025:   PetscCall(PetscLogGpuFlops(nv * 2 * n));
1026:   PetscFunctionReturn(PETSC_SUCCESS);
1027: }

1029: // v->ops->maxpy
1030: template <device::cupm::DeviceType T>
1031: inline PetscErrorCode VecSeq_CUPM<T>::MAXPY(Vec xin, PetscInt nv, const PetscScalar *alpha, Vec *yin) noexcept
1032: {
1033:   PetscFunctionBegin;
1034:   PetscCall(MAXPYAsync(xin, nv, alpha, yin, nullptr));
1035:   PetscFunctionReturn(PETSC_SUCCESS);
1036: }

1038: template <device::cupm::DeviceType T>
1039: inline PetscErrorCode VecSeq_CUPM<T>::Dot(Vec xin, Vec yin, PetscScalar *z) noexcept
1040: {
1041:   PetscBool yiscupm;

1043:   PetscFunctionBegin;
1044:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1045:   if (!yiscupm) {
1046:     PetscCall(VecDot_Seq(xin, yin, z));
1047:     PetscFunctionReturn(PETSC_SUCCESS);
1048:   }
1049:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1050:     PetscDeviceContext dctx;
1051:     cupmBlasHandle_t   cupmBlasHandle;

1053:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1054:     // arguments y, x are reversed because BLAS complex conjugates the first argument, PETSc the
1055:     // second
1056:     PetscCall(PetscLogGpuTimeBegin());
1057:     PetscCallCUPMBLAS(cupmBlasXdot(cupmBlasHandle, n, DeviceArrayRead(dctx, yin), 1, DeviceArrayRead(dctx, xin), 1, cupmScalarPtrCast(z)));
1058:     PetscCall(PetscLogGpuTimeEnd());
1059:     PetscCall(PetscLogGpuFlops(2 * n - 1));
1060:   } else {
1061:     *z = 0.0;
1062:   }
1063:   PetscFunctionReturn(PETSC_SUCCESS);
1064: }

1066: #define MDOT_WORKGROUP_NUM  128
1067: #define MDOT_WORKGROUP_SIZE MDOT_WORKGROUP_NUM

1069: namespace kernels
1070: {

1072: PETSC_DEVICE_INLINE_DECL static PetscInt EntriesPerGroup(const PetscInt size) noexcept
1073: {
1074:   const auto group_entries = (size - 1) / gridDim.x + 1;
1075:   // for very small vectors, a group should still do some work
1076:   return group_entries ? group_entries : 1;
1077: }

1079: template <typename... ConstPetscScalarPointer>
1080: PETSC_KERNEL_DECL static void MDot_kernel(const PetscScalar *PETSC_RESTRICT x, const PetscInt size, PetscScalar *PETSC_RESTRICT results, ConstPetscScalarPointer... y)
1081: {
1082:   constexpr int      N        = sizeof...(ConstPetscScalarPointer);
1083:   const PetscScalar *ylocal[] = {y...};
1084:   PetscScalar        sumlocal[N];

1086:   PETSC_SHAREDMEM_DECL PetscScalar shmem[N * MDOT_WORKGROUP_SIZE];

1088:   // HIP -- for whatever reason -- has threadIdx, blockIdx, blockDim, and gridDim as separate
1089:   // types, so each of these go on separate lines...
1090:   const auto tx       = threadIdx.x;
1091:   const auto bx       = blockIdx.x;
1092:   const auto bdx      = blockDim.x;
1093:   const auto gdx      = gridDim.x;
1094:   const auto worksize = EntriesPerGroup(size);
1095:   const auto begin    = tx + bx * worksize;
1096:   const auto end      = min((bx + 1) * worksize, size);

1098: #pragma unroll
1099:   for (auto i = 0; i < N; ++i) sumlocal[i] = 0;

1101:   for (auto i = begin; i < end; i += bdx) {
1102:     const auto xi = x[i]; // load only once from global memory!

1104: #pragma unroll
1105:     for (auto j = 0; j < N; ++j) sumlocal[j] += ylocal[j][i] * xi;
1106:   }

1108: #pragma unroll
1109:   for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] = sumlocal[i];

1111:   // parallel reduction
1112:   for (auto stride = bdx / 2; stride > 0; stride /= 2) {
1113:     __syncthreads();
1114:     if (tx < stride) {
1115: #pragma unroll
1116:       for (auto i = 0; i < N; ++i) shmem[tx + i * MDOT_WORKGROUP_SIZE] += shmem[tx + stride + i * MDOT_WORKGROUP_SIZE];
1117:     }
1118:   }
1119:   // bottom N threads per block write to global memory
1120:   // REVIEW ME: I am ~pretty~ sure we don't need another __syncthreads() here since each thread
1121:   // writes to the same sections in the above loop that it is about to read from below, but
1122:   // running this under the racecheck tool of compute-sanitizer reports a write-after-write hazard.
1123:   __syncthreads();
1124:   if (tx < N) results[bx + tx * gdx] = shmem[tx * MDOT_WORKGROUP_SIZE];
1125:   return;
1126: }

1128: namespace
1129: {

1131: PETSC_KERNEL_DECL void sum_kernel(const PetscInt size, PetscScalar *PETSC_RESTRICT results)
1132: {
1133:   int         local_i = 0;
1134:   PetscScalar local_results[8];

1136:   // each thread sums up MDOT_WORKGROUP_NUM entries of the result, storing it in a local buffer
1137:   //
1138:   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
1139:   // | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | ...
1140:   // *-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*-*
1141:   //  |  ______________________________________________________/
1142:   //  | /            <- MDOT_WORKGROUP_NUM ->
1143:   //  |/
1144:   //  +
1145:   //  v
1146:   // *-*-*
1147:   // | | | ...
1148:   // *-*-*
1149:   //
1150:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
1151:     PetscScalar z_sum = 0;

1153:     for (auto j = i * MDOT_WORKGROUP_SIZE; j < (i + 1) * MDOT_WORKGROUP_SIZE; ++j) z_sum += results[j];
1154:     local_results[local_i++] = z_sum;
1155:   });
1156:   // if we needed more than 1 workgroup to handle the vector we should sync since other threads
1157:   // may currently be reading from results
1158:   if (size >= MDOT_WORKGROUP_SIZE) __syncthreads();
1159:   // Local buffer is now written to global memory
1160:   ::Petsc::device::cupm::kernels::util::grid_stride_1D(size, [&](PetscInt i) {
1161:     const auto j = --local_i;

1163:     if (j >= 0) results[i] = local_results[j];
1164:   });
1165:   return;
1166: }

1168: } // namespace

1170: #if PetscDefined(USING_HCC)
1171: namespace do_not_use
1172: {

1174: inline void silence_warning_function_sum_kernel_is_not_needed_and_will_not_be_emitted()
1175: {
1176:   (void)sum_kernel;
1177: }

1179: } // namespace do_not_use
1180: #endif

1182: } // namespace kernels

1184: template <device::cupm::DeviceType T>
1185: template <std::size_t... Idx>
1186: inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, util::index_sequence<Idx...>) noexcept
1187: {
1188:   PetscFunctionBegin;
1189:   // REVIEW ME: convert this kernel launch to PetscCUPMLaunchKernel1D(), it currently launches
1190:   // 128 blocks of 128 threads every time which may be wasteful
1191:   // clang-format off
1192:   PetscCallCUPM(
1193:     cupmLaunchKernel(
1194:       kernels::MDot_kernel<typename detail::repeat_type<const PetscScalar *, Idx>::type...>,
1195:       MDOT_WORKGROUP_NUM, MDOT_WORKGROUP_SIZE, 0, stream,
1196:       xarr, size, results, DeviceArrayRead(dctx, yin[Idx]).data()...
1197:     )
1198:   );
1199:   // clang-format on
1200:   PetscFunctionReturn(PETSC_SUCCESS);
1201: }

1203: template <device::cupm::DeviceType T>
1204: template <int N>
1205: inline PetscErrorCode VecSeq_CUPM<T>::MDot_kernel_dispatch_(PetscDeviceContext dctx, cupmStream_t stream, const PetscScalar *xarr, const Vec yin[], PetscInt size, PetscScalar *results, PetscInt &yidx) noexcept
1206: {
1207:   PetscFunctionBegin;
1208:   PetscCall(MDot_kernel_dispatch_(dctx, stream, xarr, yin + yidx, size, results + yidx * MDOT_WORKGROUP_NUM, util::make_index_sequence<N>{}));
1209:   yidx += N;
1210:   PetscFunctionReturn(PETSC_SUCCESS);
1211: }

1213: template <device::cupm::DeviceType T>
1214: inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::false_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
1215: {
1216:   // the largest possible size of a batch
1217:   constexpr PetscInt batchsize = 8;
1218:   // how many sub streams to create, if nv <= batchsize we can do this without looping, so we
1219:   // do not create substreams. Note we don't create more than 8 streams, in practice we could
1220:   // not get more parallelism with higher numbers.
1221:   const auto   num_sub_streams = nv > batchsize ? std::min((nv + batchsize) / batchsize, batchsize) : 0;
1222:   const auto   n               = xin->map->n;
1223:   const auto   nwork           = nv * MDOT_WORKGROUP_NUM;
1224:   PetscScalar *d_results;
1225:   cupmStream_t stream;

1227:   PetscFunctionBegin;
1228:   PetscCall(GetHandlesFrom_(dctx, &stream));
1229:   // allocate scratchpad memory for the results of individual work groups
1230:   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nwork, &d_results));
1231:   {
1232:     const auto          xptr       = DeviceArrayRead(dctx, xin);
1233:     PetscInt            yidx       = 0;
1234:     auto                subidx     = 0;
1235:     auto                cur_stream = stream;
1236:     auto                cur_ctx    = dctx;
1237:     PetscDeviceContext *sub        = nullptr;
1238:     PetscStreamType     stype;

1240:     // REVIEW ME: maybe PetscDeviceContextFork() should insert dctx into the first entry of
1241:     // sub. Ideally the parent context should also join in on the fork, but it is extremely
1242:     // fiddly to do so presently
1243:     PetscCall(PetscDeviceContextGetStreamType(dctx, &stype));
1244:     if (stype == PETSC_STREAM_DEFAULT || stype == PETSC_STREAM_DEFAULT_WITH_BARRIER) stype = PETSC_STREAM_NONBLOCKING;
1245:     // If we have a default stream create nonblocking streams instead (as we can
1246:     // locally exploit the parallelism). Otherwise use the prescribed stream type.
1247:     PetscCall(PetscDeviceContextForkWithStreamType(dctx, stype, num_sub_streams, &sub));
1248:     PetscCall(PetscLogGpuTimeBegin());
1249:     do {
1250:       if (num_sub_streams) {
1251:         cur_ctx = sub[subidx++ % num_sub_streams];
1252:         PetscCall(GetHandlesFrom_(cur_ctx, &cur_stream));
1253:       }
1254:       // REVIEW ME: Should probably try and load-balance these. Consider the case where nv = 9;
1255:       // it is very likely better to do 4+5 rather than 8+1
1256:       switch (nv - yidx) {
1257:       case 7:
1258:         PetscCall(MDot_kernel_dispatch_<7>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1259:         break;
1260:       case 6:
1261:         PetscCall(MDot_kernel_dispatch_<6>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1262:         break;
1263:       case 5:
1264:         PetscCall(MDot_kernel_dispatch_<5>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1265:         break;
1266:       case 4:
1267:         PetscCall(MDot_kernel_dispatch_<4>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1268:         break;
1269:       case 3:
1270:         PetscCall(MDot_kernel_dispatch_<3>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1271:         break;
1272:       case 2:
1273:         PetscCall(MDot_kernel_dispatch_<2>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1274:         break;
1275:       case 1:
1276:         PetscCall(MDot_kernel_dispatch_<1>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1277:         break;
1278:       default: // 8 or more
1279:         PetscCall(MDot_kernel_dispatch_<8>(cur_ctx, cur_stream, xptr.data(), yin, n, d_results, yidx));
1280:         break;
1281:       }
1282:     } while (yidx < nv);
1283:     PetscCall(PetscLogGpuTimeEnd());
1284:     PetscCall(PetscDeviceContextJoin(dctx, num_sub_streams, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &sub));
1285:   }

1287:   PetscCall(PetscCUPMLaunchKernel1D(nv, 0, stream, kernels::sum_kernel, nv, d_results));
1288:   // copy result of device reduction to host
1289:   PetscCall(PetscCUPMMemcpyAsync(z, d_results, nv, cupmMemcpyDeviceToHost, stream));
1290:   // do these now while final reduction is in flight
1291:   PetscCall(PetscLogGpuFlops(nwork));
1292:   PetscCall(PetscDeviceFree(dctx, d_results));
1293:   PetscFunctionReturn(PETSC_SUCCESS);
1294: }

1296: #undef MDOT_WORKGROUP_NUM
1297: #undef MDOT_WORKGROUP_SIZE

1299: template <device::cupm::DeviceType T>
1300: inline PetscErrorCode VecSeq_CUPM<T>::MDot_(std::true_type, Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z, PetscDeviceContext dctx) noexcept
1301: {
1302:   // probably not worth it to run more than 8 of these at a time?
1303:   const auto          n_sub = PetscMin(nv, 8);
1304:   const auto          n     = static_cast<cupmBlasInt_t>(xin->map->n);
1305:   const auto          xptr  = DeviceArrayRead(dctx, xin);
1306:   PetscScalar        *d_z;
1307:   PetscDeviceContext *subctx;
1308:   cupmStream_t        stream;

1310:   PetscFunctionBegin;
1311:   PetscCall(GetHandlesFrom_(dctx, &stream));
1312:   PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), nv, &d_z));
1313:   PetscCall(PetscDeviceContextFork(dctx, n_sub, &subctx));
1314:   PetscCall(PetscLogGpuTimeBegin());
1315:   for (PetscInt i = 0; i < nv; ++i) {
1316:     const auto            sub = subctx[i % n_sub];
1317:     cupmBlasHandle_t      handle;
1318:     cupmBlasPointerMode_t old_mode;

1320:     PetscCall(GetHandlesFrom_(sub, &handle));
1321:     PetscCallCUPMBLAS(cupmBlasGetPointerMode(handle, &old_mode));
1322:     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, CUPMBLAS_POINTER_MODE_DEVICE));
1323:     PetscCallCUPMBLAS(cupmBlasXdot(handle, n, DeviceArrayRead(sub, yin[i]), 1, xptr.cupmdata(), 1, cupmScalarPtrCast(d_z + i)));
1324:     if (old_mode != CUPMBLAS_POINTER_MODE_DEVICE) PetscCallCUPMBLAS(cupmBlasSetPointerMode(handle, old_mode));
1325:   }
1326:   PetscCall(PetscLogGpuTimeEnd());
1327:   PetscCall(PetscDeviceContextJoin(dctx, n_sub, PETSC_DEVICE_CONTEXT_JOIN_DESTROY, &subctx));
1328:   PetscCall(PetscCUPMMemcpyAsync(z, d_z, nv, cupmMemcpyDeviceToHost, stream));
1329:   PetscCall(PetscDeviceFree(dctx, d_z));
1330:   // REVIEW ME: flops?????
1331:   PetscFunctionReturn(PETSC_SUCCESS);
1332: }

1334: // v->ops->mdot
1335: template <device::cupm::DeviceType T>
1336: inline PetscErrorCode VecSeq_CUPM<T>::MDot(Vec xin, PetscInt nv, const Vec yin[], PetscScalar *z) noexcept
1337: {
1338:   PetscFunctionBegin;
1339:   if (PetscUnlikely(nv == 1)) {
1340:     // dot handles nv = 0 correctly
1341:     PetscCall(Dot(xin, const_cast<Vec>(yin[0]), z));
1342:   } else if (const auto n = xin->map->n) {
1343:     PetscDeviceContext dctx;

1345:     PetscCheck(nv > 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "Number of vectors provided to %s %" PetscInt_FMT " not positive", PETSC_FUNCTION_NAME, nv);
1346:     PetscCall(GetHandles_(&dctx));
1347:     PetscCall(MDot_(std::integral_constant<bool, PetscDefined(USE_COMPLEX)>{}, xin, nv, yin, z, dctx));
1348:     // REVIEW ME: double count of flops??
1349:     PetscCall(PetscLogGpuFlops(nv * (2 * n - 1)));
1350:     PetscCall(PetscDeviceContextSynchronize(dctx));
1351:   } else {
1352:     PetscCall(PetscArrayzero(z, nv));
1353:   }
1354:   PetscFunctionReturn(PETSC_SUCCESS);
1355: }

1357: // VecSetAsync_Private
1358: template <device::cupm::DeviceType T>
1359: inline PetscErrorCode VecSeq_CUPM<T>::SetAsync(Vec xin, PetscScalar alpha, PetscDeviceContext dctx) noexcept
1360: {
1361:   const auto   n = xin->map->n;
1362:   cupmStream_t stream;

1364:   PetscFunctionBegin;
1365:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1366:   PetscCall(GetHandlesFrom_(dctx, &stream));
1367:   {
1368:     const auto xptr = DeviceArrayWrite(dctx, xin);

1370:     if (alpha == PetscScalar(0.0)) {
1371:       PetscCall(PetscCUPMMemsetAsync(xptr.data(), 0, n, stream));
1372:     } else {
1373:       const auto dptr = thrust::device_pointer_cast(xptr.data());

1375:       PetscCallThrust(THRUST_CALL(thrust::fill, stream, dptr, dptr + n, alpha));
1376:     }
1377:   }
1378:   if (n > 0) PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1379:   PetscFunctionReturn(PETSC_SUCCESS);
1380: }

1382: // v->ops->set
1383: template <device::cupm::DeviceType T>
1384: inline PetscErrorCode VecSeq_CUPM<T>::Set(Vec xin, PetscScalar alpha) noexcept
1385: {
1386:   PetscFunctionBegin;
1387:   PetscCall(SetAsync(xin, alpha, nullptr));
1388:   PetscFunctionReturn(PETSC_SUCCESS);
1389: }

1391: // VecScaleAsync_Private
1392: template <device::cupm::DeviceType T>
1393: inline PetscErrorCode VecSeq_CUPM<T>::ScaleAsync(Vec xin, PetscScalar alpha, PetscDeviceContext dctx) noexcept
1394: {
1395:   PetscFunctionBegin;
1396:   if (PetscUnlikely(alpha == PetscScalar(1.0))) PetscFunctionReturn(PETSC_SUCCESS);
1397:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1398:   if (PetscUnlikely(alpha == PetscScalar(0.0))) {
1399:     PetscCall(SetAsync(xin, alpha, dctx));
1400:   } else if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1401:     cupmBlasHandle_t cupmBlasHandle;

1403:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1404:     PetscCall(PetscLogGpuTimeBegin());
1405:     PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&alpha), DeviceArrayReadWrite(dctx, xin), 1));
1406:     PetscCall(PetscLogGpuTimeEnd());
1407:     PetscCall(PetscLogGpuFlops(n));
1408:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1409:   } else {
1410:     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1411:   }
1412:   PetscFunctionReturn(PETSC_SUCCESS);
1413: }

1415: // v->ops->scale
1416: template <device::cupm::DeviceType T>
1417: inline PetscErrorCode VecSeq_CUPM<T>::Scale(Vec xin, PetscScalar alpha) noexcept
1418: {
1419:   PetscFunctionBegin;
1420:   PetscCall(ScaleAsync(xin, alpha, nullptr));
1421:   PetscFunctionReturn(PETSC_SUCCESS);
1422: }

1424: // v->ops->tdot
1425: template <device::cupm::DeviceType T>
1426: inline PetscErrorCode VecSeq_CUPM<T>::TDot(Vec xin, Vec yin, PetscScalar *z) noexcept
1427: {
1428:   PetscBool yiscupm;

1430:   PetscFunctionBegin;
1431:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1432:   if (!yiscupm) {
1433:     PetscCall(VecTDot_Seq(xin, yin, z));
1434:     PetscFunctionReturn(PETSC_SUCCESS);
1435:   }
1436:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1437:     PetscDeviceContext dctx;
1438:     cupmBlasHandle_t   cupmBlasHandle;

1440:     PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1441:     PetscCall(PetscLogGpuTimeBegin());
1442:     PetscCallCUPMBLAS(cupmBlasXdotu(cupmBlasHandle, n, DeviceArrayRead(dctx, xin), 1, DeviceArrayRead(dctx, yin), 1, cupmScalarPtrCast(z)));
1443:     PetscCall(PetscLogGpuTimeEnd());
1444:     PetscCall(PetscLogGpuFlops(2 * n - 1));
1445:   } else {
1446:     *z = 0.0;
1447:   }
1448:   PetscFunctionReturn(PETSC_SUCCESS);
1449: }

1451: // VecCopyAsync_Private
1452: template <device::cupm::DeviceType T>
1453: inline PetscErrorCode VecSeq_CUPM<T>::CopyAsync(Vec xin, Vec yout, PetscDeviceContext dctx) noexcept
1454: {
1455:   PetscFunctionBegin;
1456:   if (xin == yout) PetscFunctionReturn(PETSC_SUCCESS);
1457:   if (const auto n = xin->map->n) {
1458:     const auto xmask = xin->offloadmask;
1459:     // silence buggy gcc warning: mode may be used uninitialized in this function
1460:     auto         mode = cupmMemcpyDeviceToDevice;
1461:     cupmStream_t stream;

1463:     // translate from PetscOffloadMask to cupmMemcpyKind
1464:     PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1465:     switch (const auto ymask = yout->offloadmask) {
1466:     case PETSC_OFFLOAD_UNALLOCATED: {
1467:       PetscBool yiscupm;

1469:       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1470:       if (yiscupm) {
1471:         mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToHost;
1472:         break;
1473:       }
1474:     } // fall-through if unallocated and not cupm
1475: #if PETSC_CPP_VERSION >= 17
1476:       [[fallthrough]];
1477: #endif
1478:     case PETSC_OFFLOAD_CPU: {
1479:       PetscBool yiscupm;

1481:       PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yout), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1482:       if (yiscupm) {
1483:         mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToDevice : cupmMemcpyDeviceToDevice;
1484:       } else {
1485:         mode = PetscOffloadHost(xmask) ? cupmMemcpyHostToHost : cupmMemcpyDeviceToHost;
1486:       }
1487:       break;
1488:     }
1489:     case PETSC_OFFLOAD_BOTH:
1490:     case PETSC_OFFLOAD_GPU:
1491:       mode = PetscOffloadDevice(xmask) ? cupmMemcpyDeviceToDevice : cupmMemcpyHostToDevice;
1492:       break;
1493:     default:
1494:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_ARG_INCOMP, "Incompatible offload mask %s", PetscOffloadMaskToString(ymask));
1495:     }

1497:     PetscCall(GetHandlesFrom_(dctx, &stream));
1498:     switch (mode) {
1499:     case cupmMemcpyDeviceToDevice: // the best case
1500:     case cupmMemcpyHostToDevice: { // not terrible
1501:       const auto yptr = DeviceArrayWrite(dctx, yout);
1502:       const auto xptr = mode == cupmMemcpyDeviceToDevice ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();

1504:       PetscCall(PetscLogGpuTimeBegin());
1505:       PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr, n, mode, stream));
1506:       PetscCall(PetscLogGpuTimeEnd());
1507:     } break;
1508:     case cupmMemcpyDeviceToHost: // not great
1509:     case cupmMemcpyHostToHost: { // worst case
1510:       const auto   xptr = mode == cupmMemcpyDeviceToHost ? DeviceArrayRead(dctx, xin).data() : HostArrayRead(dctx, xin).data();
1511:       PetscScalar *yptr;

1513:       PetscCall(VecGetArrayWrite(yout, &yptr));
1514:       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeBegin());
1515:       PetscCall(PetscCUPMMemcpyAsync(yptr, xptr, n, mode, stream, /* force async */ true));
1516:       if (mode == cupmMemcpyDeviceToHost) PetscCall(PetscLogGpuTimeEnd());
1517:       PetscCall(VecRestoreArrayWrite(yout, &yptr));
1518:     } break;
1519:     default:
1520:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_GPU, "Unknown cupmMemcpyKind %d", static_cast<int>(mode));
1521:     }
1522:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1523:   } else {
1524:     PetscCall(MaybeIncrementEmptyLocalVec(yout));
1525:   }
1526:   PetscFunctionReturn(PETSC_SUCCESS);
1527: }

1529: // v->ops->copy
1530: template <device::cupm::DeviceType T>
1531: inline PetscErrorCode VecSeq_CUPM<T>::Copy(Vec xin, Vec yout) noexcept
1532: {
1533:   PetscFunctionBegin;
1534:   PetscCall(CopyAsync(xin, yout, nullptr));
1535:   PetscFunctionReturn(PETSC_SUCCESS);
1536: }

1538: // VecSwapAsync_Private
1539: template <device::cupm::DeviceType T>
1540: inline PetscErrorCode VecSeq_CUPM<T>::SwapAsync(Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
1541: {
1542:   PetscBool yiscupm;

1544:   PetscFunctionBegin;
1545:   if (xin == yin) PetscFunctionReturn(PETSC_SUCCESS);
1546:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(yin), &yiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1547:   PetscCheck(yiscupm, PetscObjectComm(PetscObjectCast(yin)), PETSC_ERR_SUP, "Cannot swap with Y of type %s", PetscObjectCast(yin)->type_name);
1548:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1549:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1550:     cupmBlasHandle_t cupmBlasHandle;

1552:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1553:     PetscCall(PetscLogGpuTimeBegin());
1554:     PetscCallCUPMBLAS(cupmBlasXswap(cupmBlasHandle, n, DeviceArrayReadWrite(dctx, xin), 1, DeviceArrayReadWrite(dctx, yin), 1));
1555:     PetscCall(PetscLogGpuTimeEnd());
1556:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1557:   } else {
1558:     PetscCall(MaybeIncrementEmptyLocalVec(xin));
1559:     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1560:   }
1561:   PetscFunctionReturn(PETSC_SUCCESS);
1562: }

1564: // v->ops->swap
1565: template <device::cupm::DeviceType T>
1566: inline PetscErrorCode VecSeq_CUPM<T>::Swap(Vec xin, Vec yin) noexcept
1567: {
1568:   PetscFunctionBegin;
1569:   PetscCall(SwapAsync(xin, yin, nullptr));
1570:   PetscFunctionReturn(PETSC_SUCCESS);
1571: }

1573: // VecAXPYBYAsync_Private
1574: template <device::cupm::DeviceType T>
1575: inline PetscErrorCode VecSeq_CUPM<T>::AXPBYAsync(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin, PetscDeviceContext dctx) noexcept
1576: {
1577:   PetscBool xiscupm;

1579:   PetscFunctionBegin;
1580:   PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1581:   if (!xiscupm) {
1582:     PetscCall(VecAXPBY_Seq(yin, alpha, beta, xin));
1583:     PetscFunctionReturn(PETSC_SUCCESS);
1584:   }
1585:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1586:   if (alpha == PetscScalar(0.0)) {
1587:     PetscCall(ScaleAsync(yin, beta, dctx));
1588:   } else if (beta == PetscScalar(1.0)) {
1589:     PetscCall(AXPYAsync(yin, alpha, xin, dctx));
1590:   } else if (alpha == PetscScalar(1.0)) {
1591:     PetscCall(AYPXAsync(yin, beta, xin, dctx));
1592:   } else if (const auto n = static_cast<cupmBlasInt_t>(yin->map->n)) {
1593:     PetscBool xiscupm;

1595:     PetscCall(PetscObjectTypeCompareAny(PetscObjectCast(xin), &xiscupm, VECSEQCUPM(), VECMPICUPM(), ""));
1596:     if (!xiscupm) {
1597:       PetscCall(VecAXPBY_Seq(yin, alpha, beta, xin));
1598:       PetscFunctionReturn(PETSC_SUCCESS);
1599:     }

1601:     const auto       betaIsZero = beta == PetscScalar(0.0);
1602:     const auto       aptr       = cupmScalarPtrCast(&alpha);
1603:     cupmBlasHandle_t cupmBlasHandle;

1605:     PetscCall(GetHandlesFrom_(dctx, &cupmBlasHandle));
1606:     {
1607:       const auto xptr = DeviceArrayRead(dctx, xin);

1609:       if (betaIsZero /* beta = 0 */) {
1610:         // here we can get away with purely write-only as we memcpy into it first
1611:         const auto   yptr = DeviceArrayWrite(dctx, yin);
1612:         cupmStream_t stream;

1614:         PetscCall(GetHandlesFrom_(dctx, &stream));
1615:         PetscCall(PetscLogGpuTimeBegin());
1616:         PetscCall(PetscCUPMMemcpyAsync(yptr.data(), xptr.data(), n, cupmMemcpyDeviceToDevice, stream));
1617:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, aptr, yptr.cupmdata(), 1));
1618:       } else {
1619:         const auto yptr = DeviceArrayReadWrite(dctx, yin);

1621:         PetscCall(PetscLogGpuTimeBegin());
1622:         PetscCallCUPMBLAS(cupmBlasXscal(cupmBlasHandle, n, cupmScalarPtrCast(&beta), yptr.cupmdata(), 1));
1623:         PetscCallCUPMBLAS(cupmBlasXaxpy(cupmBlasHandle, n, aptr, xptr.cupmdata(), 1, yptr.cupmdata(), 1));
1624:       }
1625:     }
1626:     PetscCall(PetscLogGpuTimeEnd());
1627:     PetscCall(PetscLogGpuFlops((betaIsZero ? 1 : 3) * n));
1628:     PetscCall(PetscDeviceContextSynchronizeIfWithBarrier_Internal(dctx));
1629:   } else {
1630:     PetscCall(MaybeIncrementEmptyLocalVec(yin));
1631:   }
1632:   PetscFunctionReturn(PETSC_SUCCESS);
1633: }

1635: // v->ops->axpby
1636: template <device::cupm::DeviceType T>
1637: inline PetscErrorCode VecSeq_CUPM<T>::AXPBY(Vec yin, PetscScalar alpha, PetscScalar beta, Vec xin) noexcept
1638: {
1639:   PetscFunctionBegin;
1640:   PetscCall(AXPBYAsync(yin, alpha, beta, xin, nullptr));
1641:   PetscFunctionReturn(PETSC_SUCCESS);
1642: }

1644: // VecAXPBYPCZAsync_Private
1645: template <device::cupm::DeviceType T>
1646: inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZAsync(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin, PetscDeviceContext dctx) noexcept
1647: {
1648:   PetscFunctionBegin;
1649:   PetscCall(PetscDeviceContextGetOptionalNullContext_Internal(&dctx));
1650:   if (gamma != PetscScalar(1.0)) PetscCall(ScaleAsync(zin, gamma, dctx));
1651:   PetscCall(AXPYAsync(zin, alpha, xin, dctx));
1652:   PetscCall(AXPYAsync(zin, beta, yin, dctx));
1653:   PetscFunctionReturn(PETSC_SUCCESS);
1654: }

1656: // v->ops->axpbypcz
1657: template <device::cupm::DeviceType T>
1658: inline PetscErrorCode VecSeq_CUPM<T>::AXPBYPCZ(Vec zin, PetscScalar alpha, PetscScalar beta, PetscScalar gamma, Vec xin, Vec yin) noexcept
1659: {
1660:   PetscFunctionBegin;
1661:   PetscCall(AXPBYPCZAsync(zin, alpha, beta, gamma, xin, yin, nullptr));
1662:   PetscFunctionReturn(PETSC_SUCCESS);
1663: }

1665: // v->ops->norm
1666: template <device::cupm::DeviceType T>
1667: inline PetscErrorCode VecSeq_CUPM<T>::Norm(Vec xin, NormType type, PetscReal *z) noexcept
1668: {
1669:   PetscDeviceContext dctx;
1670:   cupmBlasHandle_t   cupmBlasHandle;

1672:   PetscFunctionBegin;
1673:   PetscCall(GetHandles_(&dctx, &cupmBlasHandle));
1674:   if (const auto n = static_cast<cupmBlasInt_t>(xin->map->n)) {
1675:     const auto xptr      = DeviceArrayRead(dctx, xin);
1676:     PetscInt   flopCount = 0;

1678:     PetscCall(PetscLogGpuTimeBegin());
1679:     switch (type) {
1680:     case NORM_1_AND_2:
1681:     case NORM_1:
1682:       PetscCallCUPMBLAS(cupmBlasXasum(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1683:       flopCount = std::max(n - 1, 0);
1684:       if (type == NORM_1) break;
1685:       ++z; // fall-through
1686: #if PETSC_CPP_VERSION >= 17
1687:       [[fallthrough]];
1688: #endif
1689:     case NORM_2:
1690:     case NORM_FROBENIUS:
1691:       PetscCallCUPMBLAS(cupmBlasXnrm2(cupmBlasHandle, n, xptr.cupmdata(), 1, cupmRealPtrCast(z)));
1692:       flopCount += std::max(2 * n - 1, 0); // += in case we've fallen through from NORM_1_AND_2
1693:       break;
1694:     case NORM_INFINITY: {
1695:       cupmBlasInt_t max_loc = 0;
1696:       PetscScalar   xv      = 0.;
1697:       cupmStream_t  stream;

1699:       PetscCall(GetHandlesFrom_(dctx, &stream));
1700:       PetscCallCUPMBLAS(cupmBlasXamax(cupmBlasHandle, n, xptr.cupmdata(), 1, &max_loc));
1701:       PetscCall(PetscCUPMMemcpyAsync(&xv, xptr.data() + max_loc - 1, 1, cupmMemcpyDeviceToHost, stream));
1702:       *z = PetscAbsScalar(xv);
1703:       // REVIEW ME: flopCount = ???
1704:     } break;
1705:     }
1706:     PetscCall(PetscLogGpuTimeEnd());
1707:     PetscCall(PetscLogGpuFlops(flopCount));
1708:   } else {
1709:     z[0]                    = 0.0;
1710:     z[type == NORM_1_AND_2] = 0.0;
1711:   }
1712:   PetscFunctionReturn(PETSC_SUCCESS);
1713: }

1715: namespace detail
1716: {

1718: template <NormType wnormtype>
1719: class ErrorWNormTransformBase {
1720: public:
1721:   using result_type = thrust::tuple<PetscReal, PetscReal, PetscReal, PetscInt, PetscInt, PetscInt>;

1723:   constexpr explicit ErrorWNormTransformBase(PetscReal v) noexcept : ignore_max_{v} { }

1725: protected:
1726:   struct NormTuple {
1727:     PetscReal norm;
1728:     PetscInt  loc;
1729:   };

1731:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL static NormTuple compute_norm_(PetscReal err, PetscReal tol) noexcept
1732:   {
1733:     if (tol > 0.) {
1734:       const auto val = err / tol;

1736:       return {wnormtype == NORM_INFINITY ? val : PetscSqr(val), 1};
1737:     } else {
1738:       return {0.0, 0};
1739:     }
1740:   }

1742:   PetscReal ignore_max_;
1743: };

1745: template <NormType wnormtype>
1746: struct ErrorWNormTransform : ErrorWNormTransformBase<wnormtype> {
1747:   using base_type     = ErrorWNormTransformBase<wnormtype>;
1748:   using result_type   = typename base_type::result_type;
1749:   using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar>;

1751:   using base_type::base_type;

1753:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept
1754:   {
1755:     const auto u     = thrust::get<0>(x); // with x.get<0>(), cuda-12.4.0 gives error: class "cuda::std::__4::tuple" has no member "get"
1756:     const auto y     = thrust::get<1>(x);
1757:     const auto au    = PetscAbsScalar(u);
1758:     const auto ay    = PetscAbsScalar(y);
1759:     const auto skip  = au < this->ignore_max_ || ay < this->ignore_max_;
1760:     const auto tola  = skip ? 0.0 : PetscRealPart(thrust::get<2>(x));
1761:     const auto tolr  = skip ? 0.0 : PetscRealPart(thrust::get<3>(x)) * PetscMax(au, ay);
1762:     const auto tol   = tola + tolr;
1763:     const auto err   = PetscAbsScalar(u - y);
1764:     const auto tup_a = this->compute_norm_(err, tola);
1765:     const auto tup_r = this->compute_norm_(err, tolr);
1766:     const auto tup_n = this->compute_norm_(err, tol);

1768:     return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc};
1769:   }
1770: };

1772: template <NormType wnormtype>
1773: struct ErrorWNormETransform : ErrorWNormTransformBase<wnormtype> {
1774:   using base_type     = ErrorWNormTransformBase<wnormtype>;
1775:   using result_type   = typename base_type::result_type;
1776:   using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar, PetscScalar, PetscScalar>;

1778:   using base_type::base_type;

1780:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL result_type operator()(const argument_type &x) const noexcept
1781:   {
1782:     const auto au    = PetscAbsScalar(thrust::get<0>(x));
1783:     const auto ay    = PetscAbsScalar(thrust::get<1>(x));
1784:     const auto skip  = au < this->ignore_max_ || ay < this->ignore_max_;
1785:     const auto tola  = skip ? 0.0 : PetscRealPart(thrust::get<3>(x));
1786:     const auto tolr  = skip ? 0.0 : PetscRealPart(thrust::get<4>(x)) * PetscMax(au, ay);
1787:     const auto tol   = tola + tolr;
1788:     const auto err   = PetscAbsScalar(thrust::get<2>(x));
1789:     const auto tup_a = this->compute_norm_(err, tola);
1790:     const auto tup_r = this->compute_norm_(err, tolr);
1791:     const auto tup_n = this->compute_norm_(err, tol);

1793:     return {tup_n.norm, tup_a.norm, tup_r.norm, tup_n.loc, tup_a.loc, tup_r.loc};
1794:   }
1795: };

1797: template <NormType wnormtype>
1798: struct ErrorWNormReduce {
1799:   using value_type = typename ErrorWNormTransformBase<wnormtype>::result_type;

1801:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept
1802:   {
1803:     // cannot use lhs.get<0>() etc since the using decl above ambiguates the fact that
1804:     // result_type is a template, so in order to fix this we would need to write:
1805:     //
1806:     // lhs.template get<0>()
1807:     //
1808:     // which is unseemly.
1809:     if (wnormtype == NORM_INFINITY) {
1810:       // clang-format off
1811:       return {
1812:         PetscMax(thrust::get<0>(lhs), thrust::get<0>(rhs)),
1813:         PetscMax(thrust::get<1>(lhs), thrust::get<1>(rhs)),
1814:         PetscMax(thrust::get<2>(lhs), thrust::get<2>(rhs)),
1815:         thrust::get<3>(lhs) + thrust::get<3>(rhs),
1816:         thrust::get<4>(lhs) + thrust::get<4>(rhs),
1817:         thrust::get<5>(lhs) + thrust::get<5>(rhs)
1818:       };
1819:       // clang-format on
1820:     } else {
1821:       // clang-format off
1822:       return {
1823:         thrust::get<0>(lhs) + thrust::get<0>(rhs),
1824:         thrust::get<1>(lhs) + thrust::get<1>(rhs),
1825:         thrust::get<2>(lhs) + thrust::get<2>(rhs),
1826:         thrust::get<3>(lhs) + thrust::get<3>(rhs),
1827:         thrust::get<4>(lhs) + thrust::get<4>(rhs),
1828:         thrust::get<5>(lhs) + thrust::get<5>(rhs)
1829:       };
1830:       // clang-format on
1831:     }
1832:   }
1833: };

1835: template <template <NormType> class WNormTransformType, typename Tuple, typename cupmStream_t>
1836: inline PetscErrorCode ExecuteWNorm(Tuple &&first, Tuple &&last, NormType wnormtype, cupmStream_t stream, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
1837: {
1838:   auto      begin = thrust::make_zip_iterator(std::forward<Tuple>(first));
1839:   auto      end   = thrust::make_zip_iterator(std::forward<Tuple>(last));
1840:   PetscReal n = 0, na = 0, nr = 0;
1841:   PetscInt  n_loc = 0, na_loc = 0, nr_loc = 0;

1843:   PetscFunctionBegin;
1844:   // clang-format off
1845:   if (wnormtype == NORM_INFINITY) {
1846:     PetscCallThrust(
1847:       thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL(
1848:         thrust::transform_reduce,
1849:         stream,
1850:         std::move(begin),
1851:         std::move(end),
1852:         WNormTransformType<NORM_INFINITY>{ignore_max},
1853:         thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc),
1854:         ErrorWNormReduce<NORM_INFINITY>{}
1855:       )
1856:     );
1857:   } else {
1858:     PetscCallThrust(
1859:       thrust::tie(*norm, *norma, *normr, *norm_loc, *norma_loc, *normr_loc) = THRUST_CALL(
1860:         thrust::transform_reduce,
1861:         stream,
1862:         std::move(begin),
1863:         std::move(end),
1864:         WNormTransformType<NORM_2>{ignore_max},
1865:         thrust::make_tuple(n, na, nr, n_loc, na_loc, nr_loc),
1866:         ErrorWNormReduce<NORM_2>{}
1867:       )
1868:     );
1869:   }
1870:   // clang-format on
1871:   if (wnormtype == NORM_2) {
1872:     *norm  = PetscSqrtReal(*norm);
1873:     *norma = PetscSqrtReal(*norma);
1874:     *normr = PetscSqrtReal(*normr);
1875:   }
1876:   PetscFunctionReturn(PETSC_SUCCESS);
1877: }

1879: } // namespace detail

1881: // v->ops->errorwnorm
1882: template <device::cupm::DeviceType T>
1883: inline PetscErrorCode VecSeq_CUPM<T>::ErrorWnorm(Vec U, Vec Y, Vec E, NormType wnormtype, PetscReal atol, Vec vatol, PetscReal rtol, Vec vrtol, PetscReal ignore_max, PetscReal *norm, PetscInt *norm_loc, PetscReal *norma, PetscInt *norma_loc, PetscReal *normr, PetscInt *normr_loc) noexcept
1884: {
1885:   const auto         nl  = U->map->n;
1886:   auto               ait = thrust::make_constant_iterator(static_cast<PetscScalar>(atol));
1887:   auto               rit = thrust::make_constant_iterator(static_cast<PetscScalar>(rtol));
1888:   PetscDeviceContext dctx;
1889:   cupmStream_t       stream;

1891:   PetscFunctionBegin;
1892:   PetscCall(GetHandles_(&dctx, &stream));
1893:   {
1894:     const auto ConditionalDeviceArrayRead = [&](Vec v) {
1895:       if (v) {
1896:         return thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
1897:       } else {
1898:         return thrust::device_ptr<PetscScalar>{nullptr};
1899:       }
1900:     };

1902:     const auto uarr = DeviceArrayRead(dctx, U);
1903:     const auto yarr = DeviceArrayRead(dctx, Y);
1904:     const auto uptr = thrust::device_pointer_cast(uarr.data());
1905:     const auto yptr = thrust::device_pointer_cast(yarr.data());
1906:     const auto eptr = ConditionalDeviceArrayRead(E);
1907:     const auto rptr = ConditionalDeviceArrayRead(vrtol);
1908:     const auto aptr = ConditionalDeviceArrayRead(vatol);

1910:     if (!vatol && !vrtol) {
1911:       if (E) {
1912:         // clang-format off
1913:         PetscCall(
1914:           detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1915:             thrust::make_tuple(uptr, yptr, eptr, ait, rit),
1916:             thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rit),
1917:             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1918:           )
1919:         );
1920:         // clang-format on
1921:       } else {
1922:         // clang-format off
1923:         PetscCall(
1924:           detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1925:             thrust::make_tuple(uptr, yptr, ait, rit),
1926:             thrust::make_tuple(uptr + nl, yptr + nl, ait, rit),
1927:             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1928:           )
1929:         );
1930:         // clang-format on
1931:       }
1932:     } else if (!vatol) {
1933:       if (E) {
1934:         // clang-format off
1935:         PetscCall(
1936:           detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1937:             thrust::make_tuple(uptr, yptr, eptr, ait, rptr),
1938:             thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, ait, rptr + nl),
1939:             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1940:           )
1941:         );
1942:         // clang-format on
1943:       } else {
1944:         // clang-format off
1945:         PetscCall(
1946:           detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1947:             thrust::make_tuple(uptr, yptr, ait, rptr),
1948:             thrust::make_tuple(uptr + nl, yptr + nl, ait, rptr + nl),
1949:             wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1950:           )
1951:         );
1952:         // clang-format on
1953:       }
1954:     } else if (!vrtol) {
1955:       if (E) {
1956:         // clang-format off
1957:           PetscCall(
1958:             detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1959:               thrust::make_tuple(uptr, yptr, eptr, aptr, rit),
1960:               thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rit),
1961:               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1962:             )
1963:           );
1964:         // clang-format on
1965:       } else {
1966:         // clang-format off
1967:           PetscCall(
1968:             detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1969:               thrust::make_tuple(uptr, yptr, aptr, rit),
1970:               thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rit),
1971:               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1972:             )
1973:           );
1974:         // clang-format on
1975:       }
1976:     } else {
1977:       if (E) {
1978:         // clang-format off
1979:           PetscCall(
1980:             detail::ExecuteWNorm<detail::ErrorWNormETransform>(
1981:               thrust::make_tuple(uptr, yptr, eptr, aptr, rptr),
1982:               thrust::make_tuple(uptr + nl, yptr + nl, eptr + nl, aptr + nl, rptr + nl),
1983:               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1984:             )
1985:           );
1986:         // clang-format on
1987:       } else {
1988:         // clang-format off
1989:           PetscCall(
1990:             detail::ExecuteWNorm<detail::ErrorWNormTransform>(
1991:               thrust::make_tuple(uptr, yptr, aptr, rptr),
1992:               thrust::make_tuple(uptr + nl, yptr + nl, aptr + nl, rptr + nl),
1993:               wnormtype, stream, ignore_max, norm, norm_loc, norma, norma_loc, normr, normr_loc
1994:             )
1995:           );
1996:         // clang-format on
1997:       }
1998:     }
1999:   }
2000:   PetscFunctionReturn(PETSC_SUCCESS);
2001: }

2003: namespace detail
2004: {
2005: struct dotnorm2_mult {
2006:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscScalar, PetscScalar> operator()(const PetscScalar &s, const PetscScalar &t) const noexcept
2007:   {
2008:     const auto conjt = PetscConj(t);

2010:     return {s * conjt, t * conjt};
2011:   }
2012: };

2014: // it is positively __bananas__ that thrust does not define default operator+ for tuples... I
2015: // would do it myself but now I am worried that they do so on purpose...
2016: struct dotnorm2_tuple_plus {
2017:   using value_type = thrust::tuple<PetscScalar, PetscScalar>;

2019:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL value_type operator()(const value_type &lhs, const value_type &rhs) const noexcept { return {thrust::get<0>(lhs) + thrust::get<0>(rhs), thrust::get<1>(lhs) + thrust::get<1>(rhs)}; }
2020: };

2022: } // namespace detail

2024: // v->ops->dotnorm2
2025: template <device::cupm::DeviceType T>
2026: inline PetscErrorCode VecSeq_CUPM<T>::DotNorm2(Vec s, Vec t, PetscScalar *dp, PetscScalar *nm) noexcept
2027: {
2028:   PetscDeviceContext dctx;
2029:   cupmStream_t       stream;

2031:   PetscFunctionBegin;
2032:   PetscCall(GetHandles_(&dctx, &stream));
2033:   {
2034:     PetscScalar dpt = 0.0, nmt = 0.0;
2035:     const auto  sdptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, s).data());

2037:     // clang-format off
2038:     PetscCallThrust(
2039:       thrust::tie(*dp, *nm) = THRUST_CALL(
2040:         thrust::inner_product,
2041:         stream,
2042:         sdptr, sdptr+s->map->n, thrust::device_pointer_cast(DeviceArrayRead(dctx, t).data()),
2043:         thrust::make_tuple(dpt, nmt),
2044:         detail::dotnorm2_tuple_plus{}, detail::dotnorm2_mult{}
2045:       );
2046:     );
2047:     // clang-format on
2048:   }
2049:   PetscFunctionReturn(PETSC_SUCCESS);
2050: }

2052: namespace detail
2053: {
2054: struct conjugate {
2055:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const PetscScalar &x) const noexcept { return PetscConj(x); }
2056: };

2058: } // namespace detail

2060: // v->ops->conjugate
2061: template <device::cupm::DeviceType T>
2062: inline PetscErrorCode VecSeq_CUPM<T>::ConjugateAsync(Vec xin, PetscDeviceContext dctx) noexcept
2063: {
2064:   PetscFunctionBegin;
2065:   if (PetscDefined(USE_COMPLEX)) PetscCall(PointwiseUnary_(detail::conjugate{}, xin, nullptr, dctx));
2066:   PetscFunctionReturn(PETSC_SUCCESS);
2067: }

2069: // v->ops->conjugate
2070: template <device::cupm::DeviceType T>
2071: inline PetscErrorCode VecSeq_CUPM<T>::Conjugate(Vec xin) noexcept
2072: {
2073:   PetscFunctionBegin;
2074:   PetscCall(ConjugateAsync(xin, nullptr));
2075:   PetscFunctionReturn(PETSC_SUCCESS);
2076: }

2078: namespace detail
2079: {

2081: struct real_part {
2082:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL thrust::tuple<PetscReal, PetscInt> operator()(const thrust::tuple<PetscScalar, PetscInt> &x) const noexcept { return {PetscRealPart(thrust::get<0>(x)), thrust::get<1>(x)}; }

2084:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscReal operator()(const PetscScalar &x) const noexcept { return PetscRealPart(x); }
2085: };

2087: // deriving from Operator allows us to "store" an instance of the operator in the class but
2088: // also take advantage of empty base class optimization if the operator is stateless
2089: template <typename Operator>
2090: class tuple_compare : Operator {
2091: public:
2092:   using tuple_type    = thrust::tuple<PetscReal, PetscInt>;
2093:   using operator_type = Operator;

2095:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL tuple_type operator()(const tuple_type &x, const tuple_type &y) const noexcept
2096:   {
2097:     if (op_()(thrust::get<0>(y), thrust::get<0>(x))) {
2098:       // if y is strictly greater/less than x, return y
2099:       return y;
2100:     } else if (thrust::get<0>(y) == thrust::get<0>(x)) {
2101:       // if equal, prefer lower index
2102:       return thrust::get<1>(y) < thrust::get<1>(x) ? y : x;
2103:     }
2104:     // otherwise return x
2105:     return x;
2106:   }

2108: private:
2109:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL const operator_type &op_() const noexcept { return *this; }
2110: };

2112: } // namespace detail

2114: template <device::cupm::DeviceType T>
2115: template <typename TupleFuncT, typename UnaryFuncT>
2116: inline PetscErrorCode VecSeq_CUPM<T>::MinMax_(TupleFuncT &&tuple_ftr, UnaryFuncT &&unary_ftr, Vec v, PetscInt *p, PetscReal *m) noexcept
2117: {
2118:   PetscFunctionBegin;
2119:   PetscCheckTypeNames(v, VECSEQCUPM(), VECMPICUPM());
2120:   if (p) *p = -1;
2121:   if (const auto n = v->map->n) {
2122:     PetscDeviceContext dctx;
2123:     cupmStream_t       stream;

2125:     PetscCall(GetHandles_(&dctx, &stream));
2126:     // needed to:
2127:     // 1. switch between transform_reduce and reduce
2128:     // 2. strip the real_part functor from the arguments
2129: #if PetscDefined(USE_COMPLEX)
2130:   #define THRUST_MINMAX_REDUCE(...) THRUST_CALL(thrust::transform_reduce, __VA_ARGS__)
2131: #else
2132:   #define THRUST_MINMAX_REDUCE(s, b, e, real_part__, ...) THRUST_CALL(thrust::reduce, s, b, e, __VA_ARGS__)
2133: #endif
2134:     {
2135:       const auto vptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());

2137:       if (p) {
2138:         // clang-format off
2139:         const auto zip = thrust::make_zip_iterator(
2140:           thrust::make_tuple(std::move(vptr), thrust::make_counting_iterator(PetscInt{0}))
2141:         );
2142:         // clang-format on
2143:         // need to use preprocessor conditionals since otherwise thrust complains about not being
2144:         // able to convert a thrust::device_reference to a PetscReal on complex
2145:         // builds...
2146:         // clang-format off
2147:         PetscCallThrust(
2148:           thrust::tie(*m, *p) = THRUST_MINMAX_REDUCE(
2149:             stream, zip, zip + n, detail::real_part{},
2150:             thrust::make_tuple(*m, *p), std::forward<TupleFuncT>(tuple_ftr)
2151:           );
2152:         );
2153:         // clang-format on
2154:       } else {
2155:         // clang-format off
2156:         PetscCallThrust(
2157:           *m = THRUST_MINMAX_REDUCE(
2158:             stream, vptr, vptr + n, detail::real_part{},
2159:             *m, std::forward<UnaryFuncT>(unary_ftr)
2160:           );
2161:         );
2162:         // clang-format on
2163:       }
2164:     }
2165: #undef THRUST_MINMAX_REDUCE
2166:   }
2167:   // REVIEW ME: flops?
2168:   PetscFunctionReturn(PETSC_SUCCESS);
2169: }

2171: // v->ops->max
2172: template <device::cupm::DeviceType T>
2173: inline PetscErrorCode VecSeq_CUPM<T>::Max(Vec v, PetscInt *p, PetscReal *m) noexcept
2174: {
2175: #if CCCL_VERSION >= 3001000
2176:   using tuple_functor = detail::tuple_compare<cuda::std::greater<PetscReal>>;
2177:   using unary_functor = cuda::maximum<PetscReal>;
2178: #else
2179:   using tuple_functor = detail::tuple_compare<thrust::greater<PetscReal>>;
2180:   using unary_functor = thrust::maximum<PetscReal>;
2181: #endif

2183:   PetscFunctionBegin;
2184:   *m = PETSC_MIN_REAL;
2185:   // use {} constructor syntax otherwise most vexing parse
2186:   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
2187:   PetscFunctionReturn(PETSC_SUCCESS);
2188: }

2190: // v->ops->min
2191: template <device::cupm::DeviceType T>
2192: inline PetscErrorCode VecSeq_CUPM<T>::Min(Vec v, PetscInt *p, PetscReal *m) noexcept
2193: {
2194: #if CCCL_VERSION >= 3001000
2195:   using tuple_functor = detail::tuple_compare<cuda::std::less<PetscReal>>;
2196:   using unary_functor = cuda::minimum<PetscReal>;
2197: #else
2198:   using tuple_functor = detail::tuple_compare<thrust::less<PetscReal>>;
2199:   using unary_functor = thrust::minimum<PetscReal>;
2200: #endif

2202:   PetscFunctionBegin;
2203:   *m = PETSC_MAX_REAL;
2204:   // use {} constructor syntax otherwise most vexing parse
2205:   PetscCall(MinMax_(tuple_functor{}, unary_functor{}, v, p, m));
2206:   PetscFunctionReturn(PETSC_SUCCESS);
2207: }

2209: // v->ops->sum
2210: template <device::cupm::DeviceType T>
2211: inline PetscErrorCode VecSeq_CUPM<T>::Sum(Vec v, PetscScalar *sum) noexcept
2212: {
2213:   PetscFunctionBegin;
2214:   if (const auto n = v->map->n) {
2215:     PetscDeviceContext dctx;
2216:     cupmStream_t       stream;

2218:     PetscCall(GetHandles_(&dctx, &stream));
2219:     const auto dptr = thrust::device_pointer_cast(DeviceArrayRead(dctx, v).data());
2220:     // REVIEW ME: why not cupmBlasXasum()?
2221:     PetscCallThrust(*sum = THRUST_CALL(thrust::reduce, stream, dptr, dptr + n, PetscScalar{0.0}););
2222:     // REVIEW ME: must be at least n additions
2223:     PetscCall(PetscLogGpuFlops(n));
2224:   } else {
2225:     *sum = 0.0;
2226:   }
2227:   PetscFunctionReturn(PETSC_SUCCESS);
2228: }

2230: template <device::cupm::DeviceType T>
2231: inline PetscErrorCode VecSeq_CUPM<T>::ShiftAsync(Vec v, PetscScalar shift, PetscDeviceContext dctx) noexcept
2232: {
2233:   PetscFunctionBegin;
2234:   PetscCall(PointwiseUnary_(device::cupm::functors::make_plus_equals(shift), v, nullptr, dctx));
2235:   PetscFunctionReturn(PETSC_SUCCESS);
2236: }

2238: template <device::cupm::DeviceType T>
2239: inline PetscErrorCode VecSeq_CUPM<T>::Shift(Vec v, PetscScalar shift) noexcept
2240: {
2241:   PetscFunctionBegin;
2242:   PetscCall(ShiftAsync(v, shift, nullptr));
2243:   PetscFunctionReturn(PETSC_SUCCESS);
2244: }

2246: template <device::cupm::DeviceType T>
2247: inline PetscErrorCode VecSeq_CUPM<T>::SetRandom(Vec v, PetscRandom rand) noexcept
2248: {
2249:   PetscFunctionBegin;
2250:   if (const auto n = v->map->n) {
2251:     PetscBool          iscurand;
2252:     PetscDeviceContext dctx;

2254:     PetscCall(GetHandles_(&dctx));
2255:     PetscCall(PetscObjectTypeCompare(PetscObjectCast(rand), PETSCCURAND, &iscurand));
2256:     if (iscurand) PetscCall(PetscRandomGetValues(rand, n, DeviceArrayWrite(dctx, v)));
2257:     else PetscCall(PetscRandomGetValues(rand, n, HostArrayWrite(dctx, v)));
2258:   } else {
2259:     PetscCall(MaybeIncrementEmptyLocalVec(v));
2260:   }
2261:   // REVIEW ME: flops????
2262:   // REVIEW ME: Timing???
2263:   PetscFunctionReturn(PETSC_SUCCESS);
2264: }

2266: // v->ops->setpreallocation
2267: template <device::cupm::DeviceType T>
2268: inline PetscErrorCode VecSeq_CUPM<T>::SetPreallocationCOO(Vec v, PetscCount ncoo, const PetscInt coo_i[]) noexcept
2269: {
2270:   PetscDeviceContext dctx;

2272:   PetscFunctionBegin;
2273:   PetscCall(GetHandles_(&dctx));
2274:   PetscCall(VecSetPreallocationCOO_Seq(v, ncoo, coo_i));
2275:   PetscCall(SetPreallocationCOO_CUPMBase(v, ncoo, coo_i, dctx));
2276:   PetscFunctionReturn(PETSC_SUCCESS);
2277: }

2279: // v->ops->setvaluescoo
2280: template <device::cupm::DeviceType T>
2281: inline PetscErrorCode VecSeq_CUPM<T>::SetValuesCOO(Vec x, const PetscScalar v[], InsertMode imode) noexcept
2282: {
2283:   auto               vv = const_cast<PetscScalar *>(v);
2284:   PetscMemType       memtype;
2285:   PetscDeviceContext dctx;
2286:   cupmStream_t       stream;

2288:   PetscFunctionBegin;
2289:   PetscCall(GetHandles_(&dctx, &stream));
2290:   PetscCall(PetscGetMemType(v, &memtype));
2291:   if (PetscMemTypeHost(memtype)) {
2292:     const auto size = VecIMPLCast(x)->coo_n;

2294:     // If user gave v[] in host, we might need to copy it to device if any
2295:     PetscCall(PetscDeviceMalloc(dctx, PETSC_MEMTYPE_CUPM(), size, &vv));
2296:     PetscCall(PetscCUPMMemcpyAsync(vv, v, size, cupmMemcpyHostToDevice, stream));
2297:   }

2299:   if (const auto n = x->map->n) {
2300:     const auto vcu = VecCUPMCast(x);

2302:     PetscCall(PetscCUPMLaunchKernel1D(n, 0, stream, kernels::add_coo_values, vv, n, vcu->jmap1_d, vcu->perm1_d, imode, imode == INSERT_VALUES ? DeviceArrayWrite(dctx, x).data() : DeviceArrayReadWrite(dctx, x).data()));
2303:   } else {
2304:     PetscCall(MaybeIncrementEmptyLocalVec(x));
2305:   }

2307:   if (PetscMemTypeHost(memtype)) PetscCall(PetscDeviceFree(dctx, vv));
2308:   PetscCall(PetscDeviceContextSynchronize(dctx));
2309:   PetscFunctionReturn(PETSC_SUCCESS);
2310: }

2312: } // namespace impl

2314: } // namespace cupm

2316: } // namespace vec

2318: } // namespace Petsc