Actual source code: lmbasis.c

  1: #include <petsc/private/petscimpl.h>
  2: #include "lmbasis.h"
  3: #include "blas_cyclic/blas_cyclic.h"

  5: PetscLogEvent LMBASIS_GEMM, LMBASIS_GEMV, LMBASIS_GEMVH;

  7: PetscErrorCode LMBasisCreate(Vec v, PetscInt m, LMBasis *basis_p)
  8: {
  9:   PetscInt    n, N;
 10:   PetscMPIInt rank;
 11:   Mat         backing;
 12:   VecType     type;
 13:   LMBasis     basis;

 15:   PetscFunctionBegin;
 18:   PetscCheck(m >= 0, PetscObjectComm((PetscObject)v), PETSC_ERR_ARG_OUTOFRANGE, "Requested window size %" PetscInt_FMT " is not >= 0", m);
 19:   PetscCall(VecGetLocalSize(v, &n));
 20:   PetscCall(VecGetSize(v, &N));
 21:   PetscCallMPI(MPI_Comm_rank(PetscObjectComm((PetscObject)v), &rank));
 22:   PetscCall(VecGetType(v, &type));
 23:   PetscCall(MatCreateDenseFromVecType(PetscObjectComm((PetscObject)v), type, n, rank == 0 ? m : 0, N, m, n, NULL, &backing));
 24:   PetscCall(PetscNew(&basis));
 25:   *basis_p    = basis;
 26:   basis->m    = m;
 27:   basis->k    = 0;
 28:   basis->vecs = backing;
 29:   PetscFunctionReturn(PETSC_SUCCESS);
 30: }

 32: static PetscErrorCode LMBasisGetVec_Internal(LMBasis basis, PetscInt idx, PetscMemoryAccessMode mode, Vec *single, PetscBool check_idx)
 33: {
 34:   PetscFunctionBegin;
 35:   PetscAssertPointer(basis, 1);
 36:   if (check_idx) {
 38:     PetscCheck(idx < basis->k, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_OUTOFRANGE, "Asked for index %" PetscInt_FMT " >= number of inserted vecs %" PetscInt_FMT, idx, basis->k);
 39:     PetscInt earliest = PetscMax(0, basis->k - basis->m);
 40:     PetscCheck(idx >= earliest, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_OUTOFRANGE, "Asked for index %" PetscInt_FMT " < the earliest retained index % " PetscInt_FMT, idx, earliest);
 41:   }
 42:   PetscAssert(mode == PETSC_MEMORY_ACCESS_READ || mode == PETSC_MEMORY_ACCESS_WRITE, PETSC_COMM_SELF, PETSC_ERR_PLIB, "READ_WRITE access not implemented");
 43:   if (mode == PETSC_MEMORY_ACCESS_READ) PetscCall(MatDenseGetColumnVecRead(basis->vecs, idx % basis->m, single));
 44:   else PetscCall(MatDenseGetColumnVecWrite(basis->vecs, idx % basis->m, single));
 45:   PetscFunctionReturn(PETSC_SUCCESS);
 46: }

 48: PETSC_INTERN PetscErrorCode LMBasisGetVec(LMBasis basis, PetscInt idx, PetscMemoryAccessMode mode, Vec *single)
 49: {
 50:   PetscFunctionBegin;
 51:   PetscCall(LMBasisGetVec_Internal(basis, idx, mode, single, PETSC_TRUE));
 52:   PetscFunctionReturn(PETSC_SUCCESS);
 53: }

 55: PETSC_INTERN PetscErrorCode LMBasisRestoreVec(LMBasis basis, PetscInt idx, PetscMemoryAccessMode mode, Vec *single)
 56: {
 57:   PetscFunctionBegin;
 58:   PetscAssertPointer(basis, 1);
 59:   PetscAssert(mode == PETSC_MEMORY_ACCESS_READ || mode == PETSC_MEMORY_ACCESS_WRITE, PETSC_COMM_SELF, PETSC_ERR_PLIB, "READ_WRITE access not implemented");
 60:   if (mode == PETSC_MEMORY_ACCESS_READ) {
 61:     PetscCall(MatDenseRestoreColumnVecRead(basis->vecs, idx % basis->m, single));
 62:   } else {
 63:     PetscCall(MatDenseRestoreColumnVecWrite(basis->vecs, idx % basis->m, single));
 64:   }
 65:   *single = NULL;
 66:   PetscFunctionReturn(PETSC_SUCCESS);
 67: }

 69: PETSC_INTERN PetscErrorCode LMBasisGetVecRead(LMBasis B, PetscInt i, Vec *b)
 70: {
 71:   return LMBasisGetVec(B, i, PETSC_MEMORY_ACCESS_READ, b);
 72: }
 73: PETSC_INTERN PetscErrorCode LMBasisRestoreVecRead(LMBasis B, PetscInt i, Vec *b)
 74: {
 75:   return LMBasisRestoreVec(B, i, PETSC_MEMORY_ACCESS_READ, b);
 76: }

 78: PETSC_INTERN PetscErrorCode LMBasisGetNextVec(LMBasis basis, Vec *single)
 79: {
 80:   PetscFunctionBegin;
 81:   PetscCall(LMBasisGetVec_Internal(basis, basis->k, PETSC_MEMORY_ACCESS_WRITE, single, PETSC_FALSE));
 82:   PetscFunctionReturn(PETSC_SUCCESS);
 83: }

 85: PETSC_INTERN PetscErrorCode LMBasisRestoreNextVec(LMBasis basis, Vec *single)
 86: {
 87:   PetscFunctionBegin;
 88:   PetscAssertPointer(basis, 1);
 89:   PetscCall(LMBasisRestoreVec(basis, basis->k++, PETSC_MEMORY_ACCESS_WRITE, single));
 90:   // basis is updated, invalidate cached product
 91:   basis->cached_vec_id    = 0;
 92:   basis->cached_vec_state = 0;
 93:   PetscFunctionReturn(PETSC_SUCCESS);
 94: }

 96: PETSC_INTERN PetscErrorCode LMBasisSetNextVec(LMBasis basis, Vec single)
 97: {
 98:   Vec next;

100:   PetscFunctionBegin;
101:   PetscCall(LMBasisGetNextVec(basis, &next));
102:   PetscCall(VecCopy(single, next));
103:   PetscCall(LMBasisRestoreNextVec(basis, &next));
104:   PetscFunctionReturn(PETSC_SUCCESS);
105: }

