feat: indie status page MVP -- FastAPI + SQLite

- 8 DB models (services, incidents, monitors, subscribers, etc.)
- Full CRUD API for services, incidents, monitors
- Public status page with live data
- Incident detail page with timeline
- API key authentication
- Uptime monitoring scheduler
- 13 tests passing
- TECHNICAL_DESIGN.md with full spec
This commit is contained in:
IndieStatusBot 2026-04-25 05:00:00 +00:00
commit 902133edd3
4655 changed files with 1342691 additions and 0 deletions

View file

@ -0,0 +1 @@
# This page intentionally left blank

View file

@ -0,0 +1,50 @@
"""Mypy type checker command line tool."""
from __future__ import annotations
import os
import sys
import traceback
from mypy.main import main, process_options
from mypy.util import FancyFormatter
def console_entry() -> None:
try:
main()
sys.stdout.flush()
sys.stderr.flush()
except BrokenPipeError:
# Python flushes standard streams on exit; redirect remaining output
# to devnull to avoid another BrokenPipeError at shutdown
devnull = os.open(os.devnull, os.O_WRONLY)
os.dup2(devnull, sys.stdout.fileno())
sys.exit(2)
except KeyboardInterrupt:
_, options = process_options(args=sys.argv[1:])
if options.show_traceback:
sys.stdout.write(traceback.format_exc())
formatter = FancyFormatter(sys.stdout, sys.stderr, False)
msg = "Interrupted\n"
sys.stdout.write(formatter.style(msg, color="red", bold=True))
sys.stdout.flush()
sys.stderr.flush()
sys.exit(2)
except Exception as e:
# Try reporting any uncaught error canonically, otherwise just flush the traceback.
try:
import mypy.errors
_, options = process_options(args=sys.argv[1:])
mypy.errors.report_internal_error(e, None, 0, None, options)
except Exception:
pass
sys.stdout.write(traceback.format_exc())
sys.stdout.flush()
sys.stderr.flush()
sys.exit(2)
if __name__ == "__main__":
console_entry()

View file

@ -0,0 +1,95 @@
"""This module makes it possible to use mypy as part of a Python application.
Since mypy still changes, the API was kept utterly simple and non-intrusive.
It just mimics command line activation without starting a new interpreter.
So the normal docs about the mypy command line apply.
Changes in the command line version of mypy will be immediately usable.
Just import this module and then call the 'run' function with a parameter of
type List[str], containing what normally would have been the command line
arguments to mypy.
Function 'run' returns a Tuple[str, str, int], namely
(<normal_report>, <error_report>, <exit_status>),
in which <normal_report> is what mypy normally writes to sys.stdout,
<error_report> is what mypy normally writes to sys.stderr and exit_status is
the exit status mypy normally returns to the operating system.
Any pretty formatting is left to the caller.
The 'run_dmypy' function is similar, but instead mimics invocation of
dmypy. Note that run_dmypy is not thread-safe and modifies sys.stdout
and sys.stderr during its invocation.
Note that these APIs don't support incremental generation of error
messages.
Trivial example of code using this module:
import sys
from mypy import api
result = api.run(sys.argv[1:])
if result[0]:
print('\nType checking report:\n')
print(result[0]) # stdout
if result[1]:
print('\nError report:\n')
print(result[1]) # stderr
print('\nExit status:', result[2])
"""
from __future__ import annotations
import sys
from collections.abc import Callable
from io import StringIO
from typing import TextIO
def _run(main_wrapper: Callable[[TextIO, TextIO], None]) -> tuple[str, str, int]:
stdout = StringIO()
stderr = StringIO()
try:
main_wrapper(stdout, stderr)
exit_status = 0
except SystemExit as system_exit:
assert isinstance(system_exit.code, int)
exit_status = system_exit.code
return stdout.getvalue(), stderr.getvalue(), exit_status
def run(args: list[str]) -> tuple[str, str, int]:
# Lazy import to avoid needing to import all of mypy to call run_dmypy
from mypy.main import main
return _run(
lambda stdout, stderr: main(args=args, stdout=stdout, stderr=stderr, clean_exit=True)
)
def run_dmypy(args: list[str]) -> tuple[str, str, int]:
from mypy.dmypy.client import main
# A bunch of effort has been put into threading stdout and stderr
# through the main API to avoid the threadsafety problems of
# modifying sys.stdout/sys.stderr, but that hasn't been done for
# the dmypy client, so we just do the non-threadsafe thing.
def f(stdout: TextIO, stderr: TextIO) -> None:
old_stdout = sys.stdout
old_stderr = sys.stderr
try:
sys.stdout = stdout
sys.stderr = stderr
main(args)
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
return _run(f)

View file

@ -0,0 +1,303 @@
from __future__ import annotations
from collections.abc import Callable, Iterable, Sequence
import mypy.subtypes
from mypy.erasetype import erase_typevars
from mypy.expandtype import expand_type
from mypy.nodes import Context, TypeInfo
from mypy.type_visitor import TypeTranslator
from mypy.typeops import get_all_type_vars
from mypy.types import (
AnyType,
CallableType,
Instance,
Parameters,
ParamSpecFlavor,
ParamSpecType,
PartialType,
ProperType,
Type,
TypeAliasType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UninhabitedType,
UnpackType,
get_proper_type,
remove_dups,
)
def get_target_type(
tvar: TypeVarLikeType,
type: Type,
callable: CallableType,
report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None],
context: Context,
skip_unsatisfied: bool,
) -> Type | None:
p_type = get_proper_type(type)
if isinstance(p_type, UninhabitedType) and tvar.has_default():
return tvar.default
if isinstance(tvar, ParamSpecType):
return type
if isinstance(tvar, TypeVarTupleType):
return type
assert isinstance(tvar, TypeVarType)
values = tvar.values
if values:
if isinstance(p_type, AnyType):
return type
if isinstance(p_type, TypeVarType) and p_type.values:
# Allow substituting T1 for T if every allowed value of T1
# is also a legal value of T.
if all(any(mypy.subtypes.is_same_type(v, v1) for v in values) for v1 in p_type.values):
return type
matching = []
for value in values:
if mypy.subtypes.is_subtype(type, value):
matching.append(value)
if matching:
best = matching[0]
# If there are more than one matching value, we select the narrowest
for match in matching[1:]:
if mypy.subtypes.is_subtype(match, best):
best = match
return best
if skip_unsatisfied:
return None
report_incompatible_typevar_value(callable, type, tvar.name, context)
else:
upper_bound = tvar.upper_bound
if tvar.name == "Self":
# Internally constructed Self-types contain class type variables in upper bound,
# so we need to erase them to avoid false positives. This is safe because we do
# not support type variables in upper bounds of user defined types.
upper_bound = erase_typevars(upper_bound)
if not mypy.subtypes.is_subtype(type, upper_bound):
if skip_unsatisfied:
return None
report_incompatible_typevar_value(callable, type, tvar.name, context)
return type
def apply_generic_arguments(
callable: CallableType,
orig_types: Sequence[Type | None],
report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None],
context: Context,
skip_unsatisfied: bool = False,
) -> CallableType:
"""Apply generic type arguments to a callable type.
For example, applying [int] to 'def [T] (T) -> T' results in
'def (int) -> int'.
Note that each type can be None; in this case, it will not be applied.
If `skip_unsatisfied` is True, then just skip the types that don't satisfy type variable
bound or constraints, instead of giving an error.
"""
tvars = callable.variables
assert len(orig_types) <= len(tvars)
# Check that inferred type variable values are compatible with allowed
# values and bounds. Also, promote subtype values to allowed values.
# Create a map from type variable id to target type.
id_to_type: dict[TypeVarId, Type] = {}
for tvar, type in zip(tvars, orig_types):
assert not isinstance(type, PartialType), "Internal error: must never apply partial type"
if type is None:
continue
target_type = get_target_type(
tvar, type, callable, report_incompatible_typevar_value, context, skip_unsatisfied
)
if target_type is not None:
id_to_type[tvar.id] = target_type
# TODO: validate arg_kinds/arg_names for ParamSpec and TypeVarTuple replacements,
# not just type variable bounds above.
param_spec = callable.param_spec()
if param_spec is not None:
nt = id_to_type.get(param_spec.id)
if nt is not None:
# ParamSpec expansion is special-cased, so we need to always expand callable
# as a whole, not expanding arguments individually.
callable = expand_type(callable, id_to_type)
assert isinstance(callable, CallableType)
return callable.copy_modified(
variables=[tv for tv in tvars if tv.id not in id_to_type]
)
# Apply arguments to argument types.
var_arg = callable.var_arg()
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
# Same as for ParamSpec, callable with variadic types needs to be expanded as a whole.
callable = expand_type(callable, id_to_type)
assert isinstance(callable, CallableType)
return callable.copy_modified(variables=[tv for tv in tvars if tv.id not in id_to_type])
else:
callable = callable.copy_modified(
arg_types=[expand_type(at, id_to_type) for at in callable.arg_types]
)
# Apply arguments to TypeGuard and TypeIs if any.
if callable.type_guard is not None:
type_guard = expand_type(callable.type_guard, id_to_type)
else:
type_guard = None
if callable.type_is is not None:
type_is = expand_type(callable.type_is, id_to_type)
else:
type_is = None
# The callable may retain some type vars if only some were applied.
# TODO: move apply_poly() logic here when new inference
# becomes universally used (i.e. in all passes + in unification).
# With this new logic we can actually *add* some new free variables.
remaining_tvars: list[TypeVarLikeType] = []
for tv in tvars:
if tv.id in id_to_type:
continue
if not tv.has_default():
remaining_tvars.append(tv)
continue
# TypeVarLike isn't in id_to_type mapping.
# Only expand the TypeVar default here.
typ = expand_type(tv, id_to_type)
assert isinstance(typ, TypeVarLikeType)
remaining_tvars.append(typ)
return callable.copy_modified(
ret_type=expand_type(callable.ret_type, id_to_type),
variables=remaining_tvars,
type_guard=type_guard,
type_is=type_is,
)
def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> CallableType | None:
"""Make free type variables generic in the type if possible.
This will translate the type `tp` while trying to create valid bindings for
type variables `poly_tvars` while traversing the type. This follows the same rules
as we do during semantic analysis phase, examples:
* Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T
* Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T)
* List[T] -> None (not possible)
"""
try:
return tp.copy_modified(
arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types],
ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)),
variables=[],
)
except PolyTranslationError:
return None
class PolyTranslationError(Exception):
pass
class PolyTranslator(TypeTranslator):
"""Make free type variables generic in the type if possible.
See docstring for apply_poly() for details.
"""
def __init__(
self,
poly_tvars: Iterable[TypeVarLikeType],
bound_tvars: frozenset[TypeVarLikeType] = frozenset(),
seen_aliases: frozenset[TypeInfo] = frozenset(),
) -> None:
super().__init__()
self.poly_tvars = set(poly_tvars)
# This is a simplified version of TypeVarScope used during semantic analysis.
self.bound_tvars = bound_tvars
self.seen_aliases = seen_aliases
def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]:
found_vars = []
for arg in t.arg_types:
for tv in get_all_type_vars(arg):
if isinstance(tv, ParamSpecType):
normalized: TypeVarLikeType = tv.copy_modified(
flavor=ParamSpecFlavor.BARE, prefix=Parameters([], [], [])
)
else:
normalized = tv
if normalized in self.poly_tvars and normalized not in self.bound_tvars:
found_vars.append(normalized)
return remove_dups(found_vars)
def visit_callable_type(self, t: CallableType) -> Type:
found_vars = self.collect_vars(t)
self.bound_tvars |= set(found_vars)
result = super().visit_callable_type(t)
self.bound_tvars -= set(found_vars)
assert isinstance(result, ProperType) and isinstance(result, CallableType)
result.variables = result.variables + tuple(found_vars)
return result
def visit_type_var(self, t: TypeVarType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_type_var(t)
def visit_param_spec(self, t: ParamSpecType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_param_spec(t)
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_type_var_tuple(t)
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
if not t.args:
return t.copy_modified()
if not t.is_recursive:
return get_proper_type(t).accept(self)
# We can't handle polymorphic application for recursive generic aliases
# without risking an infinite recursion, just give up for now.
raise PolyTranslationError()
def visit_instance(self, t: Instance) -> Type:
if t.type.has_param_spec_type:
# We need this special-casing to preserve the possibility to store a
# generic function in an instance type. Things like
# forall T . Foo[[x: T], T]
# are not really expressible in current type system, but this looks like
# a useful feature, so let's keep it.
param_spec_index = next(
i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType)
)
p = get_proper_type(t.args[param_spec_index])
if isinstance(p, Parameters):
found_vars = self.collect_vars(p)
self.bound_tvars |= set(found_vars)
new_args = [a.accept(self) for a in t.args]
self.bound_tvars -= set(found_vars)
repl = new_args[param_spec_index]
assert isinstance(repl, ProperType) and isinstance(repl, Parameters)
repl.variables = list(repl.variables) + list(found_vars)
return t.copy_modified(args=new_args)
# There is the same problem with callback protocols as with aliases
# (callback protocols are essentially more flexible aliases to callables).
if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]:
if t.type in self.seen_aliases:
raise PolyTranslationError()
call = mypy.subtypes.find_member("__call__", t, t, is_operator=True)
assert call is not None
return call.accept(
PolyTranslator(self.poly_tvars, self.bound_tvars, self.seen_aliases | {t.type})
)
return super().visit_instance(t)

View file

@ -0,0 +1,269 @@
"""Utilities for mapping between actual and formal arguments (and their types)."""
from __future__ import annotations
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING
from mypy import nodes
from mypy.maptype import map_instance_to_supertype
from mypy.types import (
AnyType,
Instance,
ParamSpecType,
TupleType,
Type,
TypedDictType,
TypeOfAny,
TypeVarTupleType,
UnpackType,
get_proper_type,
)
if TYPE_CHECKING:
from mypy.infer import ArgumentInferContext
def map_actuals_to_formals(
actual_kinds: list[nodes.ArgKind],
actual_names: Sequence[str | None] | None,
formal_kinds: list[nodes.ArgKind],
formal_names: Sequence[str | None],
actual_arg_type: Callable[[int], Type],
) -> list[list[int]]:
"""Calculate mapping between actual (caller) args and formals.
The result contains a list of caller argument indexes mapping to each
callee argument index, indexed by callee index.
The actual_arg_type argument should evaluate to the type of the actual
argument with the given index.
"""
nformals = len(formal_kinds)
formal_to_actual: list[list[int]] = [[] for i in range(nformals)]
ambiguous_actual_kwargs: list[int] = []
fi = 0
for ai, actual_kind in enumerate(actual_kinds):
if actual_kind == nodes.ARG_POS:
if fi < nformals:
if not formal_kinds[fi].is_star():
formal_to_actual[fi].append(ai)
fi += 1
elif formal_kinds[fi] == nodes.ARG_STAR:
formal_to_actual[fi].append(ai)
elif actual_kind == nodes.ARG_STAR:
# We need to know the actual type to map varargs.
actualt = get_proper_type(actual_arg_type(ai))
if isinstance(actualt, TupleType):
# A tuple actual maps to a fixed number of formals.
for _ in range(len(actualt.items)):
if fi < nformals:
if formal_kinds[fi] != nodes.ARG_STAR2:
formal_to_actual[fi].append(ai)
else:
break
if formal_kinds[fi] != nodes.ARG_STAR:
fi += 1
else:
# Assume that it is an iterable (if it isn't, there will be
# an error later).
while fi < nformals:
if formal_kinds[fi].is_named(star=True):
break
else:
formal_to_actual[fi].append(ai)
if formal_kinds[fi] == nodes.ARG_STAR:
break
fi += 1
elif actual_kind.is_named():
assert actual_names is not None, "Internal error: named kinds without names given"
name = actual_names[ai]
if name in formal_names and formal_kinds[formal_names.index(name)] != nodes.ARG_STAR:
formal_to_actual[formal_names.index(name)].append(ai)
elif nodes.ARG_STAR2 in formal_kinds:
formal_to_actual[formal_kinds.index(nodes.ARG_STAR2)].append(ai)
else:
assert actual_kind == nodes.ARG_STAR2
actualt = get_proper_type(actual_arg_type(ai))
if isinstance(actualt, TypedDictType):
for name in actualt.items:
if name in formal_names:
formal_to_actual[formal_names.index(name)].append(ai)
elif nodes.ARG_STAR2 in formal_kinds:
formal_to_actual[formal_kinds.index(nodes.ARG_STAR2)].append(ai)
else:
# We don't exactly know which **kwargs are provided by the
# caller, so we'll defer until all the other unambiguous
# actuals have been processed
ambiguous_actual_kwargs.append(ai)
if ambiguous_actual_kwargs:
# Assume the ambiguous kwargs will fill the remaining arguments.
#
# TODO: If there are also tuple varargs, we might be missing some potential
# matches if the tuple was short enough to not match everything.
unmatched_formals = [
fi
for fi in range(nformals)
if (
formal_names[fi]
and (
not formal_to_actual[fi]
or actual_kinds[formal_to_actual[fi][0]] == nodes.ARG_STAR
)
and formal_kinds[fi] != nodes.ARG_STAR
)
or formal_kinds[fi] == nodes.ARG_STAR2
]
for ai in ambiguous_actual_kwargs:
for fi in unmatched_formals:
formal_to_actual[fi].append(ai)
return formal_to_actual
def map_formals_to_actuals(
actual_kinds: list[nodes.ArgKind],
actual_names: Sequence[str | None] | None,
formal_kinds: list[nodes.ArgKind],
formal_names: list[str | None],
actual_arg_type: Callable[[int], Type],
) -> list[list[int]]:
"""Calculate the reverse mapping of map_actuals_to_formals."""
formal_to_actual = map_actuals_to_formals(
actual_kinds, actual_names, formal_kinds, formal_names, actual_arg_type
)
# Now reverse the mapping.
actual_to_formal: list[list[int]] = [[] for _ in actual_kinds]
for formal, actuals in enumerate(formal_to_actual):
for actual in actuals:
actual_to_formal[actual].append(formal)
return actual_to_formal
class ArgTypeExpander:
"""Utility class for mapping actual argument types to formal arguments.
One of the main responsibilities is to expand caller tuple *args and TypedDict
**kwargs, and to keep track of which tuple/TypedDict items have already been
consumed.
Example:
def f(x: int, *args: str) -> None: ...
f(*(1, 'x', 1.1))
We'd call expand_actual_type three times:
1. The first call would provide 'int' as the actual type of 'x' (from '1').
2. The second call would provide 'str' as one of the actual types for '*args'.
2. The third call would provide 'float' as one of the actual types for '*args'.
A single instance can process all the arguments for a single call. Each call
needs a separate instance since instances have per-call state.
"""
def __init__(self, context: ArgumentInferContext) -> None:
# Next tuple *args index to use.
self.tuple_index = 0
# Keyword arguments in TypedDict **kwargs used.
self.kwargs_used: set[str] | None = None
# Type context for `*` and `**` arg kinds.
self.context = context
def expand_actual_type(
self,
actual_type: Type,
actual_kind: nodes.ArgKind,
formal_name: str | None,
formal_kind: nodes.ArgKind,
allow_unpack: bool = False,
) -> Type:
"""Return the actual (caller) type(s) of a formal argument with the given kinds.
If the actual argument is a tuple *args, return the next individual tuple item that
maps to the formal arg.
If the actual argument is a TypedDict **kwargs, return the next matching typed dict
value type based on formal argument name and kind.
This is supposed to be called for each formal, in order. Call multiple times per
formal if multiple actuals map to a formal.
"""
original_actual = actual_type
actual_type = get_proper_type(actual_type)
if actual_kind == nodes.ARG_STAR:
if isinstance(actual_type, TypeVarTupleType):
# This code path is hit when *Ts is passed to a callable and various
# special-handling didn't catch this. The best thing we can do is to use
# the upper bound.
actual_type = get_proper_type(actual_type.upper_bound)
if isinstance(actual_type, Instance) and actual_type.args:
from mypy.subtypes import is_subtype
if is_subtype(actual_type, self.context.iterable_type):
return map_instance_to_supertype(
actual_type, self.context.iterable_type.type
).args[0]
else:
# We cannot properly unpack anything other
# than `Iterable` type with `*`.
# Just return `Any`, other parts of code would raise
# a different error for improper use.
return AnyType(TypeOfAny.from_error)
elif isinstance(actual_type, TupleType):
# Get the next tuple item of a tuple *arg.
if self.tuple_index >= len(actual_type.items):
# Exhausted a tuple -- continue to the next *args.
self.tuple_index = 1
else:
self.tuple_index += 1
item = actual_type.items[self.tuple_index - 1]
if isinstance(item, UnpackType) and not allow_unpack:
# An unpack item that doesn't have special handling, use upper bound as above.
unpacked = get_proper_type(item.type)
if isinstance(unpacked, TypeVarTupleType):
fallback = get_proper_type(unpacked.upper_bound)
else:
fallback = unpacked
assert (
isinstance(fallback, Instance)
and fallback.type.fullname == "builtins.tuple"
)
item = fallback.args[0]
return item
elif isinstance(actual_type, ParamSpecType):
# ParamSpec is valid in *args but it can't be unpacked.
return actual_type
else:
return AnyType(TypeOfAny.from_error)
elif actual_kind == nodes.ARG_STAR2:
from mypy.subtypes import is_subtype
if isinstance(actual_type, TypedDictType):
if self.kwargs_used is None:
self.kwargs_used = set()
if formal_kind != nodes.ARG_STAR2 and formal_name in actual_type.items:
# Lookup type based on keyword argument name.
assert formal_name is not None
else:
# Pick an arbitrary item if no specified keyword is expected.
formal_name = (set(actual_type.items.keys()) - self.kwargs_used).pop()
self.kwargs_used.add(formal_name)
return actual_type.items[formal_name]
elif isinstance(actual_type, Instance) and is_subtype(
actual_type, self.context.mapping_type
):
# Only `Mapping` type can be unpacked with `**`.
# Other types will produce an error somewhere else.
return map_instance_to_supertype(actual_type, self.context.mapping_type.type).args[
1
]
elif isinstance(actual_type, ParamSpecType):
# ParamSpec is valid in **kwargs but it can't be unpacked.
return actual_type
else:
return AnyType(TypeOfAny.from_error)
else:
# No translation for other kinds -- 1:1 mapping.
return original_actual

View file

@ -0,0 +1,709 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Literal, NamedTuple, TypeAlias as _TypeAlias
from mypy.erasetype import remove_instance_last_known_values
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash, subkeys
from mypy.nodes import (
LITERAL_NO,
Expression,
IndexExpr,
MemberExpr,
NameExpr,
RefExpr,
TypeInfo,
Var,
)
from mypy.options import Options
from mypy.subtypes import is_same_type, is_subtype
from mypy.typeops import make_simplified_union
from mypy.types import (
AnyType,
Instance,
NoneType,
PartialType,
ProperType,
TupleType,
Type,
TypeOfAny,
TypeType,
TypeVarType,
UnionType,
UnpackType,
find_unpack_in_list,
flatten_nested_unions,
get_proper_type,
)
from mypy.typevars import fill_typevars_with_any
BindableExpression: _TypeAlias = IndexExpr | MemberExpr | NameExpr
class CurrentType(NamedTuple):
type: Type
from_assignment: bool
class Frame:
"""A Frame represents a specific point in the execution of a program.
It carries information about the current types of expressions at
that point, arising either from assignments to those expressions
or the result of isinstance checks and other type narrowing
operations. It also records whether it is possible to reach that
point at all.
We add a new frame wherenever there is a new scope or control flow
branching.
This information is not copied into a new Frame when it is pushed
onto the stack, so a given Frame only has information about types
that were assigned in that frame.
Expressions are stored in dicts using 'literal hashes' as keys (type
"Key"). These are hashable values derived from expression AST nodes
(only those that can be narrowed). literal_hash(expr) is used to
calculate the hashes. Note that this isn't directly related to literal
types -- the concept predates literal types.
"""
def __init__(self, id: int, conditional_frame: bool = False) -> None:
self.id = id
self.types: dict[Key, CurrentType] = {}
self.unreachable = False
self.conditional_frame = conditional_frame
self.suppress_unreachable_warnings = False
def __repr__(self) -> str:
return f"Frame({self.id}, {self.types}, {self.unreachable}, {self.conditional_frame})"
Assigns = defaultdict[Expression, list[tuple[Type, Type | None]]]
class FrameContext:
"""Context manager pushing a Frame to ConditionalTypeBinder.
See frame_context() below for documentation on parameters. We use this class
instead of @contextmanager as a mypyc-specific performance optimization.
"""
def __init__(
self,
binder: ConditionalTypeBinder,
can_skip: bool,
fall_through: int,
break_frame: int,
continue_frame: int,
conditional_frame: bool,
try_frame: bool,
discard: bool,
) -> None:
self.binder = binder
self.can_skip = can_skip
self.fall_through = fall_through
self.break_frame = break_frame
self.continue_frame = continue_frame
self.conditional_frame = conditional_frame
self.try_frame = try_frame
self.discard = discard
def __enter__(self) -> Frame:
assert len(self.binder.frames) > 1
if self.break_frame:
self.binder.break_frames.append(len(self.binder.frames) - self.break_frame)
if self.continue_frame:
self.binder.continue_frames.append(len(self.binder.frames) - self.continue_frame)
if self.try_frame:
self.binder.try_frames.add(len(self.binder.frames) - 1)
new_frame = self.binder.push_frame(self.conditional_frame)
if self.try_frame:
# An exception may occur immediately
self.binder.allow_jump(-1)
return new_frame
def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> Literal[False]:
self.binder.pop_frame(self.can_skip, self.fall_through, discard=self.discard)
if self.break_frame:
self.binder.break_frames.pop()
if self.continue_frame:
self.binder.continue_frames.pop()
if self.try_frame:
self.binder.try_frames.remove(len(self.binder.frames) - 1)
return False
class ConditionalTypeBinder:
"""Keep track of conditional types of variables.
NB: Variables are tracked by literal hashes of expressions, so it is
possible to confuse the binder when there is aliasing. Example:
class A:
a: int | str
x = A()
lst = [x]
reveal_type(x.a) # int | str
x.a = 1
reveal_type(x.a) # int
reveal_type(lst[0].a) # int | str
lst[0].a = 'a'
reveal_type(x.a) # int
reveal_type(lst[0].a) # str
"""
# Stored assignments for situations with tuple/list lvalue and rvalue of union type.
# This maps an expression to a list of bound types for every item in the union type.
type_assignments: Assigns | None = None
def __init__(self, options: Options) -> None:
# Each frame gets an increasing, distinct id.
self.next_id = 1
# The stack of frames currently used. These map
# literal_hash(expr) -- literals like 'foo.bar' --
# to types. The last element of this list is the
# top-most, current frame. Each earlier element
# records the state as of when that frame was last
# on top of the stack.
self.frames = [Frame(self._get_id())]
# For frames higher in the stack, we record the set of
# Frames that can escape there, either by falling off
# the end of the frame or by a loop control construct
# or raised exception. The last element of self.frames
# has no corresponding element in this list.
self.options_on_return: list[list[Frame]] = []
# Maps literal_hash(expr) to get_declaration(expr)
# for every expr stored in the binder
self.declarations: dict[Key, Type | None] = {}
# Set of other keys to invalidate if a key is changed, e.g. x -> {x.a, x[0]}
# Whenever a new key (e.g. x.a.b) is added, we update this
self.dependencies: dict[Key, set[Key]] = {}
# Whether the last pop changed the newly top frame on exit
self.last_pop_changed = False
# These are used to track control flow in try statements and loops.
self.try_frames: set[int] = set()
self.break_frames: list[int] = []
self.continue_frames: list[int] = []
# If True, initial assignment to a simple variable (e.g. "x", but not "x.y")
# is added to the binder. This allows more precise narrowing and more
# flexible inference of variable types (--allow-redefinition-new).
self.bind_all = options.allow_redefinition_new
# This tracks any externally visible changes in binder to invalidate
# expression caches when needed.
self.version = 0
def _get_id(self) -> int:
self.next_id += 1
return self.next_id
def _add_dependencies(self, key: Key, value: Key | None = None) -> None:
if value is None:
value = key
else:
self.dependencies.setdefault(key, set()).add(value)
for elt in subkeys(key):
self._add_dependencies(elt, value)
def push_frame(self, conditional_frame: bool = False) -> Frame:
"""Push a new frame into the binder."""
f = Frame(self._get_id(), conditional_frame)
self.frames.append(f)
self.options_on_return.append([])
return f
def _put(self, key: Key, type: Type, from_assignment: bool, index: int = -1) -> None:
self.version += 1
self.frames[index].types[key] = CurrentType(type, from_assignment)
def _get(self, key: Key, index: int = -1) -> CurrentType | None:
if index < 0:
index += len(self.frames)
for i in range(index, -1, -1):
if key in self.frames[i].types:
return self.frames[i].types[key]
return None
@classmethod
def can_put_directly(cls, expr: Expression) -> bool:
"""Will `.put()` on this expression be successful?
This is inlined in `.put()` because the logic is rather hot and must be kept
in sync.
"""
return isinstance(expr, (IndexExpr, MemberExpr, NameExpr)) and literal(expr) > LITERAL_NO
def put(self, expr: Expression, typ: Type, *, from_assignment: bool = True) -> None:
"""Directly set the narrowed type of expression (if it supports it).
This is used for isinstance() etc. Assignments should go through assign_type().
"""
if not isinstance(expr, (IndexExpr, MemberExpr, NameExpr)):
return
if not literal(expr):
return
key = literal_hash(expr)
assert key is not None, "Internal error: binder tried to put non-literal"
if key not in self.declarations:
self.declarations[key] = get_declaration(expr)
self._add_dependencies(key)
self._put(key, typ, from_assignment)
def unreachable(self) -> None:
self.version += 1
self.frames[-1].unreachable = True
def suppress_unreachable_warnings(self) -> None:
self.frames[-1].suppress_unreachable_warnings = True
def get(self, expr: Expression) -> Type | None:
key = literal_hash(expr)
assert key is not None, "Internal error: binder tried to get non-literal"
found = self._get(key)
if found is None:
return None
return found.type
def is_unreachable(self) -> bool:
# TODO: Copy the value of unreachable into new frames to avoid
# this traversal on every statement?
return any(f.unreachable for f in self.frames)
def is_unreachable_warning_suppressed(self) -> bool:
return any(f.suppress_unreachable_warnings for f in self.frames)
def cleanse(self, expr: Expression) -> None:
"""Remove all references to a Node from the binder."""
key = literal_hash(expr)
assert key is not None, "Internal error: binder tried cleanse non-literal"
self._cleanse_key(key)
def _cleanse_key(self, key: Key) -> None:
"""Remove all references to a key from the binder."""
for frame in self.frames:
if key in frame.types:
del frame.types[key]
def update_from_options(self, frames: list[Frame]) -> bool:
"""Update the frame to reflect that each key will be updated
as in one of the frames. Return whether any item changes.
If a key is declared as AnyType, only update it if all the
options are the same.
"""
all_reachable = all(not f.unreachable for f in frames)
if not all_reachable:
frames = [f for f in frames if not f.unreachable]
changed = False
keys = [key for f in frames for key in f.types]
if len(keys) > 1:
keys = list(set(keys))
for key in keys:
current_value = self._get(key)
resulting_values = [f.types.get(key, current_value) for f in frames]
# Keys can be narrowed using two different semantics. The new semantics
# is enabled for inferred variables when bind_all is true, and it allows
# variable types to be widened using subsequent assignments. This is
# not allowed for instance attributes and annotated variables.
var = extract_var_from_literal_hash(key)
old_semantics = (
not self.bind_all or var is None or not var.is_inferred and not var.is_argument
)
if old_semantics and any(x is None for x in resulting_values):
# We didn't know anything about key before
# (current_value must be None), and we still don't
# know anything about key in at least one possible frame.
continue
resulting_values = [x for x in resulting_values if x is not None]
if all_reachable and all(not x.from_assignment for x in resulting_values):
# Do not synthesize a new type if we encountered a conditional block
# (if, while or match-case) without assignments.
# See check-isinstance.test::testNoneCheckDoesNotMakeTypeVarOptional
# This is a safe assumption: the fact that we checked something with `is`
# or `isinstance` does not change the type of the value.
continue
# Remove exact duplicates to save pointless work later, this is
# a micro-optimization for --allow-redefinition-new.
seen_types = set()
resulting_types = []
for rv in resulting_values:
assert rv is not None
if rv.type in seen_types:
continue
resulting_types.append(rv.type)
seen_types.add(rv.type)
type = resulting_types[0]
declaration_type = get_proper_type(self.declarations.get(key))
if isinstance(declaration_type, AnyType):
# At this point resulting values can't contain None, see continue above
if not all(is_same_type(type, t) for t in resulting_types[1:]):
type = AnyType(TypeOfAny.from_another_any, source_any=declaration_type)
else:
possible_types = []
for t in resulting_types:
assert t is not None
possible_types.append(t)
if len(possible_types) == 1:
# This is to avoid calling get_proper_type() unless needed, as this may
# interfere with our (hacky) TypeGuard support.
type = possible_types[0]
else:
type = make_simplified_union(possible_types)
# Legacy guard for corner case when the original type is TypeVarType.
if isinstance(declaration_type, TypeVarType) and not is_subtype(
type, declaration_type
):
type = declaration_type
# Try simplifying resulting type for unions involving variadic tuples.
# Technically, everything is still valid without this step, but if we do
# not do this, this may create long unions after exiting an if check like:
# x: tuple[int, ...]
# if len(x) < 10:
# ...
# We want the type of x to be tuple[int, ...] after this block (if it is
# still equivalent to such type).
if isinstance(type, UnionType):
type = collapse_variadic_union(type)
if (
old_semantics
and isinstance(type, ProperType)
and isinstance(type, UnionType)
):
# Simplify away any extra Any's that were added to the declared
# type when popping a frame.
simplified = UnionType.make_union(
[t for t in type.items if not isinstance(get_proper_type(t), AnyType)]
)
if simplified == self.declarations[key]:
type = simplified
if (
current_value is None
or not is_same_type(type, current_value.type)
# Manually carry over any narrowing from hasattr() from inner frames. This is
# a bit ad-hoc, but our handling of hasattr() is on best effort basis anyway.
or isinstance(p_type := get_proper_type(type), Instance)
and p_type.extra_attrs
):
self._put(key, type, from_assignment=True)
if current_value is not None or extract_var_from_literal_hash(key) is None:
# We definitely learned something new
changed = True
elif not changed:
# If there is no current value compare with the declaration. This prevents
# reporting false changes in cases like this:
# x: int
# if foo():
# x = 1
# else:
# x = 2
# We check partial types and widening in accept_loop() separately, so
# this should be safe.
changed = declaration_type is not None and not is_same_type(
type, declaration_type
)
self.frames[-1].unreachable = not frames
return changed
def pop_frame(self, can_skip: bool, fall_through: int, *, discard: bool = False) -> Frame:
"""Pop a frame and return it.
See frame_context() for documentation of fall_through and discard.
"""
if fall_through > 0:
self.allow_jump(-fall_through)
result = self.frames.pop()
options = self.options_on_return.pop()
if discard:
self.last_pop_changed = False
return result
if can_skip:
options.insert(0, self.frames[-1])
self.last_pop_changed = self.update_from_options(options)
return result
@contextmanager
def accumulate_type_assignments(self) -> Iterator[Assigns]:
"""Push a new map to collect assigned types in multiassign from union.
If this map is not None, actual binding is deferred until all items in
the union are processed (a union of collected items is later bound
manually by the caller).
"""
old_assignments = None
if self.type_assignments is not None:
old_assignments = self.type_assignments
self.type_assignments = defaultdict(list)
yield self.type_assignments
self.type_assignments = old_assignments
def assign_type(self, expr: Expression, type: Type, declared_type: Type | None) -> None:
"""Narrow type of expression through an assignment.
Do nothing if the expression doesn't support narrowing.
When not narrowing though an assignment (isinstance() etc.), use put()
directly. This omits some special-casing logic for assignments.
"""
# We should erase last known value in binder, because if we are using it,
# it means that the target is not final, and therefore can't hold a literal.
type = remove_instance_last_known_values(type)
if self.type_assignments is not None:
# We are in a multiassign from union, defer the actual binding,
# just collect the types.
self.type_assignments[expr].append((type, declared_type))
return
if not isinstance(expr, (IndexExpr, MemberExpr, NameExpr)):
return
if not literal(expr):
return
self.invalidate_dependencies(expr)
if declared_type is None:
# Not sure why this happens. It seems to mainly happen in
# member initialization.
return
if not is_subtype(type, declared_type):
# Pretty sure this is only happens when there's a type error.
# Ideally this function wouldn't be called if the
# expression has a type error, though -- do other kinds of
# errors cause this function to get called at invalid
# times?
return
p_declared = get_proper_type(declared_type)
p_type = get_proper_type(type)
if isinstance(p_type, AnyType):
# Any type requires some special casing, for both historical reasons,
# and to optimise user experience without sacrificing correctness too much.
if isinstance(expr, RefExpr) and isinstance(expr.node, Var) and expr.node.is_inferred:
# First case: a local/global variable without explicit annotation,
# in this case we just assign Any (essentially following the SSA logic).
self.put(expr, type)
elif isinstance(p_declared, UnionType):
all_items = flatten_nested_unions(p_declared.items)
if any(isinstance(get_proper_type(item), NoneType) for item in all_items):
# Second case: explicit optional type, in this case we optimize for
# a common pattern when an untyped value used as a fallback replacing None.
new_items = [
type if isinstance(get_proper_type(item), NoneType) else item
for item in all_items
]
self.put(expr, UnionType(new_items))
elif any(isinstance(get_proper_type(item), AnyType) for item in all_items):
# Third case: a union already containing Any (most likely from
# an un-imported name), in this case we allow assigning Any as well.
self.put(expr, type)
else:
# In all other cases we don't narrow to Any to minimize false negatives.
self.put(expr, declared_type)
else:
self.put(expr, declared_type)
elif isinstance(p_declared, AnyType):
# Mirroring the first case above, we don't narrow to a precise type if the variable
# has an explicit `Any` type annotation.
if isinstance(expr, RefExpr) and isinstance(expr.node, Var) and expr.node.is_inferred:
self.put(expr, type)
else:
self.put(expr, declared_type)
else:
self.put(expr, type)
for i in self.try_frames:
# XXX This should probably not copy the entire frame, but
# just copy this variable into a single stored frame.
self.allow_jump(i)
def invalidate_dependencies(self, expr: BindableExpression) -> None:
"""Invalidate knowledge of types that include expr, but not expr itself.
For example, when expr is foo.bar, invalidate foo.bar.baz.
It is overly conservative: it invalidates globally, including
in code paths unreachable from here.
"""
key = literal_hash(expr)
assert key is not None
for dep in self.dependencies.get(key, set()):
self._cleanse_key(dep)
def allow_jump(self, index: int) -> None:
# self.frames and self.options_on_return have different lengths
# so make sure the index is positive
if index < 0:
index += len(self.options_on_return)
frame = Frame(self._get_id())
for f in self.frames[index + 1 :]:
frame.types.update(f.types)
if f.unreachable:
frame.unreachable = True
self.options_on_return[index].append(frame)
def handle_break(self) -> None:
self.allow_jump(self.break_frames[-1])
self.unreachable()
def handle_continue(self) -> None:
self.allow_jump(self.continue_frames[-1])
self.unreachable()
def frame_context(
self,
*,
can_skip: bool,
fall_through: int = 1,
break_frame: int = 0,
continue_frame: int = 0,
conditional_frame: bool = False,
try_frame: bool = False,
discard: bool = False,
) -> FrameContext:
"""Return a context manager that pushes/pops frames on enter/exit.
If can_skip is True, control flow is allowed to bypass the
newly-created frame.
If fall_through > 0, then it will allow control flow that
falls off the end of the frame to escape to its ancestor
`fall_through` levels higher. Otherwise, control flow ends
at the end of the frame.
If break_frame > 0, then 'break' statements within this frame
will jump out to the frame break_frame levels higher than the
frame created by this call to frame_context. Similarly, for
continue_frame and 'continue' statements.
If try_frame is true, then execution is allowed to jump at any
point within the newly created frame (or its descendants) to
its parent (i.e., to the frame that was on top before this
call to frame_context).
If discard is True, then this is a temporary throw-away frame
(used e.g. for isolation) and its effect will be discarded on pop.
After the context manager exits, self.last_pop_changed indicates
whether any types changed in the newly-topmost frame as a result
of popping this frame.
"""
return FrameContext(
self,
can_skip=can_skip,
fall_through=fall_through,
break_frame=break_frame,
continue_frame=continue_frame,
conditional_frame=conditional_frame,
try_frame=try_frame,
discard=discard,
)
@contextmanager
def top_frame_context(self) -> Iterator[Frame]:
"""A variant of frame_context for use at the top level of
a namespace (module, function, or class).
"""
assert len(self.frames) == 1
yield self.push_frame()
self.pop_frame(True, 0)
assert len(self.frames) == 1
def get_declaration(expr: BindableExpression) -> Type | None:
"""Get the declared or inferred type of a RefExpr expression.
Return None if there is no type or the expression is not a RefExpr.
This can return None if the type hasn't been inferred yet.
"""
if isinstance(expr, RefExpr):
if isinstance(expr.node, Var):
type = expr.node.type
if not isinstance(get_proper_type(type), PartialType):
return type
elif isinstance(expr.node, TypeInfo):
return TypeType(fill_typevars_with_any(expr.node))
return None
def collapse_variadic_union(typ: UnionType) -> Type:
"""Simplify a union involving variadic tuple if possible.
This will collapse a type like e.g.
tuple[X, Z] | tuple[X, Y, Z] | tuple[X, Y, Y, *tuple[Y, ...], Z]
back to
tuple[X, *tuple[Y, ...], Z]
which is equivalent, but much simpler form of the same type.
"""
tuple_items = []
other_items = []
for t in typ.items:
p_t = get_proper_type(t)
if isinstance(p_t, TupleType):
tuple_items.append(p_t)
else:
other_items.append(t)
if len(tuple_items) <= 1:
# This type cannot be simplified further.
return typ
tuple_items = sorted(tuple_items, key=lambda t: len(t.items))
first = tuple_items[0]
last = tuple_items[-1]
unpack_index = find_unpack_in_list(last.items)
if unpack_index is None:
return typ
unpack = last.items[unpack_index]
assert isinstance(unpack, UnpackType)
unpacked = get_proper_type(unpack.type)
if not isinstance(unpacked, Instance):
return typ
assert unpacked.type.fullname == "builtins.tuple"
suffix = last.items[unpack_index + 1 :]
# Check that first item matches the expected pattern and infer prefix.
if len(first.items) < len(suffix):
return typ
if suffix and first.items[-len(suffix) :] != suffix:
return typ
if suffix:
prefix = first.items[: -len(suffix)]
else:
prefix = first.items
# Check that all middle types match the expected pattern as well.
arg = unpacked.args[0]
for i, it in enumerate(tuple_items[1:-1]):
if it.items != prefix + [arg] * (i + 1) + suffix:
return typ
# Check the last item (the one with unpack), and choose an appropriate simplified type.
if last.items != prefix + [arg] * (len(typ.items) - 1) + [unpack] + suffix:
return typ
if len(first.items) == 0:
simplified: Type = unpacked.copy_modified()
else:
simplified = TupleType(prefix + [unpack] + suffix, fallback=last.partial_fallback)
return UnionType.make_union([simplified] + other_items)

View file

@ -0,0 +1,27 @@
"""A Bogus[T] type alias for marking when we subvert the type system
We need this for compiling with mypyc, which inserts runtime
typechecks that cause problems when we subvert the type system. So
when compiling with mypyc, we turn those places into Any, while
keeping the types around for normal typechecks.
Since this causes the runtime types to be Any, this is best used
in places where efficient access to properties is not important.
For those cases some other technique should be used.
"""
from __future__ import annotations
from typing import Any, TypeVar
from mypy_extensions import FlexibleAlias
T = TypeVar("T")
# This won't ever be true at runtime, but we consider it true during
# mypyc compilations.
MYPYC = False
if MYPYC:
Bogus = FlexibleAlias[T, Any]
else:
Bogus = FlexibleAlias[T, T]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,6 @@
from __future__ import annotations
from mypy.build_worker.worker import console_entry
if __name__ == "__main__":
console_entry()

View file

@ -0,0 +1,249 @@
"""
Mypy parallel build worker.
The protocol of communication with the coordinator is as following:
* Read (pickled) build options from command line.
* Populate status file with pid and socket address.
* Receive build sources from coordinator.
* Load graph using the sources, and send ack to coordinator.
* Receive SCC structure from coordinator, and ack it.
* Receive an SCC id from coordinator, process it, and send back the results.
* When prompted by coordinator (with a scc_id=None message), cleanup and shutdown.
"""
from __future__ import annotations
import argparse
import gc
import json
import os
import pickle
import platform
import sys
import time
from typing import NamedTuple
from librt.base64 import b64decode
from mypy import util
from mypy.build import (
SCC,
AckMessage,
BuildManager,
Graph,
GraphMessage,
SccRequestMessage,
SccResponseMessage,
SccsDataMessage,
SourcesDataMessage,
load_plugins,
process_stale_scc,
)
from mypy.defaults import RECURSION_LIMIT, WORKER_CONNECTION_TIMEOUT
from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error
from mypy.fscache import FileSystemCache
from mypy.ipc import IPCException, IPCServer, receive, send
from mypy.modulefinder import BuildSource, BuildSourceSet, compute_search_paths
from mypy.nodes import FileRawData
from mypy.options import Options
from mypy.util import read_py_file
from mypy.version import __version__
parser = argparse.ArgumentParser(prog="mypy_worker", description="Mypy build worker")
parser.add_argument("--status-file", help="status file to communicate worker details")
parser.add_argument("--options-data", help="serialized mypy options")
CONNECTION_NAME = "build_worker"
class ServerContext(NamedTuple):
options: Options
disable_error_code: list[str]
enable_error_code: list[str]
errors: Errors
fscache: FileSystemCache
def main(argv: list[str]) -> None:
# Set recursion limit and GC thresholds consistent with mypy/main.py
sys.setrecursionlimit(RECURSION_LIMIT)
if platform.python_implementation() == "CPython":
gc.set_threshold(200 * 1000, 30, 30)
args = parser.parse_args(argv)
# This mimics how daemon receives the options. Note we need to postpone
# processing error codes after plugins are loaded, because plugins can add
# custom error codes.
options_dict = pickle.loads(b64decode(args.options_data))
options_obj = Options()
disable_error_code = options_dict.pop("disable_error_code", [])
enable_error_code = options_dict.pop("enable_error_code", [])
options = options_obj.apply_changes(options_dict)
status_file = args.status_file
server = IPCServer(CONNECTION_NAME, WORKER_CONNECTION_TIMEOUT)
try:
with open(status_file, "w") as f:
json.dump({"pid": os.getpid(), "connection_name": server.connection_name}, f)
f.write("\n")
except Exception as exc:
print(f"Error writing status file {status_file}:", exc)
raise
fscache = FileSystemCache()
cached_read = fscache.read
errors = Errors(options, read_source=lambda path: read_py_file(path, cached_read))
ctx = ServerContext(options, disable_error_code, enable_error_code, errors, fscache)
try:
with server:
serve(server, ctx)
except (OSError, IPCException) as exc:
if options.verbosity >= 1:
print("Error communicating with coordinator:", exc)
except Exception as exc:
report_internal_error(exc, errors.file, 0, errors, options)
finally:
server.cleanup()
if options.fast_exit:
# Exit fast if allowed, since coordinator is waiting on us.
util.hard_exit(0)
def serve(server: IPCServer, ctx: ServerContext) -> None:
"""Main server loop of the worker.
Receive initial state from the coordinator, then process each
SCC checking request and reply to client (coordinator). See module
docstring for more details on the protocol.
"""
sources = SourcesDataMessage.read(receive(server)).sources
manager = setup_worker_manager(sources, ctx)
if manager is None:
return
# Notify coordinator we are done with setup.
send(server, AckMessage())
graph_data = GraphMessage.read(receive(server), manager)
# Update some manager data in-place as it has been passed to semantic analyzer.
manager.missing_modules |= graph_data.missing_modules
graph = graph_data.graph
for id in graph:
manager.import_map[id] = graph[id].dependencies_set
# Link modules dicts, so that plugins will get access to ASTs as we parse them.
manager.plugin.set_modules(manager.modules)
# Notify coordinator we are ready to receive computed graph SCC structure.
send(server, AckMessage())
sccs = SccsDataMessage.read(receive(server)).sccs
manager.scc_by_id = {scc.id: scc for scc in sccs}
manager.top_order = [scc.id for scc in sccs]
# Notify coordinator we are ready to start processing SCCs.
send(server, AckMessage())
while True:
scc_message = SccRequestMessage.read(receive(server))
scc_id = scc_message.scc_id
if scc_id is None:
manager.dump_stats()
break
scc = manager.scc_by_id[scc_id]
t0 = time.time()
try:
if platform.python_implementation() == "CPython":
# Since we are splitting the GC freeze hack into multiple smaller freezes,
# we should collect young generations to not accumulate accidental garbage.
gc.collect(generation=1)
gc.collect(generation=0)
gc.disable()
load_states(scc, graph, manager, scc_message.import_errors, scc_message.mod_data)
if platform.python_implementation() == "CPython":
gc.freeze()
gc.unfreeze()
gc.enable()
result = process_stale_scc(graph, scc, manager, from_cache=graph_data.from_cache)
# We must commit after each SCC, otherwise we break --sqlite-cache.
manager.metastore.commit()
except CompileError as blocker:
send(server, SccResponseMessage(scc_id=scc_id, blocker=blocker))
else:
send(server, SccResponseMessage(scc_id=scc_id, result=result))
manager.add_stats(total_process_stale_time=time.time() - t0, stale_sccs_processed=1)
def load_states(
scc: SCC,
graph: Graph,
manager: BuildManager,
import_errors: dict[str, list[ErrorInfo]],
mod_data: dict[str, tuple[bytes, FileRawData | None]],
) -> None:
"""Re-create full state of an SCC as it would have been in coordinator."""
for id in scc.mod_ids:
state = graph[id]
# Re-clone options since we don't send them, it is usually faster than deserializing.
state.options = state.options.clone_for_module(state.id)
suppressed_deps_opts, raw_data = mod_data[id]
state.parse_file(raw_data=raw_data)
# Set data that is needed to be written to cache meta.
state.known_suppressed_deps_opts = suppressed_deps_opts
assert state.tree is not None
import_lines = {imp.line for imp in state.tree.imports}
state.imports_ignored = {
line: codes for line, codes in state.tree.ignored_lines.items() if line in import_lines
}
# Replay original errors encountered during graph loading in coordinator.
if id in import_errors:
manager.errors.set_file(state.xpath, id, state.options)
for err_info in import_errors[id]:
manager.errors.add_error_info(err_info)
def setup_worker_manager(sources: list[BuildSource], ctx: ServerContext) -> BuildManager | None:
data_dir = os.path.dirname(os.path.dirname(__file__))
# This is used for testing only now.
alt_lib_path = os.environ.get("MYPY_ALT_LIB_PATH")
search_paths = compute_search_paths(sources, ctx.options, data_dir, alt_lib_path)
source_set = BuildSourceSet(sources)
try:
plugin, snapshot = load_plugins(ctx.options, ctx.errors, sys.stdout, [])
except CompileError:
# CompileError while importing plugins will be reported by the coordinator.
return None
# Process the rest of the options when plugins are loaded.
options = ctx.options
options.disable_error_code = ctx.disable_error_code
options.enable_error_code = ctx.enable_error_code
options.process_error_codes(error_callback=lambda msg: None)
def flush_errors(filename: str | None, new_messages: list[str], is_serious: bool) -> None:
# We never flush errors in the worker, we send them back to coordinator.
pass
return BuildManager(
data_dir,
search_paths,
ignore_prefix=os.getcwd(),
source_set=source_set,
reports=None,
options=options,
version_id=__version__,
plugin=plugin,
plugins_snapshot=snapshot,
errors=ctx.errors,
error_formatter=None,
flush_errors=flush_errors,
fscache=ctx.fscache,
stdout=sys.stdout,
stderr=sys.stderr,
parallel_worker=True,
)
def console_entry() -> None:
main(sys.argv[1:])

View file

@ -0,0 +1,532 @@
"""
This module contains high-level logic for fixed format serialization.
Lower-level parts are implemented in C in mypyc/lib-rt/internal/librt_internal.c
Short summary of low-level functionality:
* integers are automatically serialized as 1, 2, or 4 bytes, or arbitrary length.
* str/bytes are serialized as size (1, 2, or 4 bytes) followed by bytes buffer.
* floats are serialized as C doubles.
At high-level we add type tags as needed so that our format is self-descriptive.
More precisely:
* False, True, and None are stored as just a tag: 0, 1, 2 correspondingly.
* builtin primitives like int/str/bytes/float are stored as their type tag followed
by bare (low-level) representation of the value. Reserved tag range for primitives is
3 ... 19.
* generic (heterogeneous) list are stored as tag, followed by bare size, followed by
sequence of tagged values.
* homogeneous lists of primitives are stored as tag, followed by bare size, followed
by sequence of bare values.
* reserved tag range for sequence-like builtins is 20 ... 29
* currently we have only one mapping-like format: string-keyed dictionary with heterogeneous
values. It is stored as tag, followed by bare size, followed by sequence of pairs: bare
string key followed by tagged value.
* reserved tag range for mapping-like builtins is 30 ... 39
* there is an additional reserved tag range 40 ... 49 for any other builtin collections.
* custom classes (like types, symbols etc.) are stored as tag, followed by a sequence of
tagged field values, followed by a special end tag 255. Names of class fields are
*not* stored, the caller should know the field names and order for the given class tag.
* reserved tag range for symbols (TypeInfo, Var, etc) is 50 ... 79.
* class Instance is the only exception from the above format (since it is the most common one).
It has two extra formats: few most common instances like "builtins.object" are stored as
instance tag followed by a secondary tag, other plain non-generic instances are stored as
instance tag followed by secondary tag followed by fullname as bare string. All generic
readers must handle these.
* reserved tag range for Instance type formats is 80 ... 99, for other types it is 100 ... 149.
* tag 254 is reserved for if we would ever need to extend the tag range to indicated second tag
page. Tags 150 ... 253 are free for everything else (e.g. AST nodes etc).
General convention is that custom classes implement write() and read() methods for FF
serialization. The write method should write both class tag and end tag. The read method
conventionally *does not* read the start tag (to simplify logic for unions). Known exceptions
are MypyFile.read() and SymbolTableNode.read(), since those two never appear in a union.
If any of these details change, or if the structure of CacheMeta changes please
bump CACHE_VERSION below.
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Final, TypeAlias as _TypeAlias
from librt.internal import (
ReadBuffer as ReadBuffer,
WriteBuffer as WriteBuffer,
read_bool as read_bool,
read_bytes as read_bytes_bare,
read_float as read_float_bare,
read_int as read_int_bare,
read_str as read_str_bare,
read_tag as read_tag,
write_bool as write_bool,
write_bytes as write_bytes_bare,
write_float as write_float_bare,
write_int as write_int_bare,
write_str as write_str_bare,
write_tag as write_tag,
)
from mypy_extensions import u8
# High-level cache layout format
CACHE_VERSION: Final = 7
# Type used internally to represent errors:
# (path, line, column, end_line, end_column, severity, message, code)
ErrorTuple: _TypeAlias = tuple[str | None, int, int, int, int, str, str, str | None]
class CacheMeta:
"""Class representing cache metadata for a module."""
def __init__(
self,
*,
id: str,
path: str,
mtime: int,
size: int,
hash: str,
dependencies: list[str],
data_mtime: int,
data_file: str,
suppressed: list[str],
imports_ignored: dict[int, list[str]],
options: dict[str, object],
suppressed_deps_opts: bytes,
dep_prios: list[int],
dep_lines: list[int],
dep_hashes: list[bytes],
interface_hash: bytes,
trans_dep_hash: bytes,
version_id: str,
ignore_all: bool,
plugin_data: Any,
) -> None:
self.id = id
self.path = path
self.mtime = mtime # source file mtime
self.size = size # source file size
self.hash = hash # source file hash (as a hex string for historical reasons)
self.dependencies = dependencies # names of imported modules
self.data_mtime = data_mtime # mtime of data_file
self.data_file = data_file # path of <id>.data.json or <id>.data.ff
self.suppressed = suppressed # dependencies that weren't imported
self.imports_ignored = imports_ignored # type ignore codes by line
self.options = options # build options snapshot
self.suppressed_deps_opts = suppressed_deps_opts # hash of import-related options
# dep_prios and dep_lines are both aligned with dependencies + suppressed
self.dep_prios = dep_prios
self.dep_lines = dep_lines
# dep_hashes list is aligned with dependencies only
self.dep_hashes = dep_hashes # list of interface_hash for dependencies
self.interface_hash = interface_hash # hash representing the public interface
self.trans_dep_hash = trans_dep_hash # hash of import structure (transitive)
self.version_id = version_id # mypy version for cache invalidation
self.ignore_all = ignore_all # if errors were ignored
self.plugin_data = plugin_data # config data from plugins
def serialize(self) -> dict[str, Any]:
return {
"id": self.id,
"path": self.path,
"mtime": self.mtime,
"size": self.size,
"hash": self.hash,
"data_mtime": self.data_mtime,
"dependencies": self.dependencies,
"suppressed": self.suppressed,
"imports_ignored": {str(line): codes for line, codes in self.imports_ignored.items()},
"options": self.options,
"suppressed_deps_opts": self.suppressed_deps_opts.hex(),
"dep_prios": self.dep_prios,
"dep_lines": self.dep_lines,
"dep_hashes": [dep.hex() for dep in self.dep_hashes],
"interface_hash": self.interface_hash.hex(),
"trans_dep_hash": self.trans_dep_hash.hex(),
"version_id": self.version_id,
"ignore_all": self.ignore_all,
"plugin_data": self.plugin_data,
}
@classmethod
def deserialize(cls, meta: dict[str, Any], data_file: str) -> CacheMeta | None:
try:
return CacheMeta(
id=meta["id"],
path=meta["path"],
mtime=meta["mtime"],
size=meta["size"],
hash=meta["hash"],
dependencies=meta["dependencies"],
data_mtime=meta["data_mtime"],
data_file=data_file,
suppressed=meta["suppressed"],
imports_ignored={
int(line): codes for line, codes in meta["imports_ignored"].items()
},
options=meta["options"],
suppressed_deps_opts=bytes.fromhex(meta["suppressed_deps_opts"]),
dep_prios=meta["dep_prios"],
dep_lines=meta["dep_lines"],
dep_hashes=[bytes.fromhex(dep) for dep in meta["dep_hashes"]],
interface_hash=bytes.fromhex(meta["interface_hash"]),
trans_dep_hash=bytes.fromhex(meta["trans_dep_hash"]),
version_id=meta["version_id"],
ignore_all=meta["ignore_all"],
plugin_data=meta["plugin_data"],
)
except (KeyError, ValueError):
return None
def write(self, data: WriteBuffer) -> None:
write_str(data, self.id)
write_str(data, self.path)
write_int(data, self.mtime)
write_int(data, self.size)
write_str(data, self.hash)
write_str_list(data, self.dependencies)
write_int(data, self.data_mtime)
write_str_list(data, self.suppressed)
write_int_bare(data, len(self.imports_ignored))
for line, codes in self.imports_ignored.items():
write_int(data, line)
write_str_list(data, codes)
write_json(data, self.options)
write_bytes(data, self.suppressed_deps_opts)
write_int_list(data, self.dep_prios)
write_int_list(data, self.dep_lines)
write_bytes_list(data, self.dep_hashes)
write_bytes(data, self.interface_hash)
write_bytes(data, self.trans_dep_hash)
write_str(data, self.version_id)
write_bool(data, self.ignore_all)
# Plugin data may be not a dictionary, so we use
# a more generic write_json_value() here.
write_json_value(data, self.plugin_data)
@classmethod
def read(cls, data: ReadBuffer, data_file: str) -> CacheMeta | None:
try:
return CacheMeta(
id=read_str(data),
path=read_str(data),
mtime=read_int(data),
size=read_int(data),
hash=read_str(data),
dependencies=read_str_list(data),
data_mtime=read_int(data),
data_file=data_file,
suppressed=read_str_list(data),
imports_ignored={
read_int(data): read_str_list(data) for _ in range(read_int_bare(data))
},
options=read_json(data),
suppressed_deps_opts=read_bytes(data),
dep_prios=read_int_list(data),
dep_lines=read_int_list(data),
dep_hashes=read_bytes_list(data),
interface_hash=read_bytes(data),
trans_dep_hash=read_bytes(data),
version_id=read_str(data),
ignore_all=read_bool(data),
plugin_data=read_json_value(data),
)
except (ValueError, AssertionError):
return None
# Always use this type alias to refer to type tags.
Tag = u8
# Primitives.
LITERAL_FALSE: Final[Tag] = 0
LITERAL_TRUE: Final[Tag] = 1
LITERAL_NONE: Final[Tag] = 2
LITERAL_INT: Final[Tag] = 3
LITERAL_STR: Final[Tag] = 4
LITERAL_BYTES: Final[Tag] = 5
LITERAL_FLOAT: Final[Tag] = 6
LITERAL_COMPLEX: Final[Tag] = 7
# Collections.
LIST_GEN: Final[Tag] = 20
LIST_INT: Final[Tag] = 21
LIST_STR: Final[Tag] = 22
LIST_BYTES: Final[Tag] = 23
TUPLE_GEN: Final[Tag] = 24
DICT_STR_GEN: Final[Tag] = 30
DICT_INT_GEN: Final[Tag] = 31
# Misc classes.
EXTRA_ATTRS: Final[Tag] = 150
DT_SPEC: Final[Tag] = 151
# Four integers representing source file (line, column) range.
LOCATION: Final[Tag] = 152
END_TAG: Final[Tag] = 255
def read_literal(data: ReadBuffer, tag: Tag) -> int | str | bool | float:
if tag == LITERAL_INT:
return read_int_bare(data)
elif tag == LITERAL_STR:
return read_str_bare(data)
elif tag == LITERAL_FALSE:
return False
elif tag == LITERAL_TRUE:
return True
elif tag == LITERAL_FLOAT:
return read_float_bare(data)
assert False, f"Unknown literal tag {tag}"
# There is an intentional asymmetry between read and write for literals because
# None and/or complex values are only allowed in some contexts but not in others.
def write_literal(data: WriteBuffer, value: int | str | bool | float | complex | None) -> None:
if isinstance(value, bool):
write_bool(data, value)
elif isinstance(value, int):
write_tag(data, LITERAL_INT)
write_int_bare(data, value)
elif isinstance(value, str):
write_tag(data, LITERAL_STR)
write_str_bare(data, value)
elif isinstance(value, float):
write_tag(data, LITERAL_FLOAT)
write_float_bare(data, value)
elif isinstance(value, complex):
write_tag(data, LITERAL_COMPLEX)
write_float_bare(data, value.real)
write_float_bare(data, value.imag)
else:
write_tag(data, LITERAL_NONE)
def read_int(data: ReadBuffer) -> int:
assert read_tag(data) == LITERAL_INT
return read_int_bare(data)
def write_int(data: WriteBuffer, value: int) -> None:
write_tag(data, LITERAL_INT)
write_int_bare(data, value)
def read_str(data: ReadBuffer) -> str:
assert read_tag(data) == LITERAL_STR
return read_str_bare(data)
def write_str(data: WriteBuffer, value: str) -> None:
write_tag(data, LITERAL_STR)
write_str_bare(data, value)
def read_bytes(data: ReadBuffer) -> bytes:
assert read_tag(data) == LITERAL_BYTES
return read_bytes_bare(data)
def write_bytes(data: WriteBuffer, value: bytes) -> None:
write_tag(data, LITERAL_BYTES)
write_bytes_bare(data, value)
def read_int_opt(data: ReadBuffer) -> int | None:
tag = read_tag(data)
if tag == LITERAL_NONE:
return None
assert tag == LITERAL_INT
return read_int_bare(data)
def write_int_opt(data: WriteBuffer, value: int | None) -> None:
if value is not None:
write_tag(data, LITERAL_INT)
write_int_bare(data, value)
else:
write_tag(data, LITERAL_NONE)
def read_str_opt(data: ReadBuffer) -> str | None:
tag = read_tag(data)
if tag == LITERAL_NONE:
return None
assert tag == LITERAL_STR
return read_str_bare(data)
def write_str_opt(data: WriteBuffer, value: str | None) -> None:
if value is not None:
write_tag(data, LITERAL_STR)
write_str_bare(data, value)
else:
write_tag(data, LITERAL_NONE)
def read_int_list(data: ReadBuffer) -> list[int]:
assert read_tag(data) == LIST_INT
size = read_int_bare(data)
return [read_int_bare(data) for _ in range(size)]
def write_int_list(data: WriteBuffer, value: list[int]) -> None:
write_tag(data, LIST_INT)
write_int_bare(data, len(value))
for item in value:
write_int_bare(data, item)
def read_str_list(data: ReadBuffer) -> list[str]:
assert read_tag(data) == LIST_STR
size = read_int_bare(data)
return [read_str_bare(data) for _ in range(size)]
def write_str_list(data: WriteBuffer, value: Sequence[str]) -> None:
write_tag(data, LIST_STR)
write_int_bare(data, len(value))
for item in value:
write_str_bare(data, item)
def read_bytes_list(data: ReadBuffer) -> list[bytes]:
assert read_tag(data) == LIST_BYTES
size = read_int_bare(data)
return [read_bytes_bare(data) for _ in range(size)]
def write_bytes_list(data: WriteBuffer, value: Sequence[bytes]) -> None:
write_tag(data, LIST_BYTES)
write_int_bare(data, len(value))
for item in value:
write_bytes_bare(data, item)
def read_str_opt_list(data: ReadBuffer) -> list[str | None]:
assert read_tag(data) == LIST_GEN
size = read_int_bare(data)
return [read_str_opt(data) for _ in range(size)]
def write_str_opt_list(data: WriteBuffer, value: list[str | None]) -> None:
write_tag(data, LIST_GEN)
write_int_bare(data, len(value))
for item in value:
write_str_opt(data, item)
Value: _TypeAlias = None | int | str | bool
# Our JSON format is somewhat non-standard as we distinguish lists and tuples.
# This is convenient for some internal things, like mypyc plugin and error serialization.
JsonValue: _TypeAlias = (
Value | list["JsonValue"] | dict[str, "JsonValue"] | tuple["JsonValue", ...]
)
def read_json_value(data: ReadBuffer) -> JsonValue:
tag = read_tag(data)
if tag == LITERAL_NONE:
return None
if tag == LITERAL_FALSE:
return False
if tag == LITERAL_TRUE:
return True
if tag == LITERAL_INT:
return read_int_bare(data)
if tag == LITERAL_STR:
return read_str_bare(data)
if tag == LIST_GEN:
size = read_int_bare(data)
return [read_json_value(data) for _ in range(size)]
if tag == TUPLE_GEN:
size = read_int_bare(data)
return tuple(read_json_value(data) for _ in range(size))
if tag == DICT_STR_GEN:
size = read_int_bare(data)
return {read_str_bare(data): read_json_value(data) for _ in range(size)}
assert False, f"Invalid JSON tag: {tag}"
def write_json_value(data: WriteBuffer, value: JsonValue) -> None:
if value is None:
write_tag(data, LITERAL_NONE)
elif isinstance(value, bool):
write_bool(data, value)
elif isinstance(value, int):
write_tag(data, LITERAL_INT)
write_int_bare(data, value)
elif isinstance(value, str):
write_tag(data, LITERAL_STR)
write_str_bare(data, value)
elif isinstance(value, list):
write_tag(data, LIST_GEN)
write_int_bare(data, len(value))
for val in value:
write_json_value(data, val)
elif isinstance(value, tuple):
write_tag(data, TUPLE_GEN)
write_int_bare(data, len(value))
for val in value:
write_json_value(data, val)
elif isinstance(value, dict):
write_tag(data, DICT_STR_GEN)
write_int_bare(data, len(value))
for key in sorted(value):
write_str_bare(data, key)
write_json_value(data, value[key])
else:
assert False, f"Invalid JSON value: {value}"
# These are functions for JSON *dictionaries* specifically. Unfortunately, we
# must use imprecise types here, because the callers use imprecise types.
def read_json(data: ReadBuffer) -> dict[str, Any]:
assert read_tag(data) == DICT_STR_GEN
size = read_int_bare(data)
return {read_str_bare(data): read_json_value(data) for _ in range(size)}
def write_json(data: WriteBuffer, value: dict[str, Any]) -> None:
write_tag(data, DICT_STR_GEN)
write_int_bare(data, len(value))
for key in sorted(value):
write_str_bare(data, key)
write_json_value(data, value[key])
def write_errors(data: WriteBuffer, errs: list[ErrorTuple]) -> None:
write_tag(data, LIST_GEN)
write_int_bare(data, len(errs))
for path, line, column, end_line, end_column, severity, message, code in errs:
write_tag(data, TUPLE_GEN)
write_str_opt(data, path)
write_int(data, line)
write_int(data, column)
write_int(data, end_line)
write_int(data, end_column)
write_str(data, severity)
write_str(data, message)
write_str_opt(data, code)
def read_errors(data: ReadBuffer) -> list[ErrorTuple]:
assert read_tag(data) == LIST_GEN
result = []
for _ in range(read_int_bare(data)):
assert read_tag(data) == TUPLE_GEN
result.append(
(
read_str_opt(data),
read_int(data),
read_int(data),
read_int(data),
read_int(data),
read_str(data),
read_str(data),
read_str_opt(data),
)
)
return result

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,365 @@
"""Shared definitions used by different parts of type checker."""
from __future__ import annotations
from abc import abstractmethod
from collections.abc import Iterator, Sequence, Set as AbstractSet
from contextlib import contextmanager
from typing import NamedTuple, overload
from mypy_extensions import trait
from mypy.errorcodes import ErrorCode
from mypy.errors import ErrorWatcher
from mypy.message_registry import ErrorMessage
from mypy.nodes import (
ArgKind,
Context,
Expression,
FuncItem,
LambdaExpr,
MypyFile,
Node,
RefExpr,
SymbolNode,
TypeInfo,
Var,
)
from mypy.plugin import CheckerPluginInterface, Plugin
from mypy.types import (
CallableType,
Instance,
LiteralValue,
Overloaded,
PartialType,
TupleType,
Type,
TypedDictType,
TypeType,
)
from mypy.typevars import fill_typevars
# An object that represents either a precise type or a type with an upper bound;
# it is important for correct type inference with isinstance.
class TypeRange(NamedTuple):
item: Type
is_upper_bound: bool # False => precise type
@trait
class ExpressionCheckerSharedApi:
@abstractmethod
def accept(
self,
node: Expression,
type_context: Type | None = None,
allow_none_return: bool = False,
always_allow_any: bool = False,
is_callee: bool = False,
) -> Type:
raise NotImplementedError
@abstractmethod
def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
raise NotImplementedError
@abstractmethod
def check_call(
self,
callee: Type,
args: list[Expression],
arg_kinds: list[ArgKind],
context: Context,
arg_names: Sequence[str | None] | None = None,
callable_node: Expression | None = None,
callable_name: str | None = None,
object_type: Type | None = None,
original_type: Type | None = None,
) -> tuple[Type, Type]:
raise NotImplementedError
@abstractmethod
def transform_callee_type(
self,
callable_name: str | None,
callee: Type,
args: list[Expression],
arg_kinds: list[ArgKind],
context: Context,
arg_names: Sequence[str | None] | None = None,
object_type: Type | None = None,
) -> Type:
raise NotImplementedError
@abstractmethod
def method_fullname(self, object_type: Type, method_name: str) -> str | None:
raise NotImplementedError
@abstractmethod
def check_method_call_by_name(
self,
method: str,
base_type: Type,
args: list[Expression],
arg_kinds: list[ArgKind],
context: Context,
original_type: Type | None = None,
) -> tuple[Type, Type]:
raise NotImplementedError
@abstractmethod
def visit_typeddict_index_expr(
self, td_type: TypedDictType, index: Expression, setitem: bool = False
) -> tuple[Type, set[str]]:
raise NotImplementedError
@abstractmethod
def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Type:
raise NotImplementedError
@abstractmethod
def analyze_static_reference(
self,
node: SymbolNode,
ctx: Context,
is_lvalue: bool,
*,
include_modules: bool = True,
suppress_errors: bool = False,
) -> Type:
raise NotImplementedError
@trait
class TypeCheckerSharedApi(CheckerPluginInterface):
plugin: Plugin
module_refs: set[str]
scope: CheckerScope
checking_missing_await: bool
allow_constructor_cache: bool
@property
@abstractmethod
def expr_checker(self) -> ExpressionCheckerSharedApi:
raise NotImplementedError
@abstractmethod
def named_type(self, name: str) -> Instance:
raise NotImplementedError
@abstractmethod
def lookup_typeinfo(self, fullname: str) -> TypeInfo:
raise NotImplementedError
@abstractmethod
def lookup_type(self, node: Expression) -> Type:
raise NotImplementedError
@abstractmethod
def handle_cannot_determine_type(self, name: str, context: Context) -> None:
raise NotImplementedError
@abstractmethod
def handle_partial_var_type(
self, typ: PartialType, is_lvalue: bool, node: Var, context: Context
) -> Type:
raise NotImplementedError
@overload
@abstractmethod
def check_subtype(
self,
subtype: Type,
supertype: Type,
context: Context,
msg: str,
subtype_label: str | None = None,
supertype_label: str | None = None,
*,
notes: list[str] | None = None,
code: ErrorCode | None = None,
outer_context: Context | None = None,
) -> bool: ...
@overload
@abstractmethod
def check_subtype(
self,
subtype: Type,
supertype: Type,
context: Context,
msg: ErrorMessage,
subtype_label: str | None = None,
supertype_label: str | None = None,
*,
notes: list[str] | None = None,
outer_context: Context | None = None,
) -> bool: ...
# Unfortunately, mypyc doesn't support abstract overloads yet.
@abstractmethod
def check_subtype(
self,
subtype: Type,
supertype: Type,
context: Context,
msg: str | ErrorMessage,
subtype_label: str | None = None,
supertype_label: str | None = None,
*,
notes: list[str] | None = None,
code: ErrorCode | None = None,
outer_context: Context | None = None,
) -> bool:
raise NotImplementedError
@abstractmethod
def get_final_context(self) -> bool:
raise NotImplementedError
@overload
@abstractmethod
def conditional_types_with_intersection(
self,
expr_type: Type,
type_ranges: list[TypeRange] | None,
ctx: Context,
default: None = None,
) -> tuple[Type | None, Type | None]: ...
@overload
@abstractmethod
def conditional_types_with_intersection(
self, expr_type: Type, type_ranges: list[TypeRange] | None, ctx: Context, default: Type
) -> tuple[Type, Type]: ...
# Unfortunately, mypyc doesn't support abstract overloads yet.
@abstractmethod
def conditional_types_with_intersection(
self,
expr_type: Type,
type_ranges: list[TypeRange] | None,
ctx: Context,
default: Type | None = None,
) -> tuple[Type | None, Type | None]:
raise NotImplementedError
@abstractmethod
def narrow_type_by_identity_equality(
self,
operator: str,
operands: list[Expression],
operand_types: list[Type],
expr_indices: list[int],
narrowable_indices: AbstractSet[int],
) -> tuple[dict[Expression, Type] | None, dict[Expression, Type] | None]:
raise NotImplementedError
@abstractmethod
def check_deprecated(self, node: Node | None, context: Context) -> None:
raise NotImplementedError
@abstractmethod
def warn_deprecated(self, node: Node | None, context: Context) -> None:
raise NotImplementedError
@abstractmethod
def type_is_iterable(self, type: Type) -> bool:
raise NotImplementedError
@abstractmethod
def iterable_item_type(
self, it: Instance | CallableType | TypeType | Overloaded, context: Context
) -> Type:
raise NotImplementedError
@abstractmethod
@contextmanager
def checking_await_set(self) -> Iterator[None]:
raise NotImplementedError
@abstractmethod
def get_precise_awaitable_type(self, typ: Type, local_errors: ErrorWatcher) -> Type | None:
raise NotImplementedError
@abstractmethod
def add_any_attribute_to_type(self, typ: Type, name: str) -> Type:
raise NotImplementedError
@abstractmethod
def is_defined_in_stub(self, typ: Instance, /) -> bool:
raise NotImplementedError
class CheckerScope:
# We keep two stacks combined, to maintain the relative order
stack: list[TypeInfo | FuncItem | MypyFile]
def __init__(self, module: MypyFile) -> None:
self.stack = [module]
def current_function(self) -> FuncItem | None:
for e in reversed(self.stack):
if isinstance(e, FuncItem):
return e
return None
def top_level_function(self) -> FuncItem | None:
"""Return top-level non-lambda function."""
for e in self.stack:
if isinstance(e, FuncItem) and not isinstance(e, LambdaExpr):
return e
return None
def active_class(self) -> TypeInfo | None:
if isinstance(self.stack[-1], TypeInfo):
return self.stack[-1]
return None
def enclosing_class(self, func: FuncItem | None = None) -> TypeInfo | None:
"""Is there a class *directly* enclosing this function?"""
func = func or self.current_function()
assert func, "This method must be called from inside a function"
index = self.stack.index(func)
assert index, "CheckerScope stack must always start with a module"
enclosing = self.stack[index - 1]
if isinstance(enclosing, TypeInfo):
return enclosing
return None
def active_self_type(self) -> Instance | TupleType | None:
"""An instance or tuple type representing the current class.
This returns None unless we are in class body or in a method.
In particular, inside a function nested in method this returns None.
"""
info = self.active_class()
if not info and self.current_function():
info = self.enclosing_class()
if info:
return fill_typevars(info)
return None
def current_self_type(self) -> Instance | TupleType | None:
"""Same as active_self_type() but handle functions nested in methods."""
for item in reversed(self.stack):
if isinstance(item, TypeInfo):
return fill_typevars(item)
return None
def is_top_level(self) -> bool:
"""Is current scope top-level (no classes or functions)?"""
return len(self.stack) == 1
@contextmanager
def push_function(self, item: FuncItem) -> Iterator[None]:
self.stack.append(item)
yield
self.stack.pop()
@contextmanager
def push_class(self, info: TypeInfo) -> Iterator[None]:
self.stack.append(info)
yield
self.stack.pop()

View file

@ -0,0 +1,30 @@
from __future__ import annotations
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Final
from mypy.checker_shared import TypeCheckerSharedApi
# This is global mutable state. Don't add anything here unless there's a very
# good reason.
class TypeCheckerState:
# Wrap this in a class since it's faster that using a module-level attribute.
def __init__(self, type_checker: TypeCheckerSharedApi | None) -> None:
# Value varies by file being processed
self.type_checker = type_checker
@contextmanager
def set(self, value: TypeCheckerSharedApi) -> Iterator[None]:
saved = self.type_checker
self.type_checker = value
try:
yield
finally:
self.type_checker = saved
checker_state: Final = TypeCheckerState(type_checker=None)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,885 @@
"""Pattern checker. This file is conceptually part of TypeChecker."""
from __future__ import annotations
from collections import defaultdict
from typing import Final, NamedTuple
from mypy import message_registry
from mypy.checker_shared import TypeCheckerSharedApi, TypeRange
from mypy.checkmember import analyze_member_access
from mypy.expandtype import expand_type_by_instance
from mypy.join import join_types
from mypy.literals import literal_hash
from mypy.maptype import map_instance_to_supertype
from mypy.meet import narrow_declared_type
from mypy.messages import MessageBuilder
from mypy.nodes import ARG_POS, Expression, NameExpr, TempNode, TypeAlias, Var
from mypy.options import Options
from mypy.patterns import (
AsPattern,
ClassPattern,
MappingPattern,
OrPattern,
Pattern,
SequencePattern,
SingletonPattern,
StarredPattern,
ValuePattern,
)
from mypy.plugin import Plugin
from mypy.subtypes import is_subtype
from mypy.typeops import (
coerce_to_literal,
make_simplified_union,
try_getting_str_literals_from_type,
tuple_fallback,
)
from mypy.types import (
AnyType,
FunctionLike,
Instance,
NoneType,
ProperType,
TupleType,
Type,
TypedDictType,
TypeOfAny,
TypeType,
TypeVarTupleType,
TypeVarType,
UninhabitedType,
UnionType,
UnpackType,
callable_with_ellipsis,
find_unpack_in_list,
get_proper_type,
split_with_prefix_and_suffix,
)
from mypy.typevars import fill_typevars, fill_typevars_with_any
from mypy.visitor import PatternVisitor
self_match_type_names: Final = [
"builtins.bool",
"builtins.bytearray",
"builtins.bytes",
"builtins.dict",
"builtins.float",
"builtins.frozenset",
"builtins.int",
"builtins.list",
"builtins.set",
"builtins.str",
"builtins.tuple",
]
non_sequence_match_type_names: Final = ["builtins.str", "builtins.bytes", "builtins.bytearray"]
# For every Pattern a PatternType can be calculated. This requires recursively calculating
# the PatternTypes of the sub-patterns first.
# Using the data in the PatternType the match subject and captured names can be narrowed/inferred.
class PatternType(NamedTuple):
type: Type # The type the match subject can be narrowed to
rest_type: Type # The remaining type if the pattern didn't match
captures: dict[Expression, Type] # The variables captured by the pattern
class PatternChecker(PatternVisitor[PatternType]):
"""Pattern checker.
This class checks if a pattern can match a type, what the type can be narrowed to, and what
type capture patterns should be inferred as.
"""
# Some services are provided by a TypeChecker instance.
chk: TypeCheckerSharedApi
# This is shared with TypeChecker, but stored also here for convenience.
msg: MessageBuilder
# Currently unused
plugin: Plugin
# The expression being matched against the pattern
subject: Expression
subject_type: Type
# Type of the subject to check the (sub)pattern against
type_context: list[Type]
# Types that match against self instead of their __match_args__ if used as a class pattern
# Filled in from self_match_type_names
self_match_types: list[Type]
# Types that are sequences, but don't match sequence patterns. Filled in from
# non_sequence_match_type_names
non_sequence_match_types: list[Type]
options: Options
def __init__(
self, chk: TypeCheckerSharedApi, msg: MessageBuilder, plugin: Plugin, options: Options
) -> None:
self.chk = chk
self.msg = msg
self.plugin = plugin
self.type_context = []
self.self_match_types = self.generate_types_from_names(self_match_type_names)
self.non_sequence_match_types = self.generate_types_from_names(
non_sequence_match_type_names
)
self.options = options
def accept(self, o: Pattern, type_context: Type) -> PatternType:
self.type_context.append(type_context)
result = o.accept(self)
self.type_context.pop()
return result
def visit_as_pattern(self, o: AsPattern) -> PatternType:
current_type = self.type_context[-1]
if o.pattern is not None:
pattern_type = self.accept(o.pattern, current_type)
typ, rest_type, type_map = pattern_type
else:
typ, rest_type, type_map = current_type, UninhabitedType(), {}
if not is_uninhabited(typ) and o.name is not None:
typ, _ = self.chk.conditional_types_with_intersection(
current_type, [get_type_range(typ)], o, default=current_type
)
if not is_uninhabited(typ):
type_map[o.name] = typ
return PatternType(typ, rest_type, type_map)
def visit_or_pattern(self, o: OrPattern) -> PatternType:
current_type = self.type_context[-1]
#
# Check all the subpatterns
#
pattern_types = []
for pattern in o.patterns:
pattern_type = self.accept(pattern, current_type)
pattern_types.append(pattern_type)
if not is_uninhabited(pattern_type.type):
current_type = pattern_type.rest_type
#
# Collect the final type
#
types = []
for pattern_type in pattern_types:
if not is_uninhabited(pattern_type.type):
types.append(pattern_type.type)
#
# Check the capture types
#
capture_types: dict[Var, list[tuple[Expression, Type]]] = defaultdict(list)
# Collect captures from the first subpattern
for expr, typ in pattern_types[0].captures.items():
node = get_var(expr)
capture_types[node].append((expr, typ))
# Check if other subpatterns capture the same names
for i, pattern_type in enumerate(pattern_types[1:]):
vars = {get_var(expr) for expr, _ in pattern_type.captures.items()}
if capture_types.keys() != vars:
self.msg.fail(message_registry.OR_PATTERN_ALTERNATIVE_NAMES, o.patterns[i])
for expr, typ in pattern_type.captures.items():
node = get_var(expr)
capture_types[node].append((expr, typ))
captures: dict[Expression, Type] = {}
for capture_list in capture_types.values():
typ = UninhabitedType()
for _, other in capture_list:
typ = make_simplified_union([typ, other])
captures[capture_list[0][0]] = typ
union_type = make_simplified_union(types)
return PatternType(union_type, current_type, captures)
def visit_value_pattern(self, o: ValuePattern) -> PatternType:
current_type = self.type_context[-1]
typ = self.chk.expr_checker.accept(o.expr)
typ = coerce_to_literal(typ)
node = TempNode(current_type)
# Value patterns are essentially a syntactic sugar on top of `if x == Value`.
# They should be treated equivalently.
ok_map, rest_map = self.chk.narrow_type_by_identity_equality(
"==", [node, TempNode(typ)], [current_type, typ], [0, 1], {0}
)
ok_type = ok_map.get(node, current_type) if ok_map is not None else UninhabitedType()
rest_type = rest_map.get(node, current_type) if rest_map is not None else UninhabitedType()
return PatternType(ok_type, rest_type, {})
def visit_singleton_pattern(self, o: SingletonPattern) -> PatternType:
current_type = self.type_context[-1]
value: bool | None = o.value
if isinstance(value, bool):
typ = self.chk.expr_checker.infer_literal_expr_type(value, "builtins.bool")
elif value is None:
typ = NoneType()
else:
assert False
narrowed_type, rest_type = self.chk.conditional_types_with_intersection(
current_type, [get_type_range(typ)], o, default=current_type
)
return PatternType(narrowed_type, rest_type, {})
def visit_sequence_pattern(self, o: SequencePattern) -> PatternType:
#
# Step 1. Check for existence of a starred pattern
#
current_type = get_proper_type(self.type_context[-1])
if not self.can_match_sequence(current_type):
return self.early_non_match()
star_positions = [i for i, p in enumerate(o.patterns) if isinstance(p, StarredPattern)]
star_position: int | None = None
if len(star_positions) == 1:
star_position = star_positions[0]
elif len(star_positions) >= 2:
assert False, "Parser should prevent multiple starred patterns"
required_patterns = len(o.patterns)
if star_position is not None:
required_patterns -= 1
#
# Step 2. If we have a union, recurse and return the combined result
#
if isinstance(current_type, UnionType):
match_types: list[Type] = []
rest_types: list[Type] = []
captures_list: dict[Expression, list[Type]] = {}
if star_position is not None:
star_pattern = o.patterns[star_position]
assert isinstance(star_pattern, StarredPattern)
star_expr = star_pattern.capture
else:
star_expr = None
for t in current_type.items:
match_type, rest_type, captures = self.accept(o, t)
match_types.append(match_type)
rest_types.append(rest_type)
if not is_uninhabited(match_type):
for expr, typ in captures.items():
p_typ = get_proper_type(typ)
if expr not in captures_list:
captures_list[expr] = []
# Avoid adding in a list[Never] for empty list captures
if (
expr == star_expr
and isinstance(p_typ, Instance)
and p_typ.type.fullname == "builtins.list"
and is_uninhabited(p_typ.args[0])
):
continue
captures_list[expr].append(typ)
return PatternType(
make_simplified_union(match_types),
make_simplified_union(rest_types),
{expr: make_simplified_union(types) for expr, types in captures_list.items()},
)
#
# Step 3. Get inner types of original type
#
unpack_index = None
if isinstance(current_type, TupleType):
inner_types: list[Type] = current_type.items
unpack_index = find_unpack_in_list(inner_types)
if unpack_index is None:
size_diff = len(inner_types) - required_patterns
if size_diff < 0:
return self.early_non_match()
elif size_diff > 0 and star_position is None:
return self.early_non_match()
else:
normalized_inner_types = []
for it in inner_types:
# Unfortunately, it is not possible to "split" the TypeVarTuple
# into individual items, so we just use its upper bound for the whole
# analysis instead.
if isinstance(it, UnpackType) and isinstance(it.type, TypeVarTupleType):
it = UnpackType(it.type.upper_bound)
normalized_inner_types.append(it)
inner_types = normalized_inner_types
current_type = current_type.copy_modified(items=normalized_inner_types)
if len(inner_types) - 1 > required_patterns and star_position is None:
return self.early_non_match()
elif isinstance(current_type, AnyType):
inner_type = AnyType(TypeOfAny.from_another_any, current_type)
inner_types = [inner_type] * len(o.patterns)
elif isinstance(current_type, Instance) and self.chk.type_is_iterable(current_type):
inner_type = self.chk.iterable_item_type(current_type, o)
inner_types = [inner_type] * len(o.patterns)
else:
inner_type = self.chk.named_type("builtins.object")
inner_types = [inner_type] * len(o.patterns)
#
# Step 4. Match inner patterns
#
contracted_new_inner_types: list[Type] = []
contracted_rest_inner_types: list[Type] = []
captures = {} # dict[Expression, Type]
contracted_inner_types = self.contract_starred_pattern_types(
inner_types, star_position, required_patterns
)
for p, t in zip(o.patterns, contracted_inner_types):
pattern_type = self.accept(p, t)
typ, rest, type_map = pattern_type
contracted_new_inner_types.append(typ)
contracted_rest_inner_types.append(rest)
self.update_type_map(captures, type_map)
new_inner_types = self.expand_starred_pattern_types(
contracted_new_inner_types, star_position, len(inner_types), unpack_index is not None
)
rest_inner_types = self.expand_starred_pattern_types(
contracted_rest_inner_types, star_position, len(inner_types), unpack_index is not None
)
#
# Step 5. Calculate new type
#
new_type: Type
rest_type = current_type
if isinstance(current_type, TupleType) and unpack_index is None:
if any(is_uninhabited(typ) for typ in new_inner_types):
new_type = UninhabitedType()
else:
new_type = TupleType(new_inner_types, current_type.partial_fallback)
num_always_match = sum(is_uninhabited(typ) for typ in rest_inner_types)
if num_always_match == len(rest_inner_types):
# All subpatterns always match, so we can apply negative narrowing
rest_type = UninhabitedType()
elif num_always_match == len(rest_inner_types) - 1:
# Exactly one subpattern may conditionally match, the rest always match.
# We can apply negative narrowing to this one position.
rest_type = TupleType(
[
curr if is_uninhabited(rest) else rest
for curr, rest in zip(inner_types, rest_inner_types)
],
current_type.partial_fallback,
)
elif isinstance(current_type, TupleType):
# For variadic tuples it is too tricky to match individual items like for fixed
# tuples, so we instead try to narrow the entire type.
# TODO: use more precise narrowing when possible (e.g. for identical shapes).
new_tuple_type = TupleType(new_inner_types, current_type.partial_fallback)
new_type, _ = self.chk.conditional_types_with_intersection(
new_tuple_type, [get_type_range(current_type)], o, default=new_tuple_type
)
if (
star_position is not None
and required_patterns <= len(inner_types) - 1
and all(is_uninhabited(rest) for rest in rest_inner_types)
):
rest_type = UninhabitedType()
else:
new_inner_type = UninhabitedType()
for typ in new_inner_types:
new_inner_type = join_types(new_inner_type, typ)
new_type = self.construct_sequence_child(current_type, new_inner_type)
new_type, possible_rest_type = self.chk.conditional_types_with_intersection(
current_type, [get_type_range(new_type)], o, default=current_type
)
if star_position is not None and len(o.patterns) == 1:
# Match cannot be refuted, so narrow the remaining type
rest_type = possible_rest_type
return PatternType(new_type, rest_type, captures)
def contract_starred_pattern_types(
self, types: list[Type], star_pos: int | None, num_patterns: int
) -> list[Type]:
"""
Contracts a list of types in a sequence pattern depending on the position of a starred
capture pattern.
For example if the sequence pattern [a, *b, c] is matched against types [bool, int, str,
bytes] the contracted types are [bool, Union[int, str], bytes].
If star_pos in None the types are returned unchanged.
"""
unpack_index = find_unpack_in_list(types)
if unpack_index is not None:
# Variadic tuples require "re-shaping" to match the requested pattern.
unpack = types[unpack_index]
assert isinstance(unpack, UnpackType)
unpacked = get_proper_type(unpack.type)
# This should be guaranteed by the normalization in the caller.
assert isinstance(unpacked, Instance) and unpacked.type.fullname == "builtins.tuple"
if star_pos is None:
missing = num_patterns - len(types) + 1
new_types = types[:unpack_index]
new_types += [unpacked.args[0]] * missing
new_types += types[unpack_index + 1 :]
return new_types
prefix, middle, suffix = split_with_prefix_and_suffix(
tuple([UnpackType(unpacked) if isinstance(t, UnpackType) else t for t in types]),
star_pos,
num_patterns - star_pos,
)
new_middle = []
for m in middle:
# The existing code expects the star item type, rather than the type of
# the whole tuple "slice".
if isinstance(m, UnpackType):
new_middle.append(unpacked.args[0])
else:
new_middle.append(m)
return list(prefix) + [make_simplified_union(new_middle)] + list(suffix)
else:
if star_pos is None:
return types
new_types = types[:star_pos]
star_length = len(types) - num_patterns
new_types.append(make_simplified_union(types[star_pos : star_pos + star_length]))
new_types += types[star_pos + star_length :]
return new_types
def expand_starred_pattern_types(
self, types: list[Type], star_pos: int | None, num_types: int, original_unpack: bool
) -> list[Type]:
"""Undoes the contraction done by contract_starred_pattern_types.
For example if the sequence pattern is [a, *b, c] and types [bool, int, str] are extended
to length 4 the result is [bool, int, int, str].
"""
if star_pos is None:
return types
if original_unpack:
# In the case where original tuple type has an unpack item, it is not practical
# to coerce pattern type back to the original shape (and may not even be possible),
# so we only restore the type of the star item.
res = []
for i, t in enumerate(types):
if i != star_pos or is_uninhabited(t):
res.append(t)
else:
res.append(UnpackType(self.chk.named_generic_type("builtins.tuple", [t])))
return res
new_types = types[:star_pos]
star_length = num_types - len(types) + 1
new_types += [types[star_pos]] * star_length
new_types += types[star_pos + 1 :]
return new_types
def visit_starred_pattern(self, o: StarredPattern) -> PatternType:
captures: dict[Expression, Type] = {}
if o.capture is not None:
list_type = self.chk.named_generic_type("builtins.list", [self.type_context[-1]])
captures[o.capture] = list_type
return PatternType(self.type_context[-1], UninhabitedType(), captures)
def visit_mapping_pattern(self, o: MappingPattern) -> PatternType:
current_type = get_proper_type(self.type_context[-1])
can_match = True
captures: dict[Expression, Type] = {}
for key, value in zip(o.keys, o.values):
inner_type = self.get_mapping_item_type(o, current_type, key)
if inner_type is None:
can_match = False
inner_type = self.chk.named_type("builtins.object")
pattern_type = self.accept(value, inner_type)
if is_uninhabited(pattern_type.type):
can_match = False
else:
self.update_type_map(captures, pattern_type.captures)
if o.rest is not None:
mapping = self.chk.named_type("typing.Mapping")
if is_subtype(current_type, mapping) and isinstance(current_type, Instance):
mapping_inst = map_instance_to_supertype(current_type, mapping.type)
dict_typeinfo = self.chk.lookup_typeinfo("builtins.dict")
rest_type = Instance(dict_typeinfo, mapping_inst.args)
else:
object_type = self.chk.named_type("builtins.object")
rest_type = self.chk.named_generic_type(
"builtins.dict", [object_type, object_type]
)
captures[o.rest] = rest_type
else_type = current_type
if can_match:
# We can't narrow the type here, as Mapping key is invariant.
new_type = self.type_context[-1]
if not o.keys:
# Match cannot be refuted, so narrow the remaining type
mapping = self.chk.named_type("typing.Mapping")
if_type, else_type = self.chk.conditional_types_with_intersection(
current_type,
[TypeRange(mapping, is_upper_bound=False)],
o,
default=current_type,
)
if not isinstance(current_type, AnyType):
new_type = if_type
else:
new_type = UninhabitedType()
return PatternType(new_type, else_type, captures)
def get_mapping_item_type(
self, pattern: MappingPattern, mapping_type: Type, key: Expression
) -> Type | None:
mapping_type = get_proper_type(mapping_type)
if isinstance(mapping_type, TypedDictType):
with self.msg.filter_errors() as local_errors:
result: Type | None = self.chk.expr_checker.visit_typeddict_index_expr(
mapping_type, key
)[0]
has_local_errors = local_errors.has_new_errors()
# If we can't determine the type statically fall back to treating it as a normal
# mapping
if has_local_errors:
with self.msg.filter_errors() as local_errors:
result = self.get_simple_mapping_item_type(pattern, mapping_type, key)
if local_errors.has_new_errors():
result = None
else:
with self.msg.filter_errors():
result = self.get_simple_mapping_item_type(pattern, mapping_type, key)
return result
def get_simple_mapping_item_type(
self, pattern: MappingPattern, mapping_type: Type, key: Expression
) -> Type:
result, _ = self.chk.expr_checker.check_method_call_by_name(
"__getitem__", mapping_type, [key], [ARG_POS], pattern
)
return result
def visit_class_pattern(self, o: ClassPattern) -> PatternType:
current_type = get_proper_type(self.type_context[-1])
#
# Check class type
#
type_info = o.class_ref.node
if isinstance(type_info, TypeAlias) and not type_info.no_args:
self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o)
return self.early_non_match()
typ = self.chk.expr_checker.accept(o.class_ref)
type_ranges = self.get_class_pattern_type_ranges(typ, o)
if type_ranges is None:
return self.early_non_match()
typ = UnionType.make_union([t.item for t in type_ranges])
new_type, rest_type = self.chk.conditional_types_with_intersection(
current_type, type_ranges, o, default=current_type
)
if is_uninhabited(new_type):
return self.early_non_match()
# TODO: Do I need this?
narrowed_type = narrow_declared_type(current_type, new_type)
#
# Convert positional to keyword patterns
#
keyword_pairs: list[tuple[str | None, Pattern]] = []
match_arg_set: set[str] = set()
captures: dict[Expression, Type] = {}
if len(o.positionals) != 0:
if self.should_self_match(typ):
if len(o.positionals) > 1:
self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
pattern_type = self.accept(o.positionals[0], narrowed_type)
if not is_uninhabited(pattern_type.type):
return PatternType(
pattern_type.type,
join_types(rest_type, pattern_type.rest_type),
pattern_type.captures,
)
captures = pattern_type.captures
else:
with self.msg.filter_errors() as local_errors:
match_args_type = analyze_member_access(
"__match_args__",
typ,
o,
is_lvalue=False,
is_super=False,
is_operator=False,
original_type=typ,
chk=self.chk,
)
has_local_errors = local_errors.has_new_errors()
if has_local_errors:
self.msg.fail(
message_registry.MISSING_MATCH_ARGS.format(
typ.str_with_options(self.options)
),
o,
)
return self.early_non_match()
proper_match_args_type = get_proper_type(match_args_type)
if isinstance(proper_match_args_type, TupleType):
match_arg_names = get_match_arg_names(proper_match_args_type)
if len(o.positionals) > len(match_arg_names):
self.msg.fail(message_registry.CLASS_PATTERN_TOO_MANY_POSITIONAL_ARGS, o)
return self.early_non_match()
else:
match_arg_names = [None] * len(o.positionals)
for arg_name, pos in zip(match_arg_names, o.positionals):
keyword_pairs.append((arg_name, pos))
if arg_name is not None:
match_arg_set.add(arg_name)
#
# Check for duplicate patterns
#
keyword_arg_set = set()
has_duplicates = False
for key, value in zip(o.keyword_keys, o.keyword_values):
keyword_pairs.append((key, value))
if key in match_arg_set:
self.msg.fail(
message_registry.CLASS_PATTERN_KEYWORD_MATCHES_POSITIONAL.format(key), value
)
has_duplicates = True
elif key in keyword_arg_set:
self.msg.fail(
message_registry.CLASS_PATTERN_DUPLICATE_KEYWORD_PATTERN.format(key), value
)
has_duplicates = True
keyword_arg_set.add(key)
if has_duplicates:
return self.early_non_match()
#
# Check keyword patterns
#
can_match = True
for keyword, pattern in keyword_pairs:
key_type: Type | None = None
with self.msg.filter_errors() as local_errors:
if keyword is not None:
key_type = analyze_member_access(
keyword,
narrowed_type,
pattern,
is_lvalue=False,
is_super=False,
is_operator=False,
original_type=new_type,
chk=self.chk,
)
else:
key_type = AnyType(TypeOfAny.from_error)
has_local_errors = local_errors.has_new_errors()
if has_local_errors or key_type is None:
key_type = AnyType(TypeOfAny.from_error)
if not (type_info and type_info.fullname == "builtins.object"):
self.msg.fail(
message_registry.CLASS_PATTERN_UNKNOWN_KEYWORD.format(
typ.str_with_options(self.options), keyword
),
pattern,
)
elif keyword is not None:
new_type = self.chk.add_any_attribute_to_type(new_type, keyword)
inner_type, inner_rest_type, inner_captures = self.accept(pattern, key_type)
if is_uninhabited(inner_type):
can_match = False
else:
self.update_type_map(captures, inner_captures)
if not is_uninhabited(inner_rest_type):
rest_type = current_type
if not can_match:
new_type = UninhabitedType()
return PatternType(new_type, rest_type, captures)
def get_class_pattern_type_ranges(self, typ: Type, o: ClassPattern) -> list[TypeRange] | None:
p_typ = get_proper_type(typ)
if isinstance(p_typ, UnionType):
type_ranges = []
for item in p_typ.items:
type_range = self.get_class_pattern_type_ranges(item, o)
if type_range is not None:
type_ranges.extend(type_range)
if not type_ranges:
return None
return type_ranges
if isinstance(p_typ, FunctionLike) and p_typ.is_type_obj():
typ = fill_typevars_with_any(p_typ.type_object())
return [TypeRange(typ, is_upper_bound=False)]
if (
isinstance(o.class_ref.node, Var)
and o.class_ref.node.type is not None
and o.class_ref.node.fullname == "typing.Callable"
):
# Create a `Callable[..., Any]`
fallback = self.chk.named_type("builtins.function")
any_type = AnyType(TypeOfAny.unannotated)
typ = callable_with_ellipsis(any_type, ret_type=any_type, fallback=fallback)
return [TypeRange(typ, is_upper_bound=False)]
if isinstance(p_typ, TypeType):
typ = p_typ.item
return [TypeRange(p_typ.item, is_upper_bound=True)]
if isinstance(p_typ, AnyType):
return [TypeRange(p_typ, is_upper_bound=False)]
self.msg.fail(
message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(
typ.str_with_options(self.options)
),
o,
)
return None
def should_self_match(self, typ: Type) -> bool:
typ = get_proper_type(typ)
if isinstance(typ, TupleType):
typ = typ.partial_fallback
if isinstance(typ, AnyType):
return False
if isinstance(typ, Instance) and typ.type.get("__match_args__") is not None:
# Named tuples and other subtypes of builtins that define __match_args__
# should not self match.
return False
for other in self.self_match_types:
if is_subtype(typ, other):
return True
return False
def can_match_sequence(self, typ: ProperType) -> bool:
if isinstance(typ, AnyType):
return True
if isinstance(typ, UnionType):
return any(self.can_match_sequence(get_proper_type(item)) for item in typ.items)
for other in self.non_sequence_match_types:
# We have to ignore promotions, as memoryview should match, but bytes,
# which it can be promoted to, shouldn't
if is_subtype(typ, other, ignore_promotions=True):
return False
sequence = self.chk.named_type("typing.Sequence")
# If the static type is more general than sequence the actual type could still match
return is_subtype(typ, sequence) or is_subtype(sequence, typ)
def generate_types_from_names(self, type_names: list[str]) -> list[Type]:
types: list[Type] = []
for name in type_names:
try:
types.append(self.chk.named_type(name))
except KeyError as e:
# Some built in types are not defined in all test cases
if not name.startswith("builtins."):
raise e
return types
def update_type_map(
self, original_type_map: dict[Expression, Type], extra_type_map: dict[Expression, Type]
) -> None:
# Calculating this would not be needed if TypeMap directly used literal hashes instead of
# expressions, as suggested in the TODO above it's definition
already_captured = {literal_hash(expr) for expr in original_type_map}
for expr, typ in extra_type_map.items():
if literal_hash(expr) in already_captured:
node = get_var(expr)
self.msg.fail(
message_registry.MULTIPLE_ASSIGNMENTS_IN_PATTERN.format(node.name), expr
)
else:
original_type_map[expr] = typ
def construct_sequence_child(self, outer_type: Type, inner_type: Type) -> Type:
"""
If outer_type is a child class of typing.Sequence returns a new instance of
outer_type, that is a Sequence of inner_type. If outer_type is not a child class of
typing.Sequence just returns a Sequence of inner_type
For example:
construct_sequence_child(List[int], str) = List[str]
TODO: this doesn't make sense. For example if one has class S(Sequence[int], Generic[T])
or class T(Sequence[Tuple[T, T]]), there is no way any of those can map to Sequence[str].
"""
proper_type = get_proper_type(outer_type)
if isinstance(proper_type, TypeVarType):
new_bound = self.construct_sequence_child(proper_type.upper_bound, inner_type)
return proper_type.copy_modified(upper_bound=new_bound)
if isinstance(proper_type, AnyType):
return outer_type
if isinstance(proper_type, UnionType):
types = [
self.construct_sequence_child(item, inner_type)
for item in proper_type.items
if self.can_match_sequence(get_proper_type(item))
]
return make_simplified_union(types)
sequence = self.chk.named_generic_type("typing.Sequence", [inner_type])
if is_subtype(outer_type, self.chk.named_type("typing.Sequence")):
if isinstance(proper_type, TupleType):
proper_type = tuple_fallback(proper_type)
assert isinstance(proper_type, Instance)
empty_type = fill_typevars(proper_type.type)
partial_type = expand_type_by_instance(empty_type, sequence)
return expand_type_by_instance(partial_type, proper_type)
else:
return sequence
def early_non_match(self) -> PatternType:
return PatternType(UninhabitedType(), self.type_context[-1], {})
def get_match_arg_names(typ: TupleType) -> list[str | None]:
args: list[str | None] = []
for item in typ.items:
values = try_getting_str_literals_from_type(item)
if values is None or len(values) != 1:
args.append(None)
else:
args.append(values[0])
return args
def get_var(expr: Expression) -> Var:
"""
Warning: this in only true for expressions captured by a match statement.
Don't call it from anywhere else
"""
assert isinstance(expr, NameExpr), expr
node = expr.node
assert isinstance(node, Var), node
return node
def get_type_range(typ: Type) -> TypeRange:
typ = get_proper_type(typ)
if (
isinstance(typ, Instance)
and typ.last_known_value
and isinstance(typ.last_known_value.value, bool)
):
typ = typ.last_known_value
return TypeRange(typ, is_upper_bound=False)
def is_uninhabited(typ: Type) -> bool:
return isinstance(get_proper_type(typ), UninhabitedType)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,734 @@
from __future__ import annotations
import argparse
import configparser
import glob as fileglob
import os
import re
import sys
from io import StringIO
if sys.version_info >= (3, 11):
import tomllib
else:
import tomli as tomllib
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from typing import Any, Final, TextIO, TypeAlias
from typing_extensions import Never
from mypy import defaults
from mypy.options import PER_MODULE_OPTIONS, Options
_CONFIG_VALUE_TYPES: TypeAlias = (
str | bool | int | float | dict[str, str] | list[str] | tuple[int, int]
)
_INI_PARSER_CALLABLE: TypeAlias = Callable[[Any], _CONFIG_VALUE_TYPES]
class VersionTypeError(argparse.ArgumentTypeError):
"""Provide a fallback value if the Python version is unsupported."""
def __init__(self, *args: Any, fallback: tuple[int, int]) -> None:
self.fallback = fallback
super().__init__(*args)
def parse_version(v: str | float) -> tuple[int, int]:
m = re.match(r"\A(\d)\.(\d+)\Z", str(v))
if not m:
raise argparse.ArgumentTypeError(f"Invalid python version '{v}' (expected format: 'x.y')")
major, minor = int(m.group(1)), int(m.group(2))
if major == 2 and minor == 7:
pass # Error raised elsewhere
elif major == 3:
if minor < defaults.PYTHON3_VERSION_MIN[1]:
msg = "Python 3.{} is not supported (must be {}.{} or higher)".format(
minor, *defaults.PYTHON3_VERSION_MIN
)
if isinstance(v, float):
msg += ". You may need to put quotes around your Python version"
raise VersionTypeError(msg, fallback=defaults.PYTHON3_VERSION_MIN)
else:
raise argparse.ArgumentTypeError(
f"Python major version '{major}' out of range (must be 3)"
)
return major, minor
def try_split(v: str | Sequence[str] | object, split_regex: str = ",") -> list[str]:
"""Split and trim a str or sequence (eg: list) of str into a list of str.
If an element of the input is not str, a type error will be raised."""
def complain(x: object, additional_info: str = "") -> Never:
raise argparse.ArgumentTypeError(
f"Expected a list or a stringified version thereof, but got: '{x}', of type {type(x).__name__}.{additional_info}"
)
if isinstance(v, str):
items = [p.strip() for p in re.split(split_regex, v)]
if items and items[-1] == "":
items.pop(-1)
return items
elif isinstance(v, Sequence):
return [
(
p.strip()
if isinstance(p, str)
else complain(p, additional_info=" (As an element of the list.)")
)
for p in v
]
else:
complain(v)
def validate_package_allow_list(allow_list: list[str]) -> list[str]:
for p in allow_list:
msg = f"Invalid allow list entry: {p}"
if "*" in p:
raise argparse.ArgumentTypeError(
f"{msg} (entries are already prefixes so must not contain *)"
)
if "\\" in p or "/" in p:
raise argparse.ArgumentTypeError(
f"{msg} (entries must be packages like foo.bar not directories or files)"
)
return allow_list
def expand_path(path: str) -> str:
"""Expand the user home directory and any environment variables contained within
the provided path.
"""
return os.path.expandvars(os.path.expanduser(path))
def str_or_array_as_list(v: str | Sequence[str]) -> list[str]:
if isinstance(v, str):
return [v.strip()] if v.strip() else []
return [p.strip() for p in v if p.strip()]
def split_and_match_files_list(paths: Sequence[str]) -> list[str]:
"""Take a list of files/directories (with support for globbing through the glob library).
Where a path/glob matches no file, we still include the raw path in the resulting list.
Returns a list of file paths
"""
expanded_paths = []
for path in paths:
path = expand_path(path.strip())
globbed_files = fileglob.glob(path, recursive=True)
if globbed_files:
expanded_paths.extend(globbed_files)
else:
expanded_paths.append(path)
return expanded_paths
def split_and_match_files(paths: str) -> list[str]:
"""Take a string representing a list of files/directories (with support for globbing
through the glob library).
Where a path/glob matches no file, we still include the raw path in the resulting list.
Returns a list of file paths
"""
return split_and_match_files_list(split_commas(paths))
def check_follow_imports(choice: str) -> str:
choices = ["normal", "silent", "skip", "error"]
if choice not in choices:
raise argparse.ArgumentTypeError(
"invalid choice '{}' (choose from {})".format(
choice, ", ".join(f"'{x}'" for x in choices)
)
)
return choice
def check_junit_format(choice: str) -> str:
choices = ["global", "per_file"]
if choice not in choices:
raise argparse.ArgumentTypeError(
"invalid choice '{}' (choose from {})".format(
choice, ", ".join(f"'{x}'" for x in choices)
)
)
return choice
def split_commas(value: str) -> list[str]:
# Uses a bit smarter technique to allow last trailing comma
# and to remove last `""` item from the split.
items = value.split(",")
if items and items[-1] == "":
items.pop(-1)
return items
# For most options, the type of the default value set in options.py is
# sufficient, and we don't have to do anything here. This table
# exists to specify types for values initialized to None or container
# types.
ini_config_types: Final[dict[str, _INI_PARSER_CALLABLE]] = {
"python_version": parse_version,
"custom_typing_module": str,
"custom_typeshed_dir": expand_path,
"mypy_path": lambda s: [expand_path(p.strip()) for p in re.split("[,:]", s)],
"files": split_and_match_files,
"quickstart_file": expand_path,
"junit_xml": expand_path,
"junit_format": check_junit_format,
"follow_imports": check_follow_imports,
"no_site_packages": bool,
"plugins": lambda s: [p.strip() for p in split_commas(s)],
"always_true": lambda s: [p.strip() for p in split_commas(s)],
"always_false": lambda s: [p.strip() for p in split_commas(s)],
"untyped_calls_exclude": lambda s: validate_package_allow_list(
[p.strip() for p in split_commas(s)]
),
"enable_incomplete_feature": lambda s: [p.strip() for p in split_commas(s)],
"disable_error_code": lambda s: [p.strip() for p in split_commas(s)],
"enable_error_code": lambda s: [p.strip() for p in split_commas(s)],
"package_root": lambda s: [p.strip() for p in split_commas(s)],
"cache_dir": expand_path,
"python_executable": expand_path,
"strict": bool,
"exclude": lambda s: [s.strip()],
"packages": try_split,
"modules": try_split,
}
# Reuse the ini_config_types and overwrite the diff
toml_config_types: Final[dict[str, _INI_PARSER_CALLABLE]] = ini_config_types.copy()
toml_config_types.update(
{
"python_version": parse_version,
"mypy_path": lambda s: [expand_path(p) for p in try_split(s, "[,:]")],
"files": lambda s: split_and_match_files_list(try_split(s)),
"junit_format": lambda s: check_junit_format(str(s)),
"follow_imports": lambda s: check_follow_imports(str(s)),
"plugins": try_split,
"always_true": try_split,
"always_false": try_split,
"untyped_calls_exclude": lambda s: validate_package_allow_list(try_split(s)),
"enable_incomplete_feature": try_split,
"disable_error_code": lambda s: try_split(s),
"enable_error_code": lambda s: try_split(s),
"package_root": try_split,
"exclude": str_or_array_as_list,
"packages": try_split,
"modules": try_split,
}
)
def _parse_individual_file(
config_file: str, stderr: TextIO | None = None
) -> tuple[MutableMapping[str, Any], dict[str, _INI_PARSER_CALLABLE], str] | None:
if not os.path.exists(config_file):
return None
parser: MutableMapping[str, Any]
try:
if is_toml(config_file):
with open(config_file, "rb") as f:
toml_data = tomllib.load(f)
# Filter down to just mypy relevant toml keys
toml_data = toml_data.get("tool", {})
if "mypy" not in toml_data:
return None
toml_data = {"mypy": toml_data["mypy"]}
parser = destructure_overrides(toml_data)
config_types = toml_config_types
else:
parser = configparser.RawConfigParser()
parser.read(config_file)
config_types = ini_config_types
except (tomllib.TOMLDecodeError, configparser.Error, ConfigTOMLValueError) as err:
print(f"{config_file}: {err}", file=stderr)
return None
if os.path.basename(config_file) in defaults.SHARED_CONFIG_NAMES and "mypy" not in parser:
return None
return parser, config_types, config_file
def _find_config_file(
stderr: TextIO | None = None,
) -> tuple[MutableMapping[str, Any], dict[str, _INI_PARSER_CALLABLE], str] | None:
current_dir = os.path.abspath(os.getcwd())
while True:
for name in defaults.CONFIG_NAMES + defaults.SHARED_CONFIG_NAMES:
config_file = os.path.relpath(os.path.join(current_dir, name))
ret = _parse_individual_file(config_file, stderr)
if ret is None:
continue
return ret
if any(
os.path.exists(os.path.join(current_dir, cvs_root)) for cvs_root in (".git", ".hg")
):
break
parent_dir = os.path.dirname(current_dir)
if parent_dir == current_dir:
break
current_dir = parent_dir
for config_file in defaults.USER_CONFIG_FILES:
ret = _parse_individual_file(config_file, stderr)
if ret is None:
continue
return ret
return None
def parse_config_file(
options: Options,
set_strict_flags: Callable[[], None],
filename: str | None,
stdout: TextIO | None = None,
stderr: TextIO | None = None,
) -> None:
"""Parse a config file into an Options object.
Errors are written to stderr but are not fatal.
If filename is None, fall back to default config files.
"""
stdout = stdout or sys.stdout
stderr = stderr or sys.stderr
ret = (
_parse_individual_file(filename, stderr)
if filename is not None
else _find_config_file(stderr)
)
if ret is None:
return
parser, config_types, file_read = ret
options.config_file = file_read
os.environ["MYPY_CONFIG_FILE_DIR"] = os.path.dirname(os.path.abspath(file_read))
if "mypy" not in parser:
if filename or os.path.basename(file_read) not in defaults.SHARED_CONFIG_NAMES:
print(f"{file_read}: No [mypy] section in config file", file=stderr)
else:
section = parser["mypy"]
prefix = f"{file_read}: [mypy]: "
updates, report_dirs = parse_section(
prefix, options, set_strict_flags, section, config_types, stderr
)
for k, v in updates.items():
setattr(options, k, v)
options.report_dirs.update(report_dirs)
for name, section in parser.items():
if name.startswith("mypy-"):
prefix = get_prefix(file_read, name)
updates, report_dirs = parse_section(
prefix, options, set_strict_flags, section, config_types, stderr
)
if report_dirs:
print(
prefix,
"Per-module sections should not specify reports ({})".format(
", ".join(s + "_report" for s in sorted(report_dirs))
),
file=stderr,
)
if set(updates) - PER_MODULE_OPTIONS:
print(
prefix,
"Per-module sections should only specify per-module flags ({})".format(
", ".join(sorted(set(updates) - PER_MODULE_OPTIONS))
),
file=stderr,
)
updates = {k: v for k, v in updates.items() if k in PER_MODULE_OPTIONS}
globs = name[5:]
for glob in globs.split(","):
# For backwards compatibility, replace (back)slashes with dots.
glob = glob.replace(os.sep, ".")
if os.altsep:
glob = glob.replace(os.altsep, ".")
if any(c in glob for c in "?[]!") or any(
"*" in x and x != "*" for x in glob.split(".")
):
print(
prefix,
"Patterns must be fully-qualified module names, optionally "
"with '*' in some components (e.g spam.*.eggs.*)",
file=stderr,
)
else:
options.per_module_options[glob] = updates
def get_prefix(file_read: str, name: str) -> str:
if is_toml(file_read):
module_name_str = 'module = "%s"' % "-".join(name.split("-")[1:])
else:
module_name_str = name
return f"{file_read}: [{module_name_str}]:"
def is_toml(filename: str) -> bool:
return filename.lower().endswith(".toml")
def destructure_overrides(toml_data: dict[str, Any]) -> dict[str, Any]:
"""Take the new [[tool.mypy.overrides]] section array in the pyproject.toml file,
and convert it back to a flatter structure that the existing config_parser can handle.
E.g. the following pyproject.toml file:
[[tool.mypy.overrides]]
module = [
"a.b",
"b.*"
]
disallow_untyped_defs = true
[[tool.mypy.overrides]]
module = 'c'
disallow_untyped_defs = false
Would map to the following config dict that it would have gotten from parsing an equivalent
ini file:
{
"mypy-a.b": {
disallow_untyped_defs = true,
},
"mypy-b.*": {
disallow_untyped_defs = true,
},
"mypy-c": {
disallow_untyped_defs: false,
},
}
"""
if "overrides" not in toml_data["mypy"]:
return toml_data
if not isinstance(toml_data["mypy"]["overrides"], list):
raise ConfigTOMLValueError(
"tool.mypy.overrides sections must be an array. Please make "
"sure you are using double brackets like so: [[tool.mypy.overrides]]"
)
result = toml_data.copy()
for override in result["mypy"]["overrides"]:
if "module" not in override:
raise ConfigTOMLValueError(
"toml config file contains a [[tool.mypy.overrides]] "
"section, but no module to override was specified."
)
if isinstance(override["module"], str):
modules = [override["module"]]
elif isinstance(override["module"], list):
modules = override["module"]
else:
raise ConfigTOMLValueError(
"toml config file contains a [[tool.mypy.overrides]] "
"section with a module value that is not a string or a list of "
"strings"
)
for module in modules:
module_overrides = override.copy()
del module_overrides["module"]
old_config_name = f"mypy-{module}"
if old_config_name not in result:
result[old_config_name] = module_overrides
else:
for new_key, new_value in module_overrides.items():
if (
new_key in result[old_config_name]
and result[old_config_name][new_key] != new_value
):
raise ConfigTOMLValueError(
"toml config file contains "
"[[tool.mypy.overrides]] sections with conflicting "
f"values. Module '{module}' has two different values for '{new_key}'"
)
result[old_config_name][new_key] = new_value
del result["mypy"]["overrides"]
return result
def parse_section(
prefix: str,
template: Options,
set_strict_flags: Callable[[], None],
section: Mapping[str, Any],
config_types: dict[str, Any],
stderr: TextIO = sys.stderr,
) -> tuple[dict[str, object], dict[str, str]]:
"""Parse one section of a config file.
Returns a dict of option values encountered, and a dict of report directories.
"""
results: dict[str, object] = {}
report_dirs: dict[str, str] = {}
# Because these fields exist on Options, without proactive checking, we would accept them
# and crash later
invalid_options = {
"enabled_error_codes": "enable_error_code",
"disabled_error_codes": "disable_error_code",
}
for key in section:
invert = False
# Here we use `key` for original config section key, and `options_key` for
# the corresponding Options attribute.
options_key = key
# Match aliasing for command line flag.
if key.endswith("allow_redefinition"):
options_key += "_old"
if key in config_types:
ct = config_types[key]
elif key in invalid_options:
print(
f"{prefix}Unrecognized option: {key} = {section[key]}"
f" (did you mean {invalid_options[key]}?)",
file=stderr,
)
continue
else:
dv = getattr(template, options_key, None)
if dv is None:
if key.endswith("_report"):
report_type = key[:-7].replace("_", "-")
if report_type in defaults.REPORTER_NAMES:
report_dirs[report_type] = str(section[key])
else:
print(f"{prefix}Unrecognized report type: {key}", file=stderr)
continue
if key.startswith("x_"):
pass # Don't complain about `x_blah` flags
elif key.startswith("no_") and hasattr(template, options_key[3:]):
options_key = options_key[3:]
invert = True
elif key.startswith("allow") and hasattr(template, "dis" + options_key):
options_key = "dis" + options_key
invert = True
elif key.startswith("disallow") and hasattr(template, options_key[3:]):
options_key = options_key[3:]
invert = True
elif key.startswith("show_") and hasattr(template, "hide_" + options_key[5:]):
options_key = "hide_" + options_key[5:]
invert = True
elif key == "strict":
pass # Special handling below
else:
print(f"{prefix}Unrecognized option: {key} = {section[key]}", file=stderr)
if invert:
dv = getattr(template, options_key, None)
else:
continue
ct = type(dv) if dv is not None else None
v: Any = None
try:
if ct is bool:
if isinstance(section, dict):
v = convert_to_boolean(section.get(key))
else:
v = section.getboolean(key) # type: ignore[attr-defined] # Until better stub
if invert:
v = not v
elif callable(ct):
if invert:
print(f"{prefix}Can not invert non-boolean key {options_key}", file=stderr)
continue
try:
v = ct(section.get(key))
except VersionTypeError as err_version:
print(f"{prefix}{key}: {err_version}", file=stderr)
v = err_version.fallback
except argparse.ArgumentTypeError as err:
print(f"{prefix}{key}: {err}", file=stderr)
continue
else:
print(f"{prefix}Don't know what type {key} should have", file=stderr)
continue
except ValueError as err:
print(f"{prefix}{key}: {err}", file=stderr)
continue
if key == "strict":
if v:
set_strict_flags()
continue
results[options_key] = v
# These two flags act as per-module overrides, so store the empty defaults.
if "disable_error_code" not in results:
results["disable_error_code"] = []
if "enable_error_code" not in results:
results["enable_error_code"] = []
return results, report_dirs
def convert_to_boolean(value: Any | None) -> bool:
"""Return a boolean value translating from other types if necessary."""
if isinstance(value, bool):
return value
if not isinstance(value, str):
value = str(value)
if value.lower() not in configparser.RawConfigParser.BOOLEAN_STATES:
raise ValueError(f"Not a boolean: {value}")
return configparser.RawConfigParser.BOOLEAN_STATES[value.lower()]
def split_directive(s: str) -> tuple[list[str], list[str]]:
"""Split s on commas, except during quoted sections.
Returns the parts and a list of error messages."""
parts = []
cur: list[str] = []
errors = []
i = 0
while i < len(s):
if s[i] == ",":
parts.append("".join(cur).strip())
cur = []
elif s[i] == '"':
i += 1
while i < len(s) and s[i] != '"':
cur.append(s[i])
i += 1
if i == len(s):
errors.append("Unterminated quote in configuration comment")
cur.clear()
else:
cur.append(s[i])
i += 1
if cur:
parts.append("".join(cur).strip())
return parts, errors
def mypy_comments_to_config_map(line: str, template: Options) -> tuple[dict[str, str], list[str]]:
"""Rewrite the mypy comment syntax into ini file syntax."""
options = {}
entries, errors = split_directive(line)
for entry in entries:
if "=" not in entry:
name = entry
value = None
else:
name, value = (x.strip() for x in entry.split("=", 1))
name = name.replace("-", "_")
if value is None:
value = "True"
options[name] = value
return options, errors
def parse_mypy_comments(
args: list[tuple[int, str]], template: Options
) -> tuple[dict[str, object], list[tuple[int, str]]]:
"""Parse a collection of inline mypy: configuration comments.
Returns a dictionary of options to be applied and a list of error messages
generated.
"""
errors: list[tuple[int, str]] = []
sections: dict[str, object] = {"enable_error_code": [], "disable_error_code": []}
for lineno, line in args:
# In order to easily match the behavior for bools, we abuse configparser.
# Oddly, the only way to get the SectionProxy object with the getboolean
# method is to create a config parser.
parser = configparser.RawConfigParser()
options, parse_errors = mypy_comments_to_config_map(line, template)
if "python_version" in options:
errors.append((lineno, "python_version not supported in inline configuration"))
del options["python_version"]
parser["dummy"] = options
errors.extend((lineno, x) for x in parse_errors)
stderr = StringIO()
strict_found = False
def set_strict_flags() -> None:
nonlocal strict_found
strict_found = True
new_sections, reports = parse_section(
"", template, set_strict_flags, parser["dummy"], ini_config_types, stderr=stderr
)
errors.extend((lineno, x) for x in stderr.getvalue().strip().split("\n") if x)
if reports:
errors.append((lineno, "Reports not supported in inline configuration"))
if strict_found:
errors.append(
(
lineno,
'Setting "strict" not supported in inline configuration: specify it in '
"a configuration file instead, or set individual inline flags "
'(see "mypy -h" for the list of flags enabled in strict mode)',
)
)
# Because this is currently special-cased
# (the new_sections for an inline config *always* includes 'disable_error_code' and
# 'enable_error_code' fields, usually empty, which overwrite the old ones),
# we have to manipulate them specially.
# This could use a refactor, but so could the whole subsystem.
if (
"enable_error_code" in new_sections
and isinstance(neec := new_sections["enable_error_code"], list)
and isinstance(eec := sections.get("enable_error_code", []), list)
):
new_sections["enable_error_code"] = sorted(set(neec + eec))
if (
"disable_error_code" in new_sections
and isinstance(ndec := new_sections["disable_error_code"], list)
and isinstance(dec := sections.get("disable_error_code", []), list)
):
new_sections["disable_error_code"] = sorted(set(ndec + dec))
sections.update(new_sections)
return sections, errors
def get_config_module_names(filename: str | None, modules: list[str]) -> str:
if not filename or not modules:
return ""
if not is_toml(filename):
return ", ".join(f"[mypy-{module}]" for module in modules)
return "module = ['%s']" % ("', '".join(sorted(modules)))
class ConfigTOMLValueError(ValueError):
pass

View file

@ -0,0 +1,187 @@
"""Constant folding of expressions.
For example, 3 + 5 can be constant folded into 8.
"""
from __future__ import annotations
from typing import Final
from mypy.nodes import (
ComplexExpr,
Expression,
FloatExpr,
IntExpr,
NameExpr,
OpExpr,
StrExpr,
UnaryExpr,
Var,
)
# All possible result types of constant folding
ConstantValue = int | bool | float | complex | str
CONST_TYPES: Final = (int, bool, float, complex, str)
def constant_fold_expr(expr: Expression, cur_mod_id: str) -> ConstantValue | None:
"""Return the constant value of an expression for supported operations.
Among other things, support int arithmetic and string
concatenation. For example, the expression 3 + 5 has the constant
value 8.
Also bind simple references to final constants defined in the
current module (cur_mod_id). Binding to references is best effort
-- we don't bind references to other modules. Mypyc trusts these
to be correct in compiled modules, so that it can replace a
constant expression (or a reference to one) with the statically
computed value. We don't want to infer constant values based on
stubs, in particular, as these might not match the implementation
(due to version skew, for example).
Return None if unsuccessful.
"""
if isinstance(expr, IntExpr):
return expr.value
if isinstance(expr, StrExpr):
return expr.value
if isinstance(expr, FloatExpr):
return expr.value
if isinstance(expr, ComplexExpr):
return expr.value
elif isinstance(expr, NameExpr):
if expr.name == "True":
return True
elif expr.name == "False":
return False
node = expr.node
if (
isinstance(node, Var)
and node.is_final
and node.fullname.rsplit(".", 1)[0] == cur_mod_id
):
value = node.final_value
if isinstance(value, (CONST_TYPES)):
return value
elif isinstance(expr, OpExpr):
left = constant_fold_expr(expr.left, cur_mod_id)
right = constant_fold_expr(expr.right, cur_mod_id)
if left is not None and right is not None:
return constant_fold_binary_op(expr.op, left, right)
elif isinstance(expr, UnaryExpr):
value = constant_fold_expr(expr.expr, cur_mod_id)
if value is not None:
return constant_fold_unary_op(expr.op, value)
return None
def constant_fold_binary_op(
op: str, left: ConstantValue, right: ConstantValue
) -> ConstantValue | None:
if isinstance(left, int) and isinstance(right, int):
return constant_fold_binary_int_op(op, left, right)
# Float and mixed int/float arithmetic.
if isinstance(left, float) and isinstance(right, float):
return constant_fold_binary_float_op(op, left, right)
elif isinstance(left, float) and isinstance(right, int):
return constant_fold_binary_float_op(op, left, right)
elif isinstance(left, int) and isinstance(right, float):
return constant_fold_binary_float_op(op, left, right)
# String concatenation and multiplication.
if op == "+" and isinstance(left, str) and isinstance(right, str):
return left + right
elif op == "*" and isinstance(left, str) and isinstance(right, int):
return left * right
elif op == "*" and isinstance(left, int) and isinstance(right, str):
return left * right
# Complex construction.
if op == "+" and isinstance(left, (int, float)) and isinstance(right, complex):
return left + right
elif op == "+" and isinstance(left, complex) and isinstance(right, (int, float)):
return left + right
elif op == "-" and isinstance(left, (int, float)) and isinstance(right, complex):
return left - right
elif op == "-" and isinstance(left, complex) and isinstance(right, (int, float)):
return left - right
return None
def constant_fold_binary_int_op(op: str, left: int, right: int) -> int | float | None:
if op == "+":
return left + right
if op == "-":
return left - right
elif op == "*":
return left * right
elif op == "/":
if right != 0:
return left / right
elif op == "//":
if right != 0:
return left // right
elif op == "%":
if right != 0:
return left % right
elif op == "&":
return left & right
elif op == "|":
return left | right
elif op == "^":
return left ^ right
elif op == "<<":
if right >= 0:
return left << right
elif op == ">>":
if right >= 0:
return left >> right
elif op == "**":
if right >= 0:
ret = left**right
assert isinstance(ret, int)
return ret
return None
def constant_fold_binary_float_op(op: str, left: int | float, right: int | float) -> float | None:
assert not (isinstance(left, int) and isinstance(right, int)), (op, left, right)
if op == "+":
return left + right
elif op == "-":
return left - right
elif op == "*":
return left * right
elif op == "/":
if right != 0:
return left / right
elif op == "//":
if right != 0:
return left // right
elif op == "%":
if right != 0:
return left % right
elif op == "**":
if (left < 0 and isinstance(right, int)) or left > 0:
try:
ret = left**right
except OverflowError:
return None
else:
assert isinstance(ret, float), ret
return ret
return None
def constant_fold_unary_op(op: str, value: ConstantValue) -> int | float | None:
if op == "-" and isinstance(value, (int, float)):
return -value
elif op == "~" and isinstance(value, int):
return ~value
elif op == "+" and isinstance(value, (int, float)):
return value
return None

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,134 @@
from __future__ import annotations
from typing import Any, cast
from mypy.types import (
AnyType,
CallableType,
DeletedType,
ErasedType,
Instance,
LiteralType,
NoneType,
Overloaded,
Parameters,
ParamSpecType,
PartialType,
ProperType,
TupleType,
TypeAliasType,
TypedDictType,
TypeType,
TypeVarTupleType,
TypeVarType,
UnboundType,
UninhabitedType,
UnionType,
UnpackType,
)
# type_visitor needs to be imported after types
from mypy.type_visitor import TypeVisitor # ruff: isort: skip
def copy_type(t: ProperType) -> ProperType:
"""Create a shallow copy of a type.
This can be used to mutate the copy with truthiness information.
Classes compiled with mypyc don't support copy.copy(), so we need
a custom implementation.
"""
return t.accept(TypeShallowCopier())
class TypeShallowCopier(TypeVisitor[ProperType]):
def visit_unbound_type(self, t: UnboundType) -> ProperType:
return t
def visit_any(self, t: AnyType) -> ProperType:
return self.copy_common(t, AnyType(t.type_of_any, t.source_any, t.missing_import_name))
def visit_none_type(self, t: NoneType) -> ProperType:
return self.copy_common(t, NoneType())
def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType:
dup = UninhabitedType()
dup.ambiguous = t.ambiguous
return self.copy_common(t, dup)
def visit_erased_type(self, t: ErasedType) -> ProperType:
return self.copy_common(t, ErasedType())
def visit_deleted_type(self, t: DeletedType) -> ProperType:
return self.copy_common(t, DeletedType(t.source))
def visit_instance(self, t: Instance) -> ProperType:
dup = Instance(t.type, t.args, last_known_value=t.last_known_value)
return self.copy_common(t, dup)
def visit_type_var(self, t: TypeVarType) -> ProperType:
return self.copy_common(t, t.copy_modified())
def visit_param_spec(self, t: ParamSpecType) -> ProperType:
dup = ParamSpecType(
t.name, t.fullname, t.id, t.flavor, t.upper_bound, t.default, prefix=t.prefix
)
return self.copy_common(t, dup)
def visit_parameters(self, t: Parameters) -> ProperType:
dup = Parameters(
t.arg_types,
t.arg_kinds,
t.arg_names,
variables=t.variables,
is_ellipsis_args=t.is_ellipsis_args,
)
return self.copy_common(t, dup)
def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
dup = TypeVarTupleType(
t.name, t.fullname, t.id, t.upper_bound, t.tuple_fallback, t.default
)
return self.copy_common(t, dup)
def visit_unpack_type(self, t: UnpackType) -> ProperType:
dup = UnpackType(t.type)
return self.copy_common(t, dup)
def visit_partial_type(self, t: PartialType) -> ProperType:
return self.copy_common(t, PartialType(t.type, t.var, t.value_type))
def visit_callable_type(self, t: CallableType) -> ProperType:
return self.copy_common(t, t.copy_modified())
def visit_tuple_type(self, t: TupleType) -> ProperType:
return self.copy_common(t, TupleType(t.items, t.partial_fallback, implicit=t.implicit))
def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
return self.copy_common(
t, TypedDictType(t.items, t.required_keys, t.readonly_keys, t.fallback)
)
def visit_literal_type(self, t: LiteralType) -> ProperType:
return self.copy_common(t, LiteralType(value=t.value, fallback=t.fallback))
def visit_union_type(self, t: UnionType) -> ProperType:
return self.copy_common(t, UnionType(t.items))
def visit_overloaded(self, t: Overloaded) -> ProperType:
return self.copy_common(t, Overloaded(items=t.items))
def visit_type_type(self, t: TypeType) -> ProperType:
# Use cast since the type annotations in TypeType are imprecise.
return self.copy_common(t, TypeType(cast(Any, t.item), is_type_form=t.is_type_form))
def visit_type_alias_type(self, t: TypeAliasType) -> ProperType:
assert False, "only ProperTypes supported"
def copy_common(self, t: ProperType, t2: ProperType) -> ProperType:
t2.line = t.line
t2.column = t.column
t2.can_be_false = t.can_be_false
t2.can_be_true = t.can_be_true
return t2

View file

@ -0,0 +1,51 @@
from __future__ import annotations
import os
from typing import Final
# Earliest fully supported Python 3.x version. Used as the default Python
# version in tests. Mypy wheels should be built starting with this version,
# and CI tests should be run on this version (and later versions).
PYTHON3_VERSION: Final = (3, 10)
# Earliest Python 3.x version supported via --python-version 3.x. To run
# mypy, at least version PYTHON3_VERSION is needed.
PYTHON3_VERSION_MIN: Final = (3, 9) # Keep in sync with typeshed's python support
CACHE_DIR: Final = ".mypy_cache"
CONFIG_NAMES: Final = ["mypy.ini", ".mypy.ini"]
SHARED_CONFIG_NAMES: Final = ["pyproject.toml", "setup.cfg"]
USER_CONFIG_FILES: list[str] = ["~/.config/mypy/config", "~/.mypy.ini"]
if os.environ.get("XDG_CONFIG_HOME"):
USER_CONFIG_FILES.insert(0, os.path.join(os.environ["XDG_CONFIG_HOME"], "mypy/config"))
USER_CONFIG_FILES = [os.path.expanduser(f) for f in USER_CONFIG_FILES]
# This must include all reporters defined in mypy.report. This is defined here
# to make reporter names available without importing mypy.report -- this speeds
# up startup.
REPORTER_NAMES: Final = [
"linecount",
"any-exprs",
"linecoverage",
"memory-xml",
"cobertura-xml",
"xml",
"xslt-html",
"xslt-txt",
"html",
"txt",
"lineprecision",
]
# Threshold after which we sometimes filter out most errors to avoid very
# verbose output. The default is to show all errors.
MANY_ERRORS_THRESHOLD: Final = -1
RECURSION_LIMIT: Final = 2**14
WORKER_START_INTERVAL: Final = 0.01
WORKER_START_TIMEOUT: Final = 3
WORKER_CONNECTION_TIMEOUT: Final = 10
WORKER_DONE_TIMEOUT: Final = 600

View file

@ -0,0 +1,6 @@
from __future__ import annotations
from mypy.dmypy.client import console_entry
if __name__ == "__main__":
console_entry()

View file

@ -0,0 +1,732 @@
"""Client for mypy daemon mode.
This manages a daemon process which keeps useful state in memory
rather than having to read it back from disk on each run.
"""
from __future__ import annotations
import argparse
import json
import os
import pickle
import sys
import time
import traceback
from collections.abc import Callable, Mapping
from typing import Any, NoReturn
from librt.base64 import b64decode
from mypy.defaults import RECURSION_LIMIT
from mypy.dmypy_os import alive, kill
from mypy.dmypy_util import DEFAULT_STATUS_FILE, receive, send
from mypy.ipc import BadStatus, IPCClient, IPCException, read_status
from mypy.util import check_python_version, get_terminal_width, should_force_color
from mypy.version import __version__
# Argument parser. Subparsers are tied to action functions by the
# @action(subparse) decorator.
class AugmentedHelpFormatter(argparse.RawDescriptionHelpFormatter):
def __init__(self, prog: str, **kwargs: Any) -> None:
super().__init__(prog=prog, max_help_position=30, **kwargs)
parser = argparse.ArgumentParser(
prog="dmypy", description="Client for mypy daemon mode", fromfile_prefix_chars="@"
)
parser.set_defaults(action=None)
parser.add_argument(
"--status-file", default=DEFAULT_STATUS_FILE, help="status file to retrieve daemon details"
)
parser.add_argument(
"-V",
"--version",
action="version",
version="%(prog)s " + __version__,
help="Show program's version number and exit",
)
subparsers = parser.add_subparsers()
start_parser = p = subparsers.add_parser("start", help="Start daemon")
p.add_argument("--log-file", metavar="FILE", type=str, help="Direct daemon stdout/stderr to FILE")
p.add_argument(
"--timeout", metavar="TIMEOUT", type=int, help="Server shutdown timeout (in seconds)"
)
p.add_argument(
"flags", metavar="FLAG", nargs="*", type=str, help="Regular mypy flags (precede with --)"
)
restart_parser = p = subparsers.add_parser(
"restart", help="Restart daemon (stop or kill followed by start)"
)
p.add_argument("--log-file", metavar="FILE", type=str, help="Direct daemon stdout/stderr to FILE")
p.add_argument(
"--timeout", metavar="TIMEOUT", type=int, help="Server shutdown timeout (in seconds)"
)
p.add_argument(
"flags", metavar="FLAG", nargs="*", type=str, help="Regular mypy flags (precede with --)"
)
status_parser = p = subparsers.add_parser("status", help="Show daemon status")
p.add_argument("-v", "--verbose", action="store_true", help="Print detailed status")
p.add_argument("--fswatcher-dump-file", help="Collect information about the current file state")
stop_parser = p = subparsers.add_parser("stop", help="Stop daemon (asks it politely to go away)")
kill_parser = p = subparsers.add_parser("kill", help="Kill daemon (kills the process)")
check_parser = p = subparsers.add_parser(
"check", formatter_class=AugmentedHelpFormatter, help="Check some files (requires daemon)"
)
p.add_argument("-v", "--verbose", action="store_true", help="Print detailed status")
p.add_argument("-q", "--quiet", action="store_true", help=argparse.SUPPRESS) # Deprecated
p.add_argument("--junit-xml", help="Write junit.xml to the given file")
p.add_argument("--perf-stats-file", help="write performance information to the given file")
p.add_argument("files", metavar="FILE", nargs="+", help="File (or directory) to check")
p.add_argument(
"--export-types",
action="store_true",
help="Store types of all expressions in a shared location (useful for inspections)",
)
run_parser = p = subparsers.add_parser(
"run",
formatter_class=AugmentedHelpFormatter,
help="Check some files, [re]starting daemon if necessary",
)
p.add_argument("-v", "--verbose", action="store_true", help="Print detailed status")
p.add_argument("--junit-xml", help="Write junit.xml to the given file")
p.add_argument("--perf-stats-file", help="write performance information to the given file")
p.add_argument(
"--timeout", metavar="TIMEOUT", type=int, help="Server shutdown timeout (in seconds)"
)
p.add_argument("--log-file", metavar="FILE", type=str, help="Direct daemon stdout/stderr to FILE")
p.add_argument(
"--export-types",
action="store_true",
help="Store types of all expressions in a shared location (useful for inspections)",
)
p.add_argument(
"flags",
metavar="ARG",
nargs="*",
type=str,
help="Regular mypy flags and files (precede with --)",
)
recheck_parser = p = subparsers.add_parser(
"recheck",
formatter_class=AugmentedHelpFormatter,
help="Re-check the previous list of files, with optional modifications (requires daemon)",
)
p.add_argument("-v", "--verbose", action="store_true", help="Print detailed status")
p.add_argument("-q", "--quiet", action="store_true", help=argparse.SUPPRESS) # Deprecated
p.add_argument("--junit-xml", help="Write junit.xml to the given file")
p.add_argument("--perf-stats-file", help="write performance information to the given file")
p.add_argument(
"--export-types",
action="store_true",
help="Store types of all expressions in a shared location (useful for inspections)",
)
p.add_argument(
"--update",
metavar="FILE",
nargs="*",
help="Files in the run to add or check again (default: all from previous run)",
)
p.add_argument("--remove", metavar="FILE", nargs="*", help="Files to remove from the run")
suggest_parser = p = subparsers.add_parser(
"suggest", help="Suggest a signature or show call sites for a specific function"
)
p.add_argument(
"function",
metavar="FUNCTION",
type=str,
help="Function specified as '[package.]module.[class.]function'",
)
p.add_argument(
"--json",
action="store_true",
help="Produce json that pyannotate can use to apply a suggestion",
)
p.add_argument(
"--no-errors", action="store_true", help="Only produce suggestions that cause no errors"
)
p.add_argument(
"--no-any", action="store_true", help="Only produce suggestions that don't contain Any"
)
p.add_argument(
"--flex-any",
type=float,
help="Allow anys in types if they go above a certain score (scores are from 0-1)",
)
p.add_argument(
"--callsites", action="store_true", help="Find callsites instead of suggesting a type"
)
p.add_argument(
"--use-fixme",
metavar="NAME",
type=str,
help="A dummy name to use instead of Any for types that can't be inferred",
)
p.add_argument(
"--max-guesses",
type=int,
help="Set the maximum number of types to try for a function (default 64)",
)
inspect_parser = p = subparsers.add_parser(
"inspect", help="Locate and statically inspect expression(s)"
)
p.add_argument(
"location",
metavar="LOCATION",
type=str,
help="Location specified as path/to/file.py:line:column[:end_line:end_column]."
" If position is given (i.e. only line and column), this will return all"
" enclosing expressions",
)
p.add_argument(
"--show",
metavar="INSPECTION",
type=str,
default="type",
choices=["type", "attrs", "definition"],
help="What kind of inspection to run",
)
p.add_argument(
"--verbose",
"-v",
action="count",
default=0,
help="Increase verbosity of the type string representation (can be repeated)",
)
p.add_argument(
"--limit",
metavar="NUM",
type=int,
default=0,
help="Return at most NUM innermost expressions (if position is given); 0 means no limit",
)
p.add_argument(
"--include-span",
action="store_true",
help="Prepend each inspection result with the span of corresponding expression"
' (e.g. 1:2:3:4:"int")',
)
p.add_argument(
"--include-kind",
action="store_true",
help="Prepend each inspection result with the kind of corresponding expression"
' (e.g. NameExpr:"int")',
)
p.add_argument(
"--include-object-attrs",
action="store_true",
help='Include attributes of "object" in "attrs" inspection',
)
p.add_argument(
"--union-attrs",
action="store_true",
help="Include attributes valid for some of possible expression types"
" (by default an intersection is returned)",
)
p.add_argument(
"--force-reload",
action="store_true",
help="Re-parse and re-type-check file before inspection (may be slow)",
)
hang_parser = p = subparsers.add_parser("hang", help="Hang for 100 seconds")
daemon_parser = p = subparsers.add_parser("daemon", help="Run daemon in foreground")
p.add_argument(
"--timeout", metavar="TIMEOUT", type=int, help="Server shutdown timeout (in seconds)"
)
p.add_argument("--log-file", metavar="FILE", type=str, help="Direct daemon stdout/stderr to FILE")
p.add_argument(
"flags", metavar="FLAG", nargs="*", type=str, help="Regular mypy flags (precede with --)"
)
p.add_argument("--options-data", help=argparse.SUPPRESS)
help_parser = p = subparsers.add_parser("help")
del p
def main(argv: list[str]) -> None:
"""The code is top-down."""
check_python_version("dmypy")
# set recursion limit consistent with mypy/main.py
sys.setrecursionlimit(RECURSION_LIMIT)
args = parser.parse_args(argv)
if not args.action:
parser.print_usage()
else:
try:
args.action(args)
except BadStatus as err:
fail(err.args[0])
except Exception:
# We do this explicitly to avoid exceptions percolating up
# through mypy.api invocations
traceback.print_exc()
sys.exit(2)
def fail(msg: str) -> NoReturn:
print(msg, file=sys.stderr)
sys.exit(2)
ActionFunction = Callable[[argparse.Namespace], None]
def action(subparser: argparse.ArgumentParser) -> Callable[[ActionFunction], ActionFunction]:
"""Decorator to tie an action function to a subparser."""
def register(func: ActionFunction) -> ActionFunction:
subparser.set_defaults(action=func)
return func
return register
# Action functions (run in client from command line).
@action(start_parser)
def do_start(args: argparse.Namespace) -> None:
"""Start daemon (it must not already be running).
This is where mypy flags are set from the command line.
Setting flags is a bit awkward; you have to use e.g.:
dmypy start -- --strict
since we don't want to duplicate mypy's huge list of flags.
"""
try:
get_status(args.status_file)
except BadStatus:
# Bad or missing status file or dead process; good to start.
pass
else:
fail("Daemon is still alive")
start_server(args)
@action(restart_parser)
def do_restart(args: argparse.Namespace) -> None:
"""Restart daemon (it may or may not be running; but not hanging).
We first try to stop it politely if it's running. This also sets
mypy flags from the command line (see do_start()).
"""
restart_server(args)
def restart_server(args: argparse.Namespace, allow_sources: bool = False) -> None:
"""Restart daemon (it may or may not be running; but not hanging)."""
try:
do_stop(args)
except BadStatus:
# Bad or missing status file or dead process; good to start.
pass
start_server(args, allow_sources)
def start_server(args: argparse.Namespace, allow_sources: bool = False) -> None:
"""Start the server from command arguments and wait for it."""
# Lazy import so this import doesn't slow down other commands.
from mypy.dmypy_server import daemonize, process_start_options
start_options = process_start_options(args.flags, allow_sources)
if daemonize(start_options, args.status_file, timeout=args.timeout, log_file=args.log_file):
sys.exit(2)
wait_for_server(args.status_file)
def wait_for_server(status_file: str, timeout: float = 5.0) -> None:
"""Wait until the server is up.
Exit if it doesn't happen within the timeout.
"""
endtime = time.time() + timeout
while time.time() < endtime:
try:
data = read_status(status_file)
except BadStatus:
# If the file isn't there yet, retry later.
time.sleep(0.1)
continue
# If the file's content is bogus or the process is dead, fail.
check_status(data)
print("Daemon started")
return
fail("Timed out waiting for daemon to start")
@action(run_parser)
def do_run(args: argparse.Namespace) -> None:
"""Do a check, starting (or restarting) the daemon as necessary
Restarts the daemon if the running daemon reports that it is
required (due to a configuration change, for example).
Setting flags is a bit awkward; you have to use e.g.:
dmypy run -- --strict a.py b.py ...
since we don't want to duplicate mypy's huge list of flags.
(The -- is only necessary if flags are specified.)
"""
if not is_running(args.status_file):
# Bad or missing status file or dead process; good to start.
start_server(args, allow_sources=True)
t0 = time.time()
response = request(
args.status_file,
"run",
version=__version__,
args=args.flags,
export_types=args.export_types,
)
# If the daemon signals that a restart is necessary, do it
if "restart" in response:
print(f"Restarting: {response['restart']}")
restart_server(args, allow_sources=True)
response = request(
args.status_file,
"run",
version=__version__,
args=args.flags,
export_types=args.export_types,
)
t1 = time.time()
response["roundtrip_time"] = t1 - t0
check_output(response, args.verbose, args.junit_xml, args.perf_stats_file)
@action(status_parser)
def do_status(args: argparse.Namespace) -> None:
"""Print daemon status.
This verifies that it is responsive to requests.
"""
status = read_status(args.status_file)
if args.verbose:
show_stats(status)
# Both check_status() and request() may raise BadStatus,
# which will be handled by main().
check_status(status)
response = request(
args.status_file, "status", fswatcher_dump_file=args.fswatcher_dump_file, timeout=5
)
if args.verbose or "error" in response:
show_stats(response)
if "error" in response:
fail(f"Daemon may be busy processing; if this persists, consider {sys.argv[0]} kill")
print("Daemon is up and running")
@action(stop_parser)
def do_stop(args: argparse.Namespace) -> None:
"""Stop daemon via a 'stop' request."""
# May raise BadStatus, which will be handled by main().
response = request(args.status_file, "stop", timeout=5)
if "error" in response:
show_stats(response)
fail(f"Daemon may be busy processing; if this persists, consider {sys.argv[0]} kill")
else:
print("Daemon stopped")
@action(kill_parser)
def do_kill(args: argparse.Namespace) -> None:
"""Kill daemon process with SIGKILL."""
pid, _ = get_status(args.status_file)
try:
kill(pid)
except OSError as err:
fail(str(err))
else:
print("Daemon killed")
@action(check_parser)
def do_check(args: argparse.Namespace) -> None:
"""Ask the daemon to check a list of files."""
t0 = time.time()
response = request(args.status_file, "check", files=args.files, export_types=args.export_types)
t1 = time.time()
response["roundtrip_time"] = t1 - t0
check_output(response, args.verbose, args.junit_xml, args.perf_stats_file)
@action(recheck_parser)
def do_recheck(args: argparse.Namespace) -> None:
"""Ask the daemon to recheck the previous list of files, with optional modifications.
If at least one of --remove or --update is given, the server will
update the list of files to check accordingly and assume that any other files
are unchanged. If none of these flags are given, the server will call stat()
on each file last checked to determine its status.
Files given in --update ought to exist. Files given in --remove need not exist;
if they don't they will be ignored.
The lists may be empty but oughtn't contain duplicates or overlap.
NOTE: The list of files is lost when the daemon is restarted.
"""
t0 = time.time()
if args.remove is not None or args.update is not None:
response = request(
args.status_file,
"recheck",
export_types=args.export_types,
remove=args.remove,
update=args.update,
)
else:
response = request(args.status_file, "recheck", export_types=args.export_types)
t1 = time.time()
response["roundtrip_time"] = t1 - t0
check_output(response, args.verbose, args.junit_xml, args.perf_stats_file)
@action(suggest_parser)
def do_suggest(args: argparse.Namespace) -> None:
"""Ask the daemon for a suggested signature.
This just prints whatever the daemon reports as output.
For now it may be closer to a list of call sites.
"""
response = request(
args.status_file,
"suggest",
function=args.function,
json=args.json,
callsites=args.callsites,
no_errors=args.no_errors,
no_any=args.no_any,
flex_any=args.flex_any,
use_fixme=args.use_fixme,
max_guesses=args.max_guesses,
)
check_output(response, verbose=False, junit_xml=None, perf_stats_file=None)
@action(inspect_parser)
def do_inspect(args: argparse.Namespace) -> None:
"""Ask daemon to print the type of an expression."""
response = request(
args.status_file,
"inspect",
show=args.show,
location=args.location,
verbosity=args.verbose,
limit=args.limit,
include_span=args.include_span,
include_kind=args.include_kind,
include_object_attrs=args.include_object_attrs,
union_attrs=args.union_attrs,
force_reload=args.force_reload,
)
check_output(response, verbose=False, junit_xml=None, perf_stats_file=None)
def check_output(
response: dict[str, Any], verbose: bool, junit_xml: str | None, perf_stats_file: str | None
) -> None:
"""Print the output from a check or recheck command.
Call sys.exit() unless the status code is zero.
"""
if os.name == "nt":
# Enable ANSI color codes for Windows cmd using this strange workaround
# ( see https://github.com/python/cpython/issues/74261 )
os.system("")
if "error" in response:
fail(response["error"])
try:
out, err, status_code = response["out"], response["err"], response["status"]
except KeyError:
fail(f"Response: {str(response)}")
sys.stdout.write(out)
sys.stdout.flush()
sys.stderr.write(err)
sys.stderr.flush()
if verbose:
show_stats(response)
if junit_xml:
# Lazy import so this import doesn't slow things down when not writing junit
from mypy.util import write_junit_xml
messages = (out + err).splitlines()
write_junit_xml(
response["roundtrip_time"],
bool(err),
{None: messages} if messages else {},
junit_xml,
response["python_version"],
response["platform"],
)
if perf_stats_file:
telemetry = response.get("stats", {})
with open(perf_stats_file, "w") as f:
json.dump(telemetry, f)
if status_code:
sys.exit(status_code)
def show_stats(response: Mapping[str, object]) -> None:
for key, value in sorted(response.items()):
if key in ("out", "err", "stdout", "stderr"):
# Special case text output to display just 40 characters of text
value = repr(value)[1:-1]
if len(value) > 50:
value = f"{value[:40]} ... {len(value)-40} more characters"
print("%-24s: %s" % (key, value))
continue
print("%-24s: %10s" % (key, "%.3f" % value if isinstance(value, float) else value))
@action(hang_parser)
def do_hang(args: argparse.Namespace) -> None:
"""Hang for 100 seconds, as a debug hack."""
print(request(args.status_file, "hang", timeout=1))
@action(daemon_parser)
def do_daemon(args: argparse.Namespace) -> None:
"""Serve requests in the foreground."""
# Lazy import so this import doesn't slow down other commands.
from mypy.dmypy_server import Server, process_start_options
if args.log_file:
sys.stdout = sys.stderr = open(args.log_file, "a", buffering=1)
fd = sys.stdout.fileno()
os.dup2(fd, 2)
os.dup2(fd, 1)
if args.options_data:
from mypy.options import Options
options_dict = pickle.loads(b64decode(args.options_data))
options_obj = Options()
options = options_obj.apply_changes(options_dict)
else:
options = process_start_options(args.flags, allow_sources=False)
Server(options, args.status_file, timeout=args.timeout).serve()
@action(help_parser)
def do_help(args: argparse.Namespace) -> None:
"""Print full help (same as dmypy --help)."""
parser.print_help()
# Client-side infrastructure.
def request(
status_file: str, command: str, *, timeout: int | None = None, **kwds: object
) -> dict[str, Any]:
"""Send a request to the daemon.
Return the JSON dict with the response.
Raise BadStatus if there is something wrong with the status file
or if the process whose pid is in the status file has died.
Return {'error': <message>} if an IPC operation or receive()
raised OSError. This covers cases such as connection refused or
closed prematurely as well as invalid JSON received.
"""
response: dict[str, str] = {}
args = dict(kwds)
args["command"] = command
# Tell the server whether this request was initiated from a human-facing terminal,
# so that it can format the type checking output accordingly.
args["is_tty"] = sys.stdout.isatty() or should_force_color()
args["terminal_width"] = get_terminal_width()
_, name = get_status(status_file)
try:
with IPCClient(name, timeout) as client:
send(client, args)
final = False
while not final:
response = receive(client)
final = bool(response.pop("final", False))
# Display debugging output written to stdout/stderr in the server process for convenience.
# This should not be confused with "out" and "err" fields in the response.
# Those fields hold the output of the "check" command, and are handled in check_output().
stdout = response.pop("stdout", None)
if stdout:
sys.stdout.write(stdout)
stderr = response.pop("stderr", None)
if stderr:
sys.stderr.write(stderr)
except (OSError, IPCException) as err:
return {"error": str(err)}
# TODO: Other errors, e.g. ValueError, UnicodeError
return response
def get_status(status_file: str) -> tuple[int, str]:
"""Read status file and check if the process is alive.
Return (pid, connection_name) on success.
Raise BadStatus if something's wrong.
"""
data = read_status(status_file)
return check_status(data)
def check_status(data: dict[str, Any]) -> tuple[int, str]:
"""Check if the process is alive.
Return (pid, connection_name) on success.
Raise BadStatus if something's wrong.
"""
if "pid" not in data:
raise BadStatus("Invalid status file (no pid field)")
pid = data["pid"]
if not isinstance(pid, int):
raise BadStatus("pid field is not an int")
if not alive(pid):
raise BadStatus("Daemon has died")
if "connection_name" not in data:
raise BadStatus("Invalid status file (no connection_name field)")
connection_name = data["connection_name"]
if not isinstance(connection_name, str):
raise BadStatus("connection_name field is not a string")
return pid, connection_name
def is_running(status_file: str) -> bool:
"""Check if the server is running cleanly"""
try:
get_status(status_file)
except BadStatus:
return False
return True
# Run main().
def console_entry() -> None:
main(sys.argv[1:])

View file

@ -0,0 +1,43 @@
from __future__ import annotations
import sys
from collections.abc import Callable
from typing import Any
if sys.platform == "win32":
import ctypes
import subprocess
from ctypes.wintypes import DWORD, HANDLE
PROCESS_QUERY_LIMITED_INFORMATION = ctypes.c_ulong(0x1000)
kernel32 = ctypes.windll.kernel32
OpenProcess: Callable[[DWORD, int, int], HANDLE] = kernel32.OpenProcess
GetExitCodeProcess: Callable[[HANDLE, Any], int] = kernel32.GetExitCodeProcess
else:
import os
import signal
def alive(pid: int) -> bool:
"""Is the process alive?"""
if sys.platform == "win32":
# why can't anything be easy...
status = DWORD()
handle = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, 0, pid)
GetExitCodeProcess(handle, ctypes.byref(status))
return status.value == 259 # STILL_ACTIVE
else:
try:
os.kill(pid, 0)
except OSError:
return False
return True
def kill(pid: int) -> None:
"""Kill the process."""
if sys.platform == "win32":
subprocess.check_output(f"taskkill /pid {pid} /f /t")
else:
os.kill(pid, signal.SIGKILL)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,117 @@
"""Shared code between dmypy.py and dmypy_server.py.
This should be pretty lightweight and not depend on other mypy code (other than ipc).
"""
from __future__ import annotations
import io
import json
from collections.abc import Iterable, Iterator
from types import TracebackType
from typing import Any, Final, TextIO
from mypy.ipc import IPCBase
DEFAULT_STATUS_FILE: Final = ".dmypy.json"
def receive(connection: IPCBase) -> Any:
"""Receive single JSON data frame from a connection.
Raise OSError if the data received is not valid JSON or if it is
not a dict.
"""
bdata = connection.read()
if not bdata:
raise OSError("No data received")
try:
data = json.loads(bdata)
except Exception as e:
raise OSError("Data received is not valid JSON") from e
if not isinstance(data, dict):
raise OSError(f"Data received is not a dict ({type(data)})")
return data
def send(connection: IPCBase, data: Any) -> None:
"""Send data to a connection encoded and framed.
The data must be JSON-serializable. We assume that a single send call is a
single frame to be sent on the connect.
"""
connection.write(json.dumps(data))
class WriteToConn(TextIO):
"""Helper class to write to a connection instead of standard output."""
def __init__(self, server: IPCBase, output_key: str, isatty: bool) -> None:
self.server = server
self.output_key = output_key
self._isatty = isatty
def __enter__(self) -> TextIO:
return self
def __exit__(
self,
t: type[BaseException] | None,
value: BaseException | None,
traceback: TracebackType | None,
) -> None:
pass
def __iter__(self) -> Iterator[str]:
raise io.UnsupportedOperation
def __next__(self) -> str:
raise io.UnsupportedOperation
def close(self) -> None:
pass
def fileno(self) -> int:
raise OSError
def flush(self) -> None:
pass
def isatty(self) -> bool:
return self._isatty
def read(self, n: int = 0) -> str:
raise io.UnsupportedOperation
def readable(self) -> bool:
return False
def readline(self, limit: int = 0) -> str:
raise io.UnsupportedOperation
def readlines(self, hint: int = 0) -> list[str]:
raise io.UnsupportedOperation
def seek(self, offset: int, whence: int = 0) -> int:
raise io.UnsupportedOperation
def seekable(self) -> bool:
return False
def tell(self) -> int:
raise io.UnsupportedOperation
def truncate(self, size: int | None = 0) -> int:
raise io.UnsupportedOperation
def write(self, output: str) -> int:
resp: dict[str, Any] = {self.output_key: output}
send(self.server, resp)
return len(output)
def writable(self) -> bool:
return True
def writelines(self, lines: Iterable[str]) -> None:
for s in lines:
self.write(s)

View file

@ -0,0 +1,301 @@
from __future__ import annotations
from collections.abc import Callable, Container
from typing import cast
from mypy.nodes import ARG_STAR, ARG_STAR2
from mypy.types import (
AnyType,
CallableType,
DeletedType,
ErasedType,
Instance,
LiteralType,
NoneType,
Overloaded,
Parameters,
ParamSpecType,
PartialType,
ProperType,
TupleType,
Type,
TypeAliasType,
TypedDictType,
TypeOfAny,
TypeTranslator,
TypeType,
TypeVarId,
TypeVarTupleType,
TypeVarType,
TypeVisitor,
UnboundType,
UninhabitedType,
UnionType,
UnpackType,
get_proper_type,
get_proper_types,
)
from mypy.typevartuples import erased_vars
def erase_type(typ: Type) -> ProperType:
"""Erase any type variables from a type.
Also replace tuple types with the corresponding concrete types.
Examples:
A -> A
B[X] -> B[Any]
Tuple[A, B] -> tuple
Callable[[A1, A2, ...], R] -> Callable[..., Any]
Type[X] -> Type[Any]
"""
typ = get_proper_type(typ)
return typ.accept(EraseTypeVisitor())
class EraseTypeVisitor(TypeVisitor[ProperType]):
def visit_unbound_type(self, t: UnboundType) -> ProperType:
# TODO: replace with an assert after UnboundType can't leak from semantic analysis.
return AnyType(TypeOfAny.from_error)
def visit_any(self, t: AnyType) -> ProperType:
return t
def visit_none_type(self, t: NoneType) -> ProperType:
return t
def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType:
return t
def visit_erased_type(self, t: ErasedType) -> ProperType:
return t
def visit_partial_type(self, t: PartialType) -> ProperType:
# Should not get here.
raise RuntimeError("Cannot erase partial types")
def visit_deleted_type(self, t: DeletedType) -> ProperType:
return t
def visit_instance(self, t: Instance) -> ProperType:
args = erased_vars(t.type.defn.type_vars, TypeOfAny.special_form)
return Instance(t.type, args, t.line)
def visit_type_var(self, t: TypeVarType) -> ProperType:
return AnyType(TypeOfAny.special_form)
def visit_param_spec(self, t: ParamSpecType) -> ProperType:
return AnyType(TypeOfAny.special_form)
def visit_parameters(self, t: Parameters) -> ProperType:
raise RuntimeError("Parameters should have been bound to a class")
def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
# Likely, we can never get here because of aggressive erasure of types that
# can contain this, but better still return a valid replacement.
return t.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)])
def visit_unpack_type(self, t: UnpackType) -> ProperType:
return AnyType(TypeOfAny.special_form)
def visit_callable_type(self, t: CallableType) -> ProperType:
# We must preserve the fallback type for overload resolution to work.
any_type = AnyType(TypeOfAny.special_form)
return CallableType(
arg_types=[any_type, any_type],
arg_kinds=[ARG_STAR, ARG_STAR2],
arg_names=[None, None],
ret_type=any_type,
fallback=t.fallback,
is_ellipsis_args=True,
implicit=True,
)
def visit_overloaded(self, t: Overloaded) -> ProperType:
return t.fallback.accept(self)
def visit_tuple_type(self, t: TupleType) -> ProperType:
return t.partial_fallback.accept(self)
def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
return t.fallback.accept(self)
def visit_literal_type(self, t: LiteralType) -> ProperType:
# The fallback for literal types should always be either
# something like int or str, or an enum class -- types that
# don't contain any TypeVars. So there's no need to visit it.
return t
def visit_union_type(self, t: UnionType) -> ProperType:
erased_items = [erase_type(item) for item in t.items]
from mypy.typeops import make_simplified_union
return make_simplified_union(erased_items)
def visit_type_type(self, t: TypeType) -> ProperType:
return TypeType.make_normalized(
t.item.accept(self), line=t.line, is_type_form=t.is_type_form
)
def visit_type_alias_type(self, t: TypeAliasType) -> ProperType:
raise RuntimeError("Type aliases should be expanded before accepting this visitor")
def erase_typevars(t: Type, ids_to_erase: Container[TypeVarId] | None = None) -> Type:
"""Replace all type variables in a type with any,
or just the ones in the provided collection.
"""
if ids_to_erase is None:
return t.accept(TypeVarEraser(None, AnyType(TypeOfAny.special_form)))
def erase_id(id: TypeVarId) -> bool:
return id in ids_to_erase
return t.accept(TypeVarEraser(erase_id, AnyType(TypeOfAny.special_form)))
def erase_meta_id(id: TypeVarId) -> bool:
return id.is_meta_var()
def replace_meta_vars(t: Type, target_type: Type) -> Type:
"""Replace unification variables in a type with the target type."""
return t.accept(TypeVarEraser(erase_meta_id, target_type))
class TypeVarEraser(TypeTranslator):
"""Implementation of type erasure"""
def __init__(self, erase_id: Callable[[TypeVarId], bool] | None, replacement: Type) -> None:
super().__init__()
self.erase_id = erase_id
self.replacement = replacement
def visit_type_var(self, t: TypeVarType) -> Type:
if self.erase_id is None or self.erase_id(t.id):
return self.replacement
return t
# TODO: below two methods duplicate some logic with expand_type().
# In fact, we may want to refactor this whole visitor to use expand_type().
def visit_instance(self, t: Instance) -> Type:
result = super().visit_instance(t)
assert isinstance(result, ProperType) and isinstance(result, Instance)
if t.type.fullname == "builtins.tuple":
# Normalize Tuple[*Tuple[X, ...], ...] -> Tuple[X, ...]
arg = result.args[0]
if isinstance(arg, UnpackType):
unpacked = get_proper_type(arg.type)
if isinstance(unpacked, Instance):
assert unpacked.type.fullname == "builtins.tuple"
return unpacked
return result
def visit_tuple_type(self, t: TupleType) -> Type:
result = super().visit_tuple_type(t)
assert isinstance(result, ProperType) and isinstance(result, TupleType)
if len(result.items) == 1:
# Normalize Tuple[*Tuple[X, ...]] -> Tuple[X, ...]
item = result.items[0]
if isinstance(item, UnpackType):
unpacked = get_proper_type(item.type)
if isinstance(unpacked, Instance):
assert unpacked.type.fullname == "builtins.tuple"
if result.partial_fallback.type.fullname != "builtins.tuple":
# If it is a subtype (like named tuple) we need to preserve it,
# this essentially mimics the logic in tuple_fallback().
return result.partial_fallback.accept(self)
return unpacked
return result
def visit_callable_type(self, t: CallableType) -> Type:
result = super().visit_callable_type(t)
assert isinstance(result, ProperType) and isinstance(result, CallableType)
# Usually this is done in semanal_typeargs.py, but erasure can create
# a non-normal callable from normal one.
result.normalize_trivial_unpack()
return result
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
if self.erase_id is None or self.erase_id(t.id):
return t.tuple_fallback.copy_modified(args=[self.replacement])
return t
def visit_param_spec(self, t: ParamSpecType) -> Type:
# TODO: we should probably preserve prefix here.
if self.erase_id is None or self.erase_id(t.id):
return self.replacement
return t
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Type alias target can't contain bound type variables (not bound by the type
# alias itself), so it is safe to just erase the arguments.
return t.copy_modified(args=[a.accept(self) for a in t.args])
def remove_instance_last_known_values(t: Type) -> Type:
return t.accept(LastKnownValueEraser())
class LastKnownValueEraser(TypeTranslator):
"""Removes the Literal[...] type that may be associated with any
Instance types."""
def visit_instance(self, t: Instance) -> Type:
if not t.last_known_value and not t.args:
return t
return t.copy_modified(args=[a.accept(self) for a in t.args], last_known_value=None)
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Type aliases can't contain literal values, because they are
# always constructed as explicit types.
return t
def visit_union_type(self, t: UnionType) -> Type:
new = cast(UnionType, super().visit_union_type(t))
# Erasure can result in many duplicate items; merge them.
# Call make_simplified_union only on lists of instance types
# that all have the same fullname, to avoid simplifying too
# much.
instances = [item for item in new.items if isinstance(get_proper_type(item), Instance)]
# Avoid merge in simple cases such as optional types.
if len(instances) > 1:
instances_by_name: dict[str, list[Instance]] = {}
p_new_items = get_proper_types(new.items)
for p_item in p_new_items:
if isinstance(p_item, Instance) and not p_item.args:
instances_by_name.setdefault(p_item.type.fullname, []).append(p_item)
merged: list[Type] = []
for item in new.items:
orig_item = item
item = get_proper_type(item)
if isinstance(item, Instance) and not item.args:
types = instances_by_name.get(item.type.fullname)
if types is not None:
if len(types) == 1:
merged.append(item)
else:
from mypy.typeops import make_simplified_union
merged.append(make_simplified_union(types))
del instances_by_name[item.type.fullname]
else:
merged.append(orig_item)
return UnionType.make_union(merged)
return new
def shallow_erase_type_for_equality(typ: Type) -> ProperType:
"""Erase type variables from Instance's"""
p_typ = get_proper_type(typ)
if isinstance(p_typ, Instance):
if not p_typ.args:
return p_typ
args = erased_vars(p_typ.type.defn.type_vars, TypeOfAny.special_form)
return Instance(p_typ.type, args, p_typ.line)
if isinstance(p_typ, UnionType):
items = [shallow_erase_type_for_equality(item) for item in p_typ.items]
return UnionType.make_union(items)
return p_typ

View file

@ -0,0 +1,39 @@
"""Defines the different custom formats in which mypy can output."""
import json
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from mypy.errors import MypyError
class ErrorFormatter(ABC):
"""Base class to define how errors are formatted before being printed."""
@abstractmethod
def report_error(self, error: "MypyError") -> str:
raise NotImplementedError
class JSONFormatter(ErrorFormatter):
"""Formatter for basic JSON output format."""
def report_error(self, error: "MypyError") -> str:
"""Prints out the errors as simple, static JSON lines."""
return json.dumps(
{
"file": error.file_path,
"line": error.line,
"column": error.column,
"end_line": error.end_line,
"end_column": error.end_column,
"message": error.message,
"hint": None if len(error.hints) == 0 else "\n".join(error.hints),
"code": error.errorcode,
"severity": error.severity,
}
)
OUTPUT_CHOICES = {"json": JSONFormatter()}

View file

@ -0,0 +1,333 @@
"""Classification of possible errors mypy can detect.
These can be used for filtering specific errors.
"""
from __future__ import annotations
from collections import defaultdict
from typing import Final
from mypy_extensions import mypyc_attr
error_codes: dict[str, ErrorCode] = {}
sub_code_map: dict[str, set[str]] = defaultdict(set)
@mypyc_attr(allow_interpreted_subclasses=True)
class ErrorCode:
def __init__(
self,
code: str,
description: str,
category: str,
default_enabled: bool = True,
sub_code_of: ErrorCode | None = None,
) -> None:
self.code = code
self.description = description
self.category = category
self.default_enabled = default_enabled
self.sub_code_of = sub_code_of
if sub_code_of is not None:
assert sub_code_of.sub_code_of is None, "Nested subcategories are not supported"
sub_code_map[sub_code_of.code].add(code)
error_codes[code] = self
def __str__(self) -> str:
return f"<ErrorCode {self.code}>"
def __repr__(self) -> str:
"""This doesn't fulfill the goals of repr but it's better than the default view."""
return f"<ErrorCode {self.category}: {self.code}>"
def __eq__(self, other: object) -> bool:
if not isinstance(other, ErrorCode):
return False
return self.code == other.code
def __hash__(self) -> int:
return hash((self.code,))
ATTR_DEFINED: Final = ErrorCode("attr-defined", "Check that attribute exists", "General")
NAME_DEFINED: Final = ErrorCode("name-defined", "Check that name is defined", "General")
CALL_ARG: Final = ErrorCode(
"call-arg", "Check number, names and kinds of arguments in calls", "General"
)
ARG_TYPE: Final = ErrorCode("arg-type", "Check argument types in calls", "General")
CALL_OVERLOAD: Final = ErrorCode(
"call-overload", "Check that an overload variant matches arguments", "General"
)
VALID_TYPE: Final = ErrorCode("valid-type", "Check that type (annotation) is valid", "General")
NONETYPE_TYPE: Final = ErrorCode(
"nonetype-type", "Check that type (annotation) is not NoneType", "General"
)
VAR_ANNOTATED: Final = ErrorCode(
"var-annotated", "Require variable annotation if type can't be inferred", "General"
)
OVERRIDE: Final = ErrorCode(
"override", "Check that method override is compatible with base class", "General"
)
RETURN: Final = ErrorCode("return", "Check that function always returns a value", "General")
RETURN_VALUE: Final = ErrorCode(
"return-value", "Check that return value is compatible with signature", "General"
)
ASSIGNMENT: Final = ErrorCode(
"assignment", "Check that assigned value is compatible with target", "General"
)
METHOD_ASSIGN: Final = ErrorCode(
"method-assign",
"Check that assignment target is not a method",
"General",
sub_code_of=ASSIGNMENT,
)
TYPE_ARG: Final = ErrorCode("type-arg", "Check that generic type arguments are present", "General")
TYPE_VAR: Final = ErrorCode("type-var", "Check that type variable values are valid", "General")
UNION_ATTR: Final = ErrorCode(
"union-attr", "Check that attribute exists in each item of a union", "General"
)
INDEX: Final = ErrorCode("index", "Check indexing operations", "General")
OPERATOR: Final = ErrorCode("operator", "Check that operator is valid for operands", "General")
LIST_ITEM: Final = ErrorCode(
"list-item", "Check list items in a list expression [item, ...]", "General"
)
DICT_ITEM: Final = ErrorCode(
"dict-item", "Check dict items in a dict expression {key: value, ...}", "General"
)
TYPEDDICT_ITEM: Final = ErrorCode(
"typeddict-item", "Check items when constructing TypedDict", "General"
)
TYPEDDICT_UNKNOWN_KEY: Final = ErrorCode(
"typeddict-unknown-key",
"Check unknown keys when constructing TypedDict",
"General",
sub_code_of=TYPEDDICT_ITEM,
)
HAS_TYPE: Final = ErrorCode(
"has-type", "Check that type of reference can be determined", "General"
)
IMPORT: Final = ErrorCode(
"import", "Require that imported module can be found or has stubs", "General"
)
IMPORT_NOT_FOUND: Final = ErrorCode(
"import-not-found", "Require that imported module can be found", "General", sub_code_of=IMPORT
)
IMPORT_UNTYPED: Final = ErrorCode(
"import-untyped", "Require that imported module has stubs", "General", sub_code_of=IMPORT
)
NO_REDEF: Final = ErrorCode("no-redef", "Check that each name is defined once", "General")
FUNC_RETURNS_VALUE: Final = ErrorCode(
"func-returns-value", "Check that called function returns a value in value context", "General"
)
ABSTRACT: Final = ErrorCode(
"abstract", "Prevent instantiation of classes with abstract attributes", "General"
)
TYPE_ABSTRACT: Final = ErrorCode(
"type-abstract", "Require only concrete classes where Type[...] is expected", "General"
)
VALID_NEWTYPE: Final = ErrorCode(
"valid-newtype", "Check that argument 2 to NewType is valid", "General"
)
STRING_FORMATTING: Final = ErrorCode(
"str-format", "Check that string formatting/interpolation is type-safe", "General"
)
STR_BYTES_PY3: Final = ErrorCode(
"str-bytes-safe", "Warn about implicit coercions related to bytes and string types", "General"
)
EXIT_RETURN: Final = ErrorCode(
"exit-return", "Warn about too general return type for '__exit__'", "General"
)
LITERAL_REQ: Final = ErrorCode("literal-required", "Check that value is a literal", "General")
UNUSED_COROUTINE: Final = ErrorCode(
"unused-coroutine", "Ensure that all coroutines are used", "General"
)
EMPTY_BODY: Final = ErrorCode(
"empty-body",
"A dedicated error code to opt out return errors for empty/trivial bodies",
"General",
)
SAFE_SUPER: Final = ErrorCode(
"safe-super", "Warn about calls to abstract methods with empty/trivial bodies", "General"
)
TOP_LEVEL_AWAIT: Final = ErrorCode(
"top-level-await", "Warn about top level await expressions", "General"
)
AWAIT_NOT_ASYNC: Final = ErrorCode(
"await-not-async", 'Warn about "await" outside coroutine ("async def")', "General"
)
# These error codes aren't enabled by default.
NO_UNTYPED_DEF: Final = ErrorCode(
"no-untyped-def", "Check that every function has an annotation", "General"
)
NO_UNTYPED_CALL: Final = ErrorCode(
"no-untyped-call",
"Disallow calling functions without type annotations from annotated functions",
"General",
)
REDUNDANT_CAST: Final = ErrorCode(
"redundant-cast", "Check that cast changes type of expression", "General"
)
ASSERT_TYPE: Final = ErrorCode("assert-type", "Check that assert_type() call succeeds", "General")
COMPARISON_OVERLAP: Final = ErrorCode(
"comparison-overlap", "Check that types in comparisons and 'in' expressions overlap", "General"
)
NO_ANY_UNIMPORTED: Final = ErrorCode(
"no-any-unimported", 'Reject "Any" types from unfollowed imports', "General"
)
NO_ANY_RETURN: Final = ErrorCode(
"no-any-return",
'Reject returning value with "Any" type if return type is not "Any"',
"General",
)
UNREACHABLE: Final = ErrorCode(
"unreachable", "Warn about unreachable statements or expressions", "General"
)
ANNOTATION_UNCHECKED: Final = ErrorCode(
"annotation-unchecked", "Notify about type annotations in unchecked functions", "General"
)
TYPEDDICT_READONLY_MUTATED: Final = ErrorCode(
"typeddict-readonly-mutated", "TypedDict's ReadOnly key is mutated", "General"
)
POSSIBLY_UNDEFINED: Final = ErrorCode(
"possibly-undefined",
"Warn about variables that are defined only in some execution paths",
"General",
default_enabled=False,
)
REDUNDANT_EXPR: Final = ErrorCode(
"redundant-expr", "Warn about redundant expressions", "General", default_enabled=False
)
TRUTHY_BOOL: Final = ErrorCode(
"truthy-bool",
"Warn about expressions that could always evaluate to true in boolean contexts",
"General",
default_enabled=False,
)
TRUTHY_FUNCTION: Final = ErrorCode(
"truthy-function",
"Warn about function that always evaluate to true in boolean contexts",
"General",
)
TRUTHY_ITERABLE: Final = ErrorCode(
"truthy-iterable",
"Warn about Iterable expressions that could always evaluate to true in boolean contexts",
"General",
default_enabled=False,
)
STR_UNPACK: Final[ErrorCode] = ErrorCode(
"str-unpack", "Warn about expressions that unpack str", "General"
)
NAME_MATCH: Final = ErrorCode(
"name-match", "Check that type definition has consistent naming", "General"
)
NO_OVERLOAD_IMPL: Final = ErrorCode(
"no-overload-impl",
"Check that overloaded functions outside stub files have an implementation",
"General",
)
IGNORE_WITHOUT_CODE: Final = ErrorCode(
"ignore-without-code",
"Warn about '# type: ignore' comments which do not have error codes",
"General",
default_enabled=False,
)
UNUSED_AWAITABLE: Final = ErrorCode(
"unused-awaitable",
"Ensure that all awaitable values are used",
"General",
default_enabled=False,
)
REDUNDANT_SELF_TYPE: Final = ErrorCode(
"redundant-self",
"Warn about redundant Self type annotations on method first argument",
"General",
default_enabled=False,
)
USED_BEFORE_DEF: Final = ErrorCode(
"used-before-def", "Warn about variables that are used before they are defined", "General"
)
UNUSED_IGNORE: Final = ErrorCode(
"unused-ignore", "Ensure that all type ignores are used", "General", default_enabled=False
)
EXPLICIT_OVERRIDE_REQUIRED: Final = ErrorCode(
"explicit-override",
"Require @override decorator if method is overriding a base class method",
"General",
default_enabled=False,
)
UNIMPORTED_REVEAL: Final = ErrorCode(
"unimported-reveal",
"Require explicit import from typing or typing_extensions for reveal_type",
"General",
default_enabled=False,
)
MUTABLE_OVERRIDE: Final = ErrorCode(
"mutable-override",
"Reject covariant overrides for mutable attributes",
"General",
default_enabled=False,
)
EXHAUSTIVE_MATCH: Final = ErrorCode(
"exhaustive-match",
"Reject match statements that are not exhaustive",
"General",
default_enabled=False,
)
METACLASS: Final = ErrorCode("metaclass", "Ensure that metaclass is valid", "General")
MAYBE_UNRECOGNIZED_STR_TYPEFORM: Final = ErrorCode(
"maybe-unrecognized-str-typeform",
"Error when a string is used where a TypeForm is expected but a string annotation cannot be recognized",
"General",
)
# Syntax errors are often blocking.
SYNTAX: Final = ErrorCode("syntax", "Report syntax errors", "General")
# This is a catch-all for remaining uncategorized errors.
MISC: Final = ErrorCode("misc", "Miscellaneous other checks", "General")
OVERLOAD_CANNOT_MATCH: Final = ErrorCode(
"overload-cannot-match",
"Warn if an @overload signature can never be matched",
"General",
sub_code_of=MISC,
)
OVERLOAD_OVERLAP: Final = ErrorCode(
"overload-overlap",
"Warn if multiple @overload variants overlap in unsafe ways",
"General",
sub_code_of=MISC,
)
PROPERTY_DECORATOR: Final = ErrorCode(
"prop-decorator",
"Decorators on top of @property are not supported",
"General",
sub_code_of=MISC,
)
UNTYPED_DECORATOR: Final = ErrorCode(
"untyped-decorator", "Error if an untyped decorator makes a typed function untyped", "General"
)
NARROWED_TYPE_NOT_SUBTYPE: Final = ErrorCode(
"narrowed-type-not-subtype",
"Warn if a TypeIs function's narrowed type is not a subtype of the original type",
"General",
)
EXPLICIT_ANY: Final = ErrorCode(
"explicit-any", "Warn about explicit Any type annotations", "General"
)
DEPRECATED: Final = ErrorCode(
"deprecated",
"Warn when importing or using deprecated (overloaded) functions, methods or classes",
"General",
default_enabled=False,
)
# This copy will not include any error codes defined later in the plugins.
mypy_error_codes = error_codes.copy()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,211 @@
"""
Evaluate an expression.
Used by stubtest; in a separate file because things break if we don't
put it in a mypyc-compiled file.
"""
import ast
from typing import Final
import mypy.nodes
from mypy.visitor import ExpressionVisitor
UNKNOWN = object()
class _NodeEvaluator(ExpressionVisitor[object]):
def visit_int_expr(self, o: mypy.nodes.IntExpr) -> int:
return o.value
def visit_str_expr(self, o: mypy.nodes.StrExpr) -> str:
return o.value
def visit_bytes_expr(self, o: mypy.nodes.BytesExpr) -> object:
# The value of a BytesExpr is a string created from the repr()
# of the bytes object. Get the original bytes back.
try:
return ast.literal_eval(f"b'{o.value}'")
except SyntaxError:
return ast.literal_eval(f'b"{o.value}"')
def visit_float_expr(self, o: mypy.nodes.FloatExpr) -> float:
return o.value
def visit_complex_expr(self, o: mypy.nodes.ComplexExpr) -> object:
return o.value
def visit_ellipsis(self, o: mypy.nodes.EllipsisExpr) -> object:
return Ellipsis
def visit_star_expr(self, o: mypy.nodes.StarExpr) -> object:
return UNKNOWN
def visit_name_expr(self, o: mypy.nodes.NameExpr) -> object:
if o.name == "True":
return True
elif o.name == "False":
return False
elif o.name == "None":
return None
# TODO: Handle more names by figuring out a way to hook into the
# symbol table.
return UNKNOWN
def visit_member_expr(self, o: mypy.nodes.MemberExpr) -> object:
return UNKNOWN
def visit_yield_from_expr(self, o: mypy.nodes.YieldFromExpr) -> object:
return UNKNOWN
def visit_yield_expr(self, o: mypy.nodes.YieldExpr) -> object:
return UNKNOWN
def visit_call_expr(self, o: mypy.nodes.CallExpr) -> object:
return UNKNOWN
def visit_op_expr(self, o: mypy.nodes.OpExpr) -> object:
return UNKNOWN
def visit_comparison_expr(self, o: mypy.nodes.ComparisonExpr) -> object:
return UNKNOWN
def visit_cast_expr(self, o: mypy.nodes.CastExpr) -> object:
return o.expr.accept(self)
def visit_type_form_expr(self, o: mypy.nodes.TypeFormExpr) -> object:
return UNKNOWN
def visit_assert_type_expr(self, o: mypy.nodes.AssertTypeExpr) -> object:
return o.expr.accept(self)
def visit_reveal_expr(self, o: mypy.nodes.RevealExpr) -> object:
return UNKNOWN
def visit_super_expr(self, o: mypy.nodes.SuperExpr) -> object:
return UNKNOWN
def visit_unary_expr(self, o: mypy.nodes.UnaryExpr) -> object:
operand = o.expr.accept(self)
if operand is UNKNOWN:
return UNKNOWN
if o.op == "-":
if isinstance(operand, (int, float, complex)):
return -operand
elif o.op == "+":
if isinstance(operand, (int, float, complex)):
return +operand
elif o.op == "~":
if isinstance(operand, int):
return ~operand
elif o.op == "not":
if isinstance(operand, (bool, int, float, str, bytes)):
return not operand
return UNKNOWN
def visit_assignment_expr(self, o: mypy.nodes.AssignmentExpr) -> object:
return o.value.accept(self)
def visit_list_expr(self, o: mypy.nodes.ListExpr) -> object:
items = [item.accept(self) for item in o.items]
if all(item is not UNKNOWN for item in items):
return items
return UNKNOWN
def visit_dict_expr(self, o: mypy.nodes.DictExpr) -> object:
items = [
(UNKNOWN if key is None else key.accept(self), value.accept(self))
for key, value in o.items
]
if all(key is not UNKNOWN and value is not None for key, value in items):
return dict(items)
return UNKNOWN
def visit_tuple_expr(self, o: mypy.nodes.TupleExpr) -> object:
items = [item.accept(self) for item in o.items]
if all(item is not UNKNOWN for item in items):
return tuple(items)
return UNKNOWN
def visit_set_expr(self, o: mypy.nodes.SetExpr) -> object:
items = [item.accept(self) for item in o.items]
if all(item is not UNKNOWN for item in items):
return set(items)
return UNKNOWN
def visit_index_expr(self, o: mypy.nodes.IndexExpr) -> object:
return UNKNOWN
def visit_type_application(self, o: mypy.nodes.TypeApplication) -> object:
return UNKNOWN
def visit_lambda_expr(self, o: mypy.nodes.LambdaExpr) -> object:
return UNKNOWN
def visit_list_comprehension(self, o: mypy.nodes.ListComprehension) -> object:
return UNKNOWN
def visit_set_comprehension(self, o: mypy.nodes.SetComprehension) -> object:
return UNKNOWN
def visit_dictionary_comprehension(self, o: mypy.nodes.DictionaryComprehension) -> object:
return UNKNOWN
def visit_generator_expr(self, o: mypy.nodes.GeneratorExpr) -> object:
return UNKNOWN
def visit_slice_expr(self, o: mypy.nodes.SliceExpr) -> object:
return UNKNOWN
def visit_conditional_expr(self, o: mypy.nodes.ConditionalExpr) -> object:
return UNKNOWN
def visit_type_var_expr(self, o: mypy.nodes.TypeVarExpr) -> object:
return UNKNOWN
def visit_paramspec_expr(self, o: mypy.nodes.ParamSpecExpr) -> object:
return UNKNOWN
def visit_type_var_tuple_expr(self, o: mypy.nodes.TypeVarTupleExpr) -> object:
return UNKNOWN
def visit_type_alias_expr(self, o: mypy.nodes.TypeAliasExpr) -> object:
return UNKNOWN
def visit_namedtuple_expr(self, o: mypy.nodes.NamedTupleExpr) -> object:
return UNKNOWN
def visit_enum_call_expr(self, o: mypy.nodes.EnumCallExpr) -> object:
return UNKNOWN
def visit_typeddict_expr(self, o: mypy.nodes.TypedDictExpr) -> object:
return UNKNOWN
def visit_newtype_expr(self, o: mypy.nodes.NewTypeExpr) -> object:
return UNKNOWN
def visit__promote_expr(self, o: mypy.nodes.PromoteExpr) -> object:
return UNKNOWN
def visit_await_expr(self, o: mypy.nodes.AwaitExpr) -> object:
return UNKNOWN
def visit_template_str_expr(self, o: mypy.nodes.TemplateStrExpr) -> object:
return UNKNOWN
def visit_temp_node(self, o: mypy.nodes.TempNode) -> object:
return UNKNOWN
_evaluator: Final = _NodeEvaluator()
def evaluate_expression(expr: mypy.nodes.Expression) -> object:
"""Evaluate an expression at runtime.
Return the result of the expression, or UNKNOWN if the expression cannot be
evaluated.
"""
return expr.accept(_evaluator)

View file

@ -0,0 +1,662 @@
from __future__ import annotations
from collections.abc import Iterable, Mapping
from typing import Final, TypeVar, cast, overload
from mypy.nodes import ARG_STAR, ArgKind, FakeInfo, Var
from mypy.state import state
from mypy.types import (
ANY_STRATEGY,
AnyType,
BoolTypeQuery,
CallableType,
DeletedType,
ErasedType,
FunctionLike,
Instance,
LiteralType,
NoneType,
Overloaded,
Parameters,
ParamSpecFlavor,
ParamSpecType,
PartialType,
ProperType,
TrivialSyntheticTypeTranslator,
TupleType,
Type,
TypeAliasType,
TypedDictType,
TypeOfAny,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UnboundType,
UninhabitedType,
UnionType,
UnpackType,
flatten_nested_unions,
get_proper_type,
split_with_prefix_and_suffix,
)
from mypy.typevartuples import split_with_instance
# Solving the import cycle:
import mypy.type_visitor # ruff: isort: skip
# WARNING: these functions should never (directly or indirectly) depend on
# is_subtype(), meet_types(), join_types() etc.
# TODO: add a static dependency test for this.
@overload
def expand_type(typ: CallableType, env: Mapping[TypeVarId, Type]) -> CallableType: ...
@overload
def expand_type(typ: ProperType, env: Mapping[TypeVarId, Type]) -> ProperType: ...
@overload
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type: ...
def expand_type(typ: Type, env: Mapping[TypeVarId, Type]) -> Type:
"""Substitute any type variable references in a type given by a type
environment.
"""
return typ.accept(ExpandTypeVisitor(env))
@overload
def expand_type_by_instance(typ: CallableType, instance: Instance) -> CallableType: ...
@overload
def expand_type_by_instance(typ: ProperType, instance: Instance) -> ProperType: ...
@overload
def expand_type_by_instance(typ: Type, instance: Instance) -> Type: ...
def expand_type_by_instance(typ: Type, instance: Instance) -> Type:
"""Substitute type variables in type using values from an Instance.
Type variables are considered to be bound by the class declaration."""
if not instance.args and not instance.type.has_type_var_tuple_type:
return typ
else:
variables: dict[TypeVarId, Type] = {}
if instance.type.has_type_var_tuple_type:
assert instance.type.type_var_tuple_prefix is not None
assert instance.type.type_var_tuple_suffix is not None
args_prefix, args_middle, args_suffix = split_with_instance(instance)
tvars_prefix, tvars_middle, tvars_suffix = split_with_prefix_and_suffix(
tuple(instance.type.defn.type_vars),
instance.type.type_var_tuple_prefix,
instance.type.type_var_tuple_suffix,
)
tvar = tvars_middle[0]
assert isinstance(tvar, TypeVarTupleType)
variables = {tvar.id: TupleType(list(args_middle), tvar.tuple_fallback)}
instance_args = args_prefix + args_suffix
tvars = tvars_prefix + tvars_suffix
else:
tvars = tuple(instance.type.defn.type_vars)
instance_args = instance.args
for binder, arg in zip(tvars, instance_args):
assert isinstance(binder, TypeVarLikeType)
variables[binder.id] = arg
return expand_type(typ, variables)
F = TypeVar("F", bound=FunctionLike)
def freshen_function_type_vars(callee: F) -> F:
"""Substitute fresh type variables for generic function type variables."""
if isinstance(callee, CallableType):
if not callee.is_generic():
return callee
tvs = []
tvmap: dict[TypeVarId, Type] = {}
for v in callee.variables:
tv = v.new_unification_variable(v)
tvs.append(tv)
tvmap[v.id] = tv
fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)
return cast(F, fresh)
else:
assert isinstance(callee, Overloaded)
fresh_overload = Overloaded([freshen_function_type_vars(item) for item in callee.items])
return cast(F, fresh_overload)
class HasGenericCallable(BoolTypeQuery):
def __init__(self) -> None:
super().__init__(ANY_STRATEGY)
def visit_callable_type(self, t: CallableType) -> bool:
return t.is_generic() or super().visit_callable_type(t)
# Share a singleton since this is performance sensitive
has_generic_callable: Final = HasGenericCallable()
T = TypeVar("T", bound=Type)
def freshen_all_functions_type_vars(t: T) -> T:
result: Type
has_generic_callable.reset()
if not t.accept(has_generic_callable):
return t # Fast path to avoid expensive freshening
else:
result = t.accept(FreshenCallableVisitor())
assert isinstance(result, type(t))
return result
class FreshenCallableVisitor(mypy.type_visitor.TypeTranslator):
def visit_callable_type(self, t: CallableType) -> Type:
result = super().visit_callable_type(t)
assert isinstance(result, ProperType) and isinstance(result, CallableType)
return freshen_function_type_vars(result)
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Same as for ExpandTypeVisitor
return t.copy_modified(args=[arg.accept(self) for arg in t.args])
class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
"""Visitor that substitutes type variables with values."""
variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value
def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
super().__init__()
self.variables = variables
self.recursive_tvar_guard: dict[TypeVarId, Type | None] | None = None
def visit_unbound_type(self, t: UnboundType) -> Type:
return t
def visit_any(self, t: AnyType) -> Type:
return t
def visit_none_type(self, t: NoneType) -> Type:
return t
def visit_uninhabited_type(self, t: UninhabitedType) -> Type:
return t
def visit_deleted_type(self, t: DeletedType) -> Type:
return t
def visit_erased_type(self, t: ErasedType) -> Type:
# This may happen during type inference if some function argument
# type is a generic callable, and its erased form will appear in inferred
# constraints, then solver may check subtyping between them, which will trigger
# unify_generic_callables(), this is why we can get here. Another example is
# when inferring type of lambda in generic context, the lambda body contains
# a generic method in generic class.
return t
def visit_instance(self, t: Instance) -> Type:
if len(t.args) == 0:
return t
args = self.expand_type_tuple_with_unpack(t.args)
if isinstance(t.type, FakeInfo):
# The type checker expands function definitions and bodies
# if they depend on constrained type variables but the body
# might contain a tuple type comment (e.g., # type: (int, float)),
# in which case 't.type' is not yet available.
#
# See: https://github.com/python/mypy/issues/16649
return t.copy_modified(args=args)
if t.type.fullname == "builtins.tuple":
# Normalize Tuple[*Tuple[X, ...], ...] -> Tuple[X, ...]
arg = args[0]
if isinstance(arg, UnpackType):
unpacked = get_proper_type(arg.type)
if isinstance(unpacked, Instance):
# TODO: this and similar asserts below may be unsafe because get_proper_type()
# may be called during semantic analysis before all invalid types are removed.
assert unpacked.type.fullname == "builtins.tuple"
args = list(unpacked.args)
return t.copy_modified(args=args)
def visit_type_var(self, t: TypeVarType) -> Type:
# Normally upper bounds can't contain other type variables, the only exception is
# special type variable Self`0 <: C[T, S], where C is the class where Self is used.
if t.id.is_self():
t = t.copy_modified(upper_bound=t.upper_bound.accept(self))
repl = self.variables.get(t.id, t)
if isinstance(repl, ProperType) and isinstance(repl, Instance):
# TODO: do we really need to do this?
# If I try to remove this special-casing ~40 tests fail on reveal_type().
return repl.copy_modified(last_known_value=None)
if isinstance(repl, TypeVarType) and repl.has_default():
if self.recursive_tvar_guard is None:
self.recursive_tvar_guard = {}
if (tvar_id := repl.id) in self.recursive_tvar_guard:
return self.recursive_tvar_guard[tvar_id] or repl
self.recursive_tvar_guard[tvar_id] = None
repl.default = repl.default.accept(self)
expanded = repl.accept(self) # Note: `expanded is repl` may be true.
repl = repl if isinstance(expanded, TypeVarType) else expanded
self.recursive_tvar_guard[tvar_id] = repl
return repl
def visit_param_spec(self, t: ParamSpecType) -> Type:
# Set prefix to something empty, so we don't duplicate it below.
repl = self.variables.get(t.id, t.copy_modified(prefix=Parameters([], [], [])))
if isinstance(repl, ParamSpecType):
return repl.copy_modified(
flavor=t.flavor,
prefix=t.prefix.copy_modified(
arg_types=self.expand_types(t.prefix.arg_types) + repl.prefix.arg_types,
arg_kinds=t.prefix.arg_kinds + repl.prefix.arg_kinds,
arg_names=t.prefix.arg_names + repl.prefix.arg_names,
),
)
elif isinstance(repl, Parameters):
assert isinstance(t.upper_bound, ProperType) and isinstance(t.upper_bound, Instance)
if t.flavor == ParamSpecFlavor.BARE:
return Parameters(
self.expand_types(t.prefix.arg_types) + repl.arg_types,
t.prefix.arg_kinds + repl.arg_kinds,
t.prefix.arg_names + repl.arg_names,
variables=[*t.prefix.variables, *repl.variables],
imprecise_arg_kinds=repl.imprecise_arg_kinds,
)
elif t.flavor == ParamSpecFlavor.ARGS:
assert all(k.is_positional() for k in t.prefix.arg_kinds)
return self._possible_callable_varargs(
repl, list(t.prefix.arg_types), t.upper_bound
)
else:
assert t.flavor == ParamSpecFlavor.KWARGS
return self._possible_callable_kwargs(repl, t.upper_bound)
else:
# We could encode Any as trivial parameters etc., but it would be too verbose.
# TODO: assert this is a trivial type, like Any, Never, or object.
return repl
@classmethod
def _possible_callable_varargs(
cls, repl: Parameters, required_prefix: list[Type], tuple_type: Instance
) -> ProperType:
"""Given a callable, extract all parameters that can be passed as `*args`.
This builds a union of all (possibly variadic) tuples representing all possible
argument sequences that can be passed positionally. Each such tuple starts with
all required (pos-only without a default) arguments, followed by some prefix
of other arguments that can be passed positionally.
"""
required_posargs = required_prefix
if repl.variables:
# We will tear the callable apart, do not leak type variables
return tuple_type
optional_posargs: list[Type] = []
for kind, name, type in zip(repl.arg_kinds, repl.arg_names, repl.arg_types):
if kind == ArgKind.ARG_POS and name is None:
if optional_posargs:
# May happen following Unpack expansion without kinds correction
required_posargs += optional_posargs
optional_posargs = []
required_posargs.append(type)
elif kind.is_positional():
optional_posargs.append(type)
elif kind == ArgKind.ARG_STAR:
if isinstance(type, UnpackType):
optional_posargs.append(type)
else:
optional_posargs.append(UnpackType(Instance(tuple_type.type, [type])))
break
return UnionType.make_union(
[
TupleType(required_posargs + optional_posargs[:i], fallback=tuple_type)
for i in range(len(optional_posargs) + 1)
]
)
@classmethod
def _possible_callable_kwargs(cls, repl: Parameters, dict_type: Instance) -> ProperType:
"""Given a callable, extract all parameters that can be passed as `**kwargs`.
If the function only accepts **kwargs, this will be a `dict[str, KwargsValueType]`.
Otherwise, this will be a `TypedDict` containing all explicit args and ignoring
`**kwargs` (until PEP 728 `extra_items` is supported). TypedDict entries will
be required iff the corresponding argument is kw-only and has no default.
"""
if repl.variables:
# We will tear the callable apart, do not leak type variables
return dict_type
kwargs = {}
required_names = set()
extra_items: Type = UninhabitedType()
for kind, name, type in zip(repl.arg_kinds, repl.arg_names, repl.arg_types):
if kind == ArgKind.ARG_NAMED and name is not None:
kwargs[name] = type
required_names.add(name)
elif kind == ArgKind.ARG_STAR2:
# Unpack[TypedDict] is normalized early, it isn't stored as Unpack
extra_items = type
elif not kind.is_star() and name is not None:
kwargs[name] = type
if not kwargs:
return Instance(dict_type.type, [dict_type.args[0], extra_items])
# TODO: when PEP 728 is implemented, pass extra_items below.
return TypedDictType(kwargs, required_names, set(), fallback=dict_type)
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
# Sometimes solver may need to expand a type variable with (a copy of) itself
# (usually together with other TypeVars, but it is hard to filter out TypeVarTuples).
repl = self.variables.get(t.id, t)
if isinstance(repl, TypeVarTupleType):
return repl
elif isinstance(repl, ProperType) and isinstance(repl, (AnyType, UninhabitedType)):
# Some failed inference scenarios will try to set all type variables to Never.
# Instead of being picky and require all the callers to wrap them,
# do this here instead.
# Note: most cases when this happens are handled in expand unpack below, but
# in rare cases (e.g. ParamSpec containing Unpack star args) it may be skipped.
return t.tuple_fallback.copy_modified(args=[repl])
raise NotImplementedError
def visit_unpack_type(self, t: UnpackType) -> Type:
# It is impossible to reasonably implement visit_unpack_type, because
# unpacking inherently expands to something more like a list of types.
#
# Relevant sections that can call unpack should call expand_unpack()
# instead.
# However, if the item is a variadic tuple, we can simply carry it over.
# In particular, if we expand A[*tuple[T, ...]] with substitutions {T: str},
# it is hard to assert this without getting proper type. Another important
# example is non-normalized types when called from semanal.py.
return UnpackType(t.type.accept(self))
def expand_unpack(self, t: UnpackType) -> list[Type]:
assert isinstance(t.type, TypeVarTupleType)
repl = get_proper_type(self.variables.get(t.type.id, t.type))
if isinstance(repl, UnpackType):
repl = get_proper_type(repl.type)
if isinstance(repl, TupleType):
return repl.items
elif (
isinstance(repl, Instance)
and repl.type.fullname == "builtins.tuple"
or isinstance(repl, TypeVarTupleType)
):
return [UnpackType(typ=repl)]
elif isinstance(repl, (AnyType, UninhabitedType)):
# Replace *Ts = Any with *Ts = *tuple[Any, ...] and same for Never.
# These types may appear here as a result of user error or failed inference.
return [UnpackType(t.type.tuple_fallback.copy_modified(args=[repl]))]
else:
raise RuntimeError(f"Invalid type replacement to expand: {repl}")
def visit_parameters(self, t: Parameters) -> Type:
return t.copy_modified(arg_types=self.expand_types(t.arg_types))
def interpolate_args_for_unpack(self, t: CallableType, var_arg: UnpackType) -> list[Type]:
star_index = t.arg_kinds.index(ARG_STAR)
prefix = self.expand_types(t.arg_types[:star_index])
suffix = self.expand_types(t.arg_types[star_index + 1 :])
var_arg_type = get_proper_type(var_arg.type)
new_unpack: Type
if isinstance(var_arg_type, TupleType):
# We have something like Unpack[Tuple[Unpack[Ts], X1, X2]]
expanded_tuple = var_arg_type.accept(self)
assert isinstance(expanded_tuple, ProperType) and isinstance(expanded_tuple, TupleType)
expanded_items = expanded_tuple.items
fallback = var_arg_type.partial_fallback
new_unpack = UnpackType(TupleType(expanded_items, fallback))
elif isinstance(var_arg_type, TypeVarTupleType):
# We have plain Unpack[Ts]
fallback = var_arg_type.tuple_fallback
expanded_items = self.expand_unpack(var_arg)
new_unpack = UnpackType(TupleType(expanded_items, fallback))
# Since get_proper_type() may be called in semanal.py before callable
# normalization happens, we need to also handle non-normal cases here.
elif isinstance(var_arg_type, Instance):
# we have something like Unpack[Tuple[Any, ...]]
new_unpack = UnpackType(var_arg.type.accept(self))
else:
# We have invalid type in Unpack. This can happen when expanding aliases
# to Callable[[*Invalid], Ret]
new_unpack = AnyType(TypeOfAny.from_error, line=var_arg.line, column=var_arg.column)
return prefix + [new_unpack] + suffix
def visit_callable_type(self, t: CallableType) -> CallableType:
param_spec = t.param_spec()
if param_spec is not None:
repl = self.variables.get(param_spec.id)
# If a ParamSpec in a callable type is substituted with a
# callable type, we can't use normal substitution logic,
# since ParamSpec is actually split into two components
# *P.args and **P.kwargs in the original type. Instead, we
# must expand both of them with all the argument types,
# kinds and names in the replacement. The return type in
# the replacement is ignored.
if isinstance(repl, Parameters):
# We need to expand both the types in the prefix and the ParamSpec itself
expanded = t.copy_modified(
arg_types=self.expand_types(t.arg_types[:-2]) + repl.arg_types,
arg_kinds=t.arg_kinds[:-2] + repl.arg_kinds,
arg_names=t.arg_names[:-2] + repl.arg_names,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
imprecise_arg_kinds=(t.imprecise_arg_kinds or repl.imprecise_arg_kinds),
variables=[*repl.variables, *t.variables],
)
var_arg = expanded.var_arg()
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
# Sometimes we get new unpacks after expanding ParamSpec.
expanded.normalize_trivial_unpack()
return expanded
elif isinstance(repl, ParamSpecType):
# We're substituting one ParamSpec for another; this can mean that the prefix
# changes, e.g. substitute Concatenate[int, P] in place of Q.
prefix = repl.prefix
clean_repl = repl.copy_modified(prefix=Parameters([], [], []))
return t.copy_modified(
arg_types=self.expand_types(t.arg_types[:-2])
+ prefix.arg_types
+ [
clean_repl.with_flavor(ParamSpecFlavor.ARGS),
clean_repl.with_flavor(ParamSpecFlavor.KWARGS),
],
arg_kinds=t.arg_kinds[:-2] + prefix.arg_kinds + t.arg_kinds[-2:],
arg_names=t.arg_names[:-2] + prefix.arg_names + t.arg_names[-2:],
ret_type=t.ret_type.accept(self),
from_concatenate=t.from_concatenate or bool(repl.prefix.arg_types),
imprecise_arg_kinds=(t.imprecise_arg_kinds or prefix.imprecise_arg_kinds),
)
var_arg = t.var_arg()
needs_normalization = False
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
needs_normalization = True
arg_types = self.interpolate_args_for_unpack(t, var_arg.typ)
else:
arg_types = self.expand_types(t.arg_types)
expanded = t.copy_modified(
arg_types=arg_types,
ret_type=t.ret_type.accept(self),
type_guard=(t.type_guard.accept(self) if t.type_guard is not None else None),
type_is=(t.type_is.accept(self) if t.type_is is not None else None),
)
if needs_normalization:
return expanded.with_normalized_var_args()
return expanded
def visit_overloaded(self, t: Overloaded) -> Type:
items: list[CallableType] = []
for item in t.items:
new_item = item.accept(self)
assert isinstance(new_item, ProperType)
assert isinstance(new_item, CallableType)
items.append(new_item)
return Overloaded(items)
def expand_type_list_with_unpack(self, typs: list[Type]) -> list[Type]:
"""Expands a list of types that has an unpack."""
items: list[Type] = []
for item in typs:
if isinstance(item, UnpackType) and isinstance(item.type, TypeVarTupleType):
items.extend(self.expand_unpack(item))
else:
items.append(item.accept(self))
return items
def expand_type_tuple_with_unpack(self, typs: tuple[Type, ...]) -> list[Type]:
"""Expands a tuple of types that has an unpack."""
# Micro-optimization: Specialized variant of expand_type_list_with_unpack
items: list[Type] = []
for item in typs:
if isinstance(item, UnpackType) and isinstance(item.type, TypeVarTupleType):
items.extend(self.expand_unpack(item))
else:
items.append(item.accept(self))
return items
def visit_tuple_type(self, t: TupleType) -> Type:
items = self.expand_type_list_with_unpack(t.items)
if len(items) == 1:
# Normalize Tuple[*Tuple[X, ...]] -> Tuple[X, ...]
item = items[0]
if isinstance(item, UnpackType):
unpacked = get_proper_type(item.type)
if isinstance(unpacked, Instance):
# expand_type() may be called during semantic analysis, before invalid unpacks are fixed.
if unpacked.type.fullname != "builtins.tuple":
return t.partial_fallback.accept(self)
if t.partial_fallback.type.fullname != "builtins.tuple":
# If it is a subtype (like named tuple) we need to preserve it,
# this essentially mimics the logic in tuple_fallback().
return t.partial_fallback.accept(self)
return unpacked
fallback = t.partial_fallback.accept(self)
assert isinstance(fallback, ProperType) and isinstance(fallback, Instance)
return t.copy_modified(items=items, fallback=fallback)
def visit_typeddict_type(self, t: TypedDictType) -> Type:
if cached := self.get_cached(t):
return cached
fallback = t.fallback.accept(self)
assert isinstance(fallback, ProperType) and isinstance(fallback, Instance)
result = t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)
self.set_cached(t, result)
return result
def visit_literal_type(self, t: LiteralType) -> Type:
# TODO: Verify this implementation is correct
return t
def visit_union_type(self, t: UnionType) -> Type:
# Use cache to avoid O(n**2) or worse expansion of types during translation
# (only for large unions, since caching adds overhead)
use_cache = len(t.items) > 3
if use_cache and (cached := self.get_cached(t)):
return cached
expanded = self.expand_types(t.items)
# After substituting for type variables in t.items, some resulting types
# might be subtypes of others, however calling make_simplified_union()
# can cause recursion, so we just remove strict duplicates.
simplified = UnionType.make_union(
remove_trivial(flatten_nested_unions(expanded)), t.line, t.column
)
# This call to get_proper_type() is unfortunate but is required to preserve
# the invariant that ProperType will stay ProperType after applying expand_type(),
# otherwise a single item union of a type alias will break it. Note this should not
# cause infinite recursion since pathological aliases like A = Union[A, B] are
# banned at the semantic analysis level.
result = get_proper_type(simplified)
if use_cache:
self.set_cached(t, result)
return result
def visit_partial_type(self, t: PartialType) -> Type:
return t
def visit_type_type(self, t: TypeType) -> Type:
# TODO: Verify that the new item type is valid (instance or
# union of instances or Any). Sadly we can't report errors
# here yet.
item = t.item.accept(self)
return TypeType.make_normalized(item, is_type_form=t.is_type_form)
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# Target of the type alias cannot contain type variables (not bound by the type
# alias itself), so we just expand the arguments.
if len(t.args) == 0:
return t
args = self.expand_type_list_with_unpack(t.args)
# TODO: normalize if target is Tuple, and args are [*tuple[X, ...]]?
return t.copy_modified(args=args)
def expand_types(self, types: Iterable[Type]) -> list[Type]:
a: list[Type] = []
for t in types:
a.append(t.accept(self))
return a
@overload
def expand_self_type(var: Var, typ: ProperType, replacement: ProperType) -> ProperType: ...
@overload
def expand_self_type(var: Var, typ: Type, replacement: Type) -> Type: ...
def expand_self_type(var: Var, typ: Type, replacement: Type) -> Type:
"""Expand appearances of Self type in a variable type."""
if var.info.self_type is not None and not var.is_property:
return expand_type(typ, {var.info.self_type.id: replacement})
return typ
def remove_trivial(types: Iterable[Type]) -> list[Type]:
"""Make trivial simplifications on a list of types without calling is_subtype().
This makes following simplifications:
* Remove bottom types (taking into account strict optional setting)
* Remove everything else if there is an `object`
* Remove strict duplicate types
"""
removed_none = False
new_types = []
all_types = set()
for t in types:
p_t = get_proper_type(t)
if isinstance(p_t, UninhabitedType):
continue
if isinstance(p_t, NoneType) and not state.strict_optional:
removed_none = True
continue
if isinstance(p_t, Instance) and p_t.type.fullname == "builtins.object":
return [p_t]
if p_t not in all_types:
new_types.append(t)
all_types.add(p_t)
if new_types:
return new_types
if removed_none:
return [NoneType()]
return [UninhabitedType()]

View file

@ -0,0 +1,612 @@
"""Tool to convert binary mypy cache files (.ff) to JSON (.ff.json).
Usage:
python -m mypy.exportjson .mypy_cache/.../my_module.data.ff
The idea is to make caches introspectable once we've switched to a binary
cache format and removed support for the older JSON cache format.
This is primarily to support existing use cases that need to inspect
cache files, and to support debugging mypy caching issues. This means that
this doesn't necessarily need to be kept 1:1 up to date with changes in the
binary cache format (to simplify maintenance -- we don't want this to slow
down mypy development).
"""
import argparse
import json
import sys
from typing import Any, TypeAlias as _TypeAlias
from librt.internal import ReadBuffer, cache_version
from mypy.cache import CACHE_VERSION, CacheMeta
from mypy.nodes import (
FUNCBASE_FLAGS,
FUNCDEF_FLAGS,
VAR_FLAGS,
ClassDef,
DataclassTransformSpec,
Decorator,
FuncDef,
MypyFile,
OverloadedFuncDef,
OverloadPart,
ParamSpecExpr,
SymbolNode,
SymbolTable,
SymbolTableNode,
TypeAlias,
TypeInfo,
TypeVarExpr,
TypeVarTupleExpr,
Var,
get_flags,
node_kinds,
)
from mypy.types import (
NOT_READY,
AnyType,
CallableType,
ExtraAttrs,
Instance,
LiteralType,
NoneType,
Overloaded,
Parameters,
ParamSpecType,
TupleType,
Type,
TypeAliasType,
TypedDictType,
TypeType,
TypeVarTupleType,
TypeVarType,
UnboundType,
UninhabitedType,
UnionType,
UnpackType,
get_proper_type,
)
Json: _TypeAlias = dict[str, Any] | str
class Config:
def __init__(self, *, implicit_names: bool = True) -> None:
self.implicit_names = implicit_names
def convert_binary_cache_to_json(data: bytes, *, implicit_names: bool = True) -> Json:
tree = MypyFile.read(ReadBuffer(data))
return convert_mypy_file_to_json(tree, Config(implicit_names=implicit_names))
def convert_mypy_file_to_json(self: MypyFile, cfg: Config) -> Json:
return {
".class": "MypyFile",
"_fullname": self._fullname,
"names": convert_symbol_table(self.names, cfg),
"is_stub": self.is_stub,
"path": self.path,
"is_partial_stub_package": self.is_partial_stub_package,
"future_import_flags": sorted(self.future_import_flags),
}
def convert_symbol_table(self: SymbolTable, cfg: Config) -> Json:
data: dict[str, Any] = {".class": "SymbolTable"}
for key, value in self.items():
# Skip __builtins__: it's a reference to the builtins
# module that gets added to every module by
# SemanticAnalyzerPass2.visit_file(), but it shouldn't be
# accessed by users of the module.
if key == "__builtins__" or value.no_serialize:
continue
if not cfg.implicit_names and key in {
"__spec__",
"__package__",
"__file__",
"__doc__",
"__annotations__",
"__name__",
}:
continue
data[key] = convert_symbol_table_node(value, cfg)
return data
def convert_symbol_table_node(self: SymbolTableNode, cfg: Config) -> Json:
data: dict[str, Any] = {".class": "SymbolTableNode", "kind": node_kinds[self.kind]}
if self.module_hidden:
data["module_hidden"] = True
if not self.module_public:
data["module_public"] = False
if self.implicit:
data["implicit"] = True
if self.plugin_generated:
data["plugin_generated"] = True
if self.cross_ref:
data["cross_ref"] = self.cross_ref
elif self.node is not None:
data["node"] = convert_symbol_node(self.node, cfg)
return data
def convert_symbol_node(self: SymbolNode, cfg: Config) -> Json:
if isinstance(self, FuncDef):
return convert_func_def(self)
elif isinstance(self, OverloadedFuncDef):
return convert_overloaded_func_def(self)
elif isinstance(self, Decorator):
return convert_decorator(self)
elif isinstance(self, Var):
return convert_var(self)
elif isinstance(self, TypeInfo):
return convert_type_info(self, cfg)
elif isinstance(self, TypeAlias):
return convert_type_alias(self)
elif isinstance(self, TypeVarExpr):
return convert_type_var_expr(self)
elif isinstance(self, ParamSpecExpr):
return convert_param_spec_expr(self)
elif isinstance(self, TypeVarTupleExpr):
return convert_type_var_tuple_expr(self)
return {"ERROR": f"{type(self)!r} unrecognized"}
def convert_func_def(self: FuncDef) -> Json:
return {
".class": "FuncDef",
"name": self._name,
"fullname": self._fullname,
"arg_names": self.arg_names,
"arg_kinds": [int(x.value) for x in self.arg_kinds],
"type": None if self.type is None else convert_type(self.type),
"flags": get_flags(self, FUNCDEF_FLAGS),
"abstract_status": self.abstract_status,
# TODO: Do we need expanded, original_def?
"dataclass_transform_spec": (
None
if self.dataclass_transform_spec is None
else convert_dataclass_transform_spec(self.dataclass_transform_spec)
),
"deprecated": self.deprecated,
"original_first_arg": self.original_first_arg,
}
def convert_dataclass_transform_spec(self: DataclassTransformSpec) -> Json:
return {
"eq_default": self.eq_default,
"order_default": self.order_default,
"kw_only_default": self.kw_only_default,
"frozen_default": self.frozen_default,
"field_specifiers": list(self.field_specifiers),
}
def convert_overloaded_func_def(self: OverloadedFuncDef) -> Json:
return {
".class": "OverloadedFuncDef",
"items": [convert_overload_part(i) for i in self.items],
"type": None if self.type is None else convert_type(self.type),
"fullname": self._fullname,
"impl": None if self.impl is None else convert_overload_part(self.impl),
"flags": get_flags(self, FUNCBASE_FLAGS),
"deprecated": self.deprecated,
"setter_index": self.setter_index,
}
def convert_overload_part(self: OverloadPart) -> Json:
if isinstance(self, FuncDef):
return convert_func_def(self)
else:
return convert_decorator(self)
def convert_decorator(self: Decorator) -> Json:
return {
".class": "Decorator",
"func": convert_func_def(self.func),
"var": convert_var(self.var),
"is_overload": self.is_overload,
}
def convert_var(self: Var) -> Json:
data: dict[str, Any] = {
".class": "Var",
"name": self._name,
"fullname": self._fullname,
"type": None if self.type is None else convert_type(self.type),
"setter_type": None if self.setter_type is None else convert_type(self.setter_type),
"flags": get_flags(self, VAR_FLAGS),
}
if self.final_value is not None:
data["final_value"] = self.final_value
return data
def convert_type_info(self: TypeInfo, cfg: Config) -> Json:
data = {
".class": "TypeInfo",
"module_name": self.module_name,
"fullname": self.fullname,
"names": convert_symbol_table(self.names, cfg),
"defn": convert_class_def(self.defn),
"abstract_attributes": self.abstract_attributes,
"type_vars": self.type_vars,
"has_param_spec_type": self.has_param_spec_type,
"bases": [convert_type(b) for b in self.bases],
"mro": self._mro_refs,
"_promote": [convert_type(p) for p in self._promote],
"alt_promote": None if self.alt_promote is None else convert_type(self.alt_promote),
"declared_metaclass": (
None if self.declared_metaclass is None else convert_type(self.declared_metaclass)
),
"metaclass_type": (
None if self.metaclass_type is None else convert_type(self.metaclass_type)
),
"tuple_type": None if self.tuple_type is None else convert_type(self.tuple_type),
"typeddict_type": (
None if self.typeddict_type is None else convert_typeddict_type(self.typeddict_type)
),
"flags": get_flags(self, TypeInfo.FLAGS),
"metadata": self.metadata,
"slots": sorted(self.slots) if self.slots is not None else None,
"deletable_attributes": self.deletable_attributes,
"self_type": convert_type(self.self_type) if self.self_type is not None else None,
"dataclass_transform_spec": (
convert_dataclass_transform_spec(self.dataclass_transform_spec)
if self.dataclass_transform_spec is not None
else None
),
"deprecated": self.deprecated,
}
return data
def convert_class_def(self: ClassDef) -> Json:
return {
".class": "ClassDef",
"name": self.name,
"fullname": self.fullname,
"type_vars": [convert_type(v) for v in self.type_vars],
}
def convert_type_alias(self: TypeAlias) -> Json:
data: Json = {
".class": "TypeAlias",
"fullname": self._fullname,
"module": self.module,
"target": convert_type(self.target),
"alias_tvars": [convert_type(v) for v in self.alias_tvars],
"no_args": self.no_args,
"normalized": self.normalized,
"python_3_12_type_alias": self.python_3_12_type_alias,
}
return data
def convert_type_var_expr(self: TypeVarExpr) -> Json:
return {
".class": "TypeVarExpr",
"name": self._name,
"fullname": self._fullname,
"values": [convert_type(t) for t in self.values],
"upper_bound": convert_type(self.upper_bound),
"default": convert_type(self.default),
"variance": self.variance,
}
def convert_param_spec_expr(self: ParamSpecExpr) -> Json:
return {
".class": "ParamSpecExpr",
"name": self._name,
"fullname": self._fullname,
"upper_bound": convert_type(self.upper_bound),
"default": convert_type(self.default),
"variance": self.variance,
}
def convert_type_var_tuple_expr(self: TypeVarTupleExpr) -> Json:
return {
".class": "TypeVarTupleExpr",
"name": self._name,
"fullname": self._fullname,
"upper_bound": convert_type(self.upper_bound),
"tuple_fallback": convert_type(self.tuple_fallback),
"default": convert_type(self.default),
"variance": self.variance,
}
def convert_type(typ: Type) -> Json:
if type(typ) is TypeAliasType:
return convert_type_alias_type(typ)
typ = get_proper_type(typ)
if isinstance(typ, Instance):
return convert_instance(typ)
elif isinstance(typ, AnyType):
return convert_any_type(typ)
elif isinstance(typ, NoneType):
return convert_none_type(typ)
elif isinstance(typ, UnionType):
return convert_union_type(typ)
elif isinstance(typ, TupleType):
return convert_tuple_type(typ)
elif isinstance(typ, CallableType):
return convert_callable_type(typ)
elif isinstance(typ, Overloaded):
return convert_overloaded(typ)
elif isinstance(typ, LiteralType):
return convert_literal_type(typ)
elif isinstance(typ, TypeVarType):
return convert_type_var_type(typ)
elif isinstance(typ, TypeType):
return convert_type_type(typ)
elif isinstance(typ, UninhabitedType):
return convert_uninhabited_type(typ)
elif isinstance(typ, UnpackType):
return convert_unpack_type(typ)
elif isinstance(typ, ParamSpecType):
return convert_param_spec_type(typ)
elif isinstance(typ, TypeVarTupleType):
return convert_type_var_tuple_type(typ)
elif isinstance(typ, Parameters):
return convert_parameters(typ)
elif isinstance(typ, TypedDictType):
return convert_typeddict_type(typ)
elif isinstance(typ, UnboundType):
return convert_unbound_type(typ)
return {"ERROR": f"{type(typ)!r} unrecognized"}
def convert_instance(self: Instance) -> Json:
ready = self.type is not NOT_READY
if not self.args and not self.last_known_value and not self.extra_attrs:
if ready:
return self.type.fullname
elif self.type_ref:
return self.type_ref
data: dict[str, Any] = {
".class": "Instance",
"type_ref": self.type.fullname if ready else self.type_ref,
"args": [convert_type(arg) for arg in self.args],
}
if self.last_known_value is not None:
data["last_known_value"] = convert_type(self.last_known_value)
data["extra_attrs"] = convert_extra_attrs(self.extra_attrs) if self.extra_attrs else None
return data
def convert_extra_attrs(self: ExtraAttrs) -> Json:
return {
".class": "ExtraAttrs",
"attrs": {k: convert_type(v) for k, v in self.attrs.items()},
"immutable": sorted(self.immutable),
"mod_name": self.mod_name,
}
def convert_type_alias_type(self: TypeAliasType) -> Json:
data: Json = {
".class": "TypeAliasType",
"type_ref": self.type_ref,
"args": [convert_type(arg) for arg in self.args],
}
return data
def convert_any_type(self: AnyType) -> Json:
return {
".class": "AnyType",
"type_of_any": self.type_of_any,
"source_any": convert_type(self.source_any) if self.source_any is not None else None,
"missing_import_name": self.missing_import_name,
}
def convert_none_type(self: NoneType) -> Json:
return {".class": "NoneType"}
def convert_union_type(self: UnionType) -> Json:
return {
".class": "UnionType",
"items": [convert_type(t) for t in self.items],
"uses_pep604_syntax": self.uses_pep604_syntax,
}
def convert_tuple_type(self: TupleType) -> Json:
return {
".class": "TupleType",
"items": [convert_type(t) for t in self.items],
"partial_fallback": convert_type(self.partial_fallback),
"implicit": self.implicit,
}
def convert_literal_type(self: LiteralType) -> Json:
return {".class": "LiteralType", "value": self.value, "fallback": convert_type(self.fallback)}
def convert_type_var_type(self: TypeVarType) -> Json:
assert not self.id.is_meta_var()
return {
".class": "TypeVarType",
"name": self.name,
"fullname": self.fullname,
"id": self.id.raw_id,
"namespace": self.id.namespace,
"values": [convert_type(v) for v in self.values],
"upper_bound": convert_type(self.upper_bound),
"default": convert_type(self.default),
"variance": self.variance,
}
def convert_callable_type(self: CallableType) -> Json:
return {
".class": "CallableType",
"arg_types": [convert_type(t) for t in self.arg_types],
"arg_kinds": [int(x.value) for x in self.arg_kinds],
"arg_names": self.arg_names,
"ret_type": convert_type(self.ret_type),
"fallback": convert_type(self.fallback),
"name": self.name,
# We don't serialize the definition (only used for error messages).
"variables": [convert_type(v) for v in self.variables],
"is_ellipsis_args": self.is_ellipsis_args,
"implicit": self.implicit,
"is_bound": self.is_bound,
"type_guard": convert_type(self.type_guard) if self.type_guard is not None else None,
"type_is": convert_type(self.type_is) if self.type_is is not None else None,
"from_concatenate": self.from_concatenate,
"imprecise_arg_kinds": self.imprecise_arg_kinds,
"unpack_kwargs": self.unpack_kwargs,
}
def convert_overloaded(self: Overloaded) -> Json:
return {".class": "Overloaded", "items": [convert_type(t) for t in self.items]}
def convert_type_type(self: TypeType) -> Json:
return {".class": "TypeType", "item": convert_type(self.item)}
def convert_uninhabited_type(self: UninhabitedType) -> Json:
return {".class": "UninhabitedType"}
def convert_unpack_type(self: UnpackType) -> Json:
return {".class": "UnpackType", "type": convert_type(self.type)}
def convert_param_spec_type(self: ParamSpecType) -> Json:
assert not self.id.is_meta_var()
return {
".class": "ParamSpecType",
"name": self.name,
"fullname": self.fullname,
"id": self.id.raw_id,
"namespace": self.id.namespace,
"flavor": self.flavor,
"upper_bound": convert_type(self.upper_bound),
"default": convert_type(self.default),
"prefix": convert_type(self.prefix),
}
def convert_type_var_tuple_type(self: TypeVarTupleType) -> Json:
assert not self.id.is_meta_var()
return {
".class": "TypeVarTupleType",
"name": self.name,
"fullname": self.fullname,
"id": self.id.raw_id,
"namespace": self.id.namespace,
"upper_bound": convert_type(self.upper_bound),
"tuple_fallback": convert_type(self.tuple_fallback),
"default": convert_type(self.default),
"min_len": self.min_len,
}
def convert_parameters(self: Parameters) -> Json:
return {
".class": "Parameters",
"arg_types": [convert_type(t) for t in self.arg_types],
"arg_kinds": [int(x.value) for x in self.arg_kinds],
"arg_names": self.arg_names,
"variables": [convert_type(tv) for tv in self.variables],
"imprecise_arg_kinds": self.imprecise_arg_kinds,
}
def convert_typeddict_type(self: TypedDictType) -> Json:
return {
".class": "TypedDictType",
"items": [[n, convert_type(t)] for (n, t) in self.items.items()],
"required_keys": sorted(self.required_keys),
"readonly_keys": sorted(self.readonly_keys),
"fallback": convert_type(self.fallback),
}
def convert_unbound_type(self: UnboundType) -> Json:
return {
".class": "UnboundType",
"name": self.name,
"args": [convert_type(a) for a in self.args],
"expr": self.original_str_expr,
"expr_fallback": self.original_str_fallback,
}
def convert_binary_cache_meta_to_json(data: bytes, data_file: str) -> Json:
assert (
data[0] == cache_version() and data[1] == CACHE_VERSION
), "Cache file created by an incompatible mypy version"
meta = CacheMeta.read(ReadBuffer(data[2:]), data_file)
assert meta is not None, f"Error reading meta cache file associated with {data_file}"
return {
"id": meta.id,
"path": meta.path,
"mtime": meta.mtime,
"size": meta.size,
"hash": meta.hash,
"data_mtime": meta.data_mtime,
"dependencies": meta.dependencies,
"suppressed": meta.suppressed,
"options": meta.options,
"dep_prios": meta.dep_prios,
"dep_lines": meta.dep_lines,
"dep_hashes": [dep.hex() for dep in meta.dep_hashes],
"interface_hash": meta.interface_hash.hex(),
"version_id": meta.version_id,
"ignore_all": meta.ignore_all,
"plugin_data": meta.plugin_data,
}
def main() -> None:
parser = argparse.ArgumentParser(
description="Convert binary cache files to JSON. "
"Create files in the same directory with extra .json extension."
)
parser.add_argument(
"path", nargs="+", help="mypy cache data file to convert (.data.ff extension)"
)
args = parser.parse_args()
fnams: list[str] = args.path
for fnam in fnams:
if fnam.endswith(".data.ff"):
is_data = True
elif fnam.endswith(".meta.ff"):
is_data = False
else:
sys.exit(f"error: Expected .data.ff or .meta.ff extension, but got {fnam}")
with open(fnam, "rb") as f:
data = f.read()
if is_data:
json_data = convert_binary_cache_to_json(data)
else:
data_file = fnam.removesuffix(".meta.ff") + ".data.ff"
json_data = convert_binary_cache_meta_to_json(data, data_file)
new_fnam = fnam + ".json"
with open(new_fnam, "w") as f:
json.dump(json_data, f)
print(f"{fnam} -> {new_fnam}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,289 @@
"""Translate an Expression to a Type value."""
from __future__ import annotations
from collections.abc import Callable
from mypy.fastparse import parse_type_string
from mypy.nodes import (
MISSING_FALLBACK,
BytesExpr,
CallExpr,
ComplexExpr,
Context,
DictExpr,
EllipsisExpr,
Expression,
FloatExpr,
IndexExpr,
IntExpr,
ListExpr,
MemberExpr,
NameExpr,
OpExpr,
RefExpr,
StarExpr,
StrExpr,
SymbolTableNode,
TupleExpr,
UnaryExpr,
get_member_expr_fullname,
)
from mypy.options import Options
from mypy.types import (
ANNOTATED_TYPE_NAMES,
AnyType,
CallableArgument,
EllipsisType,
Instance,
ProperType,
RawExpressionType,
Type,
TypedDictType,
TypeList,
TypeOfAny,
UnboundType,
UnionType,
UnpackType,
)
class TypeTranslationError(Exception):
"""Exception raised when an expression is not valid as a type."""
def _extract_argument_name(expr: Expression) -> str | None:
if isinstance(expr, NameExpr) and expr.name == "None":
return None
elif isinstance(expr, StrExpr):
return expr.value
else:
raise TypeTranslationError()
def expr_to_unanalyzed_type(
expr: Expression,
options: Options,
allow_new_syntax: bool = False,
_parent: Expression | None = None,
allow_unpack: bool = False,
lookup_qualified: Callable[[str, Context], SymbolTableNode | None] | None = None,
) -> ProperType:
"""Translate an expression to the corresponding type.
The result is not semantically analyzed. It can be UnboundType or TypeList.
Raise TypeTranslationError if the expression cannot represent a type.
If lookup_qualified is not provided, the expression is expected to be semantically
analyzed.
If allow_new_syntax is True, allow all type syntax independent of the target
Python version (used in stubs).
# TODO: a lot of code here is duplicated in fastparse.py, refactor this.
"""
# The `parent` parameter is used in recursive calls to provide context for
# understanding whether an CallableArgument is ok.
name: str | None = None
if isinstance(expr, NameExpr):
name = expr.name
if name == "True":
return RawExpressionType(True, "builtins.bool", line=expr.line, column=expr.column)
elif name == "False":
return RawExpressionType(False, "builtins.bool", line=expr.line, column=expr.column)
else:
return UnboundType(name, line=expr.line, column=expr.column)
elif isinstance(expr, MemberExpr):
fullname = get_member_expr_fullname(expr)
if fullname:
return UnboundType(fullname, line=expr.line, column=expr.column)
else:
raise TypeTranslationError()
elif isinstance(expr, IndexExpr):
base = expr_to_unanalyzed_type(
expr.base, options, allow_new_syntax, expr, lookup_qualified=lookup_qualified
)
if isinstance(base, UnboundType):
if base.args:
raise TypeTranslationError()
if isinstance(expr.index, TupleExpr):
args = expr.index.items
else:
args = [expr.index]
if isinstance(expr.base, RefExpr):
# Check if the type is Annotated[...]. For this we need the fullname,
# which must be looked up if the expression hasn't been semantically analyzed.
base_fullname = None
if lookup_qualified is not None:
sym = lookup_qualified(base.name, expr)
if sym and sym.node:
base_fullname = sym.node.fullname
else:
base_fullname = expr.base.fullname
if base_fullname is not None and base_fullname in ANNOTATED_TYPE_NAMES:
# TODO: this is not the optimal solution as we are basically getting rid
# of the Annotation definition and only returning the type information,
# losing all the annotations.
return expr_to_unanalyzed_type(
args[0], options, allow_new_syntax, expr, lookup_qualified=lookup_qualified
)
base.args = tuple(
expr_to_unanalyzed_type(
arg,
options,
allow_new_syntax,
expr,
allow_unpack=True,
lookup_qualified=lookup_qualified,
)
for arg in args
)
if not base.args:
base.empty_tuple_index = True
return base
else:
raise TypeTranslationError()
elif (
isinstance(expr, OpExpr)
and expr.op == "|"
and ((options.python_version >= (3, 10)) or allow_new_syntax)
):
return UnionType(
[
expr_to_unanalyzed_type(
expr.left, options, allow_new_syntax, lookup_qualified=lookup_qualified
),
expr_to_unanalyzed_type(
expr.right, options, allow_new_syntax, lookup_qualified=lookup_qualified
),
],
uses_pep604_syntax=True,
)
elif isinstance(expr, CallExpr) and isinstance(_parent, ListExpr):
c = expr.callee
names = []
# Go through the dotted member expr chain to get the full arg
# constructor name to look up
while True:
if isinstance(c, NameExpr):
names.append(c.name)
break
elif isinstance(c, MemberExpr):
names.append(c.name)
c = c.expr
else:
raise TypeTranslationError()
arg_const = ".".join(reversed(names))
# Go through the constructor args to get its name and type.
name = None
default_type = AnyType(TypeOfAny.unannotated)
typ: Type = default_type
for i, arg in enumerate(expr.args):
if expr.arg_names[i] is not None:
if expr.arg_names[i] == "name":
if name is not None:
# Two names
raise TypeTranslationError()
name = _extract_argument_name(arg)
continue
elif expr.arg_names[i] == "type":
if typ is not default_type:
# Two types
raise TypeTranslationError()
typ = expr_to_unanalyzed_type(
arg, options, allow_new_syntax, expr, lookup_qualified=lookup_qualified
)
continue
else:
raise TypeTranslationError()
elif i == 0:
typ = expr_to_unanalyzed_type(
arg, options, allow_new_syntax, expr, lookup_qualified=lookup_qualified
)
elif i == 1:
name = _extract_argument_name(arg)
else:
raise TypeTranslationError()
return CallableArgument(typ, name, arg_const, expr.line, expr.column)
elif isinstance(expr, ListExpr):
return TypeList(
[
expr_to_unanalyzed_type(
t,
options,
allow_new_syntax,
expr,
allow_unpack=True,
lookup_qualified=lookup_qualified,
)
for t in expr.items
],
line=expr.line,
column=expr.column,
)
elif isinstance(expr, StrExpr):
return parse_type_string(expr.value, "builtins.str", expr.line, expr.column)
elif isinstance(expr, BytesExpr):
return parse_type_string(expr.value, "builtins.bytes", expr.line, expr.column)
elif isinstance(expr, UnaryExpr):
typ = expr_to_unanalyzed_type(
expr.expr, options, allow_new_syntax, lookup_qualified=lookup_qualified
)
if isinstance(typ, RawExpressionType):
if isinstance(typ.literal_value, int):
if expr.op == "-":
typ.literal_value *= -1
return typ
elif expr.op == "+":
return typ
raise TypeTranslationError()
elif isinstance(expr, IntExpr):
return RawExpressionType(expr.value, "builtins.int", line=expr.line, column=expr.column)
elif isinstance(expr, FloatExpr):
# Floats are not valid parameters for RawExpressionType , so we just
# pass in 'None' for now. We'll report the appropriate error at a later stage.
return RawExpressionType(None, "builtins.float", line=expr.line, column=expr.column)
elif isinstance(expr, ComplexExpr):
# Same thing as above with complex numbers.
return RawExpressionType(None, "builtins.complex", line=expr.line, column=expr.column)
elif isinstance(expr, EllipsisExpr):
return EllipsisType(expr.line)
elif allow_unpack and isinstance(expr, StarExpr):
return UnpackType(
expr_to_unanalyzed_type(
expr.expr, options, allow_new_syntax, lookup_qualified=lookup_qualified
),
from_star_syntax=True,
)
elif isinstance(expr, DictExpr):
if not expr.items:
raise TypeTranslationError()
items: dict[str, Type] = {}
extra_items_from = []
for item_name, value in expr.items:
if not isinstance(item_name, StrExpr):
if item_name is None:
extra_items_from.append(
expr_to_unanalyzed_type(
value,
options,
allow_new_syntax,
expr,
lookup_qualified=lookup_qualified,
)
)
continue
raise TypeTranslationError()
items[item_name.value] = expr_to_unanalyzed_type(
value, options, allow_new_syntax, expr, lookup_qualified=lookup_qualified
)
result = TypedDictType(
items, set(), set(), Instance(MISSING_FALLBACK, ()), expr.line, expr.column
)
result.extra_items_from = extra_items_from
return result
else:
raise TypeTranslationError()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,257 @@
"""Routines for finding the sources that mypy will check"""
from __future__ import annotations
import functools
import os
from collections.abc import Sequence
from typing import Final
from mypy.fscache import FileSystemCache
from mypy.modulefinder import (
PYTHON_EXTENSIONS,
BuildSource,
matches_exclude,
matches_gitignore,
mypy_path,
)
from mypy.options import Options
PY_EXTENSIONS: Final = tuple(PYTHON_EXTENSIONS)
class InvalidSourceList(Exception):
"""Exception indicating a problem in the list of sources given to mypy."""
def create_source_list(
paths: Sequence[str],
options: Options,
fscache: FileSystemCache | None = None,
allow_empty_dir: bool = False,
) -> list[BuildSource]:
"""From a list of source files/directories, makes a list of BuildSources.
Raises InvalidSourceList on errors.
"""
fscache = fscache or FileSystemCache()
finder = SourceFinder(fscache, options)
sources = []
for path in paths:
path = os.path.normpath(path)
if path.endswith(PY_EXTENSIONS):
# Can raise InvalidSourceList if a directory doesn't have a valid module name.
name, base_dir = finder.crawl_up(path)
sources.append(BuildSource(path, name, None, base_dir))
elif fscache.isdir(path):
sub_sources = finder.find_sources_in_dir(path)
if not sub_sources and not allow_empty_dir:
raise InvalidSourceList(f"There are no .py[i] files in directory '{path}'")
sources.extend(sub_sources)
else:
mod = os.path.basename(path) if options.scripts_are_modules else None
sources.append(BuildSource(path, mod, None))
return sources
def keyfunc(name: str) -> tuple[bool, int, str]:
"""Determines sort order for directory listing.
The desirable properties are:
1) foo < foo.pyi < foo.py
2) __init__.py[i] < foo
"""
base, suffix = os.path.splitext(name)
for i, ext in enumerate(PY_EXTENSIONS):
if suffix == ext:
return (base != "__init__", i, base)
return (base != "__init__", -1, name)
def normalise_package_base(root: str) -> str:
if not root:
root = os.curdir
root = os.path.abspath(root)
if root.endswith(os.sep):
root = root[:-1]
return root
def get_explicit_package_bases(options: Options) -> list[str] | None:
"""Returns explicit package bases to use if the option is enabled, or None if disabled.
We currently use MYPYPATH and the current directory as the package bases. In the future,
when --namespace-packages is the default could also use the values passed with the
--package-root flag, see #9632.
Values returned are normalised so we can use simple string comparisons in
SourceFinder.is_explicit_package_base
"""
if not options.explicit_package_bases:
return None
roots = mypy_path() + options.mypy_path + [os.getcwd()]
return [normalise_package_base(root) for root in roots]
class SourceFinder:
def __init__(self, fscache: FileSystemCache, options: Options) -> None:
self.fscache = fscache
self.explicit_package_bases = get_explicit_package_bases(options)
self.namespace_packages = options.namespace_packages
self.exclude = options.exclude
self.exclude_gitignore = options.exclude_gitignore
self.verbosity = options.verbosity
def is_explicit_package_base(self, path: str) -> bool:
assert self.explicit_package_bases
return normalise_package_base(path) in self.explicit_package_bases
def find_sources_in_dir(self, path: str) -> list[BuildSource]:
sources = []
seen: set[str] = set()
names = sorted(self.fscache.listdir(path), key=keyfunc)
for name in names:
# Skip certain names altogether
if name in ("__pycache__", "site-packages", "node_modules") or name.startswith("."):
continue
subpath = os.path.join(path, name)
if matches_exclude(subpath, self.exclude, self.fscache, self.verbosity >= 2):
continue
if self.exclude_gitignore and matches_gitignore(
subpath, self.fscache, self.verbosity >= 2
):
continue
if self.fscache.isdir(subpath):
sub_sources = self.find_sources_in_dir(subpath)
if sub_sources:
seen.add(name)
sources.extend(sub_sources)
else:
stem, suffix = os.path.splitext(name)
if stem not in seen and suffix in PY_EXTENSIONS:
seen.add(stem)
module, base_dir = self.crawl_up(subpath)
sources.append(BuildSource(subpath, module, None, base_dir))
return sources
def crawl_up(self, path: str) -> tuple[str, str]:
"""Given a .py[i] filename, return module and base directory.
For example, given "xxx/yyy/foo/bar.py", we might return something like:
("foo.bar", "xxx/yyy")
If namespace packages is off, we crawl upwards until we find a directory without
an __init__.py
If namespace packages is on, we crawl upwards until the nearest explicit base directory.
Failing that, we return one past the highest directory containing an __init__.py
We won't crawl past directories with invalid package names.
The base directory returned is an absolute path.
"""
path = os.path.abspath(path)
parent, filename = os.path.split(path)
module_name = strip_py(filename) or filename
parent_module, base_dir = self.crawl_up_dir(parent)
if module_name == "__init__":
return parent_module, base_dir
# Note that module_name might not actually be a valid identifier, but that's okay
# Ignoring this possibility sidesteps some search path confusion
module = module_join(parent_module, module_name)
return module, base_dir
def crawl_up_dir(self, dir: str) -> tuple[str, str]:
return self._crawl_up_helper(dir) or ("", dir)
@functools.lru_cache # noqa: B019
def _crawl_up_helper(self, dir: str) -> tuple[str, str] | None:
"""Given a directory, maybe returns module and base directory.
We return a non-None value if we were able to find something clearly intended as a base
directory (as adjudicated by being an explicit base directory or by containing a package
with __init__.py).
This distinction is necessary for namespace packages, so that we know when to treat
ourselves as a subpackage.
"""
# stop crawling if we're an explicit base directory
if self.explicit_package_bases is not None and self.is_explicit_package_base(dir):
return "", dir
parent, name = os.path.split(dir)
name = name.removesuffix("-stubs") # PEP-561 stub-only directory
# recurse if there's an __init__.py
init_file = self.get_init_file(dir)
if init_file is not None:
if not name.isidentifier():
# in most cases the directory name is invalid, we'll just stop crawling upwards
# but if there's an __init__.py in the directory, something is messed up
raise InvalidSourceList(
f"{name} contains {os.path.basename(init_file)} "
"but is not a valid Python package name"
)
# we're definitely a package, so we always return a non-None value
mod_prefix, base_dir = self.crawl_up_dir(parent)
return module_join(mod_prefix, name), base_dir
# stop crawling if we're out of path components or our name is an invalid identifier
if not name or not parent or not name.isidentifier():
return None
# stop crawling if namespace packages is off (since we don't have an __init__.py)
if not self.namespace_packages:
return None
# at this point: namespace packages is on, we don't have an __init__.py and we're not an
# explicit base directory
result = self._crawl_up_helper(parent)
if result is None:
# we're not an explicit base directory and we don't have an __init__.py
# and none of our parents are either, so return
return None
# one of our parents was an explicit base directory or had an __init__.py, so we're
# definitely a subpackage! chain our name to the module.
mod_prefix, base_dir = result
return module_join(mod_prefix, name), base_dir
def get_init_file(self, dir: str) -> str | None:
"""Check whether a directory contains a file named __init__.py[i].
If so, return the file's name (with dir prefixed). If not, return None.
This prefers .pyi over .py (because of the ordering of PY_EXTENSIONS).
"""
for ext in PY_EXTENSIONS:
f = os.path.join(dir, "__init__" + ext)
if self.fscache.isfile(f):
return f
if ext == ".py" and self.fscache.init_under_package_root(f):
return f
return None
def module_join(parent: str, child: str) -> str:
"""Join module ids, accounting for a possibly empty parent."""
if parent:
return parent + "." + child
return child
def strip_py(arg: str) -> str | None:
"""Strip a trailing .py or .pyi suffix.
Return None if no such suffix is found.
"""
for ext in PY_EXTENSIONS:
if arg.endswith(ext):
return arg[: -len(ext)]
return None

View file

@ -0,0 +1,444 @@
"""Fix up various things after deserialization."""
from __future__ import annotations
from typing import Any, Final
from mypy.lookup import lookup_fully_qualified
from mypy.nodes import (
Block,
ClassDef,
Decorator,
FuncDef,
MypyFile,
OverloadedFuncDef,
ParamSpecExpr,
SymbolTable,
TypeAlias,
TypeInfo,
TypeVarExpr,
TypeVarTupleExpr,
Var,
)
from mypy.types import (
NOT_READY,
AnyType,
CallableType,
Instance,
LiteralType,
Overloaded,
Parameters,
ParamSpecType,
ProperType,
TupleType,
TypeAliasType,
TypedDictType,
TypeOfAny,
TypeType,
TypeVarTupleType,
TypeVarType,
TypeVisitor,
UnboundType,
UnionType,
UnpackType,
)
from mypy.visitor import NodeVisitor
# N.B: we do a allow_missing fixup when fixing up a fine-grained
# incremental cache load (since there may be cross-refs into deleted
# modules)
def fixup_module(tree: MypyFile, modules: dict[str, MypyFile], allow_missing: bool) -> None:
node_fixer = NodeFixer(modules, allow_missing)
node_fixer.visit_symbol_table(tree.names, tree.fullname)
# TODO: Fix up .info when deserializing, i.e. much earlier.
class NodeFixer(NodeVisitor[None]):
current_info: TypeInfo | None = None
def __init__(self, modules: dict[str, MypyFile], allow_missing: bool) -> None:
self.modules = modules
self.allow_missing = allow_missing
self.type_fixer = TypeFixer(self.modules, allow_missing)
# NOTE: This method isn't (yet) part of the NodeVisitor API.
def visit_type_info(self, info: TypeInfo) -> None:
save_info = self.current_info
try:
self.current_info = info
if info.defn:
info.defn.accept(self)
if info.names:
self.visit_symbol_table(info.names, info.fullname)
if info.bases:
for base in info.bases:
base.accept(self.type_fixer)
if info._promote:
for p in info._promote:
p.accept(self.type_fixer)
if info.tuple_type:
info.tuple_type.accept(self.type_fixer)
info.update_tuple_type(info.tuple_type)
if info.special_alias:
info.special_alias.alias_tvars = list(info.defn.type_vars)
for i, t in enumerate(info.defn.type_vars):
if isinstance(t, TypeVarTupleType):
info.special_alias.tvar_tuple_index = i
if info.typeddict_type:
info.typeddict_type.accept(self.type_fixer)
info.update_typeddict_type(info.typeddict_type)
if info.special_alias:
info.special_alias.alias_tvars = list(info.defn.type_vars)
for i, t in enumerate(info.defn.type_vars):
if isinstance(t, TypeVarTupleType):
info.special_alias.tvar_tuple_index = i
if info.declared_metaclass:
info.declared_metaclass.accept(self.type_fixer)
if info.metaclass_type:
info.metaclass_type.accept(self.type_fixer)
if info.self_type:
info.self_type.accept(self.type_fixer)
if info.alt_promote:
info.alt_promote.accept(self.type_fixer)
instance = Instance(info, [])
# Hack: We may also need to add a backwards promotion (from int to native int),
# since it might not be serialized.
if instance not in info.alt_promote.type._promote:
info.alt_promote.type._promote.append(instance)
if info._mro_refs:
info.mro = [
lookup_fully_qualified_typeinfo(
self.modules, name, allow_missing=self.allow_missing
)
for name in info._mro_refs
]
info._mro_refs = None
finally:
self.current_info = save_info
# NOTE: This method *definitely* isn't part of the NodeVisitor API.
def visit_symbol_table(self, symtab: SymbolTable, table_fullname: str) -> None:
# Copy the items because we may mutate symtab.
for key in list(symtab):
value = symtab[key]
cross_ref = value.cross_ref
if cross_ref is not None: # Fix up cross-reference.
value.cross_ref = None
if cross_ref in self.modules:
value.node = self.modules[cross_ref]
else:
stnode = lookup_fully_qualified(
cross_ref, self.modules, raise_on_missing=not self.allow_missing
)
if stnode is not None:
if stnode is value:
# The node seems to refer to itself, which can mean that
# the target is a deleted submodule of the current module,
# and thus lookup falls back to the symbol table of the parent
# package. Here's how this may happen:
#
# pkg/__init__.py:
# from pkg import sub
#
# Now if pkg.sub is deleted, the pkg.sub symbol table entry
# appears to refer to itself. Replace the entry with a
# placeholder to avoid a crash. We can't delete the entry,
# as it would stop dependency propagation.
value.node = Var(key + "@deleted")
else:
assert stnode.node is not None, (table_fullname + "." + key, cross_ref)
value.node = stnode.node
elif not self.allow_missing:
assert False, f"Could not find cross-ref {cross_ref}"
else:
# We have a missing crossref in allow missing mode, need to put something
value.node = missing_info(self.modules)
else:
if isinstance(value.node, TypeInfo):
# TypeInfo has no accept(). TODO: Add it?
self.visit_type_info(value.node)
elif value.node is not None:
value.node.accept(self)
else:
assert False, f"Unexpected empty node {key!r}: {value}"
def visit_func_def(self, func: FuncDef) -> None:
if self.current_info is not None:
func.info = self.current_info
if func.type is not None:
func.type.accept(self.type_fixer)
if isinstance(func.type, CallableType):
func.type.definition = func
def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
if self.current_info is not None:
o.info = self.current_info
if o.type:
o.type.accept(self.type_fixer)
for item in o.items:
item.accept(self)
if o.impl:
o.impl.accept(self)
if isinstance(o.type, Overloaded):
# For error messages we link the original definition for each item.
for typ, item in zip(o.type.items, o.items):
typ.definition = item
def visit_decorator(self, d: Decorator) -> None:
if self.current_info is not None:
d.var.info = self.current_info
if d.func:
d.func.accept(self)
if d.var:
d.var.accept(self)
for node in d.decorators:
node.accept(self)
typ = d.var.type
if isinstance(typ, ProperType) and isinstance(typ, CallableType):
typ.definition = d.func
def visit_class_def(self, c: ClassDef) -> None:
for v in c.type_vars:
v.accept(self.type_fixer)
def visit_type_var_expr(self, tv: TypeVarExpr) -> None:
for value in tv.values:
value.accept(self.type_fixer)
tv.upper_bound.accept(self.type_fixer)
tv.default.accept(self.type_fixer)
def visit_paramspec_expr(self, p: ParamSpecExpr) -> None:
p.upper_bound.accept(self.type_fixer)
p.default.accept(self.type_fixer)
def visit_type_var_tuple_expr(self, tv: TypeVarTupleExpr) -> None:
tv.upper_bound.accept(self.type_fixer)
tv.tuple_fallback.accept(self.type_fixer)
tv.default.accept(self.type_fixer)
def visit_var(self, v: Var) -> None:
if self.current_info is not None:
v.info = self.current_info
if v.type is not None:
v.type.accept(self.type_fixer)
if v.setter_type is not None:
v.setter_type.accept(self.type_fixer)
def visit_type_alias(self, a: TypeAlias) -> None:
a.target.accept(self.type_fixer)
for v in a.alias_tvars:
v.accept(self.type_fixer)
class TypeFixer(TypeVisitor[None]):
def __init__(self, modules: dict[str, MypyFile], allow_missing: bool) -> None:
self.modules = modules
self.allow_missing = allow_missing
def visit_instance(self, inst: Instance) -> None:
# TODO: Combine Instances that are exactly the same?
type_ref = inst.type_ref
if type_ref is None:
return # We've already been here.
inst.type_ref = None
inst.type = lookup_fully_qualified_typeinfo(
self.modules, type_ref, allow_missing=self.allow_missing
)
# TODO: Is this needed or redundant?
# Also fix up the bases, just in case.
for base in inst.type.bases:
if base.type is NOT_READY:
base.accept(self)
for a in inst.args:
a.accept(self)
if inst.last_known_value is not None:
inst.last_known_value.accept(self)
if inst.extra_attrs:
for v in inst.extra_attrs.attrs.values():
v.accept(self)
def visit_type_alias_type(self, t: TypeAliasType) -> None:
type_ref = t.type_ref
if type_ref is None:
return # We've already been here.
t.type_ref = None
t.alias = lookup_fully_qualified_alias(
self.modules, type_ref, allow_missing=self.allow_missing
)
for a in t.args:
a.accept(self)
def visit_any(self, o: Any) -> None:
pass # Nothing to descend into.
def visit_callable_type(self, ct: CallableType) -> None:
if ct.fallback:
ct.fallback.accept(self)
for argt in ct.arg_types:
# argt may be None, e.g. for __self in NamedTuple constructors.
if argt is not None:
argt.accept(self)
if ct.ret_type is not None:
ct.ret_type.accept(self)
for v in ct.variables:
v.accept(self)
if ct.type_guard is not None:
ct.type_guard.accept(self)
if ct.type_is is not None:
ct.type_is.accept(self)
def visit_overloaded(self, t: Overloaded) -> None:
for ct in t.items:
ct.accept(self)
def visit_erased_type(self, o: Any) -> None:
# This type should exist only temporarily during type inference
raise RuntimeError("Shouldn't get here", o)
def visit_deleted_type(self, o: Any) -> None:
pass # Nothing to descend into.
def visit_none_type(self, o: Any) -> None:
pass # Nothing to descend into.
def visit_uninhabited_type(self, o: Any) -> None:
pass # Nothing to descend into.
def visit_partial_type(self, o: Any) -> None:
raise RuntimeError("Shouldn't get here", o)
def visit_tuple_type(self, tt: TupleType) -> None:
if tt.items:
for it in tt.items:
it.accept(self)
if tt.partial_fallback is not None:
tt.partial_fallback.accept(self)
def visit_typeddict_type(self, tdt: TypedDictType) -> None:
if tdt.items:
for it in tdt.items.values():
it.accept(self)
if tdt.fallback is not None:
if tdt.fallback.type_ref is not None:
if (
lookup_fully_qualified(
tdt.fallback.type_ref,
self.modules,
raise_on_missing=not self.allow_missing,
)
is None
):
# We reject fake TypeInfos for TypedDict fallbacks because
# the latter are used in type checking and must be valid.
tdt.fallback.type_ref = "typing._TypedDict"
tdt.fallback.accept(self)
def visit_literal_type(self, lt: LiteralType) -> None:
lt.fallback.accept(self)
def visit_type_var(self, tvt: TypeVarType) -> None:
if tvt.values:
for vt in tvt.values:
vt.accept(self)
tvt.upper_bound.accept(self)
tvt.default.accept(self)
def visit_param_spec(self, p: ParamSpecType) -> None:
p.upper_bound.accept(self)
p.default.accept(self)
p.prefix.accept(self)
def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
t.tuple_fallback.accept(self)
t.upper_bound.accept(self)
t.default.accept(self)
def visit_unpack_type(self, u: UnpackType) -> None:
u.type.accept(self)
def visit_parameters(self, p: Parameters) -> None:
for argt in p.arg_types:
if argt is not None:
argt.accept(self)
for var in p.variables:
var.accept(self)
def visit_unbound_type(self, o: UnboundType) -> None:
for a in o.args:
a.accept(self)
def visit_union_type(self, ut: UnionType) -> None:
if ut.items:
for it in ut.items:
it.accept(self)
def visit_type_type(self, t: TypeType) -> None:
t.item.accept(self)
def lookup_fully_qualified_typeinfo(
modules: dict[str, MypyFile], name: str, *, allow_missing: bool
) -> TypeInfo:
stnode = lookup_fully_qualified(name, modules, raise_on_missing=not allow_missing)
node = stnode.node if stnode else None
if isinstance(node, TypeInfo):
return node
else:
# Looks like a missing TypeInfo during an initial daemon load, put something there
assert (
allow_missing
), "Should never get here in normal mode, got {}:{} instead of TypeInfo".format(
type(node).__name__, node.fullname if node else ""
)
return missing_info(modules)
def lookup_fully_qualified_alias(
modules: dict[str, MypyFile], name: str, *, allow_missing: bool
) -> TypeAlias:
stnode = lookup_fully_qualified(name, modules, raise_on_missing=not allow_missing)
node = stnode.node if stnode else None
if isinstance(node, TypeAlias):
return node
elif isinstance(node, TypeInfo):
if node.special_alias:
# Already fixed up.
return node.special_alias
if node.tuple_type:
alias = TypeAlias.from_tuple_type(node)
elif node.typeddict_type:
alias = TypeAlias.from_typeddict_type(node)
else:
assert allow_missing
return missing_alias()
node.special_alias = alias
return alias
else:
# Looks like a missing TypeAlias during an initial daemon load, put something there
assert (
allow_missing
), "Should never get here in normal mode, got {}:{} instead of TypeAlias".format(
type(node).__name__, node.fullname if node else ""
)
return missing_alias()
_SUGGESTION: Final = "<missing {}: *should* have gone away during fine-grained update>"
def missing_info(modules: dict[str, MypyFile]) -> TypeInfo:
suggestion = _SUGGESTION.format("info")
dummy_def = ClassDef(suggestion, Block([]))
dummy_def.fullname = suggestion
info = TypeInfo(SymbolTable(), dummy_def, "<missing>")
obj_type = lookup_fully_qualified_typeinfo(modules, "builtins.object", allow_missing=False)
info.bases = [Instance(obj_type, [])]
info.mro = [info, obj_type]
return info
def missing_alias() -> TypeAlias:
suggestion = _SUGGESTION.format("alias")
return TypeAlias(AnyType(TypeOfAny.special_form), suggestion, "<missing>", line=-1, column=-1)

View file

@ -0,0 +1,23 @@
"""Generic node traverser visitor"""
from __future__ import annotations
from mypy.nodes import Block, MypyFile
from mypy.traverser import TraverserVisitor
class TreeFreer(TraverserVisitor):
def visit_block(self, block: Block) -> None:
super().visit_block(block)
block.body.clear()
def free_tree(tree: MypyFile) -> None:
"""Free all the ASTs associated with a module.
This needs to be done recursively, since symbol tables contain
references to definitions, so those won't be freed but we want their
contents to be.
"""
tree.accept(TreeFreer())
tree.defs.clear()

View file

@ -0,0 +1,307 @@
"""Interface for accessing the file system with automatic caching.
The idea is to cache the results of any file system state reads during
a single transaction. This has two main benefits:
* This avoids redundant syscalls, as we won't perform the same OS
operations multiple times.
* This makes it easier to reason about concurrent FS updates, as different
operations targeting the same paths can't report different state during
a transaction.
Note that this only deals with reading state, not writing.
Properties maintained by the API:
* The contents of the file are always from the same or later time compared
to the reported mtime of the file, even if mtime is queried after reading
a file.
* Repeating an operation produces the same result as the first one during
a transaction.
* Call flush() to start a new transaction (flush the caches).
The API is a bit limited. It's easy to add new cached operations, however.
You should perform all file system reads through the API to actually take
advantage of the benefits.
"""
from __future__ import annotations
import os
import stat
from mypy_extensions import mypyc_attr
from mypy.util import hash_digest
@mypyc_attr(allow_interpreted_subclasses=True) # for tests
class FileSystemCache:
def __init__(self) -> None:
# The package root is not flushed with the caches.
# It is set by set_package_root() below.
self.package_root: list[str] = []
self.flush()
def set_package_root(self, package_root: list[str]) -> None:
self.package_root = package_root
def flush(self) -> None:
"""Start another transaction and empty all caches."""
self.stat_or_none_cache: dict[str, os.stat_result | None] = {}
self.listdir_cache: dict[str, list[str]] = {}
self.listdir_error_cache: dict[str, OSError] = {}
self.isfile_case_cache: dict[str, bool] = {}
self.exists_case_cache: dict[str, bool] = {}
self.read_cache: dict[str, bytes] = {}
self.read_error_cache: dict[str, Exception] = {}
self.hash_cache: dict[str, str] = {}
self.fake_package_cache: set[str] = set()
def stat_or_none(self, path: str) -> os.stat_result | None:
if path in self.stat_or_none_cache:
return self.stat_or_none_cache[path]
st = None
try:
st = os.stat(path)
except OSError:
if self.init_under_package_root(path):
try:
st = self._fake_init(path)
except OSError:
pass
self.stat_or_none_cache[path] = st
return st
def init_under_package_root(self, path: str) -> bool:
"""Is this path an __init__.py under a package root?
This is used to detect packages that don't contain __init__.py
files, which is needed to support Bazel. The function should
only be called for non-existing files.
It will return True if it refers to a __init__.py file that
Bazel would create, so that at runtime Python would think the
directory containing it is a package. For this to work you
must pass one or more package roots using the --package-root
flag.
As an exceptional case, any directory that is a package root
itself will not be considered to contain a __init__.py file.
This is different from the rules Bazel itself applies, but is
necessary for mypy to properly distinguish packages from other
directories.
See https://docs.bazel.build/versions/master/be/python.html,
where this behavior is described under legacy_create_init.
"""
if not self.package_root:
return False
dirname, basename = os.path.split(path)
if basename != "__init__.py":
return False
if not os.path.basename(dirname).isidentifier():
# Can't put an __init__.py in a place that's not an identifier
return False
st = self.stat_or_none(dirname)
if st is None:
return False
else:
if not stat.S_ISDIR(st.st_mode):
return False
ok = False
# skip if on a different drive
current_drive, _ = os.path.splitdrive(os.getcwd())
drive, _ = os.path.splitdrive(path)
if drive != current_drive:
return False
if os.path.isabs(path):
path = os.path.relpath(path)
path = os.path.normpath(path)
for root in self.package_root:
if path.startswith(root):
if path == root + basename:
# A package root itself is never a package.
ok = False
break
else:
ok = True
return ok
def _fake_init(self, path: str) -> os.stat_result:
"""Prime the cache with a fake __init__.py file.
This makes code that looks for path believe an empty file by
that name exists. Should only be called after
init_under_package_root() returns True.
"""
dirname, basename = os.path.split(path)
assert basename == "__init__.py", path
assert not os.path.exists(path), path # Not cached!
dirname = os.path.normpath(dirname)
st = os.stat(dirname) # May raise OSError
# Get stat result as a list so we can modify it.
seq: list[float] = list(st)
seq[stat.ST_MODE] = stat.S_IFREG | 0o444
seq[stat.ST_INO] = 1
seq[stat.ST_NLINK] = 1
seq[stat.ST_SIZE] = 0
st = os.stat_result(seq)
# Make listdir() and read() also pretend this file exists.
self.fake_package_cache.add(dirname)
return st
def listdir(self, path: str) -> list[str]:
path = os.path.normpath(path)
if path in self.listdir_cache:
res = self.listdir_cache[path]
# Check the fake cache.
if path in self.fake_package_cache and "__init__.py" not in res:
res.append("__init__.py") # Updates the result as well as the cache
return res
if path in self.listdir_error_cache:
raise copy_os_error(self.listdir_error_cache[path])
try:
results = os.listdir(path)
except OSError as err:
# Like above, take a copy to reduce memory use.
self.listdir_error_cache[path] = copy_os_error(err)
raise err
self.listdir_cache[path] = results
# Check the fake cache.
if path in self.fake_package_cache and "__init__.py" not in results:
results.append("__init__.py")
return results
def isfile(self, path: str) -> bool:
st = self.stat_or_none(path)
if st is None:
return False
return stat.S_ISREG(st.st_mode)
def isfile_case(self, path: str, prefix: str) -> bool:
"""Return whether path exists and is a file.
On case-insensitive filesystems (like Mac or Windows) this returns
False if the case of path's last component does not exactly match
the case found in the filesystem.
We check also the case of other path components up to prefix.
For example, if path is 'user-stubs/pack/mod.pyi' and prefix is 'user-stubs',
we check that the case of 'pack' and 'mod.py' matches exactly, 'user-stubs' will be
case insensitive on case insensitive filesystems.
The caller must ensure that prefix is a valid file system prefix of path.
"""
if not self.isfile(path):
# Fast path
return False
if path in self.isfile_case_cache:
return self.isfile_case_cache[path]
head, tail = os.path.split(path)
if not tail:
self.isfile_case_cache[path] = False
return False
try:
names = self.listdir(head)
# This allows one to check file name case sensitively in
# case-insensitive filesystems.
res = tail in names
except OSError:
res = False
if res:
# Also recursively check the other path components in case sensitive way.
res = self.exists_case(head, prefix)
self.isfile_case_cache[path] = res
return res
def exists_case(self, path: str, prefix: str) -> bool:
"""Return whether path exists - checking path components in case sensitive
fashion, up to prefix.
"""
if path in self.exists_case_cache:
return self.exists_case_cache[path]
head, tail = os.path.split(path)
if not head.startswith(prefix) or not tail:
# Only perform the check for paths under prefix.
self.exists_case_cache[path] = True
return True
try:
names = self.listdir(head)
# This allows one to check file name case sensitively in
# case-insensitive filesystems.
res = tail in names
except OSError:
res = False
if res:
# Also recursively check other path components.
res = self.exists_case(head, prefix)
self.exists_case_cache[path] = res
return res
def isdir(self, path: str) -> bool:
st = self.stat_or_none(path)
if st is None:
return False
return stat.S_ISDIR(st.st_mode)
def exists(self, path: str) -> bool:
st = self.stat_or_none(path)
return st is not None
def read(self, path: str) -> bytes:
if path in self.read_cache:
return self.read_cache[path]
if path in self.read_error_cache:
raise self.read_error_cache[path]
# Need to stat first so that the contents of file are from no
# earlier instant than the mtime reported by self.stat().
self.stat_or_none(path)
dirname, basename = os.path.split(path)
dirname = os.path.normpath(dirname)
# Check the fake cache.
if basename == "__init__.py" and dirname in self.fake_package_cache:
data = b""
else:
try:
with open(path, "rb") as f:
data = f.read()
except OSError as err:
self.read_error_cache[path] = err
raise
self.read_cache[path] = data
self.hash_cache[path] = hash_digest(data)
return data
def hash_digest(self, path: str) -> str:
if path not in self.hash_cache:
self.read(path)
return self.hash_cache[path]
def samefile(self, f1: str, f2: str) -> bool:
s1 = self.stat_or_none(f1)
s2 = self.stat_or_none(f2)
if s1 is None or s2 is None:
return False
return os.path.samestat(s1, s2)
def copy_os_error(e: OSError) -> OSError:
new = OSError(*e.args)
new.errno = e.errno
new.strerror = e.strerror
new.filename = e.filename
if e.filename2:
new.filename2 = e.filename2
return new

View file

@ -0,0 +1,106 @@
"""Watch parts of the file system for changes."""
from __future__ import annotations
import os
from collections.abc import Iterable, Set as AbstractSet
from typing import NamedTuple
from mypy.fscache import FileSystemCache
class FileData(NamedTuple):
st_mtime: float
st_size: int
hash: str
class FileSystemWatcher:
"""Watcher for file system changes among specific paths.
All file system access is performed using FileSystemCache. We
detect changed files by stat()ing them all and comparing hashes
of potentially changed files. If a file has both size and mtime
unmodified, the file is assumed to be unchanged.
An important goal of this class is to make it easier to eventually
use file system events to detect file changes.
Note: This class doesn't flush the file system cache. If you don't
manually flush it, changes won't be seen.
"""
# TODO: Watching directories?
# TODO: Handle non-files
def __init__(self, fs: FileSystemCache) -> None:
self.fs = fs
self._paths: set[str] = set()
self._file_data: dict[str, FileData | None] = {}
def dump_file_data(self) -> dict[str, tuple[float, int, str]]:
return {k: v for k, v in self._file_data.items() if v is not None}
def set_file_data(self, path: str, data: FileData) -> None:
self._file_data[path] = data
def add_watched_paths(self, paths: Iterable[str]) -> None:
for path in paths:
if path not in self._paths:
# By storing None this path will get reported as changed by
# find_changed if it exists.
self._file_data[path] = None
self._paths |= set(paths)
def remove_watched_paths(self, paths: Iterable[str]) -> None:
for path in paths:
if path in self._file_data:
del self._file_data[path]
self._paths -= set(paths)
def _update(self, path: str, st: os.stat_result) -> None:
hash_digest = self.fs.hash_digest(path)
self._file_data[path] = FileData(st.st_mtime, st.st_size, hash_digest)
def _find_changed(self, paths: Iterable[str]) -> AbstractSet[str]:
changed = set()
for path in paths:
old = self._file_data[path]
st = self.fs.stat_or_none(path)
if st is None:
if old is not None:
# File was deleted.
changed.add(path)
self._file_data[path] = None
else:
if old is None:
# File is new.
changed.add(path)
self._update(path, st)
# Round mtimes down, to match the mtimes we write to meta files
elif st.st_size != old.st_size or int(st.st_mtime) != int(old.st_mtime):
# Only look for changes if size or mtime has changed as an
# optimization, since calculating hash is expensive.
new_hash = self.fs.hash_digest(path)
self._update(path, st)
if st.st_size != old.st_size or new_hash != old.hash:
# Changed file.
changed.add(path)
return changed
def find_changed(self) -> AbstractSet[str]:
"""Return paths that have changes since the last call, in the watched set."""
return self._find_changed(self._paths)
def update_changed(self, remove: list[str], update: list[str]) -> AbstractSet[str]:
"""Alternative to find_changed() given explicit changes.
This only calls self.fs.stat() on added or updated files, not
on all files. It believes all other files are unchanged!
Implies add_watched_paths() for add and update, and
remove_watched_paths() for remove.
"""
self.remove_watched_paths(remove)
self.add_watched_paths(update)
return self._find_changed(update)

View file

@ -0,0 +1,48 @@
from __future__ import annotations
import gc
import time
from collections.abc import Mapping
class GcLogger:
"""Context manager to log GC stats and overall time."""
def __enter__(self) -> GcLogger:
self.gc_start_time: float | None = None
self.gc_time = 0.0
self.gc_calls = 0
self.gc_collected = 0
self.gc_uncollectable = 0
gc.callbacks.append(self.gc_callback)
self.start_time = time.time()
return self
def gc_callback(self, phase: str, info: Mapping[str, int]) -> None:
if phase == "start":
assert self.gc_start_time is None, "Start phase out of sequence"
self.gc_start_time = time.time()
elif phase == "stop":
assert self.gc_start_time is not None, "Stop phase out of sequence"
self.gc_calls += 1
self.gc_time += time.time() - self.gc_start_time
self.gc_start_time = None
self.gc_collected += info["collected"]
self.gc_uncollectable += info["uncollectable"]
else:
assert False, f"Unrecognized gc phase ({phase!r})"
def __exit__(self, *args: object) -> None:
while self.gc_callback in gc.callbacks:
gc.callbacks.remove(self.gc_callback)
def get_stats(self) -> Mapping[str, float]:
end_time = time.time()
result = {
"gc_time": self.gc_time,
"gc_calls": self.gc_calls,
"gc_collected": self.gc_collected,
"gc_uncollectable": self.gc_uncollectable,
"build_time": end_time - self.start_time,
}
return result

View file

@ -0,0 +1,34 @@
"""Git utilities."""
# Used also from setup.py, so don't pull in anything additional here (like mypy or typing):
from __future__ import annotations
import os
import subprocess
def is_git_repo(dir: str) -> bool:
"""Is the given directory version-controlled with git?"""
return os.path.exists(os.path.join(dir, ".git"))
def have_git() -> bool:
"""Can we run the git executable?"""
try:
subprocess.check_output(["git", "--help"])
return True
except subprocess.CalledProcessError:
return False
except OSError:
return False
def git_revision(dir: str) -> bytes:
"""Get the SHA-1 of the HEAD of a git repository."""
return subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=dir).strip()
def is_dirty(dir: str) -> bool:
"""Check whether a git repository has uncommitted changes."""
output = subprocess.check_output(["git", "status", "-uno", "--porcelain"], cwd=dir)
return output.strip() != b""

View file

@ -0,0 +1,161 @@
"""Helpers for manipulations with graphs."""
from __future__ import annotations
from collections.abc import Iterator, Set as AbstractSet
from typing import TypeVar
T = TypeVar("T")
def strongly_connected_components(
vertices: AbstractSet[T], edges: dict[T, list[T]]
) -> Iterator[set[T]]:
"""Compute Strongly Connected Components of a directed graph.
Args:
vertices: the labels for the vertices
edges: for each vertex, gives the target vertices of its outgoing edges
Returns:
An iterator yielding strongly connected components, each
represented as a set of vertices. Each input vertex will occur
exactly once; vertices not part of a SCC are returned as
singleton sets.
From https://code.activestate.com/recipes/578507/.
"""
identified: set[T] = set()
stack: list[T] = []
index: dict[T, int] = {}
boundaries: list[int] = []
def dfs(v: T) -> Iterator[set[T]]:
index[v] = len(stack)
stack.append(v)
boundaries.append(index[v])
for w in edges[v]:
if w not in index:
yield from dfs(w)
elif w not in identified:
while index[w] < boundaries[-1]:
boundaries.pop()
if boundaries[-1] == index[v]:
boundaries.pop()
scc = set(stack[index[v] :])
del stack[index[v] :]
identified.update(scc)
yield scc
for v in vertices:
if v not in index:
yield from dfs(v)
def prepare_sccs(
sccs: list[set[T]], edges: dict[T, list[T]]
) -> dict[AbstractSet[T], set[AbstractSet[T]]]:
"""Use original edges to organize SCCs in a graph by dependencies between them."""
sccsmap = {}
for scc in sccs:
scc_frozen = frozenset(scc)
for v in scc:
sccsmap[v] = scc_frozen
data: dict[AbstractSet[T], set[AbstractSet[T]]] = {}
for scc in sccs:
deps: set[AbstractSet[T]] = set()
for v in scc:
deps.update(sccsmap[x] for x in edges[v])
data[frozenset(scc)] = deps
return data
class topsort(Iterator[set[T]]): # noqa: N801
"""Topological sort using Kahn's algorithm.
Uses in-degree counters and a reverse adjacency list, so the total work
is O(V + E).
Implemented as a class rather than a generator for better mypyc
compilation.
Args:
data: A map from vertices to all vertices that it has an edge
connecting it to. NOTE: dependency sets in this data
structure are modified in place to remove self-dependencies.
Orphans are handled internally and are not added to `data`.
Returns:
An iterator yielding sets of vertices that have an equivalent
ordering.
Example:
Suppose the input has the following structure:
{A: {B, C}, B: {D}, C: {D}}
The algorithm treats orphan dependencies as if normalized to:
{A: {B, C}, B: {D}, C: {D}, D: {}}
It will yield the following values:
{D}
{B, C}
{A}
"""
def __init__(self, data: dict[T, set[T]]) -> None:
# Single pass: remove self-deps, build reverse adjacency list,
# compute in-degree counts, detect orphans, and find initial ready set.
in_degree: dict[T, int] = {}
rev: dict[T, list[T]] = {}
ready: set[T] = set()
for item, deps in data.items():
deps.discard(item) # Ignore self dependencies.
deg = len(deps)
in_degree[item] = deg
if deg == 0:
ready.add(item)
if item not in rev:
rev[item] = []
for dep in deps:
if dep in rev:
rev[dep].append(item)
else:
rev[dep] = [item]
if dep not in data:
# Orphan: appears as dependency but has no entry in data.
in_degree[dep] = 0
ready.add(dep)
self.in_degree = in_degree
self.rev = rev
self.ready = ready
self.remaining = len(in_degree) - len(ready)
def __iter__(self) -> Iterator[set[T]]:
return self
def __next__(self) -> set[T]:
ready = self.ready
if not ready:
assert self.remaining == 0, (
f"A cyclic dependency exists amongst "
f"{[k for k, deg in self.in_degree.items() if deg > 0]!r}"
)
raise StopIteration
in_degree = self.in_degree
rev = self.rev
new_ready: set[T] = set()
for item in ready:
for dependent in rev[item]:
new_deg = in_degree[dependent] - 1
in_degree[dependent] = new_deg
if new_deg == 0:
new_ready.add(dependent)
self.remaining -= len(new_ready)
self.ready = new_ready
return ready

View file

@ -0,0 +1,170 @@
from __future__ import annotations
from collections.abc import Iterable
import mypy.types as types
from mypy.types import TypeVisitor
class TypeIndirectionVisitor(TypeVisitor[None]):
"""Returns all module references within a particular type."""
def __init__(self) -> None:
# Module references are collected here
self.modules: set[str] = set()
# User to avoid infinite recursion with recursive types
self.seen_types: set[types.TypeAliasType | types.Instance] = set()
def find_modules(self, typs: Iterable[types.Type]) -> set[str]:
self.modules = set()
self.seen_types = set()
for typ in typs:
self._visit(typ)
return self.modules
def _visit(self, typ: types.Type) -> None:
# Note: instances are needed for `class str(Sequence[str]): ...`
if (
isinstance(typ, types.TypeAliasType)
or isinstance(typ, types.ProperType)
and isinstance(typ, types.Instance)
):
# Avoid infinite recursion for recursive types.
if typ in self.seen_types:
return
self.seen_types.add(typ)
typ.accept(self)
def _visit_type_tuple(self, typs: tuple[types.Type, ...]) -> None:
# Micro-optimization: Specialized version of _visit for lists
for typ in typs:
if (
isinstance(typ, types.TypeAliasType)
or isinstance(typ, types.ProperType)
and isinstance(typ, types.Instance)
):
# Avoid infinite recursion for recursive types.
if typ in self.seen_types:
continue
self.seen_types.add(typ)
typ.accept(self)
def _visit_type_list(self, typs: list[types.Type]) -> None:
# Micro-optimization: Specialized version of _visit for tuples
for typ in typs:
if (
isinstance(typ, types.TypeAliasType)
or isinstance(typ, types.ProperType)
and isinstance(typ, types.Instance)
):
# Avoid infinite recursion for recursive types.
if typ in self.seen_types:
continue
self.seen_types.add(typ)
typ.accept(self)
def visit_unbound_type(self, t: types.UnboundType) -> None:
self._visit_type_tuple(t.args)
def visit_any(self, t: types.AnyType) -> None:
pass
def visit_none_type(self, t: types.NoneType) -> None:
pass
def visit_uninhabited_type(self, t: types.UninhabitedType) -> None:
pass
def visit_erased_type(self, t: types.ErasedType) -> None:
pass
def visit_deleted_type(self, t: types.DeletedType) -> None:
pass
def visit_type_var(self, t: types.TypeVarType) -> None:
self._visit_type_list(t.values)
self._visit(t.upper_bound)
self._visit(t.default)
def visit_param_spec(self, t: types.ParamSpecType) -> None:
self._visit(t.upper_bound)
self._visit(t.default)
self._visit(t.prefix)
def visit_type_var_tuple(self, t: types.TypeVarTupleType) -> None:
self._visit(t.upper_bound)
self._visit(t.default)
def visit_unpack_type(self, t: types.UnpackType) -> None:
t.type.accept(self)
def visit_parameters(self, t: types.Parameters) -> None:
self._visit_type_list(t.arg_types)
def visit_instance(self, t: types.Instance) -> None:
# Instance is named, record its definition and continue digging into
# components that constitute semantic meaning of this type: bases, metaclass,
# tuple type, and typeddict type.
# Note: we cannot simply record the MRO, in case an intermediate base contains
# a reference to type alias, this affects meaning of map_instance_to_supertype(),
# see e.g. testDoubleReexportGenericUpdated.
self._visit_type_tuple(t.args)
if t.type:
# Important optimization: instead of simply recording the definition and
# recursing into bases, record the MRO and only traverse generic bases.
for s in t.type.mro:
self.modules.add(s.module_name)
for base in s.bases:
if base.args:
self._visit_type_tuple(base.args)
if t.type.metaclass_type:
self._visit(t.type.metaclass_type)
if t.type.typeddict_type:
self._visit(t.type.typeddict_type)
if t.type.tuple_type:
self._visit(t.type.tuple_type)
if t.type.is_protocol:
# For protocols, member types constitute the semantic meaning of the type.
# TODO: this doesn't cover some edge cases, like setter types and exotic nodes.
for m in t.type.protocol_members:
node = t.type.names.get(m)
if node and node.type:
self._visit(node.type)
def visit_callable_type(self, t: types.CallableType) -> None:
self._visit_type_list(t.arg_types)
self._visit(t.ret_type)
self._visit_type_tuple(t.variables)
def visit_overloaded(self, t: types.Overloaded) -> None:
for item in t.items:
self._visit(item)
self._visit(t.fallback)
def visit_tuple_type(self, t: types.TupleType) -> None:
self._visit_type_list(t.items)
self._visit(t.partial_fallback)
def visit_typeddict_type(self, t: types.TypedDictType) -> None:
self._visit_type_list(list(t.items.values()))
self._visit(t.fallback)
def visit_literal_type(self, t: types.LiteralType) -> None:
self._visit(t.fallback)
def visit_union_type(self, t: types.UnionType) -> None:
self._visit_type_list(t.items)
def visit_partial_type(self, t: types.PartialType) -> None:
pass
def visit_type_type(self, t: types.TypeType) -> None:
self._visit(t.item)
def visit_type_alias_type(self, t: types.TypeAliasType) -> None:
# Type alias is named, record its definition and continue digging into
# components that constitute semantic meaning of this type: target and args.
if t.alias:
self.modules.add(t.alias.module)
self._visit(t.alias.target)
self._visit_type_list(t.args)

View file

@ -0,0 +1,76 @@
"""Utilities for type argument inference."""
from __future__ import annotations
from collections.abc import Sequence
from typing import NamedTuple
from mypy.constraints import (
SUBTYPE_OF,
SUPERTYPE_OF,
infer_constraints,
infer_constraints_for_callable,
)
from mypy.nodes import ArgKind
from mypy.solve import solve_constraints
from mypy.types import CallableType, Instance, Type, TypeVarLikeType
class ArgumentInferContext(NamedTuple):
"""Type argument inference context.
We need this because we pass around ``Mapping`` and ``Iterable`` types.
These types are only known by ``TypeChecker`` itself.
It is required for ``*`` and ``**`` argument inference.
https://github.com/python/mypy/issues/11144
"""
mapping_type: Instance
iterable_type: Instance
def infer_function_type_arguments(
callee_type: CallableType,
arg_types: Sequence[Type | None],
arg_kinds: list[ArgKind],
arg_names: Sequence[str | None] | None,
formal_to_actual: list[list[int]],
context: ArgumentInferContext,
strict: bool = True,
allow_polymorphic: bool = False,
) -> tuple[list[Type | None], list[TypeVarLikeType]]:
"""Infer the type arguments of a generic function.
Return an array of lower bound types for the type variables -1 (at
index 0), -2 (at index 1), etc. A lower bound is None if a value
could not be inferred.
Arguments:
callee_type: the target generic function
arg_types: argument types at the call site (each optional; if None,
we are not considering this argument in the current pass)
arg_kinds: nodes.ARG_* values for arg_types
formal_to_actual: mapping from formal to actual variable indices
"""
# Infer constraints.
constraints = infer_constraints_for_callable(
callee_type, arg_types, arg_kinds, arg_names, formal_to_actual, context
)
# Solve constraints.
type_vars = callee_type.variables
return solve_constraints(type_vars, constraints, strict, allow_polymorphic)
def infer_type_arguments(
type_vars: Sequence[TypeVarLikeType],
template: Type,
actual: Type,
is_supertype: bool = False,
skip_unsatisfied: bool = False,
) -> list[Type | None]:
# Like infer_function_type_arguments, but only match a single type
# against a generic type.
constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF)
return solve_constraints(type_vars, constraints, skip_unsatisfied=skip_unsatisfied)[0]

View file

@ -0,0 +1,626 @@
from __future__ import annotations
import os
from collections import defaultdict
from collections.abc import Callable
from functools import cmp_to_key
from mypy.build import State
from mypy.messages import format_type
from mypy.modulefinder import PYTHON_EXTENSIONS
from mypy.nodes import (
LDEF,
Decorator,
Expression,
FuncBase,
MemberExpr,
MypyFile,
Node,
OverloadedFuncDef,
RefExpr,
SymbolNode,
TypeInfo,
Var,
)
from mypy.server.update import FineGrainedBuildManager
from mypy.traverser import ExtendedTraverserVisitor
from mypy.typeops import tuple_fallback
from mypy.types import (
FunctionLike,
Instance,
LiteralType,
ProperType,
TupleType,
TypedDictType,
TypeVarType,
UnionType,
get_proper_type,
)
from mypy.typevars import fill_typevars_with_any
def node_starts_after(o: Node, line: int, column: int) -> bool:
return o.line > line or o.line == line and o.column > column
def node_ends_before(o: Node, line: int, column: int) -> bool:
# Unfortunately, end positions for some statements are a mess,
# e.g. overloaded functions, so we return False when we don't know.
if o.end_line is not None and o.end_column is not None:
if o.end_line < line or o.end_line == line and o.end_column < column:
return True
return False
def expr_span(expr: Expression) -> str:
"""Format expression span as in mypy error messages."""
return f"{expr.line}:{expr.column + 1}:{expr.end_line}:{expr.end_column}"
def get_instance_fallback(typ: ProperType) -> list[Instance]:
"""Returns the Instance fallback for this type if one exists or None."""
if isinstance(typ, Instance):
return [typ]
elif isinstance(typ, TupleType):
return [tuple_fallback(typ)]
elif isinstance(typ, TypedDictType):
return [typ.fallback]
elif isinstance(typ, FunctionLike):
return [typ.fallback]
elif isinstance(typ, LiteralType):
return [typ.fallback]
elif isinstance(typ, TypeVarType):
if typ.values:
res = []
for t in typ.values:
res.extend(get_instance_fallback(get_proper_type(t)))
return res
return get_instance_fallback(get_proper_type(typ.upper_bound))
elif isinstance(typ, UnionType):
res = []
for t in typ.items:
res.extend(get_instance_fallback(get_proper_type(t)))
return res
return []
def find_node(name: str, info: TypeInfo) -> Var | FuncBase | None:
"""Find the node defining member 'name' in given TypeInfo."""
# TODO: this code shares some logic with checkmember.py
method = info.get_method(name)
if method:
if isinstance(method, Decorator):
return method.var
if method.is_property:
assert isinstance(method, OverloadedFuncDef)
dec = method.items[0]
assert isinstance(dec, Decorator)
return dec.var
return method
else:
# don't have such method, maybe variable?
node = info.get(name)
v = node.node if node else None
if isinstance(v, Var):
return v
return None
def find_module_by_fullname(fullname: str, modules: dict[str, State]) -> State | None:
"""Find module by a node fullname.
This logic mimics the one we use in fixup, so should be good enough.
"""
head = fullname
# Special case: a module symbol is considered to be defined in itself, not in enclosing
# package, since this is what users want when clicking go to definition on a module.
if head in modules:
return modules[head]
while True:
if "." not in head:
return None
head, tail = head.rsplit(".", maxsplit=1)
mod = modules.get(head)
if mod is not None:
return mod
class SearchVisitor(ExtendedTraverserVisitor):
"""Visitor looking for an expression whose span matches given one exactly."""
def __init__(self, line: int, column: int, end_line: int, end_column: int) -> None:
self.line = line
self.column = column
self.end_line = end_line
self.end_column = end_column
self.result: Expression | None = None
def visit(self, o: Node) -> bool:
if node_starts_after(o, self.line, self.column):
return False
if node_ends_before(o, self.end_line, self.end_column):
return False
if (
o.line == self.line
and o.end_line == self.end_line
and o.column == self.column
and o.end_column == self.end_column
):
if isinstance(o, Expression):
self.result = o
return self.result is None
def find_by_location(
tree: MypyFile, line: int, column: int, end_line: int, end_column: int
) -> Expression | None:
"""Find an expression matching given span, or None if not found."""
if end_line < line:
raise ValueError('"end_line" must not be before "line"')
if end_line == line and end_column <= column:
raise ValueError('"end_column" must be after "column"')
visitor = SearchVisitor(line, column, end_line, end_column)
tree.accept(visitor)
return visitor.result
class SearchAllVisitor(ExtendedTraverserVisitor):
"""Visitor looking for all expressions whose spans enclose given position."""
def __init__(self, line: int, column: int) -> None:
self.line = line
self.column = column
self.result: list[Expression] = []
def visit(self, o: Node) -> bool:
if node_starts_after(o, self.line, self.column):
return False
if node_ends_before(o, self.line, self.column):
return False
if isinstance(o, Expression):
self.result.append(o)
return True
def find_all_by_location(tree: MypyFile, line: int, column: int) -> list[Expression]:
"""Find all expressions enclosing given position starting from innermost."""
visitor = SearchAllVisitor(line, column)
tree.accept(visitor)
return list(reversed(visitor.result))
class InspectionEngine:
"""Engine for locating and statically inspecting expressions."""
def __init__(
self,
fg_manager: FineGrainedBuildManager,
*,
verbosity: int = 0,
limit: int = 0,
include_span: bool = False,
include_kind: bool = False,
include_object_attrs: bool = False,
union_attrs: bool = False,
force_reload: bool = False,
) -> None:
self.fg_manager = fg_manager
self.verbosity = verbosity
self.limit = limit
self.include_span = include_span
self.include_kind = include_kind
self.include_object_attrs = include_object_attrs
self.union_attrs = union_attrs
self.force_reload = force_reload
# Module for which inspection was requested.
self.module: State | None = None
def reload_module(self, state: State) -> None:
"""Reload given module while temporary exporting types."""
old = self.fg_manager.manager.options.export_types
self.fg_manager.manager.options.export_types = True
try:
self.fg_manager.flush_cache()
assert state.path is not None
self.fg_manager.update([(state.id, state.path)], [])
finally:
self.fg_manager.manager.options.export_types = old
def expr_type(self, expression: Expression) -> tuple[str, bool]:
"""Format type for an expression using current options.
If type is known, second item returned is True. If type is not known, an error
message is returned instead, and second item returned is False.
"""
expr_type = self.fg_manager.manager.all_types.get(expression)
if expr_type is None:
return self.missing_type(expression), False
type_str = format_type(
expr_type, self.fg_manager.manager.options, verbosity=self.verbosity
)
return self.add_prefixes(type_str, expression), True
def object_type(self) -> Instance:
builtins = self.fg_manager.graph["builtins"].tree
assert builtins is not None
object_node = builtins.names["object"].node
assert isinstance(object_node, TypeInfo)
return Instance(object_node, [])
def collect_attrs(self, instances: list[Instance]) -> dict[TypeInfo, list[str]]:
"""Collect attributes from all union/typevar variants."""
def item_attrs(attr_dict: dict[TypeInfo, list[str]]) -> set[str]:
attrs = set()
for base in attr_dict:
attrs |= set(attr_dict[base])
return attrs
def cmp_types(x: TypeInfo, y: TypeInfo) -> int:
if x in y.mro:
return 1
if y in x.mro:
return -1
return 0
# First gather all attributes for every union variant.
assert instances
all_attrs = []
for instance in instances:
attrs = {}
mro = instance.type.mro
if not self.include_object_attrs:
mro = mro[:-1]
for base in mro:
attrs[base] = sorted(base.names)
all_attrs.append(attrs)
# Find attributes valid for all variants in a union or type variable.
intersection = item_attrs(all_attrs[0])
for item in all_attrs[1:]:
intersection &= item_attrs(item)
# Combine attributes from all variants into a single dict while
# also removing invalid attributes (unless using --union-attrs).
combined_attrs = defaultdict(list)
for item in all_attrs:
for base in item:
if base in combined_attrs:
continue
for name in item[base]:
if self.union_attrs or name in intersection:
combined_attrs[base].append(name)
# Sort bases by MRO, unrelated will appear in the order they appeared as union variants.
sorted_bases = sorted(combined_attrs.keys(), key=cmp_to_key(cmp_types))
result = {}
for base in sorted_bases:
if not combined_attrs[base]:
# Skip bases where everytihng was filtered out.
continue
result[base] = combined_attrs[base]
return result
def _fill_from_dict(
self, attrs_strs: list[str], attrs_dict: dict[TypeInfo, list[str]]
) -> None:
for base in attrs_dict:
cls_name = base.name if self.verbosity < 1 else base.fullname
attrs = [f'"{attr}"' for attr in attrs_dict[base]]
attrs_strs.append(f'"{cls_name}": [{", ".join(attrs)}]')
def expr_attrs(self, expression: Expression) -> tuple[str, bool]:
"""Format attributes that are valid for a given expression.
If expression type is not an Instance, try using fallback. Attributes are
returned as a JSON (ordered by MRO) that maps base class name to list of
attributes. Attributes may appear in multiple bases if overridden (we simply
follow usual mypy logic for creating new Vars etc).
"""
expr_type = self.fg_manager.manager.all_types.get(expression)
if expr_type is None:
return self.missing_type(expression), False
expr_type = get_proper_type(expr_type)
instances = get_instance_fallback(expr_type)
if not instances:
# Everything is an object in Python.
instances = [self.object_type()]
attrs_dict = self.collect_attrs(instances)
# Special case: modules have names apart from those from ModuleType.
if isinstance(expression, RefExpr) and isinstance(expression.node, MypyFile):
node = expression.node
names = sorted(node.names)
if "__builtins__" in names:
# This is just to make tests stable. No one will really need this name.
names.remove("__builtins__")
mod_dict = {f'"<{node.fullname}>"': [f'"{name}"' for name in names]}
else:
mod_dict = {}
# Special case: for class callables, prepend with the class attributes.
# TODO: also handle cases when such callable appears in a union.
if isinstance(expr_type, FunctionLike) and expr_type.is_type_obj():
template = fill_typevars_with_any(expr_type.type_object())
class_dict = self.collect_attrs(get_instance_fallback(template))
else:
class_dict = {}
# We don't use JSON dump to be sure keys order is always preserved.
base_attrs = []
if mod_dict:
for mod in mod_dict:
base_attrs.append(f'{mod}: [{", ".join(mod_dict[mod])}]')
self._fill_from_dict(base_attrs, class_dict)
self._fill_from_dict(base_attrs, attrs_dict)
return self.add_prefixes(f'{{{", ".join(base_attrs)}}}', expression), True
def format_node(self, module: State, node: FuncBase | SymbolNode) -> str:
return f"{module.path}:{node.line}:{node.column + 1}:{node.name}"
def collect_nodes(self, expression: RefExpr) -> list[FuncBase | SymbolNode]:
"""Collect nodes that can be referred to by an expression.
Note: it can be more than one for example in case of a union attribute.
"""
node: FuncBase | SymbolNode | None = expression.node
nodes: list[FuncBase | SymbolNode]
if node is None:
# Tricky case: instance attribute
if isinstance(expression, MemberExpr) and expression.kind is None:
base_type = self.fg_manager.manager.all_types.get(expression.expr)
if base_type is None:
return []
# Now we use the base type to figure out where the attribute is defined.
base_type = get_proper_type(base_type)
instances = get_instance_fallback(base_type)
nodes = []
for instance in instances:
node = find_node(expression.name, instance.type)
if node:
nodes.append(node)
if not nodes:
# Try checking class namespace if attribute is on a class object.
if isinstance(base_type, FunctionLike) and base_type.is_type_obj():
instances = get_instance_fallback(
fill_typevars_with_any(base_type.type_object())
)
for instance in instances:
node = find_node(expression.name, instance.type)
if node:
nodes.append(node)
else:
# Still no luck, give up.
return []
else:
return []
else:
# Easy case: a module-level definition
nodes = [node]
return nodes
def modules_for_nodes(
self, nodes: list[FuncBase | SymbolNode], expression: RefExpr
) -> tuple[dict[FuncBase | SymbolNode, State], bool]:
"""Gather modules where given nodes where defined.
Also check if they need to be refreshed (cached nodes may have
lines/columns missing).
"""
modules = {}
reload_needed = False
for node in nodes:
module = find_module_by_fullname(node.fullname, self.fg_manager.graph)
if not module:
if expression.kind == LDEF and self.module:
module = self.module
else:
continue
modules[node] = module
if not module.tree or module.tree.is_cache_skeleton or self.force_reload:
reload_needed |= not module.tree or module.tree.is_cache_skeleton
self.reload_module(module)
return modules, reload_needed
def expression_def(self, expression: Expression) -> tuple[str, bool]:
"""Find and format definition location for an expression.
If it is not a RefExpr, it is effectively skipped by returning an
empty result.
"""
if not isinstance(expression, RefExpr):
# If there are no suitable matches at all, we return error later.
return "", True
nodes = self.collect_nodes(expression)
if not nodes:
return self.missing_node(expression), False
modules, reload_needed = self.modules_for_nodes(nodes, expression)
if reload_needed:
# TODO: line/column are not stored in cache for vast majority of symbol nodes.
# Adding them will make thing faster, but will have visible memory impact.
nodes = self.collect_nodes(expression)
modules, reload_needed = self.modules_for_nodes(nodes, expression)
assert not reload_needed
result = []
for node in modules:
result.append(self.format_node(modules[node], node))
if not result:
return self.missing_node(expression), False
return self.add_prefixes(", ".join(result), expression), True
def missing_type(self, expression: Expression) -> str:
alt_suggestion = ""
if not self.force_reload:
alt_suggestion = " or try --force-reload"
return (
f'No known type available for "{type(expression).__name__}"'
f" (maybe unreachable{alt_suggestion})"
)
def missing_node(self, expression: Expression) -> str:
return (
f'Cannot find definition for "{type(expression).__name__}" at {expr_span(expression)}'
)
def add_prefixes(self, result: str, expression: Expression) -> str:
prefixes = []
if self.include_kind:
prefixes.append(f"{type(expression).__name__}")
if self.include_span:
prefixes.append(expr_span(expression))
if prefixes:
prefix = ":".join(prefixes) + " -> "
else:
prefix = ""
return prefix + result
def run_inspection_by_exact_location(
self,
tree: MypyFile,
line: int,
column: int,
end_line: int,
end_column: int,
method: Callable[[Expression], tuple[str, bool]],
) -> dict[str, object]:
"""Get type of an expression matching a span.
Type or error is returned as a standard daemon response dict.
"""
try:
expression = find_by_location(tree, line, column - 1, end_line, end_column)
except ValueError as err:
return {"error": str(err)}
if expression is None:
span = f"{line}:{column}:{end_line}:{end_column}"
return {"out": f"Can't find expression at span {span}", "err": "", "status": 1}
inspection_str, success = method(expression)
return {"out": inspection_str, "err": "", "status": 0 if success else 1}
def run_inspection_by_position(
self,
tree: MypyFile,
line: int,
column: int,
method: Callable[[Expression], tuple[str, bool]],
) -> dict[str, object]:
"""Get types of all expressions enclosing a position.
Types and/or errors are returned as a standard daemon response dict.
"""
expressions = find_all_by_location(tree, line, column - 1)
if not expressions:
position = f"{line}:{column}"
return {
"out": f"Can't find any expressions at position {position}",
"err": "",
"status": 1,
}
inspection_strs = []
status = 0
for expression in expressions:
inspection_str, success = method(expression)
if not success:
status = 1
if inspection_str:
inspection_strs.append(inspection_str)
if self.limit:
inspection_strs = inspection_strs[: self.limit]
return {"out": "\n".join(inspection_strs), "err": "", "status": status}
def find_module(self, file: str) -> tuple[State | None, dict[str, object]]:
"""Find module by path, or return a suitable error message.
Note we don't use exceptions to simplify handling 1 vs 2 statuses.
"""
if not any(file.endswith(ext) for ext in PYTHON_EXTENSIONS):
return None, {"error": "Source file is not a Python file"}
# We are using a bit slower but robust way to find a module by path,
# to be sure that namespace packages are handled properly.
abs_path = os.path.abspath(file)
state = next((s for s in self.fg_manager.graph.values() if s.abspath == abs_path), None)
self.module = state
return (
state,
{"out": f"Unknown module: {file}", "err": "", "status": 1} if state is None else {},
)
def run_inspection(
self, location: str, method: Callable[[Expression], tuple[str, bool]]
) -> dict[str, object]:
"""Top-level logic to inspect expression(s) at a location.
This can be reused by various simple inspections.
"""
try:
file, pos = parse_location(location)
except ValueError as err:
return {"error": str(err)}
state, err_dict = self.find_module(file)
if state is None:
assert err_dict
return err_dict
# Force reloading to load from cache, account for any edits, etc.
if not state.tree or state.tree.is_cache_skeleton or self.force_reload:
self.reload_module(state)
assert state.tree is not None
if len(pos) == 4:
# Full span, return an exact match only.
line, column, end_line, end_column = pos
return self.run_inspection_by_exact_location(
state.tree, line, column, end_line, end_column, method
)
assert len(pos) == 2
# Inexact location, return all expressions.
line, column = pos
return self.run_inspection_by_position(state.tree, line, column, method)
def get_type(self, location: str) -> dict[str, object]:
"""Get types of expression(s) at a location."""
return self.run_inspection(location, self.expr_type)
def get_attrs(self, location: str) -> dict[str, object]:
"""Get attributes of expression(s) at a location."""
return self.run_inspection(location, self.expr_attrs)
def get_definition(self, location: str) -> dict[str, object]:
"""Get symbol definitions of expression(s) at a location."""
result = self.run_inspection(location, self.expression_def)
if "out" in result and not result["out"]:
# None of the expressions found turns out to be a RefExpr.
_, location = location.split(":", maxsplit=1)
result["out"] = f"No name or member expressions at {location}"
result["status"] = 1
return result
def parse_location(location: str) -> tuple[str, list[int]]:
if location.count(":") < 2:
raise ValueError("Format should be file:line:column[:end_line:end_column]")
parts = location.rsplit(":", maxsplit=2)
start, *rest = parts
# Note: we must allow drive prefix like `C:` on Windows.
if start.count(":") < 2:
return start, [int(p) for p in rest]
parts = start.rsplit(":", maxsplit=2)
start, *start_rest = parts
if start.count(":") < 2:
return start, [int(p) for p in start_rest + rest]
raise ValueError("Format should be file:line:column[:end_line:end_column]")

View file

@ -0,0 +1,457 @@
"""Cross-platform abstractions for inter-process communication
On Unix, this uses AF_UNIX sockets.
On Windows, this uses NamedPipes.
"""
from __future__ import annotations
import json
import os
import shutil
import struct
import sys
import tempfile
from abc import abstractmethod
from collections.abc import Callable
from select import select
from types import TracebackType
from typing import Final
from typing_extensions import Self
from librt.base64 import urlsafe_b64encode
from librt.internal import ReadBuffer, WriteBuffer
if sys.platform == "win32":
# This may be private, but it is needed for IPC on Windows, and is basically stable
import _winapi
import ctypes
_IPCHandle = int
kernel32 = ctypes.windll.kernel32
DisconnectNamedPipe: Callable[[_IPCHandle], int] = kernel32.DisconnectNamedPipe
FlushFileBuffers: Callable[[_IPCHandle], int] = kernel32.FlushFileBuffers
else:
import socket
_IPCHandle = socket.socket
# Size of the message packed as !L, i.e. 4 bytes in network order (big-endian).
HEADER_SIZE = 4
# TODO: we should make sure consistent exceptions are raised on different platforms.
# Currently we raise either IPCException or OSError for equivalent conditions.
class IPCException(Exception):
"""Exception for IPC issues."""
class IPCBase:
"""Base class for communication between the dmypy client and server.
This contains logic shared between the client and server, such as reading
and writing.
We want to be able to send multiple "messages" over a single connection and
to be able to separate the messages. We do this by prefixing each message
with its size in a fixed format.
"""
connection: _IPCHandle
def __init__(self, name: str, timeout: float | None) -> None:
self.name = name
self.timeout = timeout
self.message_size: int | None = None
self.buffer = bytearray()
def frame_from_buffer(self) -> bytes | None:
"""Return a full frame from the bytes we have in the buffer."""
size = len(self.buffer)
if size < HEADER_SIZE:
return None
if self.message_size is None:
self.message_size = struct.unpack("!L", self.buffer[:HEADER_SIZE])[0]
if size < self.message_size + HEADER_SIZE:
return None
# We have a full frame, avoid extra copy in case we get a large frame.
bdata = memoryview(self.buffer)[HEADER_SIZE : HEADER_SIZE + self.message_size]
self.buffer = self.buffer[HEADER_SIZE + self.message_size :]
self.message_size = None
return bytes(bdata)
def read(self, size: int = 100000) -> str:
return self.read_bytes(size).decode("utf-8")
def read_bytes(self, size: int = 100000) -> bytes:
"""Read bytes from an IPC connection until we have a full frame."""
if sys.platform == "win32":
while True:
# Check if we already have a message in the buffer before
# receiving any more data from the socket.
bdata = self.frame_from_buffer()
if bdata is not None:
break
# Receive more data into the buffer.
ov, err = _winapi.ReadFile(self.connection, size, overlapped=True)
try:
if err == _winapi.ERROR_IO_PENDING:
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
res = _winapi.WaitForSingleObject(ov.event, timeout)
if res != _winapi.WAIT_OBJECT_0:
raise IPCException(f"Bad result from I/O wait: {res}")
except BaseException:
ov.cancel()
raise
_, err = ov.GetOverlappedResult(True)
more = ov.getbuffer()
if more:
self.buffer.extend(more)
bdata = self.frame_from_buffer()
if bdata is not None:
break
if err == 0:
# we are done!
break
elif err == _winapi.ERROR_MORE_DATA:
# read again
continue
elif err == _winapi.ERROR_OPERATION_ABORTED:
raise IPCException("ReadFile operation aborted.")
else:
while True:
# Check if we already have a message in the buffer before
# receiving any more data from the socket.
bdata = self.frame_from_buffer()
if bdata is not None:
break
# Receive more data into the buffer.
more = self.connection.recv(size)
if not more:
# Connection closed
break
self.buffer.extend(more)
if not bdata:
# Socket was empty, and we didn't get any frame.
# This should only happen if the socket was closed.
return b""
return bdata
def write(self, data: str) -> None:
self.write_bytes(data.encode("utf-8"))
def write_bytes(self, data: bytes) -> None:
"""Write to an IPC connection."""
# Frame the data by adding fixed size header.
encoded_data = struct.pack("!L", len(data)) + data
if sys.platform == "win32":
try:
ov, err = _winapi.WriteFile(self.connection, encoded_data, overlapped=True)
try:
if err == _winapi.ERROR_IO_PENDING:
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
res = _winapi.WaitForSingleObject(ov.event, timeout)
if res != _winapi.WAIT_OBJECT_0:
raise IPCException(f"Bad result from I/O wait: {res}")
elif err != 0:
raise IPCException(f"Failed writing to pipe with error: {err}")
except BaseException:
ov.cancel()
raise
bytes_written, err = ov.GetOverlappedResult(True)
assert err == 0, err
assert bytes_written == len(encoded_data)
except OSError as e:
raise IPCException(f"Failed to write with error: {e.winerror}") from e
else:
self.connection.sendall(encoded_data)
def close(self) -> None:
if sys.platform == "win32":
if self.connection != _winapi.NULL:
_winapi.CloseHandle(self.connection)
else:
self.connection.close()
class IPCClient(IPCBase):
"""The client side of an IPC connection."""
def __init__(self, name: str, timeout: float | None) -> None:
super().__init__(name, timeout)
if sys.platform == "win32":
timeout = int(self.timeout * 1000) if self.timeout else _winapi.NMPWAIT_WAIT_FOREVER
try:
_winapi.WaitNamedPipe(self.name, timeout)
except FileNotFoundError as e:
raise IPCException(f"The NamedPipe at {self.name} was not found.") from e
except OSError as e:
if e.winerror == _winapi.ERROR_SEM_TIMEOUT:
raise IPCException("Timed out waiting for connection.") from e
else:
raise
try:
self.connection = _winapi.CreateFile(
self.name,
_winapi.GENERIC_READ | _winapi.GENERIC_WRITE,
0,
_winapi.NULL,
_winapi.OPEN_EXISTING,
_winapi.FILE_FLAG_OVERLAPPED,
_winapi.NULL,
)
except OSError as e:
if e.winerror == _winapi.ERROR_PIPE_BUSY:
raise IPCException("The connection is busy.") from e
else:
raise
_winapi.SetNamedPipeHandleState(
self.connection, _winapi.PIPE_READMODE_MESSAGE, None, None
)
else:
self.connection = socket.socket(socket.AF_UNIX)
self.connection.settimeout(timeout)
self.connection.connect(name)
def __enter__(self) -> IPCClient:
return self
def __exit__(
self,
exc_ty: type[BaseException] | None = None,
exc_val: BaseException | None = None,
exc_tb: TracebackType | None = None,
) -> None:
self.close()
class IPCServer(IPCBase):
BUFFER_SIZE: Final = 2**16
def __init__(self, name: str, timeout: float | None = None) -> None:
if sys.platform == "win32":
name = r"\\.\pipe\{}-{}.pipe".format(name, urlsafe_b64encode(os.urandom(6)).decode())
else:
name = f"{name}.sock"
super().__init__(name, timeout)
if sys.platform == "win32":
self.connection = _winapi.CreateNamedPipe(
self.name,
_winapi.PIPE_ACCESS_DUPLEX
| _winapi.FILE_FLAG_FIRST_PIPE_INSTANCE
| _winapi.FILE_FLAG_OVERLAPPED,
_winapi.PIPE_READMODE_MESSAGE
| _winapi.PIPE_TYPE_MESSAGE
| _winapi.PIPE_WAIT
| 0x8, # PIPE_REJECT_REMOTE_CLIENTS
1, # one instance
self.BUFFER_SIZE,
self.BUFFER_SIZE,
_winapi.NMPWAIT_WAIT_FOREVER,
0, # Use default security descriptor
)
if self.connection == -1: # INVALID_HANDLE_VALUE
err = _winapi.GetLastError()
raise IPCException(f"Invalid handle to pipe: {err}")
else:
self.sock_directory = tempfile.mkdtemp()
sockfile = os.path.join(self.sock_directory, self.name)
self.sock = socket.socket(socket.AF_UNIX)
self.sock.bind(sockfile)
self.sock.listen(1)
if timeout is not None:
self.sock.settimeout(timeout)
def __enter__(self) -> IPCServer:
if sys.platform == "win32":
# NOTE: It is theoretically possible that this will hang forever if the
# client never connects, though this can be "solved" by killing the server
try:
ov = _winapi.ConnectNamedPipe(self.connection, overlapped=True)
except OSError as e:
# Don't raise if the client already exists, or the client already connected
if e.winerror not in (_winapi.ERROR_PIPE_CONNECTED, _winapi.ERROR_NO_DATA):
raise
else:
try:
timeout = int(self.timeout * 1000) if self.timeout else _winapi.INFINITE
res = _winapi.WaitForSingleObject(ov.event, timeout)
assert res == _winapi.WAIT_OBJECT_0
except BaseException:
ov.cancel()
_winapi.CloseHandle(self.connection)
raise
_, err = ov.GetOverlappedResult(True)
assert err == 0
else:
try:
self.connection, _ = self.sock.accept()
except TimeoutError as e:
raise IPCException("The socket timed out") from e
return self
def __exit__(
self,
exc_ty: type[BaseException] | None = None,
exc_val: BaseException | None = None,
exc_tb: TracebackType | None = None,
) -> None:
if sys.platform == "win32":
try:
# Wait for the client to finish reading the last write before disconnecting
if not FlushFileBuffers(self.connection):
raise IPCException(
"Failed to flush NamedPipe buffer, maybe the client hung up?"
)
finally:
DisconnectNamedPipe(self.connection)
else:
self.close()
def cleanup(self) -> None:
if sys.platform == "win32":
self.close()
else:
shutil.rmtree(self.sock_directory)
@property
def connection_name(self) -> str:
if sys.platform == "win32":
return self.name
elif sys.platform == "gnu0":
# GNU/Hurd returns empty string from getsockname()
# for AF_UNIX sockets
return os.path.join(self.sock_directory, self.name)
else:
name = self.sock.getsockname()
assert isinstance(name, str)
return name
class BadStatus(Exception):
"""Exception raised when there is something wrong with the status file.
For example:
- No status file found
- Status file malformed
- Process whose pid is in the status file does not exist
"""
def read_status(status_file: str) -> dict[str, object]:
"""Read status file.
Raise BadStatus if the status file doesn't exist or contains
invalid JSON or the JSON is not a dict.
"""
if not os.path.isfile(status_file):
raise BadStatus("No status file found")
with open(status_file) as f:
try:
data = json.load(f)
except Exception as e:
raise BadStatus(f"Malformed status file: {str(e)}") from e
if not isinstance(data, dict):
raise BadStatus(f"Invalid status file (not a dict): {data}")
return data
def ready_to_read(conns: list[IPCClient], timeout: float | None = None) -> list[int]:
"""Wait until some connections are readable.
Return index of each readable connection in the original list.
"""
if sys.platform == "win32":
# Windows doesn't support select() on named pipes. Instead, start an overlapped
# ReadFile on each pipe (which internally creates an event via CreateEventW),
# then WaitForMultipleObjects on those events for efficient OS-level waiting.
# Any data consumed by the probe reads is stored into each connection's buffer
# so the subsequent read_bytes() call will find it via frame_from_buffer().
WAIT_FAILED = 0xFFFFFFFF
pending: list[tuple[int, _winapi.Overlapped]] = []
events: list[int] = []
ready: list[int] = []
for i, conn in enumerate(conns):
try:
ov, err = _winapi.ReadFile(conn.connection, 1, overlapped=True)
except OSError:
# Broken/closed pipe. Mimic Linux behavior here, caller will get
# the exception when trying to read from this socket.
ready.append(i)
continue
if err == _winapi.ERROR_IO_PENDING:
events.append(ov.event)
pending.append((i, ov))
else:
# Data was immediately available (err == 0 or ERROR_MORE_DATA)
_, err = ov.GetOverlappedResult(True)
data = ov.getbuffer()
if data:
conn.buffer.extend(data)
ready.append(i)
# Wait only if nothing is immediately ready and there are pending operations
if not ready and events:
timeout_ms = int(timeout * 1000) if timeout is not None else _winapi.INFINITE
res = _winapi.WaitForMultipleObjects(events, False, timeout_ms)
if res == WAIT_FAILED:
for _, ov in pending:
ov.cancel()
raise IPCException(f"Failed to wait for connections: {_winapi.GetLastError()}")
# Check which pending operations completed, cancel the rest
for i, ov in pending:
if _winapi.WaitForSingleObject(ov.event, 0) == _winapi.WAIT_OBJECT_0:
_, err = ov.GetOverlappedResult(True)
data = ov.getbuffer()
if data:
conns[i].buffer.extend(data)
ready.append(i)
else:
ov.cancel()
return ready
else:
connections = [conn.connection for conn in conns]
ready, _, _ = select(connections, [], [], timeout)
return [connections.index(r) for r in ready]
def send(connection: IPCBase, data: IPCMessage) -> None:
"""Send data to a connection encoded and framed.
The data must be a non-abstract IPCMessage. We assume that a single send call is a
single frame to be sent.
"""
buf = WriteBuffer()
data.write(buf)
connection.write_bytes(buf.getvalue())
def receive(connection: IPCBase) -> ReadBuffer:
"""Receive single encoded IPCMessage frame from a connection.
Raise OSError if the data received is not valid.
"""
bdata = connection.read_bytes()
if not bdata:
raise OSError("No data received")
return ReadBuffer(bdata)
class IPCMessage:
@classmethod
@abstractmethod
def read(cls, buf: ReadBuffer) -> Self:
raise NotImplementedError
@abstractmethod
def write(self, buf: WriteBuffer) -> None:
raise NotImplementedError

View file

@ -0,0 +1,916 @@
"""Calculation of the least upper bound types (joins)."""
from __future__ import annotations
from collections.abc import Sequence
from typing import overload
import mypy.typeops
from mypy.expandtype import expand_type
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY, TypeInfo
from mypy.state import state
from mypy.subtypes import (
SubtypeContext,
find_member,
is_equivalent,
is_proper_subtype,
is_protocol_implementation,
is_subtype,
)
from mypy.types import (
AnyType,
CallableType,
DeletedType,
ErasedType,
FunctionLike,
Instance,
LiteralType,
NoneType,
Overloaded,
Parameters,
ParamSpecType,
PartialType,
ProperType,
TupleType,
Type,
TypeAliasType,
TypedDictType,
TypeOfAny,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
TypeVisitor,
UnboundType,
UninhabitedType,
UnionType,
UnpackType,
find_unpack_in_list,
get_proper_type,
get_proper_types,
split_with_prefix_and_suffix,
)
class InstanceJoiner:
def __init__(self) -> None:
self.seen_instances: list[tuple[Instance, Instance]] = []
def join_instances(self, t: Instance, s: Instance) -> ProperType:
if (t, s) in self.seen_instances or (s, t) in self.seen_instances:
return object_from_instance(t)
self.seen_instances.append((t, s))
# Calculate the join of two instance types
if t.type == s.type:
# Simplest case: join two types with the same base type (but
# potentially different arguments).
# Combine type arguments.
args: list[Type] = []
# N.B: We use zip instead of indexing because the lengths might have
# mismatches during daemon reprocessing.
if t.type.has_type_var_tuple_type:
# We handle joins of variadic instances by simply creating correct mapping
# for type arguments and compute the individual joins same as for regular
# instances. All the heavy lifting is done in the join of tuple types.
assert s.type.type_var_tuple_prefix is not None
assert s.type.type_var_tuple_suffix is not None
prefix = s.type.type_var_tuple_prefix
suffix = s.type.type_var_tuple_suffix
tvt = s.type.defn.type_vars[prefix]
assert isinstance(tvt, TypeVarTupleType)
fallback = tvt.tuple_fallback
s_prefix, s_middle, s_suffix = split_with_prefix_and_suffix(s.args, prefix, suffix)
t_prefix, t_middle, t_suffix = split_with_prefix_and_suffix(t.args, prefix, suffix)
s_args = s_prefix + (TupleType(list(s_middle), fallback),) + s_suffix
t_args = t_prefix + (TupleType(list(t_middle), fallback),) + t_suffix
else:
t_args = t.args
s_args = s.args
for ta, sa, type_var in zip(t_args, s_args, t.type.defn.type_vars):
ta_proper = get_proper_type(ta)
sa_proper = get_proper_type(sa)
new_type: Type | None = None
if isinstance(ta_proper, AnyType):
new_type = AnyType(TypeOfAny.from_another_any, ta_proper)
elif isinstance(sa_proper, AnyType):
new_type = AnyType(TypeOfAny.from_another_any, sa_proper)
elif isinstance(type_var, TypeVarType):
if type_var.variance in (COVARIANT, VARIANCE_NOT_READY):
new_type = join_types(ta, sa, self)
if len(type_var.values) != 0 and new_type not in type_var.values:
self.seen_instances.pop()
return object_from_instance(t)
if not is_subtype(new_type, type_var.upper_bound):
self.seen_instances.pop()
return object_from_instance(t)
# TODO: contravariant case should use meet but pass seen instances as
# an argument to keep track of recursive checks.
elif type_var.variance in (INVARIANT, CONTRAVARIANT):
if isinstance(ta_proper, UninhabitedType) and ta_proper.ambiguous:
new_type = sa
elif isinstance(sa_proper, UninhabitedType) and sa_proper.ambiguous:
new_type = ta
elif not is_equivalent(ta, sa):
self.seen_instances.pop()
return object_from_instance(t)
else:
# If the types are different but equivalent, then an Any is involved
# so using a join in the contravariant case is also OK.
new_type = join_types(ta, sa, self)
elif isinstance(type_var, TypeVarTupleType):
new_type = get_proper_type(join_types(ta, sa, self))
# Put the joined arguments back into instance in the normal form:
# a) Tuple[X, Y, Z] -> [X, Y, Z]
# b) tuple[X, ...] -> [*tuple[X, ...]]
if isinstance(new_type, Instance):
assert new_type.type.fullname == "builtins.tuple"
new_type = UnpackType(new_type)
else:
assert isinstance(new_type, TupleType)
args.extend(new_type.items)
continue
else:
# ParamSpec type variables behave the same, independent of variance
if not is_equivalent(ta, sa):
return get_proper_type(type_var.upper_bound)
new_type = join_types(ta, sa, self)
assert new_type is not None
args.append(new_type)
result: ProperType = Instance(t.type, args)
elif t.type.bases and is_proper_subtype(
t, s, subtype_context=SubtypeContext(ignore_type_params=True)
):
result = self.join_instances_via_supertype(t, s)
else:
# Now t is not a subtype of s, and t != s. Now s could be a subtype
# of t; alternatively, we need to find a common supertype. This works
# in of the both cases.
result = self.join_instances_via_supertype(s, t)
self.seen_instances.pop()
return result
def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:
# Give preference to joins via duck typing relationship, so that
# join(int, float) == float, for example.
for p in t.type._promote:
if is_subtype(p, s):
return join_types(p, s, self)
for p in s.type._promote:
if is_subtype(p, t):
return join_types(t, p, self)
# Compute the "best" supertype of t when joined with s.
# The definition of "best" may evolve; for now it is the one with
# the longest MRO. Ties are broken by using the earlier base.
# Go over both sets of bases in case there's an explicit Protocol base. This is important
# to ensure commutativity of join (although in cases where both classes have relevant
# Protocol bases this maybe might still not be commutative)
base_types: dict[TypeInfo, None] = {} # dict to deduplicate but preserve order
for base in t.type.bases:
base_types[base.type] = None
for base in s.type.bases:
if base.type.is_protocol and is_subtype(t, base):
base_types[base.type] = None
best: ProperType | None = None
for base_type in base_types:
mapped = map_instance_to_supertype(t, base_type)
res = self.join_instances(mapped, s)
if best is None or is_better(res, best):
best = res
assert best is not None
for promote in t.type._promote:
if isinstance(promote, Instance):
res = self.join_instances(promote, s)
if is_better(res, best):
best = res
return best
def trivial_join(s: Type, t: Type) -> Type:
"""Return one of types (expanded) if it is a supertype of other, otherwise top type."""
if is_subtype(s, t):
return t
elif is_subtype(t, s):
return s
else:
return object_or_any_from_type(get_proper_type(t))
@overload
def join_types(
s: ProperType, t: ProperType, instance_joiner: InstanceJoiner | None = None
) -> ProperType: ...
@overload
def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None) -> Type: ...
def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None) -> Type:
"""Return the least upper bound of s and t.
For example, the join of 'int' and 'object' is 'object'.
"""
if mypy.typeops.is_recursive_pair(s, t):
# This case can trigger an infinite recursion, general support for this will be
# tricky so we use a trivial join (like for protocols).
return trivial_join(s, t)
s = get_proper_type(s)
t = get_proper_type(t)
if (s.can_be_true, s.can_be_false) != (t.can_be_true, t.can_be_false):
# if types are restricted in different ways, use the more general versions
s = mypy.typeops.true_or_false(s)
t = mypy.typeops.true_or_false(t)
if isinstance(s, UnionType) and not isinstance(t, UnionType):
s, t = t, s
if isinstance(s, AnyType):
return s
if isinstance(s, ErasedType):
return t
if isinstance(s, NoneType) and not isinstance(t, NoneType):
s, t = t, s
if isinstance(s, UninhabitedType) and not isinstance(t, UninhabitedType):
s, t = t, s
# Meets/joins require callable type normalization.
s, t = normalize_callables(s, t)
# Use a visitor to handle non-trivial cases.
return t.accept(TypeJoinVisitor(s, instance_joiner))
class TypeJoinVisitor(TypeVisitor[ProperType]):
"""Implementation of the least upper bound algorithm.
Attributes:
s: The other (left) type operand.
"""
def __init__(self, s: ProperType, instance_joiner: InstanceJoiner | None = None) -> None:
self.s = s
self.instance_joiner = instance_joiner
def visit_unbound_type(self, t: UnboundType) -> ProperType:
return AnyType(TypeOfAny.special_form)
def visit_union_type(self, t: UnionType) -> ProperType:
if is_proper_subtype(self.s, t):
return t
else:
return mypy.typeops.make_simplified_union([self.s, t])
def visit_any(self, t: AnyType) -> ProperType:
return t
def visit_none_type(self, t: NoneType) -> ProperType:
if state.strict_optional:
if isinstance(self.s, (NoneType, UninhabitedType)):
return t
elif isinstance(self.s, (UnboundType, AnyType)):
return AnyType(TypeOfAny.special_form)
else:
return mypy.typeops.make_simplified_union([self.s, t])
else:
return self.s
def visit_uninhabited_type(self, t: UninhabitedType) -> ProperType:
return self.s
def visit_deleted_type(self, t: DeletedType) -> ProperType:
return self.s
def visit_erased_type(self, t: ErasedType) -> ProperType:
return self.s
def visit_type_var(self, t: TypeVarType) -> ProperType:
if isinstance(self.s, TypeVarType):
if self.s.id == t.id:
if self.s.upper_bound == t.upper_bound:
return self.s
return self.s.copy_modified(
upper_bound=join_types(self.s.upper_bound, t.upper_bound)
)
# Fix non-commutative joins
return get_proper_type(join_types(self.s.upper_bound, t.upper_bound))
else:
return self.default(self.s)
def visit_param_spec(self, t: ParamSpecType) -> ProperType:
if self.s == t:
return t
return self.default(self.s)
def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType:
if self.s == t:
return t
if isinstance(self.s, Instance) and is_subtype(t.upper_bound, self.s):
# TODO: should we do this more generally and for all TypeVarLikeTypes?
return self.s
return self.default(self.s)
def visit_unpack_type(self, t: UnpackType) -> UnpackType:
raise NotImplementedError
def visit_parameters(self, t: Parameters) -> ProperType:
if isinstance(self.s, Parameters):
if not is_similar_params(t, self.s):
# TODO: it would be prudent to return [*object, **object] instead of Any.
return self.default(self.s)
from mypy.meet import meet_types
return t.copy_modified(
arg_types=[
meet_types(s_a, t_a) for s_a, t_a in zip(self.s.arg_types, t.arg_types)
],
arg_names=combine_arg_names(self.s, t),
)
else:
return self.default(self.s)
def visit_instance(self, t: Instance) -> ProperType:
if isinstance(self.s, Instance):
if self.instance_joiner is None:
self.instance_joiner = InstanceJoiner()
nominal = self.instance_joiner.join_instances(t, self.s)
structural: Instance | None = None
if t.type.is_protocol and is_protocol_implementation(self.s, t):
structural = t
elif self.s.type.is_protocol and is_protocol_implementation(t, self.s):
structural = self.s
# Structural join is preferred in the case where we have found both
# structural and nominal and they have same MRO length (see two comments
# in join_instances_via_supertype). Otherwise, just return the nominal join.
if not structural or is_better(nominal, structural):
return nominal
return structural
elif isinstance(self.s, FunctionLike):
if t.type.is_protocol:
call = unpack_callback_protocol(t)
if call:
return join_types(call, self.s)
return join_types(t, self.s.fallback)
elif isinstance(self.s, TypeType):
return join_types(t, self.s)
elif isinstance(self.s, TypedDictType):
return join_types(t, self.s)
elif isinstance(self.s, TupleType):
return join_types(t, self.s)
elif isinstance(self.s, LiteralType):
return join_types(t, self.s)
elif isinstance(self.s, TypeVarTupleType) and is_subtype(self.s.upper_bound, t):
return t
else:
return self.default(self.s)
def visit_callable_type(self, t: CallableType) -> ProperType:
if isinstance(self.s, CallableType):
if is_similar_callables(t, self.s):
if is_equivalent(t, self.s):
return combine_similar_callables(t, self.s)
result = join_similar_callables(t, self.s)
if any(
isinstance(tp, (NoneType, UninhabitedType))
for tp in get_proper_types(result.arg_types)
):
# We don't want to return unusable Callable, attempt fallback instead.
return join_types(t.fallback, self.s)
# We set the from_type_type flag to suppress error when a collection of
# concrete class objects gets inferred as their common abstract superclass.
if not (
(t.is_type_obj() and t.type_object().is_abstract)
or (self.s.is_type_obj() and self.s.type_object().is_abstract)
):
result.from_type_type = True
return result
else:
s2, t2 = self.s, t
if t2.is_var_arg:
s2, t2 = t2, s2
if is_subtype(s2, t2):
return t2.copy_modified()
elif is_subtype(t2, s2):
return s2.copy_modified()
return join_types(t.fallback, self.s)
elif isinstance(self.s, Overloaded):
# Switch the order of arguments to that we'll get to visit_overloaded.
return join_types(t, self.s)
elif isinstance(self.s, Instance) and self.s.type.is_protocol:
call = unpack_callback_protocol(self.s)
if call:
return join_types(t, call)
return join_types(t.fallback, self.s)
def visit_overloaded(self, t: Overloaded) -> ProperType:
# This is more complex than most other cases. Here are some
# examples that illustrate how this works.
#
# First let's define a concise notation:
# - Cn are callable types (for n in 1, 2, ...)
# - Ov(C1, C2, ...) is an overloaded type with items C1, C2, ...
# - Callable[[T, ...], S] is written as [T, ...] -> S.
#
# We want some basic properties to hold (assume Cn are all
# unrelated via Any-similarity):
#
# join(Ov(C1, C2), C1) == C1
# join(Ov(C1, C2), Ov(C1, C2)) == Ov(C1, C2)
# join(Ov(C1, C2), Ov(C1, C3)) == C1
# join(Ov(C2, C2), C3) == join of fallback types
#
# The presence of Any types makes things more interesting. The join is the
# most general type we can get with respect to Any:
#
# join(Ov([int] -> int, [str] -> str), [Any] -> str) == Any -> str
#
# We could use a simplification step that removes redundancies, but that's not
# implemented right now. Consider this example, where we get a redundancy:
#
# join(Ov([int, Any] -> Any, [str, Any] -> Any), [Any, int] -> Any) ==
# Ov([Any, int] -> Any, [Any, int] -> Any)
#
# TODO: Consider more cases of callable subtyping.
result: list[CallableType] = []
s = self.s
if isinstance(s, FunctionLike):
# The interesting case where both types are function types.
for t_item in t.items:
for s_item in s.items:
if is_similar_callables(t_item, s_item):
if is_equivalent(t_item, s_item):
result.append(combine_similar_callables(t_item, s_item))
elif is_subtype(t_item, s_item):
result.append(s_item)
if result:
# TODO: Simplify redundancies from the result.
if len(result) == 1:
return result[0]
else:
return Overloaded(result)
return join_types(t.fallback, s.fallback)
elif isinstance(s, Instance) and s.type.is_protocol:
call = unpack_callback_protocol(s)
if call:
return join_types(t, call)
return join_types(t.fallback, s)
def join_tuples(self, s: TupleType, t: TupleType) -> list[Type] | None:
"""Join two tuple types while handling variadic entries.
This is surprisingly tricky, and we don't handle some tricky corner cases.
Most of the trickiness comes from the variadic tuple items like *tuple[X, ...]
since they can have arbitrary partial overlaps (while *Ts can't be split).
"""
s_unpack_index = find_unpack_in_list(s.items)
t_unpack_index = find_unpack_in_list(t.items)
if s_unpack_index is None and t_unpack_index is None:
if s.length() == t.length():
items: list[Type] = []
for i in range(t.length()):
items.append(join_types(t.items[i], s.items[i]))
return items
return None
if s_unpack_index is not None and t_unpack_index is not None:
# The most complex case: both tuples have an unpack item.
s_unpack = s.items[s_unpack_index]
assert isinstance(s_unpack, UnpackType)
s_unpacked = get_proper_type(s_unpack.type)
t_unpack = t.items[t_unpack_index]
assert isinstance(t_unpack, UnpackType)
t_unpacked = get_proper_type(t_unpack.type)
if s.length() == t.length() and s_unpack_index == t_unpack_index:
# We can handle a case where arity is perfectly aligned, e.g.
# join(Tuple[X1, *tuple[Y1, ...], Z1], Tuple[X2, *tuple[Y2, ...], Z2]).
# We can essentially perform the join elementwise.
prefix_len = t_unpack_index
suffix_len = t.length() - t_unpack_index - 1
items = []
for si, ti in zip(s.items[:prefix_len], t.items[:prefix_len]):
items.append(join_types(si, ti))
joined = join_types(s_unpacked, t_unpacked)
if isinstance(joined, TypeVarTupleType):
items.append(UnpackType(joined))
elif isinstance(joined, Instance) and joined.type.fullname == "builtins.tuple":
items.append(UnpackType(joined))
else:
if isinstance(t_unpacked, Instance):
assert t_unpacked.type.fullname == "builtins.tuple"
tuple_instance = t_unpacked
else:
assert isinstance(t_unpacked, TypeVarTupleType)
tuple_instance = t_unpacked.tuple_fallback
items.append(
UnpackType(
tuple_instance.copy_modified(
args=[object_from_instance(tuple_instance)]
)
)
)
if suffix_len:
for si, ti in zip(s.items[-suffix_len:], t.items[-suffix_len:]):
items.append(join_types(si, ti))
return items
if s.length() == 1 or t.length() == 1:
# Another case we can handle is when one of tuple is purely variadic
# (i.e. a non-normalized form of tuple[X, ...]), in this case the join
# will be again purely variadic.
if not (isinstance(s_unpacked, Instance) and isinstance(t_unpacked, Instance)):
return None
assert s_unpacked.type.fullname == "builtins.tuple"
assert t_unpacked.type.fullname == "builtins.tuple"
mid_joined = join_types(s_unpacked.args[0], t_unpacked.args[0])
t_other = [a for i, a in enumerate(t.items) if i != t_unpack_index]
s_other = [a for i, a in enumerate(s.items) if i != s_unpack_index]
other_joined = join_type_list(s_other + t_other)
mid_joined = join_types(mid_joined, other_joined)
return [UnpackType(s_unpacked.copy_modified(args=[mid_joined]))]
# TODO: are there other case we can handle (e.g. both prefix/suffix are shorter)?
return None
if s_unpack_index is not None:
variadic = s
unpack_index = s_unpack_index
fixed = t
else:
assert t_unpack_index is not None
variadic = t
unpack_index = t_unpack_index
fixed = s
# Case where one tuple has variadic item and the other one doesn't. The join will
# be variadic, since fixed tuple is a subtype of variadic, but not vice versa.
unpack = variadic.items[unpack_index]
assert isinstance(unpack, UnpackType)
unpacked = get_proper_type(unpack.type)
if not isinstance(unpacked, Instance):
return None
if fixed.length() < variadic.length() - 1:
# There are no non-trivial types that are supertype of both.
return None
prefix_len = unpack_index
suffix_len = variadic.length() - prefix_len - 1
prefix, middle, suffix = split_with_prefix_and_suffix(
tuple(fixed.items), prefix_len, suffix_len
)
items = []
for fi, vi in zip(prefix, variadic.items[:prefix_len]):
items.append(join_types(fi, vi))
mid_joined = join_type_list(list(middle))
mid_joined = join_types(mid_joined, unpacked.args[0])
items.append(UnpackType(unpacked.copy_modified(args=[mid_joined])))
if suffix_len:
for fi, vi in zip(suffix, variadic.items[-suffix_len:]):
items.append(join_types(fi, vi))
return items
def visit_tuple_type(self, t: TupleType) -> ProperType:
# When given two fixed-length tuples:
# * If they have the same length, join their subtypes item-wise:
# Tuple[int, bool] + Tuple[bool, bool] becomes Tuple[int, bool]
# * If lengths do not match, return a variadic tuple:
# Tuple[bool, int] + Tuple[bool] becomes Tuple[int, ...]
#
# Otherwise, `t` is a fixed-length tuple but `self.s` is NOT:
# * Joining with a variadic tuple returns variadic tuple:
# Tuple[int, bool] + Tuple[bool, ...] becomes Tuple[int, ...]
# * Joining with any Sequence also returns a Sequence:
# Tuple[int, bool] + List[bool] becomes Sequence[int]
if isinstance(self.s, TupleType):
if self.instance_joiner is None:
self.instance_joiner = InstanceJoiner()
fallback = self.instance_joiner.join_instances(
mypy.typeops.tuple_fallback(self.s), mypy.typeops.tuple_fallback(t)
)
assert isinstance(fallback, Instance)
items = self.join_tuples(self.s, t)
if items is not None:
if len(items) == 1 and isinstance(item := items[0], UnpackType):
if isinstance(unpacked := get_proper_type(item.type), Instance):
# Avoid double-wrapping tuple[*tuple[X, ...]]
return unpacked
return TupleType(items, fallback)
else:
# TODO: should this be a default fallback behaviour like for meet?
if is_proper_subtype(self.s, t):
return t
if is_proper_subtype(t, self.s):
return self.s
return fallback
else:
return join_types(self.s, mypy.typeops.tuple_fallback(t))
def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
if isinstance(self.s, TypedDictType):
items = {
item_name: s_item_type
for (item_name, s_item_type, t_item_type) in self.s.zip(t)
if (
is_equivalent(s_item_type, t_item_type)
and (item_name in t.required_keys) == (item_name in self.s.required_keys)
)
}
fallback = self.s.create_anonymous_fallback()
all_keys = set(items.keys())
# We need to filter by items.keys() since some required keys present in both t and
# self.s might be missing from the join if the types are incompatible.
required_keys = all_keys & t.required_keys & self.s.required_keys
# If one type has a key as readonly, we mark it as readonly for both:
readonly_keys = (t.readonly_keys | t.readonly_keys) & all_keys
return TypedDictType(items, required_keys, readonly_keys, fallback)
elif isinstance(self.s, Instance):
return join_types(self.s, t.fallback)
else:
return self.default(self.s)
def visit_literal_type(self, t: LiteralType) -> ProperType:
if isinstance(self.s, LiteralType):
if t == self.s:
return t
if self.s.fallback.type.is_enum and t.fallback.type.is_enum:
return mypy.typeops.make_simplified_union([self.s, t])
return join_types(self.s.fallback, t.fallback)
elif isinstance(self.s, Instance) and self.s.last_known_value == t:
return t
else:
return join_types(self.s, t.fallback)
def visit_partial_type(self, t: PartialType) -> ProperType:
# We only have partial information so we can't decide the join result. We should
# never get here.
assert False, "Internal error"
def visit_type_type(self, t: TypeType) -> ProperType:
if isinstance(self.s, TypeType):
return TypeType.make_normalized(
join_types(t.item, self.s.item),
line=t.line,
is_type_form=self.s.is_type_form or t.is_type_form,
)
elif isinstance(self.s, Instance) and self.s.type.fullname == "builtins.type":
return self.s
else:
return self.default(self.s)
def visit_type_alias_type(self, t: TypeAliasType) -> ProperType:
assert False, f"This should be never called, got {t}"
def default(self, typ: Type) -> ProperType:
typ = get_proper_type(typ)
if isinstance(typ, Instance):
return object_from_instance(typ)
elif isinstance(typ, TypeType):
return self.default(typ.item)
elif isinstance(typ, UnboundType):
return AnyType(TypeOfAny.special_form)
elif isinstance(typ, TupleType):
return self.default(mypy.typeops.tuple_fallback(typ))
elif isinstance(typ, TypedDictType):
return self.default(typ.fallback)
elif isinstance(typ, FunctionLike):
return self.default(typ.fallback)
elif isinstance(typ, TypeVarType):
return self.default(typ.upper_bound)
elif isinstance(typ, ParamSpecType):
return self.default(typ.upper_bound)
else:
return AnyType(TypeOfAny.special_form)
def is_better(t: Type, s: Type) -> bool:
# Given two possible results from join_instances_via_supertype(),
# indicate whether t is the better one.
t = get_proper_type(t)
s = get_proper_type(s)
if isinstance(t, Instance):
if not isinstance(s, Instance):
return True
if t.type.is_protocol != s.type.is_protocol:
if t.type.fullname != "builtins.object" and s.type.fullname != "builtins.object":
# mro of protocol is not really relevant
return not t.type.is_protocol
# Use len(mro) as a proxy for the better choice.
if len(t.type.mro) > len(s.type.mro):
return True
return False
def normalize_callables(s: ProperType, t: ProperType) -> tuple[ProperType, ProperType]:
if isinstance(s, (CallableType, Overloaded)):
s = s.with_unpacked_kwargs()
if isinstance(t, (CallableType, Overloaded)):
t = t.with_unpacked_kwargs()
return s, t
def is_similar_callables(t: CallableType, s: CallableType) -> bool:
"""Return True if t and s have identical numbers of
arguments, default arguments and varargs.
"""
return (
len(t.arg_types) == len(s.arg_types)
and t.min_args == s.min_args
and t.is_var_arg == s.is_var_arg
)
def is_similar_params(t: Parameters, s: Parameters) -> bool:
# This matches the logic in is_similar_callables() above.
return (
len(t.arg_types) == len(s.arg_types)
and t.min_args == s.min_args
and (t.var_arg() is not None) == (s.var_arg() is not None)
)
def update_callable_ids(c: CallableType, ids: list[TypeVarId]) -> CallableType:
tv_map = {}
tvs = []
for tv, new_id in zip(c.variables, ids):
new_tv = tv.copy_modified(id=new_id)
tvs.append(new_tv)
tv_map[tv.id] = new_tv
return expand_type(c, tv_map).copy_modified(variables=tvs)
def match_generic_callables(t: CallableType, s: CallableType) -> tuple[CallableType, CallableType]:
# The case where we combine/join/meet similar callables, situation where both are generic
# requires special care. A more principled solution may involve unify_generic_callable(),
# but it would have two problems:
# * This adds risk of infinite recursion: e.g. join -> unification -> solver -> join
# * Using unification is an incorrect thing for meets, as it "widens" the types
# Finally, this effectively falls back to an old behaviour before namespaces were added to
# type variables, and it worked relatively well.
max_len = max(len(t.variables), len(s.variables))
min_len = min(len(t.variables), len(s.variables))
if min_len == 0:
return t, s
new_ids = [TypeVarId.new(meta_level=0) for _ in range(max_len)]
# Note: this relies on variables being in order they appear in function definition.
return update_callable_ids(t, new_ids), update_callable_ids(s, new_ids)
def join_similar_callables(t: CallableType, s: CallableType) -> CallableType:
t, s = match_generic_callables(t, s)
arg_types: list[Type] = []
for i in range(len(t.arg_types)):
arg_types.append(safe_meet(t.arg_types[i], s.arg_types[i]))
# TODO in combine_similar_callables also applies here (names and kinds; user metaclasses)
# The fallback type can be either 'function', 'type', or some user-provided metaclass.
# The result should always use 'function' as a fallback if either operands are using it.
if t.fallback.type.fullname == "builtins.function":
fallback = t.fallback
else:
fallback = s.fallback
return t.copy_modified(
arg_types=arg_types,
arg_names=combine_arg_names(t, s),
ret_type=join_types(t.ret_type, s.ret_type),
fallback=fallback,
name=None,
)
def safe_join(t: Type, s: Type) -> Type:
# This is a temporary solution to prevent crashes in combine_similar_callables() etc.,
# until relevant TODOs on handling arg_kinds will be addressed there.
if not isinstance(t, UnpackType) and not isinstance(s, UnpackType):
return join_types(t, s)
if isinstance(t, UnpackType) and isinstance(s, UnpackType):
return UnpackType(join_types(t.type, s.type))
return object_or_any_from_type(get_proper_type(t))
def safe_meet(t: Type, s: Type) -> Type:
# Similar to above but for meet_types().
from mypy.meet import meet_types
if not isinstance(t, UnpackType) and not isinstance(s, UnpackType):
return meet_types(t, s)
if isinstance(t, UnpackType) and isinstance(s, UnpackType):
unpacked = get_proper_type(t.type)
if isinstance(unpacked, TypeVarTupleType):
fallback_type = unpacked.tuple_fallback.type
elif isinstance(unpacked, TupleType):
fallback_type = unpacked.partial_fallback.type
else:
assert isinstance(unpacked, Instance) and unpacked.type.fullname == "builtins.tuple"
fallback_type = unpacked.type
res = meet_types(t.type, s.type)
if isinstance(res, UninhabitedType):
res = Instance(fallback_type, [res])
return UnpackType(res)
return UninhabitedType()
def combine_similar_callables(t: CallableType, s: CallableType) -> CallableType:
t, s = match_generic_callables(t, s)
arg_types: list[Type] = []
for i in range(len(t.arg_types)):
arg_types.append(safe_join(t.arg_types[i], s.arg_types[i]))
# TODO kinds and argument names
# TODO what should happen if one fallback is 'type' and the other is a user-provided metaclass?
# The fallback type can be either 'function', 'type', or some user-provided metaclass.
# The result should always use 'function' as a fallback if either operands are using it.
if t.fallback.type.fullname == "builtins.function":
fallback = t.fallback
else:
fallback = s.fallback
return t.copy_modified(
arg_types=arg_types,
arg_names=combine_arg_names(t, s),
ret_type=join_types(t.ret_type, s.ret_type),
fallback=fallback,
name=None,
)
def combine_arg_names(
t: CallableType | Parameters, s: CallableType | Parameters
) -> list[str | None]:
"""Produces a list of argument names compatible with both callables.
For example, suppose 't' and 's' have the following signatures:
- t: (a: int, b: str, X: str) -> None
- s: (a: int, b: str, Y: str) -> None
This function would return ["a", "b", None]. This information
is then used above to compute the join of t and s, which results
in a signature of (a: int, b: str, str) -> None.
Note that the third argument's name is omitted and 't' and 's'
are both valid subtypes of this inferred signature.
Precondition: is_similar_types(t, s) is true.
"""
num_args = len(t.arg_types)
new_names = []
for i in range(num_args):
t_name = t.arg_names[i]
s_name = s.arg_names[i]
if t_name == s_name or t.arg_kinds[i].is_named() or s.arg_kinds[i].is_named():
new_names.append(t_name)
else:
new_names.append(None)
return new_names
def object_from_instance(instance: Instance) -> Instance:
"""Construct the type 'builtins.object' from an instance type."""
# Use the fact that 'object' is always the last class in the mro.
res = Instance(instance.type.mro[-1], [])
return res
def object_or_any_from_type(typ: ProperType) -> ProperType:
# Similar to object_from_instance() but tries hard for all types.
# TODO: find a better way to get object, or make this more reliable.
if isinstance(typ, Instance):
return object_from_instance(typ)
elif isinstance(typ, (CallableType, TypedDictType, LiteralType)):
return object_from_instance(typ.fallback)
elif isinstance(typ, TupleType):
return object_from_instance(typ.partial_fallback)
elif isinstance(typ, TypeType):
return object_or_any_from_type(typ.item)
elif isinstance(typ, TypeVarLikeType) and isinstance(typ.upper_bound, ProperType):
return object_or_any_from_type(typ.upper_bound)
elif isinstance(typ, UnionType):
for item in typ.items:
if isinstance(item, ProperType):
candidate = object_or_any_from_type(item)
if isinstance(candidate, Instance):
return candidate
elif isinstance(typ, UnpackType):
object_or_any_from_type(get_proper_type(typ.type))
return AnyType(TypeOfAny.implementation_artifact)
def join_type_list(types: Sequence[Type]) -> Type:
if not types:
# This is a little arbitrary but reasonable. Any empty tuple should be compatible
# with all variable length tuples, and this makes it possible.
return UninhabitedType()
joined = types[0]
for t in types[1:]:
joined = join_types(joined, t)
return joined
def unpack_callback_protocol(t: Instance) -> ProperType | None:
assert t.type.is_protocol
if t.type.protocol_members == ["__call__"]:
return get_proper_type(find_member("__call__", t, t, is_operator=True))
return None

Some files were not shown because too many files have changed in this diff Show more