Actual source code: taotermtest2.c

  1: #include <petsctao.h>

  3: static char help[] = "Using TaoTermShell with mapping matrices that are not diagonal.\n";

  5: typedef struct {
  6:   Vec pdiff_work; /* Work vector for x - params */
  7: } HalfL2Ctx;

  9: typedef struct {
 10:   Mat A;    /* Mapping matrix A */
 11:   Vec p;    /* Target vector p */
 12:   Vec Ax;   /* Work vector for A*x */
 13:   Vec Ax_p; /* Work vector for A*x - p */
 14: } CallbackCtx;

 16: static PetscErrorCode FormFunctionGradient(TaoTerm, Vec, Vec, PetscReal *, Vec);
 17: static PetscErrorCode FormHessian(TaoTerm, Vec, Vec, Mat, Mat);
 18: static PetscErrorCode CtxDestroy(PetscCtxRt ctx);

 20: /* Callback functions for traditional TAO interface */
 21: static PetscErrorCode FormObjectiveGradient_Callback(Tao, Vec, PetscReal *, Vec, void *);
 22: static PetscErrorCode FormHessian_Callback(Tao, Vec, Mat, Mat, void *);

 24: int main(int argc, char **argv)
 25: {
 26:   TaoTerm      objective;
 27:   Tao          tao, tao2;
 28:   PetscMPIInt  size;
 29:   HalfL2Ctx   *ctx;
 30:   MPI_Comm     comm;
 31:   PetscInt     n = 10, m = 10;
 32:   Mat          A;
 33:   Vec          target;
 34:   CallbackCtx *cb_ctx;
 35:   Vec          x_term, x_callback, x2, diff;
 36:   Mat          H2;
 37:   PetscReal    norm_diff, diag_val = 1.1;
 38:   PetscBool    opt, is_diag, is_cdiag, is_aij, is_dense, fd_notpossible;
 39:   const char  *mtype         = MATAIJ;
 40:   char         typeName[256] = "";

 42:   PetscFunctionBeginUser;
 43:   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
 44:   comm = PETSC_COMM_WORLD;
 45:   PetscCallMPI(MPI_Comm_size(comm, &size));
 46:   PetscCheck(size == 1, comm, PETSC_ERR_WRONG_MPI_SIZE, "Incorrect number of processors");

 48:   fd_notpossible = PETSC_FALSE;

 50:   PetscOptionsBegin(comm, "", help, "none");
 51:   PetscCall(PetscOptionsBool("-fd_notpossible", "Set TaoTermShell ComputeHessianFDPossible as false", "", fd_notpossible, &fd_notpossible, NULL));
 52:   PetscCall(PetscOptionsInt("-n", "Problem size", "", n, &n, NULL));
 53:   PetscCall(PetscOptionsInt("-m", "Mapping matrix row size", "", m, &m, NULL));
 54:   PetscCall(PetscOptionsReal("-diag_val", "Value of constant diagonal matrix", NULL, diag_val, &diag_val, NULL));
 55:   PetscCall(PetscOptionsFList("-mapping_mtype", "Mapping matrix type", "", MatList, mtype, typeName, 256, &opt));
 56:   PetscOptionsEnd();

 58:   PetscCall(PetscNew(&ctx));

 60:   /* Initialize typeName to default if option was not set */
 61:   if (!opt) PetscCall(PetscStrcpy(typeName, mtype));

 63:   PetscCall(PetscStrcmp(typeName, MATDIAGONAL, &is_diag));
 64:   PetscCall(PetscStrcmp(typeName, MATCONSTANTDIAGONAL, &is_cdiag));
 65:   PetscCall(PetscStrcmp(typeName, MATAIJ, &is_aij));
 66:   PetscCall(PetscStrcmp(typeName, MATDENSE, &is_dense));
 67:   /* Create mapping matrix A: m x n (maps from solution space to term space) */
 68:   if (is_diag) {
 69:     /* Create a diagonal matrix */
 70:     Vec      diag_vec;
 71:     PetscInt diag_size;

 73:     PetscCheck(m == n, comm, PETSC_ERR_ARG_INCOMP, "For diagonal matrix, m and n must be equal (got m=%" PetscInt_FMT ", n=%" PetscInt_FMT ")", m, n);
 74:     diag_size = m;
 75:     PetscCall(VecCreate(comm, &diag_vec));
 76:     PetscCall(VecSetSizes(diag_vec, PETSC_DECIDE, diag_size));
 77:     PetscCall(VecSetFromOptions(diag_vec));
 78:     PetscCall(VecSetRandom(diag_vec, NULL));
 79:     PetscCall(MatCreateDiagonal(diag_vec, &A));
 80:     PetscCall(VecDestroy(&diag_vec));
 81:   } else if (is_cdiag) {
 82:     /* Create a constant diagonal matrix */
 83:     PetscCheck(m == n, comm, PETSC_ERR_ARG_INCOMP, "For constant diagonal matrix, m and n must be equal (got m=%" PetscInt_FMT ", n=%" PetscInt_FMT ")", m, n);
 84:     PetscCall(MatCreateConstantDiagonal(comm, PETSC_DECIDE, PETSC_DECIDE, m, n, diag_val, &A));
 85:   } else if (is_dense) {
 86:     /* Create a dense matrix */
 87:     PetscCall(MatCreateDense(comm, PETSC_DECIDE, PETSC_DECIDE, m, n, NULL, &A));
 88:     PetscCall(MatSetFromOptions(A));
 89:     PetscCall(MatSetRandom(A, NULL));
 90:     PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
 91:     PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
 92:   } else {
 93:     /* Create an AIJ matrix (default) */
 94:     PetscCall(MatCreateSeqAIJ(comm, m, n, PETSC_DEFAULT, NULL, &A));
 95:     PetscCall(MatSetFromOptions(A));
 96:     PetscCall(MatSetRandom(A, NULL));
 97:     PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
 98:     PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
 99:   }

101:   /* Create shell term that computes f(x) = 0.5 ||x||_2^2 */
102:   PetscCall(TaoTermCreateShell(comm, ctx, CtxDestroy, &objective));

104:   /* Set solution and parameter sizes to match the mapped space (m) */
105:   PetscCall(TaoTermSetSolutionSizes(objective, PETSC_DECIDE, m, 1));
106:   PetscCall(TaoTermSetParametersSizes(objective, PETSC_DECIDE, m, 1));

108:   PetscCall(TaoTermShellSetObjectiveAndGradient(objective, FormFunctionGradient));
109:   PetscCall(TaoTermShellSetCreateHessianMatrices(objective, TaoTermCreateHessianMatricesDefault));
110:   PetscCall(TaoTermSetCreateHessianMode(objective, PETSC_TRUE /* H == Hpre */, MATAIJ, NULL));
111:   PetscCall(TaoTermShellSetHessian(objective, FormHessian));
112:   PetscCall(TaoTermSetFromOptions(objective));
113:   if (fd_notpossible) PetscCall(TaoTermShellSetIsComputeHessianFDPossible(objective, PETSC_BOOL3_FALSE));

115:   PetscCall(TaoTermSetUp(objective));

117:   /* Create target vector for least squares problem (parameters) */
118:   PetscCall(TaoTermCreateParametersVec(objective, &target));
119:   PetscCall(VecSetRandom(target, NULL));

121:   PetscCall(TaoCreate(comm, &tao));
122:   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)tao, "shell_"));
123:   PetscCall(TaoSetType(tao, TAOLMVM));

