Actual source code: linear.c

  1: #include <../src/ml/regressor/impls/linear/linearimpl.h>

  3: const char *const PetscRegressorLinearTypes[] = {"ols", "lasso", "ridge", "RegressorLinearType", "REGRESSOR_LINEAR_", NULL};

  5: static PetscErrorCode PetscRegressorLinearSetFitIntercept_Linear(PetscRegressor regressor, PetscBool flg)
  6: {
  7:   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;

  9:   PetscFunctionBegin;
 10:   linear->fit_intercept = flg;
 11:   PetscFunctionReturn(PETSC_SUCCESS);
 12: }

 14: static PetscErrorCode PetscRegressorLinearSetType_Linear(PetscRegressor regressor, PetscRegressorLinearType type)
 15: {
 16:   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;

 18:   PetscFunctionBegin;
 19:   linear->type = type;
 20:   PetscFunctionReturn(PETSC_SUCCESS);
 21: }

 23: static PetscErrorCode PetscRegressorLinearGetType_Linear(PetscRegressor regressor, PetscRegressorLinearType *type)
 24: {
 25:   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;

 27:   PetscFunctionBegin;
 28:   *type = linear->type;
 29:   PetscFunctionReturn(PETSC_SUCCESS);
 30: }

 32: static PetscErrorCode PetscRegressorLinearGetIntercept_Linear(PetscRegressor regressor, PetscScalar *intercept)
 33: {
 34:   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;

 36:   PetscFunctionBegin;
 37:   *intercept = linear->intercept;
 38:   PetscFunctionReturn(PETSC_SUCCESS);
 39: }

 41: static PetscErrorCode PetscRegressorLinearGetCoefficients_Linear(PetscRegressor regressor, Vec *coefficients)
 42: {
 43:   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;

 45:   PetscFunctionBegin;
 46:   *coefficients = linear->coefficients;
 47:   PetscFunctionReturn(PETSC_SUCCESS);
 48: }

 50: static PetscErrorCode PetscRegressorLinearGetKSP_Linear(PetscRegressor regressor, KSP *ksp)
 51: {
 52:   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;

 54:   PetscFunctionBegin;
 55:   if (!linear->ksp) {
 56:     PetscCall(KSPCreate(PetscObjectComm((PetscObject)regressor), &linear->ksp));
 57:     PetscCall(PetscObjectIncrementTabLevel((PetscObject)linear->ksp, (PetscObject)regressor, 1));
 58:     PetscCall(PetscObjectSetOptions((PetscObject)linear->ksp, ((PetscObject)regressor)->options));
 59:   }
 60:   *ksp = linear->ksp;
 61:   PetscFunctionReturn(PETSC_SUCCESS);
 62: }

 64: static PetscErrorCode PetscRegressorLinearSetUseKSP_Linear(PetscRegressor regressor, PetscBool flg)
 65: {
 66:   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;

 68:   PetscFunctionBegin;
 69:   linear->use_ksp = flg;
 70:   PetscFunctionReturn(PETSC_SUCCESS);
 71: }

 73: static PetscErrorCode EvaluateResidual(Tao tao, Vec x, Vec f, void *ptr)
 74: {
 75:   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)ptr;

 77:   PetscFunctionBegin;
 78:   /* Evaluate f = A * x - b */
 79:   PetscCall(MatMult(linear->X, x, f));
 80:   PetscCall(VecAXPY(f, -1.0, linear->rhs));
 81:   PetscFunctionReturn(PETSC_SUCCESS);
 82: }

 84: static PetscErrorCode EvaluateJacobian(Tao tao, Vec x, Mat J, Mat Jpre, void *ptr)
 85: {
 86:   /* The TAOBRGN API expects us to pass an EvaluateJacobian() routine to it, but in this case it is a dummy function.
 87:      Denoting our data matrix as X, for linear least squares J[m][n] = df[m]/dx[n] = X[m][n]. */
 88:   PetscFunctionBegin;
 89:   PetscFunctionReturn(PETSC_SUCCESS);
 90: }

 92: static PetscErrorCode PetscRegressorSetUp_Linear(PetscRegressor regressor)
 93: {
 94:   PetscInt               M, N;
 95:   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
 96:   KSP                    ksp;
 97:   Tao                    tao;

 99:   PetscFunctionBegin;
100:   PetscCall(MatGetSize(regressor->training, &M, &N));

102:   if (linear->fit_intercept) {
103:     /* If we are fitting the intercept, we need to make A a composite matrix using MATCENTERING to preserve sparsity.
104:      * Though there might be some cases we don't want to do this for, depending on what kind of matrix is passed in. (Probably bad idea for dense?)
105:      * We will also need to ensure that the right-hand side passed to the KSP is also mean-centered, since we
106:      * intend to compute the intercept separately from regression coefficients (that is, we will not be adding a
107:      * column of all 1s to our design matrix). */
108:     PetscCall(MatCreateCentering(PetscObjectComm((PetscObject)regressor), PETSC_DECIDE, M, &linear->C));
109:     PetscCall(MatCreate(PetscObjectComm((PetscObject)regressor), &linear->X));
110:     PetscCall(MatSetSizes(linear->X, PETSC_DECIDE, PETSC_DECIDE, M, N));
111:     PetscCall(MatSetType(linear->X, MATCOMPOSITE));
112:     PetscCall(MatCompositeSetType(linear->X, MAT_COMPOSITE_MULTIPLICATIVE));
113:     PetscCall(MatCompositeAddMat(linear->X, regressor->training));
114:     PetscCall(MatCompositeAddMat(linear->X, linear->C));
115:     PetscCall(VecDuplicate(regressor->target, &linear->rhs));
116:     PetscCall(MatMult(linear->C, regressor->target, linear->rhs));
117:   } else {
118:     // When not fitting intercept, we assume that the input data are already centered.
119:     linear->X   = regressor->training;
120:     linear->rhs = regressor->target;

122:     PetscCall(PetscObjectReference((PetscObject)linear->X));
123:     PetscCall(PetscObjectReference((PetscObject)linear->rhs));
124:   }

126:   if (linear->coefficients) PetscCall(VecDestroy(&linear->coefficients));

128:   if (linear->use_ksp) {
129:     PetscCheck(linear->type == REGRESSOR_LINEAR_OLS, PetscObjectComm((PetscObject)regressor), PETSC_ERR_ARG_WRONGSTATE, "KSP can be used to fit a linear regressor only when its type is OLS");

131:     if (!linear->ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
132:     ksp = linear->ksp;

134:     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, NULL));
135:     /* Set up the KSP to solve the least squares problem (without solving for intercept, as this is done separately) using KSPLSQR. */
136:     PetscCall(MatCreateNormal(linear->X, &linear->XtX));
137:     PetscCall(KSPSetType(ksp, KSPLSQR));
138:     PetscCall(KSPSetOperators(ksp, linear->X, linear->XtX));
139:     PetscCall(KSPSetOptionsPrefix(ksp, ((PetscObject)regressor)->prefix));
140:     PetscCall(KSPAppendOptionsPrefix(ksp, "regressor_linear_"));
141:     PetscCall(KSPSetFromOptions(ksp));
142:   } else {
143:     /* Note: Currently implementation creates TAO inside of implementations.
144:       * Thus, all the prefix jobs are done inside implementations, not in interface */
145:     const char *prefix;

147:     if (!regressor->tao) PetscCall(PetscRegressorGetTao(regressor, &tao));

149:     PetscCall(MatCreateVecs(linear->X, &linear->coefficients, &linear->residual));
150:     /* Set up the TAO object to solve the (regularized) least squares problem (without solving for intercept, which is done separately) using TAOBRGN. */
151:     PetscCall(TaoSetType(tao, TAOBRGN));
152:     PetscCall(TaoSetSolution(tao, linear->coefficients));
153:     PetscCall(TaoSetResidualRoutine(tao, linear->residual, EvaluateResidual, linear));
154:     PetscCall(TaoSetJacobianResidualRoutine(tao, linear->X, linear->X, EvaluateJacobian, linear));
155:     // Set the regularization type and weight for the BRGN as linear->type dictates:
156:     // TODO BRGN needs to be BRGNSetRegularizationType
157:     // PetscOptionsSetValue no longer works due to functioning prefix system
158:     PetscCall(PetscRegressorGetOptionsPrefix(regressor, &prefix));
159:     PetscCall(TaoSetOptionsPrefix(regressor->tao, prefix));
160:     PetscCall(TaoAppendOptionsPrefix(tao, "regressor_linear_"));
161:     switch (linear->type) {
162:     case REGRESSOR_LINEAR_OLS:
163:       regressor->regularizer_weight = 0.0; // OLS, by definition, uses a regularizer weight of 0
164:       break;
165:     case REGRESSOR_LINEAR_LASSO:
166:       PetscCall(TaoBRGNSetRegularizationType(regressor->tao, TAOBRGN_REGULARIZATION_L1DICT));
167:       break;
168:     case REGRESSOR_LINEAR_RIDGE:
169:       PetscCall(TaoBRGNSetRegularizationType(regressor->tao, TAOBRGN_REGULARIZATION_L2PURE));
170:       break;
171:     default:
172:       break;
173:     }
174:     if (!linear->use_ksp) PetscCall(TaoBRGNSetRegularizerWeight(tao, regressor->regularizer_weight));
175:     PetscCall(TaoSetFromOptions(tao));
176:   }
177:   PetscFunctionReturn(PETSC_SUCCESS);
178: }