107: PETSC_INTERN PetscErrorCode LMBasisDestroy(LMBasis *basis_p)
108: {
109:   LMBasis basis = *basis_p;

111:   PetscFunctionBegin;
112:   *basis_p = NULL;
113:   if (basis == NULL) PetscFunctionReturn(PETSC_SUCCESS);
114:   PetscCall(LMBasisReset(basis));
115:   PetscCheck(basis->work_vecs_in_use == NULL, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_WRONGSTATE, "Work vecs are still checked out at destruction");
116:   {
117:     VecLink head = basis->work_vecs_available;

119:     while (head) {
120:       VecLink next = head->next;

122:       PetscCall(VecDestroy(&head->vec));
123:       PetscCall(PetscFree(head));
124:       head = next;
125:     }
126:   }
127:   PetscCheck(basis->work_rows_in_use == NULL, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_WRONGSTATE, "Work rows are still checked out at destruction");
128:   {
129:     VecLink head = basis->work_rows_available;

131:     while (head) {
132:       VecLink next = head->next;

134:       PetscCall(VecDestroy(&head->vec));
135:       PetscCall(PetscFree(head));
136:       head = next;
137:     }
138:   }
139:   PetscCall(MatDestroy(&basis->vecs));
140:   PetscCall(PetscFree(basis));
141:   PetscFunctionReturn(PETSC_SUCCESS);
142: }

