Actual source code: sfnvshmem.cu

  1: #include <petsc/private/cudavecimpl.h>
  2: #include <../src/vec/is/sf/impls/basic/sfpack.h>
  3: #include <mpi.h>
  4: #include <nvshmem.h>
  5: #include <nvshmemx.h>

  7: PetscErrorCode PetscNvshmemInitializeCheck(void)
  8: {
  9:   PetscFunctionBegin;
 10:   if (!PetscNvshmemInitialized) { /* Note NVSHMEM does not provide a routine to check whether it is initialized */
 11:     nvshmemx_init_attr_t attr;
 12:     attr.mpi_comm = &PETSC_COMM_WORLD;
 13:     PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
 14:     PetscCall(nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr));
 15:     PetscNvshmemInitialized = PETSC_TRUE;
 16:     PetscBeganNvshmem       = PETSC_TRUE;
 17:   }
 18:   PetscFunctionReturn(PETSC_SUCCESS);
 19: }

 21: PetscErrorCode PetscNvshmemMalloc(size_t size, void **ptr)
 22: {
 23:   PetscFunctionBegin;
 24:   PetscCall(PetscNvshmemInitializeCheck());
 25:   *ptr = nvshmem_malloc(size);
 26:   PetscCheck(*ptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "nvshmem_malloc() failed to allocate %zu bytes", size);
 27:   PetscFunctionReturn(PETSC_SUCCESS);
 28: }

 30: PetscErrorCode PetscNvshmemCalloc(size_t size, void **ptr)
 31: {
 32:   PetscFunctionBegin;
 33:   PetscCall(PetscNvshmemInitializeCheck());
 34:   *ptr = nvshmem_calloc(size, 1);
 35:   PetscCheck(*ptr, PETSC_COMM_SELF, PETSC_ERR_ARG_WRONG, "nvshmem_calloc() failed to allocate %zu bytes", size);
 36:   PetscFunctionReturn(PETSC_SUCCESS);
 37: }

 39: PetscErrorCode PetscNvshmemFree_Private(void *ptr)
 40: {
 41:   PetscFunctionBegin;
 42:   nvshmem_free(ptr);
 43:   PetscFunctionReturn(PETSC_SUCCESS);
 44: }

 46: /*@C
 47:   PetscNvshmemFinalize - Tear down NVSHMEM after PETSc is done using it

 49:   Collective

 51:   Level: developer

 53:   Note:
 54:   Called internally during `PetscFinalize()` if PETSc previously initialized NVSHMEM. Users normally
 55:   do not need to call this directly.

 57: .seealso: `PetscFinalize()`, `PetscSF`
 58: @*/
 59: PetscErrorCode PetscNvshmemFinalize(void)
 60: {
 61:   PetscFunctionBegin;
 62:   nvshmem_finalize();
 63:   PetscFunctionReturn(PETSC_SUCCESS);
 64: }

 66: /* Free nvshmem related fields in the SF */
 67: PetscErrorCode PetscSFReset_Basic_NVSHMEM(PetscSF sf)
 68: {
 69:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;

 71:   PetscFunctionBegin;
 72:   PetscCall(PetscFree2(bas->leafsigdisp, bas->leafbufdisp));
 73:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->leafbufdisp_d));
 74:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->leafsigdisp_d));
 75:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->iranks_d));
 76:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, bas->ioffset_d));

 78:   PetscCall(PetscFree2(sf->rootsigdisp, sf->rootbufdisp));
 79:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->rootbufdisp_d));
 80:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->rootsigdisp_d));
 81:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->ranks_d));
 82:   PetscCall(PetscSFFree(sf, PETSC_MEMTYPE_CUDA, sf->roffset_d));
 83:   PetscFunctionReturn(PETSC_SUCCESS);
 84: }

 86: /* Set up NVSHMEM related fields for an SF of type SFBASIC (only after PetscSFSetup_Basic() already set up dependent fields) */
 87: static PetscErrorCode PetscSFSetUp_Basic_NVSHMEM(PetscSF sf)
 88: {
 89:   cudaError_t    cerr;
 90:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
 91:   PetscInt       i, nRemoteRootRanks, nRemoteLeafRanks;
 92:   PetscMPIInt    tag;
 93:   MPI_Comm       comm;
 94:   MPI_Request   *rootreqs, *leafreqs;
 95:   PetscInt       tmp, stmp[4], rtmp[4]; /* tmps for send/recv buffers */

 97:   PetscFunctionBegin;
 98:   PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
 99:   PetscCall(PetscObjectGetNewTag((PetscObject)sf, &tag));

101:   nRemoteRootRanks      = sf->nranks - sf->ndranks;
102:   nRemoteLeafRanks      = bas->niranks - bas->ndiranks;
103:   sf->nRemoteRootRanks  = nRemoteRootRanks;
104:   bas->nRemoteLeafRanks = nRemoteLeafRanks;

106:   PetscCall(PetscMalloc2(nRemoteLeafRanks, &rootreqs, nRemoteRootRanks, &leafreqs));

108:   stmp[0] = nRemoteRootRanks;
109:   stmp[1] = sf->leafbuflen[PETSCSF_REMOTE];
110:   stmp[2] = nRemoteLeafRanks;
111:   stmp[3] = bas->rootbuflen[PETSCSF_REMOTE];

113:   PetscCallMPI(MPIU_Allreduce(stmp, rtmp, 4, MPIU_INT, MPI_MAX, comm));

115:   sf->nRemoteRootRanksMax  = rtmp[0];
116:   sf->leafbuflen_rmax      = rtmp[1];
117:   bas->nRemoteLeafRanksMax = rtmp[2];
118:   bas->rootbuflen_rmax     = rtmp[3];

120:   /* Total four rounds of MPI communications to set up the nvshmem fields */

122:   /* Root ranks to leaf ranks: send info about rootsigdisp[] and rootbufdisp[] */
123:   PetscCall(PetscMalloc2(nRemoteRootRanks, &sf->rootsigdisp, nRemoteRootRanks, &sf->rootbufdisp));
124:   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPIU_Irecv(&sf->rootsigdisp[i], 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm, &leafreqs[i])); /* Leaves recv */
125:   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPI_Send(&i, 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm));                              /* Roots send. Note i changes, so we use MPI_Send. */
126:   PetscCallMPI(MPI_Waitall(nRemoteRootRanks, leafreqs, MPI_STATUSES_IGNORE));

