diff options
Diffstat (limited to 'bitsandbytes/functional.py')
-rw-r--r-- | bitsandbytes/functional.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 076414d..49d4db1 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -130,10 +130,10 @@ class Cusparse_Context(object): return cls._instance -def create_linear_map(signed=True, bits=8): +def create_linear_map(signed=True, total_bits=8): sign = (-1.0 if signed else 0.0) - values = torch.linspace(sign, 1.0, 2**bits) + values = torch.linspace(sign, 1.0, 2**total_bits) gap = 256 - values.numel() if gap == 0: return values @@ -457,6 +457,7 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra The quantization state to undo the quantization. """ + if code is None: if "dynamic" not in name2qmap: name2qmap["dynamic"] = create_dynamic_map().to(A.device) @@ -474,8 +475,11 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra out = torch.zeros_like(A, dtype=torch.uint8) if A.device.type != 'cpu': + assert blocksize in [4096, 2048, 1024, 512] is_on_gpu([code, A, absmax, out, rand]) + cblocksize = ct.c_int32(blocksize) if rand is not None: + assert blocksize==4096 assert rand.numel() >= 1024 rand_offset = random.randint(0, 1023) if A.dtype == torch.float32: @@ -483,18 +487,14 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra elif A.dtype == torch.float16: lib.cquantize_blockwise_stochastic_fp16(get_ptr(code), get_ptr(A),get_ptr(absmax), get_ptr(out), get_ptr(rand), ct.c_int32(rand_offset), ct.c_int(A.numel())) else: - raise ValueError( - f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" - ) + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") else: if A.dtype == torch.float32: - lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) + lib.cquantize_blockwise_fp32(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) elif A.dtype == torch.float16: - lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out),ct.c_int(A.numel())) + lib.cquantize_blockwise_fp16(get_ptr(code), get_ptr(A), get_ptr(absmax), get_ptr(out), cblocksize, ct.c_int(A.numel())) else: - raise ValueError( - f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}" - ) + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") else: # cpu assert rand is None |