Actual source code: tomographyADMM.c
1: #include <petsctao.h>
2: /*
3: Description: ADMM tomography reconstruction example .
4: 0.5*||Ax-b||^2 + lambda*g(x)
5: Reference: BRGN Tomography Example
6: */
8: static char help[] = "Finds the ADMM solution to the under constraint linear model Ax = b, with regularizer. \n\
9: A is a M*N real matrix (M<N), x is sparse. A good regularizer is an L1 regularizer. \n\
10: We first split the operator into 0.5*||Ax-b||^2, f(x), and lambda*||x||_1, g(z), where lambda is user specified weight. \n\
11: g(z) could be either ||z||_1, or ||z||_2^2. Default closed form solution for NORM1 would be soft-threshold, which is \n\
12: natively supported in admm.c with -tao_admm_regularizer_type soft-threshold. Or user can use regular TAO solver for \n\
13: either NORM1 or NORM2 or TAOSHELL, with -reg {1,2,3} \n\
14: Then, we augment both f and g, and solve it via ADMM. \n\
15: D is the M*N transform matrix so that D*x is sparse. \n";
17: typedef struct {
18: PetscInt M, N, K, reg;
19: PetscReal lambda, eps, mumin;
20: Mat A, ATA, H, Hx, D, Hz, DTD, HF;
21: Vec c, xlb, xub, x, b, workM, workN, workN2, workN3, xGT; /* observation b, ground truth xGT, the lower bound and upper bound of x*/
22: } AppCtx;
24: /*------------------------------------------------------------*/
26: PetscErrorCode NullJacobian(Tao tao, Vec X, Mat J, Mat Jpre, void *ptr)
27: {
28: PetscFunctionBegin;
29: PetscFunctionReturn(PETSC_SUCCESS);
30: }
32: /*------------------------------------------------------------*/
34: static PetscErrorCode TaoShellSolve_SoftThreshold(Tao tao)
35: {
36: PetscReal lambda, mu;
37: AppCtx *user;
38: Vec out, work, y, x;
39: Tao admm_tao, misfit;
41: PetscFunctionBegin;
42: user = NULL;
43: mu = 0;
44: PetscCall(TaoGetADMMParentTao(tao, &admm_tao));
45: PetscCall(TaoADMMGetMisfitSubsolver(admm_tao, &misfit));
46: PetscCall(TaoADMMGetSpectralPenalty(admm_tao, &mu));
47: PetscCall(TaoShellGetContext(tao, &user));
48: PetscCall(TaoADMMGetRegularizerCoefficient(admm_tao, &lambda));
50: work = user->workN;
51: PetscCall(TaoGetSolution(tao, &out));
52: PetscCall(TaoGetSolution(misfit, &x));
53: PetscCall(TaoADMMGetDualVector(admm_tao, &y));
55: /* Dx + y/mu */
56: PetscCall(MatMult(user->D, x, work));
57: PetscCall(VecAXPY(work, 1 / mu, y));
59: /* soft thresholding */
60: PetscCall(TaoSoftThreshold(work, -lambda / mu, lambda / mu, out));
61: PetscFunctionReturn(PETSC_SUCCESS);
62: }
64: /*------------------------------------------------------------*/
66: PetscErrorCode MisfitObjectiveAndGradient(Tao tao, Vec X, PetscReal *f, Vec g, void *ptr)
67: {
68: AppCtx *user = (AppCtx *)ptr;
70: PetscFunctionBegin;
71: /* Objective 0.5*||Ax-b||_2^2 */
72: PetscCall(MatMult(user->A, X, user->workM));
73: PetscCall(VecAXPY(user->workM, -1, user->b));
74: PetscCall(VecDot(user->workM, user->workM, f));
75: *f *= 0.5;
76: /* Gradient. ATAx-ATb */
77: PetscCall(MatMult(user->ATA, X, user->workN));
78: PetscCall(MatMultTranspose(user->A, user->b, user->workN2));
79: PetscCall(VecWAXPY(g, -1., user->workN2, user->workN));
80: PetscFunctionReturn(PETSC_SUCCESS);
81: }
83: /*------------------------------------------------------------*/
85: PetscErrorCode RegularizerObjectiveAndGradient1(Tao tao, Vec X, PetscReal *f_reg, Vec G_reg, void *ptr)
86: {
87: AppCtx *user = (AppCtx *)ptr;
88: PetscReal lambda;
89: Tao admm_tao;
91: PetscFunctionBegin;
92: /* compute regularizer objective
93: * f = f + lambda*sum(sqrt(y.^2+epsilon^2) - epsilon), where y = D*x */
94: PetscCall(VecCopy(X, user->workN2));
95: PetscCall(VecPow(user->workN2, 2.));
96: PetscCall(VecShift(user->workN2, user->eps * user->eps));
97: PetscCall(VecSqrtAbs(user->workN2));
98: PetscCall(VecCopy(user->workN2, user->workN3));
99: PetscCall(VecShift(user->workN2, -user->eps));
100: PetscCall(VecSum(user->workN2, f_reg));
101: PetscCall(TaoGetADMMParentTao(tao, &admm_tao));
102: PetscCall(TaoADMMGetRegularizerCoefficient(admm_tao, &lambda));
103: *f_reg *= lambda;
104: /* compute regularizer gradient = lambda*x */
105: PetscCall(VecPointwiseDivide(G_reg, X, user->workN3));
106: PetscCall(VecScale(G_reg, lambda));
107: PetscFunctionReturn(PETSC_SUCCESS);
108: }
110: /*------------------------------------------------------------*/
112: PetscErrorCode RegularizerObjectiveAndGradient2(Tao tao, Vec X, PetscReal *f_reg, Vec G_reg, void *ptr)
113: {
114: PetscReal temp, lambda;
115: Tao admm_tao;
117: PetscFunctionBegin;
118: /* compute regularizer objective = lambda*|z|_2^2 */
119: PetscCall(VecDot(X, X, &temp));
120: PetscCall(TaoGetADMMParentTao(tao, &admm_tao));
121: PetscCall(TaoADMMGetRegularizerCoefficient(admm_tao, &lambda));
122: *f_reg = 0.5 * lambda * temp;
123: /* compute regularizer gradient = lambda*z */
124: PetscCall(VecCopy(X, G_reg));
125: PetscCall(VecScale(G_reg, lambda));
126: PetscFunctionReturn(PETSC_SUCCESS);
127: }
129: /*------------------------------------------------------------*/
131: static PetscErrorCode HessianMisfit(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
132: {
133: PetscFunctionBegin;
134: PetscFunctionReturn(PETSC_SUCCESS);
135: }
137: /*------------------------------------------------------------*/
139: static PetscErrorCode HessianReg(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
140: {
141: AppCtx *user = (AppCtx *)ptr;
143: PetscFunctionBegin;
144: PetscCall(MatMult(user->D, x, user->workN));
145: PetscCall(VecPow(user->workN2, 2.));
146: PetscCall(VecShift(user->workN2, user->eps * user->eps));
147: PetscCall(VecSqrtAbs(user->workN2));
148: PetscCall(VecShift(user->workN2, -user->eps));
149: PetscCall(VecReciprocal(user->workN2));
150: PetscCall(VecScale(user->workN2, user->eps * user->eps));
151: PetscCall(MatDiagonalSet(H, user->workN2, INSERT_VALUES));
152: PetscFunctionReturn(PETSC_SUCCESS);
153: }
155: /*------------------------------------------------------------*/
157: PetscErrorCode FullObjGrad(Tao tao, Vec X, PetscReal *f, Vec g, void *ptr)
158: {
159: AppCtx *user = (AppCtx *)ptr;
160: PetscReal f_reg, lambda;
161: PetscBool is_admm;
163: PetscFunctionBegin;
164: /* Objective 0.5*||Ax-b||_2^2 + lambda*||x||_{1,2}^2*/
165: PetscCall(MatMult(user->A, X, user->workM));
166: PetscCall(VecAXPY(user->workM, -1, user->b));
167: PetscCall(VecDot(user->workM, user->workM, f));
168: if (user->reg == 1) {
169: PetscCall(VecNorm(X, NORM_1, &f_reg));
170: } else {
171: PetscCall(VecNorm(X, NORM_2, &f_reg));
172: }
173: PetscCall(PetscObjectTypeCompare((PetscObject)tao, TAOADMM, &is_admm));
174: if (is_admm) {
175: PetscCall(TaoADMMGetRegularizerCoefficient(tao, &lambda));
176: } else {
177: lambda = user->lambda;
178: }
179: *f *= 0.5;
180: *f += lambda * f_reg * f_reg;
181: /* Gradient. ATAx-ATb + 2*lambda*x */
182: PetscCall(MatMult(user->ATA, X, user->workN));
183: PetscCall(MatMultTranspose(user->A, user->b, user->workN2));
184: PetscCall(VecWAXPY(g, -1., user->workN2, user->workN));
185: PetscCall(VecAXPY(g, 2 * lambda, X));
186: PetscFunctionReturn(PETSC_SUCCESS);
187: }
188: /*------------------------------------------------------------*/
190: static PetscErrorCode HessianFull(Tao tao, Vec x, Mat H, Mat Hpre, void *ptr)
191: {
192: PetscFunctionBegin;
193: PetscFunctionReturn(PETSC_SUCCESS);
194: }
195: /*------------------------------------------------------------*/
197: PetscErrorCode InitializeUserData(AppCtx *user)
198: {
199: char dataFile[PETSC_MAX_PATH_LEN], path[PETSC_MAX_PATH_LEN]; /* Matrix A and vectors b, xGT(ground truth) binary files generated by Matlab. Debug: change from "tomographyData_A_b_xGT" to "cs1Data_A_b_xGT". */
200: PetscViewer fd; /* used to load data from file */
201: PetscInt k, n;
202: PetscScalar v;
203: PetscBool flg;
205: PetscFunctionBegin;
206: PetscCall(PetscOptionsGetString(NULL, NULL, "-path", path, sizeof(path), &flg));
207: PetscCheck(flg, PETSC_COMM_WORLD, PETSC_ERR_USER, "Must specify -path ${DATAFILESPATH}/tao/tomography");
208: /* Load the A matrix, b vector, and xGT vector from a binary file. */
209: PetscCall(PetscSNPrintf(dataFile, sizeof(dataFile), "%s/tomographyData_A_b_xGT", path));
210: PetscCall(PetscViewerBinaryOpen(PETSC_COMM_WORLD, dataFile, FILE_MODE_READ, &fd));
211: PetscCall(MatCreate(PETSC_COMM_WORLD, &user->A));
212: PetscCall(MatSetType(user->A, MATAIJ));
213: PetscCall(MatLoad(user->A, fd));
214: PetscCall(VecCreate(PETSC_COMM_WORLD, &user->b));
215: PetscCall(VecLoad(user->b, fd));
216: PetscCall(VecCreate(PETSC_COMM_WORLD, &user->xGT));
217: PetscCall(VecLoad(user->xGT, fd));
218: PetscCall(PetscViewerDestroy(&fd));
220: PetscCall(MatGetSize(user->A, &user->M, &user->N));
222: PetscCall(MatCreate(PETSC_COMM_WORLD, &user->D));
223: PetscCall(MatSetSizes(user->D, PETSC_DECIDE, PETSC_DECIDE, user->N, user->N));
224: PetscCall(MatSetFromOptions(user->D));
225: PetscCall(MatSetUp(user->D));
226: for (k = 0; k < user->N; k++) {
227: v = 1.0;
228: n = k + 1;
229: if (k < user->N - 1) PetscCall(MatSetValues(user->D, 1, &k, 1, &n, &v, INSERT_VALUES));
230: v = -1.0;
231: PetscCall(MatSetValues(user->D, 1, &k, 1, &k, &v, INSERT_VALUES));
232: }
233: PetscCall(MatAssemblyBegin(user->D, MAT_FINAL_ASSEMBLY));
234: PetscCall(MatAssemblyEnd(user->D, MAT_FINAL_ASSEMBLY));
236: PetscCall(MatTransposeMatMult(user->D, user->D, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &user->DTD));
238: PetscCall(MatCreate(PETSC_COMM_WORLD, &user->Hz));
239: PetscCall(MatSetSizes(user->Hz, PETSC_DECIDE, PETSC_DECIDE, user->N, user->N));
240: PetscCall(MatSetFromOptions(user->Hz));
241: PetscCall(MatSetUp(user->Hz));
242: PetscCall(MatAssemblyBegin(user->Hz, MAT_FINAL_ASSEMBLY));
243: PetscCall(MatAssemblyEnd(user->Hz, MAT_FINAL_ASSEMBLY));
245: PetscCall(VecCreate(PETSC_COMM_WORLD, &user->x));
246: PetscCall(VecCreate(PETSC_COMM_WORLD, &user->workM));
247: PetscCall(VecCreate(PETSC_COMM_WORLD, &user->workN));
248: PetscCall(VecCreate(PETSC_COMM_WORLD, &user->workN2));
249: PetscCall(VecSetSizes(user->x, PETSC_DECIDE, user->N));
250: PetscCall(VecSetSizes(user->workM, PETSC_DECIDE, user->M));
251: PetscCall(VecSetSizes(user->workN, PETSC_DECIDE, user->N));
252: PetscCall(VecSetSizes(user->workN2, PETSC_DECIDE, user->N));
253: PetscCall(VecSetFromOptions(user->x));
254: PetscCall(VecSetFromOptions(user->workM));
255: PetscCall(VecSetFromOptions(user->workN));
256: PetscCall(VecSetFromOptions(user->workN2));
258: PetscCall(VecDuplicate(user->workN, &user->workN3));
259: PetscCall(VecDuplicate(user->x, &user->xlb));
260: PetscCall(VecDuplicate(user->x, &user->xub));
261: PetscCall(VecDuplicate(user->x, &user->c));
262: PetscCall(VecSet(user->xlb, 0.0));
263: PetscCall(VecSet(user->c, 0.0));
264: PetscCall(VecSet(user->xub, PETSC_INFINITY));
266: PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &user->ATA));
267: PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &user->Hx));
268: PetscCall(MatTransposeMatMult(user->A, user->A, MAT_INITIAL_MATRIX, PETSC_DETERMINE, &user->HF));
270: PetscCall(MatAssemblyBegin(user->ATA, MAT_FINAL_ASSEMBLY));
271: PetscCall(MatAssemblyEnd(user->ATA, MAT_FINAL_ASSEMBLY));
272: PetscCall(MatAssemblyBegin(user->Hx, MAT_FINAL_ASSEMBLY));
273: PetscCall(MatAssemblyEnd(user->Hx, MAT_FINAL_ASSEMBLY));
274: PetscCall(MatAssemblyBegin(user->HF, MAT_FINAL_ASSEMBLY));
275: PetscCall(MatAssemblyEnd(user->HF, MAT_FINAL_ASSEMBLY));
277: user->lambda = 1.e-8;
278: user->eps = 1.e-3;
279: user->reg = 2;
280: user->mumin = 5.e-6;
282: PetscOptionsBegin(PETSC_COMM_WORLD, NULL, "Configure separable objection example", "tomographyADMM.c");
283: PetscCall(PetscOptionsInt("-reg", "Regularization scheme for z solver (1,2)", "tomographyADMM.c", user->reg, &user->reg, NULL));
284: PetscCall(PetscOptionsReal("-lambda", "The regularization multiplier. 1 default", "tomographyADMM.c", user->lambda, &user->lambda, NULL));
285: PetscCall(PetscOptionsReal("-eps", "L1 norm epsilon padding", "tomographyADMM.c", user->eps, &user->eps, NULL));
286: PetscCall(PetscOptionsReal("-mumin", "Minimum value for ADMM spectral penalty", "tomographyADMM.c", user->mumin, &user->mumin, NULL));
287: PetscOptionsEnd();
288: PetscFunctionReturn(PETSC_SUCCESS);
289: }
291: /*------------------------------------------------------------*/
293: PetscErrorCode DestroyContext(AppCtx *user)
294: {
295: PetscFunctionBegin;
296: PetscCall(MatDestroy(&user->A));
297: PetscCall(MatDestroy(&user->ATA));
298: PetscCall(MatDestroy(&user->Hx));
299: PetscCall(MatDestroy(&user->Hz));
300: PetscCall(MatDestroy(&user->HF));
301: PetscCall(MatDestroy(&user->D));
302: PetscCall(MatDestroy(&user->DTD));
303: PetscCall(VecDestroy(&user->xGT));
304: PetscCall(VecDestroy(&user->xlb));
305: PetscCall(VecDestroy(&user->xub));
306: PetscCall(VecDestroy(&user->b));
307: PetscCall(VecDestroy(&user->x));
308: PetscCall(VecDestroy(&user->c));
309: PetscCall(VecDestroy(&user->workN3));
310: PetscCall(VecDestroy(&user->workN2));
311: PetscCall(VecDestroy(&user->workN));
312: PetscCall(VecDestroy(&user->workM));
313: PetscFunctionReturn(PETSC_SUCCESS);
314: }
316: /*------------------------------------------------------------*/
318: int main(int argc, char **argv)
319: {
320: Tao tao, misfit, reg;
321: PetscReal v1, v2;
322: AppCtx *user;
323: PetscViewer fd;
324: char resultFile[] = "tomographyResult_x";
326: PetscFunctionBeginUser;
327: PetscCall(PetscInitialize(&argc, &argv, NULL, help));
328: PetscCall(PetscNew(&user));
329: PetscCall(InitializeUserData(user));
331: PetscCall(TaoCreate(PETSC_COMM_WORLD, &tao));
332: PetscCall(TaoSetType(tao, TAOADMM));
333: PetscCall(TaoSetSolution(tao, user->x));
334: /* f(x) + g(x) for parent tao */
335: PetscCall(TaoADMMSetSpectralPenalty(tao, 1.));
336: PetscCall(TaoSetObjectiveAndGradient(tao, NULL, FullObjGrad, (void *)user));
337: PetscCall(MatShift(user->HF, user->lambda));
338: PetscCall(TaoSetHessian(tao, user->HF, user->HF, HessianFull, (void *)user));
340: /* f(x) for misfit tao */
341: PetscCall(TaoADMMSetMisfitObjectiveAndGradientRoutine(tao, MisfitObjectiveAndGradient, (void *)user));
342: PetscCall(TaoADMMSetMisfitHessianRoutine(tao, user->Hx, user->Hx, HessianMisfit, (void *)user));
343: PetscCall(TaoADMMSetMisfitHessianChangeStatus(tao, PETSC_FALSE));
344: PetscCall(TaoADMMSetMisfitConstraintJacobian(tao, user->D, user->D, NullJacobian, (void *)user));
346: /* g(x) for regularizer tao */
347: if (user->reg == 1) {
348: PetscCall(TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient1, (void *)user));
349: PetscCall(TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianReg, (void *)user));
350: PetscCall(TaoADMMSetRegHessianChangeStatus(tao, PETSC_TRUE));
351: } else if (user->reg == 2) {
352: PetscCall(TaoADMMSetRegularizerObjectiveAndGradientRoutine(tao, RegularizerObjectiveAndGradient2, (void *)user));
353: PetscCall(MatShift(user->Hz, 1));
354: PetscCall(MatScale(user->Hz, user->lambda));
355: PetscCall(TaoADMMSetRegularizerHessianRoutine(tao, user->Hz, user->Hz, HessianMisfit, (void *)user));
356: PetscCall(TaoADMMSetRegHessianChangeStatus(tao, PETSC_TRUE));
357: } else PetscCheck(user->reg == 3, PETSC_COMM_WORLD, PETSC_ERR_ARG_UNKNOWN_TYPE, "Incorrect Reg type"); /* TaoShell case */
359: /* Set type for the misfit solver */
360: PetscCall(TaoADMMGetMisfitSubsolver(tao, &misfit));
361: PetscCall(TaoADMMGetRegularizationSubsolver(tao, ®));
362: PetscCall(TaoSetType(misfit, TAONLS));
363: if (user->reg == 3) {
364: PetscCall(TaoSetType(reg, TAOSHELL));
365: PetscCall(TaoShellSetContext(reg, (void *)user));
366: PetscCall(TaoShellSetSolve(reg, TaoShellSolve_SoftThreshold));
367: } else {
368: PetscCall(TaoSetType(reg, TAONLS));
369: }
370: PetscCall(TaoSetVariableBounds(misfit, user->xlb, user->xub));
372: /* Soft Thresholding solves the ADMM problem with the L1 regularizer lambda*||z||_1 and the x-z=0 constraint */
373: PetscCall(TaoADMMSetRegularizerCoefficient(tao, user->lambda));
374: PetscCall(TaoADMMSetRegularizerConstraintJacobian(tao, NULL, NULL, NullJacobian, (void *)user));
375: PetscCall(TaoADMMSetMinimumSpectralPenalty(tao, user->mumin));
377: PetscCall(TaoADMMSetConstraintVectorRHS(tao, user->c));
378: PetscCall(TaoSetFromOptions(tao));
379: PetscCall(TaoSolve(tao));
381: /* Save x (reconstruction of object) vector to a binary file, which maybe read from MATLAB and convert to a 2D image for comparison. */
382: PetscCall(PetscViewerBinaryOpen(PETSC_COMM_WORLD, resultFile, FILE_MODE_WRITE, &fd));
383: PetscCall(VecView(user->x, fd));
384: PetscCall(PetscViewerDestroy(&fd));
386: /* compute the error */
387: PetscCall(VecAXPY(user->x, -1, user->xGT));
388: PetscCall(VecNorm(user->x, NORM_2, &v1));
389: PetscCall(VecNorm(user->xGT, NORM_2, &v2));
390: PetscCall(PetscPrintf(PETSC_COMM_WORLD, "relative reconstruction error: ||x-xGT||/||xGT|| = %6.4e.\n", (double)(v1 / v2)));
392: /* Free TAO data structures */
393: PetscCall(TaoDestroy(&tao));
394: PetscCall(DestroyContext(user));
395: PetscCall(PetscFree(user));
396: PetscCall(PetscFinalize());
397: return 0;
398: }
400: /*TEST
402: build:
403: requires: !complex !single !__float128 !defined(PETSC_USE_64BIT_INDICES)
405: testset:
406: requires: datafilespath
407: args: -path ${DATAFILESPATH}/tao/tomography
409: test:
410: suffix: 1
411: args: -lambda 1.e-8 -tao_monitor -tao_type nls -tao_nls_pc_type icc
413: test:
414: suffix: 2
415: args: -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor -reg_tao_monitor
417: test:
418: suffix: 3
419: args: -lambda 1.e-8 -tao_admm_dual_update update_basic -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_nls_pc_type icc -misfit_tao_monitor
421: test:
422: suffix: 4
423: args: -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_soft_thresh -tao_max_it 20 -tao_monitor -misfit_tao_monitor -misfit_tao_nls_pc_type icc
425: test:
426: suffix: 5
427: args: -reg 2 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc
429: test:
430: suffix: 6
431: args: -reg 3 -lambda 1.e-8 -tao_admm_dual_update update_adaptive -tao_admm_regularizer_type regularizer_user -tao_max_it 20 -tao_monitor -tao_admm_tolerance_update_factor 1.e-8 -misfit_tao_monitor -reg_tao_monitor -misfit_tao_nls_pc_type icc
433: TEST*/