indie-status-page/venv/lib/python3.11/site-packages/mypy/nativeparse.py
IndieStatusBot 902133edd3 feat: indie status page MVP -- FastAPI + SQLite
- 8 DB models (services, incidents, monitors, subscribers, etc.)
- Full CRUD API for services, incidents, monitors
- Public status page with live data
- Incident detail page with timeline
- API key authentication
- Uptime monitoring scheduler
- 13 tests passing
- TECHNICAL_DESIGN.md with full spec
2026-04-25 05:00:00 +00:00

2094 lines
71 KiB
Python

# mypy: allow-redefinition-new, local-partial-types
"""Python parser that directly constructs a native AST (when compiled).
Use a Rust extension to generate a serialized AST, and deserialize the AST directly
to a mypy AST.
NOTE: This is work in progress. To use this, you need to manually build the
ast_serialize Rust extension. See the README at https://github.com/mypyc/ast_serialize.
Expected benefits over mypy.fastparse:
* No intermediate non-mypyc Python-level AST created, to improve performance
* Parsing doesn't need GIL => use multithreading to construct serialized ASTs in parallel
* Produce import dependencies without having to build an AST => helps parallel type checking
* Support all Python syntax even if mypy is running on an older Python version
* Generate an AST even if there are syntax errors
* Potential to support incremental parsing (quickly process modified sections in a file)
* Stripping function bodies in third-party code can happen earlier, for extra performance
"""
from __future__ import annotations
import os
from typing import Any, Final, cast
import ast_serialize # type: ignore[import-untyped, import-not-found, unused-ignore]
from librt.internal import (
read_float as read_float_bare,
read_int as read_int_bare,
read_str as read_str_bare,
)
from mypy import message_registry, nodes, types
from mypy.cache import (
DICT_STR_GEN,
END_TAG,
LIST_GEN,
LIST_INT,
LITERAL_FLOAT,
LITERAL_NONE,
LITERAL_STR,
LOCATION,
ReadBuffer,
Tag,
read_bool,
read_int,
read_str,
read_str_opt,
read_tag,
)
from mypy.nodes import (
ARG_KINDS,
ARG_POS,
IMPORT_METADATA,
IMPORTALL_METADATA,
IMPORTFROM_METADATA,
MISSING_FALLBACK,
Argument,
AssertStmt,
AssignmentExpr,
AssignmentStmt,
AwaitExpr,
Block,
BreakStmt,
BytesExpr,
CallExpr,
ClassDef,
ComparisonExpr,
ComplexExpr,
ConditionalExpr,
Context,
ContinueStmt,
Decorator,
DelStmt,
DictExpr,
DictionaryComprehension,
EllipsisExpr,
Expression,
ExpressionStmt,
FileRawData,
FloatExpr,
ForStmt,
FuncDef,
GeneratorExpr,
GlobalDecl,
IfStmt,
Import,
ImportAll,
ImportBase,
ImportFrom,
IndexExpr,
IntExpr,
LambdaExpr,
ListComprehension,
ListExpr,
MatchStmt,
MemberExpr,
MypyFile,
NameExpr,
NonlocalDecl,
OperatorAssignmentStmt,
OpExpr,
OverloadedFuncDef,
OverloadPart,
PassStmt,
RaiseStmt,
RefExpr,
ReturnStmt,
SetComprehension,
SetExpr,
SliceExpr,
StarExpr,
Statement,
StrExpr,
SuperExpr,
TemplateStrExpr,
TempNode,
TryStmt,
TupleExpr,
TypeAliasStmt,
TypeParam,
UnaryExpr,
Var,
WhileStmt,
WithStmt,
YieldExpr,
YieldFromExpr,
)
from mypy.options import Options
from mypy.patterns import (
AsPattern,
ClassPattern,
MappingPattern,
OrPattern,
Pattern,
SequencePattern,
SingletonPattern,
StarredPattern,
ValuePattern,
)
from mypy.reachability import infer_reachability_of_if_statement
from mypy.sharedparse import special_function_elide_names
from mypy.types import (
AnyType,
CallableArgument,
CallableType,
EllipsisType,
Instance,
ProperType,
RawExpressionType,
TupleType,
Type,
TypedDictType,
TypeList,
TypeOfAny,
UnboundType,
UnionType,
UnpackType,
)
from mypy.util import unnamed_function
TypeIgnores = list[tuple[int, list[str]]]
# There is no way to create reasonable fallbacks at this stage,
# they must be patched later.
_dummy_fallback: Final = Instance(MISSING_FALLBACK, [], -1)
class State:
def __init__(self, options: Options) -> None:
self.options = options
self.errors: list[dict[str, Any]] = []
self.num_funcs = 0
def add_error(
self,
message: str,
line: int,
column: int,
*,
blocker: bool = False,
code: str | None = None,
) -> None:
"""Report an error at a specific location.
Args:
message: Error message to display
line: Line number where error occurred
column: Column number where error occurred
blocker: If True, this error blocks further analysis
code: Error code for categorization
"""
self.errors.append(
{"line": line, "column": column, "message": message, "blocker": blocker, "code": code}
)
def native_parse(
filename: str, options: Options, skip_function_bodies: bool = False, imports_only: bool = False
) -> tuple[MypyFile, list[dict[str, Any]], TypeIgnores]:
"""Parse a Python file using the native Rust-based parser.
Uses the ast_serialize Rust extension to parse Python code and deserialize
the resulting AST directly into mypy's native AST representation.
Args:
filename: Path to the Python source file to parse
options: Mypy options affecting parsing behavior (e.g., Python version)
skip_function_bodies: If True, many function and method bodies are omitted from
the AST, useful for parsing stubs or extracting signatures without full
implementation details
imports_only: If True create an empty MypyFile with actual serialized defs
stored in binary_data.
Returns:
A tuple containing:
- MypyFile: The parsed AST as a mypy AST node
- list[dict[str, Any]]: List of parse errors and deserialization errors
- TypeIgnores: List of (line_number, ignored_codes) tuples for type: ignore comments
"""
# If the path is a directory, return empty AST (matching fastparse behavior)
# This can happen for packages that only contain .pyc files without source
if os.path.isdir(filename):
node = MypyFile([], [])
node.path = filename
return node, [], []
b, errors, ignores, import_bytes, is_partial_package, uses_template_strings = (
parse_to_binary_ast(filename, options, skip_function_bodies)
)
data = ReadBuffer(b)
n = read_int(data)
state = State(options)
if imports_only:
defs = []
else:
defs = read_statements(state, data, n)
imports = deserialize_imports(import_bytes)
node = MypyFile(defs, imports)
node.path = filename
node.is_partial_stub_package = is_partial_package
if imports_only:
node.raw_data = FileRawData(
b, import_bytes, errors, dict(ignores), is_partial_package, uses_template_strings
)
node.uses_template_strings = uses_template_strings
# Merge deserialization errors with parsing errors
all_errors = errors + state.errors
return node, all_errors, ignores
def expect_end_tag(data: ReadBuffer) -> None:
assert read_tag(data) == END_TAG
def expect_tag(data: ReadBuffer, tag: Tag) -> None:
assert (actual := read_tag(data)) == tag, actual
def read_statements(state: State, data: ReadBuffer, n: int) -> list[Statement]:
defs: list[Statement] = []
old_num_funcs = state.num_funcs
for _ in range(n):
stmt = read_statement(state, data)
defs.append(stmt)
if state.num_funcs > old_num_funcs + 1:
# There were at least two functions, so we may need to merge overloads.
defs = fix_function_overloads(state, defs)
return defs
def parse_to_binary_ast(
filename: str, options: Options, skip_function_bodies: bool = False
) -> tuple[bytes, list[dict[str, Any]], TypeIgnores, bytes, bool, bool]:
ast_bytes, errors, ignores, import_bytes, ast_data = ast_serialize.parse(
filename,
skip_function_bodies=skip_function_bodies,
python_version=options.python_version,
platform=options.platform,
always_true=options.always_true,
always_false=options.always_false,
)
return (
ast_bytes,
cast("list[dict[str, Any]]", errors),
ignores,
import_bytes,
ast_data["is_partial_package"],
ast_data["uses_template_strings"],
)
def read_statement(state: State, data: ReadBuffer) -> Statement:
tag = read_tag(data)
stmt: Statement
if tag == nodes.FUNC_DEF_STMT:
return read_func_def(state, data)
elif tag == nodes.DECORATOR:
expect_tag(data, LIST_GEN)
n_decorators = read_int_bare(data)
decorators = [read_expression(state, data) for i in range(n_decorators)]
line = read_int(data)
column = read_int(data)
fdef = read_statement(state, data)
assert isinstance(fdef, FuncDef)
fdef.is_decorated = True
var = Var(fdef.name)
var.line = fdef.line
var.is_ready = False
stmt = Decorator(fdef, decorators, var)
stmt.line = line
stmt.column = column
stmt.end_line = fdef.end_line
stmt.end_column = fdef.end_column
# TODO: Adjust funcdef location to start after decorator?
expect_end_tag(data)
return stmt
elif tag == nodes.EXPR_STMT:
es = ExpressionStmt(read_expression(state, data))
set_line_column_range(es, es.expr)
expect_end_tag(data)
return es
elif tag == nodes.ASSIGNMENT_STMT:
lvalues = read_expression_list(state, data)
rvalue = read_expression(state, data)
has_type = read_bool(data)
if has_type:
type_annotation = read_type(state, data)
else:
type_annotation = None
new_syntax = read_bool(data)
a = AssignmentStmt(lvalues, rvalue, type=type_annotation, new_syntax=new_syntax)
read_loc(data, a)
# If rvalue is TempNode, copy location from AssignmentStmt
if isinstance(rvalue, TempNode):
set_line_column_range(rvalue, a)
expect_end_tag(data)
return a
elif tag == nodes.OPERATOR_ASSIGNMENT_STMT:
# Read operator string
op = read_str(data)
# Read lvalue (target)
lvalue = read_expression(state, data)
# Read rvalue (value)
rvalue = read_expression(state, data)
stmt = OperatorAssignmentStmt(op, lvalue, rvalue)
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.IF_STMT:
# Read the main if condition and body
expr = read_expression(state, data)
body = read_block(state, data)
# Read elif clauses
num_elif = read_int(data)
elif_exprs = []
elif_bodies = []
for i in range(num_elif):
elif_exprs.append(read_expression(state, data))
elif_bodies.append(read_block(state, data))
has_else = read_bool(data)
if has_else:
else_body = read_block(state, data)
else:
else_body = None
# Normalize elif into nested if/else statements
# Build from the bottom up, starting with the final else body
current_else = else_body
# Process elif clauses in reverse order
for i in range(len(elif_exprs) - 1, -1, -1):
elif_stmt = IfStmt([elif_exprs[i]], [elif_bodies[i]], current_else)
# Set location from the elif expression
elif_stmt.line = elif_exprs[i].line
elif_stmt.column = elif_exprs[i].column
# Set end location based on what follows
if current_else is not None:
elif_stmt.end_line = current_else.end_line
elif_stmt.end_column = current_else.end_column
else:
elif_stmt.end_line = elif_bodies[i].end_line
elif_stmt.end_column = elif_bodies[i].end_column
# Wrap in a Block to become the else clause for the outer if
current_else = Block([elif_stmt])
set_line_column_range(current_else, elif_stmt)
if_stmt = IfStmt([expr], [body], current_else)
read_loc(data, if_stmt)
expect_end_tag(data)
return if_stmt
elif tag == nodes.RETURN_STMT:
has_value = read_bool(data)
if has_value:
value = read_expression(state, data)
else:
value = None
stmt = ReturnStmt(value)
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.RAISE_STMT:
has_exc = read_bool(data)
if has_exc:
exc = read_expression(state, data)
else:
exc = None
has_from = read_bool(data)
if has_from:
from_expr = read_expression(state, data)
else:
from_expr = None
stmt = RaiseStmt(exc, from_expr)
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.ASSERT_STMT:
test = read_expression(state, data)
has_msg = read_bool(data)
if has_msg:
msg = read_expression(state, data)
else:
msg = None
stmt = AssertStmt(test, msg)
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.WHILE_STMT:
expr = read_expression(state, data)
body = read_block(state, data)
else_body = read_optional_block(state, data)
stmt = WhileStmt(expr, body, else_body)
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.FOR_STMT:
index = read_expression(state, data)
expr = read_expression(state, data)
body = read_block(state, data)
else_body = read_optional_block(state, data)
is_async = read_bool(data)
stmt = ForStmt(index, expr, body, else_body)
stmt.is_async = is_async
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.WITH_STMT:
n = read_int(data)
expr_list = []
target_list: list[Expression | None] = []
for _ in range(n):
context_expr = read_expression(state, data)
expr_list.append(context_expr)
has_target = read_bool(data)
if has_target:
target = read_expression(state, data)
target_list.append(target)
else:
target_list.append(None)
body = read_block(state, data)
is_async = read_bool(data)
stmt = WithStmt(expr_list, target_list, body)
stmt.is_async = is_async
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.PASS_STMT:
stmt = PassStmt()
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.BREAK_STMT:
stmt = BreakStmt()
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.CONTINUE_STMT:
stmt = ContinueStmt()
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.IMPORT:
n = read_int(data)
ids = []
for _ in range(n):
name = read_str(data)
has_asname = read_bool(data)
if has_asname:
asname = read_str(data)
else:
asname = None
ids.append((name, asname))
stmt = Import(ids)
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.IMPORT_FROM:
relative = read_int(data)
module_id = read_str(data) # Empty string for "from . import x"
n = read_int(data)
names = []
for _ in range(n):
name = read_str(data)
has_asname = read_bool(data)
if has_asname:
asname = read_str(data)
else:
asname = None
names.append((name, asname))
stmt = ImportFrom(module_id, relative, names)
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.IMPORT_ALL:
module_id = read_str(data) # Empty string for "from . import *"
relative = read_int(data)
stmt = ImportAll(module_id, relative)
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.CLASS_DEF:
return read_class_def(state, data)
elif tag == nodes.TYPE_ALIAS_STMT:
return read_type_alias_stmt(state, data)
elif tag == nodes.TRY_STMT:
return read_try_stmt(state, data)
elif tag == nodes.DEL_STMT:
expr = read_expression(state, data)
stmt = DelStmt(expr)
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.GLOBAL_DECL:
n = read_int(data)
decl_names = []
for _ in range(n):
decl_names.append(read_str(data))
stmt = GlobalDecl(decl_names)
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.NONLOCAL_DECL:
n = read_int(data)
decl_names = []
for _ in range(n):
decl_names.append(read_str(data))
stmt = NonlocalDecl(decl_names)
read_loc(data, stmt)
expect_end_tag(data)
return stmt
elif tag == nodes.MATCH_STMT:
subject = read_expression(state, data)
n_cases = read_int(data)
patterns = []
guards: list[Expression | None] = []
bodies = []
for _ in range(n_cases):
pattern = read_pattern(state, data)
patterns.append(pattern)
has_guard = read_bool(data)
if has_guard:
guard = read_expression(state, data)
guards.append(guard)
else:
guards.append(None)
body = read_block(state, data)
bodies.append(body)
stmt = MatchStmt(subject, patterns, guards, bodies)
read_loc(data, stmt)
expect_end_tag(data)
return stmt
else:
assert False, tag
def read_parameters(state: State, data: ReadBuffer) -> tuple[list[Argument], bool]:
"""Read function/lambda parameters from the buffer.
Returns:
A tuple of (arguments list, has_annotations flag)
"""
expect_tag(data, LIST_GEN)
n_args = read_int_bare(data)
arguments = []
has_ann = False
for _ in range(n_args):
arg_name = read_str(data)
arg_kind_int = read_int(data)
arg_kind = ARG_KINDS[arg_kind_int]
has_type = read_bool(data)
if has_type:
ann = read_type(state, data)
has_ann = True
else:
ann = None
has_default = read_bool(data)
if has_default:
default = read_expression(state, data)
else:
default = None
pos_only = read_bool(data)
# Apply implicit_optional if enabled and default is None
if state.options.implicit_optional and ann is not None:
optional = isinstance(default, NameExpr) and default.name == "None"
if isinstance(ann, UnboundType):
ann.optional = optional
var = Var(arg_name, ann)
var.is_inferred = False
var.is_argument = True
arg = Argument(var, ann, default, arg_kind, pos_only)
read_loc(data, arg)
set_line_column_range(var, arg)
arguments.append(arg)
return arguments, has_ann
def read_type_params(state: State, data: ReadBuffer) -> list[TypeParam]:
"""Read type parameters (PEP 695 generics)."""
type_params: list[TypeParam] = []
n = read_int_bare(data)
for _ in range(n):
kind = read_int(data)
name = read_str(data)
has_bound = read_bool(data)
if has_bound:
upper_bound = read_type(state, data)
else:
upper_bound = None
expect_tag(data, LIST_GEN)
n_values = read_int_bare(data)
values = [read_type(state, data) for _ in range(n_values)]
has_default = read_bool(data)
if has_default:
default = read_type(state, data)
else:
default = None
type_params.append(TypeParam(name, kind, upper_bound, values, default))
return type_params
def read_func_def(state: State, data: ReadBuffer) -> FuncDef:
state.num_funcs += 1
name = read_str(data)
arguments, has_ann = read_parameters(state, data)
if special_function_elide_names(name):
for arg in arguments:
arg.pos_only = True
body = read_block(state, data)
is_async = read_bool(data)
# Type parameters (PEP 695)
has_type_params = read_bool(data)
if has_type_params:
type_params = read_type_params(state, data)
else:
type_params = None
has_return_type = read_bool(data)
if has_return_type:
return_type = read_type(state, data)
has_ann = True
else:
return_type = None
if has_ann:
typ = CallableType(
[
arg.type_annotation if arg.type_annotation else AnyType(TypeOfAny.unannotated)
for arg in arguments
],
[arg.kind for arg in arguments],
[None if arg.pos_only else arg.variable.name for arg in arguments],
return_type if return_type else AnyType(TypeOfAny.unannotated),
_dummy_fallback,
)
else:
typ = None
func_def = FuncDef(name, arguments, body, typ=typ, type_args=type_params)
if is_async:
func_def.is_coroutine = True
read_loc(data, func_def)
if typ:
typ.line = func_def.line
typ.column = func_def.column
typ.definition = func_def
# TODO: This seems wasteful, can we avoid it?
func_def.unanalyzed_type = typ.copy_modified()
expect_end_tag(data)
return func_def
def read_class_def(state: State, data: ReadBuffer) -> ClassDef:
name = read_str(data)
body = read_block(state, data)
base_type_exprs = read_expression_list(state, data)
expect_tag(data, LIST_GEN)
n_decorators = read_int_bare(data)
decorators = [read_expression(state, data) for _ in range(n_decorators)]
# Type parameters (PEP 695)
has_type_params = read_bool(data)
if has_type_params:
type_params = read_type_params(state, data)
else:
type_params = None
# Keywords (all keyword arguments including metaclass)
expect_tag(data, DICT_STR_GEN)
n_keywords = read_int_bare(data)
keywords = []
for _ in range(n_keywords):
key = read_str(data)
value = read_expression(state, data)
keywords.append((key, value))
# Extract metaclass from keywords if present
metaclass = dict(keywords).get("metaclass") if keywords else None
# Remove metaclass from keywords since it's passed as a separate field
filtered_keywords = [(k, v) for k, v in keywords if k != "metaclass"] if keywords else None
class_def = ClassDef(
name,
body,
base_type_exprs=base_type_exprs if base_type_exprs else None,
metaclass=metaclass,
keywords=filtered_keywords,
type_args=type_params,
)
class_def.decorators = decorators
read_loc(data, class_def)
expect_end_tag(data)
return class_def
def read_type_alias_stmt(state: State, data: ReadBuffer) -> TypeAliasStmt:
"""Read PEP 695 type alias statement."""
name = read_expression(state, data)
assert isinstance(name, NameExpr), f"Expected NameExpr for type alias name, got {type(name)}"
n_type_params = read_int_bare(data)
if n_type_params > 0:
type_params = []
for _ in range(n_type_params):
kind = read_int(data)
param_name = read_str(data)
has_bound = read_bool(data)
if has_bound:
upper_bound = read_type(state, data)
else:
upper_bound = None
# Read values (for constrained TypeVar)
expect_tag(data, LIST_GEN)
n_values = read_int_bare(data)
values = [read_type(state, data) for _ in range(n_values)]
has_default = read_bool(data)
if has_default:
default = read_type(state, data)
else:
default = None
type_params.append(TypeParam(param_name, kind, upper_bound, values, default))
else:
type_params = []
value_expr = read_expression(state, data)
# Wrap the value expression in a LambdaExpr as expected by TypeAliasStmt
# The LambdaExpr body is a Block with a single ReturnStmt
return_stmt = ReturnStmt(value_expr)
set_line_column_range(return_stmt, value_expr)
block = Block([return_stmt])
block.line = -1 # Synthetic block
block.column = 0
block.end_line = -1
block.end_column = 0
lambda_expr = LambdaExpr([], block)
set_line_column_range(lambda_expr, value_expr)
stmt = TypeAliasStmt(name, type_params, lambda_expr)
read_loc(data, stmt)
expect_end_tag(data)
return stmt
def read_try_stmt(state: State, data: ReadBuffer) -> TryStmt:
body = read_block(state, data)
num_handlers = read_int(data)
types_list: list[Expression | None] = []
for _ in range(num_handlers):
has_type = read_bool(data)
if has_type:
exc_type = read_expression(state, data)
types_list.append(exc_type)
else:
types_list.append(None)
vars_list: list[NameExpr | None] = []
for _ in range(num_handlers):
has_name = read_bool(data)
if has_name:
var_name = read_str(data)
var_expr = NameExpr(var_name)
vars_list.append(var_expr)
else:
vars_list.append(None)
handlers = []
for _ in range(num_handlers):
handler_body = read_block(state, data)
handlers.append(handler_body)
has_else = read_bool(data)
if has_else:
else_body = read_block(state, data)
else:
else_body = None
has_finally = read_bool(data)
if has_finally:
finally_body = read_block(state, data)
else:
finally_body = None
# Read is_star flag (for except* in Python 3.11+)
is_star = read_bool(data)
stmt = TryStmt(body, vars_list, types_list, handlers, else_body, finally_body)
stmt.is_star = is_star
read_loc(data, stmt)
expect_end_tag(data)
return stmt
def read_type(state: State, data: ReadBuffer) -> Type:
tag = read_tag(data)
if tag == types.UNBOUND_TYPE:
name = read_str(data)
expect_tag(data, LIST_GEN)
n = read_int_bare(data)
args = tuple(read_type(state, data) for i in range(n))
empty_tuple_index = read_bool(data)
# Read optional original_str_expr
t = read_tag(data)
if t == LITERAL_NONE:
original_str_expr = None
elif t == LITERAL_STR:
original_str_expr = read_str_bare(data)
else:
assert False, f"Unexpected tag for original_str_expr: {t}"
# Read optional original_str_fallback
t = read_tag(data)
if t == LITERAL_NONE:
original_str_fallback = None
elif t == LITERAL_STR:
original_str_fallback = read_str_bare(data)
else:
assert False, f"Unexpected tag for original_str_fallback: {t}"
unbound = UnboundType(
name,
args,
empty_tuple_index=empty_tuple_index,
original_str_expr=original_str_expr,
original_str_fallback=original_str_fallback,
)
read_loc(data, unbound)
expect_end_tag(data)
return unbound
elif tag == types.UNION_TYPE:
# Read items list
expect_tag(data, LIST_GEN)
n = read_int_bare(data)
items = [read_type(state, data) for i in range(n)]
# Read uses_pep604_syntax flag
uses_pep604_syntax = read_bool(data)
# Read optional original_str_expr
t = read_tag(data)
if t == LITERAL_NONE:
original_str_expr = None
elif t == LITERAL_STR:
original_str_expr = read_str_bare(data)
else:
assert False, f"Unexpected tag for original_str_expr: {t}"
# Read optional original_str_fallback
t = read_tag(data)
if t == LITERAL_NONE:
original_str_fallback = None
elif t == LITERAL_STR:
original_str_fallback = read_str_bare(data)
else:
assert False, f"Unexpected tag for original_str_fallback: {t}"
union = UnionType(items, uses_pep604_syntax=uses_pep604_syntax)
union.original_str_expr = original_str_expr
union.original_str_fallback = original_str_fallback
union.is_evaluated = read_bool(data)
read_loc(data, union)
expect_end_tag(data)
return union
elif tag == types.LIST_TYPE:
# Read items list
expect_tag(data, LIST_GEN)
n = read_int_bare(data)
items = [read_type(state, data) for i in range(n)]
type_list = TypeList(items)
read_loc(data, type_list)
expect_end_tag(data)
return type_list
elif tag == types.TUPLE_TYPE:
# Read items list
expect_tag(data, LIST_GEN)
n = read_int_bare(data)
items = [read_type(state, data) for i in range(n)]
implicit = read_bool(data)
tuple_type = TupleType(items, _dummy_fallback, implicit=implicit)
read_loc(data, tuple_type)
expect_end_tag(data)
return tuple_type
elif tag == types.TYPED_DICT_TYPE:
expect_tag(data, LIST_GEN)
n = read_int_bare(data)
keys = [read_str_opt(data) for i in range(n)]
expect_tag(data, LIST_GEN)
n = read_int_bare(data)
values = [read_type(state, data) for i in range(n)]
td_items = {}
extra_items_from = []
for key, val in zip(keys, values):
if key is None:
assert isinstance(val, ProperType)
extra_items_from.append(val)
else:
td_items[key] = val
typeddict_type = TypedDictType(td_items, set(), set(), _dummy_fallback)
typeddict_type.extra_items_from = extra_items_from
read_loc(data, typeddict_type)
expect_end_tag(data)
return typeddict_type
elif tag == types.ELLIPSIS_TYPE:
# EllipsisType has no attributes
ellipsis_type = EllipsisType()
read_loc(data, ellipsis_type)
expect_end_tag(data)
return ellipsis_type
elif tag == types.RAW_EXPRESSION_TYPE:
type_name = read_str(data)
value: types.LiteralValue | str | None
if type_name == "builtins.bool":
value = read_bool(data)
elif type_name == "builtins.int":
value = read_int(data)
elif type_name == "builtins.str":
value = read_str(data)
elif type_name == "builtins.bytes":
# Bytes literals are serialized as escaped strings
value = read_str(data)
elif type_name == "typing.Any":
# Invalid type - read None value
tag = read_tag(data)
assert tag == LITERAL_NONE, f"Expected LITERAL_NONE for invalid type, got {tag}"
value = None
else:
assert False, f"Unsupported RawExpressionType: {type_name}"
raw_type = RawExpressionType(value, type_name)
read_loc(data, raw_type)
expect_end_tag(data)
return raw_type
elif tag == types.UNPACK_TYPE:
inner_type = read_type(state, data)
from_star_syntax = read_bool(data)
unpack = UnpackType(inner_type, from_star_syntax=from_star_syntax)
read_loc(data, unpack)
expect_end_tag(data)
return unpack
elif tag == types.CALL_TYPE:
return read_call_type(state, data)
else:
assert False, tag
def stringify_type_name(typ: Type) -> str | None:
"""Extract qualified name from a type (for Arg constructor detection)."""
if isinstance(typ, UnboundType):
return typ.name
return None
def extract_arg_name(typ: Type) -> str | None:
"""Extract argument name from a type (for Arg name parameter)."""
if isinstance(typ, RawExpressionType) and typ.base_type_name == "builtins.str":
return typ.literal_value # type: ignore[return-value]
elif isinstance(typ, UnboundType):
# String literals in type context are parsed as UnboundType (forward references)
# For Arg names, these are typically simple names without dots
if typ.name == "None":
return None
# Return the name as-is (it's the argument name)
return typ.name
return None # Invalid, but let validation handle it
def read_call_type(state: State, data: ReadBuffer) -> Type:
"""Read Call in type context - check if it's an Arg/DefaultArg/VarArg/KwArg constructor.
This performs validation and error reporting similar to mypy/fastparse.py.
"""
callee_type = read_type(state, data)
# Read positional arguments
expect_tag(data, LIST_GEN)
n_args = read_int_bare(data)
args = [read_type(state, data) for _ in range(n_args)]
# Read keyword arguments
expect_tag(data, LIST_GEN)
n_kwargs = read_int_bare(data)
kwargs = []
for _ in range(n_kwargs):
tag_kw = read_tag(data)
if tag_kw == LITERAL_NONE:
kw_name = None
elif tag_kw == LITERAL_STR:
kw_name = read_str_bare(data)
else:
assert False, f"Unexpected tag for keyword name: {tag_kw}"
kw_value = read_type(state, data)
kwargs.append((kw_name, kw_value))
# Try to detect Arg/DefaultArg/VarArg/KwArg pattern
constructor = stringify_type_name(callee_type)
# We'll read location before processing errors so we can report them correctly
invalid = AnyType(TypeOfAny.from_error)
read_loc(data, invalid)
expect_end_tag(data)
if not constructor:
# ARG_CONSTRUCTOR_NAME_EXPECTED
state.add_error(
message_registry.ARG_CONSTRUCTOR_NAME_EXPECTED.value,
invalid.line,
invalid.column,
blocker=True,
code="misc",
)
return invalid
# Extract type and name from arguments
name: str | None = None
name_set_from_positional = False
default_type = AnyType(TypeOfAny.special_form)
typ: Type = default_type
typ_set_from_positional = False
# Process positional arguments
for i, arg in enumerate(args):
if i == 0:
typ = arg
typ_set_from_positional = True
elif i == 1:
name = extract_arg_name(arg)
name_set_from_positional = True
else:
# ARG_CONSTRUCTOR_TOO_MANY_ARGS
state.add_error(
message_registry.ARG_CONSTRUCTOR_TOO_MANY_ARGS.value,
invalid.line,
invalid.column,
blocker=True,
code="misc",
)
# Process keyword arguments
for kw_name, kw_value in kwargs:
if kw_name == "name":
# MULTIPLE_VALUES_FOR_NAME_KWARG
if name is not None and name_set_from_positional:
state.add_error(
message_registry.MULTIPLE_VALUES_FOR_NAME_KWARG.format(constructor).value,
invalid.line,
invalid.column,
blocker=True,
code="misc",
)
name = extract_arg_name(kw_value)
elif kw_name == "type":
# MULTIPLE_VALUES_FOR_TYPE_KWARG
if typ is not default_type and typ_set_from_positional:
state.add_error(
message_registry.MULTIPLE_VALUES_FOR_TYPE_KWARG.format(constructor).value,
invalid.line,
invalid.column,
blocker=True,
code="misc",
)
typ = kw_value
else:
# ARG_CONSTRUCTOR_UNEXPECTED_ARG
state.add_error(
message_registry.ARG_CONSTRUCTOR_UNEXPECTED_ARG.format(kw_name).value,
invalid.line,
invalid.column,
blocker=True,
code="misc",
)
# Create CallableArgument
call_arg = CallableArgument(typ, name, constructor)
set_line_column_range(call_arg, invalid)
return call_arg
def read_pattern(state: State, data: ReadBuffer) -> Pattern:
"""Read a pattern node from the buffer."""
tag = read_tag(data)
if tag == nodes.AS_PATTERN:
has_pattern = read_bool(data)
if has_pattern:
pattern = read_pattern(state, data)
else:
pattern = None
has_name = read_bool(data)
if has_name:
name_str = read_str(data)
name = NameExpr(name_str)
read_loc(data, name)
else:
name = None
as_pattern = AsPattern(pattern, name)
read_loc(data, as_pattern)
expect_end_tag(data)
return as_pattern
elif tag == nodes.OR_PATTERN:
n = read_int(data)
patterns = [read_pattern(state, data) for _ in range(n)]
or_pattern = OrPattern(patterns)
read_loc(data, or_pattern)
expect_end_tag(data)
return or_pattern
elif tag == nodes.VALUE_PATTERN:
expr = read_expression(state, data)
value_pattern = ValuePattern(expr)
read_loc(data, value_pattern)
expect_end_tag(data)
return value_pattern
elif tag == nodes.SINGLETON_PATTERN:
singleton_tag = read_tag(data)
if singleton_tag == LITERAL_NONE:
value = None
else:
# It's a boolean
value = singleton_tag == 1 # TAG_LITERAL_TRUE
singleton_pattern = SingletonPattern(value)
read_loc(data, singleton_pattern)
expect_end_tag(data)
return singleton_pattern
elif tag == nodes.SEQUENCE_PATTERN:
n = read_int(data)
patterns = [read_pattern(state, data) for _ in range(n)]
sequence_pattern = SequencePattern(patterns)
read_loc(data, sequence_pattern)
expect_end_tag(data)
return sequence_pattern
elif tag == nodes.STARRED_PATTERN:
# Read optional capture name
has_name = read_bool(data)
if has_name:
name_str = read_str(data)
name = NameExpr(name_str)
read_loc(data, name)
else:
name = None
starred_pattern = StarredPattern(name)
read_loc(data, starred_pattern)
expect_end_tag(data)
return starred_pattern
elif tag == nodes.MAPPING_PATTERN:
n = read_int(data)
keys = []
values = []
for _ in range(n):
key = read_expression(state, data)
value = read_pattern(state, data)
keys.append(key)
values.append(value)
has_rest = read_bool(data)
if has_rest:
rest_str = read_str(data)
rest = NameExpr(rest_str)
read_loc(data, rest)
else:
rest = None
mapping_pattern = MappingPattern(keys, values, rest)
read_loc(data, mapping_pattern)
expect_end_tag(data)
return mapping_pattern
elif tag == nodes.CLASS_PATTERN:
class_ref = cast(RefExpr, read_expression(state, data))
n_positional = read_int(data)
positionals = [read_pattern(state, data) for _ in range(n_positional)]
n_keywords = read_int(data)
keyword_keys = []
keyword_values = []
for _ in range(n_keywords):
key = read_str(data)
value = read_pattern(state, data)
keyword_keys.append(key)
keyword_values.append(value)
class_pattern = ClassPattern(class_ref, positionals, keyword_keys, keyword_values)
read_loc(data, class_pattern)
expect_end_tag(data)
return class_pattern
else:
assert False, f"Unknown pattern tag: {tag}"
def read_block(state: State, data: ReadBuffer) -> Block:
expect_tag(data, nodes.BLOCK)
expect_tag(data, LIST_GEN)
n = read_int_bare(data)
is_unreachable = read_bool(data)
if n == 0:
# Empty block - read explicit location
b = Block([], is_unreachable=is_unreachable)
read_loc(data, b)
expect_end_tag(data)
return b
else:
# Non-empty block - read statements and set location from them
a = read_statements(state, data, n)
expect_end_tag(data)
b = Block(a, is_unreachable=is_unreachable)
b.line = a[0].line
b.column = a[0].column
b.end_line = a[-1].end_line
b.end_column = a[-1].end_column
return b
def read_optional_block(state: State, data: ReadBuffer) -> Block | None:
expect_tag(data, nodes.BLOCK)
expect_tag(data, LIST_GEN)
n = read_int_bare(data)
is_unreachable = read_bool(data)
if n == 0:
b = None
else:
a = [read_statement(state, data) for i in range(n)]
b = Block(a, is_unreachable=is_unreachable)
b.line = a[0].line
b.column = a[0].column
b.end_line = a[-1].end_line
b.end_column = a[-1].end_column
expect_end_tag(data)
return b
bin_ops: Final = ["+", "-", "*", "@", "/", "%", "**", "<<", ">>", "|", "^", "&", "//"]
bool_ops: Final = ["and", "or"]
cmp_ops: Final = ["==", "!=", "<", "<=", ">", ">=", "is", "is not", "in", "not in"]
unary_ops: Final = ["~", "not", "+", "-"]
def read_expression(state: State, data: ReadBuffer) -> Expression:
tag = read_tag(data)
expr: Expression
if tag == nodes.CALL_EXPR:
callee = read_expression(state, data)
args = read_expression_list(state, data)
# Read argument kinds
expect_tag(data, LIST_INT)
n_kinds = read_int_bare(data)
arg_kinds = [ARG_KINDS[read_int_bare(data)] for _ in range(n_kinds)]
# Read argument names
expect_tag(data, LIST_GEN)
n_names = read_int_bare(data)
arg_names: list[str | None] = []
for _ in range(n_names):
tag = read_tag(data)
if tag == LITERAL_NONE:
arg_names.append(None)
elif tag == LITERAL_STR:
arg_names.append(read_str_bare(data))
else:
assert False, f"Unexpected tag for arg_name: {tag}"
ce = CallExpr(callee, args, arg_kinds, arg_names)
read_loc(data, ce)
expect_end_tag(data)
return ce
elif tag == nodes.NAME_EXPR:
s = read_str(data)
ne = NameExpr(s)
read_loc(data, ne)
expect_end_tag(data)
return ne
elif tag == nodes.MEMBER_EXPR:
e = read_expression(state, data)
attr = read_str(data)
m = MemberExpr(e, attr)
# Check if this is a super() call - if so, convert to SuperExpr
if isinstance(e, CallExpr) and isinstance(e.callee, NameExpr) and e.callee.name == "super":
result: Expression = SuperExpr(attr, e)
else:
result = m
read_loc(data, result)
expect_end_tag(data)
return result
elif tag == nodes.STR_EXPR:
se = StrExpr(read_str(data))
read_loc(data, se)
expect_end_tag(data)
return se
elif tag == nodes.INT_EXPR:
ie = IntExpr(read_int(data))
read_loc(data, ie)
expect_end_tag(data)
return ie
elif tag == nodes.FLOAT_EXPR:
expect_tag(data, LITERAL_FLOAT)
value = read_float_bare(data)
fe = FloatExpr(value)
read_loc(data, fe)
expect_end_tag(data)
return fe
elif tag == nodes.LIST_EXPR:
items = read_expression_list(state, data)
expr = ListExpr(items)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.TUPLE_EXPR:
items = read_expression_list(state, data)
t = TupleExpr(items)
read_loc(data, t)
expect_end_tag(data)
return t
elif tag == nodes.SET_EXPR:
items = read_expression_list(state, data)
expr = SetExpr(items)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.GENERATOR_EXPR:
expr = read_generator_expr(state, data)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.LIST_COMPREHENSION:
generator = read_generator_expr(state, data)
expr = ListComprehension(generator)
read_loc(data, expr)
# Also copy location to the inner generator
set_line_column_range(generator, expr)
expect_end_tag(data)
return expr
elif tag == nodes.SET_COMPREHENSION:
generator = read_generator_expr(state, data)
expr = SetComprehension(generator)
read_loc(data, expr)
# Also copy location to the inner generator
set_line_column_range(generator, expr)
expect_end_tag(data)
return expr
elif tag == nodes.DICT_COMPREHENSION:
key = read_expression(state, data)
value = read_expression(state, data)
n_generators = read_int(data)
indices = [read_expression(state, data) for _ in range(n_generators)]
sequences = [read_expression(state, data) for _ in range(n_generators)]
condlists = [read_expression_list(state, data) for _ in range(n_generators)]
is_async = [read_bool(data) for _ in range(n_generators)]
expr = DictionaryComprehension(key, value, indices, sequences, condlists, is_async)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.YIELD_EXPR:
has_value = read_bool(data)
if has_value:
value = read_expression(state, data)
else:
value = None
expr = YieldExpr(value)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.YIELD_FROM_EXPR:
value = read_expression(state, data)
expr = YieldFromExpr(value)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.OP_EXPR:
op = bin_ops[read_int(data)]
left = read_expression(state, data)
right = read_expression(state, data)
o = OpExpr(op, left, right)
# TODO: Store these explicitly?
o.line = left.line
o.column = left.column
o.end_line = right.end_line
o.end_column = right.end_column
expect_end_tag(data)
return o
elif tag == nodes.INDEX_EXPR:
base = read_expression(state, data)
index = read_expression(state, data)
expr = IndexExpr(base, index)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.BOOL_OP_EXPR:
op = bool_ops[read_int(data)]
values = read_expression_list(state, data)
# Convert list of values to nested OpExpr nodes
# E.g., [a, b, c] with "and" becomes OpExpr("and", OpExpr("and", a, b), c)
assert len(values) >= 2
result = values[0]
for val in values[1:]:
result = OpExpr(op, result, val)
result.line = values[0].line
result.column = values[0].column
result.end_line = val.end_line
result.end_column = val.end_column
read_loc(data, result)
expect_end_tag(data)
return result
elif tag == nodes.COMPARISON_EXPR:
left = read_expression(state, data)
expect_tag(data, LIST_INT)
n_ops = read_int_bare(data)
ops = [cmp_ops[read_int_bare(data)] for _ in range(n_ops)]
comparators = read_expression_list(state, data)
assert len(ops) == len(comparators)
expr = ComparisonExpr(ops, [left] + comparators)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.UNARY_EXPR:
op = unary_ops[read_int(data)]
operand = read_expression(state, data)
expr = UnaryExpr(op, operand)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.DICT_EXPR:
expect_tag(data, LIST_GEN)
n_keys = read_int_bare(data)
keys: list[Expression | None] = []
for _ in range(n_keys):
has_key = read_bool(data)
if has_key:
keys.append(read_expression(state, data))
else:
keys.append(None)
values = read_expression_list(state, data)
# Zip keys and values into items
items = list(zip(keys, values))
expr = DictExpr(items)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.COMPLEX_EXPR:
expect_tag(data, LITERAL_FLOAT)
real = read_float_bare(data)
expect_tag(data, LITERAL_FLOAT)
imag = read_float_bare(data)
value = complex(real, imag)
expr = ComplexExpr(value)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.SLICE_EXPR:
has_begin = read_bool(data)
begin_index = read_expression(state, data) if has_begin else None
has_end = read_bool(data)
end_index = read_expression(state, data) if has_end else None
has_stride = read_bool(data)
stride = read_expression(state, data) if has_stride else None
expr = SliceExpr(begin_index, end_index, stride)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.TEMP_NODE:
# TempNode with no attributes
temp = TempNode(AnyType(TypeOfAny.special_form), no_rhs=True)
expect_end_tag(data)
return temp
elif tag == nodes.ELLIPSIS_EXPR:
expr = EllipsisExpr()
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.CONDITIONAL_EXPR:
if_expr = read_expression(state, data)
cond = read_expression(state, data)
else_expr = read_expression(state, data)
expr = ConditionalExpr(cond, if_expr, else_expr)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.FSTRING_EXPR:
# F-strings are converted into nodes representing "".join([...]), to match
# pre-existing behavior.
nparts = read_int(data)
fitems = []
for _ in range(nparts):
b = read_bool(data)
if b:
n = read_int(data)
for i in range(n):
fitems.append(read_fstring_item(state, data))
else:
s = StrExpr(read_str(data))
read_loc(data, s)
fitems.append(s)
expr = build_fstring_join(state, data, fitems)
expect_end_tag(data)
return expr
elif tag == nodes.TSTRING_EXPR:
nparts = read_int(data)
titems: list[Expression | tuple[Expression, str, str | None, Expression | None]] = []
for _ in range(nparts):
if read_bool(data):
e = read_expression(state, data)
s = read_str(data)
if read_bool(data):
conv = read_str(data)
else:
conv = None
if read_bool(data):
# Parse format spec as a JoinedStr, this matches the old parser behavior.
format_spec = read_fstring_items(state, data)
else:
format_spec = None
titems.append((e, s, conv, format_spec))
else:
s = StrExpr(read_str(data))
read_loc(data, s)
titems.append(s)
expr = TemplateStrExpr(titems)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.LAMBDA_EXPR:
arguments, has_ann = read_parameters(state, data)
body = read_block(state, data)
if has_ann:
typ = CallableType(
[
arg.type_annotation if arg.type_annotation else AnyType(TypeOfAny.unannotated)
for arg in arguments
],
[arg.kind for arg in arguments],
[None if arg.pos_only else arg.variable.name for arg in arguments],
AnyType(TypeOfAny.unannotated),
_dummy_fallback,
)
else:
typ = None
expr = LambdaExpr(arguments, body)
expr.type = typ
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.NAMED_EXPR:
target = read_expression(state, data)
value = read_expression(state, data)
# AssignmentExpr expects target to be a NameExpr
if not isinstance(target, NameExpr):
# In case target is not a NameExpr, we need to handle this
# For now, we'll assert since the grammar should ensure it's a NameExpr
assert isinstance(
target, NameExpr
), f"Expected NameExpr for target, got {type(target)}"
expr = AssignmentExpr(target, value)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.STAR_EXPR:
wrapped_expr = read_expression(state, data)
expr = StarExpr(wrapped_expr)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.BYTES_EXPR:
# Read bytes literal as string
value = read_str(data)
expr = BytesExpr(value)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.AWAIT_EXPR:
value = read_expression(state, data)
expr = AwaitExpr(value)
read_loc(data, expr)
expect_end_tag(data)
return expr
elif tag == nodes.BIG_INT_EXPR:
strval = read_str(data)
ie = IntExpr(int(strval, base=0))
read_loc(data, ie)
expect_end_tag(data)
return ie
else:
assert False, tag
def read_fstring_items(state: State, data: ReadBuffer) -> Expression:
items = []
n = read_int(data)
items = [read_fstring_item(state, data) for i in range(n)]
return build_fstring_join(state, data, items)
def build_fstring_join(state: State, data: ReadBuffer, items: list[Expression]) -> Expression:
if len(items) == 1:
expr = items[0]
read_loc(data, expr)
return expr
if all(isinstance(item, StrExpr) for item in items):
s = "".join([cast(StrExpr, item).value for item in items])
expr = StrExpr(s)
read_loc(data, expr)
return expr
args = ListExpr(items)
str_expr = StrExpr("")
member = MemberExpr(str_expr, "join")
call = CallExpr(member, [args], [ARG_POS], [None])
read_loc(data, call)
set_line_column(args, call)
set_line_column(str_expr, call)
set_line_column(member, call)
return call
def read_fstring_item(state: State, data: ReadBuffer) -> Expression:
t = read_tag(data)
if t == LITERAL_STR:
str_expr = StrExpr(read_str_bare(data))
read_loc(data, str_expr)
return str_expr
elif t == nodes.FSTRING_INTERPOLATION:
expr = read_expression(state, data)
# Read conversion flag such as !r
has_conv = read_bool(data)
if has_conv:
c = read_str(data)
fmt = "{" + c + ":{}}"
else:
fmt = "{:{}}"
# Read format spec such as <30 (which may have nested {...})
has_spec = read_bool(data)
if has_spec:
spec = read_fstring_items(state, data)
else:
spec = StrExpr("")
member = MemberExpr(StrExpr(fmt), "format")
set_line_column(member, expr)
call = CallExpr(member, [expr, spec], [ARG_POS, ARG_POS], [None, None])
set_line_column(call, expr)
expect_end_tag(data)
return call
else:
raise ValueError(f"Unexpected tag {t}")
def set_line_column(target: Context, src: Context) -> None:
target.line = src.line
target.column = src.column
def set_line_column_range(target: Context, src: Context) -> None:
target.line = src.line
target.column = src.column
target.end_line = src.end_line
target.end_column = src.end_column
def read_expression_list(state: State, data: ReadBuffer) -> list[Expression]:
expect_tag(data, LIST_GEN)
n = read_int_bare(data)
return [read_expression(state, data) for i in range(n)]
def read_generator_expr(state: State, data: ReadBuffer) -> GeneratorExpr:
"""Helper function to read comprehension data (shared by Generator, ListComp, SetComp)"""
left_expr = read_expression(state, data)
n_generators = read_int(data)
indices = [read_expression(state, data) for _ in range(n_generators)]
sequences = [read_expression(state, data) for _ in range(n_generators)]
condlists = [read_expression_list(state, data) for _ in range(n_generators)]
is_async = [read_bool(data) for _ in range(n_generators)]
return GeneratorExpr(left_expr, indices, sequences, condlists, is_async)
def read_loc(data: ReadBuffer, node: Context) -> None:
expect_tag(data, LOCATION)
line = read_int_bare(data)
node.line = line
column = read_int_bare(data)
node.column = column
node.end_line = line + read_int_bare(data)
node.end_column = column + read_int_bare(data)
def strip_contents_from_if_stmt(stmt: IfStmt) -> None:
"""Remove contents from IfStmt.
Needed to still be able to check the conditions after the contents
have been merged with the surrounding function overloads.
"""
if len(stmt.body) == 1:
stmt.body[0].body = []
if stmt.else_body and len(stmt.else_body.body) == 1:
if isinstance(stmt.else_body.body[0], IfStmt):
strip_contents_from_if_stmt(stmt.else_body.body[0])
else:
stmt.else_body.body = []
def is_stripped_if_stmt(stmt: Statement) -> bool:
"""Check stmt to make sure it is a stripped IfStmt.
See also: strip_contents_from_if_stmt
"""
if not isinstance(stmt, IfStmt):
return False
if not (len(stmt.body) == 1 and len(stmt.body[0].body) == 0):
# Body not empty
return False
if not stmt.else_body or len(stmt.else_body.body) == 0:
# No or empty else_body
return True
# For elif, IfStmt are stored recursively in else_body
return is_stripped_if_stmt(stmt.else_body.body[0])
def fail_merge_overload(state: State, node: IfStmt) -> None:
"""Report an error when overloads cannot be merged due to unknown condition."""
state.add_error(
message_registry.FAILED_TO_MERGE_OVERLOADS.value,
node.line,
node.column,
blocker=False,
code="misc",
)
def check_ifstmt_for_overloads(
stmt: IfStmt, current_overload_name: str | None = None
) -> str | None:
"""Check if IfStmt contains only overloads with the same name.
Return overload_name if found, None otherwise.
"""
# Check that block only contains a single Decorator, FuncDef, or OverloadedFuncDef.
# Multiple overloads have already been merged as OverloadedFuncDef.
if not (
len(stmt.body[0].body) == 1
and (
isinstance(stmt.body[0].body[0], (Decorator, OverloadedFuncDef))
or current_overload_name is not None
and isinstance(stmt.body[0].body[0], FuncDef)
)
or len(stmt.body[0].body) > 1
and isinstance(stmt.body[0].body[-1], OverloadedFuncDef)
and all(is_stripped_if_stmt(if_stmt) for if_stmt in stmt.body[0].body[:-1])
):
return None
overload_name = cast(Decorator | FuncDef | OverloadedFuncDef, stmt.body[0].body[-1]).name
if stmt.else_body is None or stmt.else_body.is_unreachable:
return overload_name
if len(stmt.else_body.body) == 1:
# For elif: else_body contains an IfStmt itself -> do a recursive check.
if (
isinstance(stmt.else_body.body[0], (Decorator, FuncDef, OverloadedFuncDef))
and stmt.else_body.body[0].name == overload_name
):
return overload_name
if (
isinstance(stmt.else_body.body[0], IfStmt)
and check_ifstmt_for_overloads(stmt.else_body.body[0], current_overload_name)
== overload_name
):
return overload_name
return None
def get_executable_if_block_with_overloads(
stmt: IfStmt, options: Options
) -> tuple[Block | None, IfStmt | None]:
"""Return block from IfStmt that will get executed.
Return
0 -> A block if sure that alternative blocks are unreachable.
1 -> An IfStmt if the reachability of it can't be inferred,
i.e. the truth value is unknown.
"""
infer_reachability_of_if_statement(stmt, options)
if stmt.else_body is None and stmt.body[0].is_unreachable is True:
# always False condition with no else
return None, None
if (
stmt.else_body is None
or stmt.body[0].is_unreachable is False
and stmt.else_body.is_unreachable is False
):
# The truth value is unknown, thus not conclusive
return None, stmt
if stmt.else_body.is_unreachable:
# else_body will be set unreachable if condition is always True
return stmt.body[0], None
if stmt.body[0].is_unreachable is True:
# body will be set unreachable if condition is always False
# else_body can contain an IfStmt itself (for elif) -> do a recursive check
if isinstance(stmt.else_body.body[0], IfStmt):
return get_executable_if_block_with_overloads(stmt.else_body.body[0], options)
return stmt.else_body, None
return None, stmt
def fix_function_overloads(state: State, stmts: list[Statement]) -> list[Statement]:
"""Merge consecutive function overloads into OverloadedFuncDef nodes.
This function processes a list of statements and combines function overloads
(marked with @overload decorator) that have the same name into a single
OverloadedFuncDef node. It also handles conditional overloads (overloads
inside if statements) when the condition can be evaluated.
"""
ret: list[Statement] = []
current_overload: list[OverloadPart] = []
current_overload_name: str | None = None
last_unconditional_func_def: str | None = None
last_if_stmt: IfStmt | None = None
last_if_overload: Decorator | FuncDef | OverloadedFuncDef | None = None
last_if_stmt_overload_name: str | None = None
last_if_unknown_truth_value: IfStmt | None = None
skipped_if_stmts: list[IfStmt] = []
for stmt in stmts:
if_overload_name: str | None = None
if_block_with_overload: Block | None = None
if_unknown_truth_value: IfStmt | None = None
if isinstance(stmt, IfStmt):
# Check IfStmt block to determine if function overloads can be merged
if_overload_name = check_ifstmt_for_overloads(stmt, current_overload_name)
if if_overload_name is not None:
if_block_with_overload, if_unknown_truth_value = (
get_executable_if_block_with_overloads(stmt, state.options)
)
if (
current_overload_name is not None
and isinstance(stmt, (Decorator, FuncDef))
and stmt.name == current_overload_name
):
if last_if_stmt is not None:
skipped_if_stmts.append(last_if_stmt)
if last_if_overload is not None:
# Last stmt was an IfStmt with same overload name
# Add overloads to current_overload
if isinstance(last_if_overload, OverloadedFuncDef):
current_overload.extend(last_if_overload.items)
else:
current_overload.append(last_if_overload)
last_if_stmt, last_if_overload = None, None
if last_if_unknown_truth_value:
fail_merge_overload(state, last_if_unknown_truth_value)
last_if_unknown_truth_value = None
current_overload.append(stmt)
if isinstance(stmt, FuncDef):
# This is, strictly speaking, wrong: there might be a decorated
# implementation. However, it only affects the error message we show:
# ideally it's "already defined", but "implementation must come last"
# is also reasonable.
# TODO: can we get rid of this completely and just always emit
# "implementation must come last" instead?
last_unconditional_func_def = stmt.name
elif (
current_overload_name is not None
and isinstance(stmt, IfStmt)
and if_overload_name == current_overload_name
and last_unconditional_func_def != current_overload_name
):
# IfStmt only contains stmts relevant to current_overload.
# Check if stmts are reachable and add them to current_overload,
# otherwise skip IfStmt to allow subsequent overload
# or function definitions.
skipped_if_stmts.append(stmt)
if if_block_with_overload is None:
if if_unknown_truth_value is not None:
fail_merge_overload(state, if_unknown_truth_value)
continue
if last_if_overload is not None:
# Last stmt was an IfStmt with same overload name
# Add overloads to current_overload
if isinstance(last_if_overload, OverloadedFuncDef):
current_overload.extend(last_if_overload.items)
else:
current_overload.append(last_if_overload)
last_if_stmt, last_if_overload = None, None
if isinstance(if_block_with_overload.body[-1], OverloadedFuncDef):
skipped_if_stmts.extend(cast(list[IfStmt], if_block_with_overload.body[:-1]))
current_overload.extend(if_block_with_overload.body[-1].items)
else:
current_overload.append(cast(Decorator | FuncDef, if_block_with_overload.body[0]))
else:
if last_if_stmt is not None:
ret.append(last_if_stmt)
last_if_stmt_overload_name = current_overload_name
last_if_stmt, last_if_overload = None, None
last_if_unknown_truth_value = None
if current_overload and current_overload_name == last_if_stmt_overload_name:
# Remove last stmt (IfStmt) from ret if the overload names matched
# Only happens if no executable block had been found in IfStmt
popped = ret.pop()
assert isinstance(popped, IfStmt)
skipped_if_stmts.append(popped)
if current_overload and skipped_if_stmts:
# Add bare IfStmt (without overloads) to ret
# Required for mypy to be able to still check conditions
for if_stmt in skipped_if_stmts:
strip_contents_from_if_stmt(if_stmt)
ret.append(if_stmt)
skipped_if_stmts = []
if len(current_overload) == 1:
ret.append(current_overload[0])
elif len(current_overload) > 1:
ret.append(OverloadedFuncDef(current_overload))
# If we have multiple decorated functions named "_" next to each, we want to treat
# them as a series of regular FuncDefs instead of one OverloadedFuncDef because
# most of mypy/mypyc assumes that all the functions in an OverloadedFuncDef are
# related, but multiple underscore functions next to each other aren't necessarily
# related
last_unconditional_func_def = None
if isinstance(stmt, Decorator) and not unnamed_function(stmt.name):
current_overload = [stmt]
current_overload_name = stmt.name
elif isinstance(stmt, IfStmt) and if_overload_name is not None:
current_overload = []
current_overload_name = if_overload_name
last_if_stmt = stmt
last_if_stmt_overload_name = None
if if_block_with_overload is not None:
skipped_if_stmts.extend(cast(list[IfStmt], if_block_with_overload.body[:-1]))
last_if_overload = cast(
Decorator | FuncDef | OverloadedFuncDef, if_block_with_overload.body[-1]
)
last_if_unknown_truth_value = if_unknown_truth_value
else:
current_overload = []
current_overload_name = None
ret.append(stmt)
if current_overload and skipped_if_stmts:
# Add bare IfStmt (without overloads) to ret
# Required for mypy to be able to still check conditions
for if_stmt in skipped_if_stmts:
strip_contents_from_if_stmt(if_stmt)
ret.append(if_stmt)
if len(current_overload) == 1:
ret.append(current_overload[0])
elif len(current_overload) > 1:
ret.append(OverloadedFuncDef(current_overload))
elif last_if_overload is not None:
ret.append(last_if_overload)
elif last_if_stmt is not None:
ret.append(last_if_stmt)
return ret
def deserialize_imports(import_bytes: bytes) -> list[ImportBase]:
"""Deserialize import metadata from bytes into mypy AST nodes.
Args:
import_bytes: Serialized import metadata from the Rust parser
Returns:
List of Import and ImportFrom AST nodes with location and metadata
"""
if not import_bytes:
return []
data = ReadBuffer(import_bytes)
expect_tag(data, LIST_GEN)
n_imports = read_int_bare(data)
imports: list[ImportBase] = []
for _ in range(n_imports):
tag = read_tag(data)
if tag == IMPORT_METADATA:
name = read_str(data)
relative = read_int(data)
has_asname = read_bool(data)
if has_asname:
asname = read_str(data)
else:
asname = None
# Note: relative imports are handled via ImportFrom, so relative should be 0 here
stmt = Import([(name, asname)])
_read_and_set_import_metadata(data, stmt)
imports.append(stmt)
elif tag == IMPORTFROM_METADATA:
module = read_str(data)
relative = read_int(data)
expect_tag(data, LIST_GEN)
n_names = read_int_bare(data)
names: list[tuple[str, str | None]] = []
for _ in range(n_names):
name = read_str(data)
has_asname = read_bool(data)
if has_asname:
asname = read_str(data)
else:
asname = None
names.append((name, asname))
stmt = ImportFrom(module, relative, names)
_read_and_set_import_metadata(data, stmt)
imports.append(stmt)
elif tag == IMPORTALL_METADATA:
module = read_str(data)
relative = read_int(data)
stmt = ImportAll(module, relative)
_read_and_set_import_metadata(data, stmt)
imports.append(stmt)
else:
raise ValueError(f"Unexpected tag in import metadata: {tag}")
return imports
def _read_and_set_import_metadata(data: ReadBuffer, stmt: Import | ImportFrom | ImportAll) -> None:
"""Read location and metadata flags from buffer and set them on the import statement.
Args:
data: Buffer containing serialized data
stmt: Import, ImportFrom, or ImportAll statement to populate with location and metadata
"""
read_loc(data, stmt)
# Metadata flags as a single integer bitfield
flags = read_int(data)
# Extract individual flags using bitwise operations
# Bit 0: is_top_level
# Bit 1: is_unreachable
# Bit 2: is_mypy_only
stmt.is_top_level = (flags & 0x01) != 0
stmt.is_unreachable = (flags & 0x02) != 0
stmt.is_mypy_only = (flags & 0x04) != 0