summaryrefslogtreecommitdiff
path: root/mesonbuild
diff options
context:
space:
mode:
authorEli Schwartz <eschwartz93@gmail.com>2023-10-15 19:01:34 -0400
committerEli Schwartz <eschwartz93@gmail.com>2024-02-12 23:12:10 -0500
commitcf35d9b4cebecd3e565d49e4d9f4a5366429463a (patch)
treefd75808aa51f396afdac6de4ca33e3320236960d /mesonbuild
parent6f7e74505246ac0363a83c49339026e977a1c03d (diff)
downloadmeson-cf35d9b4cebecd3e565d49e4d9f4a5366429463a.tar.gz
cuda module: use typed_kwargs
This officially only ever accepted string or array of strings.
Diffstat (limited to 'mesonbuild')
-rw-r--r--mesonbuild/modules/cuda.py31
1 files changed, 19 insertions, 12 deletions
diff --git a/mesonbuild/modules/cuda.py b/mesonbuild/modules/cuda.py
index 525010839..7cfd3a04f 100644
--- a/mesonbuild/modules/cuda.py
+++ b/mesonbuild/modules/cuda.py
@@ -8,18 +8,26 @@ import re
from ..mesonlib import version_compare
from ..compilers.cuda import CudaCompiler
+from ..interpreter.type_checking import NoneType
from . import NewExtensionModule, ModuleInfo
from ..interpreterbase import (
- flatten, permittedKwargs, noKwargs,
- InvalidArguments
+ ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs,
)
if T.TYPE_CHECKING:
+ from typing_extensions import TypedDict
+
from . import ModuleState
from ..compilers import Compiler
+ class ArchFlagsKwargs(TypedDict):
+ detected: T.Optional[T.List[str]]
+
+
+DETECTED_KW: KwargInfo[T.Union[None, T.List[str]]] = KwargInfo('detected', (ContainerTypeInfo(list, str), NoneType), listify=True)
+
class CudaModule(NewExtensionModule):
INFO = ModuleInfo('CUDA', '0.50.0', unstable=True)
@@ -87,18 +95,18 @@ class CudaModule(NewExtensionModule):
return driver_version
- @permittedKwargs(['detected'])
+ @typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW)
def nvcc_arch_flags(self, state: 'ModuleState',
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
- kwargs: T.Dict[str, T.Any]) -> T.List[str]:
+ kwargs: ArchFlagsKwargs) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0]
return ret
- @permittedKwargs(['detected'])
+ @typed_kwargs('cuda.nvcc_arch_readable', DETECTED_KW)
def nvcc_arch_readable(self, state: 'ModuleState',
args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
- kwargs: T.Dict[str, T.Any]) -> T.List[str]:
+ kwargs: ArchFlagsKwargs) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1]
return ret
@@ -110,10 +118,10 @@ class CudaModule(NewExtensionModule):
return s
@staticmethod
- def _detected_cc_from_compiler(c):
+ def _detected_cc_from_compiler(c) -> T.List[str]:
if isinstance(c, CudaCompiler):
- return c.detected_cc
- return ''
+ return [c.detected_cc]
+ return []
@staticmethod
def _version_from_compiler(c):
@@ -123,7 +131,7 @@ class CudaModule(NewExtensionModule):
return c
return 'unknown'
- def _validate_nvcc_arch_args(self, args, kwargs):
+ def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs):
argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
if len(args) < 1:
@@ -141,8 +149,7 @@ class CudaModule(NewExtensionModule):
raise InvalidArguments('''The special architectures 'All', 'Common' and 'Auto' must appear alone, as a positional argument!''')
arch_list = arch_list[0] if len(arch_list) == 1 else arch_list
- detected = kwargs.get('detected', self._detected_cc_from_compiler(compiler))
- detected = flatten([detected])
+ detected = kwargs['detected'] if kwargs['detected'] is not None else self._detected_cc_from_compiler(compiler)
detected = [self._break_arch_string(a) for a in detected]
detected = flatten(detected)
if not set(detected).isdisjoint({'All', 'Common', 'Auto'}):