Actual source code: ex1.c

  1: const char help[] = "Coverage tests for TAOTERMSHELL";

  3: #include <petsctaoterm.h>

  5: static PetscErrorCode TaoTermCreateSolutionVec_Test(TaoTerm term, Vec *solution)
  6: {
  7:   Mat A;

  9:   PetscFunctionBeginUser;
 10:   PetscCall(TaoTermShellGetContext(term, &A));
 11:   PetscCall(MatCreateVecs(A, NULL, solution));
 12:   PetscFunctionReturn(PETSC_SUCCESS);
 13: }

 15: static PetscErrorCode TaoTermCreateParametersVec_Test(TaoTerm term, Vec *params)
 16: {
 17:   Mat A;

 19:   PetscFunctionBeginUser;
 20:   PetscCall(TaoTermShellGetContext(term, &A));
 21:   PetscCall(MatCreateVecs(A, params, NULL));
 22:   PetscFunctionReturn(PETSC_SUCCESS);
 23: }

 25: static PetscErrorCode TaoTermView_Test(TaoTerm term, PetscViewer viewer)
 26: {
 27:   PetscFunctionBeginUser;
 28:   PetscCall(PetscViewerASCIIPrintf(viewer, "TaoTermView_Test()\n"));
 29:   PetscFunctionReturn(PETSC_SUCCESS);
 30: }

 32: static PetscErrorCode TaoTermComputeObjective_Test(TaoTerm term, Vec x, Vec params, PetscReal *value)
 33: {
 34:   Mat A;
 35:   Vec r;

 37:   PetscFunctionBeginUser;
 38:   PetscCall(TaoTermShellGetContext(term, &A));
 39:   PetscCall(VecDuplicate(x, &r));
 40:   PetscCall(MatMult(A, params, r));
 41:   PetscCall(VecDotRealPart(x, r, value));
 42:   PetscCall(VecDestroy(&r));
 43:   PetscFunctionReturn(PETSC_SUCCESS);
 44: }

 46: static PetscErrorCode TaoTermComputeGradient_Test(TaoTerm term, Vec x, Vec params, Vec g)
 47: {
 48:   Mat A;

 50:   PetscFunctionBeginUser;
 51:   PetscCall(TaoTermShellGetContext(term, &A));
 52:   PetscCall(MatMult(A, params, g));
 53:   PetscFunctionReturn(PETSC_SUCCESS);
 54: }

 56: static PetscErrorCode TaoTermComputeObjectiveAndGradient_Test(TaoTerm term, Vec x, Vec params, PetscReal *value, Vec g)
 57: {
 58:   Mat A;

 60:   PetscFunctionBeginUser;
 61:   PetscCall(TaoTermShellGetContext(term, &A));
 62:   PetscCall(MatMult(A, params, g));
 63:   PetscCall(VecDotRealPart(x, g, value));
 64:   PetscFunctionReturn(PETSC_SUCCESS);
 65: }

 67: static PetscErrorCode testShell(MPI_Comm comm, PetscBool separate)
 68: {
 69:   PetscRandom rand;
 70:   Mat         A;
 71:   TaoTerm     term;
 72:   PetscInt    m = 23, n = 11;
 73:   Vec         x, params, g;
 74:   PetscInt    test_m, test_n;
 75:   PetscReal   value, g_norm;

 77:   PetscFunctionBeginUser;
 78:   PetscCall(PetscRandomCreate(comm, &rand));
 79:   PetscCall(MatCreateDense(comm, PETSC_DECIDE, PETSC_DECIDE, m, n, NULL, &A));
 80:   PetscCall(MatSetUp(A));
 81:   PetscCall(MatSetRandom(A, rand));
 82:   PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
 83:   PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));

 85:   PetscCall(TaoTermCreateShell(comm, A, NULL, &term));
 86:   PetscCall(TaoTermSetParametersMode(term, TAOTERM_PARAMETERS_REQUIRED));

 88:   if (separate) {
 89:     PetscCall(MatCreateVecs(A, &params, &x));
 90:     PetscCall(TaoTermSetSolutionTemplate(term, x));
 91:     PetscCall(TaoTermSetParametersTemplate(term, params));
 92:     PetscCall(VecDestroy(&params));
 93:     PetscCall(VecDestroy(&x));
 94:   } else {
 95:     PetscCall(TaoTermShellSetCreateSolutionVec(term, TaoTermCreateSolutionVec_Test));
 96:     PetscCall(TaoTermShellSetCreateParametersVec(term, TaoTermCreateParametersVec_Test));
 97:   }

 99:   PetscCall(TaoTermSetUp(term));

101:   PetscCall(TaoTermGetSolutionSizes(term, NULL, &test_m, NULL));
102:   PetscCall(TaoTermGetParametersSizes(term, NULL, &test_n, NULL));
103:   PetscCheck(test_m == m, comm, PETSC_ERR_PLIB, "Inconsistent solution size");
104:   PetscCheck(test_n == n, comm, PETSC_ERR_PLIB, "Inconsistent parameters size");

106:   if (separate) {
107:     PetscCall(TaoTermShellSetObjective(term, TaoTermComputeObjective_Test));
108:     PetscCall(TaoTermShellSetGradient(term, TaoTermComputeGradient_Test));
109:   } else {
110:     PetscCall(TaoTermShellSetObjectiveAndGradient(term, TaoTermComputeObjectiveAndGradient_Test));
111:     PetscCall(TaoTermShellSetView(term, TaoTermView_Test));
112:   }

114:   PetscCall(TaoTermView(term, PETSC_VIEWER_STDOUT_(comm)));

116:   PetscCall(TaoTermCreateSolutionVec(term, &x));
117:   PetscCall(TaoTermCreateParametersVec(term, &params));

119:   PetscCall(VecSetRandom(x, rand));
120:   PetscCall(VecSetRandom(params, rand));
121:   PetscCall(VecDuplicate(x, &g));
122:   PetscCall(TaoTermComputeObjectiveAndGradient(term, x, params, &value, g));
123:   PetscCall(VecNorm(g, NORM_2, &g_norm));
124:   PetscCall(PetscPrintf(comm, "objective: %g, gradient norm %g\n", (double)value, (double)g_norm));

126:   PetscCall(VecDestroy(&g));
127:   PetscCall(VecDestroy(&params));
128:   PetscCall(VecDestroy(&x));
129:   PetscCall(TaoTermDestroy(&term));
130:   PetscCall(MatDestroy(&A));
131:   PetscCall(PetscRandomDestroy(&rand));
132:   PetscFunctionReturn(PETSC_SUCCESS);
133: }

135: int main(int argc, char **argv)
136: {
137:   PetscFunctionBeginUser;
138:   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
139:   PetscCall(testShell(PETSC_COMM_WORLD, PETSC_TRUE));
140:   PetscCall(testShell(PETSC_COMM_WORLD, PETSC_FALSE));
141:   PetscCall(PetscFinalize());
142:   return 0;
143: }

145: /*TEST

147:   test:
148:     suffix: 0

150: TEST*/