144: PETSC_INTERN PetscErrorCode LMBasisGetWorkVec(LMBasis basis, Vec *vec_p)
145: {
146:   VecLink link;

148:   PetscFunctionBegin;
149:   if (!basis->work_vecs_available) {
150:     PetscCall(PetscNew(&basis->work_vecs_available));
151:     PetscCall(MatCreateVecs(basis->vecs, NULL, &basis->work_vecs_available->vec));
152:   }
153:   link                       = basis->work_vecs_available;
154:   basis->work_vecs_available = link->next;
155:   link->next                 = basis->work_vecs_in_use;
156:   basis->work_vecs_in_use    = link;

158:   *vec_p    = link->vec;
159:   link->vec = NULL;
160:   PetscFunctionReturn(PETSC_SUCCESS);
161: }

163: PETSC_INTERN PetscErrorCode LMBasisRestoreWorkVec(LMBasis basis, Vec *vec_p)
164: {
165:   Vec     v    = *vec_p;
166:   VecLink link = NULL;

168:   PetscFunctionBegin;
169:   *vec_p = NULL;
170:   PetscCheck(basis->work_vecs_in_use, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_WRONGSTATE, "Trying to check in a vec that wasn't checked out");
171:   link                       = basis->work_vecs_in_use;
172:   basis->work_vecs_in_use    = link->next;
173:   link->next                 = basis->work_vecs_available;
174:   basis->work_vecs_available = link;

176:   PetscAssert(link->vec == NULL, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_PLIB, "Link not ready to return vector");
177:   link->vec = v;
178:   PetscFunctionReturn(PETSC_SUCCESS);
179: }

181: PETSC_INTERN PetscErrorCode LMBasisCreateRow(LMBasis basis, Vec *row_p)
182: {
183:   PetscFunctionBegin;
184:   PetscCall(MatCreateVecs(basis->vecs, row_p, NULL));
185:   PetscFunctionReturn(PETSC_SUCCESS);
186: }

188: PETSC_INTERN PetscErrorCode LMBasisGetWorkRow(LMBasis basis, Vec *row_p)
189: {
190:   VecLink link;

192:   PetscFunctionBegin;
193:   if (!basis->work_rows_available) {
194:     PetscCall(PetscNew(&basis->work_rows_available));
195:     PetscCall(MatCreateVecs(basis->vecs, &basis->work_rows_available->vec, NULL));
196:   }
197:   link                       = basis->work_rows_available;
198:   basis->work_rows_available = link->next;
199:   link->next                 = basis->work_rows_in_use;
200:   basis->work_rows_in_use    = link;

202:   PetscCall(VecZeroEntries(link->vec));
203:   *row_p    = link->vec;
204:   link->vec = NULL;
205:   PetscFunctionReturn(PETSC_SUCCESS);
206: }

208: PETSC_INTERN PetscErrorCode LMBasisRestoreWorkRow(LMBasis basis, Vec *row_p)
209: {
210:   Vec     v    = *row_p;
211:   VecLink link = NULL;

213:   PetscFunctionBegin;
214:   *row_p = NULL;
215:   PetscCheck(basis->work_rows_in_use, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_ARG_WRONGSTATE, "Trying to check in a row that wasn't checked out");
216:   link                       = basis->work_rows_in_use;
217:   basis->work_rows_in_use    = link->next;
218:   link->next                 = basis->work_rows_available;
219:   basis->work_rows_available = link;

221:   PetscAssert(link->vec == NULL, PetscObjectComm((PetscObject)basis->vecs), PETSC_ERR_PLIB, "Link not ready to return vector");
222:   link->vec = v;
223:   PetscFunctionReturn(PETSC_SUCCESS);
224: }

226: PETSC_INTERN PetscErrorCode LMBasisCopy(LMBasis basis_a, LMBasis basis_b)
227: {
228:   PetscFunctionBegin;
229:   PetscCheck(basis_a->m == basis_b->m, PetscObjectComm((PetscObject)basis_a), PETSC_ERR_ARG_SIZ, "Copy target has different number of vecs, %" PetscInt_FMT " != %" PetscInt_FMT, basis_b->m, basis_a->m);
230:   basis_b->k = basis_a->k;
231:   PetscCall(MatCopy(basis_a->vecs, basis_b->vecs, SAME_NONZERO_PATTERN));
232:   basis_b->cached_vec_id    = basis_a->cached_vec_id;
233:   basis_b->cached_vec_state = basis_a->cached_vec_state;
234:   if (basis_a->cached_product) {
235:     if (!basis_b->cached_product) PetscCall(VecDuplicate(basis_a->cached_product, &basis_b->cached_product));
236:     PetscCall(VecCopy(basis_a->cached_product, basis_b->cached_product));
237:   }
238:   PetscFunctionReturn(PETSC_SUCCESS);
239: }

