diff options
| author | Dylan Baker <dylan@pnwbakers.com> | 2025-12-03 09:49:17 -0800 |
|---|---|---|
| committer | Nirbheek Chauhan <nirbheek.chauhan@gmail.com> | 2025-12-05 01:35:06 +0530 |
| commit | a41bd69cb65d121702c03551860d0665347ebeae (patch) | |
| tree | 2514a9e8e12d8440567a8e52af47e23b8785bbb1 /mesonbuild/modules | |
| parent | 05105325bd22ad96b0ba932cd8f1a58293d11d13 (diff) | |
| download | meson-a41bd69cb65d121702c03551860d0665347ebeae.tar.gz | |
modules/cuda: Pull driver table out of class body
This makes use of a small class to simplify the implementation
Diffstat (limited to 'mesonbuild/modules')
| -rw-r--r-- | mesonbuild/modules/cuda.py | 99 |
1 files changed, 57 insertions, 42 deletions
diff --git a/mesonbuild/modules/cuda.py b/mesonbuild/modules/cuda.py index eb73a5770..0c5df8c58 100644 --- a/mesonbuild/modules/cuda.py +++ b/mesonbuild/modules/cuda.py @@ -3,6 +3,7 @@ from __future__ import annotations +import dataclasses import re import typing as T @@ -31,6 +32,56 @@ if T.TYPE_CHECKING: DETECTED_KW: KwargInfo[T.Union[None, T.List[str]]] = KwargInfo('detected', (ContainerTypeInfo(list, str), NoneType), listify=True) + +@dataclasses.dataclass +class _CudaVersion: + + meson: str + windows: str + linux: str + + def compare(self, version: str, machine: str) -> T.Optional[str]: + if version_compare(version, f'>={self.meson}'): + return self.windows if machine == 'windows' else self.linux + return None + + +_DRIVER_TABLE_VERSION: T.List[_CudaVersion] = [ + _CudaVersion('12.0.1', '528.41', '525.60.13'), + _CudaVersion('12.0.0', '527.41', '525.60.13'), + _CudaVersion('12.0.0', '527.41', '525.60.13'), + _CudaVersion('11.8.0', '522.06', '520.61.05'), + _CudaVersion('11.7.1', '516.31', '515.48.07'), + _CudaVersion('11.7.0', '516.01', '515.43.04'), + _CudaVersion('11.6.1', '511.65', '510.47.03'), + _CudaVersion('11.6.0', '511.23', '510.39.01'), + _CudaVersion('11.5.1', '496.13', '495.29.05'), + _CudaVersion('11.5.0', '496.04', '495.29.05'), + _CudaVersion('11.4.3', '472.50', '470.82.01'), + _CudaVersion('11.4.1', '471.41', '470.57.02'), + _CudaVersion('11.4.0', '471.11', '470.42.01'), + _CudaVersion('11.3.0', '465.89', '465.19.01'), + _CudaVersion('11.2.2', '461.33', '460.32.03'), + _CudaVersion('11.2.1', '461.09', '460.32.03'), + _CudaVersion('11.2.0', '460.82', '460.27.03'), + _CudaVersion('11.1.1', '456.81', '455.32'), + _CudaVersion('11.1.0', '456.38', '455.23'), + _CudaVersion('11.0.3', '451.82', '450.51.06'), + _CudaVersion('11.0.2', '451.48', '450.51.05'), + _CudaVersion('11.0.1', '451.22', '450.36.06'), + _CudaVersion('10.2.89', '441.22', '440.33'), + _CudaVersion('10.1.105', '418.96', '418.39'), + _CudaVersion('10.0.130', '411.31', '410.48'), + _CudaVersion('148', '398.26', '396.37'), + _CudaVersion('9.2.88', '397.44', '396.26'), + _CudaVersion('9.1.85', '391.29', '390.46'), + _CudaVersion('9.0.76', '385.54', '384.81'), + _CudaVersion('8.0.61', '376.51', '375.26'), + _CudaVersion('8.0.44', '369.30', '367.48'), + _CudaVersion('7.5.16', '353.66', '352.31'), + _CudaVersion('7.0.28', '347.62', '346.46'), +] + class CudaModule(NewExtensionModule): INFO = ModuleInfo('CUDA', '0.50.0', unstable=True) @@ -51,52 +102,16 @@ class CudaModule(NewExtensionModule): 'a CUDA Toolkit version string. Beware that, since CUDA 11.0, ' + 'the CUDA Toolkit\'s components (including NVCC) are versioned ' + 'independently from each other (and the CUDA Toolkit as a whole).') - if len(args) != 1 or not isinstance(args[0], str): raise argerror cuda_version = args[0] - driver_version_table = [ - {'cuda_version': '>=12.0.0', 'windows': '527.41', 'linux': '525.60.13'}, - {'cuda_version': '>=11.8.0', 'windows': '522.06', 'linux': '520.61.05'}, - {'cuda_version': '>=11.7.1', 'windows': '516.31', 'linux': '515.48.07'}, - {'cuda_version': '>=11.7.0', 'windows': '516.01', 'linux': '515.43.04'}, - {'cuda_version': '>=11.6.1', 'windows': '511.65', 'linux': '510.47.03'}, - {'cuda_version': '>=11.6.0', 'windows': '511.23', 'linux': '510.39.01'}, - {'cuda_version': '>=11.5.1', 'windows': '496.13', 'linux': '495.29.05'}, - {'cuda_version': '>=11.5.0', 'windows': '496.04', 'linux': '495.29.05'}, - {'cuda_version': '>=11.4.3', 'windows': '472.50', 'linux': '470.82.01'}, - {'cuda_version': '>=11.4.1', 'windows': '471.41', 'linux': '470.57.02'}, - {'cuda_version': '>=11.4.0', 'windows': '471.11', 'linux': '470.42.01'}, - {'cuda_version': '>=11.3.0', 'windows': '465.89', 'linux': '465.19.01'}, - {'cuda_version': '>=11.2.2', 'windows': '461.33', 'linux': '460.32.03'}, - {'cuda_version': '>=11.2.1', 'windows': '461.09', 'linux': '460.32.03'}, - {'cuda_version': '>=11.2.0', 'windows': '460.82', 'linux': '460.27.03'}, - {'cuda_version': '>=11.1.1', 'windows': '456.81', 'linux': '455.32'}, - {'cuda_version': '>=11.1.0', 'windows': '456.38', 'linux': '455.23'}, - {'cuda_version': '>=11.0.3', 'windows': '451.82', 'linux': '450.51.06'}, - {'cuda_version': '>=11.0.2', 'windows': '451.48', 'linux': '450.51.05'}, - {'cuda_version': '>=11.0.1', 'windows': '451.22', 'linux': '450.36.06'}, - {'cuda_version': '>=10.2.89', 'windows': '441.22', 'linux': '440.33'}, - {'cuda_version': '>=10.1.105', 'windows': '418.96', 'linux': '418.39'}, - {'cuda_version': '>=10.0.130', 'windows': '411.31', 'linux': '410.48'}, - {'cuda_version': '>=9.2.148', 'windows': '398.26', 'linux': '396.37'}, - {'cuda_version': '>=9.2.88', 'windows': '397.44', 'linux': '396.26'}, - {'cuda_version': '>=9.1.85', 'windows': '391.29', 'linux': '390.46'}, - {'cuda_version': '>=9.0.76', 'windows': '385.54', 'linux': '384.81'}, - {'cuda_version': '>=8.0.61', 'windows': '376.51', 'linux': '375.26'}, - {'cuda_version': '>=8.0.44', 'windows': '369.30', 'linux': '367.48'}, - {'cuda_version': '>=7.5.16', 'windows': '353.66', 'linux': '352.31'}, - {'cuda_version': '>=7.0.28', 'windows': '347.62', 'linux': '346.46'}, - ] - - driver_version = 'unknown' - for d in driver_version_table: - if version_compare(cuda_version, d['cuda_version']): - driver_version = d.get(state.environment.machines.host.system, d['linux']) - break - - return driver_version + + for d in _DRIVER_TABLE_VERSION: + driver_version = d.compare(cuda_version, state.environment.machines.host.system) + if driver_version is not None: + return driver_version + return 'unknown' @typed_pos_args('cuda.nvcc_arch_flags', (str, CudaCompiler), varargs=str) @typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW) |
