Actual source code: mpiaijcupm.hpp

  1: #pragma once

  3: /* Shared CUPM (CUDA/HIP) implementations for MPIAIJCUSPARSE and MPIAIJHIPSPARSE
  4:    that do not depend on the cuSPARSE/hipSPARSE library proper.

  6:    Include ordering requirement: the vendor-specific MPI impl header
  7:    (mpicusparsematimpl.h or mpihipsparsematimpl.h) must be included before
  8:    this header so that Mat_MPIAIJCUSPARSE/Mat_MPIAIJHIPSPARSE types are visible.

 10:    Instantiated by:
 11:      mpiaijcusparse.cu      (DeviceType::CUDA, using MatMPIAIJCUSPARSE_Policy)
 12:      mpiaijhipsparse.hip.cxx (DeviceType::HIP,  using MatMPIAIJHIPSPARSE_Policy) */

 14: #include <petsc/private/cupmobject.hpp>
 15: #include <petsc/private/matimpl.h>
 16: #include <../src/mat/impls/aij/mpi/mpiaij.h>
 17: #include <petscsf.h>

 19: namespace Petsc
 20: {

 22: namespace mat
 23: {

 25: namespace aij
 26: {

 28: namespace cupm
 29: {

 31: namespace impl
 32: {

 34: /* --------------------------------------------------------------------------
 35:    Shared __global__ kernel: pack entries to be sent to remote.
 36:    -------------------------------------------------------------------------- */
 37: __global__ static void MatPackCOOValues_MPI(const PetscScalar kv[], PetscCount nnz, const PetscCount perm[], PetscScalar buf[])
 38: {
 39:   PetscCount       i         = blockIdx.x * blockDim.x + threadIdx.x;
 40:   const PetscCount grid_size = gridDim.x * blockDim.x;
 41:   for (; i < nnz; i += grid_size) buf[i] = kv[perm[i]];
 42: }

 44: /* --------------------------------------------------------------------------
 45:    Shared __global__ kernel: add local COO values to diagonal and off-diagonal.
 46:    -------------------------------------------------------------------------- */
 47: __global__ static void MatAddLocalCOOValues_MPI(const PetscScalar kv[], InsertMode imode, PetscCount Annz, const PetscCount Ajmap1[], const PetscCount Aperm1[], PetscScalar Aa[], PetscCount Bnnz, const PetscCount Bjmap1[], const PetscCount Bperm1[], PetscScalar Ba[])
 48: {
 49:   PetscCount       i         = blockIdx.x * blockDim.x + threadIdx.x;
 50:   const PetscCount grid_size = gridDim.x * blockDim.x;
 51:   for (; i < Annz + Bnnz; i += grid_size) {
 52:     PetscScalar sum = 0.0;
 53:     if (i < Annz) {
 54:       for (PetscCount k = Ajmap1[i]; k < Ajmap1[i + 1]; k++) sum += kv[Aperm1[k]];
 55:       Aa[i] = (imode == INSERT_VALUES ? 0.0 : Aa[i]) + sum;
 56:     } else {
 57:       i -= Annz;
 58:       for (PetscCount k = Bjmap1[i]; k < Bjmap1[i + 1]; k++) sum += kv[Bperm1[k]];
 59:       Ba[i] = (imode == INSERT_VALUES ? 0.0 : Ba[i]) + sum;
 60:     }
 61:   }
 62: }

 64: /* --------------------------------------------------------------------------
 65:    Shared __global__ kernel: add remote COO values to diagonal and off-diagonal.
 66:    -------------------------------------------------------------------------- */
 67: __global__ static void MatAddRemoteCOOValues_MPI(const PetscScalar kv[], PetscCount Annz2, const PetscCount Aimap2[], const PetscCount Ajmap2[], const PetscCount Aperm2[], PetscScalar Aa[], PetscCount Bnnz2, const PetscCount Bimap2[], const PetscCount Bjmap2[], const PetscCount Bperm2[], PetscScalar Ba[])
 68: {
 69:   PetscCount       i         = blockIdx.x * blockDim.x + threadIdx.x;
 70:   const PetscCount grid_size = gridDim.x * blockDim.x;
 71:   for (; i < Annz2 + Bnnz2; i += grid_size) {
 72:     if (i < Annz2) {
 73:       for (PetscCount k = Ajmap2[i]; k < Ajmap2[i + 1]; k++) Aa[Aimap2[i]] += kv[Aperm2[k]];
 74:     } else {
 75:       i -= Annz2;
 76:       for (PetscCount k = Bjmap2[i]; k < Bjmap2[i + 1]; k++) Ba[Bimap2[i]] += kv[Bperm2[k]];
 77:     }
 78:   }
 79: }

 81: /* ==========================================================================
 82:    MatMPIAIJCUSPARSE_CUPM<T, Policy>

 84:    Policy (C++11 traits class) requirements - all static:

 86:      typedef ... mat_struct_type;  // Mat_MPIAIJCUSPARSE / Mat_MPIAIJHIPSPARSE

 88:      static const char *mpi_mat_type;   // MATMPIAIJCUSPARSE / MATMPIAIJHIPSPARSE
 89:      static const char *seq_mat_type;   // MATSEQAIJCUSPARSE / MATSEQAIJHIPSPARSE
 90:      static const char *vec_seq_type;   // VECSEQCUDA / VECSEQHIP

 92:      // Seq sub-matrix device copy
 93:      static PetscErrorCode CopyToGPU(Mat);

 95:      // Seq sub-matrix merge (for GetLocalMatMerge)
 96:      static PetscErrorCode MergeMats(Mat, Mat, MatReuse, Mat *);

 98:      // Seq sub-matrix device array access (for SetValuesCOO)
 99:      static PetscErrorCode GetArray     (Mat, PetscScalar **);
100:      static PetscErrorCode GetArrayWrite(Mat, PetscScalar **);
101:      static PetscErrorCode RestoreArray     (Mat, PetscScalar **);
102:      static PetscErrorCode RestoreArrayWrite(Mat, PetscScalar **);

104:      // Set cuSPARSE/hipSPARSE storage format on both sub-matrices
105:      static PetscErrorCode SetSubMatFormats(Mat, Mat, mat_struct_type *);

107:      // Compose-function keys that differ between CUDA and HIP
108:      static const char *set_format_c;          // "MatCUSPARSESetFormat_C"            / "MatHIPSPARSESetFormat_C"
109:      static const char *mpi_convert_hypre_c;   // "MatConvert_mpiaijcusparse_hypre_C" / "_mpiaijhipsparse_hypre_C"
110:    ========================================================================== */

112: template <device::cupm::DeviceType T, typename Policy>
113: struct MatMPIAIJCUSPARSE_CUPM : device::cupm::impl::CUPMObject<T> {
114:   PETSC_CUPMOBJECT_HEADER(T);

116:   typedef typename Policy::mat_struct_type MatStructType;

118:   /* MatCOOStructDestroy: release all device-side COO arrays */
119:   static PetscErrorCode COOStructDestroy(PetscCtxRt data) noexcept
120:   {
121:     MatCOOStruct_MPIAIJ *coo = *(MatCOOStruct_MPIAIJ **)data;

123:     PetscFunctionBegin;
124:     PetscCall(PetscSFDestroy(&coo->sf));
125:     PetscCallCUPM(cupmFree(coo->Ajmap1));
126:     PetscCallCUPM(cupmFree(coo->Aperm1));
127:     PetscCallCUPM(cupmFree(coo->Bjmap1));
128:     PetscCallCUPM(cupmFree(coo->Bperm1));
129:     PetscCallCUPM(cupmFree(coo->Aimap2));
130:     PetscCallCUPM(cupmFree(coo->Ajmap2));
131:     PetscCallCUPM(cupmFree(coo->Aperm2));
132:     PetscCallCUPM(cupmFree(coo->Bimap2));
133:     PetscCallCUPM(cupmFree(coo->Bjmap2));
134:     PetscCallCUPM(cupmFree(coo->Bperm2));
135:     PetscCallCUPM(cupmFree(coo->Cperm1));
136:     PetscCallCUPM(cupmFree(coo->sendbuf));
137:     PetscCallCUPM(cupmFree(coo->recvbuf));
138:     PetscCall(PetscFree(coo));
139:     PetscFunctionReturn(PETSC_SUCCESS);
140:   }

142:   /* MatSetPreallocationCOO: copy MPIAIJ COO bookkeeping struct to device */
143:   static PetscErrorCode SetPreallocationCOO(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) noexcept
144:   {
145:     Mat_MPIAIJ          *mpiaij = (Mat_MPIAIJ *)mat->data;
146:     PetscBool            dev_ij = PETSC_FALSE;
147:     PetscMemType         mtype  = PETSC_MEMTYPE_HOST;
148:     PetscInt            *i, *j;
149:     PetscContainer       container_h;
150:     MatCOOStruct_MPIAIJ *coo_h, *coo_d;

152:     PetscFunctionBegin;
153:     PetscCall(PetscFree(mpiaij->garray));
154:     PetscCall(VecDestroy(&mpiaij->lvec));
155: #if defined(PETSC_USE_CTABLE)
156:     PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
157: #else
158:     PetscCall(PetscFree(mpiaij->colmap));
159: #endif
160:     PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
161:     mat->assembled     = PETSC_FALSE;
162:     mat->was_assembled = PETSC_FALSE;
163:     PetscCall(PetscGetMemType(coo_i, &mtype));
164:     if (PetscMemTypeDevice(mtype)) {
165:       dev_ij = PETSC_TRUE;
166:       PetscCall(PetscMalloc2(coo_n, &i, coo_n, &j));
167:       PetscCallCUPM(cupmMemcpy(i, coo_i, coo_n * sizeof(PetscInt), cupmMemcpyDeviceToHost));
168:       PetscCallCUPM(cupmMemcpy(j, coo_j, coo_n * sizeof(PetscInt), cupmMemcpyDeviceToHost));
169:     } else {
170:       i = coo_i;
171:       j = coo_j;
172:     }
173:     PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, i, j));
174:     if (dev_ij) PetscCall(PetscFree2(i, j));
175:     mat->offloadmask = PETSC_OFFLOAD_CPU;
176:     /* Create the GPU memory */
177:     PetscCall(Policy::CopyToGPU(mpiaij->A));
178:     PetscCall(Policy::CopyToGPU(mpiaij->B));

180:     /* Copy the COO struct to device */
181:     PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
182:     PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
183:     PetscCall(PetscMalloc1(1, &coo_d));
184:     *coo_d = *coo_h; /* shallow copy; device fields amended below */
185:     PetscCall(PetscObjectReference((PetscObject)coo_d->sf));
186:     PetscCallCUPM(cupmMalloc((void **)&coo_d->Ajmap1, (coo_h->Annz + 1) * sizeof(PetscCount)));
187:     PetscCallCUPM(cupmMalloc((void **)&coo_d->Aperm1, coo_h->Atot1 * sizeof(PetscCount)));
188:     PetscCallCUPM(cupmMalloc((void **)&coo_d->Bjmap1, (coo_h->Bnnz + 1) * sizeof(PetscCount)));
189:     PetscCallCUPM(cupmMalloc((void **)&coo_d->Bperm1, coo_h->Btot1 * sizeof(PetscCount)));
190:     PetscCallCUPM(cupmMalloc((void **)&coo_d->Aimap2, coo_h->Annz2 * sizeof(PetscCount)));
191:     PetscCallCUPM(cupmMalloc((void **)&coo_d->Ajmap2, (coo_h->Annz2 + 1) * sizeof(PetscCount)));
192:     PetscCallCUPM(cupmMalloc((void **)&coo_d->Aperm2, coo_h->Atot2 * sizeof(PetscCount)));
193:     PetscCallCUPM(cupmMalloc((void **)&coo_d->Bimap2, coo_h->Bnnz2 * sizeof(PetscCount)));
194:     PetscCallCUPM(cupmMalloc((void **)&coo_d->Bjmap2, (coo_h->Bnnz2 + 1) * sizeof(PetscCount)));
195:     PetscCallCUPM(cupmMalloc((void **)&coo_d->Bperm2, coo_h->Btot2 * sizeof(PetscCount)));
196:     PetscCallCUPM(cupmMalloc((void **)&coo_d->Cperm1, coo_h->sendlen * sizeof(PetscCount)));
197:     PetscCallCUPM(cupmMalloc((void **)&coo_d->sendbuf, coo_h->sendlen * sizeof(PetscScalar)));
198:     PetscCallCUPM(cupmMalloc((void **)&coo_d->recvbuf, coo_h->recvlen * sizeof(PetscScalar)));
199:     PetscCallCUPM(cupmMemcpy(coo_d->Ajmap1, coo_h->Ajmap1, (coo_h->Annz + 1) * sizeof(PetscCount), cupmMemcpyHostToDevice));
200:     PetscCallCUPM(cupmMemcpy(coo_d->Aperm1, coo_h->Aperm1, coo_h->Atot1 * sizeof(PetscCount), cupmMemcpyHostToDevice));
201:     PetscCallCUPM(cupmMemcpy(coo_d->Bjmap1, coo_h->Bjmap1, (coo_h->Bnnz + 1) * sizeof(PetscCount), cupmMemcpyHostToDevice));
202:     PetscCallCUPM(cupmMemcpy(coo_d->Bperm1, coo_h->Bperm1, coo_h->Btot1 * sizeof(PetscCount), cupmMemcpyHostToDevice));
203:     PetscCallCUPM(cupmMemcpy(coo_d->Aimap2, coo_h->Aimap2, coo_h->Annz2 * sizeof(PetscCount), cupmMemcpyHostToDevice));
204:     PetscCallCUPM(cupmMemcpy(coo_d->Ajmap2, coo_h->Ajmap2, (coo_h->Annz2 + 1) * sizeof(PetscCount), cupmMemcpyHostToDevice));
205:     PetscCallCUPM(cupmMemcpy(coo_d->Aperm2, coo_h->Aperm2, coo_h->Atot2 * sizeof(PetscCount), cupmMemcpyHostToDevice));
206:     PetscCallCUPM(cupmMemcpy(coo_d->Bimap2, coo_h->Bimap2, coo_h->Bnnz2 * sizeof(PetscCount), cupmMemcpyHostToDevice));
207:     PetscCallCUPM(cupmMemcpy(coo_d->Bjmap2, coo_h->Bjmap2, (coo_h->Bnnz2 + 1) * sizeof(PetscCount), cupmMemcpyHostToDevice));
208:     PetscCallCUPM(cupmMemcpy(coo_d->Bperm2, coo_h->Bperm2, coo_h->Btot2 * sizeof(PetscCount), cupmMemcpyHostToDevice));
209:     PetscCallCUPM(cupmMemcpy(coo_d->Cperm1, coo_h->Cperm1, coo_h->sendlen * sizeof(PetscCount), cupmMemcpyHostToDevice));
210:     /* Put the COO struct in a container and attach it to the matrix */
211:     PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatMPIAIJCUSPARSE_CUPM::COOStructDestroy));
212:     PetscFunctionReturn(PETSC_SUCCESS);
213:   }

215:   /* MatSetValuesCOO: launch CUPM kernels for packing and adding local/remote COO values */
216:   static PetscErrorCode SetValuesCOO(Mat mat, const PetscScalar v[], InsertMode imode) noexcept
217:   {
218:     Mat_MPIAIJ          *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
219:     Mat                  A = mpiaij->A, B = mpiaij->B;
220:     PetscScalar         *Aa, *Ba;
221:     const PetscScalar   *v1 = v;
222:     PetscMemType         memtype;
223:     PetscContainer       container;
224:     MatCOOStruct_MPIAIJ *coo;
225:     cupmStream_t         stream;

227:     PetscFunctionBegin;
228:     PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
229:     PetscCheck(container, PetscObjectComm((PetscObject)mat), PETSC_ERR_PLIB, "Not found MatCOOStruct on this matrix");
230:     PetscCall(PetscContainerGetPointer(container, (void **)&coo));

232:     const auto &Annz   = coo->Annz;
233:     const auto &Annz2  = coo->Annz2;
234:     const auto &Bnnz   = coo->Bnnz;
235:     const auto &Bnnz2  = coo->Bnnz2;
236:     const auto &vsend  = coo->sendbuf;
237:     const auto &v2     = coo->recvbuf;
238:     const auto &Ajmap1 = coo->Ajmap1;
239:     const auto &Ajmap2 = coo->Ajmap2;
240:     const auto &Aimap2 = coo->Aimap2;
241:     const auto &Bjmap1 = coo->Bjmap1;
242:     const auto &Bjmap2 = coo->Bjmap2;
243:     const auto &Bimap2 = coo->Bimap2;
244:     const auto &Aperm1 = coo->Aperm1;
245:     const auto &Aperm2 = coo->Aperm2;
246:     const auto &Bperm1 = coo->Bperm1;
247:     const auto &Bperm2 = coo->Bperm2;
248:     const auto &Cperm1 = coo->Cperm1;

250:     PetscCall(PetscGetMemType(v, &memtype));
251:     if (PetscMemTypeHost(memtype)) {
252:       PetscCallCUPM(cupmMalloc((void **)&v1, coo->n * sizeof(PetscScalar)));
253:       PetscCallCUPM(cupmMemcpy((void *)v1, v, coo->n * sizeof(PetscScalar), cupmMemcpyHostToDevice));
254:     }

256:     if (imode == INSERT_VALUES) {
257:       PetscCall(Policy::GetArrayWrite(A, &Aa));
258:       PetscCall(Policy::GetArrayWrite(B, &Ba));
259:     } else {
260:       PetscCall(Policy::GetArray(A, &Aa));
261:       PetscCall(Policy::GetArray(B, &Ba));
262:     }

264:     PetscCall(GetHandles_(&stream));
265:     PetscCall(PetscLogGpuTimeBegin());
266:     /* Pack entries to be sent to remote */
267:     if (coo->sendlen) {
268:       PetscCallCUPM(cupmLaunchKernel(MatPackCOOValues_MPI, (unsigned int)((coo->sendlen + 255) / 256), 256u, (size_t)0, stream, v1, (PetscCount)coo->sendlen, Cperm1, vsend));
269:       PetscCallCUPM(cupmGetLastError());
270:     }
271:     /* Send remote entries and overlap communication with local computation */
272:     PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_CUPM(), vsend, PETSC_MEMTYPE_CUPM(), v2, MPI_REPLACE));
273:     /* Add local entries to A and B */
274:     if (Annz + Bnnz > 0) {
275:       PetscCallCUPM(cupmLaunchKernel(MatAddLocalCOOValues_MPI, (unsigned int)((Annz + Bnnz + 255) / 256), 256u, (size_t)0, stream, v1, imode, Annz, Ajmap1, Aperm1, Aa, Bnnz, Bjmap1, Bperm1, Ba));
276:       PetscCallCUPM(cupmGetLastError());
277:     }
278:     PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend, v2, MPI_REPLACE));
279:     /* Add received remote entries to A and B */
280:     if (Annz2 + Bnnz2 > 0) {
281:       PetscCallCUPM(cupmLaunchKernel(MatAddRemoteCOOValues_MPI, (unsigned int)((Annz2 + Bnnz2 + 255) / 256), 256u, (size_t)0, stream, v2, Annz2, Aimap2, Ajmap2, Aperm2, Aa, Bnnz2, Bimap2, Bjmap2, Bperm2, Ba));
282:       PetscCallCUPM(cupmGetLastError());
283:     }
284:     PetscCall(PetscLogGpuTimeEnd());

