diff options
Diffstat (limited to 'tests/test_autograd.py')
-rw-r--r-- | tests/test_autograd.py | 273 |
1 files changed, 198 insertions, 75 deletions
diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 1b6c2ab..8ebe8c8 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -1,27 +1,44 @@ -import pytest +from itertools import product +import pytest import torch -import bitsandbytes as bnb -from itertools import product +import bitsandbytes as bnb n = 1 k = 25 -dim1 = torch.randint(16,64, size=(n,)).tolist() -dim2 = torch.randint(32,96, size=(n,)).tolist() -dim3 = torch.randint(32,96, size=(n,)).tolist() -dim4 = torch.randint(32,96, size=(n,)).tolist() +dim1 = torch.randint(16, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 96, size=(n,)).tolist() +dim3 = torch.randint(32, 96, size=(n,)).tolist() +dim4 = torch.randint(32, 96, size=(n,)).tolist() funcs = [(torch.bmm, bnb.bmm_cublas), (torch.matmul, bnb.matmul_cublas)] -str_funcs = ['bmm', 'matmul'] +str_funcs = ["bmm", "matmul"] req_grad = [(False, False), (True, False), (True, True), (False, True)] -req_grad_str = ['FF', 'TF', 'TT', 'FT'] +req_grad_str = ["FF", "TF", "TT", "FT"] transpose = [(False, False), (False, True), (True, True), (True, False)] -str_transpose = ['FF', 'FT', 'TT', 'TF'] +str_transpose = ["FF", "FT", "TT", "TF"] dtype = [torch.float32, torch.float16] -values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose)) -str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}'.format(*vals) for vals in str_values] -@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", values, ids=names) +values = list( + product(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose) +) +str_values = list( + product( + dim1, dim2, dim3, dim4, str_funcs, dtype, req_grad_str, str_transpose + ) +) +names = [ + "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}".format( + *vals + ) + for vals in str_values +] + + +@pytest.mark.parametrize( + "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose", + values, + ids=names, +) def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): if dim2 > 0: dim2 = dim2 - (dim2 % 16) @@ -33,9 +50,11 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): if funcs[0] in [torch.mm, torch.matmul]: dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) - A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0]) - B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1]) - target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1]) + A = torch.randn(size=dimA, device="cuda", requires_grad=req_grad[0]) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) + target = torch.randn( + size=(dim2, dim4), device="cuda", requires_grad=req_grad[1] + ) torch.nn.init.xavier_uniform_(B) if not transpose[0] and not transpose[1]: @@ -53,9 +72,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): n = out_bnb.numel() idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx==0).sum().item() < n*0.0175 + assert (idx == 0).sum().item() < n * 0.0175 idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) - assert (idx==0).sum().item() < n*0.001 + assert (idx == 0).sum().item() < n * 0.001 if any(req_grad): out_bnb.data.copy_(out_torch) @@ -67,7 +86,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -75,20 +96,36 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_allclose( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - assert (idx==0).sum().item() < n*0.1 + assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx==0).sum().item() < n*0.02 - torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) + assert (idx == 0).sum().item() < n * 0.02 + torch.testing.assert_allclose( + gradB1, gradB2, atol=0.18, rtol=0.3 + ) # batched matrix multiply if funcs[0] in [torch.bmm, torch.matmul]: - A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0]) - B = torch.randn(size=(dim1, dim3, dim4), device='cuda', requires_grad=req_grad[1]) - target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1]) + A = torch.randn( + size=(dim1, dim2, dim3), + device="cuda", + requires_grad=req_grad[0], + ) + B = torch.randn( + size=(dim1, dim3, dim4), + device="cuda", + requires_grad=req_grad[1], + ) + target = torch.randn( + size=(dim1, dim2, dim4), + device="cuda", + requires_grad=req_grad[1], + ) torch.nn.init.xavier_uniform_(B) out_torch = funcs[0](A, B) @@ -96,8 +133,10 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): n = out_bnb.numel() idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx==0).sum().item() < n*0.01 - torch.testing.assert_allclose(out_bnb, out_torch, atol=0.027, rtol=0.2) + assert (idx == 0).sum().item() < n * 0.01 + torch.testing.assert_allclose( + out_bnb, out_torch, atol=0.027, rtol=0.2 + ) if any(req_grad): out_bnb.data.copy_(out_torch) @@ -109,7 +148,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -117,20 +158,30 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_allclose( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - assert (idx==0).sum().item() < n*0.1 + assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx==0).sum().item() < n*0.02 + assert (idx == 0).sum().item() < n * 0.02 if funcs[0] in [torch.matmul]: dim1 = dim1 - (dim1 % 16) - A = torch.randn(size=(dim1, dim2, dim3), device='cuda', requires_grad=req_grad[0]) + A = torch.randn( + size=(dim1, dim2, dim3), + device="cuda", + requires_grad=req_grad[0], + ) dimB = (dim4, dim3) if transpose[1] else (dim3, dim4) - B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1]) - target = torch.randn(size=(dim1, dim2, dim4), device='cuda', requires_grad=req_grad[1]) + B = torch.randn(size=dimB, device="cuda", requires_grad=req_grad[1]) + target = torch.randn( + size=(dim1, dim2, dim4), + device="cuda", + requires_grad=req_grad[1], + ) torch.nn.init.xavier_uniform_(B) if transpose[1]: @@ -142,9 +193,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): n = out_bnb.numel() idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx==0).sum().item() < n*0.0175 + assert (idx == 0).sum().item() < n * 0.0175 idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) - assert (idx==0).sum().item() < n*0.001 + assert (idx == 0).sum().item() < n * 0.001 if any(req_grad): out_bnb.data.copy_(out_torch) @@ -156,7 +207,9 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -164,56 +217,111 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose): B.grad = None if req_grad[0]: - torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_allclose( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - assert (idx==0).sum().item() < n*0.1 + assert (idx == 0).sum().item() < n * 0.1 idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx==0).sum().item() < n*0.02 + assert (idx == 0).sum().item() < n * 0.02 n = 1 k = 3 -dim1 = torch.randint(16,64, size=(n,)).tolist() -dim2 = torch.randint(32,96, size=(n,)).tolist() -dim3 = torch.randint(32,96, size=(n,)).tolist() -dim4 = torch.randint(32,96, size=(n,)).tolist() +dim1 = torch.randint(16, 64, size=(n,)).tolist() +dim2 = torch.randint(32, 96, size=(n,)).tolist() +dim3 = torch.randint(32, 96, size=(n,)).tolist() +dim4 = torch.randint(32, 96, size=(n,)).tolist() dim2.append(0) -#dim1 = (17,) -#dim2 = (7,) -#dim3 = (37,) -#dim4 = (23,) decomp = [0.0, 6.0] funcs = [(torch.matmul, bnb.matmul)] -str_funcs = ['matmul'] +str_funcs = ["matmul"] req_grad = [(False, False), (True, False), (True, True), (False, True)] -req_grad_str = ['FF', 'TF', 'TT', 'FT'] +req_grad_str = ["FF", "TF", "TT", "FT"] transpose = [(False, True), (False, False)] -str_transpose = ['NT', 'NN'] +str_transpose = ["NT", "NN"] dtype = [torch.float16] has_fp16_weights = [True, False] -values = list(product(dim1,dim2,dim3,dim4,funcs, dtype, req_grad, transpose, decomp, has_fp16_weights)) -str_values = list(product(dim1,dim2,dim3,dim4,str_funcs, dtype, req_grad_str, str_transpose, decomp, has_fp16_weights)) -names = ['dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}'.format(*vals) for vals in str_values] -@pytest.mark.parametrize("dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", values, ids=names) -def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights): +values = list( + product( + dim1, + dim2, + dim3, + dim4, + funcs, + dtype, + req_grad, + transpose, + decomp, + has_fp16_weights, + ) +) +str_values = list( + product( + dim1, + dim2, + dim3, + dim4, + str_funcs, + dtype, + req_grad_str, + str_transpose, + decomp, + has_fp16_weights, + ) +) +names = [ + "dim1_{0}_dim2_{1}_dim3_{2}_dim4_{3}_func_{4}_dtype_{5}_requires_grad_{6}_transpose_{7}_decomp_{8}_has_fp16_weights_{9}".format( + *vals + ) + for vals in str_values +] + + +@pytest.mark.parametrize( + "dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights", + values, + ids=names, +) +def test_matmullt( + dim1, + dim2, + dim3, + dim4, + funcs, + dtype, + req_grad, + transpose, + decomp, + has_fp16_weights, +): dimA = (dim2, dim3) if not transpose[0] else (dim3, dim2) dimB = (dim3, dim4) if not transpose[1] else (dim4, dim3) - outlier_dim = torch.randint(0, dimA[1], size=(dimA[1]//8,), device='cuda') + outlier_dim = torch.randint(0, dimA[1], size=(dimA[1] // 8,), device="cuda") for i in range(k): # normal multiply if funcs[0] in [torch.mm, torch.matmul]: - A = torch.randn(size=dimA, device='cuda', requires_grad=req_grad[0], dtype=dtype) + A = torch.randn( + size=dimA, device="cuda", requires_grad=req_grad[0], dtype=dtype + ) if decomp == 6.0: with torch.no_grad(): A[:, outlier_dim] = 6.0 - B = torch.randn(size=dimB, device='cuda', requires_grad=req_grad[1], dtype=dtype) - target = torch.randn(size=(dim2, dim4), device='cuda', requires_grad=req_grad[1], dtype=dtype) + B = torch.randn( + size=dimB, device="cuda", requires_grad=req_grad[1], dtype=dtype + ) + target = torch.randn( + size=(dim2, dim4), + device="cuda", + requires_grad=req_grad[1], + dtype=dtype, + ) torch.nn.init.xavier_uniform_(B) B2 = B.clone() @@ -221,8 +329,15 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec state.threshold = decomp state.has_fp16_weights = has_fp16_weights if not has_fp16_weights: - if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous() - state.CB, CBt, state.SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B2) + if not transpose[0] and not transpose[1]: + B2 = B2.t().contiguous() + ( + state.CB, + CBt, + state.SCB, + SCBt, + coo_tensorB, + ) = bnb.functional.double_quant(B2) B2 = state.CB if not transpose[0] and transpose[1]: @@ -233,25 +348,29 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec out_bnb = funcs[1](A, B2.t(), state=state) n = out_bnb.numel() - err = torch.abs(out_bnb-out_torch).mean().item() - #print(f'abs error {err:.4f}') + err = torch.abs(out_bnb - out_torch).mean().item() + # print(f'abs error {err:.4f}') idx = torch.isclose(out_bnb, out_torch, atol=0.01, rtol=0.1) - assert (idx==0).sum().item() <= n*0.0175 + assert (idx == 0).sum().item() < n * 0.0175 idx = torch.isclose(out_bnb, out_torch, atol=0.035, rtol=0.2) - assert (idx==0).sum().item() <= n*0.001 + assert (idx == 0).sum().item() < n * 0.001 if has_fp16_weights: if any(req_grad): out_bnb.data.copy_(out_torch) torch.cuda.synchronize() - loss_bnb = torch.nn.functional.mse_loss(out_bnb, target).mean() + loss_bnb = torch.nn.functional.mse_loss( + out_bnb, target + ).mean() loss_bnb.backward() gradA1 = A.grad gradB1 = B.grad A.grad = None B.grad = None - loss_torch = torch.nn.functional.mse_loss(out_torch, target).mean() + loss_torch = torch.nn.functional.mse_loss( + out_torch, target + ).mean() loss_torch.backward() gradA2 = A.grad gradB2 = B.grad @@ -259,7 +378,9 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec B.grad = None if req_grad[0]: - torch.testing.assert_allclose(gradA1, gradA2, atol=0.015, rtol=0.1) + torch.testing.assert_allclose( + gradA1, gradA2, atol=0.015, rtol=0.1 + ) if req_grad[1]: n = gradB1.numel() if dim2 > 0: @@ -269,8 +390,10 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec assert torch.abs(gradB1).sum() == 0.0 assert torch.abs(gradB2).sum() == 0.0 idx = torch.isclose(gradB1, gradB2, atol=0.06, rtol=0.3) - assert (idx==0).sum().item() <= n*0.1 - idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) - assert (idx==0).sum().item() <= n*0.02 - torch.testing.assert_allclose(gradB1, gradB2, atol=0.18, rtol=0.3) + assert (idx == 0).sum().item() < n * 0.1 + idx = torch.isclose(gradB1, gradB2, atol=0.10, rtol=0.3) + assert (idx == 0).sum().item() < n * 0.02 + torch.testing.assert_allclose( + gradB1, gradB2, atol=0.18, rtol=0.3 + ) |