diff options
author | Tim Dettmers <dettmers@cs.washington.edu> | 2021-10-20 19:15:47 -0700 |
---|---|---|
committer | Tim Dettmers <dettmers@cs.washington.edu> | 2021-10-20 19:15:47 -0700 |
commit | a6eae2e7f2bf03f268fcb6b055201ff6827684c4 (patch) | |
tree | d2f72792251c9feaef1cf9dcddc3c79e6312a93a /bitsandbytes/optim | |
parent | bb34fd50a1fec74e62beb6e23d51f0142c7d0ab6 (diff) |
Added skip_zeros; tests are passing.
Diffstat (limited to 'bitsandbytes/optim')
-rw-r--r-- | bitsandbytes/optim/optimizer.py | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index 25512b1..4b70b5c 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -336,7 +336,7 @@ class Optimizer2State(Optimizer8bit): if state['state1'].dtype == torch.float: F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], state['state2'], config['betas'][1], config['weight_decay'], gnorm_scale, - state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm']) + state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], skip_zeros=config['skip_zeros']) elif state['state1'].dtype == torch.uint8 and not config['block_wise']: F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1], @@ -352,7 +352,7 @@ class Optimizer2State(Optimizer8bit): F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], state['state2'], config['betas'][0], config['betas'][1], config['eps'], step, config['lr'], state['qmap1'], state['qmap2'], state['absmax1'], state['absmax2'], - config['weight_decay'], gnorm_scale=gnorm_scale) + config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros']) class Optimizer1State(Optimizer8bit): @@ -450,7 +450,7 @@ class Optimizer1State(Optimizer8bit): F.optimizer_update_32bit(self.optimizer_name, grad, p, state['state1'], config['betas'][0], config['eps'], step, config['lr'], None, 0.0, config['weight_decay'], gnorm_scale, state['unorm_vec'] if config['max_unorm'] > 0.0 else None, max_unorm=config['max_unorm'], - skip_zeros=False) + skip_zeros=config['skip_zeros']) elif state['state1'].dtype == torch.uint8 and not config['block_wise']: F.optimizer_update_8bit(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], @@ -463,4 +463,4 @@ class Optimizer1State(Optimizer8bit): F.optimizer_update_8bit_blockwise(self.optimizer_name, grad, p, state['state1'], None, config['betas'][0], config['betas'][1], config['eps'], step, config['lr'], state['qmap1'], None, state['absmax1'], None, - config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=False) + config['weight_decay'], gnorm_scale=gnorm_scale, skip_zeros=config['skip_zeros']) |