Actual source code: letkf_local_analysis.kokkos.cxx

  1: #include "../src/ml/da/impls/ensemble/letkf/letkf.h"
  2: #include <petscblaslapack.h>
  3: #include <Kokkos_Core.hpp>
  4: #include <KokkosBlas.hpp>
  5: #include <KokkosBatched_SVD_Decl.hpp>
  6: #include <KokkosBatched_SVD_Serial_Impl.hpp>
  7: #include <KokkosBatched_Gemm_Decl.hpp>
  8: #include <KokkosBatched_Gemm_Serial_Impl.hpp>
  9: #include <KokkosBatched_Util.hpp>

 11: #if defined(KOKKOS_ENABLE_CUDA)
 12:   #include <cusolverDn.h>
 13:   #include <cuda_runtime.h>
 14: #include <petscdevice_cuda.h>
 15: #elif defined(KOKKOS_ENABLE_HIP)
 16:   #include <rocsolver/rocsolver.h>
 17:   #include <hip/hip_runtime.h>
 18: #include <petscdevice_hip.h>
 19: #elif defined(KOKKOS_ENABLE_SYCL)
 20:   #include <oneapi/mkl.hpp>
 21:   #include <sycl/sycl.hpp>
 22:   #include <petscdevice_sycl.h>
 23: #endif

 25: /* ========================================================================== */
 26: /*                    Batched Eigendecomposition for LETKF                    */
 27: /* ========================================================================== */

 29: /* Structure to hold reusable workspace for eigensolvers */
 30: struct EigenWorkspace {
 31:   /* Tracking for reuse */
 32:   PetscInt max_chunk_size;
 33:   PetscInt m;
 34:   PetscInt n_obs_vertex;

 36:   /* Persistent Kokkos Views */
 37:   using exec_space = Kokkos::DefaultExecutionSpace;
 38:   using view_3d    = Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, exec_space>;
 39:   using view_2d    = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>;

 41:   view_3d Z_batch;
 42:   view_3d S_batch;
 43:   view_3d T_batch;
 44:   view_3d V_batch;
 45:   view_2d Lambda_batch;
 46:   view_3d T_sqrt_batch;
 47:   view_2d w_batch;
 48:   view_2d delta_batch;
 49:   view_2d y_batch;
 50:   view_2d y_mean_batch;
 51:   view_2d r_inv_sqrt_batch;
 52:   view_2d temp1_batch;
 53:   view_2d temp2_batch;
 54:   view_2d inv_sqrt_lambda_batch;

 56:   /* Host workspace */
 57:   PetscScalar *all_v;
 58:   PetscReal   *all_lambda;
 59:   PetscScalar *all_work;
 60: #if defined(PETSC_USE_COMPLEX)
 61:   PetscReal *all_rwork;
 62: #endif
 63:   PetscBLASInt lwork;
 64:   PetscBLASInt n_blas;

 66:   /* Device workspace */
 67: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
 68:   #if defined(KOKKOS_ENABLE_CUDA)
 69:   syevjInfo_t  syevj_params;
 70:   PetscScalar *d_work;
 71:   int         *d_info;
 72:   PetscScalar *d_A_contig;
 73:   PetscScalar *d_W_contig;
 74:   int          lwork_device;
 75:   #elif defined(KOKKOS_ENABLE_HIP)
 76:   PetscScalar *d_work;
 77:   int         *d_info;
 78:   PetscScalar *d_A_contig;
 79:   PetscScalar *d_W_contig;
 80:   int          lwork_device;
 81:   #elif defined(KOKKOS_ENABLE_SYCL)
 82:   PetscScalar *d_work;
 83:   int         *d_info;
 84:   PetscScalar *d_A_contig;
 85:   PetscScalar *d_W_contig;
 86:   int          lwork_device;
 87:   #endif
 88: #endif

 90:   EigenWorkspace() : max_chunk_size(0), m(0), n_obs_vertex(0), all_v(nullptr), all_lambda(nullptr), all_work(nullptr)
 91:   {
 92: #if defined(PETSC_USE_COMPLEX)
 93:     all_rwork = nullptr;
 94: #endif
 95: #if defined(KOKKOS_ENABLE_CUDA)
 96:     d_work       = nullptr;
 97:     d_info       = nullptr;
 98:     d_A_contig   = nullptr;
 99:     d_W_contig   = nullptr;
100:     syevj_params = nullptr;
101: #elif defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
102:     d_work     = nullptr;
103:     d_info     = nullptr;
104:     d_A_contig = nullptr;
105:     d_W_contig = nullptr;
106: #endif
107:   }
108: };

110: /*
111:   BatchedEigenSolve_Host - Compute eigendecomposition for a batch of symmetric matrices (CPU version)

113:   Input Parameters:
114: + T_batch      - batch of symmetric matrices (n_batch x n_size x n_size)
115: . n_batch      - number of matrices in the batch
116: - n_size       - size of each matrix (m x m)
117: - work         - reusable workspace structure

119:   Output Parameters:
120: + Lambda_batch - eigenvalues for each matrix (n_batch x n_size)
121: - V_batch      - eigenvectors for each matrix (n_batch x n_size x n_size)

123:   Notes:
124:   Uses LAPACK's syev routine to compute eigendecomposition sequentially on host.
125: */
126: #if !defined(KOKKOS_ENABLE_CUDA) && !defined(KOKKOS_ENABLE_HIP) && !defined(KOKKOS_ENABLE_SYCL)
127: static PetscErrorCode BatchedEigenSolve_Host(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, EigenWorkspace *work)
128: {
129:   PetscFunctionBegin;
130:   /* Create host mirrors and copy data in one operation */
131:   /* This is required for HIP+complex where create_mirror_view + deep_copy fails */
132:   auto T_host      = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), T_batch);
133:   auto Lambda_host = Kokkos::create_mirror_view(Kokkos::HostSpace(), Lambda_batch);
134:   auto V_host      = Kokkos::create_mirror_view(Kokkos::HostSpace(), V_batch);

136:   /* Use pre-allocated workspace */
137:   PetscScalar *all_v      = work->all_v;
138:   PetscReal   *all_lambda = work->all_lambda;
139:   PetscScalar *all_work   = work->all_work;
140:   PetscBLASInt lwork      = work->lwork;
141:   PetscBLASInt n_blas     = work->n_blas;
142:   #if defined(PETSC_USE_COMPLEX)
143:   PetscReal *all_rwork = work->all_rwork;
144:   #endif

146:   /* Process each matrix in parallel on host using LAPACK */
147:   Kokkos::parallel_for(
148:     "BatchedEigenSolve_Host", Kokkos::RangePolicy<Kokkos::DefaultHostExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
149:       PetscBLASInt n   = n_blas;
150:       PetscBLASInt lda = n;
151:       PetscBLASInt info;
152:       PetscBLASInt lw = lwork;

154:       /* Pointers for this matrix */
155:       PetscScalar *v_ptr      = all_v + i * n_size * n_size;
156:       PetscReal   *lambda_ptr = all_lambda + i * n_size;
157:       PetscScalar *work_ptr   = all_work + i * lwork;
158:   #if defined(PETSC_USE_COMPLEX)
159:       PetscReal *rwork_ptr = all_rwork + i * (3 * n_size - 2);
160:   #endif

162:       /* Copy T_host(i, :, :) to v_ptr (column-major) */
163:       for (PetscInt j = 0; j < n_size; j++) {
164:         for (PetscInt k = 0; k < n_size; k++) v_ptr[k + j * n_size] = T_host(i, k, j);
165:       }

167:     /* Compute eigendecomposition: T = V * Lambda * V^T */
168:   #if defined(PETSC_USE_COMPLEX)
169:       LAPACKsyev_("V", "U", &n, v_ptr, &lda, lambda_ptr, work_ptr, &lw, rwork_ptr, &info);
170:   #else
171:       LAPACKsyev_("V", "U", &n, v_ptr, &lda, lambda_ptr, work_ptr, &lw, &info);
172:   #endif

174:       if (info != 0) {
175:         /* We cannot return error code from lambda, so we just abort or ignore.
176:            In production code, we should use a reduction to report errors. */
177:         Kokkos::abort("LAPACK eigendecomposition failed in parallel region");
178:       }

180:       /* Copy results back to host views */
181:       for (PetscInt j = 0; j < n_size; j++) {
182:         Lambda_host(i, j) = (PetscScalar)lambda_ptr[j];
183:         for (PetscInt k = 0; k < n_size; k++) V_host(i, k, j) = v_ptr[k + j * n_size];
184:       }
185:     });

187:   /* Copy results back to device */
188:   Kokkos::deep_copy(Lambda_batch, Lambda_host);
189:   Kokkos::deep_copy(V_batch, V_host);
190:   PetscFunctionReturn(PETSC_SUCCESS);
191: }
192: #endif

