diff options
author | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-04 07:40:48 -0700 |
---|---|---|
committer | Tim Dettmers <tim.dettmers@gmail.com> | 2022-08-04 07:40:48 -0700 |
commit | cc5b323876392658b1d91655f30840d24be6d821 (patch) | |
tree | 8e23e961709a3cc082a707ebc8ea0f52baee6923 /bitsandbytes/functional.py | |
parent | 6101a8fb9f76c2cc4018452b4420dd52e946d52b (diff) | |
parent | bd515328d70f344f935075f359c5aefc616878d5 (diff) |
Merge branch 'extract_outliers' into debug
Diffstat (limited to 'bitsandbytes/functional.py')
-rw-r--r-- | bitsandbytes/functional.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index e7261bc..08c108c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1435,3 +1435,29 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): x *= SA[1]/127 x +=offset return x.to(dtype) + +def extract_outliers(A, SA, idx): + shapeA = SA[0] + formatA = SA[1] + assert formatA in ['col_turing', 'col_ampere'] + assert A.device.type == 'cuda' + + out = torch.zeros((shapeA[0], idx.numel()), dtype=torch.int8, device=A.device) + + idx_size = ct.c_int32(idx.numel()) + rows = ct.c_int32(shapeA[0]) + cols = ct.c_int32(shapeA[1]) + ptrA = get_ptr(A) + ptrIdx = get_ptr(idx) + ptrOut = get_ptr(out) + + if formatA == 'col_turing': + lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + elif formatA == 'col_ampere': + lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols) + + return out + + + + |