Actual source code: elastic_net_regularization.c

  1: const char help[] = "Demonstration of elastic net regularization (https://en.wikipedia.org/wiki/Elastic_net_regularization) using TAO";

  3: #include <petsctao.h>

  5: int main(int argc, char **argv)
  6: {
  7:   /*
  8:     This example demonstrates the solution of an elastic net regularized least squares problem

 10:     (1/2) || Ax - b ||_W^2 + lambda_2 (1/2) || x ||_2^2 + lambda_1 || Dx - y ||_1
 11:    */

 13:   MPI_Comm    comm;
 14:   Mat         A;                // data matrix
 15:   Mat         D;                // dictionary matrix
 16:   Mat         W;                // weight matrix
 17:   Vec         w;                // observation vector
 18:   Vec         b;                // observation vector
 19:   Vec         y;                // dictionary vector
 20:   Vec         x;                // solution vector
 21:   PetscInt    m          = 100; // data size
 22:   PetscInt    n          = 20;  // model size
 23:   PetscInt    k          = 10;  // dictionary size
 24:   PetscBool   set_prefix = PETSC_TRUE;
 25:   PetscBool   set_name   = PETSC_FALSE;
 26:   PetscBool   check_eps  = PETSC_FALSE;
 27:   TaoTerm     data_term;
 28:   TaoTerm     l2_reg_term;
 29:   TaoTerm     l1_reg_term;
 30:   TaoTerm     full_objective;
 31:   PetscRandom rand;
 32:   PetscReal   lambda_1 = 0.1;
 33:   PetscReal   lambda_2 = 0.1;
 34:   Tao         tao;

 36:   PetscFunctionBeginUser;
 37:   PetscCall(PetscInitialize(&argc, &argv, NULL, help));
 38:   comm = PETSC_COMM_WORLD;

 40:   PetscOptionsBegin(comm, "", help, "none");
 41:   PetscCall(PetscOptionsBoundedInt("-m", "data size", "", m, &m, NULL, 0));
 42:   PetscCall(PetscOptionsBoundedInt("-n", "model size", "", n, &n, NULL, 0));
 43:   PetscCall(PetscOptionsBoundedInt("-k", "dictionary size", "", k, &k, NULL, 0));
 44:   PetscCall(PetscOptionsBool("-set_term_prefix", "Set prefix to terms", NULL, set_prefix, &set_prefix, NULL));
 45:   PetscCall(PetscOptionsBool("-set_term_name", "Set name to terms", NULL, set_name, &set_name, NULL));
 46:   PetscCall(PetscOptionsBool("-check_l1_eps", "Check epsilon of L1 term", NULL, check_eps, &check_eps, NULL));
 47:   PetscOptionsEnd();

 49:   PetscCall(TaoCreate(comm, &tao));

 51:   PetscCall(PetscRandomCreate(comm, &rand));
 52:   PetscCall(PetscRandomSetInterval(rand, -1.0, 1.0));
 53:   PetscCall(PetscRandomSetFromOptions(rand));

 55:   // create the model data, A, W and b
 56:   PetscCall(MatCreateDense(comm, PETSC_DECIDE, PETSC_DECIDE, m, n, NULL, &A));
 57:   PetscCall(MatSetRandom(A, rand));
 58:   PetscCall(MatCreateVecs(A, NULL, &b));
 59:   PetscCall(VecSetRandom(b, rand));
 60:   PetscCall(VecDuplicate(b, &w));
 61:   PetscCall(VecSetRandom(w, rand));
 62:   PetscCall(VecAbs(w));
 63:   PetscCall(VecShift(w, 1.0));
 64:   PetscCall(MatCreateDiagonal(w, &W));
 65:   PetscCall(VecDestroy(&w));

 67:   // create the dictionary data, D and y
 68:   PetscCall(MatCreateDense(comm, PETSC_DECIDE, PETSC_DECIDE, k, n, NULL, &D));
 69:   PetscCall(MatSetRandom(D, rand));
 70:   PetscCall(MatCreateVecs(D, NULL, &y));
 71:   PetscCall(VecSetRandom(y, rand));

 73:   // the model term,  (1/2) || Ax - b ||_W^2
 74:   PetscCall(TaoTermCreateQuadratic(W, &data_term));
 75:   if (set_prefix) {
 76:     PetscCall(PetscObjectSetOptionsPrefix((PetscObject)data_term, "data_"));
 77:     PetscCall(PetscObjectSetOptionsPrefix((PetscObject)b, "bvec_"));
 78:     PetscCall(PetscObjectSetOptionsPrefix((PetscObject)A, "Amat_"));
 79:   }
 80:   if (set_name) PetscCall(PetscObjectSetName((PetscObject)data_term, "Data TaoTerm"));
 81:   PetscCall(TaoAddTerm(tao, "data_", 1.0, data_term, b, A));
 82:   PetscCall(TaoTermDestroy(&data_term));

 84:   // the L2 term,  (1/2) lambda_2 || x ||_2^2
 85:   PetscCall(TaoTermCreateHalfL2Squared(comm, PETSC_DECIDE, n, &l2_reg_term));
 86:   if (set_prefix) PetscCall(PetscObjectSetOptionsPrefix((PetscObject)l2_reg_term, "ridge_"));
 87:   if (set_name) PetscCall(PetscObjectSetName((PetscObject)l2_reg_term, "Ridge TaoTerm"));
 88:   PetscCall(TaoAddTerm(tao, "ridge_", lambda_2, l2_reg_term, NULL, NULL)); // Note: no parameter vector, no map matrix needed
 89:   PetscCall(TaoTermDestroy(&l2_reg_term));

 91:   // the L1 term,  lambda_1 || Dx - y ||_1
 92:   PetscCall(TaoTermCreateL1(comm, PETSC_DECIDE, k, 0.0, &l1_reg_term));
 93:   if (set_prefix) {
 94:     PetscCall(PetscObjectSetOptionsPrefix((PetscObject)l1_reg_term, "lasso_"));
 95:     PetscCall(PetscObjectSetOptionsPrefix((PetscObject)y, "yvec_"));
 96:     PetscCall(PetscObjectSetOptionsPrefix((PetscObject)D, "Dmat_"));
 97:   }
 98:   if (set_name) PetscCall(PetscObjectSetName((PetscObject)l1_reg_term, "Lasso TaoTerm"));
 99:   PetscCall(TaoAddTerm(tao, "lasso_", lambda_1, l1_reg_term, y, D));
100:   PetscCall(TaoTermDestroy(&l1_reg_term));

102:   PetscCall(TaoGetTerm(tao, NULL, &full_objective, NULL, NULL));
103:   PetscCall(TaoTermCreateSolutionVec(full_objective, &x));
104:   PetscCall(VecSetRandom(x, rand));
105:   PetscCall(TaoSetSolution(tao, x));
106:   PetscCall(TaoSetFromOptions(tao));
107:   PetscCall(TaoSolve(tao));

109:   {
110:     PetscReal scale_get;
111:     TaoTerm   get_term;
112:     Vec       get_vec, p2, p1;
113:     Mat       get_mat;

115:     PetscCall(TaoGetTerm(tao, &scale_get, &get_term, &get_vec, &get_mat));
116:     PetscCall(VecNestGetTaoTermSumParameters(get_vec, 0, &p1));
117:     PetscCall(VecNestGetTaoTermSumParameters(get_vec, 1, &p2));
118:     PetscCheck(p1 == b, PETSC_COMM_SELF, PETSC_ERR_COR, "First parameter vector is not same as what was set");
119:     PetscCheck(p2 == NULL, PETSC_COMM_SELF, PETSC_ERR_COR, "Second parameter vector is not none");
120:   }

122:   if (check_eps) {
123:     PetscReal scale_get;
124:     TaoTerm   get_term;
125:     Vec       get_vec;
126:     Mat       get_mat;
127:     PetscInt  n_terms;
128:     PetscInt  last_index;
129:     PetscBool is_l1;
130:     TaoTerm   last_subterm;
131:     PetscReal epsilon;

133:     PetscCall(TaoGetTerm(tao, &scale_get, &get_term, &get_vec, &get_mat));
134:     PetscCall(TaoTermSumGetNumberTerms(get_term, &n_terms));
135:     last_index = n_terms - 1;
136:     PetscCall(TaoTermSumGetTerm(get_term, last_index, NULL, NULL, &last_subterm, NULL));
137:     PetscCall(PetscObjectTypeCompare((PetscObject)last_subterm, TAOTERML1, &is_l1));
138:     PetscCheck(is_l1, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "Last term is not L1");
139:     PetscCall(TaoTermL1GetEpsilon(last_subterm, &epsilon));
140:     PetscCheck(PetscAbsReal(epsilon - 0.1) < PETSC_SMALL, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "L1 epsilon is not 0.1, got: %g", (double)epsilon);
141:   }

143:   PetscCall(VecDestroy(&x));
144:   PetscCall(VecDestroy(&y));
145:   PetscCall(MatDestroy(&D));
146:   PetscCall(VecDestroy(&b));
147:   PetscCall(MatDestroy(&W));
148:   PetscCall(MatDestroy(&A));
149:   PetscCall(PetscRandomDestroy(&rand));
150:   PetscCall(TaoDestroy(&tao));
151:   PetscCall(PetscFinalize());
152:   return 0;
153: }