180: static PetscErrorCode PetscRegressorReset_Linear(PetscRegressor regressor)
181: {
182:   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;

184:   PetscFunctionBegin;
185:   /* Destroy the PETSc objects associated with the linear regressor implementation. */
186:   linear->ksp_its     = 0;
187:   linear->ksp_tot_its = 0;

189:   PetscCall(MatDestroy(&linear->X));
190:   PetscCall(MatDestroy(&linear->XtX));
191:   PetscCall(MatDestroy(&linear->C));
192:   PetscCall(KSPDestroy(&linear->ksp));
193:   PetscCall(VecDestroy(&linear->coefficients));
194:   PetscCall(VecDestroy(&linear->rhs));
195:   PetscCall(VecDestroy(&linear->residual));
196:   PetscFunctionReturn(PETSC_SUCCESS);
197: }

199: static PetscErrorCode PetscRegressorDestroy_Linear(PetscRegressor regressor)
200: {
201:   PetscFunctionBegin;
202:   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", NULL));
203:   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", NULL));
204:   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", NULL));
205:   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", NULL));
206:   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", NULL));
207:   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", NULL));
208:   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", NULL));
209:   PetscCall(PetscRegressorReset_Linear(regressor));
210:   PetscCall(PetscFree(regressor->data));
211:   PetscFunctionReturn(PETSC_SUCCESS);
212: }