194: /*
195:   BatchedEigenSolve_Device - Compute eigendecomposition for a batch of symmetric matrices (Device version)

197:   Input Parameters:
198: + T_batch      - batch of symmetric matrices (n_batch x n_size x n_size)
199: . n_batch      - number of matrices in the batch
200: - n_size       - size of each matrix (m x m)
201: - device_handle - device-specific solver handle (cusolverDnHandle_t, rocblas_handle, or sycl::queue*)
202: - work         - reusable workspace structure

204:   Output Parameters:
205: + Lambda_batch - eigenvalues for each matrix (n_batch x n_size)
206: - V_batch      - eigenvectors for each matrix (n_batch x n_size x n_size)

208:   Notes:
209:   Uses vendor-specific batched symmetric eigensolvers:
210:   - CUDA: cuSOLVER's syevjBatched
211:   - HIP: rocSOLVER's rocsolver_dsyevj_batched
212:   - SYCL: oneMKL's syevd_batch
213: */
214: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
215:   #if defined(KOKKOS_ENABLE_CUDA)
216: static PetscErrorCode BatchedEigenSolve_Device(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, cusolverDnHandle_t cusolverH, EigenWorkspace *work)
217: {
218:   cusolverStatus_t cusolver_status;

220:   PetscFunctionBegin;
221:   /* Use pre-allocated workspace */
222:   syevjInfo_t  syevj_params = work->syevj_params;
223:   PetscScalar *d_work       = work->d_work;
224:   int         *d_info       = work->d_info;
225:   PetscScalar *d_A_contig   = work->d_A_contig;
226:   PetscScalar *d_W_contig   = work->d_W_contig;
227:   int          lwork        = work->lwork_device;

229:   /* Copy T_batch to contiguous layout for cuSOLVER */
230:   Kokkos::parallel_for(
231:     "ReorganizeForCuSOLVER", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
232:       for (int j = 0; j < n_size; j++) {
233:         for (int k = 0; k < n_size; k++) d_A_contig[i * n_size * n_size + k * n_size + j] = T_batch(i, j, k);
234:       }
235:     });
236:   Kokkos::fence();

238:     /* Solve batched eigendecomposition */
239:     #if defined(PETSC_USE_REAL_SINGLE)
240:   cusolver_status = cusolverDnSsyevjBatched(cusolverH, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, n_size, d_A_contig, n_size, d_W_contig, d_work, lwork, d_info, syevj_params, n_batch);
241:     #else
242:   cusolver_status = cusolverDnDsyevjBatched(cusolverH, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, n_size, d_A_contig, n_size, d_W_contig, d_work, lwork, d_info, syevj_params, n_batch);
243:     #endif
244:   PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDn*syevjBatched failed");

246:   /* Check info */
247:   int *h_info;
248:   PetscCall(PetscMalloc1(n_batch, &h_info));
249:   PetscCallCUDA(cudaMemcpy(h_info, d_info, sizeof(int) * n_batch, cudaMemcpyDeviceToHost));
250:   for (int i = 0; i < n_batch; i++) {
251:     if (h_info[i] != 0) PetscCheck(h_info[i] == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "cuSOLVER eigendecomposition failed for matrix %" PetscInt_FMT ": info=%d", i, h_info[i]);
252:   }
253:   PetscCall(PetscFree(h_info));

255:   /* Copy results back from contiguous layout to V_batch */
256:   Kokkos::parallel_for(
257:     "CopyResultsBack", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
258:       for (int j = 0; j < n_size; j++) {
259:         Lambda_batch(i, j) = d_W_contig[i * n_size + j];
260:         for (int k = 0; k < n_size; k++) V_batch(i, j, k) = d_A_contig[i * n_size * n_size + k * n_size + j];
261:       }
262:     });
263:   Kokkos::fence();
264:   PetscFunctionReturn(PETSC_SUCCESS);
265: }
266:   #elif defined(KOKKOS_ENABLE_HIP)
267: static PetscErrorCode BatchedEigenSolve_Device(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, rocblas_handle rocblasH, EigenWorkspace *work)
268: {
269:   PetscFunctionBegin;
270:   /* Use pre-allocated workspace */
271:   PetscScalar *d_work = work->d_work;
272:   (void)d_work;
273:   int         *d_info     = work->d_info;
274:   PetscScalar *d_A_contig = work->d_A_contig;
275:   PetscScalar *d_W_contig = work->d_W_contig;

277:   /* Copy T_batch to contiguous layout for rocSOLVER */
278:   Kokkos::parallel_for(
279:     "ReorganizeForRocSOLVER", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
280:       for (int j = 0; j < n_size; j++) {
281:         for (int k = 0; k < n_size; k++) d_A_contig[i * n_size * n_size + k * n_size + j] = T_batch(i, j, k);
282:       }
283:     });
284:   Kokkos::fence();

286:     /* rocSOLVER doesn't have a native batched syevj, so we loop over batch */
287:     /* Use rocsolver_dsyevd which is more efficient than calling syev in a loop */
288:     #if defined(PETSC_USE_COMPLEX)
289:   SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Complex numbers not supported on HIP backend for LETKF");
290:     #else
291:   for (int i = 0; i < n_batch; i++) {
292:     PetscScalar   *A_ptr    = d_A_contig + i * n_size * n_size;
293:     PetscScalar   *W_ptr    = d_W_contig + i * n_size;
294:     int           *info_ptr = d_info + i;
295:     rocblas_status hip_status;

297:       #if defined(PETSC_USE_REAL_SINGLE)
298:     hip_status = rocsolver_ssyevd(rocblasH, rocblas_evect_original, rocblas_fill_upper, n_size, A_ptr, n_size, W_ptr, d_work, info_ptr);
299:       #else
300:     hip_status = rocsolver_dsyevd(rocblasH, rocblas_evect_original, rocblas_fill_upper, n_size, A_ptr, n_size, W_ptr, d_work, info_ptr);
301:       #endif
302:     PetscCheck(hip_status == rocblas_status_success, PETSC_COMM_SELF, PETSC_ERR_LIB, "rocsolver_*syevd failed for batch %" PetscInt_FMT, i);
303:   }
304:     #endif

306:   /* Check info */
307:   int *h_info;
308:   PetscCall(PetscMalloc1(n_batch, &h_info));
309:   PetscCallHIP(hipMemcpy(h_info, d_info, sizeof(int) * n_batch, hipMemcpyDeviceToHost));
310:   for (int i = 0; i < n_batch; i++) {
311:     if (h_info[i] != 0) PetscCheck(h_info[i] == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "rocSOLVER eigendecomposition failed for matrix %" PetscInt_FMT ": info=%d", i, h_info[i]);
312:   }
313:   PetscCall(PetscFree(h_info));

315:   /* Copy results back from contiguous layout to V_batch */
316:   Kokkos::parallel_for(
317:     "CopyResultsBack", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
318:       for (int j = 0; j < n_size; j++) {
319:         Lambda_batch(i, j) = d_W_contig[i * n_size + j];
320:         for (int k = 0; k < n_size; k++) V_batch(i, j, k) = d_A_contig[i * n_size * n_size + k * n_size + j];
321:       }
322:     });
323:   Kokkos::fence();
324:   PetscFunctionReturn(PETSC_SUCCESS);
325: }
326:   #elif defined(KOKKOS_ENABLE_SYCL)
327: static PetscErrorCode BatchedEigenSolve_Device(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, sycl::queue *q, EigenWorkspace *work)
328: {
329:   PetscFunctionBegin;
330:   /* Use pre-allocated workspace */
331:   PetscScalar *d_work     = work->d_work;
332:   int         *d_info     = work->d_info;
333:   PetscScalar *d_A_contig = work->d_A_contig;
334:   PetscScalar *d_W_contig = work->d_W_contig;

336:   /* Copy T_batch to contiguous layout for oneMKL */
337:   Kokkos::parallel_for(
338:     "ReorganizeForOneMKL", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
339:       for (int j = 0; j < n_size; j++) {
340:         for (int k = 0; k < n_size; k++) d_A_contig[i * n_size * n_size + k * n_size + j] = T_batch(i, j, k);
341:       }
342:     });
343:   Kokkos::fence();

345:   /* oneMKL doesn't have a native batched syevd, so we loop over batch */
346:   /* Use oneapi::mkl::lapack::syevd which computes eigenvalues and eigenvectors */
347:   for (int i = 0; i < n_batch; i++) {
348:     PetscScalar *A_ptr    = d_A_contig + i * n_size * n_size;
349:     PetscScalar *W_ptr    = d_W_contig + i * n_size;
350:     int         *info_ptr = d_info + i;

352:     try {
353:     #if defined(PETSC_USE_REAL_SINGLE)
354:       oneapi::mkl::lapack::syevd(*q, oneapi::mkl::job::vec, oneapi::mkl::uplo::upper, n_size, A_ptr, n_size, W_ptr, d_work, work->lwork_device, info_ptr);
355:     #else
356:       oneapi::mkl::lapack::syevd(*q, oneapi::mkl::job::vec, oneapi::mkl::uplo::upper, n_size, A_ptr, n_size, W_ptr, d_work, work->lwork_device, info_ptr);
357:     #endif
358:       q->wait();
359:     } catch (sycl::exception const &e) {
360:       SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "oneMKL syevd failed for batch %d: %s", i, e.what());
361:     }
362:   }

364:   /* Check info */
365:   int *h_info;
366:   PetscCall(PetscMalloc1(n_batch, &h_info));
367:   q->memcpy(h_info, d_info, sizeof(int) * n_batch).wait();
368:   for (int i = 0; i < n_batch; i++) {
369:     if (h_info[i] != 0) PetscCheck(h_info[i] == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "oneMKL eigendecomposition failed for matrix %" PetscInt_FMT ": info=%d", i, h_info[i]);
370:   }
371:   PetscCall(PetscFree(h_info));

