diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-07-26 19:38:01 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-07-26 19:38:01 -0700 |
commit | 5737f2b027a1e0ec8540a3aa914632d44ad9c62d (patch) | |
tree | b288c905eaba75dc6b43a8bcebc82720c16e4816 /csrc/kernels.cu | |
parent | 47a73d94c3d3284f6073b0ff189ed5bc9e3a8762 (diff) | |
parent | dc8c9efdb33130f960adc864916b67d0cb744dbb (diff) |
Merge branch 'patch_merge' into extract_outliers
Diffstat (limited to 'csrc/kernels.cu')
-rw-r--r-- | csrc/kernels.cu | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 79ad5de..d4eb56c 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1768,7 +1768,6 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_ __shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS]; __shared__ int smem_row_nnz_values[TILE_ROWS]; - //__shared__ float smem_col_absmax_values[ITEMS_PER_THREAD*THREADS]; half local_data[ITEMS_PER_THREAD]; float local_data_fp32[ITEMS_PER_THREAD]; @@ -1828,13 +1827,14 @@ template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_ local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j])); // 3. compute row max (per block); store in smem to accumulate full global mem transation - __syncthreads(); // this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units) #pragma unroll ITEMS_PER_THREAD for(int j = 0; j < ITEMS_PER_THREAD; j++) local_data_fp32[j] = local_data[j]; + __syncthreads(); + row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max()); if(SPARSE_DECOMP) { @@ -2166,7 +2166,6 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T __shared__ char smem_data[32*33*ITEMS_PER_THREAD]; char local_data[ITEMS_PER_THREAD]; typedef cub::BlockExchange<char, THREADS, ITEMS_PER_THREAD> BlockExchange; - __shared__ typename BlockExchange::TempStorage temp_storage; // we load row after row from the base_position // Load data row by row @@ -2446,7 +2445,7 @@ template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int T #define MAX_SPARSE_COUNT 32 #define SMEM_SIZE 8*256 template <typename T, int SPMM_ITEMS, int BITS> -__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz, int rowsA, int rowsB, int colsB) +__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) { // 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block @@ -2500,7 +2499,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o { for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x) if((idx_col_B+i-local_idx_col_B_offset) < colsB) - smem_dequant_stats[i] = __ldg(&dequant_stats[idx_col_B+i-local_idx_col_B_offset]); + smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset]; __syncthreads(); } |