Actual source code: letkf_local_analysis.kokkos.cxx
1: #include "../src/ml/da/impls/ensemble/letkf/letkf.h"
2: #include <petscblaslapack.h>
3: #include <Kokkos_Core.hpp>
4: #include <KokkosBlas.hpp>
5: #include <KokkosBatched_SVD_Decl.hpp>
6: #include <KokkosBatched_SVD_Serial_Impl.hpp>
7: #include <KokkosBatched_Gemm_Decl.hpp>
8: #include <KokkosBatched_Gemm_Serial_Impl.hpp>
9: #include <KokkosBatched_Util.hpp>
11: #if defined(KOKKOS_ENABLE_CUDA)
12: #include <cusolverDn.h>
13: #include <cuda_runtime.h>
14: #include <petscdevice_cuda.h>
15: #elif defined(KOKKOS_ENABLE_HIP)
16: #include <rocsolver/rocsolver.h>
17: #include <hip/hip_runtime.h>
18: #include <petscdevice_hip.h>
19: #elif defined(KOKKOS_ENABLE_SYCL)
20: #include <oneapi/mkl.hpp>
21: #include <sycl/sycl.hpp>
22: #include <petscdevice_sycl.h>
23: #endif
25: /* ========================================================================== */
26: /* Batched Eigendecomposition for LETKF */
27: /* ========================================================================== */
29: /* Structure to hold reusable workspace for eigensolvers */
30: struct EigenWorkspace {
31: /* Tracking for reuse */
32: PetscInt max_chunk_size;
33: PetscInt m;
34: PetscInt n_obs_vertex;
36: /* Persistent Kokkos Views */
37: using exec_space = Kokkos::DefaultExecutionSpace;
38: using view_3d = Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, exec_space>;
39: using view_2d = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>;
41: view_3d Z_batch;
42: view_3d S_batch;
43: view_3d T_batch;
44: view_3d V_batch;
45: view_2d Lambda_batch;
46: view_3d T_sqrt_batch;
47: view_2d w_batch;
48: view_2d delta_batch;
49: view_2d y_batch;
50: view_2d y_mean_batch;
51: view_2d r_inv_sqrt_batch;
52: view_2d temp1_batch;
53: view_2d temp2_batch;
54: view_2d inv_sqrt_lambda_batch;
56: /* Host workspace */
57: PetscScalar *all_v;
58: PetscReal *all_lambda;
59: PetscScalar *all_work;
60: #if defined(PETSC_USE_COMPLEX)
61: PetscReal *all_rwork;
62: #endif
63: PetscBLASInt lwork;
64: PetscBLASInt n_blas;
66: /* Device workspace */
67: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
68: #if defined(KOKKOS_ENABLE_CUDA)
69: syevjInfo_t syevj_params;
70: PetscScalar *d_work;
71: int *d_info;
72: PetscScalar *d_A_contig;
73: PetscScalar *d_W_contig;
74: int lwork_device;
75: #elif defined(KOKKOS_ENABLE_HIP)
76: PetscScalar *d_work;
77: int *d_info;
78: PetscScalar *d_A_contig;
79: PetscScalar *d_W_contig;
80: int lwork_device;
81: #elif defined(KOKKOS_ENABLE_SYCL)
82: PetscScalar *d_work;
83: int *d_info;
84: PetscScalar *d_A_contig;
85: PetscScalar *d_W_contig;
86: int lwork_device;
87: #endif
88: #endif
90: EigenWorkspace() : max_chunk_size(0), m(0), n_obs_vertex(0), all_v(nullptr), all_lambda(nullptr), all_work(nullptr)
91: {
92: #if defined(PETSC_USE_COMPLEX)
93: all_rwork = nullptr;
94: #endif
95: #if defined(KOKKOS_ENABLE_CUDA)
96: d_work = nullptr;
97: d_info = nullptr;
98: d_A_contig = nullptr;
99: d_W_contig = nullptr;
100: syevj_params = nullptr;
101: #elif defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
102: d_work = nullptr;
103: d_info = nullptr;
104: d_A_contig = nullptr;
105: d_W_contig = nullptr;
106: #endif
107: }
108: };
110: /*
111: BatchedEigenSolve_Host - Compute eigendecomposition for a batch of symmetric matrices (CPU version)
113: Input Parameters:
114: + T_batch - batch of symmetric matrices (n_batch x n_size x n_size)
115: . n_batch - number of matrices in the batch
116: - n_size - size of each matrix (m x m)
117: - work - reusable workspace structure
119: Output Parameters:
120: + Lambda_batch - eigenvalues for each matrix (n_batch x n_size)
121: - V_batch - eigenvectors for each matrix (n_batch x n_size x n_size)
123: Notes:
124: Uses LAPACK's syev routine to compute eigendecomposition sequentially on host.
125: */
126: #if !defined(KOKKOS_ENABLE_CUDA) && !defined(KOKKOS_ENABLE_HIP) && !defined(KOKKOS_ENABLE_SYCL)
127: static PetscErrorCode BatchedEigenSolve_Host(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, EigenWorkspace *work)
128: {
129: PetscFunctionBegin;
130: /* Create host mirrors and copy data in one operation */
131: /* This is required for HIP+complex where create_mirror_view + deep_copy fails */
132: auto T_host = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), T_batch);
133: auto Lambda_host = Kokkos::create_mirror_view(Kokkos::HostSpace(), Lambda_batch);
134: auto V_host = Kokkos::create_mirror_view(Kokkos::HostSpace(), V_batch);
136: /* Use pre-allocated workspace */
137: PetscScalar *all_v = work->all_v;
138: PetscReal *all_lambda = work->all_lambda;
139: PetscScalar *all_work = work->all_work;
140: PetscBLASInt lwork = work->lwork;
141: PetscBLASInt n_blas = work->n_blas;
142: #if defined(PETSC_USE_COMPLEX)
143: PetscReal *all_rwork = work->all_rwork;
144: #endif
146: /* Process each matrix in parallel on host using LAPACK */
147: Kokkos::parallel_for(
148: "BatchedEigenSolve_Host", Kokkos::RangePolicy<Kokkos::DefaultHostExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
149: PetscBLASInt n = n_blas;
150: PetscBLASInt lda = n;
151: PetscBLASInt info;
152: PetscBLASInt lw = lwork;
154: /* Pointers for this matrix */
155: PetscScalar *v_ptr = all_v + i * n_size * n_size;
156: PetscReal *lambda_ptr = all_lambda + i * n_size;
157: PetscScalar *work_ptr = all_work + i * lwork;
158: #if defined(PETSC_USE_COMPLEX)
159: PetscReal *rwork_ptr = all_rwork + i * (3 * n_size - 2);
160: #endif
162: /* Copy T_host(i, :, :) to v_ptr (column-major) */
163: for (PetscInt j = 0; j < n_size; j++) {
164: for (PetscInt k = 0; k < n_size; k++) v_ptr[k + j * n_size] = T_host(i, k, j);
165: }
167: /* Compute eigendecomposition: T = V * Lambda * V^T */
168: #if defined(PETSC_USE_COMPLEX)
169: LAPACKsyev_("V", "U", &n, v_ptr, &lda, lambda_ptr, work_ptr, &lw, rwork_ptr, &info);
170: #else
171: LAPACKsyev_("V", "U", &n, v_ptr, &lda, lambda_ptr, work_ptr, &lw, &info);
172: #endif
174: if (info != 0) {
175: /* We cannot return error code from lambda, so we just abort or ignore.
176: In production code, we should use a reduction to report errors. */
177: Kokkos::abort("LAPACK eigendecomposition failed in parallel region");
178: }
180: /* Copy results back to host views */
181: for (PetscInt j = 0; j < n_size; j++) {
182: Lambda_host(i, j) = (PetscScalar)lambda_ptr[j];
183: for (PetscInt k = 0; k < n_size; k++) V_host(i, k, j) = v_ptr[k + j * n_size];
184: }
185: });
187: /* Copy results back to device */
188: Kokkos::deep_copy(Lambda_batch, Lambda_host);
189: Kokkos::deep_copy(V_batch, V_host);
190: PetscFunctionReturn(PETSC_SUCCESS);
191: }
192: #endif
194: /*
195: BatchedEigenSolve_Device - Compute eigendecomposition for a batch of symmetric matrices (Device version)
197: Input Parameters:
198: + T_batch - batch of symmetric matrices (n_batch x n_size x n_size)
199: . n_batch - number of matrices in the batch
200: - n_size - size of each matrix (m x m)
201: - device_handle - device-specific solver handle (cusolverDnHandle_t, rocblas_handle, or sycl::queue*)
202: - work - reusable workspace structure
204: Output Parameters:
205: + Lambda_batch - eigenvalues for each matrix (n_batch x n_size)
206: - V_batch - eigenvectors for each matrix (n_batch x n_size x n_size)
208: Notes:
209: Uses vendor-specific batched symmetric eigensolvers:
210: - CUDA: cuSOLVER's syevjBatched
211: - HIP: rocSOLVER's rocsolver_dsyevj_batched
212: - SYCL: oneMKL's syevd_batch
213: */
214: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
215: #if defined(KOKKOS_ENABLE_CUDA)
216: static PetscErrorCode BatchedEigenSolve_Device(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, cusolverDnHandle_t cusolverH, EigenWorkspace *work)
217: {
218: cusolverStatus_t cusolver_status;
220: PetscFunctionBegin;
221: /* Use pre-allocated workspace */
222: syevjInfo_t syevj_params = work->syevj_params;
223: PetscScalar *d_work = work->d_work;
224: int *d_info = work->d_info;
225: PetscScalar *d_A_contig = work->d_A_contig;
226: PetscScalar *d_W_contig = work->d_W_contig;
227: int lwork = work->lwork_device;
229: /* Copy T_batch to contiguous layout for cuSOLVER */
230: Kokkos::parallel_for(
231: "ReorganizeForCuSOLVER", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
232: for (int j = 0; j < n_size; j++) {
233: for (int k = 0; k < n_size; k++) d_A_contig[i * n_size * n_size + k * n_size + j] = T_batch(i, j, k);
234: }
235: });
236: Kokkos::fence();
238: /* Solve batched eigendecomposition */
239: #if defined(PETSC_USE_REAL_SINGLE)
240: cusolver_status = cusolverDnSsyevjBatched(cusolverH, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, n_size, d_A_contig, n_size, d_W_contig, d_work, lwork, d_info, syevj_params, n_batch);
241: #else
242: cusolver_status = cusolverDnDsyevjBatched(cusolverH, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, n_size, d_A_contig, n_size, d_W_contig, d_work, lwork, d_info, syevj_params, n_batch);
243: #endif
244: PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDn*syevjBatched failed");
246: /* Check info */
247: int *h_info;
248: PetscCall(PetscMalloc1(n_batch, &h_info));
249: PetscCallCUDA(cudaMemcpy(h_info, d_info, sizeof(int) * n_batch, cudaMemcpyDeviceToHost));
250: for (int i = 0; i < n_batch; i++) {
251: if (h_info[i] != 0) PetscCheck(h_info[i] == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "cuSOLVER eigendecomposition failed for matrix %" PetscInt_FMT ": info=%d", i, h_info[i]);
252: }
253: PetscCall(PetscFree(h_info));
255: /* Copy results back from contiguous layout to V_batch */
256: Kokkos::parallel_for(
257: "CopyResultsBack", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
258: for (int j = 0; j < n_size; j++) {
259: Lambda_batch(i, j) = d_W_contig[i * n_size + j];
260: for (int k = 0; k < n_size; k++) V_batch(i, j, k) = d_A_contig[i * n_size * n_size + k * n_size + j];
261: }
262: });
263: Kokkos::fence();
264: PetscFunctionReturn(PETSC_SUCCESS);
265: }
266: #elif defined(KOKKOS_ENABLE_HIP)
267: static PetscErrorCode BatchedEigenSolve_Device(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, rocblas_handle rocblasH, EigenWorkspace *work)
268: {
269: PetscFunctionBegin;
270: /* Use pre-allocated workspace */
271: PetscScalar *d_work = work->d_work;
272: (void)d_work;
273: int *d_info = work->d_info;
274: PetscScalar *d_A_contig = work->d_A_contig;
275: PetscScalar *d_W_contig = work->d_W_contig;
277: /* Copy T_batch to contiguous layout for rocSOLVER */
278: Kokkos::parallel_for(
279: "ReorganizeForRocSOLVER", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
280: for (int j = 0; j < n_size; j++) {
281: for (int k = 0; k < n_size; k++) d_A_contig[i * n_size * n_size + k * n_size + j] = T_batch(i, j, k);
282: }
283: });
284: Kokkos::fence();
286: /* rocSOLVER doesn't have a native batched syevj, so we loop over batch */
287: /* Use rocsolver_dsyevd which is more efficient than calling syev in a loop */
288: #if defined(PETSC_USE_COMPLEX)
289: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_SUP, "Complex numbers not supported on HIP backend for LETKF");
290: #else
291: for (int i = 0; i < n_batch; i++) {
292: PetscScalar *A_ptr = d_A_contig + i * n_size * n_size;
293: PetscScalar *W_ptr = d_W_contig + i * n_size;
294: int *info_ptr = d_info + i;
295: rocblas_status hip_status;
297: #if defined(PETSC_USE_REAL_SINGLE)
298: hip_status = rocsolver_ssyevd(rocblasH, rocblas_evect_original, rocblas_fill_upper, n_size, A_ptr, n_size, W_ptr, d_work, info_ptr);
299: #else
300: hip_status = rocsolver_dsyevd(rocblasH, rocblas_evect_original, rocblas_fill_upper, n_size, A_ptr, n_size, W_ptr, d_work, info_ptr);
301: #endif
302: PetscCheck(hip_status == rocblas_status_success, PETSC_COMM_SELF, PETSC_ERR_LIB, "rocsolver_*syevd failed for batch %" PetscInt_FMT, i);
303: }
304: #endif
306: /* Check info */
307: int *h_info;
308: PetscCall(PetscMalloc1(n_batch, &h_info));
309: PetscCallHIP(hipMemcpy(h_info, d_info, sizeof(int) * n_batch, hipMemcpyDeviceToHost));
310: for (int i = 0; i < n_batch; i++) {
311: if (h_info[i] != 0) PetscCheck(h_info[i] == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "rocSOLVER eigendecomposition failed for matrix %" PetscInt_FMT ": info=%d", i, h_info[i]);
312: }
313: PetscCall(PetscFree(h_info));
315: /* Copy results back from contiguous layout to V_batch */
316: Kokkos::parallel_for(
317: "CopyResultsBack", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
318: for (int j = 0; j < n_size; j++) {
319: Lambda_batch(i, j) = d_W_contig[i * n_size + j];
320: for (int k = 0; k < n_size; k++) V_batch(i, j, k) = d_A_contig[i * n_size * n_size + k * n_size + j];
321: }
322: });
323: Kokkos::fence();
324: PetscFunctionReturn(PETSC_SUCCESS);
325: }
326: #elif defined(KOKKOS_ENABLE_SYCL)
327: static PetscErrorCode BatchedEigenSolve_Device(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, sycl::queue *q, EigenWorkspace *work)
328: {
329: PetscFunctionBegin;
330: /* Use pre-allocated workspace */
331: PetscScalar *d_work = work->d_work;
332: int *d_info = work->d_info;
333: PetscScalar *d_A_contig = work->d_A_contig;
334: PetscScalar *d_W_contig = work->d_W_contig;
336: /* Copy T_batch to contiguous layout for oneMKL */
337: Kokkos::parallel_for(
338: "ReorganizeForOneMKL", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
339: for (int j = 0; j < n_size; j++) {
340: for (int k = 0; k < n_size; k++) d_A_contig[i * n_size * n_size + k * n_size + j] = T_batch(i, j, k);
341: }
342: });
343: Kokkos::fence();
345: /* oneMKL doesn't have a native batched syevd, so we loop over batch */
346: /* Use oneapi::mkl::lapack::syevd which computes eigenvalues and eigenvectors */
347: for (int i = 0; i < n_batch; i++) {
348: PetscScalar *A_ptr = d_A_contig + i * n_size * n_size;
349: PetscScalar *W_ptr = d_W_contig + i * n_size;
350: int *info_ptr = d_info + i;
352: try {
353: #if defined(PETSC_USE_REAL_SINGLE)
354: oneapi::mkl::lapack::syevd(*q, oneapi::mkl::job::vec, oneapi::mkl::uplo::upper, n_size, A_ptr, n_size, W_ptr, d_work, work->lwork_device, info_ptr);
355: #else
356: oneapi::mkl::lapack::syevd(*q, oneapi::mkl::job::vec, oneapi::mkl::uplo::upper, n_size, A_ptr, n_size, W_ptr, d_work, work->lwork_device, info_ptr);
357: #endif
358: q->wait();
359: } catch (sycl::exception const &e) {
360: SETERRQ(PETSC_COMM_SELF, PETSC_ERR_LIB, "oneMKL syevd failed for batch %d: %s", i, e.what());
361: }
362: }
364: /* Check info */
365: int *h_info;
366: PetscCall(PetscMalloc1(n_batch, &h_info));
367: q->memcpy(h_info, d_info, sizeof(int) * n_batch).wait();
368: for (int i = 0; i < n_batch; i++) {
369: if (h_info[i] != 0) PetscCheck(h_info[i] == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "oneMKL eigendecomposition failed for matrix %" PetscInt_FMT ": info=%d", i, h_info[i]);
370: }
371: PetscCall(PetscFree(h_info));
373: /* Copy results back from contiguous layout to V_batch */
374: Kokkos::parallel_for(
375: "CopyResultsBack", Kokkos::RangePolicy<Kokkos::DefaultExecutionSpace>(0, n_batch), KOKKOS_LAMBDA(const int i) {
376: for (int j = 0; j < n_size; j++) {
377: Lambda_batch(i, j) = d_W_contig[i * n_size + j];
378: for (int k = 0; k < n_size; k++) V_batch(i, j, k) = d_A_contig[i * n_size * n_size + k * n_size + j];
379: }
380: });
381: Kokkos::fence();
382: PetscFunctionReturn(PETSC_SUCCESS);
383: }
384: #endif
385: #endif
387: /*
388: BatchedEigenSolve - Compute eigendecomposition for a batch of symmetric matrices
390: Input Parameters:
391: + T_batch - batch of symmetric matrices (n_batch x n_size x n_size)
392: . n_batch - number of matrices in the batch
393: - n_size - size of each matrix (m x m)
394: - device_handle - device-specific solver handle (only for device builds)
395: - work - reusable workspace structure
397: Output Parameters:
398: + Lambda_batch - eigenvalues for each matrix (n_batch x n_size)
399: - V_batch - eigenvectors for each matrix (n_batch x n_size x n_size)
401: Notes:
402: Dispatcher function that calls the appropriate backend (Device or Host).
403: */
404: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
405: #if defined(KOKKOS_ENABLE_CUDA)
406: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, cusolverDnHandle_t device_handle, EigenWorkspace *work)
407: {
408: PetscFunctionBegin;
409: PetscCall(BatchedEigenSolve_Device(T_batch, Lambda_batch, V_batch, n_batch, n_size, device_handle, work));
410: PetscFunctionReturn(PETSC_SUCCESS);
411: }
412: #elif defined(KOKKOS_ENABLE_HIP)
413: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, rocblas_handle device_handle, EigenWorkspace *work)
414: {
415: PetscFunctionBegin;
416: PetscCall(BatchedEigenSolve_Device(T_batch, Lambda_batch, V_batch, n_batch, n_size, device_handle, work));
417: PetscFunctionReturn(PETSC_SUCCESS);
418: }
419: #elif defined(KOKKOS_ENABLE_SYCL)
420: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, sycl::queue *device_handle, EigenWorkspace *work)
421: {
422: PetscFunctionBegin;
423: PetscCall(BatchedEigenSolve_Device(T_batch, Lambda_batch, V_batch, n_batch, n_size, device_handle, work));
424: PetscFunctionReturn(PETSC_SUCCESS);
425: }
426: #endif
427: #else
428: static PetscErrorCode BatchedEigenSolve(Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> T_batch, Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> Lambda_batch, Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, Kokkos::DefaultExecutionSpace> V_batch, PetscInt n_batch, PetscInt n_size, EigenWorkspace *work)
429: {
430: PetscFunctionBegin;
431: PetscCall(BatchedEigenSolve_Host(T_batch, Lambda_batch, V_batch, n_batch, n_size, work));
432: PetscFunctionReturn(PETSC_SUCCESS);
433: }
434: #endif
436: /*
437: PetscDALETKFSetupLocalization_Kokkos - Prepares device views for localization matrix Q
438: */
439: PetscErrorCode PetscDALETKFSetupLocalization_Kokkos(PetscDA_LETKF *impl, Mat H)
440: {
441: PetscInt nrows;
443: PetscFunctionBegin;
444: PetscCheck(impl->Q, PETSC_COMM_SELF, PETSC_ERR_LIB, "impl->Q = 0");
445: PetscCall(PetscKokkosInitializeCheck());
447: /* Get CSR data */
448: PetscInt rstart, rend, i, nnz;
449: PetscCall(MatGetOwnershipRange(impl->Q, &rstart, &rend));
450: nrows = rend - rstart;
452: /* Create IS for local observations needed by this process */
453: /* We need to find all unique column indices in the local rows of Q */
454: {
455: PetscInt *obs_indices;
456: PetscInt n_obs_local_total = 0;
457: PetscInt max_obs = nrows * impl->n_obs_vertex;
458: PetscInt count = 0;
459: PetscHMapI ht;
460: PetscHashIter iter;
461: PetscBool missing;
463: PetscCall(PetscHMapICreate(&ht));
464: PetscCall(PetscMalloc1(max_obs, &obs_indices));
466: for (i = 0; i < nrows; i++) {
467: const PetscInt *cols;
468: const PetscScalar *vals;
469: PetscCall(MatGetRow(impl->Q, rstart + i, &nnz, &cols, &vals));
470: for (PetscInt k = 0; k < nnz; k++) {
471: PetscCall(PetscHMapIPut(ht, cols[k], &iter, &missing));
472: if (missing) {
473: obs_indices[count] = cols[k];
474: count++;
475: }
476: }
477: PetscCall(MatRestoreRow(impl->Q, rstart + i, &nnz, &cols, &vals));
478: }
479: n_obs_local_total = count;
481: /* Sort indices for consistent ordering */
482: PetscCall(PetscSortInt(n_obs_local_total, obs_indices));
484: /* Create IS and VecScatter */
485: PetscCall(ISCreateGeneral(PETSC_COMM_SELF, n_obs_local_total, obs_indices, PETSC_COPY_VALUES, &impl->obs_is_local));
487: /* Create global-to-local map for observations */
488: PetscCall(PetscHMapICreate(&impl->obs_g2l));
489: for (i = 0; i < n_obs_local_total; i++) {
490: PetscCall(PetscHMapIPut(impl->obs_g2l, obs_indices[i], &iter, &missing));
491: PetscCall(PetscHMapIIterSet(impl->obs_g2l, iter, i));
492: }
494: PetscCall(PetscFree(obs_indices));
495: PetscCall(PetscHMapIDestroy(&ht));
496: }
498: /* Create work vectors and scatter context */
499: {
500: PetscInt n_obs_local_total;
501: PetscCall(ISGetLocalSize(impl->obs_is_local, &n_obs_local_total));
503: PetscCall(VecCreateSeq(PETSC_COMM_SELF, n_obs_local_total, &impl->obs_work));
504: PetscCall(VecCreateSeq(PETSC_COMM_SELF, n_obs_local_total, &impl->y_mean_work));
505: PetscCall(VecCreateSeq(PETSC_COMM_SELF, n_obs_local_total, &impl->r_inv_sqrt_work));
507: Vec gvec;
508: IS is_to;
509: PetscCall(MatCreateVecs(H, NULL, &gvec)); /* Create template global vector (left vector = rows = observations) */
510: PetscCall(ISCreateStride(PETSC_COMM_SELF, n_obs_local_total, 0, 1, &is_to));
511: PetscCall(VecScatterCreate(gvec, impl->obs_is_local, impl->obs_work, is_to, &impl->obs_scat));
512: PetscCall(VecDestroy(&gvec));
513: PetscCall(ISDestroy(&is_to));
514: }
516: /* Define View types */
517: using view_1d_int = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
518: using view_1d_scalar = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft>;
520: /* Allocate device views */
521: view_1d_int *d_Q_i = new view_1d_int("Q_i", nrows + 1);
522: view_1d_int *d_Q_j = new view_1d_int("Q_j", nrows * impl->n_obs_vertex);
523: view_1d_scalar *d_Q_a = new view_1d_scalar("Q_a", nrows * impl->n_obs_vertex);
525: /* Create host mirrors */
526: auto h_Q_i = Kokkos::create_mirror_view(*d_Q_i);
527: auto h_Q_j = Kokkos::create_mirror_view(*d_Q_j);
528: auto h_Q_a = Kokkos::create_mirror_view(*d_Q_a);
530: /* Fill host mirrors with LOCAL indices into obs_work */
531: h_Q_i(0) = 0;
532: for (i = 0; i < nrows; i++) {
533: const PetscInt *cols;
534: const PetscScalar *vals;
535: PetscCall(MatGetRow(impl->Q, rstart + i, &nnz, &cols, &vals));
536: h_Q_i(i + 1) = h_Q_i(i) + nnz;
537: for (PetscInt k = 0; k < nnz; k++) {
538: PetscInt local_idx;
539: PetscCall(ISLocate(impl->obs_is_local, cols[k], &local_idx));
540: PetscCheck(local_idx >= 0, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Observation index %" PetscInt_FMT " not found in local IS", cols[k]);
541: h_Q_j(h_Q_i(i) + k) = local_idx;
542: h_Q_a(h_Q_i(i) + k) = vals[k];
543: }
544: PetscCall(MatRestoreRow(impl->Q, rstart + i, &nnz, &cols, &vals));
545: }
547: /* Copy to device */
548: Kokkos::deep_copy(*d_Q_i, h_Q_i);
549: Kokkos::deep_copy(*d_Q_j, h_Q_j);
550: Kokkos::deep_copy(*d_Q_a, h_Q_a);
552: /* Store in impl */
553: PetscCheck(!impl->Q_device_i, PETSC_COMM_SELF, PETSC_ERR_LIB, "impl->Q = 0");
554: impl->Q_device_i = static_cast<void *>(d_Q_i);
555: impl->Q_device_j = static_cast<void *>(d_Q_j);
556: impl->Q_device_a = static_cast<void *>(d_Q_a);
557: PetscFunctionReturn(PETSC_SUCCESS);
558: }
560: PetscErrorCode PetscDALETKFDestroyLocalization_Kokkos(PetscDA_LETKF *impl)
561: {
562: PetscFunctionBegin;
563: PetscCall(VecDestroy(&impl->obs_work));
564: PetscCall(VecDestroy(&impl->y_mean_work));
565: PetscCall(VecDestroy(&impl->r_inv_sqrt_work));
566: PetscCall(VecScatterDestroy(&impl->obs_scat));
567: PetscCall(MatDestroy(&impl->Z_work));
568: PetscCall(PetscHMapIDestroy(&impl->obs_g2l));
569: if (impl->Q_device_i) {
570: using view_1d_int = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
571: delete static_cast<view_1d_int *>(impl->Q_device_i);
572: impl->Q_device_i = NULL;
573: }
574: if (impl->Q_device_j) {
575: using view_1d_int = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
576: delete static_cast<view_1d_int *>(impl->Q_device_j);
577: impl->Q_device_j = NULL;
578: }
579: if (impl->Q_device_a) {
580: using view_1d_scalar = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft>;
581: delete static_cast<view_1d_scalar *>(impl->Q_device_a);
582: impl->Q_device_a = NULL;
583: }
585: /* Destroy solver handle and workspace */
586: if (impl->eigen_work) {
587: EigenWorkspace *work = static_cast<EigenWorkspace *>(impl->eigen_work);
589: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
590: #if defined(KOKKOS_ENABLE_CUDA)
591: PetscCallCUDA(cudaFree(work->d_A_contig));
592: PetscCallCUDA(cudaFree(work->d_W_contig));
593: PetscCallCUDA(cudaFree(work->d_work));
594: PetscCallCUDA(cudaFree(work->d_info));
595: if (work->syevj_params) cusolverDnDestroySyevjInfo(work->syevj_params);
596: #elif defined(KOKKOS_ENABLE_HIP)
597: PetscCallHIP(hipFree(work->d_A_contig));
598: PetscCallHIP(hipFree(work->d_W_contig));
599: PetscCallHIP(hipFree(work->d_work));
600: PetscCallHIP(hipFree(work->d_info));
601: #elif defined(KOKKOS_ENABLE_SYCL)
602: if (impl->solver_handle) {
603: sycl::queue *q = static_cast<sycl::queue *>(impl->solver_handle);
604: if (work->d_A_contig) sycl::free(work->d_A_contig, *q);
605: if (work->d_W_contig) sycl::free(work->d_W_contig, *q);
606: if (work->d_work) sycl::free(work->d_work, *q);
607: if (work->d_info) sycl::free(work->d_info, *q);
608: }
609: #endif
610: #else
611: #if defined(PETSC_USE_COMPLEX)
612: PetscCall(PetscFree4(work->all_v, work->all_lambda, work->all_work, work->all_rwork));
613: #else
614: PetscCall(PetscFree3(work->all_v, work->all_lambda, work->all_work));
615: #endif
616: #endif
618: delete work;
619: impl->eigen_work = NULL;
620: }
622: if (impl->solver_handle) {
623: #if defined(KOKKOS_ENABLE_CUDA)
624: cusolverDnDestroy(static_cast<cusolverDnHandle_t>(impl->solver_handle));
625: #elif defined(KOKKOS_ENABLE_HIP)
626: rocblas_destroy_handle(static_cast<rocblas_handle>(impl->solver_handle));
627: #elif defined(KOKKOS_ENABLE_SYCL)
628: delete static_cast<sycl::queue *>(impl->solver_handle);
629: #endif
630: impl->solver_handle = NULL;
631: }
632: PetscFunctionReturn(PETSC_SUCCESS);
633: }
635: /* ========================================================================== */
636: /* LETKF Local Analysis (Main Function) */
637: /* ========================================================================== */
639: /*
640: PetscDALETKFLocalAnalysis_GPU - Performs local LETKF analysis for all grid points (Kokkos version)
642: Input Parameters:
643: + da - the PetscDA context
644: . impl - LETKF implementation data
645: . m - ensemble size
646: . n_vertices - number of grid points
647: . X - global anomaly matrix (state_size x m)
648: . observation - observation vector
649: . Z_global - global observation ensemble (obs_size x m)
650: . y_mean_global - global observation mean
651: - r_inv_sqrt_global - global R^{-1/2}
653: Output:
654: . da->ensemble - updated with analysis ensemble
656: Notes:
657: This function performs the local analysis loop for LETKF, processing each grid point
658: independently using its local observations defined by the localization matrix Q.
659: This is the CPU version that does not use Kokkos acceleration.
661: All local analysis workspace objects (Z_local, S_local, T_sqrt_local, G_local, y_local,
662: y_mean_local, delta_scaled_local, r_inv_sqrt_local, w_local, s_transpose_delta) are
663: created with PETSC_COMM_SELF because the analysis at each vertex is serial and independent.
664: */
665: PetscErrorCode PetscDALETKFLocalAnalysis_GPU(PetscDA da, PetscDA_LETKF *impl, PetscInt m, PetscInt n_vertices, Mat X, Vec observation, Mat Z_global, Vec y_mean_global, Vec r_inv_sqrt_global)
666: {
667: PetscDA_Ensemble *en = (PetscDA_Ensemble *)da->data;
668: PetscInt ndof;
669: PetscReal sqrt_m_minus_1, scale, inflation_inv;
671: PetscFunctionBegin;
672: ndof = da->ndof;
673: scale = 1.0 / PetscSqrtReal((PetscReal)(m - 1));
674: sqrt_m_minus_1 = PetscSqrtReal((PetscReal)(m - 1));
675: inflation_inv = 1.0 / en->inflation; /* (1/rho) for T matrix: T = (1/rho)I + S^T*S */
677: /* ===================================================================== */
678: /* Step 2.1.1: Create batched workspace for ALL grid points */
679: /* ===================================================================== */
680: /*
681: NOTE ON PARALLELISM STRATEGY:
682: We use Kokkos::RangePolicy over grid points (n_vertices) combined with KokkosBatched::Serial kernels.
683: Since the data layout is LayoutLeft (Column-Major) to match PETSc/LAPACK, the index 'i' (grid point)
684: is the fastest varying index (stride 1).
686: RangePolicy maps consecutive threads to consecutive 'i', ensuring perfect memory coalescing
687: when accessing arrays like S_batch(i, p, j).
689: Using TeamPolicy/TeamVectorRange to parallelize inner loops (m or p) would assign a team to 'i',
690: causing threads within the team to access S_batch with stride 'n_vertices', which leads to
691: uncoalesced memory access and poor performance on GPUs.
693: Therefore, RangePolicy + SerialGemm is the optimal strategy for this data layout.
694: */
695: using exec_space = Kokkos::DefaultExecutionSpace;
696: using view_3d = Kokkos::View<PetscScalar ***, Kokkos::LayoutLeft, exec_space>;
697: using view_2d = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>;
699: /* ===================================================================== */
700: /* Step 2.1.2a: Pre-extract Q matrix CSR data for device access */
701: /* ===================================================================== */
702: using view_1d_int_const = Kokkos::View<const PetscInt *, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
703: using view_1d_scalar_const = Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
704: using view_1d_int = Kokkos::View<PetscInt *, Kokkos::LayoutLeft>;
705: using view_1d_scalar = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft>;
707: view_1d_int_const Q_i_view;
708: view_1d_int_const Q_j_view;
709: view_1d_scalar_const Q_a_view;
711: if (impl->Q_device_i) {
712: /* Use pre-allocated device views */
713: view_1d_int *d_Q_i = static_cast<view_1d_int *>(impl->Q_device_i);
714: view_1d_int *d_Q_j = static_cast<view_1d_int *>(impl->Q_device_j);
715: view_1d_scalar *d_Q_a = static_cast<view_1d_scalar *>(impl->Q_device_a);
717: Q_i_view = view_1d_int_const(d_Q_i->data(), d_Q_i->extent(0));
718: Q_j_view = view_1d_int_const(d_Q_j->data(), d_Q_j->extent(0));
719: Q_a_view = view_1d_scalar_const(d_Q_a->data(), d_Q_a->extent(0));
720: } else {
721: /* Fallback to host pointers (unsafe if not UVM) */
722: PetscCheck(PETSC_FALSE, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Q matrix must be setup with PetscDALETKFSetupLocalization_Kokkos");
723: }
725: /* Get global observation data arrays */
726: const PetscScalar *z_global_array, *y_global_array, *y_mean_global_array, *r_inv_sqrt_global_array;
727: PetscInt lda_z_global;
728: PetscMemType z_mem_type, y_mem_type, y_mean_mem_type, r_inv_sqrt_mem_type;
730: PetscCall(MatDenseGetArrayReadAndMemType(Z_global, &z_global_array, &z_mem_type));
731: PetscCall(VecGetArrayReadAndMemType(observation, &y_global_array, &y_mem_type));
732: PetscCall(VecGetArrayReadAndMemType(y_mean_global, &y_mean_global_array, &y_mean_mem_type));
733: PetscCall(VecGetArrayReadAndMemType(r_inv_sqrt_global, &r_inv_sqrt_global_array, &r_inv_sqrt_mem_type));
734: PetscCall(MatDenseGetLDA(Z_global, &lda_z_global));
736: /* Handle memory mirroring for observation data */
737: Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space> z_managed;
738: Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space> y_managed;
739: Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space> y_mean_managed;
740: Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space> r_inv_sqrt_managed;
742: const PetscScalar *z_ptr = z_global_array;
743: const PetscScalar *y_ptr = y_global_array;
744: const PetscScalar *y_mean_ptr = y_mean_global_array;
745: const PetscScalar *r_inv_sqrt_ptr = r_inv_sqrt_global_array;
747: if (z_mem_type == PETSC_MEMTYPE_HOST) {
748: z_managed = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>("z_managed", lda_z_global, m);
749: Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(z_global_array, lda_z_global, m);
750: Kokkos::deep_copy(z_managed, src);
751: z_ptr = z_managed.data();
752: }
753: if (y_mem_type == PETSC_MEMTYPE_HOST) {
754: y_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("y_managed", lda_z_global);
755: Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(y_global_array, lda_z_global);
756: Kokkos::deep_copy(y_managed, src);
757: y_ptr = y_managed.data();
758: }
759: if (y_mean_mem_type == PETSC_MEMTYPE_HOST) {
760: y_mean_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("y_mean_managed", lda_z_global);
761: Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(y_mean_global_array, lda_z_global);
762: Kokkos::deep_copy(y_mean_managed, src);
763: y_mean_ptr = y_mean_managed.data();
764: }
765: if (r_inv_sqrt_mem_type == PETSC_MEMTYPE_HOST) {
766: r_inv_sqrt_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("r_inv_sqrt_managed", lda_z_global);
767: Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(r_inv_sqrt_global_array, lda_z_global);
768: Kokkos::deep_copy(r_inv_sqrt_managed, src);
769: r_inv_sqrt_ptr = r_inv_sqrt_managed.data();
770: }
772: /* Create unmanaged Kokkos views for global observation data */
773: using view_2d_unmanaged = Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
774: using view_1d_unmanaged = Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
776: view_2d_unmanaged Z_global_view(z_ptr, lda_z_global, m);
777: view_1d_unmanaged y_global_view(y_ptr, lda_z_global);
778: view_1d_unmanaged y_mean_global_view(y_mean_ptr, lda_z_global);
779: view_1d_unmanaged r_inv_sqrt_global_view(r_inv_sqrt_ptr, lda_z_global);
781: /* Get access to global X matrix and mean vector */
782: const PetscScalar *x_array, *mean_array;
783: PetscScalar *e_array;
784: PetscInt lda_x, lda_e;
785: PetscMemType x_mem_type, mean_mem_type, e_mem_type;
787: PetscCall(MatDenseGetArrayReadAndMemType(X, &x_array, &x_mem_type));
788: PetscCall(VecGetArrayReadAndMemType(impl->mean, &mean_array, &mean_mem_type));
789: PetscCall(MatDenseGetArrayWriteAndMemType(en->ensemble, &e_array, &e_mem_type));
790: PetscCall(MatDenseGetLDA(X, &lda_x));
791: PetscCall(MatDenseGetLDA(en->ensemble, &lda_e));
793: /* Handle memory mirroring for state data */
794: Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space> x_managed;
795: Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space> mean_managed;
796: Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space> e_managed;
798: const PetscScalar *x_ptr = x_array;
799: const PetscScalar *mean_ptr = mean_array;
800: PetscScalar *e_ptr = e_array;
801: bool e_is_copy = false;
803: if (x_mem_type == PETSC_MEMTYPE_HOST) {
804: x_managed = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>("x_managed", lda_x, m);
805: Kokkos::View<const PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(x_array, lda_x, m);
806: Kokkos::deep_copy(x_managed, src);
807: x_ptr = x_managed.data();
808: }
809: if (mean_mem_type == PETSC_MEMTYPE_HOST) {
810: mean_managed = Kokkos::View<PetscScalar *, Kokkos::LayoutLeft, exec_space>("mean_managed", lda_x);
811: Kokkos::View<const PetscScalar *, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(mean_array, lda_x);
812: Kokkos::deep_copy(mean_managed, src);
813: mean_ptr = mean_managed.data();
814: }
815: if (e_mem_type == PETSC_MEMTYPE_HOST) {
816: e_managed = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, exec_space>("e_managed", lda_e, m);
817: Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> src(e_array, lda_e, m);
818: Kokkos::deep_copy(e_managed, src);
819: e_ptr = e_managed.data();
820: e_is_copy = true;
821: }
823: /* Create unmanaged Kokkos views for global data */
824: using view_2d_unmanaged_write = Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
825: view_2d_unmanaged X_view(const_cast<PetscScalar *>(x_ptr), lda_x, m);
826: view_1d_unmanaged mean_view(mean_ptr, lda_x);
827: view_2d_unmanaged_write E_view(e_ptr, lda_e, m);
829: /* Determine chunk size to avoid OOM on large grids */
830: PetscInt chunk_size;
831: if (impl->batch_size > 0) {
832: chunk_size = impl->batch_size;
833: } else {
834: /* Target ~2GB workspace. Approx memory per point: m*m*8 (T) + p*m*8 (Z) */
835: /* With reuse: m*m*8 + p*m*8 */
836: PetscInt mem_per_point = sizeof(PetscScalar) * (m * m + impl->n_obs_vertex * m);
837: chunk_size = (PetscInt)(2.0 * 1024 * 1024 * 1024 / mem_per_point);
838: /* Clamp to reasonable max to avoid huge allocations even if memory allows */
839: if (chunk_size > 32768) chunk_size = 32768;
840: }
842: if (chunk_size < 1) chunk_size = 1;
843: if (chunk_size > n_vertices) chunk_size = n_vertices;
845: /* OPTIMIZATION: Create device solver handle once, reuse across chunks */
846: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
847: #if defined(KOKKOS_ENABLE_CUDA)
848: cusolverDnHandle_t device_handle = nullptr;
849: cusolverStatus_t cusolver_status;
850: if (impl->solver_handle) {
851: device_handle = static_cast<cusolverDnHandle_t>(impl->solver_handle);
852: } else {
853: cusolver_status = cusolverDnCreate(&device_handle);
854: PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDnCreate failed");
855: impl->solver_handle = static_cast<void *>(device_handle);
856: }
857: #elif defined(KOKKOS_ENABLE_HIP)
858: rocblas_handle device_handle = nullptr;
859: if (impl->solver_handle) {
860: device_handle = static_cast<rocblas_handle>(impl->solver_handle);
861: } else {
862: rocblas_status hip_status = rocblas_create_handle(&device_handle);
863: PetscCheck(hip_status == rocblas_status_success, PETSC_COMM_SELF, PETSC_ERR_LIB, "rocblas_create_handle failed");
864: impl->solver_handle = static_cast<void *>(device_handle);
865: }
866: #elif defined(KOKKOS_ENABLE_SYCL)
867: sycl::queue *device_handle = nullptr;
868: if (impl->solver_handle) {
869: device_handle = static_cast<sycl::queue *>(impl->solver_handle);
870: } else {
871: device_handle = new sycl::queue(sycl::gpu_selector_v);
872: impl->solver_handle = static_cast<void *>(device_handle);
873: }
874: #endif
875: #endif
877: /* ===================================================================== */
878: /* OPTIMIZATION: Hoist allocations outside the chunk loop */
879: /* ===================================================================== */
880: /* Allocate Kokkos Views once for the maximum chunk size */
881: PetscInt n_obs_vertex_copy = impl->n_obs_vertex;
883: EigenWorkspace *eigen_work = static_cast<EigenWorkspace *>(impl->eigen_work);
884: if (!eigen_work) {
885: eigen_work = new EigenWorkspace();
886: impl->eigen_work = static_cast<void *>(eigen_work);
887: }
889: /* Check if reallocation is needed */
890: if (eigen_work->max_chunk_size < chunk_size || eigen_work->m != m || eigen_work->n_obs_vertex != n_obs_vertex_copy) {
891: /* Free old device workspace if exists */
892: #if defined(KOKKOS_ENABLE_CUDA)
893: PetscCallCUDA(cudaFree(eigen_work->d_work));
894: PetscCallCUDA(cudaFree(eigen_work->d_info));
895: PetscCallCUDA(cudaFree(eigen_work->d_A_contig));
896: PetscCallCUDA(cudaFree(eigen_work->d_W_contig));
897: if (eigen_work->syevj_params) cusolverDnDestroySyevjInfo(eigen_work->syevj_params);
898: eigen_work->syevj_params = nullptr;
899: #elif defined(KOKKOS_ENABLE_HIP)
900: PetscCallHIP(hipFree(eigen_work->d_work));
901: PetscCallHIP(hipFree(eigen_work->d_info));
902: PetscCallHIP(hipFree(eigen_work->d_A_contig));
903: PetscCallHIP(hipFree(eigen_work->d_W_contig));
904: #elif defined(KOKKOS_ENABLE_SYCL)
905: if (eigen_work->d_work) sycl::free(eigen_work->d_work, *device_handle);
906: if (eigen_work->d_info) sycl::free(eigen_work->d_info, *device_handle);
907: if (eigen_work->d_A_contig) sycl::free(eigen_work->d_A_contig, *device_handle);
908: if (eigen_work->d_W_contig) sycl::free(eigen_work->d_W_contig, *device_handle);
909: #endif
911: #if !defined(KOKKOS_ENABLE_CUDA) && !defined(KOKKOS_ENABLE_HIP) && !defined(KOKKOS_ENABLE_SYCL)
912: #if defined(PETSC_USE_COMPLEX)
913: if (eigen_work->all_v) PetscCall(PetscFree4(eigen_work->all_v, eigen_work->all_lambda, eigen_work->all_work, eigen_work->all_rwork));
914: #else
915: if (eigen_work->all_v) PetscCall(PetscFree3(eigen_work->all_v, eigen_work->all_lambda, eigen_work->all_work));
916: #endif
917: #endif
919: /* Update dimensions */
920: eigen_work->max_chunk_size = chunk_size;
921: eigen_work->m = m;
922: eigen_work->n_obs_vertex = n_obs_vertex_copy;
924: /* Allocate Kokkos Views */
925: eigen_work->Z_batch = view_3d("Z_batch", chunk_size, n_obs_vertex_copy, m);
926: eigen_work->S_batch = eigen_work->Z_batch;
927: eigen_work->T_batch = view_3d("T_batch", chunk_size, m, m);
928: eigen_work->V_batch = eigen_work->T_batch;
929: eigen_work->Lambda_batch = view_2d("Lambda_batch", chunk_size, m);
930: eigen_work->T_sqrt_batch = view_3d("T_sqrt_batch", chunk_size, m, m);
931: eigen_work->w_batch = view_2d("w_batch", chunk_size, m);
932: eigen_work->delta_batch = view_2d("delta_batch", chunk_size, n_obs_vertex_copy);
933: eigen_work->y_batch = view_2d("y_batch", chunk_size, n_obs_vertex_copy);
934: eigen_work->y_mean_batch = view_2d("y_mean_batch", chunk_size, n_obs_vertex_copy);
935: eigen_work->r_inv_sqrt_batch = view_2d("r_inv_sqrt_batch", chunk_size, n_obs_vertex_copy);
936: eigen_work->temp1_batch = view_2d("temp1_batch", chunk_size, m);
937: eigen_work->temp2_batch = view_2d("temp2_batch", chunk_size, m);
938: eigen_work->inv_sqrt_lambda_batch = view_2d("inv_sqrt_lambda_batch", chunk_size, m);
940: /* Allocate solver workspace */
941: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
942: #if defined(KOKKOS_ENABLE_CUDA)
943: {
944: /* Create syevj params */
945: cusolver_status = cusolverDnCreateSyevjInfo(&eigen_work->syevj_params);
946: PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDnCreateSyevjInfo failed");
948: /* Set default params */
949: cusolverDnXsyevjSetTolerance(eigen_work->syevj_params, 1e-7);
950: cusolverDnXsyevjSetMaxSweeps(eigen_work->syevj_params, 100);
951: cusolverDnXsyevjSetSortEig(eigen_work->syevj_params, 1); /* Sort eigenvalues */
953: /* Query workspace size */
954: PetscScalar *d_A = eigen_work->T_batch.data();
955: PetscScalar *d_W = eigen_work->Lambda_batch.data();
956: int lwork;
957: #if defined(PETSC_USE_REAL_SINGLE)
958: cusolver_status = cusolverDnSsyevjBatched_bufferSize(device_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, m, d_A, m, d_W, &lwork, eigen_work->syevj_params, chunk_size);
959: #else
960: cusolver_status = cusolverDnDsyevjBatched_bufferSize(device_handle, CUSOLVER_EIG_MODE_VECTOR, CUBLAS_FILL_MODE_UPPER, m, d_A, m, d_W, &lwork, eigen_work->syevj_params, chunk_size);
961: #endif
962: PetscCheck(cusolver_status == CUSOLVER_STATUS_SUCCESS, PETSC_COMM_SELF, PETSC_ERR_LIB, "cusolverDn*syevjBatched_bufferSize failed");
963: eigen_work->lwork_device = lwork;
965: /* Allocate workspace */
966: PetscCallCUDA(cudaMalloc(&eigen_work->d_work, sizeof(PetscScalar) * lwork));
967: PetscCallCUDA(cudaMalloc(&eigen_work->d_info, sizeof(int) * chunk_size));
968: PetscCallCUDA(cudaMalloc(&eigen_work->d_A_contig, sizeof(PetscScalar) * chunk_size * m * m));
969: PetscCallCUDA(cudaMalloc(&eigen_work->d_W_contig, sizeof(PetscScalar) * chunk_size * m));
970: }
971: #elif defined(KOKKOS_ENABLE_HIP)
972: {
973: /* rocsolver_dsyevd does not support size query via -1.
974: We use a safe upper bound estimate based on LAPACK dsyevd requirements.
975: */
976: #if defined(PETSC_USE_COMPLEX)
977: int lwork = 0; /* Complex not supported on device */
978: #else
979: int lwork = 1 + 6 * m + 2 * m * m;
980: #endif
981: eigen_work->lwork_device = lwork;
983: /* Allocate workspace */
984: if (lwork > 0) {
985: PetscCallHIP(hipMalloc(&eigen_work->d_work, sizeof(PetscScalar) * lwork));
986: PetscCallHIP(hipMalloc(&eigen_work->d_info, sizeof(int) * chunk_size));
987: PetscCallHIP(hipMalloc(&eigen_work->d_A_contig, sizeof(PetscScalar) * chunk_size * m * m));
988: PetscCallHIP(hipMalloc(&eigen_work->d_W_contig, sizeof(PetscScalar) * chunk_size * m));
989: }
990: }
991: #elif defined(KOKKOS_ENABLE_SYCL)
992: {
993: /* Query workspace size for oneapi::mkl::lapack::syevd */
994: /* For syevd, workspace size is typically: */
995: /* lwork >= 1 + 6*n + 2*n*n for real, or */
996: /* lwork >= 2*n + n*n for complex */
997: int lwork;
998: #if defined(PETSC_USE_COMPLEX)
999: lwork = 2 * m + m * m;
1000: #else
1001: lwork = 1 + 6 * m + 2 * m * m;
1002: #endif
1003: eigen_work->lwork_device = lwork;
1005: /* Allocate workspace using SYCL malloc_device */
1006: eigen_work->d_work = sycl::malloc_device<PetscScalar>(lwork, *device_handle);
1007: eigen_work->d_info = sycl::malloc_device<int>(chunk_size, *device_handle);
1008: eigen_work->d_A_contig = sycl::malloc_device<PetscScalar>(chunk_size * m * m, *device_handle);
1009: eigen_work->d_W_contig = sycl::malloc_device<PetscScalar>(chunk_size * m, *device_handle);
1010: PetscCheck(eigen_work->d_work && eigen_work->d_info && eigen_work->d_A_contig && eigen_work->d_W_contig, PETSC_COMM_SELF, PETSC_ERR_MEM, "SYCL memory allocation failed");
1011: }
1012: #endif
1013: #else
1014: {
1015: PetscBLASInt n_blas;
1016: PetscCall(PetscBLASIntCast(m, &n_blas));
1017: eigen_work->n_blas = n_blas;
1019: /* Query workspace size */
1020: PetscBLASInt lwork_query = -1;
1021: PetscScalar work_query;
1022: PetscBLASInt info;
1023: #if defined(PETSC_USE_COMPLEX)
1024: PetscReal rwork_query;
1025: LAPACKsyev_("V", "U", &n_blas, &work_query, &n_blas, &rwork_query, &work_query, &lwork_query, &rwork_query, &info);
1026: #else
1027: LAPACKsyev_("V", "U", &n_blas, &work_query, &n_blas, &work_query, &work_query, &lwork_query, &info);
1028: #endif
1029: PetscCheck(info == 0, PETSC_COMM_SELF, PETSC_ERR_LIB, "LAPACK workspace query failed");
1030: eigen_work->lwork = (PetscBLASInt)PetscRealPart(work_query);
1032: /* Allocate workspace */
1033: #if defined(PETSC_USE_COMPLEX)
1034: PetscCall(PetscMalloc4(chunk_size * m * m, &eigen_work->all_v, chunk_size * m, &eigen_work->all_lambda, chunk_size * eigen_work->lwork, &eigen_work->all_work, chunk_size * (3 * m - 2), &eigen_work->all_rwork));
1035: #else
1036: PetscCall(PetscMalloc3(chunk_size * m * m, &eigen_work->all_v, chunk_size * m, &eigen_work->all_lambda, chunk_size * eigen_work->lwork, &eigen_work->all_work));
1037: #endif
1038: }
1039: #endif
1040: }
1042: /* Create aliases for current function use */
1043: view_3d Z_batch_alloc = eigen_work->Z_batch;
1044: view_3d S_batch_alloc = eigen_work->S_batch;
1045: view_3d T_batch_alloc = eigen_work->T_batch;
1046: view_3d V_batch_alloc = eigen_work->V_batch;
1047: view_2d Lambda_batch_alloc = eigen_work->Lambda_batch;
1048: view_3d T_sqrt_batch_alloc = eigen_work->T_sqrt_batch;
1049: view_2d w_batch_alloc = eigen_work->w_batch;
1050: view_2d delta_batch_alloc = eigen_work->delta_batch;
1051: view_2d y_batch_alloc = eigen_work->y_batch;
1052: view_2d y_mean_batch_alloc = eigen_work->y_mean_batch;
1053: view_2d r_inv_sqrt_batch_alloc = eigen_work->r_inv_sqrt_batch;
1054: view_2d temp1_batch_alloc = eigen_work->temp1_batch;
1055: view_2d temp2_batch_alloc = eigen_work->temp2_batch;
1056: view_2d inv_sqrt_lambda_batch_alloc = eigen_work->inv_sqrt_lambda_batch;
1058: /* Loop over chunks */
1059: for (PetscInt chunk_start = 0; chunk_start < n_vertices; chunk_start += chunk_size) {
1060: PetscInt chunk_end = (chunk_start + chunk_size > n_vertices) ? n_vertices : chunk_start + chunk_size;
1061: PetscInt n_batch_current = chunk_end - chunk_start;
1063: /* Create subviews for current batch size */
1064: auto Z_batch = Kokkos::subview(Z_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1065: auto S_batch = Kokkos::subview(S_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1066: auto T_batch = Kokkos::subview(T_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1067: auto V_batch = Kokkos::subview(V_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1068: auto Lambda_batch = Kokkos::subview(Lambda_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1069: auto T_sqrt_batch = Kokkos::subview(T_sqrt_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL(), Kokkos::ALL());
1070: auto w_batch = Kokkos::subview(w_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1071: auto delta_batch = Kokkos::subview(delta_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1072: auto y_batch = Kokkos::subview(y_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1073: auto y_mean_batch = Kokkos::subview(y_mean_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1074: auto r_inv_sqrt_batch = Kokkos::subview(r_inv_sqrt_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1075: auto temp1_batch = Kokkos::subview(temp1_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1076: auto temp2_batch = Kokkos::subview(temp2_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1077: auto inv_sqrt_lambda_batch = Kokkos::subview(inv_sqrt_lambda_batch_alloc, Kokkos::make_pair(0, (int)n_batch_current), Kokkos::ALL());
1079: /* ===================================================================== */
1080: /* Step 2.1.2: Fused observation extraction and S/Delta computation */
1081: /* ===================================================================== */
1082: /* Extract local observations and immediately compute S and delta */
1083: /* This fusion eliminates one kernel launch and improves cache locality */
1084: Kokkos::parallel_for(
1085: "ExtractAndComputeSAndDelta", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i_local) {
1086: PetscInt i_global = chunk_start + i_local;
1087: /* Get Q row for this grid point using CSR format */
1088: PetscInt row_start = Q_i_view(i_global);
1089: PetscInt row_end = Q_i_view(i_global + 1);
1090: PetscInt ncols = row_end - row_start;
1092: /* Extract observations and compute S/delta for this grid point */
1093: for (PetscInt k = 0; k < ncols; k++) {
1094: PetscInt obs_idx = Q_j_view(row_start + k);
1095: PetscScalar weight = Q_a_view(row_start + k);
1097: /* Extract observation vectors */
1098: PetscScalar y_val = y_global_view(obs_idx);
1099: PetscScalar y_mean_val = y_mean_global_view(obs_idx);
1100: PetscScalar r_inv_sqrt = r_inv_sqrt_global_view(obs_idx) * Kokkos::sqrt(PetscRealPart(weight));
1102: /* Store for later use if needed */
1103: y_batch(i_local, k) = y_val;
1104: y_mean_batch(i_local, k) = y_mean_val;
1105: r_inv_sqrt_batch(i_local, k) = r_inv_sqrt;
1107: /* Compute delta immediately: delta = R^{-1/2}(y - y_mean) */
1108: delta_batch(i_local, k) = (y_val - y_mean_val) * r_inv_sqrt;
1110: /* Compute S row: S = R^{-1/2}(Z - y_mean * 1')/sqrt(m-1) */
1111: PetscScalar scale_factor = scale * r_inv_sqrt;
1112: for (int j = 0; j < m; j++) {
1113: PetscScalar z_val = Z_global_view(obs_idx, j);
1114: Z_batch(i_local, k, j) = z_val; /* Store Z for potential later use */
1115: S_batch(i_local, k, j) = (z_val - y_mean_val) * scale_factor;
1116: }
1117: }
1118: });
1119: Kokkos::fence();
1121: /* DEBUG: Check S for NaNs */
1122: if (PetscDefined(USE_DEBUG)) {
1123: PetscInt nan_count = 0;
1124: Kokkos::parallel_reduce(
1125: "CheckS", Kokkos::RangePolicy<exec_space>(0, n_batch_current),
1126: KOKKOS_LAMBDA(const int i, int &l_count) {
1127: for (int j = 0; j < n_obs_vertex_copy; j++) {
1128: for (int k = 0; k < m; k++) {
1129: if (S_batch(i, j, k) != S_batch(i, j, k)) l_count++;
1130: }
1131: }
1132: },
1133: nan_count);
1134: PetscCheck(nan_count == 0, PETSC_COMM_SELF, PETSC_ERR_FP, "Found %" PetscInt_FMT " NaNs in S_batch at chunk_start %" PetscInt_FMT, nan_count, chunk_start);
1135: }
1137: /* ===================================================================== */
1138: /* Step 2.1.4: Optimized T matrix formation (T = (1/rho)I + S^T * S) */
1139: /* ===================================================================== */
1140: /* Compute T_i = (1/rho)I + S_i^T * S_i for current chunk */
1141: /* Exploit symmetry: only compute upper triangle, then copy to lower */
1142: /* This reduces operations by ~50% */
1143: Kokkos::parallel_for(
1144: "ComputeAllTMatrices", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i) {
1145: auto S_i = Kokkos::subview(S_batch, i, Kokkos::ALL(), Kokkos::ALL());
1146: auto T_i = Kokkos::subview(T_batch, i, Kokkos::ALL(), Kokkos::ALL());
1148: /* Compute upper triangle of T_i = (1/rho)I + S_i^T * S_i */
1149: /* T_i(j,k) = (1/rho)*delta_jk + sum_p S_i(p,j) * S_i(p,k) for j <= k */
1150: for (int j = 0; j < m; j++) {
1151: for (int k = j; k < m; k++) {
1152: PetscScalar sum = (j == k) ? inflation_inv : 0.0;
1153: for (int p = 0; p < n_obs_vertex_copy; p++) sum += S_i(p, j) * S_i(p, k);
1154: T_i(j, k) = sum;
1155: }
1156: }
1158: /* Copy upper triangle to lower triangle (T is symmetric) */
1159: for (int j = 0; j < m; j++) {
1160: for (int k = 0; k < j; k++) T_i(j, k) = T_i(k, j);
1161: }
1162: });
1163: Kokkos::fence();
1165: /* DEBUG: Check T for NaNs */
1166: if (PetscDefined(USE_DEBUG)) {
1167: PetscInt nan_count = 0;
1168: Kokkos::parallel_reduce(
1169: "CheckT", Kokkos::RangePolicy<exec_space>(0, n_batch_current),
1170: KOKKOS_LAMBDA(const int i, int &l_count) {
1171: for (int j = 0; j < m; j++) {
1172: for (int k = 0; k < m; k++) {
1173: if (T_batch(i, j, k) != T_batch(i, j, k)) l_count++;
1174: }
1175: }
1176: },
1177: nan_count);
1178: PetscCheck(nan_count == 0, PETSC_COMM_SELF, PETSC_ERR_FP, "Found %" PetscInt_FMT " NaNs in T_batch at chunk_start %" PetscInt_FMT, nan_count, chunk_start);
1179: }
1181: /* ===================================================================== */
1182: /* Step 3.1.1: Batched eigendecomposition for current chunk */
1183: /* ===================================================================== */
1184: /* Compute T_i = V_i * Lambda_i * V_i^T for current chunk */
1185: #if defined(KOKKOS_ENABLE_CUDA) || defined(KOKKOS_ENABLE_HIP) || defined(KOKKOS_ENABLE_SYCL)
1186: PetscCall(BatchedEigenSolve(T_batch, Lambda_batch, V_batch, n_batch_current, m, device_handle, eigen_work));
1187: #else
1188: PetscCall(BatchedEigenSolve(T_batch, Lambda_batch, V_batch, n_batch_current, m, eigen_work));
1189: #endif
1191: /* DEBUG: Check Lambda for NaNs or negative values */
1192: if (PetscDefined(USE_DEBUG)) {
1193: PetscInt bad_lambda = 0;
1194: Kokkos::parallel_reduce(
1195: "CheckLambda", Kokkos::RangePolicy<exec_space>(0, n_batch_current),
1196: KOKKOS_LAMBDA(const int i, int &l_count) {
1197: for (int k = 0; k < m; k++) {
1198: if (Lambda_batch(i, k) != Lambda_batch(i, k) || PetscRealPart(Lambda_batch(i, k)) < -1e-8) l_count++;
1199: }
1200: },
1201: bad_lambda);
1202: PetscCheck(bad_lambda == 0, PETSC_COMM_SELF, PETSC_ERR_FP, "Found %" PetscInt_FMT " bad eigenvalues (NaN or negative) at chunk_start %" PetscInt_FMT, bad_lambda, chunk_start);
1203: }
1205: /* ===================================================================== */
1206: /* Step 3.1.2: Precompute w and inv_sqrt_lambda for ensemble update */
1207: /* ===================================================================== */
1208: /* Compute w_i = T_i^{-1} * (S_i^T * delta_i) using eigendecomposition */
1209: /* Precompute 1/sqrt(Lambda) for use in ensemble update */
1210: Kokkos::parallel_for(
1211: "ComputeWeightsAndInvSqrtLambda", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i) {
1212: auto S_i = Kokkos::subview(S_batch, i, Kokkos::ALL(), Kokkos::ALL());
1213: auto V_i = Kokkos::subview(V_batch, i, Kokkos::ALL(), Kokkos::ALL());
1214: auto Lambda_i = Kokkos::subview(Lambda_batch, i, Kokkos::ALL());
1215: auto delta_i = Kokkos::subview(delta_batch, i, Kokkos::ALL());
1216: auto w_i = Kokkos::subview(w_batch, i, Kokkos::ALL());
1217: auto inv_sqrt_lambda_i = Kokkos::subview(inv_sqrt_lambda_batch, i, Kokkos::ALL());
1218: auto temp1 = Kokkos::subview(temp1_batch, i, Kokkos::ALL());
1219: auto temp2 = Kokkos::subview(temp2_batch, i, Kokkos::ALL());
1221: /* 1. Compute w_i = V * L^-1 * V^T * S^T * delta */
1222: /* Step 1a: temp1 = S^T * delta using KokkosBlas::gemv for better vectorization */
1223: KokkosBlas::SerialGemv<KokkosBlas::Trans::Transpose, KokkosBlas::Algo::Gemv::Unblocked>::invoke(1.0, S_i, delta_i, 0.0, temp1);
1225: /* Step 1b: temp2 = V^T * temp1 using KokkosBlas::gemv for better vectorization */
1226: KokkosBlas::SerialGemv<KokkosBlas::Trans::Transpose, KokkosBlas::Algo::Gemv::Unblocked>::invoke(1.0, V_i, temp1, 0.0, temp2);
1228: /* Step 1c: temp2 = temp2 / Lambda */
1229: for (int j = 0; j < m; j++) temp2(j) /= (Lambda_i(j) + 1.0e-14);
1231: /* Step 1d: w = V * temp2 using KokkosBlas::gemv for better vectorization */
1232: KokkosBlas::SerialGemv<KokkosBlas::Trans::NoTranspose, KokkosBlas::Algo::Gemv::Unblocked>::invoke(1.0, V_i, temp2, 0.0, w_i);
1234: /* 2. Precompute 1/sqrt(Lambda) for ensemble update */
1235: for (int p = 0; p < m; p++) inv_sqrt_lambda_i(p) = 1.0 / Kokkos::sqrt(PetscRealPart(Lambda_i(p)) + 1.0e-14);
1236: });
1237: Kokkos::fence();
1239: /* ===================================================================== */
1240: /* Step 3.1.3: Fused G computation and ensemble update */
1241: /* ===================================================================== */
1242: /* Compute E[i,:] = mean[i] + X[i,:] * G_i on-the-fly */
1243: /* G_i is computed column-by-column and immediately applied */
1244: /* This eliminates the need to store G_batch, saving m*m*n_batch memory */
1245: Kokkos::parallel_for(
1246: "FusedGComputeAndEnsembleUpdate", Kokkos::RangePolicy<exec_space>(0, n_batch_current), KOKKOS_LAMBDA(const int i_local) {
1247: PetscInt i_global = chunk_start + i_local;
1249: auto X_i = Kokkos::subview(X_view, Kokkos::make_pair(i_global * ndof, (i_global + 1) * ndof), Kokkos::ALL());
1250: auto E_i = Kokkos::subview(E_view, Kokkos::make_pair(i_global * ndof, (i_global + 1) * ndof), Kokkos::ALL());
1251: auto mean_i = Kokkos::subview(mean_view, Kokkos::make_pair(i_global * ndof, (i_global + 1) * ndof));
1253: auto V_i = Kokkos::subview(V_batch, i_local, Kokkos::ALL(), Kokkos::ALL());
1254: auto w_i = Kokkos::subview(w_batch, i_local, Kokkos::ALL());
1255: auto inv_sqrt_lambda_i = Kokkos::subview(inv_sqrt_lambda_batch, i_local, Kokkos::ALL());
1256: auto T_sqrt_i = Kokkos::subview(T_sqrt_batch, i_local, Kokkos::ALL(), Kokkos::ALL());
1258: /* Initialize E_i with mean */
1259: for (int row = 0; row < ndof; row++) {
1260: PetscScalar m_val = mean_i(row);
1261: for (int col = 0; col < m; col++) E_i(row, col) = m_val;
1262: }
1264: /* Compute T_sqrt = V * diag(1/sqrt(Lambda)) * V^T */
1265: /* Optimized: Exploit symmetry - only compute upper triangle, then copy to lower */
1266: /* T_sqrt(j,k) = sum_p V(j,p) * V(k,p) / sqrt(Lambda(p)) for j <= k */
1267: for (int j = 0; j < m; j++) {
1268: for (int k = j; k < m; k++) {
1269: PetscScalar sum = 0.0;
1270: for (int p = 0; p < m; p++) sum += V_i(j, p) * V_i(k, p) * inv_sqrt_lambda_i(p);
1271: T_sqrt_i(j, k) = sum;
1272: }
1273: }
1274: /* Copy upper triangle to lower triangle (T_sqrt is symmetric) */
1275: for (int j = 0; j < m; j++) {
1276: for (int k = 0; k < j; k++) T_sqrt_i(j, k) = T_sqrt_i(k, j);
1277: }
1279: /* Compute E_i += X_i * G_i column-by-column */
1280: /* G_i(:,k) = w_i + sqrt(m-1) * T_sqrt_i(:,k) */
1281: for (int k = 0; k < m; k++) {
1282: /* Compute column k of G on-the-fly */
1283: for (int row = 0; row < ndof; row++) {
1284: PetscScalar sum = 0.0;
1285: for (int j = 0; j < m; j++) {
1286: /* G_i(j,k) = w_i(j) + sqrt(m-1) * T_sqrt_i(j,k) */
1287: PetscScalar G_jk = w_i(j) + sqrt_m_minus_1 * T_sqrt_i(j, k);
1288: sum += X_i(row, j) * G_jk;
1289: }
1290: E_i(row, k) += sum;
1291: }
1292: }
1293: });
1294: Kokkos::fence();
1295: }
1297: /* Cleanup workspace */
1298: /* NOTE: Workspace is now persistent in impl->eigen_work and impl->solver_handle */
1299: /* It will be destroyed in PetscDALETKFDestroyLocalization_Kokkos */
1301: /* Copy back updated ensemble if needed */
1302: if (e_is_copy) {
1303: Kokkos::View<PetscScalar **, Kokkos::LayoutLeft, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> dst(e_array, lda_e, m);
1304: Kokkos::deep_copy(dst, e_managed);
1305: }
1307: /* Restore arrays */
1308: PetscCall(MatDenseRestoreArrayWriteAndMemType(en->ensemble, &e_array));
1309: PetscCall(VecRestoreArrayReadAndMemType(impl->mean, &mean_array));
1310: PetscCall(MatDenseRestoreArrayReadAndMemType(X, &x_array));
1312: /* Restore global observation arrays */
1313: PetscCall(VecRestoreArrayReadAndMemType(r_inv_sqrt_global, &r_inv_sqrt_global_array));
1314: PetscCall(VecRestoreArrayReadAndMemType(y_mean_global, &y_mean_global_array));
1315: PetscCall(VecRestoreArrayReadAndMemType(observation, &y_global_array));
1316: PetscCall(MatDenseRestoreArrayReadAndMemType(Z_global, &z_global_array));
1318: /* Ensemble has been updated in batched form above */
1319: PetscCall(MatAssemblyBegin(en->ensemble, MAT_FINAL_ASSEMBLY));
1320: PetscCall(MatAssemblyEnd(en->ensemble, MAT_FINAL_ASSEMBLY));
1322: {
1323: MatInfo info;
1324: PetscReal flops = 0.0;
1325: PetscReal n_obs_total;
1327: if (impl->Q) {
1328: PetscCall(MatGetInfo(impl->Q, MAT_LOCAL, &info));
1329: n_obs_total = info.nz_used;
1330: } else {
1331: n_obs_total = 0.0;
1332: }
1334: /* Step 2.1.2: Fused observation extraction and S/Delta computation */
1335: flops += n_obs_total * (2.0 + 2.0 * m);
1337: /* Step 2.1.4: Optimized T matrix formation */
1338: flops += (PetscReal)n_vertices * m * (m + 1) * impl->n_obs_vertex;
1340: /* Step 3.1.2: Precompute w and inv_sqrt_lambda */
1341: flops += (PetscReal)n_vertices * (2.0 * m * impl->n_obs_vertex + 4.0 * m * m + 3.0 * m);
1343: /* Step 3.1.3: Fused G computation and ensemble update */
1344: /* T_sqrt: 1.5*m^3 + 1.5*m^2 */
1345: flops += (PetscReal)n_vertices * (1.5 * m * m * m + 1.5 * m * m);
1346: /* E update: ndof * m * (4*m + 1) */
1347: /* Note: G_jk computation (2 flops) is inside the inner loop, so it's 2*m*ndof*m */
1348: /* Matrix product X*G (2 flops) is also 2*m*ndof*m */
1349: flops += (PetscReal)n_vertices * ndof * m * (4.0 * m + 1.0);
1351: PetscCall(PetscLogGpuFlops(flops));
1352: }
1353: PetscFunctionReturn(PETSC_SUCCESS);
1354: }