diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-11-03 19:49:50 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-11-03 19:49:50 -0700 |
commit | 1efb87d89d1c3fe532eb97847c3b48fd1a8e5d83 (patch) | |
tree | dd6b1ca29464d6c419b5c169f3d5ea946e7fce50 /bitsandbytes/functional.py | |
parent | 8d87c0b85214c07756b5dcdb09ceb26b0bb1cb7a (diff) |
Added FP8 quantization map.
Diffstat (limited to 'bitsandbytes/functional.py')
-rw-r--r-- | bitsandbytes/functional.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index c104ebd..d7e186f 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -6,6 +6,7 @@ import ctypes as ct import operator import random import torch +import itertools from typing import Tuple from torch import Tensor @@ -136,6 +137,39 @@ def create_linear_map(signed=True): return torch.linspace(0.0, 1.0, 256) +def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2): + e = exponent_bits + p = precision_bits + assert e+p == 7 + # the exponent is biased to 2^(e-1) -1 == 0 + evalues = [] + pvalues = [] + for i, val in enumerate(range(-((2**(exponent_bits-1))), 2**(exponent_bits-1), 1)): + evalues.append(2**val) + + + lst = list(itertools.product([0, 1], repeat=precision_bits)) + for bit_pattern in lst: + value = 1 + for i, pval in enumerate(list(bit_pattern)): + value += pval*(2**-(i+1)) + pvalues.append(value) + + assert len(evalues)*len(pvalues) == 128 + values = [] + for ev in evalues: + for pv in pvalues: + values.append(-ev*pv) + values.append(ev*pv) + values.sort() + code = torch.Tensor(values) + code /= code.max() + code[127] = 0 + + return code + + + def create_dynamic_map(signed=True, n=7): """ Creates the dynamic quantiztion map. |