241: PETSC_INTERN PetscErrorCode LMBasisGetRange(LMBasis basis, PetscInt *oldest, PetscInt *next)
242: {
243:   PetscFunctionBegin;
244:   *next   = basis->k;
245:   *oldest = PetscMax(0, basis->k - basis->m);
246:   PetscFunctionReturn(PETSC_SUCCESS);
247: }

249: static PetscErrorCode LMBasisMultCheck(LMBasis A, PetscInt oldest, PetscInt next)
250: {
251:   PetscInt basis_oldest, basis_next;

253:   PetscFunctionBegin;
254:   PetscCall(LMBasisGetRange(A, &basis_oldest, &basis_next));
255:   PetscCheck(oldest >= basis_oldest && next <= basis_next, PetscObjectComm((PetscObject)A->vecs), PETSC_ERR_ARG_OUTOFRANGE, "Asked for vec that hasn't been computed or is no longer stored");
256:   PetscFunctionReturn(PETSC_SUCCESS);
257: }

259: PETSC_INTERN PetscErrorCode LMBasisGEMV(LMBasis A, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y)
260: {
261:   PetscInt lim        = next - oldest;
262:   PetscInt next_idx   = ((next - 1) % A->m) + 1;
263:   PetscInt oldest_idx = oldest % A->m;
264:   Vec      x_work     = NULL;
265:   Vec      x_         = x;

267:   PetscFunctionBegin;
268:   if (lim <= 0) PetscFunctionReturn(PETSC_SUCCESS);
269:   PetscCall(PetscLogEventBegin(LMBASIS_GEMV, NULL, NULL, NULL, NULL));
270:   PetscCall(LMBasisMultCheck(A, oldest, next));
271:   if (alpha != 1.0) {
272:     PetscCall(LMBasisGetWorkRow(A, &x_work));
273:     PetscCall(VecAXPBYCyclic(oldest, next, alpha, x, 0.0, x_work));
274:     x_ = x_work;
275:   }
276:   if (beta != 1.0 && beta != 0.0) PetscCall(VecScale(y, beta));
277:   if (lim == A->m) {
278:     // all vectors are used
279:     if (beta == 0.0) PetscCall(MatMult(A->vecs, x_, y));
280:     else PetscCall(MatMultAdd(A->vecs, x_, y, y));
281:   } else if (oldest_idx < next_idx) {
282:     // contiguous vectors are used
283:     if (beta == 0.0) PetscCall(MatMultColumnRange(A->vecs, x_, y, oldest_idx, next_idx));
284:     else PetscCall(MatMultAddColumnRange(A->vecs, x_, y, y, oldest_idx, next_idx));
285:   } else {
286:     if (beta == 0.0) PetscCall(MatMultColumnRange(A->vecs, x_, y, 0, next_idx));
287:     else PetscCall(MatMultAddColumnRange(A->vecs, x_, y, y, 0, next_idx));
288:     PetscCall(MatMultAddColumnRange(A->vecs, x_, y, y, oldest_idx, A->m));
289:   }
290:   if (alpha != 1.0) PetscCall(LMBasisRestoreWorkRow(A, &x_work));
291:   PetscCall(PetscLogEventEnd(LMBASIS_GEMV, NULL, NULL, NULL, NULL));
292:   PetscFunctionReturn(PETSC_SUCCESS);
293: }