286:     if (imode == INSERT_VALUES) {
287:       PetscCall(Policy::RestoreArrayWrite(A, &Aa));
288:       PetscCall(Policy::RestoreArrayWrite(B, &Ba));
289:     } else {
290:       PetscCall(Policy::RestoreArray(A, &Aa));
291:       PetscCall(Policy::RestoreArray(B, &Ba));
292:     }
293:     if (PetscMemTypeHost(memtype)) {
294:       void *v1_device = (void *)v1;
295:       PetscCallCUPM(cupmFree(v1_device));
296:     }
297:     mat->offloadmask = PETSC_OFFLOAD_GPU;
298:     PetscFunctionReturn(PETSC_SUCCESS);
299:   }

301:   /* MatMPIAIJGetLocalMatMerge */
302:   static PetscErrorCode GetLocalMatMerge(Mat A, MatReuse scall, IS *glob, Mat *A_loc) noexcept
303:   {
304:     Mat             Ad, Ao;
305:     const PetscInt *cmap;

307:     PetscFunctionBegin;
308:     PetscCall(MatMPIAIJGetSeqAIJ(A, &Ad, &Ao, &cmap));
309:     PetscCall(Policy::MergeMats(Ad, Ao, scall, A_loc));
310:     if (glob) {
311:       PetscInt cst, i, dn, on, *gidx;

313:       PetscCall(MatGetLocalSize(Ad, NULL, &dn));
314:       PetscCall(MatGetLocalSize(Ao, NULL, &on));
315:       PetscCall(MatGetOwnershipRangeColumn(A, &cst, NULL));
316:       PetscCall(PetscMalloc1(dn + on, &gidx));
317:       for (i = 0; i < dn; i++) gidx[i] = cst + i;
318:       for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
319:       PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
320:     }
321:     PetscFunctionReturn(PETSC_SUCCESS);
322:   }