214: /*@
215:   PetscRegressorLinearSetFitIntercept - Set a flag to indicate that the intercept (also known as the "bias" or "offset") should
216:   be calculated; data are assumed to be mean-centered if false.

218:   Logically Collective

220:   Input Parameters:
221: + regressor - the `PetscRegressor` context
222: - flg       - `PETSC_TRUE` to calculate the intercept, `PETSC_FALSE` to assume mean-centered data (default is `PETSC_TRUE`)

224:   Level: intermediate

226:   Options Database Key:
227: . regressor_linear_fit_intercept <true,false> - fit the intercept

229:   Note:
230:   If the user indicates that the intercept should not be calculated, the intercept will be set to zero.

232: .seealso: `PetscRegressor`, `PetscRegressorFit()`
233: @*/
234: PetscErrorCode PetscRegressorLinearSetFitIntercept(PetscRegressor regressor, PetscBool flg)
235: {
236:   PetscFunctionBegin;
237:   /* TODO: Add companion PetscRegressorLinearGetFitIntercept(), and put it in the .seealso: */
240:   PetscTryMethod(regressor, "PetscRegressorLinearSetFitIntercept_C", (PetscRegressor, PetscBool), (regressor, flg));
241:   PetscFunctionReturn(PETSC_SUCCESS);
242: }