295: PETSC_INTERN PetscErrorCode LMBasisGEMVH(LMBasis A, PetscInt oldest, PetscInt next, PetscScalar alpha, Vec x, PetscScalar beta, Vec y)
296: {
297:   PetscInt lim        = next - oldest;
298:   PetscInt next_idx   = ((next - 1) % A->m) + 1;
299:   PetscInt oldest_idx = oldest % A->m;
300:   Vec      y_         = y;

302:   PetscFunctionBegin;
303:   if (lim <= 0) PetscFunctionReturn(PETSC_SUCCESS);
304:   PetscCall(LMBasisMultCheck(A, oldest, next));
305:   if (A->cached_product && A->cached_vec_id != 0 && A->cached_vec_state != 0) {
306:     // see if x is the cached input vector
307:     PetscObjectId    x_id;
308:     PetscObjectState x_state;

310:     PetscCall(PetscObjectGetId((PetscObject)x, &x_id));
311:     PetscCall(PetscObjectStateGet((PetscObject)x, &x_state));
312:     if (x_id == A->cached_vec_id && x_state == A->cached_vec_state) {
313:       PetscCall(VecAXPBYCyclic(oldest, next, alpha, A->cached_product, beta, y));
314:       PetscFunctionReturn(PETSC_SUCCESS);
315:     }
316:   }
317:   PetscCall(PetscLogEventBegin(LMBASIS_GEMVH, NULL, NULL, NULL, NULL));
318:   if (alpha != 1.0 || (beta != 1.0 && beta != 0.0)) PetscCall(LMBasisGetWorkRow(A, &y_));
319:   if (lim == A->m) {
320:     // all vectors are used
321:     if (alpha == 1.0 && beta == 1.0) PetscCall(MatMultHermitianTransposeAdd(A->vecs, x, y_, y_));
322:     else PetscCall(MatMultHermitianTranspose(A->vecs, x, y_));
323:   } else if (oldest_idx < next_idx) {
324:     // contiguous vectors are used
325:     if (alpha == 1.0 && beta == 1.0) PetscCall(MatMultHermitianTransposeAddColumnRange(A->vecs, x, y_, y_, oldest_idx, next_idx));
326:     else PetscCall(MatMultHermitianTransposeColumnRange(A->vecs, x, y_, oldest_idx, next_idx));
327:   } else {
328:     if (alpha == 1.0 && beta == 1.0) {
329:       PetscCall(MatMultHermitianTransposeAddColumnRange(A->vecs, x, y_, y_, 0, next_idx));
330:       PetscCall(MatMultHermitianTransposeAddColumnRange(A->vecs, x, y_, y_, oldest_idx, A->m));
331:     } else {
332:       PetscCall(MatMultHermitianTransposeColumnRange(A->vecs, x, y_, 0, next_idx));
333:       PetscCall(MatMultHermitianTransposeColumnRange(A->vecs, x, y_, oldest_idx, A->m));
334:     }
335:   }
336:   if (alpha != 1.0 || (beta != 1.0 && beta != 0.0)) {
337:     PetscCall(VecAXPBYCyclic(oldest, next, alpha, y_, beta, y));
338:     PetscCall(LMBasisRestoreWorkRow(A, &y_));
339:   }
340:   PetscCall(PetscLogEventEnd(LMBASIS_GEMVH, NULL, NULL, NULL, NULL));
341:   PetscFunctionReturn(PETSC_SUCCESS);
342: }

344: static PetscErrorCode LMBasisGEMMH_Internal(Mat A, Mat B, PetscScalar alpha, PetscScalar beta, Mat G)
345: {
346:   PetscFunctionBegin;
347:   if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(A));
348:   if (beta != 0.0) {
349:     Mat G_alloc;

351:     if (beta != 1.0) PetscCall(MatScale(G, beta));
352:     PetscCall(MatTransposeMatMult(A, B, MAT_INITIAL_MATRIX, PETSC_DECIDE, &G_alloc));
353:     PetscCall(MatAXPY(G, alpha, G_alloc, DIFFERENT_NONZERO_PATTERN));
354:     PetscCall(MatDestroy(&G_alloc));
355:   } else {
356:     PetscCall(MatProductClear(G));
357:     PetscCall(MatProductCreateWithMat(A, B, NULL, G));
358:     PetscCall(MatProductSetType(G, MATPRODUCT_AtB));
359:     PetscCall(MatProductSetFromOptions(G));
360:     PetscCall(MatProductSymbolic(G));
361:     PetscCall(MatProductNumeric(G));
362:     if (alpha != 1.0) PetscCall(MatScale(G, alpha));
363:   }
364:   if (PetscDefined(USE_COMPLEX)) PetscCall(MatConjugate(A));
365:   PetscFunctionReturn(PETSC_SUCCESS);
366: }