324:   /* MatMPIAIJSetPreallocation */
325:   static PetscErrorCode SetPreallocation(Mat B, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) noexcept
326:   {
327:     Mat_MPIAIJ    *b     = (Mat_MPIAIJ *)B->data;
328:     MatStructType *spptr = (MatStructType *)b->spptr;
329:     PetscInt       i;

331:     PetscFunctionBegin;
332:     if (B->hash_active) {
333:       B->ops[0]      = b->cops;
334:       B->hash_active = PETSC_FALSE;
335:     }
336:     PetscCall(PetscLayoutSetUp(B->rmap));
337:     PetscCall(PetscLayoutSetUp(B->cmap));
338:     if (PetscDefined(USE_DEBUG) && d_nnz) {
339:       for (i = 0; i < B->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
340:     }
341:     if (PetscDefined(USE_DEBUG) && o_nnz) {
342:       for (i = 0; i < B->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
343:     }
344: #if defined(PETSC_USE_CTABLE)
345:     PetscCall(PetscHMapIDestroy(&b->colmap));
346: #else
347:     PetscCall(PetscFree(b->colmap));
348: #endif
349:     PetscCall(PetscFree(b->garray));
350:     PetscCall(VecDestroy(&b->lvec));
351:     PetscCall(VecScatterDestroy(&b->Mvctx));
352:     /* Because the B will have been resized we simply destroy it and create a new one each time */
353:     PetscCall(MatDestroy(&b->B));
354:     if (!b->A) {
355:       PetscCall(MatCreate(PETSC_COMM_SELF, &b->A));
356:       PetscCall(MatSetSizes(b->A, B->rmap->n, B->cmap->n, B->rmap->n, B->cmap->n));
357:     }
358:     if (!b->B) {
359:       PetscMPIInt size;

361:       PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)B), &size));
362:       PetscCall(MatCreate(PETSC_COMM_SELF, &b->B));
363:       PetscCall(MatSetSizes(b->B, B->rmap->n, size > 1 ? B->cmap->N : 0, B->rmap->n, size > 1 ? B->cmap->N : 0));
364:     }
365:     PetscCall(MatSetType(b->A, Policy::seq_mat_type));
366:     PetscCall(MatSetType(b->B, Policy::seq_mat_type));
367:     PetscCall(MatBindToCPU(b->A, B->boundtocpu));
368:     PetscCall(MatBindToCPU(b->B, B->boundtocpu));
369:     PetscCall(MatSeqAIJSetPreallocation(b->A, d_nz, d_nnz));
370:     PetscCall(MatSeqAIJSetPreallocation(b->B, o_nz, o_nnz));
371:     PetscCall(Policy::SetSubMatFormats(b->A, b->B, spptr));
372:     B->preallocated  = PETSC_TRUE;
373:     B->was_assembled = PETSC_FALSE;
374:     B->assembled     = PETSC_FALSE;
375:     PetscFunctionReturn(PETSC_SUCCESS);
376:   }