373:   /* Copy results back from contiguous layout to V_batch */
374:   Kokkos::parallel_for(
375:     "CopyResultsBack", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
376:       for (int j = 0; j < n_size; j++) {
377:         Lambda_batch(i, j) = d_W_contig[i * n_size + j];
378:         for (int k = 0; k < n_size; k++) V_batch(i, j, k) = d_A_contig[i * n_size * n_size + k * n_size + j];
379:       }
380:     });
381:   Kokkos::fence();
382:   PetscFunctionReturn(PETSC_SUCCESS);
383: }
384:   #endif
385: #endif

387: /*
388:   BatchedEigenSolve - Compute eigendecomposition for a batch of symmetric matrices

390:   Input Parameters:
391: + T_batch      - batch of symmetric matrices (n_batch x n_size x n_size)
392: . n_batch      - number of matrices in the batch
393: - n_size       - size of each matrix (m x m)
394: - device_handle - device-specific solver handle (only for device builds)
395: - work         - reusable workspace structure

397:   Output Parameters:
398: + Lambda_batch - eigenvalues for each matrix (n_batch x n_size)
399: - V_batch      - eigenvectors for each matrix (n_batch x n_size x n_size)

401:   Notes:
402:   Dispatcher function that calls the appropriate backend (Device or Host).
403: */
404: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
405:   #if defined(KOKKOS_ENABLE_CUDA)
406: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, cusolverDnHandle_t device_handle, EigenWorkspace *work)
407: {
408:   PetscFunctionBegin;
409:   PetscCall(BatchedEigenSolve_Device(T_batch, Lambda_batch, V_batch, n_batch, n_size, device_handle, work));
410:   PetscFunctionReturn(PETSC_SUCCESS);
411: }
412:   #elif defined(KOKKOS_ENABLE_HIP)
413: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, rocblas_handle device_handle, EigenWorkspace *work)
414: {
415:   PetscFunctionBegin;
416:   PetscCall(BatchedEigenSolve_Device(T_batch, Lambda_batch, V_batch, n_batch, n_size, device_handle, work));
417:   PetscFunctionReturn(PETSC_SUCCESS);
418: }
419:   #elif defined(KOKKOS_ENABLE_SYCL)
420: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, sycl::queue *device_handle, EigenWorkspace *work)
421: {
422:   PetscFunctionBegin;
423:   PetscCall(BatchedEigenSolve_Device(T_batch, Lambda_batch, V_batch, n_batch, n_size, device_handle, work));
424:   PetscFunctionReturn(PETSC_SUCCESS);
425: }
426:   #endif
427: #else
428: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, EigenWorkspace *work)
429: {
430:   PetscFunctionBegin;
431:   PetscCall(BatchedEigenSolve_Host(T_batch, Lambda_batch, V_batch, n_batch, n_size, work));
432:   PetscFunctionReturn(PETSC_SUCCESS);
433: }
434: #endif

