diff options
author | justheuristic <justheuristic@gmail.com> | 2022-09-18 00:55:53 +0300 |
---|---|---|
committer | justheuristic <justheuristic@gmail.com> | 2022-09-18 00:55:53 +0300 |
commit | 2cd047e35da3a421c4b491ff1a137e19b9c6c919 (patch) | |
tree | 030a92de50216760a1fcc454719b9df7782397f4 | |
parent | 591f60395a1e9c62f291e23c91af45cc699f072c (diff) |
run backward
-rw-r--r-- | tests/test_modules.py | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/tests/test_modules.py b/tests/test_modules.py index 53a675f..d3992a9 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -554,11 +554,22 @@ def test_linear8bitlt_no_fp16_weights(threshold, memory_efficient_backward): assert mlp.fc1.state.idx is not None if threshold > 0: assert mlp.fc2.state.idx is not None + assert mlp.fc1.weight.dtype == torch.int8 assert mlp.fc2.weight.dtype == torch.int8 assert mlp.fc1.weight.device.type == "cuda" assert mlp.fc2.weight.device.type == "cuda" + if memory_efficient_backward: + b1 = torch.randn(16, 8, 32, device="cuda", requires_grad=True, dtype=torch.half) + o1 = mlp(b1) + assert o1.dtype == torch.float16 + assert o1.requires_grad + grad_proj = torch.randn_like(o1) + + (o1 * grad_proj).sum().backward() + + def test_linear8bitlt_fp32_bias(): |