125:   /* Add term with mapping matrix A: f(Ax; p) = 0.5 ||Ax - p||_2^2 */
126:   PetscCall(TaoAddTerm(tao, NULL, 1.0, objective, target, A));

128:   PetscCall(TaoSetFromOptions(tao));
129:   PetscCall(TaoSolve(tao));

131:   /* Allocate callback context */
132:   PetscCall(PetscNew(&cb_ctx));
133:   cb_ctx->A = A;
134:   cb_ctx->p = target;

136:   /* Create work vectors */
137:   PetscCall(MatCreateVecs(A, NULL, &cb_ctx->Ax));
138:   PetscCall(VecDuplicate(target, &cb_ctx->Ax_p));

140:   PetscCall(MatCreateVecs(A, &x2, NULL));
141:   PetscCall(VecZeroEntries(x2));

143:   /* Create Hessian matrix A^T * A */
144:   if (is_diag) {
145:     Vec A_diag, H2_diag;

147:     PetscCall(MatCreateVecs(A, &A_diag, NULL));
148:     PetscCall(MatGetDiagonal(A, A_diag));
149:     PetscCall(VecDuplicate(A_diag, &H2_diag));
150:     PetscCall(VecPointwiseMult(H2_diag, A_diag, A_diag));
151:     PetscCall(MatCreateDiagonal(H2_diag, &H2));
152:     PetscCall(VecDestroy(&A_diag));
153:     PetscCall(VecDestroy(&H2_diag));
154:   } else if (is_cdiag) {
155:     PetscCall(MatCreateConstantDiagonal(comm, PETSC_DECIDE, PETSC_DECIDE, m, n, diag_val * diag_val, &H2));
156:   } else {
157:     Mat       Htest, Hpretest;
158:     PetscBool is_h_dense;

160:     PetscCall(MatTransposeMatMult(A, A, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &H2));
161:     PetscCall(MatAssemblyBegin(H2, MAT_FINAL_ASSEMBLY));
162:     PetscCall(MatAssemblyEnd(H2, MAT_FINAL_ASSEMBLY));

164:     PetscCall(TaoGetHessianMatrices(tao, &Htest, &Hpretest));
165:     PetscCall(PetscObjectBaseTypeCompare((PetscObject)Htest, MATSEQDENSE, &is_h_dense));
166:     if (is_h_dense) PetscCall(MatConvert(H2, MATDENSE, MAT_INPLACE_MATRIX, &H2));
167:   }
168:   /* Create second TAO solver */
169:   PetscCall(TaoCreate(comm, &tao2));
170:   PetscCall(PetscObjectSetOptionsPrefix((PetscObject)tao2, "regular_"));
171:   PetscCall(TaoSetType(tao2, TAOLMVM));
172:   PetscCall(TaoSetSolution(tao2, x2));
173:   PetscCall(TaoSetObjectiveAndGradient(tao2, NULL, FormObjectiveGradient_Callback, cb_ctx));
174:   PetscCall(TaoSetHessian(tao2, H2, H2, FormHessian_Callback, cb_ctx));
175:   PetscCall(TaoSetFromOptions(tao2));
176:   PetscCall(TaoSolve(tao2));