378:   /* MatMult: identical in both CUDA and HIP */
379:   static PetscErrorCode Mult(Mat A, Vec xx, Vec yy) noexcept
380:   {
381:     Mat_MPIAIJ *a = (Mat_MPIAIJ *)A->data;

383:     PetscFunctionBegin;
384:     PetscCall(VecScatterBegin(a->Mvctx, xx, a->lvec, INSERT_VALUES, SCATTER_FORWARD));
385:     PetscUseTypeMethod(a->A, mult, xx, yy);
386:     PetscCall(VecScatterEnd(a->Mvctx, xx, a->lvec, INSERT_VALUES, SCATTER_FORWARD));
387:     PetscUseTypeMethod(a->B, multadd, a->lvec, yy, yy);
388:     PetscFunctionReturn(PETSC_SUCCESS);
389:   }

391:   /* MatZeroEntries: identical in both CUDA and HIP */
392:   static PetscErrorCode ZeroEntries(Mat A) noexcept
393:   {
394:     Mat_MPIAIJ *l = (Mat_MPIAIJ *)A->data;

396:     PetscFunctionBegin;
397:     PetscCall(MatZeroEntries(l->A));
398:     PetscCall(MatZeroEntries(l->B));
399:     PetscFunctionReturn(PETSC_SUCCESS);
400:   }

