Actual source code: matdiagonalkokkos.kokkos.cxx
1: #include <petscvec_kokkos.hpp>
2: #include <petsc_kokkos.hpp>
3: #include <petsc/private/kokkosimpl.hpp>
4: #include <petsc/private/vecimpl.h>
5: #include <petsc/private/matimpl.h>
7: PETSC_INTERN PetscErrorCode MatADot_Diagonal_SeqKokkos(Mat A, Vec x, Vec y, PetscScalar *z)
8: {
9: Mat_Diagonal *ctx = (Mat_Diagonal *)A->data;
10: ConstPetscScalarKokkosView xv, yv, wv;
12: PetscFunctionBegin;
13: PetscCall(PetscLogGpuTimeBegin());
14: PetscCall(VecGetKokkosView(x, &xv));
15: PetscCall(VecGetKokkosView(y, &yv));
16: PetscCall(VecGetKokkosView(ctx->diag, &wv));
17: // Kokkos always overwrites z, so no need to init it
18: PetscCallCXX(Kokkos::parallel_reduce("MatADot_Diagonal", Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, x->map->n), KOKKOS_LAMBDA(const PetscInt &i, PetscScalar &update) { update += PetscConj(yv(i)) * wv(i) * xv(i); }, *z));
19: PetscCall(VecRestoreKokkosView(x, &xv));
20: PetscCall(VecRestoreKokkosView(y, &yv));
21: PetscCall(VecRestoreKokkosView(ctx->diag, &wv));
22: PetscCall(PetscLogGpuTimeEnd());
23: if (x->map->n > 0) PetscCall(PetscLogGpuFlops(3.0 * x->map->n));
24: PetscFunctionReturn(PETSC_SUCCESS);
25: }
27: PETSC_INTERN PetscErrorCode MatANormSq_Diagonal_SeqKokkos(Mat A, Vec x, PetscReal *z)
28: {
29: Mat_Diagonal *ctx = (Mat_Diagonal *)A->data;
30: ConstPetscScalarKokkosView xv, wv;
31: PetscScalar res = 0.;
33: PetscFunctionBegin;
34: PetscCall(PetscLogGpuTimeBegin());
35: PetscCall(VecGetKokkosView(x, &xv));
36: PetscCall(VecGetKokkosView(ctx->diag, &wv));
37: PetscCallCXX(Kokkos::parallel_reduce("MatANorm_Diagonal", Kokkos::RangePolicy<>(PetscGetKokkosExecutionSpace(), 0, x->map->n), KOKKOS_LAMBDA(const PetscInt &i, PetscScalar &update) { update += PetscConj(xv(i)) * wv(i) * xv(i); }, res));
38: PetscCall(VecRestoreKokkosView(x, &xv));
39: PetscCall(VecRestoreKokkosView(ctx->diag, &wv));
40: PetscCall(PetscLogGpuTimeEnd());
41: *z = PetscRealPart(res);
42: if (x->map->n > 0) PetscCall(PetscLogGpuFlops(3.0 * x->map->n));
43: PetscFunctionReturn(PETSC_SUCCESS);
44: }