178:   /* Compare solutions */
179:   PetscCall(TaoGetSolution(tao, &x_term));
180:   PetscCall(TaoGetSolution(tao2, &x_callback));
181:   PetscCall(VecDuplicate(x_term, &diff));
182:   PetscCall(VecCopy(x_term, diff));
183:   PetscCall(VecAXPY(diff, -1.0, x_callback));
184:   PetscCall(VecNorm(diff, NORM_2, &norm_diff));
185:   if (norm_diff <= 1.e-12) PetscCall(PetscPrintf(comm, "Relative difference < 1e-12\n"));
186:   else PetscCall(PetscPrintf(comm, "Relative difference > 1e-12: %6.10e\n", (double)norm_diff));
187:   PetscCall(VecDestroy(&x2));
188:   PetscCall(VecDestroy(&diff));
189:   PetscCall(VecDestroy(&cb_ctx->Ax));
190:   PetscCall(VecDestroy(&cb_ctx->Ax_p));
191:   PetscCall(PetscFree(cb_ctx));
192:   PetscCall(VecDestroy(&target));
193:   PetscCall(MatDestroy(&A));
194:   PetscCall(MatDestroy(&H2));
195:   PetscCall(TaoDestroy(&tao2));
196:   PetscCall(TaoDestroy(&tao));
197:   PetscCall(TaoTermDestroy(&objective));
198:   PetscCall(PetscFinalize());
199:   return 0;
200: }

202: /*
203:   FormFunctionGradient - Evaluates the function, f(X), and gradient, G(X).

205:   Input Parameters:
206: + term      - the `TaoTerm` for the objective function
207: . x         - input vector
208: - params    - optional vector of parameters

210:   Output Parameters:
211: + f - function value
212: - G - vector containing the newly evaluated gradient

214:   Note:
215:   Computes f = 0.5 * ||x - params||_2^2 and g = x - params, matching TAOTERMHALFL2SQUARED.
216: */
217: static PetscErrorCode FormFunctionGradient(TaoTerm term, Vec x, Vec params, PetscReal *f, Vec G)
218: {
219:   HalfL2Ctx  *ctx;
220:   PetscScalar v;

222:   PetscFunctionBeginUser;
223:   PetscCall(TaoTermShellGetContext(term, &ctx));
224:   if (params) {
225:     PetscCall(VecWAXPY(G, -1.0, params, x));
226:     PetscCall(VecDot(G, G, &v));
227:   } else {
228:     PetscCall(VecCopy(x, G));
229:     PetscCall(VecDot(G, G, &v));
230:   }
231:   *f = 0.5 * PetscRealPart(v);
232:   PetscFunctionReturn(PETSC_SUCCESS);
233: }

