Actual source code: ex1.c
1: const char help[] = "Coverage tests for TAOTERMSHELL";
3: #include <petsctaoterm.h>
5: static PetscErrorCode TaoTermCreateSolutionVec_Test(TaoTerm term, Vec *solution)
6: {
7: Mat A;
9: PetscFunctionBeginUser;
10: PetscCall(TaoTermShellGetContext(term, &A));
11: PetscCall(MatCreateVecs(A, NULL, solution));
12: PetscFunctionReturn(PETSC_SUCCESS);
13: }
15: static PetscErrorCode TaoTermCreateParametersVec_Test(TaoTerm term, Vec *params)
16: {
17: Mat A;
19: PetscFunctionBeginUser;
20: PetscCall(TaoTermShellGetContext(term, &A));
21: PetscCall(MatCreateVecs(A, params, NULL));
22: PetscFunctionReturn(PETSC_SUCCESS);
23: }
25: static PetscErrorCode TaoTermView_Test(TaoTerm term, PetscViewer viewer)
26: {
27: PetscFunctionBeginUser;
28: PetscCall(PetscViewerASCIIPrintf(viewer, "TaoTermView_Test()\n"));
29: PetscFunctionReturn(PETSC_SUCCESS);
30: }
32: static PetscErrorCode TaoTermComputeObjective_Test(TaoTerm term, Vec x, Vec params, PetscReal *value)
33: {
34: Mat A;
35: Vec r;
37: PetscFunctionBeginUser;
38: PetscCall(TaoTermShellGetContext(term, &A));
39: PetscCall(VecDuplicate(x, &r));
40: PetscCall(MatMult(A, params, r));
41: PetscCall(VecDotRealPart(x, r, value));
42: PetscCall(VecDestroy(&r));
43: PetscFunctionReturn(PETSC_SUCCESS);
44: }
46: static PetscErrorCode TaoTermComputeGradient_Test(TaoTerm term, Vec x, Vec params, Vec g)
47: {
48: Mat A;
50: PetscFunctionBeginUser;
51: PetscCall(TaoTermShellGetContext(term, &A));
52: PetscCall(MatMult(A, params, g));
53: PetscFunctionReturn(PETSC_SUCCESS);
54: }
56: static PetscErrorCode TaoTermComputeObjectiveAndGradient_Test(TaoTerm term, Vec x, Vec params, PetscReal *value, Vec g)
57: {
58: Mat A;
60: PetscFunctionBeginUser;
61: PetscCall(TaoTermShellGetContext(term, &A));
62: PetscCall(MatMult(A, params, g));
63: PetscCall(VecDotRealPart(x, g, value));
64: PetscFunctionReturn(PETSC_SUCCESS);
65: }
67: static PetscErrorCode testShell(MPI_Comm comm, PetscBool separate)
68: {
69: PetscRandom rand;
70: Mat A;
71: TaoTerm term;
72: PetscInt m = 23, n = 11;
73: Vec x, params, g;
74: PetscInt test_m, test_n;
75: PetscReal value, g_norm;
77: PetscFunctionBeginUser;
78: PetscCall(PetscRandomCreate(comm, &rand));
79: PetscCall(MatCreateDense(comm, PETSC_DECIDE, PETSC_DECIDE, m, n, NULL, &A));
80: PetscCall(MatSetUp(A));
81: PetscCall(MatSetRandom(A, rand));
82: PetscCall(MatAssemblyBegin(A, MAT_FINAL_ASSEMBLY));
83: PetscCall(MatAssemblyEnd(A, MAT_FINAL_ASSEMBLY));
85: PetscCall(TaoTermCreateShell(comm, A, NULL, &term));
86: PetscCall(TaoTermSetParametersMode(term, TAOTERM_PARAMETERS_REQUIRED));
88: if (separate) {
89: PetscCall(MatCreateVecs(A, ¶ms, &x));
90: PetscCall(TaoTermSetSolutionTemplate(term, x));
91: PetscCall(TaoTermSetParametersTemplate(term, params));
92: PetscCall(VecDestroy(¶ms));
93: PetscCall(VecDestroy(&x));
94: } else {
95: PetscCall(TaoTermShellSetCreateSolutionVec(term, TaoTermCreateSolutionVec_Test));
96: PetscCall(TaoTermShellSetCreateParametersVec(term, TaoTermCreateParametersVec_Test));
97: }
99: PetscCall(TaoTermSetUp(term));
101: PetscCall(TaoTermGetSolutionSizes(term, NULL, &test_m, NULL));
102: PetscCall(TaoTermGetParametersSizes(term, NULL, &test_n, NULL));
103: PetscCheck(test_m == m, comm, PETSC_ERR_PLIB, "Inconsistent solution size");
104: PetscCheck(test_n == n, comm, PETSC_ERR_PLIB, "Inconsistent parameters size");
106: if (separate) {
107: PetscCall(TaoTermShellSetObjective(term, TaoTermComputeObjective_Test));
108: PetscCall(TaoTermShellSetGradient(term, TaoTermComputeGradient_Test));
109: } else {
110: PetscCall(TaoTermShellSetObjectiveAndGradient(term, TaoTermComputeObjectiveAndGradient_Test));
111: PetscCall(TaoTermShellSetView(term, TaoTermView_Test));
112: }
114: PetscCall(TaoTermView(term, PETSC_VIEWER_STDOUT_(comm)));
116: PetscCall(TaoTermCreateSolutionVec(term, &x));
117: PetscCall(TaoTermCreateParametersVec(term, ¶ms));
119: PetscCall(VecSetRandom(x, rand));
120: PetscCall(VecSetRandom(params, rand));
121: PetscCall(VecDuplicate(x, &g));
122: PetscCall(TaoTermComputeObjectiveAndGradient(term, x, params, &value, g));
123: PetscCall(VecNorm(g, NORM_2, &g_norm));
124: PetscCall(PetscPrintf(comm, "objective: %g, gradient norm %g\n", (double)value, (double)g_norm));
126: PetscCall(VecDestroy(&g));
127: PetscCall(VecDestroy(¶ms));
128: PetscCall(VecDestroy(&x));
129: PetscCall(TaoTermDestroy(&term));
130: PetscCall(MatDestroy(&A));
131: PetscCall(PetscRandomDestroy(&rand));
132: PetscFunctionReturn(PETSC_SUCCESS);
133: }
135: int main(int argc, char **argv)
136: {
137: PetscFunctionBeginUser;
138: PetscCall(PetscInitialize(&argc, &argv, NULL, help));
139: PetscCall(testShell(PETSC_COMM_WORLD, PETSC_TRUE));
140: PetscCall(testShell(PETSC_COMM_WORLD, PETSC_FALSE));
141: PetscCall(PetscFinalize());
142: return 0;
143: }
145: /*TEST
147: test:
148: suffix: 0
150: TEST*/