Actual source code: matmpidensecupm.hpp

  1: #pragma once

  3: #include <petsc/private/matdensecupmimpl.h>
  4: #include <../src/mat/impls/dense/mpi/mpidense.h>

  6: #include <../src/mat/impls/dense/seq/cupm/matseqdensecupm.hpp>
  7: #include <../src/vec/vec/impls/mpi/cupm/vecmpicupm.hpp>

  9: namespace Petsc
 10: {

 12: namespace mat
 13: {

 15: namespace cupm
 16: {

 18: namespace impl
 19: {

 21: template <device::cupm::DeviceType T>
 22: class PETSC_SINGLE_LIBRARY_VISIBILITY_INTERNAL MatDense_MPI_CUPM : MatDense_CUPM<T, MatDense_MPI_CUPM<T>> {
 23: public:
 24:   MATDENSECUPM_HEADER(T, MatDense_MPI_CUPM<T>);

 26: private:
 27:   PETSC_NODISCARD static constexpr Mat_MPIDense *MatIMPLCast_(Mat) noexcept;
 28:   PETSC_NODISCARD static constexpr MatType       MATIMPLCUPM_() noexcept;

 30:   static PetscErrorCode SetPreallocation_(Mat, PetscDeviceContext, PetscScalar *) noexcept;

 32:   template <bool to_host>
 33:   static PetscErrorCode Convert_Dispatch_(Mat, MatType, MatReuse, Mat *) noexcept;

 35: public:
 36:   PETSC_NODISCARD static constexpr const char *MatConvert_mpidensecupm_mpidense_C() noexcept;

 38:   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept;
 39:   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept;

 41:   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept;
 42:   PETSC_NODISCARD static constexpr const char *MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept;

 44:   static PetscErrorCode Create(Mat) noexcept;

 46:   static PetscErrorCode BindToCPU(Mat, PetscBool) noexcept;
 47:   static PetscErrorCode Convert_MPIDenseCUPM_MPIDense(Mat, MatType, MatReuse, Mat *) noexcept;
 48:   static PetscErrorCode Convert_MPIDense_MPIDenseCUPM(Mat, MatType, MatReuse, Mat *) noexcept;

 50:   template <PetscMemType, PetscMemoryAccessMode>
 51:   static PetscErrorCode GetArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;
 52:   template <PetscMemType, PetscMemoryAccessMode>
 53:   static PetscErrorCode RestoreArray(Mat, PetscScalar **, PetscDeviceContext = nullptr) noexcept;

 55: private:
 56:   template <PetscMemType mtype, PetscMemoryAccessMode mode>
 57:   static PetscErrorCode GetArrayC_(Mat m, PetscScalar **p) noexcept
 58:   {
 59:     return GetArray<mtype, mode>(m, p);
 60:   }

 62:   template <PetscMemType mtype, PetscMemoryAccessMode mode>
 63:   static PetscErrorCode RestoreArrayC_(Mat m, PetscScalar **p) noexcept
 64:   {
 65:     return RestoreArray<mtype, mode>(m, p);
 66:   }

 68: public:
 69:   template <PetscMemoryAccessMode>
 70:   static PetscErrorCode GetColumnVec(Mat, PetscInt, Vec *) noexcept;
 71:   template <PetscMemoryAccessMode>
 72:   static PetscErrorCode RestoreColumnVec(Mat, PetscInt, Vec *) noexcept;

 74:   static PetscErrorCode PlaceArray(Mat, const PetscScalar *) noexcept;
 75:   static PetscErrorCode ReplaceArray(Mat, const PetscScalar *) noexcept;
 76:   static PetscErrorCode ResetArray(Mat) noexcept;
 77: };

 79: } // namespace impl

 81: namespace
 82: {

 84: // Declare this here so that the functions below can make use of it
 85: template <device::cupm::DeviceType T>
 86: inline PetscErrorCode MatCreateMPIDenseCUPM(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr, bool preallocate = true) noexcept
 87: {
 88:   PetscFunctionBegin;
 89:   PetscCall(impl::MatDense_MPI_CUPM<T>::CreateIMPLDenseCUPM(comm, m, n, M, N, data, A, dctx, preallocate));
 90:   PetscFunctionReturn(PETSC_SUCCESS);
 91: }

 93: } // anonymous namespace

 95: namespace impl
 96: {

 98: // ==========================================================================================
 99: // MatDense_MPI_CUPM -- Private API
100: // ==========================================================================================

102: template <device::cupm::DeviceType T>
103: inline constexpr Mat_MPIDense *MatDense_MPI_CUPM<T>::MatIMPLCast_(Mat m) noexcept
104: {
105:   return static_cast<Mat_MPIDense *>(m->data);
106: }

108: template <device::cupm::DeviceType T>
109: inline constexpr MatType MatDense_MPI_CUPM<T>::MATIMPLCUPM_() noexcept
110: {
111:   return MATMPIDENSECUPM();
112: }

114: // ==========================================================================================

116: template <device::cupm::DeviceType T>
117: inline PetscErrorCode MatDense_MPI_CUPM<T>::SetPreallocation_(Mat A, PetscDeviceContext dctx, PetscScalar *device_array) noexcept
118: {
119:   PetscFunctionBegin;
120:   if (auto &mimplA = MatIMPLCast(A)->A) {
121:     PetscCall(MatSetType(mimplA, MATSEQDENSECUPM()));
122:     PetscCall(MatDense_Seq_CUPM<T>::SetPreallocation(mimplA, dctx, device_array));
123:   } else {
124:     PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, device_array, &mimplA, dctx));
125:   }
126:   PetscFunctionReturn(PETSC_SUCCESS);
127: }

129: template <device::cupm::DeviceType T>
130: template <bool to_host>
131: inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_Dispatch_(Mat M, MatType, MatReuse reuse, Mat *newmat) noexcept
132: {
133:   PetscFunctionBegin;
134:   if (reuse == MAT_INITIAL_MATRIX) {
135:     PetscCall(MatDuplicate(M, MAT_COPY_VALUES, newmat));
136:   } else if (reuse == MAT_REUSE_MATRIX) {
137:     PetscCall(MatCopy(M, *newmat, SAME_NONZERO_PATTERN));
138:   }
139:   {
140:     const auto B    = *newmat;
141:     const auto pobj = PetscObjectCast(B);

143:     if (to_host) PetscCall(BindToCPU(B, PETSC_TRUE));
144:     else PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUPM()));

146:     PetscCall(PetscStrFreeAllocpy(to_host ? VECSTANDARD : VecMPI_CUPM::VECCUPM(), &B->defaultvectype));
147:     PetscCall(PetscObjectChangeTypeName(pobj, to_host ? MATMPIDENSE : MATMPIDENSECUPM()));

149:     // ============================================================
150:     // Composed Ops
151:     // ============================================================
152:     MatComposeOp_CUPM(to_host, pobj, MatConvert_mpidensecupm_mpidense_C(), nullptr, Convert_MPIDenseCUPM_MPIDense);
153:     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaij_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
154:     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C(), nullptr, MatProductSetFromOptions_MPIAIJ_MPIDense);
155:     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaij_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
156:     MatComposeOp_CUPM(to_host, pobj, MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C(), nullptr, MatProductSetFromOptions_MPIDense_MPIAIJ);
157:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArray_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
158:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayRead_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
159:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMGetArrayWrite_C(), nullptr, GetArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
160:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArray_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ_WRITE>);
161:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayRead_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_READ>);
162:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMRestoreArrayWrite_C(), nullptr, RestoreArrayC_<PETSC_MEMTYPE_DEVICE, PETSC_MEMORY_ACCESS_WRITE>);
163:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMPlaceArray_C(), nullptr, PlaceArray);
164:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMResetArray_C(), nullptr, ResetArray);
165:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMReplaceArray_C(), nullptr, ReplaceArray);
166:     MatComposeOp_CUPM(to_host, pobj, MatDenseCUPMSetPreallocation_C(), nullptr, SetPreallocation);

