```  1: #include <petsctao.h>
2: /*
3: Description:   ADMM tomography reconstruction example .
4:                0.5*||Ax-b||^2 + lambda*g(x)
5: Reference:     BRGN Tomography Example
6: */

8: static char help[] = "Finds the ADMM solution to the under constraint linear model Ax = b, with regularizer. \n\
9:                       A is a M*N real matrix (M<N), x is sparse. A good regularizer is an L1 regularizer. \n\
10:                       We first split the operator into 0.5*||Ax-b||^2, f(x), and lambda*||x||_1, g(z), where lambda is user specified weight. \n\
11:                       g(z) could be either ||z||_1, or ||z||_2^2. Default closed form solution for NORM1 would be soft-threshold, which is \n\
12:                       natively supported in admm.c with -tao_admm_regularizer_type soft-threshold. Or user can use regular TAO solver for  \n\
13:                       either NORM1 or NORM2 or TAOSHELL, with -reg {1,2,3} \n\
14:                       Then, we augment both f and g, and solve it via ADMM. \n\
15:                       D is the M*N transform matrix so that D*x is sparse. \n";

17: typedef struct {
18:   PetscInt  M, N, K, reg;
19:   PetscReal lambda, eps, mumin;
20:   Mat       A, ATA, H, Hx, D, Hz, DTD, HF;
21:   Vec       c, xlb, xub, x, b, workM, workN, workN2, workN3, xGT; /* observation b, ground truth xGT, the lower bound and upper bound of x*/
22: } AppCtx;

24: /*------------------------------------------------------------*/

26: PetscErrorCode NullJacobian(Tao tao, Vec X, Mat J, Mat Jpre, void *ptr)
27: {
28:   PetscFunctionBegin;
29:   PetscFunctionReturn(PETSC_SUCCESS);
30: }

32: /*------------------------------------------------------------*/

34: static PetscErrorCode TaoShellSolve_SoftThreshold(Tao tao)
35: {
36:   PetscReal lambda, mu;
37:   AppCtx   *user;
38:   Vec       out, work, y, x;

41:   PetscFunctionBegin;
42:   user = NULL;
43:   mu   = 0;
47:   PetscCall(TaoShellGetContext(tao, &user));

50:   work = user->workN;
51:   PetscCall(TaoGetSolution(tao, &out));
52:   PetscCall(TaoGetSolution(misfit, &x));

55:   /* Dx + y/mu */
56:   PetscCall(MatMult(user->D, x, work));
57:   PetscCall(VecAXPY(work, 1 / mu, y));

59:   /* soft thresholding */
60:   PetscCall(TaoSoftThreshold(work, -lambda / mu, lambda / mu, out));
61:   PetscFunctionReturn(PETSC_SUCCESS);
62: }

64: /*------------------------------------------------------------*/

66: PetscErrorCode MisfitObjectiveAndGradient(Tao tao, Vec X, PetscReal *f, Vec g, void *ptr)
67: {
68:   AppCtx *user = (AppCtx *)ptr;

70:   PetscFunctionBegin;
71:   /* Objective  0.5*||Ax-b||_2^2 */
72:   PetscCall(MatMult(user->A, X, user->workM));
73:   PetscCall(VecAXPY(user->workM, -1, user->b));
74:   PetscCall(VecDot(user->workM, user->workM, f));
75:   *f *= 0.5;
77:   PetscCall(MatMult(user->ATA, X, user->workN));
78:   PetscCall(MatMultTranspose(user->A, user->b, user->workN2));
79:   PetscCall(VecWAXPY(g, -1., user->workN2, user->workN));
80:   PetscFunctionReturn(PETSC_SUCCESS);
81: }

83: /*------------------------------------------------------------*/

85: PetscErrorCode RegularizerObjectiveAndGradient1(Tao tao, Vec X, PetscReal *f_reg, Vec G_reg, void *ptr)
86: {
87:   AppCtx   *user = (AppCtx *)ptr;
88:   PetscReal lambda;

91:   PetscFunctionBegin;
92:   /* compute regularizer objective
93:    * f = f + lambda*sum(sqrt(y.^2+epsilon^2) - epsilon), where y = D*x */
94:   PetscCall(VecCopy(X, user->workN2));
95:   PetscCall(VecPow(user->workN2, 2.));
96:   PetscCall(VecShift(user->workN2, user->eps * user->eps));
97:   PetscCall(VecSqrtAbs(user->workN2));
98:   PetscCall(VecCopy(user->workN2, user->workN3));
99:   PetscCall(VecShift(user->workN2, -user->eps));
100:   PetscCall(VecSum(user->workN2, f_reg));
103:   *f_reg *= lambda;
104:   /* compute regularizer gradient = lambda*x */
105:   PetscCall(VecPointwiseDivide(G_reg, X, user->workN3));
106:   PetscCall(VecScale(G_reg, lambda));
107:   PetscFunctionReturn(PETSC_SUCCESS);
108: }

110: /*------------------------------------------------------------*/

112: PetscErrorCode RegularizerObjectiveAndGradient2(Tao tao, Vec X, PetscReal *f_reg, Vec G_reg, void *ptr)
113: {
114:   PetscReal temp, lambda;

117:   PetscFunctionBegin;
118:   /* compute regularizer objective = lambda*|z|_2^2 */
119:   PetscCall(VecDot(X, X, &temp));
122:   *f_reg = 0.5 * lambda * temp;
123:   /* compute regularizer gradient = lambda*z */
124:   PetscCall(VecCopy(X, G_reg));
125:   PetscCall(VecScale(G_reg, lambda));
126:   PetscFunctionReturn(PETSC_SUCCESS);
127: }

129: /*------------------------------------------------------------*/

131: static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
132: {
133:   PetscFunctionBegin;
134:   PetscFunctionReturn(PETSC_SUCCESS);
135: }

137: /*------------------------------------------------------------*/

139: static PetscErrorCode HessianReg(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
140: {
141:   AppCtx *user = (AppCtx *)ptr;

143:   PetscFunctionBegin;
144:   PetscCall(MatMult(user->D, x, user->workN));
145:   PetscCall(VecPow(user->workN2, 2.));
146:   PetscCall(VecShift(user->workN2, user->eps * user->eps));
147:   PetscCall(VecSqrtAbs(user->workN2));
148:   PetscCall(VecShift(user->workN2, -user->eps));
149:   PetscCall(VecReciprocal(user->workN2));
150:   PetscCall(VecScale(user->workN2, user->eps * user->eps));
151:   PetscCall(MatDiagonalSet(H, user->workN2, INSERT_VALUES));
152:   PetscFunctionReturn(PETSC_SUCCESS);
153: }

155: /*------------------------------------------------------------*/

157: PetscErrorCode FullObjGrad(Tao tao, Vec X, PetscReal *f, Vec g, void *ptr)
158: {
159:   AppCtx   *user = (AppCtx *)ptr;
160:   PetscReal f_reg, lambda;

163:   PetscFunctionBegin;
164:   /* Objective  0.5*||Ax-b||_2^2 + lambda*||x||_{1,2}^2*/
165:   PetscCall(MatMult(user->A, X, user->workM));
166:   PetscCall(VecAXPY(user->workM, -1, user->b));
167:   PetscCall(VecDot(user->workM, user->workM, f));
168:   if (user->reg == 1) {
169:     PetscCall(VecNorm(X, NORM_1, &f_reg));
170:   } else {
171:     PetscCall(VecNorm(X, NORM_2, &f_reg));
172:   }
176:   } else {
177:     lambda = user->lambda;
178:   }
179:   *f *= 0.5;
180:   *f += lambda * f_reg * f_reg;
181:   /* Gradient. ATAx-ATb + 2*lambda*x */
182:   PetscCall(MatMult(user->ATA, X, user->workN));
183:   PetscCall(MatMultTranspose(user->A, user->b, user->workN2));
184:   PetscCall(VecWAXPY(g, -1., user->workN2, user->workN));
185:   PetscCall(VecAXPY(g, 2 * lambda, X));
186:   PetscFunctionReturn(PETSC_SUCCESS);
187: }
188: /*------------------------------------------------------------*/

190: static PetscErrorCode HessianFull(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
191: {
192:   PetscFunctionBegin;
193:   PetscFunctionReturn(PETSC_SUCCESS);
194: }
195: /*------------------------------------------------------------*/

197: PetscErrorCode InitializeUserData(AppCtx *user)
198: {
199:   char        dataFile[] = "tomographyData_A_b_xGT"; /* Matrix A and vectors b, xGT(ground truth) binary files generated by Matlab. Debug: change from "tomographyData_A_b_xGT" to "cs1Data_A_b_xGT". */
200:   PetscViewer fd;                                    /* used to load data from file */
201:   PetscInt    k, n;
202:   PetscScalar v;

204:   PetscFunctionBegin;
205:   /* Load the A matrix, b vector, and xGT vector from a binary file. */
207:   PetscCall(MatCreate(PETSC_COMM_WORLD, &user->A));
208:   PetscCall(MatSetType(user->A, MATAIJ));
210:   PetscCall(VecCreate(PETSC_COMM_WORLD, &user->b));
212:   PetscCall(VecCreate(PETSC_COMM_WORLD, &user->xGT));
214:   PetscCall(PetscViewerDestroy(&fd));

216:   PetscCall(MatGetSize(user->A, &user->M, &user->N));

218:   PetscCall(MatCreate(PETSC_COMM_WORLD, &user->D));
219:   PetscCall(MatSetSizes(user->D, PETSC_DECIDE, PETSC_DECIDE, user->N, user->N));
220:   PetscCall(MatSetFromOptions(user->D));
221:   PetscCall(MatSetUp(user->D));
222:   for (k = 0; k < user->N; k++) {
223:     v = 1.0;
224:     n = k + 1;
225:     if (k < user->N - 1) PetscCall(MatSetValues(user->D, 1, &k, 1, &n, &v, INSERT_VALUES));
226:     v = -1.0;
227:     PetscCall(MatSetValues(user->D, 1, &k, 1, &k, &v, INSERT_VALUES));
228:   }
229:   PetscCall(MatAssemblyBegin(user->D, MAT_FINAL_ASSEMBLY));
230:   PetscCall(MatAssemblyEnd(user->D, MAT_FINAL_ASSEMBLY));

232:   PetscCall(MatTransposeMatMult(user->D, user->D, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &user->DTD));

234:   PetscCall(MatCreate(PETSC_COMM_WORLD, &user->Hz));
235:   PetscCall(MatSetSizes(user->Hz, PETSC_DECIDE, PETSC_DECIDE, user->N, user->N));
236:   PetscCall(MatSetFromOptions(user->Hz));
237:   PetscCall(MatSetUp(user->Hz));
238:   PetscCall(MatAssemblyBegin(user->Hz, MAT_FINAL_ASSEMBLY));
239:   PetscCall(MatAssemblyEnd(user->Hz, MAT_FINAL_ASSEMBLY));

241:   PetscCall(VecCreate(PETSC_COMM_WORLD, &user->x));
242:   PetscCall(VecCreate(PETSC_COMM_WORLD, &user->workM));
243:   PetscCall(VecCreate(PETSC_COMM_WORLD, &user->workN));
244:   PetscCall(VecCreate(PETSC_COMM_WORLD, &user->workN2));
245:   PetscCall(VecSetSizes(user->x, PETSC_DECIDE, user->N));
246:   PetscCall(VecSetSizes(user->workM, PETSC_DECIDE, user->M));
247:   PetscCall(VecSetSizes(user->workN, PETSC_DECIDE, user->N));
248:   PetscCall(VecSetSizes(user->workN2, PETSC_DECIDE, user->N));
249:   PetscCall(VecSetFromOptions(user->x));
250:   PetscCall(VecSetFromOptions(user->workM));
251:   PetscCall(VecSetFromOptions(user->workN));
252:   PetscCall(VecSetFromOptions(user->workN2));

254:   PetscCall(VecDuplicate(user->workN, &user->workN3));
255:   PetscCall(VecDuplicate(user->x, &user->xlb));
256:   PetscCall(VecDuplicate(user->x, &user->xub));
257:   PetscCall(VecDuplicate(user->x, &user->c));
258:   PetscCall(VecSet(user->xlb, 0.0));
259:   PetscCall(VecSet(user->c, 0.0));
260:   PetscCall(VecSet(user->xub, PETSC_INFINITY));

262:   PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &user->ATA));
263:   PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &user->Hx));
264:   PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DEFAULT, &user->HF));

266:   PetscCall(MatAssemblyBegin(user->ATA, MAT_FINAL_ASSEMBLY));
267:   PetscCall(MatAssemblyEnd(user->ATA, MAT_FINAL_ASSEMBLY));
268:   PetscCall(MatAssemblyBegin(user->Hx, MAT_FINAL_ASSEMBLY));
269:   PetscCall(MatAssemblyEnd(user->Hx, MAT_FINAL_ASSEMBLY));
270:   PetscCall(MatAssemblyBegin(user->HF, MAT_FINAL_ASSEMBLY));
271:   PetscCall(MatAssemblyEnd(user->HF, MAT_FINAL_ASSEMBLY));

273:   user->lambda = 1.e-8;
274:   user->eps    = 1.e-3;
275:   user->reg    = 2;
276:   user->mumin  = 5.e-6;

278:   PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "tomographyADMM.c");
279:   PetscCall(PetscOptionsInt("-reg", "Regularization scheme for z solver (1,2)", "tomographyADMM.c", user->reg, &user->reg, NULL));
280:   PetscCall(PetscOptionsReal("-lambda", "The regularization multiplier. 1 default", "tomographyADMM.c", user->lambda, &user->lambda, NULL));
282:   PetscCall(PetscOptionsReal("-mumin", "Minimum value for ADMM spectral penalty", "tomographyADMM.c", user->mumin, &user->mumin, NULL));
283:   PetscOptionsEnd();
284:   PetscFunctionReturn(PETSC_SUCCESS);
285: }

287: /*------------------------------------------------------------*/

289: PetscErrorCode DestroyContext(AppCtx *user)
290: {
291:   PetscFunctionBegin;
292:   PetscCall(MatDestroy(&user->A));
293:   PetscCall(MatDestroy(&user->ATA));
294:   PetscCall(MatDestroy(&user->Hx));
295:   PetscCall(MatDestroy(&user->Hz));
296:   PetscCall(MatDestroy(&user->HF));
297:   PetscCall(MatDestroy(&user->D));
298:   PetscCall(MatDestroy(&user->DTD));
299:   PetscCall(VecDestroy(&user->xGT));
300:   PetscCall(VecDestroy(&user->xlb));
301:   PetscCall(VecDestroy(&user->xub));
302:   PetscCall(VecDestroy(&user->b));
303:   PetscCall(VecDestroy(&user->x));
304:   PetscCall(VecDestroy(&user->c));
305:   PetscCall(VecDestroy(&user->workN3));
306:   PetscCall(VecDestroy(&user->workN2));
307:   PetscCall(VecDestroy(&user->workN));
308:   PetscCall(VecDestroy(&user->workM));
309:   PetscFunctionReturn(PETSC_SUCCESS);
310: }

312: /*------------------------------------------------------------*/

314: int main(int argc, char **argv)
315: {
316:   Tao         tao, misfit, reg;
317:   PetscReal   v1, v2;
318:   AppCtx     *user;
319:   PetscViewer fd;
320:   char        resultFile[] = "tomographyResult_x";

322:   PetscFunctionBeginUser;
323:   PetscCall(PetscInitialize(&argc, &argv, (char *)0, help));
324:   PetscCall(PetscNew(&user));
325:   PetscCall(InitializeUserData(user));

327:   PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao));
329:   PetscCall(TaoSetSolution(tao, user->x));
330:   /* f(x) + g(x) for parent tao */
333:   PetscCall(MatShift(user->HF, user->lambda));
334:   PetscCall(TaoSetHessian(tao, user->HF, user->HF, HessianFull, (void *)user));

336:   /* f(x) for misfit tao */
338:   PetscCall(TaoADMMSetMisfitHessianRoutine(tao, user->Hx, user->Hx, HessianMisfit, (void *)user));
340:   PetscCall(TaoADMMSetMisfitConstraintJacobian(tao, user->D, user->D, NullJacobian, (void *)user));

342:   /* g(x) for regularizer tao */
343:   if (user->reg == 1) {
345:     PetscCall(TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianReg, (void *)user));
347:   } else if (user->reg == 2) {
349:     PetscCall(MatShift(user->Hz, 1));
350:     PetscCall(MatScale(user->Hz, user->lambda));
351:     PetscCall(TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianMisfit, (void *)user));
353:   } else PetscCheck(user->reg == 3, PETSC_COMM_WORLD, PETSC_ERR_ARG_UNKNOWN_TYPE, "Incorrect Reg type"); /* TaoShell case */

355:   /* Set type for the misfit solver */
358:   PetscCall(TaoSetType(misfit, TAONLS));
359:   if (user->reg == 3) {
360:     PetscCall(TaoSetType(reg, TAOSHELL));
361:     PetscCall(TaoShellSetContext(reg, (void *)user));
362:     PetscCall(TaoShellSetSolve(reg, TaoShellSolve_SoftThreshold));
363:   } else {
364:     PetscCall(TaoSetType(reg, TAONLS));
365:   }
366:   PetscCall(TaoSetVariableBounds(misfit, user->xlb, user->xub));

368:   /* Soft Thresholding solves the ADMM problem with the L1 regularizer lambda*||z||_1 and the x-z=0 constraint */
370:   PetscCall(TaoADMMSetRegularizerConstraintJacobian(tao, NULL, NULL, NullJacobian, (void *)user));

374:   PetscCall(TaoSetFromOptions(tao));
375:   PetscCall(TaoSolve(tao));

377:   /* Save x (reconstruction of object) vector to a binary file, which maybe read from MATLAB and convert to a 2D image for comparison. */
378:   PetscCall(PetscViewerBinaryOpen(PETSC_COMM_WORLD, resultFile, FILE_MODE_WRITE, &fd));
379:   PetscCall(VecView(user->x, fd));
380:   PetscCall(PetscViewerDestroy(&fd));

382:   /* compute the error */
383:   PetscCall(VecAXPY(user->x, -1, user->xGT));
384:   PetscCall(VecNorm(user->x, NORM_2, &v1));
385:   PetscCall(VecNorm(user->xGT, NORM_2, &v2));
386:   PetscCall(PetscPrintf(PETSC_COMM_WORLD, "relative reconstruction error: ||x-xGT||/||xGT|| = %6.4e.\n", (double)(v1 / v2)));

388:   /* Free TAO data structures */
389:   PetscCall(TaoDestroy(&tao));
390:   PetscCall(DestroyContext(user));
391:   PetscCall(PetscFree(user));
392:   PetscCall(PetscFinalize());
393:   return 0;
394: }

396: /*TEST

398:    build:
399:       requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES)

401:    test:
402:       suffix: 1
403:       localrunfiles: tomographyData_A_b_xGT
404:       args: -lambda 1.e-8 -tao_monitor -tao_type nls -tao_nls_pc_type icc

406:    test:
407:       suffix: 2
408:       localrunfiles: tomographyData_A_b_xGT
409:       args: -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor -reg_tao_monitor

411:    test:
412:       suffix: 3
413:       localrunfiles: tomographyData_A_b_xGT

416:    test:
417:       suffix: 4
418:       localrunfiles: tomographyData_A_b_xGT