Actual source code: ex1.c

  1: const char help[] = "TAOTERMCALLBACKS coverage tests";

  3: #include <petsctao.h>

  5: typedef struct {
  6:   PetscInt obj_count;
  7:   PetscInt grad_count;
  8:   PetscInt obj_and_grad_count;
  9:   PetscInt hess_count;
 10: } AppCtx;

 12: static PetscErrorCode objective(Tao tao, Vec x, PetscReal *value, void *ctx)
 13: {
 14:   AppCtx *app = (AppCtx *)ctx;

 16:   PetscFunctionBeginUser;
 17:   *value = 0.0;
 18:   app->obj_count++;
 19:   PetscFunctionReturn(PETSC_SUCCESS);
 20: }

 22: static PetscErrorCode gradient(Tao tao, Vec x, Vec g, void *ctx)
 23: {
 24:   AppCtx *app = (AppCtx *)ctx;

 26:   PetscFunctionBeginUser;
 27:   PetscCall(VecZeroEntries(g));
 28:   app->grad_count++;
 29:   PetscFunctionReturn(PETSC_SUCCESS);
 30: }

 32: static PetscErrorCode objective_and_gradient(Tao tao, Vec x, PetscReal *value, Vec g, void *ctx)
 33: {
 34:   AppCtx *app = (AppCtx *)ctx;

 36:   PetscFunctionBeginUser;
 37:   *value = 0.0;
 38:   PetscCall(VecZeroEntries(g));
 39:   app->obj_and_grad_count++;
 40:   PetscFunctionReturn(PETSC_SUCCESS);
 41: }

 43: static PetscErrorCode hessian(Tao tao, Vec x, Mat H, Mat Hpre, void *ctx)
 44: {
 45:   AppCtx *app = (AppCtx *)ctx;

 47:   PetscFunctionBeginUser;
 48:   if (H) {
 49:     PetscCall(MatZeroEntries(H));
 50:     PetscCall(MatAssemblyBegin(H, MAT_FINAL_ASSEMBLY));
 51:     PetscCall(MatAssemblyEnd(H, MAT_FINAL_ASSEMBLY));
 52:   }
 53:   if (Hpre && Hpre != H) {
 54:     PetscCall(MatZeroEntries(Hpre));
 55:     PetscCall(MatAssemblyBegin(Hpre, MAT_FINAL_ASSEMBLY));
 56:     PetscCall(MatAssemblyEnd(Hpre, MAT_FINAL_ASSEMBLY));
 57:   }
 58:   app->hess_count++;
 59:   PetscFunctionReturn(PETSC_SUCCESS);
 60: }

 62: static PetscErrorCode testCallbacks(PetscBool separate)
 63: {
 64:   MPI_Comm    comm = PETSC_COMM_WORLD;
 65:   Tao         tao;
 66:   TaoTerm     term;
 67:   TaoTermType type;
 68:   PetscBool   same;
 69:   PetscErrorCode (*_hessian)(Tao, Vec, Mat, Mat, void *);
 70:   AppCtx    app;
 71:   Vec       sol, grad;
 72:   Mat       H, Hpre;
 73:   PetscInt  N = 10;
 74:   PetscReal value;

 76:   PetscFunctionBeginUser;
 77:   app.obj_count          = 0;
 78:   app.grad_count         = 0;
 79:   app.obj_and_grad_count = 0;
 80:   app.hess_count         = 0;
 81:   PetscCall(VecCreateMPI(comm, PETSC_DECIDE, N, &sol));
 82:   PetscCall(VecZeroEntries(sol));
 83:   PetscCall(VecDuplicate(sol, &grad));
 84:   PetscCall(MatCreateAIJ(comm, PETSC_DECIDE, PETSC_DECIDE, N, N, 1, NULL, 0, NULL, &H));
 85:   PetscCall(MatDuplicate(H, MAT_DO_NOT_COPY_VALUES, &Hpre));
 86:   PetscCall(TaoCreate(comm, &tao));
 87:   PetscCall(TaoSetSolution(tao, sol));
 88:   PetscCall(TaoGetTerm(tao, NULL, &term, NULL, NULL));
 89:   PetscCall(TaoTermGetType(term, &type));
 90:   PetscCall(PetscStrcmp(type, TAOTERMCALLBACKS, &same));
 91:   PetscCheck(same, comm, PETSC_ERR_PLIB, "wrong TaoTermType");

 93:   if (separate) {
 94:     PetscCall(TaoSetObjective(tao, objective, (void *)&app));
 95:     PetscCall(TaoSetGradient(tao, grad, gradient, (void *)&app));
 96:   } else {
 97:     PetscCall(TaoSetObjectiveAndGradient(tao, grad, objective_and_gradient, (void *)&app));
 98:   }
 99:   PetscCall(TaoSetHessian(tao, H, Hpre, hessian, (void *)&app));

101:   {
102:     PetscBool is_defined;

104:     PetscCall(TaoTermIsHessianDefined(term, &is_defined));
105:     PetscCheck(is_defined == PETSC_TRUE, comm, PETSC_ERR_PLIB, "Hessian should be defined after setting it");
106:   }

108:   if (separate) {
109:     PetscErrorCode (*_objective)(Tao, Vec, PetscReal *, void *);
110:     PetscErrorCode (*_gradient)(Tao, Vec, Vec, void *);

112:     PetscCall(TaoGetObjective(tao, &_objective, NULL));
113:     PetscCall(TaoGetGradient(tao, NULL, &_gradient, NULL));
114:     PetscCheck(_objective == objective, comm, PETSC_ERR_PLIB, "wrong objective callback");
115:     PetscCheck(_gradient == gradient, comm, PETSC_ERR_PLIB, "wrong gradient callback");
116:   } else {
117:     PetscErrorCode (*_objective_and_gradient)(Tao, Vec, PetscReal *, Vec, void *);

119:     PetscCall(TaoGetObjectiveAndGradient(tao, NULL, &_objective_and_gradient, NULL));
120:     PetscCheck(_objective_and_gradient == objective_and_gradient, comm, PETSC_ERR_PLIB, "wrong objective and gradient callback");
121:   }
122:   PetscCall(TaoGetHessian(tao, NULL, NULL, &_hessian, NULL));
123:   PetscCheck(_hessian == hessian, comm, PETSC_ERR_PLIB, "wrong hessian callback");

125:   PetscCall(TaoComputeObjective(tao, sol, &value));
126:   (void)value;
127:   PetscCall(TaoComputeGradient(tao, sol, grad));
128:   PetscCall(TaoComputeObjectiveAndGradient(tao, sol, &value, grad));
129:   (void)value;
130:   PetscCall(TaoComputeHessian(tao, sol, H, Hpre));

132:   if (separate) {
133:     PetscCheck(app.obj_count == 2, comm, PETSC_ERR_PLIB, "Incorrect number of objective evaluations");
134:     PetscCheck(app.grad_count == 2, comm, PETSC_ERR_PLIB, "Incorrect number of gradient evaluations");
135:     PetscCheck(app.obj_and_grad_count == 0, comm, PETSC_ERR_PLIB, "Incorrect number of objective+gradient evaluations");
136:   } else {
137:     PetscCheck(app.obj_count == 0, comm, PETSC_ERR_PLIB, "Incorrect number of objective evaluations");
138:     PetscCheck(app.grad_count == 0, comm, PETSC_ERR_PLIB, "Incorrect number of gradient evaluations");
139:     PetscCheck(app.obj_and_grad_count == 3, comm, PETSC_ERR_PLIB, "Incorrect number of objective+gradient evaluations");
140:   }
141:   PetscCheck(app.hess_count == 1, comm, PETSC_ERR_PLIB, "Incorrect number of hessian evaluations");

143:   PetscCall(TaoDestroy(&tao));
144:   PetscCall(MatDestroy(&Hpre));
145:   PetscCall(MatDestroy(&H));
146:   PetscCall(VecDestroy(&grad));
147:   PetscCall(VecDestroy(&sol));
148:   PetscFunctionReturn(PETSC_SUCCESS);
149: }

151: int main(int argc, char **argv)
152: {
153:   PetscFunctionBeginUser;
154:   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
155:   PetscCall(testCallbacks(PETSC_FALSE));
156:   PetscCall(testCallbacks(PETSC_TRUE));
157:   PetscCall(PetscFinalize());
158:   return 0;
159: }

161: /*TEST

163:   test:
164:     suffix: 0
165:     output_file: output/empty.out

167: TEST*/