244: /*@
245:   PetscRegressorLinearSetUseKSP - Set a flag to indicate that a `KSP` object, instead of a `Tao` one, should be used
246:   to fit the linear regressor

248:   Logically Collective

250:   Input Parameters:
251: + regressor - the `PetscRegressor` context
252: - flg       - `PETSC_TRUE` to use a `KSP`, `PETSC_FALSE` to use a `Tao` object (default is false)

254:   Options Database Key:
255: . regressor_linear_use_ksp <true,false> - use `KSP`

257:   Level: intermediate

259:   Notes:
260:   `KSPLSQR` with no preconditioner is used to solve the normal equations by default.

262:   For sequential `MATSEQAIJ` sparse matrices QR factorization a `PCType` of `PCQR` can be used to solve the least-squares system with a `MatSolverType` of
263:   `MATSOLVERSPQR`, using, for example,
264: .vb
265:   -ksp_type none -pc_type qr -pc_factor_mat_solver_type sp
266: .ve
267:   if centering, `PetscRegressorLinearSetFitIntercept()`, is not used.

269:   Developer Notes:
270:   It should be possible to use Cholesky (and any other preconditioners) to solve the normal equations.

272:   It should be possible to use QR if centering is used. See ml/regressor/ex1.c and ex2.c

274:   It should be possible to use dense SVD `PCSVD` and dense qr directly on the rectangular matrix to solve the least squares problem.

276:   Adding the above support seems to require a refactorization of how least squares problems are solved with PETSc in `KSPLSQR`

278: .seealso: `PetscRegressor`, `PetscRegressorLinearGetKSP()`, `KSPLSQR`, `PCQR`, `MATSOLVERSPQR`, `MatSolverType`, `MATSEQDENSE`, `PCSVD`
279: @*/
280: PetscErrorCode PetscRegressorLinearSetUseKSP(PetscRegressor regressor, PetscBool flg)
281: {
282:   PetscFunctionBegin;
283:   /* TODO: Add companion PetscRegressorLinearGetUseKSP(), and put it in the .seealso: */
286:   PetscTryMethod(regressor, "PetscRegressorLinearSetUseKSP_C", (PetscRegressor, PetscBool), (regressor, flg));
287:   PetscFunctionReturn(PETSC_SUCCESS);
288: }