235: /*
236:   FormHessian - Evaluates Hessian matrix.

238:   Input Parameters:
239: + term      - the `TaoTerm` for the objective function
240: . x         - input vector
241: . params    - optional vector of parameters
242: - Hpre      - optional preconditioner matrix

244:   Output Parameters:
245: + H    - Hessian matrix
246: - Hpre - Preconditioning matrix

248:   Note:
249:   Computes H = I (identity matrix), matching TAOTERMHALFL2SQUARED.
250: */
251: static PetscErrorCode FormHessian(TaoTerm term, Vec x, Vec params, Mat H, Mat Hpre)
252: {
253:   PetscFunctionBeginUser;
254:   if (H) {
255:     PetscCall(MatZeroEntries(H));
256:     PetscCall(MatAssemblyBegin(H, MAT_FINAL_ASSEMBLY));
257:     PetscCall(MatAssemblyEnd(H, MAT_FINAL_ASSEMBLY));
258:     PetscCall(MatShift(H, 1.0));
259:   }
260:   if (Hpre && Hpre != H) {
261:     PetscCall(MatZeroEntries(Hpre));
262:     PetscCall(MatAssemblyBegin(Hpre, MAT_FINAL_ASSEMBLY));
263:     PetscCall(MatAssemblyEnd(Hpre, MAT_FINAL_ASSEMBLY));
264:     PetscCall(MatShift(Hpre, 1.0));
265:   }
266:   PetscFunctionReturn(PETSC_SUCCESS);
267: }

269: static PetscErrorCode CtxDestroy(PetscCtxRt ctx_ptr)
270: {
271:   HalfL2Ctx *ctx = *(HalfL2Ctx **)ctx_ptr;

273:   PetscFunctionBeginUser;
274:   if (ctx) {
275:     PetscCall(VecDestroy(&ctx->pdiff_work));
276:     PetscCall(PetscFree(ctx));
277:     *(void **)ctx_ptr = NULL;
278:   }
279:   PetscFunctionReturn(PETSC_SUCCESS);
280: }

282: /*
283:   FormObjectiveGradient_Callback - Evaluates the objective and gradient for traditional TAO callback interface.

285:   Input Parameters:
286: + tao  - the Tao solver context
287: . x    - input vector (size n)
288: - ctx  - user context containing A and p

290:   Output Parameters:
291: + f - function value: 0.5 * ||Ax - p||_2^2
292: - g - gradient vector: A^T (Ax - p)

294:   Note:
295:   Computes f = 0.5 * ||Ax - p||_2^2 and g = A^T (Ax - p)
296: */
297: static PetscErrorCode FormObjectiveGradient_Callback(Tao tao, Vec x, PetscReal *f, Vec g, void *ctx)
298: {
299:   CallbackCtx *cb_ctx = (CallbackCtx *)ctx;
300:   PetscScalar  v;

302:   PetscFunctionBeginUser;
303:   /* Compute Ax */
304:   PetscCall(MatMult(cb_ctx->A, x, cb_ctx->Ax));
305:   /* Compute Ax - p */
306:   PetscCall(VecCopy(cb_ctx->Ax, cb_ctx->Ax_p));
307:   PetscCall(VecAXPY(cb_ctx->Ax_p, -1.0, cb_ctx->p));
308:   /* Compute objective: 0.5 * ||Ax - p||_2^2 */
309:   PetscCall(VecDot(cb_ctx->Ax_p, cb_ctx->Ax_p, &v));
310:   *f = 0.5 * PetscRealPart(v);
311:   /* Compute gradient: A^T (Ax - p) */
312:   PetscCall(MatMultTranspose(cb_ctx->A, cb_ctx->Ax_p, g));
313:   PetscFunctionReturn(PETSC_SUCCESS);
314: }