436: /*
437:   PetscDALETKFSetupLocalization_Kokkos - Prepares device views for localization matrix Q
438: */
439: PetscErrorCode PetscDALETKFSetupLocalization_Kokkos(PetscDA_LETKF *impl, Mat H)
440: {
441:   PetscInt nrows;

443:   PetscFunctionBegin;
444:   PetscCheck(impl->Q, PETSC_COMM_SELF, PETSC_ERR_LIB, "impl->Q = 0");
445:   PetscCall(PetscKokkosInitializeCheck());

447:   /* Get CSR data */
448:   PetscInt rstart, rend, i, nnz;
449:   PetscCall(MatGetOwnershipRange(impl->Q, &rstart, &rend));
450:   nrows = rend - rstart;

452:   /* Create IS for local observations needed by this process */
453:   /* We need to find all unique column indices in the local rows of Q */
454:   {
455:     PetscInt     *obs_indices;
456:     PetscInt      n_obs_local_total = 0;
457:     PetscInt      max_obs           = nrows * impl->n_obs_vertex;
458:     PetscInt      count             = 0;
459:     PetscHMapI    ht;
460:     PetscHashIter iter;
461:     PetscBool     missing;

463:     PetscCall(PetscHMapICreate(&ht));
464:     PetscCall(PetscMalloc1(max_obs, &obs_indices));

466:     for (i = 0; i < nrows; i++) {
467:       const PetscInt    *cols;
468:       const PetscScalar *vals;
469:       PetscCall(MatGetRow(impl->Q, rstart + i, &nnz, &cols, &vals));
470:       for (PetscInt k = 0; k < nnz; k++) {
471:         PetscCall(PetscHMapIPut(ht, cols[k], &iter, &missing));
472:         if (missing) {
473:           obs_indices[count] = cols[k];
474:           count++;
475:         }
476:       }
477:       PetscCall(MatRestoreRow(impl->Q, rstart + i, &nnz, &cols, &vals));
478:     }
479:     n_obs_local_total = count;

481:     /* Sort indices for consistent ordering */
482:     PetscCall(PetscSortInt(n_obs_local_total, obs_indices));

484:     /* Create IS and VecScatter */
485:     PetscCall(ISCreateGeneral(PETSC_COMM_SELF, n_obs_local_total, obs_indices, PETSC_COPY_VALUES, &impl->obs_is_local));

487:     /* Create global-to-local map for observations */
488:     PetscCall(PetscHMapICreate(&impl->obs_g2l));
489:     for (i = 0; i < n_obs_local_total; i++) {
490:       PetscCall(PetscHMapIPut(impl->obs_g2l, obs_indices[i], &iter, &missing));
491:       PetscCall(PetscHMapIIterSet(impl->obs_g2l, iter, i));
492:     }

494:     PetscCall(PetscFree(obs_indices));
495:     PetscCall(PetscHMapIDestroy(&ht));
496:   }

498:   /* Create work vectors and scatter context */
499:   {
500:     PetscInt n_obs_local_total;
501:     PetscCall(ISGetLocalSize(impl->obs_is_local, &n_obs_local_total));

503:     PetscCall(VecCreateSeq(PETSC_COMM_SELF, n_obs_local_total, &impl->obs_work));
504:     PetscCall(VecCreateSeq(PETSC_COMM_SELF, n_obs_local_total, &impl->y_mean_work));
505:     PetscCall(VecCreateSeq(PETSC_COMM_SELF, n_obs_local_total, &impl->r_inv_sqrt_work));

507:     Vec gvec;
508:     IS  is_to;
509:     PetscCall(MatCreateVecs(H, NULL, &gvec)); /* Create template global vector (left vector = rows = observations) */
510:     PetscCall(ISCreateStride(PETSC_COMM_SELF, n_obs_local_total, 0, 1, &is_to));
511:     PetscCall(VecScatterCreate(gvec, impl->obs_is_local, impl->obs_work, is_to, &impl->obs_scat));
512:     PetscCall(VecDestroy(&gvec));
513:     PetscCall(ISDestroy(&is_to));
514:   }

516:   /* Define View types */
517:   using view_1d_int    = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
518:   using view_1d_scalar = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft>;

520:   /* Allocate device views */
521:   view_1d_int    *d_Q_i = new view_1d_int("Q_i", nrows + 1);
522:   view_1d_int    *d_Q_j = new view_1d_int("Q_j", nrows * impl->n_obs_vertex);
523:   view_1d_scalar *d_Q_a = new view_1d_scalar("Q_a", nrows * impl->n_obs_vertex);

525:   /* Create host mirrors */
526:   auto h_Q_i = Kokkos::create_mirror_view(*d_Q_i);
527:   auto h_Q_j = Kokkos::create_mirror_view(*d_Q_j);
528:   auto h_Q_a = Kokkos::create_mirror_view(*d_Q_a);

530:   /* Fill host mirrors with LOCAL indices into obs_work */
531:   h_Q_i(0) = 0;
532:   for (i = 0; i < nrows; i++) {
533:     const PetscInt    *cols;
534:     const PetscScalar *vals;
535:     PetscCall(MatGetRow(impl->Q, rstart + i, &nnz, &cols, &vals));
536:     h_Q_i(i + 1) = h_Q_i(i) + nnz;
537:     for (PetscInt k = 0; k < nnz; k++) {
538:       PetscInt local_idx;
539:       PetscCall(ISLocate(impl->obs_is_local, cols[k], &local_idx));
540:       PetscCheck(local_idx >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Observation index %" PetscInt_FMT " not found in local IS", cols[k]);
541:       h_Q_j(h_Q_i(i) + k) = local_idx;
542:       h_Q_a(h_Q_i(i) + k) = vals[k];
543:     }
544:     PetscCall(MatRestoreRow(impl->Q, rstart + i, &nnz, &cols, &vals));
545:   }

547:   /* Copy to device */
548:   Kokkos::deep_copy(*d_Q_i, h_Q_i);
549:   Kokkos::deep_copy(*d_Q_j, h_Q_j);
550:   Kokkos::deep_copy(*d_Q_a, h_Q_a);

552:   /* Store in impl */
553:   PetscCheck(!impl->Q_device_i, PETSC_COMM_SELF, PETSC_ERR_LIB, "impl->Q = 0");
554:   impl->Q_device_i = static_cast<void *>(d_Q_i);
555:   impl->Q_device_j = static_cast<void *>(d_Q_j);
556:   impl->Q_device_a = static_cast<void *>(d_Q_a);
557:   PetscFunctionReturn(PETSC_SUCCESS);
558: }

560: PetscErrorCode PetscDALETKFDestroyLocalization_Kokkos(PetscDA_LETKF *impl)
561: {
562:   PetscFunctionBegin;
563:   PetscCall(VecDestroy(&impl->obs_work));
564:   PetscCall(VecDestroy(&impl->y_mean_work));
565:   PetscCall(VecDestroy(&impl->r_inv_sqrt_work));
566:   PetscCall(VecScatterDestroy(&impl->obs_scat));
567:   PetscCall(MatDestroy(&impl->Z_work));
568:   PetscCall(PetscHMapIDestroy(&impl->obs_g2l));
569:   if (impl->Q_device_i) {
570:     using view_1d_int = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
571:     delete static_cast<view_1d_int *>(impl->Q_device_i);
572:     impl->Q_device_i = NULL;
573:   }
574:   if (impl->Q_device_j) {
575:     using view_1d_int = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
576:     delete static_cast<view_1d_int *>(impl->Q_device_j);
577:     impl->Q_device_j = NULL;
578:   }
579:   if (impl->Q_device_a) {
580:     using view_1d_scalar = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft>;
581:     delete static_cast<view_1d_scalar *>(impl->Q_device_a);
582:     impl->Q_device_a = NULL;
583:   }

585:   /* Destroy solver handle and workspace */
586:   if (impl->eigen_work) {
587:     EigenWorkspace *work = static_cast<EigenWorkspace *>(impl->eigen_work);

589: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
590:   #if defined(KOKKOS_ENABLE_CUDA)
591:     PetscCallCUDA(cudaFree(work->d_A_contig));
592:     PetscCallCUDA(cudaFree(work->d_W_contig));
593:     PetscCallCUDA(cudaFree(work->d_work));
594:     PetscCallCUDA(cudaFree(work->d_info));
595:     if (work->syevj_params) cusolverDnDestroySyevjInfo(work->syevj_params);
596:   #elif defined(KOKKOS_ENABLE_HIP)
597:     PetscCallHIP(hipFree(work->d_A_contig));
598:     PetscCallHIP(hipFree(work->d_W_contig));
599:     PetscCallHIP(hipFree(work->d_work));
600:     PetscCallHIP(hipFree(work->d_info));
601:   #elif defined(KOKKOS_ENABLE_SYCL)
602:     if (impl->solver_handle) {
603:       sycl::queue *q = static_cast<sycl::queue *>(impl->solver_handle);
604:       if (work->d_A_contig) sycl::free(work->d_A_contig, *q);
605:       if (work->d_W_contig) sycl::free(work->d_W_contig, *q);
606:       if (work->d_work) sycl::free(work->d_work, *q);
607:       if (work->d_info) sycl::free(work->d_info, *q);
608:     }
609:   #endif
610: #else
611:   #if defined(PETSC_USE_COMPLEX)
612:     PetscCall(PetscFree4(work->all_v, work->all_lambda, work->all_work, work->all_rwork));
613:   #else
614:     PetscCall(PetscFree3(work->all_v, work->all_lambda, work->all_work));
615:   #endif
616: #endif

618:     delete work;
619:     impl->eigen_work = NULL;
620:   }

622:   if (impl->solver_handle) {
623: #if defined(KOKKOS_ENABLE_CUDA)
624:     cusolverDnDestroy(static_cast<cusolverDnHandle_t>(impl->solver_handle));
625: #elif defined(KOKKOS_ENABLE_HIP)
626:     rocblas_destroy_handle(static_cast<rocblas_handle>(impl->solver_handle));
627: #elif defined(KOKKOS_ENABLE_SYCL)
628:     delete static_cast<sycl::queue *>(impl->solver_handle);
629: #endif
630:     impl->solver_handle = NULL;
631:   }
632:   PetscFunctionReturn(PETSC_SUCCESS);
633: }

635: /* ========================================================================== */
636: /*                    LETKF Local Analysis (Main Function)                    */
637: /* ========================================================================== */

639: /*
640:   PetscDALETKFLocalAnalysis_GPU - Performs local LETKF analysis for all grid points (Kokkos version)

642:   Input Parameters:
643: + da             - the PetscDA context
644: . impl           - LETKF implementation data
645: . m              - ensemble size
646: . n_vertices     - number of grid points
647: . X              - global anomaly matrix (state_size x m)
648: . observation    - observation vector
649: . Z_global       - global observation ensemble (obs_size x m)
650: . y_mean_global  - global observation mean
651: - r_inv_sqrt_global - global R^{-1/2}

653:   Output:
654: . da->ensemble - updated with analysis ensemble

656:   Notes:
657:   This function performs the local analysis loop for LETKF, processing each grid point
658:   independently using its local observations defined by the localization matrix Q.
659:   This is the CPU version that does not use Kokkos acceleration.

661:   All local analysis workspace objects (Z_local, S_local, T_sqrt_local, G_local, y_local,
662:   y_mean_local, delta_scaled_local, r_inv_sqrt_local, w_local, s_transpose_delta) are
663:   created with PETSC_COMM_SELF because the analysis at each vertex is serial and independent.
664: */
665: PetscErrorCode PetscDALETKFLocalAnalysis_GPU(PetscDA da, PetscDA_LETKF *impl, PetscInt m, PetscInt n_vertices, Mat X, Vec observation, Mat Z_global, Vec y_mean_global, Vec r_inv_sqrt_global)
666: {
667:   PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
668:   PetscInt          ndof;
669:   PetscReal         sqrt_m_minus_1, scale, inflation_inv;

671:   PetscFunctionBegin;
672:   ndof           = da->ndof;
673:   scale          = 1.0 / PetscSqrtReal((PetscReal)(m - 1));
674:   sqrt_m_minus_1 = PetscSqrtReal((PetscReal)(m - 1));
675:   inflation_inv  = 1.0 / en->inflation; /* (1/rho) for T matrix: T = (1/rho)I + S^T*S */

677:   /* ===================================================================== */
678:   /* Step 2.1.1: Create batched workspace for ALL grid points            */
679:   /* ===================================================================== */
680:   /*
681:      NOTE ON PARALLELISM STRATEGY:
682:      We use Kokkos::RangePolicy over grid points (n_vertices) combined with KokkosBatched::Serial kernels.
683:      Since the data layout is LayoutLeft (Column-Major) to match PETSc/LAPACK, the index 'i' (grid point)
684:      is the fastest varying index (stride 1).

686:      RangePolicy maps consecutive threads to consecutive 'i', ensuring perfect memory coalescing
687:      when accessing arrays like S_batch(i, p, j).

689:      Using TeamPolicy/TeamVectorRange to parallelize inner loops (m or p) would assign a team to 'i',
690:      causing threads within the team to access S_batch with stride 'n_vertices', which leads to
691:      uncoalesced memory access and poor performance on GPUs.

693:      Therefore, RangePolicy + SerialGemm is the optimal strategy for this data layout.
694:   */
695:   using exec_space = Kokkos::DefaultExecutionSpace;
696:   using view_3d    = Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, exec_space>;
697:   using view_2d    = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>;

699:   /* ===================================================================== */
700:   /* Step 2.1.2a: Pre-extract Q matrix CSR data for device access        */
701:   /* ===================================================================== */
702:   using view_1d_int_const    = Kokkos::View<const PetscInt *, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
703:   using view_1d_scalar_const = Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
704:   using view_1d_int          = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
705:   using view_1d_scalar       = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft>;

707:   view_1d_int_const    Q_i_view;
708:   view_1d_int_const    Q_j_view;
709:   view_1d_scalar_const Q_a_view;

711:   if (impl->Q_device_i) {
712:     /* Use pre-allocated device views */
713:     view_1d_int    *d_Q_i = static_cast<view_1d_int *>(impl->Q_device_i);
714:     view_1d_int    *d_Q_j = static_cast<view_1d_int *>(impl->Q_device_j);
715:     view_1d_scalar *d_Q_a = static_cast<view_1d_scalar *>(impl->Q_device_a);

717:     Q_i_view = view_1d_int_const(d_Q_i->data(), d_Q_i->extent(0));
718:     Q_j_view = view_1d_int_const(d_Q_j->data(), d_Q_j->extent(0));
719:     Q_a_view = view_1d_scalar_const(d_Q_a->data(), d_Q_a->extent(0));
720:   } else {
721:     /* Fallback to host pointers (unsafe if not UVM) */
722:     PetscCheck(PETSC_FALSE, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Q matrix must be setup with PetscDALETKFSetupLocalization_Kokkos");
723:   }

725:   /* Get global observation data arrays */
726:   const PetscScalar *z_global_array, *y_global_array, *y_mean_global_array, *r_inv_sqrt_global_array;
727:   PetscInt           lda_z_global;
728:   PetscMemType       z_mem_type, y_mem_type, y_mean_mem_type, r_inv_sqrt_mem_type;

730:   PetscCall(MatDenseGetArrayReadAndMemType(Z_global, &z_global_array, &z_mem_type));
731:   PetscCall(VecGetArrayReadAndMemType(observation, &y_global_array, &y_mem_type));
732:   PetscCall(VecGetArrayReadAndMemType(y_mean_global, &y_mean_global_array, &y_mean_mem_type));
733:   PetscCall(VecGetArrayReadAndMemType(r_inv_sqrt_global, &r_inv_sqrt_global_array, &r_inv_sqrt_mem_type));
734:   PetscCall(MatDenseGetLDA(Z_global, &lda_z_global));

736:   /* Handle memory mirroring for observation data */
737:   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space> z_managed;
738:   Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>  y_managed;
739:   Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>  y_mean_managed;
740:   Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>  r_inv_sqrt_managed;

742:   const PetscScalar *z_ptr          = z_global_array;
743:   const PetscScalar *y_ptr          = y_global_array;
744:   const PetscScalar *y_mean_ptr     = y_mean_global_array;
745:   const PetscScalar *r_inv_sqrt_ptr = r_inv_sqrt_global_array;

747:   if (z_mem_type == PETSC_MEMTYPE_HOST) {
748:     z_managed = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>("z_managed", lda_z_global, m);
749:     Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(z_global_array, lda_z_global, m);
750:     Kokkos::deep_copy(z_managed, src);
751:     z_ptr = z_managed.data();
752:   }
753:   if (y_mem_type == PETSC_MEMTYPE_HOST) {
754:     y_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("y_managed", lda_z_global);
755:     Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(y_global_array, lda_z_global);
756:     Kokkos::deep_copy(y_managed, src);
757:     y_ptr = y_managed.data();
758:   }
759:   if (y_mean_mem_type == PETSC_MEMTYPE_HOST) {
760:     y_mean_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("y_mean_managed", lda_z_global);
761:     Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(y_mean_global_array, lda_z_global);
762:     Kokkos::deep_copy(y_mean_managed, src);
763:     y_mean_ptr = y_mean_managed.data();
764:   }
765:   if (r_inv_sqrt_mem_type == PETSC_MEMTYPE_HOST) {
766:     r_inv_sqrt_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("r_inv_sqrt_managed", lda_z_global);
767:     Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(r_inv_sqrt_global_array, lda_z_global);
768:     Kokkos::deep_copy(r_inv_sqrt_managed, src);
769:     r_inv_sqrt_ptr = r_inv_sqrt_managed.data();
770:   }

772:   /* Create unmanaged Kokkos views for global observation data */
773:   using view_2d_unmanaged = Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
774:   using view_1d_unmanaged = Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

776:   view_2d_unmanaged Z_global_view(z_ptr, lda_z_global, m);
777:   view_1d_unmanaged y_global_view(y_ptr, lda_z_global);
778:   view_1d_unmanaged y_mean_global_view(y_mean_ptr, lda_z_global);
779:   view_1d_unmanaged r_inv_sqrt_global_view(r_inv_sqrt_ptr, lda_z_global);

781:   /* Get access to global X matrix and mean vector */
782:   const PetscScalar *x_array, *mean_array;
783:   PetscScalar       *e_array;
784:   PetscInt           lda_x, lda_e;
785:   PetscMemType       x_mem_type, mean_mem_type, e_mem_type;

787:   PetscCall(MatDenseGetArrayReadAndMemType(X, &x_array, &x_mem_type));
788:   PetscCall(VecGetArrayReadAndMemType(impl->mean, &mean_array, &mean_mem_type));
789:   PetscCall(MatDenseGetArrayWriteAndMemType(en->ensemble, &e_array, &e_mem_type));
790:   PetscCall(MatDenseGetLDA(X, &lda_x));
791:   PetscCall(MatDenseGetLDA(en->ensemble, &lda_e));

793:   /* Handle memory mirroring for state data */
794:   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space> x_managed;
795:   Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>  mean_managed;
796:   Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space> e_managed;

798:   const PetscScalar *x_ptr     = x_array;
799:   const PetscScalar *mean_ptr  = mean_array;
800:   PetscScalar       *e_ptr     = e_array;
801:   bool               e_is_copy = false;

803:   if (x_mem_type == PETSC_MEMTYPE_HOST) {
804:     x_managed = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>("x_managed", lda_x, m);
805:     Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(x_array, lda_x, m);
806:     Kokkos::deep_copy(x_managed, src);
807:     x_ptr = x_managed.data();
808:   }
809:   if (mean_mem_type == PETSC_MEMTYPE_HOST) {
810:     mean_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("mean_managed", lda_x);
811:     Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(mean_array, lda_x);
812:     Kokkos::deep_copy(mean_managed, src);
813:     mean_ptr = mean_managed.data();
814:   }
815:   if (e_mem_type == PETSC_MEMTYPE_HOST) {
816:     e_managed = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>("e_managed", lda_e, m);
817:     Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(e_array, lda_e, m);
818:     Kokkos::deep_copy(e_managed, src);
819:     e_ptr     = e_managed.data();
820:     e_is_copy = true;
821:   }

823:   /* Create unmanaged Kokkos views for global data */
824:   using view_2d_unmanaged_write = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
825:   view_2d_unmanaged       X_view(const_cast<PetscScalar *>(x_ptr), lda_x, m);
826:   view_1d_unmanaged       mean_view(mean_ptr, lda_x);
827:   view_2d_unmanaged_write E_view(e_ptr, lda_e, m);

829:   /* Determine chunk size to avoid OOM on large grids */
830:   PetscInt chunk_size;
831:   if (impl->batch_size > 0) {
832:     chunk_size = impl->batch_size;
833:   } else {
834:     /* Target ~2GB workspace. Approx memory per point: m*m*8 (T) + p*m*8 (Z) */
835:     /* With reuse: m*m*8 + p*m*8 */
836:     PetscInt mem_per_point = sizeof(PetscScalar) * (m * m + impl->n_obs_vertex * m);
837:     chunk_size             = (PetscInt)(2.0 * 1024 * 1024 * 1024 / mem_per_point);
838:     /* Clamp to reasonable max to avoid huge allocations even if memory allows */
839:     if (chunk_size > 32768) chunk_size = 32768;
840:   }

842:   if (chunk_size < 1) chunk_size = 1;
843:   if (chunk_size > n_vertices) chunk_size = n_vertices;

845:   /* OPTIMIZATION: Create device solver handle once, reuse across chunks */
846: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
847:   #if defined(KOKKOS_ENABLE_CUDA)
848:   cusolverDnHandle_t device_handle = nullptr;
849:   cusolverStatus_t   cusolver_status;
850:   if (impl->solver_handle) {
851:     device_handle = static_cast<cusolverDnHandle_t>(impl->solver_handle);
852:   } else {
853:     cusolver_status = cusolverDnCreate(&device_handle);
854:     PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDnCreate failed");
855:     impl->solver_handle = static_cast<void *>(device_handle);
856:   }
857:   #elif defined(KOKKOS_ENABLE_HIP)
858:   rocblas_handle device_handle = nullptr;
859:   if (impl->solver_handle) {
860:     device_handle = static_cast<rocblas_handle>(impl->solver_handle);
861:   } else {
862:     rocblas_status hip_status = rocblas_create_handle(&device_handle);
863:     PetscCheck(hip_status == rocblas_status_success, PETSC_COMM_SELF, PETSC_ERR_LIB, "rocblas_create_handle failed");
864:     impl->solver_handle = static_cast<void *>(device_handle);
865:   }
866:   #elif defined(KOKKOS_ENABLE_SYCL)
867:   sycl::queue *device_handle = nullptr;
868:   if (impl->solver_handle) {
869:     device_handle = static_cast<sycl::queue *>(impl->solver_handle);
870:   } else {
871:     device_handle       = new sycl::queue(sycl::gpu_selector_v);
872:     impl->solver_handle = static_cast<void *>(device_handle);
873:   }
874:   #endif
875: #endif

877:   /* ===================================================================== */
878:   /* OPTIMIZATION: Hoist allocations outside the chunk loop                */
879:   /* ===================================================================== */
880:   /* Allocate Kokkos Views once for the maximum chunk size */
881:   PetscInt n_obs_vertex_copy = impl->n_obs_vertex;

883:   EigenWorkspace *eigen_work = static_cast<EigenWorkspace *>(impl->eigen_work);
884:   if (!eigen_work) {
885:     eigen_work       = new EigenWorkspace();
886:     impl->eigen_work = static_cast<void *>(eigen_work);
887:   }

889:   /* Check if reallocation is needed */
890:   if (eigen_work->max_chunk_size < chunk_size || eigen_work->m != m || eigen_work->n_obs_vertex != n_obs_vertex_copy) {
891:     /* Free old device workspace if exists */
892: #if defined(KOKKOS_ENABLE_CUDA)
893:     PetscCallCUDA(cudaFree(eigen_work->d_work));
894:     PetscCallCUDA(cudaFree(eigen_work->d_info));
895:     PetscCallCUDA(cudaFree(eigen_work->d_A_contig));
896:     PetscCallCUDA(cudaFree(eigen_work->d_W_contig));
897:     if (eigen_work->syevj_params) cusolverDnDestroySyevjInfo(eigen_work->syevj_params);
898:     eigen_work->syevj_params = nullptr;
899: #elif defined(KOKKOS_ENABLE_HIP)
900:     PetscCallHIP(hipFree(eigen_work->d_work));
901:     PetscCallHIP(hipFree(eigen_work->d_info));
902:     PetscCallHIP(hipFree(eigen_work->d_A_contig));
903:     PetscCallHIP(hipFree(eigen_work->d_W_contig));
904: #elif defined(KOKKOS_ENABLE_SYCL)
905:     if (eigen_work->d_work) sycl::free(eigen_work->d_work, *device_handle);
906:     if (eigen_work->d_info) sycl::free(eigen_work->d_info, *device_handle);
907:     if (eigen_work->d_A_contig) sycl::free(eigen_work->d_A_contig, *device_handle);
908:     if (eigen_work->d_W_contig) sycl::free(eigen_work->d_W_contig, *device_handle);
909: #endif

911: #if !defined(KOKKOS_ENABLE_CUDA) && !defined(KOKKOS_ENABLE_HIP) && !defined(KOKKOS_ENABLE_SYCL)
912:   #if defined(PETSC_USE_COMPLEX)
913:     if (eigen_work->all_v) PetscCall(PetscFree4(eigen_work->all_v, eigen_work->all_lambda, eigen_work->all_work, eigen_work->all_rwork));
914:   #else
915:     if (eigen_work->all_v) PetscCall(PetscFree3(eigen_work->all_v, eigen_work->all_lambda, eigen_work->all_work));
916:   #endif
917: #endif

919:     /* Update dimensions */
920:     eigen_work->max_chunk_size = chunk_size;
921:     eigen_work->m              = m;
922:     eigen_work->n_obs_vertex   = n_obs_vertex_copy;

924:     /* Allocate Kokkos Views */
925:     eigen_work->Z_batch               = view_3d("Z_batch", chunk_size, n_obs_vertex_copy, m);
926:     eigen_work->S_batch               = eigen_work->Z_batch;
927:     eigen_work->T_batch               = view_3d("T_batch", chunk_size, m, m);
928:     eigen_work->V_batch               = eigen_work->T_batch;
929:     eigen_work->Lambda_batch          = view_2d("Lambda_batch", chunk_size, m);
930:     eigen_work->T_sqrt_batch          = view_3d("T_sqrt_batch", chunk_size, m, m);
931:     eigen_work->w_batch               = view_2d("w_batch", chunk_size, m);
932:     eigen_work->delta_batch           = view_2d("delta_batch", chunk_size, n_obs_vertex_copy);
933:     eigen_work->y_batch               = view_2d("y_batch", chunk_size, n_obs_vertex_copy);
934:     eigen_work->y_mean_batch          = view_2d("y_mean_batch", chunk_size, n_obs_vertex_copy);
935:     eigen_work->r_inv_sqrt_batch      = view_2d("r_inv_sqrt_batch", chunk_size, n_obs_vertex_copy);
936:     eigen_work->temp1_batch           = view_2d("temp1_batch", chunk_size, m);
937:     eigen_work->temp2_batch           = view_2d("temp2_batch", chunk_size, m);
938:     eigen_work->inv_sqrt_lambda_batch = view_2d("inv_sqrt_lambda_batch", chunk_size, m);

940:     /* Allocate solver workspace */
941: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
942:   #if defined(KOKKOS_ENABLE_CUDA)
943:     {
944:       /* Create syevj params */
945:       cusolver_status = cusolverDnCreateSyevjInfo(&eigen_work->syevj_params);
946:       PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDnCreateSyevjInfo failed");

948:       /* Set default params */
949:       cusolverDnXsyevjSetTolerance(eigen_work->syevj_params, 1e-7);
950:       cusolverDnXsyevjSetMaxSweeps(eigen_work->syevj_params, 100);
951:       cusolverDnXsyevjSetSortEig(eigen_work->syevj_params, 1); /* Sort eigenvalues */

953:       /* Query workspace size */
954:       PetscScalar *d_A = eigen_work->T_batch.data();
955:       PetscScalar *d_W = eigen_work->Lambda_batch.data();
956:       int          lwork;
957:     #if defined(PETSC_USE_REAL_SINGLE)
958:       cusolver_status = cusolverDnSsyevjBatched_bufferSize(device_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, m, d_A, m, d_W, &lwork, eigen_work->syevj_params, chunk_size);
959:     #else
960:       cusolver_status = cusolverDnDsyevjBatched_bufferSize(device_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, m, d_A, m, d_W, &lwork, eigen_work->syevj_params, chunk_size);
961:     #endif
962:       PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDn*syevjBatched_bufferSize failed");
963:       eigen_work->lwork_device = lwork;

965:       /* Allocate workspace */
966:       PetscCallCUDA(cudaMalloc(&eigen_work->d_work, sizeof(PetscScalar) * lwork));
967:       PetscCallCUDA(cudaMalloc(&eigen_work->d_info, sizeof(int) * chunk_size));
968:       PetscCallCUDA(cudaMalloc(&eigen_work->d_A_contig, sizeof(PetscScalar) * chunk_size * m * m));
969:       PetscCallCUDA(cudaMalloc(&eigen_work->d_W_contig, sizeof(PetscScalar) * chunk_size * m));
970:     }
971:   #elif defined(KOKKOS_ENABLE_HIP)
972:     {
973:         /* rocsolver_dsyevd does not support size query via -1.
974:          We use a safe upper bound estimate based on LAPACK dsyevd requirements.
975:       */
976:     #if defined(PETSC_USE_COMPLEX)
977:       int lwork = 0; /* Complex not supported on device */
978:     #else
979:       int lwork = 1 + 6 * m + 2 * m * m;
980:     #endif
981:       eigen_work->lwork_device = lwork;

983:       /* Allocate workspace */
984:       if (lwork > 0) {
985:         PetscCallHIP(hipMalloc(&eigen_work->d_work, sizeof(PetscScalar) * lwork));
986:         PetscCallHIP(hipMalloc(&eigen_work->d_info, sizeof(int) * chunk_size));
987:         PetscCallHIP(hipMalloc(&eigen_work->d_A_contig, sizeof(PetscScalar) * chunk_size * m * m));
988:         PetscCallHIP(hipMalloc(&eigen_work->d_W_contig, sizeof(PetscScalar) * chunk_size * m));
989:       }
990:     }
991:   #elif defined(KOKKOS_ENABLE_SYCL)
992:     {
993:       /* Query workspace size for oneapi::mkl::lapack::syevd */
994:       /* For syevd, workspace size is typically: */
995:       /* lwork >= 1 + 6*n + 2*n*n for real, or */
996:       /* lwork >= 2*n + n*n for complex */
997:       int lwork;
998:     #if defined(PETSC_USE_COMPLEX)
999:       lwork = 2 * m + m * m;
1000:     #else
1001:       lwork = 1 + 6 * m + 2 * m * m;
1002:     #endif
1003:       eigen_work->lwork_device = lwork;

1005:       /* Allocate workspace using SYCL malloc_device */
1006:       eigen_work->d_work     = sycl::malloc_device<PetscScalar>(lwork, *device_handle);
1007:       eigen_work->d_info     = sycl::malloc_device<int>(chunk_size, *device_handle);
1008:       eigen_work->d_A_contig = sycl::malloc_device<PetscScalar>(chunk_size * m * m, *device_handle);
1009:       eigen_work->d_W_contig = sycl::malloc_device<PetscScalar>(chunk_size * m, *device_handle);
1010:       PetscCheck(eigen_work->d_work && eigen_work->d_info && eigen_work->d_A_contig && eigen_work->d_W_contig, PETSC_COMM_SELF, PETSC_ERR_MEM, "SYCL memory allocation failed");
1011:     }
1012:   #endif
1013: #else
1014:     {
1015:       PetscBLASInt n_blas;
1016:       PetscCall(PetscBLASIntCast(m, &n_blas));
1017:       eigen_work->n_blas = n_blas;

1019:       /* Query workspace size */
1020:       PetscBLASInt lwork_query = -1;
1021:       PetscScalar  work_query;
1022:       PetscBLASInt info;
1023:   #if defined(PETSC_USE_COMPLEX)
1024:       PetscReal rwork_query;
1025:       LAPACKsyev_("V", "U", &n_blas, &work_query, &n_blas, &rwork_query, &work_query, &lwork_query, &rwork_query, &info);
1026:   #else
1027:       LAPACKsyev_("V", "U", &n_blas, &work_query, &n_blas, &work_query, &work_query, &lwork_query, &info);
1028:   #endif
1029:       PetscCheck(info == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "LAPACK workspace query failed");
1030:       eigen_work->lwork = (PetscBLASInt)PetscRealPart(work_query);

1032:       /* Allocate workspace */
1033:   #if defined(PETSC_USE_COMPLEX)
1034:       PetscCall(PetscMalloc4(chunk_size * m * m, &eigen_work->all_v, chunk_size * m, &eigen_work->all_lambda, chunk_size * eigen_work->lwork, &eigen_work->all_work, chunk_size * (3 * m - 2), &eigen_work->all_rwork));
1035:   #else
1036:       PetscCall(PetscMalloc3(chunk_size * m * m, &eigen_work->all_v, chunk_size * m, &eigen_work->all_lambda, chunk_size * eigen_work->lwork, &eigen_work->all_work));
1037:   #endif
1038:     }
1039: #endif
1040:   }

1042:   /* Create aliases for current function use */
1043:   view_3d Z_batch_alloc               = eigen_work->Z_batch;
1044:   view_3d S_batch_alloc               = eigen_work->S_batch;
1045:   view_3d T_batch_alloc               = eigen_work->T_batch;
1046:   view_3d V_batch_alloc               = eigen_work->V_batch;
1047:   view_2d Lambda_batch_alloc          = eigen_work->Lambda_batch;
1048:   view_3d T_sqrt_batch_alloc          = eigen_work->T_sqrt_batch;
1049:   view_2d w_batch_alloc               = eigen_work->w_batch;
1050:   view_2d delta_batch_alloc           = eigen_work->delta_batch;
1051:   view_2d y_batch_alloc               = eigen_work->y_batch;
1052:   view_2d y_mean_batch_alloc          = eigen_work->y_mean_batch;
1053:   view_2d r_inv_sqrt_batch_alloc      = eigen_work->r_inv_sqrt_batch;
1054:   view_2d temp1_batch_alloc           = eigen_work->temp1_batch;
1055:   view_2d temp2_batch_alloc           = eigen_work->temp2_batch;
1056:   view_2d inv_sqrt_lambda_batch_alloc = eigen_work->inv_sqrt_lambda_batch;

1058:   /* Loop over chunks */
1059:   for (PetscInt chunk_start = 0; chunk_start < n_vertices; chunk_start += chunk_size) {
1060:     PetscInt chunk_end       = (chunk_start + chunk_size > n_vertices) ? n_vertices : chunk_start + chunk_size;
1061:     PetscInt n_batch_current = chunk_end - chunk_start;

1063:     /* Create subviews for current batch size */
1064:     auto Z_batch               = Kokkos::subview(Z_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1065:     auto S_batch               = Kokkos::subview(S_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1066:     auto T_batch               = Kokkos::subview(T_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1067:     auto V_batch               = Kokkos::subview(V_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1068:     auto Lambda_batch          = Kokkos::subview(Lambda_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1069:     auto T_sqrt_batch          = Kokkos::subview(T_sqrt_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1070:     auto w_batch               = Kokkos::subview(w_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1071:     auto delta_batch           = Kokkos::subview(delta_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1072:     auto y_batch               = Kokkos::subview(y_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1073:     auto y_mean_batch          = Kokkos::subview(y_mean_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1074:     auto r_inv_sqrt_batch      = Kokkos::subview(r_inv_sqrt_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1075:     auto temp1_batch           = Kokkos::subview(temp1_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1076:     auto temp2_batch           = Kokkos::subview(temp2_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1077:     auto inv_sqrt_lambda_batch = Kokkos::subview(inv_sqrt_lambda_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());

1079:     /* ===================================================================== */
1080:     /* Step 2.1.2: Fused observation extraction and S/Delta computation     */
1081:     /* ===================================================================== */
1082:     /* Extract local observations and immediately compute S and delta       */
1083:     /* This fusion eliminates one kernel launch and improves cache locality */
1084:     Kokkos::parallel_for(
1085:       "ExtractAndComputeSAndDelta", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i_local) {
1086:         PetscInt i_global = chunk_start + i_local;
1087:         /* Get Q row for this grid point using CSR format */
1088:         PetscInt row_start = Q_i_view(i_global);
1089:         PetscInt row_end   = Q_i_view(i_global + 1);
1090:         PetscInt ncols     = row_end - row_start;

1092:         /* Extract observations and compute S/delta for this grid point */
1093:         for (PetscInt k = 0; k < ncols; k++) {
1094:           PetscInt    obs_idx = Q_j_view(row_start + k);
1095:           PetscScalar weight  = Q_a_view(row_start + k);

1097:           /* Extract observation vectors */
1098:           PetscScalar y_val      = y_global_view(obs_idx);
1099:           PetscScalar y_mean_val = y_mean_global_view(obs_idx);
1100:           PetscScalar r_inv_sqrt = r_inv_sqrt_global_view(obs_idx) * Kokkos::sqrt(PetscRealPart(weight));

1102:           /* Store for later use if needed */
1103:           y_batch(i_local, k)          = y_val;
1104:           y_mean_batch(i_local, k)     = y_mean_val;
1105:           r_inv_sqrt_batch(i_local, k) = r_inv_sqrt;

1107:           /* Compute delta immediately: delta = R^{-1/2}(y - y_mean) */
1108:           delta_batch(i_local, k) = (y_val - y_mean_val) * r_inv_sqrt;

1110:           /* Compute S row: S = R^{-1/2}(Z - y_mean * 1')/sqrt(m-1) */
1111:           PetscScalar scale_factor = scale * r_inv_sqrt;
1112:           for (int j = 0; j < m; j++) {
1113:             PetscScalar z_val      = Z_global_view(obs_idx, j);
1114:             Z_batch(i_local, k, j) = z_val; /* Store Z for potential later use */
1115:             S_batch(i_local, k, j) = (z_val - y_mean_val) * scale_factor;
1116:           }
1117:         }
1118:       });
1119:     Kokkos::fence();

1121:     /* DEBUG: Check S for NaNs */
1122:     if (PetscDefined(USE_DEBUG)) {
1123:       PetscInt nan_count = 0;
1124:       Kokkos::parallel_reduce(
1125:         "CheckS", Kokkos::RangePolicy<exec_space>(0, n_batch_current),
1126:         KOKKOS_LAMBDA(const int i, int &l_count) {
1127:           for (int j = 0; j < n_obs_vertex_copy; j++) {
1128:             for (int k = 0; k < m; k++) {
1129:               if (S_batch(i, j, k) != S_batch(i, j, k)) l_count++;
1130:             }
1131:           }
1132:         },
1133:         nan_count);
1134:       PetscCheck(nan_count == 0, PETSC_COMM_SELF, PETSC_ERR_FP, "Found %" PetscInt_FMT " NaNs in S_batch at chunk_start %" PetscInt_FMT, nan_count, chunk_start);
1135:     }

1137:     /* ===================================================================== */
1138:     /* Step 2.1.4: Optimized T matrix formation (T = (1/rho)I + S^T * S)    */
1139:     /* ===================================================================== */
1140:     /* Compute T_i = (1/rho)I + S_i^T * S_i for current chunk */
1141:     /* Exploit symmetry: only compute upper triangle, then copy to lower */
1142:     /* This reduces operations by ~50% */
1143:     Kokkos::parallel_for(
1144:       "ComputeAllTMatrices", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i) {
1145:         auto S_i = Kokkos::subview(S_batch, i, Kokkos::ALL(), Kokkos::ALL());
1146:         auto T_i = Kokkos::subview(T_batch, i, Kokkos::ALL(), Kokkos::ALL());

1148:         /* Compute upper triangle of T_i = (1/rho)I + S_i^T * S_i */
1149:         /* T_i(j,k) = (1/rho)*delta_jk + sum_p S_i(p,j) * S_i(p,k) for j <= k */
1150:         for (int j = 0; j < m; j++) {
1151:           for (int k = j; k < m; k++) {
1152:             PetscScalar sum = (j == k) ? inflation_inv : 0.0;
1153:             for (int p = 0; p < n_obs_vertex_copy; p++) sum += S_i(p, j) * S_i(p, k);
1154:             T_i(j, k) = sum;
1155:           }
1156:         }

1158:         /* Copy upper triangle to lower triangle (T is symmetric) */
1159:         for (int j = 0; j < m; j++) {
1160:           for (int k = 0; k < j; k++) T_i(j, k) = T_i(k, j);
1161:         }
1162:       });
1163:     Kokkos::fence();

1165:     /* DEBUG: Check T for NaNs */
1166:     if (PetscDefined(USE_DEBUG)) {
1167:       PetscInt nan_count = 0;
1168:       Kokkos::parallel_reduce(
1169:         "CheckT", Kokkos::RangePolicy<exec_space>(0, n_batch_current),
1170:         KOKKOS_LAMBDA(const int i, int &l_count) {
1171:           for (int j = 0; j < m; j++) {
1172:             for (int k = 0; k < m; k++) {
1173:               if (T_batch(i, j, k) != T_batch(i, j, k)) l_count++;
1174:             }
1175:           }
1176:         },
1177:         nan_count);
1178:       PetscCheck(nan_count == 0, PETSC_COMM_SELF, PETSC_ERR_FP, "Found %" PetscInt_FMT " NaNs in T_batch at chunk_start %" PetscInt_FMT, nan_count, chunk_start);
1179:     }

1181:     /* ===================================================================== */
1182:     /* Step 3.1.1: Batched eigendecomposition for current chunk            */
1183:     /* ===================================================================== */
1184:     /* Compute T_i = V_i * Lambda_i * V_i^T for current chunk */
1185: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
1186:     PetscCall(BatchedEigenSolve(T_batch, Lambda_batch, V_batch, n_batch_current, m, device_handle, eigen_work));
1187: #else
1188:     PetscCall(BatchedEigenSolve(T_batch, Lambda_batch, V_batch, n_batch_current, m, eigen_work));
1189: #endif

1191:     /* DEBUG: Check Lambda for NaNs or negative values */
1192:     if (PetscDefined(USE_DEBUG)) {
1193:       PetscInt bad_lambda = 0;
1194:       Kokkos::parallel_reduce(
1195:         "CheckLambda", Kokkos::RangePolicy<exec_space>(0, n_batch_current),
1196:         KOKKOS_LAMBDA(const int i, int &l_count) {
1197:           for (int k = 0; k < m; k++) {
1198:             if (Lambda_batch(i, k) != Lambda_batch(i, k) || PetscRealPart(Lambda_batch(i, k)) < -1e-8) l_count++;
1199:           }
1200:         },
1201:         bad_lambda);
1202:       PetscCheck(bad_lambda == 0, PETSC_COMM_SELF, PETSC_ERR_FP, "Found %" PetscInt_FMT " bad eigenvalues (NaN or negative) at chunk_start %" PetscInt_FMT, bad_lambda, chunk_start);
1203:     }

1205:     /* ===================================================================== */
1206:     /* Step 3.1.2: Precompute w and inv_sqrt_lambda for ensemble update    */
1207:     /* ===================================================================== */
1208:     /* Compute w_i = T_i^{-1} * (S_i^T * delta_i) using eigendecomposition */
1209:     /* Precompute 1/sqrt(Lambda) for use in ensemble update */
1210:     Kokkos::parallel_for(
1211:       "ComputeWeightsAndInvSqrtLambda", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i) {
1212:         auto S_i               = Kokkos::subview(S_batch, i, Kokkos::ALL(), Kokkos::ALL());
1213:         auto V_i               = Kokkos::subview(V_batch, i, Kokkos::ALL(), Kokkos::ALL());
1214:         auto Lambda_i          = Kokkos::subview(Lambda_batch, i, Kokkos::ALL());
1215:         auto delta_i           = Kokkos::subview(delta_batch, i, Kokkos::ALL());
1216:         auto w_i               = Kokkos::subview(w_batch, i, Kokkos::ALL());
1217:         auto inv_sqrt_lambda_i = Kokkos::subview(inv_sqrt_lambda_batch, i, Kokkos::ALL());
1218:         auto temp1             = Kokkos::subview(temp1_batch, i, Kokkos::ALL());
1219:         auto temp2             = Kokkos::subview(temp2_batch, i, Kokkos::ALL());

1221:         /* 1. Compute w_i = V * L^-1 * V^T * S^T * delta */
1222:         /* Step 1a: temp1 = S^T * delta using KokkosBlas::gemv for better vectorization */
1223:         KokkosBlas::SerialGemv<KokkosBlas::Trans::Transpose, KokkosBlas::Algo::Gemv::Unblocked>::invoke(1.0, S_i, delta_i, 0.0, temp1);

1225:         /* Step 1b: temp2 = V^T * temp1 using KokkosBlas::gemv for better vectorization */
1226:         KokkosBlas::SerialGemv<KokkosBlas::Trans::Transpose, KokkosBlas::Algo::Gemv::Unblocked>::invoke(1.0, V_i, temp1, 0.0, temp2);

1228:         /* Step 1c: temp2 = temp2 / Lambda */
1229:         for (int j = 0; j < m; j++) temp2(j) /= (Lambda_i(j) + 1.0e-14);

1231:         /* Step 1d: w = V * temp2 using KokkosBlas::gemv for better vectorization */
1232:         KokkosBlas::SerialGemv<KokkosBlas::Trans::NoTranspose, KokkosBlas::Algo::Gemv::Unblocked>::invoke(1.0, V_i, temp2, 0.0, w_i);

1234:         /* 2. Precompute 1/sqrt(Lambda) for ensemble update */
1235:         for (int p = 0; p < m; p++) inv_sqrt_lambda_i(p) = 1.0 / Kokkos::sqrt(PetscRealPart(Lambda_i(p)) + 1.0e-14);
1236:       });
1237:     Kokkos::fence();

1239:     /* ===================================================================== */
1240:     /* Step 3.1.3: Fused G computation and ensemble update                  */
1241:     /* ===================================================================== */
1242:     /* Compute E[i,:] = mean[i] + X[i,:] * G_i on-the-fly */
1243:     /* G_i is computed column-by-column and immediately applied */
1244:     /* This eliminates the need to store G_batch, saving m*m*n_batch memory */
1245:     Kokkos::parallel_for(
1246:       "FusedGComputeAndEnsembleUpdate", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i_local) {
1247:         PetscInt i_global = chunk_start + i_local;

1249:         auto X_i    = Kokkos::subview(X_view, Kokkos::make_pair(i_global * ndof, (i_global + 1) * ndof), Kokkos::ALL());
1250:         auto E_i    = Kokkos::subview(E_view, Kokkos::make_pair(i_global * ndof, (i_global + 1) * ndof), Kokkos::ALL());
1251:         auto mean_i = Kokkos::subview(mean_view, Kokkos::make_pair(i_global * ndof, (i_global + 1) * ndof));

1253:         auto V_i               = Kokkos::subview(V_batch, i_local, Kokkos::ALL(), Kokkos::ALL());
1254:         auto w_i               = Kokkos::subview(w_batch, i_local, Kokkos::ALL());
1255:         auto inv_sqrt_lambda_i = Kokkos::subview(inv_sqrt_lambda_batch, i_local, Kokkos::ALL());
1256:         auto T_sqrt_i          = Kokkos::subview(T_sqrt_batch, i_local, Kokkos::ALL(), Kokkos::ALL());

1258:         /* Initialize E_i with mean */
1259:         for (int row = 0; row < ndof; row++) {
1260:           PetscScalar m_val = mean_i(row);
1261:           for (int col = 0; col < m; col++) E_i(row, col) = m_val;
1262:         }

1264:         /* Compute T_sqrt = V * diag(1/sqrt(Lambda)) * V^T */
1265:         /* Optimized: Exploit symmetry - only compute upper triangle, then copy to lower */
1266:         /* T_sqrt(j,k) = sum_p V(j,p) * V(k,p) / sqrt(Lambda(p)) for j <= k */
1267:         for (int j = 0; j < m; j++) {
1268:           for (int k = j; k < m; k++) {
1269:             PetscScalar sum = 0.0;
1270:             for (int p = 0; p < m; p++) sum += V_i(j, p) * V_i(k, p) * inv_sqrt_lambda_i(p);
1271:             T_sqrt_i(j, k) = sum;
1272:           }
1273:         }
1274:         /* Copy upper triangle to lower triangle (T_sqrt is symmetric) */
1275:         for (int j = 0; j < m; j++) {
1276:           for (int k = 0; k < j; k++) T_sqrt_i(j, k) = T_sqrt_i(k, j);
1277:         }

1279:         /* Compute E_i += X_i * G_i column-by-column */
1280:         /* G_i(:,k) = w_i + sqrt(m-1) * T_sqrt_i(:,k) */
1281:         for (int k = 0; k < m; k++) {
1282:           /* Compute column k of G on-the-fly */
1283:           for (int row = 0; row < ndof; row++) {
1284:             PetscScalar sum = 0.0;
1285:             for (int j = 0; j < m; j++) {
1286:               /* G_i(j,k) = w_i(j) + sqrt(m-1) * T_sqrt_i(j,k) */
1287:               PetscScalar G_jk = w_i(j) + sqrt_m_minus_1 * T_sqrt_i(j, k);
1288:               sum += X_i(row, j) * G_jk;
1289:             }
1290:             E_i(row, k) += sum;
1291:           }
1292:         }
1293:       });
1294:     Kokkos::fence();
1295:   }

1297:   /* Cleanup workspace */
1298:   /* NOTE: Workspace is now persistent in impl->eigen_work and impl->solver_handle */
1299:   /* It will be destroyed in PetscDALETKFDestroyLocalization_Kokkos */

1301:   /* Copy back updated ensemble if needed */
1302:   if (e_is_copy) {
1303:     Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> dst(e_array, lda_e, m);
1304:     Kokkos::deep_copy(dst, e_managed);
1305:   }

1307:   /* Restore arrays */
1308:   PetscCall(MatDenseRestoreArrayWriteAndMemType(en->ensemble, &e_array));
1309:   PetscCall(VecRestoreArrayReadAndMemType(impl->mean, &mean_array));
1310:   PetscCall(MatDenseRestoreArrayReadAndMemType(X, &x_array));

1312:   /* Restore global observation arrays */
1313:   PetscCall(VecRestoreArrayReadAndMemType(r_inv_sqrt_global, &r_inv_sqrt_global_array));
1314:   PetscCall(VecRestoreArrayReadAndMemType(y_mean_global, &y_mean_global_array));
1315:   PetscCall(VecRestoreArrayReadAndMemType(observation, &y_global_array));
1316:   PetscCall(MatDenseRestoreArrayReadAndMemType(Z_global, &z_global_array));

1318:   /* Ensemble has been updated in batched form above */
1319:   PetscCall(MatAssemblyBegin(en->ensemble, MAT_FINAL_ASSEMBLY));
1320:   PetscCall(MatAssemblyEnd(en->ensemble, MAT_FINAL_ASSEMBLY));

1322:   {
1323:     MatInfo   info;
1324:     PetscReal flops = 0.0;
1325:     PetscReal n_obs_total;

1327:     if (impl->Q) {
1328:       PetscCall(MatGetInfo(impl->Q, MAT_LOCAL, &info));
1329:       n_obs_total = info.nz_used;
1330:     } else {
1331:       n_obs_total = 0.0;
1332:     }

1334:     /* Step 2.1.2: Fused observation extraction and S/Delta computation */
1335:     flops += n_obs_total * (2.0 + 2.0 * m);

1337:     /* Step 2.1.4: Optimized T matrix formation */
1338:     flops += (PetscReal)n_vertices * m * (m + 1) * impl->n_obs_vertex;

1340:     /* Step 3.1.2: Precompute w and inv_sqrt_lambda */
1341:     flops += (PetscReal)n_vertices * (2.0 * m * impl->n_obs_vertex + 4.0 * m * m + 3.0 * m);

1343:     /* Step 3.1.3: Fused G computation and ensemble update */
1344:     /* T_sqrt: 1.5*m^3 + 1.5*m^2 */
1345:     flops += (PetscReal)n_vertices * (1.5 * m * m * m + 1.5 * m * m);
1346:     /* E update: ndof * m * (4*m + 1) */
1347:     /* Note: G_jk computation (2 flops) is inside the inner loop, so it's 2*m*ndof*m */
1348:     /* Matrix product X*G (2 flops) is also 2*m*ndof*m */
1349:     flops += (PetscReal)n_vertices * ndof * m * (4.0 * m + 1.0);

1351:     PetscCall(PetscLogGpuFlops(flops));
1352:   }
1353:   PetscFunctionReturn(PETSC_SUCCESS);
1354: }