From ee325f02157cd23b37059e3dce5fb17cb1c1b137 Mon Sep 17 00:00:00 2001
From: dbaranchuk <dmitrybaranchuk@gmail.com>
Date: Sun, 11 Sep 2022 06:18:44 +0300
Subject: clarified an exception message

---
 bitsandbytes/autograd/_functions.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

(limited to 'bitsandbytes/autograd')

diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 271c690..008655d 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -373,7 +373,7 @@ class MatMul8bitLt(torch.autograd.Function):
                 grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
 
         if req_gradA:
-            if state.CBt:
+            if state.CBt is not None:
                 C32grad, Sgrad = F.transform(Cgrad, "col32")
                 if state.CxBt is None:
                     state.CxBt, state.SBt = F.transform(
@@ -381,13 +381,13 @@ class MatMul8bitLt(torch.autograd.Function):
                     )
                 gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
                 grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
-            elif state.CB:
+            elif state.CB is not None:
                 CB = state.CB.half()
                 SCB = (state.SCB.unsqueeze(1) / 127.0).half()
                 CB *= SCB
                 grad_A = torch.mm(grad_output, CB).view(ctx.grad_shape)
             else:
-                raise Exception('State must contain either CBt or CB matrix')
+                raise Exception('State must contain either CBt or CB matrix for backward')
 
         if req_gradBias:
             grad_bias = grad_output.sum(0)
-- 
cgit v1.2.3