diff options
Diffstat (limited to 'bitsandbytes/cuda_setup/compute_capability.py')
-rw-r--r-- | bitsandbytes/cuda_setup/compute_capability.py | 79 |
1 files changed, 79 insertions, 0 deletions
diff --git a/bitsandbytes/cuda_setup/compute_capability.py b/bitsandbytes/cuda_setup/compute_capability.py new file mode 100644 index 0000000..7a3f463 --- /dev/null +++ b/bitsandbytes/cuda_setup/compute_capability.py @@ -0,0 +1,79 @@ +import ctypes +from dataclasses import dataclass, field + + +@dataclass +class CudaLibVals: + # code bits taken from + # https://gist.github.com/f0k/63a664160d016a491b2cbea15913d549 + + nGpus: ctypes.c_int = field(default=ctypes.c_int()) + cc_major: ctypes.c_int = field(default=ctypes.c_int()) + cc_minor: ctypes.c_int = field(default=ctypes.c_int()) + device: ctypes.c_int = field(default=ctypes.c_int()) + error_str: ctypes.c_char_p = field(default=ctypes.c_char_p()) + cuda: ctypes.CDLL = field(init=False, repr=False) + ccs: List[str, ...] = field(init=False) + + def _initialize_driver_API(self): + self.check_cuda_result(self.cuda.cuInit(0)) + + def _load_cuda_lib(self): + """ + 1. find libcuda.so library (GPU driver) (/usr/lib) + init_device -> init variables -> call function by reference + """ + libnames = "libcuda.so" + for libname in libnames: + try: + self.cuda = ctypes.CDLL(libname) + except OSError: + continue + else: + break + else: + raise OSError("could not load any of: " + " ".join(libnames)) + + def call_cuda_func(self, function_obj, **kwargs): + CUDA_SUCCESS = 0 # constant taken from cuda.h + pass + # if (CUDA_SUCCESS := function_obj( + + def _error_handle(cuda_lib_call_return_value): + """ + 2. call extern C function to determine CC + (see https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__DEVICE__DEPRECATED.html) + """ + CUDA_SUCCESS = 0 # constant taken from cuda.h + + if cuda_lib_call_return_value != CUDA_SUCCESS: + self.cuda.cuGetErrorString( + cuda_lib_call_return_value, + ctypes.byref(self.error_str), + ) + print("Count not initialize CUDA - failure!") + raise Exception("CUDA exception!") + return cuda_lib_call_return_value + + def __post_init__(self): + self._load_cuda_lib() + self._initialize_driver_API() + self.check_cuda_result( + self.cuda, self.cuda.cuDeviceGetCount(ctypes.byref(self.nGpus)) + ) + tmp_ccs = [] + for gpu_index in range(self.nGpus.value): + check_cuda_result( + self.cuda, + self.cuda.cuDeviceGet(ctypes.byref(self.device), gpu_index), + ) + check_cuda_result( + self.cuda, + self.cuda.cuDeviceComputeCapability( + ctypes.byref(self.cc_major), + ctypes.byref(self.cc_minor), + self.device, + ), + ) + tmp_ccs.append(f"{self.cc_major.value}.{self.cc_minor.value}") + self.ccs = sorted(tmp_ccs, reverse=True) |