diff options
author | Tim Dettmers <dettmers@cs.washington.edu> | 2021-10-20 18:37:44 -0700 |
---|---|---|
committer | Tim Dettmers <dettmers@cs.washington.edu> | 2021-10-20 18:37:44 -0700 |
commit | bb34fd50a1fec74e62beb6e23d51f0142c7d0ab6 (patch) | |
tree | a01ed945c348027480a9d0cefb6698dfd7259fb1 /csrc/ops.cuh | |
parent | 8400b58cbbc06e0a434cfa71f76c2efd713473fc (diff) |
Initial plumbing for skip_zeros.
Diffstat (limited to 'csrc/ops.cuh')
-rw-r--r-- | csrc/ops.cuh | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/csrc/ops.cuh b/csrc/ops.cuh index e6033cb..465b4a4 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -49,7 +49,7 @@ template<typename T> void dequantizeBlockwise(float *code, unsigned char *A, flo template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, float beta1, float beta2, float eps, float weight_decay, - int step, float lr, const float gnorm_scale, int n); + int step, float lr, const float gnorm_scale, bool skip_zeros, int n); template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigned char* state1, unsigned char* state2, float *unorm, float max_unorm, float param_norm, @@ -62,7 +62,8 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, unsigne template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, int n); + float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, + bool skip_zeros, int n); template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n); |