316: /*
317:   FormHessian_Callback - Evaluates the Hessian matrix for traditional TAO callback interface.

319:   Input Parameters:
320: + tao  - the Tao solver context
321: . x    - input vector
322: . H    - Hessian matrix (should be pre-allocated as A^T * A)
323: . Hpre - preconditioner matrix
324: - ctx  - user context containing A and p

326:   Output Parameters:
327: + H    - Hessian matrix (A^T * A)
328: - Hpre - Preconditioning matrix

330:   Note:
331:   The Hessian for 0.5 * ||Ax - p||_2^2 is constant: H = A^T * A
332: */
333: static PetscErrorCode FormHessian_Callback(Tao tao, Vec x, Mat H, Mat Hpre, void *ctx)
334: {
335:   PetscFunctionBeginUser;
336:   /* Hessian is constant: A^T * A, which should already be set in H */
337:   if (Hpre && Hpre != H) PetscCall(MatCopy(H, Hpre, SAME_NONZERO_PATTERN));
338:   PetscFunctionReturn(PETSC_SUCCESS);
339: }

341: /* Note: For dense variations, relative error may be greater than 1.e-12, *
342:  * but that is okay, as it is a result of KSP, and PC using AIJ matrices  *
343:  * instead of dense.                                                      */

345: /*TEST

347:    build:
348:      requires: !complex !single !quad !defined(PETSC_USE_64BIT_INDICES) !__float128

350:    test:
351:      suffix: diag_diag
352:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
353:      args: -tao_term_hessian_mat_type diagonal -mapping_mtype diagonal

355:    test:
356:      suffix: diag_cdiag
357:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
358:      args: -tao_term_hessian_mat_type diagonal -mapping_mtype constantdiagonal

360:    test:
361:      suffix: diag_dense
362:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
363:      args: -tao_term_hessian_mat_type diagonal -mapping_mtype dense

365:    test:
366:      suffix: diag_dense_nsq
367:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
368:      args: -tao_term_hessian_mat_type diagonal -mapping_mtype dense -m 15

370:    test:
371:      suffix: diag_aij
372:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
373:      args: -tao_term_hessian_mat_type diagonal -mapping_mtype aij

375:    test:
376:      suffix: cdiag_diag
377:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
378:      args: -tao_term_hessian_mat_type constantdiagonal -mapping_mtype diagonal

380:    test:
381:      suffix: cdiag_cdiag
382:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
383:      args: -tao_term_hessian_mat_type constantdiagonal -mapping_mtype constantdiagonal

385:    test:
386:      suffix: cdiag_dense
387:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
388:      args: -tao_term_hessian_mat_type constantdiagonal -mapping_mtype dense

390:    test:
391:      suffix: cdiag_dense_nsq
392:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
393:      args: -tao_term_hessian_mat_type constantdiagonal -mapping_mtype dense -m 15

395:    test:
396:      suffix: cdiag_aij
397:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
398:      args: -tao_term_hessian_mat_type constantdiagonal -mapping_mtype aij

400:    test:
401:      suffix: dense_diag
402:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
403:      args: -tao_term_hessian_mat_type dense -mapping_mtype diagonal

405:    test:
406:      suffix: dense_cdiag
407:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
408:      args: -tao_term_hessian_mat_type dense -mapping_mtype constantdiagonal

410:    test:
411:      suffix: dense_dense
412:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
413:      args: -tao_term_hessian_mat_type dense -mapping_mtype dense -fd_notpossible {{0 1}}

415:    test:
416:      suffix: dense_dense_nsq
417:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
418:      args: -tao_term_hessian_mat_type dense -mapping_mtype dense -m 15

420:    test:
421:      suffix: dense_aij
422:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
423:      args: -tao_term_hessian_mat_type dense -mapping_mtype aij

425:    test:
426:      suffix: aij_diag
427:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
428:      args: -tao_term_hessian_mat_type aij -mapping_mtype diagonal

430:    test:
431:      suffix: aij_cdiag
432:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
433:      args: -tao_term_hessian_mat_type aij -mapping_mtype constantdiagonal

435:    test:
436:      suffix: aij_dense
437:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
438:      args: -tao_term_hessian_mat_type aij -mapping_mtype dense

440:    test:
441:      suffix: aij_dense_nsq
442:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
443:      args: -tao_term_hessian_mat_type aij -mapping_mtype dense -m 15

445:    test:
446:      suffix: aij_aij
447:      args: -shell_tao_type nls -shell_tao_view ::ascii_info_detail -regular_tao_type nls -regular_tao_view ::ascii_info_detail
448:      args: -tao_term_hessian_mat_type aij -mapping_mtype aij

450: TEST*/