168:     if (to_host) {
169:       if (auto &m_A = MatIMPLCast(B)->A) PetscCall(MatConvert(m_A, MATSEQDENSE, MAT_INPLACE_MATRIX, &m_A));
170:       B->offloadmask = PETSC_OFFLOAD_CPU;
171:     } else {
172:       if (auto &m_A = MatIMPLCast(B)->A) {
173:         PetscCall(MatConvert(m_A, MATSEQDENSECUPM(), MAT_INPLACE_MATRIX, &m_A));
174:         B->offloadmask = PETSC_OFFLOAD_BOTH;
175:       } else {
176:         B->offloadmask = PETSC_OFFLOAD_UNALLOCATED;
177:       }
178:       PetscCall(BindToCPU(B, PETSC_FALSE));
179:     }

181:     // ============================================================
182:     // Function Pointer Ops
183:     // ============================================================
184:     MatSetOp_CUPM(to_host, B, getdiagonal, MatGetDiagonal_MPIDense, GetDiagonal);
185:     MatSetOp_CUPM(to_host, B, bindtocpu, nullptr, BindToCPU);
186:   }
187:   PetscFunctionReturn(PETSC_SUCCESS);
188: }

190: // ==========================================================================================
191: // MatDense_MPI_CUPM -- Public API
192: // ==========================================================================================

