Actual source code: curand.c

  1: #include <petsc/private/deviceimpl.h>
  2: #include <petsc/private/randomimpl.h>
  3: #include <petscdevice_cuda.h>

  5: typedef struct {
  6:   curandGenerator_t gen;
  7: } PetscRandom_CURAND;

  9: static PetscErrorCode PetscRandomSeed_CURAND(PetscRandom r)
 10: {
 11:   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;

 13:   PetscFunctionBegin;
 14:   PetscCallCURAND(curandSetPseudoRandomGeneratorSeed(curand->gen, r->seed));
 15:   PetscFunctionReturn(PETSC_SUCCESS);
 16: }

 18: PETSC_INTERN PetscErrorCode PetscRandomCurandScale_Private(PetscRandom, size_t, PetscReal *, PetscBool);

 20: static PetscErrorCode PetscRandomGetValuesReal_CURAND(PetscRandom r, PetscInt n, PetscReal *val)
 21: {
 22:   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;
 23:   size_t              nn     = n < 0 ? (size_t)(-2 * n) : (size_t)n; /* handle complex case */

 25:   PetscFunctionBegin;
 26: #if defined(PETSC_USE_REAL_SINGLE)
 27:   PetscCallCURAND(curandGenerateUniform(curand->gen, val, nn));
 28: #else
 29:   PetscCallCURAND(curandGenerateUniformDouble(curand->gen, val, nn));
 30: #endif
 31:   if (r->iset) PetscCall(PetscRandomCurandScale_Private(r, nn, val, (PetscBool)(n < 0)));
 32:   PetscFunctionReturn(PETSC_SUCCESS);
 33: }

 35: static PetscErrorCode PetscRandomGetValues_CURAND(PetscRandom r, PetscInt n, PetscScalar *val)
 36: {
 37:   PetscFunctionBegin;
 38: #if defined(PETSC_USE_COMPLEX)
 39:   /* pass negative size to flag complex scaling (if needed) */
 40:   PetscCall(PetscRandomGetValuesReal_CURAND(r, -n, (PetscReal *)val));
 41: #else
 42:   PetscCall(PetscRandomGetValuesReal_CURAND(r, n, val));
 43: #endif
 44:   PetscFunctionReturn(PETSC_SUCCESS);
 45: }

 47: static PetscErrorCode PetscRandomDestroy_CURAND(PetscRandom r)
 48: {
 49:   PetscRandom_CURAND *curand = (PetscRandom_CURAND *)r->data;

 51:   PetscFunctionBegin;
 52:   PetscCallCURAND(curandDestroyGenerator(curand->gen));
 53:   PetscCall(PetscFree(r->data));
 54:   PetscFunctionReturn(PETSC_SUCCESS);
 55: }

 57: static struct _PetscRandomOps PetscRandomOps_Values = {
 58:   PetscDesignatedInitializer(seed, PetscRandomSeed_CURAND),
 59:   PetscDesignatedInitializer(getvalue, NULL),
 60:   PetscDesignatedInitializer(getvaluereal, NULL),
 61:   PetscDesignatedInitializer(getvalues, PetscRandomGetValues_CURAND),
 62:   PetscDesignatedInitializer(getvaluesreal, PetscRandomGetValuesReal_CURAND),
 63:   PetscDesignatedInitializer(destroy, PetscRandomDestroy_CURAND),
 64: };

 66: /*MC
 67:    PETSCCURAND - access to the CUDA random number generator from a `PetscRandom` object

 69:   Level: beginner

 71:   Note:
 72:   This random number generator is available when PETSc is configured with ``./configure --with-cuda=1``

 74: .seealso: `PetscRandomCreate()`, `PetscRandomSetType()`, `PetscRandomType`
 75: M*/

 77: PETSC_EXTERN PetscErrorCode PetscRandomCreate_CURAND(PetscRandom r)
 78: {
 79:   PetscRandom_CURAND *curand;
 80:   PetscDeviceContext  dctx;
 81:   cudaStream_t       *stream;

 83:   PetscFunctionBegin;
 84:   PetscCall(PetscDeviceInitialize(PETSC_DEVICE_CUDA));
 85:   PetscCall(PetscDeviceContextGetCurrentContextAssertType_Internal(&dctx, PETSC_DEVICE_CUDA));
 86:   PetscCall(PetscDeviceContextGetStreamHandle(dctx, (void **)&stream));
 87:   PetscCall(PetscNew(&curand));
 88:   PetscCallCURAND(curandCreateGenerator(&curand->gen, CURAND_RNG_PSEUDO_DEFAULT));
 89:   PetscCallCURAND(curandSetStream(curand->gen, *stream));
 90:   /* https://docs.nvidia.com/cuda/curand/host-api-overview.html#performance-notes2 */
 91:   PetscCallCURAND(curandSetGeneratorOrdering(curand->gen, CURAND_ORDERING_PSEUDO_SEEDED));
 92:   r->ops[0] = PetscRandomOps_Values;
 93:   PetscCall(PetscObjectChangeTypeName((PetscObject)r, PETSCCURAND));
 94:   r->data = curand;
 95:   PetscCall(PetscRandomSeed_CURAND(r));
 96:   PetscFunctionReturn(PETSC_SUCCESS);
 97: }