Actual source code: elastic_net_regularization.c
1: const char help[] = "Demonstration of elastic net regularization (https://en.wikipedia.org/wiki/Elastic_net_regularization) using TAO";
3: #include <petsctao.h>
5: int main(int argc, char **argv)
6: {
7: /*
8: This example demonstrates the solution of an elastic net regularized least squares problem
10: (1/2) || Ax - b ||_W^2 + lambda_2 (1/2) || x ||_2^2 + lambda_1 || Dx - y ||_1
11: */
13: MPI_Comm comm;
14: Mat A; // data matrix
15: Mat D; // dictionary matrix
16: Mat W; // weight matrix
17: Vec w; // observation vector
18: Vec b; // observation vector
19: Vec y; // dictionary vector
20: Vec x; // solution vector
21: PetscInt m = 100; // data size
22: PetscInt n = 20; // model size
23: PetscInt k = 10; // dictionary size
24: PetscBool set_prefix = PETSC_TRUE;
25: PetscBool set_name = PETSC_FALSE;
26: PetscBool check_eps = PETSC_FALSE;
27: TaoTerm data_term;
28: TaoTerm l2_reg_term;
29: TaoTerm l1_reg_term;
30: TaoTerm full_objective;
31: PetscRandom rand;
32: PetscReal lambda_1 = 0.1;
33: PetscReal lambda_2 = 0.1;
34: Tao tao;
36: PetscFunctionBeginUser;
37: PetscCall(PetscInitialize(&argc, &argv, NULL, help));
38: comm = PETSC_COMM_WORLD;
40: PetscOptionsBegin(comm, "", help, "none");
41: PetscCall(PetscOptionsBoundedInt("-m", "data size", "", m, &m, NULL, 0));
42: PetscCall(PetscOptionsBoundedInt("-n", "model size", "", n, &n, NULL, 0));
43: PetscCall(PetscOptionsBoundedInt("-k", "dictionary size", "", k, &k, NULL, 0));
44: PetscCall(PetscOptionsBool("-set_term_prefix", "Set prefix to terms", NULL, set_prefix, &set_prefix, NULL));
45: PetscCall(PetscOptionsBool("-set_term_name", "Set name to terms", NULL, set_name, &set_name, NULL));
46: PetscCall(PetscOptionsBool("-check_l1_eps", "Check epsilon of L1 term", NULL, check_eps, &check_eps, NULL));
47: PetscOptionsEnd();
49: PetscCall(TaoCreate(comm, &tao));
51: PetscCall(PetscRandomCreate(comm, &rand));
52: PetscCall(PetscRandomSetInterval(rand, -1.0, 1.0));
53: PetscCall(PetscRandomSetFromOptions(rand));
55: // create the model data, A, W and b
56: PetscCall(MatCreateDense(comm, PETSC_DECIDE, PETSC_DECIDE, m, n, NULL, &A));
57: PetscCall(MatSetRandom(A, rand));
58: PetscCall(MatCreateVecs(A, NULL, &b));
59: PetscCall(VecSetRandom(b, rand));
60: PetscCall(VecDuplicate(b, &w));
61: PetscCall(VecSetRandom(w, rand));
62: PetscCall(VecAbs(w));
63: PetscCall(VecShift(w, 1.0));
64: PetscCall(MatCreateDiagonal(w, &W));
65: PetscCall(VecDestroy(&w));
67: // create the dictionary data, D and y
68: PetscCall(MatCreateDense(comm, PETSC_DECIDE, PETSC_DECIDE, k, n, NULL, &D));
69: PetscCall(MatSetRandom(D, rand));
70: PetscCall(MatCreateVecs(D, NULL, &y));
71: PetscCall(VecSetRandom(y, rand));
73: // the model term, (1/2) || Ax - b ||_W^2
74: PetscCall(TaoTermCreateQuadratic(W, &data_term));
75: if (set_prefix) {
76: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)data_term, "data_"));
77: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)b, "bvec_"));
78: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)A, "Amat_"));
79: }
80: if (set_name) PetscCall(PetscObjectSetName((PetscObject)data_term, "Data TaoTerm"));
81: PetscCall(TaoAddTerm(tao, "data_", 1.0, data_term, b, A));
82: PetscCall(TaoTermDestroy(&data_term));
84: // the L2 term, (1/2) lambda_2 || x ||_2^2
85: PetscCall(TaoTermCreateHalfL2Squared(comm, PETSC_DECIDE, n, &l2_reg_term));
86: if (set_prefix) PetscCall(PetscObjectSetOptionsPrefix((PetscObject)l2_reg_term, "ridge_"));
87: if (set_name) PetscCall(PetscObjectSetName((PetscObject)l2_reg_term, "Ridge TaoTerm"));
88: PetscCall(TaoAddTerm(tao, "ridge_", lambda_2, l2_reg_term, NULL, NULL)); // Note: no parameter vector, no map matrix needed
89: PetscCall(TaoTermDestroy(&l2_reg_term));
91: // the L1 term, lambda_1 || Dx - y ||_1
92: PetscCall(TaoTermCreateL1(comm, PETSC_DECIDE, k, 0.0, &l1_reg_term));
93: if (set_prefix) {
94: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)l1_reg_term, "lasso_"));
95: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)y, "yvec_"));
96: PetscCall(PetscObjectSetOptionsPrefix((PetscObject)D, "Dmat_"));
97: }
98: if (set_name) PetscCall(PetscObjectSetName((PetscObject)l1_reg_term, "Lasso TaoTerm"));
99: PetscCall(TaoAddTerm(tao, "lasso_", lambda_1, l1_reg_term, y, D));
100: PetscCall(TaoTermDestroy(&l1_reg_term));
102: PetscCall(TaoGetTerm(tao, NULL, &full_objective, NULL, NULL));
103: PetscCall(TaoTermCreateSolutionVec(full_objective, &x));
104: PetscCall(VecSetRandom(x, rand));
105: PetscCall(TaoSetSolution(tao, x));
106: PetscCall(TaoSetFromOptions(tao));
107: PetscCall(TaoSolve(tao));
109: {
110: PetscReal scale_get;
111: TaoTerm get_term;
112: Vec get_vec, p2, p1;
113: Mat get_mat;
115: PetscCall(TaoGetTerm(tao, &scale_get, &get_term, &get_vec, &get_mat));
116: PetscCall(VecNestGetTaoTermSumParameters(get_vec, 0, &p1));
117: PetscCall(VecNestGetTaoTermSumParameters(get_vec, 1, &p2));
118: PetscCheck(p1 == b, PETSC_COMM_SELF, PETSC_ERR_COR, "First parameter vector is not same as what was set");
119: PetscCheck(p2 == NULL, PETSC_COMM_SELF, PETSC_ERR_COR, "Second parameter vector is not none");
120: }
122: if (check_eps) {
123: PetscReal scale_get;
124: TaoTerm get_term;
125: Vec get_vec;
126: Mat get_mat;
127: PetscInt n_terms;
128: PetscInt last_index;
129: PetscBool is_l1;
130: TaoTerm last_subterm;
131: PetscReal epsilon;
133: PetscCall(TaoGetTerm(tao, &scale_get, &get_term, &get_vec, &get_mat));
134: PetscCall(TaoTermSumGetNumberTerms(get_term, &n_terms));
135: last_index = n_terms - 1;
136: PetscCall(TaoTermSumGetTerm(get_term, last_index, NULL, NULL, &last_subterm, NULL));
137: PetscCall(PetscObjectTypeCompare((PetscObject)last_subterm, TAOTERML1, &is_l1));
138: PetscCheck(is_l1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Last term is not L1");
139: PetscCall(TaoTermL1GetEpsilon(last_subterm, &epsilon));
140: PetscCheck(PetscAbsReal(epsilon - 0.1) < PETSC_SMALL, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "L1 epsilon is not 0.1, got: %g", (double)epsilon);
141: }
143: PetscCall(VecDestroy(&x));
144: PetscCall(VecDestroy(&y));
145: PetscCall(MatDestroy(&D));
146: PetscCall(VecDestroy(&b));
147: PetscCall(MatDestroy(&W));
148: PetscCall(MatDestroy(&A));
149: PetscCall(PetscRandomDestroy(&rand));
150: PetscCall(TaoDestroy(&tao));
151: PetscCall(PetscFinalize());
152: return 0;
153: }
155: /*TEST
157: build:
158: requires: !complex !single !quad !defined(PETSC_USE_64BIT_INDICES) !__float128
160: test:
161: suffix: 0
162: args: -tao_monitor_short -tao_view -lasso_tao_term_l1_epsilon 0.1 -tao_type nls -check_l1_eps 1
164: test:
165: suffix: 1
166: args: -tao_type nls -lasso_tao_term_hessian_mat_type aij -tao_view ::ascii_info_detail
168: test:
169: suffix: sum_hpre_is_not_h
170: args: -tao_type nls -tao_view ::ascii_info_detail -tao_term_hessian_pre_is_hessian 0
172: test:
173: suffix: data_hpre_is_not_h
174: args: -tao_type nls -tao_view ::ascii_info_detail -data_tao_term_hessian_pre_is_hessian 0
176: test:
177: suffix: ridge_hpre_is_not_h
178: args: -tao_type nls -tao_view ::ascii_info_detail -ridge_tao_term_hessian_pre_is_hessian 0
180: test:
181: suffix: lasso_hpre_is_not_h
182: args: -tao_type nls -tao_view ::ascii_info_detail -lasso_tao_term_hessian_pre_is_hessian 0
184: test:
185: suffix: hpre_is_not_h
186: args: -tao_type nls -tao_view ::ascii_info_detail -lasso_tao_term_hessian_pre_is_hessian 0
187: args: -ridge_tao_term_hessian_pre_is_hessian 0 -data_tao_term_hessian_pre_is_hessian 0
189: test:
190: suffix: data_ridge_hpre_is_not_h
191: args: -tao_type nls -tao_view ::ascii_info_detail
192: args: -ridge_tao_term_hessian_pre_is_hessian 0 -data_tao_term_hessian_pre_is_hessian 0
194: test:
195: suffix: no_prefix
196: args: -tao_monitor_short -tao_view -tao_term_l1_epsilon 0.1 -tao_type nls -set_term_prefix 0
198: test:
199: suffix: no_prefix_yes_name
200: args: -tao_monitor_short -tao_view -tao_term_l1_epsilon 0.1 -tao_type nls -set_term_prefix 0 -set_term_name 1
202: test:
203: suffix: yes_prefix_yes_name
204: args: -tao_monitor_short -tao_view -lasso_tao_term_l1_epsilon 0.1 -tao_type nls -set_term_prefix 1 -set_term_name 1
206: test:
207: suffix: mask_failure
208: args: -tao_monitor_short -tao_view -lasso_tao_term_l1_epsilon 0.1 -tao_type nls
209: args: -tao_term_sum_ridge_mask objective -tao_term_sum_lasso_mask gradient
210: args: -tao_view ::ascii_info_detail
212: test:
213: suffix: assembled
214: args: -tao_monitor_short -tao_view -lasso_tao_term_l1_epsilon 0.1 -tao_type nls -ridge_tao_term_hessian_mat_type constantdiagonal -lasso_tao_term_hessian_mat_type diagonal
216: test:
217: suffix: snes
218: args: -tao_monitor_short -tao_view -lasso_tao_term_l1_epsilon 0.1 -tao_type snes
220: test:
221: suffix: extra_info_view
222: args: -tao_type nls -tao_add_terms extra_ -extra_tao_term_type halfl2squared -tao_term_sum_extra_scale 1.0 -tao_view
224: TEST*/