Actual source code: mpiaijcupm.hpp
1: #pragma once
3: /* Shared CUPM (CUDA/HIP) implementations for MPIAIJCUSPARSE and MPIAIJHIPSPARSE
4: that do not depend on the cuSPARSE/hipSPARSE library proper.
6: Include ordering requirement: the vendor-specific MPI impl header
7: (mpicusparsematimpl.h or mpihipsparsematimpl.h) must be included before
8: this header so that Mat_MPIAIJCUSPARSE/Mat_MPIAIJHIPSPARSE types are visible.
10: Instantiated by:
11: mpiaijcusparse.cu (DeviceType::CUDA, using MatMPIAIJCUSPARSE_Policy)
12: mpiaijhipsparse.hip.cxx (DeviceType::HIP, using MatMPIAIJHIPSPARSE_Policy) */
14: #include <petsc/private/cupmobject.hpp>
15: #include <petsc/private/matimpl.h>
16: #include <../src/mat/impls/aij/mpi/mpiaij.h>
17: #include <petscsf.h>
19: namespace Petsc
20: {
22: namespace mat
23: {
25: namespace aij
26: {
28: namespace cupm
29: {
31: namespace impl
32: {
34: /* --------------------------------------------------------------------------
35: Shared __global__ kernel: pack entries to be sent to remote.
36: -------------------------------------------------------------------------- */
37: __global__ static void MatPackCOOValues_MPI(const PetscScalar kv[], PetscCount nnz, const PetscCount perm[], PetscScalar buf[])
38: {
39: PetscCount i = blockIdx.x * blockDim.x + threadIdx.x;
40: const PetscCount grid_size = gridDim.x * blockDim.x;
41: for (; i < nnz; i += grid_size) buf[i] = kv[perm[i]];
42: }
44: /* --------------------------------------------------------------------------
45: Shared __global__ kernel: add local COO values to diagonal and off-diagonal.
46: -------------------------------------------------------------------------- */
47: __global__ static void MatAddLocalCOOValues_MPI(const PetscScalar kv[], InsertMode imode, PetscCount Annz, const PetscCount Ajmap1[], const PetscCount Aperm1[], PetscScalar Aa[], PetscCount Bnnz, const PetscCount Bjmap1[], const PetscCount Bperm1[], PetscScalar Ba[])
48: {
49: PetscCount i = blockIdx.x * blockDim.x + threadIdx.x;
50: const PetscCount grid_size = gridDim.x * blockDim.x;
51: for (; i < Annz + Bnnz; i += grid_size) {
52: PetscScalar sum = 0.0;
53: if (i < Annz) {
54: for (PetscCount k = Ajmap1[i]; k < Ajmap1[i + 1]; k++) sum += kv[Aperm1[k]];
55: Aa[i] = (imode == INSERT_VALUES ? 0.0 : Aa[i]) + sum;
56: } else {
57: i -= Annz;
58: for (PetscCount k = Bjmap1[i]; k < Bjmap1[i + 1]; k++) sum += kv[Bperm1[k]];
59: Ba[i] = (imode == INSERT_VALUES ? 0.0 : Ba[i]) + sum;
60: }
61: }
62: }
64: /* --------------------------------------------------------------------------
65: Shared __global__ kernel: add remote COO values to diagonal and off-diagonal.
66: -------------------------------------------------------------------------- */
67: __global__ static void MatAddRemoteCOOValues_MPI(const PetscScalar kv[], PetscCount Annz2, const PetscCount Aimap2[], const PetscCount Ajmap2[], const PetscCount Aperm2[], PetscScalar Aa[], PetscCount Bnnz2, const PetscCount Bimap2[], const PetscCount Bjmap2[], const PetscCount Bperm2[], PetscScalar Ba[])
68: {
69: PetscCount i = blockIdx.x * blockDim.x + threadIdx.x;
70: const PetscCount grid_size = gridDim.x * blockDim.x;
71: for (; i < Annz2 + Bnnz2; i += grid_size) {
72: if (i < Annz2) {
73: for (PetscCount k = Ajmap2[i]; k < Ajmap2[i + 1]; k++) Aa[Aimap2[i]] += kv[Aperm2[k]];
74: } else {
75: i -= Annz2;
76: for (PetscCount k = Bjmap2[i]; k < Bjmap2[i + 1]; k++) Ba[Bimap2[i]] += kv[Bperm2[k]];
77: }
78: }
79: }
81: /* ==========================================================================
82: MatMPIAIJCUSPARSE_CUPM<T, Policy>
84: Policy (C++11 traits class) requirements - all static:
86: typedef ... mat_struct_type; // Mat_MPIAIJCUSPARSE / Mat_MPIAIJHIPSPARSE
88: static const char *mpi_mat_type; // MATMPIAIJCUSPARSE / MATMPIAIJHIPSPARSE
89: static const char *seq_mat_type; // MATSEQAIJCUSPARSE / MATSEQAIJHIPSPARSE
90: static const char *vec_seq_type; // VECSEQCUDA / VECSEQHIP
92: // Seq sub-matrix device copy
93: static PetscErrorCode CopyToGPU(Mat);
95: // Seq sub-matrix merge (for GetLocalMatMerge)
96: static PetscErrorCode MergeMats(Mat, Mat, MatReuse, Mat *);
98: // Seq sub-matrix device array access (for SetValuesCOO)
99: static PetscErrorCode GetArray (Mat, PetscScalar **);
100: static PetscErrorCode GetArrayWrite(Mat, PetscScalar **);
101: static PetscErrorCode RestoreArray (Mat, PetscScalar **);
102: static PetscErrorCode RestoreArrayWrite(Mat, PetscScalar **);
104: // Set cuSPARSE/hipSPARSE storage format on both sub-matrices
105: static PetscErrorCode SetSubMatFormats(Mat, Mat, mat_struct_type *);
107: // Compose-function keys that differ between CUDA and HIP
108: static const char *set_format_c; // "MatCUSPARSESetFormat_C" / "MatHIPSPARSESetFormat_C"
109: static const char *mpi_convert_hypre_c; // "MatConvert_mpiaijcusparse_hypre_C" / "_mpiaijhipsparse_hypre_C"
110: ========================================================================== */
112: template <device::cupm::DeviceType T, typename Policy>
113: struct MatMPIAIJCUSPARSE_CUPM : device::cupm::impl::CUPMObject<T> {
114: PETSC_CUPMOBJECT_HEADER(T);
116: typedef typename Policy::mat_struct_type MatStructType;
118: /* MatCOOStructDestroy: release all device-side COO arrays */
119: static PetscErrorCode COOStructDestroy(PetscCtxRt data) noexcept
120: {
121: MatCOOStruct_MPIAIJ *coo = *(MatCOOStruct_MPIAIJ **)data;
123: PetscFunctionBegin;
124: PetscCall(PetscSFDestroy(&coo->sf));
125: PetscCallCUPM(cupmFree(coo->Ajmap1));
126: PetscCallCUPM(cupmFree(coo->Aperm1));
127: PetscCallCUPM(cupmFree(coo->Bjmap1));
128: PetscCallCUPM(cupmFree(coo->Bperm1));
129: PetscCallCUPM(cupmFree(coo->Aimap2));
130: PetscCallCUPM(cupmFree(coo->Ajmap2));
131: PetscCallCUPM(cupmFree(coo->Aperm2));
132: PetscCallCUPM(cupmFree(coo->Bimap2));
133: PetscCallCUPM(cupmFree(coo->Bjmap2));
134: PetscCallCUPM(cupmFree(coo->Bperm2));
135: PetscCallCUPM(cupmFree(coo->Cperm1));
136: PetscCallCUPM(cupmFree(coo->sendbuf));
137: PetscCallCUPM(cupmFree(coo->recvbuf));
138: PetscCall(PetscFree(coo));
139: PetscFunctionReturn(PETSC_SUCCESS);
140: }
142: /* MatSetPreallocationCOO: copy MPIAIJ COO bookkeeping struct to device */
143: static PetscErrorCode SetPreallocationCOO(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[]) noexcept
144: {
145: Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
146: PetscBool dev_ij = PETSC_FALSE;
147: PetscMemType mtype = PETSC_MEMTYPE_HOST;
148: PetscInt *i, *j;
149: PetscContainer container_h;
150: MatCOOStruct_MPIAIJ *coo_h, *coo_d;
152: PetscFunctionBegin;
153: PetscCall(PetscFree(mpiaij->garray));
154: PetscCall(VecDestroy(&mpiaij->lvec));
155: #if defined(PETSC_USE_CTABLE)
156: PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
157: #else
158: PetscCall(PetscFree(mpiaij->colmap));
159: #endif
160: PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
161: mat->assembled = PETSC_FALSE;
162: mat->was_assembled = PETSC_FALSE;
163: PetscCall(PetscGetMemType(coo_i, &mtype));
164: if (PetscMemTypeDevice(mtype)) {
165: dev_ij = PETSC_TRUE;
166: PetscCall(PetscMalloc2(coo_n, &i, coo_n, &j));
167: PetscCallCUPM(cupmMemcpy(i, coo_i, coo_n * sizeof(PetscInt), cupmMemcpyDeviceToHost));
168: PetscCallCUPM(cupmMemcpy(j, coo_j, coo_n * sizeof(PetscInt), cupmMemcpyDeviceToHost));
169: } else {
170: i = coo_i;
171: j = coo_j;
172: }
173: PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, i, j));
174: if (dev_ij) PetscCall(PetscFree2(i, j));
175: mat->offloadmask = PETSC_OFFLOAD_CPU;
176: /* Create the GPU memory */
177: PetscCall(Policy::CopyToGPU(mpiaij->A));
178: PetscCall(Policy::CopyToGPU(mpiaij->B));
180: /* Copy the COO struct to device */
181: PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Host", (PetscObject *)&container_h));
182: PetscCall(PetscContainerGetPointer(container_h, (void **)&coo_h));
183: PetscCall(PetscMalloc1(1, &coo_d));
184: *coo_d = *coo_h; /* shallow copy; device fields amended below */
185: PetscCall(PetscObjectReference((PetscObject)coo_d->sf));
186: PetscCallCUPM(cupmMalloc((void **)&coo_d->Ajmap1, (coo_h->Annz + 1) * sizeof(PetscCount)));
187: PetscCallCUPM(cupmMalloc((void **)&coo_d->Aperm1, coo_h->Atot1 * sizeof(PetscCount)));
188: PetscCallCUPM(cupmMalloc((void **)&coo_d->Bjmap1, (coo_h->Bnnz + 1) * sizeof(PetscCount)));
189: PetscCallCUPM(cupmMalloc((void **)&coo_d->Bperm1, coo_h->Btot1 * sizeof(PetscCount)));
190: PetscCallCUPM(cupmMalloc((void **)&coo_d->Aimap2, coo_h->Annz2 * sizeof(PetscCount)));
191: PetscCallCUPM(cupmMalloc((void **)&coo_d->Ajmap2, (coo_h->Annz2 + 1) * sizeof(PetscCount)));
192: PetscCallCUPM(cupmMalloc((void **)&coo_d->Aperm2, coo_h->Atot2 * sizeof(PetscCount)));
193: PetscCallCUPM(cupmMalloc((void **)&coo_d->Bimap2, coo_h->Bnnz2 * sizeof(PetscCount)));
194: PetscCallCUPM(cupmMalloc((void **)&coo_d->Bjmap2, (coo_h->Bnnz2 + 1) * sizeof(PetscCount)));
195: PetscCallCUPM(cupmMalloc((void **)&coo_d->Bperm2, coo_h->Btot2 * sizeof(PetscCount)));
196: PetscCallCUPM(cupmMalloc((void **)&coo_d->Cperm1, coo_h->sendlen * sizeof(PetscCount)));
197: PetscCallCUPM(cupmMalloc((void **)&coo_d->sendbuf, coo_h->sendlen * sizeof(PetscScalar)));
198: PetscCallCUPM(cupmMalloc((void **)&coo_d->recvbuf, coo_h->recvlen * sizeof(PetscScalar)));
199: PetscCallCUPM(cupmMemcpy(coo_d->Ajmap1, coo_h->Ajmap1, (coo_h->Annz + 1) * sizeof(PetscCount), cupmMemcpyHostToDevice));
200: PetscCallCUPM(cupmMemcpy(coo_d->Aperm1, coo_h->Aperm1, coo_h->Atot1 * sizeof(PetscCount), cupmMemcpyHostToDevice));
201: PetscCallCUPM(cupmMemcpy(coo_d->Bjmap1, coo_h->Bjmap1, (coo_h->Bnnz + 1) * sizeof(PetscCount), cupmMemcpyHostToDevice));
202: PetscCallCUPM(cupmMemcpy(coo_d->Bperm1, coo_h->Bperm1, coo_h->Btot1 * sizeof(PetscCount), cupmMemcpyHostToDevice));
203: PetscCallCUPM(cupmMemcpy(coo_d->Aimap2, coo_h->Aimap2, coo_h->Annz2 * sizeof(PetscCount), cupmMemcpyHostToDevice));
204: PetscCallCUPM(cupmMemcpy(coo_d->Ajmap2, coo_h->Ajmap2, (coo_h->Annz2 + 1) * sizeof(PetscCount), cupmMemcpyHostToDevice));
205: PetscCallCUPM(cupmMemcpy(coo_d->Aperm2, coo_h->Aperm2, coo_h->Atot2 * sizeof(PetscCount), cupmMemcpyHostToDevice));
206: PetscCallCUPM(cupmMemcpy(coo_d->Bimap2, coo_h->Bimap2, coo_h->Bnnz2 * sizeof(PetscCount), cupmMemcpyHostToDevice));
207: PetscCallCUPM(cupmMemcpy(coo_d->Bjmap2, coo_h->Bjmap2, (coo_h->Bnnz2 + 1) * sizeof(PetscCount), cupmMemcpyHostToDevice));
208: PetscCallCUPM(cupmMemcpy(coo_d->Bperm2, coo_h->Bperm2, coo_h->Btot2 * sizeof(PetscCount), cupmMemcpyHostToDevice));
209: PetscCallCUPM(cupmMemcpy(coo_d->Cperm1, coo_h->Cperm1, coo_h->sendlen * sizeof(PetscCount), cupmMemcpyHostToDevice));
210: /* Put the COO struct in a container and attach it to the matrix */
211: PetscCall(PetscObjectContainerCompose((PetscObject)mat, "__PETSc_MatCOOStruct_Device", coo_d, MatMPIAIJCUSPARSE_CUPM::COOStructDestroy));
212: PetscFunctionReturn(PETSC_SUCCESS);
213: }
215: /* MatSetValuesCOO: launch CUPM kernels for packing and adding local/remote COO values */
216: static PetscErrorCode SetValuesCOO(Mat mat, const PetscScalar v[], InsertMode imode) noexcept
217: {
218: Mat_MPIAIJ *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
219: Mat A = mpiaij->A, B = mpiaij->B;
220: PetscScalar *Aa, *Ba;
221: const PetscScalar *v1 = v;
222: PetscMemType memtype;
223: PetscContainer container;
224: MatCOOStruct_MPIAIJ *coo;
225: cupmStream_t stream;
227: PetscFunctionBegin;
228: PetscCall(PetscObjectQuery((PetscObject)mat, "__PETSc_MatCOOStruct_Device", (PetscObject *)&container));
229: PetscCheck(container, PetscObjectComm((PetscObject)mat), PETSC_ERR_PLIB, "Not found MatCOOStruct on this matrix");
230: PetscCall(PetscContainerGetPointer(container, (void **)&coo));
232: const auto &Annz = coo->Annz;
233: const auto &Annz2 = coo->Annz2;
234: const auto &Bnnz = coo->Bnnz;
235: const auto &Bnnz2 = coo->Bnnz2;
236: const auto &vsend = coo->sendbuf;
237: const auto &v2 = coo->recvbuf;
238: const auto &Ajmap1 = coo->Ajmap1;
239: const auto &Ajmap2 = coo->Ajmap2;
240: const auto &Aimap2 = coo->Aimap2;
241: const auto &Bjmap1 = coo->Bjmap1;
242: const auto &Bjmap2 = coo->Bjmap2;
243: const auto &Bimap2 = coo->Bimap2;
244: const auto &Aperm1 = coo->Aperm1;
245: const auto &Aperm2 = coo->Aperm2;
246: const auto &Bperm1 = coo->Bperm1;
247: const auto &Bperm2 = coo->Bperm2;
248: const auto &Cperm1 = coo->Cperm1;
250: PetscCall(PetscGetMemType(v, &memtype));
251: if (PetscMemTypeHost(memtype)) {
252: PetscCallCUPM(cupmMalloc((void **)&v1, coo->n * sizeof(PetscScalar)));
253: PetscCallCUPM(cupmMemcpy((void *)v1, v, coo->n * sizeof(PetscScalar), cupmMemcpyHostToDevice));
254: }
256: if (imode == INSERT_VALUES) {
257: PetscCall(Policy::GetArrayWrite(A, &Aa));
258: PetscCall(Policy::GetArrayWrite(B, &Ba));
259: } else {
260: PetscCall(Policy::GetArray(A, &Aa));
261: PetscCall(Policy::GetArray(B, &Ba));
262: }
264: PetscCall(GetHandles_(&stream));
265: PetscCall(PetscLogGpuTimeBegin());
266: /* Pack entries to be sent to remote */
267: if (coo->sendlen) {
268: PetscCallCUPM(cupmLaunchKernel(MatPackCOOValues_MPI, (unsigned int)((coo->sendlen + 255) / 256), 256u, (size_t)0, stream, v1, (PetscCount)coo->sendlen, Cperm1, vsend));
269: PetscCallCUPM(cupmGetLastError());
270: }
271: /* Send remote entries and overlap communication with local computation */
272: PetscCall(PetscSFReduceWithMemTypeBegin(coo->sf, MPIU_SCALAR, PETSC_MEMTYPE_CUPM(), vsend, PETSC_MEMTYPE_CUPM(), v2, MPI_REPLACE));
273: /* Add local entries to A and B */
274: if (Annz + Bnnz > 0) {
275: PetscCallCUPM(cupmLaunchKernel(MatAddLocalCOOValues_MPI, (unsigned int)((Annz + Bnnz + 255) / 256), 256u, (size_t)0, stream, v1, imode, Annz, Ajmap1, Aperm1, Aa, Bnnz, Bjmap1, Bperm1, Ba));
276: PetscCallCUPM(cupmGetLastError());
277: }
278: PetscCall(PetscSFReduceEnd(coo->sf, MPIU_SCALAR, vsend, v2, MPI_REPLACE));
279: /* Add received remote entries to A and B */
280: if (Annz2 + Bnnz2 > 0) {
281: PetscCallCUPM(cupmLaunchKernel(MatAddRemoteCOOValues_MPI, (unsigned int)((Annz2 + Bnnz2 + 255) / 256), 256u, (size_t)0, stream, v2, Annz2, Aimap2, Ajmap2, Aperm2, Aa, Bnnz2, Bimap2, Bjmap2, Bperm2, Ba));
282: PetscCallCUPM(cupmGetLastError());
283: }
284: PetscCall(PetscLogGpuTimeEnd());
286: if (imode == INSERT_VALUES) {
287: PetscCall(Policy::RestoreArrayWrite(A, &Aa));
288: PetscCall(Policy::RestoreArrayWrite(B, &Ba));
289: } else {
290: PetscCall(Policy::RestoreArray(A, &Aa));
291: PetscCall(Policy::RestoreArray(B, &Ba));
292: }
293: if (PetscMemTypeHost(memtype)) {
294: void *v1_device = (void *)v1;
295: PetscCallCUPM(cupmFree(v1_device));
296: }
297: mat->offloadmask = PETSC_OFFLOAD_GPU;
298: PetscFunctionReturn(PETSC_SUCCESS);
299: }
301: /* MatMPIAIJGetLocalMatMerge */
302: static PetscErrorCode GetLocalMatMerge(Mat A, MatReuse scall, IS *glob, Mat *A_loc) noexcept
303: {
304: Mat Ad, Ao;
305: const PetscInt *cmap;
307: PetscFunctionBegin;
308: PetscCall(MatMPIAIJGetSeqAIJ(A, &Ad, &Ao, &cmap));
309: PetscCall(Policy::MergeMats(Ad, Ao, scall, A_loc));
310: if (glob) {
311: PetscInt cst, i, dn, on, *gidx;
313: PetscCall(MatGetLocalSize(Ad, NULL, &dn));
314: PetscCall(MatGetLocalSize(Ao, NULL, &on));
315: PetscCall(MatGetOwnershipRangeColumn(A, &cst, NULL));
316: PetscCall(PetscMalloc1(dn + on, &gidx));
317: for (i = 0; i < dn; i++) gidx[i] = cst + i;
318: for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
319: PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
320: }
321: PetscFunctionReturn(PETSC_SUCCESS);
322: }
324: /* MatMPIAIJSetPreallocation */
325: static PetscErrorCode SetPreallocation(Mat B, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[]) noexcept
326: {
327: Mat_MPIAIJ *b = (Mat_MPIAIJ *)B->data;
328: MatStructType *spptr = (MatStructType *)b->spptr;
329: PetscInt i;
331: PetscFunctionBegin;
332: if (B->hash_active) {
333: B->ops[0] = b->cops;
334: B->hash_active = PETSC_FALSE;
335: }
336: PetscCall(PetscLayoutSetUp(B->rmap));
337: PetscCall(PetscLayoutSetUp(B->cmap));
338: if (PetscDefined(USE_DEBUG) && d_nnz) {
339: for (i = 0; i < B->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
340: }
341: if (PetscDefined(USE_DEBUG) && o_nnz) {
342: for (i = 0; i < B->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
343: }
344: #if defined(PETSC_USE_CTABLE)
345: PetscCall(PetscHMapIDestroy(&b->colmap));
346: #else
347: PetscCall(PetscFree(b->colmap));
348: #endif
349: PetscCall(PetscFree(b->garray));
350: PetscCall(VecDestroy(&b->lvec));
351: PetscCall(VecScatterDestroy(&b->Mvctx));
352: /* Because the B will have been resized we simply destroy it and create a new one each time */
353: PetscCall(MatDestroy(&b->B));
354: if (!b->A) {
355: PetscCall(MatCreate(PETSC_COMM_SELF, &b->A));
356: PetscCall(MatSetSizes(b->A, B->rmap->n, B->cmap->n, B->rmap->n, B->cmap->n));
357: }
358: if (!b->B) {
359: PetscMPIInt size;
361: PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)B), &size));
362: PetscCall(MatCreate(PETSC_COMM_SELF, &b->B));
363: PetscCall(MatSetSizes(b->B, B->rmap->n, size > 1 ? B->cmap->N : 0, B->rmap->n, size > 1 ? B->cmap->N : 0));
364: }
365: PetscCall(MatSetType(b->A, Policy::seq_mat_type));
366: PetscCall(MatSetType(b->B, Policy::seq_mat_type));
367: PetscCall(MatBindToCPU(b->A, B->boundtocpu));
368: PetscCall(MatBindToCPU(b->B, B->boundtocpu));
369: PetscCall(MatSeqAIJSetPreallocation(b->A, d_nz, d_nnz));
370: PetscCall(MatSeqAIJSetPreallocation(b->B, o_nz, o_nnz));
371: PetscCall(Policy::SetSubMatFormats(b->A, b->B, spptr));
372: B->preallocated = PETSC_TRUE;
373: B->was_assembled = PETSC_FALSE;
374: B->assembled = PETSC_FALSE;
375: PetscFunctionReturn(PETSC_SUCCESS);
376: }
378: /* MatMult: identical in both CUDA and HIP */
379: static PetscErrorCode Mult(Mat A, Vec xx, Vec yy) noexcept
380: {
381: Mat_MPIAIJ *a = (Mat_MPIAIJ *)A->data;
383: PetscFunctionBegin;
384: PetscCall(VecScatterBegin(a->Mvctx, xx, a->lvec, INSERT_VALUES, SCATTER_FORWARD));
385: PetscUseTypeMethod(a->A, mult, xx, yy);
386: PetscCall(VecScatterEnd(a->Mvctx, xx, a->lvec, INSERT_VALUES, SCATTER_FORWARD));
387: PetscUseTypeMethod(a->B, multadd, a->lvec, yy, yy);
388: PetscFunctionReturn(PETSC_SUCCESS);
389: }
391: /* MatZeroEntries: identical in both CUDA and HIP */
392: static PetscErrorCode ZeroEntries(Mat A) noexcept
393: {
394: Mat_MPIAIJ *l = (Mat_MPIAIJ *)A->data;
396: PetscFunctionBegin;
397: PetscCall(MatZeroEntries(l->A));
398: PetscCall(MatZeroEntries(l->B));
399: PetscFunctionReturn(PETSC_SUCCESS);
400: }
402: /* MatMultAdd: identical in both CUDA and HIP */
403: static PetscErrorCode MultAdd(Mat A, Vec xx, Vec yy, Vec zz) noexcept
404: {
405: Mat_MPIAIJ *a = (Mat_MPIAIJ *)A->data;
407: PetscFunctionBegin;
408: PetscCall(VecScatterBegin(a->Mvctx, xx, a->lvec, INSERT_VALUES, SCATTER_FORWARD));
409: PetscUseTypeMethod(a->A, multadd, xx, yy, zz);
410: PetscCall(VecScatterEnd(a->Mvctx, xx, a->lvec, INSERT_VALUES, SCATTER_FORWARD));
411: PetscUseTypeMethod(a->B, multadd, a->lvec, zz, zz);
412: PetscFunctionReturn(PETSC_SUCCESS);
413: }
415: /* MatMultTranspose: identical in both CUDA and HIP */
416: static PetscErrorCode MultTranspose(Mat A, Vec xx, Vec yy) noexcept
417: {
418: Mat_MPIAIJ *a = (Mat_MPIAIJ *)A->data;
420: PetscFunctionBegin;
421: PetscUseTypeMethod(a->B, multtranspose, xx, a->lvec);
422: PetscUseTypeMethod(a->A, multtranspose, xx, yy);
423: PetscCall(VecScatterBegin(a->Mvctx, a->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
424: PetscCall(VecScatterEnd(a->Mvctx, a->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
425: PetscFunctionReturn(PETSC_SUCCESS);
426: }
428: /* MatAssemblyEnd: set lvec type to vendor-appropriate VECSEQ type */
429: static PetscErrorCode AssemblyEnd(Mat A, MatAssemblyType mode) noexcept
430: {
431: Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;
433: PetscFunctionBegin;
434: PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
435: if (mpiaij->lvec) PetscCall(VecSetType(mpiaij->lvec, Policy::vec_seq_type));
436: PetscFunctionReturn(PETSC_SUCCESS);
437: }
439: /* MatCreateAIJ: allocate and preallocate a parallel sparse matrix of this type */
440: static PetscErrorCode CreateAIJ(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A) noexcept
441: {
442: PetscMPIInt size;
444: PetscFunctionBegin;
445: PetscCall(MatCreate(comm, A));
446: PetscCall(MatSetSizes(*A, m, n, M, N));
447: PetscCallMPI(MPI_Comm_size(comm, &size));
448: if (size > 1) {
449: PetscCall(MatSetType(*A, Policy::mpi_mat_type));
450: PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
451: } else {
452: PetscCall(MatSetType(*A, Policy::seq_mat_type));
453: PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
454: }
455: PetscFunctionReturn(PETSC_SUCCESS);
456: }
458: /* MatDestroy: free vendor-specific state, deregister composed functions */
459: static PetscErrorCode Destroy(Mat A) noexcept
460: {
461: Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
462: MatStructType *mpiStruct = (MatStructType *)aij->spptr;
464: PetscFunctionBegin;
465: PetscCheck(mpiStruct, PETSC_COMM_SELF, PETSC_ERR_COR, "Missing spptr");
466: PetscCallCXX(delete mpiStruct);
467: PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
468: PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
469: PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
470: PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
471: PetscCall(PetscObjectComposeFunction((PetscObject)A, Policy::set_format_c, NULL));
472: PetscCall(PetscObjectComposeFunction((PetscObject)A, Policy::mpi_convert_hypre_c, NULL));
473: PetscCall(MatDestroy_MPIAIJ(A));
474: PetscFunctionReturn(PETSC_SUCCESS);
475: }
476: };
478: } // namespace impl
480: } // namespace cupm
482: } // namespace aij
484: } // namespace mat
486: } // namespace Petsc