diff options
author | Tim Dettmers <dettmers@g3036.hyak.local> | 2021-11-10 15:10:02 -0800 |
---|---|---|
committer | Tim Dettmers <dettmers@g3036.hyak.local> | 2021-11-10 15:10:02 -0800 |
commit | 8b3c0f355c779170d55a1975df981df9e53b59fa (patch) | |
tree | 0ebc5f8e869fb02e7dec90f809fbf07d778f9aca /csrc/ops.cu | |
parent | 22b2877c7f8277317a073ea7cf49231d33fe79fd (diff) |
Added adagrad with tests (no clipping).
Diffstat (limited to 'csrc/ops.cu')
-rw-r--r-- | csrc/ops.cu | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/csrc/ops.cu b/csrc/ops.cu index 182d6e6..9691241 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -199,6 +199,8 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, break; case MOMENTUM: case RMSPROP: + case ADAGRAD: + if(max_unorm > 0.0f) { CUDA_CHECK_RETURN(cudaMemset(unorm, 0, 1*sizeof(float))); @@ -240,6 +242,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, break; case MOMENTUM: case RMSPROP: + case ADAGRAD: CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float))); kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<blocks, 256>>>(p, g, state1, unorm, beta1, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); @@ -274,6 +277,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g break; case MOMENTUM: case RMSPROP: + case ADAGRAD: blocks = n/BLOCKSIZE_1STATE; blocks = n % BLOCKSIZE_1STATE == 0 ? blocks : blocks + 1; kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr, @@ -321,6 +325,8 @@ MAKE_optimizer32bit(MOMENTUM, half) MAKE_optimizer32bit(MOMENTUM, float) MAKE_optimizer32bit(RMSPROP, half) MAKE_optimizer32bit(RMSPROP, float) +MAKE_optimizer32bit(ADAGRAD, half) +MAKE_optimizer32bit(ADAGRAD, float) #define MAKE_optimizerStatic8bit(name, gtype) \ template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \ @@ -350,6 +356,8 @@ MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM); MAKE_optimizerStatic8bitBlockwise(half, RMSPROP); MAKE_optimizerStatic8bitBlockwise(float, RMSPROP); +MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD); +MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD); template void percentileClipping(float * g, float *gnorm_vec, int step, const int n); template void percentileClipping(half * g, float *gnorm_vec, int step, const int n); |