summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVolker Weißmann <volker.weissmann@gmx.de>2025-03-09 18:57:14 +0100
committerDylan Baker <dylan@pnwbakers.com>2025-05-29 09:20:27 -0700
commit4bbbd33d641e8761244342dc850e1b0ff97002b7 (patch)
tree09a7c35bce52cd748c359d0be7f8977200da642e
parent792db9439b30b64449097eb5e5fc940148cacaab (diff)
downloadmeson-4bbbd33d641e8761244342dc850e1b0ff97002b7.tar.gz
rewriter.py: make type safe
-rw-r--r--mesonbuild/ast/interpreter.py4
-rw-r--r--mesonbuild/ast/introspection.py10
-rw-r--r--mesonbuild/rewriter.py226
-rwxr-xr-xrun_mypy.py1
4 files changed, 137 insertions, 104 deletions
diff --git a/mesonbuild/ast/interpreter.py b/mesonbuild/ast/interpreter.py
index cde578a9a..b18bb5882 100644
--- a/mesonbuild/ast/interpreter.py
+++ b/mesonbuild/ast/interpreter.py
@@ -49,7 +49,7 @@ from ..mparser import (
if T.TYPE_CHECKING:
from .visitor import AstVisitor
from ..interpreter import Interpreter
- from ..interpreterbase import SubProject, TYPE_nkwargs, TYPE_var
+ from ..interpreterbase import SubProject, TYPE_nkwargs, TYPE_var, TYPE_nvar
from ..mparser import (
AndNode,
ComparisonNode,
@@ -387,7 +387,7 @@ class AstInterpreter(InterpreterBase):
return result
- def flatten_args(self, args_raw: T.Union[TYPE_var, T.Sequence[TYPE_var]], include_unknown_args: bool = False, id_loop_detect: T.Optional[T.List[str]] = None) -> T.List[TYPE_var]:
+ def flatten_args(self, args_raw: T.Union[TYPE_nvar, T.Sequence[TYPE_nvar]], include_unknown_args: bool = False, id_loop_detect: T.Optional[T.List[str]] = None) -> T.List[TYPE_var]:
# Make sure we are always dealing with lists
if isinstance(args_raw, list):
args = args_raw
diff --git a/mesonbuild/ast/introspection.py b/mesonbuild/ast/introspection.py
index 88ca72c84..b86b4e1e0 100644
--- a/mesonbuild/ast/introspection.py
+++ b/mesonbuild/ast/introspection.py
@@ -67,7 +67,7 @@ class IntrospectionInterpreter(AstInterpreter):
self.project_data: T.Dict[str, T.Any] = {}
self.targets: T.List[IntrospectionBuildTarget] = []
self.dependencies: T.List[IntrospectionDependency] = []
- self.project_node: BaseNode = None
+ self.project_node: FunctionNode = None
self.funcs.update({
'add_languages': self.func_add_languages,
@@ -85,6 +85,7 @@ class IntrospectionInterpreter(AstInterpreter):
def func_project(self, node: BaseNode, args: T.List[TYPE_var], kwargs: T.Dict[str, TYPE_var]) -> None:
if self.project_node:
raise InvalidArguments('Second call to project()')
+ assert isinstance(node, FunctionNode)
self.project_node = node
if len(args) < 1:
raise InvalidArguments('Not enough arguments to project(). Needs at least the project name.')
@@ -211,16 +212,20 @@ class IntrospectionInterpreter(AstInterpreter):
self.coredata.process_compiler_options(lang, comp, self.subproject)
def func_dependency(self, node: BaseNode, args: T.List[TYPE_var], kwargs: T.Dict[str, TYPE_var]) -> None:
+ assert isinstance(node, FunctionNode)
args = self.flatten_args(args)
kwargs = self.flatten_kwargs(kwargs)
if not args:
return
name = args[0]
+ assert isinstance(name, str)
has_fallback = 'fallback' in kwargs
required = kwargs.get('required', True)
version = kwargs.get('version', [])
if not isinstance(version, list):
version = [version]
+ assert all(isinstance(el, str) for el in version)
+ version = T.cast(T.List[str], version)
if isinstance(required, ElementaryNode):
required = required.value
if not isinstance(required, bool):
@@ -235,11 +240,12 @@ class IntrospectionInterpreter(AstInterpreter):
self.dependencies += [newdep]
def build_target(self, node: BaseNode, args: T.List[TYPE_var], kwargs_raw: T.Dict[str, TYPE_var], targetclass: T.Type[BuildTarget]) -> IntrospectionBuildTarget:
+ assert isinstance(node, FunctionNode)
args = self.flatten_args(args)
if not args or not isinstance(args[0], str):
return None
name = args[0]
- srcqueue = [node]
+ srcqueue: T.List[BaseNode] = [node]
extra_queue = []
# Process the sources BEFORE flattening the kwargs, to preserve the original nodes
diff --git a/mesonbuild/rewriter.py b/mesonbuild/rewriter.py
index 253e3f2fb..fff3f9cfc 100644
--- a/mesonbuild/rewriter.py
+++ b/mesonbuild/rewriter.py
@@ -10,24 +10,26 @@
from __future__ import annotations
from .ast import IntrospectionInterpreter, BUILD_TARGET_FUNCTIONS, AstConditionLevel, AstIDGenerator, AstIndentationGenerator, AstPrinter
-from .ast.interpreter import IntrospectionDependency
+from .ast.interpreter import IntrospectionBuildTarget, IntrospectionDependency
+from .interpreterbase import TV_func
from mesonbuild.mesonlib import MesonException, setup_vsenv
from . import mlog, environment
from functools import wraps
-from .mparser import Token, ArrayNode, ArgumentNode, AssignmentNode, StringNode, BooleanNode, ElementaryNode, IdNode, FunctionNode, SymbolNode
+from .mparser import Token, ArrayNode, ArgumentNode, AssignmentNode, BaseNode, StringNode, BooleanNode, ElementaryNode, IdNode, FunctionNode, SymbolNode
import json, os, re, sys
import typing as T
if T.TYPE_CHECKING:
- from argparse import ArgumentParser, HelpFormatter
- from .mparser import BaseNode
+ import argparse
+ from argparse import ArgumentParser, _FormatterClass
+ from .mlog import AnsiDecorator
class RewriterException(MesonException):
pass
# Note: when adding arguments, please also add them to the completion
# scripts in $MESONSRC/data/shell-completions/
-def add_arguments(parser: ArgumentParser, formatter: T.Callable[[str], HelpFormatter]) -> None:
+def add_arguments(parser: ArgumentParser, formatter: _FormatterClass) -> None:
parser.add_argument('-s', '--sourcedir', type=str, default='.', metavar='SRCDIR', help='Path to source directory.')
parser.add_argument('-V', '--verbose', action='store_true', default=False, help='Enable verbose output')
parser.add_argument('-S', '--skip-errors', dest='skip', action='store_true', default=False, help='Skip errors instead of aborting')
@@ -63,12 +65,14 @@ def add_arguments(parser: ArgumentParser, formatter: T.Callable[[str], HelpForma
cmd_parser.add_argument('json', help='JSON string or file to execute')
class RequiredKeys:
- def __init__(self, keys):
+ keys: T.Dict[str, T.Any]
+
+ def __init__(self, keys: T.Dict[str, T.Any]):
self.keys = keys
- def __call__(self, f):
+ def __call__(self, f: TV_func) -> TV_func:
@wraps(f)
- def wrapped(*wrapped_args, **wrapped_kwargs):
+ def wrapped(*wrapped_args: T.Any, **wrapped_kwargs: T.Any) -> T.Any:
assert len(wrapped_args) >= 2
cmd = wrapped_args[1]
for key, val in self.keys.items():
@@ -91,12 +95,14 @@ class RequiredKeys:
.format(key, choices, cmd[key]))
return f(*wrapped_args, **wrapped_kwargs)
- return wrapped
+ return T.cast('TV_func', wrapped)
def _symbol(val: str) -> SymbolNode:
return SymbolNode(Token('', '', 0, 0, 0, (0, 0), val))
class MTypeBase:
+ node: BaseNode
+
def __init__(self, node: T.Optional[BaseNode] = None):
if node is None:
self.node = self.new_node()
@@ -108,30 +114,30 @@ class MTypeBase:
self.node_type = i
@classmethod
- def new_node(cls, value=None):
+ def new_node(cls, value: T.Any = None) -> BaseNode:
# Overwrite in derived class
raise RewriterException('Internal error: new_node of MTypeBase was called')
@classmethod
- def supported_nodes(cls):
+ def supported_nodes(cls) -> T.List[type]:
# Overwrite in derived class
return []
- def can_modify(self):
+ def can_modify(self) -> bool:
return self.node_type is not None
- def get_node(self):
+ def get_node(self) -> BaseNode:
return self.node
- def add_value(self, value):
+ def add_value(self, value: T.Any) -> None:
# Overwrite in derived class
mlog.warning('Cannot add a value of type', mlog.bold(type(self).__name__), '--> skipping')
- def remove_value(self, value):
+ def remove_value(self, value: T.Any) -> None:
# Overwrite in derived class
mlog.warning('Cannot remove a value of type', mlog.bold(type(self).__name__), '--> skipping')
- def remove_regex(self, value):
+ def remove_regex(self, value: T.Any) -> None:
# Overwrite in derived class
mlog.warning('Cannot remove a regex in type', mlog.bold(type(self).__name__), '--> skipping')
@@ -140,13 +146,13 @@ class MTypeStr(MTypeBase):
super().__init__(node)
@classmethod
- def new_node(cls, value=None):
+ def new_node(cls, value: T.Optional[str] = None) -> BaseNode:
if value is None:
value = ''
return StringNode(Token('string', '', 0, 0, 0, None, str(value)))
@classmethod
- def supported_nodes(cls):
+ def supported_nodes(cls) -> T.List[type]:
return [StringNode]
class MTypeBool(MTypeBase):
@@ -154,11 +160,11 @@ class MTypeBool(MTypeBase):
super().__init__(node)
@classmethod
- def new_node(cls, value=None):
+ def new_node(cls, value: T.Optional[str] = None) -> BaseNode:
return BooleanNode(Token('', '', 0, 0, 0, None, bool(value)))
@classmethod
- def supported_nodes(cls):
+ def supported_nodes(cls) -> T.List[type]:
return [BooleanNode]
class MTypeID(MTypeBase):
@@ -166,21 +172,23 @@ class MTypeID(MTypeBase):
super().__init__(node)
@classmethod
- def new_node(cls, value=None):
+ def new_node(cls, value: T.Optional[str] = None) -> BaseNode:
if value is None:
value = ''
return IdNode(Token('', '', 0, 0, 0, None, str(value)))
@classmethod
- def supported_nodes(cls):
+ def supported_nodes(cls) -> T.List[type]:
return [IdNode]
class MTypeList(MTypeBase):
+ node: ArrayNode
+
def __init__(self, node: T.Optional[BaseNode] = None):
super().__init__(node)
@classmethod
- def new_node(cls, value=None):
+ def new_node(cls, value: T.Optional[T.List[T.Any]] = None) -> ArrayNode:
if value is None:
value = []
elif not isinstance(value, list):
@@ -190,50 +198,52 @@ class MTypeList(MTypeBase):
return ArrayNode(_symbol('['), args, _symbol(']'))
@classmethod
- def _new_element_node(cls, value):
+ def _new_element_node(cls, value: T.Any) -> BaseNode:
# Overwrite in derived class
raise RewriterException('Internal error: _new_element_node of MTypeList was called')
- def _ensure_array_node(self):
+ def _ensure_array_node(self) -> None:
if not isinstance(self.node, ArrayNode):
tmp = self.node
self.node = self.new_node()
self.node.args.arguments = [tmp]
@staticmethod
- def _check_is_equal(node, value) -> bool:
+ def _check_is_equal(node: BaseNode, value: str) -> bool:
# Overwrite in derived class
return False
@staticmethod
- def _check_regex_matches(node, regex: str) -> bool:
+ def _check_regex_matches(node: BaseNode, regex: str) -> bool:
# Overwrite in derived class
return False
- def get_node(self):
+ def get_node(self) -> BaseNode:
if isinstance(self.node, ArrayNode):
if len(self.node.args.arguments) == 1:
return self.node.args.arguments[0]
return self.node
@classmethod
- def supported_element_nodes(cls):
+ def supported_element_nodes(cls) -> T.List[T.Type]:
# Overwrite in derived class
return []
@classmethod
- def supported_nodes(cls):
+ def supported_nodes(cls) -> T.List[T.Type]:
return [ArrayNode] + cls.supported_element_nodes()
- def add_value(self, value):
+ def add_value(self, value: T.Any) -> None:
if not isinstance(value, list):
value = [value]
self._ensure_array_node()
for i in value:
+ assert hasattr(self.node, 'args') # For mypy
+ assert isinstance(self.node.args, ArgumentNode) # For mypy
self.node.args.arguments += [self._new_element_node(i)]
- def _remove_helper(self, value, equal_func):
- def check_remove_node(node):
+ def _remove_helper(self, value: T.Any, equal_func: T.Callable[[T.Any, T.Any], bool]) -> None:
+ def check_remove_node(node: BaseNode) -> bool:
for j in value:
if equal_func(i, j):
return True
@@ -242,16 +252,18 @@ class MTypeList(MTypeBase):
if not isinstance(value, list):
value = [value]
self._ensure_array_node()
+ assert hasattr(self.node, 'args') # For mypy
+ assert isinstance(self.node.args, ArgumentNode) # For mypy
removed_list = []
for i in self.node.args.arguments:
if not check_remove_node(i):
removed_list += [i]
self.node.args.arguments = removed_list
- def remove_value(self, value):
+ def remove_value(self, value: T.Any) -> None:
self._remove_helper(value, self._check_is_equal)
- def remove_regex(self, regex: str):
+ def remove_regex(self, regex: str) -> None:
self._remove_helper(regex, self._check_regex_matches)
class MTypeStrList(MTypeList):
@@ -259,23 +271,23 @@ class MTypeStrList(MTypeList):
super().__init__(node)
@classmethod
- def _new_element_node(cls, value):
+ def _new_element_node(cls, value: str) -> StringNode:
return StringNode(Token('string', '', 0, 0, 0, None, str(value)))
@staticmethod
- def _check_is_equal(node, value) -> bool:
+ def _check_is_equal(node: BaseNode, value: str) -> bool:
if isinstance(node, StringNode):
- return node.value == value
+ return bool(node.value == value)
return False
@staticmethod
- def _check_regex_matches(node, regex: str) -> bool:
+ def _check_regex_matches(node: BaseNode, regex: str) -> bool:
if isinstance(node, StringNode):
return re.match(regex, node.value) is not None
return False
@classmethod
- def supported_element_nodes(cls):
+ def supported_element_nodes(cls) -> T.List[T.Type]:
return [StringNode]
class MTypeIDList(MTypeList):
@@ -283,26 +295,26 @@ class MTypeIDList(MTypeList):
super().__init__(node)
@classmethod
- def _new_element_node(cls, value):
+ def _new_element_node(cls, value: str) -> IdNode:
return IdNode(Token('', '', 0, 0, 0, None, str(value)))
@staticmethod
- def _check_is_equal(node, value) -> bool:
+ def _check_is_equal(node: BaseNode, value: str) -> bool:
if isinstance(node, IdNode):
- return node.value == value
+ return bool(node.value == value)
return False
@staticmethod
- def _check_regex_matches(node, regex: str) -> bool:
+ def _check_regex_matches(node: BaseNode, regex: str) -> bool:
if isinstance(node, StringNode):
return re.match(regex, node.value) is not None
return False
@classmethod
- def supported_element_nodes(cls):
+ def supported_element_nodes(cls) -> T.List[T.Type]:
return [IdNode]
-rewriter_keys = {
+rewriter_keys: T.Dict[str, T.Dict[str, T.Any]] = {
'default_options': {
'operation': (str, None, ['set', 'delete']),
'options': (dict, {}, None)
@@ -356,13 +368,15 @@ rewriter_func_kwargs = {
}
class Rewriter:
+ info_dump: T.Optional[T.Dict[str, T.Dict[str, T.Any]]]
+
def __init__(self, sourcedir: str, generator: str = 'ninja', skip_errors: bool = False):
self.sourcedir = sourcedir
self.interpreter = IntrospectionInterpreter(sourcedir, '', generator, visitors = [AstIDGenerator(), AstIndentationGenerator(), AstConditionLevel()])
self.skip_errors = skip_errors
- self.modified_nodes = []
- self.to_remove_nodes = []
- self.to_add_nodes = []
+ self.modified_nodes: T.List[BaseNode] = []
+ self.to_remove_nodes: T.List[BaseNode] = []
+ self.to_add_nodes: T.List[BaseNode] = []
self.functions = {
'default_options': self.process_default_options,
'kwargs': self.process_kwargs,
@@ -370,36 +384,36 @@ class Rewriter:
}
self.info_dump = None
- def analyze_meson(self):
+ def analyze_meson(self) -> None:
mlog.log('Analyzing meson file:', mlog.bold(os.path.join(self.sourcedir, environment.build_filename)))
self.interpreter.analyze()
mlog.log(' -- Project:', mlog.bold(self.interpreter.project_data['descriptive_name']))
mlog.log(' -- Version:', mlog.cyan(self.interpreter.project_data['version']))
- def add_info(self, cmd_type: str, cmd_id: str, data: dict):
+ def add_info(self, cmd_type: str, cmd_id: str, data: dict) -> None:
if self.info_dump is None:
self.info_dump = {}
if cmd_type not in self.info_dump:
self.info_dump[cmd_type] = {}
self.info_dump[cmd_type][cmd_id] = data
- def print_info(self):
+ def print_info(self) -> None:
if self.info_dump is None:
return
sys.stdout.write(json.dumps(self.info_dump, indent=2))
- def on_error(self):
+ def on_error(self) -> T.Tuple[AnsiDecorator, AnsiDecorator]:
if self.skip_errors:
return mlog.cyan('-->'), mlog.yellow('skipping')
return mlog.cyan('-->'), mlog.red('aborting')
- def handle_error(self):
+ def handle_error(self) -> None:
if self.skip_errors:
return None
raise MesonException('Rewriting the meson.build failed')
- def find_target(self, target: str):
- def check_list(name: str) -> T.List[BaseNode]:
+ def find_target(self, target: str) -> T.Optional[IntrospectionBuildTarget]:
+ def check_list(name: str) -> T.List[IntrospectionBuildTarget]:
result = []
for i in self.interpreter.targets:
if name in {i.name, i.id}:
@@ -426,10 +440,11 @@ class Rewriter:
if node.func_name.value in {'executable', 'jar', 'library', 'shared_library', 'shared_module', 'static_library', 'both_libraries'}:
tgt = self.interpreter.assign_vals[target]
+ assert isinstance(tgt, (type(None), IntrospectionBuildTarget))
return tgt
def find_dependency(self, dependency: str) -> T.Optional[IntrospectionDependency]:
- def check_list(name: str):
+ def check_list(name: str) -> T.Optional[IntrospectionDependency]:
for i in self.interpreter.dependencies:
if name == i.name:
return i
@@ -445,14 +460,15 @@ class Rewriter:
if isinstance(node, FunctionNode):
if node.func_name.value == 'dependency':
name = self.interpreter.flatten_args(node.args)[0]
+ assert isinstance(name, str)
dep = check_list(name)
return dep
@RequiredKeys(rewriter_keys['default_options'])
- def process_default_options(self, cmd):
+ def process_default_options(self, cmd: T.Dict[str, T.Any]) -> None:
# First, remove the old values
- kwargs_cmd = {
+ kwargs_cmd: T.Dict[str, T.Any] = {
'function': 'project',
'id': "/",
'operation': 'remove_regex',
@@ -496,7 +512,7 @@ class Rewriter:
self.process_kwargs(kwargs_cmd)
@RequiredKeys(rewriter_keys['kwargs'])
- def process_kwargs(self, cmd):
+ def process_kwargs(self, cmd: T.Dict[str, T.Any]) -> None:
mlog.log('Processing function type', mlog.bold(cmd['function']), 'with id', mlog.cyan("'" + cmd['id'] + "'"))
if cmd['function'] not in rewriter_func_kwargs:
mlog.error('Unknown function type', cmd['function'], *self.on_error())
@@ -531,12 +547,12 @@ class Rewriter:
assert isinstance(node, FunctionNode)
assert isinstance(arg_node, ArgumentNode)
# Transform the key nodes to plain strings
- arg_node.kwargs = {k.value: v for k, v in arg_node.kwargs.items()}
+ kwargs = {T.cast(IdNode, k).value: v for k, v in arg_node.kwargs.items()}
# Print kwargs info
if cmd['operation'] == 'info':
- info_data = {}
- for key, val in sorted(arg_node.kwargs.items()):
+ info_data: T.Dict[str, T.Any] = {}
+ for key, val in sorted(kwargs.items()):
info_data[key] = None
if isinstance(val, ElementaryNode):
info_data[key] = val.value
@@ -562,21 +578,21 @@ class Rewriter:
if cmd['operation'] == 'delete':
# Remove the key from the kwargs
- if key not in arg_node.kwargs:
+ if key not in kwargs:
mlog.log(' -- Key', mlog.bold(key), 'is already deleted')
continue
mlog.log(' -- Deleting', mlog.bold(key), 'from the kwargs')
- del arg_node.kwargs[key]
+ del kwargs[key]
elif cmd['operation'] == 'set':
# Replace the key from the kwargs
mlog.log(' -- Setting', mlog.bold(key), 'to', mlog.yellow(str(val)))
- arg_node.kwargs[key] = kwargs_def[key].new_node(val)
+ kwargs[key] = kwargs_def[key].new_node(val)
else:
# Modify the value from the kwargs
- if key not in arg_node.kwargs:
- arg_node.kwargs[key] = None
- modifier = kwargs_def[key](arg_node.kwargs[key])
+ if key not in kwargs:
+ kwargs[key] = None
+ modifier = kwargs_def[key](kwargs[key])
if not modifier.can_modify():
mlog.log(' -- Skipping', mlog.bold(key), 'because it is too complex to modify')
continue
@@ -594,12 +610,12 @@ class Rewriter:
modifier.remove_regex(val)
# Write back the result
- arg_node.kwargs[key] = modifier.get_node()
+ kwargs[key] = modifier.get_node()
num_changed += 1
# Convert the keys back to IdNode's
- arg_node.kwargs = {IdNode(Token('', '', 0, 0, 0, None, k)): v for k, v in arg_node.kwargs.items()}
+ arg_node.kwargs = {IdNode(Token('', '', 0, 0, 0, None, k)): v for k, v in kwargs.items()}
for k, v in arg_node.kwargs.items():
k.level = v.level
if num_changed > 0 and node not in self.modified_nodes:
@@ -607,11 +623,13 @@ class Rewriter:
def find_assignment_node(self, node: BaseNode) -> AssignmentNode:
if node.ast_id and node.ast_id in self.interpreter.reverse_assignment:
- return self.interpreter.reverse_assignment[node.ast_id]
+ ret = self.interpreter.reverse_assignment[node.ast_id]
+ assert isinstance(ret, AssignmentNode)
+ return ret
return None
@RequiredKeys(rewriter_keys['target'])
- def process_target(self, cmd):
+ def process_target(self, cmd: T.Dict[str, T.Any]) -> None:
mlog.log('Processing target', mlog.bold(cmd['target']), 'operation', mlog.cyan(cmd['operation']))
target = self.find_target(cmd['target'])
if target is None and cmd['operation'] != 'target_add':
@@ -632,7 +650,7 @@ class Rewriter:
cmd['sources'] = [rel_source(x) for x in cmd['sources']]
# Utility function to get a list of the sources from a node
- def arg_list_from_node(n):
+ def arg_list_from_node(n: BaseNode) -> T.List[BaseNode]:
args = []
if isinstance(n, FunctionNode):
args = list(n.args.arguments)
@@ -656,8 +674,8 @@ class Rewriter:
# Generate the current source list
src_list = []
- for i in target.source_nodes:
- for j in arg_list_from_node(i):
+ for src_node in target.source_nodes:
+ for j in arg_list_from_node(src_node):
if isinstance(j, StringNode):
src_list += [j.value]
@@ -689,7 +707,7 @@ class Rewriter:
elif cmd['operation'] == 'src_rm':
# Helper to find the exact string node and its parent
- def find_node(src):
+ def find_node(src: str) -> T.Tuple[BaseNode, StringNode]:
for i in target.source_nodes:
for j in arg_list_from_node(i):
if isinstance(j, StringNode):
@@ -779,7 +797,7 @@ class Rewriter:
elif cmd['operation'] == 'extra_files_rm':
# Helper to find the exact string node and its parent
- def find_node(src):
+ def find_node(src: str) -> T.Tuple[BaseNode, StringNode]:
for i in target.extra_files:
for j in arg_list_from_node(i):
if isinstance(j, StringNode):
@@ -794,6 +812,9 @@ class Rewriter:
mlog.warning(' -- Unable to find extra file', mlog.green(i), 'in the target')
continue
+ if not isinstance(root, (FunctionNode, ArrayNode)):
+ raise NotImplementedError # I'm lazy
+
# Remove the found string node from the argument list
arg_node = root.args
mlog.log(' -- Removing extra file', mlog.green(i), 'from',
@@ -839,7 +860,7 @@ class Rewriter:
self.to_add_nodes += [src_ass_node, tgt_ass_node]
elif cmd['operation'] == 'target_rm':
- to_remove = self.find_assignment_node(target.node)
+ to_remove: BaseNode = self.find_assignment_node(target.node)
if to_remove is None:
to_remove = target.node
self.to_remove_nodes += [to_remove]
@@ -867,16 +888,21 @@ class Rewriter:
# Sort files
for i in to_sort_nodes:
- convert = lambda text: int(text) if text.isdigit() else text.lower()
- alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
- path_sorter = lambda key: ([(key.count('/') <= idx, alphanum_key(x)) for idx, x in enumerate(key.split('/'))])
+ def convert(text: str) -> T.Union[int, str]:
+ return int(text) if text.isdigit() else text.lower()
+
+ def alphanum_key(key: str) -> T.List[T.Union[int, str]]:
+ return [convert(c) for c in re.split('([0-9]+)', key)]
+
+ def path_sorter(key: str) -> T.List[T.Tuple[bool, T.List[T.Union[int, str]]]]:
+ return [(key.count('/') <= idx, alphanum_key(x)) for idx, x in enumerate(key.split('/'))]
unknown = [x for x in i.arguments if not isinstance(x, StringNode)]
sources = [x for x in i.arguments if isinstance(x, StringNode)]
sources = sorted(sources, key=lambda x: path_sorter(x.value))
- i.arguments = unknown + sources
+ i.arguments = unknown + T.cast(T.List[BaseNode], sources)
- def process(self, cmd):
+ def process(self, cmd: T.Dict[str, T.Any]) -> None:
if 'type' not in cmd:
raise RewriterException('Command has no key "type"')
if cmd['type'] not in self.functions:
@@ -884,7 +910,7 @@ class Rewriter:
.format(cmd['type'], list(self.functions.keys())))
self.functions[cmd['type']](cmd)
- def apply_changes(self):
+ def apply_changes(self) -> None:
assert all(hasattr(x, 'lineno') and hasattr(x, 'colno') and hasattr(x, 'filename') for x in self.modified_nodes)
assert all(hasattr(x, 'lineno') and hasattr(x, 'colno') and hasattr(x, 'filename') for x in self.to_remove_nodes)
assert all(isinstance(x, (ArrayNode, FunctionNode)) for x in self.modified_nodes)
@@ -892,7 +918,7 @@ class Rewriter:
# Sort based on line and column in reversed order
work_nodes = [{'node': x, 'action': 'modify'} for x in self.modified_nodes]
work_nodes += [{'node': x, 'action': 'rm'} for x in self.to_remove_nodes]
- work_nodes = sorted(work_nodes, key=lambda x: (x['node'].lineno, x['node'].colno), reverse=True)
+ work_nodes = sorted(work_nodes, key=lambda x: (T.cast(BaseNode, x['node']).lineno, T.cast(BaseNode, x['node']).colno), reverse=True)
work_nodes += [{'node': x, 'action': 'add'} for x in self.to_add_nodes]
# Generating the new replacement string
@@ -901,11 +927,11 @@ class Rewriter:
new_data = ''
if i['action'] == 'modify' or i['action'] == 'add':
printer = AstPrinter()
- i['node'].accept(printer)
+ T.cast(BaseNode, i['node']).accept(printer)
printer.post_process()
new_data = printer.result.strip()
data = {
- 'file': i['node'].filename,
+ 'file': T.cast(BaseNode, i['node']).filename,
'str': new_data,
'node': i['node'],
'action': i['action']
@@ -913,11 +939,11 @@ class Rewriter:
str_list += [data]
# Load build files
- files = {}
+ files: T.Dict[str, T.Any] = {}
for i in str_list:
if i['file'] in files:
continue
- fpath = os.path.realpath(os.path.join(self.sourcedir, i['file']))
+ fpath = os.path.realpath(os.path.join(self.sourcedir, T.cast(str, i['file'])))
fdata = ''
# Create an empty file if it does not exist
if not os.path.exists(fpath):
@@ -934,14 +960,14 @@ class Rewriter:
line_offsets += [offset]
offset += len(j)
- files[i['file']] = {
+ files[T.cast(str, i['file'])] = {
'path': fpath,
'raw': fdata,
'offsets': line_offsets
}
# Replace in source code
- def remove_node(i):
+ def remove_node(i: T.Dict[str, T.Any]) -> None:
offsets = files[i['file']]['offsets']
raw = files[i['file']]['raw']
node = i['node']
@@ -969,7 +995,7 @@ class Rewriter:
if i['action'] in {'modify', 'rm'}:
remove_node(i)
elif i['action'] == 'add':
- files[i['file']]['raw'] += i['str'] + '\n'
+ files[T.cast(str, i['file'])]['raw'] += T.cast(str, i['str']) + '\n'
# Write the files back
for key, val in files.items():
@@ -1000,7 +1026,7 @@ def list_to_dict(in_list: T.List[str]) -> T.Dict[str, str]:
raise TypeError('in_list parameter of list_to_dict must have an even length.')
return result
-def generate_target(options) -> T.List[dict]:
+def generate_target(options: argparse.Namespace) -> T.List[T.Dict[str, T.Any]]:
return [{
'type': 'target',
'target': options.target,
@@ -1010,7 +1036,7 @@ def generate_target(options) -> T.List[dict]:
'target_type': options.tgt_type,
}]
-def generate_kwargs(options) -> T.List[dict]:
+def generate_kwargs(options: argparse.Namespace) -> T.List[T.Dict[str, T.Any]]:
return [{
'type': 'kwargs',
'function': options.function,
@@ -1019,19 +1045,19 @@ def generate_kwargs(options) -> T.List[dict]:
'kwargs': list_to_dict(options.kwargs),
}]
-def generate_def_opts(options) -> T.List[dict]:
+def generate_def_opts(options: argparse.Namespace) -> T.List[T.Dict[str, T.Any]]:
return [{
'type': 'default_options',
'operation': options.operation,
'options': list_to_dict(options.options),
}]
-def generate_cmd(options) -> T.List[dict]:
+def generate_cmd(options: argparse.Namespace) -> T.List[T.Dict[str, T.Any]]:
if os.path.exists(options.json):
with open(options.json, encoding='utf-8') as fp:
- return json.load(fp)
+ return T.cast(T.List[T.Dict[str, T.Any]], json.load(fp))
else:
- return json.loads(options.json)
+ return T.cast(T.List[T.Dict[str, T.Any]], json.loads(options.json))
# Map options.type to the actual type name
cli_type_map = {
@@ -1044,7 +1070,7 @@ cli_type_map = {
'cmd': generate_cmd,
}
-def run(options):
+def run(options: argparse.Namespace) -> int:
mlog.redirect(True)
if not options.verbose:
mlog.set_quiet()
diff --git a/run_mypy.py b/run_mypy.py
index d7d3aaade..545328c67 100755
--- a/run_mypy.py
+++ b/run_mypy.py
@@ -83,6 +83,7 @@ modules = [
'mesonbuild/optinterpreter.py',
'mesonbuild/options.py',
'mesonbuild/programs.py',
+ 'mesonbuild/rewriter.py',
]
additional = [
'run_mypy.py',