128:   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPIU_Irecv(&sf->rootbufdisp[i], 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm, &leafreqs[i])); /* Leaves recv */
129:   for (i = 0; i < nRemoteLeafRanks; i++) {
130:     tmp = bas->ioffset[i + bas->ndiranks] - bas->ioffset[bas->ndiranks];
131:     PetscCallMPI(MPI_Send(&tmp, 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm)); /* Roots send. Note tmp changes, so we use MPI_Send. */
132:   }
133:   PetscCallMPI(MPI_Waitall(nRemoteRootRanks, leafreqs, MPI_STATUSES_IGNORE));

135:   PetscCallCUDA(cudaMalloc((void **)&sf->rootbufdisp_d, nRemoteRootRanks * sizeof(PetscInt)));
136:   PetscCallCUDA(cudaMalloc((void **)&sf->rootsigdisp_d, nRemoteRootRanks * sizeof(PetscInt)));
137:   PetscCallCUDA(cudaMalloc((void **)&sf->ranks_d, nRemoteRootRanks * sizeof(PetscMPIInt)));
138:   PetscCallCUDA(cudaMalloc((void **)&sf->roffset_d, (nRemoteRootRanks + 1) * sizeof(PetscInt)));

140:   PetscCallCUDA(cudaMemcpyAsync(sf->rootbufdisp_d, sf->rootbufdisp, nRemoteRootRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
141:   PetscCallCUDA(cudaMemcpyAsync(sf->rootsigdisp_d, sf->rootsigdisp, nRemoteRootRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
142:   PetscCallCUDA(cudaMemcpyAsync(sf->ranks_d, sf->ranks + sf->ndranks, nRemoteRootRanks * sizeof(PetscMPIInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
143:   PetscCallCUDA(cudaMemcpyAsync(sf->roffset_d, sf->roffset + sf->ndranks, (nRemoteRootRanks + 1) * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));

145:   /* Leaf ranks to root ranks: send info about leafsigdisp[] and leafbufdisp[] */
146:   PetscCall(PetscMalloc2(nRemoteLeafRanks, &bas->leafsigdisp, nRemoteLeafRanks, &bas->leafbufdisp));
147:   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPIU_Irecv(&bas->leafsigdisp[i], 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm, &rootreqs[i]));
148:   for (i = 0; i < nRemoteRootRanks; i++) PetscCallMPI(MPI_Send(&i, 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm));
149:   PetscCallMPI(MPI_Waitall(nRemoteLeafRanks, rootreqs, MPI_STATUSES_IGNORE));

151:   for (i = 0; i < nRemoteLeafRanks; i++) PetscCallMPI(MPIU_Irecv(&bas->leafbufdisp[i], 1, MPIU_INT, bas->iranks[i + bas->ndiranks], tag, comm, &rootreqs[i]));
152:   for (i = 0; i < nRemoteRootRanks; i++) {
153:     tmp = sf->roffset[i + sf->ndranks] - sf->roffset[sf->ndranks];
154:     PetscCallMPI(MPI_Send(&tmp, 1, MPIU_INT, sf->ranks[i + sf->ndranks], tag, comm));
155:   }
156:   PetscCallMPI(MPI_Waitall(nRemoteLeafRanks, rootreqs, MPI_STATUSES_IGNORE));

158:   PetscCallCUDA(cudaMalloc((void **)&bas->leafbufdisp_d, nRemoteLeafRanks * sizeof(PetscInt)));
159:   PetscCallCUDA(cudaMalloc((void **)&bas->leafsigdisp_d, nRemoteLeafRanks * sizeof(PetscInt)));
160:   PetscCallCUDA(cudaMalloc((void **)&bas->iranks_d, nRemoteLeafRanks * sizeof(PetscMPIInt)));
161:   PetscCallCUDA(cudaMalloc((void **)&bas->ioffset_d, (nRemoteLeafRanks + 1) * sizeof(PetscInt)));

163:   PetscCallCUDA(cudaMemcpyAsync(bas->leafbufdisp_d, bas->leafbufdisp, nRemoteLeafRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
164:   PetscCallCUDA(cudaMemcpyAsync(bas->leafsigdisp_d, bas->leafsigdisp, nRemoteLeafRanks * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
165:   PetscCallCUDA(cudaMemcpyAsync(bas->iranks_d, bas->iranks + bas->ndiranks, nRemoteLeafRanks * sizeof(PetscMPIInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));
166:   PetscCallCUDA(cudaMemcpyAsync(bas->ioffset_d, bas->ioffset + bas->ndiranks, (nRemoteLeafRanks + 1) * sizeof(PetscInt), cudaMemcpyHostToDevice, PetscDefaultCudaStream));

168:   PetscCall(PetscFree2(rootreqs, leafreqs));
169:   PetscFunctionReturn(PETSC_SUCCESS);
170: }

172: PetscErrorCode PetscSFLinkNvshmemCheck(PetscSF sf, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, PetscBool *use_nvshmem)
173: {
174:   MPI_Comm    comm;
175:   PetscBool   isBasic;
176:   PetscMPIInt result = MPI_UNEQUAL;

178:   PetscFunctionBegin;
179:   PetscCall(PetscObjectGetComm((PetscObject)sf, &comm));
180:   /* Check if the sf is eligible for NVSHMEM, if we have not checked yet.
181:      Note the check result <use_nvshmem> must be the same over comm, since an SFLink must be collectively either NVSHMEM or MPI.
182:   */
183:   sf->checked_nvshmem_eligibility = PETSC_TRUE;
184:   if (sf->use_nvshmem && !sf->checked_nvshmem_eligibility) {
185:     /* Only use NVSHMEM for SFBASIC on PETSC_COMM_WORLD  */
186:     PetscCall(PetscObjectTypeCompare((PetscObject)sf, PETSCSFBASIC, &isBasic));
187:     if (isBasic) PetscCallMPI(MPI_Comm_compare(PETSC_COMM_WORLD, comm, &result));
188:     if (!isBasic || (result != MPI_IDENT && result != MPI_CONGRUENT)) sf->use_nvshmem = PETSC_FALSE; /* If not eligible, clear the flag so that we don't try again */

190:     /* Do further check: If on a rank, both rootdata and leafdata are NULL, we might think they are PETSC_MEMTYPE_CUDA (or HOST)
191:        and then use NVSHMEM. But if root/leafmtypes on other ranks are PETSC_MEMTYPE_HOST (or DEVICE), this would lead to
192:        inconsistency on the return value <use_nvshmem>. To be safe, we simply disable nvshmem on these rare SFs.
193:     */
194:     if (sf->use_nvshmem) {
195:       PetscInt hasNullRank = (!rootdata && !leafdata) ? 1 : 0;
196:       PetscCallMPI(MPIU_Allreduce(MPI_IN_PLACE, &hasNullRank, 1, MPIU_INT, MPI_LOR, comm));
197:       if (hasNullRank) sf->use_nvshmem = PETSC_FALSE;
198:     }
199:     sf->checked_nvshmem_eligibility = PETSC_TRUE; /* If eligible, don't do above check again */
200:   }

202:   /* Check if rootmtype and leafmtype collectively are PETSC_MEMTYPE_CUDA */
203:   if (sf->use_nvshmem) {
204:     PetscInt oneCuda = (!rootdata || PetscMemTypeCUDA(rootmtype)) && (!leafdata || PetscMemTypeCUDA(leafmtype)) ? 1 : 0; /* Do I use cuda for both root&leafmtype? */
205:     PetscInt allCuda = oneCuda;                                                                                          /* Assume the same for all ranks. But if not, in opt mode, return value <use_nvshmem> won't be collective! */
206: #if defined(PETSC_USE_DEBUG)                                                                                             /* Check in debug mode. Note MPI_Allreduce is expensive, so only in debug mode */
207:     PetscCallMPI(MPIU_Allreduce(&oneCuda, &allCuda, 1, MPIU_INT, MPI_LAND, comm));
208:     PetscCheck(allCuda == oneCuda, comm, PETSC_ERR_SUP, "root/leaf mtypes are inconsistent among ranks, which may lead to SF nvshmem failure in opt mode. Add -use_nvshmem 0 to disable it.");
209: #endif
210:     if (allCuda) {
211:       PetscCall(PetscNvshmemInitializeCheck());
212:       if (!sf->setup_nvshmem) { /* Set up nvshmem related fields on this SF on-demand */
213:         PetscCall(PetscSFSetUp_Basic_NVSHMEM(sf));
214:         sf->setup_nvshmem = PETSC_TRUE;
215:       }
216:       *use_nvshmem = PETSC_TRUE;
217:     } else {
218:       *use_nvshmem = PETSC_FALSE;
219:     }
220:   } else {
221:     *use_nvshmem = PETSC_FALSE;
222:   }
223:   PetscFunctionReturn(PETSC_SUCCESS);
224: }

226: /* Build dependence between <stream> and <remoteCommStream> at the entry of NVSHMEM communication */
227: static PetscErrorCode PetscSFLinkBuildDependenceBegin(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
228: {
229:   cudaError_t    cerr;
230:   PetscSF_Basic *bas    = (PetscSF_Basic *)sf->data;
231:   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF) ? bas->rootbuflen[PETSCSF_REMOTE] : sf->leafbuflen[PETSCSF_REMOTE];

233:   PetscFunctionBegin;
234:   if (buflen) {
235:     PetscCallCUDA(cudaEventRecord(link->dataReady, link->stream));
236:     PetscCallCUDA(cudaStreamWaitEvent(link->remoteCommStream, link->dataReady, 0));
237:   }
238:   PetscFunctionReturn(PETSC_SUCCESS);
239: }

241: /* Build dependence between <stream> and <remoteCommStream> at the exit of NVSHMEM communication */
242: static PetscErrorCode PetscSFLinkBuildDependenceEnd(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
243: {
244:   cudaError_t    cerr;
245:   PetscSF_Basic *bas    = (PetscSF_Basic *)sf->data;
246:   PetscInt       buflen = (direction == PETSCSF_ROOT2LEAF) ? sf->leafbuflen[PETSCSF_REMOTE] : bas->rootbuflen[PETSCSF_REMOTE];

248:   PetscFunctionBegin;
249:   /* If unpack to non-null device buffer, build the endRemoteComm dependence */
250:   if (buflen) {
251:     PetscCallCUDA(cudaEventRecord(link->endRemoteComm, link->remoteCommStream));
252:     PetscCallCUDA(cudaStreamWaitEvent(link->stream, link->endRemoteComm, 0));
253:   }
254:   PetscFunctionReturn(PETSC_SUCCESS);
255: }

257: /* Send/Put signals to remote ranks

259:  Input parameters:
260:   + n        - Number of remote ranks
261:   . sig      - Signal address in symmetric heap
262:   . sigdisp  - To i-th rank, use its signal at offset sigdisp[i]
263:   . ranks    - remote ranks
264:   - newval   - Set signals to this value
265: */
266: __global__ static void NvshmemSendSignals(PetscInt n, uint64_t *sig, PetscInt *sigdisp, PetscMPIInt *ranks, uint64_t newval)
267: {
268:   int i = blockIdx.x * blockDim.x + threadIdx.x;

270:   /* Each thread puts one remote signal */
271:   if (i < n) nvshmemx_uint64_signal(sig + sigdisp[i], newval, ranks[i]);
272: }

274: /* Wait until local signals equal to the expected value and then set them to a new value

276:  Input parameters:
277:   + n        - Number of signals
278:   . sig      - Local signal address
279:   . expval   - expected value
280:   - newval   - Set signals to this new value
281: */
282: __global__ static void NvshmemWaitSignals(PetscInt n, uint64_t *sig, uint64_t expval, uint64_t newval)
283: {
284: #if 0
285:   /* Akhil Langer@NVIDIA said using 1 thread and nvshmem_uint64_wait_until_all is better */
286:   int i = blockIdx.x*blockDim.x + threadIdx.x;
287:   if (i < n) {
288:     nvshmem_signal_wait_until(sig+i,NVSHMEM_CMP_EQ,expval);
289:     sig[i] = newval;
290:   }
291: #else
292:   nvshmem_uint64_wait_until_all(sig, n, NULL /*no mask*/, NVSHMEM_CMP_EQ, expval);
293:   for (int i = 0; i < n; i++) sig[i] = newval;
294: #endif
295: }

297: /* ===========================================================================================================

299:    A set of routines to support receiver initiated communication using the get method

301:     The getting protocol is:

303:     Sender has a send buf (sbuf) and a signal variable (ssig);  Receiver has a recv buf (rbuf) and a signal variable (rsig);
304:     All signal variables have an initial value 0.

306:     Sender:                                 |  Receiver:
307:   1.  Wait ssig be 0, then set it to 1
308:   2.  Pack data into stand alone sbuf       |
309:   3.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
310:                                             |   2. Get data from remote sbuf to local rbuf
311:                                             |   3. Put 1 to sender's ssig
312:                                             |   4. Unpack data from local rbuf
313:    ===========================================================================================================*/
314: /* PrePack operation -- since sender will overwrite the send buffer which the receiver might be getting data from.
315:    Sender waits for signals (from receivers) indicating receivers have finished getting data
316: */
317: static PetscErrorCode PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
318: {
319:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
320:   uint64_t      *sig;
321:   PetscInt       n;

323:   PetscFunctionBegin;
324:   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
325:     sig = link->rootSendSig;            /* leaf ranks set my rootSendsig */
326:     n   = bas->nRemoteLeafRanks;
327:   } else { /* LEAF2ROOT */
328:     sig = link->leafSendSig;
329:     n   = sf->nRemoteRootRanks;
330:   }

332:   if (n) {
333:     NvshmemWaitSignals<<<1, 1, 0, link->remoteCommStream>>>(n, sig, 0, 1); /* wait the signals to be 0, then set them to 1 */
334:     PetscCallCUDA(cudaGetLastError());
335:   }
336:   PetscFunctionReturn(PETSC_SUCCESS);
337: }

339: /* n thread blocks. Each takes in charge one remote rank */
340: __global__ static void GetDataFromRemotelyAccessible(PetscInt nsrcranks, PetscMPIInt *srcranks, const char *src, PetscInt *srcdisp, char *dst, PetscInt *dstdisp, PetscInt unitbytes)
341: {
342:   int         bid = blockIdx.x;
343:   PetscMPIInt pe  = srcranks[bid];

345:   if (!nvshmem_ptr(src, pe)) {
346:     PetscInt nelems = (dstdisp[bid + 1] - dstdisp[bid]) * unitbytes;
347:     nvshmem_getmem_nbi(dst + (dstdisp[bid] - dstdisp[0]) * unitbytes, src + srcdisp[bid] * unitbytes, nelems, pe);
348:   }
349: }

351: /* Start communication -- Get data in the given direction */
352: static PetscErrorCode PetscSFLinkGetDataBegin_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
353: {
354:   cudaError_t    cerr;
355:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;

357:   PetscInt nsrcranks, ndstranks, nLocallyAccessible = 0;

359:   char        *src, *dst;
360:   PetscInt    *srcdisp_h, *dstdisp_h;
361:   PetscInt    *srcdisp_d, *dstdisp_d;
362:   PetscMPIInt *srcranks_h;
363:   PetscMPIInt *srcranks_d, *dstranks_d;
364:   uint64_t    *dstsig;
365:   PetscInt    *dstsigdisp_d;

367:   PetscFunctionBegin;
368:   PetscCall(PetscSFLinkBuildDependenceBegin(sf, link, direction));
369:   if (direction == PETSCSF_ROOT2LEAF) { /* src is root, dst is leaf; we will move data from src to dst */
370:     nsrcranks = sf->nRemoteRootRanks;
371:     src       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* root buf is the send buf; it is in symmetric heap */

373:     srcdisp_h  = sf->rootbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
374:     srcdisp_d  = sf->rootbufdisp_d;
375:     srcranks_h = sf->ranks + sf->ndranks; /* my (remote) root ranks */
376:     srcranks_d = sf->ranks_d;

378:     ndstranks = bas->nRemoteLeafRanks;
379:     dst       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* recv buf is the local leaf buf, also in symmetric heap */

381:     dstdisp_h  = sf->roffset + sf->ndranks; /* offsets of the local leaf buf. Note dstdisp[0] is not necessarily 0 */
382:     dstdisp_d  = sf->roffset_d;
383:     dstranks_d = bas->iranks_d; /* my (remote) leaf ranks */

385:     dstsig       = link->leafRecvSig;
386:     dstsigdisp_d = bas->leafsigdisp_d;
387:   } else { /* src is leaf, dst is root; we will move data from src to dst */
388:     nsrcranks = bas->nRemoteLeafRanks;
389:     src       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* leaf buf is the send buf */

391:     srcdisp_h  = bas->leafbufdisp; /* for my i-th remote root rank, I will access its buf at offset rootbufdisp[i] */
392:     srcdisp_d  = bas->leafbufdisp_d;
393:     srcranks_h = bas->iranks + bas->ndiranks; /* my (remote) root ranks */
394:     srcranks_d = bas->iranks_d;

396:     ndstranks = sf->nRemoteRootRanks;
397:     dst       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* the local root buf is the recv buf */

399:     dstdisp_h  = bas->ioffset + bas->ndiranks; /* offsets of the local root buf. Note dstdisp[0] is not necessarily 0 */
400:     dstdisp_d  = bas->ioffset_d;
401:     dstranks_d = sf->ranks_d; /* my (remote) root ranks */

403:     dstsig       = link->rootRecvSig;
404:     dstsigdisp_d = sf->rootsigdisp_d;
405:   }

407:   /* After Pack operation -- src tells dst ranks that they are allowed to get data */
408:   if (ndstranks) {
409:     NvshmemSendSignals<<<(ndstranks + 255) / 256, 256, 0, link->remoteCommStream>>>(ndstranks, dstsig, dstsigdisp_d, dstranks_d, 1); /* set signals to 1 */
410:     PetscCallCUDA(cudaGetLastError());
411:   }

413:   /* dst waits for signals (permissions) from src ranks to start getting data */
414:   if (nsrcranks) {
415:     NvshmemWaitSignals<<<1, 1, 0, link->remoteCommStream>>>(nsrcranks, dstsig, 1, 0); /* wait the signals to be 1, then set them to 0 */
416:     PetscCallCUDA(cudaGetLastError());
417:   }

419:   /* dst gets data from src ranks using non-blocking nvshmem_gets, which are finished in PetscSFLinkGetDataEnd_NVSHMEM() */

421:   /* Count number of locally accessible src ranks, which should be a small number */
422:   for (int i = 0; i < nsrcranks; i++) {
423:     if (nvshmem_ptr(src, srcranks_h[i])) nLocallyAccessible++;
424:   }

426:   /* Get data from remotely accessible PEs */
427:   if (nLocallyAccessible < nsrcranks) {
428:     GetDataFromRemotelyAccessible<<<nsrcranks, 1, 0, link->remoteCommStream>>>(nsrcranks, srcranks_d, src, srcdisp_d, dst, dstdisp_d, link->unitbytes);
429:     PetscCallCUDA(cudaGetLastError());
430:   }

432:   /* Get data from locally accessible PEs */
433:   if (nLocallyAccessible) {
434:     for (int i = 0; i < nsrcranks; i++) {
435:       int pe = srcranks_h[i];
436:       if (nvshmem_ptr(src, pe)) {
437:         size_t nelems = (dstdisp_h[i + 1] - dstdisp_h[i]) * link->unitbytes;
438:         nvshmemx_getmem_nbi_on_stream(dst + (dstdisp_h[i] - dstdisp_h[0]) * link->unitbytes, src + srcdisp_h[i] * link->unitbytes, nelems, pe, link->remoteCommStream);
439:       }
440:     }
441:   }
442:   PetscFunctionReturn(PETSC_SUCCESS);
443: }

445: /* Finish the communication (can be done before Unpack)
446:    Receiver tells its senders that they are allowed to reuse their send buffer (since receiver has got data from their send buffer)
447: */
448: static PetscErrorCode PetscSFLinkGetDataEnd_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
449: {
450:   cudaError_t    cerr;
451:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
452:   uint64_t      *srcsig;
453:   PetscInt       nsrcranks, *srcsigdisp;
454:   PetscMPIInt   *srcranks;

456:   PetscFunctionBegin;
457:   if (direction == PETSCSF_ROOT2LEAF) { /* leaf ranks are getting data */
458:     nsrcranks  = sf->nRemoteRootRanks;
459:     srcsig     = link->rootSendSig; /* I want to set their root signal */
460:     srcsigdisp = sf->rootsigdisp_d; /* offset of each root signal */
461:     srcranks   = sf->ranks_d;       /* ranks of the n root ranks */
462:   } else {                          /* LEAF2ROOT, root ranks are getting data */
463:     nsrcranks  = bas->nRemoteLeafRanks;
464:     srcsig     = link->leafSendSig;
465:     srcsigdisp = bas->leafsigdisp_d;
466:     srcranks   = bas->iranks_d;
467:   }

469:   if (nsrcranks) {
470:     nvshmemx_quiet_on_stream(link->remoteCommStream); /* Finish the nonblocking get, so that we can unpack afterwards */
471:     PetscCallCUDA(cudaGetLastError());
472:     NvshmemSendSignals<<<(nsrcranks + 511) / 512, 512, 0, link->remoteCommStream>>>(nsrcranks, srcsig, srcsigdisp, srcranks, 0); /* set signals to 0 */
473:     PetscCallCUDA(cudaGetLastError());
474:   }
475:   PetscCall(PetscSFLinkBuildDependenceEnd(sf, link, direction));
476:   PetscFunctionReturn(PETSC_SUCCESS);
477: }

479: /* ===========================================================================================================

481:    A set of routines to support sender initiated communication using the put-based method (the default)

483:     The putting protocol is:

485:     Sender has a send buf (sbuf) and a send signal var (ssig);  Receiver has a stand-alone recv buf (rbuf)
486:     and a recv signal var (rsig); All signal variables have an initial value 0. rbuf is allocated by SF and
487:     is in nvshmem space.

489:     Sender:                                 |  Receiver:
490:                                             |
491:   1.  Pack data into sbuf                   |
492:   2.  Wait ssig be 0, then set it to 1      |
493:   3.  Put data to remote stand-alone rbuf   |
494:   4.  Fence // make sure 5 happens after 3  |
495:   5.  Put 1 to receiver's rsig              |   1. Wait rsig to be 1, then set it 0
496:                                             |   2. Unpack data from local rbuf
497:                                             |   3. Put 0 to sender's ssig
498:    ===========================================================================================================*/

500: /* n thread blocks. Each takes in charge one remote rank */
501: __global__ static void WaitAndPutDataToRemotelyAccessible(PetscInt ndstranks, PetscMPIInt *dstranks, char *dst, PetscInt *dstdisp, const char *src, PetscInt *srcdisp, uint64_t *srcsig, PetscInt unitbytes)
502: {
503:   int         bid = blockIdx.x;
504:   PetscMPIInt pe  = dstranks[bid];

506:   if (!nvshmem_ptr(dst, pe)) {
507:     PetscInt nelems = (srcdisp[bid + 1] - srcdisp[bid]) * unitbytes;
508:     nvshmem_uint64_wait_until(srcsig + bid, NVSHMEM_CMP_EQ, 0); /* Wait until the sig = 0 */
509:     srcsig[bid] = 1;
510:     nvshmem_putmem_nbi(dst + dstdisp[bid] * unitbytes, src + (srcdisp[bid] - srcdisp[0]) * unitbytes, nelems, pe);
511:   }
512: }

514: /* one-thread kernel, which takes in charge all locally accessible */
515: __global__ static void WaitSignalsFromLocallyAccessible(PetscInt ndstranks, PetscMPIInt *dstranks, uint64_t *srcsig, const char *dst)
516: {
517:   for (int i = 0; i < ndstranks; i++) {
518:     int pe = dstranks[i];
519:     if (nvshmem_ptr(dst, pe)) {
520:       nvshmem_uint64_wait_until(srcsig + i, NVSHMEM_CMP_EQ, 0); /* Wait until the sig = 0 */
521:       srcsig[i] = 1;
522:     }
523:   }
524: }

526: /* Put data in the given direction  */
527: static PetscErrorCode PetscSFLinkPutDataBegin_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
528: {
529:   cudaError_t    cerr;
530:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
531:   PetscInt       ndstranks, nLocallyAccessible = 0;
532:   char          *src, *dst;
533:   PetscInt      *srcdisp_h, *dstdisp_h;
534:   PetscInt      *srcdisp_d, *dstdisp_d;
535:   PetscMPIInt   *dstranks_h;
536:   PetscMPIInt   *dstranks_d;
537:   uint64_t      *srcsig;

539:   PetscFunctionBegin;
540:   PetscCall(PetscSFLinkBuildDependenceBegin(sf, link, direction));
541:   if (direction == PETSCSF_ROOT2LEAF) {                              /* put data in rootbuf to leafbuf  */
542:     ndstranks = bas->nRemoteLeafRanks;                               /* number of (remote) leaf ranks */
543:     src       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]; /* Both src & dst must be symmetric */
544:     dst       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];

546:     srcdisp_h = bas->ioffset + bas->ndiranks; /* offsets of rootbuf. srcdisp[0] is not necessarily zero */
547:     srcdisp_d = bas->ioffset_d;
548:     srcsig    = link->rootSendSig;

550:     dstdisp_h  = bas->leafbufdisp; /* for my i-th remote leaf rank, I will access its leaf buf at offset leafbufdisp[i] */
551:     dstdisp_d  = bas->leafbufdisp_d;
552:     dstranks_h = bas->iranks + bas->ndiranks; /* remote leaf ranks */
553:     dstranks_d = bas->iranks_d;
554:   } else { /* put data in leafbuf to rootbuf */
555:     ndstranks = sf->nRemoteRootRanks;
556:     src       = link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
557:     dst       = link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];

559:     srcdisp_h = sf->roffset + sf->ndranks; /* offsets of leafbuf */
560:     srcdisp_d = sf->roffset_d;
561:     srcsig    = link->leafSendSig;

563:     dstdisp_h  = sf->rootbufdisp; /* for my i-th remote root rank, I will access its root buf at offset rootbufdisp[i] */
564:     dstdisp_d  = sf->rootbufdisp_d;
565:     dstranks_h = sf->ranks + sf->ndranks; /* remote root ranks */
566:     dstranks_d = sf->ranks_d;
567:   }

569:   /* Wait for signals and then put data to dst ranks using non-blocking nvshmem_put, which are finished in PetscSFLinkPutDataEnd_NVSHMEM */

571:   /* Count number of locally accessible neighbors, which should be a small number */
572:   for (int i = 0; i < ndstranks; i++) {
573:     if (nvshmem_ptr(dst, dstranks_h[i])) nLocallyAccessible++;
574:   }

576:   /* For remotely accessible PEs, send data to them in one kernel call */
577:   if (nLocallyAccessible < ndstranks) {
578:     WaitAndPutDataToRemotelyAccessible<<<ndstranks, 1, 0, link->remoteCommStream>>>(ndstranks, dstranks_d, dst, dstdisp_d, src, srcdisp_d, srcsig, link->unitbytes);
579:     PetscCallCUDA(cudaGetLastError());
580:   }

582:   /* For locally accessible PEs, use host API, which uses CUDA copy-engines and is much faster than device API */
583:   if (nLocallyAccessible) {
584:     WaitSignalsFromLocallyAccessible<<<1, 1, 0, link->remoteCommStream>>>(ndstranks, dstranks_d, srcsig, dst);
585:     for (int i = 0; i < ndstranks; i++) {
586:       int pe = dstranks_h[i];
587:       if (nvshmem_ptr(dst, pe)) { /* If return a non-null pointer, then <pe> is locally accessible */
588:         size_t nelems = (srcdisp_h[i + 1] - srcdisp_h[i]) * link->unitbytes;
589:         /* Initiate the nonblocking communication */
590:         nvshmemx_putmem_nbi_on_stream(dst + dstdisp_h[i] * link->unitbytes, src + (srcdisp_h[i] - srcdisp_h[0]) * link->unitbytes, nelems, pe, link->remoteCommStream);
591:       }
592:     }
593:   }

595:   if (nLocallyAccessible) nvshmemx_quiet_on_stream(link->remoteCommStream); /* Calling nvshmem_fence/quiet() does not fence the above nvshmemx_putmem_nbi_on_stream! */
596:   PetscFunctionReturn(PETSC_SUCCESS);
597: }

599: /* A one-thread kernel. The thread takes in charge all remote PEs */
600: __global__ static void PutDataEnd(PetscInt nsrcranks, PetscInt ndstranks, PetscMPIInt *dstranks, uint64_t *dstsig, PetscInt *dstsigdisp)
601: {
602:   /* TODO: Shall we finished the non-blocking remote puts? */

604:   /* 1. Send a signal to each dst rank */

606:   /* According to Akhil@NVIDIA, IB is orderred, so no fence is needed for remote PEs.
607:      For local PEs, we already called nvshmemx_quiet_on_stream(). Therefore, we are good to send signals to all dst ranks now.
608:   */
609:   for (int i = 0; i < ndstranks; i++) nvshmemx_uint64_signal(dstsig + dstsigdisp[i], 1, dstranks[i]); /* set sig to 1 */

611:   /* 2. Wait for signals from src ranks (if any) */
612:   if (nsrcranks) {
613:     nvshmem_uint64_wait_until_all(dstsig, nsrcranks, NULL /*no mask*/, NVSHMEM_CMP_EQ, 1); /* wait sigs to be 1, then set them to 0 */
614:     for (int i = 0; i < nsrcranks; i++) dstsig[i] = 0;
615:   }
616: }

618: /* Finish the communication -- A receiver waits until it can access its receive buffer */
619: static PetscErrorCode PetscSFLinkPutDataEnd_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
620: {
621:   cudaError_t    cerr;
622:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
623:   PetscMPIInt   *dstranks;
624:   uint64_t      *dstsig;
625:   PetscInt       nsrcranks, ndstranks, *dstsigdisp;

627:   PetscFunctionBegin;
628:   if (direction == PETSCSF_ROOT2LEAF) { /* put root data to leaf */
629:     nsrcranks = sf->nRemoteRootRanks;

631:     ndstranks  = bas->nRemoteLeafRanks;
632:     dstranks   = bas->iranks_d;      /* leaf ranks */
633:     dstsig     = link->leafRecvSig;  /* I will set my leaf ranks's RecvSig */
634:     dstsigdisp = bas->leafsigdisp_d; /* for my i-th remote leaf rank, I will access its signal at offset leafsigdisp[i] */
635:   } else {                           /* LEAF2ROOT */
636:     nsrcranks = bas->nRemoteLeafRanks;

638:     ndstranks  = sf->nRemoteRootRanks;
639:     dstranks   = sf->ranks_d;
640:     dstsig     = link->rootRecvSig;
641:     dstsigdisp = sf->rootsigdisp_d;
642:   }

644:   if (nsrcranks || ndstranks) {
645:     PutDataEnd<<<1, 1, 0, link->remoteCommStream>>>(nsrcranks, ndstranks, dstranks, dstsig, dstsigdisp);
646:     PetscCallCUDA(cudaGetLastError());
647:   }
648:   PetscCall(PetscSFLinkBuildDependenceEnd(sf, link, direction));
649:   PetscFunctionReturn(PETSC_SUCCESS);
650: }

652: /* PostUnpack operation -- A receiver tells its senders that they are allowed to put data to here (it implies recv buf is free to take new data) */
653: static PetscErrorCode PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM(PetscSF sf, PetscSFLink link, PetscSFDirection direction)
654: {
655:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
656:   uint64_t      *srcsig;
657:   PetscInt       nsrcranks, *srcsigdisp_d;
658:   PetscMPIInt   *srcranks_d;

660:   PetscFunctionBegin;
661:   if (direction == PETSCSF_ROOT2LEAF) { /* I allow my root ranks to put data to me */
662:     nsrcranks    = sf->nRemoteRootRanks;
663:     srcsig       = link->rootSendSig; /* I want to set their send signals */
664:     srcsigdisp_d = sf->rootsigdisp_d; /* offset of each root signal */
665:     srcranks_d   = sf->ranks_d;       /* ranks of the n root ranks */
666:   } else {                            /* LEAF2ROOT */
667:     nsrcranks    = bas->nRemoteLeafRanks;
668:     srcsig       = link->leafSendSig;
669:     srcsigdisp_d = bas->leafsigdisp_d;
670:     srcranks_d   = bas->iranks_d;
671:   }

673:   if (nsrcranks) {
674:     NvshmemSendSignals<<<(nsrcranks + 255) / 256, 256, 0, link->remoteCommStream>>>(nsrcranks, srcsig, srcsigdisp_d, srcranks_d, 0); /* Set remote signals to 0 */
675:     PetscCallCUDA(cudaGetLastError());
676:   }
677:   PetscFunctionReturn(PETSC_SUCCESS);
678: }

680: /* Destructor when the link uses nvshmem for communication */
681: static PetscErrorCode PetscSFLinkDestroy_NVSHMEM(PetscSF sf, PetscSFLink link)
682: {
683:   cudaError_t cerr;

685:   PetscFunctionBegin;
686:   PetscCallCUDA(cudaEventDestroy(link->dataReady));
687:   PetscCallCUDA(cudaEventDestroy(link->endRemoteComm));
688:   PetscCallCUDA(cudaStreamDestroy(link->remoteCommStream));

690:   /* nvshmem does not need buffers on host, which should be NULL */
691:   PetscCall(PetscNvshmemFree(link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
692:   PetscCall(PetscNvshmemFree(link->leafSendSig));
693:   PetscCall(PetscNvshmemFree(link->leafRecvSig));
694:   PetscCall(PetscNvshmemFree(link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
695:   PetscCall(PetscNvshmemFree(link->rootSendSig));
696:   PetscCall(PetscNvshmemFree(link->rootRecvSig));
697:   PetscFunctionReturn(PETSC_SUCCESS);
698: }

700: PetscErrorCode PetscSFLinkCreate_NVSHMEM(PetscSF sf, MPI_Datatype unit, PetscMemType rootmtype, const void *rootdata, PetscMemType leafmtype, const void *leafdata, MPI_Op op, PetscSFOperation sfop, PetscSFLink *mylink)
701: {
702:   cudaError_t    cerr;
703:   PetscSF_Basic *bas = (PetscSF_Basic *)sf->data;
704:   PetscSFLink   *p, link;
705:   PetscBool      match, rootdirect[2], leafdirect[2];
706:   int            greatestPriority;

708:   PetscFunctionBegin;
709:   /* Check to see if we can directly send/recv root/leafdata with the given sf, sfop and op.
710:      We only care root/leafdirect[PETSCSF_REMOTE], since we never need intermediate buffers in local communication with NVSHMEM.
711:   */
712:   if (sfop == PETSCSF_BCAST) { /* Move data from rootbuf to leafbuf */
713:     if (sf->use_nvshmem_get) {
714:       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* send buffer has to be stand-alone (can't be rootdata) */
715:       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
716:     } else {
717:       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
718:       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* Our put-protocol always needs a nvshmem alloc'ed recv buffer */
719:     }
720:   } else if (sfop == PETSCSF_REDUCE) { /* Move data from leafbuf to rootbuf */
721:     if (sf->use_nvshmem_get) {
722:       rootdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(rootmtype) && bas->rootcontig[PETSCSF_REMOTE] && op == MPI_REPLACE) ? PETSC_TRUE : PETSC_FALSE;
723:       leafdirect[PETSCSF_REMOTE] = PETSC_FALSE;
724:     } else {
725:       rootdirect[PETSCSF_REMOTE] = PETSC_FALSE;
726:       leafdirect[PETSCSF_REMOTE] = (PetscMemTypeNVSHMEM(leafmtype) && sf->leafcontig[PETSCSF_REMOTE]) ? PETSC_TRUE : PETSC_FALSE;
727:     }
728:   } else {                                    /* PETSCSF_FETCH */
729:     rootdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* FETCH always need a separate rootbuf */
730:     leafdirect[PETSCSF_REMOTE] = PETSC_FALSE; /* We also force allocating a separate leafbuf so that leafdata and leafupdate can share mpi requests */
731:   }

733:   /* Look for free nvshmem links in cache */
734:   for (p = &bas->avail; (link = *p); p = &link->next) {
735:     if (link->use_nvshmem) {
736:       PetscCall(MPIPetsc_Type_compare(unit, link->unit, &match));
737:       if (match) {
738:         *p = link->next; /* Remove from available list */
739:         goto found;
740:       }
741:     }
742:   }
743:   PetscCall(PetscNew(&link));
744:   PetscCall(PetscSFLinkSetUp_Host(sf, link, unit));                                          /* Compute link->unitbytes, dup link->unit etc. */
745:   if (sf->backend == PETSCSF_BACKEND_CUDA) PetscCall(PetscSFLinkSetUp_CUDA(sf, link, unit)); /* Setup pack routines, streams etc */
746: #if defined(PETSC_HAVE_KOKKOS)
747:   else if (sf->backend == PETSCSF_BACKEND_KOKKOS) PetscCall(PetscSFLinkSetUp_Kokkos(sf, link, unit));
748: #endif

750:   link->rootdirect[PETSCSF_LOCAL] = PETSC_TRUE; /* For the local part we directly use root/leafdata */
751:   link->leafdirect[PETSCSF_LOCAL] = PETSC_TRUE;

753:   /* Init signals to zero */
754:   if (!link->rootSendSig) PetscCall(PetscNvshmemCalloc(bas->nRemoteLeafRanksMax * sizeof(uint64_t), (void **)&link->rootSendSig));
755:   if (!link->rootRecvSig) PetscCall(PetscNvshmemCalloc(bas->nRemoteLeafRanksMax * sizeof(uint64_t), (void **)&link->rootRecvSig));
756:   if (!link->leafSendSig) PetscCall(PetscNvshmemCalloc(sf->nRemoteRootRanksMax * sizeof(uint64_t), (void **)&link->leafSendSig));
757:   if (!link->leafRecvSig) PetscCall(PetscNvshmemCalloc(sf->nRemoteRootRanksMax * sizeof(uint64_t), (void **)&link->leafRecvSig));

759:   link->use_nvshmem = PETSC_TRUE;
760:   link->rootmtype   = PETSC_MEMTYPE_DEVICE; /* Only need 0/1-based mtype from now on */
761:   link->leafmtype   = PETSC_MEMTYPE_DEVICE;
762:   /* Overwrite some function pointers set by PetscSFLinkSetUp_CUDA */
763:   link->Destroy = PetscSFLinkDestroy_NVSHMEM;
764:   if (sf->use_nvshmem_get) { /* get-based protocol */
765:     link->PrePack             = PetscSFLinkWaitSignalsOfCompletionOfGettingData_NVSHMEM;
766:     link->StartCommunication  = PetscSFLinkGetDataBegin_NVSHMEM;
767:     link->FinishCommunication = PetscSFLinkGetDataEnd_NVSHMEM;
768:   } else { /* put-based protocol */
769:     link->StartCommunication  = PetscSFLinkPutDataBegin_NVSHMEM;
770:     link->FinishCommunication = PetscSFLinkPutDataEnd_NVSHMEM;
771:     link->PostUnpack          = PetscSFLinkSendSignalsToAllowPuttingData_NVSHMEM;
772:   }

774:   PetscCallCUDA(cudaDeviceGetStreamPriorityRange(NULL, &greatestPriority));
775:   PetscCallCUDA(cudaStreamCreateWithPriority(&link->remoteCommStream, cudaStreamNonBlocking, greatestPriority));

777:   PetscCallCUDA(cudaEventCreateWithFlags(&link->dataReady, cudaEventDisableTiming));
778:   PetscCallCUDA(cudaEventCreateWithFlags(&link->endRemoteComm, cudaEventDisableTiming));

780: found:
781:   if (rootdirect[PETSCSF_REMOTE]) {
782:     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char *)rootdata + bas->rootstart[PETSCSF_REMOTE] * link->unitbytes;
783:   } else {
784:     if (!link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) PetscCall(PetscNvshmemMalloc(bas->rootbuflen_rmax * link->unitbytes, (void **)&link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
785:     link->rootbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->rootbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
786:   }

788:   if (leafdirect[PETSCSF_REMOTE]) {
789:     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = (char *)leafdata + sf->leafstart[PETSCSF_REMOTE] * link->unitbytes;
790:   } else {
791:     if (!link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]) PetscCall(PetscNvshmemMalloc(sf->leafbuflen_rmax * link->unitbytes, (void **)&link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE]));
792:     link->leafbuf[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE] = link->leafbuf_alloc[PETSCSF_REMOTE][PETSC_MEMTYPE_DEVICE];
793:   }

795:   link->rootdirect[PETSCSF_REMOTE] = rootdirect[PETSCSF_REMOTE];
796:   link->leafdirect[PETSCSF_REMOTE] = leafdirect[PETSCSF_REMOTE];
797:   link->rootdata                   = rootdata; /* root/leafdata are keys to look up links in PetscSFXxxEnd */
798:   link->leafdata                   = leafdata;
799:   link->next                       = bas->inuse;
800:   bas->inuse                       = link;
801:   *mylink                          = link;
802:   PetscFunctionReturn(PETSC_SUCCESS);
803: }

805: #if defined(PETSC_USE_REAL_SINGLE)
806: PetscErrorCode PetscNvshmemSum(PetscInt count, float *dst, const float *src)
807: {
808:   PetscMPIInt num; /* Assume nvshmem's int is MPI's int */

810:   PetscFunctionBegin;
811:   PetscCall(PetscMPIIntCast(count, &num));
812:   nvshmemx_float_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
813:   PetscFunctionReturn(PETSC_SUCCESS);
814: }

816: PetscErrorCode PetscNvshmemMax(PetscInt count, float *dst, const float *src)
817: {
818:   PetscMPIInt num;

820:   PetscFunctionBegin;
821:   PetscCall(PetscMPIIntCast(count, &num));
822:   nvshmemx_float_max_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
823:   PetscFunctionReturn(PETSC_SUCCESS);
824: }
825: #elif defined(PETSC_USE_REAL_DOUBLE)
826: PetscErrorCode PetscNvshmemSum(PetscInt count, double *dst, const double *src)
827: {
828:   PetscMPIInt num;

830:   PetscFunctionBegin;
831:   PetscCall(PetscMPIIntCast(count, &num));
832:   nvshmemx_double_sum_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
833:   PetscFunctionReturn(PETSC_SUCCESS);
834: }

836: PetscErrorCode PetscNvshmemMax(PetscInt count, double *dst, const double *src)
837: {
838:   PetscMPIInt num;

840:   PetscFunctionBegin;
841:   PetscCall(PetscMPIIntCast(count, &num));
842:   nvshmemx_double_max_reduce_on_stream(NVSHMEM_TEAM_WORLD, dst, src, num, PetscDefaultCudaStream);
843:   PetscFunctionReturn(PETSC_SUCCESS);
844: }
845: #endif