402:   /* MatMultAdd: identical in both CUDA and HIP */
403:   static PetscErrorCode MultAdd(Mat A, Vec xx, Vec yy, Vec zz) noexcept
404:   {
405:     Mat_MPIAIJ *a = (Mat_MPIAIJ *)A->data;

407:     PetscFunctionBegin;
408:     PetscCall(VecScatterBegin(a->Mvctx, xx, a->lvec, INSERT_VALUES, SCATTER_FORWARD));
409:     PetscUseTypeMethod(a->A, multadd, xx, yy, zz);
410:     PetscCall(VecScatterEnd(a->Mvctx, xx, a->lvec, INSERT_VALUES, SCATTER_FORWARD));
411:     PetscUseTypeMethod(a->B, multadd, a->lvec, zz, zz);
412:     PetscFunctionReturn(PETSC_SUCCESS);
413:   }

415:   /* MatMultTranspose: identical in both CUDA and HIP */
416:   static PetscErrorCode MultTranspose(Mat A, Vec xx, Vec yy) noexcept
417:   {
418:     Mat_MPIAIJ *a = (Mat_MPIAIJ *)A->data;

420:     PetscFunctionBegin;
421:     PetscUseTypeMethod(a->B, multtranspose, xx, a->lvec);
422:     PetscUseTypeMethod(a->A, multtranspose, xx, yy);
423:     PetscCall(VecScatterBegin(a->Mvctx, a->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
424:     PetscCall(VecScatterEnd(a->Mvctx, a->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
425:     PetscFunctionReturn(PETSC_SUCCESS);
426:   }

428:   /* MatAssemblyEnd: set lvec type to vendor-appropriate VECSEQ type */
429:   static PetscErrorCode AssemblyEnd(Mat A, MatAssemblyType mode) noexcept
430:   {
431:     Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;

433:     PetscFunctionBegin;
434:     PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
435:     if (mpiaij->lvec) PetscCall(VecSetType(mpiaij->lvec, Policy::vec_seq_type));
436:     PetscFunctionReturn(PETSC_SUCCESS);
437:   }

439:   /* MatCreateAIJ: allocate and preallocate a parallel sparse matrix of this type */
440:   static PetscErrorCode CreateAIJ(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) noexcept
441:   {
442:     PetscMPIInt size;

444:     PetscFunctionBegin;
445:     PetscCall(MatCreate(comm, A));
446:     PetscCall(MatSetSizes(*A, m, n, M, N));
447:     PetscCallMPI(MPI_Comm_size(comm, &size));
448:     if (size > 1) {
449:       PetscCall(MatSetType(*A, Policy::mpi_mat_type));
450:       PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
451:     } else {
452:       PetscCall(MatSetType(*A, Policy::seq_mat_type));
453:       PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
454:     }
455:     PetscFunctionReturn(PETSC_SUCCESS);
456:   }

458:   /* MatDestroy: free vendor-specific state, deregister composed functions */
459:   static PetscErrorCode Destroy(Mat A) noexcept
460:   {
461:     Mat_MPIAIJ    *aij       = (Mat_MPIAIJ *)A->data;
462:     MatStructType *mpiStruct = (MatStructType *)aij->spptr;

464:     PetscFunctionBegin;
465:     PetscCheck(mpiStruct, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing spptr");
466:     PetscCallCXX(delete mpiStruct);
467:     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
468:     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
469:     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
470:     PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
471:     PetscCall(PetscObjectComposeFunction((PetscObject)A, Policy::set_format_c, NULL));
472:     PetscCall(PetscObjectComposeFunction((PetscObject)A, Policy::mpi_convert_hypre_c, NULL));
473:     PetscCall(MatDestroy_MPIAIJ(A));
474:     PetscFunctionReturn(PETSC_SUCCESS);
475:   }
476: };

478: } // namespace impl

480: } // namespace cupm

482: } // namespace aij

484: } // namespace mat

486: } // namespace Petsc