// cudamatrix/cu-common.h

// Copyright 2009-2011  Karel Vesely
//                      Johns Hopkins University (author: Daniel Povey)

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.


#ifndef KALDI_CUDAMATRIX_CU_COMMON_H_
#define KALDI_CUDAMATRIX_CU_COMMON_H_

#include <iostream>
#include <sstream>

#include "base/kaldi-error.h"
#include "cudamatrix/cu-matrixdim.h" // for CU1DBLOCK and CU2DBLOCK
#include "matrix/matrix-common.h"

#if HAVE_CUDA

#ifdef __IS_HIP_COMPILE__
#include <hip/hip_runtime_api.h>
#include <hipblas/hipblas.h>
#include <hiprand/hiprand.h>
#include <hipsparse/hipsparse.h>
#include <roctracer/roctx.h>

#include "hipify.h"
#else
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <curand.h>
#include <cusparse.h>
#include <nvtx3/nvToolsExt.h>

#define GPU_WARP_SIZE 32
#define GPU_MAX_THREADS_PER_BLOCK 1024
#define GPU_MAX_WARPS_PER_BLOCK (GPU_MAX_THREADS_PER_BLOCK / GPU_WARP_SIZE)
#endif

#define CU_SAFE_CALL(fun) \
{ \
  int32 ret; \
  if ((ret = (fun)) != 0) { \
    KALDI_ERR << "cudaError_t " << ret << " : \"" << cudaGetErrorString((cudaError_t)ret) << "\" returned from '" << #fun << "'"; \
  } \
}

#define CUFFT_SAFE_CALL(fun) \
{ \
  int32 ret; \
  if ((ret = (fun)) != CUFFT_SUCCESS) { \
    KALDI_ERR << "cufftResult " << ret << " returned from '" << #fun << "'"; \
  } \
}

#define CUBLAS_SAFE_CALL(fun) \
{ \
  int32 ret; \
  if ((ret = (fun)) != 0) { \
    KALDI_ERR << "cublasStatus_t " << ret << " : \"" << cublasGetStatusStringK((cublasStatus_t)ret) << "\" returned from '" << #fun << "'"; \
  } \
}

#define CUSOLVER_SAFE_CALL(fun) \
{ \
  int32 ret; \
  if ((ret = (fun)) != 0) { \
    KALDI_ERR << "cusolverStatus_t " << ret << " : \"" << ret << "\" returned from '" << #fun << "'"; \
  } \
}


#define CUSPARSE_SAFE_CALL(fun) \
{ \
  int32 ret; \
  if ((ret = (fun)) != 0) { \
    KALDI_ERR << "cusparseStatus_t " << ret << " : \"" << cusparseGetStatusString((cusparseStatus_t)ret) << "\" returned from '" << #fun << "'"; \
  } \
}

#define CURAND_SAFE_CALL(fun) \
{ \
  int32 ret; \
  if ((ret = (fun)) != 0) { \
    KALDI_ERR << "curandStatus_t " << ret << " : \"" << curandGetStatusString((curandStatus_t)ret) << "\" returned from '" << #fun << "'"; \
  } \
}

#define KALDI_CUDA_ERR(ret, msg) \
{ \
  if (ret != 0) { \
    KALDI_ERR << msg << ", diagnostics: cudaError_t " << ret << " : \"" << cudaGetErrorString((cudaError_t)ret) << "\", in " << __FILE__ << ":" << __LINE__; \
  } \
}


namespace kaldi {

#ifdef USE_NVTX
class NvtxTracer {
public:
    NvtxTracer(const char* name);
    ~NvtxTracer();
};
#define NVTX_RANGE(name) NvtxTracer uniq_name_using_macros(name);
#else
#define NVTX_RANGE(name)
#endif

/** Number of blocks in which the task of size 'size' is splitted **/
inline int32 n_blocks(int32 size, int32 block_size) {
  return size / block_size + ((size % block_size == 0)? 0 : 1);
}

cublasOperation_t KaldiTransToCuTrans(MatrixTransposeType kaldi_trans);


/*
  This function gives you suitable dimBlock and dimGrid sizes for a simple
  matrix operation (one that applies to each element of the matrix.  The x
  indexes will be interpreted as column indexes, and the y indexes will be
  interpreted as row indexes; this is based on our interpretation of a matrix as
  being row-major, i.e.  having column-stride = 1, not based on CuBLAS's
  opposite interpretation.  There is a good reason for associating the column
  index with x and not y; this helps memory locality in adjacent kernels.
 */
void GetBlockSizesForSimpleMatrixOperation(int32 num_rows,
                                           int32 num_cols,
                                           dim3 *dimGrid,
                                           dim3 *dimBlock);

/** This is analogous to the CUDA function cudaGetErrorString(). **/
const char* cublasGetStatusStringK(cublasStatus_t status);

/** This is analogous to the CUDA function cudaGetErrorString(). **/
const char* cusparseGetStatusString(cusparseStatus_t status);

/** This is analogous to the CUDA function cudaGetErrorString(). **/
const char* curandGetStatusString(curandStatus_t status);

}  // namespace kaldi

#else  // HAVE CUDA
#define NVTX_RANGE(name)
#endif  // HAVE_CUDA

namespace kaldi {
// Some forward declarations, needed for friend declarations.
template<typename Real> class CuVectorBase;
template<typename Real> class CuVector;
template<typename Real> class CuSubVector;
template<typename Real> class CuRand;
template<typename Real> class CuMatrixBase;
template<typename Real> class CuMatrix;
template<typename Real> class CuSubMatrix;
template<typename Real> class CuPackedMatrix;
template<typename Real> class CuSpMatrix;
template<typename Real> class CuTpMatrix;
template<typename Real> class CuSparseMatrix;

template<typename Real> class CuBlockMatrix; // this has no non-CU counterpart.

}  // namespace kaldi

#endif  // KALDI_CUDAMATRIX_CU_COMMON_H_
