diff options
| author | Volker Weißmann <volker.weissmann@gmx.de> | 2025-03-09 18:57:14 +0100 |
|---|---|---|
| committer | Dylan Baker <dylan@pnwbakers.com> | 2025-05-29 09:20:27 -0700 |
| commit | 4bbbd33d641e8761244342dc850e1b0ff97002b7 (patch) | |
| tree | 09a7c35bce52cd748c359d0be7f8977200da642e | |
| parent | 792db9439b30b64449097eb5e5fc940148cacaab (diff) | |
| download | meson-4bbbd33d641e8761244342dc850e1b0ff97002b7.tar.gz | |
rewriter.py: make type safe
| -rw-r--r-- | mesonbuild/ast/interpreter.py | 4 | ||||
| -rw-r--r-- | mesonbuild/ast/introspection.py | 10 | ||||
| -rw-r--r-- | mesonbuild/rewriter.py | 226 | ||||
| -rwxr-xr-x | run_mypy.py | 1 |
4 files changed, 137 insertions, 104 deletions
diff --git a/mesonbuild/ast/interpreter.py b/mesonbuild/ast/interpreter.py index cde578a9a..b18bb5882 100644 --- a/mesonbuild/ast/interpreter.py +++ b/mesonbuild/ast/interpreter.py @@ -49,7 +49,7 @@ from ..mparser import ( if T.TYPE_CHECKING: from .visitor import AstVisitor from ..interpreter import Interpreter - from ..interpreterbase import SubProject, TYPE_nkwargs, TYPE_var + from ..interpreterbase import SubProject, TYPE_nkwargs, TYPE_var, TYPE_nvar from ..mparser import ( AndNode, ComparisonNode, @@ -387,7 +387,7 @@ class AstInterpreter(InterpreterBase): return result - def flatten_args(self, args_raw: T.Union[TYPE_var, T.Sequence[TYPE_var]], include_unknown_args: bool = False, id_loop_detect: T.Optional[T.List[str]] = None) -> T.List[TYPE_var]: + def flatten_args(self, args_raw: T.Union[TYPE_nvar, T.Sequence[TYPE_nvar]], include_unknown_args: bool = False, id_loop_detect: T.Optional[T.List[str]] = None) -> T.List[TYPE_var]: # Make sure we are always dealing with lists if isinstance(args_raw, list): args = args_raw diff --git a/mesonbuild/ast/introspection.py b/mesonbuild/ast/introspection.py index 88ca72c84..b86b4e1e0 100644 --- a/mesonbuild/ast/introspection.py +++ b/mesonbuild/ast/introspection.py @@ -67,7 +67,7 @@ class IntrospectionInterpreter(AstInterpreter): self.project_data: T.Dict[str, T.Any] = {} self.targets: T.List[IntrospectionBuildTarget] = [] self.dependencies: T.List[IntrospectionDependency] = [] - self.project_node: BaseNode = None + self.project_node: FunctionNode = None self.funcs.update({ 'add_languages': self.func_add_languages, @@ -85,6 +85,7 @@ class IntrospectionInterpreter(AstInterpreter): def func_project(self, node: BaseNode, args: T.List[TYPE_var], kwargs: T.Dict[str, TYPE_var]) -> None: if self.project_node: raise InvalidArguments('Second call to project()') + assert isinstance(node, FunctionNode) self.project_node = node if len(args) < 1: raise InvalidArguments('Not enough arguments to project(). Needs at least the project name.') @@ -211,16 +212,20 @@ class IntrospectionInterpreter(AstInterpreter): self.coredata.process_compiler_options(lang, comp, self.subproject) def func_dependency(self, node: BaseNode, args: T.List[TYPE_var], kwargs: T.Dict[str, TYPE_var]) -> None: + assert isinstance(node, FunctionNode) args = self.flatten_args(args) kwargs = self.flatten_kwargs(kwargs) if not args: return name = args[0] + assert isinstance(name, str) has_fallback = 'fallback' in kwargs required = kwargs.get('required', True) version = kwargs.get('version', []) if not isinstance(version, list): version = [version] + assert all(isinstance(el, str) for el in version) + version = T.cast(T.List[str], version) if isinstance(required, ElementaryNode): required = required.value if not isinstance(required, bool): @@ -235,11 +240,12 @@ class IntrospectionInterpreter(AstInterpreter): self.dependencies += [newdep] def build_target(self, node: BaseNode, args: T.List[TYPE_var], kwargs_raw: T.Dict[str, TYPE_var], targetclass: T.Type[BuildTarget]) -> IntrospectionBuildTarget: + assert isinstance(node, FunctionNode) args = self.flatten_args(args) if not args or not isinstance(args[0], str): return None name = args[0] - srcqueue = [node] + srcqueue: T.List[BaseNode] = [node] extra_queue = [] # Process the sources BEFORE flattening the kwargs, to preserve the original nodes diff --git a/mesonbuild/rewriter.py b/mesonbuild/rewriter.py index 253e3f2fb..fff3f9cfc 100644 --- a/mesonbuild/rewriter.py +++ b/mesonbuild/rewriter.py @@ -10,24 +10,26 @@ from __future__ import annotations from .ast import IntrospectionInterpreter, BUILD_TARGET_FUNCTIONS, AstConditionLevel, AstIDGenerator, AstIndentationGenerator, AstPrinter -from .ast.interpreter import IntrospectionDependency +from .ast.interpreter import IntrospectionBuildTarget, IntrospectionDependency +from .interpreterbase import TV_func from mesonbuild.mesonlib import MesonException, setup_vsenv from . import mlog, environment from functools import wraps -from .mparser import Token, ArrayNode, ArgumentNode, AssignmentNode, StringNode, BooleanNode, ElementaryNode, IdNode, FunctionNode, SymbolNode +from .mparser import Token, ArrayNode, ArgumentNode, AssignmentNode, BaseNode, StringNode, BooleanNode, ElementaryNode, IdNode, FunctionNode, SymbolNode import json, os, re, sys import typing as T if T.TYPE_CHECKING: - from argparse import ArgumentParser, HelpFormatter - from .mparser import BaseNode + import argparse + from argparse import ArgumentParser, _FormatterClass + from .mlog import AnsiDecorator class RewriterException(MesonException): pass # Note: when adding arguments, please also add them to the completion # scripts in $MESONSRC/data/shell-completions/ -def add_arguments(parser: ArgumentParser, formatter: T.Callable[[str], HelpFormatter]) -> None: +def add_arguments(parser: ArgumentParser, formatter: _FormatterClass) -> None: parser.add_argument('-s', '--sourcedir', type=str, default='.', metavar='SRCDIR', help='Path to source directory.') parser.add_argument('-V', '--verbose', action='store_true', default=False, help='Enable verbose output') parser.add_argument('-S', '--skip-errors', dest='skip', action='store_true', default=False, help='Skip errors instead of aborting') @@ -63,12 +65,14 @@ def add_arguments(parser: ArgumentParser, formatter: T.Callable[[str], HelpForma cmd_parser.add_argument('json', help='JSON string or file to execute') class RequiredKeys: - def __init__(self, keys): + keys: T.Dict[str, T.Any] + + def __init__(self, keys: T.Dict[str, T.Any]): self.keys = keys - def __call__(self, f): + def __call__(self, f: TV_func) -> TV_func: @wraps(f) - def wrapped(*wrapped_args, **wrapped_kwargs): + def wrapped(*wrapped_args: T.Any, **wrapped_kwargs: T.Any) -> T.Any: assert len(wrapped_args) >= 2 cmd = wrapped_args[1] for key, val in self.keys.items(): @@ -91,12 +95,14 @@ class RequiredKeys: .format(key, choices, cmd[key])) return f(*wrapped_args, **wrapped_kwargs) - return wrapped + return T.cast('TV_func', wrapped) def _symbol(val: str) -> SymbolNode: return SymbolNode(Token('', '', 0, 0, 0, (0, 0), val)) class MTypeBase: + node: BaseNode + def __init__(self, node: T.Optional[BaseNode] = None): if node is None: self.node = self.new_node() @@ -108,30 +114,30 @@ class MTypeBase: self.node_type = i @classmethod - def new_node(cls, value=None): + def new_node(cls, value: T.Any = None) -> BaseNode: # Overwrite in derived class raise RewriterException('Internal error: new_node of MTypeBase was called') @classmethod - def supported_nodes(cls): + def supported_nodes(cls) -> T.List[type]: # Overwrite in derived class return [] - def can_modify(self): + def can_modify(self) -> bool: return self.node_type is not None - def get_node(self): + def get_node(self) -> BaseNode: return self.node - def add_value(self, value): + def add_value(self, value: T.Any) -> None: # Overwrite in derived class mlog.warning('Cannot add a value of type', mlog.bold(type(self).__name__), '--> skipping') - def remove_value(self, value): + def remove_value(self, value: T.Any) -> None: # Overwrite in derived class mlog.warning('Cannot remove a value of type', mlog.bold(type(self).__name__), '--> skipping') - def remove_regex(self, value): + def remove_regex(self, value: T.Any) -> None: # Overwrite in derived class mlog.warning('Cannot remove a regex in type', mlog.bold(type(self).__name__), '--> skipping') @@ -140,13 +146,13 @@ class MTypeStr(MTypeBase): super().__init__(node) @classmethod - def new_node(cls, value=None): + def new_node(cls, value: T.Optional[str] = None) -> BaseNode: if value is None: value = '' return StringNode(Token('string', '', 0, 0, 0, None, str(value))) @classmethod - def supported_nodes(cls): + def supported_nodes(cls) -> T.List[type]: return [StringNode] class MTypeBool(MTypeBase): @@ -154,11 +160,11 @@ class MTypeBool(MTypeBase): super().__init__(node) @classmethod - def new_node(cls, value=None): + def new_node(cls, value: T.Optional[str] = None) -> BaseNode: return BooleanNode(Token('', '', 0, 0, 0, None, bool(value))) @classmethod - def supported_nodes(cls): + def supported_nodes(cls) -> T.List[type]: return [BooleanNode] class MTypeID(MTypeBase): @@ -166,21 +172,23 @@ class MTypeID(MTypeBase): super().__init__(node) @classmethod - def new_node(cls, value=None): + def new_node(cls, value: T.Optional[str] = None) -> BaseNode: if value is None: value = '' return IdNode(Token('', '', 0, 0, 0, None, str(value))) @classmethod - def supported_nodes(cls): + def supported_nodes(cls) -> T.List[type]: return [IdNode] class MTypeList(MTypeBase): + node: ArrayNode + def __init__(self, node: T.Optional[BaseNode] = None): super().__init__(node) @classmethod - def new_node(cls, value=None): + def new_node(cls, value: T.Optional[T.List[T.Any]] = None) -> ArrayNode: if value is None: value = [] elif not isinstance(value, list): @@ -190,50 +198,52 @@ class MTypeList(MTypeBase): return ArrayNode(_symbol('['), args, _symbol(']')) @classmethod - def _new_element_node(cls, value): + def _new_element_node(cls, value: T.Any) -> BaseNode: # Overwrite in derived class raise RewriterException('Internal error: _new_element_node of MTypeList was called') - def _ensure_array_node(self): + def _ensure_array_node(self) -> None: if not isinstance(self.node, ArrayNode): tmp = self.node self.node = self.new_node() self.node.args.arguments = [tmp] @staticmethod - def _check_is_equal(node, value) -> bool: + def _check_is_equal(node: BaseNode, value: str) -> bool: # Overwrite in derived class return False @staticmethod - def _check_regex_matches(node, regex: str) -> bool: + def _check_regex_matches(node: BaseNode, regex: str) -> bool: # Overwrite in derived class return False - def get_node(self): + def get_node(self) -> BaseNode: if isinstance(self.node, ArrayNode): if len(self.node.args.arguments) == 1: return self.node.args.arguments[0] return self.node @classmethod - def supported_element_nodes(cls): + def supported_element_nodes(cls) -> T.List[T.Type]: # Overwrite in derived class return [] @classmethod - def supported_nodes(cls): + def supported_nodes(cls) -> T.List[T.Type]: return [ArrayNode] + cls.supported_element_nodes() - def add_value(self, value): + def add_value(self, value: T.Any) -> None: if not isinstance(value, list): value = [value] self._ensure_array_node() for i in value: + assert hasattr(self.node, 'args') # For mypy + assert isinstance(self.node.args, ArgumentNode) # For mypy self.node.args.arguments += [self._new_element_node(i)] - def _remove_helper(self, value, equal_func): - def check_remove_node(node): + def _remove_helper(self, value: T.Any, equal_func: T.Callable[[T.Any, T.Any], bool]) -> None: + def check_remove_node(node: BaseNode) -> bool: for j in value: if equal_func(i, j): return True @@ -242,16 +252,18 @@ class MTypeList(MTypeBase): if not isinstance(value, list): value = [value] self._ensure_array_node() + assert hasattr(self.node, 'args') # For mypy + assert isinstance(self.node.args, ArgumentNode) # For mypy removed_list = [] for i in self.node.args.arguments: if not check_remove_node(i): removed_list += [i] self.node.args.arguments = removed_list - def remove_value(self, value): + def remove_value(self, value: T.Any) -> None: self._remove_helper(value, self._check_is_equal) - def remove_regex(self, regex: str): + def remove_regex(self, regex: str) -> None: self._remove_helper(regex, self._check_regex_matches) class MTypeStrList(MTypeList): @@ -259,23 +271,23 @@ class MTypeStrList(MTypeList): super().__init__(node) @classmethod - def _new_element_node(cls, value): + def _new_element_node(cls, value: str) -> StringNode: return StringNode(Token('string', '', 0, 0, 0, None, str(value))) @staticmethod - def _check_is_equal(node, value) -> bool: + def _check_is_equal(node: BaseNode, value: str) -> bool: if isinstance(node, StringNode): - return node.value == value + return bool(node.value == value) return False @staticmethod - def _check_regex_matches(node, regex: str) -> bool: + def _check_regex_matches(node: BaseNode, regex: str) -> bool: if isinstance(node, StringNode): return re.match(regex, node.value) is not None return False @classmethod - def supported_element_nodes(cls): + def supported_element_nodes(cls) -> T.List[T.Type]: return [StringNode] class MTypeIDList(MTypeList): @@ -283,26 +295,26 @@ class MTypeIDList(MTypeList): super().__init__(node) @classmethod - def _new_element_node(cls, value): + def _new_element_node(cls, value: str) -> IdNode: return IdNode(Token('', '', 0, 0, 0, None, str(value))) @staticmethod - def _check_is_equal(node, value) -> bool: + def _check_is_equal(node: BaseNode, value: str) -> bool: if isinstance(node, IdNode): - return node.value == value + return bool(node.value == value) return False @staticmethod - def _check_regex_matches(node, regex: str) -> bool: + def _check_regex_matches(node: BaseNode, regex: str) -> bool: if isinstance(node, StringNode): return re.match(regex, node.value) is not None return False @classmethod - def supported_element_nodes(cls): + def supported_element_nodes(cls) -> T.List[T.Type]: return [IdNode] -rewriter_keys = { +rewriter_keys: T.Dict[str, T.Dict[str, T.Any]] = { 'default_options': { 'operation': (str, None, ['set', 'delete']), 'options': (dict, {}, None) @@ -356,13 +368,15 @@ rewriter_func_kwargs = { } class Rewriter: + info_dump: T.Optional[T.Dict[str, T.Dict[str, T.Any]]] + def __init__(self, sourcedir: str, generator: str = 'ninja', skip_errors: bool = False): self.sourcedir = sourcedir self.interpreter = IntrospectionInterpreter(sourcedir, '', generator, visitors = [AstIDGenerator(), AstIndentationGenerator(), AstConditionLevel()]) self.skip_errors = skip_errors - self.modified_nodes = [] - self.to_remove_nodes = [] - self.to_add_nodes = [] + self.modified_nodes: T.List[BaseNode] = [] + self.to_remove_nodes: T.List[BaseNode] = [] + self.to_add_nodes: T.List[BaseNode] = [] self.functions = { 'default_options': self.process_default_options, 'kwargs': self.process_kwargs, @@ -370,36 +384,36 @@ class Rewriter: } self.info_dump = None - def analyze_meson(self): + def analyze_meson(self) -> None: mlog.log('Analyzing meson file:', mlog.bold(os.path.join(self.sourcedir, environment.build_filename))) self.interpreter.analyze() mlog.log(' -- Project:', mlog.bold(self.interpreter.project_data['descriptive_name'])) mlog.log(' -- Version:', mlog.cyan(self.interpreter.project_data['version'])) - def add_info(self, cmd_type: str, cmd_id: str, data: dict): + def add_info(self, cmd_type: str, cmd_id: str, data: dict) -> None: if self.info_dump is None: self.info_dump = {} if cmd_type not in self.info_dump: self.info_dump[cmd_type] = {} self.info_dump[cmd_type][cmd_id] = data - def print_info(self): + def print_info(self) -> None: if self.info_dump is None: return sys.stdout.write(json.dumps(self.info_dump, indent=2)) - def on_error(self): + def on_error(self) -> T.Tuple[AnsiDecorator, AnsiDecorator]: if self.skip_errors: return mlog.cyan('-->'), mlog.yellow('skipping') return mlog.cyan('-->'), mlog.red('aborting') - def handle_error(self): + def handle_error(self) -> None: if self.skip_errors: return None raise MesonException('Rewriting the meson.build failed') - def find_target(self, target: str): - def check_list(name: str) -> T.List[BaseNode]: + def find_target(self, target: str) -> T.Optional[IntrospectionBuildTarget]: + def check_list(name: str) -> T.List[IntrospectionBuildTarget]: result = [] for i in self.interpreter.targets: if name in {i.name, i.id}: @@ -426,10 +440,11 @@ class Rewriter: if node.func_name.value in {'executable', 'jar', 'library', 'shared_library', 'shared_module', 'static_library', 'both_libraries'}: tgt = self.interpreter.assign_vals[target] + assert isinstance(tgt, (type(None), IntrospectionBuildTarget)) return tgt def find_dependency(self, dependency: str) -> T.Optional[IntrospectionDependency]: - def check_list(name: str): + def check_list(name: str) -> T.Optional[IntrospectionDependency]: for i in self.interpreter.dependencies: if name == i.name: return i @@ -445,14 +460,15 @@ class Rewriter: if isinstance(node, FunctionNode): if node.func_name.value == 'dependency': name = self.interpreter.flatten_args(node.args)[0] + assert isinstance(name, str) dep = check_list(name) return dep @RequiredKeys(rewriter_keys['default_options']) - def process_default_options(self, cmd): + def process_default_options(self, cmd: T.Dict[str, T.Any]) -> None: # First, remove the old values - kwargs_cmd = { + kwargs_cmd: T.Dict[str, T.Any] = { 'function': 'project', 'id': "/", 'operation': 'remove_regex', @@ -496,7 +512,7 @@ class Rewriter: self.process_kwargs(kwargs_cmd) @RequiredKeys(rewriter_keys['kwargs']) - def process_kwargs(self, cmd): + def process_kwargs(self, cmd: T.Dict[str, T.Any]) -> None: mlog.log('Processing function type', mlog.bold(cmd['function']), 'with id', mlog.cyan("'" + cmd['id'] + "'")) if cmd['function'] not in rewriter_func_kwargs: mlog.error('Unknown function type', cmd['function'], *self.on_error()) @@ -531,12 +547,12 @@ class Rewriter: assert isinstance(node, FunctionNode) assert isinstance(arg_node, ArgumentNode) # Transform the key nodes to plain strings - arg_node.kwargs = {k.value: v for k, v in arg_node.kwargs.items()} + kwargs = {T.cast(IdNode, k).value: v for k, v in arg_node.kwargs.items()} # Print kwargs info if cmd['operation'] == 'info': - info_data = {} - for key, val in sorted(arg_node.kwargs.items()): + info_data: T.Dict[str, T.Any] = {} + for key, val in sorted(kwargs.items()): info_data[key] = None if isinstance(val, ElementaryNode): info_data[key] = val.value @@ -562,21 +578,21 @@ class Rewriter: if cmd['operation'] == 'delete': # Remove the key from the kwargs - if key not in arg_node.kwargs: + if key not in kwargs: mlog.log(' -- Key', mlog.bold(key), 'is already deleted') continue mlog.log(' -- Deleting', mlog.bold(key), 'from the kwargs') - del arg_node.kwargs[key] + del kwargs[key] elif cmd['operation'] == 'set': # Replace the key from the kwargs mlog.log(' -- Setting', mlog.bold(key), 'to', mlog.yellow(str(val))) - arg_node.kwargs[key] = kwargs_def[key].new_node(val) + kwargs[key] = kwargs_def[key].new_node(val) else: # Modify the value from the kwargs - if key not in arg_node.kwargs: - arg_node.kwargs[key] = None - modifier = kwargs_def[key](arg_node.kwargs[key]) + if key not in kwargs: + kwargs[key] = None + modifier = kwargs_def[key](kwargs[key]) if not modifier.can_modify(): mlog.log(' -- Skipping', mlog.bold(key), 'because it is too complex to modify') continue @@ -594,12 +610,12 @@ class Rewriter: modifier.remove_regex(val) # Write back the result - arg_node.kwargs[key] = modifier.get_node() + kwargs[key] = modifier.get_node() num_changed += 1 # Convert the keys back to IdNode's - arg_node.kwargs = {IdNode(Token('', '', 0, 0, 0, None, k)): v for k, v in arg_node.kwargs.items()} + arg_node.kwargs = {IdNode(Token('', '', 0, 0, 0, None, k)): v for k, v in kwargs.items()} for k, v in arg_node.kwargs.items(): k.level = v.level if num_changed > 0 and node not in self.modified_nodes: @@ -607,11 +623,13 @@ class Rewriter: def find_assignment_node(self, node: BaseNode) -> AssignmentNode: if node.ast_id and node.ast_id in self.interpreter.reverse_assignment: - return self.interpreter.reverse_assignment[node.ast_id] + ret = self.interpreter.reverse_assignment[node.ast_id] + assert isinstance(ret, AssignmentNode) + return ret return None @RequiredKeys(rewriter_keys['target']) - def process_target(self, cmd): + def process_target(self, cmd: T.Dict[str, T.Any]) -> None: mlog.log('Processing target', mlog.bold(cmd['target']), 'operation', mlog.cyan(cmd['operation'])) target = self.find_target(cmd['target']) if target is None and cmd['operation'] != 'target_add': @@ -632,7 +650,7 @@ class Rewriter: cmd['sources'] = [rel_source(x) for x in cmd['sources']] # Utility function to get a list of the sources from a node - def arg_list_from_node(n): + def arg_list_from_node(n: BaseNode) -> T.List[BaseNode]: args = [] if isinstance(n, FunctionNode): args = list(n.args.arguments) @@ -656,8 +674,8 @@ class Rewriter: # Generate the current source list src_list = [] - for i in target.source_nodes: - for j in arg_list_from_node(i): + for src_node in target.source_nodes: + for j in arg_list_from_node(src_node): if isinstance(j, StringNode): src_list += [j.value] @@ -689,7 +707,7 @@ class Rewriter: elif cmd['operation'] == 'src_rm': # Helper to find the exact string node and its parent - def find_node(src): + def find_node(src: str) -> T.Tuple[BaseNode, StringNode]: for i in target.source_nodes: for j in arg_list_from_node(i): if isinstance(j, StringNode): @@ -779,7 +797,7 @@ class Rewriter: elif cmd['operation'] == 'extra_files_rm': # Helper to find the exact string node and its parent - def find_node(src): + def find_node(src: str) -> T.Tuple[BaseNode, StringNode]: for i in target.extra_files: for j in arg_list_from_node(i): if isinstance(j, StringNode): @@ -794,6 +812,9 @@ class Rewriter: mlog.warning(' -- Unable to find extra file', mlog.green(i), 'in the target') continue + if not isinstance(root, (FunctionNode, ArrayNode)): + raise NotImplementedError # I'm lazy + # Remove the found string node from the argument list arg_node = root.args mlog.log(' -- Removing extra file', mlog.green(i), 'from', @@ -839,7 +860,7 @@ class Rewriter: self.to_add_nodes += [src_ass_node, tgt_ass_node] elif cmd['operation'] == 'target_rm': - to_remove = self.find_assignment_node(target.node) + to_remove: BaseNode = self.find_assignment_node(target.node) if to_remove is None: to_remove = target.node self.to_remove_nodes += [to_remove] @@ -867,16 +888,21 @@ class Rewriter: # Sort files for i in to_sort_nodes: - convert = lambda text: int(text) if text.isdigit() else text.lower() - alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] - path_sorter = lambda key: ([(key.count('/') <= idx, alphanum_key(x)) for idx, x in enumerate(key.split('/'))]) + def convert(text: str) -> T.Union[int, str]: + return int(text) if text.isdigit() else text.lower() + + def alphanum_key(key: str) -> T.List[T.Union[int, str]]: + return [convert(c) for c in re.split('([0-9]+)', key)] + + def path_sorter(key: str) -> T.List[T.Tuple[bool, T.List[T.Union[int, str]]]]: + return [(key.count('/') <= idx, alphanum_key(x)) for idx, x in enumerate(key.split('/'))] unknown = [x for x in i.arguments if not isinstance(x, StringNode)] sources = [x for x in i.arguments if isinstance(x, StringNode)] sources = sorted(sources, key=lambda x: path_sorter(x.value)) - i.arguments = unknown + sources + i.arguments = unknown + T.cast(T.List[BaseNode], sources) - def process(self, cmd): + def process(self, cmd: T.Dict[str, T.Any]) -> None: if 'type' not in cmd: raise RewriterException('Command has no key "type"') if cmd['type'] not in self.functions: @@ -884,7 +910,7 @@ class Rewriter: .format(cmd['type'], list(self.functions.keys()))) self.functions[cmd['type']](cmd) - def apply_changes(self): + def apply_changes(self) -> None: assert all(hasattr(x, 'lineno') and hasattr(x, 'colno') and hasattr(x, 'filename') for x in self.modified_nodes) assert all(hasattr(x, 'lineno') and hasattr(x, 'colno') and hasattr(x, 'filename') for x in self.to_remove_nodes) assert all(isinstance(x, (ArrayNode, FunctionNode)) for x in self.modified_nodes) @@ -892,7 +918,7 @@ class Rewriter: # Sort based on line and column in reversed order work_nodes = [{'node': x, 'action': 'modify'} for x in self.modified_nodes] work_nodes += [{'node': x, 'action': 'rm'} for x in self.to_remove_nodes] - work_nodes = sorted(work_nodes, key=lambda x: (x['node'].lineno, x['node'].colno), reverse=True) + work_nodes = sorted(work_nodes, key=lambda x: (T.cast(BaseNode, x['node']).lineno, T.cast(BaseNode, x['node']).colno), reverse=True) work_nodes += [{'node': x, 'action': 'add'} for x in self.to_add_nodes] # Generating the new replacement string @@ -901,11 +927,11 @@ class Rewriter: new_data = '' if i['action'] == 'modify' or i['action'] == 'add': printer = AstPrinter() - i['node'].accept(printer) + T.cast(BaseNode, i['node']).accept(printer) printer.post_process() new_data = printer.result.strip() data = { - 'file': i['node'].filename, + 'file': T.cast(BaseNode, i['node']).filename, 'str': new_data, 'node': i['node'], 'action': i['action'] @@ -913,11 +939,11 @@ class Rewriter: str_list += [data] # Load build files - files = {} + files: T.Dict[str, T.Any] = {} for i in str_list: if i['file'] in files: continue - fpath = os.path.realpath(os.path.join(self.sourcedir, i['file'])) + fpath = os.path.realpath(os.path.join(self.sourcedir, T.cast(str, i['file']))) fdata = '' # Create an empty file if it does not exist if not os.path.exists(fpath): @@ -934,14 +960,14 @@ class Rewriter: line_offsets += [offset] offset += len(j) - files[i['file']] = { + files[T.cast(str, i['file'])] = { 'path': fpath, 'raw': fdata, 'offsets': line_offsets } # Replace in source code - def remove_node(i): + def remove_node(i: T.Dict[str, T.Any]) -> None: offsets = files[i['file']]['offsets'] raw = files[i['file']]['raw'] node = i['node'] @@ -969,7 +995,7 @@ class Rewriter: if i['action'] in {'modify', 'rm'}: remove_node(i) elif i['action'] == 'add': - files[i['file']]['raw'] += i['str'] + '\n' + files[T.cast(str, i['file'])]['raw'] += T.cast(str, i['str']) + '\n' # Write the files back for key, val in files.items(): @@ -1000,7 +1026,7 @@ def list_to_dict(in_list: T.List[str]) -> T.Dict[str, str]: raise TypeError('in_list parameter of list_to_dict must have an even length.') return result -def generate_target(options) -> T.List[dict]: +def generate_target(options: argparse.Namespace) -> T.List[T.Dict[str, T.Any]]: return [{ 'type': 'target', 'target': options.target, @@ -1010,7 +1036,7 @@ def generate_target(options) -> T.List[dict]: 'target_type': options.tgt_type, }] -def generate_kwargs(options) -> T.List[dict]: +def generate_kwargs(options: argparse.Namespace) -> T.List[T.Dict[str, T.Any]]: return [{ 'type': 'kwargs', 'function': options.function, @@ -1019,19 +1045,19 @@ def generate_kwargs(options) -> T.List[dict]: 'kwargs': list_to_dict(options.kwargs), }] -def generate_def_opts(options) -> T.List[dict]: +def generate_def_opts(options: argparse.Namespace) -> T.List[T.Dict[str, T.Any]]: return [{ 'type': 'default_options', 'operation': options.operation, 'options': list_to_dict(options.options), }] -def generate_cmd(options) -> T.List[dict]: +def generate_cmd(options: argparse.Namespace) -> T.List[T.Dict[str, T.Any]]: if os.path.exists(options.json): with open(options.json, encoding='utf-8') as fp: - return json.load(fp) + return T.cast(T.List[T.Dict[str, T.Any]], json.load(fp)) else: - return json.loads(options.json) + return T.cast(T.List[T.Dict[str, T.Any]], json.loads(options.json)) # Map options.type to the actual type name cli_type_map = { @@ -1044,7 +1070,7 @@ cli_type_map = { 'cmd': generate_cmd, } -def run(options): +def run(options: argparse.Namespace) -> int: mlog.redirect(True) if not options.verbose: mlog.set_quiet() diff --git a/run_mypy.py b/run_mypy.py index d7d3aaade..545328c67 100755 --- a/run_mypy.py +++ b/run_mypy.py @@ -83,6 +83,7 @@ modules = [ 'mesonbuild/optinterpreter.py', 'mesonbuild/options.py', 'mesonbuild/programs.py', + 'mesonbuild/rewriter.py', ] additional = [ 'run_mypy.py', |