194: template <device::cupm::DeviceType T>
195: inline constexpr const char *MatDense_MPI_CUPM<T>::MatConvert_mpidensecupm_mpidense_C() noexcept
196: {
197:   return T == device::cupm::DeviceType::CUDA ? "MatConvert_mpidensecuda_mpidense_C" : "MatConvert_mpidensehip_mpidense_C";
198: }

200: template <device::cupm::DeviceType T>
201: inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaij_mpidensecupm_C() noexcept
202: {
203:   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaij_mpidensecuda_C" : "MatProductSetFromOptions_mpiaij_mpidensehip_C";
204: }

206: template <device::cupm::DeviceType T>
207: inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaij_C() noexcept
208: {
209:   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaij_C" : "MatProductSetFromOptions_mpidensehip_mpiaij_C";
210: }

212: template <device::cupm::DeviceType T>
213: inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpiaijcupmsparse_mpidensecupm_C() noexcept
214: {
215:   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpiaijcusparse_mpidensecuda_C" : "MatProductSetFromOptions_mpiaijhipsparse_mpidensehip_C";
216: }

218: template <device::cupm::DeviceType T>
219: inline constexpr const char *MatDense_MPI_CUPM<T>::MatProductSetFromOptions_mpidensecupm_mpiaijcupmsparse_C() noexcept
220: {
221:   return T == device::cupm::DeviceType::CUDA ? "MatProductSetFromOptions_mpidensecuda_mpiaijcusparse_C" : "MatProductSetFromOptions_mpidensehip_mpiaijhipsparse_C";
222: }

224: // ==========================================================================================

226: template <device::cupm::DeviceType T>
227: inline PetscErrorCode MatDense_MPI_CUPM<T>::Create(Mat A) noexcept
228: {
229:   PetscFunctionBegin;
230:   PetscCall(MatCreate_MPIDense(A));
231:   PetscCall(Convert_MPIDense_MPIDenseCUPM(A, MATMPIDENSECUPM(), MAT_INPLACE_MATRIX, &A));
232:   PetscFunctionReturn(PETSC_SUCCESS);
233: }

235: // ==========================================================================================

237: template <device::cupm::DeviceType T>
238: inline PetscErrorCode MatDense_MPI_CUPM<T>::BindToCPU(Mat A, PetscBool usehost) noexcept
239: {
240:   const auto mimpl = MatIMPLCast(A);
241:   const auto pobj  = PetscObjectCast(A);

243:   PetscFunctionBegin;
244:   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
245:   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
246:   if (const auto mimpl_A = mimpl->A) PetscCall(MatBindToCPU(mimpl_A, usehost));
247:   A->boundtocpu = usehost;
248:   PetscCall(PetscStrFreeAllocpy(usehost ? PETSCRANDER48 : PETSCDEVICERAND(), &A->defaultrandtype));
249:   if (!usehost) {
250:     PetscBool iscupm;

252:     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cvec), VecMPI_CUPM::VECMPICUPM(), &iscupm));
253:     if (!iscupm) PetscCall(VecDestroy(&mimpl->cvec));
254:     PetscCall(PetscObjectTypeCompare(PetscObjectCast(mimpl->cmat), MATMPIDENSECUPM(), &iscupm));
255:     if (!iscupm) PetscCall(MatDestroy(&mimpl->cmat));
256:   }