155: /*TEST

157:   build:
158:     requires: !complex !single !quad !defined(PETSC_USE_64BIT_INDICES) !__float128

160:   test:
161:     suffix: 0
162:     args: -tao_monitor_short -tao_view -lasso_tao_term_l1_epsilon 0.1 -tao_type nls -check_l1_eps 1

164:   test:
165:     suffix: 1
166:     args: -tao_type nls -lasso_tao_term_hessian_mat_type aij -tao_view ::ascii_info_detail

168:   test:
169:     suffix: sum_hpre_is_not_h
170:     args: -tao_type nls -tao_view ::ascii_info_detail -tao_term_hessian_pre_is_hessian 0

172:   test:
173:     suffix: data_hpre_is_not_h
174:     args: -tao_type nls -tao_view ::ascii_info_detail -data_tao_term_hessian_pre_is_hessian 0

176:   test:
177:     suffix: ridge_hpre_is_not_h
178:     args: -tao_type nls -tao_view ::ascii_info_detail -ridge_tao_term_hessian_pre_is_hessian 0

180:   test:
181:     suffix: lasso_hpre_is_not_h
182:     args: -tao_type nls -tao_view ::ascii_info_detail -lasso_tao_term_hessian_pre_is_hessian 0

184:   test:
185:     suffix: hpre_is_not_h
186:     args: -tao_type nls -tao_view ::ascii_info_detail -lasso_tao_term_hessian_pre_is_hessian 0
187:     args: -ridge_tao_term_hessian_pre_is_hessian 0 -data_tao_term_hessian_pre_is_hessian 0

189:   test:
190:     suffix: data_ridge_hpre_is_not_h
191:     args: -tao_type nls -tao_view ::ascii_info_detail
192:     args: -ridge_tao_term_hessian_pre_is_hessian 0 -data_tao_term_hessian_pre_is_hessian 0

194:   test:
195:     suffix: no_prefix
196:     args: -tao_monitor_short -tao_view -tao_term_l1_epsilon 0.1 -tao_type nls -set_term_prefix 0

198:   test:
199:     suffix: no_prefix_yes_name
200:     args: -tao_monitor_short -tao_view -tao_term_l1_epsilon 0.1 -tao_type nls -set_term_prefix 0 -set_term_name 1

202:   test:
203:     suffix: yes_prefix_yes_name
204:     args: -tao_monitor_short -tao_view -lasso_tao_term_l1_epsilon 0.1 -tao_type nls -set_term_prefix 1 -set_term_name 1

206:   test:
207:     suffix: mask_failure
208:     args: -tao_monitor_short -tao_view -lasso_tao_term_l1_epsilon 0.1 -tao_type nls
209:     args: -tao_term_sum_ridge_mask objective -tao_term_sum_lasso_mask gradient
210:     args: -tao_view ::ascii_info_detail

212:   test:
213:     suffix: assembled
214:     args: -tao_monitor_short -tao_view -lasso_tao_term_l1_epsilon 0.1 -tao_type nls -ridge_tao_term_hessian_mat_type constantdiagonal -lasso_tao_term_hessian_mat_type diagonal

216:   test:
217:     suffix: snes
218:     args: -tao_monitor_short -tao_view -lasso_tao_term_l1_epsilon 0.1 -tao_type snes

220:   test:
221:     suffix: extra_info_view
222:     args: -tao_type nls -tao_add_terms extra_ -extra_tao_term_type halfl2squared -tao_term_sum_extra_scale 1.0 -tao_view

224: TEST*/