diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-04 08:03:00 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-04 08:03:00 -0700 |
commit | 758c7175a24df307c40b743b1def8b4c34f68674 (patch) | |
tree | d7046117149950c2e97a5af6bd99d87f7688a357 /bitsandbytes/autograd | |
parent | 96bc209baf55f2e05e649e555c2de5fc478c24dc (diff) | |
parent | ab72a1294fda03a0fd4ec297562fdab806349752 (diff) |
Merge branch 'debug' into cuda-bin-switch-and-cli
Diffstat (limited to 'bitsandbytes/autograd')
-rw-r--r-- | bitsandbytes/autograd/_functions.py | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index b56b2ee..14f2660 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -1,7 +1,7 @@ from dataclasses import dataclass import torch - +import math import bitsandbytes as bnb import bitsandbytes.functional as F @@ -199,6 +199,17 @@ class MatmulLtState: class MatMul8bitLt(torch.autograd.Function): @staticmethod def forward(ctx, A, B, out=None, state=MatmulLtState()): + # default to pytorch behavior if inputs are empty + ctx.is_empty = False + if math.prod(A.shape) == 0: + ctx.is_empty = True + ctx.A = A + ctx.B = B + if A.shape[-1] == B.shape[0]: + return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device) + else: + return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device) + # 1. Quantize A # 2. Quantize B # 3. Matmul @@ -339,6 +350,8 @@ class MatMul8bitLt(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): + if ctx.is_empty: + return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None req_gradA, req_gradB = ctx.req_grads CAt, subA = ctx.tensors SCAt, idx = ctx.tensor_states @@ -375,7 +388,7 @@ class MatMul8bitLt(torch.autograd.Function): ctx.grad_shape ) - return grad_A, grad_B, None, None, None, None, None + return grad_A, grad_B, None, None matmul = MatMul8bitLt.apply |