258:   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVec_C", MatDenseGetColumnVec_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
259:   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVec_C", MatDenseRestoreColumnVec_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ_WRITE>);
260:   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecRead_C", MatDenseGetColumnVecRead_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_READ>);
261:   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecRead_C", MatDenseRestoreColumnVecRead_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_READ>);
262:   MatComposeOp_CUPM(usehost, pobj, "MatDenseGetColumnVecWrite_C", MatDenseGetColumnVecWrite_MPIDense, GetColumnVec<PETSC_MEMORY_ACCESS_WRITE>);
263:   MatComposeOp_CUPM(usehost, pobj, "MatDenseRestoreColumnVecWrite_C", MatDenseRestoreColumnVecWrite_MPIDense, RestoreColumnVec<PETSC_MEMORY_ACCESS_WRITE>);

265:   MatSetOp_CUPM(usehost, A, shift, MatShift_MPIDense, Shift);

267:   if (const auto mimpl_cmat = mimpl->cmat) PetscCall(MatBindToCPU(mimpl_cmat, usehost));
268:   PetscFunctionReturn(PETSC_SUCCESS);
269: }

271: template <device::cupm::DeviceType T>
272: inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDenseCUPM_MPIDense(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
273: {
274:   PetscFunctionBegin;
275:   PetscCall(Convert_Dispatch_</* to host */ true>(M, mtype, reuse, newmat));
276:   PetscFunctionReturn(PETSC_SUCCESS);
277: }

279: template <device::cupm::DeviceType T>
280: inline PetscErrorCode MatDense_MPI_CUPM<T>::Convert_MPIDense_MPIDenseCUPM(Mat M, MatType mtype, MatReuse reuse, Mat *newmat) noexcept
281: {
282:   PetscFunctionBegin;
283:   PetscCall(Convert_Dispatch_</* to host */ false>(M, mtype, reuse, newmat));
284:   PetscFunctionReturn(PETSC_SUCCESS);
285: }

287: // ==========================================================================================

289: template <device::cupm::DeviceType T>
290: template <PetscMemType, PetscMemoryAccessMode access>
291: inline PetscErrorCode MatDense_MPI_CUPM<T>::GetArray(Mat A, PetscScalar **array, PetscDeviceContext dctx) noexcept
292: {
293:   auto &mimplA = MatIMPLCast(A)->A;

295:   PetscFunctionBegin;
296:   if (!mimplA) PetscCall(MatCreateSeqDenseCUPM<T>(PETSC_COMM_SELF, A->rmap->n, A->cmap->N, nullptr, &mimplA, dctx));
297:   PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimplA, array));
298:   PetscFunctionReturn(PETSC_SUCCESS);
299: }

301: template <device::cupm::DeviceType T>
302: template <PetscMemType, PetscMemoryAccessMode access>
303: inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreArray(Mat A, PetscScalar **array, PetscDeviceContext) noexcept
304: {
305:   PetscFunctionBegin;
306:   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(MatIMPLCast(A)->A, array));
307:   PetscFunctionReturn(PETSC_SUCCESS);
308: }

310: // ==========================================================================================

