summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mesonbuild/modules/cuda.py27
1 files changed, 11 insertions, 16 deletions
diff --git a/mesonbuild/modules/cuda.py b/mesonbuild/modules/cuda.py
index b52288ab5..690053868 100644
--- a/mesonbuild/modules/cuda.py
+++ b/mesonbuild/modules/cuda.py
@@ -13,14 +13,13 @@ from ..interpreter.type_checking import NoneType
from . import NewExtensionModule, ModuleInfo
from ..interpreterbase import (
- ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs,
+ ContainerTypeInfo, InvalidArguments, KwargInfo, flatten, noKwargs, typed_kwargs, typed_pos_args,
)
if T.TYPE_CHECKING:
from typing_extensions import TypedDict
from . import ModuleState
- from ..compilers import Compiler
class ArchFlagsKwargs(TypedDict):
detected: T.Optional[T.List[str]]
@@ -95,17 +94,19 @@ class CudaModule(NewExtensionModule):
return driver_version
+ @typed_pos_args('cuda.nvcc_arch_flags', (str, CudaCompiler), varargs=str)
@typed_kwargs('cuda.nvcc_arch_flags', DETECTED_KW)
def nvcc_arch_flags(self, state: 'ModuleState',
- args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
+ args: T.Tuple[T.Union[CudaCompiler, str], T.List[str]],
kwargs: ArchFlagsKwargs) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[0]
return ret
+ @typed_pos_args('cuda.nvcc_arch_readable', (str, CudaCompiler), varargs=str)
@typed_kwargs('cuda.nvcc_arch_readable', DETECTED_KW)
def nvcc_arch_readable(self, state: 'ModuleState',
- args: T.Tuple[T.Union[Compiler, CudaCompiler, str]],
+ args: T.Tuple[T.Union[CudaCompiler, str], T.List[str]],
kwargs: ArchFlagsKwargs) -> T.List[str]:
nvcc_arch_args = self._validate_nvcc_arch_args(args, kwargs)
ret = self._nvcc_arch_flags(*nvcc_arch_args)[1]
@@ -123,21 +124,15 @@ class CudaModule(NewExtensionModule):
return [c.detected_cc]
return []
- def _validate_nvcc_arch_args(self, args, kwargs: ArchFlagsKwargs):
- argerror = InvalidArguments('The first argument must be an NVCC compiler object, or its version string!')
+ def _validate_nvcc_arch_args(self, args: T.Tuple[T.Union[str, CudaCompiler], T.List[str]], kwargs: ArchFlagsKwargs):
- if len(args) < 1:
- raise argerror
+ compiler = args[0]
+ if isinstance(compiler, CudaCompiler):
+ cuda_version = compiler.version
else:
- compiler = args[0]
- if isinstance(compiler, CudaCompiler):
- cuda_version = compiler.version
- elif isinstance(compiler, str):
- cuda_version = compiler
- else:
- raise argerror
+ cuda_version = compiler
- arch_list = [] if len(args) <= 1 else flatten(args[1:])
+ arch_list = args[1]
arch_list = [self._break_arch_string(a) for a in arch_list]
arch_list = flatten(arch_list)
if len(arch_list) > 1 and not set(arch_list).isdisjoint({'All', 'Common', 'Auto'}):