290: static PetscErrorCode PetscRegressorSetFromOptions_Linear(PetscRegressor regressor, PetscOptionItems PetscOptionsObject)
291: {
292:   PetscBool              set, flg = PETSC_FALSE;
293:   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;

295:   PetscFunctionBegin;
296:   PetscOptionsHeadBegin(PetscOptionsObject, "PetscRegressor options for linear regressors");
297:   PetscCall(PetscOptionsBool("-regressor_linear_fit_intercept", "Calculate intercept for linear model", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
298:   if (set) PetscCall(PetscRegressorLinearSetFitIntercept(regressor, flg));
299:   PetscCall(PetscOptionsBool("-regressor_linear_use_ksp", "Use KSP instead of TAO for linear model fitting problem", "PetscRegressorLinearSetFitIntercept", flg, &flg, &set));
300:   if (set) PetscCall(PetscRegressorLinearSetUseKSP(regressor, flg));
301:   PetscCall(PetscOptionsEnum("-regressor_linear_type", "Linear regression method", "PetscRegressorLinearTypes", PetscRegressorLinearTypes, (PetscEnum)linear->type, (PetscEnum *)&linear->type, NULL));
302:   PetscOptionsHeadEnd();
303:   PetscFunctionReturn(PETSC_SUCCESS);
304: }

306: static PetscErrorCode PetscRegressorView_Linear(PetscRegressor regressor, PetscViewer viewer)
307: {
308:   PetscBool              isascii;
309:   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;

311:   PetscFunctionBegin;
312:   PetscCall(PetscObjectTypeCompare((PetscObject)viewer, PETSCVIEWERASCII, &isascii));
313:   if (isascii) {
314:     PetscCall(PetscViewerASCIIPushTab(viewer));
315:     PetscCall(PetscViewerASCIIPrintf(viewer, "PetscRegressor Linear Type: %s\n", PetscRegressorLinearTypes[linear->type]));
316:     if (linear->ksp) {
317:       PetscCall(KSPView(linear->ksp, viewer));
318:       PetscCall(PetscViewerASCIIPrintf(viewer, "total KSP iterations: %" PetscInt_FMT "\n", linear->ksp_tot_its));
319:     }
320:     if (linear->fit_intercept) PetscCall(PetscViewerASCIIPrintf(viewer, "Intercept=%g\n", (double)linear->intercept));
321:     PetscCall(PetscViewerASCIIPopTab(viewer));
322:   }
323:   PetscFunctionReturn(PETSC_SUCCESS);
324: }

326: /*@
327:   PetscRegressorLinearGetKSP - Returns the `KSP` context for a `PETSCREGRESSORLINEAR` object.

329:   Not Collective, but if the `PetscRegressor` is parallel, then the `KSP` object is parallel

331:   Input Parameter:
332: . regressor - the `PetscRegressor` context

334:   Output Parameter:
335: . ksp - the `KSP` context

337:   Level: beginner

339:   Note:
340:   This routine will always return a `KSP`, but, depending on the type of the linear regressor and the options that are set, the regressor may actually use a `Tao` object instead of this `KSP`.

342: .seealso: `PetscRegressorGetTao()`
343: @*/
344: PetscErrorCode PetscRegressorLinearGetKSP(PetscRegressor regressor, KSP *ksp)
345: {
346:   PetscFunctionBegin;
348:   PetscAssertPointer(ksp, 2);
349:   PetscUseMethod(regressor, "PetscRegressorLinearGetKSP_C", (PetscRegressor, KSP *), (regressor, ksp));
350:   PetscFunctionReturn(PETSC_SUCCESS);
351: }

353: /*@
354:   PetscRegressorLinearGetCoefficients - Get a vector of the fitted coefficients from a linear regression model

356:   Not Collective but the vector is parallel

358:   Input Parameter:
359: . regressor - the `PetscRegressor` context

361:   Output Parameter:
362: . coefficients - the vector of the coefficients

364:   Level: beginner

366: .seealso: `PetscRegressor`, `PetscRegressorLinearGetIntercept()`, `PETSCREGRESSORLINEAR`, `Vec`
367: @*/
368: PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetCoefficients(PetscRegressor regressor, Vec *coefficients)
369: {
370:   PetscFunctionBegin;
372:   PetscAssertPointer(coefficients, 2);
373:   PetscUseMethod(regressor, "PetscRegressorLinearGetCoefficients_C", (PetscRegressor, Vec *), (regressor, coefficients));
374:   PetscFunctionReturn(PETSC_SUCCESS);
375: }

377: /*@
378:   PetscRegressorLinearGetIntercept - Get the intercept from a linear regression model

380:   Not Collective

382:   Input Parameter:
383: . regressor - the `PetscRegressor` context

385:   Output Parameter:
386: . intercept - the intercept

388:   Level: beginner

390: .seealso: `PetscRegressor`, `PetscRegressorLinearSetFitIntercept()`, `PetscRegressorLinearGetCoefficients()`, `PETSCREGRESSORLINEAR`
391: @*/
392: PETSC_EXTERN PetscErrorCode PetscRegressorLinearGetIntercept(PetscRegressor regressor, PetscScalar *intercept)
393: {
394:   PetscFunctionBegin;
396:   PetscAssertPointer(intercept, 2);
397:   PetscUseMethod(regressor, "PetscRegressorLinearGetIntercept_C", (PetscRegressor, PetscScalar *), (regressor, intercept));
398:   PetscFunctionReturn(PETSC_SUCCESS);
399: }

401: /*@C
402:   PetscRegressorLinearSetType - Sets the type of linear regression to be performed

404:   Logically Collective

406:   Input Parameters:
407: + regressor - the `PetscRegressor` context (should be of type `PETSCREGRESSORLINEAR`)
408: - type      - a known linear regression method

410:   Options Database Key:
411: . -regressor_linear_type - Sets the linear regression method; use -help for a list of available methods
412:    (for instance "-regressor_linear_type ols" or "-regressor_linear_type lasso")

414:   Level: intermediate

416: .seealso: `PetscRegressorLinearGetType()`, `PetscRegressorLinearType`, `PetscRegressorSetType()`, `REGRESSOR_LINEAR_OLS`,
417:           `REGRESSOR_LINEAR_LASSO`, `REGRESSOR_LINEAR_RIDGE`
418: @*/
419: PetscErrorCode PetscRegressorLinearSetType(PetscRegressor regressor, PetscRegressorLinearType type)
420: {
421:   PetscFunctionBegin;
424:   PetscTryMethod(regressor, "PetscRegressorLinearSetType_C", (PetscRegressor, PetscRegressorLinearType), (regressor, type));
425:   PetscFunctionReturn(PETSC_SUCCESS);
426: }

428: /*@
429:   PetscRegressorLinearGetType - Return the type for the `PETSCREGRESSORLINEAR` solver

431:   Input Parameter:
432: . regressor - the `PetscRegressor` solver context

434:   Output Parameter:
435: . type - `PETSCREGRESSORLINEAR` type

437:   Level: advanced

439: .seealso: `PetscRegressor`, `PETSCREGRESSORLINEAR`, `PetscRegressorLinearSetType()`, `PetscRegressorLinearType`
440: @*/
441: PetscErrorCode PetscRegressorLinearGetType(PetscRegressor regressor, PetscRegressorLinearType *type)
442: {
443:   PetscFunctionBegin;
445:   PetscAssertPointer(type, 2);
446:   PetscUseMethod(regressor, "PetscRegressorLinearGetType_C", (PetscRegressor, PetscRegressorLinearType *), (regressor, type));
447:   PetscFunctionReturn(PETSC_SUCCESS);
448: }

450: static PetscErrorCode PetscRegressorFit_Linear(PetscRegressor regressor)
451: {
452:   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;
453:   KSP                    ksp;
454:   PetscScalar            target_mean, *column_means_global, *column_means_local, column_means_dot_coefficients;
455:   Vec                    column_means;
456:   PetscInt               m, N, istart, i, kspits;

458:   PetscFunctionBegin;
459:   if (linear->use_ksp) PetscCall(PetscRegressorLinearGetKSP(regressor, &linear->ksp));
460:   ksp = linear->ksp;

462:   /* Solve the least-squares problem (previously set up in PetscRegressorSetUp_Linear()) without finding the intercept. */
463:   if (linear->use_ksp) {
464:     PetscCall(KSPSolve(ksp, linear->rhs, linear->coefficients));
465:     PetscCall(KSPGetIterationNumber(ksp, &kspits));
466:     linear->ksp_its += kspits;
467:     linear->ksp_tot_its += kspits;
468:   } else {
469:     PetscCall(TaoSolve(regressor->tao));
470:   }

472:   /* Calculate the intercept. */
473:   if (linear->fit_intercept) {
474:     PetscCall(MatGetSize(regressor->training, NULL, &N));
475:     PetscCall(PetscMalloc1(N, &column_means_global));
476:     PetscCall(VecMean(regressor->target, &target_mean));
477:     /* We need the means of all columns of regressor->training, placed into a Vec compatible with linear->coefficients.
478:      * Note the potential scalability issue: MatGetColumnMeans() computes means of ALL colummns. */
479:     PetscCall(MatGetColumnMeans(regressor->training, column_means_global));
480:     /* TODO: Calculation of the Vec and matrix column means should probably go into the SetUp phase, and also be placed
481:      *       into a routine that is callable from outside of PetscRegressorFit_Linear(), because we'll want to do the same
482:      *       thing for other models, such as ridge and LASSO regression, and should avoid code duplication.
483:      *       What we are calling 'target_mean' and 'column_means' should be stashed in the base linear regressor struct,
484:      *       and perhaps renamed to make it clear they are offsets that should be applied (though the current naming
485:      *       makes sense since it makes it clear where these come from.) */
486:     PetscCall(VecDuplicate(linear->coefficients, &column_means));
487:     PetscCall(VecGetLocalSize(column_means, &m));
488:     PetscCall(VecGetOwnershipRange(column_means, &istart, NULL));
489:     PetscCall(VecGetArrayWrite(column_means, &column_means_local));
490:     for (i = 0; i < m; i++) column_means_local[i] = column_means_global[istart + i];
491:     PetscCall(VecRestoreArrayWrite(column_means, &column_means_local));
492:     PetscCall(VecDot(column_means, linear->coefficients, &column_means_dot_coefficients));
493:     PetscCall(VecDestroy(&column_means));
494:     PetscCall(PetscFree(column_means_global));
495:     linear->intercept = target_mean - column_means_dot_coefficients;
496:   } else {
497:     linear->intercept = 0.0;
498:   }
499:   PetscFunctionReturn(PETSC_SUCCESS);
500: }

502: static PetscErrorCode PetscRegressorPredict_Linear(PetscRegressor regressor, Mat X, Vec y)
503: {
504:   PetscRegressor_Linear *linear = (PetscRegressor_Linear *)regressor->data;

506:   PetscFunctionBegin;
507:   PetscCall(MatMult(X, linear->coefficients, y));
508:   PetscCall(VecShift(y, linear->intercept));
509:   PetscFunctionReturn(PETSC_SUCCESS);
510: }

512: /*MC
513:      PETSCREGRESSORLINEAR - Linear regression model (ordinary least squares or regularized variants)

515:    Options Database:
516: +  -regressor_linear_fit_intercept - Calculate the intercept for the linear model
517: -  -regressor_linear_use_ksp       - Use `KSP` instead of `Tao` for linear model fitting (non-regularized variants only)

519:    Level: beginner

521:    Note:
522:    This is the default regressor in `PetscRegressor`.

524: .seealso: `PetscRegressorCreate()`, `PetscRegressor`, `PetscRegressorSetType()`
525: M*/
526: PETSC_EXTERN PetscErrorCode PetscRegressorCreate_Linear(PetscRegressor regressor)
527: {
528:   PetscRegressor_Linear *linear;

530:   PetscFunctionBegin;
531:   PetscCall(PetscNew(&linear));
532:   regressor->data = (void *)linear;

534:   regressor->ops->setup          = PetscRegressorSetUp_Linear;
535:   regressor->ops->reset          = PetscRegressorReset_Linear;
536:   regressor->ops->destroy        = PetscRegressorDestroy_Linear;
537:   regressor->ops->setfromoptions = PetscRegressorSetFromOptions_Linear;
538:   regressor->ops->view           = PetscRegressorView_Linear;
539:   regressor->ops->fit            = PetscRegressorFit_Linear;
540:   regressor->ops->predict        = PetscRegressorPredict_Linear;

542:   linear->intercept     = 0.0;
543:   linear->fit_intercept = PETSC_TRUE;  /* Default to calculating the intercept. */
544:   linear->use_ksp       = PETSC_FALSE; /* Do not default to using KSP for solving the model-fitting problem (use TAO instead). */
545:   linear->type          = REGRESSOR_LINEAR_OLS;
546:   /* Above, manually set the default linear regressor type.
547:        We don't use PetscRegressorLinearSetType() here, because that expects the SetUp event to already have happened. */

549:   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetFitIntercept_C", PetscRegressorLinearSetFitIntercept_Linear));
550:   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetUseKSP_C", PetscRegressorLinearSetUseKSP_Linear));
551:   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetKSP_C", PetscRegressorLinearGetKSP_Linear));
552:   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetCoefficients_C", PetscRegressorLinearGetCoefficients_Linear));
553:   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetIntercept_C", PetscRegressorLinearGetIntercept_Linear));
554:   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearSetType_C", PetscRegressorLinearSetType_Linear));
555:   PetscCall(PetscObjectComposeFunction((PetscObject)regressor, "PetscRegressorLinearGetType_C", PetscRegressorLinearGetType_Linear));
556:   PetscFunctionReturn(PETSC_SUCCESS);
557: }