312: template <device::cupm::DeviceType T>
313: template <PetscMemoryAccessMode access>
314: inline PetscErrorCode MatDense_MPI_CUPM<T>::GetColumnVec(Mat A, PetscInt col, Vec *v) noexcept
315: {
316:   using namespace vec::cupm;

318:   const auto mimpl   = MatIMPLCast(A);
319:   const auto mimpl_A = mimpl->A;
320:   const auto pobj    = PetscObjectCast(A);
321:   PetscInt   lda;

323:   PetscFunctionBegin;
324:   PetscCheck(!mimpl->vecinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
325:   PetscCheck(!mimpl->matinuse, PetscObjectComm(pobj), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
326:   mimpl->vecinuse = col + 1;

328:   if (!mimpl->cvec) PetscCall(MatDenseCreateColumnVec_Private(A, &mimpl->cvec));

330:   PetscCall(MatDenseGetLDA(mimpl_A, &lda));
331:   PetscCall(MatDenseCUPMGetArray_Private<T, access>(mimpl_A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
332:   PetscCall(VecCUPMPlaceArrayAsync<T>(mimpl->cvec, mimpl->ptrinuse + static_cast<std::size_t>(col) * static_cast<std::size_t>(lda)));

334:   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPush(mimpl->cvec));
335:   *v = mimpl->cvec;
336:   PetscFunctionReturn(PETSC_SUCCESS);
337: }

339: template <device::cupm::DeviceType T>
340: template <PetscMemoryAccessMode access>
341: inline PetscErrorCode MatDense_MPI_CUPM<T>::RestoreColumnVec(Mat A, PetscInt, Vec *v) noexcept
342: {
343:   using namespace vec::cupm;

345:   const auto mimpl = MatIMPLCast(A);
346:   const auto cvec  = mimpl->cvec;

348:   PetscFunctionBegin;
349:   PetscCheck(mimpl->vecinuse, PETSC_COMM_SELF, PETSC_ERR_ORDER, "Need to call MatDenseGetColumnVec() first");
350:   PetscCheck(cvec, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing internal column vector");
351:   mimpl->vecinuse = 0;

353:   PetscCall(MatDenseCUPMRestoreArray_Private<T, access>(mimpl->A, const_cast<PetscScalar **>(&mimpl->ptrinuse)));
354:   if (access == PETSC_MEMORY_ACCESS_READ) PetscCall(VecLockReadPop(cvec));
355:   PetscCall(VecCUPMResetArrayAsync<T>(cvec));

357:   if (v) *v = nullptr;
358:   PetscFunctionReturn(PETSC_SUCCESS);
359: }

361: // ==========================================================================================

363: template <device::cupm::DeviceType T>
364: inline PetscErrorCode MatDense_MPI_CUPM<T>::PlaceArray(Mat A, const PetscScalar *array) noexcept
365: {
366:   const auto mimpl = MatIMPLCast(A);

368:   PetscFunctionBegin;
369:   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
370:   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
371:   PetscCall(MatDenseCUPMPlaceArray<T>(mimpl->A, array));
372:   PetscFunctionReturn(PETSC_SUCCESS);
373: }

375: template <device::cupm::DeviceType T>
376: inline PetscErrorCode MatDense_MPI_CUPM<T>::ReplaceArray(Mat A, const PetscScalar *array) noexcept
377: {
378:   const auto mimpl = MatIMPLCast(A);

380:   PetscFunctionBegin;
381:   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
382:   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
383:   PetscCall(MatDenseCUPMReplaceArray<T>(mimpl->A, array));
384:   PetscFunctionReturn(PETSC_SUCCESS);
385: }

387: template <device::cupm::DeviceType T>
388: inline PetscErrorCode MatDense_MPI_CUPM<T>::ResetArray(Mat A) noexcept
389: {
390:   const auto mimpl = MatIMPLCast(A);

392:   PetscFunctionBegin;
393:   PetscCheck(!mimpl->vecinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreColumnVec() first");
394:   PetscCheck(!mimpl->matinuse, PetscObjectComm(PetscObjectCast(A)), PETSC_ERR_ORDER, "Need to call MatDenseRestoreSubMatrix() first");
395:   PetscCall(MatDenseCUPMResetArray<T>(mimpl->A));
396:   PetscFunctionReturn(PETSC_SUCCESS);
397: }

399: } // namespace impl

401: namespace
402: {

404: template <device::cupm::DeviceType T>
405: inline PetscErrorCode MatCreateDenseCUPM(MPI_Comm comm, PetscInt n, PetscInt m, PetscInt N, PetscInt M, PetscScalar *data, Mat *A, PetscDeviceContext dctx = nullptr) noexcept
406: {
407:   PetscMPIInt size;

409:   PetscFunctionBegin;
410:   PetscAssertPointer(A, 7);
411:   PetscCallMPI(MPI_Comm_size(comm, &size));
412:   if (size > 1) {
413:     PetscCall(MatCreateMPIDenseCUPM<T>(comm, n, m, N, M, data, A, dctx));
414:   } else {
415:     if (n == PETSC_DECIDE) n = N;
416:     if (m == PETSC_DECIDE) m = M;
417:     // It's OK here if both are PETSC_DECIDE since PetscSplitOwnership() will catch that down
418:     // the line
419:     PetscCall(MatCreateSeqDenseCUPM<T>(comm, n, m, data, A, dctx));
420:   }
421:   PetscFunctionReturn(PETSC_SUCCESS);
422: }

424: } // anonymous namespace

426: } // namespace cupm

428: } // namespace mat

430: } // namespace Petsc