diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-07-22 14:41:05 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-07-22 14:41:05 -0700 |
commit | c771b3a75a6ebbfbfc398a028a477246b0799cf0 (patch) | |
tree | 158353d531766ed133be34d3c5085da6e8a4d01e /csrc/kernels.cuh | |
parent | 4cd7ea62b2f51c68aacde2f62e7141765e476111 (diff) |
Most tests passing.
Diffstat (limited to 'csrc/kernels.cuh')
-rw-r--r-- | csrc/kernels.cuh | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 0a3676c..cbfbeba 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -106,6 +106,18 @@ template<typename T, int BLOCK_SIZE, int NUM_VALS> __global__ void kPercentileCl __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index2, float *src, const int maxidx1, const int n); + +template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB); + +template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16( + int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, + half *out, float* newRowStats, float* newcolStats, const int numRows, const int numCols, const int tileCols, const int n); + +template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); +template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); + +template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); + #endif |