368: PETSC_INTERN PetscErrorCode LMBasisGEMMH(LMBasis A, PetscInt a_oldest, PetscInt a_next, LMBasis B, PetscInt b_oldest, PetscInt b_next, PetscScalar alpha, PetscScalar beta, Mat G)
369: {
370:   PetscInt a_lim = a_next - a_oldest;
371:   PetscInt b_lim = b_next - b_oldest;

373:   PetscFunctionBegin;
374:   if (a_lim <= 0 || b_lim <= 0) PetscFunctionReturn(PETSC_SUCCESS);
375:   PetscCall(PetscLogEventBegin(LMBASIS_GEMM, NULL, NULL, NULL, NULL));
376:   PetscCall(LMBasisMultCheck(A, a_oldest, a_next));
377:   PetscCall(LMBasisMultCheck(B, b_oldest, b_next));
378:   if (b_lim == 1) {
379:     Vec b;
380:     Vec g;

382:     PetscCall(LMBasisGetVecRead(B, b_oldest, &b));
383:     PetscCall(MatDenseGetColumnVec(G, b_oldest % B->m, &g));
384:     PetscCall(LMBasisGEMVH(A, a_oldest, a_next, alpha, b, beta, g));
385:     PetscCall(MatDenseRestoreColumnVec(G, b_oldest % B->m, &g));
386:     PetscCall(LMBasisRestoreVecRead(B, b_oldest, &b));
387:   } else if (a_lim == 1) {
388:     Vec a;
389:     Vec g;

391:     PetscCall(LMBasisGetVecRead(A, a_oldest, &a));
392:     PetscCall(LMBasisGetWorkRow(B, &g));
393:     PetscCall(LMBasisGEMVH(B, b_oldest, b_next, 1.0, a, 0.0, g));
394:     if (PetscDefined(USE_COMPLEX)) PetscCall(VecConjugate(g));
395:     PetscCall(MatSeqDenseRowAXPBYCyclic(b_oldest, b_next, alpha, g, beta, G, a_oldest));
396:     PetscCall(LMBasisRestoreWorkRow(B, &g));
397:     PetscCall(LMBasisRestoreVecRead(A, a_oldest, &a));
398:   } else {
399:     PetscInt a_next_idx        = ((a_next - 1) % A->m) + 1;
400:     PetscInt a_oldest_idx      = a_oldest % A->m;
401:     PetscInt b_next_idx        = ((b_next - 1) % B->m) + 1;
402:     PetscInt b_oldest_idx      = b_oldest % B->m;
403:     PetscInt a_intervals[2][2] = {
404:       {0,            a_next_idx},
405:       {a_oldest_idx, A->m      }
406:     };
407:     PetscInt b_intervals[2][2] = {
408:       {0,            b_next_idx},
409:       {b_oldest_idx, B->m      }
410:     };
411:     PetscInt a_num_intervals = 2;
412:     PetscInt b_num_intervals = 2;

414:     if (a_lim == A->m || a_oldest_idx < a_next_idx) {
415:       a_num_intervals = 1;
416:       if (a_lim == A->m) {
417:         a_intervals[0][0] = 0;
418:         a_intervals[0][1] = A->m;
419:       } else {
420:         a_intervals[0][0] = a_oldest_idx;
421:         a_intervals[0][1] = a_next_idx;
422:       }
423:     }
424:     if (b_lim == B->m || b_oldest_idx < b_next_idx) {
425:       b_num_intervals = 1;
426:       if (b_lim == B->m) {
427:         b_intervals[0][0] = 0;
428:         b_intervals[0][1] = B->m;
429:       } else {
430:         b_intervals[0][0] = b_oldest_idx;
431:         b_intervals[0][1] = b_next_idx;
432:       }
433:     }
434:     for (PetscInt i = 0; i < a_num_intervals; i++) {
435:       Mat sub_A = A->vecs;
436:       Mat sub_A_;

438:       if (a_intervals[i][0] != 0 || a_intervals[i][1] != A->m) PetscCall(MatDenseGetSubMatrix(A->vecs, PETSC_DECIDE, PETSC_DECIDE, a_intervals[i][0], a_intervals[i][1], &sub_A));
439:       sub_A_ = sub_A;

441:       for (PetscInt j = 0; j < b_num_intervals; j++) {
442:         Mat sub_B = B->vecs;
443:         Mat sub_G = G;

445:         if (b_intervals[j][0] != 0 || b_intervals[j][1] != B->m) {
446:           if (sub_A_ == sub_A && sub_A != A->vecs && B->vecs == A->vecs) {
447:             /* We're hampered by the fact that you can only get one submatrix from a MatDense at a time.  This case
448:              * should not happen often, copying here is acceptable */
449:             PetscCall(MatDuplicate(sub_A, MAT_COPY_VALUES, &sub_A_));
450:             PetscCall(MatDenseRestoreSubMatrix(A->vecs, &sub_A));
451:             sub_A = A->vecs;
452:           }
453:           PetscCall(MatDenseGetSubMatrix(B->vecs, PETSC_DECIDE, PETSC_DECIDE, b_intervals[j][0], b_intervals[j][1], &sub_B));
454:         }

456:         if (sub_A_ != A->vecs || sub_B != B->vecs) PetscCall(MatDenseGetSubMatrix(G, a_intervals[i][0], a_intervals[i][1], b_intervals[j][0], b_intervals[j][1], &sub_G));

458:         PetscCall(LMBasisGEMMH_Internal(sub_A_, sub_B, alpha, beta, sub_G));

460:         if (sub_G != G) PetscCall(MatDenseRestoreSubMatrix(G, &sub_G));
461:         if (sub_B != B->vecs) PetscCall(MatDenseRestoreSubMatrix(B->vecs, &sub_B));
462:       }

464:       if (sub_A_ != sub_A) PetscCall(MatDestroy(&sub_A_));
465:       if (sub_A != A->vecs) PetscCall(MatDenseRestoreSubMatrix(A->vecs, &sub_A));
466:     }
467:   }
468:   PetscCall(PetscLogEventEnd(LMBASIS_GEMM, NULL, NULL, NULL, NULL));
469:   PetscFunctionReturn(PETSC_SUCCESS);
470: }

