feat: indie status page MVP -- FastAPI + SQLite

- 8 DB models (services, incidents, monitors, subscribers, etc.)
- Full CRUD API for services, incidents, monitors
- Public status page with live data
- Incident detail page with timeline
- API key authentication
- Uptime monitoring scheduler
- 13 tests passing
- TECHNICAL_DESIGN.md with full spec
This commit is contained in:
IndieStatusBot 2026-04-25 05:00:00 +00:00
commit 902133edd3
4655 changed files with 1342691 additions and 0 deletions

View file

@ -0,0 +1,72 @@
"""Mypyc command-line tool.
Usage:
$ mypyc foo.py [...]
$ python3 -c 'import foo' # Uses compiled 'foo'
This is just a thin wrapper that generates a setup.py file that uses
mypycify, suitable for prototyping and testing.
"""
from __future__ import annotations
import os
import os.path
import subprocess
import sys
base_path = os.path.join(os.path.dirname(__file__), "..")
setup_format = """\
from setuptools import setup
from mypyc.build import mypycify
setup(
name='mypyc_output',
ext_modules=mypycify(
{},
opt_level="{}",
debug_level="{}",
strict_dunder_typing={},
log_trace={},
),
)
"""
def main() -> None:
build_dir = "build" # can this be overridden??
try:
os.mkdir(build_dir)
except FileExistsError:
pass
opt_level = os.getenv("MYPYC_OPT_LEVEL", "3")
debug_level = os.getenv("MYPYC_DEBUG_LEVEL", "1")
strict_dunder_typing = bool(int(os.getenv("MYPYC_STRICT_DUNDER_TYPING", "0")))
# If enabled, compiled code writes a sampled log of executed ops (or events) to
# mypyc_trace.txt.
log_trace = bool(int(os.getenv("MYPYC_LOG_TRACE", "0")))
setup_file = os.path.join(build_dir, "setup.py")
with open(setup_file, "w") as f:
f.write(
setup_format.format(
sys.argv[1:], opt_level, debug_level, strict_dunder_typing, log_trace
)
)
# We don't use run_setup (like we do in the test suite) because it throws
# away the error code from distutils, and we don't care about the slight
# performance loss here.
env = os.environ.copy()
base_path = os.path.join(os.path.dirname(__file__), "..")
env["PYTHONPATH"] = base_path + os.pathsep + env.get("PYTHONPATH", "")
cmd = subprocess.run([sys.executable, setup_file, "build_ext", "--inplace"], env=env)
sys.exit(cmd.returncode)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,437 @@
"""Always defined attribute analysis.
An always defined attribute has some statements in __init__ or the
class body that cause the attribute to be always initialized when an
instance is constructed. It must also not be possible to read the
attribute before initialization, and it can't be deletable.
We can assume that the value is always defined when reading an always
defined attribute. Otherwise we'll need to raise AttributeError if the
value is undefined (i.e. has the error value).
We use data flow analysis to figure out attributes that are always
defined. Example:
class C:
def __init__(self) -> None:
self.x = 0
if func():
self.y = 1
else:
self.y = 2
self.z = 3
In this example, the attributes 'x' and 'y' are always defined, but 'z'
is not. The analysis assumes that we know that there won't be any subclasses.
The analysis also works if there is a known, closed set of subclasses.
An attribute defined in a base class can only be always defined if it's
also always defined in all subclasses.
As soon as __init__ contains an op that can 'leak' self to another
function, we will stop inferring always defined attributes, since the
analysis is mostly intra-procedural and only looks at __init__ methods.
The called code could read an uninitialized attribute. Example:
class C:
def __init__(self) -> None:
self.x = self.foo()
def foo(self) -> int:
...
Now we won't infer 'x' as always defined, since 'foo' might read 'x'
before initialization.
As an exception to the above limitation, we perform inter-procedural
analysis of super().__init__ calls, since these are very common.
Our analysis is somewhat optimistic. We assume that nobody calls a
method of a partially uninitialized object through gc.get_objects(), in
particular. Code like this could potentially cause a segfault with a null
pointer dereference. This seems very unlikely to be an issue in practice,
however.
Accessing an attribute via getattr always checks for undefined attributes
and thus works if the object is partially uninitialized. This can be used
as a workaround if somebody ever needs to inspect partially uninitialized
objects via gc.get_objects().
The analysis runs after IR building as a separate pass. Since we only
run this on __init__ methods, this analysis pass will be fairly quick.
"""
from __future__ import annotations
from typing import Final
from mypyc.analysis.dataflow import (
CFG,
MAYBE_ANALYSIS,
AnalysisResult,
BaseAnalysisVisitor,
get_cfg,
run_analysis,
)
from mypyc.analysis.selfleaks import analyze_self_leaks
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.ops import (
Assign,
AssignMulti,
BasicBlock,
Branch,
Call,
ControlOp,
GetAttr,
Register,
RegisterOp,
Return,
SetAttr,
SetMem,
Unreachable,
)
from mypyc.ir.rtypes import RInstance
# If True, print out all always-defined attributes of native classes (to aid
# debugging and testing)
dump_always_defined: Final = False
def analyze_always_defined_attrs(class_irs: list[ClassIR]) -> None:
"""Find always defined attributes all classes of a compilation unit.
Also tag attribute initialization ops to not decref the previous
value (as this would read a NULL pointer and segfault).
Update the _always_initialized_attrs, _sometimes_initialized_attrs
and init_self_leak attributes in ClassIR instances.
This is the main entry point.
"""
seen: set[ClassIR] = set()
# First pass: only look at target class and classes in MRO
for cl in class_irs:
analyze_always_defined_attrs_in_class(cl, seen)
# Second pass: look at all derived class
seen = set()
for cl in class_irs:
update_always_defined_attrs_using_subclasses(cl, seen)
# Final pass: detect attributes that need to use a bitmap to track definedness
seen = set()
for cl in class_irs:
detect_undefined_bitmap(cl, seen)
def analyze_always_defined_attrs_in_class(cl: ClassIR, seen: set[ClassIR]) -> None:
if cl in seen:
return
seen.add(cl)
if (
cl.is_trait
or cl.inherits_python
or cl.allow_interpreted_subclasses
or cl.builtin_base is not None
or cl.children is None
or cl.is_serializable()
or cl.has_method("__new__")
):
# Give up -- we can't enforce that attributes are always defined.
return
# First analyze all base classes. Track seen classes to avoid duplicate work.
for base in cl.mro[1:]:
analyze_always_defined_attrs_in_class(base, seen)
m = cl.get_method("__init__")
if m is None:
cl._always_initialized_attrs = cl.attrs_with_defaults.copy()
cl._sometimes_initialized_attrs = cl.attrs_with_defaults.copy()
return
self_reg = m.arg_regs[0]
cfg = get_cfg(m.blocks)
dirty = analyze_self_leaks(m.blocks, self_reg, cfg)
maybe_defined = analyze_maybe_defined_attrs_in_init(
m.blocks, self_reg, cl.attrs_with_defaults, cfg
)
all_attrs: set[str] = set()
for base in cl.mro:
all_attrs.update(base.attributes)
maybe_undefined = analyze_maybe_undefined_attrs_in_init(
m.blocks, self_reg, initial_undefined=all_attrs - cl.attrs_with_defaults, cfg=cfg
)
always_defined = find_always_defined_attributes(
m.blocks, self_reg, all_attrs, maybe_defined, maybe_undefined, dirty
)
always_defined = {a for a in always_defined if not cl.is_deletable(a)}
cl._always_initialized_attrs = always_defined
if dump_always_defined:
print(cl.name, sorted(always_defined))
cl._sometimes_initialized_attrs = find_sometimes_defined_attributes(
m.blocks, self_reg, maybe_defined, dirty
)
mark_attr_initialization_ops(m.blocks, self_reg, maybe_defined, dirty)
# Check if __init__ can run unpredictable code (leak 'self').
any_dirty = False
for b in m.blocks:
for i, op in enumerate(b.ops):
if dirty.after[b, i] and not isinstance(op, Return):
any_dirty = True
break
cl.init_self_leak = any_dirty
def find_always_defined_attributes(
blocks: list[BasicBlock],
self_reg: Register,
all_attrs: set[str],
maybe_defined: AnalysisResult[str],
maybe_undefined: AnalysisResult[str],
dirty: AnalysisResult[None],
) -> set[str]:
"""Find attributes that are always initialized in some basic blocks.
The analysis results are expected to be up-to-date for the blocks.
Return a set of always defined attributes.
"""
attrs = all_attrs.copy()
for block in blocks:
for i, op in enumerate(block.ops):
# If an attribute we *read* may be undefined, it isn't always defined.
if isinstance(op, GetAttr) and op.obj is self_reg:
if op.attr in maybe_undefined.before[block, i]:
attrs.discard(op.attr)
# If an attribute we *set* may be sometimes undefined and
# sometimes defined, don't consider it always defined. Unlike
# the get case, it's fine for the attribute to be undefined.
# The set operation will then be treated as initialization.
if isinstance(op, SetAttr) and op.obj is self_reg:
if (
op.attr in maybe_undefined.before[block, i]
and op.attr in maybe_defined.before[block, i]
):
attrs.discard(op.attr)
# Treat an op that might run arbitrary code as an "exit"
# in terms of the analysis -- we can't do any inference
# afterwards reliably.
if dirty.after[block, i]:
if not dirty.before[block, i]:
attrs = attrs & (
maybe_defined.after[block, i] - maybe_undefined.after[block, i]
)
break
if isinstance(op, ControlOp):
for target in op.targets():
# Gotos/branches can also be "exits".
if not dirty.after[block, i] and dirty.before[target, 0]:
attrs = attrs & (
maybe_defined.after[target, 0] - maybe_undefined.after[target, 0]
)
return attrs
def find_sometimes_defined_attributes(
blocks: list[BasicBlock],
self_reg: Register,
maybe_defined: AnalysisResult[str],
dirty: AnalysisResult[None],
) -> set[str]:
"""Find attributes that are sometimes initialized in some basic blocks."""
attrs: set[str] = set()
for block in blocks:
for i, op in enumerate(block.ops):
# Only look at possibly defined attributes at exits.
if dirty.after[block, i]:
if not dirty.before[block, i]:
attrs = attrs | maybe_defined.after[block, i]
break
if isinstance(op, ControlOp):
for target in op.targets():
if not dirty.after[block, i] and dirty.before[target, 0]:
attrs = attrs | maybe_defined.after[target, 0]
return attrs
def mark_attr_initialization_ops(
blocks: list[BasicBlock],
self_reg: Register,
maybe_defined: AnalysisResult[str],
dirty: AnalysisResult[None],
) -> None:
"""Tag all SetAttr ops in the basic blocks that initialize attributes.
Initialization ops assume that the previous attribute value is the error value,
so there's no need to decref or check for definedness.
"""
for block in blocks:
for i, op in enumerate(block.ops):
if isinstance(op, SetAttr) and op.obj is self_reg:
attr = op.attr
if attr not in maybe_defined.before[block, i] and not dirty.after[block, i]:
op.mark_as_initializer()
GenAndKill = tuple[set[str], set[str]]
def attributes_initialized_by_init_call(op: Call) -> set[str]:
"""Calculate attributes that are always initialized by a super().__init__ call."""
self_type = op.fn.sig.args[0].type
assert isinstance(self_type, RInstance), self_type
cl = self_type.class_ir
return {a for base in cl.mro for a in base.attributes if base.is_always_defined(a)}
def attributes_maybe_initialized_by_init_call(op: Call) -> set[str]:
"""Calculate attributes that may be initialized by a super().__init__ call."""
self_type = op.fn.sig.args[0].type
assert isinstance(self_type, RInstance), self_type
cl = self_type.class_ir
return attributes_initialized_by_init_call(op) | cl._sometimes_initialized_attrs
class AttributeMaybeDefinedVisitor(BaseAnalysisVisitor[str]):
"""Find attributes that may have been defined via some code path.
Consider initializations in class body and assignments to 'self.x'
and calls to base class '__init__'.
"""
def __init__(self, self_reg: Register) -> None:
self.self_reg = self_reg
def visit_branch(self, op: Branch) -> tuple[set[str], set[str]]:
return set(), set()
def visit_return(self, op: Return) -> tuple[set[str], set[str]]:
return set(), set()
def visit_unreachable(self, op: Unreachable) -> tuple[set[str], set[str]]:
return set(), set()
def visit_register_op(self, op: RegisterOp) -> tuple[set[str], set[str]]:
if isinstance(op, SetAttr) and op.obj is self.self_reg:
return {op.attr}, set()
if isinstance(op, Call) and op.fn.class_name and op.fn.name == "__init__":
return attributes_maybe_initialized_by_init_call(op), set()
return set(), set()
def visit_assign(self, op: Assign) -> tuple[set[str], set[str]]:
return set(), set()
def visit_assign_multi(self, op: AssignMulti) -> tuple[set[str], set[str]]:
return set(), set()
def visit_set_mem(self, op: SetMem) -> tuple[set[str], set[str]]:
return set(), set()
def analyze_maybe_defined_attrs_in_init(
blocks: list[BasicBlock], self_reg: Register, attrs_with_defaults: set[str], cfg: CFG
) -> AnalysisResult[str]:
return run_analysis(
blocks=blocks,
cfg=cfg,
gen_and_kill=AttributeMaybeDefinedVisitor(self_reg),
initial=attrs_with_defaults,
backward=False,
kind=MAYBE_ANALYSIS,
)
class AttributeMaybeUndefinedVisitor(BaseAnalysisVisitor[str]):
"""Find attributes that may be undefined via some code path.
Consider initializations in class body, assignments to 'self.x'
and calls to base class '__init__'.
"""
def __init__(self, self_reg: Register) -> None:
self.self_reg = self_reg
def visit_branch(self, op: Branch) -> tuple[set[str], set[str]]:
return set(), set()
def visit_return(self, op: Return) -> tuple[set[str], set[str]]:
return set(), set()
def visit_unreachable(self, op: Unreachable) -> tuple[set[str], set[str]]:
return set(), set()
def visit_register_op(self, op: RegisterOp) -> tuple[set[str], set[str]]:
if isinstance(op, SetAttr) and op.obj is self.self_reg:
return set(), {op.attr}
if isinstance(op, Call) and op.fn.class_name and op.fn.name == "__init__":
return set(), attributes_initialized_by_init_call(op)
return set(), set()
def visit_assign(self, op: Assign) -> tuple[set[str], set[str]]:
return set(), set()
def visit_assign_multi(self, op: AssignMulti) -> tuple[set[str], set[str]]:
return set(), set()
def visit_set_mem(self, op: SetMem) -> tuple[set[str], set[str]]:
return set(), set()
def analyze_maybe_undefined_attrs_in_init(
blocks: list[BasicBlock], self_reg: Register, initial_undefined: set[str], cfg: CFG
) -> AnalysisResult[str]:
return run_analysis(
blocks=blocks,
cfg=cfg,
gen_and_kill=AttributeMaybeUndefinedVisitor(self_reg),
initial=initial_undefined,
backward=False,
kind=MAYBE_ANALYSIS,
)
def update_always_defined_attrs_using_subclasses(cl: ClassIR, seen: set[ClassIR]) -> None:
"""Remove attributes not defined in all subclasses from always defined attrs."""
if cl in seen:
return
if cl.children is None:
# Subclasses are unknown
return
removed = set()
for attr in cl._always_initialized_attrs:
for child in cl.children:
update_always_defined_attrs_using_subclasses(child, seen)
if attr not in child._always_initialized_attrs:
removed.add(attr)
cl._always_initialized_attrs -= removed
seen.add(cl)
def detect_undefined_bitmap(cl: ClassIR, seen: set[ClassIR]) -> None:
if cl.is_trait:
return
if cl in seen:
return
seen.add(cl)
for base in cl.base_mro[1:]:
detect_undefined_bitmap(base, seen)
if len(cl.base_mro) > 1:
cl.bitmap_attrs.extend(cl.base_mro[1].bitmap_attrs)
for n, t in cl.attributes.items():
if t.error_overlap and not cl.is_always_defined(n):
cl.bitmap_attrs.append(n)
for base in cl.mro[1:]:
if base.is_trait:
for n, t in base.attributes.items():
if t.error_overlap and not cl.is_always_defined(n) and n not in cl.bitmap_attrs:
cl.bitmap_attrs.append(n)

View file

@ -0,0 +1,32 @@
"""Find basic blocks that are likely to be executed frequently.
For example, this would not include blocks that have exception handlers.
We can use different optimization heuristics for common and rare code. For
example, we can make IR fast to compile instead of fast to execute for rare
code.
"""
from __future__ import annotations
from mypyc.ir.ops import BasicBlock, Branch, Goto
def frequently_executed_blocks(entry_point: BasicBlock) -> set[BasicBlock]:
result: set[BasicBlock] = set()
worklist = [entry_point]
while worklist:
block = worklist.pop()
if block in result:
continue
result.add(block)
t = block.terminator
if isinstance(t, Goto):
worklist.append(t.label)
elif isinstance(t, Branch):
if t.rare or t.traceback_entry is not None:
worklist.append(t.false)
else:
worklist.append(t.true)
worklist.append(t.false)
return result

View file

@ -0,0 +1,79 @@
from __future__ import annotations
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.deps import Dependency
from mypyc.ir.func_ir import FuncIR
from mypyc.ir.ops import Assign, CallC, PrimitiveOp
from mypyc.ir.rtypes import RStruct, RTuple, RType, RUnion, RVec
def find_implicit_op_dependencies(fn: FuncIR) -> set[Dependency] | None:
"""Find implicit dependencies that need to be imported.
Using primitives or types defined in librt submodules such as "librt.base64"
requires dependency imports (e.g., capsule imports).
Note that a module can depend on a librt module even if it doesn't explicitly
import it, for example via re-exported names or via return types of functions
defined in other modules.
"""
deps: set[Dependency] | None = None
# Check function signature types for dependencies
deps = find_type_dependencies(fn, deps)
# Check ops for dependencies
for block in fn.blocks:
for op in block.ops:
assert not isinstance(op, PrimitiveOp), "Lowered IR is expected"
if isinstance(op, CallC) and op.dependencies is not None:
for dep in op.dependencies:
if deps is None:
deps = set()
deps.add(dep)
deps = collect_type_deps(op.type, deps)
if isinstance(op, Assign):
deps = collect_type_deps(op.dest.type, deps)
return deps
def find_type_dependencies(fn: FuncIR, deps: set[Dependency] | None) -> set[Dependency] | None:
"""Find dependencies from RTypes in function signatures.
Some RTypes (e.g., those for librt types) have associated dependencies
that need to be imported when the type is used.
"""
# Check parameter types
for arg in fn.decl.sig.args:
deps = collect_type_deps(arg.type, deps)
# Check return type
deps = collect_type_deps(fn.decl.sig.ret_type, deps)
return deps
def find_class_dependencies(cl: ClassIR) -> set[Dependency] | None:
"""Find dependencies from class attribute types."""
deps: set[Dependency] | None = None
for base in cl.mro:
for attr_type in base.attributes.values():
deps = collect_type_deps(attr_type, deps)
return deps
def collect_type_deps(typ: RType, deps: set[Dependency] | None) -> set[Dependency] | None:
"""Collect dependencies from an RType, recursively checking compound types."""
if typ.dependencies is not None:
for dep in typ.dependencies:
if deps is None:
deps = set()
deps.add(dep)
if isinstance(typ, RUnion):
for item in typ.items:
deps = collect_type_deps(item, deps)
elif isinstance(typ, RTuple):
for item in typ.types:
deps = collect_type_deps(item, deps)
elif isinstance(typ, RStruct):
for item in typ.types:
deps = collect_type_deps(item, deps)
elif isinstance(typ, RVec):
deps = collect_type_deps(typ.item_type, deps)
return deps

View file

@ -0,0 +1,645 @@
"""Data-flow analyses."""
from __future__ import annotations
from abc import abstractmethod
from collections.abc import Iterable, Iterator, Set as AbstractSet
from typing import Any, Generic, TypeVar
from mypyc.ir.ops import (
Assign,
AssignMulti,
BasicBlock,
Box,
Branch,
Call,
CallC,
Cast,
ComparisonOp,
ControlOp,
DecRef,
Extend,
Float,
FloatComparisonOp,
FloatNeg,
FloatOp,
GetAttr,
GetElement,
GetElementPtr,
Goto,
IncRef,
InitStatic,
Integer,
IntOp,
KeepAlive,
LoadAddress,
LoadErrorValue,
LoadGlobal,
LoadLiteral,
LoadMem,
LoadStatic,
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
RegisterOp,
Return,
SetAttr,
SetElement,
SetMem,
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Undef,
Unreachable,
Value,
)
class CFG:
"""Control-flow graph.
Node 0 is always assumed to be the entry point. There must be a
non-empty set of exits.
"""
def __init__(
self,
succ: dict[BasicBlock, list[BasicBlock]],
pred: dict[BasicBlock, list[BasicBlock]],
exits: set[BasicBlock],
) -> None:
assert exits
self.succ = succ
self.pred = pred
self.exits = exits
def __str__(self) -> str:
exits = sorted(self.exits, key=lambda e: int(e.label))
return f"exits: {exits}\nsucc: {self.succ}\npred: {self.pred}"
def get_cfg(blocks: list[BasicBlock], *, use_yields: bool = False) -> CFG:
"""Calculate basic block control-flow graph.
If use_yields is set, then we treat returns inserted by yields as gotos
instead of exits.
"""
succ_map = {}
pred_map: dict[BasicBlock, list[BasicBlock]] = {}
exits = set()
for block in blocks:
assert not any(
isinstance(op, ControlOp) for op in block.ops[:-1]
), "Control-flow ops must be at the end of blocks"
if use_yields and isinstance(block.terminator, Return) and block.terminator.yield_target:
succ = [block.terminator.yield_target]
else:
succ = list(block.terminator.targets())
if not succ:
exits.add(block)
# Errors can occur anywhere inside a block, which means that
# we can't assume that the entire block has executed before
# jumping to the error handler. In our CFG construction, we
# model this as saying that a block can jump to its error
# handler or the error handlers of any of its normal
# successors (to represent an error before that next block
# completes). This works well for analyses like "must
# defined", where it implies that registers assigned in a
# block may be undefined in its error handler, but is in
# general not a precise representation of reality; any
# analyses that require more fidelity must wait until after
# exception insertion.
for error_point in [block] + succ:
if error_point.error_handler:
succ.append(error_point.error_handler)
succ_map[block] = succ
pred_map[block] = []
for prev, nxt in succ_map.items():
for label in nxt:
pred_map[label].append(prev)
return CFG(succ_map, pred_map, exits)
def get_real_target(label: BasicBlock) -> BasicBlock:
if len(label.ops) == 1 and isinstance(label.ops[-1], Goto):
label = label.ops[-1].label
return label
def cleanup_cfg(blocks: list[BasicBlock]) -> None:
"""Cleanup the control flow graph.
This eliminates obviously dead basic blocks and eliminates blocks that contain
nothing but a single jump.
There is a lot more that could be done.
"""
changed = True
while changed:
# First collapse any jumps to basic block that only contain a goto
for block in blocks:
for i, tgt in enumerate(block.terminator.targets()):
block.terminator.set_target(i, get_real_target(tgt))
# Then delete any blocks that have no predecessors
changed = False
cfg = get_cfg(blocks)
orig_blocks = blocks.copy()
blocks.clear()
for i, block in enumerate(orig_blocks):
if i == 0 or cfg.pred[block]:
blocks.append(block)
else:
changed = True
T = TypeVar("T")
AnalysisDict = dict[tuple[BasicBlock, int], set[T]]
class AnalysisResult(Generic[T]):
def __init__(self, before: AnalysisDict[T], after: AnalysisDict[T]) -> None:
self.before = before
self.after = after
def __str__(self) -> str:
return f"before: {self.before}\nafter: {self.after}\n"
GenAndKill = tuple[AbstractSet[T], AbstractSet[T]]
_EMPTY: tuple[frozenset[Any], frozenset[Any]] = (frozenset(), frozenset())
class BaseAnalysisVisitor(OpVisitor[GenAndKill[T]]):
def visit_goto(self, op: Goto) -> GenAndKill[T]:
return _EMPTY
@abstractmethod
def visit_register_op(self, op: RegisterOp) -> GenAndKill[T]:
raise NotImplementedError
@abstractmethod
def visit_assign(self, op: Assign) -> GenAndKill[T]:
raise NotImplementedError
@abstractmethod
def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[T]:
raise NotImplementedError
@abstractmethod
def visit_set_mem(self, op: SetMem) -> GenAndKill[T]:
raise NotImplementedError
def visit_inc_ref(self, op: IncRef) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_dec_ref(self, op: DecRef) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_call(self, op: Call) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_method_call(self, op: MethodCall) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_load_error_value(self, op: LoadErrorValue) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_load_literal(self, op: LoadLiteral) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_get_attr(self, op: GetAttr) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_set_attr(self, op: SetAttr) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_load_static(self, op: LoadStatic) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_init_static(self, op: InitStatic) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_tuple_get(self, op: TupleGet) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_tuple_set(self, op: TupleSet) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_box(self, op: Box) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_unbox(self, op: Unbox) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_cast(self, op: Cast) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_call_c(self, op: CallC) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_primitive_op(self, op: PrimitiveOp) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_truncate(self, op: Truncate) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_extend(self, op: Extend) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_load_global(self, op: LoadGlobal) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_int_op(self, op: IntOp) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_float_op(self, op: FloatOp) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_float_neg(self, op: FloatNeg) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_comparison_op(self, op: ComparisonOp) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_float_comparison_op(self, op: FloatComparisonOp) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_load_mem(self, op: LoadMem) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_get_element(self, op: GetElement) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_get_element_ptr(self, op: GetElementPtr) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_set_element(self, op: SetElement) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_load_address(self, op: LoadAddress) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_keep_alive(self, op: KeepAlive) -> GenAndKill[T]:
return self.visit_register_op(op)
def visit_unborrow(self, op: Unborrow) -> GenAndKill[T]:
return self.visit_register_op(op)
class DefinedVisitor(BaseAnalysisVisitor[Value]):
"""Visitor for finding defined registers.
Note that this only deals with registers and not temporaries, on
the assumption that we never access temporaries when they might be
undefined.
If strict_errors is True, then we regard any use of LoadErrorValue
as making a register undefined. Otherwise we only do if
`undefines` is set on the error value.
This lets us only consider the things we care about during
uninitialized variable checking while capturing all possibly
undefined things for refcounting.
"""
def __init__(self, strict_errors: bool = False) -> None:
self.strict_errors = strict_errors
def visit_branch(self, op: Branch) -> GenAndKill[Value]:
return _EMPTY
def visit_return(self, op: Return) -> GenAndKill[Value]:
return _EMPTY
def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]:
return _EMPTY
def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]:
return _EMPTY
def visit_assign(self, op: Assign) -> GenAndKill[Value]:
# Loading an error value may undefine the register.
if isinstance(op.src, LoadErrorValue) and (op.src.undefines or self.strict_errors):
return set(), {op.dest}
else:
return {op.dest}, set()
def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
# Array registers are special and we don't track the definedness of them.
return _EMPTY
def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
return _EMPTY
def analyze_maybe_defined_regs(
blocks: list[BasicBlock], cfg: CFG, initial_defined: set[Value]
) -> AnalysisResult[Value]:
"""Calculate potentially defined registers at each CFG location.
A register is defined if it has a value along some path from the initial location.
"""
return run_analysis(
blocks=blocks,
cfg=cfg,
gen_and_kill=DefinedVisitor(),
initial=initial_defined,
backward=False,
kind=MAYBE_ANALYSIS,
)
def analyze_must_defined_regs(
blocks: list[BasicBlock],
cfg: CFG,
initial_defined: set[Value],
regs: Iterable[Value],
strict_errors: bool = False,
) -> AnalysisResult[Value]:
"""Calculate always defined registers at each CFG location.
This analysis can work before exception insertion, since it is a
sound assumption that registers defined in a block might not be
initialized in its error handler.
A register is defined if it has a value along all paths from the
initial location.
"""
return run_analysis(
blocks=blocks,
cfg=cfg,
gen_and_kill=DefinedVisitor(strict_errors=strict_errors),
initial=initial_defined,
backward=False,
kind=MUST_ANALYSIS,
universe=set(regs),
)
class BorrowedArgumentsVisitor(BaseAnalysisVisitor[Value]):
def __init__(self, args: set[Value]) -> None:
self.args = args
def visit_branch(self, op: Branch) -> GenAndKill[Value]:
return _EMPTY
def visit_return(self, op: Return) -> GenAndKill[Value]:
return _EMPTY
def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]:
return _EMPTY
def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]:
return _EMPTY
def visit_assign(self, op: Assign) -> GenAndKill[Value]:
if op.dest in self.args:
return set(), {op.dest}
return _EMPTY
def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
return _EMPTY
def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
return _EMPTY
def analyze_borrowed_arguments(
blocks: list[BasicBlock], cfg: CFG, borrowed: set[Value]
) -> AnalysisResult[Value]:
"""Calculate arguments that can use references borrowed from the caller.
When assigning to an argument, it no longer is borrowed.
"""
return run_analysis(
blocks=blocks,
cfg=cfg,
gen_and_kill=BorrowedArgumentsVisitor(borrowed),
initial=borrowed,
backward=False,
kind=MUST_ANALYSIS,
universe=borrowed,
)
class UndefinedVisitor(BaseAnalysisVisitor[Value]):
def visit_branch(self, op: Branch) -> GenAndKill[Value]:
return _EMPTY
def visit_return(self, op: Return) -> GenAndKill[Value]:
return _EMPTY
def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]:
return _EMPTY
def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]:
return set(), {op} if not op.is_void else set()
def visit_assign(self, op: Assign) -> GenAndKill[Value]:
return set(), {op.dest}
def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
return set(), {op.dest}
def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
return _EMPTY
def non_trivial_sources(op: Op) -> set[Value]:
result = set()
for source in op.sources():
if not isinstance(source, (Integer, Float, Undef)):
result.add(source)
return result
class LivenessVisitor(BaseAnalysisVisitor[Value]):
def visit_branch(self, op: Branch) -> GenAndKill[Value]:
return non_trivial_sources(op), set()
def visit_return(self, op: Return) -> GenAndKill[Value]:
if not isinstance(op.value, (Integer, Float)):
return {op.value}, set()
else:
return _EMPTY
def visit_unreachable(self, op: Unreachable) -> GenAndKill[Value]:
return _EMPTY
def visit_register_op(self, op: RegisterOp) -> GenAndKill[Value]:
gen = non_trivial_sources(op)
if not op.is_void:
return gen, {op}
else:
return gen, set()
def visit_assign(self, op: Assign) -> GenAndKill[Value]:
return non_trivial_sources(op), {op.dest}
def visit_assign_multi(self, op: AssignMulti) -> GenAndKill[Value]:
return non_trivial_sources(op), {op.dest}
def visit_set_mem(self, op: SetMem) -> GenAndKill[Value]:
return non_trivial_sources(op), set()
def visit_inc_ref(self, op: IncRef) -> GenAndKill[Value]:
return _EMPTY
def visit_dec_ref(self, op: DecRef) -> GenAndKill[Value]:
return _EMPTY
def analyze_live_regs(blocks: list[BasicBlock], cfg: CFG) -> AnalysisResult[Value]:
"""Calculate live registers at each CFG location.
A register is live at a location if it can be read along some CFG path starting
from the location.
"""
return run_analysis(
blocks=blocks,
cfg=cfg,
gen_and_kill=LivenessVisitor(),
initial=set(),
backward=True,
kind=MAYBE_ANALYSIS,
)
# Analysis kinds
MUST_ANALYSIS = 0
MAYBE_ANALYSIS = 1
def run_analysis(
blocks: list[BasicBlock],
cfg: CFG,
gen_and_kill: OpVisitor[GenAndKill[T]],
initial: set[T],
kind: int,
backward: bool,
universe: set[T] | None = None,
) -> AnalysisResult[T]:
"""Run a general set-based data flow analysis.
Args:
blocks: All basic blocks
cfg: Control-flow graph for the code
gen_and_kill: Implementation of gen and kill functions for each op
initial: Value of analysis for the entry points (for a forward analysis) or the
exit points (for a backward analysis)
kind: MUST_ANALYSIS or MAYBE_ANALYSIS
backward: If False, the analysis is a forward analysis; it's backward otherwise
universe: For a must analysis, the set of all possible values. This is the starting
value for the work list algorithm, which will narrow this down until reaching a
fixed point. For a maybe analysis the iteration always starts from an empty set
and this argument is ignored.
Return analysis results: (before, after)
"""
block_gen = {}
block_kill = {}
# Calculate kill and gen sets for entire basic blocks.
for block in blocks:
gen: set[T] = set()
kill: set[T] = set()
ops = block.ops
if backward:
ops = list(reversed(ops))
for op in ops:
opgen, opkill = op.accept(gen_and_kill)
if opkill:
gen -= opkill
if opgen:
gen |= opgen
kill -= opgen
if opkill:
kill |= opkill
block_gen[block] = gen
block_kill[block] = kill
# Set up initial state for worklist algorithm.
worklist = list(blocks)
if not backward:
worklist.reverse() # Reverse for a small performance improvement
workset = set(worklist)
before: dict[BasicBlock, set[T]] = {}
after: dict[BasicBlock, set[T]] = {}
for block in blocks:
if kind == MAYBE_ANALYSIS:
before[block] = set()
after[block] = set()
else:
assert universe is not None, "Universe must be defined for a must analysis"
before[block] = set(universe)
after[block] = set(universe)
if backward:
pred_map = cfg.succ
succ_map = cfg.pred
else:
pred_map = cfg.pred
succ_map = cfg.succ
# Run work list algorithm to generate in and out sets for each basic block.
while worklist:
label = worklist.pop()
workset.remove(label)
if pred_map[label]:
new_before: set[T] | None = None
for pred in pred_map[label]:
if new_before is None:
new_before = set(after[pred])
elif kind == MAYBE_ANALYSIS:
new_before |= after[pred]
else:
new_before &= after[pred]
assert new_before is not None
else:
new_before = set(initial)
before[label] = new_before
new_after = (new_before - block_kill[label]) | block_gen[label]
if new_after != after[label]:
for succ in succ_map[label]:
if succ not in workset:
worklist.append(succ)
workset.add(succ)
after[label] = new_after
# Run algorithm for each basic block to generate opcode-level sets.
op_before: dict[tuple[BasicBlock, int], set[T]] = {}
op_after: dict[tuple[BasicBlock, int], set[T]] = {}
for block in blocks:
label = block
cur = before[label]
ops_enum: Iterator[tuple[int, Op]] = enumerate(block.ops)
if backward:
ops_enum = reversed(list(ops_enum))
for idx, op in ops_enum:
op_before[label, idx] = cur
opgen, opkill = op.accept(gen_and_kill)
if opkill:
cur = cur - opkill
if opgen:
cur = cur | opgen
op_after[label, idx] = cur
if backward:
op_after, op_before = op_before, op_after
return AnalysisResult(op_before, op_after)

View file

@ -0,0 +1,498 @@
"""Utilities for checking that internal ir is valid and consistent."""
from __future__ import annotations
from mypyc.ir.func_ir import FUNC_STATICMETHOD, FuncIR
from mypyc.ir.ops import (
Assign,
AssignMulti,
BaseAssign,
BasicBlock,
Box,
Branch,
Call,
CallC,
Cast,
ComparisonOp,
ControlOp,
DecRef,
Extend,
Float,
FloatComparisonOp,
FloatNeg,
FloatOp,
GetAttr,
GetElement,
GetElementPtr,
Goto,
IncRef,
InitStatic,
Integer,
IntOp,
KeepAlive,
LoadAddress,
LoadErrorValue,
LoadGlobal,
LoadLiteral,
LoadMem,
LoadStatic,
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
Return,
SetAttr,
SetElement,
SetMem,
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Undef,
Unreachable,
Value,
)
from mypyc.ir.pprint import format_func
from mypyc.ir.rtypes import (
KNOWN_NATIVE_TYPES,
RArray,
RInstance,
RPrimitive,
RType,
RUnion,
RVec,
bytes_rprimitive,
dict_rprimitive,
int_rprimitive,
is_c_py_ssize_t_rprimitive,
is_fixed_width_rtype,
is_float_rprimitive,
is_object_rprimitive,
is_pointer_rprimitive,
list_rprimitive,
pointer_rprimitive,
range_rprimitive,
set_rprimitive,
str_rprimitive,
tuple_rprimitive,
)
class FnError:
def __init__(self, source: Op | BasicBlock, desc: str) -> None:
self.source = source
self.desc = desc
def __eq__(self, other: object) -> bool:
return (
isinstance(other, FnError) and self.source == other.source and self.desc == other.desc
)
def __repr__(self) -> str:
return f"FnError(source={self.source}, desc={self.desc})"
def check_func_ir(fn: FuncIR) -> list[FnError]:
"""Applies validations to a given function ir and returns a list of errors found."""
errors = []
op_set = set()
for block in fn.blocks:
if not block.terminated:
errors.append(
FnError(source=block.ops[-1] if block.ops else block, desc="Block not terminated")
)
for op in block.ops[:-1]:
if isinstance(op, ControlOp):
errors.append(FnError(source=op, desc="Block has operations after control op"))
if op in op_set:
errors.append(FnError(source=op, desc="Func has a duplicate op"))
op_set.add(op)
errors.extend(check_op_sources_valid(fn))
if errors:
return errors
op_checker = OpChecker(fn)
for block in fn.blocks:
for op in block.ops:
op.accept(op_checker)
return op_checker.errors
class IrCheckException(Exception):
pass
def assert_func_ir_valid(fn: FuncIR) -> None:
errors = check_func_ir(fn)
if errors:
raise IrCheckException(
"Internal error: Generated invalid IR: \n"
+ "\n".join(format_func(fn, [(e.source, e.desc) for e in errors]))
)
def check_op_sources_valid(fn: FuncIR) -> list[FnError]:
errors = []
valid_ops: set[Op] = set()
valid_registers: set[Register] = set()
for block in fn.blocks:
valid_ops.update(block.ops)
for op in block.ops:
if isinstance(op, BaseAssign):
valid_registers.add(op.dest)
elif isinstance(op, LoadAddress) and isinstance(op.src, Register):
valid_registers.add(op.src)
valid_registers.update(fn.arg_regs)
for block in fn.blocks:
for op in block.ops:
for source in op.sources():
if isinstance(source, (Integer, Float, Undef)):
pass
elif isinstance(source, Op):
if source not in valid_ops:
errors.append(
FnError(
source=op,
desc=f"Invalid op reference to op of type {type(source).__name__}",
)
)
elif isinstance(source, Register):
if source not in valid_registers:
errors.append(
FnError(
source=op, desc=f"Invalid op reference to register {source.name!r}"
)
)
return errors
disjoint_types = {
int_rprimitive.name,
bytes_rprimitive.name,
str_rprimitive.name,
dict_rprimitive.name,
list_rprimitive.name,
set_rprimitive.name,
tuple_rprimitive.name,
range_rprimitive.name,
} | set(KNOWN_NATIVE_TYPES)
def can_coerce_to(src: RType, dest: RType) -> bool:
"""Check if src can be assigned to dest_rtype.
Currently okay to have false positives.
"""
if isinstance(dest, RUnion):
return any(can_coerce_to(src, d) for d in dest.items)
if isinstance(dest, RPrimitive):
if isinstance(src, RPrimitive):
# If either src or dest is a disjoint type, then they must both be.
if src.name in disjoint_types and dest.name in disjoint_types:
return src.name == dest.name
return src.size == dest.size
if isinstance(src, (RInstance, RVec)):
return is_object_rprimitive(dest)
if isinstance(src, RUnion):
# IR doesn't have the ability to narrow unions based on
# control flow, so cannot be a strict all() here.
return any(can_coerce_to(s, dest) for s in src.items)
return False
return True
def is_valid_ptr_displacement_type(rtype: RType) -> bool:
"""Check if rtype is a valid displacement type for pointer arithmetic."""
if not (is_fixed_width_rtype(rtype) or is_c_py_ssize_t_rprimitive(rtype)):
return False
assert isinstance(rtype, RPrimitive)
return rtype.size == pointer_rprimitive.size
def is_pointer_arithmetic(op: IntOp) -> bool:
"""Check if op is add/subtract targeting pointer_rprimitive and integer of the same size."""
if op.op not in (IntOp.ADD, IntOp.SUB):
return False
if not is_pointer_rprimitive(op.type):
return False
left = op.lhs.type
right = op.rhs.type
if is_pointer_rprimitive(left):
return is_valid_ptr_displacement_type(right)
if is_pointer_rprimitive(right):
return is_valid_ptr_displacement_type(left)
return False
class OpChecker(OpVisitor[None]):
def __init__(self, parent_fn: FuncIR) -> None:
self.parent_fn = parent_fn
self.errors: list[FnError] = []
def fail(self, source: Op, desc: str) -> None:
self.errors.append(FnError(source=source, desc=desc))
def check_control_op_targets(self, op: ControlOp) -> None:
for target in op.targets():
if target not in self.parent_fn.blocks:
self.fail(source=op, desc=f"Invalid control operation target: {target.label}")
def check_type_coercion(self, op: Op, src: RType, dest: RType) -> None:
if not can_coerce_to(src, dest):
self.fail(
source=op, desc=f"Cannot coerce source type {src.name} to dest type {dest.name}"
)
def check_compatibility(self, op: Op, t: RType, s: RType) -> None:
if not can_coerce_to(t, s) or not can_coerce_to(s, t):
self.fail(source=op, desc=f"{t.name} and {s.name} are not compatible")
def expect_float(self, op: Op, v: Value) -> None:
if not is_float_rprimitive(v.type):
self.fail(op, f"Float expected (actual type is {v.type})")
def expect_non_float(self, op: Op, v: Value) -> None:
if is_float_rprimitive(v.type):
self.fail(op, "Float not expected")
def expect_primitive_type(self, op: Op, v: Value) -> None:
if not isinstance(v.type, RPrimitive):
self.fail(op, f"RPrimitive expected, got {type(v.type).__name__}")
def visit_goto(self, op: Goto) -> None:
self.check_control_op_targets(op)
def visit_branch(self, op: Branch) -> None:
self.check_control_op_targets(op)
def visit_return(self, op: Return) -> None:
self.check_type_coercion(op, op.value.type, self.parent_fn.decl.sig.ret_type)
def visit_unreachable(self, op: Unreachable) -> None:
# Unreachables are checked at a higher level since validation
# requires access to the entire basic block.
pass
def visit_assign(self, op: Assign) -> None:
self.check_type_coercion(op, op.src.type, op.dest.type)
def visit_assign_multi(self, op: AssignMulti) -> None:
for src in op.src:
assert isinstance(op.dest.type, RArray)
self.check_type_coercion(op, src.type, op.dest.type.item_type)
def visit_load_error_value(self, op: LoadErrorValue) -> None:
# Currently it is assumed that all types have an error value.
# Once this is fixed we can validate that the rtype here actually
# has an error value.
pass
def check_tuple_items_valid_literals(self, op: LoadLiteral, t: tuple[object, ...]) -> None:
for x in t:
if x is not None and not isinstance(x, (str, bytes, bool, int, float, complex, tuple)):
self.fail(op, f"Invalid type for item of tuple literal: {type(x)})")
if isinstance(x, tuple):
self.check_tuple_items_valid_literals(op, x)
def check_frozenset_items_valid_literals(self, op: LoadLiteral, s: frozenset[object]) -> None:
for x in s:
if x is None or isinstance(x, (str, bytes, bool, int, float, complex)):
pass
elif isinstance(x, tuple):
self.check_tuple_items_valid_literals(op, x)
else:
self.fail(op, f"Invalid type for item of frozenset literal: {type(x)})")
def visit_load_literal(self, op: LoadLiteral) -> None:
expected_type = None
if op.value is None:
expected_type = "builtins.object"
elif isinstance(op.value, int):
expected_type = "builtins.int"
elif isinstance(op.value, str):
expected_type = "builtins.str"
elif isinstance(op.value, bytes):
expected_type = "builtins.bytes"
elif isinstance(op.value, float):
expected_type = "builtins.float"
elif isinstance(op.value, complex):
expected_type = "builtins.object"
elif isinstance(op.value, tuple):
expected_type = "builtins.tuple"
self.check_tuple_items_valid_literals(op, op.value)
elif isinstance(op.value, frozenset):
# There's no frozenset_rprimitive type since it'd be pretty useless so we just pretend
# it's a set (when it's really a frozenset).
expected_type = "builtins.set"
self.check_frozenset_items_valid_literals(op, op.value)
assert expected_type is not None, "Missed a case for LoadLiteral check"
if op.type.name not in [expected_type, "builtins.object"]:
self.fail(
op,
f"Invalid literal value for type: value has "
f"type {expected_type}, but op has type {op.type.name}",
)
def visit_get_attr(self, op: GetAttr) -> None:
# Nothing to do.
pass
def visit_set_attr(self, op: SetAttr) -> None:
# Nothing to do.
pass
# Static operations cannot be checked at the function level.
def visit_load_static(self, op: LoadStatic) -> None:
pass
def visit_init_static(self, op: InitStatic) -> None:
pass
def visit_tuple_get(self, op: TupleGet) -> None:
# Nothing to do.
pass
def visit_tuple_set(self, op: TupleSet) -> None:
# Nothing to do.
pass
def visit_inc_ref(self, op: IncRef) -> None:
# Nothing to do.
pass
def visit_dec_ref(self, op: DecRef) -> None:
# Nothing to do.
pass
def visit_call(self, op: Call) -> None:
# Length is checked in constructor, and return type is set
# in a way that can't be incorrect
for arg_value, arg_runtime in zip(op.args, op.fn.sig.args):
self.check_type_coercion(op, arg_value.type, arg_runtime.type)
def visit_method_call(self, op: MethodCall) -> None:
# Similar to above, but we must look up method first.
method_decl = op.receiver_type.class_ir.method_decl(op.method)
if method_decl.kind == FUNC_STATICMETHOD:
decl_index = 0
else:
decl_index = 1
if len(op.args) + decl_index != len(method_decl.sig.args):
self.fail(op, "Incorrect number of args for method call.")
# Skip the receiver argument (self)
for arg_value, arg_runtime in zip(op.args, method_decl.sig.args[decl_index:]):
self.check_type_coercion(op, arg_value.type, arg_runtime.type)
def visit_cast(self, op: Cast) -> None:
pass
def visit_box(self, op: Box) -> None:
pass
def visit_unbox(self, op: Unbox) -> None:
pass
def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
pass
def visit_call_c(self, op: CallC) -> None:
pass
def visit_primitive_op(self, op: PrimitiveOp) -> None:
pass
def visit_truncate(self, op: Truncate) -> None:
pass
def visit_extend(self, op: Extend) -> None:
pass
def visit_load_global(self, op: LoadGlobal) -> None:
pass
def visit_int_op(self, op: IntOp) -> None:
self.expect_primitive_type(op, op.lhs)
self.expect_primitive_type(op, op.rhs)
self.expect_non_float(op, op.lhs)
self.expect_non_float(op, op.rhs)
left = op.lhs.type
right = op.rhs.type
op_str = op.op_str[op.op]
if (
isinstance(left, RPrimitive)
and isinstance(right, RPrimitive)
and left.is_signed != right.is_signed
and (
op_str in ("+", "-", "*", "/", "%")
or (op_str not in ("<<", ">>") and left.size != right.size)
)
and not is_pointer_arithmetic(op)
):
self.fail(op, f"Operand types have incompatible signs: {left}, {right}")
def visit_comparison_op(self, op: ComparisonOp) -> None:
self.check_compatibility(op, op.lhs.type, op.rhs.type)
self.expect_non_float(op, op.lhs)
self.expect_non_float(op, op.rhs)
left = op.lhs.type
right = op.rhs.type
if (
isinstance(left, RPrimitive)
and isinstance(right, RPrimitive)
and left.is_signed != right.is_signed
):
self.fail(op, f"Operand types have incompatible signs: {left}, {right}")
def visit_float_op(self, op: FloatOp) -> None:
self.expect_float(op, op.lhs)
self.expect_float(op, op.rhs)
def visit_float_neg(self, op: FloatNeg) -> None:
self.expect_float(op, op.src)
def visit_float_comparison_op(self, op: FloatComparisonOp) -> None:
self.expect_float(op, op.lhs)
self.expect_float(op, op.rhs)
def visit_load_mem(self, op: LoadMem) -> None:
pass
def visit_set_mem(self, op: SetMem) -> None:
pass
def visit_get_element(self, op: GetElement) -> None:
pass
def visit_get_element_ptr(self, op: GetElementPtr) -> None:
pass
def visit_set_element(self, op: SetElement) -> None:
pass
def visit_load_address(self, op: LoadAddress) -> None:
pass
def visit_keep_alive(self, op: KeepAlive) -> None:
pass
def visit_unborrow(self, op: Unborrow) -> None:
pass

View file

@ -0,0 +1,231 @@
from __future__ import annotations
from mypyc.analysis.dataflow import (
CFG,
MAYBE_ANALYSIS,
AnalysisResult,
GenAndKill as _DataflowGenAndKill,
run_analysis,
)
from mypyc.ir.ops import (
Assign,
AssignMulti,
BasicBlock,
Box,
Branch,
Call,
CallC,
Cast,
ComparisonOp,
DecRef,
Extend,
FloatComparisonOp,
FloatNeg,
FloatOp,
GetAttr,
GetElement,
GetElementPtr,
Goto,
IncRef,
InitStatic,
IntOp,
KeepAlive,
LoadAddress,
LoadErrorValue,
LoadGlobal,
LoadLiteral,
LoadMem,
LoadStatic,
MethodCall,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
RegisterOp,
Return,
SetAttr,
SetElement,
SetMem,
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Unreachable,
)
from mypyc.ir.rtypes import RInstance
GenAndKill = _DataflowGenAndKill[None]
CLEAN: GenAndKill = (set(), set())
DIRTY: GenAndKill = ({None}, {None})
class SelfLeakedVisitor(OpVisitor[GenAndKill]):
"""Analyze whether 'self' may be seen by arbitrary code in '__init__'.
More formally, the set is not empty if along some path from IR entry point
arbitrary code could have been executed that has access to 'self'.
(We don't consider access via 'gc.get_objects()'.)
"""
def __init__(self, self_reg: Register) -> None:
self.self_reg = self_reg
def visit_goto(self, op: Goto) -> GenAndKill:
return CLEAN
def visit_branch(self, op: Branch) -> GenAndKill:
return CLEAN
def visit_return(self, op: Return) -> GenAndKill:
# Consider all exits from the function 'dirty' since they implicitly
# cause 'self' to be returned.
return DIRTY
def visit_unreachable(self, op: Unreachable) -> GenAndKill:
return CLEAN
def visit_assign(self, op: Assign) -> GenAndKill:
if op.src is self.self_reg or op.dest is self.self_reg:
return DIRTY
return CLEAN
def visit_assign_multi(self, op: AssignMulti) -> GenAndKill:
return CLEAN
def visit_set_mem(self, op: SetMem) -> GenAndKill:
return CLEAN
def visit_inc_ref(self, op: IncRef) -> GenAndKill:
return CLEAN
def visit_dec_ref(self, op: DecRef) -> GenAndKill:
return CLEAN
def visit_call(self, op: Call) -> GenAndKill:
fn = op.fn
if fn.class_name and fn.name == "__init__":
self_type = op.fn.sig.args[0].type
assert isinstance(self_type, RInstance), self_type
cl = self_type.class_ir
if not cl.init_self_leak:
return CLEAN
return self.check_register_op(op)
def visit_method_call(self, op: MethodCall) -> GenAndKill:
return self.check_register_op(op)
def visit_load_error_value(self, op: LoadErrorValue) -> GenAndKill:
return CLEAN
def visit_load_literal(self, op: LoadLiteral) -> GenAndKill:
return CLEAN
def visit_get_attr(self, op: GetAttr) -> GenAndKill:
cl = op.class_type.class_ir
if cl.get_method(op.attr):
# Property -- calls a function
return self.check_register_op(op)
return CLEAN
def visit_set_attr(self, op: SetAttr) -> GenAndKill:
cl = op.class_type.class_ir
if cl.get_method(op.attr):
# Property - calls a function
return self.check_register_op(op)
return CLEAN
def visit_load_static(self, op: LoadStatic) -> GenAndKill:
return CLEAN
def visit_init_static(self, op: InitStatic) -> GenAndKill:
return self.check_register_op(op)
def visit_tuple_get(self, op: TupleGet) -> GenAndKill:
return CLEAN
def visit_tuple_set(self, op: TupleSet) -> GenAndKill:
return self.check_register_op(op)
def visit_box(self, op: Box) -> GenAndKill:
return self.check_register_op(op)
def visit_unbox(self, op: Unbox) -> GenAndKill:
return self.check_register_op(op)
def visit_cast(self, op: Cast) -> GenAndKill:
return self.check_register_op(op)
def visit_raise_standard_error(self, op: RaiseStandardError) -> GenAndKill:
return CLEAN
def visit_call_c(self, op: CallC) -> GenAndKill:
return self.check_register_op(op)
def visit_primitive_op(self, op: PrimitiveOp) -> GenAndKill:
return self.check_register_op(op)
def visit_truncate(self, op: Truncate) -> GenAndKill:
return CLEAN
def visit_extend(self, op: Extend) -> GenAndKill:
return CLEAN
def visit_load_global(self, op: LoadGlobal) -> GenAndKill:
return CLEAN
def visit_int_op(self, op: IntOp) -> GenAndKill:
return CLEAN
def visit_comparison_op(self, op: ComparisonOp) -> GenAndKill:
return CLEAN
def visit_float_op(self, op: FloatOp) -> GenAndKill:
return CLEAN
def visit_float_neg(self, op: FloatNeg) -> GenAndKill:
return CLEAN
def visit_float_comparison_op(self, op: FloatComparisonOp) -> GenAndKill:
return CLEAN
def visit_load_mem(self, op: LoadMem) -> GenAndKill:
return CLEAN
def visit_get_element(self, op: GetElement) -> GenAndKill:
return CLEAN
def visit_get_element_ptr(self, op: GetElementPtr) -> GenAndKill:
return CLEAN
def visit_set_element(self, op: SetElement) -> GenAndKill:
return CLEAN
def visit_load_address(self, op: LoadAddress) -> GenAndKill:
return CLEAN
def visit_keep_alive(self, op: KeepAlive) -> GenAndKill:
return CLEAN
def visit_unborrow(self, op: Unborrow) -> GenAndKill:
return CLEAN
def check_register_op(self, op: RegisterOp) -> GenAndKill:
if any(src is self.self_reg for src in op.sources()):
return DIRTY
return CLEAN
def analyze_self_leaks(
blocks: list[BasicBlock], self_reg: Register, cfg: CFG
) -> AnalysisResult[None]:
return run_analysis(
blocks=blocks,
cfg=cfg,
gen_and_kill=SelfLeakedVisitor(self_reg),
initial=set(),
backward=False,
kind=MAYBE_ANALYSIS,
)

View file

@ -0,0 +1,472 @@
"""Generate source code formatted as HTML, with bottlenecks annotated and highlighted.
Various heuristics are used to detect common issues that cause slower than
expected performance.
"""
from __future__ import annotations
import os.path
import sys
from html import escape
from typing import Final
from mypy.build import BuildResult
from mypy.nodes import (
AssignmentStmt,
CallExpr,
ClassDef,
Decorator,
DictionaryComprehension,
Expression,
ForStmt,
FuncDef,
GeneratorExpr,
IndexExpr,
LambdaExpr,
MemberExpr,
MypyFile,
NamedTupleExpr,
NameExpr,
NewTypeExpr,
Node,
OpExpr,
RefExpr,
TupleExpr,
TypedDictExpr,
TypeInfo,
TypeVarExpr,
Var,
WithStmt,
)
from mypy.traverser import TraverserVisitor
from mypy.types import AnyType, Instance, ProperType, Type, TypeOfAny, get_proper_type
from mypy.util import FancyFormatter
from mypyc.ir.func_ir import FuncIR
from mypyc.ir.module_ir import ModuleIR
from mypyc.ir.ops import CallC, LoadLiteral, LoadStatic, Value
from mypyc.irbuild.mapper import Mapper
class Annotation:
"""HTML annotation for compiled source code"""
def __init__(self, message: str, priority: int = 1) -> None:
# Message as HTML that describes an issue and/or how to fix it.
# Multiple messages on a line may be concatenated.
self.message = message
# If multiple annotations are generated for a single line, only report
# the highest-priority ones. Some use cases generate multiple annotations,
# and this can be used to reduce verbosity by hiding the lower-priority
# ones.
self.priority = priority
op_hints: Final = {
"PyNumber_Add": Annotation('Generic "+" operation.'),
"PyNumber_Subtract": Annotation('Generic "-" operation.'),
"PyNumber_Multiply": Annotation('Generic "*" operation.'),
"PyNumber_TrueDivide": Annotation('Generic "/" operation.'),
"PyNumber_FloorDivide": Annotation('Generic "//" operation.'),
"PyNumber_Positive": Annotation('Generic unary "+" operation.'),
"PyNumber_Negative": Annotation('Generic unary "-" operation.'),
"PyNumber_And": Annotation('Generic "&" operation.'),
"PyNumber_Or": Annotation('Generic "|" operation.'),
"PyNumber_Xor": Annotation('Generic "^" operation.'),
"PyNumber_Lshift": Annotation('Generic "<<" operation.'),
"PyNumber_Rshift": Annotation('Generic ">>" operation.'),
"PyNumber_Invert": Annotation('Generic "~" operation.'),
"PyObject_Call": Annotation("Generic call operation."),
"PyObject_CallObject": Annotation("Generic call operation."),
"PyObject_RichCompare": Annotation("Generic comparison operation."),
"PyObject_GetItem": Annotation("Generic indexing operation."),
"PyObject_SetItem": Annotation("Generic indexed assignment."),
}
stdlib_hints: Final = {
"functools.partial": Annotation(
'"functools.partial" is inefficient in compiled code.', priority=3
),
"itertools.chain": Annotation(
'"itertools.chain" is inefficient in compiled code (hint: replace with for loops).',
priority=3,
),
"itertools.groupby": Annotation(
'"itertools.groupby" is inefficient in compiled code.', priority=3
),
"itertools.islice": Annotation(
'"itertools.islice" is inefficient in compiled code (hint: replace with for loop over index range).',
priority=3,
),
"copy.deepcopy": Annotation(
'"copy.deepcopy" tends to be slow. Make a shallow copy if possible.', priority=2
),
}
CSS = """\
.collapsible {
cursor: pointer;
}
.content {
display: block;
margin-top: 10px;
margin-bottom: 10px;
}
.hint {
display: inline;
border: 1px solid #ccc;
padding: 5px;
}
"""
JS = """\
document.querySelectorAll('.collapsible').forEach(function(collapsible) {
collapsible.addEventListener('click', function() {
const content = this.nextElementSibling;
if (content.style.display === 'none') {
content.style.display = 'block';
} else {
content.style.display = 'none';
}
});
});
"""
class AnnotatedSource:
"""Annotations for a single compiled source file."""
def __init__(self, path: str, annotations: dict[int, list[Annotation]]) -> None:
self.path = path
self.annotations = annotations
def generate_annotated_html(
html_fnam: str, result: BuildResult, modules: dict[str, ModuleIR], mapper: Mapper
) -> None:
annotations = []
for mod, mod_ir in modules.items():
path = result.graph[mod].path
tree = result.graph[mod].tree
assert tree is not None
annotations.append(
generate_annotations(path or "<source>", tree, mod_ir, result.types, mapper)
)
html = generate_html_report(annotations)
with open(html_fnam, "w") as f:
f.write(html)
formatter = FancyFormatter(sys.stdout, sys.stderr, False)
formatted = formatter.style(os.path.abspath(html_fnam), "none", underline=True, bold=True)
print(f"\nWrote {formatted} -- open in browser to view\n")
def generate_annotations(
path: str, tree: MypyFile, ir: ModuleIR, type_map: dict[Expression, Type], mapper: Mapper
) -> AnnotatedSource:
anns = {}
for func_ir in ir.functions:
anns.update(function_annotations(func_ir, tree))
visitor = ASTAnnotateVisitor(type_map, mapper)
for defn in tree.defs:
defn.accept(visitor)
anns.update(visitor.anns)
for line in visitor.ignored_lines:
if line in anns:
del anns[line]
return AnnotatedSource(path, anns)
def function_annotations(func_ir: FuncIR, tree: MypyFile) -> dict[int, list[Annotation]]:
"""Generate annotations based on mypyc IR."""
# TODO: check if func_ir.line is -1
anns: dict[int, list[Annotation]] = {}
for block in func_ir.blocks:
for op in block.ops:
if isinstance(op, CallC):
name = op.function_name
ann: str | Annotation | None = None
if name == "CPyObject_GetAttr":
attr_name = get_str_literal(op.args[1])
if attr_name in ("__prepare__", "GeneratorExit", "StopIteration"):
# These attributes are internal to mypyc/CPython, and/or accessed
# implicitly in generated code. The user has little control over
# them.
ann = None
elif attr_name:
ann = f'Get non-native attribute "{attr_name}".'
else:
ann = "Dynamic attribute lookup."
elif name == "PyObject_SetAttr":
attr_name = get_str_literal(op.args[1])
if attr_name == "__mypyc_attrs__":
# This is set implicitly and can't be avoided.
ann = None
elif attr_name:
ann = f'Set non-native attribute "{attr_name}".'
else:
ann = "Dynamic attribute set."
elif name == "PyObject_VectorcallMethod":
method_name = get_str_literal(op.args[0])
if method_name:
ann = f'Call non-native method "{method_name}" (it may be defined in a non-native class, or decorated).'
else:
ann = "Dynamic method call."
elif name in op_hints:
ann = op_hints[name]
elif name in ("CPyDict_GetItem", "CPyDict_SetItem"):
if (
isinstance(op.args[0], LoadStatic)
and isinstance(op.args[1], LoadLiteral)
and func_ir.name != "__top_level__"
):
load = op.args[0]
name = str(op.args[1].value)
sym = tree.names.get(name)
if (
sym
and sym.node
and load.namespace == "static"
and load.identifier == "globals"
):
if sym.node.fullname in stdlib_hints:
ann = stdlib_hints[sym.node.fullname]
elif isinstance(sym.node, Var):
ann = (
f'Access global "{name}" through namespace '
+ "dictionary (hint: access is faster if you can make it Final)."
)
else:
ann = f'Access "{name}" through global namespace dictionary.'
if ann:
if isinstance(ann, str):
ann = Annotation(ann)
anns.setdefault(op.line, []).append(ann)
return anns
class ASTAnnotateVisitor(TraverserVisitor):
"""Generate annotations from mypy AST and inferred types."""
def __init__(self, type_map: dict[Expression, Type], mapper: Mapper) -> None:
self.anns: dict[int, list[Annotation]] = {}
self.ignored_lines: set[int] = set()
self.func_depth = 0
self.type_map = type_map
self.mapper = mapper
def visit_func_def(self, o: FuncDef, /) -> None:
if self.func_depth > 0:
self.annotate(
o,
"A nested function object is allocated each time statement is executed. "
+ "A module-level function would be faster.",
)
self.func_depth += 1
super().visit_func_def(o)
self.func_depth -= 1
def visit_for_stmt(self, o: ForStmt, /) -> None:
self.check_iteration([o.expr], "For loop")
super().visit_for_stmt(o)
def visit_dictionary_comprehension(self, o: DictionaryComprehension, /) -> None:
self.check_iteration(o.sequences, "Comprehension")
super().visit_dictionary_comprehension(o)
def visit_generator_expr(self, o: GeneratorExpr, /) -> None:
self.check_iteration(o.sequences, "Comprehension or generator")
super().visit_generator_expr(o)
def check_iteration(self, expressions: list[Expression], kind: str) -> None:
for expr in expressions:
typ = self.get_type(expr)
if isinstance(typ, AnyType):
self.annotate(expr, f'{kind} uses generic operations (iterable has type "Any").')
elif isinstance(typ, Instance) and typ.type.fullname in (
"typing.Iterable",
"typing.Iterator",
"typing.Sequence",
"typing.MutableSequence",
):
self.annotate(
expr,
f'{kind} uses generic operations (iterable has the abstract type "{typ.type.fullname}").',
)
def visit_class_def(self, o: ClassDef, /) -> None:
super().visit_class_def(o)
if self.func_depth == 0:
# Don't complain about base classes at top level
for base in o.base_type_exprs:
self.ignored_lines.add(base.line)
for s in o.defs.body:
if isinstance(s, AssignmentStmt):
# Don't complain about attribute initializers
self.ignored_lines.add(s.line)
elif isinstance(s, Decorator):
# Don't complain about decorator definitions that generate some
# dynamic operations. This is a bit heavy-handed.
self.ignored_lines.add(s.func.line)
def visit_with_stmt(self, o: WithStmt, /) -> None:
for expr in o.expr:
if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr):
node = expr.callee.node
if isinstance(node, Decorator):
if any(
isinstance(d, RefExpr)
and d.node
and d.node.fullname == "contextlib.contextmanager"
for d in node.decorators
):
self.annotate(
expr,
f'"{node.name}" uses @contextmanager, which is slow '
+ "in compiled code. Use a native class with "
+ '"__enter__" and "__exit__" methods instead.',
priority=3,
)
super().visit_with_stmt(o)
def visit_assignment_stmt(self, o: AssignmentStmt, /) -> None:
special_form = False
if self.func_depth == 0:
analyzed: Expression | None = o.rvalue
if isinstance(o.rvalue, (CallExpr, IndexExpr, OpExpr)):
analyzed = o.rvalue.analyzed
if o.is_alias_def or isinstance(
analyzed, (TypeVarExpr, NamedTupleExpr, TypedDictExpr, NewTypeExpr)
):
special_form = True
if special_form:
# TODO: Ignore all lines if multi-line
self.ignored_lines.add(o.line)
super().visit_assignment_stmt(o)
def visit_name_expr(self, o: NameExpr, /) -> None:
if ann := stdlib_hints.get(o.fullname):
self.annotate(o, ann)
def visit_member_expr(self, o: MemberExpr, /) -> None:
super().visit_member_expr(o)
if ann := stdlib_hints.get(o.fullname):
self.annotate(o, ann)
def visit_call_expr(self, o: CallExpr, /) -> None:
super().visit_call_expr(o)
if (
isinstance(o.callee, RefExpr)
and o.callee.fullname == "builtins.isinstance"
and len(o.args) == 2
):
arg = o.args[1]
self.check_isinstance_arg(arg)
elif isinstance(o.callee, RefExpr) and isinstance(o.callee.node, TypeInfo):
info = o.callee.node
class_ir = self.mapper.type_to_ir.get(info)
if (class_ir and not class_ir.is_ext_class) or (
class_ir is None and not info.fullname.startswith("builtins.")
):
self.annotate(
o, f'Creating an instance of non-native class "{info.name}" ' + "is slow.", 2
)
elif class_ir and class_ir.is_augmented:
self.annotate(
o,
f'Class "{info.name}" is only partially native, and '
+ "constructing an instance is slow.",
2,
)
elif isinstance(o.callee, RefExpr) and isinstance(o.callee.node, Decorator):
decorator = o.callee.node
if self.mapper.is_native_ref_expr(o.callee):
self.annotate(
o,
f'Calling a decorated function ("{decorator.name}") is inefficient, even if it\'s native.',
2,
)
def check_isinstance_arg(self, arg: Expression) -> None:
if isinstance(arg, RefExpr):
if isinstance(arg.node, TypeInfo) and arg.node.is_protocol:
self.annotate(
arg, f'Expensive isinstance() check against protocol "{arg.node.name}".'
)
elif isinstance(arg, TupleExpr):
for item in arg.items:
self.check_isinstance_arg(item)
def visit_lambda_expr(self, o: LambdaExpr, /) -> None:
self.annotate(
o,
"A new object is allocated for lambda each time it is evaluated. "
+ "A module-level function would be faster.",
)
super().visit_lambda_expr(o)
def annotate(self, o: Node, ann: str | Annotation, priority: int = 1) -> None:
if isinstance(ann, str):
ann = Annotation(ann, priority=priority)
self.anns.setdefault(o.line, []).append(ann)
def get_type(self, e: Expression) -> ProperType:
t = self.type_map.get(e)
if t:
return get_proper_type(t)
return AnyType(TypeOfAny.unannotated)
def get_str_literal(v: Value) -> str | None:
if isinstance(v, LoadLiteral) and isinstance(v.value, str):
return v.value
return None
def get_max_prio(anns: list[Annotation]) -> list[Annotation]:
max_prio = max(a.priority for a in anns)
return [a for a in anns if a.priority == max_prio]
def generate_html_report(sources: list[AnnotatedSource]) -> str:
html = []
html.append("<html>\n<head>\n")
html.append(f"<style>\n{CSS}\n</style>")
html.append("</head>\n")
html.append("<body>\n")
for src in sources:
html.append(f"<h2><tt>{src.path}</tt></h2>\n")
html.append("<pre>")
src_anns = src.annotations
with open(src.path) as f:
lines = f.readlines()
for i, s in enumerate(lines):
s = escape(s)
line = i + 1
linenum = "%5d" % line
if line in src_anns:
anns = get_max_prio(src_anns[line])
ann_strs = [a.message for a in anns]
hint = " ".join(ann_strs)
s = colorize_line(linenum, s, hint_html=hint)
else:
s = linenum + " " + s
html.append(s)
html.append("</pre>")
html.append("<script>")
html.append(JS)
html.append("</script>")
html.append("</body></html>\n")
return "".join(html)
def colorize_line(linenum: str, s: str, hint_html: str) -> str:
hint_prefix = " " * len(linenum) + " "
line_span = f'<div class="collapsible" style="background-color: #fcc">{linenum} {s}</div>'
hint_div = f'<div class="content">{hint_prefix}<div class="hint">{hint_html}</div></div>'
return f"<span>{line_span}{hint_div}</span>"

View file

@ -0,0 +1,842 @@
"""Support for building extensions using mypyc with distutils or setuptools
The main entry point is mypycify, which produces a list of extension
modules to be passed to setup. A trivial setup.py for a mypyc built
project, then, looks like:
from setuptools import setup
from mypyc.build import mypycify
setup(name='test_module',
ext_modules=mypycify(['foo.py']),
)
See the mypycify docs for additional arguments.
mypycify can integrate with either distutils or setuptools, but needs
to know at import-time whether it is using distutils or setuputils. We
hackily decide based on whether setuptools has been imported already.
"""
from __future__ import annotations
import hashlib
import os.path
import re
import sys
import time
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, NamedTuple, NoReturn, cast
import mypyc.build_setup # noqa: F401
from mypy.build import BuildSource
from mypy.errors import CompileError
from mypy.fscache import FileSystemCache
from mypy.main import process_options
from mypy.options import Options
from mypy.util import write_junit_xml
from mypyc.annotate import generate_annotated_html
from mypyc.codegen import emitmodule
from mypyc.common import IS_FREE_THREADED, RUNTIME_C_FILES, shared_lib_name
from mypyc.errors import Errors
from mypyc.ir.deps import SourceDep
from mypyc.ir.pprint import format_modules
from mypyc.namegen import exported_name
from mypyc.options import CompilerOptions
class ModDesc(NamedTuple):
module: str
c_files: list[str]
other_files: list[str]
include_dirs: list[str]
LIBRT_MODULES = [
ModDesc("librt.internal", ["internal/librt_internal.c"], [], ["internal"]),
ModDesc("librt.strings", ["strings/librt_strings.c"], [], ["strings"]),
ModDesc(
"librt.base64",
[
"base64/librt_base64.c",
"base64/lib.c",
"base64/codec_choose.c",
"base64/tables/tables.c",
"base64/arch/generic/codec.c",
"base64/arch/ssse3/codec.c",
"base64/arch/sse41/codec.c",
"base64/arch/sse42/codec.c",
"base64/arch/avx/codec.c",
"base64/arch/avx2/codec.c",
"base64/arch/avx512/codec.c",
"base64/arch/neon32/codec.c",
"base64/arch/neon64/codec.c",
],
[
"base64/arch/avx/enc_loop_asm.c",
"base64/arch/avx2/enc_loop.c",
"base64/arch/avx2/enc_loop_asm.c",
"base64/arch/avx2/enc_reshuffle.c",
"base64/arch/avx2/enc_translate.c",
"base64/arch/avx2/dec_loop.c",
"base64/arch/avx2/dec_reshuffle.c",
"base64/arch/generic/32/enc_loop.c",
"base64/arch/generic/64/enc_loop.c",
"base64/arch/generic/32/dec_loop.c",
"base64/arch/generic/enc_head.c",
"base64/arch/generic/enc_tail.c",
"base64/arch/generic/dec_head.c",
"base64/arch/generic/dec_tail.c",
"base64/arch/ssse3/dec_reshuffle.c",
"base64/arch/ssse3/dec_loop.c",
"base64/arch/ssse3/enc_loop_asm.c",
"base64/arch/ssse3/enc_translate.c",
"base64/arch/ssse3/enc_reshuffle.c",
"base64/arch/ssse3/enc_loop.c",
"base64/arch/neon64/dec_loop.c",
"base64/arch/neon64/enc_loop_asm.c",
"base64/codecs.h",
"base64/env.h",
"base64/lib_openmp.c",
"base64/tables/tables.h",
"base64/tables/table_dec_32bit.h",
"base64/tables/table_enc_12bit.h",
],
["base64"],
),
ModDesc(
"librt.vecs",
[
"vecs/librt_vecs.c",
"vecs/vec_i64.c",
"vecs/vec_i32.c",
"vecs/vec_i16.c",
"vecs/vec_u8.c",
"vecs/vec_float.c",
"vecs/vec_bool.c",
"vecs/vec_t.c",
"vecs/vec_nested.c",
],
["vecs/librt_vecs.h", "vecs/vec_template.c"],
["vecs"],
),
ModDesc("librt.time", ["time/librt_time.c"], ["time/librt_time.h"], []),
]
try:
# Import setuptools so that it monkey-patch overrides distutils
import setuptools
except ImportError:
pass
if TYPE_CHECKING:
if sys.version_info >= (3, 12):
from setuptools import Extension
else:
from distutils.core import Extension as _distutils_Extension
from typing import TypeAlias
from setuptools import Extension as _setuptools_Extension
Extension: TypeAlias = _setuptools_Extension | _distutils_Extension
if sys.version_info >= (3, 12):
# From setuptools' monkeypatch
from distutils import ccompiler, sysconfig # type: ignore[import-not-found]
else:
from distutils import ccompiler, sysconfig
def get_extension() -> type[Extension]:
# We can work with either setuptools or distutils, and pick setuptools
# if it has been imported.
use_setuptools = "setuptools" in sys.modules
extension_class: type[Extension]
if sys.version_info < (3, 12) and not use_setuptools:
import distutils.core
extension_class = distutils.core.Extension
else:
if not use_setuptools:
sys.exit("error: setuptools not installed")
extension_class = setuptools.Extension
return extension_class
def setup_mypycify_vars() -> None:
"""Rewrite a bunch of config vars in pretty dubious ways."""
# There has to be a better approach to this.
# The vars can contain ints but we only work with str ones
vars = cast(dict[str, str], sysconfig.get_config_vars())
if sys.platform == "darwin":
# Disable building 32-bit binaries, since we generate too much code
# for a 32-bit Mach-O object. There has to be a better way to do this.
vars["LDSHARED"] = vars["LDSHARED"].replace("-arch i386", "")
vars["LDFLAGS"] = vars["LDFLAGS"].replace("-arch i386", "")
vars["CFLAGS"] = vars["CFLAGS"].replace("-arch i386", "")
def fail(message: str) -> NoReturn:
# TODO: Is there something else we should do to fail?
sys.exit(message)
def emit_messages(options: Options, messages: list[str], dt: float, serious: bool = False) -> None:
# ... you know, just in case.
if options.junit_xml:
py_version = f"{options.python_version[0]}_{options.python_version[1]}"
write_junit_xml(
dt,
serious,
{None: messages} if messages else {},
options.junit_xml,
py_version,
options.platform,
)
if messages:
print("\n".join(messages))
def get_mypy_config(
mypy_options: list[str],
only_compile_paths: Iterable[str] | None,
compiler_options: CompilerOptions,
fscache: FileSystemCache | None,
) -> tuple[list[BuildSource], list[BuildSource], Options]:
"""Construct mypy BuildSources and Options from file and options lists"""
all_sources, options = process_options(mypy_options, fscache=fscache, mypyc=True)
if only_compile_paths is not None:
paths_set = set(only_compile_paths)
mypyc_sources = [s for s in all_sources if s.path in paths_set]
else:
mypyc_sources = all_sources
if compiler_options.separate:
mypyc_sources = [src for src in mypyc_sources if src.path]
if not mypyc_sources:
return mypyc_sources, all_sources, options
# Override whatever python_version is inferred from the .ini file,
# and set the python_version to be the currently used version.
options.python_version = sys.version_info[:2]
if options.python_version[0] == 2:
fail("Python 2 not supported")
if not options.strict_optional:
fail("Disabling strict optional checking not supported")
options.show_traceback = True
# Needed to get types for all AST nodes
options.export_types = True
# We use mypy incremental mode when doing separate/incremental mypyc compilation
options.incremental = compiler_options.separate
options.preserve_asts = True
for source in mypyc_sources:
options.per_module_options.setdefault(source.module, {})["mypyc"] = True
return mypyc_sources, all_sources, options
def is_package_source(source: BuildSource) -> bool:
return source.path is not None and os.path.split(source.path)[1] == "__init__.py"
def generate_c_extension_shim(
full_module_name: str, module_name: str, dir_name: str, group_name: str
) -> str:
"""Create a C extension shim with a passthrough PyInit function.
Arguments:
full_module_name: the dotted full module name
module_name: the final component of the module name
dir_name: the directory to place source code
group_name: the name of the group
"""
cname = "%s.c" % full_module_name.replace(".", os.sep)
cpath = os.path.join(dir_name, cname)
if IS_FREE_THREADED:
# We use multi-phase init in free-threaded builds to enable free threading.
shim_name = "module_shim_no_gil_multiphase.tmpl"
else:
shim_name = "module_shim.tmpl"
# We load the C extension shim template from a file.
# (So that the file could be reused as a bazel template also.)
with open(os.path.join(include_dir(), shim_name)) as f:
shim_template = f.read()
write_file(
cpath,
shim_template.format(
modname=module_name,
libname=shared_lib_name(group_name),
full_modname=exported_name(full_module_name),
),
)
return cpath
def group_name(modules: list[str]) -> str:
"""Produce a probably unique name for a group from a list of module names."""
if len(modules) == 1:
return modules[0]
h = hashlib.sha1()
h.update(",".join(modules).encode())
return h.hexdigest()[:20]
def include_dir() -> str:
"""Find the path of the lib-rt dir that needs to be included"""
return os.path.join(os.path.abspath(os.path.dirname(__file__)), "lib-rt")
def generate_c(
sources: list[BuildSource],
options: Options,
groups: emitmodule.Groups,
fscache: FileSystemCache,
compiler_options: CompilerOptions,
) -> tuple[list[list[tuple[str, str]]], str, list[SourceDep]]:
"""Drive the actual core compilation step.
The groups argument describes how modules are assigned to C
extension modules. See the comments on the Groups type in
mypyc.emitmodule for details.
Returns the C source code, (for debugging) the pretty printed IR, and list of SourceDeps.
"""
t0 = time.time()
try:
result = emitmodule.parse_and_typecheck(
sources, options, compiler_options, groups, fscache
)
except CompileError as e:
emit_messages(options, e.messages, time.time() - t0, serious=(not e.use_stdout))
sys.exit(1)
t1 = time.time()
if result.errors:
emit_messages(options, result.errors, t1 - t0)
sys.exit(1)
if compiler_options.verbose:
print(f"Parsed and typechecked in {t1 - t0:.3f}s")
errors = Errors(options)
modules, ctext, mapper = emitmodule.compile_modules_to_c(
result, compiler_options=compiler_options, errors=errors, groups=groups
)
t2 = time.time()
emit_messages(options, errors.new_messages(), t2 - t1)
if errors.num_errors:
# No need to stop the build if only warnings were emitted.
sys.exit(1)
if compiler_options.verbose:
print(f"Compiled to C in {t2 - t1:.3f}s")
if options.mypyc_annotation_file:
generate_annotated_html(options.mypyc_annotation_file, result, modules, mapper)
# Collect SourceDep dependencies
source_deps = sorted(emitmodule.collect_source_dependencies(modules), key=lambda d: d.path)
return ctext, "\n".join(format_modules(modules)), source_deps
def build_using_shared_lib(
sources: list[BuildSource],
group_name: str,
cfiles: list[str],
deps: list[str],
build_dir: str,
extra_compile_args: list[str],
) -> list[Extension]:
"""Produce the list of extension modules when a shared library is needed.
This creates one shared library extension module that all the
others import, and one shim extension module for each
module in the build. Each shim simply calls an initialization function
in the shared library.
The shared library (which lib_name is the name of) is a Python
extension module that exports the real initialization functions in
Capsules stored in module attributes.
"""
extensions = [
get_extension()(
shared_lib_name(group_name),
sources=cfiles,
include_dirs=[include_dir(), build_dir],
depends=deps,
extra_compile_args=extra_compile_args,
)
]
for source in sources:
module_name = source.module.split(".")[-1]
shim_file = generate_c_extension_shim(source.module, module_name, build_dir, group_name)
# We include the __init__ in the "module name" we stick in the Extension,
# since this seems to be needed for it to end up in the right place.
full_module_name = source.module
assert source.path
if is_package_source(source):
full_module_name += ".__init__"
extensions.append(
get_extension()(
full_module_name, sources=[shim_file], extra_compile_args=extra_compile_args
)
)
return extensions
def build_single_module(
sources: list[BuildSource], cfiles: list[str], extra_compile_args: list[str]
) -> list[Extension]:
"""Produce the list of extension modules for a standalone extension.
This contains just one module, since there is no need for a shared module.
"""
return [
get_extension()(
sources[0].module,
sources=cfiles,
include_dirs=[include_dir()],
extra_compile_args=extra_compile_args,
)
]
def write_file(path: str, contents: str) -> None:
"""Write data into a file.
If the file already exists and has the same contents we
want to write, skip writing so as to preserve the mtime
and avoid triggering recompilation.
"""
# We encode it ourselves and open the files as binary to avoid windows
# newline translation
encoded_contents = contents.encode("utf-8")
try:
with open(path, "rb") as f:
old_contents: bytes | None = f.read()
except OSError:
old_contents = None
if old_contents != encoded_contents:
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "wb") as g:
g.write(encoded_contents)
# Fudge the mtime forward because otherwise when two builds happen close
# together (like in a test) setuptools might not realize the source is newer
# than the new artifact.
# XXX: This is bad though.
new_mtime = os.stat(path).st_mtime + 1
os.utime(path, times=(new_mtime, new_mtime))
def construct_groups(
sources: list[BuildSource],
separate: bool | list[tuple[list[str], str | None]],
use_shared_lib: bool,
group_name_override: str | None,
) -> emitmodule.Groups:
"""Compute Groups given the input source list and separate configs.
separate is the user-specified configuration for how to assign
modules to compilation groups (see mypycify docstring for details).
This takes that and expands it into our internal representation of
group configuration, documented in mypyc.emitmodule's definition
of Group.
"""
if separate is True:
groups: emitmodule.Groups = [([source], None) for source in sources]
elif isinstance(separate, list):
groups = []
used_sources = set()
for files, name in separate:
normalized_files = {os.path.normpath(f) for f in files}
group_sources = [
src
for src in sources
if src.path is not None and os.path.normpath(src.path) in normalized_files
]
groups.append((group_sources, name))
used_sources.update(group_sources)
unused_sources = [src for src in sources if src not in used_sources]
if unused_sources:
groups.extend([([source], None) for source in unused_sources])
else:
groups = [(sources, None)]
# Generate missing names
for i, (group, name) in enumerate(groups):
if use_shared_lib and not name:
if group_name_override is not None:
name = group_name_override
else:
name = group_name([source.module for source in group])
groups[i] = (group, name)
return groups
def get_header_deps(cfiles: list[tuple[str, str]]) -> list[str]:
"""Find all the headers used by a group of cfiles.
We do this by just regexping the source, which is a bit simpler than
properly plumbing the data through.
Arguments:
cfiles: A list of (file name, file contents) pairs.
"""
headers: set[str] = set()
for _, contents in cfiles:
headers.update(re.findall(r'#include "(.*)"', contents))
return sorted(headers)
def mypyc_build(
paths: list[str],
compiler_options: CompilerOptions,
*,
separate: bool | list[tuple[list[str], str | None]] = False,
only_compile_paths: Iterable[str] | None = None,
skip_cgen_input: tuple[list[list[tuple[str, str]]], list[str]] | None = None,
always_use_shared_lib: bool = False,
) -> tuple[emitmodule.Groups, list[tuple[list[str], list[str]]], list[SourceDep]]:
"""Do the front and middle end of mypyc building, producing and writing out C source."""
fscache = FileSystemCache()
mypyc_sources, all_sources, options = get_mypy_config(
paths, only_compile_paths, compiler_options, fscache
)
# We generate a shared lib if there are multiple modules or if any
# of the modules are in package. (Because I didn't want to fuss
# around with making the single module code handle packages.)
use_shared_lib = (
len(mypyc_sources) > 1
or any("." in x.module for x in mypyc_sources)
or any(is_package_source(x) for x in mypyc_sources)
or always_use_shared_lib
)
groups = construct_groups(mypyc_sources, separate, use_shared_lib, compiler_options.group_name)
if compiler_options.group_name is not None:
assert len(groups) == 1, "If using custom group_name, only one group is expected"
# We let the test harness just pass in the c file contents instead
# so that it can do a corner-cutting version without full stubs.
source_deps: list[SourceDep] = []
if not skip_cgen_input:
group_cfiles, ops_text, source_deps = generate_c(
all_sources, options, groups, fscache, compiler_options=compiler_options
)
# TODO: unique names?
write_file(os.path.join(compiler_options.target_dir, "ops.txt"), ops_text)
else:
group_cfiles = skip_cgen_input[0]
source_deps = [SourceDep(d) for d in skip_cgen_input[1]]
# Write out the generated C and collect the files for each group
# Should this be here??
group_cfilenames: list[tuple[list[str], list[str]]] = []
for cfiles in group_cfiles:
cfilenames = []
for cfile, ctext in cfiles:
cfile = os.path.join(compiler_options.target_dir, cfile)
if not options.mypyc_skip_c_generation:
write_file(cfile, ctext)
if os.path.splitext(cfile)[1] == ".c":
cfilenames.append(cfile)
deps = [os.path.join(compiler_options.target_dir, dep) for dep in get_header_deps(cfiles)]
group_cfilenames.append((cfilenames, deps))
return groups, group_cfilenames, source_deps
def get_cflags(
*,
compiler_type: str | None = None,
opt_level: str = "3",
debug_level: str = "1",
multi_file: bool = False,
experimental_features: bool = False,
log_trace: bool = False,
) -> list[str]:
"""Get C compiler flags for the given configuration.
Args:
compiler_type: Compiler type, e.g. "unix" or "msvc". If None, detected automatically.
opt_level: Optimization level as string ("0", "1", "2", or "3").
debug_level: Debug level as string ("0", "1", "2", or "3").
multi_file: Whether multi-file compilation mode is enabled.
experimental_features: Whether experimental features are enabled.
log_trace: Whether trace logging is enabled.
Returns:
List of compiler flags.
"""
if compiler_type is None:
compiler: Any = ccompiler.new_compiler()
sysconfig.customize_compiler(compiler)
compiler_type = compiler.compiler_type
cflags: list[str] = []
if compiler_type == "unix":
cflags += [
f"-O{opt_level}",
f"-g{debug_level}",
"-Werror",
"-Wno-unused-function",
"-Wno-unused-label",
"-Wno-unreachable-code",
"-Wno-unused-variable",
"-Wno-unused-command-line-argument",
"-Wno-unknown-warning-option",
"-Wno-unused-but-set-variable",
"-Wno-ignored-optimization-argument",
# GCC at -O3 false-positives on struct hack (items[1]) in vec buffers
"-Wno-array-bounds",
"-Wno-stringop-overread",
"-Wno-stringop-overflow",
# Disables C Preprocessor (cpp) warnings
# See https://github.com/mypyc/mypyc/issues/956
"-Wno-cpp",
]
if log_trace:
cflags.append("-DMYPYC_LOG_TRACE")
if experimental_features:
cflags.append("-DMYPYC_EXPERIMENTAL")
if opt_level == "0":
cflags.append("-UNDEBUG")
elif compiler_type == "msvc":
# msvc doesn't have levels, '/O2' is full and '/Od' is disable
if opt_level == "0":
opt_level = "d"
cflags.append("/UNDEBUG")
elif opt_level in ("1", "2", "3"):
opt_level = "2"
if debug_level == "0":
debug_level = "NONE"
elif debug_level == "1":
debug_level = "FASTLINK"
elif debug_level in ("2", "3"):
debug_level = "FULL"
cflags += [
f"/O{opt_level}",
f"/DEBUG:{debug_level}",
"/wd4102", # unreferenced label
"/wd4101", # unreferenced local variable
"/wd4146", # negating unsigned int
]
if multi_file:
# Disable whole program optimization in multi-file mode so
# that we actually get the compilation speed and memory
# use wins that multi-file mode is intended for.
cflags += ["/GL-", "/wd9025"] # warning about overriding /GL
if log_trace:
cflags.append("/DMYPYC_LOG_TRACE")
if experimental_features:
cflags.append("/DMYPYC_EXPERIMENTAL")
return cflags
def mypycify(
paths: list[str],
*,
only_compile_paths: Iterable[str] | None = None,
verbose: bool = False,
opt_level: str = "3",
debug_level: str = "1",
strip_asserts: bool = False,
multi_file: bool = False,
separate: bool | list[tuple[list[str], str | None]] = False,
skip_cgen_input: tuple[list[list[tuple[str, str]]], list[str]] | None = None,
target_dir: str | None = None,
include_runtime_files: bool | None = None,
strict_dunder_typing: bool = False,
group_name: str | None = None,
log_trace: bool = False,
depends_on_librt_internal: bool = False,
install_librt: bool = False,
experimental_features: bool = False,
) -> list[Extension]:
"""Main entry point to building using mypyc.
This produces a list of Extension objects that should be passed as the
ext_modules parameter to setup.
Arguments:
paths: A list of file paths to build. It may also contain mypy options.
only_compile_paths: If not None, an iterable of paths that are to be
the only modules compiled, even if other modules
appear in the mypy command line given to paths.
(These modules must still be passed to paths.)
verbose: Should mypyc be more verbose. Defaults to false.
opt_level: The optimization level, as a string. Defaults to '3' (meaning '-O3').
debug_level: The debug level, as a string. Defaults to '1' (meaning '-g1').
strip_asserts: Should asserts be stripped from the generated code.
multi_file: Should each Python module be compiled into its own C source file.
This can reduce compile time and memory requirements at the likely
cost of runtime performance of compiled code. Defaults to false.
separate: Should compiled modules be placed in separate extension modules.
If False, all modules are placed in a single shared library.
If True, every module is placed in its own library.
Otherwise, separate should be a list of
(file name list, optional shared library name) pairs specifying
groups of files that should be placed in the same shared library
(while all other modules will be placed in its own library).
Each group can be compiled independently, which can
speed up compilation, but calls between groups can
be slower than calls within a group and can't be
inlined.
target_dir: The directory to write C output files. Defaults to 'build'.
include_runtime_files: If not None, whether the mypyc runtime library
should be directly #include'd instead of linked
separately in order to reduce compiler invocations.
Defaults to False in multi_file mode, True otherwise.
strict_dunder_typing: If True, force dunder methods to have the return type
of the method strictly, which can lead to more
optimization opportunities. Defaults to False.
group_name: If set, override the default group name derived from
the hash of module names. This is used for the names of the
output C files and the shared library. This is only supported
if there is a single group. [Experimental]
log_trace: If True, compiled code writes a trace log of events in
mypyc_trace.txt (derived from executed operations). This is
useful for performance analysis, such as analyzing which
primitive ops are used the most and on which lines.
depends_on_librt_internal: This is True only for mypy itself.
install_librt: If True, also build the librt extension modules. Normally,
those are build and published on PyPI separately, but during
tests, we want to use their development versions (i.e. from
current commit).
experimental_features: Enable experimental features (install_librt=True is
also needed if using experimental librt features). These
have no backward compatibility guarantees!
"""
# Figure out our configuration
compiler_options = CompilerOptions(
strip_asserts=strip_asserts,
multi_file=multi_file,
verbose=verbose,
separate=separate is not False,
target_dir=target_dir,
include_runtime_files=include_runtime_files,
strict_dunder_typing=strict_dunder_typing,
group_name=group_name,
log_trace=log_trace,
depends_on_librt_internal=depends_on_librt_internal,
experimental_features=experimental_features,
)
# Generate all the actual important C code
groups, group_cfilenames, source_deps = mypyc_build(
paths,
only_compile_paths=only_compile_paths,
compiler_options=compiler_options,
separate=separate,
skip_cgen_input=skip_cgen_input,
)
# Mess around with setuptools and actually get the thing built
setup_mypycify_vars()
# Create a compiler object so we can make decisions based on what
# compiler is being used. typeshed is missing some attributes on the
# compiler object so we give it type Any
compiler: Any = ccompiler.new_compiler()
sysconfig.customize_compiler(compiler)
build_dir = compiler_options.target_dir
cflags = get_cflags(
compiler_type=compiler.compiler_type,
opt_level=opt_level,
debug_level=debug_level,
multi_file=multi_file,
experimental_features=experimental_features,
log_trace=log_trace,
)
# If configured to (defaults to yes in multi-file mode), copy the
# runtime library in. Otherwise it just gets #included to save on
# compiler invocations.
shared_cfilenames = []
if not compiler_options.include_runtime_files:
# Collect all files to copy: runtime files + conditional source files
files_to_copy = list(RUNTIME_C_FILES)
for source_dep in source_deps:
files_to_copy.append(source_dep.path)
files_to_copy.append(source_dep.get_header())
# Copy all files
for name in files_to_copy:
rt_file = os.path.join(build_dir, name)
with open(os.path.join(include_dir(), name), encoding="utf-8") as f:
write_file(rt_file, f.read())
if name.endswith(".c"):
shared_cfilenames.append(rt_file)
extensions = []
for (group_sources, lib_name), (cfilenames, deps) in zip(groups, group_cfilenames):
if lib_name:
extensions.extend(
build_using_shared_lib(
group_sources,
lib_name,
cfilenames + shared_cfilenames,
deps,
build_dir,
cflags,
)
)
else:
extensions.extend(
build_single_module(group_sources, cfilenames + shared_cfilenames, cflags)
)
if install_librt:
for name in RUNTIME_C_FILES:
rt_file = os.path.join(build_dir, name)
with open(os.path.join(include_dir(), name), encoding="utf-8") as f:
write_file(rt_file, f.read())
for mod, file_names, addit_files, includes in LIBRT_MODULES:
for file_name in file_names + addit_files:
rt_file = os.path.join(build_dir, file_name)
with open(os.path.join(include_dir(), file_name), encoding="utf-8") as f:
write_file(rt_file, f.read())
extensions.append(
get_extension()(
mod,
sources=[
os.path.join(build_dir, file) for file in file_names + RUNTIME_C_FILES
],
include_dirs=[include_dir()]
+ [os.path.join(include_dir(), d) for d in includes],
extra_compile_args=cflags,
)
)
return extensions

View file

@ -0,0 +1,69 @@
# This file must have the same content for mypyc/build_setup.py and lib-rt/build_setup.py,
# it exists to work around absence of support for per-file compile flags in setuptools.
# The version in mypyc/ is the source of truth, and should be copied to lib-rt if modified.
import os
import platform
import sys
try:
# Import setuptools so that it monkey-patch overrides distutils
import setuptools # noqa: F401
except ImportError:
pass
if sys.version_info >= (3, 12):
# From setuptools' monkeypatch
from distutils import ccompiler # type: ignore[import-not-found]
else:
from distutils import ccompiler
EXTRA_FLAGS_PER_COMPILER_TYPE_PER_PATH_COMPONENT = {
"msvc": {
"base64/arch/sse42": ["/arch:SSE4.2"],
"base64/arch/avx2": ["/arch:AVX2"],
"base64/arch/avx": ["/arch:AVX"],
}
}
ccompiler.CCompiler.__spawn = ccompiler.CCompiler.spawn # type: ignore[attr-defined]
X86_64 = platform.machine() in ("x86_64", "AMD64", "amd64")
PYODIDE = "PYODIDE" in os.environ
NO_EXTRA_FLAGS = "MYPYC_NO_EXTRA_FLAGS" in os.environ
def spawn(self, cmd, **kwargs) -> None: # type: ignore[no-untyped-def]
new_cmd = list(cmd)
if PYODIDE:
for argument in reversed(new_cmd):
if not str(argument).endswith(".c"):
continue
if "base64/arch/" in str(argument):
new_cmd.extend(["-msimd128"])
elif not NO_EXTRA_FLAGS:
compiler_type: str = self.compiler_type
extra_options = EXTRA_FLAGS_PER_COMPILER_TYPE_PER_PATH_COMPONENT.get(compiler_type, None)
if X86_64 and extra_options is not None:
# filenames are closer to the end of command line
for argument in reversed(new_cmd):
# Check if the matching argument contains a source filename.
if not str(argument).endswith(".c"):
continue
for path in extra_options.keys():
if path in str(argument):
if compiler_type == "bcpp":
compiler = new_cmd.pop()
# Borland accepts a source file name at the end,
# insert the options before it
new_cmd.extend(extra_options[path])
new_cmd.append(compiler)
else:
new_cmd.extend(extra_options[path])
# path component is found, no need to search any further
break
self.__spawn(new_cmd, **kwargs)
ccompiler.CCompiler.spawn = spawn # type: ignore[method-assign]

View file

@ -0,0 +1,54 @@
"""Encode valid C string literals from Python strings.
If a character is not allowed in C string literals, it is either emitted
as a simple escape sequence (e.g. '\\n'), or an octal escape sequence
with exactly three digits ('\\oXXX'). Question marks are escaped to
prevent trigraphs in the string literal from being interpreted. Note
that '\\?' is an invalid escape sequence in Python.
Consider the string literal "AB\\xCDEF". As one would expect, Python
parses it as ['A', 'B', 0xCD, 'E', 'F']. However, the C standard
specifies that all hexadecimal digits immediately following '\\x' will
be interpreted as part of the escape sequence. Therefore, it is
unexpectedly parsed as ['A', 'B', 0xCDEF].
Emitting ("AB\\xCD" "EF") would avoid this behaviour. However, we opt
for simplicity and use octal escape sequences instead. They do not
suffer from the same issue as they are defined to parse at most three
octal digits.
"""
from __future__ import annotations
import string
from typing import Final
CHAR_MAP: Final = [f"\\{i:03o}" for i in range(256)]
# It is safe to use string.printable as it always uses the C locale.
for c in string.printable:
CHAR_MAP[ord(c)] = c
# These assignments must come last because we prioritize simple escape
# sequences over any other representation.
for c in ("'", '"', "\\", "a", "b", "f", "n", "r", "t", "v"):
escaped = f"\\{c}"
decoded = escaped.encode("ascii").decode("unicode_escape")
CHAR_MAP[ord(decoded)] = escaped
# This escape sequence is invalid in Python.
CHAR_MAP[ord("?")] = r"\?"
def encode_bytes_as_c_string(b: bytes) -> str:
"""Produce contents of a C string literal for a byte string, without quotes."""
escaped = "".join([CHAR_MAP[i] for i in b])
return escaped
def c_string_initializer(value: bytes) -> str:
"""Create initializer for a C char[]/ char * variable from a string.
For example, if value if b'foo', the result would be '"foo"'.
"""
return '"' + encode_bytes_as_c_string(value) + '"'

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,993 @@
"""Code generation for native function bodies."""
from __future__ import annotations
from typing import Final
from mypyc.analysis.blockfreq import frequently_executed_blocks
from mypyc.codegen.emit import (
DEBUG_ERRORS,
PREFIX_MAP,
Emitter,
TracebackAndGotoHandler,
c_array_initializer,
)
from mypyc.common import GENERATOR_ATTRIBUTE_PREFIX, HAVE_IMMORTAL, NATIVE_PREFIX, REG_PREFIX
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.func_ir import FUNC_CLASSMETHOD, FUNC_STATICMETHOD, FuncDecl, FuncIR, all_values
from mypyc.ir.ops import (
ERR_FALSE,
NAMESPACE_TYPE,
Assign,
AssignMulti,
BasicBlock,
Box,
Branch,
Call,
CallC,
Cast,
ComparisonOp,
ControlOp,
CString,
DecRef,
Extend,
Float,
FloatComparisonOp,
FloatNeg,
FloatOp,
GetAttr,
GetElement,
GetElementPtr,
Goto,
IncRef,
InitStatic,
Integer,
IntOp,
KeepAlive,
LoadAddress,
LoadErrorValue,
LoadGlobal,
LoadLiteral,
LoadMem,
LoadStatic,
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
Return,
SetAttr,
SetElement,
SetMem,
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Undef,
Unreachable,
Value,
)
from mypyc.ir.pprint import generate_names_for_ir
from mypyc.ir.rtypes import (
RArray,
RInstance,
RStruct,
RTuple,
RType,
RVec,
is_bool_or_bit_rprimitive,
is_int32_rprimitive,
is_int64_rprimitive,
is_int_rprimitive,
is_none_rprimitive,
is_pointer_rprimitive,
is_tagged,
)
def native_function_type(fn: FuncIR, emitter: Emitter) -> str:
args = ", ".join(emitter.ctype(arg.type) for arg in fn.args) or "void"
ret = emitter.ctype(fn.ret_type)
return f"{ret} (*)({args})"
def native_function_header(fn: FuncDecl, emitter: Emitter) -> str:
args = []
for arg in fn.sig.args:
args.append(f"{emitter.ctype_spaced(arg.type)}{REG_PREFIX}{arg.name}")
return "{ret_type}{name}({args})".format(
ret_type=emitter.ctype_spaced(fn.sig.ret_type),
name=emitter.native_function_name(fn),
args=", ".join(args) or "void",
)
def generate_native_function(
fn: FuncIR, emitter: Emitter, source_path: str, module_name: str
) -> None:
declarations = Emitter(emitter.context)
names = generate_names_for_ir(fn.arg_regs, fn.blocks)
body = Emitter(emitter.context, names)
visitor = FunctionEmitterVisitor(body, declarations, source_path, module_name)
declarations.emit_line(f"{native_function_header(fn.decl, emitter)} {{")
body.indent()
for r in all_values(fn.arg_regs, fn.blocks):
if isinstance(r.type, RTuple):
emitter.declare_tuple_struct(r.type)
if isinstance(r.type, RArray):
continue # Special: declared on first assignment
if r in fn.arg_regs:
continue # Skip the arguments
ctype = emitter.ctype_spaced(r.type)
init = ""
declarations.emit_line(
"{ctype}{prefix}{name}{init};".format(
ctype=ctype, prefix=REG_PREFIX, name=names[r], init=init
)
)
# Before we emit the blocks, give them all labels
blocks = fn.blocks
for i, block in enumerate(blocks):
block.label = i
# Find blocks that are never jumped to or are only jumped to from the
# block directly above it. This allows for more labels and gotos to be
# eliminated during code generation.
for block in fn.blocks:
terminator = block.terminator
assert isinstance(terminator, ControlOp), terminator
for target in terminator.targets():
is_next_block = target.label == block.label + 1
# Always emit labels for GetAttr error checks since the emit code that
# generates them will add instructions between the branch and the
# next label, causing the label to be wrongly removed. A better
# solution would be to change the IR so that it adds a basic block
# in between the calls.
is_problematic_op = isinstance(terminator, Branch) and any(
isinstance(s, GetAttr) for s in terminator.sources()
)
if not is_next_block or is_problematic_op:
fn.blocks[target.label].referenced = True
common = frequently_executed_blocks(fn.blocks[0])
for i in range(len(blocks)):
block = blocks[i]
visitor.rare = block not in common
next_block = None
if i + 1 < len(blocks):
next_block = blocks[i + 1]
body.emit_label(block)
visitor.next_block = next_block
ops = block.ops
visitor.ops = ops
visitor.op_index = 0
while visitor.op_index < len(ops):
ops[visitor.op_index].accept(visitor)
visitor.op_index += 1
body.emit_line("}")
emitter.emit_from_emitter(declarations)
emitter.emit_from_emitter(body)
class FunctionEmitterVisitor(OpVisitor[None]):
def __init__(
self, emitter: Emitter, declarations: Emitter, source_path: str, module_name: str
) -> None:
self.emitter = emitter
self.names = emitter.names
self.declarations = declarations
self.source_path = source_path
self.module_name = module_name
self.literals = emitter.context.literals
self.rare = False
# Next basic block to be processed after the current one (if any), set by caller
self.next_block: BasicBlock | None = None
# Ops in the basic block currently being processed, set by caller
self.ops: list[Op] = []
# Current index within ops; visit methods can increment this to skip/merge ops
self.op_index = 0
def temp_name(self) -> str:
return self.emitter.temp_name()
def visit_goto(self, op: Goto) -> None:
if op.label is not self.next_block:
self.emit_line("goto %s;" % self.label(op.label))
def error_value_check(self, value: Value, compare: str) -> str:
typ = value.type
if isinstance(typ, RTuple):
# TODO: What about empty tuple?
return self.emitter.tuple_undefined_check_cond(
typ, self.reg(value), self.c_error_value, compare
)
elif isinstance(typ, RVec):
# Error values for vecs are represented by a negative length.
vec_compare = ">=" if compare == "!=" else "<"
return f"{self.reg(value)}.len {vec_compare} 0"
else:
return f"{self.reg(value)} {compare} {self.c_error_value(typ)}"
def visit_branch(self, op: Branch) -> None:
true, false = op.true, op.false
negated = op.negated
negated_rare = False
if true is self.next_block and op.traceback_entry is None:
# Switch true/false since it avoids an else block.
true, false = false, true
negated = not negated
negated_rare = True
neg = "!" if negated else ""
cond = ""
if op.op == Branch.BOOL:
expr_result = self.reg(op.value)
cond = f"{neg}{expr_result}"
elif op.op == Branch.IS_ERROR:
compare = "!=" if negated else "=="
cond = self.error_value_check(op.value, compare)
else:
assert False, "Invalid branch"
# For error checks, tell the compiler the branch is unlikely
if op.traceback_entry is not None or op.rare:
if not negated_rare:
cond = f"unlikely({cond})"
else:
cond = f"likely({cond})"
if false is self.next_block:
if op.traceback_entry is None:
if true is not self.next_block:
self.emit_line(f"if ({cond}) goto {self.label(true)};")
else:
self.emit_line(f"if ({cond}) {{")
self.emit_traceback(op)
self.emit_lines("goto %s;" % self.label(true), "}")
else:
self.emit_line(f"if ({cond}) {{")
self.emit_traceback(op)
if true is not self.next_block:
self.emit_line("goto %s;" % self.label(true))
self.emit_lines("} else", " goto %s;" % self.label(false))
def visit_return(self, op: Return) -> None:
value_str = self.reg(op.value)
self.emit_line("return %s;" % value_str)
def visit_tuple_set(self, op: TupleSet) -> None:
dest = self.reg(op)
tuple_type = op.tuple_type
self.emitter.declare_tuple_struct(tuple_type)
if len(op.items) == 0: # empty tuple
self.emit_line(f"{dest}.empty_struct_error_flag = 0;")
else:
for i, item in enumerate(op.items):
self.emit_line(f"{dest}.f{i} = {self.reg(item)};")
def visit_assign(self, op: Assign) -> None:
dest = self.reg(op.dest)
src = self.reg(op.src)
# clang whines about self assignment (which we might generate
# for some casts), so don't emit it.
if dest != src:
src_type = op.src.type
dest_type = op.dest.type
if src_type.is_unboxed and not dest_type.is_unboxed:
# We sometimes assign from an integer prepresentation of a pointer
# to a real pointer, and C compilers insist on a cast.
src = f"(void *){src}"
elif not src_type.is_unboxed and dest_type.is_unboxed:
# We sometimes assign a pointer to an integer type (e.g. to create
# tagged pointers), and here we need an explicit cast.
src = f"({self.emitter.ctype(dest_type)}){src}"
self.emit_line(f"{dest} = {src};")
def visit_assign_multi(self, op: AssignMulti) -> None:
typ = op.dest.type
assert isinstance(typ, RArray), typ
dest = self.reg(op.dest)
# RArray values can only be assigned to once, so we can always
# declare them on initialization.
self.emit_line(
"%s%s[%d] = %s;"
% (
self.emitter.ctype_spaced(typ.item_type),
dest,
len(op.src),
c_array_initializer([self.reg(s) for s in op.src], indented=True),
)
)
def visit_load_error_value(self, op: LoadErrorValue) -> None:
reg = self.reg(op)
if isinstance(op.type, RTuple):
values = [self.c_undefined_value(item) for item in op.type.types]
tmp = self.temp_name()
self.emit_line("{} {} = {{ {} }};".format(self.ctype(op.type), tmp, ", ".join(values)))
self.emit_line(f"{reg} = {tmp};")
elif isinstance(op.type, RVec):
self.emitter.set_undefined_value(reg, op.type)
else:
self.emit_line(f"{self.reg(op)} = {self.c_error_value(op.type)};")
def visit_load_literal(self, op: LoadLiteral) -> None:
index = self.literals.literal_index(op.value)
if not is_int_rprimitive(op.type):
self.emit_line("%s = CPyStatics[%d];" % (self.reg(op), index), ann=op.value)
else:
self.emit_line(
"%s = (CPyTagged)CPyStatics[%d] | 1;" % (self.reg(op), index), ann=op.value
)
def get_attr_expr(self, obj: str, op: GetAttr | SetAttr, decl_cl: ClassIR) -> str:
"""Generate attribute accessor for normal (non-property) access.
This either has a form like obj->attr_name for attributes defined in non-trait
classes, and *(obj + attr_offset) for attributes defined by traits. We also
insert all necessary C casts here.
"""
cast = f"({op.class_type.struct_name(self.emitter.names)} *)"
if decl_cl.is_trait and op.class_type.class_ir.is_trait:
# For pure trait access find the offset first, offsets
# are ordered by attribute position in the cl.attributes dict.
# TODO: pre-calculate the mapping to make this faster.
trait_attr_index = list(decl_cl.attributes).index(op.attr)
# TODO: reuse these names somehow?
offset = self.emitter.temp_name()
self.declarations.emit_line(f"size_t {offset};")
self.emitter.emit_line(
"{} = {};".format(
offset,
"CPy_FindAttrOffset({}, {}, {})".format(
self.emitter.type_struct_name(decl_cl),
f"({cast}{obj})->vtable",
trait_attr_index,
),
)
)
attr_cast = f"({self.ctype(op.class_type.attr_type(op.attr))} *)"
return f"*{attr_cast}((char *){obj} + {offset})"
else:
# Cast to something non-trait. Note: for this to work, all struct
# members for non-trait classes must obey monotonic linear growth.
if op.class_type.class_ir.is_trait:
assert not decl_cl.is_trait
cast = f"({decl_cl.struct_name(self.emitter.names)} *)"
return f"({cast}{obj})->{self.emitter.attr(op.attr)}"
def visit_get_attr(self, op: GetAttr) -> None:
if op.allow_error_value:
self.get_attr_with_allow_error_value(op)
return
dest = self.reg(op)
obj = self.reg(op.obj)
rtype = op.class_type
cl = rtype.class_ir
attr_rtype, decl_cl = cl.attr_details(op.attr)
prefer_method = cl.is_trait and attr_rtype.error_overlap
if cl.get_method(op.attr, prefer_method=prefer_method):
# Properties are essentially methods, so use vtable access for them.
if cl.is_method_final(op.attr):
self.emit_method_call(f"{dest} = ", op.obj, op.attr, [])
else:
version = "_TRAIT" if cl.is_trait else ""
self.emit_line(
"%s = CPY_GET_ATTR%s(%s, %s, %d, %s, %s); /* %s */"
% (
dest,
version,
obj,
self.emitter.type_struct_name(rtype.class_ir),
rtype.getter_index(op.attr),
rtype.struct_name(self.names),
self.ctype(rtype.attr_type(op.attr)),
op.attr,
)
)
else:
# Otherwise, use direct or offset struct access.
attr_expr = self.get_attr_expr(obj, op, decl_cl)
self.emitter.emit_line(f"{dest} = {attr_expr};")
always_defined = cl.is_always_defined(op.attr)
merged_branch = None
if not always_defined:
self.emitter.emit_undefined_attr_check(
attr_rtype, dest, "==", obj, op.attr, cl, unlikely=True
)
branch = self.next_branch()
if branch is not None:
if (
branch.value is op
and branch.op == Branch.IS_ERROR
and branch.traceback_entry is not None
and not branch.negated
):
# Generate code for the following branch here to avoid
# redundant branches in the generated code.
self.emit_attribute_error(branch, cl.name, op.attr)
self.emit_line("goto %s;" % self.label(branch.true))
merged_branch = branch
self.emitter.emit_line("}")
if not merged_branch:
exc_class = "PyExc_AttributeError"
self.emitter.emit_line(
'PyErr_SetString({}, "attribute {} of {} undefined");'.format(
exc_class,
repr(op.attr.removeprefix(GENERATOR_ATTRIBUTE_PREFIX)),
repr(cl.name),
)
)
if attr_rtype.is_refcounted and not op.is_borrowed:
if not merged_branch and not always_defined:
self.emitter.emit_line("} else {")
self.emitter.emit_inc_ref(dest, attr_rtype)
if merged_branch:
if merged_branch.false is not self.next_block:
self.emit_line("goto %s;" % self.label(merged_branch.false))
self.op_index += 1
elif not always_defined:
self.emitter.emit_line("}")
def get_attr_with_allow_error_value(self, op: GetAttr) -> None:
"""Handle GetAttr with allow_error_value=True.
This allows NULL or other error value without raising AttributeError.
"""
dest = self.reg(op)
obj = self.reg(op.obj)
rtype = op.class_type
cl = rtype.class_ir
attr_rtype, decl_cl = cl.attr_details(op.attr)
# Direct struct access without NULL check
attr_expr = self.get_attr_expr(obj, op, decl_cl)
self.emitter.emit_line(f"{dest} = {attr_expr};")
# Only emit inc_ref if not NULL
if attr_rtype.is_refcounted and not op.is_borrowed:
check = self.error_value_check(op, "!=")
self.emitter.emit_line(f"if ({check}) {{")
self.emitter.emit_inc_ref(dest, attr_rtype)
self.emitter.emit_line("}")
def next_branch(self) -> Branch | None:
if self.op_index + 1 < len(self.ops):
next_op = self.ops[self.op_index + 1]
if isinstance(next_op, Branch):
return next_op
return None
def visit_set_attr(self, op: SetAttr) -> None:
if op.error_kind == ERR_FALSE:
dest = self.reg(op)
obj = self.reg(op.obj)
src = self.reg(op.src)
rtype = op.class_type
cl = rtype.class_ir
attr_rtype, decl_cl = cl.attr_details(op.attr)
if op.is_propset:
# Again, use vtable access for properties...
assert not op.is_init and op.error_kind == ERR_FALSE, "%s %d %d %s" % (
op.attr,
op.is_init,
op.error_kind,
rtype,
)
version = "_TRAIT" if cl.is_trait else ""
self.emit_line(
"%s = CPY_SET_ATTR%s(%s, %s, %d, %s, %s, %s); /* %s */"
% (
dest,
version,
obj,
self.emitter.type_struct_name(rtype.class_ir),
rtype.setter_index(op.attr),
src,
rtype.struct_name(self.names),
self.ctype(rtype.attr_type(op.attr)),
op.attr,
)
)
else:
# ...and struct access for normal attributes.
attr_expr = self.get_attr_expr(obj, op, decl_cl)
if not op.is_init and attr_rtype.is_refcounted:
# This is not an initialization (where we know that the attribute was
# previously undefined), so decref the old value.
always_defined = cl.is_always_defined(op.attr)
if not always_defined:
self.emitter.emit_undefined_attr_check(
attr_rtype, attr_expr, "!=", obj, op.attr, cl
)
self.emitter.emit_dec_ref(attr_expr, attr_rtype)
if not always_defined:
self.emitter.emit_line("}")
elif attr_rtype.error_overlap and not cl.is_always_defined(op.attr):
# If there is overlap with the error value, update bitmap to mark
# attribute as defined.
self.emitter.emit_attr_bitmap_set(src, obj, attr_rtype, cl, op.attr)
# This steals the reference to src, so we don't need to increment the arg
self.emitter.emit_line(f"{attr_expr} = {src};")
if op.error_kind == ERR_FALSE:
self.emitter.emit_line(f"{dest} = 1;")
def visit_load_static(self, op: LoadStatic) -> None:
dest = self.reg(op)
prefix = PREFIX_MAP[op.namespace]
name = self.emitter.static_name(op.identifier, op.module_name, prefix)
if op.namespace == NAMESPACE_TYPE:
name = "(PyObject *)%s" % name
self.emit_line(f"{dest} = {name};", ann=op.ann)
def visit_init_static(self, op: InitStatic) -> None:
value = self.reg(op.value)
prefix = PREFIX_MAP[op.namespace]
name = self.emitter.static_name(op.identifier, op.module_name, prefix)
if op.namespace == NAMESPACE_TYPE:
value = "(PyTypeObject *)%s" % value
self.emit_line(f"{name} = {value};")
self.emit_inc_ref(name, op.value.type)
def visit_tuple_get(self, op: TupleGet) -> None:
dest = self.reg(op)
src = self.reg(op.src)
self.emit_line(f"{dest} = {src}.f{op.index};")
if not op.is_borrowed:
self.emit_inc_ref(dest, op.type)
def get_dest_assign(self, dest: Value) -> str:
if not dest.is_void:
return self.reg(dest) + " = "
else:
return ""
def visit_call(self, op: Call) -> None:
"""Call native function."""
dest = self.get_dest_assign(op)
args = ", ".join(self.reg(arg) for arg in op.args)
lib = self.emitter.get_group_prefix(op.fn)
cname = op.fn.cname(self.names)
self.emit_line(f"{dest}{lib}{NATIVE_PREFIX}{cname}({args});")
def visit_method_call(self, op: MethodCall) -> None:
"""Call native method."""
dest = self.get_dest_assign(op)
self.emit_method_call(dest, op.obj, op.method, op.args)
def emit_method_call(self, dest: str, op_obj: Value, name: str, op_args: list[Value]) -> None:
obj = self.reg(op_obj)
rtype = op_obj.type
assert isinstance(rtype, RInstance), rtype
class_ir = rtype.class_ir
method = rtype.class_ir.get_method(name)
assert method is not None
# Can we call the method directly, bypassing vtable?
is_direct = class_ir.is_method_final(name)
# The first argument gets omitted for static methods and
# turned into the class for class methods
obj_args = (
[]
if method.decl.kind == FUNC_STATICMETHOD
else [f"(PyObject *)Py_TYPE({obj})"] if method.decl.kind == FUNC_CLASSMETHOD else [obj]
)
args = ", ".join(obj_args + [self.reg(arg) for arg in op_args])
mtype = native_function_type(method, self.emitter)
version = "_TRAIT" if rtype.class_ir.is_trait else ""
if is_direct:
# Directly call method, without going through the vtable.
lib = self.emitter.get_group_prefix(method.decl)
self.emit_line(f"{dest}{lib}{NATIVE_PREFIX}{method.cname(self.names)}({args});")
else:
# Call using vtable.
method_idx = rtype.method_index(name)
self.emit_line(
"{}CPY_GET_METHOD{}({}, {}, {}, {}, {})({}); /* {} */".format(
dest,
version,
obj,
self.emitter.type_struct_name(rtype.class_ir),
method_idx,
rtype.struct_name(self.names),
mtype,
args,
name,
)
)
def visit_inc_ref(self, op: IncRef) -> None:
if (
isinstance(op.src, Box)
and (is_none_rprimitive(op.src.src.type) or is_bool_or_bit_rprimitive(op.src.src.type))
and HAVE_IMMORTAL
):
# On Python 3.12+, None/True/False are immortal, and we can skip inc ref
return
if isinstance(op.src, LoadLiteral) and HAVE_IMMORTAL:
value = op.src.value
# We can skip inc ref for immortal literals on Python 3.12+
if type(value) is int and -5 <= value <= 256:
# Small integers are immortal
return
src = self.reg(op.src)
self.emit_inc_ref(src, op.src.type)
def visit_dec_ref(self, op: DecRef) -> None:
src = self.reg(op.src)
self.emit_dec_ref(src, op.src.type, is_xdec=op.is_xdec)
def visit_box(self, op: Box) -> None:
self.emitter.emit_box(self.reg(op.src), self.reg(op), op.src.type, can_borrow=True)
def visit_cast(self, op: Cast) -> None:
if op.is_unchecked and op.is_borrowed:
self.emit_line(f"{self.reg(op)} = {self.reg(op.src)};")
return
branch = self.next_branch()
handler = None
if branch is not None:
if (
branch.value is op
and branch.op == Branch.IS_ERROR
and branch.traceback_entry is not None
and not branch.negated
and branch.false is self.next_block
):
# Generate code also for the following branch here to avoid
# redundant branches in the generated code.
handler = TracebackAndGotoHandler(
self.label(branch.true),
self.source_path,
self.module_name,
branch.traceback_entry,
)
self.op_index += 1
self.emitter.emit_cast(
self.reg(op.src), self.reg(op), op.type, src_type=op.src.type, error=handler
)
def visit_unbox(self, op: Unbox) -> None:
self.emitter.emit_unbox(self.reg(op.src), self.reg(op), op.type)
def visit_unreachable(self, op: Unreachable) -> None:
self.emitter.emit_line("CPy_Unreachable();")
def visit_raise_standard_error(self, op: RaiseStandardError) -> None:
# TODO: Better escaping of backspaces and such
if op.value is not None:
if isinstance(op.value, str):
message = op.value.replace('"', '\\"')
self.emitter.emit_line(f'PyErr_SetString(PyExc_{op.class_name}, "{message}");')
elif isinstance(op.value, Value):
self.emitter.emit_line(
"PyErr_SetObject(PyExc_{}, {});".format(
op.class_name, self.emitter.reg(op.value)
)
)
else:
assert False, "op value type must be either str or Value"
else:
self.emitter.emit_line(f"PyErr_SetNone(PyExc_{op.class_name});")
self.emitter.emit_line(f"{self.reg(op)} = 0;")
def visit_call_c(self, op: CallC) -> None:
if op.is_void:
dest = ""
else:
dest = self.get_dest_assign(op)
args = ", ".join(self.reg(arg) for arg in op.args)
self.emitter.emit_line(f"{dest}{op.function_name}({args});")
def visit_primitive_op(self, op: PrimitiveOp) -> None:
raise RuntimeError(
f"unexpected PrimitiveOp {op.desc.name}: they must be lowered before codegen"
)
def visit_truncate(self, op: Truncate) -> None:
dest = self.reg(op)
value = self.reg(op.src)
# for C backend the generated code are straight assignments
self.emit_line(f"{dest} = {value};")
def visit_extend(self, op: Extend) -> None:
dest = self.reg(op)
value = self.reg(op.src)
if op.signed:
src_cast = self.emit_signed_int_cast(op.src.type)
else:
src_cast = self.emit_unsigned_int_cast(op.src.type)
self.emit_line(f"{dest} = {src_cast}{value};")
def visit_load_global(self, op: LoadGlobal) -> None:
dest = self.reg(op)
self.emit_line(f"{dest} = {op.identifier};", ann=op.ann)
def visit_int_op(self, op: IntOp) -> None:
dest = self.reg(op)
lhs = self.reg(op.lhs)
rhs = self.reg(op.rhs)
if op.op == IntOp.RIGHT_SHIFT:
# Signed right shift
lhs = self.emit_signed_int_cast(op.lhs.type) + lhs
rhs = self.emit_signed_int_cast(op.rhs.type) + rhs
self.emit_line(f"{dest} = {lhs} {op.op_str[op.op]} {rhs};")
def visit_comparison_op(self, op: ComparisonOp) -> None:
dest = self.reg(op)
lhs = self.reg(op.lhs)
rhs = self.reg(op.rhs)
lhs_cast = ""
rhs_cast = ""
if op.op in (ComparisonOp.SLT, ComparisonOp.SGT, ComparisonOp.SLE, ComparisonOp.SGE):
# Always signed comparison op
lhs_cast = self.emit_signed_int_cast(op.lhs.type)
rhs_cast = self.emit_signed_int_cast(op.rhs.type)
elif op.op in (ComparisonOp.ULT, ComparisonOp.UGT, ComparisonOp.ULE, ComparisonOp.UGE):
# Always unsigned comparison op
lhs_cast = self.emit_unsigned_int_cast(op.lhs.type)
rhs_cast = self.emit_unsigned_int_cast(op.rhs.type)
elif isinstance(op.lhs, Integer) and op.lhs.value < 0:
# Force signed ==/!= with negative operand
rhs_cast = self.emit_signed_int_cast(op.rhs.type)
elif isinstance(op.rhs, Integer) and op.rhs.value < 0:
# Force signed ==/!= with negative operand
lhs_cast = self.emit_signed_int_cast(op.lhs.type)
self.emit_line(f"{dest} = {lhs_cast}{lhs} {op.op_str[op.op]} {rhs_cast}{rhs};")
def visit_float_op(self, op: FloatOp) -> None:
dest = self.reg(op)
lhs = self.reg(op.lhs)
rhs = self.reg(op.rhs)
if op.op != FloatOp.MOD:
self.emit_line(f"{dest} = {lhs} {op.op_str[op.op]} {rhs};")
else:
# TODO: This may set errno as a side effect, that is a little sketchy.
self.emit_line(f"{dest} = fmod({lhs}, {rhs});")
def visit_float_neg(self, op: FloatNeg) -> None:
dest = self.reg(op)
src = self.reg(op.src)
self.emit_line(f"{dest} = -{src};")
def visit_float_comparison_op(self, op: FloatComparisonOp) -> None:
dest = self.reg(op)
lhs = self.reg(op.lhs)
rhs = self.reg(op.rhs)
self.emit_line(f"{dest} = {lhs} {op.op_str[op.op]} {rhs};")
def visit_load_mem(self, op: LoadMem) -> None:
dest = self.reg(op)
src = self.reg(op.src)
# TODO: we shouldn't dereference to type that are pointer type so far
type = self.ctype(op.type)
self.emit_line(f"{dest} = *({type} *){src};")
if not op.is_borrowed:
self.emit_inc_ref(dest, op.type)
def visit_set_mem(self, op: SetMem) -> None:
dest = self.reg(op.dest)
src = self.reg(op.src)
dest_type = self.ctype(op.dest_type)
# clang whines about self assignment (which we might generate
# for some casts), so don't emit it.
if dest != src:
self.emit_line(f"*({dest_type} *){dest} = {src};")
def visit_get_element(self, op: GetElement) -> None:
dest = self.reg(op)
src = self.reg(op.src)
dest_type = self.ctype(op.type)
self.emit_line(f"{dest} = ({dest_type}){src}.{op.field};")
def visit_get_element_ptr(self, op: GetElementPtr) -> None:
dest = self.reg(op)
src = self.reg(op.src)
# TODO: support tuple type
assert isinstance(op.src_type, RStruct), op.src_type
assert op.field in op.src_type.names, "Invalid field name."
# Use offsetof to avoid undefined behavior when src is NULL
# (e.g., vec buf pointer for empty vecs). The &((T*)p)->field
# pattern is UB when p is NULL, which GCC -O3 can exploit.
self.emit_line(
"{} = ({})((CPyPtr){} + offsetof({}, {}));".format(
dest, op.type._ctype, src, op.src_type.name, op.field
)
)
def visit_set_element(self, op: SetElement) -> None:
dest = self.reg(op)
item = self.reg(op.item)
field = op.field
if isinstance(op.src, Undef):
# First assignment to an undefined struct is trivial.
self.emit_line(f"{dest}.{field} = {item};")
else:
# In the general case create a copy of the struct with a single
# item modified.
#
# TODO: Can we do better if only a subset of fields are initialized?
# TODO: Make this less verbose in the common case
# TODO: Support tuples (or use RStruct for tuples)?
src = self.reg(op.src)
src_type = op.src.type
assert isinstance(src_type, RStruct), src_type
init_items = []
for n in src_type.names:
if n != field:
init_items.append(f"{src}.{n}")
else:
init_items.append(item)
self.emit_line(f"{dest} = ({self.ctype(src_type)}) {{ {', '.join(init_items)} }};")
def visit_load_address(self, op: LoadAddress) -> None:
typ = op.type
dest = self.reg(op)
if isinstance(op.src, Register):
src = self.reg(op.src)
elif isinstance(op.src, LoadStatic):
prefix = PREFIX_MAP[op.src.namespace]
src = self.emitter.static_name(op.src.identifier, op.src.module_name, prefix)
else:
src = op.src
self.emit_line(f"{dest} = ({typ._ctype})&{src};")
def visit_keep_alive(self, op: KeepAlive) -> None:
# This is a no-op.
pass
def visit_unborrow(self, op: Unborrow) -> None:
# This is a no-op that propagates the source value.
dest = self.reg(op)
src = self.reg(op.src)
self.emit_line(f"{dest} = {src};")
# Helpers
def label(self, label: BasicBlock) -> str:
return self.emitter.label(label)
def reg(self, reg: Value) -> str:
if isinstance(reg, Integer):
val = reg.value
if val == 0 and is_pointer_rprimitive(reg.type):
return "NULL"
s = str(val)
if val >= (1 << 31):
# Avoid overflowing signed 32-bit int
if val >= (1 << 63):
s += "ULL"
else:
s += "LL"
elif val == -(1 << 63):
# Avoid overflowing C integer literal
s = "(-9223372036854775807LL - 1)"
elif val <= -(1 << 31):
s += "LL"
return s
elif isinstance(reg, Float):
r = repr(reg.value)
if r == "inf":
return "INFINITY"
elif r == "-inf":
return "-INFINITY"
elif r == "nan":
return "NAN"
return r
elif isinstance(reg, CString):
return '"' + encode_c_string_literal(reg.value) + '"'
else:
return self.emitter.reg(reg)
def ctype(self, rtype: RType) -> str:
return self.emitter.ctype(rtype)
def c_error_value(self, rtype: RType) -> str:
return self.emitter.c_error_value(rtype)
def c_undefined_value(self, rtype: RType) -> str:
return self.emitter.c_undefined_value(rtype)
def emit_line(self, line: str, *, ann: object = None) -> None:
self.emitter.emit_line(line, ann=ann)
def emit_lines(self, *lines: str) -> None:
self.emitter.emit_lines(*lines)
def emit_inc_ref(self, dest: str, rtype: RType) -> None:
self.emitter.emit_inc_ref(dest, rtype, rare=self.rare)
def emit_dec_ref(self, dest: str, rtype: RType, is_xdec: bool) -> None:
self.emitter.emit_dec_ref(dest, rtype, is_xdec=is_xdec, rare=self.rare)
def emit_declaration(self, line: str) -> None:
self.declarations.emit_line(line)
def emit_traceback(self, op: Branch) -> None:
if op.traceback_entry is not None:
self.emitter.emit_traceback(self.source_path, self.module_name, op.traceback_entry)
def emit_attribute_error(self, op: Branch, class_name: str, attr: str) -> None:
assert op.traceback_entry is not None
if self.emitter.context.strict_traceback_checks:
assert (
op.traceback_entry[1] >= 0
), "AttributeError traceback cannot have a negative line number"
globals_static = self.emitter.static_name("globals", self.module_name)
self.emit_line(
'CPy_AttributeError("%s", "%s", "%s", "%s", %d, %s);'
% (
self.source_path.replace("\\", "\\\\"),
op.traceback_entry[0],
class_name,
attr.removeprefix(GENERATOR_ATTRIBUTE_PREFIX),
op.traceback_entry[1],
globals_static,
)
)
if DEBUG_ERRORS:
self.emit_line('assert(PyErr_Occurred() != NULL && "failure w/o err!");')
def emit_signed_int_cast(self, type: RType) -> str:
if is_tagged(type):
return "(Py_ssize_t)"
else:
return ""
def emit_unsigned_int_cast(self, type: RType) -> str:
if is_int32_rprimitive(type):
return "(uint32_t)"
elif is_int64_rprimitive(type):
return "(uint64_t)"
else:
return ""
_translation_table: Final[dict[int, str]] = {}
def encode_c_string_literal(b: bytes) -> str:
"""Convert bytestring to the C string literal syntax (with necessary escaping).
For example, b'foo\n' gets converted to 'foo\\n' (note that double quotes are not added).
"""
if not _translation_table:
# Initialize the translation table on the first call.
d = {
ord("\n"): "\\n",
ord("\r"): "\\r",
ord("\t"): "\\t",
ord('"'): '\\"',
ord("\\"): "\\\\",
}
for i in range(256):
if i not in d:
if i < 32 or i >= 127:
d[i] = "\\x%.2x" % i
else:
d[i] = chr(i)
_translation_table.update(str.maketrans(d))
return b.decode("latin1").translate(_translation_table)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,978 @@
"""Generate CPython API wrapper functions for native functions.
The wrapper functions are used by the CPython runtime when calling
native functions from interpreted code, and when the called function
can't be determined statically in compiled code. They validate, match,
unbox and type check function arguments, and box return values as
needed. All wrappers accept and return 'PyObject *' (boxed) values.
The wrappers aren't used for most calls between two native functions
or methods in a single compilation unit.
"""
from __future__ import annotations
from collections.abc import Sequence
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, ARG_OPT, ARG_POS, ARG_STAR, ARG_STAR2, ArgKind
from mypy.operators import op_methods_to_symbols, reverse_op_method_names, reverse_op_methods
from mypyc.codegen.emit import AssignHandler, Emitter, ErrorHandler, GotoHandler, ReturnHandler
from mypyc.common import (
BITMAP_BITS,
BITMAP_TYPE,
DUNDER_PREFIX,
NATIVE_PREFIX,
PREFIX,
bitmap_name,
)
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.func_ir import FUNC_STATICMETHOD, FuncIR, RuntimeArg
from mypyc.ir.rtypes import (
RInstance,
RType,
is_bool_rprimitive,
is_int_rprimitive,
is_object_rprimitive,
object_rprimitive,
)
from mypyc.namegen import NameGenerator
# Generic vectorcall wrapper functions (Python 3.7+)
#
# A wrapper function has a signature like this:
#
# PyObject *fn(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
#
# The function takes a self object, pointer to an array of arguments,
# the number of positional arguments, and a tuple of keyword argument
# names (that are stored starting in args[nargs]).
#
# It returns the returned object, or NULL on an exception.
#
# These are more efficient than legacy wrapper functions, since
# usually no tuple or dict objects need to be created for the
# arguments. Vectorcalls also use pre-constructed str objects for
# keyword argument names and other pre-computed information, instead
# of processing the argument format string on each call.
def wrapper_function_header(fn: FuncIR, names: NameGenerator) -> str:
"""Return header of a vectorcall wrapper function.
See comment above for a summary of the arguments.
"""
assert not fn.internal
return (
"PyObject *{prefix}{name}("
"PyObject *self, PyObject *const *args, size_t nargs, PyObject *kwnames)"
).format(prefix=PREFIX, name=fn.cname(names))
def generate_traceback_code(
fn: FuncIR, emitter: Emitter, source_path: str, module_name: str
) -> str:
# If we hit an error while processing arguments, then we emit a
# traceback frame to make it possible to debug where it happened.
# Unlike traceback frames added for exceptions seen in IR, we do this
# even if there is no `traceback_name`. This is because the error will
# have originated here and so we need it in the traceback.
globals_static = emitter.static_name("globals", module_name)
traceback_code = 'CPy_AddTraceback("%s", "%s", %d, %s);' % (
source_path.replace("\\", "\\\\"),
fn.traceback_name or fn.name,
fn.line,
globals_static,
)
return traceback_code
def make_arg_groups(args: list[RuntimeArg]) -> dict[ArgKind, list[RuntimeArg]]:
"""Group arguments by kind."""
return {k: [arg for arg in args if arg.kind == k] for k in ArgKind}
def reorder_arg_groups(groups: dict[ArgKind, list[RuntimeArg]]) -> list[RuntimeArg]:
"""Reorder argument groups to match their order in a format string."""
return groups[ARG_POS] + groups[ARG_OPT] + groups[ARG_NAMED_OPT] + groups[ARG_NAMED]
def make_static_kwlist(args: list[RuntimeArg]) -> str:
arg_names = "".join(f'"{arg.name}", ' for arg in args)
return f"static const char * const kwlist[] = {{{arg_names}0}};"
def make_format_string(func_name: str | None, groups: dict[ArgKind, list[RuntimeArg]]) -> str:
"""Return a format string that specifies the accepted arguments.
The format string is an extended subset of what is supported by
PyArg_ParseTupleAndKeywords(). Only the type 'O' is used, and we
also support some extensions:
- Required keyword-only arguments are introduced after '@'
- If the function receives *args or **kwargs, we add a '%' prefix
Each group requires the previous groups' delimiters to be present
first.
These are used by both vectorcall and legacy wrapper functions.
"""
format = ""
if groups[ARG_STAR] or groups[ARG_STAR2]:
format += "%"
format += "O" * len(groups[ARG_POS])
if groups[ARG_OPT] or groups[ARG_NAMED_OPT] or groups[ARG_NAMED]:
format += "|" + "O" * len(groups[ARG_OPT])
if groups[ARG_NAMED_OPT] or groups[ARG_NAMED]:
format += "$" + "O" * len(groups[ARG_NAMED_OPT])
if groups[ARG_NAMED]:
format += "@" + "O" * len(groups[ARG_NAMED])
if func_name is not None:
format += f":{func_name}"
return format
def generate_wrapper_function(
fn: FuncIR, emitter: Emitter, source_path: str, module_name: str
) -> None:
"""Generate a CPython-compatible vectorcall wrapper for a native function.
In particular, this handles unboxing the arguments, calling the native function, and
then boxing the return value.
"""
emitter.emit_line(f"{wrapper_function_header(fn, emitter.names)} {{")
# If fn is a method, then the first argument is a self param
real_args = list(fn.args)
if fn.sig.num_bitmap_args:
real_args = real_args[: -fn.sig.num_bitmap_args]
if fn.class_name and fn.decl.kind != FUNC_STATICMETHOD:
arg = real_args.pop(0)
emitter.emit_line(f"PyObject *obj_{arg.name} = self;")
# Need to order args as: required, optional, kwonly optional, kwonly required
# This is because CPyArg_ParseStackAndKeywords format string requires
# them grouped in that way.
groups = make_arg_groups(real_args)
reordered_args = reorder_arg_groups(groups)
emitter.emit_line(make_static_kwlist(reordered_args))
fmt = make_format_string(fn.name, groups)
# Define the arguments the function accepts (but no types yet)
emitter.emit_line(f'static CPyArg_Parser parser = {{"{fmt}", kwlist, 0}};')
for arg in real_args:
emitter.emit_line(
"PyObject *obj_{}{};".format(arg.name, " = NULL" if arg.optional else "")
)
cleanups = [f"CPy_DECREF(obj_{arg.name});" for arg in groups[ARG_STAR] + groups[ARG_STAR2]]
arg_ptrs: list[str] = []
if groups[ARG_STAR] or groups[ARG_STAR2]:
arg_ptrs += [f"&obj_{groups[ARG_STAR][0].name}" if groups[ARG_STAR] else "NULL"]
arg_ptrs += [f"&obj_{groups[ARG_STAR2][0].name}" if groups[ARG_STAR2] else "NULL"]
arg_ptrs += [f"&obj_{arg.name}" for arg in reordered_args]
if fn.name == "__call__":
nargs = "PyVectorcall_NARGS(nargs)"
else:
nargs = "nargs"
parse_fn = "CPyArg_ParseStackAndKeywords"
# Special case some common signatures
if not real_args:
# No args
parse_fn = "CPyArg_ParseStackAndKeywordsNoArgs"
elif len(real_args) == 1 and len(groups[ARG_POS]) == 1:
# Single positional arg
parse_fn = "CPyArg_ParseStackAndKeywordsOneArg"
elif len(real_args) == len(groups[ARG_POS]) + len(groups[ARG_OPT]):
# No keyword-only args, *args or **kwargs
parse_fn = "CPyArg_ParseStackAndKeywordsSimple"
emitter.emit_lines(
"if (!{}(args, {}, kwnames, &parser{})) {{".format(
parse_fn, nargs, "".join(", " + n for n in arg_ptrs)
),
"return NULL;",
"}",
)
for i in range(fn.sig.num_bitmap_args):
name = bitmap_name(i)
emitter.emit_line(f"{BITMAP_TYPE} {name} = 0;")
traceback_code = generate_traceback_code(fn, emitter, source_path, module_name)
generate_wrapper_core(
fn,
emitter,
groups[ARG_OPT] + groups[ARG_NAMED_OPT],
cleanups=cleanups,
traceback_code=traceback_code,
)
emitter.emit_line("}")
# Legacy generic wrapper functions
#
# These take a self object, a Python tuple of positional arguments,
# and a dict of keyword arguments. These are a lot slower than
# vectorcall wrappers, especially in calls involving keyword
# arguments.
def legacy_wrapper_function_header(fn: FuncIR, names: NameGenerator) -> str:
return "PyObject *{prefix}{name}(PyObject *self, PyObject *args, PyObject *kw)".format(
prefix=PREFIX, name=fn.cname(names)
)
def generate_legacy_wrapper_function(
fn: FuncIR, emitter: Emitter, source_path: str, module_name: str
) -> None:
"""Generates a CPython-compatible legacy wrapper for a native function.
In particular, this handles unboxing the arguments, calling the native function, and
then boxing the return value.
"""
emitter.emit_line(f"{legacy_wrapper_function_header(fn, emitter.names)} {{")
# If fn is a method, then the first argument is a self param
real_args = list(fn.args)
if fn.sig.num_bitmap_args:
real_args = real_args[: -fn.sig.num_bitmap_args]
if fn.class_name and (fn.decl.name == "__new__" or fn.decl.kind != FUNC_STATICMETHOD):
arg = real_args.pop(0)
emitter.emit_line(f"PyObject *obj_{arg.name} = self;")
# Need to order args as: required, optional, kwonly optional, kwonly required
# This is because CPyArg_ParseTupleAndKeywords format string requires
# them grouped in that way.
groups = make_arg_groups(real_args)
reordered_args = reorder_arg_groups(groups)
emitter.emit_line(make_static_kwlist(reordered_args))
for arg in real_args:
emitter.emit_line(
"PyObject *obj_{}{};".format(arg.name, " = NULL" if arg.optional else "")
)
cleanups = [f"CPy_DECREF(obj_{arg.name});" for arg in groups[ARG_STAR] + groups[ARG_STAR2]]
arg_ptrs: list[str] = []
if groups[ARG_STAR] or groups[ARG_STAR2]:
arg_ptrs += [f"&obj_{groups[ARG_STAR][0].name}" if groups[ARG_STAR] else "NULL"]
arg_ptrs += [f"&obj_{groups[ARG_STAR2][0].name}" if groups[ARG_STAR2] else "NULL"]
arg_ptrs += [f"&obj_{arg.name}" for arg in reordered_args]
emitter.emit_lines(
'if (!CPyArg_ParseTupleAndKeywords(args, kw, "{}", "{}", kwlist{})) {{'.format(
make_format_string(None, groups), fn.name, "".join(", " + n for n in arg_ptrs)
),
"return NULL;",
"}",
)
for i in range(fn.sig.num_bitmap_args):
name = bitmap_name(i)
emitter.emit_line(f"{BITMAP_TYPE} {name} = 0;")
traceback_code = generate_traceback_code(fn, emitter, source_path, module_name)
generate_wrapper_core(
fn,
emitter,
groups[ARG_OPT] + groups[ARG_NAMED_OPT],
cleanups=cleanups,
traceback_code=traceback_code,
)
emitter.emit_line("}")
# Specialized wrapper functions
def generate_dunder_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for native __dunder__ methods to be able to fit into the mapping
protocol slot. This specifically means that the arguments are taken as *PyObjects and returned
as *PyObjects.
"""
gen = WrapperGenerator(cl, emitter)
gen.set_target(fn)
gen.emit_header()
gen.emit_arg_processing()
gen.emit_call()
gen.finish()
return gen.wrapper_name()
def generate_ipow_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generate a wrapper for native __ipow__.
Since __ipow__ fills a ternary slot, but almost no one defines __ipow__ to take three
arguments, the wrapper needs to tweaked to force it to accept three arguments.
"""
gen = WrapperGenerator(cl, emitter)
gen.set_target(fn)
assert len(fn.args) in (2, 3), "__ipow__ should only take 2 or 3 arguments"
gen.arg_names = ["self", "exp", "mod"]
gen.emit_header()
gen.emit_arg_processing()
handle_third_pow_argument(
fn,
emitter,
gen,
if_unsupported=[
'PyErr_SetString(PyExc_TypeError, "__ipow__ takes 2 positional arguments but 3 were given");',
"return NULL;",
],
)
gen.emit_call()
gen.finish()
return gen.wrapper_name()
def generate_bin_op_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for a native binary dunder method.
The same wrapper that handles the forward method (e.g. __add__) also handles
the corresponding reverse method (e.g. __radd__), if defined.
Both arguments and the return value are PyObject *.
"""
gen = WrapperGenerator(cl, emitter)
gen.set_target(fn)
if fn.name in ("__pow__", "__rpow__"):
gen.arg_names = ["left", "right", "mod"]
else:
gen.arg_names = ["left", "right"]
wrapper_name = gen.wrapper_name()
gen.emit_header()
if fn.name not in reverse_op_methods and fn.name in reverse_op_method_names:
# There's only a reverse operator method.
generate_bin_op_reverse_only_wrapper(fn, emitter, gen)
else:
rmethod = reverse_op_methods[fn.name]
fn_rev = cl.get_method(rmethod)
if fn_rev is None:
# There's only a forward operator method.
generate_bin_op_forward_only_wrapper(fn, emitter, gen)
else:
# There's both a forward and a reverse operator method.
generate_bin_op_both_wrappers(cl, fn, fn_rev, emitter, gen)
return wrapper_name
def generate_bin_op_forward_only_wrapper(
fn: FuncIR, emitter: Emitter, gen: WrapperGenerator
) -> None:
gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False)
handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail;"])
gen.emit_call(not_implemented_handler="goto typefail;")
gen.emit_error_handling()
emitter.emit_label("typefail")
# If some argument has an incompatible type, treat this the same as
# returning NotImplemented, and try to call the reverse operator method.
#
# Note that in normal Python you'd instead of an explicit
# return of NotImplemented, but it doesn't generally work here
# the body won't be executed at all if there is an argument
# type check failure.
#
# The recommended way is to still use a type check in the
# body. This will only be used in interpreted mode:
#
# def __add__(self, other: int) -> Foo:
# if not isinstance(other, int):
# return NotImplemented
# ...
generate_bin_op_reverse_dunder_call(fn, emitter, reverse_op_methods[fn.name])
gen.finish()
def generate_bin_op_reverse_only_wrapper(
fn: FuncIR, emitter: Emitter, gen: WrapperGenerator
) -> None:
gen.arg_names = ["right", "left"]
gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False)
handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail;"])
gen.emit_call()
gen.emit_error_handling()
emitter.emit_label("typefail")
emitter.emit_line("Py_INCREF(Py_NotImplemented);")
emitter.emit_line("return Py_NotImplemented;")
gen.finish()
def generate_bin_op_both_wrappers(
cl: ClassIR, fn: FuncIR, fn_rev: FuncIR, emitter: Emitter, gen: WrapperGenerator
) -> None:
# There's both a forward and a reverse operator method. First
# check if we should try calling the forward one. If the
# argument type check fails, fall back to the reverse method.
#
# Similar to above, we can't perfectly match Python semantics.
# In regular Python code you'd return NotImplemented if the
# operand has the wrong type, but in compiled code we'll never
# get to execute the type check.
emitter.emit_line(
"if (PyObject_IsInstance(obj_left, (PyObject *){})) {{".format(
emitter.type_struct_name(cl)
)
)
gen.emit_arg_processing(error=GotoHandler("typefail"), raise_exception=False)
handle_third_pow_argument(fn, emitter, gen, if_unsupported=["goto typefail2;"])
# Ternary __rpow__ calls aren't a thing so immediately bail
# if ternary __pow__ returns NotImplemented.
if fn.name == "__pow__" and len(fn.args) == 3:
fwd_not_implemented_handler = "goto typefail2;"
else:
fwd_not_implemented_handler = "goto typefail;"
gen.emit_call(not_implemented_handler=fwd_not_implemented_handler)
gen.emit_error_handling()
emitter.emit_line("}")
emitter.emit_label("typefail")
emitter.emit_line(
"if (PyObject_IsInstance(obj_right, (PyObject *){})) {{".format(
emitter.type_struct_name(cl)
)
)
gen.set_target(fn_rev)
gen.arg_names = ["right", "left"]
gen.emit_arg_processing(error=GotoHandler("typefail2"), raise_exception=False)
handle_third_pow_argument(fn_rev, emitter, gen, if_unsupported=["goto typefail2;"])
gen.emit_call()
gen.emit_error_handling()
emitter.emit_line("} else {")
generate_bin_op_reverse_dunder_call(fn, emitter, fn_rev.name)
emitter.emit_line("}")
emitter.emit_label("typefail2")
emitter.emit_line("Py_INCREF(Py_NotImplemented);")
emitter.emit_line("return Py_NotImplemented;")
gen.finish()
def generate_bin_op_reverse_dunder_call(fn: FuncIR, emitter: Emitter, rmethod: str) -> None:
if fn.name in ("__pow__", "__rpow__"):
# Ternary pow() will never call the reverse dunder.
emitter.emit_line("if (obj_mod == Py_None) {")
emitter.emit_line(
'return CPy_CallReverseOpMethod(obj_left, obj_right, "{}", mypyc_interned_str.{});'.format(
op_methods_to_symbols[fn.name], rmethod
)
)
if fn.name in ("__pow__", "__rpow__"):
emitter.emit_line("} else {")
emitter.emit_line("Py_INCREF(Py_NotImplemented);")
emitter.emit_line("return Py_NotImplemented;")
emitter.emit_line("}")
def handle_third_pow_argument(
fn: FuncIR, emitter: Emitter, gen: WrapperGenerator, *, if_unsupported: list[str]
) -> None:
if fn.name not in ("__pow__", "__rpow__", "__ipow__"):
return
if (fn.name in ("__pow__", "__ipow__") and len(fn.args) == 2) or fn.name == "__rpow__":
# If the power dunder only supports two arguments and the third
# argument (AKA mod) is set to a non-default value, simply bail.
#
# Importantly, this prevents any ternary __rpow__ calls from
# happening (as per the language specification).
emitter.emit_line("if (obj_mod != Py_None) {")
for line in if_unsupported:
emitter.emit_line(line)
emitter.emit_line("}")
# The slot wrapper will receive three arguments, but the call only
# supports two so make sure that the third argument isn't passed
# along. This is needed as two-argument __(i)pow__ is allowed and
# rather common.
if len(gen.arg_names) == 3:
gen.arg_names.pop()
RICHCOMPARE_OPS = {
"__lt__": "Py_LT",
"__gt__": "Py_GT",
"__le__": "Py_LE",
"__ge__": "Py_GE",
"__eq__": "Py_EQ",
"__ne__": "Py_NE",
}
def generate_richcompare_wrapper(cl: ClassIR, emitter: Emitter) -> str | None:
"""Generates a wrapper for richcompare dunder methods."""
# Sort for determinism on Python 3.5
matches = sorted(name for name in RICHCOMPARE_OPS if cl.has_method(name))
if not matches:
return None
name = f"{DUNDER_PREFIX}_RichCompare_{cl.name_prefix(emitter.names)}"
emitter.emit_line(
"static PyObject *{name}(PyObject *obj_lhs, PyObject *obj_rhs, int op) {{".format(
name=name
)
)
emitter.emit_line("switch (op) {")
for func in matches:
emitter.emit_line(f"case {RICHCOMPARE_OPS[func]}: {{")
method = cl.get_method(func)
assert method is not None
generate_wrapper_core(method, emitter, arg_names=["lhs", "rhs"])
emitter.emit_line("}")
emitter.emit_line("}")
emitter.emit_line("Py_INCREF(Py_NotImplemented);")
emitter.emit_line("return Py_NotImplemented;")
emitter.emit_line("}")
return name
def generate_get_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for native __get__ methods."""
name = f"{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}"
emitter.emit_line(
"static PyObject *{name}(PyObject *self, PyObject *instance, PyObject *owner) {{".format(
name=name
)
)
emitter.emit_line("instance = instance ? instance : Py_None;")
emitter.emit_line(f"return {NATIVE_PREFIX}{fn.cname(emitter.names)}(self, instance, owner);")
emitter.emit_line("}")
return name
def generate_hash_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for native __hash__ methods."""
name = f"{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}"
emitter.emit_line(f"static Py_ssize_t {name}(PyObject *self) {{")
emitter.emit_line(
"{}retval = {}{}{}(self);".format(
emitter.ctype_spaced(fn.ret_type),
emitter.get_group_prefix(fn.decl),
NATIVE_PREFIX,
fn.cname(emitter.names),
)
)
emitter.emit_error_check("retval", fn.ret_type, "return -1;")
if is_int_rprimitive(fn.ret_type):
emitter.emit_line("Py_ssize_t val = CPyTagged_AsSsize_t(retval);")
else:
emitter.emit_line("Py_ssize_t val = PyLong_AsSsize_t(retval);")
emitter.emit_dec_ref("retval", fn.ret_type)
emitter.emit_line("if (PyErr_Occurred()) return -1;")
# We can't return -1 from a hash function..
emitter.emit_line("if (val == -1) return -2;")
emitter.emit_line("return val;")
emitter.emit_line("}")
return name
def generate_len_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for native __len__ methods."""
name = f"{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}"
emitter.emit_line(f"static Py_ssize_t {name}(PyObject *self) {{")
emitter.emit_line(
"{}retval = {}{}{}(self);".format(
emitter.ctype_spaced(fn.ret_type),
emitter.get_group_prefix(fn.decl),
NATIVE_PREFIX,
fn.cname(emitter.names),
)
)
emitter.emit_error_check("retval", fn.ret_type, "return -1;")
if is_int_rprimitive(fn.ret_type):
emitter.emit_line("Py_ssize_t val = CPyTagged_AsSsize_t(retval);")
else:
emitter.emit_line("Py_ssize_t val = PyLong_AsSsize_t(retval);")
emitter.emit_dec_ref("retval", fn.ret_type)
emitter.emit_line("if (PyErr_Occurred()) return -1;")
emitter.emit_line("return val;")
emitter.emit_line("}")
return name
def generate_bool_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for native __bool__ methods."""
name = f"{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}"
emitter.emit_line(f"static int {name}(PyObject *self) {{")
emitter.emit_line(
"{}val = {}{}(self);".format(
emitter.ctype_spaced(fn.ret_type), NATIVE_PREFIX, fn.cname(emitter.names)
)
)
emitter.emit_error_check("val", fn.ret_type, "return -1;")
# This wouldn't be that hard to fix but it seems unimportant and
# getting error handling and unboxing right would be fiddly. (And
# way easier to do in IR!)
assert is_bool_rprimitive(fn.ret_type), "Only bool return supported for __bool__"
emitter.emit_line("return val;")
emitter.emit_line("}")
return name
def generate_del_item_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for native __delitem__.
This is only called from a combined __delitem__/__setitem__ wrapper.
"""
name = "{}{}{}".format(DUNDER_PREFIX, "__delitem__", cl.name_prefix(emitter.names))
input_args = ", ".join(f"PyObject *obj_{arg.name}" for arg in fn.args)
emitter.emit_line(f"static int {name}({input_args}) {{")
generate_set_del_item_wrapper_inner(fn, emitter, fn.args)
return name
def generate_set_del_item_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for native __setitem__ method (also works for __delitem__).
This is used with the mapping protocol slot. Arguments are taken as *PyObjects and we
return a negative C int on error.
Create a separate wrapper function for __delitem__ as needed and have the
__setitem__ wrapper call it if the value is NULL. Return the name
of the outer (__setitem__) wrapper.
"""
method_cls = cl.get_method_and_class("__delitem__")
del_name = None
if method_cls and method_cls[1] == cl:
# Generate a separate wrapper for __delitem__
del_name = generate_del_item_wrapper(cl, method_cls[0], emitter)
args = fn.args
if fn.name == "__delitem__":
# Add an extra argument for value that we expect to be NULL.
args = list(args) + [RuntimeArg("___value", object_rprimitive, ARG_POS)]
name = "{}{}{}".format(DUNDER_PREFIX, "__setitem__", cl.name_prefix(emitter.names))
input_args = ", ".join(f"PyObject *obj_{arg.name}" for arg in args)
emitter.emit_line(f"static int {name}({input_args}) {{")
# First check if this is __delitem__
emitter.emit_line(f"if (obj_{args[2].name} == NULL) {{")
if del_name is not None:
# We have a native implementation, so call it
emitter.emit_line(f"return {del_name}(obj_{args[0].name}, obj_{args[1].name});")
else:
# Try to call superclass method instead
emitter.emit_line(f"PyObject *super = CPy_Super(CPyModule_builtins, obj_{args[0].name});")
emitter.emit_line("if (super == NULL) return -1;")
emitter.emit_line(
'PyObject *result = PyObject_CallMethod(super, "__delitem__", "O", obj_{});'.format(
args[1].name
)
)
emitter.emit_line("Py_DECREF(super);")
emitter.emit_line("Py_XDECREF(result);")
emitter.emit_line("return result == NULL ? -1 : 0;")
emitter.emit_line("}")
method_cls = cl.get_method_and_class("__setitem__")
if method_cls and method_cls[1] == cl:
generate_set_del_item_wrapper_inner(fn, emitter, args)
else:
emitter.emit_line(f"PyObject *super = CPy_Super(CPyModule_builtins, obj_{args[0].name});")
emitter.emit_line("if (super == NULL) return -1;")
emitter.emit_line("PyObject *result;")
if method_cls is None and cl.builtin_base is None:
msg = f"'{cl.name}' object does not support item assignment"
emitter.emit_line(f'PyErr_SetString(PyExc_TypeError, "{msg}");')
emitter.emit_line("result = NULL;")
else:
# A base class may have __setitem__
emitter.emit_line(
'result = PyObject_CallMethod(super, "__setitem__", "OO", obj_{}, obj_{});'.format(
args[1].name, args[2].name
)
)
emitter.emit_line("Py_DECREF(super);")
emitter.emit_line("Py_XDECREF(result);")
emitter.emit_line("return result == NULL ? -1 : 0;")
emitter.emit_line("}")
return name
def generate_set_del_item_wrapper_inner(
fn: FuncIR, emitter: Emitter, args: Sequence[RuntimeArg]
) -> None:
for arg in args:
generate_arg_check(arg.name, arg.type, emitter, GotoHandler("fail"))
native_args = ", ".join(f"arg_{arg.name}" for arg in args)
emitter.emit_line(
"{}val = {}{}({});".format(
emitter.ctype_spaced(fn.ret_type), NATIVE_PREFIX, fn.cname(emitter.names), native_args
)
)
emitter.emit_error_check("val", fn.ret_type, "goto fail;")
emitter.emit_dec_ref("val", fn.ret_type)
emitter.emit_line("return 0;")
emitter.emit_label("fail")
emitter.emit_line("return -1;")
emitter.emit_line("}")
def generate_contains_wrapper(cl: ClassIR, fn: FuncIR, emitter: Emitter) -> str:
"""Generates a wrapper for a native __contains__ method."""
name = f"{DUNDER_PREFIX}{fn.name}{cl.name_prefix(emitter.names)}"
emitter.emit_line(f"static int {name}(PyObject *self, PyObject *obj_item) {{")
generate_arg_check("item", fn.args[1].type, emitter, ReturnHandler("-1"))
emitter.emit_line(
"{}val = {}{}(self, arg_item);".format(
emitter.ctype_spaced(fn.ret_type), NATIVE_PREFIX, fn.cname(emitter.names)
)
)
emitter.emit_error_check("val", fn.ret_type, "return -1;")
if is_bool_rprimitive(fn.ret_type):
emitter.emit_line("return val;")
else:
emitter.emit_line("int boolval = PyObject_IsTrue(val);")
emitter.emit_dec_ref("val", fn.ret_type)
emitter.emit_line("return boolval;")
emitter.emit_line("}")
return name
# Helpers
def generate_wrapper_core(
fn: FuncIR,
emitter: Emitter,
optional_args: list[RuntimeArg] | None = None,
arg_names: list[str] | None = None,
cleanups: list[str] | None = None,
traceback_code: str | None = None,
) -> None:
"""Generates the core part of a wrapper function for a native function.
This expects each argument as a PyObject * named obj_{arg} as a precondition.
It converts the PyObject *s to the necessary types, checking and unboxing if necessary,
makes the call, then boxes the result if necessary and returns it.
"""
gen = WrapperGenerator(None, emitter)
gen.set_target(fn)
if arg_names:
gen.arg_names = arg_names
gen.cleanups = cleanups or []
gen.optional_args = optional_args or []
gen.traceback_code = traceback_code or ""
error = ReturnHandler("NULL") if not gen.use_goto() else GotoHandler("fail")
gen.emit_arg_processing(error=error)
gen.emit_call()
gen.emit_error_handling()
def generate_arg_check(
name: str,
typ: RType,
emitter: Emitter,
error: ErrorHandler | None = None,
*,
optional: bool = False,
raise_exception: bool = True,
bitmap_arg_index: int = 0,
) -> None:
"""Insert a runtime check for argument and unbox if necessary.
The object is named PyObject *obj_{}. This is expected to generate
a value of name arg_{} (unboxed if necessary). For each primitive a runtime
check ensures the correct type.
"""
error = error or AssignHandler()
if typ.is_unboxed:
if typ.error_overlap and optional:
# Update bitmap is value is provided.
init = emitter.c_undefined_value(typ)
emitter.emit_line(f"{emitter.ctype(typ)} arg_{name} = {init};")
emitter.emit_line(f"if (obj_{name} != NULL) {{")
bitmap = bitmap_name(bitmap_arg_index // BITMAP_BITS)
emitter.emit_line(f"{bitmap} |= 1 << {bitmap_arg_index & (BITMAP_BITS - 1)};")
emitter.emit_unbox(
f"obj_{name}",
f"arg_{name}",
typ,
declare_dest=False,
raise_exception=raise_exception,
error=error,
borrow=True,
)
emitter.emit_line("}")
else:
# Borrow when unboxing to avoid reference count manipulation.
emitter.emit_unbox(
f"obj_{name}",
f"arg_{name}",
typ,
declare_dest=True,
raise_exception=raise_exception,
error=error,
borrow=True,
optional=optional,
)
elif is_object_rprimitive(typ):
# Object is trivial since any object is valid
if optional:
emitter.emit_line(f"PyObject *arg_{name};")
emitter.emit_line(f"if (obj_{name} == NULL) {{")
emitter.emit_line(f"arg_{name} = {emitter.c_error_value(typ)};")
emitter.emit_lines("} else {", f"arg_{name} = obj_{name}; ", "}")
else:
emitter.emit_line(f"PyObject *arg_{name} = obj_{name};")
else:
emitter.emit_cast(
f"obj_{name}",
f"arg_{name}",
typ,
declare_dest=True,
raise_exception=raise_exception,
error=error,
optional=optional,
)
class WrapperGenerator:
"""Helper that simplifies the generation of wrapper functions."""
# TODO: Use this for more wrappers
def __init__(self, cl: ClassIR | None, emitter: Emitter) -> None:
self.cl = cl
self.emitter = emitter
self.cleanups: list[str] = []
self.optional_args: list[RuntimeArg] = []
self.traceback_code = ""
def set_target(self, fn: FuncIR) -> None:
"""Set the wrapped function.
It's fine to modify the attributes initialized here later to customize
the wrapper function.
"""
self.target_name = fn.name
self.target_cname = fn.cname(self.emitter.names)
self.num_bitmap_args = fn.sig.num_bitmap_args
if self.num_bitmap_args:
self.args = fn.args[: -self.num_bitmap_args]
else:
self.args = fn.args
self.arg_names = [arg.name for arg in self.args]
self.ret_type = fn.ret_type
def wrapper_name(self) -> str:
"""Return the name of the wrapper function."""
return "{}{}{}".format(
DUNDER_PREFIX,
self.target_name,
self.cl.name_prefix(self.emitter.names) if self.cl else "",
)
def use_goto(self) -> bool:
"""Do we use a goto for error handling (instead of straight return)?"""
return bool(self.cleanups or self.traceback_code)
def emit_header(self) -> None:
"""Emit the function header of the wrapper implementation."""
input_args = ", ".join(f"PyObject *obj_{arg}" for arg in self.arg_names)
self.emitter.emit_line(
"static PyObject *{name}({input_args}) {{".format(
name=self.wrapper_name(), input_args=input_args
)
)
def emit_arg_processing(
self, error: ErrorHandler | None = None, raise_exception: bool = True
) -> None:
"""Emit validation and unboxing of arguments."""
error = error or self.error()
bitmap_arg_index = 0
for arg_name, arg in zip(self.arg_names, self.args):
# Suppress the argument check for *args/**kwargs, since we know it must be right.
typ = arg.type if arg.kind not in (ARG_STAR, ARG_STAR2) else object_rprimitive
optional = arg in self.optional_args
generate_arg_check(
arg_name,
typ,
self.emitter,
error,
raise_exception=raise_exception,
optional=optional,
bitmap_arg_index=bitmap_arg_index,
)
if optional and typ.error_overlap:
bitmap_arg_index += 1
def emit_call(self, not_implemented_handler: str = "") -> None:
"""Emit call to the wrapper function.
If not_implemented_handler is non-empty, use this C code to handle
a NotImplemented return value (if it's possible based on the return type).
"""
native_args = ", ".join(f"arg_{arg}" for arg in self.arg_names)
if self.num_bitmap_args:
bitmap_args = ", ".join(
[bitmap_name(i) for i in reversed(range(self.num_bitmap_args))]
)
native_args = f"{native_args}, {bitmap_args}"
ret_type = self.ret_type
emitter = self.emitter
if ret_type.is_unboxed or self.use_goto():
# TODO: The Py_RETURN macros return the correct PyObject * with reference count
# handling. Are they relevant?
emitter.emit_line(
"{}retval = {}{}({});".format(
emitter.ctype_spaced(ret_type), NATIVE_PREFIX, self.target_cname, native_args
)
)
emitter.emit_lines(*self.cleanups)
if ret_type.is_unboxed:
emitter.emit_error_check("retval", ret_type, "return NULL;")
emitter.emit_box("retval", "retbox", ret_type, declare_dest=True)
emitter.emit_line("return {};".format("retbox" if ret_type.is_unboxed else "retval"))
else:
if not_implemented_handler and not isinstance(ret_type, RInstance):
# The return value type may overlap with NotImplemented.
emitter.emit_line(
"PyObject *retbox = {}{}({});".format(
NATIVE_PREFIX, self.target_cname, native_args
)
)
emitter.emit_lines(
"if (retbox == Py_NotImplemented) {",
not_implemented_handler,
"}",
"return retbox;",
)
else:
emitter.emit_line(f"return {NATIVE_PREFIX}{self.target_cname}({native_args});")
# TODO: Tracebacks?
def error(self) -> ErrorHandler:
"""Figure out how to deal with errors in the wrapper."""
if self.cleanups or self.traceback_code:
# We'll have a label at the end with error handling code.
return GotoHandler("fail")
else:
# Nothing special needs to done to handle errors, so just return.
return ReturnHandler("NULL")
def emit_error_handling(self) -> None:
"""Emit error handling block at the end of the wrapper, if needed."""
emitter = self.emitter
if self.use_goto():
emitter.emit_label("fail")
emitter.emit_lines(*self.cleanups)
if self.traceback_code:
emitter.emit_line(self.traceback_code)
emitter.emit_line("return NULL;")
def finish(self) -> None:
self.emitter.emit_line("}")

View file

@ -0,0 +1,301 @@
from __future__ import annotations
from typing import Final, TypeGuard
# Supported Python literal types. All tuple / frozenset items must have supported
# literal types as well, but we can't represent the type precisely.
LiteralValue = (
str | bytes | int | bool | float | complex | tuple[object, ...] | frozenset[object] | None
)
def _is_literal_value(obj: object) -> TypeGuard[LiteralValue]:
return isinstance(obj, (str, bytes, int, float, complex, tuple, frozenset, type(None)))
# Some literals are singletons and handled specially (None, False and True)
NUM_SINGLETONS: Final = 3
class Literals:
"""Collection of literal values used in a compilation group and related helpers."""
def __init__(self) -> None:
# Each dict maps value to literal index (0, 1, ...)
self.str_literals: dict[str, int] = {}
self.bytes_literals: dict[bytes, int] = {}
self.int_literals: dict[int, int] = {}
self.float_literals: dict[float, int] = {}
self.complex_literals: dict[complex, int] = {}
self.tuple_literals: dict[tuple[object, ...], int] = {}
self.frozenset_literals: dict[frozenset[object], int] = {}
def record_literal(self, value: LiteralValue) -> None:
"""Ensure that the literal value is available in generated code."""
if value is None or value is True or value is False:
# These are special cased and always present
return
if isinstance(value, str):
str_literals = self.str_literals
if value not in str_literals:
str_literals[value] = len(str_literals)
elif isinstance(value, bytes):
bytes_literals = self.bytes_literals
if value not in bytes_literals:
bytes_literals[value] = len(bytes_literals)
elif isinstance(value, int):
int_literals = self.int_literals
if value not in int_literals:
int_literals[value] = len(int_literals)
elif isinstance(value, float):
float_literals = self.float_literals
if value not in float_literals:
float_literals[value] = len(float_literals)
elif isinstance(value, complex):
complex_literals = self.complex_literals
if value not in complex_literals:
complex_literals[value] = len(complex_literals)
elif isinstance(value, tuple):
tuple_literals = self.tuple_literals
if value not in tuple_literals:
for item in value:
assert _is_literal_value(item)
self.record_literal(item)
tuple_literals[value] = len(tuple_literals)
elif isinstance(value, frozenset):
frozenset_literals = self.frozenset_literals
if value not in frozenset_literals:
for item in value:
assert _is_literal_value(item)
self.record_literal(item)
frozenset_literals[value] = len(frozenset_literals)
else:
assert False, "invalid literal: %r" % value
def literal_index(self, value: LiteralValue) -> int:
"""Return the index to the literals array for given value."""
# The array contains first None and booleans, followed by all str values,
# followed by bytes values, etc.
if value is None:
return 0
elif value is False:
return 1
elif value is True:
return 2
n = NUM_SINGLETONS
if isinstance(value, str):
return n + self.str_literals[value]
n += len(self.str_literals)
if isinstance(value, bytes):
return n + self.bytes_literals[value]
n += len(self.bytes_literals)
if isinstance(value, int):
return n + self.int_literals[value]
n += len(self.int_literals)
if isinstance(value, float):
return n + self.float_literals[value]
n += len(self.float_literals)
if isinstance(value, complex):
return n + self.complex_literals[value]
n += len(self.complex_literals)
if isinstance(value, tuple):
return n + self.tuple_literals[value]
n += len(self.tuple_literals)
if isinstance(value, frozenset):
return n + self.frozenset_literals[value]
assert False, "invalid literal: %r" % value
def num_literals(self) -> int:
# The first three are for None, True and False
return (
NUM_SINGLETONS
+ len(self.str_literals)
+ len(self.bytes_literals)
+ len(self.int_literals)
+ len(self.float_literals)
+ len(self.complex_literals)
+ len(self.tuple_literals)
+ len(self.frozenset_literals)
)
# The following methods return the C encodings of literal values
# of different types
def encoded_str_values(self) -> list[bytes]:
return _encode_str_values(self.str_literals)
def encoded_int_values(self) -> list[bytes]:
return _encode_int_values(self.int_literals)
def encoded_bytes_values(self) -> list[bytes]:
return _encode_bytes_values(self.bytes_literals)
def encoded_float_values(self) -> list[str]:
return _encode_float_values(self.float_literals)
def encoded_complex_values(self) -> list[str]:
return _encode_complex_values(self.complex_literals)
def encoded_tuple_values(self) -> list[str]:
return self._encode_collection_values(self.tuple_literals)
def encoded_frozenset_values(self) -> list[str]:
return self._encode_collection_values(self.frozenset_literals)
def _encode_collection_values(
self, values: dict[tuple[object, ...], int] | dict[frozenset[object], int]
) -> list[str]:
"""Encode tuple/frozenset values into a C array.
The format of the result is like this:
<number of collections>
<length of the first collection>
<literal index of first item>
...
<literal index of last item>
<length of the second collection>
...
"""
value_by_index = {index: value for value, index in values.items()}
result = []
count = len(values)
result.append(str(count))
for i in range(count):
value = value_by_index[i]
result.append(str(len(value)))
for item in value:
assert _is_literal_value(item)
index = self.literal_index(item)
result.append(str(index))
return result
def _encode_str_values(values: dict[str, int]) -> list[bytes]:
value_by_index = {index: value for value, index in values.items()}
result = []
line: list[bytes] = []
line_len = 0
for i in range(len(values)):
value = value_by_index[i]
c_literal = format_str_literal(value)
c_len = len(c_literal)
if line_len > 0 and line_len + c_len > 70:
result.append(format_int(len(line)) + b"".join(line))
line = []
line_len = 0
line.append(c_literal)
line_len += c_len
if line:
result.append(format_int(len(line)) + b"".join(line))
result.append(b"")
return result
def _encode_bytes_values(values: dict[bytes, int]) -> list[bytes]:
value_by_index = {index: value for value, index in values.items()}
result = []
line: list[bytes] = []
line_len = 0
for i in range(len(values)):
value = value_by_index[i]
c_init = format_int(len(value))
c_len = len(c_init) + len(value)
if line_len > 0 and line_len + c_len > 70:
result.append(format_int(len(line)) + b"".join(line))
line = []
line_len = 0
line.append(c_init + value)
line_len += c_len
if line:
result.append(format_int(len(line)) + b"".join(line))
result.append(b"")
return result
def format_int(n: int) -> bytes:
"""Format an integer using a variable-length binary encoding."""
if n < 128:
a = [n]
else:
a = []
while n > 0:
a.insert(0, n & 0x7F)
n >>= 7
for i in range(len(a) - 1):
# If the highest bit is set, more 7-bit digits follow
a[i] |= 0x80
return bytes(a)
def format_str_literal(s: str) -> bytes:
utf8 = s.encode("utf-8", errors="surrogatepass")
return format_int(len(utf8)) + utf8
def _encode_int_values(values: dict[int, int]) -> list[bytes]:
"""Encode int values into C strings.
Values are stored in base 10 and separated by 0 bytes.
"""
value_by_index = {index: value for value, index in values.items()}
result = []
line: list[bytes] = []
line_len = 0
for i in range(len(values)):
value = value_by_index[i]
encoded = b"%d" % value
if line_len > 0 and line_len + len(encoded) > 70:
result.append(format_int(len(line)) + b"\0".join(line))
line = []
line_len = 0
line.append(encoded)
line_len += len(encoded)
if line:
result.append(format_int(len(line)) + b"\0".join(line))
result.append(b"")
return result
def float_to_c(x: float) -> str:
"""Return C literal representation of a float value."""
s = str(x)
if s == "inf":
return "INFINITY"
elif s == "-inf":
return "-INFINITY"
elif s == "nan":
return "NAN"
return s
def _encode_float_values(values: dict[float, int]) -> list[str]:
"""Encode float values into a C array values.
The result contains the number of values followed by individual values.
"""
value_by_index = {index: value for value, index in values.items()}
result = []
num = len(values)
result.append(str(num))
for i in range(num):
value = value_by_index[i]
result.append(float_to_c(value))
return result
def _encode_complex_values(values: dict[complex, int]) -> list[str]:
"""Encode float values into a C array values.
The result contains the number of values followed by pairs of doubles
representing complex numbers.
"""
value_by_index = {index: value for value, index in values.items()}
result = []
num = len(values)
result.append(str(num))
for i in range(num):
value = value_by_index[i]
result.append(float_to_c(value.real))
result.append(float_to_c(value.imag))
return result

View file

@ -0,0 +1,149 @@
from __future__ import annotations
import importlib.machinery
import sys
import sysconfig
from typing import Any, Final
from mypy.util import unnamed_function
PREFIX: Final = "CPyPy_" # Python wrappers
NATIVE_PREFIX: Final = "CPyDef_" # Native functions etc.
DUNDER_PREFIX: Final = "CPyDunder_" # Wrappers for exposing dunder methods to the API
REG_PREFIX: Final = "cpy_r_" # Registers
STATIC_PREFIX: Final = "CPyStatic_" # Static variables (for literals etc.)
TYPE_PREFIX: Final = "CPyType_" # Type object struct
MODULE_PREFIX: Final = "CPyModule_" # Cached modules
TYPE_VAR_PREFIX: Final = "CPyTypeVar_" # Type variables when using new-style Python 3.12 syntax
ATTR_PREFIX: Final = "_" # Attributes
FAST_PREFIX: Final = "__mypyc_fast_" # Optimized methods in non-extension classes
ENV_ATTR_NAME: Final = "__mypyc_env__"
NEXT_LABEL_ATTR_NAME: Final = "__mypyc_next_label__"
TEMP_ATTR_NAME: Final = "__mypyc_temp__"
LAMBDA_NAME: Final = "__mypyc_lambda__"
PROPSET_PREFIX: Final = "__mypyc_setter__"
SELF_NAME: Final = "__mypyc_self__"
GENERATOR_ATTRIBUTE_PREFIX: Final = "__mypyc_generator_attribute__"
CPYFUNCTION_NAME = "__cpyfunction__"
# Max short int we accept as a literal is based on 32-bit platforms,
# so that we can just always emit the same code.
TOP_LEVEL_NAME: Final = "__top_level__" # Special function representing module top level
# Maximal number of subclasses for a class to trigger fast path in isinstance() checks.
FAST_ISINSTANCE_MAX_SUBCLASSES: Final = 2
# Size of size_t, if configured.
SIZEOF_SIZE_T_SYSCONFIG: Final = sysconfig.get_config_var("SIZEOF_SIZE_T")
SIZEOF_SIZE_T: Final = (
int(SIZEOF_SIZE_T_SYSCONFIG)
if SIZEOF_SIZE_T_SYSCONFIG is not None
else (sys.maxsize + 1).bit_length() // 8
)
IS_32_BIT_PLATFORM: Final = int(SIZEOF_SIZE_T) == 4
PLATFORM_SIZE = 4 if IS_32_BIT_PLATFORM else 8
# Maximum value for a short tagged integer.
MAX_SHORT_INT: Final = 2 ** (8 * int(SIZEOF_SIZE_T) - 2) - 1
# Minimum value for a short tagged integer.
MIN_SHORT_INT: Final = -(MAX_SHORT_INT) - 1
# Maximum value for a short tagged integer represented as a C integer literal.
#
# Note: Assume that the compiled code uses the same bit width as mypyc
MAX_LITERAL_SHORT_INT: Final = MAX_SHORT_INT
MIN_LITERAL_SHORT_INT: Final = -MAX_LITERAL_SHORT_INT - 1
# Description of the C type used to track the definedness of attributes and
# the presence of argument default values that have types with overlapping
# error values. Each tracked attribute/argument has a dedicated bit in the
# relevant bitmap.
BITMAP_TYPE: Final = "uint32_t"
BITMAP_BITS: Final = 32
# Runtime C library files that are always included (some ops may bring
# extra dependencies via mypyc.ir.SourceDep)
RUNTIME_C_FILES: Final = [
"init.c",
"getargs.c",
"getargsfast.c",
"int_ops.c",
"float_ops.c",
"str_ops.c",
"bytes_ops.c",
"list_ops.c",
"dict_ops.c",
"set_ops.c",
"tuple_ops.c",
"exc_ops.c",
"misc_ops.c",
"generic_ops.c",
"pythonsupport.c",
"function_wrapper.c",
]
# Python 3.12 introduced immortal objects, specified via a special reference count
# value. The reference counts of immortal objects are normally not modified, but it's
# not strictly wrong to modify them. See PEP 683 for more information, but note that
# some details in the PEP are out of date.
HAVE_IMMORTAL: Final = sys.version_info >= (3, 12)
# Are we running on a free-threaded build (GIL disabled)? This implies that
# we are on Python 3.13 or later.
IS_FREE_THREADED: Final = bool(sysconfig.get_config_var("Py_GIL_DISABLED"))
# The file extension suffix for C extension modules on the current platform
# (e.g. ".cpython-312-x86_64-linux-gnu.so" or ".pyd").
_EXT_SUFFIXES: Final = importlib.machinery.EXTENSION_SUFFIXES
EXT_SUFFIX: Final = _EXT_SUFFIXES[0] if _EXT_SUFFIXES else ".so"
JsonDict = dict[str, Any]
def shared_lib_name(group_name: str) -> str:
"""Given a group name, return the actual name of its extension module.
(This just adds a suffix to the final component.)
"""
return f"{group_name}__mypyc"
def short_name(name: str) -> str:
if name.startswith("builtins."):
return name[9:]
return name
def get_id_from_name(name: str, fullname: str, line: int) -> str:
"""Create a unique id for a function.
This creates an id that is unique for any given function definition, so that it can be used as
a dictionary key. This is usually the fullname of the function, but this is different in that
it handles the case where the function is named '_', in which case multiple different functions
could have the same name."""
if unnamed_function(name):
return f"{fullname}.{line}"
else:
return fullname
def short_id_from_name(func_name: str, shortname: str, line: int | None) -> str:
if unnamed_function(func_name):
assert line is not None
partial_name = f"{shortname}.{line}"
else:
partial_name = shortname
return partial_name
def bitmap_name(index: int) -> str:
if index == 0:
return "__bitmap"
return f"__bitmap{index + 1}"

View file

@ -0,0 +1,32 @@
from __future__ import annotations
import sys
import traceback
from collections.abc import Iterator
from contextlib import contextmanager
from typing import NoReturn
@contextmanager
def catch_errors(module_path: str, line: int) -> Iterator[None]:
try:
yield
except Exception:
crash_report(module_path, line)
def crash_report(module_path: str, line: int) -> NoReturn:
# Adapted from report_internal_error in mypy
err = sys.exc_info()[1]
tb = traceback.extract_stack()[:-4]
# Excise all the traceback from the test runner
for i, x in enumerate(tb):
if x.name == "pytest_runtest_call":
tb = tb[i + 1 :]
break
tb2 = traceback.extract_tb(sys.exc_info()[2])[1:]
print("Traceback (most recent call last):")
for s in traceback.format_list(tb + tb2):
print(s.rstrip("\n"))
print(f"{module_path}:{line}: {type(err).__name__}: {err}")
raise SystemExit(2)

View file

@ -0,0 +1,32 @@
from __future__ import annotations
import mypy.errors
from mypy.options import Options
class Errors:
def __init__(self, options: Options) -> None:
self.num_errors = 0
self.num_warnings = 0
self._errors = mypy.errors.Errors(options, hide_error_codes=True)
def error(self, msg: str, path: str, line: int) -> None:
self._errors.set_file(path, None, self._errors.options)
self._errors.report(line, None, msg, severity="error")
self.num_errors += 1
def note(self, msg: str, path: str, line: int) -> None:
self._errors.set_file(path, None, self._errors.options)
self._errors.report(line, None, msg, severity="note")
def warning(self, msg: str, path: str, line: int) -> None:
self._errors.set_file(path, None, self._errors.options)
self._errors.report(line, None, msg, severity="warning")
self.num_warnings += 1
def new_messages(self) -> list[str]:
return self._errors.new_messages()
def flush_errors(self) -> None:
for error in self.new_messages():
print(error)

View file

@ -0,0 +1,550 @@
"""Intermediate representation of classes."""
from __future__ import annotations
from typing import NamedTuple
from mypyc.common import PROPSET_PREFIX, JsonDict
from mypyc.ir.func_ir import FuncDecl, FuncIR, FuncSignature, RuntimeArg
from mypyc.ir.ops import DeserMaps, Value
from mypyc.ir.rtypes import RInstance, RType, deserialize_type, object_rprimitive
from mypyc.namegen import NameGenerator, exported_name
# Some notes on the vtable layout: Each concrete class has a vtable
# that contains function pointers for its methods. So that subclasses
# may be efficiently used when their parent class is expected, the
# layout of child vtables must be an extension of their base class's
# vtable.
#
# This makes multiple inheritance tricky, since obviously we cannot be
# an extension of multiple parent classes. We solve this by requiring
# all but one parent to be "traits", which we can operate on in a
# somewhat less efficient way. For each trait implemented by a class,
# we generate a separate vtable for the methods in that trait.
# We then store an array of (trait type, trait vtable) pointers alongside
# a class's main vtable. When we want to call a trait method, we
# (at runtime!) search the array of trait vtables to find the correct one,
# then call through it.
# Trait vtables additionally need entries for attribute getters and setters,
# since they can't always be in the same location.
#
# To keep down the number of indirections necessary, we store the
# array of trait vtables in the memory *before* the class vtable, and
# search it backwards. (This is a trick we can only do once---there
# are only two directions to store data in---but I don't think we'll
# need it again.)
# There are some tricks we could try in the future to store the trait
# vtables inline in the trait table (which would cut down one indirection),
# but this seems good enough for now.
#
# As an example:
# Imagine that we have a class B that inherits from a concrete class A
# and traits T1 and T2, and that A has methods foo() and
# bar() and B overrides bar() with a more specific type.
# Then B's vtable will look something like:
#
# T1 type object
# ptr to B's T1 trait vtable
# T2 type object
# ptr to B's T2 trait vtable
# -> | A.foo
# | Glue function that converts between A.bar's type and B.bar
# B.bar
# B.baz
#
# The arrow points to the "start" of the vtable (what vtable pointers
# point to) and the bars indicate which parts correspond to the parent
# class A's vtable layout.
#
# Classes that allow interpreted code to subclass them also have a
# "shadow vtable" that contains implementations that delegate to
# making a pycall, so that overridden methods in interpreted children
# will be called. (A better strategy could dynamically generate these
# vtables based on which methods are overridden in the children.)
# Descriptions of method and attribute entries in class vtables.
# The 'cls' field is the class that the method/attr was defined in,
# which might be a parent class.
# The 'shadow_method', if present, contains the method that should be
# placed in the class's shadow vtable (if it has one).
class VTableMethod(NamedTuple):
cls: "ClassIR" # noqa: UP037
name: str
method: FuncIR
shadow_method: FuncIR | None
VTableEntries = list[VTableMethod]
class ClassIR:
"""Intermediate representation of a class.
This also describes the runtime structure of native instances.
"""
def __init__(
self,
name: str,
module_name: str,
is_trait: bool = False,
is_generated: bool = False,
is_abstract: bool = False,
is_ext_class: bool = True,
is_final_class: bool = False,
) -> None:
self.name = name
self.module_name = module_name
self.is_trait = is_trait
self.is_generated = is_generated
self.is_abstract = is_abstract
self.is_ext_class = is_ext_class
self.is_final_class = is_final_class
# An augmented class has additional methods separate from what mypyc generates.
# Right now the only one is dataclasses.
self.is_augmented = False
# Does this inherit from a Python class?
self.inherits_python = False
# Do instances of this class have __dict__?
self.has_dict = False
# Do we allow interpreted subclasses? Derived from a mypyc_attr.
self.allow_interpreted_subclasses = False
# Does this class need getseters to be generated for its attributes? (getseters are also
# added if is_generated is False)
self.needs_getseters = False
# Is this class declared as serializable (supports copy.copy
# and pickle) using @mypyc_attr(serializable=True)?
#
# Additionally, any class with this attribute False but with
# an __init__ that can be called without any arguments is
# *implicitly serializable*. In this case __init__ will be
# called during deserialization without arguments. If this is
# True, we match Python semantics and __init__ won't be called
# during deserialization.
#
# This impacts also all subclasses. Use is_serializable() to
# also consider base classes.
self._serializable = False
# If this a subclass of some built-in python class, the name
# of the object for that class. We currently only support this
# in a few ad-hoc cases.
self.builtin_base: str | None = None
# Default empty constructor
self.ctor = FuncDecl(name, None, module_name, FuncSignature([], RInstance(self)))
# Declare setup method that allocates and initializes an object. type is the
# type of the class being initialized, which could be another class if there
# is an interpreted subclass.
# TODO: Make it a regular method and generate its body in IR
self.setup = FuncDecl(
"__mypyc__" + name + "_setup",
None,
module_name,
FuncSignature([RuntimeArg("type", object_rprimitive)], RInstance(self)),
)
# Attributes defined in the class (not inherited)
self.attributes: dict[str, RType] = {}
# Deletable attributes
self.deletable: list[str] = []
# We populate method_types with the signatures of every method before
# we generate methods, and we rely on this information being present.
self.method_decls: dict[str, FuncDecl] = {}
# Map of methods that are actually present in an extension class
self.methods: dict[str, FuncIR] = {}
# Glue methods for boxing/unboxing when a class changes the type
# while overriding a method. Maps from (parent class overridden, method)
# to IR of glue method.
self.glue_methods: dict[tuple[ClassIR, str], FuncIR] = {}
# Properties are accessed like attributes, but have behavior like method calls.
# They don't belong in the methods dictionary, since we don't want to expose them to
# Python's method API. But we want to put them into our own vtable as methods, so that
# they are properly handled and overridden. The property dictionary values are a tuple
# containing a property getter and an optional property setter.
self.properties: dict[str, tuple[FuncIR, FuncIR | None]] = {}
# We generate these in prepare_class_def so that we have access to them when generating
# other methods and properties that rely on these types.
self.property_types: dict[str, RType] = {}
self.vtable: dict[str, int] | None = None
self.vtable_entries: VTableEntries = []
self.trait_vtables: dict[ClassIR, VTableEntries] = {}
# N.B: base might not actually quite be the direct base.
# It is the nearest concrete base, but we allow a trait in between.
self.base: ClassIR | None = None
self.traits: list[ClassIR] = []
# Supply a working mro for most generated classes. Real classes will need to
# fix it up.
self.mro: list[ClassIR] = [self]
# base_mro is the chain of concrete (non-trait) ancestors
self.base_mro: list[ClassIR] = [self]
# Direct subclasses of this class (use subclasses() to also include non-direct ones)
# None if separate compilation prevents this from working.
#
# Often it's better to use has_no_subclasses() or subclasses() instead.
self.children: list[ClassIR] | None = []
# Instance attributes that are initialized in the class body.
self.attrs_with_defaults: set[str] = set()
# Attributes that are always initialized in __init__ or class body
# (inferred in mypyc.analysis.attrdefined using interprocedural analysis).
# These can never raise AttributeError when accessed. If an attribute
# is *not* always initialized, we normally use the error value for
# an undefined value. If the attribute byte has an overlapping error value
# (the error_overlap attribute is true for the RType), we use a bitmap
# to track if the attribute is defined instead (see bitmap_attrs).
self._always_initialized_attrs: set[str] = set()
# Attributes that are sometimes initialized in __init__
self._sometimes_initialized_attrs: set[str] = set()
# If True, __init__ can make 'self' visible to unanalyzed/arbitrary code
self.init_self_leak = False
# Definedness of these attributes is backed by a bitmap. Index in the list
# indicates the bit number. Includes inherited attributes. We need the
# bitmap for types such as native ints (i64 etc.) that can't have a dedicated
# error value that doesn't overlap a valid value. The bitmap is used if the
# value of an attribute is the same as the error value.
self.bitmap_attrs: list[str] = []
# If this is a generator environment class, what is the actual method for it
self.env_user_function: FuncIR | None = None
# If True, keep one freed, cleared instance available for immediate reuse to
# speed up allocations. This helps if many objects are freed quickly, before
# other instances of the same class are allocated. This is effectively a
# per-type free "list" of up to length 1.
self.reuse_freed_instance = False
# If True, the class does not participate in cyclic garbage collection.
# This can improve performance but is only safe if instances can never
# be part of reference cycles. Derived from @mypyc_attr(acyclic=True).
self.is_acyclic = False
# Is this a class inheriting from enum.Enum? Such classes can be special-cased.
self.is_enum = False
# Name of the function if this a callable class representing a coroutine.
self.coroutine_name: str | None = None
def __repr__(self) -> str:
return (
"ClassIR("
"name={self.name}, module_name={self.module_name}, "
"is_trait={self.is_trait}, is_generated={self.is_generated}, "
"is_abstract={self.is_abstract}, is_ext_class={self.is_ext_class}, "
"is_final_class={self.is_final_class}"
")".format(self=self)
)
@property
def fullname(self) -> str:
return f"{self.module_name}.{self.name}"
def real_base(self) -> ClassIR | None:
"""Return the actual concrete base class, if there is one."""
if len(self.mro) > 1 and not self.mro[1].is_trait:
return self.mro[1]
return None
def vtable_entry(self, name: str) -> int:
assert self.vtable is not None, "vtable not computed yet"
assert name in self.vtable, f"{self.name!r} has no attribute {name!r}"
return self.vtable[name]
def attr_details(self, name: str) -> tuple[RType, ClassIR]:
for ir in self.mro:
if name in ir.attributes:
return ir.attributes[name], ir
if name in ir.property_types:
return ir.property_types[name], ir
raise KeyError(f"{self.name!r} has no attribute {name!r}")
def attr_type(self, name: str) -> RType:
return self.attr_details(name)[0]
def method_decl(self, name: str) -> FuncDecl:
for ir in self.mro:
if name in ir.method_decls:
return ir.method_decls[name]
raise KeyError(f"{self.name!r} has no attribute {name!r}")
def method_sig(self, name: str) -> FuncSignature:
return self.method_decl(name).sig
def has_method(self, name: str) -> bool:
try:
self.method_decl(name)
except KeyError:
return False
return True
def is_method_final(self, name: str) -> bool:
subs = self.subclasses()
if subs is None:
return self.is_final_class
if self.has_method(name):
method_decl = self.method_decl(name)
for subc in subs:
if subc.method_decl(name) != method_decl:
return False
return True
else:
return not any(subc.has_method(name) for subc in subs)
def has_attr(self, name: str) -> bool:
try:
self.attr_type(name)
except KeyError:
return False
return True
def is_deletable(self, name: str) -> bool:
return any(name in ir.deletable for ir in self.mro)
def is_always_defined(self, name: str) -> bool:
if self.is_deletable(name):
return False
return name in self._always_initialized_attrs
def name_prefix(self, names: NameGenerator) -> str:
return names.private_name(self.module_name, self.name)
def struct_name(self, names: NameGenerator) -> str:
return f"{exported_name(self.fullname)}Object"
def get_method_and_class(
self, name: str, *, prefer_method: bool = False
) -> tuple[FuncIR, ClassIR] | None:
for ir in self.mro:
if name in ir.methods:
func_ir = ir.methods[name]
if not prefer_method and func_ir.decl.implicit:
# This is an implicit accessor, so there is also an attribute definition
# which the caller prefers. This happens if an attribute overrides a
# property.
return None
return func_ir, ir
return None
def get_method(self, name: str, *, prefer_method: bool = False) -> FuncIR | None:
res = self.get_method_and_class(name, prefer_method=prefer_method)
return res[0] if res else None
def has_method_decl(self, name: str) -> bool:
return any(name in ir.method_decls for ir in self.mro)
def has_no_subclasses(self) -> bool:
return self.children == [] and not self.allow_interpreted_subclasses
def subclasses(self) -> set[ClassIR] | None:
"""Return all subclasses of this class, both direct and indirect.
Return None if it is impossible to identify all subclasses, for example
because we are performing separate compilation.
"""
if self.children is None or self.allow_interpreted_subclasses:
return None
result = set(self.children)
for child in self.children:
if child.children:
child_subs = child.subclasses()
if child_subs is None:
return None
result.update(child_subs)
return result
def concrete_subclasses(self) -> list[ClassIR] | None:
"""Return all concrete (i.e. non-trait and non-abstract) subclasses.
Include both direct and indirect subclasses. Place classes with no children first.
"""
subs = self.subclasses()
if subs is None:
return None
concrete = {c for c in subs if not (c.is_trait or c.is_abstract)}
# We place classes with no children first because they are more likely
# to appear in various isinstance() checks. We then sort leaves by name
# to get stable order.
return sorted(concrete, key=lambda c: (len(c.children or []), c.name))
def is_serializable(self) -> bool:
return any(ci._serializable for ci in self.mro)
def serialize(self) -> JsonDict:
return {
"name": self.name,
"module_name": self.module_name,
"is_trait": self.is_trait,
"is_ext_class": self.is_ext_class,
"is_abstract": self.is_abstract,
"is_generated": self.is_generated,
"is_augmented": self.is_augmented,
"is_final_class": self.is_final_class,
"inherits_python": self.inherits_python,
"has_dict": self.has_dict,
"allow_interpreted_subclasses": self.allow_interpreted_subclasses,
"needs_getseters": self.needs_getseters,
"_serializable": self._serializable,
"builtin_base": self.builtin_base,
"ctor": self.ctor.serialize(),
# We serialize dicts as lists to ensure order is preserved
"attributes": [(k, t.serialize()) for k, t in self.attributes.items()],
# We try to serialize a name reference, but if the decl isn't in methods
# then we can't be sure that will work so we serialize the whole decl.
"method_decls": [
(k, d.id if k in self.methods else d.serialize())
for k, d in self.method_decls.items()
],
# We serialize method fullnames out and put methods in a separate dict
"methods": [(k, m.id) for k, m in self.methods.items()],
"glue_methods": [
((cir.fullname, k), m.id) for (cir, k), m in self.glue_methods.items()
],
# We serialize properties and property_types separately out of an
# abundance of caution about preserving dict ordering...
"property_types": [(k, t.serialize()) for k, t in self.property_types.items()],
"properties": list(self.properties),
"vtable": self.vtable,
"vtable_entries": serialize_vtable(self.vtable_entries),
"trait_vtables": [
(cir.fullname, serialize_vtable(v)) for cir, v in self.trait_vtables.items()
],
# References to class IRs are all just names
"base": self.base.fullname if self.base else None,
"traits": [cir.fullname for cir in self.traits],
"mro": [cir.fullname for cir in self.mro],
"base_mro": [cir.fullname for cir in self.base_mro],
"children": (
[cir.fullname for cir in self.children] if self.children is not None else None
),
"deletable": self.deletable,
"attrs_with_defaults": sorted(self.attrs_with_defaults),
"_always_initialized_attrs": sorted(self._always_initialized_attrs),
"_sometimes_initialized_attrs": sorted(self._sometimes_initialized_attrs),
"init_self_leak": self.init_self_leak,
"env_user_function": self.env_user_function.id if self.env_user_function else None,
"reuse_freed_instance": self.reuse_freed_instance,
"is_acyclic": self.is_acyclic,
"is_enum": self.is_enum,
"is_coroutine": self.coroutine_name,
}
@classmethod
def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ClassIR:
fullname = data["module_name"] + "." + data["name"]
assert fullname in ctx.classes, "Class %s not in deser class map" % fullname
ir = ctx.classes[fullname]
ir.is_trait = data["is_trait"]
ir.is_generated = data["is_generated"]
ir.is_abstract = data["is_abstract"]
ir.is_ext_class = data["is_ext_class"]
ir.is_augmented = data["is_augmented"]
ir.is_final_class = data["is_final_class"]
ir.inherits_python = data["inherits_python"]
ir.has_dict = data["has_dict"]
ir.allow_interpreted_subclasses = data["allow_interpreted_subclasses"]
ir.needs_getseters = data["needs_getseters"]
ir._serializable = data["_serializable"]
ir.builtin_base = data["builtin_base"]
ir.ctor = FuncDecl.deserialize(data["ctor"], ctx)
ir.attributes = {k: deserialize_type(t, ctx) for k, t in data["attributes"]}
ir.method_decls = {
k: ctx.functions[v].decl if isinstance(v, str) else FuncDecl.deserialize(v, ctx)
for k, v in data["method_decls"]
}
ir.methods = {k: ctx.functions[v] for k, v in data["methods"]}
ir.glue_methods = {
(ctx.classes[c], k): ctx.functions[v] for (c, k), v in data["glue_methods"]
}
ir.property_types = {k: deserialize_type(t, ctx) for k, t in data["property_types"]}
ir.properties = {
k: (ir.methods[k], ir.methods.get(PROPSET_PREFIX + k)) for k in data["properties"]
}
ir.vtable = data["vtable"]
ir.vtable_entries = deserialize_vtable(data["vtable_entries"], ctx)
ir.trait_vtables = {
ctx.classes[k]: deserialize_vtable(v, ctx) for k, v in data["trait_vtables"]
}
base = data["base"]
ir.base = ctx.classes[base] if base else None
ir.traits = [ctx.classes[s] for s in data["traits"]]
ir.mro = [ctx.classes[s] for s in data["mro"]]
ir.base_mro = [ctx.classes[s] for s in data["base_mro"]]
ir.children = data["children"] and [ctx.classes[s] for s in data["children"]]
ir.deletable = data["deletable"]
ir.attrs_with_defaults = set(data["attrs_with_defaults"])
ir._always_initialized_attrs = set(data["_always_initialized_attrs"])
ir._sometimes_initialized_attrs = set(data["_sometimes_initialized_attrs"])
ir.init_self_leak = data["init_self_leak"]
ir.env_user_function = (
ctx.functions[data["env_user_function"]] if data["env_user_function"] else None
)
ir.reuse_freed_instance = data["reuse_freed_instance"]
ir.is_acyclic = data.get("is_acyclic", False)
ir.is_enum = data["is_enum"]
ir.coroutine_name = data["is_coroutine"]
return ir
class NonExtClassInfo:
"""Information needed to construct a non-extension class (Python class).
Includes the class dictionary, a tuple of base classes,
the class annotations dictionary, and the metaclass.
"""
def __init__(self, dict: Value, bases: Value, anns: Value, metaclass: Value) -> None:
self.dict = dict
self.bases = bases
self.anns = anns
self.metaclass = metaclass
def serialize_vtable_entry(entry: VTableMethod) -> JsonDict:
return {
".class": "VTableMethod",
"cls": entry.cls.fullname,
"name": entry.name,
"method": entry.method.decl.id,
"shadow_method": entry.shadow_method.decl.id if entry.shadow_method else None,
}
def serialize_vtable(vtable: VTableEntries) -> list[JsonDict]:
return [serialize_vtable_entry(v) for v in vtable]
def deserialize_vtable_entry(data: JsonDict, ctx: DeserMaps) -> VTableMethod:
if data[".class"] == "VTableMethod":
return VTableMethod(
ctx.classes[data["cls"]],
data["name"],
ctx.functions[data["method"]],
ctx.functions[data["shadow_method"]] if data["shadow_method"] else None,
)
assert False, "Bogus vtable .class: %s" % data[".class"]
def deserialize_vtable(data: list[JsonDict], ctx: DeserMaps) -> VTableEntries:
return [deserialize_vtable_entry(x, ctx) for x in data]
def all_concrete_classes(class_ir: ClassIR) -> list[ClassIR] | None:
"""Return all concrete classes among the class itself and its subclasses."""
concrete = class_ir.concrete_subclasses()
if concrete is None:
return None
if not (class_ir.is_abstract or class_ir.is_trait):
concrete.append(class_ir)
return concrete

View file

@ -0,0 +1,59 @@
from typing import Final
class Capsule:
"""Defines a C extension capsule that a primitive may require."""
def __init__(self, name: str) -> None:
# Module fullname, e.g. 'librt.base64'
self.name: Final = name
def __repr__(self) -> str:
return f"Capsule(name={self.name!r})"
def __eq__(self, other: object) -> bool:
return isinstance(other, Capsule) and self.name == other.name
def __hash__(self) -> int:
return hash(("Capsule", self.name))
class SourceDep:
"""Defines a C source file that a primitive may require.
Each source file must also have a corresponding .h file (replace .c with .h)
that gets implicitly #included if the source is used.
"""
def __init__(self, path: str) -> None:
# Relative path from mypyc/lib-rt, e.g. 'bytes_extra_ops.c'
self.path: Final = path
def __repr__(self) -> str:
return f"SourceDep(path={self.path!r})"
def __eq__(self, other: object) -> bool:
return isinstance(other, SourceDep) and self.path == other.path
def __hash__(self) -> int:
return hash(("SourceDep", self.path))
def get_header(self) -> str:
"""Get the header file path by replacing .c with .h"""
return self.path.replace(".c", ".h")
Dependency = Capsule | SourceDep
LIBRT_STRINGS: Final = Capsule("librt.strings")
LIBRT_BASE64: Final = Capsule("librt.base64")
LIBRT_VECS: Final = Capsule("librt.vecs")
LIBRT_TIME: Final = Capsule("librt.time")
BYTES_EXTRA_OPS: Final = SourceDep("bytes_extra_ops.c")
BYTES_WRITER_EXTRA_OPS: Final = SourceDep("byteswriter_extra_ops.c")
STRING_WRITER_EXTRA_OPS: Final = SourceDep("stringwriter_extra_ops.c")
BYTEARRAY_EXTRA_OPS: Final = SourceDep("bytearray_extra_ops.c")
STR_EXTRA_OPS: Final = SourceDep("str_extra_ops.c")
VECS_EXTRA_OPS: Final = SourceDep("vecs_extra_ops.c")

View file

@ -0,0 +1,484 @@
"""Intermediate representation of functions."""
from __future__ import annotations
import inspect
from collections.abc import Sequence
from typing import Final
from mypy.nodes import ARG_POS, ArgKind, Block, FuncDef
from mypyc.common import BITMAP_BITS, JsonDict, bitmap_name, get_id_from_name, short_id_from_name
from mypyc.ir.ops import (
Assign,
AssignMulti,
BasicBlock,
Box,
ControlOp,
DeserMaps,
Float,
Integer,
LoadAddress,
LoadLiteral,
Register,
TupleSet,
Value,
)
from mypyc.ir.rtypes import (
RType,
bitmap_rprimitive,
deserialize_type,
is_bool_rprimitive,
is_none_rprimitive,
)
from mypyc.namegen import NameGenerator
class RuntimeArg:
"""Description of a function argument in IR.
Argument kind is one of ARG_* constants defined in mypy.nodes.
"""
def __init__(
self, name: str, typ: RType, kind: ArgKind = ARG_POS, pos_only: bool = False
) -> None:
self.name = name
self.type = typ
self.kind = kind
self.pos_only = pos_only
@property
def optional(self) -> bool:
return self.kind.is_optional()
def __repr__(self) -> str:
return "RuntimeArg(name={}, type={}, optional={!r}, pos_only={!r})".format(
self.name, self.type, self.optional, self.pos_only
)
def serialize(self) -> JsonDict:
return {
"name": self.name,
"type": self.type.serialize(),
"kind": int(self.kind.value),
"pos_only": self.pos_only,
}
@classmethod
def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> RuntimeArg:
return RuntimeArg(
data["name"],
deserialize_type(data["type"], ctx),
ArgKind(data["kind"]),
data["pos_only"],
)
class FuncSignature:
"""Signature of a function in IR."""
# TODO: Track if method?
def __init__(self, args: Sequence[RuntimeArg], ret_type: RType) -> None:
self.args = tuple(args)
self.ret_type = ret_type
# Bitmap arguments are use to mark default values for arguments that
# have types with overlapping error values.
self.num_bitmap_args = num_bitmap_args(self.args)
if self.num_bitmap_args:
extra = [
RuntimeArg(bitmap_name(i), bitmap_rprimitive, pos_only=True)
for i in range(self.num_bitmap_args)
]
self.args = self.args + tuple(reversed(extra))
def real_args(self) -> tuple[RuntimeArg, ...]:
"""Return arguments without any synthetic bitmap arguments."""
if self.num_bitmap_args:
return self.args[: -self.num_bitmap_args]
return self.args
def bound_sig(self) -> FuncSignature:
if self.num_bitmap_args:
return FuncSignature(self.args[1 : -self.num_bitmap_args], self.ret_type)
else:
return FuncSignature(self.args[1:], self.ret_type)
def __repr__(self) -> str:
return f"FuncSignature(args={self.args!r}, ret={self.ret_type!r})"
def serialize(self) -> JsonDict:
if self.num_bitmap_args:
args = self.args[: -self.num_bitmap_args]
else:
args = self.args
return {"args": [t.serialize() for t in args], "ret_type": self.ret_type.serialize()}
@classmethod
def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> FuncSignature:
return FuncSignature(
[RuntimeArg.deserialize(arg, ctx) for arg in data["args"]],
deserialize_type(data["ret_type"], ctx),
)
def num_bitmap_args(args: tuple[RuntimeArg, ...]) -> int:
n = 0
for arg in args:
if arg.type.error_overlap and arg.kind.is_optional():
n += 1
return (n + (BITMAP_BITS - 1)) // BITMAP_BITS
FUNC_NORMAL: Final = 0
FUNC_STATICMETHOD: Final = 1
FUNC_CLASSMETHOD: Final = 2
class FuncDecl:
"""Declaration of a function in IR (without body or implementation).
A function can be a regular module-level function, a method, a
static method, a class method, or a property getter/setter.
"""
def __init__(
self,
name: str,
class_name: str | None,
module_name: str,
sig: FuncSignature,
kind: int = FUNC_NORMAL,
*,
is_prop_setter: bool = False,
is_prop_getter: bool = False,
is_generator: bool = False,
is_coroutine: bool = False,
implicit: bool = False,
internal: bool = False,
) -> None:
self.name = name
self.class_name = class_name
self.module_name = module_name
self.sig = sig
self.kind = kind
self.is_prop_setter = is_prop_setter
self.is_prop_getter = is_prop_getter
self.is_generator = is_generator
self.is_coroutine = is_coroutine
if class_name is None:
self.bound_sig: FuncSignature | None = None
else:
if kind == FUNC_STATICMETHOD:
self.bound_sig = sig
else:
self.bound_sig = sig.bound_sig()
# If True, not present in the mypy AST and must be synthesized during irbuild
# Currently only supported for property getters/setters
self.implicit = implicit
# If True, only direct C level calls are supported (no wrapper function)
self.internal = internal
# This is optional because this will be set to the line number when the corresponding
# FuncIR is created
self._line: int | None = None
@property
def line(self) -> int:
assert self._line is not None
return self._line
@line.setter
def line(self, line: int) -> None:
self._line = line
@property
def id(self) -> str:
assert self.line is not None
return get_id_from_name(self.name, self.fullname, self.line)
@staticmethod
def compute_shortname(class_name: str | None, name: str) -> str:
return class_name + "." + name if class_name else name
@property
def shortname(self) -> str:
return FuncDecl.compute_shortname(self.class_name, self.name)
@property
def fullname(self) -> str:
return self.module_name + "." + self.shortname
def cname(self, names: NameGenerator) -> str:
partial_name = short_id_from_name(self.name, self.shortname, self._line)
return names.private_name(self.module_name, partial_name)
def serialize(self) -> JsonDict:
return {
"name": self.name,
"class_name": self.class_name,
"module_name": self.module_name,
"sig": self.sig.serialize(),
"kind": self.kind,
"is_prop_setter": self.is_prop_setter,
"is_prop_getter": self.is_prop_getter,
"is_generator": self.is_generator,
"is_coroutine": self.is_coroutine,
"implicit": self.implicit,
"internal": self.internal,
}
# TODO: move this to FuncIR?
@staticmethod
def get_id_from_json(func_ir: JsonDict) -> str:
"""Get the id from the serialized FuncIR associated with this FuncDecl"""
decl = func_ir["decl"]
shortname = FuncDecl.compute_shortname(decl["class_name"], decl["name"])
fullname = decl["module_name"] + "." + shortname
return get_id_from_name(decl["name"], fullname, func_ir["line"])
@classmethod
def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> FuncDecl:
return FuncDecl(
data["name"],
data["class_name"],
data["module_name"],
FuncSignature.deserialize(data["sig"], ctx),
data["kind"],
is_prop_setter=data["is_prop_setter"],
is_prop_getter=data["is_prop_getter"],
is_generator=data["is_generator"],
is_coroutine=data["is_coroutine"],
implicit=data["implicit"],
internal=data["internal"],
)
class FuncIR:
"""Intermediate representation of a function with contextual information.
Unlike FuncDecl, this includes the IR of the body (basic blocks).
"""
def __init__(
self,
decl: FuncDecl,
arg_regs: list[Register],
blocks: list[BasicBlock],
line: int = -1,
traceback_name: str | None = None,
) -> None:
# Declaration of the function, including the signature
self.decl = decl
# Registers for all the arguments to the function
self.arg_regs = arg_regs
# Body of the function
self.blocks = blocks
self.decl.line = line
# The name that should be displayed for tracebacks that
# include this function. Function will be omitted from
# tracebacks if None.
self.traceback_name = traceback_name
@property
def line(self) -> int:
return self.decl.line
@property
def args(self) -> Sequence[RuntimeArg]:
return self.decl.sig.args
@property
def ret_type(self) -> RType:
return self.decl.sig.ret_type
@property
def class_name(self) -> str | None:
return self.decl.class_name
@property
def sig(self) -> FuncSignature:
return self.decl.sig
@property
def name(self) -> str:
return self.decl.name
@property
def fullname(self) -> str:
return self.decl.fullname
@property
def id(self) -> str:
return self.decl.id
@property
def internal(self) -> bool:
return self.decl.internal
def cname(self, names: NameGenerator) -> str:
return self.decl.cname(names)
def __repr__(self) -> str:
if self.class_name:
return f"<FuncIR {self.class_name}.{self.name}>"
else:
return f"<FuncIR {self.name}>"
def serialize(self) -> JsonDict:
# We don't include blocks in the serialized version
return {
"decl": self.decl.serialize(),
"line": self.line,
"traceback_name": self.traceback_name,
}
@classmethod
def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> FuncIR:
return FuncIR(
FuncDecl.deserialize(data["decl"], ctx), [], [], data["line"], data["traceback_name"]
)
INVALID_FUNC_DEF: Final = FuncDef("<INVALID_FUNC_DEF>", [], Block([]))
def all_values(args: list[Register], blocks: list[BasicBlock]) -> list[Value]:
"""Return the set of all values that may be initialized in the blocks.
This omits registers that are only read.
"""
values: list[Value] = list(args)
seen_registers = set(args)
for block in blocks:
for op in block.ops:
if not isinstance(op, ControlOp):
if isinstance(op, (Assign, AssignMulti)):
if op.dest not in seen_registers:
values.append(op.dest)
seen_registers.add(op.dest)
elif op.is_void:
continue
else:
# If we take the address of a register, it might get initialized.
if (
isinstance(op, LoadAddress)
and isinstance(op.src, Register)
and op.src not in seen_registers
):
values.append(op.src)
seen_registers.add(op.src)
values.append(op)
return values
def all_values_full(args: list[Register], blocks: list[BasicBlock]) -> list[Value]:
"""Return set of all values that are initialized or accessed."""
values: list[Value] = list(args)
seen_registers = set(args)
for block in blocks:
for op in block.ops:
for source in op.sources():
# Look for uninitialized registers that are accessed. Ignore
# non-registers since we don't allow ops outside basic blocks.
if isinstance(source, Register) and source not in seen_registers:
values.append(source)
seen_registers.add(source)
if not isinstance(op, ControlOp):
if isinstance(op, (Assign, AssignMulti)):
if op.dest not in seen_registers:
values.append(op.dest)
seen_registers.add(op.dest)
elif op.is_void:
continue
else:
values.append(op)
return values
_ARG_KIND_TO_INSPECT: Final = {
ArgKind.ARG_POS: inspect.Parameter.POSITIONAL_OR_KEYWORD,
ArgKind.ARG_OPT: inspect.Parameter.POSITIONAL_OR_KEYWORD,
ArgKind.ARG_STAR: inspect.Parameter.VAR_POSITIONAL,
ArgKind.ARG_NAMED: inspect.Parameter.KEYWORD_ONLY,
ArgKind.ARG_STAR2: inspect.Parameter.VAR_KEYWORD,
ArgKind.ARG_NAMED_OPT: inspect.Parameter.KEYWORD_ONLY,
}
# Sentinel indicating a value that cannot be represented in a text signature.
_NOT_REPRESENTABLE = object()
def get_text_signature(fn: FuncIR, *, bound: bool = False) -> str | None:
"""Return a text signature in CPython's internal doc format, or None
if the function's signature cannot be represented.
"""
parameters = []
mark_self = (fn.class_name is not None) and (fn.decl.kind != FUNC_STATICMETHOD) and not bound
sig = fn.decl.bound_sig if bound and fn.decl.bound_sig is not None else fn.decl.sig
# Pre-scan for end of positional-only parameters.
# This is needed to handle signatures like 'def foo(self, __x)', where mypy
# currently sees 'self' as being positional-or-keyword and '__x' as positional-only.
pos_only_idx = -1
for idx, arg in enumerate(sig.args):
if arg.pos_only and arg.kind in (ArgKind.ARG_POS, ArgKind.ARG_OPT):
pos_only_idx = idx
for idx, arg in enumerate(sig.args):
if arg.name.startswith(("__bitmap", "__mypyc")):
continue
kind = (
inspect.Parameter.POSITIONAL_ONLY
if idx <= pos_only_idx
else _ARG_KIND_TO_INSPECT[arg.kind]
)
default: object = inspect.Parameter.empty
if arg.optional:
default = _find_default_argument(arg.name, fn.blocks)
if default is _NOT_REPRESENTABLE:
# This default argument cannot be represented in a __text_signature__
return None
curr_param = inspect.Parameter(arg.name, kind, default=default)
parameters.append(curr_param)
if mark_self:
# Parameter.__init__/Parameter.replace do not accept $
curr_param._name = f"${arg.name}" # type: ignore[attr-defined]
mark_self = False
return f"{fn.name}{inspect.Signature(parameters)}"
def _find_default_argument(name: str, blocks: list[BasicBlock]) -> object:
# Find assignment inserted by gen_arg_defaults. Assumed to be the first assignment.
for block in blocks:
for op in block.ops:
if isinstance(op, Assign) and op.dest.name == name:
return _extract_python_literal(op.src)
return _NOT_REPRESENTABLE
def _extract_python_literal(value: Value) -> object:
if isinstance(value, Integer):
if is_none_rprimitive(value.type):
return None
val = value.numeric_value()
if is_bool_rprimitive(value.type):
return bool(val)
return val
elif isinstance(value, Float):
return value.value
elif isinstance(value, LoadLiteral):
return value.value
elif isinstance(value, Box):
return _extract_python_literal(value.src)
elif isinstance(value, TupleSet):
items = tuple(_extract_python_literal(item) for item in value.items)
if any(itm is _NOT_REPRESENTABLE for itm in items):
return _NOT_REPRESENTABLE
return items
return _NOT_REPRESENTABLE

View file

@ -0,0 +1,115 @@
"""Intermediate representation of modules."""
from __future__ import annotations
from mypyc.common import JsonDict
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.deps import Capsule, Dependency, SourceDep
from mypyc.ir.func_ir import FuncDecl, FuncIR
from mypyc.ir.ops import DeserMaps
from mypyc.ir.rtypes import RType, deserialize_type
class ModuleIR:
"""Intermediate representation of a module."""
def __init__(
self,
fullname: str,
imports: list[str],
functions: list[FuncIR],
classes: list[ClassIR],
final_names: list[tuple[str, RType]],
type_var_names: list[str],
) -> None:
self.fullname = fullname
self.imports = imports.copy()
self.functions = functions
self.classes = classes
self.final_names = final_names
# Names of C statics used for Python 3.12 type variable objects.
# These are only visible in the module that defined them, so no need
# to serialize.
self.type_var_names = type_var_names
# Dependencies needed by the module (such as capsules or source files)
self.dependencies: set[Dependency] = set()
def serialize(self) -> JsonDict:
# Serialize dependencies as a list of dicts with type information
serialized_deps = []
for dep in sorted(self.dependencies, key=lambda d: (type(d).__name__, str(d))):
if isinstance(dep, Capsule):
serialized_deps.append({"type": "Capsule", "name": dep.name})
elif isinstance(dep, SourceDep):
serialized_deps.append({"type": "SourceDep", "path": dep.path})
return {
"fullname": self.fullname,
"imports": self.imports,
"functions": [f.serialize() for f in self.functions],
"classes": [c.serialize() for c in self.classes],
"final_names": [(k, t.serialize()) for k, t in self.final_names],
"dependencies": serialized_deps,
}
@classmethod
def deserialize(cls, data: JsonDict, ctx: DeserMaps) -> ModuleIR:
module = ModuleIR(
data["fullname"],
data["imports"],
[ctx.functions[FuncDecl.get_id_from_json(f)] for f in data["functions"]],
[ClassIR.deserialize(c, ctx) for c in data["classes"]],
[(k, deserialize_type(t, ctx)) for k, t in data["final_names"]],
[],
)
# Deserialize dependencies
deps: set[Dependency] = set()
for dep_dict in data["dependencies"]:
if dep_dict["type"] == "Capsule":
deps.add(Capsule(dep_dict["name"]))
elif dep_dict["type"] == "SourceDep":
deps.add(SourceDep(dep_dict["path"]))
module.dependencies = deps
return module
def deserialize_modules(data: dict[str, JsonDict], ctx: DeserMaps) -> dict[str, ModuleIR]:
"""Deserialize a collection of modules.
The modules can contain dependencies on each other.
Arguments:
data: A dict containing the modules to deserialize.
ctx: The deserialization maps to use and to populate.
They are populated with information from the deserialized
modules and as a precondition must have been populated by
deserializing any dependencies of the modules being deserialized
(outside of dependencies between the modules themselves).
Returns a map containing the deserialized modules.
"""
for mod in data.values():
# First create ClassIRs for every class so that we can construct types and whatnot
for cls in mod["classes"]:
ir = ClassIR(cls["name"], cls["module_name"])
assert ir.fullname not in ctx.classes, "Class %s already in map" % ir.fullname
ctx.classes[ir.fullname] = ir
for mod in data.values():
# Then deserialize all of the functions so that methods are available
# to the class deserialization.
for method in mod["functions"]:
func = FuncIR.deserialize(method, ctx)
assert func.decl.id not in ctx.functions, (
"Method %s already in map" % func.decl.fullname
)
ctx.functions[func.decl.id] = func
return {k: ModuleIR.deserialize(v, ctx) for k, v in data.items()}
# ModulesIRs should also always be an *OrderedDict*, but if we
# declared it that way we would need to put it in quotes everywhere...
ModuleIRs = dict[str, ModuleIR]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,530 @@
"""Utilities for pretty-printing IR in a human-readable form."""
from __future__ import annotations
from collections import defaultdict
from collections.abc import Sequence
from typing import Any, Final
from mypyc.common import short_name
from mypyc.ir.func_ir import FuncIR, all_values_full
from mypyc.ir.module_ir import ModuleIRs
from mypyc.ir.ops import (
ERR_NEVER,
Assign,
AssignMulti,
BasicBlock,
Box,
Branch,
Call,
CallC,
Cast,
ComparisonOp,
ControlOp,
CString,
DecRef,
Extend,
Float,
FloatComparisonOp,
FloatNeg,
FloatOp,
GetAttr,
GetElement,
GetElementPtr,
Goto,
IncRef,
InitStatic,
Integer,
IntOp,
KeepAlive,
LoadAddress,
LoadErrorValue,
LoadGlobal,
LoadLiteral,
LoadMem,
LoadStatic,
MethodCall,
Op,
OpVisitor,
PrimitiveOp,
RaiseStandardError,
Register,
Return,
SetAttr,
SetElement,
SetMem,
Truncate,
TupleGet,
TupleSet,
Unborrow,
Unbox,
Undef,
Unreachable,
Value,
)
from mypyc.ir.rtypes import RType, is_bool_rprimitive, is_int_rprimitive
ErrorSource = BasicBlock | Op
class IRPrettyPrintVisitor(OpVisitor[str]):
"""Internal visitor that pretty-prints ops."""
def __init__(self, names: dict[Value, str]) -> None:
# This should contain a name for all values that are shown as
# registers in the output. This is not just for Register
# instances -- all Ops that produce values need (generated) names.
self.names = names
def visit_goto(self, op: Goto) -> str:
return self.format("goto %l", op.label)
branch_op_names: Final = {Branch.BOOL: ("%r", "bool"), Branch.IS_ERROR: ("is_error(%r)", "")}
def visit_branch(self, op: Branch) -> str:
fmt, typ = self.branch_op_names[op.op]
if op.negated:
fmt = f"not {fmt}"
cond = self.format(fmt, op.value)
tb = ""
if op.traceback_entry:
tb = " (error at %s:%d)" % op.traceback_entry
fmt = f"if {cond} goto %l{tb} else goto %l"
if typ:
fmt += f" :: {typ}"
return self.format(fmt, op.true, op.false)
def visit_return(self, op: Return) -> str:
return self.format("return %r", op.value)
def visit_unreachable(self, op: Unreachable) -> str:
return "unreachable"
def visit_assign(self, op: Assign) -> str:
return self.format("%r = %r", op.dest, op.src)
def visit_assign_multi(self, op: AssignMulti) -> str:
return self.format("%r = [%s]", op.dest, ", ".join(self.format("%r", v) for v in op.src))
def visit_load_error_value(self, op: LoadErrorValue) -> str:
return self.format("%r = <error> :: %s", op, op.type)
def visit_load_literal(self, op: LoadLiteral) -> str:
prefix = ""
# For values that have a potential unboxed representation, make
# it explicit that this is a Python object.
if isinstance(op.value, int):
prefix = "object "
rvalue = repr(op.value)
if isinstance(op.value, frozenset):
# We need to generate a string representation that won't vary
# run-to-run because sets are unordered, otherwise we may get
# spurious irbuild test failures.
#
# Sorting by the item's string representation is a bit of a
# hack, but it's stable and won't cause TypeErrors.
formatted_items = [repr(i) for i in sorted(op.value, key=str)]
rvalue = "frozenset({" + ", ".join(formatted_items) + "})"
return self.format("%r = %s%s", op, prefix, rvalue)
def visit_get_attr(self, op: GetAttr) -> str:
return self.format("%r = %s%r.%s", op, self.borrow_prefix(op), op.obj, op.attr)
def borrow_prefix(self, op: Op) -> str:
if op.is_borrowed:
return "borrow "
return ""
def visit_set_attr(self, op: SetAttr) -> str:
if op.is_init:
assert op.error_kind == ERR_NEVER
if op.error_kind == ERR_NEVER:
# Initialization and direct struct access can never fail
return self.format("%r.%s = %r", op.obj, op.attr, op.src)
else:
return self.format("%r.%s = %r; %r = is_error", op.obj, op.attr, op.src, op)
def visit_load_static(self, op: LoadStatic) -> str:
ann = f" ({repr(op.ann)})" if op.ann else ""
name = op.identifier
if op.module_name is not None:
name = f"{op.module_name}.{name}"
return self.format("%r = %s :: %s%s", op, name, op.namespace, ann)
def visit_init_static(self, op: InitStatic) -> str:
name = op.identifier
if op.module_name is not None:
name = f"{op.module_name}.{name}"
return self.format("%s = %r :: %s", name, op.value, op.namespace)
def visit_tuple_get(self, op: TupleGet) -> str:
return self.format("%r = %s%r[%d]", op, self.borrow_prefix(op), op.src, op.index)
def visit_tuple_set(self, op: TupleSet) -> str:
item_str = ", ".join(self.format("%r", item) for item in op.items)
return self.format("%r = (%s)", op, item_str)
def visit_inc_ref(self, op: IncRef) -> str:
s = self.format("inc_ref %r", op.src)
# TODO: Remove bool check (it's unboxed)
if is_bool_rprimitive(op.src.type) or is_int_rprimitive(op.src.type):
s += f" :: {short_name(op.src.type.name)}"
return s
def visit_dec_ref(self, op: DecRef) -> str:
s = self.format("%sdec_ref %r", "x" if op.is_xdec else "", op.src)
# TODO: Remove bool check (it's unboxed)
if is_bool_rprimitive(op.src.type) or is_int_rprimitive(op.src.type):
s += f" :: {short_name(op.src.type.name)}"
return s
def visit_call(self, op: Call) -> str:
args = ", ".join(self.format("%r", arg) for arg in op.args)
# TODO: Display long name?
short_name = op.fn.shortname
s = f"{short_name}({args})"
if not op.is_void:
s = self.format("%r = ", op) + s
return s
def visit_method_call(self, op: MethodCall) -> str:
args = ", ".join(self.format("%r", arg) for arg in op.args)
s = self.format("%r.%s(%s)", op.obj, op.method, args)
if not op.is_void:
s = self.format("%r = ", op) + s
return s
def visit_cast(self, op: Cast) -> str:
if op.is_unchecked:
prefix = "unchecked "
else:
prefix = ""
return self.format(
"%r = %s%scast(%s, %r)", op, prefix, self.borrow_prefix(op), op.type, op.src
)
def visit_box(self, op: Box) -> str:
return self.format("%r = box(%s, %r)", op, op.src.type, op.src)
def visit_unbox(self, op: Unbox) -> str:
return self.format("%r = unbox(%s, %r)", op, op.type, op.src)
def visit_raise_standard_error(self, op: RaiseStandardError) -> str:
if op.value is not None:
if isinstance(op.value, str):
return self.format("%r = raise %s(%s)", op, op.class_name, repr(op.value))
elif isinstance(op.value, Value):
return self.format("%r = raise %s(%r)", op, op.class_name, op.value)
else:
assert False, "value type must be either str or Value"
else:
return self.format("%r = raise %s", op, op.class_name)
def visit_call_c(self, op: CallC) -> str:
args_str = ", ".join(self.format("%r", arg) for arg in op.args)
if op.is_void:
return self.format("%s(%s)", op.function_name, args_str)
else:
return self.format("%r = %s(%s)", op, op.function_name, args_str)
def visit_primitive_op(self, op: PrimitiveOp) -> str:
args_str = ", ".join(self.format("%r", arg) for arg in op.args)
if op.is_void:
return self.format("%s %s", op.desc.name, args_str)
else:
return self.format("%r = %s %s", op, op.desc.name, args_str)
def visit_truncate(self, op: Truncate) -> str:
return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type)
def visit_extend(self, op: Extend) -> str:
if op.signed:
extra = " signed"
else:
extra = ""
return self.format("%r = extend%s %r: %t to %t", op, extra, op.src, op.src_type, op.type)
def visit_load_global(self, op: LoadGlobal) -> str:
ann = f" ({repr(op.ann)})" if op.ann else ""
return self.format("%r = load_global %s :: static%s", op, op.identifier, ann)
def visit_int_op(self, op: IntOp) -> str:
return self.format("%r = %r %s %r", op, op.lhs, IntOp.op_str[op.op], op.rhs)
def visit_comparison_op(self, op: ComparisonOp) -> str:
if op.op in (ComparisonOp.SLT, ComparisonOp.SGT, ComparisonOp.SLE, ComparisonOp.SGE):
sign_format = " :: signed"
elif op.op in (ComparisonOp.ULT, ComparisonOp.UGT, ComparisonOp.ULE, ComparisonOp.UGE):
sign_format = " :: unsigned"
else:
sign_format = ""
return self.format(
"%r = %r %s %r%s", op, op.lhs, ComparisonOp.op_str[op.op], op.rhs, sign_format
)
def visit_float_op(self, op: FloatOp) -> str:
return self.format("%r = %r %s %r", op, op.lhs, FloatOp.op_str[op.op], op.rhs)
def visit_float_neg(self, op: FloatNeg) -> str:
return self.format("%r = -%r", op, op.src)
def visit_float_comparison_op(self, op: FloatComparisonOp) -> str:
return self.format("%r = %r %s %r", op, op.lhs, op.op_str[op.op], op.rhs)
def visit_load_mem(self, op: LoadMem) -> str:
return self.format(
"%r = %sload_mem %r :: %t*", op, self.borrow_prefix(op), op.src, op.type
)
def visit_set_mem(self, op: SetMem) -> str:
return self.format("set_mem %r, %r :: %t*", op.dest, op.src, op.dest_type)
def visit_get_element(self, op: GetElement) -> str:
return self.format("%r = %r.%s", op, op.src, op.field)
def visit_get_element_ptr(self, op: GetElementPtr) -> str:
return self.format("%r = get_element_ptr %r %s :: %t", op, op.src, op.field, op.src_type)
def visit_set_element(self, op: SetElement) -> str:
return self.format("%r = set_element %r, %s, %r", op, op.src, op.field, op.item)
def visit_load_address(self, op: LoadAddress) -> str:
if isinstance(op.src, Register):
return self.format("%r = load_address %r", op, op.src)
elif isinstance(op.src, LoadStatic):
name = op.src.identifier
if op.src.module_name is not None:
name = f"{op.src.module_name}.{name}"
return self.format("%r = load_address %s :: %s", op, name, op.src.namespace)
else:
return self.format("%r = load_address %s", op, op.src)
def visit_keep_alive(self, op: KeepAlive) -> str:
if op.steal:
steal = "steal "
else:
steal = ""
return self.format(
"keep_alive {}{}".format(steal, ", ".join(self.format("%r", v) for v in op.src))
)
def visit_unborrow(self, op: Unborrow) -> str:
return self.format("%r = unborrow %r", op, op.src)
# Helpers
def format(self, fmt: str, *args: Any) -> str:
"""Helper for formatting strings.
These format sequences are supported in fmt:
%s: arbitrary object converted to string using str()
%r: name of IR value/register
%d: int
%f: float
%l: BasicBlock (formatted as label 'Ln')
%t: RType
"""
result = []
i = 0
arglist = list(args)
while i < len(fmt):
n = fmt.find("%", i)
if n < 0:
n = len(fmt)
result.append(fmt[i:n])
if n < len(fmt):
typespec = fmt[n + 1]
arg = arglist.pop(0)
if typespec == "r":
# Register/value
assert isinstance(arg, Value)
if isinstance(arg, Integer):
result.append(str(arg.value))
elif isinstance(arg, Float):
result.append(repr(arg.value))
elif isinstance(arg, CString):
result.append(f"CString({arg.value!r})")
elif isinstance(arg, Undef):
result.append(f"undef {arg.type.name}")
else:
result.append(self.names[arg])
elif typespec == "d":
# Integer
result.append("%d" % arg)
elif typespec == "f":
# Float
result.append("%f" % arg)
elif typespec == "l":
# Basic block (label)
assert isinstance(arg, BasicBlock)
result.append("L%s" % arg.label)
elif typespec == "t":
# RType
assert isinstance(arg, RType)
result.append(arg.name)
elif typespec == "s":
# String
result.append(str(arg))
else:
raise ValueError(f"Invalid format sequence %{typespec}")
i = n + 2
else:
i = n
return "".join(result)
def format_registers(func_ir: FuncIR, names: dict[Value, str]) -> list[str]:
result = []
i = 0
regs = all_values_full(func_ir.arg_regs, func_ir.blocks)
while i < len(regs):
i0 = i
group = [names[regs[i0]]]
while i + 1 < len(regs) and regs[i + 1].type == regs[i0].type:
i += 1
group.append(names[regs[i]])
i += 1
result.append("{} :: {}".format(", ".join(group), regs[i0].type))
return result
def format_blocks(
blocks: list[BasicBlock],
names: dict[Value, str],
source_to_error: dict[ErrorSource, list[str]],
) -> list[str]:
"""Format a list of IR basic blocks into a human-readable form."""
# First label all of the blocks
for i, block in enumerate(blocks):
block.label = i
handler_map: dict[BasicBlock, list[BasicBlock]] = {}
for b in blocks:
if b.error_handler:
handler_map.setdefault(b.error_handler, []).append(b)
visitor = IRPrettyPrintVisitor(names)
lines = []
for i, block in enumerate(blocks):
handler_msg = ""
if block in handler_map:
labels = sorted("L%d" % b.label for b in handler_map[block])
handler_msg = " (handler for {})".format(", ".join(labels))
lines.append("L%d:%s" % (block.label, handler_msg))
if block in source_to_error:
for error in source_to_error[block]:
lines.append(f" ERROR: {error}")
ops = block.ops
if (
isinstance(ops[-1], Goto)
and i + 1 < len(blocks)
and ops[-1].label == blocks[i + 1]
and not source_to_error.get(ops[-1], [])
):
# Hide the last goto if it just goes to the next basic block,
# and there are no assocatiated errors with the op.
ops = ops[:-1]
for op in ops:
line = " " + op.accept(visitor)
lines.append(line)
if op in source_to_error:
first = len(lines) - 1
# Use emojis to highlight the error
for error in source_to_error[op]:
lines.append(f" \U0001f446 ERROR: {error}")
lines[first] = " \U0000274c " + lines[first][4:]
if not isinstance(block.ops[-1], (Goto, Branch, Return, Unreachable)):
# Each basic block needs to exit somewhere.
lines.append(" [MISSING BLOCK EXIT OPCODE]")
return lines
def format_func(fn: FuncIR, errors: Sequence[tuple[ErrorSource, str]] = ()) -> list[str]:
lines = []
cls_prefix = fn.class_name + "." if fn.class_name else ""
lines.append(
"def {}{}({}):".format(cls_prefix, fn.name, ", ".join(arg.name for arg in fn.args))
)
names = generate_names_for_ir(fn.arg_regs, fn.blocks)
for line in format_registers(fn, names):
lines.append(" " + line)
source_to_error = defaultdict(list)
for source, error in errors:
source_to_error[source].append(error)
code = format_blocks(fn.blocks, names, source_to_error)
lines.extend(code)
return lines
def format_modules(modules: ModuleIRs) -> list[str]:
ops = []
for module in modules.values():
for fn in module.functions:
ops.extend(format_func(fn))
ops.append("")
return ops
def generate_names_for_ir(args: list[Register], blocks: list[BasicBlock]) -> dict[Value, str]:
"""Generate unique names for IR values.
Give names such as 'r5' to temp values in IR which are useful when
pretty-printing or generating C. Ensure generated names are unique.
"""
names: dict[Value, str] = {}
used_names = set()
temp_index = 0
for arg in args:
names[arg] = arg.name
used_names.add(arg.name)
for block in blocks:
for op in block.ops:
values = []
for source in op.sources():
if source not in names:
values.append(source)
if isinstance(op, (Assign, AssignMulti)):
values.append(op.dest)
elif isinstance(op, ControlOp) or op.is_void:
continue
elif op not in names:
values.append(op)
for value in values:
if value in names:
continue
if isinstance(value, Register) and value.name:
name = value.name
elif isinstance(value, (Integer, Float, Undef)):
continue
else:
name = "r%d" % temp_index
temp_index += 1
# Append _2, _3, ... if needed to make the name unique.
if name in used_names:
n = 2
while True:
candidate = "%s_%d" % (name, n)
if candidate not in used_names:
name = candidate
break
n += 1
names[value] = name
used_names.add(name)
return names

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,123 @@
"""IRBuilder AST transform helpers shared between expressions and statements.
Shared code that is tightly coupled to mypy ASTs can be put here instead of
making mypyc.irbuild.builder larger.
"""
from __future__ import annotations
from mypy.nodes import (
LDEF,
BytesExpr,
ComparisonExpr,
Expression,
FloatExpr,
IntExpr,
MemberExpr,
NameExpr,
OpExpr,
StrExpr,
UnaryExpr,
Var,
)
from mypyc.ir.ops import BasicBlock
from mypyc.ir.rtypes import is_fixed_width_rtype, is_tagged
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.constant_fold import constant_fold_expr
def process_conditional(
self: IRBuilder, e: Expression, true: BasicBlock, false: BasicBlock
) -> None:
if isinstance(e, OpExpr) and e.op in ["and", "or"]:
if e.op == "and":
# Short circuit 'and' in a conditional context.
new = BasicBlock()
process_conditional(self, e.left, new, false)
self.activate_block(new)
process_conditional(self, e.right, true, false)
else:
# Short circuit 'or' in a conditional context.
new = BasicBlock()
process_conditional(self, e.left, true, new)
self.activate_block(new)
process_conditional(self, e.right, true, false)
elif isinstance(e, UnaryExpr) and e.op == "not":
process_conditional(self, e.expr, false, true)
else:
res = maybe_process_conditional_comparison(self, e, true, false)
if res:
return
# Catch-all for arbitrary expressions.
reg = self.accept(e)
self.add_bool_branch(reg, true, false)
def maybe_process_conditional_comparison(
self: IRBuilder, e: Expression, true: BasicBlock, false: BasicBlock
) -> bool:
"""Transform simple tagged integer comparisons in a conditional context.
Return True if the operation is supported (and was transformed). Otherwise,
do nothing and return False.
Args:
self: IR form Builder
e: Arbitrary expression
true: Branch target if comparison is true
false: Branch target if comparison is false
"""
if not isinstance(e, ComparisonExpr) or len(e.operands) != 2:
return False
ltype = self.node_type(e.operands[0])
rtype = self.node_type(e.operands[1])
if not (
(is_tagged(ltype) or is_fixed_width_rtype(ltype))
and (is_tagged(rtype) or is_fixed_width_rtype(rtype))
):
return False
op = e.operators[0]
if op not in ("==", "!=", "<", "<=", ">", ">="):
return False
left_expr = e.operands[0]
right_expr = e.operands[1]
borrow_left = is_borrow_friendly_expr(self, right_expr)
left = self.accept(left_expr, can_borrow=borrow_left)
right = self.accept(right_expr, can_borrow=True)
if is_fixed_width_rtype(ltype) or is_fixed_width_rtype(rtype):
if not is_fixed_width_rtype(ltype):
left = self.coerce(left, rtype, e.line)
elif not is_fixed_width_rtype(rtype):
right = self.coerce(right, ltype, e.line)
reg = self.binary_op(left, right, op, e.line)
self.builder.flush_keep_alives(e.line)
self.add_bool_branch(reg, true, false)
else:
# "left op right" for two tagged integers
reg = self.builder.binary_op(left, right, op, e.line)
self.flush_keep_alives(e.line)
self.add_bool_branch(reg, true, false)
return True
def is_borrow_friendly_expr(self: IRBuilder, expr: Expression) -> bool:
"""Can the result of the expression borrowed temporarily?
Borrowing means keeping a reference without incrementing the reference count.
"""
if isinstance(expr, (IntExpr, FloatExpr, StrExpr, BytesExpr)):
# Literals are immortal and can always be borrowed
return True
if (
isinstance(expr, (UnaryExpr, OpExpr, NameExpr, MemberExpr))
and constant_fold_expr(self, expr) is not None
):
# Literal expressions are similar to literals
return True
if isinstance(expr, NameExpr):
if isinstance(expr.node, Var) and expr.kind == LDEF:
# Local variable reference can be borrowed
return True
if isinstance(expr, MemberExpr) and self.is_native_attr_ref(expr):
return True
return False

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,244 @@
"""Generate a class that represents a nested function.
The class defines __call__ for calling the function and allows access to
non-local variables defined in outer scopes.
"""
from __future__ import annotations
from mypyc.common import CPYFUNCTION_NAME, ENV_ATTR_NAME, PROPSET_PREFIX, SELF_NAME
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.func_ir import FuncDecl, FuncIR, FuncSignature, RuntimeArg
from mypyc.ir.ops import BasicBlock, Call, GetAttr, Integer, Register, Return, SetAttr, Value
from mypyc.ir.rtypes import RInstance, c_pointer_rprimitive, int_rprimitive, object_rprimitive
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.context import FuncInfo, ImplicitClass
from mypyc.primitives.misc_ops import (
cpyfunction_get_annotations,
cpyfunction_get_code,
cpyfunction_get_defaults,
cpyfunction_get_kwdefaults,
cpyfunction_get_name,
cpyfunction_set_annotations,
cpyfunction_set_name,
method_new_op,
)
def setup_callable_class(builder: IRBuilder) -> None:
"""Generate an (incomplete) callable class representing a function.
This can be a nested function or a function within a non-extension
class. Also set up the 'self' variable for that class.
This takes the most recently visited function and returns a
ClassIR to represent that function. Each callable class contains
an environment attribute which points to another ClassIR
representing the environment class where some of its variables can
be accessed.
Note that some methods, such as '__call__', are not yet
created here. Use additional functions, such as
add_call_to_callable_class(), to add them.
Return a newly constructed ClassIR representing the callable
class for the nested function.
"""
# Check to see that the name has not already been taken. If so,
# rename the class. We allow multiple uses of the same function
# name because this is valid in if-else blocks. Example:
#
# if True:
# def foo(): ----> foo_obj()
# return True
# else:
# def foo(): ----> foo_obj_0()
# return False
name = base_name = f"{builder.fn_info.namespaced_name()}_obj"
count = 0
while name in builder.callable_class_names:
name = base_name + "_" + str(count)
count += 1
builder.callable_class_names.add(name)
# Define the actual callable class ClassIR, and set its
# environment to point at the previously defined environment
# class.
callable_class_ir = ClassIR(name, builder.module_name, is_generated=True, is_final_class=True)
callable_class_ir.reuse_freed_instance = True
# The functools @wraps decorator attempts to call setattr on
# nested functions, so we create a dict for these nested
# functions.
# https://github.com/python/cpython/blob/3.7/Lib/functools.py#L58
if builder.fn_info.is_nested:
callable_class_ir.has_dict = True
# If the enclosing class doesn't contain nested (which will happen if
# this is a toplevel lambda), don't set up an environment.
if builder.fn_infos[-2].contains_nested:
callable_class_ir.attributes[ENV_ATTR_NAME] = RInstance(builder.fn_infos[-2].env_class)
callable_class_ir.mro = [callable_class_ir]
builder.fn_info.callable_class = ImplicitClass(callable_class_ir)
builder.classes.append(callable_class_ir)
# Add a 'self' variable to the environment of the callable class,
# and store that variable in a register to be accessed later.
self_target = builder.add_self_to_env(callable_class_ir)
builder.fn_info.callable_class.self_reg = builder.read(self_target, builder.fn_info.fitem.line)
if not builder.fn_info.in_non_ext and builder.fn_info.is_coroutine:
add_coroutine_properties(builder, callable_class_ir, builder.fn_info.name)
def add_coroutine_properties(
builder: IRBuilder, callable_class_ir: ClassIR, coroutine_name: str
) -> None:
"""Adds properties to the class to make it look like a regular python function.
Needed to make introspection functions like inspect.iscoroutinefunction work.
"""
callable_class_ir.coroutine_name = coroutine_name
callable_class_ir.attributes[CPYFUNCTION_NAME] = object_rprimitive
properties = {
"__name__": cpyfunction_get_name,
"__code__": cpyfunction_get_code,
"__annotations__": cpyfunction_get_annotations,
"__defaults__": cpyfunction_get_defaults,
"__kwdefaults__": cpyfunction_get_kwdefaults,
}
writable_props = {
"__name__": cpyfunction_set_name,
"__annotations__": cpyfunction_set_annotations,
}
line = builder.fn_info.fitem.line
def get_func_wrapper() -> Value:
return builder.add(GetAttr(builder.self(), CPYFUNCTION_NAME, line))
for name, primitive in properties.items():
with builder.enter_method(callable_class_ir, name, object_rprimitive, internal=True):
func = get_func_wrapper()
val = builder.primitive_op(primitive, [func, Integer(0, c_pointer_rprimitive)], line)
builder.add(Return(val))
for name, primitive in writable_props.items():
with builder.enter_method(
callable_class_ir, f"{PROPSET_PREFIX}{name}", int_rprimitive, internal=True
):
value = builder.add_argument("value", object_rprimitive)
func = get_func_wrapper()
rv = builder.primitive_op(
primitive, [func, value, Integer(0, c_pointer_rprimitive)], line
)
builder.add(Return(rv))
for name in properties:
getter = callable_class_ir.get_method(name)
assert getter
setter = callable_class_ir.get_method(f"{PROPSET_PREFIX}{name}")
callable_class_ir.properties[name] = (getter, setter)
def add_call_to_callable_class(
builder: IRBuilder,
args: list[Register],
blocks: list[BasicBlock],
sig: FuncSignature,
fn_info: FuncInfo,
) -> FuncIR:
"""Generate a '__call__' method for a callable class representing a nested function.
This takes the blocks and signature associated with a function
definition and uses those to build the '__call__' method of a
given callable class, used to represent that function.
"""
# Since we create a method, we also add a 'self' parameter.
nargs = len(sig.args) - sig.num_bitmap_args
sig = FuncSignature(
(RuntimeArg(SELF_NAME, object_rprimitive),) + sig.args[:nargs], sig.ret_type
)
call_fn_decl = FuncDecl("__call__", fn_info.callable_class.ir.name, builder.module_name, sig)
call_fn_ir = FuncIR(
call_fn_decl, args, blocks, fn_info.fitem.line, traceback_name=fn_info.fitem.name
)
fn_info.callable_class.ir.methods["__call__"] = call_fn_ir
fn_info.callable_class.ir.method_decls["__call__"] = call_fn_decl
return call_fn_ir
def add_get_to_callable_class(builder: IRBuilder, fn_info: FuncInfo) -> None:
"""Generate the '__get__' method for a callable class."""
line = fn_info.fitem.line
with builder.enter_method(
fn_info.callable_class.ir,
"__get__",
object_rprimitive,
fn_info,
self_type=object_rprimitive,
):
instance = builder.add_argument("instance", object_rprimitive)
builder.add_argument("owner", object_rprimitive)
# If accessed through the class, just return the callable
# object. If accessed through an object, create a new bound
# instance method object.
instance_block, class_block = BasicBlock(), BasicBlock()
comparison = builder.translate_is_op(
builder.read(instance), builder.none_object(line), "is", line
)
builder.add_bool_branch(comparison, class_block, instance_block)
builder.activate_block(class_block)
builder.add(Return(builder.self()))
builder.activate_block(instance_block)
builder.add(
Return(builder.call_c(method_new_op, [builder.self(), builder.read(instance)], line))
)
def instantiate_callable_class(builder: IRBuilder, fn_info: FuncInfo) -> Value:
"""Create an instance of a callable class for a function.
Calls to the function will actually call this instance.
Note that fn_info refers to the function being assigned, whereas
builder.fn_info refers to the function encapsulating the function
being turned into a callable class.
"""
fitem = fn_info.fitem
func_reg = builder.add(Call(fn_info.callable_class.ir.ctor, [], fitem.line))
# Set the environment attribute of the callable class to point at
# the environment class defined in the callable class' immediate
# outer scope. Note that there are three possible environment
# class registers we may use. This depends on what the encapsulating
# (parent) function is:
#
# - A nested function: the callable class is instantiated
# from the current callable class' '__call__' function, and hence
# the callable class' environment register is used.
# - A generator function: the callable class is instantiated
# from the '__next__' method of the generator class, and hence the
# environment of the generator class is used.
# - Regular function or comprehension scope: we use the environment
# of the original function. Comprehension scopes are inlined (no
# callable class), so they fall into this case despite is_nested.
curr_env_reg = None
if builder.fn_info.is_generator:
curr_env_reg = builder.fn_info.generator_class.curr_env_reg
elif builder.fn_info.is_nested and not builder.fn_info.is_comprehension_scope:
curr_env_reg = builder.fn_info.callable_class.curr_env_reg
elif builder.fn_info.contains_nested:
curr_env_reg = builder.fn_info.curr_env_reg
if curr_env_reg:
builder.add(SetAttr(func_reg, ENV_ATTR_NAME, curr_env_reg, fitem.line))
# Initialize function wrapper for callable classes. As opposed to regular functions,
# each instance of a callable class needs its own wrapper because they might be instantiated
# inside other functions.
if not fn_info.in_non_ext and fn_info.is_coroutine:
builder.add_coroutine_setup_call(fn_info.callable_class.ir.name, func_reg)
return func_reg

View file

@ -0,0 +1,967 @@
"""Transform class definitions from the mypy AST form to IR."""
from __future__ import annotations
from abc import abstractmethod
from collections.abc import Callable
from typing import Final
from mypy.nodes import (
EXCLUDED_ENUM_ATTRIBUTES,
TYPE_VAR_TUPLE_KIND,
AssignmentStmt,
CallExpr,
ClassDef,
Decorator,
EllipsisExpr,
ExpressionStmt,
FuncDef,
Lvalue,
MemberExpr,
NameExpr,
OverloadedFuncDef,
PassStmt,
RefExpr,
StrExpr,
TempNode,
TypeInfo,
TypeParam,
is_class_var,
)
from mypy.types import Instance, UnboundType, get_proper_type
from mypyc.common import PROPSET_PREFIX
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
from mypyc.ir.func_ir import FuncDecl, FuncSignature
from mypyc.ir.ops import (
NAMESPACE_TYPE,
BasicBlock,
Branch,
Call,
InitStatic,
LoadAddress,
LoadErrorValue,
LoadStatic,
MethodCall,
Register,
Return,
SetAttr,
TupleSet,
Value,
)
from mypyc.ir.rtypes import (
RType,
bool_rprimitive,
dict_rprimitive,
is_none_rprimitive,
is_object_rprimitive,
is_optional_type,
object_rprimitive,
)
from mypyc.irbuild.builder import IRBuilder, create_type_params
from mypyc.irbuild.function import (
gen_property_getter_ir,
gen_property_setter_ir,
handle_ext_method,
handle_non_ext_method,
load_type,
)
from mypyc.irbuild.prepare import GENERATOR_HELPER_NAME
from mypyc.irbuild.util import dataclass_type, get_func_def, is_constant, is_dataclass_decorator
from mypyc.primitives.dict_ops import dict_new_op, exact_dict_set_item_op
from mypyc.primitives.generic_ops import (
iter_op,
next_op,
py_get_item_op,
py_hasattr_op,
py_setattr_op,
)
from mypyc.primitives.misc_ops import (
dataclass_sleight_of_hand,
import_op,
not_implemented_op,
py_calc_meta_op,
py_init_subclass_op,
pytype_from_template_op,
type_object_op,
)
from mypyc.subtype import is_subtype
def transform_class_def(builder: IRBuilder, cdef: ClassDef) -> None:
"""Create IR for a class definition.
This can generate both extension (native) and non-extension
classes. These are generated in very different ways. In the
latter case we construct a Python type object at runtime by doing
the equivalent of "type(name, bases, dict)" in IR. Extension
classes are defined via C structs that are generated later in
mypyc.codegen.emitclass.
This is the main entry point to this module.
"""
if cdef.info not in builder.mapper.type_to_ir:
builder.error("Nested class definitions not supported", cdef.line)
return
ir = builder.mapper.type_to_ir[cdef.info]
# We do this check here because the base field of parent
# classes aren't necessarily populated yet at
# prepare_class_def time.
if any(ir.base_mro[i].base != ir.base_mro[i + 1] for i in range(len(ir.base_mro) - 1)):
builder.error("Multiple inheritance is not supported (except for traits)", cdef.line)
if ir.allow_interpreted_subclasses:
for parent in ir.mro:
if not parent.allow_interpreted_subclasses:
builder.error(
'Base class "{}" does not allow interpreted subclasses'.format(
parent.fullname
),
cdef.line,
)
# Currently, we only create non-extension classes for classes that are
# decorated or inherit from Enum. Classes decorated with @trait do not
# apply here, and are handled in a different way.
if ir.is_ext_class:
cls_type = dataclass_type(cdef)
if cls_type is None:
cls_builder: ClassBuilder = ExtClassBuilder(builder, cdef)
elif cls_type in ["dataclasses", "attr-auto"]:
cls_builder = DataClassBuilder(builder, cdef)
elif cls_type == "attr":
cls_builder = AttrsClassBuilder(builder, cdef)
else:
raise ValueError(cls_type)
else:
cls_builder = NonExtClassBuilder(builder, cdef)
# Set up class body context so that intra-class ClassVar references
# (e.g. C = A | B where A is defined earlier in the same class) can be
# resolved from the class being built instead of module globals.
builder.class_body_classvars = {}
builder.class_body_obj = cls_builder.class_body_obj()
builder.class_body_ir = ir
for stmt in cdef.defs.body:
if (
isinstance(stmt, (FuncDef, Decorator, OverloadedFuncDef))
and stmt.name == GENERATOR_HELPER_NAME
):
builder.error(
f'Method name "{stmt.name}" is reserved for mypyc internal use', stmt.line
)
if isinstance(stmt, OverloadedFuncDef) and stmt.is_property:
if isinstance(cls_builder, NonExtClassBuilder):
# properties with both getters and setters in non_extension
# classes not supported
builder.error("Property setters not supported in non-extension classes", stmt.line)
for item in stmt.items:
with builder.catch_errors(stmt.line):
cls_builder.add_method(get_func_def(item))
elif isinstance(stmt, (FuncDef, Decorator, OverloadedFuncDef)):
# Ignore plugin generated methods (since they have no
# bodies to compile and will need to have the bodies
# provided by some other mechanism.)
if cdef.info.names[stmt.name].plugin_generated:
continue
with builder.catch_errors(stmt.line):
cls_builder.add_method(get_func_def(stmt))
elif isinstance(stmt, PassStmt) or (
isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, EllipsisExpr)
):
continue
elif isinstance(stmt, AssignmentStmt):
if len(stmt.lvalues) != 1:
builder.error("Multiple assignment in class bodies not supported", stmt.line)
continue
lvalue = stmt.lvalues[0]
if not isinstance(lvalue, NameExpr):
builder.error(
"Only assignment to variables is supported in class bodies", stmt.line
)
continue
# We want to collect class variables in a dictionary for both real
# non-extension classes and fake dataclass ones.
cls_builder.add_attr(lvalue, stmt)
# Track this ClassVar so subsequent class body statements can reference it.
if is_class_var(lvalue) or stmt.is_final_def:
builder.class_body_classvars[lvalue.name] = None
elif isinstance(stmt, ExpressionStmt) and isinstance(stmt.expr, StrExpr):
# Docstring. Ignore
pass
else:
builder.error("Unsupported statement in class body", stmt.line)
# Clear class body context (nested classes are rejected above, so no need to save/restore).
builder.class_body_classvars = {}
builder.class_body_obj = None
builder.class_body_ir = None
# Generate implicit property setters/getters
for name, decl in ir.method_decls.items():
if decl.implicit and decl.is_prop_getter:
getter_ir = gen_property_getter_ir(builder, decl, cdef, ir.is_trait)
builder.functions.append(getter_ir)
ir.methods[getter_ir.decl.name] = getter_ir
setter_ir = None
setter_name = PROPSET_PREFIX + name
if setter_name in ir.method_decls:
setter_ir = gen_property_setter_ir(
builder, ir.method_decls[setter_name], cdef, ir.is_trait
)
builder.functions.append(setter_ir)
ir.methods[setter_name] = setter_ir
ir.properties[name] = (getter_ir, setter_ir)
# TODO: Generate glue method if needed?
# TODO: Do we need interpreted glue methods? Maybe not?
cls_builder.finalize(ir)
class ClassBuilder:
"""Create IR for a class definition.
This is an abstract base class.
"""
def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None:
self.builder = builder
self.cdef = cdef
self.attrs_to_cache: list[tuple[Lvalue, RType]] = []
@abstractmethod
def add_method(self, fdef: FuncDef) -> None:
"""Add a method to the class IR"""
@abstractmethod
def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None:
"""Add an attribute to the class IR"""
@abstractmethod
def finalize(self, ir: ClassIR) -> None:
"""Perform any final operations to complete the class IR"""
def class_body_obj(self) -> Value | None:
"""Return the object to use for loading class attributes during class body init.
For extension classes, this is the type object. For non-extension classes,
this is the class dict. Returns None if not applicable.
"""
return None
class NonExtClassBuilder(ClassBuilder):
def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None:
super().__init__(builder, cdef)
self.non_ext = self.create_non_ext_info()
def class_body_obj(self) -> Value | None:
return self.non_ext.dict
def create_non_ext_info(self) -> NonExtClassInfo:
non_ext_bases = populate_non_ext_bases(self.builder, self.cdef)
non_ext_metaclass = find_non_ext_metaclass(self.builder, self.cdef, non_ext_bases)
non_ext_dict = setup_non_ext_dict(
self.builder, self.cdef, non_ext_metaclass, non_ext_bases
)
# We populate __annotations__ for non-extension classes
# because dataclasses uses it to determine which attributes to compute on.
# TODO: Maybe generate more precise types for annotations
non_ext_anns = self.builder.call_c(dict_new_op, [], self.cdef.line)
return NonExtClassInfo(non_ext_dict, non_ext_bases, non_ext_anns, non_ext_metaclass)
def add_method(self, fdef: FuncDef) -> None:
handle_non_ext_method(self.builder, self.non_ext, self.cdef, fdef)
def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None:
add_non_ext_class_attr_ann(self.builder, self.non_ext, lvalue, stmt)
add_non_ext_class_attr(
self.builder, self.non_ext, lvalue, stmt, self.cdef, self.attrs_to_cache
)
def finalize(self, ir: ClassIR) -> None:
# Dynamically create the class via the type constructor
non_ext_class = load_non_ext_class(self.builder, ir, self.non_ext, self.cdef.line)
non_ext_class = load_decorated_class(self.builder, self.cdef, non_ext_class)
# Try to avoid contention when using free threading.
self.builder.set_immortal_if_free_threaded(non_ext_class, self.cdef.line)
# Save the decorated class
self.builder.add(
InitStatic(non_ext_class, self.cdef.name, self.builder.module_name, NAMESPACE_TYPE)
)
# Add the non-extension class to the dict
self.builder.call_c(
exact_dict_set_item_op,
[
self.builder.load_globals_dict(),
self.builder.load_str(self.cdef.name),
non_ext_class,
],
self.cdef.line,
)
# Cache any cacheable class attributes
cache_class_attrs(self.builder, self.attrs_to_cache, self.cdef)
class ExtClassBuilder(ClassBuilder):
def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None:
super().__init__(builder, cdef)
# If the class is not decorated, generate an extension class for it.
self.type_obj: Value = allocate_class(builder, cdef)
def class_body_obj(self) -> Value | None:
return self.type_obj
def skip_attr_default(self, name: str, stmt: AssignmentStmt) -> bool:
"""Controls whether to skip generating a default for an attribute."""
return False
def add_method(self, fdef: FuncDef) -> None:
handle_ext_method(self.builder, self.cdef, fdef)
def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None:
# Variable declaration with no body
if isinstance(stmt.rvalue, TempNode):
return
# Only treat marked class variables as class variables.
if not (is_class_var(lvalue) or stmt.is_final_def):
return
typ = self.builder.load_native_type_object(self.cdef.fullname)
value = self.builder.accept(stmt.rvalue)
self.builder.primitive_op(
py_setattr_op, [typ, self.builder.load_str(lvalue.name), value], stmt.line
)
if self.builder.non_function_scope() and stmt.is_final_def:
self.builder.init_final_static(lvalue, value, self.cdef.name)
def finalize(self, ir: ClassIR) -> None:
# Call __init_subclass__ after class attributes have been set
self.builder.call_c(py_init_subclass_op, [self.type_obj], self.cdef.line)
attrs_with_defaults, default_assignments = find_attr_initializers(
self.builder, self.cdef, self.skip_attr_default
)
ir.attrs_with_defaults.update(attrs_with_defaults)
generate_attr_defaults_init(self.builder, self.cdef, default_assignments)
create_ne_from_eq(self.builder, self.cdef)
class DataClassBuilder(ExtClassBuilder):
# controls whether an __annotations__ attribute should be added to the class
# __dict__. This is not desirable for attrs classes where auto_attribs is
# disabled, as attrs will reject it.
add_annotations_to_dict = True
def __init__(self, builder: IRBuilder, cdef: ClassDef) -> None:
super().__init__(builder, cdef)
self.non_ext = self.create_non_ext_info()
def create_non_ext_info(self) -> NonExtClassInfo:
"""Set up a NonExtClassInfo to track dataclass attributes.
In addition to setting up a normal extension class for dataclasses,
we also collect its class attributes like a non-extension class so
that we can hand them to the dataclass decorator.
"""
return NonExtClassInfo(
self.builder.call_c(dict_new_op, [], self.cdef.line),
self.builder.add(TupleSet([], self.cdef.line)),
self.builder.call_c(dict_new_op, [], self.cdef.line),
self.builder.add(LoadAddress(type_object_op.type, type_object_op.src, self.cdef.line)),
)
def skip_attr_default(self, name: str, stmt: AssignmentStmt) -> bool:
return stmt.type is not None
def get_type_annotation(self, stmt: AssignmentStmt) -> TypeInfo | None:
# We populate __annotations__ because dataclasses uses it to determine
# which attributes to compute on.
ann_type = get_proper_type(stmt.type)
if isinstance(ann_type, Instance):
return ann_type.type
return None
def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None:
add_non_ext_class_attr_ann(
self.builder, self.non_ext, lvalue, stmt, self.get_type_annotation
)
add_non_ext_class_attr(
self.builder, self.non_ext, lvalue, stmt, self.cdef, self.attrs_to_cache
)
super().add_attr(lvalue, stmt)
def finalize(self, ir: ClassIR) -> None:
"""Generate code to finish instantiating a dataclass.
This works by replacing all of the attributes on the class
(which will be descriptors) with whatever they would be in a
non-extension class, calling dataclass, then switching them back.
The resulting class is an extension class and instances of it do not
have a __dict__ (unless something else requires it).
All methods written explicitly in the source are compiled and
may be called through the vtable while the methods generated
by dataclasses are interpreted and may not be.
(If we just called dataclass without doing this, it would think that all
of the descriptors for our attributes are default values and generate an
incorrect constructor. We need to do the switch so that dataclass gets the
appropriate defaults.)
"""
super().finalize(ir)
assert self.type_obj
add_dunders_to_non_ext_dict(
self.builder, self.non_ext, self.cdef.line, self.add_annotations_to_dict
)
dec = self.builder.accept(
next(d for d in self.cdef.decorators if is_dataclass_decorator(d))
)
dataclass_type_val = self.builder.load_str(dataclass_type(self.cdef) or "unknown")
self.builder.call_c(
dataclass_sleight_of_hand,
[dec, self.type_obj, self.non_ext.dict, self.non_ext.anns, dataclass_type_val],
self.cdef.line,
)
class AttrsClassBuilder(DataClassBuilder):
"""Create IR for an attrs class where auto_attribs=False (the default).
When auto_attribs is enabled, attrs classes behave similarly to dataclasses
(i.e. types are stored as annotations on the class) and are thus handled
by DataClassBuilder, but when auto_attribs is disabled the types are
provided via attr.ib(type=...)
"""
add_annotations_to_dict = False
def skip_attr_default(self, name: str, stmt: AssignmentStmt) -> bool:
return True
def get_type_annotation(self, stmt: AssignmentStmt) -> TypeInfo | None:
if isinstance(stmt.rvalue, CallExpr):
# find the type arg in `attr.ib(type=str)`
callee = stmt.rvalue.callee
if (
isinstance(callee, MemberExpr)
and callee.fullname in ["attr.ib", "attr.attr"]
and "type" in stmt.rvalue.arg_names
):
index = stmt.rvalue.arg_names.index("type")
type_name = stmt.rvalue.args[index]
if isinstance(type_name, NameExpr) and isinstance(type_name.node, TypeInfo):
lvalue = stmt.lvalues[0]
assert isinstance(lvalue, NameExpr), lvalue
return type_name.node
return None
def allocate_class(builder: IRBuilder, cdef: ClassDef) -> Value:
# OK AND NOW THE FUN PART
base_exprs = cdef.base_type_exprs + cdef.removed_base_type_exprs
new_style_type_args = cdef.type_args
if new_style_type_args:
bases = [make_generic_base_class(builder, cdef.fullname, new_style_type_args, cdef.line)]
else:
bases = []
if base_exprs or new_style_type_args:
bases.extend([builder.accept(x) for x in base_exprs])
tp_bases = builder.new_tuple(bases, cdef.line)
else:
tp_bases = builder.add(LoadErrorValue(object_rprimitive, is_borrowed=True))
modname = builder.load_str(builder.module_name)
template = builder.add(
LoadStatic(object_rprimitive, cdef.name + "_template", builder.module_name, NAMESPACE_TYPE)
)
# Create the class
tp = builder.call_c(pytype_from_template_op, [template, tp_bases, modname], cdef.line)
# Set type object to be immortal if free threaded, as otherwise reference count contention
# can cause a big performance hit.
builder.set_immortal_if_free_threaded(tp, cdef.line)
# Immediately fix up the trait vtables, before doing anything with the class.
ir = builder.mapper.type_to_ir[cdef.info]
if not ir.is_trait and not ir.builtin_base:
builder.add(
Call(
FuncDecl(
cdef.name + "_trait_vtable_setup",
None,
builder.module_name,
FuncSignature([], bool_rprimitive),
),
[],
cdef.line,
)
)
builder.add_coroutine_setup_call(cdef.name, tp)
# Populate a '__mypyc_attrs__' field containing the list of attrs
builder.primitive_op(
py_setattr_op,
[
tp,
builder.load_str("__mypyc_attrs__"),
create_mypyc_attrs_tuple(builder, builder.mapper.type_to_ir[cdef.info], cdef.line),
],
cdef.line,
)
# Save the class
builder.add(InitStatic(tp, cdef.name, builder.module_name, NAMESPACE_TYPE))
# Add it to the dict
builder.call_c(
exact_dict_set_item_op,
[builder.load_globals_dict(), builder.load_str(cdef.name), tp],
cdef.line,
)
return tp
def make_generic_base_class(
builder: IRBuilder, fullname: str, type_args: list[TypeParam], line: int
) -> Value:
"""Construct Generic[...] base class object for a new-style generic class (Python 3.12)."""
mod = builder.call_c(import_op, [builder.load_str("_typing")], line)
tvs = create_type_params(builder, mod, type_args, line)
args = []
for tv, type_param in zip(tvs, type_args):
if type_param.kind == TYPE_VAR_TUPLE_KIND:
# Evaluate *Ts for a TypeVarTuple
it = builder.primitive_op(iter_op, [tv], line)
tv = builder.call_c(next_op, [it], line)
args.append(tv)
gent = builder.py_get_attr(mod, "Generic", line)
if len(args) == 1:
arg = args[0]
else:
arg = builder.new_tuple(args, line)
base = builder.primitive_op(py_get_item_op, [gent, arg], line)
return base
# Mypy uses these internally as base classes of TypedDict classes. These are
# lies and don't have any runtime equivalent.
MAGIC_TYPED_DICT_CLASSES: Final[tuple[str, ...]] = (
"typing._TypedDict",
"typing_extensions._TypedDict",
)
def populate_non_ext_bases(builder: IRBuilder, cdef: ClassDef) -> Value:
"""Create base class tuple of a non-extension class.
The tuple is passed to the metaclass constructor.
"""
is_named_tuple = cdef.info.is_named_tuple
ir = builder.mapper.type_to_ir[cdef.info]
bases = []
for cls in (b.type for b in cdef.info.bases):
if cls.fullname == "builtins.object":
continue
if is_named_tuple and cls.fullname in (
"typing.Sequence",
"typing.Iterable",
"typing.Collection",
"typing.Reversible",
"typing.Container",
"typing.Sized",
):
# HAX: Synthesized base classes added by mypy don't exist at runtime, so skip them.
# This could break if they were added explicitly, though...
continue
# Add the current class to the base classes list of concrete subclasses
if cls in builder.mapper.type_to_ir:
base_ir = builder.mapper.type_to_ir[cls]
if base_ir.children is not None:
base_ir.children.append(ir)
if cls.fullname in MAGIC_TYPED_DICT_CLASSES:
# HAX: Mypy internally represents TypedDict classes differently from what
# should happen at runtime. Replace with something that works.
module = "typing"
name = "_TypedDict"
base = builder.get_module_attr(module, name, cdef.line)
elif is_named_tuple and cls.fullname == "builtins.tuple":
name = "_NamedTuple"
base = builder.get_module_attr("typing", name, cdef.line)
else:
cls_module = cls.fullname.rsplit(".", 1)[0]
if cls_module == builder.current_module:
base = builder.load_global_str(cls.name, cdef.line)
else:
base = builder.load_module_attr_by_fullname(cls.fullname, cdef.line)
bases.append(base)
if cls.fullname in MAGIC_TYPED_DICT_CLASSES:
# The remaining base classes are synthesized by mypy and should be ignored.
break
return builder.new_tuple(bases, cdef.line)
def find_non_ext_metaclass(builder: IRBuilder, cdef: ClassDef, bases: Value) -> Value:
"""Find the metaclass of a class from its defs and bases."""
if cdef.metaclass:
declared_metaclass = builder.accept(cdef.metaclass)
else:
if cdef.info.typeddict_type is not None:
# In Python 3.9, the metaclass for class-based TypedDict is typing._TypedDictMeta.
# We can't easily calculate it generically, so special case it.
return builder.get_module_attr("typing", "_TypedDictMeta", cdef.line)
elif cdef.info.is_named_tuple:
# In Python 3.9, the metaclass for class-based NamedTuple is typing.NamedTupleMeta.
# We can't easily calculate it generically, so special case it.
return builder.get_module_attr("typing", "NamedTupleMeta", cdef.line)
declared_metaclass = builder.add(
LoadAddress(type_object_op.type, type_object_op.src, cdef.line)
)
return builder.call_c(py_calc_meta_op, [declared_metaclass, bases], cdef.line)
def setup_non_ext_dict(
builder: IRBuilder, cdef: ClassDef, metaclass: Value, bases: Value
) -> Value:
"""Initialize the class dictionary for a non-extension class.
This class dictionary is passed to the metaclass constructor.
"""
# Check if the metaclass defines a __prepare__ method, and if so, call it.
has_prepare = builder.primitive_op(
py_hasattr_op, [metaclass, builder.load_str("__prepare__")], cdef.line
)
non_ext_dict = Register(dict_rprimitive)
true_block, false_block, exit_block = BasicBlock(), BasicBlock(), BasicBlock()
builder.add_bool_branch(has_prepare, true_block, false_block)
builder.activate_block(true_block)
cls_name = builder.load_str(cdef.name)
prepare_meth = builder.py_get_attr(metaclass, "__prepare__", cdef.line)
prepare_dict = builder.py_call(prepare_meth, [cls_name, bases], cdef.line)
builder.assign(non_ext_dict, prepare_dict, cdef.line)
builder.goto(exit_block)
builder.activate_block(false_block)
builder.assign(non_ext_dict, builder.call_c(dict_new_op, [], cdef.line), cdef.line)
builder.goto(exit_block)
builder.activate_block(exit_block)
return non_ext_dict
def add_non_ext_class_attr_ann(
builder: IRBuilder,
non_ext: NonExtClassInfo,
lvalue: NameExpr,
stmt: AssignmentStmt,
get_type_info: Callable[[AssignmentStmt], TypeInfo | None] | None = None,
) -> None:
"""Add a class attribute to __annotations__ of a non-extension class."""
# FIXME: try to better preserve the special forms and type parameters of generics.
typ: Value | None = None
if get_type_info is not None:
type_info = get_type_info(stmt)
if type_info:
# NOTE: Using string type information is similar to using
# `from __future__ import annotations` in standard python.
# NOTE: For string types we need to use the fullname since it
# includes the module. If string type doesn't have the module,
# @dataclass will try to get the current module and fail since the
# current module is not in sys.modules.
if builder.current_module == type_info.module_name and stmt.line < type_info.line:
typ = builder.load_str(type_info.fullname)
else:
typ = load_type(builder, type_info, stmt.unanalyzed_type, stmt.line)
if typ is None:
# FIXME: if get_type_info is not provided, don't fall back to stmt.type?
ann_type = get_proper_type(stmt.type)
if (
isinstance(stmt.unanalyzed_type, UnboundType)
and stmt.unanalyzed_type.original_str_expr is not None
):
# Annotation is a forward reference, so don't attempt to load the actual
# type and load the string instead.
#
# TODO: is it possible to determine whether a non-string annotation is
# actually a forward reference due to the __annotations__ future?
typ = builder.load_str(stmt.unanalyzed_type.original_str_expr)
elif isinstance(ann_type, Instance):
typ = load_type(builder, ann_type.type, stmt.unanalyzed_type, stmt.line)
else:
typ = builder.add(LoadAddress(type_object_op.type, type_object_op.src, stmt.line))
key = builder.load_str(lvalue.name)
builder.call_c(exact_dict_set_item_op, [non_ext.anns, key, typ], stmt.line)
def add_non_ext_class_attr(
builder: IRBuilder,
non_ext: NonExtClassInfo,
lvalue: NameExpr,
stmt: AssignmentStmt,
cdef: ClassDef,
attr_to_cache: list[tuple[Lvalue, RType]],
) -> None:
"""Add a class attribute to __dict__ of a non-extension class."""
# Only add the attribute to the __dict__ if the assignment is of the form:
# x: type = value (don't add attributes of the form 'x: type' to the __dict__).
if not isinstance(stmt.rvalue, TempNode):
rvalue = builder.accept(stmt.rvalue)
builder.add_to_non_ext_dict(non_ext, lvalue.name, rvalue, stmt.line)
# We cache enum attributes to speed up enum attribute lookup since they
# are final.
if (
cdef.info.bases
# Enum class must be the last parent class.
and cdef.info.bases[-1].type.is_enum
# Skip these since Enum will remove it
and lvalue.name not in EXCLUDED_ENUM_ATTRIBUTES
):
# Enum values are always boxed, so use object_rprimitive.
attr_to_cache.append((lvalue, object_rprimitive))
def find_attr_initializers(
builder: IRBuilder, cdef: ClassDef, skip: Callable[[str, AssignmentStmt], bool] | None = None
) -> tuple[set[str], list[tuple[AssignmentStmt, str]]]:
"""Find initializers of attributes in a class body.
If provided, the skip arg should be a callable which will return whether
to skip generating a default for an attribute. It will be passed the name of
the attribute and the corresponding AssignmentStmt.
"""
cls = builder.mapper.type_to_ir[cdef.info]
if cls.builtin_base:
return set(), []
attrs_with_defaults = set()
# Pull out all assignments in classes in the mro so we can initialize them
# TODO: Support nested statements
default_assignments: list[tuple[AssignmentStmt, str]] = []
for info in reversed(cdef.info.mro):
if info not in builder.mapper.type_to_ir:
continue
for stmt in info.defn.defs.body:
if (
isinstance(stmt, AssignmentStmt)
and isinstance(stmt.lvalues[0], NameExpr)
and not is_class_var(stmt.lvalues[0])
and not isinstance(stmt.rvalue, TempNode)
):
name = stmt.lvalues[0].name
if name == "__slots__":
continue
if name == "__deletable__":
check_deletable_declaration(builder, cls, stmt.line)
continue
if skip is not None and skip(name, stmt):
continue
attr_type = cls.attr_type(name)
# If the attribute is initialized to None and type isn't optional,
# doesn't initialize it to anything (special case for "# type:" comments).
if isinstance(stmt.rvalue, RefExpr) and stmt.rvalue.fullname == "builtins.None":
if (
not is_optional_type(attr_type)
and not is_object_rprimitive(attr_type)
and not is_none_rprimitive(attr_type)
):
continue
attrs_with_defaults.add(name)
default_assignments.append((stmt, info.module_name))
return attrs_with_defaults, default_assignments
def generate_attr_defaults_init(
builder: IRBuilder, cdef: ClassDef, default_assignments: list[tuple[AssignmentStmt, str]]
) -> None:
"""Generate an initialization method for default attr values (from class vars)."""
if not default_assignments:
return
cls = builder.mapper.type_to_ir[cdef.info]
if cls.builtin_base:
return
with builder.enter_method(cls, "__mypyc_defaults_setup", bool_rprimitive):
self_var = builder.self()
for stmt, origin_module in default_assignments:
lvalue = stmt.lvalues[0]
assert isinstance(lvalue, NameExpr), lvalue
if not stmt.is_final_def and not is_constant(stmt.rvalue):
builder.warning("Unsupported default attribute value", stmt.rvalue.line)
attr_type = cls.attr_type(lvalue.name)
# When the default comes from a parent in a different module,
# set the globals lookup module so NameExpr references resolve
# against the correct module's globals dict.
builder.globals_lookup_module = (
origin_module if origin_module != builder.module_name else None
)
try:
val = builder.coerce(builder.accept(stmt.rvalue), attr_type, stmt.line)
finally:
builder.globals_lookup_module = None
init = SetAttr(self_var, lvalue.name, val, stmt.rvalue.line)
init.mark_as_initializer()
builder.add(init)
builder.add(Return(builder.true()))
def check_deletable_declaration(builder: IRBuilder, cl: ClassIR, line: int) -> None:
for attr in cl.deletable:
if attr not in cl.attributes:
if not cl.has_attr(attr):
builder.error(f'Attribute "{attr}" not defined', line)
continue
for base in cl.mro:
if attr in base.property_types:
builder.error(f'Cannot make property "{attr}" deletable', line)
break
else:
_, base = cl.attr_details(attr)
builder.error(
('Attribute "{}" not defined in "{}" ' + '(defined in "{}")').format(
attr, cl.name, base.name
),
line,
)
def create_ne_from_eq(builder: IRBuilder, cdef: ClassDef) -> None:
"""Create a "__ne__" method from a "__eq__" method (if only latter exists)."""
cls = builder.mapper.type_to_ir[cdef.info]
if cls.has_method("__eq__") and not cls.has_method("__ne__"):
gen_glue_ne_method(builder, cls, cdef.line)
def gen_glue_ne_method(builder: IRBuilder, cls: ClassIR, line: int) -> None:
"""Generate a "__ne__" method from a "__eq__" method."""
func_ir = cls.get_method("__eq__")
assert func_ir
eq_sig = func_ir.decl.sig
strict_typing = builder.options.strict_dunders_typing
with builder.enter_method(cls, "__ne__", eq_sig.ret_type):
rhs_type = eq_sig.args[0].type if strict_typing else object_rprimitive
rhs_arg = builder.add_argument("rhs", rhs_type)
eqval = builder.add(MethodCall(builder.self(), "__eq__", [rhs_arg], line))
can_return_not_implemented = is_subtype(not_implemented_op.type, eq_sig.ret_type)
return_bool = is_subtype(eq_sig.ret_type, bool_rprimitive)
if not strict_typing or can_return_not_implemented:
# If __eq__ returns NotImplemented, then __ne__ should also
not_implemented_block, regular_block = BasicBlock(), BasicBlock()
not_implemented = builder.add(
LoadAddress(not_implemented_op.type, not_implemented_op.src, line)
)
builder.add(
Branch(
builder.translate_is_op(eqval, not_implemented, "is", line),
not_implemented_block,
regular_block,
Branch.BOOL,
)
)
builder.activate_block(regular_block)
rettype = bool_rprimitive if return_bool and strict_typing else object_rprimitive
retval = builder.coerce(
builder.builder.unary_not(eqval, line, likely_bool=True), rettype, line
)
builder.add(Return(retval))
builder.activate_block(not_implemented_block)
builder.add(Return(not_implemented))
else:
rettype = bool_rprimitive if return_bool and strict_typing else object_rprimitive
retval = builder.coerce(builder.unary_op(eqval, "not", line), rettype, line)
builder.add(Return(retval))
def load_non_ext_class(
builder: IRBuilder, ir: ClassIR, non_ext: NonExtClassInfo, line: int
) -> Value:
cls_name = builder.load_str(ir.name)
add_dunders_to_non_ext_dict(builder, non_ext, line)
class_type_obj = builder.py_call(
non_ext.metaclass, [cls_name, non_ext.bases, non_ext.dict], line
)
return class_type_obj
def load_decorated_class(builder: IRBuilder, cdef: ClassDef, type_obj: Value) -> Value:
"""Apply class decorators to create a decorated (non-extension) class object.
Given a decorated ClassDef and a register containing a
non-extension representation of the ClassDef created via the type
constructor, applies the corresponding decorator functions on that
decorated ClassDef and returns a register containing the decorated
ClassDef.
"""
decorators = cdef.decorators
dec_class = type_obj
for d in reversed(decorators):
decorator = d.accept(builder.visitor)
assert isinstance(decorator, Value), decorator
dec_class = builder.py_call(decorator, [dec_class], dec_class.line)
return dec_class
def cache_class_attrs(
builder: IRBuilder, attrs_to_cache: list[tuple[Lvalue, RType]], cdef: ClassDef
) -> None:
"""Add class attributes to be cached to the global cache."""
typ = builder.load_native_type_object(cdef.info.fullname)
for lval, rtype in attrs_to_cache:
assert isinstance(lval, NameExpr), lval
rval = builder.py_get_attr(typ, lval.name, cdef.line)
builder.init_final_static(lval, rval, cdef.name, type_override=rtype)
def create_mypyc_attrs_tuple(builder: IRBuilder, ir: ClassIR, line: int) -> Value:
attrs = [name for ancestor in ir.mro for name in ancestor.attributes]
if ir.inherits_python:
attrs.append("__dict__")
items = [builder.load_str(attr) for attr in attrs]
return builder.new_tuple(items, line)
def add_dunders_to_non_ext_dict(
builder: IRBuilder, non_ext: NonExtClassInfo, line: int, add_annotations: bool = True
) -> None:
if add_annotations:
# Add __annotations__ to the class dict.
builder.add_to_non_ext_dict(non_ext, "__annotations__", non_ext.anns, line)
# We add a __doc__ attribute so if the non-extension class is decorated with the
# dataclass decorator, dataclass will not try to look for __text_signature__.
# https://github.com/python/cpython/blob/3.7/Lib/dataclasses.py#L957
filler_doc_str = "mypyc filler docstring"
builder.add_to_non_ext_dict(non_ext, "__doc__", builder.load_str(filler_doc_str), line)
builder.add_to_non_ext_dict(non_ext, "__module__", builder.load_str(builder.module_name), line)

View file

@ -0,0 +1,97 @@
"""Constant folding of IR values.
For example, 3 + 5 can be constant folded into 8.
This is mostly like mypy.constant_fold, but we can bind some additional
NameExpr and MemberExpr references here, since we have more knowledge
about which definitions can be trusted -- we constant fold only references
to other compiled modules in the same compilation unit.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Final
from mypy.constant_fold import constant_fold_binary_op, constant_fold_unary_op
from mypy.nodes import (
BytesExpr,
ComplexExpr,
Expression,
FloatExpr,
IntExpr,
MemberExpr,
NameExpr,
OpExpr,
StrExpr,
UnaryExpr,
Var,
)
from mypyc.irbuild.util import bytes_from_str
if TYPE_CHECKING:
from mypyc.irbuild.builder import IRBuilder
# All possible result types of constant folding
ConstantValue = int | float | complex | str | bytes
CONST_TYPES: Final = (int, float, complex, str, bytes)
def constant_fold_expr(builder: IRBuilder, expr: Expression) -> ConstantValue | None:
"""Return the constant value of an expression for supported operations.
Return None otherwise.
"""
if isinstance(expr, IntExpr):
return expr.value
if isinstance(expr, FloatExpr):
return expr.value
if isinstance(expr, StrExpr):
return expr.value
if isinstance(expr, BytesExpr):
return bytes_from_str(expr.value)
if isinstance(expr, ComplexExpr):
return expr.value
elif isinstance(expr, NameExpr):
node = expr.node
if isinstance(node, Var) and node.is_final:
final_value = node.final_value
if isinstance(final_value, (CONST_TYPES)):
return final_value
elif isinstance(expr, MemberExpr):
final = builder.get_final_ref(expr)
if final is not None:
fn, final_var, native = final
if final_var.is_final:
final_value = final_var.final_value
if isinstance(final_value, (CONST_TYPES)):
return final_value
elif isinstance(expr, OpExpr):
left = constant_fold_expr(builder, expr.left)
right = constant_fold_expr(builder, expr.right)
if left is not None and right is not None:
return constant_fold_binary_op_extended(expr.op, left, right)
elif isinstance(expr, UnaryExpr):
value = constant_fold_expr(builder, expr.expr)
if value is not None and not isinstance(value, bytes):
return constant_fold_unary_op(expr.op, value)
return None
def constant_fold_binary_op_extended(
op: str, left: ConstantValue, right: ConstantValue
) -> ConstantValue | None:
"""Like mypy's constant_fold_binary_op(), but includes bytes support.
mypy cannot use constant folded bytes easily so it's simpler to only support them in mypyc.
"""
if not isinstance(left, bytes) and not isinstance(right, bytes):
return constant_fold_binary_op(op, left, right)
if op == "+" and isinstance(left, bytes) and isinstance(right, bytes):
return left + right
elif op == "*" and isinstance(left, bytes) and isinstance(right, int):
return left * right
elif op == "*" and isinstance(left, int) and isinstance(right, bytes):
return left * right
return None

View file

@ -0,0 +1,206 @@
"""Helpers that store information about functions and the related classes."""
from __future__ import annotations
from mypy.nodes import FuncItem
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.func_ir import INVALID_FUNC_DEF
from mypyc.ir.ops import BasicBlock, Value
from mypyc.irbuild.targets import AssignmentTarget
class FuncInfo:
"""Contains information about functions as they are generated."""
def __init__(
self,
fitem: FuncItem = INVALID_FUNC_DEF,
name: str = "",
class_name: str | None = None,
namespace: str = "",
is_nested: bool = False,
contains_nested: bool = False,
is_decorated: bool = False,
in_non_ext: bool = False,
add_nested_funcs_to_env: bool = False,
is_comprehension_scope: bool = False,
) -> None:
self.fitem = fitem
self.name = name
self.class_name = class_name
self.ns = namespace
# Callable classes implement the '__call__' method, and are used to represent functions
# that are nested inside of other functions.
self._callable_class: ImplicitClass | None = None
# Environment classes are ClassIR instances that contain attributes representing the
# variables in the environment of the function they correspond to. Environment classes are
# generated for functions that contain nested functions.
self._env_class: ClassIR | None = None
# Generator classes implement the '__next__' method, and are used to represent generators
# returned by generator functions.
self._generator_class: GeneratorClass | None = None
# Environment class registers are the local registers associated with instances of an
# environment class, used for getting and setting attributes. curr_env_reg is the register
# associated with the current environment.
self._curr_env_reg: Value | None = None
# These are flags denoting whether a given function is nested, contains a nested function,
# is decorated, or is within a non-extension class.
self.is_nested = is_nested
self.contains_nested = contains_nested
self.is_decorated = is_decorated
self.in_non_ext = in_non_ext
self.add_nested_funcs_to_env = add_nested_funcs_to_env
# Comprehension scopes are lightweight scope boundaries created when
# a comprehension body contains a lambda. The comprehension is still
# inlined (same basic blocks), but we push a new FuncInfo so the
# closure machinery can capture loop variables through env classes.
self.is_comprehension_scope = is_comprehension_scope
# TODO: add field for ret_type: RType = none_rprimitive
def namespaced_name(self) -> str:
return "_".join(x for x in [self.name, self.class_name, self.ns] if x)
@property
def is_generator(self) -> bool:
return self.fitem.is_generator or self.fitem.is_coroutine
@property
def is_coroutine(self) -> bool:
return self.fitem.is_coroutine
@property
def callable_class(self) -> ImplicitClass:
assert self._callable_class is not None
return self._callable_class
@callable_class.setter
def callable_class(self, cls: ImplicitClass) -> None:
self._callable_class = cls
@property
def env_class(self) -> ClassIR:
assert self._env_class is not None
return self._env_class
@env_class.setter
def env_class(self, ir: ClassIR) -> None:
self._env_class = ir
@property
def generator_class(self) -> GeneratorClass:
assert self._generator_class is not None
return self._generator_class
@generator_class.setter
def generator_class(self, cls: GeneratorClass) -> None:
self._generator_class = cls
@property
def curr_env_reg(self) -> Value:
assert self._curr_env_reg is not None
return self._curr_env_reg
def can_merge_generator_and_env_classes(self) -> bool:
# In simple cases we can place the environment into the generator class,
# instead of having two separate classes.
if self._generator_class and not self._generator_class.ir.is_final_class:
result = False
else:
result = self.is_generator and not self.is_nested and not self.contains_nested
return result
class ImplicitClass:
"""Contains information regarding implicitly generated classes.
Implicit classes are generated for nested functions and generator
functions. They are not explicitly defined in the source code.
NOTE: This is both a concrete class and used as a base class.
"""
def __init__(self, ir: ClassIR) -> None:
# The ClassIR instance associated with this class.
self.ir = ir
# The register associated with the 'self' instance for this generator class.
self._self_reg: Value | None = None
# Environment class registers are the local registers associated with instances of an
# environment class, used for getting and setting attributes. curr_env_reg is the register
# associated with the current environment. prev_env_reg is the self.__mypyc_env__ field
# associated with the previous environment.
self._curr_env_reg: Value | None = None
self._prev_env_reg: Value | None = None
@property
def self_reg(self) -> Value:
assert self._self_reg is not None
return self._self_reg
@self_reg.setter
def self_reg(self, reg: Value) -> None:
self._self_reg = reg
@property
def curr_env_reg(self) -> Value:
assert self._curr_env_reg is not None
return self._curr_env_reg
@curr_env_reg.setter
def curr_env_reg(self, reg: Value) -> None:
self._curr_env_reg = reg
@property
def prev_env_reg(self) -> Value:
assert self._prev_env_reg is not None
return self._prev_env_reg
@prev_env_reg.setter
def prev_env_reg(self, reg: Value) -> None:
self._prev_env_reg = reg
class GeneratorClass(ImplicitClass):
"""Contains information about implicit generator function classes."""
def __init__(self, ir: ClassIR) -> None:
super().__init__(ir)
# This register holds the label number that the '__next__' function should go to the next
# time it is called.
self._next_label_reg: Value | None = None
self._next_label_target: AssignmentTarget | None = None
# These registers hold the error values for the generator object for the case that the
# 'throw' function is called.
self.exc_regs: tuple[Value, Value, Value] | None = None
# Holds the arg passed to send
self.send_arg_reg: Value | None = None
# Holds the PyObject ** pointer through which return value can be passed
# instead of raising StopIteration(ret_value) (only if not NULL). This
# is used for faster native-to-native calls.
self.stop_iter_value_reg: Value | None = None
# The switch block is used to decide which instruction to go using the value held in the
# next-label register.
self.switch_block = BasicBlock()
self.continuation_blocks: list[BasicBlock] = []
@property
def next_label_reg(self) -> Value:
assert self._next_label_reg is not None
return self._next_label_reg
@next_label_reg.setter
def next_label_reg(self, reg: Value) -> None:
self._next_label_reg = reg
@property
def next_label_target(self) -> AssignmentTarget:
assert self._next_label_target is not None
return self._next_label_target
@next_label_target.setter
def next_label_target(self, target: AssignmentTarget) -> None:
self._next_label_target = target

View file

@ -0,0 +1,310 @@
"""Generate classes representing function environments (+ related operations).
If we have a nested function that has non-local (free) variables, access to the
non-locals is via an instance of an environment class. Example:
def f() -> int:
x = 0 # Make 'x' an attribute of an environment class instance
def g() -> int:
# We have access to the environment class instance to
# allow accessing 'x'
return x + 2
x = x + 1 # Modify the attribute
return g()
"""
from __future__ import annotations
from mypy.nodes import Argument, FuncDef, SymbolNode, Var
from mypyc.common import (
BITMAP_BITS,
ENV_ATTR_NAME,
GENERATOR_ATTRIBUTE_PREFIX,
SELF_NAME,
bitmap_name,
)
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.ops import Call, GetAttr, SetAttr, Value
from mypyc.ir.rtypes import RInstance, bitmap_rprimitive, object_rprimitive
from mypyc.irbuild.builder import IRBuilder, SymbolTarget
from mypyc.irbuild.context import FuncInfo, GeneratorClass, ImplicitClass
from mypyc.irbuild.targets import AssignmentTargetAttr
def setup_env_class(builder: IRBuilder) -> ClassIR:
"""Generate a class representing a function environment.
Note that the variables in the function environment are not
actually populated here. This is because when the environment
class is generated, the function environment has not yet been
visited. This behavior is allowed so that when the compiler visits
nested functions, it can use the returned ClassIR instance to
figure out free variables it needs to access. The remaining
attributes of the environment class are populated when the
environment registers are loaded.
Return a ClassIR representing an environment for a function
containing a nested function.
"""
env_class = ClassIR(
f"{builder.fn_info.namespaced_name()}_env",
builder.module_name,
is_generated=True,
is_final_class=True,
)
env_class.reuse_freed_instance = True
env_class.attributes[SELF_NAME] = RInstance(env_class)
if builder.fn_info.is_nested and builder.fn_infos[-2]._env_class is not None:
# If the function is nested, its environment class must contain an environment
# attribute pointing to its encapsulating functions' environment class.
env_class.attributes[ENV_ATTR_NAME] = RInstance(builder.fn_infos[-2].env_class)
env_class.mro = [env_class]
builder.fn_info.env_class = env_class
builder.classes.append(env_class)
return env_class
def finalize_env_class(builder: IRBuilder, prefix: str = "") -> None:
"""Generate, instantiate, and set up the environment of an environment class."""
if not builder.fn_info.can_merge_generator_and_env_classes():
instantiate_env_class(builder)
# Iterate through the function arguments and replace local definitions (using registers)
# that were previously added to the environment with references to the function's
# environment class. Comprehension scopes have no arguments to add.
if not builder.fn_info.is_comprehension_scope:
if builder.fn_info.is_nested:
add_args_to_env(
builder, local=False, base=builder.fn_info.callable_class, prefix=prefix
)
else:
add_args_to_env(builder, local=False, base=builder.fn_info, prefix=prefix)
def instantiate_env_class(builder: IRBuilder) -> Value:
"""Assign an environment class to a register named after the given function definition."""
curr_env_reg = builder.add(
Call(builder.fn_info.env_class.ctor, [], builder.fn_info.fitem.line)
)
if builder.fn_info.is_nested and not builder.fn_info.is_comprehension_scope:
builder.fn_info.callable_class._curr_env_reg = curr_env_reg
builder.add(
SetAttr(
curr_env_reg,
ENV_ATTR_NAME,
builder.fn_info.callable_class.prev_env_reg,
builder.fn_info.fitem.line,
)
)
else:
# Top-level functions and comprehension scopes store env reg directly.
builder.fn_info._curr_env_reg = curr_env_reg
# Comprehension scopes link to parent env if it exists.
if (
builder.fn_info.is_nested
and builder.fn_infos[-2]._env_class is not None
and builder.fn_infos[-2]._curr_env_reg is not None
):
builder.add(
SetAttr(
curr_env_reg,
ENV_ATTR_NAME,
builder.fn_infos[-2].curr_env_reg,
builder.fn_info.fitem.line,
)
)
return curr_env_reg
def load_env_registers(builder: IRBuilder, prefix: str = "") -> None:
"""Load the registers for the current FuncItem being visited.
Adds the arguments of the FuncItem to the environment. If the
FuncItem is nested inside of another function, then this also
loads all of the outer environments of the FuncItem into registers
so that they can be used when accessing free variables.
"""
add_args_to_env(builder, local=True, prefix=prefix)
fn_info = builder.fn_info
fitem = fn_info.fitem
if fn_info.is_nested and builder.fn_infos[-2]._env_class is not None:
load_outer_envs(builder, fn_info.callable_class)
# If this is a FuncDef, then make sure to load the FuncDef into its own environment
# class so that the function can be called recursively.
if isinstance(fitem, FuncDef) and fn_info.add_nested_funcs_to_env:
setup_func_for_recursive_call(builder, fitem, fn_info.callable_class, prefix=prefix)
def load_outer_env(
builder: IRBuilder, base: Value, outer_env: dict[SymbolNode, SymbolTarget]
) -> Value:
"""Load the environment class for a given base into a register.
Additionally, iterates through all of the SymbolNode and
AssignmentTarget instances of the environment at the given index's
symtable, and adds those instances to the environment of the
current environment. This is done so that the current environment
can access outer environment variables without having to reload
all of the environment registers.
Returns the register where the environment class was loaded.
"""
env = builder.add(GetAttr(base, ENV_ATTR_NAME, builder.fn_info.fitem.line))
assert isinstance(env.type, RInstance), f"{env} must be of type RInstance"
for symbol, target in outer_env.items():
attr_name = symbol.name
if isinstance(target, AssignmentTargetAttr):
attr_name = target.attr
env.type.class_ir.attributes[attr_name] = target.type
symbol_target = AssignmentTargetAttr(env, attr_name)
builder.add_target(symbol, symbol_target)
return env
def load_outer_envs(builder: IRBuilder, base: ImplicitClass) -> None:
index = len(builder.builders) - 2
# Load the first outer environment. This one is special because it gets saved in the
# FuncInfo instance's prev_env_reg field.
has_outer = index > 1 or (index == 1 and builder.fn_infos[1].contains_nested)
if has_outer and builder.fn_infos[index]._env_class is not None:
# outer_env = builder.fn_infos[index].environment
outer_env = builder.symtables[index]
if isinstance(base, GeneratorClass):
base.prev_env_reg = load_outer_env(builder, base.curr_env_reg, outer_env)
else:
base.prev_env_reg = load_outer_env(builder, base.self_reg, outer_env)
env_reg = base.prev_env_reg
index -= 1
# Load the remaining outer environments into registers.
while index > 1:
if builder.fn_infos[index]._env_class is None:
break
# outer_env = builder.fn_infos[index].environment
outer_env = builder.symtables[index]
env_reg = load_outer_env(builder, env_reg, outer_env)
index -= 1
def num_bitmap_args(builder: IRBuilder, args: list[Argument]) -> int:
n = 0
for arg in args:
t = builder.type_to_rtype(arg.variable.type)
if t.error_overlap and arg.kind.is_optional():
n += 1
return (n + (BITMAP_BITS - 1)) // BITMAP_BITS
def add_args_to_env(
builder: IRBuilder,
local: bool = True,
base: FuncInfo | ImplicitClass | None = None,
reassign: bool = True,
prefix: str = "",
) -> None:
fn_info = builder.fn_info
args = fn_info.fitem.arguments
nb = num_bitmap_args(builder, args)
if local:
for arg in args:
rtype = builder.type_to_rtype(arg.variable.type)
builder.add_local_reg(arg.variable, rtype, is_arg=True)
for i in reversed(range(nb)):
builder.add_local_reg(Var(bitmap_name(i)), bitmap_rprimitive, is_arg=True)
else:
for arg in args:
if (
is_free_variable(builder, arg.variable)
or fn_info.is_generator
or fn_info.is_coroutine
):
rtype = builder.type_to_rtype(arg.variable.type)
assert base is not None, "base cannot be None for adding nonlocal args"
builder.add_var_to_env_class(
arg.variable, rtype, base, reassign=reassign, prefix=prefix
)
def add_vars_to_env(builder: IRBuilder, prefix: str = "") -> None:
"""Add relevant local variables and nested functions to the environment class.
Add all variables and functions that are declared/defined within current
function and are referenced in functions nested within this one to this
function's environment class so the nested functions can reference
them even if they are declared after the nested function's definition.
Note that this is done before visiting the body of the function.
"""
env_for_func: FuncInfo | ImplicitClass = builder.fn_info
if builder.fn_info.is_generator:
env_for_func = builder.fn_info.generator_class
elif (
builder.fn_info.is_nested or builder.fn_info.in_non_ext
) and not builder.fn_info.is_comprehension_scope:
env_for_func = builder.fn_info.callable_class
if builder.fn_info.fitem in builder.free_variables:
# Sort the variables to keep things deterministic
for var in sorted(builder.free_variables[builder.fn_info.fitem], key=lambda x: x.name):
if isinstance(var, Var):
rtype = builder.type_to_rtype(var.type)
builder.add_var_to_env_class(
var, rtype, env_for_func, reassign=False, prefix=prefix
)
if builder.fn_info.fitem in builder.encapsulating_funcs:
for nested_fn in builder.encapsulating_funcs[builder.fn_info.fitem]:
if isinstance(nested_fn, FuncDef):
# The return type is 'object' instead of an RInstance of the
# callable class because differently defined functions with
# the same name and signature across conditional blocks
# will generate different callable classes, so the callable
# class that gets instantiated must be generic.
if nested_fn.is_generator or nested_fn.is_coroutine:
prefix = GENERATOR_ATTRIBUTE_PREFIX
builder.add_var_to_env_class(
nested_fn, object_rprimitive, env_for_func, reassign=False, prefix=prefix
)
def setup_func_for_recursive_call(
builder: IRBuilder, fdef: FuncDef, base: ImplicitClass, prefix: str = ""
) -> None:
"""Enable calling a nested function (with a callable class) recursively.
Adds the instance of the callable class representing the given
FuncDef to a register in the environment so that the function can
be called recursively. Note that this needs to be done only for
nested functions.
"""
# First, set the attribute of the environment class so that GetAttr can be called on it.
prev_env = builder.fn_infos[-2].env_class
attr_name = prefix + fdef.name
prev_env.attributes[attr_name] = builder.type_to_rtype(fdef.type)
line = fdef.line
if isinstance(base, GeneratorClass):
# If we are dealing with a generator class, then we need to first get the register
# holding the current environment class, and load the previous environment class from
# there.
prev_env_reg = builder.add(GetAttr(base.curr_env_reg, ENV_ATTR_NAME, line))
else:
prev_env_reg = base.prev_env_reg
# Obtain the instance of the callable class representing the FuncDef, and add it to the
# current environment.
val = builder.add(GetAttr(prev_env_reg, attr_name, line))
target = builder.add_local_reg(fdef, object_rprimitive)
builder.assign(target, val, line)
def is_free_variable(builder: IRBuilder, symbol: SymbolNode) -> bool:
fitem = builder.fn_info.fitem
return fitem in builder.free_variables and symbol in builder.free_variables[fitem]

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,260 @@
"""Tokenizers for three string formatting methods"""
from __future__ import annotations
from enum import Enum, unique
from typing import Final
from mypy.checkstrformat import (
ConversionSpecifier,
parse_conversion_specifiers,
parse_format_value,
)
from mypy.errors import Errors
from mypy.messages import MessageBuilder
from mypy.nodes import Context, Expression
from mypy.options import Options
from mypyc.ir.ops import Integer, Value
from mypyc.ir.rtypes import (
c_pyssize_t_rprimitive,
is_bytes_rprimitive,
is_int_rprimitive,
is_short_int_rprimitive,
is_str_rprimitive,
)
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.constant_fold import constant_fold_expr
from mypyc.primitives.bytes_ops import bytes_build_op
from mypyc.primitives.int_ops import int_to_ascii_op, int_to_str_op
from mypyc.primitives.str_ops import str_build_op, str_op
@unique
class FormatOp(Enum):
"""FormatOp represents conversion operations of string formatting during
compile time.
Compare to ConversionSpecifier, FormatOp has fewer attributes.
For example, to mark a conversion from any object to string,
ConversionSpecifier may have several representations, like '%s', '{}'
or '{:{}}'. However, there would only exist one corresponding FormatOp.
"""
STR = "s"
INT = "d"
BYTES = "b"
def generate_format_ops(specifiers: list[ConversionSpecifier]) -> list[FormatOp] | None:
"""Convert ConversionSpecifier to FormatOp.
Different ConversionSpecifiers may share a same FormatOp.
"""
format_ops = []
for spec in specifiers:
# TODO: Match specifiers instead of using whole_seq
if spec.whole_seq == "%s" or spec.whole_seq == "{:{}}":
format_op = FormatOp.STR
elif spec.whole_seq == "%d":
format_op = FormatOp.INT
elif spec.whole_seq == "%b":
format_op = FormatOp.BYTES
elif spec.whole_seq:
return None
else:
format_op = FormatOp.STR
format_ops.append(format_op)
return format_ops
def tokenizer_printf_style(format_str: str) -> tuple[list[str], list[FormatOp]] | None:
"""Tokenize a printf-style format string using regex.
Return:
A list of string literals and a list of FormatOps.
"""
literals: list[str] = []
specifiers: list[ConversionSpecifier] = parse_conversion_specifiers(format_str)
format_ops = generate_format_ops(specifiers)
if format_ops is None:
return None
last_end = 0
for spec in specifiers:
cur_start = spec.start_pos
literals.append(format_str[last_end:cur_start])
last_end = cur_start + len(spec.whole_seq)
literals.append(format_str[last_end:])
return literals, format_ops
# The empty Context as an argument for parse_format_value().
# It wouldn't be used since the code has passed the type-checking.
EMPTY_CONTEXT: Final = Context()
def tokenizer_format_call(format_str: str) -> tuple[list[str], list[FormatOp]] | None:
"""Tokenize a str.format() format string.
The core function parse_format_value() is shared with mypy.
With these specifiers, we then parse the literal substrings
of the original format string and convert `ConversionSpecifier`
to `FormatOp`.
Return:
A list of string literals and a list of FormatOps. The literals
are interleaved with FormatOps and the length of returned literals
should be exactly one more than FormatOps.
Return None if it cannot parse the string.
"""
# Creates an empty MessageBuilder here.
# It wouldn't be used since the code has passed the type-checking.
specifiers = parse_format_value(
format_str, EMPTY_CONTEXT, MessageBuilder(Errors(Options()), {})
)
if specifiers is None:
return None
format_ops = generate_format_ops(specifiers)
if format_ops is None:
return None
literals: list[str] = []
last_end = 0
for spec in specifiers:
# Skip { and }
literals.append(format_str[last_end : spec.start_pos - 1])
last_end = spec.start_pos + len(spec.whole_seq) + 1
literals.append(format_str[last_end:])
# Deal with escaped {{
literals = [x.replace("{{", "{").replace("}}", "}") for x in literals]
return literals, format_ops
def convert_format_expr_to_str(
builder: IRBuilder, format_ops: list[FormatOp], exprs: list[Expression], line: int
) -> list[Value] | None:
"""Convert expressions into string literal objects with the guidance
of FormatOps. Return None when fails."""
if len(format_ops) != len(exprs):
return None
converted = []
for x, format_op in zip(exprs, format_ops):
node_type = builder.node_type(x)
if format_op == FormatOp.STR:
if isinstance(folded := constant_fold_expr(builder, x), str):
var_str = builder.load_literal_value(folded)
elif is_str_rprimitive(node_type):
var_str = builder.accept(x)
elif is_int_rprimitive(node_type) or is_short_int_rprimitive(node_type):
var_str = builder.primitive_op(int_to_str_op, [builder.accept(x)], line)
else:
var_str = builder.primitive_op(str_op, [builder.accept(x)], line)
elif format_op == FormatOp.INT:
if isinstance(folded := constant_fold_expr(builder, x), int):
var_str = builder.load_literal_value(str(folded))
elif is_int_rprimitive(node_type) or is_short_int_rprimitive(node_type):
var_str = builder.primitive_op(int_to_str_op, [builder.accept(x)], line)
else:
return None
else:
return None
converted.append(var_str)
return converted
def join_formatted_strings(
builder: IRBuilder, literals: list[str] | None, substitutions: list[Value], line: int
) -> Value:
"""Merge the list of literals and the list of substitutions
alternatively using 'str_build_op'.
`substitutions` is the result value of formatting conversions.
If the `literals` is set to None, we simply join the substitutions;
Otherwise, the `literals` is the literal substrings of the original
format string and its length should be exactly one more than
substitutions.
For example:
(1) 'This is a %s and the value is %d'
-> literals: ['This is a ', ' and the value is', '']
(2) '{} and the value is {}'
-> literals: ['', ' and the value is', '']
"""
# The first parameter for str_build_op is the total size of
# the following PyObject*
result_list: list[Value] = [Integer(0, c_pyssize_t_rprimitive)]
if literals is not None:
for a, b in zip(literals, substitutions):
if a:
result_list.append(builder.load_str(a))
result_list.append(b)
if literals[-1]:
result_list.append(builder.load_str(literals[-1]))
else:
result_list.extend(substitutions)
# Special case for empty string and literal string
if len(result_list) == 1:
return builder.load_str("")
if not substitutions and len(result_list) == 2:
return result_list[1]
result_list[0] = Integer(len(result_list) - 1, c_pyssize_t_rprimitive)
return builder.call_c(str_build_op, result_list, line)
def convert_format_expr_to_bytes(
builder: IRBuilder, format_ops: list[FormatOp], exprs: list[Expression], line: int
) -> list[Value] | None:
"""Convert expressions into bytes literal objects with the guidance
of FormatOps. Return None when fails."""
if len(format_ops) != len(exprs):
return None
converted = []
for x, format_op in zip(exprs, format_ops):
node_type = builder.node_type(x)
# conversion type 's' is an alias of 'b' in bytes formatting
if format_op == FormatOp.BYTES or format_op == FormatOp.STR:
if is_bytes_rprimitive(node_type):
var_bytes = builder.accept(x)
else:
return None
elif format_op == FormatOp.INT:
if isinstance(folded := constant_fold_expr(builder, x), int):
var_bytes = builder.load_literal_value(str(folded).encode("ascii"))
elif is_int_rprimitive(node_type) or is_short_int_rprimitive(node_type):
var_bytes = builder.call_c(int_to_ascii_op, [builder.accept(x)], line)
else:
return None
converted.append(var_bytes)
return converted
def join_formatted_bytes(
builder: IRBuilder, literals: list[str], substitutions: list[Value], line: int
) -> Value:
"""Merge the list of literals and the list of substitutions
alternatively using 'bytes_build_op'."""
result_list: list[Value] = [Integer(0, c_pyssize_t_rprimitive)]
for a, b in zip(literals, substitutions):
if a:
result_list.append(builder.load_bytes_from_str_literal(a))
result_list.append(b)
if literals[-1]:
result_list.append(builder.load_bytes_from_str_literal(literals[-1]))
# Special case for empty bytes and literal
if len(result_list) == 1:
return builder.load_bytes_from_str_literal("")
if not substitutions and len(result_list) == 2:
return result_list[1]
result_list[0] = Integer(len(result_list) - 1, c_pyssize_t_rprimitive)
return builder.call_c(bytes_build_op, result_list, line)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,438 @@
"""Generate IR for generator functions.
A generator function is represented by a class that implements the
generator protocol and keeps track of the generator state, including
local variables.
The top-level logic for dealing with generator functions is in
mypyc.irbuild.function.
"""
from __future__ import annotations
from collections.abc import Callable
from mypy.nodes import ARG_OPT, FuncDef, Var
from mypyc.common import ENV_ATTR_NAME, GENERATOR_ATTRIBUTE_PREFIX, NEXT_LABEL_ATTR_NAME
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.func_ir import FuncDecl, FuncIR
from mypyc.ir.ops import (
NO_TRACEBACK_LINE_NO,
BasicBlock,
Branch,
Call,
Goto,
Integer,
MethodCall,
RaiseStandardError,
Register,
Return,
SetAttr,
TupleSet,
Unreachable,
Value,
)
from mypyc.ir.rtypes import (
RInstance,
int32_rprimitive,
object_pointer_rprimitive,
object_rprimitive,
)
from mypyc.irbuild.builder import IRBuilder, calculate_arg_defaults, gen_arg_defaults
from mypyc.irbuild.context import FuncInfo
from mypyc.irbuild.env_class import (
add_args_to_env,
add_vars_to_env,
finalize_env_class,
load_env_registers,
load_outer_env,
load_outer_envs,
setup_func_for_recursive_call,
)
from mypyc.irbuild.nonlocalcontrol import ExceptNonlocalControl
from mypyc.irbuild.prepare import GENERATOR_HELPER_NAME
from mypyc.primitives.exc_ops import (
error_catch_op,
exc_matches_op,
raise_exception_with_tb_op,
reraise_exception_op,
restore_exc_info_op,
)
def gen_generator_func(
builder: IRBuilder,
gen_func_ir: Callable[
[list[Register], list[BasicBlock], FuncInfo], tuple[FuncIR, Value | None]
],
) -> tuple[FuncIR, Value | None]:
"""Generate IR for generator function that returns generator object."""
setup_generator_class(builder)
load_env_registers(builder, prefix=GENERATOR_ATTRIBUTE_PREFIX)
gen_arg_defaults(builder)
if builder.fn_info.can_merge_generator_and_env_classes():
gen = instantiate_generator_class(builder)
builder.fn_info._curr_env_reg = gen
finalize_env_class(builder, prefix=GENERATOR_ATTRIBUTE_PREFIX)
else:
finalize_env_class(builder, prefix=GENERATOR_ATTRIBUTE_PREFIX)
gen = instantiate_generator_class(builder)
builder.add(Return(gen))
args, _, blocks, ret_type, fn_info = builder.leave()
func_ir, func_reg = gen_func_ir(args, blocks, fn_info)
return func_ir, func_reg
def gen_generator_func_body(builder: IRBuilder, fn_info: FuncInfo, func_reg: Value | None) -> None:
"""Generate IR based on the body of a generator function.
Add "__next__", "__iter__" and other generator methods to the generator
class that implements the function (each function gets a separate class).
Return the symbol table for the body.
"""
builder.enter(fn_info, ret_type=object_rprimitive)
setup_env_for_generator_class(builder)
load_outer_envs(builder, builder.fn_info.generator_class)
top_level = builder.top_level_fn_info()
fitem = fn_info.fitem
if (
builder.fn_info.is_nested
and isinstance(fitem, FuncDef)
and top_level
and top_level.add_nested_funcs_to_env
):
setup_func_for_recursive_call(
builder, fitem, builder.fn_info.generator_class, prefix=GENERATOR_ATTRIBUTE_PREFIX
)
create_switch_for_generator_class(builder)
add_raise_exception_blocks_to_generator_class(builder, fitem.line)
add_vars_to_env(builder, prefix=GENERATOR_ATTRIBUTE_PREFIX)
builder.accept(fitem.body)
builder.maybe_add_implicit_return()
populate_switch_for_generator_class(builder)
# Hang on to the local symbol table, since the caller will use it
# to calculate argument defaults.
symtable = builder.symtables[-1]
args, _, blocks, ret_type, fn_info = builder.leave()
add_methods_to_generator_class(builder, fn_info, args, blocks, fitem.is_coroutine)
# Evaluate argument defaults in the surrounding scope, since we
# calculate them *once* when the function definition is evaluated.
calculate_arg_defaults(builder, fn_info, func_reg, symtable)
def instantiate_generator_class(builder: IRBuilder) -> Value:
fitem = builder.fn_info.fitem
generator_reg = builder.add(Call(builder.fn_info.generator_class.ir.ctor, [], fitem.line))
if builder.fn_info.can_merge_generator_and_env_classes():
# Set the generator instance to the initial state (zero).
zero = Integer(0)
builder.add(SetAttr(generator_reg, NEXT_LABEL_ATTR_NAME, zero, fitem.line))
else:
# Get the current environment register. If the current function is nested, then the
# generator class gets instantiated from the callable class' '__call__' method, and hence
# we use the callable class' environment register. Otherwise, we use the original
# function's environment register.
if builder.fn_info.is_nested:
curr_env_reg = builder.fn_info.callable_class.curr_env_reg
else:
curr_env_reg = builder.fn_info.curr_env_reg
# Set the generator class' environment attribute to point at the environment class
# defined in the current scope.
builder.add(SetAttr(generator_reg, ENV_ATTR_NAME, curr_env_reg, fitem.line))
# Set the generator instance's environment to the initial state (zero).
zero = Integer(0)
builder.add(SetAttr(curr_env_reg, NEXT_LABEL_ATTR_NAME, zero, fitem.line))
return generator_reg
def setup_generator_class(builder: IRBuilder) -> ClassIR:
mapper = builder.mapper
assert isinstance(builder.fn_info.fitem, FuncDef), builder.fn_info.fitem
generator_class_ir = mapper.fdef_to_generator[builder.fn_info.fitem]
if builder.fn_info.can_merge_generator_and_env_classes():
builder.fn_info.env_class = generator_class_ir
else:
generator_class_ir.attributes[ENV_ATTR_NAME] = RInstance(builder.fn_info.env_class)
builder.classes.append(generator_class_ir)
return generator_class_ir
def create_switch_for_generator_class(builder: IRBuilder) -> None:
builder.add(Goto(builder.fn_info.generator_class.switch_block))
block = BasicBlock()
builder.fn_info.generator_class.continuation_blocks.append(block)
builder.activate_block(block)
def populate_switch_for_generator_class(builder: IRBuilder) -> None:
cls = builder.fn_info.generator_class
line = builder.fn_info.fitem.line
builder.activate_block(cls.switch_block)
for label, true_block in enumerate(cls.continuation_blocks):
false_block = BasicBlock()
comparison = builder.binary_op(cls.next_label_reg, Integer(label), "==", line)
builder.add_bool_branch(comparison, true_block, false_block)
builder.activate_block(false_block)
builder.add(RaiseStandardError(RaiseStandardError.STOP_ITERATION, None, line))
builder.add(Unreachable())
def add_raise_exception_blocks_to_generator_class(builder: IRBuilder, line: int) -> None:
"""Add error handling blocks to a generator class.
Generates blocks to check if error flags are set while calling the
helper method for generator functions, and raises an exception if
those flags are set.
"""
cls = builder.fn_info.generator_class
assert cls.exc_regs is not None
exc_type, exc_val, exc_tb = cls.exc_regs
# Check to see if an exception was raised.
error_block = BasicBlock()
ok_block = BasicBlock()
comparison = builder.translate_is_op(exc_type, builder.none_object(), "is not", line)
builder.add_bool_branch(comparison, error_block, ok_block)
builder.activate_block(error_block)
builder.call_c(raise_exception_with_tb_op, [exc_type, exc_val, exc_tb], line)
builder.add(Unreachable())
builder.goto_and_activate(ok_block)
def add_methods_to_generator_class(
builder: IRBuilder,
fn_info: FuncInfo,
arg_regs: list[Register],
blocks: list[BasicBlock],
is_coroutine: bool,
) -> None:
helper_fn_decl = add_helper_to_generator_class(builder, arg_regs, blocks, fn_info)
add_next_to_generator_class(builder, fn_info, helper_fn_decl)
add_send_to_generator_class(builder, fn_info, helper_fn_decl)
add_iter_to_generator_class(builder, fn_info)
add_throw_to_generator_class(builder, fn_info, helper_fn_decl)
add_close_to_generator_class(builder, fn_info)
if is_coroutine:
add_await_to_generator_class(builder, fn_info)
def add_helper_to_generator_class(
builder: IRBuilder, arg_regs: list[Register], blocks: list[BasicBlock], fn_info: FuncInfo
) -> FuncDecl:
"""Generates a helper method for a generator class, called by '__next__' and 'throw'."""
helper_fn_decl = fn_info.generator_class.ir.method_decls[GENERATOR_HELPER_NAME]
helper_fn_ir = FuncIR(
helper_fn_decl, arg_regs, blocks, fn_info.fitem.line, traceback_name=fn_info.fitem.name
)
fn_info.generator_class.ir.methods[GENERATOR_HELPER_NAME] = helper_fn_ir
builder.functions.append(helper_fn_ir)
fn_info.env_class.env_user_function = helper_fn_ir
return helper_fn_decl
def add_iter_to_generator_class(builder: IRBuilder, fn_info: FuncInfo) -> None:
"""Generates the '__iter__' method for a generator class."""
with builder.enter_method(fn_info.generator_class.ir, "__iter__", object_rprimitive, fn_info):
builder.add(Return(builder.self()))
def add_next_to_generator_class(builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl) -> None:
"""Generates the '__next__' method for a generator class."""
with builder.enter_method(fn_info.generator_class.ir, "__next__", object_rprimitive, fn_info):
none_reg = builder.none_object()
# Call the helper function with error flags set to Py_None, and return that result.
result = builder.add(
Call(
fn_decl,
[
builder.self(),
none_reg,
none_reg,
none_reg,
none_reg,
Integer(0, object_pointer_rprimitive),
],
fn_info.fitem.line,
)
)
builder.add(Return(result))
def add_send_to_generator_class(builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl) -> None:
"""Generates the 'send' method for a generator class."""
with builder.enter_method(fn_info.generator_class.ir, "send", object_rprimitive, fn_info):
arg = builder.add_argument("arg", object_rprimitive)
none_reg = builder.none_object()
# Call the helper function with error flags set to Py_None, and return that result.
result = builder.add(
Call(
fn_decl,
[
builder.self(),
none_reg,
none_reg,
none_reg,
builder.read(arg),
Integer(0, object_pointer_rprimitive),
],
fn_info.fitem.line,
)
)
builder.add(Return(result))
def add_throw_to_generator_class(builder: IRBuilder, fn_info: FuncInfo, fn_decl: FuncDecl) -> None:
"""Generates the 'throw' method for a generator class."""
with builder.enter_method(fn_info.generator_class.ir, "throw", object_rprimitive, fn_info):
typ = builder.add_argument("type", object_rprimitive)
val = builder.add_argument("value", object_rprimitive, ARG_OPT)
tb = builder.add_argument("traceback", object_rprimitive, ARG_OPT)
# Because the value and traceback arguments are optional and hence
# can be NULL if not passed in, we have to assign them Py_None if
# they are not passed in.
none_reg = builder.none_object()
builder.assign_if_null(val, lambda: none_reg, fn_info.fitem.line)
builder.assign_if_null(tb, lambda: none_reg, fn_info.fitem.line)
# Call the helper function using the arguments passed in, and return that result.
result = builder.add(
Call(
fn_decl,
[
builder.self(),
builder.read(typ),
builder.read(val),
builder.read(tb),
none_reg,
Integer(0, object_pointer_rprimitive),
],
fn_info.fitem.line,
)
)
builder.add(Return(result))
def add_close_to_generator_class(builder: IRBuilder, fn_info: FuncInfo) -> None:
"""Generates the '__close__' method for a generator class."""
with builder.enter_method(fn_info.generator_class.ir, "close", object_rprimitive, fn_info):
except_block, else_block = BasicBlock(), BasicBlock()
builder.builder.push_error_handler(except_block)
builder.goto_and_activate(BasicBlock())
generator_exit = builder.load_module_attr_by_fullname(
"builtins.GeneratorExit", fn_info.fitem.line
)
builder.add(
MethodCall(
builder.self(),
"throw",
[generator_exit, builder.none_object(), builder.none_object()],
fn_info.fitem.line,
)
)
builder.goto(else_block)
builder.builder.pop_error_handler()
builder.activate_block(except_block)
old_exc = builder.call_c(error_catch_op, [], fn_info.fitem.line)
builder.nonlocal_control.append(
ExceptNonlocalControl(builder.nonlocal_control[-1], old_exc)
)
stop_iteration = builder.load_module_attr_by_fullname(
"builtins.StopIteration", fn_info.fitem.line
)
exceptions = builder.add(TupleSet([generator_exit, stop_iteration], fn_info.fitem.line))
matches = builder.call_c(exc_matches_op, [exceptions], fn_info.fitem.line)
match_block, non_match_block = BasicBlock(), BasicBlock()
builder.add(Branch(matches, match_block, non_match_block, Branch.BOOL))
builder.activate_block(match_block)
builder.call_c(restore_exc_info_op, [builder.read(old_exc)], fn_info.fitem.line)
builder.add(Return(builder.none_object()))
builder.activate_block(non_match_block)
builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO)
builder.add(Unreachable())
builder.nonlocal_control.pop()
builder.activate_block(else_block)
builder.add(
RaiseStandardError(
RaiseStandardError.RUNTIME_ERROR,
"generator ignored GeneratorExit",
fn_info.fitem.line,
)
)
builder.add(Unreachable())
def add_await_to_generator_class(builder: IRBuilder, fn_info: FuncInfo) -> None:
"""Generates the '__await__' method for a generator class."""
with builder.enter_method(fn_info.generator_class.ir, "__await__", object_rprimitive, fn_info):
builder.add(Return(builder.self()))
def setup_env_for_generator_class(builder: IRBuilder) -> None:
"""Populates the environment for a generator class."""
fitem = builder.fn_info.fitem
cls = builder.fn_info.generator_class
self_target = builder.add_self_to_env(cls.ir)
# Add the type, value, and traceback variables to the environment.
exc_type = builder.add_local(Var("type"), object_rprimitive, is_arg=True)
exc_val = builder.add_local(Var("value"), object_rprimitive, is_arg=True)
exc_tb = builder.add_local(Var("traceback"), object_rprimitive, is_arg=True)
# TODO: Use the right type here instead of object?
exc_arg = builder.add_local(Var("arg"), object_rprimitive, is_arg=True)
# Parameter that can used to pass a pointer which can used instead of
# raising StopIteration(value). If the value is NULL, this won't be used.
stop_iter_value_arg = builder.add_local(
Var("stop_iter_ptr"), object_pointer_rprimitive, is_arg=True
)
cls.exc_regs = (exc_type, exc_val, exc_tb)
cls.send_arg_reg = exc_arg
cls.stop_iter_value_reg = stop_iter_value_arg
cls.self_reg = builder.read(self_target, fitem.line)
if builder.fn_info.can_merge_generator_and_env_classes():
cls.curr_env_reg = cls.self_reg
else:
cls.curr_env_reg = load_outer_env(builder, cls.self_reg, builder.symtables[-1])
# Define a variable representing the label to go to the next time
# the '__next__' function of the generator is called, and add it
# as an attribute to the environment class.
cls.next_label_target = builder.add_var_to_env_class(
Var(NEXT_LABEL_ATTR_NAME), int32_rprimitive, cls, reassign=False, always_defined=True
)
# Add arguments from the original generator function to the
# environment of the generator class.
add_args_to_env(
builder, local=False, base=cls, reassign=False, prefix=GENERATOR_ATTRIBUTE_PREFIX
)
# Set the next label register for the generator class.
cls.next_label_reg = builder.read(cls.next_label_target, fitem.line)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,173 @@
"""Transform a mypy AST to the IR form (Intermediate Representation).
For example, consider a function like this:
def f(x: int) -> int:
return x * 2 + 1
It would be translated to something that conceptually looks like this:
r0 = 2
r1 = 1
r2 = x * r0 :: int
r3 = r2 + r1 :: int
return r3
This module deals with the module-level IR transformation logic and
putting it all together. The actual IR is implemented in mypyc.ir.
For the core of the IR transform implementation, look at build_ir()
below, mypyc.irbuild.builder, and mypyc.irbuild.visitor.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import Any, TypeVar, cast
from mypy.build import Graph
from mypy.nodes import ClassDef, Expression, FuncDef, MypyFile
from mypy.state import state
from mypy.types import Type
from mypyc.analysis.attrdefined import analyze_always_defined_attrs
from mypyc.common import TOP_LEVEL_NAME
from mypyc.errors import Errors
from mypyc.ir.func_ir import FuncDecl, FuncIR, FuncSignature
from mypyc.ir.module_ir import ModuleIR, ModuleIRs
from mypyc.ir.rtypes import none_rprimitive
from mypyc.irbuild.builder import IRBuilder
from mypyc.irbuild.mapper import Mapper
from mypyc.irbuild.prebuildvisitor import PreBuildVisitor
from mypyc.irbuild.prepare import (
adjust_generator_classes_of_methods,
build_type_map,
create_generator_class_for_func,
find_singledispatch_register_impls,
)
from mypyc.irbuild.visitor import IRBuilderVisitor
from mypyc.irbuild.vtable import compute_vtable
from mypyc.options import CompilerOptions
# The stubs for callable contextmanagers are busted so cast it to the
# right type...
F = TypeVar("F", bound=Callable[..., Any])
strict_optional_dec = cast(Callable[[F], F], state.strict_optional_set(True))
@strict_optional_dec # Turn on strict optional for any type manipulations we do
def build_ir(
modules: list[MypyFile],
graph: Graph,
types: dict[Expression, Type],
mapper: Mapper,
options: CompilerOptions,
errors: Errors,
) -> ModuleIRs:
"""Build basic IR for a set of modules that have been type-checked by mypy.
The returned IR is not complete and requires additional
transformations, such as the insertion of refcount handling.
"""
build_type_map(mapper, modules, graph, types, options, errors)
adjust_generator_classes_of_methods(mapper)
singledispatch_info = find_singledispatch_register_impls(modules, errors)
result: ModuleIRs = {}
if errors.num_errors > 0:
return result
# Generate IR for all modules.
class_irs = []
for module in modules:
# First pass to determine free symbols.
pbv = PreBuildVisitor(errors, module, singledispatch_info.decorators_to_remove, types)
module.accept(pbv)
# Declare generator classes for nested async functions and generators.
for fdef in pbv.nested_funcs:
if isinstance(fdef, FuncDef):
# Make generator class name sufficiently unique.
suffix = f"___{fdef.line}"
if fdef.is_coroutine or fdef.is_generator:
create_generator_class_for_func(
module.fullname, None, fdef, mapper, name_suffix=suffix
)
# Construct and configure builder objects (cyclic runtime dependency).
visitor = IRBuilderVisitor()
builder = IRBuilder(
module.fullname,
types,
graph,
errors,
mapper,
pbv,
visitor,
options,
singledispatch_info.singledispatch_impls,
)
visitor.builder = builder
# Second pass does the bulk of the work.
transform_mypy_file(builder, module)
module_ir = ModuleIR(
module.fullname,
list(builder.imports),
builder.functions,
builder.classes,
builder.final_names,
builder.type_var_names,
)
result[module.fullname] = module_ir
class_irs.extend(builder.classes)
analyze_always_defined_attrs(class_irs)
# Compute vtables.
for cir in class_irs:
if cir.is_ext_class:
compute_vtable(cir)
return result
def transform_mypy_file(builder: IRBuilder, mypyfile: MypyFile) -> None:
"""Generate IR for a single module."""
if mypyfile.fullname in ("typing", "abc"):
# These module are special; their contents are currently all
# built-in primitives.
return
builder.set_module(mypyfile.fullname, mypyfile.path)
classes = [node for node in mypyfile.defs if isinstance(node, ClassDef)]
# Collect all classes.
for cls in classes:
ir = builder.mapper.type_to_ir[cls.info]
builder.classes.append(ir)
builder.enter("<module>")
# Make sure we have a builtins import
builder.gen_import("builtins", 1)
# Generate ops.
for node in mypyfile.defs:
builder.accept(node)
builder.maybe_add_implicit_return()
# Generate special function representing module top level.
args, _, blocks, ret_type, _ = builder.leave()
sig = FuncSignature([], none_rprimitive)
func_ir = FuncIR(
FuncDecl(TOP_LEVEL_NAME, None, builder.module_name, sig),
args,
blocks,
traceback_name="<module>",
)
builder.functions.append(func_ir)

View file

@ -0,0 +1,244 @@
"""Maintain a mapping from mypy concepts to IR/compiled concepts."""
from __future__ import annotations
from mypy.nodes import ARG_STAR, ARG_STAR2, GDEF, ArgKind, FuncDef, RefExpr, SymbolNode, TypeInfo
from mypy.types import (
AnyType,
CallableType,
Instance,
LiteralType,
NoneTyp,
Overloaded,
PartialType,
TupleType,
Type,
TypedDictType,
TypeType,
TypeVarLikeType,
UnboundType,
UninhabitedType,
UnionType,
find_unpack_in_list,
get_proper_type,
)
from mypyc.ir.class_ir import ClassIR
from mypyc.ir.func_ir import FuncDecl, FuncSignature, RuntimeArg
from mypyc.ir.rtypes import (
KNOWN_NATIVE_TYPES,
RInstance,
RTuple,
RType,
RUnion,
RVec,
bool_rprimitive,
bytearray_rprimitive,
bytes_rprimitive,
dict_rprimitive,
float_rprimitive,
frozenset_rprimitive,
int16_rprimitive,
int32_rprimitive,
int64_rprimitive,
int_rprimitive,
list_rprimitive,
none_rprimitive,
object_rprimitive,
range_rprimitive,
set_rprimitive,
str_rprimitive,
tuple_rprimitive,
uint8_rprimitive,
)
class Mapper:
"""Keep track of mappings from mypy concepts to IR concepts.
For example, we keep track of how the mypy TypeInfos of compiled
classes map to class IR objects.
This state is shared across all modules being compiled in all
compilation groups.
"""
def __init__(self, group_map: dict[str, str | None]) -> None:
self.group_map = group_map
self.type_to_ir: dict[TypeInfo, ClassIR] = {}
self.func_to_decl: dict[SymbolNode, FuncDecl] = {}
self.symbol_fullnames: set[str] = set()
# The corresponding generator class that implements a generator/async function
self.fdef_to_generator: dict[FuncDef, ClassIR] = {}
def type_to_rtype(self, typ: Type | None) -> RType:
if typ is None:
return object_rprimitive
typ = get_proper_type(typ)
if isinstance(typ, Instance):
if typ.type.is_newtype:
# Unwrap NewType to its base type for rprimitive mapping
assert len(typ.type.bases) == 1, typ.type.bases
return self.type_to_rtype(typ.type.bases[0])
if typ.type.fullname == "builtins.int":
return int_rprimitive
elif typ.type.fullname == "builtins.float":
return float_rprimitive
elif typ.type.fullname == "builtins.bool":
return bool_rprimitive
elif typ.type.fullname == "builtins.str":
return str_rprimitive
elif typ.type.fullname == "builtins.bytes":
return bytes_rprimitive
elif typ.type.fullname == "builtins.bytearray":
return bytearray_rprimitive
elif typ.type.fullname == "builtins.list":
return list_rprimitive
# Dict subclasses are at least somewhat common and we
# specifically support them, so make sure that dict operations
# get optimized on them.
elif any(cls.fullname == "builtins.dict" for cls in typ.type.mro):
return dict_rprimitive
elif typ.type.fullname == "builtins.set":
return set_rprimitive
elif typ.type.fullname == "builtins.frozenset":
return frozenset_rprimitive
elif typ.type.fullname == "builtins.tuple":
return tuple_rprimitive # Varying-length tuple
elif typ.type.fullname == "builtins.range":
return range_rprimitive
elif typ.type in self.type_to_ir:
inst = RInstance(self.type_to_ir[typ.type])
# Treat protocols as Union[protocol, object], so that we can do fast
# method calls in the cases where the protocol is explicitly inherited from
# and fall back to generic operations when it isn't.
if typ.type.is_protocol:
return RUnion([inst, object_rprimitive])
else:
return inst
elif typ.type.fullname == "mypy_extensions.i64":
return int64_rprimitive
elif typ.type.fullname == "mypy_extensions.i32":
return int32_rprimitive
elif typ.type.fullname == "mypy_extensions.i16":
return int16_rprimitive
elif typ.type.fullname == "mypy_extensions.u8":
return uint8_rprimitive
elif typ.type.fullname == "librt.vecs.vec":
return RVec(self.type_to_rtype(typ.args[0]))
elif typ.type.fullname in KNOWN_NATIVE_TYPES:
return KNOWN_NATIVE_TYPES[typ.type.fullname]
else:
return object_rprimitive
elif isinstance(typ, TupleType):
# Use our unboxed tuples for raw tuples but fall back to
# being boxed for NamedTuple or for variadic tuples.
if (
typ.partial_fallback.type.fullname == "builtins.tuple"
and find_unpack_in_list(typ.items) is None
):
return RTuple([self.type_to_rtype(t) for t in typ.items])
else:
return tuple_rprimitive
elif isinstance(typ, CallableType):
return object_rprimitive
elif isinstance(typ, NoneTyp):
return none_rprimitive
elif isinstance(typ, UnionType):
return RUnion.make_simplified_union([self.type_to_rtype(item) for item in typ.items])
elif isinstance(typ, AnyType):
return object_rprimitive
elif isinstance(typ, TypeType):
return object_rprimitive
elif isinstance(typ, TypeVarLikeType):
# Erase type variable to upper bound.
# TODO: Erase to union if object has value restriction?
return self.type_to_rtype(typ.upper_bound)
elif isinstance(typ, PartialType):
assert typ.var.type is not None
return self.type_to_rtype(typ.var.type)
elif isinstance(typ, Overloaded):
return object_rprimitive
elif isinstance(typ, TypedDictType):
return dict_rprimitive
elif isinstance(typ, LiteralType):
return self.type_to_rtype(typ.fallback)
elif isinstance(typ, (UninhabitedType, UnboundType)):
# Sure, whatever!
return object_rprimitive
# I think we've covered everything that is supposed to
# actually show up, so anything else is a bug somewhere.
assert False, "unexpected type %s" % type(typ)
def get_arg_rtype(self, typ: Type, kind: ArgKind) -> RType:
if kind == ARG_STAR:
return tuple_rprimitive
elif kind == ARG_STAR2:
return dict_rprimitive
else:
return self.type_to_rtype(typ)
def fdef_to_sig(self, fdef: FuncDef, strict_dunders_typing: bool) -> FuncSignature:
if isinstance(fdef.type, CallableType):
arg_types = [
self.get_arg_rtype(typ, kind)
for typ, kind in zip(fdef.type.arg_types, fdef.type.arg_kinds)
]
arg_pos_onlys = [name is None for name in fdef.type.arg_names]
ret = self.type_to_rtype(fdef.type.ret_type)
else:
# Handle unannotated functions
arg_types = [object_rprimitive for _ in fdef.arguments]
arg_pos_onlys = [arg.pos_only for arg in fdef.arguments]
# We at least know the return type for __init__ methods will be None.
is_init_method = fdef.name == "__init__" and bool(fdef.info)
if is_init_method:
ret = none_rprimitive
else:
ret = object_rprimitive
# mypyc FuncSignatures (unlike mypy types) want to have a name
# present even when the argument is position only, since it is
# the sole way that FuncDecl arguments are tracked. This is
# generally fine except in some cases (like for computing
# init_sig) we need to produce FuncSignatures from a
# deserialized FuncDef that lacks arguments. We won't ever
# need to use those inside of a FuncIR, so we just make up
# some crap.
if hasattr(fdef, "arguments"):
arg_names = [arg.variable.name for arg in fdef.arguments]
else:
arg_names = [name or "" for name in fdef.arg_names]
args = [
RuntimeArg(arg_name, arg_type, arg_kind, arg_pos_only)
for arg_name, arg_kind, arg_type, arg_pos_only in zip(
arg_names, fdef.arg_kinds, arg_types, arg_pos_onlys
)
]
if not strict_dunders_typing:
# We force certain dunder methods to return objects to support letting them
# return NotImplemented. It also avoids some pointless boxing and unboxing,
# since tp_richcompare needs an object anyways.
# However, it also prevents some optimizations.
if fdef.name in ("__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"):
ret = object_rprimitive
return FuncSignature(args, ret)
def is_native_module(self, module: str) -> bool:
"""Is the given module one compiled by mypyc?"""
return module in self.group_map
def is_native_ref_expr(self, expr: RefExpr) -> bool:
if expr.node is None:
return False
if "." in expr.node.fullname:
name = expr.node.fullname.rpartition(".")[0]
return self.is_native_module(name) or name in self.symbol_fullnames
return True
def is_native_module_ref_expr(self, expr: RefExpr) -> bool:
return self.is_native_ref_expr(expr) and expr.kind == GDEF

View file

@ -0,0 +1,367 @@
from __future__ import annotations
from collections.abc import Generator
from contextlib import contextmanager
from mypy.nodes import MatchStmt, NameExpr, TypeInfo
from mypy.patterns import (
AsPattern,
ClassPattern,
MappingPattern,
OrPattern,
Pattern,
SequencePattern,
SingletonPattern,
StarredPattern,
ValuePattern,
)
from mypy.traverser import TraverserVisitor
from mypy.types import Instance, LiteralType, TupleType, get_proper_type
from mypyc.ir.ops import BasicBlock, Value
from mypyc.ir.rtypes import object_rprimitive
from mypyc.irbuild.builder import IRBuilder
from mypyc.primitives.dict_ops import (
dict_copy,
dict_del_item,
mapping_has_key,
supports_mapping_protocol,
)
from mypyc.primitives.generic_ops import generic_ssize_t_len_op
from mypyc.primitives.list_ops import (
sequence_get_item,
sequence_get_slice,
supports_sequence_protocol,
)
from mypyc.primitives.misc_ops import fast_isinstance_op, slow_isinstance_op
# From: https://peps.python.org/pep-0634/#class-patterns
MATCHABLE_BUILTINS = {
"builtins.bool",
"builtins.bytearray",
"builtins.bytes",
"builtins.dict",
"builtins.float",
"builtins.frozenset",
"builtins.int",
"builtins.list",
"builtins.set",
"builtins.str",
"builtins.tuple",
}
class MatchVisitor(TraverserVisitor):
builder: IRBuilder
code_block: BasicBlock
next_block: BasicBlock
final_block: BasicBlock
subject: Value
match: MatchStmt
as_pattern: AsPattern | None = None
def __init__(self, builder: IRBuilder, match_node: MatchStmt) -> None:
self.builder = builder
self.code_block = BasicBlock()
self.next_block = BasicBlock()
self.final_block = BasicBlock()
self.match = match_node
self.subject = builder.accept(match_node.subject)
def build_match_body(self, index: int) -> None:
self.builder.activate_block(self.code_block)
guard = self.match.guards[index]
if guard:
self.code_block = BasicBlock()
cond = self.builder.accept(guard)
self.builder.add_bool_branch(cond, self.code_block, self.next_block)
self.builder.activate_block(self.code_block)
self.builder.accept(self.match.bodies[index])
self.builder.goto(self.final_block)
def visit_match_stmt(self, m: MatchStmt) -> None:
for i, pattern in enumerate(m.patterns):
self.code_block = BasicBlock()
self.next_block = BasicBlock()
pattern.accept(self)
self.build_match_body(i)
self.builder.activate_block(self.next_block)
self.builder.goto_and_activate(self.final_block)
def visit_value_pattern(self, pattern: ValuePattern) -> None:
value = self.builder.accept(pattern.expr)
cond = self.builder.binary_op(self.subject, value, "==", pattern.expr.line)
self.bind_as_pattern(value)
self.builder.add_bool_branch(cond, self.code_block, self.next_block)
def visit_or_pattern(self, pattern: OrPattern) -> None:
code_block = self.code_block
next_block = self.next_block
for p in pattern.patterns:
self.code_block = BasicBlock()
self.next_block = BasicBlock()
# Hack to ensure the as pattern is bound to each pattern in the
# "or" pattern, but not every subpattern
backup = self.as_pattern
p.accept(self)
self.as_pattern = backup
self.builder.activate_block(self.code_block)
self.builder.goto(code_block)
self.builder.activate_block(self.next_block)
self.code_block = code_block
self.next_block = next_block
self.builder.goto(self.next_block)
def visit_class_pattern(self, pattern: ClassPattern) -> None:
# TODO: use faster instance check for native classes (while still
# making sure to account for inheritance)
isinstance_op = (
fast_isinstance_op
if self.builder.is_builtin_ref_expr(pattern.class_ref)
else slow_isinstance_op
)
cond = self.builder.primitive_op(
isinstance_op, [self.subject, self.builder.accept(pattern.class_ref)], pattern.line
)
self.builder.add_bool_branch(cond, self.code_block, self.next_block)
self.bind_as_pattern(self.subject, new_block=True)
if pattern.positionals:
if pattern.class_ref.fullname in MATCHABLE_BUILTINS:
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()
pattern.positionals[0].accept(self)
return
node = pattern.class_ref.node
assert isinstance(node, TypeInfo), node
match_args = extract_dunder_match_args_names(node)
for i, expr in enumerate(pattern.positionals):
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()
# TODO: use faster "get_attr" method instead when calling on native or
# builtin objects
positional = self.builder.py_get_attr(self.subject, match_args[i], expr.line)
with self.enter_subpattern(positional):
expr.accept(self)
for key, value in zip(pattern.keyword_keys, pattern.keyword_values):
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()
# TODO: same as above "get_attr" comment
attr = self.builder.py_get_attr(self.subject, key, value.line)
with self.enter_subpattern(attr):
value.accept(self)
def visit_as_pattern(self, pattern: AsPattern) -> None:
if pattern.pattern:
old_pattern = self.as_pattern
self.as_pattern = pattern
pattern.pattern.accept(self)
self.as_pattern = old_pattern
elif pattern.name:
target = self.builder.get_assignment_target(pattern.name)
self.builder.assign(target, self.subject, pattern.line)
self.builder.goto(self.code_block)
def visit_singleton_pattern(self, pattern: SingletonPattern) -> None:
if pattern.value is None:
obj = self.builder.none_object()
elif pattern.value is True:
obj = self.builder.true()
else:
obj = self.builder.false()
cond = self.builder.binary_op(self.subject, obj, "is", pattern.line)
self.builder.add_bool_branch(cond, self.code_block, self.next_block)
def visit_mapping_pattern(self, pattern: MappingPattern) -> None:
is_dict = self.builder.call_c(supports_mapping_protocol, [self.subject], pattern.line)
self.builder.add_bool_branch(is_dict, self.code_block, self.next_block)
keys: list[Value] = []
for key, value in zip(pattern.keys, pattern.values):
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()
key_value = self.builder.accept(key)
keys.append(key_value)
exists = self.builder.call_c(mapping_has_key, [self.subject, key_value], pattern.line)
self.builder.add_bool_branch(exists, self.code_block, self.next_block)
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()
item = self.builder.gen_method_call(
self.subject, "__getitem__", [key_value], object_rprimitive, pattern.line
)
with self.enter_subpattern(item):
value.accept(self)
if pattern.rest:
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()
rest = self.builder.primitive_op(dict_copy, [self.subject], pattern.rest.line)
target = self.builder.get_assignment_target(pattern.rest)
self.builder.assign(target, rest, pattern.rest.line)
for i, key_name in enumerate(keys):
self.builder.call_c(dict_del_item, [rest, key_name], pattern.keys[i].line)
self.builder.goto(self.code_block)
def visit_sequence_pattern(self, seq_pattern: SequencePattern) -> None:
star_index, capture, patterns = prep_sequence_pattern(seq_pattern)
is_list = self.builder.call_c(supports_sequence_protocol, [self.subject], seq_pattern.line)
self.builder.add_bool_branch(is_list, self.code_block, self.next_block)
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()
actual_len = self.builder.call_c(generic_ssize_t_len_op, [self.subject], seq_pattern.line)
min_len = len(patterns)
is_long_enough = self.builder.binary_op(
actual_len,
self.builder.load_int(min_len),
"==" if star_index is None else ">=",
seq_pattern.line,
)
self.builder.add_bool_branch(is_long_enough, self.code_block, self.next_block)
for i, pattern in enumerate(patterns):
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()
if star_index is not None and i >= star_index:
current = self.builder.binary_op(
actual_len, self.builder.load_int(min_len - i), "-", pattern.line
)
else:
current = self.builder.load_int(i)
item = self.builder.call_c(sequence_get_item, [self.subject, current], pattern.line)
with self.enter_subpattern(item):
pattern.accept(self)
if capture and star_index is not None:
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()
capture_end = self.builder.binary_op(
actual_len, self.builder.load_int(min_len - star_index), "-", capture.line
)
rest = self.builder.call_c(
sequence_get_slice,
[self.subject, self.builder.load_int(star_index), capture_end],
capture.line,
)
target = self.builder.get_assignment_target(capture)
self.builder.assign(target, rest, capture.line)
self.builder.goto(self.code_block)
def bind_as_pattern(self, value: Value, new_block: bool = False) -> None:
if self.as_pattern and self.as_pattern.pattern and self.as_pattern.name:
if new_block:
self.builder.activate_block(self.code_block)
self.code_block = BasicBlock()
target = self.builder.get_assignment_target(self.as_pattern.name)
self.builder.assign(target, value, self.as_pattern.pattern.line)
self.as_pattern = None
if new_block:
self.builder.goto(self.code_block)
@contextmanager
def enter_subpattern(self, subject: Value) -> Generator[None]:
old_subject = self.subject
self.subject = subject
yield
self.subject = old_subject
def prep_sequence_pattern(
seq_pattern: SequencePattern,
) -> tuple[int | None, NameExpr | None, list[Pattern]]:
star_index: int | None = None
capture: NameExpr | None = None
patterns: list[Pattern] = []
for i, pattern in enumerate(seq_pattern.patterns):
if isinstance(pattern, StarredPattern):
star_index = i
capture = pattern.capture
else:
patterns.append(pattern)
return star_index, capture, patterns
def extract_dunder_match_args_names(info: TypeInfo) -> list[str]:
ty = info.names.get("__match_args__")
assert ty
match_args_type = get_proper_type(ty.type)
assert isinstance(match_args_type, TupleType), match_args_type
match_args: list[str] = []
for item in match_args_type.items:
proper_item = get_proper_type(item)
match_arg = None
if isinstance(proper_item, Instance) and proper_item.last_known_value:
match_arg = proper_item.last_known_value.value
elif isinstance(proper_item, LiteralType):
match_arg = proper_item.value
assert isinstance(match_arg, str), f"Unrecognized __match_args__ item: {item}"
match_args.append(match_arg)
return match_args

View file

@ -0,0 +1,20 @@
from __future__ import annotations
from mypy.nodes import Expression, Node
from mypy.traverser import ExtendedTraverserVisitor
from mypy.types import AnyType, Type, TypeOfAny
class MissingTypesVisitor(ExtendedTraverserVisitor):
"""AST visitor that can be used to add any missing types as a generic AnyType."""
def __init__(self, types: dict[Expression, Type]) -> None:
super().__init__()
self.types: dict[Expression, Type] = types
def visit(self, o: Node) -> bool:
if isinstance(o, Expression) and o not in self.types:
self.types[o] = AnyType(TypeOfAny.special_form)
# If returns True, will continue to nested nodes.
return True

View file

@ -0,0 +1,216 @@
"""Helpers for dealing with nonlocal control such as 'break' and 'return'.
Model how these behave differently in different contexts.
"""
from __future__ import annotations
from abc import abstractmethod
from typing import TYPE_CHECKING
from mypyc.ir.ops import (
NO_TRACEBACK_LINE_NO,
BasicBlock,
Branch,
Goto,
Integer,
Register,
Return,
SetMem,
Unreachable,
Value,
)
from mypyc.ir.rtypes import object_rprimitive
from mypyc.irbuild.targets import AssignmentTarget
from mypyc.primitives.exc_ops import restore_exc_info_op, set_stop_iteration_value
if TYPE_CHECKING:
from mypyc.irbuild.builder import IRBuilder
class NonlocalControl:
"""ABC representing a stack frame of constructs that modify nonlocal control flow.
The nonlocal control flow constructs are break, continue, and
return, and their behavior is modified by a number of other
constructs. The most obvious is loop, which override where break
and continue jump to, but also `except` (which needs to clear
exc_info when left) and (eventually) finally blocks (which need to
ensure that the finally block is always executed when leaving the
try/except blocks).
"""
@abstractmethod
def gen_break(self, builder: IRBuilder, line: int) -> None:
pass
@abstractmethod
def gen_continue(self, builder: IRBuilder, line: int) -> None:
pass
@abstractmethod
def gen_return(self, builder: IRBuilder, value: Value, line: int) -> None:
pass
class BaseNonlocalControl(NonlocalControl):
"""Default nonlocal control outside any statements that affect it."""
def gen_break(self, builder: IRBuilder, line: int) -> None:
assert False, "break outside of loop"
def gen_continue(self, builder: IRBuilder, line: int) -> None:
assert False, "continue outside of loop"
def gen_return(self, builder: IRBuilder, value: Value, line: int) -> None:
builder.add(Return(value, line))
class LoopNonlocalControl(NonlocalControl):
"""Nonlocal control within a loop."""
def __init__(
self, outer: NonlocalControl, continue_block: BasicBlock, break_block: BasicBlock
) -> None:
self.outer = outer
self.continue_block = continue_block
self.break_block = break_block
def gen_break(self, builder: IRBuilder, line: int) -> None:
builder.add(Goto(self.break_block))
def gen_continue(self, builder: IRBuilder, line: int) -> None:
builder.add(Goto(self.continue_block))
def gen_return(self, builder: IRBuilder, value: Value, line: int) -> None:
self.outer.gen_return(builder, value, line)
class GeneratorNonlocalControl(BaseNonlocalControl):
"""Default nonlocal control in a generator function outside statements."""
def gen_return(self, builder: IRBuilder, value: Value, line: int) -> None:
# Assign an invalid next label number so that the next time
# __next__ is called, we jump to the case in which
# StopIteration is raised.
builder.assign(builder.fn_info.generator_class.next_label_target, Integer(-1), line)
# Raise a StopIteration containing a field for the value that
# should be returned. Before doing so, create a new block
# without an error handler set so that the implicitly thrown
# StopIteration isn't caught by except blocks inside of the
# generator function.
builder.builder.push_error_handler(None)
builder.goto_and_activate(BasicBlock())
# Skip creating a traceback frame when we raise here, because
# we don't care about the traceback frame and it is kind of
# expensive since raising StopIteration is an extremely common
# case. Also we call a special internal function to set
# StopIteration instead of using RaiseStandardError because
# the obvious thing doesn't work if the value is a tuple
# (???).
true, false = BasicBlock(), BasicBlock()
stop_iter_reg = builder.fn_info.generator_class.stop_iter_value_reg
assert stop_iter_reg is not None
builder.add(Branch(stop_iter_reg, true, false, Branch.IS_ERROR))
builder.activate_block(true)
# The default/slow path is to raise a StopIteration exception with
# return value.
builder.call_c(set_stop_iteration_value, [value], NO_TRACEBACK_LINE_NO)
builder.add(Unreachable())
builder.builder.pop_error_handler()
builder.activate_block(false)
# The fast path is to store return value via caller-provided pointer
# instead of raising an exception. This can only be used when the
# caller is a native function.
builder.add(SetMem(object_rprimitive, stop_iter_reg, value))
builder.add(Return(Integer(0, object_rprimitive)))
class CleanupNonlocalControl(NonlocalControl):
"""Abstract nonlocal control that runs some cleanup code."""
def __init__(self, outer: NonlocalControl) -> None:
self.outer = outer
@abstractmethod
def gen_cleanup(self, builder: IRBuilder, line: int) -> None: ...
def gen_break(self, builder: IRBuilder, line: int) -> None:
self.gen_cleanup(builder, line)
self.outer.gen_break(builder, line)
def gen_continue(self, builder: IRBuilder, line: int) -> None:
self.gen_cleanup(builder, line)
self.outer.gen_continue(builder, line)
def gen_return(self, builder: IRBuilder, value: Value, line: int) -> None:
self.gen_cleanup(builder, line)
self.outer.gen_return(builder, value, line)
class TryFinallyNonlocalControl(NonlocalControl):
"""Nonlocal control within try/finally."""
def __init__(self, target: BasicBlock) -> None:
self.target = target
self.ret_reg: None | Register | AssignmentTarget = None
def gen_break(self, builder: IRBuilder, line: int) -> None:
builder.error("break inside try/finally block is unimplemented", line)
def gen_continue(self, builder: IRBuilder, line: int) -> None:
builder.error("continue inside try/finally block is unimplemented", line)
def gen_return(self, builder: IRBuilder, value: Value, line: int) -> None:
if self.ret_reg is None:
if builder.fn_info.is_generator:
self.ret_reg = builder.make_spill_target(builder.ret_types[-1])
else:
self.ret_reg = Register(builder.ret_types[-1])
# assert needed because of apparent mypy bug... it loses track of the union
# and infers the type as object
assert isinstance(self.ret_reg, (Register, AssignmentTarget)), self.ret_reg
builder.assign(self.ret_reg, value, line)
builder.add(Goto(self.target))
class ExceptNonlocalControl(CleanupNonlocalControl):
"""Nonlocal control for except blocks.
Just makes sure that sys.exc_info always gets restored when we leave.
This is super annoying.
"""
def __init__(self, outer: NonlocalControl, saved: Value | AssignmentTarget) -> None:
super().__init__(outer)
self.saved = saved
def gen_cleanup(self, builder: IRBuilder, line: int) -> None:
builder.call_c(restore_exc_info_op, [builder.read(self.saved, line)], line)
class FinallyNonlocalControl(CleanupNonlocalControl):
"""Nonlocal control for finally blocks.
Just makes sure that sys.exc_info always gets restored when we
leave and the return register is decrefed if it isn't null.
"""
def __init__(self, outer: NonlocalControl, saved: Value) -> None:
super().__init__(outer)
self.saved = saved
def gen_cleanup(self, builder: IRBuilder, line: int) -> None:
# Restore the old exc_info
target, cleanup = BasicBlock(), BasicBlock()
builder.add(Branch(self.saved, target, cleanup, Branch.IS_ERROR))
builder.activate_block(cleanup)
builder.call_c(restore_exc_info_op, [self.saved], line)
builder.goto_and_activate(target)

View file

@ -0,0 +1,305 @@
from __future__ import annotations
from mypy.nodes import (
AssignmentStmt,
Block,
Decorator,
DictionaryComprehension,
Expression,
FuncDef,
FuncItem,
GeneratorExpr,
Import,
LambdaExpr,
MemberExpr,
MypyFile,
NameExpr,
Node,
SymbolNode,
Var,
)
from mypy.traverser import ExtendedTraverserVisitor, TraverserVisitor
from mypy.types import Type
from mypyc.errors import Errors
from mypyc.irbuild.missingtypevisitor import MissingTypesVisitor
class _LambdaChecker(TraverserVisitor):
"""Check whether an AST subtree contains a lambda expression."""
found = False
def visit_lambda_expr(self, _o: LambdaExpr) -> None:
self.found = True
def _comprehension_has_lambda(node: GeneratorExpr | DictionaryComprehension) -> bool:
"""Return True if a comprehension body contains a lambda.
Only checks body expressions (left_expr/key/value and conditions),
not the sequences, since sequences are evaluated in the enclosing scope.
"""
checker = _LambdaChecker()
if isinstance(node, GeneratorExpr):
node.left_expr.accept(checker)
else:
node.key.accept(checker)
node.value.accept(checker)
for conds in node.condlists:
for cond in conds:
cond.accept(checker)
return checker.found
class PreBuildVisitor(ExtendedTraverserVisitor):
"""Mypy file AST visitor run before building the IR.
This collects various things, including:
* Determine relationships between nested functions and functions that
contain nested functions
* Find non-local variables (free variables)
* Find property setters
* Find decorators of functions
* Find module import groups
The main IR build pass uses this information.
"""
def __init__(
self,
errors: Errors,
current_file: MypyFile,
decorators_to_remove: dict[FuncDef, list[int]],
types: dict[Expression, Type],
) -> None:
super().__init__()
# Dict from a function to symbols defined directly in the
# function that are used as non-local (free) variables within a
# nested function.
self.free_variables: dict[FuncItem, set[SymbolNode]] = {}
# Intermediate data structure used to find the function where
# a SymbolNode is declared. Initially this may point to a
# function nested inside the function with the declaration,
# but we'll eventually update this to refer to the function
# with the declaration.
self.symbols_to_funcs: dict[SymbolNode, FuncItem] = {}
# Stack representing current function nesting.
self.funcs: list[FuncItem] = []
# All property setters encountered so far.
self.prop_setters: set[FuncDef] = set()
# A map from any function that contains nested functions to
# a set of all the functions that are nested within it.
self.encapsulating_funcs: dict[FuncItem, list[FuncItem]] = {}
# Map nested function to its parent/encapsulating function.
self.nested_funcs: dict[FuncItem, FuncItem] = {}
# Map function to its non-special decorators.
self.funcs_to_decorators: dict[FuncDef, list[Expression]] = {}
# Map function to indices of decorators to remove
self.decorators_to_remove: dict[FuncDef, list[int]] = decorators_to_remove
# A mapping of import groups (a series of Import nodes with
# nothing in between) where each group is keyed by its first
# import node.
self.module_import_groups: dict[Import, list[Import]] = {}
self._current_import_group: Import | None = None
self.errors: Errors = errors
self.current_file: MypyFile = current_file
self.missing_types_visitor = MissingTypesVisitor(types)
# Synthetic FuncDef representing the module scope, created on demand
# when a comprehension at module/class level contains a lambda.
self._module_fitem: FuncDef | None = None
# Counter for generating unique synthetic comprehension scope names.
self._comprehension_counter = 0
# Map comprehension AST nodes to synthetic FuncDefs representing
# their scope (only for comprehensions that contain lambdas).
self.comprehension_to_fitem: dict[GeneratorExpr | DictionaryComprehension, FuncDef] = {}
def visit(self, o: Node) -> bool:
if not isinstance(o, Import):
self._current_import_group = None
return True
def visit_assignment_stmt(self, stmt: AssignmentStmt) -> None:
# These are cases where mypy may not have types for certain expressions,
# but mypyc needs some form type to exist.
if stmt.is_alias_def:
stmt.rvalue.accept(self.missing_types_visitor)
return super().visit_assignment_stmt(stmt)
def visit_block(self, block: Block) -> None:
self._current_import_group = None
super().visit_block(block)
self._current_import_group = None
def visit_decorator(self, dec: Decorator) -> None:
if dec.decorators:
# Only add the function being decorated if there exist
# (ordinary) decorators in the decorator list. Certain
# decorators (such as @property, @abstractmethod) are
# special cased and removed from this list by
# mypy. Functions decorated only by special decorators
# (and property setters) are not treated as decorated
# functions by the IR builder.
if isinstance(dec.decorators[0], MemberExpr) and dec.decorators[0].name == "setter":
# Property setters are not treated as decorated methods.
self.prop_setters.add(dec.func)
else:
decorators_to_store = dec.decorators.copy()
if dec.func in self.decorators_to_remove:
to_remove = self.decorators_to_remove[dec.func]
for i in reversed(to_remove):
del decorators_to_store[i]
# if all of the decorators are removed, we shouldn't treat this as a decorated
# function because there aren't any decorators to apply
if not decorators_to_store:
return
self.funcs_to_decorators[dec.func] = decorators_to_store
super().visit_decorator(dec)
def visit_func_def(self, fdef: FuncDef) -> None:
# TODO: What about overloaded functions?
self.visit_func(fdef)
self.visit_symbol_node(fdef)
def visit_lambda_expr(self, expr: LambdaExpr) -> None:
self.visit_func(expr)
def visit_func(self, func: FuncItem) -> None:
# If there were already functions or lambda expressions
# defined in the function stack, then note the previous
# FuncItem as containing a nested function and the current
# FuncItem as being a nested function.
if self.funcs:
# Add the new func to the set of nested funcs within the
# func at top of the func stack.
self.encapsulating_funcs.setdefault(self.funcs[-1], []).append(func)
# Add the func at top of the func stack as the parent of
# new func.
self.nested_funcs[func] = self.funcs[-1]
self.funcs.append(func)
super().visit_func(func)
self.funcs.pop()
def _visit_comprehension_with_scope(self, o: GeneratorExpr | DictionaryComprehension) -> None:
"""Visit a comprehension that contains lambdas.
Creates a synthetic FuncDef to represent the comprehension's scope,
registers it in the function nesting hierarchy, and traverses the
comprehension body with it on the stack.
"""
pushed_module = False
if not self.funcs:
# At module level: push synthetic module FuncDef.
if self._module_fitem is None:
self._module_fitem = FuncDef("__mypyc_module__")
self._module_fitem.line = 1
self.funcs.append(self._module_fitem)
pushed_module = True
# Create synthetic FuncDef for the comprehension scope.
comprehension_fdef = FuncDef(f"__comprehension_{self._comprehension_counter}__")
self._comprehension_counter += 1
comprehension_fdef.line = o.line
self.comprehension_to_fitem[o] = comprehension_fdef
# Register as nested within enclosing function.
self.encapsulating_funcs.setdefault(self.funcs[-1], []).append(comprehension_fdef)
self.nested_funcs[comprehension_fdef] = self.funcs[-1]
# Push and traverse.
self.funcs.append(comprehension_fdef)
if isinstance(o, GeneratorExpr):
super().visit_generator_expr(o)
else:
super().visit_dictionary_comprehension(o)
self.funcs.pop()
if pushed_module:
self.funcs.pop()
def visit_generator_expr(self, o: GeneratorExpr) -> None:
if _comprehension_has_lambda(o):
self._visit_comprehension_with_scope(o)
else:
super().visit_generator_expr(o)
def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None:
if _comprehension_has_lambda(o):
self._visit_comprehension_with_scope(o)
else:
super().visit_dictionary_comprehension(o)
def visit_import(self, imp: Import) -> None:
if self._current_import_group is not None:
self.module_import_groups[self._current_import_group].append(imp)
else:
self.module_import_groups[imp] = [imp]
self._current_import_group = imp
super().visit_import(imp)
def visit_name_expr(self, expr: NameExpr) -> None:
if isinstance(expr.node, (Var, FuncDef)):
self.visit_symbol_node(expr.node)
def visit_var(self, var: Var) -> None:
self.visit_symbol_node(var)
def visit_symbol_node(self, symbol: SymbolNode) -> None:
if not self.funcs:
# We are not inside a function and hence do not need to do
# anything regarding free variables.
return
if symbol in self.symbols_to_funcs:
orig_func = self.symbols_to_funcs[symbol]
if self.is_parent(self.funcs[-1], orig_func):
# The function in which the symbol was previously seen is
# nested within the function currently being visited. Thus
# the current function is a better candidate to contain the
# declaration.
self.symbols_to_funcs[symbol] = self.funcs[-1]
# TODO: Remove from the orig_func free_variables set?
self.free_variables.setdefault(self.funcs[-1], set()).add(symbol)
elif self.is_parent(orig_func, self.funcs[-1]):
# The SymbolNode instance has already been visited
# before in a parent function, thus it's a non-local
# symbol.
self.add_free_variable(symbol)
else:
# This is the first time the SymbolNode is being
# visited. We map the SymbolNode to the current FuncDef
# being visited to note where it was first visited.
self.symbols_to_funcs[symbol] = self.funcs[-1]
def is_parent(self, fitem: FuncItem, child: FuncItem) -> bool:
# Check if child is nested within fdef (possibly indirectly
# within multiple nested functions).
if child not in self.nested_funcs:
return False
parent = self.nested_funcs[child]
return parent == fitem or self.is_parent(fitem, parent)
def add_free_variable(self, symbol: SymbolNode) -> None:
# Find the function where the symbol was (likely) first declared,
# and mark is as a non-local symbol within that function.
func = self.symbols_to_funcs[symbol]
self.free_variables.setdefault(func, set()).add(symbol)

Some files were not shown because too many files have changed in this diff Show more