- 8 DB models (services, incidents, monitors, subscribers, etc.) - Full CRUD API for services, incidents, monitors - Public status page with live data - Incident detail page with timeline - API key authentication - Uptime monitoring scheduler - 13 tests passing - TECHNICAL_DESIGN.md with full spec
2094 lines
71 KiB
Python
2094 lines
71 KiB
Python
# mypy: allow-redefinition-new, local-partial-types
|
|
"""Python parser that directly constructs a native AST (when compiled).
|
|
|
|
Use a Rust extension to generate a serialized AST, and deserialize the AST directly
|
|
to a mypy AST.
|
|
|
|
NOTE: This is work in progress. To use this, you need to manually build the
|
|
ast_serialize Rust extension. See the README at https://github.com/mypyc/ast_serialize.
|
|
|
|
Expected benefits over mypy.fastparse:
|
|
* No intermediate non-mypyc Python-level AST created, to improve performance
|
|
* Parsing doesn't need GIL => use multithreading to construct serialized ASTs in parallel
|
|
* Produce import dependencies without having to build an AST => helps parallel type checking
|
|
* Support all Python syntax even if mypy is running on an older Python version
|
|
* Generate an AST even if there are syntax errors
|
|
* Potential to support incremental parsing (quickly process modified sections in a file)
|
|
* Stripping function bodies in third-party code can happen earlier, for extra performance
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from typing import Any, Final, cast
|
|
|
|
import ast_serialize # type: ignore[import-untyped, import-not-found, unused-ignore]
|
|
from librt.internal import (
|
|
read_float as read_float_bare,
|
|
read_int as read_int_bare,
|
|
read_str as read_str_bare,
|
|
)
|
|
|
|
from mypy import message_registry, nodes, types
|
|
from mypy.cache import (
|
|
DICT_STR_GEN,
|
|
END_TAG,
|
|
LIST_GEN,
|
|
LIST_INT,
|
|
LITERAL_FLOAT,
|
|
LITERAL_NONE,
|
|
LITERAL_STR,
|
|
LOCATION,
|
|
ReadBuffer,
|
|
Tag,
|
|
read_bool,
|
|
read_int,
|
|
read_str,
|
|
read_str_opt,
|
|
read_tag,
|
|
)
|
|
from mypy.nodes import (
|
|
ARG_KINDS,
|
|
ARG_POS,
|
|
IMPORT_METADATA,
|
|
IMPORTALL_METADATA,
|
|
IMPORTFROM_METADATA,
|
|
MISSING_FALLBACK,
|
|
Argument,
|
|
AssertStmt,
|
|
AssignmentExpr,
|
|
AssignmentStmt,
|
|
AwaitExpr,
|
|
Block,
|
|
BreakStmt,
|
|
BytesExpr,
|
|
CallExpr,
|
|
ClassDef,
|
|
ComparisonExpr,
|
|
ComplexExpr,
|
|
ConditionalExpr,
|
|
Context,
|
|
ContinueStmt,
|
|
Decorator,
|
|
DelStmt,
|
|
DictExpr,
|
|
DictionaryComprehension,
|
|
EllipsisExpr,
|
|
Expression,
|
|
ExpressionStmt,
|
|
FileRawData,
|
|
FloatExpr,
|
|
ForStmt,
|
|
FuncDef,
|
|
GeneratorExpr,
|
|
GlobalDecl,
|
|
IfStmt,
|
|
Import,
|
|
ImportAll,
|
|
ImportBase,
|
|
ImportFrom,
|
|
IndexExpr,
|
|
IntExpr,
|
|
LambdaExpr,
|
|
ListComprehension,
|
|
ListExpr,
|
|
MatchStmt,
|
|
MemberExpr,
|
|
MypyFile,
|
|
NameExpr,
|
|
NonlocalDecl,
|
|
OperatorAssignmentStmt,
|
|
OpExpr,
|
|
OverloadedFuncDef,
|
|
OverloadPart,
|
|
PassStmt,
|
|
RaiseStmt,
|
|
RefExpr,
|
|
ReturnStmt,
|
|
SetComprehension,
|
|
SetExpr,
|
|
SliceExpr,
|
|
StarExpr,
|
|
Statement,
|
|
StrExpr,
|
|
SuperExpr,
|
|
TemplateStrExpr,
|
|
TempNode,
|
|
TryStmt,
|
|
TupleExpr,
|
|
TypeAliasStmt,
|
|
TypeParam,
|
|
UnaryExpr,
|
|
Var,
|
|
WhileStmt,
|
|
WithStmt,
|
|
YieldExpr,
|
|
YieldFromExpr,
|
|
)
|
|
from mypy.options import Options
|
|
from mypy.patterns import (
|
|
AsPattern,
|
|
ClassPattern,
|
|
MappingPattern,
|
|
OrPattern,
|
|
Pattern,
|
|
SequencePattern,
|
|
SingletonPattern,
|
|
StarredPattern,
|
|
ValuePattern,
|
|
)
|
|
from mypy.reachability import infer_reachability_of_if_statement
|
|
from mypy.sharedparse import special_function_elide_names
|
|
from mypy.types import (
|
|
AnyType,
|
|
CallableArgument,
|
|
CallableType,
|
|
EllipsisType,
|
|
Instance,
|
|
ProperType,
|
|
RawExpressionType,
|
|
TupleType,
|
|
Type,
|
|
TypedDictType,
|
|
TypeList,
|
|
TypeOfAny,
|
|
UnboundType,
|
|
UnionType,
|
|
UnpackType,
|
|
)
|
|
from mypy.util import unnamed_function
|
|
|
|
TypeIgnores = list[tuple[int, list[str]]]
|
|
|
|
# There is no way to create reasonable fallbacks at this stage,
|
|
# they must be patched later.
|
|
_dummy_fallback: Final = Instance(MISSING_FALLBACK, [], -1)
|
|
|
|
|
|
class State:
|
|
def __init__(self, options: Options) -> None:
|
|
self.options = options
|
|
self.errors: list[dict[str, Any]] = []
|
|
self.num_funcs = 0
|
|
|
|
def add_error(
|
|
self,
|
|
message: str,
|
|
line: int,
|
|
column: int,
|
|
*,
|
|
blocker: bool = False,
|
|
code: str | None = None,
|
|
) -> None:
|
|
"""Report an error at a specific location.
|
|
|
|
Args:
|
|
message: Error message to display
|
|
line: Line number where error occurred
|
|
column: Column number where error occurred
|
|
blocker: If True, this error blocks further analysis
|
|
code: Error code for categorization
|
|
"""
|
|
self.errors.append(
|
|
{"line": line, "column": column, "message": message, "blocker": blocker, "code": code}
|
|
)
|
|
|
|
|
|
def native_parse(
|
|
filename: str, options: Options, skip_function_bodies: bool = False, imports_only: bool = False
|
|
) -> tuple[MypyFile, list[dict[str, Any]], TypeIgnores]:
|
|
"""Parse a Python file using the native Rust-based parser.
|
|
|
|
Uses the ast_serialize Rust extension to parse Python code and deserialize
|
|
the resulting AST directly into mypy's native AST representation.
|
|
|
|
Args:
|
|
filename: Path to the Python source file to parse
|
|
options: Mypy options affecting parsing behavior (e.g., Python version)
|
|
skip_function_bodies: If True, many function and method bodies are omitted from
|
|
the AST, useful for parsing stubs or extracting signatures without full
|
|
implementation details
|
|
imports_only: If True create an empty MypyFile with actual serialized defs
|
|
stored in binary_data.
|
|
|
|
Returns:
|
|
A tuple containing:
|
|
- MypyFile: The parsed AST as a mypy AST node
|
|
- list[dict[str, Any]]: List of parse errors and deserialization errors
|
|
- TypeIgnores: List of (line_number, ignored_codes) tuples for type: ignore comments
|
|
"""
|
|
# If the path is a directory, return empty AST (matching fastparse behavior)
|
|
# This can happen for packages that only contain .pyc files without source
|
|
if os.path.isdir(filename):
|
|
node = MypyFile([], [])
|
|
node.path = filename
|
|
return node, [], []
|
|
|
|
b, errors, ignores, import_bytes, is_partial_package, uses_template_strings = (
|
|
parse_to_binary_ast(filename, options, skip_function_bodies)
|
|
)
|
|
data = ReadBuffer(b)
|
|
n = read_int(data)
|
|
state = State(options)
|
|
if imports_only:
|
|
defs = []
|
|
else:
|
|
defs = read_statements(state, data, n)
|
|
|
|
imports = deserialize_imports(import_bytes)
|
|
|
|
node = MypyFile(defs, imports)
|
|
node.path = filename
|
|
node.is_partial_stub_package = is_partial_package
|
|
if imports_only:
|
|
node.raw_data = FileRawData(
|
|
b, import_bytes, errors, dict(ignores), is_partial_package, uses_template_strings
|
|
)
|
|
node.uses_template_strings = uses_template_strings
|
|
# Merge deserialization errors with parsing errors
|
|
all_errors = errors + state.errors
|
|
return node, all_errors, ignores
|
|
|
|
|
|
def expect_end_tag(data: ReadBuffer) -> None:
|
|
assert read_tag(data) == END_TAG
|
|
|
|
|
|
def expect_tag(data: ReadBuffer, tag: Tag) -> None:
|
|
assert (actual := read_tag(data)) == tag, actual
|
|
|
|
|
|
def read_statements(state: State, data: ReadBuffer, n: int) -> list[Statement]:
|
|
defs: list[Statement] = []
|
|
old_num_funcs = state.num_funcs
|
|
for _ in range(n):
|
|
stmt = read_statement(state, data)
|
|
defs.append(stmt)
|
|
if state.num_funcs > old_num_funcs + 1:
|
|
# There were at least two functions, so we may need to merge overloads.
|
|
defs = fix_function_overloads(state, defs)
|
|
return defs
|
|
|
|
|
|
def parse_to_binary_ast(
|
|
filename: str, options: Options, skip_function_bodies: bool = False
|
|
) -> tuple[bytes, list[dict[str, Any]], TypeIgnores, bytes, bool, bool]:
|
|
ast_bytes, errors, ignores, import_bytes, ast_data = ast_serialize.parse(
|
|
filename,
|
|
skip_function_bodies=skip_function_bodies,
|
|
python_version=options.python_version,
|
|
platform=options.platform,
|
|
always_true=options.always_true,
|
|
always_false=options.always_false,
|
|
)
|
|
return (
|
|
ast_bytes,
|
|
cast("list[dict[str, Any]]", errors),
|
|
ignores,
|
|
import_bytes,
|
|
ast_data["is_partial_package"],
|
|
ast_data["uses_template_strings"],
|
|
)
|
|
|
|
|
|
def read_statement(state: State, data: ReadBuffer) -> Statement:
|
|
tag = read_tag(data)
|
|
stmt: Statement
|
|
if tag == nodes.FUNC_DEF_STMT:
|
|
return read_func_def(state, data)
|
|
elif tag == nodes.DECORATOR:
|
|
expect_tag(data, LIST_GEN)
|
|
n_decorators = read_int_bare(data)
|
|
decorators = [read_expression(state, data) for i in range(n_decorators)]
|
|
line = read_int(data)
|
|
column = read_int(data)
|
|
fdef = read_statement(state, data)
|
|
assert isinstance(fdef, FuncDef)
|
|
fdef.is_decorated = True
|
|
var = Var(fdef.name)
|
|
var.line = fdef.line
|
|
var.is_ready = False
|
|
stmt = Decorator(fdef, decorators, var)
|
|
stmt.line = line
|
|
stmt.column = column
|
|
stmt.end_line = fdef.end_line
|
|
stmt.end_column = fdef.end_column
|
|
# TODO: Adjust funcdef location to start after decorator?
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.EXPR_STMT:
|
|
es = ExpressionStmt(read_expression(state, data))
|
|
set_line_column_range(es, es.expr)
|
|
expect_end_tag(data)
|
|
return es
|
|
elif tag == nodes.ASSIGNMENT_STMT:
|
|
lvalues = read_expression_list(state, data)
|
|
rvalue = read_expression(state, data)
|
|
has_type = read_bool(data)
|
|
if has_type:
|
|
type_annotation = read_type(state, data)
|
|
else:
|
|
type_annotation = None
|
|
new_syntax = read_bool(data)
|
|
a = AssignmentStmt(lvalues, rvalue, type=type_annotation, new_syntax=new_syntax)
|
|
read_loc(data, a)
|
|
# If rvalue is TempNode, copy location from AssignmentStmt
|
|
if isinstance(rvalue, TempNode):
|
|
set_line_column_range(rvalue, a)
|
|
expect_end_tag(data)
|
|
return a
|
|
elif tag == nodes.OPERATOR_ASSIGNMENT_STMT:
|
|
# Read operator string
|
|
op = read_str(data)
|
|
# Read lvalue (target)
|
|
lvalue = read_expression(state, data)
|
|
# Read rvalue (value)
|
|
rvalue = read_expression(state, data)
|
|
stmt = OperatorAssignmentStmt(op, lvalue, rvalue)
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.IF_STMT:
|
|
# Read the main if condition and body
|
|
expr = read_expression(state, data)
|
|
body = read_block(state, data)
|
|
|
|
# Read elif clauses
|
|
num_elif = read_int(data)
|
|
elif_exprs = []
|
|
elif_bodies = []
|
|
for i in range(num_elif):
|
|
elif_exprs.append(read_expression(state, data))
|
|
elif_bodies.append(read_block(state, data))
|
|
|
|
has_else = read_bool(data)
|
|
if has_else:
|
|
else_body = read_block(state, data)
|
|
else:
|
|
else_body = None
|
|
|
|
# Normalize elif into nested if/else statements
|
|
# Build from the bottom up, starting with the final else body
|
|
current_else = else_body
|
|
|
|
# Process elif clauses in reverse order
|
|
for i in range(len(elif_exprs) - 1, -1, -1):
|
|
elif_stmt = IfStmt([elif_exprs[i]], [elif_bodies[i]], current_else)
|
|
# Set location from the elif expression
|
|
elif_stmt.line = elif_exprs[i].line
|
|
elif_stmt.column = elif_exprs[i].column
|
|
# Set end location based on what follows
|
|
if current_else is not None:
|
|
elif_stmt.end_line = current_else.end_line
|
|
elif_stmt.end_column = current_else.end_column
|
|
else:
|
|
elif_stmt.end_line = elif_bodies[i].end_line
|
|
elif_stmt.end_column = elif_bodies[i].end_column
|
|
|
|
# Wrap in a Block to become the else clause for the outer if
|
|
current_else = Block([elif_stmt])
|
|
set_line_column_range(current_else, elif_stmt)
|
|
|
|
if_stmt = IfStmt([expr], [body], current_else)
|
|
read_loc(data, if_stmt)
|
|
expect_end_tag(data)
|
|
return if_stmt
|
|
elif tag == nodes.RETURN_STMT:
|
|
has_value = read_bool(data)
|
|
if has_value:
|
|
value = read_expression(state, data)
|
|
else:
|
|
value = None
|
|
stmt = ReturnStmt(value)
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.RAISE_STMT:
|
|
has_exc = read_bool(data)
|
|
if has_exc:
|
|
exc = read_expression(state, data)
|
|
else:
|
|
exc = None
|
|
has_from = read_bool(data)
|
|
if has_from:
|
|
from_expr = read_expression(state, data)
|
|
else:
|
|
from_expr = None
|
|
stmt = RaiseStmt(exc, from_expr)
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.ASSERT_STMT:
|
|
test = read_expression(state, data)
|
|
has_msg = read_bool(data)
|
|
if has_msg:
|
|
msg = read_expression(state, data)
|
|
else:
|
|
msg = None
|
|
stmt = AssertStmt(test, msg)
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.WHILE_STMT:
|
|
expr = read_expression(state, data)
|
|
body = read_block(state, data)
|
|
else_body = read_optional_block(state, data)
|
|
stmt = WhileStmt(expr, body, else_body)
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.FOR_STMT:
|
|
index = read_expression(state, data)
|
|
expr = read_expression(state, data)
|
|
body = read_block(state, data)
|
|
else_body = read_optional_block(state, data)
|
|
is_async = read_bool(data)
|
|
stmt = ForStmt(index, expr, body, else_body)
|
|
stmt.is_async = is_async
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.WITH_STMT:
|
|
n = read_int(data)
|
|
expr_list = []
|
|
target_list: list[Expression | None] = []
|
|
for _ in range(n):
|
|
context_expr = read_expression(state, data)
|
|
expr_list.append(context_expr)
|
|
has_target = read_bool(data)
|
|
if has_target:
|
|
target = read_expression(state, data)
|
|
target_list.append(target)
|
|
else:
|
|
target_list.append(None)
|
|
body = read_block(state, data)
|
|
is_async = read_bool(data)
|
|
stmt = WithStmt(expr_list, target_list, body)
|
|
stmt.is_async = is_async
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.PASS_STMT:
|
|
stmt = PassStmt()
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.BREAK_STMT:
|
|
stmt = BreakStmt()
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.CONTINUE_STMT:
|
|
stmt = ContinueStmt()
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.IMPORT:
|
|
n = read_int(data)
|
|
ids = []
|
|
for _ in range(n):
|
|
name = read_str(data)
|
|
has_asname = read_bool(data)
|
|
if has_asname:
|
|
asname = read_str(data)
|
|
else:
|
|
asname = None
|
|
ids.append((name, asname))
|
|
stmt = Import(ids)
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.IMPORT_FROM:
|
|
relative = read_int(data)
|
|
module_id = read_str(data) # Empty string for "from . import x"
|
|
n = read_int(data)
|
|
names = []
|
|
for _ in range(n):
|
|
name = read_str(data)
|
|
has_asname = read_bool(data)
|
|
if has_asname:
|
|
asname = read_str(data)
|
|
else:
|
|
asname = None
|
|
names.append((name, asname))
|
|
|
|
stmt = ImportFrom(module_id, relative, names)
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.IMPORT_ALL:
|
|
module_id = read_str(data) # Empty string for "from . import *"
|
|
relative = read_int(data)
|
|
|
|
stmt = ImportAll(module_id, relative)
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.CLASS_DEF:
|
|
return read_class_def(state, data)
|
|
elif tag == nodes.TYPE_ALIAS_STMT:
|
|
return read_type_alias_stmt(state, data)
|
|
elif tag == nodes.TRY_STMT:
|
|
return read_try_stmt(state, data)
|
|
elif tag == nodes.DEL_STMT:
|
|
expr = read_expression(state, data)
|
|
stmt = DelStmt(expr)
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.GLOBAL_DECL:
|
|
n = read_int(data)
|
|
decl_names = []
|
|
for _ in range(n):
|
|
decl_names.append(read_str(data))
|
|
stmt = GlobalDecl(decl_names)
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.NONLOCAL_DECL:
|
|
n = read_int(data)
|
|
decl_names = []
|
|
for _ in range(n):
|
|
decl_names.append(read_str(data))
|
|
stmt = NonlocalDecl(decl_names)
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
elif tag == nodes.MATCH_STMT:
|
|
subject = read_expression(state, data)
|
|
n_cases = read_int(data)
|
|
patterns = []
|
|
guards: list[Expression | None] = []
|
|
bodies = []
|
|
for _ in range(n_cases):
|
|
pattern = read_pattern(state, data)
|
|
patterns.append(pattern)
|
|
has_guard = read_bool(data)
|
|
if has_guard:
|
|
guard = read_expression(state, data)
|
|
guards.append(guard)
|
|
else:
|
|
guards.append(None)
|
|
body = read_block(state, data)
|
|
bodies.append(body)
|
|
stmt = MatchStmt(subject, patterns, guards, bodies)
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
else:
|
|
assert False, tag
|
|
|
|
|
|
def read_parameters(state: State, data: ReadBuffer) -> tuple[list[Argument], bool]:
|
|
"""Read function/lambda parameters from the buffer.
|
|
|
|
Returns:
|
|
A tuple of (arguments list, has_annotations flag)
|
|
"""
|
|
expect_tag(data, LIST_GEN)
|
|
n_args = read_int_bare(data)
|
|
arguments = []
|
|
has_ann = False
|
|
for _ in range(n_args):
|
|
arg_name = read_str(data)
|
|
arg_kind_int = read_int(data)
|
|
arg_kind = ARG_KINDS[arg_kind_int]
|
|
has_type = read_bool(data)
|
|
if has_type:
|
|
ann = read_type(state, data)
|
|
has_ann = True
|
|
else:
|
|
ann = None
|
|
has_default = read_bool(data)
|
|
if has_default:
|
|
default = read_expression(state, data)
|
|
else:
|
|
default = None
|
|
pos_only = read_bool(data)
|
|
|
|
# Apply implicit_optional if enabled and default is None
|
|
if state.options.implicit_optional and ann is not None:
|
|
optional = isinstance(default, NameExpr) and default.name == "None"
|
|
if isinstance(ann, UnboundType):
|
|
ann.optional = optional
|
|
|
|
var = Var(arg_name, ann)
|
|
var.is_inferred = False
|
|
var.is_argument = True
|
|
arg = Argument(var, ann, default, arg_kind, pos_only)
|
|
read_loc(data, arg)
|
|
set_line_column_range(var, arg)
|
|
arguments.append(arg)
|
|
|
|
return arguments, has_ann
|
|
|
|
|
|
def read_type_params(state: State, data: ReadBuffer) -> list[TypeParam]:
|
|
"""Read type parameters (PEP 695 generics)."""
|
|
type_params: list[TypeParam] = []
|
|
n = read_int_bare(data)
|
|
for _ in range(n):
|
|
kind = read_int(data)
|
|
name = read_str(data)
|
|
has_bound = read_bool(data)
|
|
if has_bound:
|
|
upper_bound = read_type(state, data)
|
|
else:
|
|
upper_bound = None
|
|
|
|
expect_tag(data, LIST_GEN)
|
|
n_values = read_int_bare(data)
|
|
values = [read_type(state, data) for _ in range(n_values)]
|
|
|
|
has_default = read_bool(data)
|
|
if has_default:
|
|
default = read_type(state, data)
|
|
else:
|
|
default = None
|
|
|
|
type_params.append(TypeParam(name, kind, upper_bound, values, default))
|
|
|
|
return type_params
|
|
|
|
|
|
def read_func_def(state: State, data: ReadBuffer) -> FuncDef:
|
|
state.num_funcs += 1
|
|
|
|
name = read_str(data)
|
|
arguments, has_ann = read_parameters(state, data)
|
|
|
|
if special_function_elide_names(name):
|
|
for arg in arguments:
|
|
arg.pos_only = True
|
|
|
|
body = read_block(state, data)
|
|
is_async = read_bool(data)
|
|
|
|
# Type parameters (PEP 695)
|
|
has_type_params = read_bool(data)
|
|
if has_type_params:
|
|
type_params = read_type_params(state, data)
|
|
else:
|
|
type_params = None
|
|
|
|
has_return_type = read_bool(data)
|
|
if has_return_type:
|
|
return_type = read_type(state, data)
|
|
has_ann = True
|
|
else:
|
|
return_type = None
|
|
|
|
if has_ann:
|
|
typ = CallableType(
|
|
[
|
|
arg.type_annotation if arg.type_annotation else AnyType(TypeOfAny.unannotated)
|
|
for arg in arguments
|
|
],
|
|
[arg.kind for arg in arguments],
|
|
[None if arg.pos_only else arg.variable.name for arg in arguments],
|
|
return_type if return_type else AnyType(TypeOfAny.unannotated),
|
|
_dummy_fallback,
|
|
)
|
|
else:
|
|
typ = None
|
|
|
|
func_def = FuncDef(name, arguments, body, typ=typ, type_args=type_params)
|
|
if is_async:
|
|
func_def.is_coroutine = True
|
|
read_loc(data, func_def)
|
|
if typ:
|
|
typ.line = func_def.line
|
|
typ.column = func_def.column
|
|
typ.definition = func_def
|
|
# TODO: This seems wasteful, can we avoid it?
|
|
func_def.unanalyzed_type = typ.copy_modified()
|
|
expect_end_tag(data)
|
|
return func_def
|
|
|
|
|
|
def read_class_def(state: State, data: ReadBuffer) -> ClassDef:
|
|
name = read_str(data)
|
|
body = read_block(state, data)
|
|
base_type_exprs = read_expression_list(state, data)
|
|
|
|
expect_tag(data, LIST_GEN)
|
|
n_decorators = read_int_bare(data)
|
|
decorators = [read_expression(state, data) for _ in range(n_decorators)]
|
|
|
|
# Type parameters (PEP 695)
|
|
has_type_params = read_bool(data)
|
|
if has_type_params:
|
|
type_params = read_type_params(state, data)
|
|
else:
|
|
type_params = None
|
|
|
|
# Keywords (all keyword arguments including metaclass)
|
|
expect_tag(data, DICT_STR_GEN)
|
|
n_keywords = read_int_bare(data)
|
|
keywords = []
|
|
for _ in range(n_keywords):
|
|
key = read_str(data)
|
|
value = read_expression(state, data)
|
|
keywords.append((key, value))
|
|
|
|
# Extract metaclass from keywords if present
|
|
metaclass = dict(keywords).get("metaclass") if keywords else None
|
|
# Remove metaclass from keywords since it's passed as a separate field
|
|
filtered_keywords = [(k, v) for k, v in keywords if k != "metaclass"] if keywords else None
|
|
|
|
class_def = ClassDef(
|
|
name,
|
|
body,
|
|
base_type_exprs=base_type_exprs if base_type_exprs else None,
|
|
metaclass=metaclass,
|
|
keywords=filtered_keywords,
|
|
type_args=type_params,
|
|
)
|
|
class_def.decorators = decorators
|
|
read_loc(data, class_def)
|
|
expect_end_tag(data)
|
|
return class_def
|
|
|
|
|
|
def read_type_alias_stmt(state: State, data: ReadBuffer) -> TypeAliasStmt:
|
|
"""Read PEP 695 type alias statement."""
|
|
name = read_expression(state, data)
|
|
assert isinstance(name, NameExpr), f"Expected NameExpr for type alias name, got {type(name)}"
|
|
|
|
n_type_params = read_int_bare(data)
|
|
if n_type_params > 0:
|
|
type_params = []
|
|
for _ in range(n_type_params):
|
|
kind = read_int(data)
|
|
param_name = read_str(data)
|
|
has_bound = read_bool(data)
|
|
if has_bound:
|
|
upper_bound = read_type(state, data)
|
|
else:
|
|
upper_bound = None
|
|
|
|
# Read values (for constrained TypeVar)
|
|
expect_tag(data, LIST_GEN)
|
|
n_values = read_int_bare(data)
|
|
values = [read_type(state, data) for _ in range(n_values)]
|
|
|
|
has_default = read_bool(data)
|
|
if has_default:
|
|
default = read_type(state, data)
|
|
else:
|
|
default = None
|
|
|
|
type_params.append(TypeParam(param_name, kind, upper_bound, values, default))
|
|
else:
|
|
type_params = []
|
|
|
|
value_expr = read_expression(state, data)
|
|
|
|
# Wrap the value expression in a LambdaExpr as expected by TypeAliasStmt
|
|
# The LambdaExpr body is a Block with a single ReturnStmt
|
|
return_stmt = ReturnStmt(value_expr)
|
|
set_line_column_range(return_stmt, value_expr)
|
|
|
|
block = Block([return_stmt])
|
|
block.line = -1 # Synthetic block
|
|
block.column = 0
|
|
block.end_line = -1
|
|
block.end_column = 0
|
|
|
|
lambda_expr = LambdaExpr([], block)
|
|
set_line_column_range(lambda_expr, value_expr)
|
|
|
|
stmt = TypeAliasStmt(name, type_params, lambda_expr)
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
|
|
|
|
def read_try_stmt(state: State, data: ReadBuffer) -> TryStmt:
|
|
body = read_block(state, data)
|
|
num_handlers = read_int(data)
|
|
|
|
types_list: list[Expression | None] = []
|
|
for _ in range(num_handlers):
|
|
has_type = read_bool(data)
|
|
if has_type:
|
|
exc_type = read_expression(state, data)
|
|
types_list.append(exc_type)
|
|
else:
|
|
types_list.append(None)
|
|
|
|
vars_list: list[NameExpr | None] = []
|
|
for _ in range(num_handlers):
|
|
has_name = read_bool(data)
|
|
if has_name:
|
|
var_name = read_str(data)
|
|
var_expr = NameExpr(var_name)
|
|
vars_list.append(var_expr)
|
|
else:
|
|
vars_list.append(None)
|
|
|
|
handlers = []
|
|
for _ in range(num_handlers):
|
|
handler_body = read_block(state, data)
|
|
handlers.append(handler_body)
|
|
|
|
has_else = read_bool(data)
|
|
if has_else:
|
|
else_body = read_block(state, data)
|
|
else:
|
|
else_body = None
|
|
|
|
has_finally = read_bool(data)
|
|
if has_finally:
|
|
finally_body = read_block(state, data)
|
|
else:
|
|
finally_body = None
|
|
|
|
# Read is_star flag (for except* in Python 3.11+)
|
|
is_star = read_bool(data)
|
|
|
|
stmt = TryStmt(body, vars_list, types_list, handlers, else_body, finally_body)
|
|
stmt.is_star = is_star
|
|
read_loc(data, stmt)
|
|
expect_end_tag(data)
|
|
return stmt
|
|
|
|
|
|
def read_type(state: State, data: ReadBuffer) -> Type:
|
|
tag = read_tag(data)
|
|
if tag == types.UNBOUND_TYPE:
|
|
name = read_str(data)
|
|
expect_tag(data, LIST_GEN)
|
|
n = read_int_bare(data)
|
|
args = tuple(read_type(state, data) for i in range(n))
|
|
empty_tuple_index = read_bool(data)
|
|
# Read optional original_str_expr
|
|
t = read_tag(data)
|
|
if t == LITERAL_NONE:
|
|
original_str_expr = None
|
|
elif t == LITERAL_STR:
|
|
original_str_expr = read_str_bare(data)
|
|
else:
|
|
assert False, f"Unexpected tag for original_str_expr: {t}"
|
|
# Read optional original_str_fallback
|
|
t = read_tag(data)
|
|
if t == LITERAL_NONE:
|
|
original_str_fallback = None
|
|
elif t == LITERAL_STR:
|
|
original_str_fallback = read_str_bare(data)
|
|
else:
|
|
assert False, f"Unexpected tag for original_str_fallback: {t}"
|
|
unbound = UnboundType(
|
|
name,
|
|
args,
|
|
empty_tuple_index=empty_tuple_index,
|
|
original_str_expr=original_str_expr,
|
|
original_str_fallback=original_str_fallback,
|
|
)
|
|
read_loc(data, unbound)
|
|
expect_end_tag(data)
|
|
return unbound
|
|
elif tag == types.UNION_TYPE:
|
|
# Read items list
|
|
expect_tag(data, LIST_GEN)
|
|
n = read_int_bare(data)
|
|
items = [read_type(state, data) for i in range(n)]
|
|
# Read uses_pep604_syntax flag
|
|
uses_pep604_syntax = read_bool(data)
|
|
# Read optional original_str_expr
|
|
t = read_tag(data)
|
|
if t == LITERAL_NONE:
|
|
original_str_expr = None
|
|
elif t == LITERAL_STR:
|
|
original_str_expr = read_str_bare(data)
|
|
else:
|
|
assert False, f"Unexpected tag for original_str_expr: {t}"
|
|
# Read optional original_str_fallback
|
|
t = read_tag(data)
|
|
if t == LITERAL_NONE:
|
|
original_str_fallback = None
|
|
elif t == LITERAL_STR:
|
|
original_str_fallback = read_str_bare(data)
|
|
else:
|
|
assert False, f"Unexpected tag for original_str_fallback: {t}"
|
|
union = UnionType(items, uses_pep604_syntax=uses_pep604_syntax)
|
|
union.original_str_expr = original_str_expr
|
|
union.original_str_fallback = original_str_fallback
|
|
union.is_evaluated = read_bool(data)
|
|
read_loc(data, union)
|
|
expect_end_tag(data)
|
|
return union
|
|
elif tag == types.LIST_TYPE:
|
|
# Read items list
|
|
expect_tag(data, LIST_GEN)
|
|
n = read_int_bare(data)
|
|
items = [read_type(state, data) for i in range(n)]
|
|
type_list = TypeList(items)
|
|
read_loc(data, type_list)
|
|
expect_end_tag(data)
|
|
return type_list
|
|
elif tag == types.TUPLE_TYPE:
|
|
# Read items list
|
|
expect_tag(data, LIST_GEN)
|
|
n = read_int_bare(data)
|
|
items = [read_type(state, data) for i in range(n)]
|
|
implicit = read_bool(data)
|
|
tuple_type = TupleType(items, _dummy_fallback, implicit=implicit)
|
|
read_loc(data, tuple_type)
|
|
expect_end_tag(data)
|
|
return tuple_type
|
|
elif tag == types.TYPED_DICT_TYPE:
|
|
expect_tag(data, LIST_GEN)
|
|
n = read_int_bare(data)
|
|
keys = [read_str_opt(data) for i in range(n)]
|
|
expect_tag(data, LIST_GEN)
|
|
n = read_int_bare(data)
|
|
values = [read_type(state, data) for i in range(n)]
|
|
td_items = {}
|
|
extra_items_from = []
|
|
for key, val in zip(keys, values):
|
|
if key is None:
|
|
assert isinstance(val, ProperType)
|
|
extra_items_from.append(val)
|
|
else:
|
|
td_items[key] = val
|
|
typeddict_type = TypedDictType(td_items, set(), set(), _dummy_fallback)
|
|
typeddict_type.extra_items_from = extra_items_from
|
|
read_loc(data, typeddict_type)
|
|
expect_end_tag(data)
|
|
return typeddict_type
|
|
elif tag == types.ELLIPSIS_TYPE:
|
|
# EllipsisType has no attributes
|
|
ellipsis_type = EllipsisType()
|
|
read_loc(data, ellipsis_type)
|
|
expect_end_tag(data)
|
|
return ellipsis_type
|
|
elif tag == types.RAW_EXPRESSION_TYPE:
|
|
type_name = read_str(data)
|
|
value: types.LiteralValue | str | None
|
|
if type_name == "builtins.bool":
|
|
value = read_bool(data)
|
|
elif type_name == "builtins.int":
|
|
value = read_int(data)
|
|
elif type_name == "builtins.str":
|
|
value = read_str(data)
|
|
elif type_name == "builtins.bytes":
|
|
# Bytes literals are serialized as escaped strings
|
|
value = read_str(data)
|
|
elif type_name == "typing.Any":
|
|
# Invalid type - read None value
|
|
tag = read_tag(data)
|
|
assert tag == LITERAL_NONE, f"Expected LITERAL_NONE for invalid type, got {tag}"
|
|
value = None
|
|
else:
|
|
assert False, f"Unsupported RawExpressionType: {type_name}"
|
|
raw_type = RawExpressionType(value, type_name)
|
|
read_loc(data, raw_type)
|
|
expect_end_tag(data)
|
|
return raw_type
|
|
elif tag == types.UNPACK_TYPE:
|
|
inner_type = read_type(state, data)
|
|
from_star_syntax = read_bool(data)
|
|
unpack = UnpackType(inner_type, from_star_syntax=from_star_syntax)
|
|
read_loc(data, unpack)
|
|
expect_end_tag(data)
|
|
return unpack
|
|
elif tag == types.CALL_TYPE:
|
|
return read_call_type(state, data)
|
|
else:
|
|
assert False, tag
|
|
|
|
|
|
def stringify_type_name(typ: Type) -> str | None:
|
|
"""Extract qualified name from a type (for Arg constructor detection)."""
|
|
if isinstance(typ, UnboundType):
|
|
return typ.name
|
|
return None
|
|
|
|
|
|
def extract_arg_name(typ: Type) -> str | None:
|
|
"""Extract argument name from a type (for Arg name parameter)."""
|
|
if isinstance(typ, RawExpressionType) and typ.base_type_name == "builtins.str":
|
|
return typ.literal_value # type: ignore[return-value]
|
|
elif isinstance(typ, UnboundType):
|
|
# String literals in type context are parsed as UnboundType (forward references)
|
|
# For Arg names, these are typically simple names without dots
|
|
if typ.name == "None":
|
|
return None
|
|
# Return the name as-is (it's the argument name)
|
|
return typ.name
|
|
return None # Invalid, but let validation handle it
|
|
|
|
|
|
def read_call_type(state: State, data: ReadBuffer) -> Type:
|
|
"""Read Call in type context - check if it's an Arg/DefaultArg/VarArg/KwArg constructor.
|
|
|
|
This performs validation and error reporting similar to mypy/fastparse.py.
|
|
"""
|
|
callee_type = read_type(state, data)
|
|
|
|
# Read positional arguments
|
|
expect_tag(data, LIST_GEN)
|
|
n_args = read_int_bare(data)
|
|
args = [read_type(state, data) for _ in range(n_args)]
|
|
|
|
# Read keyword arguments
|
|
expect_tag(data, LIST_GEN)
|
|
n_kwargs = read_int_bare(data)
|
|
kwargs = []
|
|
for _ in range(n_kwargs):
|
|
tag_kw = read_tag(data)
|
|
if tag_kw == LITERAL_NONE:
|
|
kw_name = None
|
|
elif tag_kw == LITERAL_STR:
|
|
kw_name = read_str_bare(data)
|
|
else:
|
|
assert False, f"Unexpected tag for keyword name: {tag_kw}"
|
|
kw_value = read_type(state, data)
|
|
kwargs.append((kw_name, kw_value))
|
|
|
|
# Try to detect Arg/DefaultArg/VarArg/KwArg pattern
|
|
constructor = stringify_type_name(callee_type)
|
|
|
|
# We'll read location before processing errors so we can report them correctly
|
|
invalid = AnyType(TypeOfAny.from_error)
|
|
read_loc(data, invalid)
|
|
expect_end_tag(data)
|
|
|
|
if not constructor:
|
|
# ARG_CONSTRUCTOR_NAME_EXPECTED
|
|
state.add_error(
|
|
message_registry.ARG_CONSTRUCTOR_NAME_EXPECTED.value,
|
|
invalid.line,
|
|
invalid.column,
|
|
blocker=True,
|
|
code="misc",
|
|
)
|
|
return invalid
|
|
|
|
# Extract type and name from arguments
|
|
name: str | None = None
|
|
name_set_from_positional = False
|
|
default_type = AnyType(TypeOfAny.special_form)
|
|
typ: Type = default_type
|
|
typ_set_from_positional = False
|
|
|
|
# Process positional arguments
|
|
for i, arg in enumerate(args):
|
|
if i == 0:
|
|
typ = arg
|
|
typ_set_from_positional = True
|
|
elif i == 1:
|
|
name = extract_arg_name(arg)
|
|
name_set_from_positional = True
|
|
else:
|
|
# ARG_CONSTRUCTOR_TOO_MANY_ARGS
|
|
state.add_error(
|
|
message_registry.ARG_CONSTRUCTOR_TOO_MANY_ARGS.value,
|
|
invalid.line,
|
|
invalid.column,
|
|
blocker=True,
|
|
code="misc",
|
|
)
|
|
|
|
# Process keyword arguments
|
|
for kw_name, kw_value in kwargs:
|
|
if kw_name == "name":
|
|
# MULTIPLE_VALUES_FOR_NAME_KWARG
|
|
if name is not None and name_set_from_positional:
|
|
state.add_error(
|
|
message_registry.MULTIPLE_VALUES_FOR_NAME_KWARG.format(constructor).value,
|
|
invalid.line,
|
|
invalid.column,
|
|
blocker=True,
|
|
code="misc",
|
|
)
|
|
name = extract_arg_name(kw_value)
|
|
elif kw_name == "type":
|
|
# MULTIPLE_VALUES_FOR_TYPE_KWARG
|
|
if typ is not default_type and typ_set_from_positional:
|
|
state.add_error(
|
|
message_registry.MULTIPLE_VALUES_FOR_TYPE_KWARG.format(constructor).value,
|
|
invalid.line,
|
|
invalid.column,
|
|
blocker=True,
|
|
code="misc",
|
|
)
|
|
typ = kw_value
|
|
else:
|
|
# ARG_CONSTRUCTOR_UNEXPECTED_ARG
|
|
state.add_error(
|
|
message_registry.ARG_CONSTRUCTOR_UNEXPECTED_ARG.format(kw_name).value,
|
|
invalid.line,
|
|
invalid.column,
|
|
blocker=True,
|
|
code="misc",
|
|
)
|
|
|
|
# Create CallableArgument
|
|
call_arg = CallableArgument(typ, name, constructor)
|
|
set_line_column_range(call_arg, invalid)
|
|
return call_arg
|
|
|
|
|
|
def read_pattern(state: State, data: ReadBuffer) -> Pattern:
|
|
"""Read a pattern node from the buffer."""
|
|
tag = read_tag(data)
|
|
if tag == nodes.AS_PATTERN:
|
|
has_pattern = read_bool(data)
|
|
if has_pattern:
|
|
pattern = read_pattern(state, data)
|
|
else:
|
|
pattern = None
|
|
has_name = read_bool(data)
|
|
if has_name:
|
|
name_str = read_str(data)
|
|
name = NameExpr(name_str)
|
|
read_loc(data, name)
|
|
else:
|
|
name = None
|
|
as_pattern = AsPattern(pattern, name)
|
|
read_loc(data, as_pattern)
|
|
expect_end_tag(data)
|
|
return as_pattern
|
|
elif tag == nodes.OR_PATTERN:
|
|
n = read_int(data)
|
|
patterns = [read_pattern(state, data) for _ in range(n)]
|
|
or_pattern = OrPattern(patterns)
|
|
read_loc(data, or_pattern)
|
|
expect_end_tag(data)
|
|
return or_pattern
|
|
elif tag == nodes.VALUE_PATTERN:
|
|
expr = read_expression(state, data)
|
|
value_pattern = ValuePattern(expr)
|
|
read_loc(data, value_pattern)
|
|
expect_end_tag(data)
|
|
return value_pattern
|
|
elif tag == nodes.SINGLETON_PATTERN:
|
|
singleton_tag = read_tag(data)
|
|
if singleton_tag == LITERAL_NONE:
|
|
value = None
|
|
else:
|
|
# It's a boolean
|
|
value = singleton_tag == 1 # TAG_LITERAL_TRUE
|
|
singleton_pattern = SingletonPattern(value)
|
|
read_loc(data, singleton_pattern)
|
|
expect_end_tag(data)
|
|
return singleton_pattern
|
|
elif tag == nodes.SEQUENCE_PATTERN:
|
|
n = read_int(data)
|
|
patterns = [read_pattern(state, data) for _ in range(n)]
|
|
sequence_pattern = SequencePattern(patterns)
|
|
read_loc(data, sequence_pattern)
|
|
expect_end_tag(data)
|
|
return sequence_pattern
|
|
elif tag == nodes.STARRED_PATTERN:
|
|
# Read optional capture name
|
|
has_name = read_bool(data)
|
|
if has_name:
|
|
name_str = read_str(data)
|
|
name = NameExpr(name_str)
|
|
read_loc(data, name)
|
|
else:
|
|
name = None
|
|
starred_pattern = StarredPattern(name)
|
|
read_loc(data, starred_pattern)
|
|
expect_end_tag(data)
|
|
return starred_pattern
|
|
elif tag == nodes.MAPPING_PATTERN:
|
|
n = read_int(data)
|
|
keys = []
|
|
values = []
|
|
for _ in range(n):
|
|
key = read_expression(state, data)
|
|
value = read_pattern(state, data)
|
|
keys.append(key)
|
|
values.append(value)
|
|
has_rest = read_bool(data)
|
|
if has_rest:
|
|
rest_str = read_str(data)
|
|
rest = NameExpr(rest_str)
|
|
read_loc(data, rest)
|
|
else:
|
|
rest = None
|
|
mapping_pattern = MappingPattern(keys, values, rest)
|
|
read_loc(data, mapping_pattern)
|
|
expect_end_tag(data)
|
|
return mapping_pattern
|
|
elif tag == nodes.CLASS_PATTERN:
|
|
class_ref = cast(RefExpr, read_expression(state, data))
|
|
n_positional = read_int(data)
|
|
positionals = [read_pattern(state, data) for _ in range(n_positional)]
|
|
n_keywords = read_int(data)
|
|
keyword_keys = []
|
|
keyword_values = []
|
|
for _ in range(n_keywords):
|
|
key = read_str(data)
|
|
value = read_pattern(state, data)
|
|
keyword_keys.append(key)
|
|
keyword_values.append(value)
|
|
class_pattern = ClassPattern(class_ref, positionals, keyword_keys, keyword_values)
|
|
read_loc(data, class_pattern)
|
|
expect_end_tag(data)
|
|
return class_pattern
|
|
else:
|
|
assert False, f"Unknown pattern tag: {tag}"
|
|
|
|
|
|
def read_block(state: State, data: ReadBuffer) -> Block:
|
|
expect_tag(data, nodes.BLOCK)
|
|
expect_tag(data, LIST_GEN)
|
|
n = read_int_bare(data)
|
|
is_unreachable = read_bool(data)
|
|
if n == 0:
|
|
# Empty block - read explicit location
|
|
b = Block([], is_unreachable=is_unreachable)
|
|
read_loc(data, b)
|
|
expect_end_tag(data)
|
|
return b
|
|
else:
|
|
# Non-empty block - read statements and set location from them
|
|
a = read_statements(state, data, n)
|
|
expect_end_tag(data)
|
|
b = Block(a, is_unreachable=is_unreachable)
|
|
b.line = a[0].line
|
|
b.column = a[0].column
|
|
b.end_line = a[-1].end_line
|
|
b.end_column = a[-1].end_column
|
|
return b
|
|
|
|
|
|
def read_optional_block(state: State, data: ReadBuffer) -> Block | None:
|
|
expect_tag(data, nodes.BLOCK)
|
|
expect_tag(data, LIST_GEN)
|
|
n = read_int_bare(data)
|
|
is_unreachable = read_bool(data)
|
|
if n == 0:
|
|
b = None
|
|
else:
|
|
a = [read_statement(state, data) for i in range(n)]
|
|
b = Block(a, is_unreachable=is_unreachable)
|
|
b.line = a[0].line
|
|
b.column = a[0].column
|
|
b.end_line = a[-1].end_line
|
|
b.end_column = a[-1].end_column
|
|
expect_end_tag(data)
|
|
return b
|
|
|
|
|
|
bin_ops: Final = ["+", "-", "*", "@", "/", "%", "**", "<<", ">>", "|", "^", "&", "//"]
|
|
bool_ops: Final = ["and", "or"]
|
|
cmp_ops: Final = ["==", "!=", "<", "<=", ">", ">=", "is", "is not", "in", "not in"]
|
|
unary_ops: Final = ["~", "not", "+", "-"]
|
|
|
|
|
|
def read_expression(state: State, data: ReadBuffer) -> Expression:
|
|
tag = read_tag(data)
|
|
expr: Expression
|
|
if tag == nodes.CALL_EXPR:
|
|
callee = read_expression(state, data)
|
|
args = read_expression_list(state, data)
|
|
# Read argument kinds
|
|
expect_tag(data, LIST_INT)
|
|
n_kinds = read_int_bare(data)
|
|
arg_kinds = [ARG_KINDS[read_int_bare(data)] for _ in range(n_kinds)]
|
|
# Read argument names
|
|
expect_tag(data, LIST_GEN)
|
|
n_names = read_int_bare(data)
|
|
arg_names: list[str | None] = []
|
|
for _ in range(n_names):
|
|
tag = read_tag(data)
|
|
if tag == LITERAL_NONE:
|
|
arg_names.append(None)
|
|
elif tag == LITERAL_STR:
|
|
arg_names.append(read_str_bare(data))
|
|
else:
|
|
assert False, f"Unexpected tag for arg_name: {tag}"
|
|
ce = CallExpr(callee, args, arg_kinds, arg_names)
|
|
read_loc(data, ce)
|
|
expect_end_tag(data)
|
|
return ce
|
|
elif tag == nodes.NAME_EXPR:
|
|
s = read_str(data)
|
|
ne = NameExpr(s)
|
|
read_loc(data, ne)
|
|
expect_end_tag(data)
|
|
return ne
|
|
elif tag == nodes.MEMBER_EXPR:
|
|
e = read_expression(state, data)
|
|
attr = read_str(data)
|
|
m = MemberExpr(e, attr)
|
|
# Check if this is a super() call - if so, convert to SuperExpr
|
|
if isinstance(e, CallExpr) and isinstance(e.callee, NameExpr) and e.callee.name == "super":
|
|
result: Expression = SuperExpr(attr, e)
|
|
else:
|
|
result = m
|
|
read_loc(data, result)
|
|
expect_end_tag(data)
|
|
return result
|
|
elif tag == nodes.STR_EXPR:
|
|
se = StrExpr(read_str(data))
|
|
read_loc(data, se)
|
|
expect_end_tag(data)
|
|
return se
|
|
elif tag == nodes.INT_EXPR:
|
|
ie = IntExpr(read_int(data))
|
|
read_loc(data, ie)
|
|
expect_end_tag(data)
|
|
return ie
|
|
elif tag == nodes.FLOAT_EXPR:
|
|
expect_tag(data, LITERAL_FLOAT)
|
|
value = read_float_bare(data)
|
|
fe = FloatExpr(value)
|
|
read_loc(data, fe)
|
|
expect_end_tag(data)
|
|
return fe
|
|
elif tag == nodes.LIST_EXPR:
|
|
items = read_expression_list(state, data)
|
|
expr = ListExpr(items)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.TUPLE_EXPR:
|
|
items = read_expression_list(state, data)
|
|
t = TupleExpr(items)
|
|
read_loc(data, t)
|
|
expect_end_tag(data)
|
|
return t
|
|
elif tag == nodes.SET_EXPR:
|
|
items = read_expression_list(state, data)
|
|
expr = SetExpr(items)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.GENERATOR_EXPR:
|
|
expr = read_generator_expr(state, data)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.LIST_COMPREHENSION:
|
|
generator = read_generator_expr(state, data)
|
|
expr = ListComprehension(generator)
|
|
read_loc(data, expr)
|
|
# Also copy location to the inner generator
|
|
set_line_column_range(generator, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.SET_COMPREHENSION:
|
|
generator = read_generator_expr(state, data)
|
|
expr = SetComprehension(generator)
|
|
read_loc(data, expr)
|
|
# Also copy location to the inner generator
|
|
set_line_column_range(generator, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.DICT_COMPREHENSION:
|
|
key = read_expression(state, data)
|
|
value = read_expression(state, data)
|
|
n_generators = read_int(data)
|
|
indices = [read_expression(state, data) for _ in range(n_generators)]
|
|
sequences = [read_expression(state, data) for _ in range(n_generators)]
|
|
condlists = [read_expression_list(state, data) for _ in range(n_generators)]
|
|
is_async = [read_bool(data) for _ in range(n_generators)]
|
|
expr = DictionaryComprehension(key, value, indices, sequences, condlists, is_async)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.YIELD_EXPR:
|
|
has_value = read_bool(data)
|
|
if has_value:
|
|
value = read_expression(state, data)
|
|
else:
|
|
value = None
|
|
expr = YieldExpr(value)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.YIELD_FROM_EXPR:
|
|
value = read_expression(state, data)
|
|
expr = YieldFromExpr(value)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.OP_EXPR:
|
|
op = bin_ops[read_int(data)]
|
|
left = read_expression(state, data)
|
|
right = read_expression(state, data)
|
|
o = OpExpr(op, left, right)
|
|
# TODO: Store these explicitly?
|
|
o.line = left.line
|
|
o.column = left.column
|
|
o.end_line = right.end_line
|
|
o.end_column = right.end_column
|
|
expect_end_tag(data)
|
|
return o
|
|
elif tag == nodes.INDEX_EXPR:
|
|
base = read_expression(state, data)
|
|
index = read_expression(state, data)
|
|
expr = IndexExpr(base, index)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.BOOL_OP_EXPR:
|
|
op = bool_ops[read_int(data)]
|
|
values = read_expression_list(state, data)
|
|
# Convert list of values to nested OpExpr nodes
|
|
# E.g., [a, b, c] with "and" becomes OpExpr("and", OpExpr("and", a, b), c)
|
|
assert len(values) >= 2
|
|
result = values[0]
|
|
for val in values[1:]:
|
|
result = OpExpr(op, result, val)
|
|
result.line = values[0].line
|
|
result.column = values[0].column
|
|
result.end_line = val.end_line
|
|
result.end_column = val.end_column
|
|
read_loc(data, result)
|
|
expect_end_tag(data)
|
|
return result
|
|
elif tag == nodes.COMPARISON_EXPR:
|
|
left = read_expression(state, data)
|
|
expect_tag(data, LIST_INT)
|
|
n_ops = read_int_bare(data)
|
|
ops = [cmp_ops[read_int_bare(data)] for _ in range(n_ops)]
|
|
comparators = read_expression_list(state, data)
|
|
assert len(ops) == len(comparators)
|
|
expr = ComparisonExpr(ops, [left] + comparators)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.UNARY_EXPR:
|
|
op = unary_ops[read_int(data)]
|
|
operand = read_expression(state, data)
|
|
expr = UnaryExpr(op, operand)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.DICT_EXPR:
|
|
expect_tag(data, LIST_GEN)
|
|
n_keys = read_int_bare(data)
|
|
keys: list[Expression | None] = []
|
|
for _ in range(n_keys):
|
|
has_key = read_bool(data)
|
|
if has_key:
|
|
keys.append(read_expression(state, data))
|
|
else:
|
|
keys.append(None)
|
|
values = read_expression_list(state, data)
|
|
# Zip keys and values into items
|
|
items = list(zip(keys, values))
|
|
expr = DictExpr(items)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.COMPLEX_EXPR:
|
|
expect_tag(data, LITERAL_FLOAT)
|
|
real = read_float_bare(data)
|
|
expect_tag(data, LITERAL_FLOAT)
|
|
imag = read_float_bare(data)
|
|
value = complex(real, imag)
|
|
expr = ComplexExpr(value)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.SLICE_EXPR:
|
|
has_begin = read_bool(data)
|
|
begin_index = read_expression(state, data) if has_begin else None
|
|
has_end = read_bool(data)
|
|
end_index = read_expression(state, data) if has_end else None
|
|
has_stride = read_bool(data)
|
|
stride = read_expression(state, data) if has_stride else None
|
|
expr = SliceExpr(begin_index, end_index, stride)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.TEMP_NODE:
|
|
# TempNode with no attributes
|
|
temp = TempNode(AnyType(TypeOfAny.special_form), no_rhs=True)
|
|
expect_end_tag(data)
|
|
return temp
|
|
elif tag == nodes.ELLIPSIS_EXPR:
|
|
expr = EllipsisExpr()
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.CONDITIONAL_EXPR:
|
|
if_expr = read_expression(state, data)
|
|
cond = read_expression(state, data)
|
|
else_expr = read_expression(state, data)
|
|
expr = ConditionalExpr(cond, if_expr, else_expr)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.FSTRING_EXPR:
|
|
# F-strings are converted into nodes representing "".join([...]), to match
|
|
# pre-existing behavior.
|
|
nparts = read_int(data)
|
|
fitems = []
|
|
for _ in range(nparts):
|
|
b = read_bool(data)
|
|
if b:
|
|
n = read_int(data)
|
|
for i in range(n):
|
|
fitems.append(read_fstring_item(state, data))
|
|
else:
|
|
s = StrExpr(read_str(data))
|
|
read_loc(data, s)
|
|
fitems.append(s)
|
|
expr = build_fstring_join(state, data, fitems)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.TSTRING_EXPR:
|
|
nparts = read_int(data)
|
|
titems: list[Expression | tuple[Expression, str, str | None, Expression | None]] = []
|
|
for _ in range(nparts):
|
|
if read_bool(data):
|
|
e = read_expression(state, data)
|
|
s = read_str(data)
|
|
if read_bool(data):
|
|
conv = read_str(data)
|
|
else:
|
|
conv = None
|
|
if read_bool(data):
|
|
# Parse format spec as a JoinedStr, this matches the old parser behavior.
|
|
format_spec = read_fstring_items(state, data)
|
|
else:
|
|
format_spec = None
|
|
titems.append((e, s, conv, format_spec))
|
|
else:
|
|
s = StrExpr(read_str(data))
|
|
read_loc(data, s)
|
|
titems.append(s)
|
|
expr = TemplateStrExpr(titems)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.LAMBDA_EXPR:
|
|
arguments, has_ann = read_parameters(state, data)
|
|
body = read_block(state, data)
|
|
|
|
if has_ann:
|
|
typ = CallableType(
|
|
[
|
|
arg.type_annotation if arg.type_annotation else AnyType(TypeOfAny.unannotated)
|
|
for arg in arguments
|
|
],
|
|
[arg.kind for arg in arguments],
|
|
[None if arg.pos_only else arg.variable.name for arg in arguments],
|
|
AnyType(TypeOfAny.unannotated),
|
|
_dummy_fallback,
|
|
)
|
|
else:
|
|
typ = None
|
|
|
|
expr = LambdaExpr(arguments, body)
|
|
expr.type = typ
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.NAMED_EXPR:
|
|
target = read_expression(state, data)
|
|
value = read_expression(state, data)
|
|
# AssignmentExpr expects target to be a NameExpr
|
|
if not isinstance(target, NameExpr):
|
|
# In case target is not a NameExpr, we need to handle this
|
|
# For now, we'll assert since the grammar should ensure it's a NameExpr
|
|
assert isinstance(
|
|
target, NameExpr
|
|
), f"Expected NameExpr for target, got {type(target)}"
|
|
expr = AssignmentExpr(target, value)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.STAR_EXPR:
|
|
wrapped_expr = read_expression(state, data)
|
|
expr = StarExpr(wrapped_expr)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.BYTES_EXPR:
|
|
# Read bytes literal as string
|
|
value = read_str(data)
|
|
expr = BytesExpr(value)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.AWAIT_EXPR:
|
|
value = read_expression(state, data)
|
|
expr = AwaitExpr(value)
|
|
read_loc(data, expr)
|
|
expect_end_tag(data)
|
|
return expr
|
|
elif tag == nodes.BIG_INT_EXPR:
|
|
strval = read_str(data)
|
|
ie = IntExpr(int(strval, base=0))
|
|
read_loc(data, ie)
|
|
expect_end_tag(data)
|
|
return ie
|
|
else:
|
|
assert False, tag
|
|
|
|
|
|
def read_fstring_items(state: State, data: ReadBuffer) -> Expression:
|
|
items = []
|
|
n = read_int(data)
|
|
items = [read_fstring_item(state, data) for i in range(n)]
|
|
return build_fstring_join(state, data, items)
|
|
|
|
|
|
def build_fstring_join(state: State, data: ReadBuffer, items: list[Expression]) -> Expression:
|
|
if len(items) == 1:
|
|
expr = items[0]
|
|
read_loc(data, expr)
|
|
return expr
|
|
if all(isinstance(item, StrExpr) for item in items):
|
|
s = "".join([cast(StrExpr, item).value for item in items])
|
|
expr = StrExpr(s)
|
|
read_loc(data, expr)
|
|
return expr
|
|
args = ListExpr(items)
|
|
str_expr = StrExpr("")
|
|
member = MemberExpr(str_expr, "join")
|
|
call = CallExpr(member, [args], [ARG_POS], [None])
|
|
read_loc(data, call)
|
|
set_line_column(args, call)
|
|
set_line_column(str_expr, call)
|
|
set_line_column(member, call)
|
|
return call
|
|
|
|
|
|
def read_fstring_item(state: State, data: ReadBuffer) -> Expression:
|
|
t = read_tag(data)
|
|
if t == LITERAL_STR:
|
|
str_expr = StrExpr(read_str_bare(data))
|
|
read_loc(data, str_expr)
|
|
return str_expr
|
|
elif t == nodes.FSTRING_INTERPOLATION:
|
|
expr = read_expression(state, data)
|
|
|
|
# Read conversion flag such as !r
|
|
has_conv = read_bool(data)
|
|
if has_conv:
|
|
c = read_str(data)
|
|
fmt = "{" + c + ":{}}"
|
|
else:
|
|
fmt = "{:{}}"
|
|
|
|
# Read format spec such as <30 (which may have nested {...})
|
|
has_spec = read_bool(data)
|
|
if has_spec:
|
|
spec = read_fstring_items(state, data)
|
|
else:
|
|
spec = StrExpr("")
|
|
|
|
member = MemberExpr(StrExpr(fmt), "format")
|
|
set_line_column(member, expr)
|
|
call = CallExpr(member, [expr, spec], [ARG_POS, ARG_POS], [None, None])
|
|
set_line_column(call, expr)
|
|
expect_end_tag(data)
|
|
return call
|
|
else:
|
|
raise ValueError(f"Unexpected tag {t}")
|
|
|
|
|
|
def set_line_column(target: Context, src: Context) -> None:
|
|
target.line = src.line
|
|
target.column = src.column
|
|
|
|
|
|
def set_line_column_range(target: Context, src: Context) -> None:
|
|
target.line = src.line
|
|
target.column = src.column
|
|
target.end_line = src.end_line
|
|
target.end_column = src.end_column
|
|
|
|
|
|
def read_expression_list(state: State, data: ReadBuffer) -> list[Expression]:
|
|
expect_tag(data, LIST_GEN)
|
|
n = read_int_bare(data)
|
|
return [read_expression(state, data) for i in range(n)]
|
|
|
|
|
|
def read_generator_expr(state: State, data: ReadBuffer) -> GeneratorExpr:
|
|
"""Helper function to read comprehension data (shared by Generator, ListComp, SetComp)"""
|
|
left_expr = read_expression(state, data)
|
|
n_generators = read_int(data)
|
|
indices = [read_expression(state, data) for _ in range(n_generators)]
|
|
sequences = [read_expression(state, data) for _ in range(n_generators)]
|
|
condlists = [read_expression_list(state, data) for _ in range(n_generators)]
|
|
is_async = [read_bool(data) for _ in range(n_generators)]
|
|
return GeneratorExpr(left_expr, indices, sequences, condlists, is_async)
|
|
|
|
|
|
def read_loc(data: ReadBuffer, node: Context) -> None:
|
|
expect_tag(data, LOCATION)
|
|
line = read_int_bare(data)
|
|
node.line = line
|
|
column = read_int_bare(data)
|
|
node.column = column
|
|
node.end_line = line + read_int_bare(data)
|
|
node.end_column = column + read_int_bare(data)
|
|
|
|
|
|
def strip_contents_from_if_stmt(stmt: IfStmt) -> None:
|
|
"""Remove contents from IfStmt.
|
|
|
|
Needed to still be able to check the conditions after the contents
|
|
have been merged with the surrounding function overloads.
|
|
"""
|
|
if len(stmt.body) == 1:
|
|
stmt.body[0].body = []
|
|
if stmt.else_body and len(stmt.else_body.body) == 1:
|
|
if isinstance(stmt.else_body.body[0], IfStmt):
|
|
strip_contents_from_if_stmt(stmt.else_body.body[0])
|
|
else:
|
|
stmt.else_body.body = []
|
|
|
|
|
|
def is_stripped_if_stmt(stmt: Statement) -> bool:
|
|
"""Check stmt to make sure it is a stripped IfStmt.
|
|
|
|
See also: strip_contents_from_if_stmt
|
|
"""
|
|
if not isinstance(stmt, IfStmt):
|
|
return False
|
|
|
|
if not (len(stmt.body) == 1 and len(stmt.body[0].body) == 0):
|
|
# Body not empty
|
|
return False
|
|
|
|
if not stmt.else_body or len(stmt.else_body.body) == 0:
|
|
# No or empty else_body
|
|
return True
|
|
|
|
# For elif, IfStmt are stored recursively in else_body
|
|
return is_stripped_if_stmt(stmt.else_body.body[0])
|
|
|
|
|
|
def fail_merge_overload(state: State, node: IfStmt) -> None:
|
|
"""Report an error when overloads cannot be merged due to unknown condition."""
|
|
state.add_error(
|
|
message_registry.FAILED_TO_MERGE_OVERLOADS.value,
|
|
node.line,
|
|
node.column,
|
|
blocker=False,
|
|
code="misc",
|
|
)
|
|
|
|
|
|
def check_ifstmt_for_overloads(
|
|
stmt: IfStmt, current_overload_name: str | None = None
|
|
) -> str | None:
|
|
"""Check if IfStmt contains only overloads with the same name.
|
|
Return overload_name if found, None otherwise.
|
|
"""
|
|
# Check that block only contains a single Decorator, FuncDef, or OverloadedFuncDef.
|
|
# Multiple overloads have already been merged as OverloadedFuncDef.
|
|
if not (
|
|
len(stmt.body[0].body) == 1
|
|
and (
|
|
isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
|
|
or current_overload_name is not None
|
|
and isinstance(stmt.body[0].body[0], FuncDef)
|
|
)
|
|
or len(stmt.body[0].body) > 1
|
|
and isinstance(stmt.body[0].body[-1], OverloadedFuncDef)
|
|
and all(is_stripped_if_stmt(if_stmt) for if_stmt in stmt.body[0].body[:-1])
|
|
):
|
|
return None
|
|
|
|
overload_name = cast(Decorator | FuncDef | OverloadedFuncDef, stmt.body[0].body[-1]).name
|
|
if stmt.else_body is None or stmt.else_body.is_unreachable:
|
|
return overload_name
|
|
|
|
if len(stmt.else_body.body) == 1:
|
|
# For elif: else_body contains an IfStmt itself -> do a recursive check.
|
|
if (
|
|
isinstance(stmt.else_body.body[0], (Decorator, FuncDef, OverloadedFuncDef))
|
|
and stmt.else_body.body[0].name == overload_name
|
|
):
|
|
return overload_name
|
|
if (
|
|
isinstance(stmt.else_body.body[0], IfStmt)
|
|
and check_ifstmt_for_overloads(stmt.else_body.body[0], current_overload_name)
|
|
== overload_name
|
|
):
|
|
return overload_name
|
|
|
|
return None
|
|
|
|
|
|
def get_executable_if_block_with_overloads(
|
|
stmt: IfStmt, options: Options
|
|
) -> tuple[Block | None, IfStmt | None]:
|
|
"""Return block from IfStmt that will get executed.
|
|
|
|
Return
|
|
0 -> A block if sure that alternative blocks are unreachable.
|
|
1 -> An IfStmt if the reachability of it can't be inferred,
|
|
i.e. the truth value is unknown.
|
|
"""
|
|
infer_reachability_of_if_statement(stmt, options)
|
|
if stmt.else_body is None and stmt.body[0].is_unreachable is True:
|
|
# always False condition with no else
|
|
return None, None
|
|
if (
|
|
stmt.else_body is None
|
|
or stmt.body[0].is_unreachable is False
|
|
and stmt.else_body.is_unreachable is False
|
|
):
|
|
# The truth value is unknown, thus not conclusive
|
|
return None, stmt
|
|
if stmt.else_body.is_unreachable:
|
|
# else_body will be set unreachable if condition is always True
|
|
return stmt.body[0], None
|
|
if stmt.body[0].is_unreachable is True:
|
|
# body will be set unreachable if condition is always False
|
|
# else_body can contain an IfStmt itself (for elif) -> do a recursive check
|
|
if isinstance(stmt.else_body.body[0], IfStmt):
|
|
return get_executable_if_block_with_overloads(stmt.else_body.body[0], options)
|
|
return stmt.else_body, None
|
|
return None, stmt
|
|
|
|
|
|
def fix_function_overloads(state: State, stmts: list[Statement]) -> list[Statement]:
|
|
"""Merge consecutive function overloads into OverloadedFuncDef nodes.
|
|
|
|
This function processes a list of statements and combines function overloads
|
|
(marked with @overload decorator) that have the same name into a single
|
|
OverloadedFuncDef node. It also handles conditional overloads (overloads
|
|
inside if statements) when the condition can be evaluated.
|
|
"""
|
|
ret: list[Statement] = []
|
|
current_overload: list[OverloadPart] = []
|
|
current_overload_name: str | None = None
|
|
last_unconditional_func_def: str | None = None
|
|
last_if_stmt: IfStmt | None = None
|
|
last_if_overload: Decorator | FuncDef | OverloadedFuncDef | None = None
|
|
last_if_stmt_overload_name: str | None = None
|
|
last_if_unknown_truth_value: IfStmt | None = None
|
|
skipped_if_stmts: list[IfStmt] = []
|
|
for stmt in stmts:
|
|
if_overload_name: str | None = None
|
|
if_block_with_overload: Block | None = None
|
|
if_unknown_truth_value: IfStmt | None = None
|
|
if isinstance(stmt, IfStmt):
|
|
# Check IfStmt block to determine if function overloads can be merged
|
|
if_overload_name = check_ifstmt_for_overloads(stmt, current_overload_name)
|
|
if if_overload_name is not None:
|
|
if_block_with_overload, if_unknown_truth_value = (
|
|
get_executable_if_block_with_overloads(stmt, state.options)
|
|
)
|
|
|
|
if (
|
|
current_overload_name is not None
|
|
and isinstance(stmt, (Decorator, FuncDef))
|
|
and stmt.name == current_overload_name
|
|
):
|
|
if last_if_stmt is not None:
|
|
skipped_if_stmts.append(last_if_stmt)
|
|
if last_if_overload is not None:
|
|
# Last stmt was an IfStmt with same overload name
|
|
# Add overloads to current_overload
|
|
if isinstance(last_if_overload, OverloadedFuncDef):
|
|
current_overload.extend(last_if_overload.items)
|
|
else:
|
|
current_overload.append(last_if_overload)
|
|
last_if_stmt, last_if_overload = None, None
|
|
if last_if_unknown_truth_value:
|
|
fail_merge_overload(state, last_if_unknown_truth_value)
|
|
last_if_unknown_truth_value = None
|
|
current_overload.append(stmt)
|
|
if isinstance(stmt, FuncDef):
|
|
# This is, strictly speaking, wrong: there might be a decorated
|
|
# implementation. However, it only affects the error message we show:
|
|
# ideally it's "already defined", but "implementation must come last"
|
|
# is also reasonable.
|
|
# TODO: can we get rid of this completely and just always emit
|
|
# "implementation must come last" instead?
|
|
last_unconditional_func_def = stmt.name
|
|
elif (
|
|
current_overload_name is not None
|
|
and isinstance(stmt, IfStmt)
|
|
and if_overload_name == current_overload_name
|
|
and last_unconditional_func_def != current_overload_name
|
|
):
|
|
# IfStmt only contains stmts relevant to current_overload.
|
|
# Check if stmts are reachable and add them to current_overload,
|
|
# otherwise skip IfStmt to allow subsequent overload
|
|
# or function definitions.
|
|
skipped_if_stmts.append(stmt)
|
|
if if_block_with_overload is None:
|
|
if if_unknown_truth_value is not None:
|
|
fail_merge_overload(state, if_unknown_truth_value)
|
|
continue
|
|
if last_if_overload is not None:
|
|
# Last stmt was an IfStmt with same overload name
|
|
# Add overloads to current_overload
|
|
if isinstance(last_if_overload, OverloadedFuncDef):
|
|
current_overload.extend(last_if_overload.items)
|
|
else:
|
|
current_overload.append(last_if_overload)
|
|
last_if_stmt, last_if_overload = None, None
|
|
if isinstance(if_block_with_overload.body[-1], OverloadedFuncDef):
|
|
skipped_if_stmts.extend(cast(list[IfStmt], if_block_with_overload.body[:-1]))
|
|
current_overload.extend(if_block_with_overload.body[-1].items)
|
|
else:
|
|
current_overload.append(cast(Decorator | FuncDef, if_block_with_overload.body[0]))
|
|
else:
|
|
if last_if_stmt is not None:
|
|
ret.append(last_if_stmt)
|
|
last_if_stmt_overload_name = current_overload_name
|
|
last_if_stmt, last_if_overload = None, None
|
|
last_if_unknown_truth_value = None
|
|
|
|
if current_overload and current_overload_name == last_if_stmt_overload_name:
|
|
# Remove last stmt (IfStmt) from ret if the overload names matched
|
|
# Only happens if no executable block had been found in IfStmt
|
|
popped = ret.pop()
|
|
assert isinstance(popped, IfStmt)
|
|
skipped_if_stmts.append(popped)
|
|
if current_overload and skipped_if_stmts:
|
|
# Add bare IfStmt (without overloads) to ret
|
|
# Required for mypy to be able to still check conditions
|
|
for if_stmt in skipped_if_stmts:
|
|
strip_contents_from_if_stmt(if_stmt)
|
|
ret.append(if_stmt)
|
|
skipped_if_stmts = []
|
|
if len(current_overload) == 1:
|
|
ret.append(current_overload[0])
|
|
elif len(current_overload) > 1:
|
|
ret.append(OverloadedFuncDef(current_overload))
|
|
|
|
# If we have multiple decorated functions named "_" next to each, we want to treat
|
|
# them as a series of regular FuncDefs instead of one OverloadedFuncDef because
|
|
# most of mypy/mypyc assumes that all the functions in an OverloadedFuncDef are
|
|
# related, but multiple underscore functions next to each other aren't necessarily
|
|
# related
|
|
last_unconditional_func_def = None
|
|
if isinstance(stmt, Decorator) and not unnamed_function(stmt.name):
|
|
current_overload = [stmt]
|
|
current_overload_name = stmt.name
|
|
elif isinstance(stmt, IfStmt) and if_overload_name is not None:
|
|
current_overload = []
|
|
current_overload_name = if_overload_name
|
|
last_if_stmt = stmt
|
|
last_if_stmt_overload_name = None
|
|
if if_block_with_overload is not None:
|
|
skipped_if_stmts.extend(cast(list[IfStmt], if_block_with_overload.body[:-1]))
|
|
last_if_overload = cast(
|
|
Decorator | FuncDef | OverloadedFuncDef, if_block_with_overload.body[-1]
|
|
)
|
|
last_if_unknown_truth_value = if_unknown_truth_value
|
|
else:
|
|
current_overload = []
|
|
current_overload_name = None
|
|
ret.append(stmt)
|
|
|
|
if current_overload and skipped_if_stmts:
|
|
# Add bare IfStmt (without overloads) to ret
|
|
# Required for mypy to be able to still check conditions
|
|
for if_stmt in skipped_if_stmts:
|
|
strip_contents_from_if_stmt(if_stmt)
|
|
ret.append(if_stmt)
|
|
if len(current_overload) == 1:
|
|
ret.append(current_overload[0])
|
|
elif len(current_overload) > 1:
|
|
ret.append(OverloadedFuncDef(current_overload))
|
|
elif last_if_overload is not None:
|
|
ret.append(last_if_overload)
|
|
elif last_if_stmt is not None:
|
|
ret.append(last_if_stmt)
|
|
return ret
|
|
|
|
|
|
def deserialize_imports(import_bytes: bytes) -> list[ImportBase]:
|
|
"""Deserialize import metadata from bytes into mypy AST nodes.
|
|
|
|
Args:
|
|
import_bytes: Serialized import metadata from the Rust parser
|
|
|
|
Returns:
|
|
List of Import and ImportFrom AST nodes with location and metadata
|
|
"""
|
|
if not import_bytes:
|
|
return []
|
|
|
|
data = ReadBuffer(import_bytes)
|
|
|
|
expect_tag(data, LIST_GEN)
|
|
n_imports = read_int_bare(data)
|
|
|
|
imports: list[ImportBase] = []
|
|
|
|
for _ in range(n_imports):
|
|
tag = read_tag(data)
|
|
|
|
if tag == IMPORT_METADATA:
|
|
name = read_str(data)
|
|
relative = read_int(data)
|
|
|
|
has_asname = read_bool(data)
|
|
if has_asname:
|
|
asname = read_str(data)
|
|
else:
|
|
asname = None
|
|
|
|
# Note: relative imports are handled via ImportFrom, so relative should be 0 here
|
|
stmt = Import([(name, asname)])
|
|
_read_and_set_import_metadata(data, stmt)
|
|
imports.append(stmt)
|
|
|
|
elif tag == IMPORTFROM_METADATA:
|
|
module = read_str(data)
|
|
relative = read_int(data)
|
|
|
|
expect_tag(data, LIST_GEN)
|
|
n_names = read_int_bare(data)
|
|
names: list[tuple[str, str | None]] = []
|
|
|
|
for _ in range(n_names):
|
|
name = read_str(data)
|
|
has_asname = read_bool(data)
|
|
if has_asname:
|
|
asname = read_str(data)
|
|
else:
|
|
asname = None
|
|
names.append((name, asname))
|
|
|
|
stmt = ImportFrom(module, relative, names)
|
|
_read_and_set_import_metadata(data, stmt)
|
|
imports.append(stmt)
|
|
|
|
elif tag == IMPORTALL_METADATA:
|
|
module = read_str(data)
|
|
relative = read_int(data)
|
|
|
|
stmt = ImportAll(module, relative)
|
|
_read_and_set_import_metadata(data, stmt)
|
|
imports.append(stmt)
|
|
|
|
else:
|
|
raise ValueError(f"Unexpected tag in import metadata: {tag}")
|
|
|
|
return imports
|
|
|
|
|
|
def _read_and_set_import_metadata(data: ReadBuffer, stmt: Import | ImportFrom | ImportAll) -> None:
|
|
"""Read location and metadata flags from buffer and set them on the import statement.
|
|
|
|
Args:
|
|
data: Buffer containing serialized data
|
|
stmt: Import, ImportFrom, or ImportAll statement to populate with location and metadata
|
|
"""
|
|
read_loc(data, stmt)
|
|
|
|
# Metadata flags as a single integer bitfield
|
|
flags = read_int(data)
|
|
|
|
# Extract individual flags using bitwise operations
|
|
# Bit 0: is_top_level
|
|
# Bit 1: is_unreachable
|
|
# Bit 2: is_mypy_only
|
|
stmt.is_top_level = (flags & 0x01) != 0
|
|
stmt.is_unreachable = (flags & 0x02) != 0
|
|
stmt.is_mypy_only = (flags & 0x04) != 0
|