472: PETSC_INTERN PetscErrorCode LMBasisReset(LMBasis basis)
473: {
474:   PetscFunctionBegin;
475:   if (basis) {
476:     basis->k = 0;
477:     PetscCall(VecDestroy(&basis->cached_product));
478:     basis->cached_vec_id    = 0;
479:     basis->cached_vec_state = 0;
480:     basis->operator_id      = 0;
481:     basis->operator_state   = 0;
482:   }
483:   PetscFunctionReturn(PETSC_SUCCESS);
484: }

486: PETSC_INTERN PetscErrorCode LMBasisSetCachedProduct(LMBasis A, Vec x, Vec Ax)
487: {
488:   PetscFunctionBegin;
489:   if (x == NULL) {
490:     A->cached_vec_id    = 0;
491:     A->cached_vec_state = 0;
492:   } else {
493:     PetscCall(PetscObjectGetId((PetscObject)x, &A->cached_vec_id));
494:     PetscCall(PetscObjectStateGet((PetscObject)x, &A->cached_vec_state));
495:   }
496:   PetscCall(PetscObjectReference((PetscObject)Ax));
497:   PetscCall(VecDestroy(&A->cached_product));
498:   A->cached_product = Ax;
499:   PetscFunctionReturn(PETSC_SUCCESS);
500: }