feat: indie status page MVP -- FastAPI + SQLite

- 8 DB models (services, incidents, monitors, subscribers, etc.)
- Full CRUD API for services, incidents, monitors
- Public status page with live data
- Incident detail page with timeline
- API key authentication
- Uptime monitoring scheduler
- 13 tests passing
- TECHNICAL_DESIGN.md with full spec
This commit is contained in:
IndieStatusBot 2026-04-25 05:00:00 +00:00
commit 902133edd3
4655 changed files with 1342691 additions and 0 deletions

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,439 @@
from __future__ import annotations
from typing import NamedTuple
from mypy.argmap import map_actuals_to_formals
from mypy.fixup import TypeFixer
from mypy.nodes import (
ARG_POS,
MDEF,
SYMBOL_FUNCBASE_TYPES,
Argument,
Block,
CallExpr,
ClassDef,
Decorator,
Expression,
FuncDef,
JsonDict,
NameExpr,
Node,
OverloadedFuncDef,
PassStmt,
RefExpr,
SymbolTableNode,
TypeInfo,
Var,
)
from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.semanal_shared import (
ALLOW_INCOMPATIBLE_OVERRIDE,
parse_bool,
require_bool_literal_argument,
set_callable_name,
)
from mypy.typeops import try_getting_str_literals as try_getting_str_literals
from mypy.types import (
AnyType,
CallableType,
Instance,
LiteralType,
NoneType,
Overloaded,
Type,
TypeOfAny,
TypeType,
TypeVarType,
deserialize_type,
get_proper_type,
)
from mypy.types_utils import is_overlapping_none
from mypy.typevars import fill_typevars
from mypy.util import get_unique_redefinition_name
def _get_decorator_bool_argument(ctx: ClassDefContext, name: str, default: bool) -> bool:
"""Return the bool argument for the decorator.
This handles both @decorator(...) and @decorator.
"""
if isinstance(ctx.reason, CallExpr):
return _get_bool_argument(ctx, ctx.reason, name, default)
else:
return default
def _get_bool_argument(ctx: ClassDefContext, expr: CallExpr, name: str, default: bool) -> bool:
"""Return the boolean value for an argument to a call or the
default if it's not found.
"""
attr_value = _get_argument(expr, name)
if attr_value:
return require_bool_literal_argument(ctx.api, attr_value, name, default)
return default
def _get_argument(call: CallExpr, name: str) -> Expression | None:
"""Return the expression for the specific argument."""
# To do this we use the CallableType of the callee to find the FormalArgument,
# then walk the actual CallExpr looking for the appropriate argument.
#
# Note: I'm not hard-coding the index so that in the future we can support other
# attrib and class makers.
callee_type = _get_callee_type(call)
if not callee_type:
return None
argument = callee_type.argument_by_name(name)
if not argument:
return None
assert argument.name
for i, (attr_name, attr_value) in enumerate(zip(call.arg_names, call.args)):
if argument.pos is not None and not attr_name and i == argument.pos:
return attr_value
if attr_name == argument.name:
return attr_value
return None
def find_shallow_matching_overload_item(overload: Overloaded, call: CallExpr) -> CallableType:
"""Perform limited lookup of a matching overload item.
Full overload resolution is only supported during type checking, but plugins
sometimes need to resolve overloads. This can be used in some such use cases.
Resolve overloads based on these things only:
* Match using argument kinds and names
* If formal argument has type None, only accept the "None" expression in the callee
* If formal argument has type Literal[True] or Literal[False], only accept the
relevant bool literal
Return the first matching overload item, or the last one if nothing matches.
"""
for item in overload.items[:-1]:
ok = True
mapped = map_actuals_to_formals(
call.arg_kinds,
call.arg_names,
item.arg_kinds,
item.arg_names,
lambda i: AnyType(TypeOfAny.special_form),
)
# Look for extra actuals
matched_actuals = set()
for actuals in mapped:
matched_actuals.update(actuals)
if any(i not in matched_actuals for i in range(len(call.args))):
ok = False
for arg_type, kind, actuals in zip(item.arg_types, item.arg_kinds, mapped):
if kind.is_required() and not actuals:
# Missing required argument
ok = False
break
elif actuals:
args = [call.args[i] for i in actuals]
arg_type = get_proper_type(arg_type)
arg_none = any(isinstance(arg, NameExpr) and arg.name == "None" for arg in args)
if isinstance(arg_type, NoneType):
if not arg_none:
ok = False
break
elif (
arg_none
and not is_overlapping_none(arg_type)
and not (
isinstance(arg_type, Instance)
and arg_type.type.fullname == "builtins.object"
)
and not isinstance(arg_type, AnyType)
):
ok = False
break
elif isinstance(arg_type, LiteralType) and isinstance(arg_type.value, bool):
if not any(parse_bool(arg) == arg_type.value for arg in args):
ok = False
break
if ok:
return item
return overload.items[-1]
def _get_callee_type(call: CallExpr) -> CallableType | None:
"""Return the type of the callee, regardless of its syntactic form."""
callee_node: Node | None = call.callee
if isinstance(callee_node, RefExpr):
callee_node = callee_node.node
# Some decorators may be using typing.dataclass_transform, which is itself a decorator, so we
# need to unwrap them to get at the true callee
if isinstance(callee_node, Decorator):
callee_node = callee_node.func
if isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES)) and callee_node.type:
callee_node_type = get_proper_type(callee_node.type)
if isinstance(callee_node_type, Overloaded):
return find_shallow_matching_overload_item(callee_node_type, call)
elif isinstance(callee_node_type, CallableType):
return callee_node_type
return None
def add_method(
ctx: ClassDefContext,
name: str,
args: list[Argument],
return_type: Type,
self_type: Type | None = None,
tvar_def: TypeVarType | None = None,
is_classmethod: bool = False,
is_staticmethod: bool = False,
) -> None:
"""
Adds a new method to a class.
Deprecated, use add_method_to_class() instead.
"""
add_method_to_class(
ctx.api,
ctx.cls,
name=name,
args=args,
return_type=return_type,
self_type=self_type,
tvar_def=tvar_def,
is_classmethod=is_classmethod,
is_staticmethod=is_staticmethod,
)
class MethodSpec(NamedTuple):
"""Represents a method signature to be added, except for `name`."""
args: list[Argument]
return_type: Type
self_type: Type | None = None
tvar_defs: list[TypeVarType] | None = None
def add_method_to_class(
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
cls: ClassDef,
name: str,
# MethodSpec items kept for backward compatibility:
args: list[Argument],
return_type: Type,
self_type: Type | None = None,
tvar_def: list[TypeVarType] | TypeVarType | None = None,
is_classmethod: bool = False,
is_staticmethod: bool = False,
) -> FuncDef | Decorator:
"""Adds a new method to a class definition."""
_prepare_class_namespace(cls, name)
if tvar_def is not None and not isinstance(tvar_def, list):
tvar_def = [tvar_def]
func, sym = _add_method_by_spec(
api,
cls.info,
name,
MethodSpec(args=args, return_type=return_type, self_type=self_type, tvar_defs=tvar_def),
is_classmethod=is_classmethod,
is_staticmethod=is_staticmethod,
)
cls.info.names[name] = sym
cls.info.defn.defs.body.append(func)
return func
def add_overloaded_method_to_class(
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
cls: ClassDef,
name: str,
items: list[MethodSpec],
is_classmethod: bool = False,
is_staticmethod: bool = False,
) -> OverloadedFuncDef:
"""Adds a new overloaded method to a class definition."""
assert len(items) >= 2, "Overloads must contain at least two cases"
# Save old definition, if it exists.
_prepare_class_namespace(cls, name)
# Create function bodies for each passed method spec.
funcs: list[Decorator | FuncDef] = []
for item in items:
func, _sym = _add_method_by_spec(
api,
cls.info,
name=name,
spec=item,
is_classmethod=is_classmethod,
is_staticmethod=is_staticmethod,
)
if isinstance(func, FuncDef):
var = Var(func.name, func.type)
var.set_line(func.line)
func.is_decorated = True
deco = Decorator(func, [], var)
else:
deco = func
deco.is_overload = True
funcs.append(deco)
# Create the final OverloadedFuncDef node:
overload_def = OverloadedFuncDef(funcs)
overload_def.info = cls.info
overload_def.is_class = is_classmethod
overload_def.is_static = is_staticmethod
sym = SymbolTableNode(MDEF, overload_def)
sym.plugin_generated = True
cls.info.names[name] = sym
cls.info.defn.defs.body.append(overload_def)
return overload_def
def _prepare_class_namespace(cls: ClassDef, name: str) -> None:
info = cls.info
assert info
# First remove any previously generated methods with the same name
# to avoid clashes and problems in the semantic analyzer.
if name in info.names:
sym = info.names[name]
if sym.plugin_generated and isinstance(sym.node, FuncDef):
cls.defs.body.remove(sym.node)
# NOTE: we would like the plugin generated node to dominate, but we still
# need to keep any existing definitions so they get semantically analyzed.
if name in info.names:
# Get a nice unique name instead.
r_name = get_unique_redefinition_name(name, info.names)
info.names[r_name] = info.names[name]
def _add_method_by_spec(
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
info: TypeInfo,
name: str,
spec: MethodSpec,
*,
is_classmethod: bool,
is_staticmethod: bool,
) -> tuple[FuncDef | Decorator, SymbolTableNode]:
args, return_type, self_type, tvar_defs = spec
assert not (
is_classmethod is True and is_staticmethod is True
), "Can't add a new method that's both staticmethod and classmethod."
if isinstance(api, SemanticAnalyzerPluginInterface):
function_type = api.named_type("builtins.function")
else:
function_type = api.named_generic_type("builtins.function", [])
if is_classmethod:
self_type = self_type or TypeType(fill_typevars(info))
first = [Argument(Var("_cls"), self_type, None, ARG_POS, True)]
elif is_staticmethod:
first = []
else:
self_type = self_type or fill_typevars(info)
first = [Argument(Var("self"), self_type, None, ARG_POS)]
args = first + args
arg_types, arg_names, arg_kinds = [], [], []
for arg in args:
assert arg.type_annotation, "All arguments must be fully typed."
arg_types.append(arg.type_annotation)
arg_names.append(arg.variable.name)
arg_kinds.append(arg.kind)
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
if tvar_defs:
signature.variables = tuple(tvar_defs)
func = FuncDef(name, args, Block([PassStmt()]))
func.info = info
func.type = set_callable_name(signature, func)
func.is_class = is_classmethod
func.is_static = is_staticmethod
func._fullname = info.fullname + "." + name
func.line = info.line
# Add decorator for is_staticmethod. It's unnecessary for is_classmethod.
if is_staticmethod:
func.is_decorated = True
v = Var(name, func.type)
v.info = info
v._fullname = func._fullname
v.is_staticmethod = True
dec = Decorator(func, [], v)
dec.line = info.line
sym = SymbolTableNode(MDEF, dec)
sym.plugin_generated = True
return dec, sym
sym = SymbolTableNode(MDEF, func)
sym.plugin_generated = True
return func, sym
def add_attribute_to_class(
api: SemanticAnalyzerPluginInterface,
cls: ClassDef,
name: str,
typ: Type,
final: bool = False,
no_serialize: bool = False,
override_allow_incompatible: bool = False,
fullname: str | None = None,
is_classvar: bool = False,
overwrite_existing: bool = False,
) -> Var:
"""
Adds a new attribute to a class definition.
This currently only generates the symbol table entry and no corresponding AssignmentStatement
"""
info = cls.info
# NOTE: we would like the plugin generated node to dominate, but we still
# need to keep any existing definitions so they get semantically analyzed.
if name in info.names and not overwrite_existing:
# Get a nice unique name instead.
r_name = get_unique_redefinition_name(name, info.names)
info.names[r_name] = info.names[name]
node = Var(name, typ)
node.info = info
node.is_final = final
node.is_classvar = is_classvar
if name in ALLOW_INCOMPATIBLE_OVERRIDE:
node.allow_incompatible_override = True
else:
node.allow_incompatible_override = override_allow_incompatible
if fullname:
node._fullname = fullname
else:
node._fullname = info.fullname + "." + name
info.names[name] = SymbolTableNode(
MDEF, node, plugin_generated=True, no_serialize=no_serialize
)
return node
def deserialize_and_fixup_type(data: str | JsonDict, api: SemanticAnalyzerPluginInterface) -> Type:
typ = deserialize_type(data)
typ.accept(TypeFixer(api.modules, allow_missing=False))
return typ

View file

@ -0,0 +1,20 @@
"""Constant definitions for plugins kept here to help with import cycles."""
from typing import Final
from mypy.semanal_enum import ENUM_BASES
SINGLEDISPATCH_TYPE: Final = "functools._SingleDispatchCallable"
SINGLEDISPATCH_REGISTER_METHOD: Final = f"{SINGLEDISPATCH_TYPE}.register"
SINGLEDISPATCH_CALLABLE_CALL_METHOD: Final = f"{SINGLEDISPATCH_TYPE}.__call__"
SINGLEDISPATCH_REGISTER_RETURN_CLASS: Final = "_SingleDispatchRegisterCallable"
SINGLEDISPATCH_REGISTER_CALLABLE_CALL_METHOD: Final = (
f"functools.{SINGLEDISPATCH_REGISTER_RETURN_CLASS}.__call__"
)
ENUM_NAME_ACCESS: Final = {f"{prefix}.name" for prefix in ENUM_BASES} | {
f"{prefix}._name_" for prefix in ENUM_BASES
}
ENUM_VALUE_ACCESS: Final = {f"{prefix}.value" for prefix in ENUM_BASES} | {
f"{prefix}._value_" for prefix in ENUM_BASES
}

View file

@ -0,0 +1,245 @@
"""Plugin to provide accurate types for some parts of the ctypes module."""
from __future__ import annotations
# Fully qualified instead of "from mypy.plugin import ..." to avoid circular import problems.
import mypy.plugin
from mypy import nodes
from mypy.maptype import map_instance_to_supertype
from mypy.messages import format_type
from mypy.subtypes import is_subtype
from mypy.typeops import make_simplified_union
from mypy.types import (
AnyType,
CallableType,
Instance,
NoneType,
ProperType,
Type,
TypeOfAny,
UnionType,
flatten_nested_unions,
get_proper_type,
)
def _find_simplecdata_base_arg(
tp: Instance, api: mypy.plugin.CheckerPluginInterface
) -> ProperType | None:
"""Try to find a parametrized _SimpleCData in tp's bases and return its single type argument.
None is returned if _SimpleCData appears nowhere in tp's (direct or indirect) bases.
"""
if tp.type.has_base("_ctypes._SimpleCData"):
simplecdata_base = map_instance_to_supertype(
tp,
api.named_generic_type("_ctypes._SimpleCData", [AnyType(TypeOfAny.special_form)]).type,
)
assert len(simplecdata_base.args) == 1, "_SimpleCData takes exactly one type argument"
return get_proper_type(simplecdata_base.args[0])
return None
def _autoconvertible_to_cdata(tp: Type, api: mypy.plugin.CheckerPluginInterface) -> Type:
"""Get a type that is compatible with all types that can be implicitly converted to the given
CData type.
Examples:
* c_int -> Union[c_int, int]
* c_char_p -> Union[c_char_p, bytes, int, NoneType]
* MyStructure -> MyStructure
"""
allowed_types = []
# If tp is a union, we allow all types that are convertible to at least one of the union
# items. This is not quite correct - strictly speaking, only types convertible to *all* of the
# union items should be allowed. This may be worth changing in the future, but the more
# correct algorithm could be too strict to be useful.
for t in flatten_nested_unions([tp]):
t = get_proper_type(t)
# Every type can be converted from itself (obviously).
allowed_types.append(t)
if isinstance(t, Instance):
unboxed = _find_simplecdata_base_arg(t, api)
if unboxed is not None:
# If _SimpleCData appears in tp's (direct or indirect) bases, its type argument
# specifies the type's "unboxed" version, which can always be converted back to
# the original "boxed" type.
allowed_types.append(unboxed)
if t.type.has_base("ctypes._PointerLike"):
# Pointer-like _SimpleCData subclasses can also be converted from
# an int or None.
allowed_types.append(api.named_generic_type("builtins.int", []))
allowed_types.append(NoneType())
return make_simplified_union(allowed_types)
def _autounboxed_cdata(tp: Type) -> ProperType:
"""Get the auto-unboxed version of a CData type, if applicable.
For *direct* _SimpleCData subclasses, the only type argument of _SimpleCData in the bases list
is returned.
For all other CData types, including indirect _SimpleCData subclasses, tp is returned as-is.
"""
tp = get_proper_type(tp)
if isinstance(tp, UnionType):
return make_simplified_union([_autounboxed_cdata(t) for t in tp.items])
elif isinstance(tp, Instance):
for base in tp.type.bases:
if base.type.fullname == "_ctypes._SimpleCData":
# If tp has _SimpleCData as a direct base class,
# the auto-unboxed type is the single type argument of the _SimpleCData type.
assert len(base.args) == 1
return get_proper_type(base.args[0])
# If tp is not a concrete type, or if there is no _SimpleCData in the bases,
# the type is not auto-unboxed.
return tp
def _get_array_element_type(tp: Type) -> ProperType | None:
"""Get the element type of the Array type tp, or None if not specified."""
tp = get_proper_type(tp)
if isinstance(tp, Instance):
assert tp.type.fullname == "_ctypes.Array"
if len(tp.args) == 1:
return get_proper_type(tp.args[0])
return None
def array_constructor_callback(ctx: mypy.plugin.FunctionContext) -> Type:
"""Callback to provide an accurate signature for the ctypes.Array constructor."""
# Extract the element type from the constructor's return type, i. e. the type of the array
# being constructed.
et = _get_array_element_type(ctx.default_return_type)
if et is not None:
allowed = _autoconvertible_to_cdata(et, ctx.api)
assert (
len(ctx.arg_types) == 1
), "The stub of the ctypes.Array constructor should have a single vararg parameter"
for arg_num, (arg_kind, arg_type) in enumerate(zip(ctx.arg_kinds[0], ctx.arg_types[0]), 1):
if arg_kind == nodes.ARG_POS and not is_subtype(arg_type, allowed):
ctx.api.msg.fail(
"Array constructor argument {} of type {}"
" is not convertible to the array element type {}".format(
arg_num,
format_type(arg_type, ctx.api.options),
format_type(et, ctx.api.options),
),
ctx.context,
)
elif arg_kind == nodes.ARG_STAR:
ty = ctx.api.named_generic_type("typing.Iterable", [allowed])
if not is_subtype(arg_type, ty):
it = ctx.api.named_generic_type("typing.Iterable", [et])
ctx.api.msg.fail(
"Array constructor argument {} of type {}"
" is not convertible to the array element type {}".format(
arg_num,
format_type(arg_type, ctx.api.options),
format_type(it, ctx.api.options),
),
ctx.context,
)
return ctx.default_return_type
def array_getitem_callback(ctx: mypy.plugin.MethodContext) -> Type:
"""Callback to provide an accurate return type for ctypes.Array.__getitem__."""
et = _get_array_element_type(ctx.type)
if et is not None:
unboxed = _autounboxed_cdata(et)
assert (
len(ctx.arg_types) == 1
), "The stub of ctypes.Array.__getitem__ should have exactly one parameter"
assert (
len(ctx.arg_types[0]) == 1
), "ctypes.Array.__getitem__'s parameter should not be variadic"
index_type = get_proper_type(ctx.arg_types[0][0])
if isinstance(index_type, Instance):
if index_type.type.has_base("builtins.int"):
return unboxed
elif index_type.type.has_base("builtins.slice"):
return ctx.api.named_generic_type("builtins.list", [unboxed])
return ctx.default_return_type
def array_setitem_callback(ctx: mypy.plugin.MethodSigContext) -> CallableType:
"""Callback to provide an accurate signature for ctypes.Array.__setitem__."""
et = _get_array_element_type(ctx.type)
if et is not None:
allowed = _autoconvertible_to_cdata(et, ctx.api)
assert len(ctx.default_signature.arg_types) == 2
index_type = get_proper_type(ctx.default_signature.arg_types[0])
if isinstance(index_type, Instance):
arg_type = None
if index_type.type.has_base("builtins.int"):
arg_type = allowed
elif index_type.type.has_base("builtins.slice"):
arg_type = ctx.api.named_generic_type("builtins.list", [allowed])
if arg_type is not None:
# Note: arg_type can only be None if index_type is invalid, in which case we use
# the default signature and let mypy report an error about it.
return ctx.default_signature.copy_modified(
arg_types=ctx.default_signature.arg_types[:1] + [arg_type]
)
return ctx.default_signature
def array_iter_callback(ctx: mypy.plugin.MethodContext) -> Type:
"""Callback to provide an accurate return type for ctypes.Array.__iter__."""
et = _get_array_element_type(ctx.type)
if et is not None:
unboxed = _autounboxed_cdata(et)
return ctx.api.named_generic_type("typing.Iterator", [unboxed])
return ctx.default_return_type
def array_value_callback(ctx: mypy.plugin.AttributeContext) -> Type:
"""Callback to provide an accurate type for ctypes.Array.value."""
et = _get_array_element_type(ctx.type)
if et is not None:
types: list[Type] = []
for tp in flatten_nested_unions([et]):
tp = get_proper_type(tp)
if isinstance(tp, AnyType):
types.append(AnyType(TypeOfAny.from_another_any, source_any=tp))
elif isinstance(tp, Instance) and tp.type.fullname == "ctypes.c_char":
types.append(ctx.api.named_generic_type("builtins.bytes", []))
elif isinstance(tp, Instance) and tp.type.fullname == "ctypes.c_wchar":
types.append(ctx.api.named_generic_type("builtins.str", []))
else:
ctx.api.msg.fail(
'Array attribute "value" is only available'
' with element type "c_char" or "c_wchar", not {}'.format(
format_type(et, ctx.api.options)
),
ctx.context,
)
return make_simplified_union(types)
return ctx.default_attr_type
def array_raw_callback(ctx: mypy.plugin.AttributeContext) -> Type:
"""Callback to provide an accurate type for ctypes.Array.raw."""
et = _get_array_element_type(ctx.type)
if et is not None:
types: list[Type] = []
for tp in flatten_nested_unions([et]):
tp = get_proper_type(tp)
if (
isinstance(tp, AnyType)
or isinstance(tp, Instance)
and tp.type.fullname == "ctypes.c_char"
):
types.append(ctx.api.named_generic_type("builtins.bytes", []))
else:
ctx.api.msg.fail(
'Array attribute "raw" is only available'
' with element type "c_char", not {}'.format(format_type(et, ctx.api.options)),
ctx.context,
)
return make_simplified_union(types)
return ctx.default_attr_type

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,612 @@
from __future__ import annotations
from collections.abc import Callable
from functools import partial
from typing import Final
import mypy.errorcodes as codes
from mypy import message_registry
from mypy.nodes import DictExpr, Expression, IntExpr, StrExpr, UnaryExpr
from mypy.plugin import (
AttributeContext,
ClassDefContext,
FunctionContext,
FunctionSigContext,
MethodContext,
MethodSigContext,
Plugin,
)
from mypy.plugins.attrs import (
attr_class_maker_callback,
attr_class_makers,
attr_dataclass_makers,
attr_define_makers,
attr_frozen_makers,
attr_tag_callback,
evolve_function_sig_callback,
fields_function_sig_callback,
)
from mypy.plugins.common import try_getting_str_literals
from mypy.plugins.constants import (
ENUM_NAME_ACCESS,
ENUM_VALUE_ACCESS,
SINGLEDISPATCH_CALLABLE_CALL_METHOD,
SINGLEDISPATCH_REGISTER_CALLABLE_CALL_METHOD,
SINGLEDISPATCH_REGISTER_METHOD,
)
from mypy.plugins.ctypes import (
array_constructor_callback,
array_getitem_callback,
array_iter_callback,
array_raw_callback,
array_setitem_callback,
array_value_callback,
)
from mypy.plugins.dataclasses import (
dataclass_class_maker_callback,
dataclass_makers,
dataclass_tag_callback,
replace_function_sig_callback,
)
from mypy.plugins.enums import enum_member_callback, enum_name_callback, enum_value_callback
from mypy.plugins.functools import (
functools_total_ordering_maker_callback,
functools_total_ordering_makers,
partial_call_callback,
partial_new_callback,
)
from mypy.plugins.singledispatch import (
call_singledispatch_function_after_register_argument,
call_singledispatch_function_callback,
create_singledispatch_function_callback,
singledispatch_register_callback,
)
from mypy.subtypes import is_subtype
from mypy.typeops import is_literal_type_like, make_simplified_union
from mypy.types import (
TPDICT_FB_NAMES,
AnyType,
CallableType,
FunctionLike,
Instance,
LiteralType,
NoneType,
TupleType,
Type,
TypedDictType,
TypeOfAny,
TypeVarType,
UnionType,
get_proper_type,
get_proper_types,
)
TD_SETDEFAULT_NAMES: Final = {n + ".setdefault" for n in TPDICT_FB_NAMES}
TD_POP_NAMES: Final = {n + ".pop" for n in TPDICT_FB_NAMES}
TD_DELITEM_NAMES: Final = {n + ".__delitem__" for n in TPDICT_FB_NAMES}
TD_UPDATE_METHOD_NAMES: Final = (
{n + ".update" for n in TPDICT_FB_NAMES}
| {n + ".__or__" for n in TPDICT_FB_NAMES}
| {n + ".__ror__" for n in TPDICT_FB_NAMES}
| {n + ".__ior__" for n in TPDICT_FB_NAMES}
)
class DefaultPlugin(Plugin):
"""Type checker plugin that is enabled by default."""
def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
if fullname == "_ctypes.Array":
return array_constructor_callback
elif fullname == "functools.singledispatch":
return create_singledispatch_function_callback
elif fullname == "functools.partial":
return partial_new_callback
elif fullname == "enum.member":
return enum_member_callback
elif fullname == "builtins.len":
return len_callback
return None
def get_function_signature_hook(
self, fullname: str
) -> Callable[[FunctionSigContext], FunctionLike] | None:
if fullname in ("attr.evolve", "attrs.evolve", "attr.assoc", "attrs.assoc"):
return evolve_function_sig_callback
elif fullname in ("attr.fields", "attrs.fields"):
return fields_function_sig_callback
elif fullname == "dataclasses.replace":
return replace_function_sig_callback
return None
def get_method_signature_hook(
self, fullname: str
) -> Callable[[MethodSigContext], FunctionLike] | None:
if fullname == "typing.Mapping.get":
return typed_dict_get_signature_callback
elif fullname in TD_SETDEFAULT_NAMES:
return typed_dict_setdefault_signature_callback
elif fullname in TD_POP_NAMES:
return typed_dict_pop_signature_callback
elif fullname == "_ctypes.Array.__setitem__":
return array_setitem_callback
elif fullname == SINGLEDISPATCH_CALLABLE_CALL_METHOD:
return call_singledispatch_function_callback
elif fullname in TD_UPDATE_METHOD_NAMES:
return typed_dict_update_signature_callback
return None
def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
if fullname == "typing.Mapping.get":
return typed_dict_get_callback
elif fullname == "builtins.int.__pow__":
return int_pow_callback
elif fullname == "builtins.int.__neg__":
return int_neg_callback
elif fullname == "builtins.int.__pos__":
return int_pos_callback
elif fullname in ("builtins.tuple.__mul__", "builtins.tuple.__rmul__"):
return tuple_mul_callback
elif fullname in TD_SETDEFAULT_NAMES:
return typed_dict_setdefault_callback
elif fullname in TD_POP_NAMES:
return typed_dict_pop_callback
elif fullname in TD_DELITEM_NAMES:
return typed_dict_delitem_callback
elif fullname == "_ctypes.Array.__getitem__":
return array_getitem_callback
elif fullname == "_ctypes.Array.__iter__":
return array_iter_callback
elif fullname == SINGLEDISPATCH_REGISTER_METHOD:
return singledispatch_register_callback
elif fullname == SINGLEDISPATCH_REGISTER_CALLABLE_CALL_METHOD:
return call_singledispatch_function_after_register_argument
elif fullname == "functools.partial.__call__":
return partial_call_callback
return None
def get_attribute_hook(self, fullname: str) -> Callable[[AttributeContext], Type] | None:
if fullname == "_ctypes.Array.value":
return array_value_callback
elif fullname == "_ctypes.Array.raw":
return array_raw_callback
elif fullname in ENUM_NAME_ACCESS:
return enum_name_callback
elif fullname in ENUM_VALUE_ACCESS:
return enum_value_callback
return None
def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
# These dataclass and attrs hooks run in the main semantic analysis pass
# and only tag known dataclasses/attrs classes, so that the second
# hooks (in get_class_decorator_hook_2) can detect dataclasses/attrs classes
# in the MRO.
if fullname in dataclass_makers:
return dataclass_tag_callback
if (
fullname in attr_class_makers
or fullname in attr_dataclass_makers
or fullname in attr_frozen_makers
or fullname in attr_define_makers
):
return attr_tag_callback
return None
def get_class_decorator_hook_2(
self, fullname: str
) -> Callable[[ClassDefContext], bool] | None:
if fullname in dataclass_makers:
return dataclass_class_maker_callback
elif fullname in functools_total_ordering_makers:
return functools_total_ordering_maker_callback
elif fullname in attr_class_makers:
return attr_class_maker_callback
elif fullname in attr_dataclass_makers:
return partial(attr_class_maker_callback, auto_attribs_default=True)
elif fullname in attr_frozen_makers:
return partial(
attr_class_maker_callback, auto_attribs_default=None, frozen_default=True
)
elif fullname in attr_define_makers:
return partial(
attr_class_maker_callback, auto_attribs_default=None, slots_default=True
)
return None
def len_callback(ctx: FunctionContext) -> Type:
"""Infer a better return type for 'len'."""
if len(ctx.arg_types) == 1 and len(ctx.arg_types[0]) == 1:
arg_type = ctx.arg_types[0][0]
arg_type = get_proper_type(arg_type)
if isinstance(arg_type, Instance) and arg_type.type.fullname == "librt.vecs.vec":
# The length of vec is a fixed-width integer, for more
# low-level optimization potential.
return ctx.api.named_generic_type("mypy_extensions.i64", [])
return ctx.default_return_type
def typed_dict_get_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.get.
This is used to get better type context for the second argument that
depends on a TypedDict value type.
"""
signature = ctx.default_signature
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.args) == 2
and len(ctx.args[0]) == 1
and isinstance(ctx.args[0][0], StrExpr)
and len(signature.arg_types) == 2
and len(signature.variables) == 1
and len(ctx.args[1]) == 1
):
key = ctx.args[0][0].value
value_type = get_proper_type(ctx.type.items.get(key))
ret_type = signature.ret_type
if value_type:
default_arg = ctx.args[1][0]
if (
isinstance(value_type, TypedDictType)
and isinstance(default_arg, DictExpr)
and len(default_arg.items) == 0
):
# Caller has empty dict {} as default for typed dict.
value_type = value_type.copy_modified(required_keys=set())
# Tweak the signature to include the value type as context. It's
# only needed for type inference since there's a union with a type
# variable that accepts everything.
tv = signature.variables[0]
assert isinstance(tv, TypeVarType)
return signature.copy_modified(
arg_types=[signature.arg_types[0], make_simplified_union([value_type, tv])],
ret_type=ret_type,
)
return signature
def typed_dict_get_callback(ctx: MethodContext) -> Type:
"""Infer a precise return type for TypedDict.get with literal first argument."""
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) >= 1
and len(ctx.arg_types[0]) == 1
):
keys = try_getting_str_literals(ctx.args[0][0], ctx.arg_types[0][0])
if keys is None:
return ctx.default_return_type
default_type: Type
default_arg: Expression | None
if len(ctx.arg_types) <= 1 or not ctx.arg_types[1]:
default_arg = None
default_type = NoneType()
elif len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
default_arg = ctx.args[1][0]
default_type = ctx.arg_types[1][0]
else:
return ctx.default_return_type
output_types: list[Type] = []
for key in keys:
value_type: Type | None = ctx.type.items.get(key)
if value_type is None:
return ctx.default_return_type
if key in ctx.type.required_keys:
output_types.append(value_type)
else:
# HACK to deal with get(key, {})
if (
isinstance(default_arg, DictExpr)
and len(default_arg.items) == 0
and isinstance(vt := get_proper_type(value_type), TypedDictType)
):
output_types.append(vt.copy_modified(required_keys=set()))
else:
output_types.append(value_type)
output_types.append(default_type)
# for nicer reveal_type, put default at the end, if it is present
if default_type in output_types:
output_types = [t for t in output_types if t != default_type] + [default_type]
return make_simplified_union(output_types)
return ctx.default_return_type
def typed_dict_pop_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.pop.
This is used to get better type context for the second argument that
depends on a TypedDict value type.
"""
signature = ctx.default_signature
str_type = ctx.api.named_generic_type("builtins.str", [])
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.args) == 2
and len(ctx.args[0]) == 1
and isinstance(ctx.args[0][0], StrExpr)
and len(signature.arg_types) == 2
and len(signature.variables) == 1
and len(ctx.args[1]) == 1
):
key = ctx.args[0][0].value
value_type = ctx.type.items.get(key)
if value_type:
# Tweak the signature to include the value type as context. It's
# only needed for type inference since there's a union with a type
# variable that accepts everything.
tv = signature.variables[0]
assert isinstance(tv, TypeVarType)
typ = make_simplified_union([value_type, tv])
return signature.copy_modified(arg_types=[str_type, typ], ret_type=typ)
return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
def typed_dict_pop_callback(ctx: MethodContext) -> Type:
"""Type check and infer a precise return type for TypedDict.pop."""
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) >= 1
and len(ctx.arg_types[0]) == 1
):
key_expr = ctx.args[0][0]
keys = try_getting_str_literals(key_expr, ctx.arg_types[0][0])
if keys is None:
ctx.api.fail(
message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
key_expr,
code=codes.LITERAL_REQ,
)
return AnyType(TypeOfAny.from_error)
value_types = []
for key in keys:
if key in ctx.type.required_keys or key in ctx.type.readonly_keys:
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, key_expr)
value_type = ctx.type.items.get(key)
if value_type:
value_types.append(value_type)
else:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, key_expr)
return AnyType(TypeOfAny.from_error)
if len(ctx.args[1]) == 0:
return make_simplified_union(value_types)
elif len(ctx.arg_types) == 2 and len(ctx.arg_types[1]) == 1 and len(ctx.args[1]) == 1:
return make_simplified_union([*value_types, ctx.arg_types[1][0]])
return ctx.default_return_type
def typed_dict_setdefault_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for TypedDict.setdefault.
This is used to get better type context for the second argument that
depends on a TypedDict value type.
"""
signature = ctx.default_signature
str_type = ctx.api.named_generic_type("builtins.str", [])
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.args) == 2
and len(ctx.args[0]) == 1
and isinstance(ctx.args[0][0], StrExpr)
and len(signature.arg_types) == 2
and len(ctx.args[1]) == 1
):
key = ctx.args[0][0].value
value_type = ctx.type.items.get(key)
if value_type:
return signature.copy_modified(arg_types=[str_type, value_type])
return signature.copy_modified(arg_types=[str_type, signature.arg_types[1]])
def typed_dict_setdefault_callback(ctx: MethodContext) -> Type:
"""Type check TypedDict.setdefault and infer a precise return type."""
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) == 2
and len(ctx.arg_types[0]) == 1
and len(ctx.arg_types[1]) == 1
):
key_expr = ctx.args[0][0]
keys = try_getting_str_literals(key_expr, ctx.arg_types[0][0])
if keys is None:
ctx.api.fail(
message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
key_expr,
code=codes.LITERAL_REQ,
)
return AnyType(TypeOfAny.from_error)
assigned_readonly_keys = ctx.type.readonly_keys & set(keys)
if assigned_readonly_keys:
ctx.api.msg.readonly_keys_mutated(assigned_readonly_keys, context=key_expr)
default_type = ctx.arg_types[1][0]
default_expr = ctx.args[1][0]
value_types = []
for key in keys:
value_type = ctx.type.items.get(key)
if value_type is None:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, key_expr)
return AnyType(TypeOfAny.from_error)
# The signature_callback above can't always infer the right signature
# (e.g. when the expression is a variable that happens to be a Literal str)
# so we need to handle the check ourselves here and make sure the provided
# default can be assigned to all key-value pairs we're updating.
if not is_subtype(default_type, value_type):
ctx.api.msg.typeddict_setdefault_arguments_inconsistent(
default_type, value_type, default_expr
)
return AnyType(TypeOfAny.from_error)
value_types.append(value_type)
return make_simplified_union(value_types)
return ctx.default_return_type
def typed_dict_delitem_callback(ctx: MethodContext) -> Type:
"""Type check TypedDict.__delitem__."""
if (
isinstance(ctx.type, TypedDictType)
and len(ctx.arg_types) == 1
and len(ctx.arg_types[0]) == 1
):
key_expr = ctx.args[0][0]
keys = try_getting_str_literals(key_expr, ctx.arg_types[0][0])
if keys is None:
ctx.api.fail(
message_registry.TYPEDDICT_KEY_MUST_BE_STRING_LITERAL,
key_expr,
code=codes.LITERAL_REQ,
)
return AnyType(TypeOfAny.from_error)
for key in keys:
if key in ctx.type.required_keys or key in ctx.type.readonly_keys:
ctx.api.msg.typeddict_key_cannot_be_deleted(ctx.type, key, key_expr)
elif key not in ctx.type.items:
ctx.api.msg.typeddict_key_not_found(ctx.type, key, key_expr)
return ctx.default_return_type
_TP_DICT_MUTATING_METHODS: Final = frozenset({"update of TypedDict", "__ior__ of TypedDict"})
def typed_dict_update_signature_callback(ctx: MethodSigContext) -> CallableType:
"""Try to infer a better signature type for methods that update `TypedDict`.
This includes: `TypedDict.update`, `TypedDict.__or__`, `TypedDict.__ror__`,
and `TypedDict.__ior__`.
"""
signature = ctx.default_signature
if isinstance(ctx.type, TypedDictType) and len(signature.arg_types) == 1:
arg_type = get_proper_type(signature.arg_types[0])
if not isinstance(arg_type, TypedDictType):
return signature
arg_type = ctx.type.copy_modified(
fallback=arg_type.create_anonymous_fallback(), required_keys=set()
)
if ctx.args and ctx.args[0]:
if signature.name in _TP_DICT_MUTATING_METHODS:
# If we want to mutate this object in place, we need to set this flag,
# it will trigger an extra check in TypedDict's checker.
arg_type.to_be_mutated = True
with ctx.api.msg.filter_errors(
filter_errors=lambda name, info: info.code != codes.TYPEDDICT_READONLY_MUTATED,
save_filtered_errors=True,
):
inferred = get_proper_type(
ctx.api.get_expression_type(ctx.args[0][0], type_context=arg_type)
)
if arg_type.to_be_mutated:
arg_type.to_be_mutated = False # Done!
possible_tds = []
if isinstance(inferred, TypedDictType):
possible_tds = [inferred]
elif isinstance(inferred, UnionType):
possible_tds = [
t
for t in get_proper_types(inferred.relevant_items())
if isinstance(t, TypedDictType)
]
items = []
for td in possible_tds:
item = arg_type.copy_modified(
required_keys=(arg_type.required_keys | td.required_keys)
& arg_type.items.keys()
)
if not ctx.api.options.extra_checks:
item = item.copy_modified(item_names=list(td.items))
items.append(item)
if items:
arg_type = make_simplified_union(items)
return signature.copy_modified(arg_types=[arg_type])
return signature
def int_pow_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for int.__pow__."""
# int.__pow__ has an optional modulo argument,
# so we expect 2 argument positions
if len(ctx.arg_types) == 2 and len(ctx.arg_types[0]) == 1 and len(ctx.arg_types[1]) == 0:
arg = ctx.args[0][0]
if isinstance(arg, IntExpr):
exponent = arg.value
elif isinstance(arg, UnaryExpr) and arg.op == "-" and isinstance(arg.expr, IntExpr):
exponent = -arg.expr.value
else:
# Right operand not an int literal or a negated literal -- give up.
return ctx.default_return_type
if exponent >= 0:
return ctx.api.named_generic_type("builtins.int", [])
else:
return ctx.api.named_generic_type("builtins.float", [])
return ctx.default_return_type
def int_neg_callback(ctx: MethodContext, multiplier: int = -1) -> Type:
"""Infer a more precise return type for int.__neg__ and int.__pos__.
This is mainly used to infer the return type as LiteralType
if the original underlying object is a LiteralType object.
"""
if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None:
value = ctx.type.last_known_value.value
fallback = ctx.type.last_known_value.fallback
if isinstance(value, int):
if is_literal_type_like(ctx.api.type_context[-1]):
return LiteralType(value=multiplier * value, fallback=fallback)
else:
return ctx.type.copy_modified(
last_known_value=LiteralType(
value=multiplier * value,
fallback=fallback,
line=ctx.type.line,
column=ctx.type.column,
)
)
elif isinstance(ctx.type, LiteralType):
value = ctx.type.value
fallback = ctx.type.fallback
if isinstance(value, int):
return LiteralType(value=multiplier * value, fallback=fallback)
return ctx.default_return_type
def int_pos_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for int.__pos__.
This is identical to __neg__, except the value is not inverted.
"""
return int_neg_callback(ctx, +1)
def tuple_mul_callback(ctx: MethodContext) -> Type:
"""Infer a more precise return type for tuple.__mul__ and tuple.__rmul__.
This is used to return a specific sized tuple if multiplied by Literal int
"""
if not isinstance(ctx.type, TupleType):
return ctx.default_return_type
arg_type = get_proper_type(ctx.arg_types[0][0])
if isinstance(arg_type, Instance) and arg_type.last_known_value is not None:
value = arg_type.last_known_value.value
if isinstance(value, int):
return ctx.type.copy_modified(items=ctx.type.items * value)
elif isinstance(arg_type, LiteralType):
value = arg_type.value
if isinstance(value, int):
return ctx.type.copy_modified(items=ctx.type.items * value)
return ctx.default_return_type

View file

@ -0,0 +1,299 @@
"""
This file contains a variety of plugins for refining how mypy infers types of
expressions involving Enums.
Currently, this file focuses on providing better inference for expressions like
'SomeEnum.FOO.name' and 'SomeEnum.FOO.value'. Note that the type of both expressions
will vary depending on exactly which instance of SomeEnum we're looking at.
Note that this file does *not* contain all special-cased logic related to enums:
we actually bake some of it directly in to the semantic analysis layer (see
semanal_enum.py).
"""
from __future__ import annotations
from collections.abc import Iterable, Sequence
from typing import TypeVar, cast
import mypy.plugin # To avoid circular imports.
from mypy.checker_shared import TypeCheckerSharedApi
from mypy.nodes import TypeInfo, Var
from mypy.subtypes import is_equivalent
from mypy.typeops import fixup_partial_type, make_simplified_union
from mypy.types import (
ELLIPSIS_TYPE_NAMES,
CallableType,
Instance,
LiteralType,
ProperType,
Type,
get_proper_type,
is_named_instance,
)
def enum_name_callback(ctx: mypy.plugin.AttributeContext) -> Type:
"""This plugin refines the 'name' attribute in enums to act as if
they were declared to be final.
For example, the expression 'MyEnum.FOO.name' normally is inferred
to be of type 'str'.
This plugin will instead make the inferred type be a 'str' where the
last known value is 'Literal["FOO"]'. This means it would be legal to
use 'MyEnum.FOO.name' in contexts that expect a Literal type, just like
any other Final variable or attribute.
This plugin assumes that the provided context is an attribute access
matching one of the strings found in 'ENUM_NAME_ACCESS'.
"""
enum_field_name = _extract_underlying_field_name(ctx.type)
if enum_field_name is None:
return ctx.default_attr_type
else:
str_type = ctx.api.named_generic_type("builtins.str", [])
literal_type = LiteralType(enum_field_name, fallback=str_type)
return str_type.copy_modified(last_known_value=literal_type)
_T = TypeVar("_T")
def _first(it: Iterable[_T]) -> _T | None:
"""Return the first value from any iterable.
Returns ``None`` if the iterable is empty.
"""
for val in it:
return val
return None
def _infer_value_type_with_auto_fallback(
ctx: mypy.plugin.AttributeContext, proper_type: ProperType | None
) -> Type | None:
"""Figure out the type of an enum value accounting for `auto()`.
This method is a no-op for a `None` proper_type and also in the case where
the type is not "enum.auto"
"""
if proper_type is None:
return None
proper_type = get_proper_type(fixup_partial_type(proper_type))
# Enums in stubs may have ... instead of actual values. If `_value_` is annotated
# (manually or inherited from IntEnum, for example), it is a more reasonable guess
# than literal ellipsis type.
if (
_is_defined_in_stub(ctx)
and isinstance(proper_type, Instance)
and proper_type.type.fullname in ELLIPSIS_TYPE_NAMES
and isinstance(ctx.type, Instance)
):
value_type = ctx.type.type.get("_value_")
if value_type is not None and isinstance(var := value_type.node, Var):
return var.type
return proper_type
if not (isinstance(proper_type, Instance) and proper_type.type.fullname == "enum.auto"):
if is_named_instance(proper_type, "enum.member") and proper_type.args:
return proper_type.args[0]
return proper_type
assert isinstance(ctx.type, Instance), "An incorrect ctx.type was passed."
info = ctx.type.type
# Find the first _generate_next_value_ on the mro. We need to know
# if it is `Enum` because `Enum` types say that the return-value of
# `_generate_next_value_` is `Any`. In reality the default `auto()`
# returns an `int` (presumably the `Any` in typeshed is to make it
# easier to subclass and change the returned type).
type_with_gnv = _first(ti for ti in info.mro if ti.names.get("_generate_next_value_"))
if type_with_gnv is None:
return ctx.default_attr_type
stnode = type_with_gnv.names["_generate_next_value_"]
# This should be a `CallableType`
node_type = get_proper_type(stnode.type)
if isinstance(node_type, CallableType):
if type_with_gnv.fullname == "enum.Enum":
int_type = ctx.api.named_generic_type("builtins.int", [])
return int_type
return get_proper_type(node_type.ret_type)
return ctx.default_attr_type
def _is_defined_in_stub(ctx: mypy.plugin.AttributeContext) -> bool:
assert isinstance(ctx.api, TypeCheckerSharedApi)
return isinstance(ctx.type, Instance) and ctx.api.is_defined_in_stub(ctx.type)
def _implements_new(info: TypeInfo) -> bool:
"""Check whether __new__ comes from enum.Enum or was implemented in a
subclass of enum.Enum. In the latter case, we must infer Any as long as mypy can't infer
the type of _value_ from assignments in __new__.
If, however, __new__ comes from a user-defined class that is not an Enum subclass (i.e.
the data type) this is allowed, because we should in general infer that an enum entry's
value has that type.
"""
type_with_new = _first(ti for ti in info.mro if ti.is_enum and ti.names.get("__new__"))
if type_with_new is None:
return False
return type_with_new.fullname not in ("enum.Enum", "enum.IntEnum", "enum.StrEnum")
def enum_member_callback(ctx: mypy.plugin.FunctionContext) -> Type:
"""By default `member(1)` will be inferred as `member[int]`,
we want to improve the inference to be `Literal[1]` here."""
if ctx.arg_types and ctx.arg_types[0]:
arg = get_proper_type(ctx.arg_types[0][0])
proper_return = get_proper_type(ctx.default_return_type)
if (
isinstance(arg, Instance)
and arg.last_known_value
and isinstance(proper_return, Instance)
and len(proper_return.args) == 1
):
return proper_return.copy_modified(args=[arg])
return ctx.default_return_type
def enum_value_callback(ctx: mypy.plugin.AttributeContext) -> Type:
"""This plugin refines the 'value' attribute in enums to refer to
the original underlying value. For example, suppose we have the
following:
class SomeEnum:
FOO = A()
BAR = B()
By default, mypy will infer that 'SomeEnum.FOO.value' and
'SomeEnum.BAR.value' both are of type 'Any'. This plugin refines
this inference so that mypy understands the expressions are
actually of types 'A' and 'B' respectively. This better reflects
the actual runtime behavior.
This plugin works simply by looking up the original value assigned
to the enum. For example, when this plugin sees 'SomeEnum.BAR.value',
it will look up whatever type 'BAR' had in the SomeEnum TypeInfo and
use that as the inferred type of the overall expression.
This plugin assumes that the provided context is an attribute access
matching one of the strings found in 'ENUM_VALUE_ACCESS'.
"""
enum_field_name = _extract_underlying_field_name(ctx.type)
if enum_field_name is None:
# We do not know the enum field name (perhaps it was passed to a
# function and we only know that it _is_ a member). All is not lost
# however, if we can prove that the all of the enum members have the
# same value-type, then it doesn't matter which member was passed in.
# The value-type is still known.
if isinstance(ctx.type, Instance):
info = ctx.type.type
# As long as mypy doesn't understand attribute creation in __new__,
# there is no way to predict the value type if the enum class has a
# custom implementation
if _implements_new(info):
return ctx.default_attr_type
stnodes = (info.get(name) for name in info.names)
# Enums _can_ have methods, instance attributes, and `nonmember`s.
# Omit methods and attributes created by assigning to self.*
# for our value inference.
node_types = (
get_proper_type(n.type) if n else None
for n in stnodes
if n is None or not n.implicit
)
proper_types = [
_infer_value_type_with_auto_fallback(ctx, t)
for t in node_types
if t is None
or (not isinstance(t, CallableType) and not is_named_instance(t, "enum.nonmember"))
]
underlying_type = _first(proper_types)
if underlying_type is None:
return ctx.default_attr_type
# At first we try to predict future `value` type if all other items
# have the same type. For example, `int`.
# If this is the case, we simply return this type.
# See https://github.com/python/mypy/pull/9443
all_same_value_type = all(
proper_type is not None and proper_type == underlying_type
for proper_type in proper_types
)
if all_same_value_type:
if underlying_type is not None:
return underlying_type
# But, after we started treating all `Enum` values as `Final`,
# we start to infer types in
# `item = 1` as `Literal[1]`, not just `int`.
# So, for example types in this `Enum` will all be different:
#
# class Ordering(IntEnum):
# one = 1
# two = 2
# three = 3
#
# We will infer three `Literal` types here.
# They are not the same, but they are equivalent.
# So, we unify them to make sure `.value` prediction still works.
# Result will be `Literal[1] | Literal[2] | Literal[3]` for this case.
all_equivalent_types = all(
proper_type is not None and is_equivalent(proper_type, underlying_type)
for proper_type in proper_types
)
if all_equivalent_types:
return make_simplified_union(cast(Sequence[Type], proper_types))
return ctx.default_attr_type
assert isinstance(ctx.type, Instance)
info = ctx.type.type
# As long as mypy doesn't understand attribute creation in __new__,
# there is no way to predict the value type if the enum class has a
# custom implementation
if _implements_new(info):
return ctx.default_attr_type
stnode = info.get(enum_field_name)
if stnode is None:
return ctx.default_attr_type
underlying_type = _infer_value_type_with_auto_fallback(ctx, get_proper_type(stnode.type))
if underlying_type is None:
return ctx.default_attr_type
return underlying_type
def _extract_underlying_field_name(typ: Type) -> str | None:
"""If the given type corresponds to some Enum instance, returns the
original name of that enum. For example, if we receive in the type
corresponding to 'SomeEnum.FOO', we return the string "SomeEnum.Foo".
This helper takes advantage of the fact that Enum instances are valid
to use inside Literal[...] types. An expression like 'SomeEnum.FOO' is
actually represented by an Instance type with a Literal enum fallback.
We can examine this Literal fallback to retrieve the string.
"""
typ = get_proper_type(typ)
if not isinstance(typ, Instance):
return None
if not typ.type.is_enum:
return None
underlying_literal = typ.last_known_value
if underlying_literal is None:
return None
# The checks above have verified this LiteralType is representing an enum value,
# which means the 'value' field is guaranteed to be the name of the enum field
# as a string.
assert isinstance(underlying_literal.value, str)
return underlying_literal.value

View file

@ -0,0 +1,398 @@
"""Plugin for supporting the functools standard library module."""
from __future__ import annotations
from typing import Final, NamedTuple
import mypy.checker
import mypy.plugin
import mypy.semanal
from mypy.argmap import map_actuals_to_formals
from mypy.erasetype import erase_typevars
from mypy.nodes import (
ARG_POS,
ARG_STAR2,
SYMBOL_FUNCBASE_TYPES,
ArgKind,
Argument,
CallExpr,
NameExpr,
Var,
)
from mypy.plugins.common import add_method_to_class
from mypy.typeops import get_all_type_vars
from mypy.types import (
AnyType,
CallableType,
Instance,
Overloaded,
ParamSpecFlavor,
ParamSpecType,
Type,
TypeOfAny,
TypeVarType,
UnboundType,
UnionType,
get_proper_type,
)
functools_total_ordering_makers: Final = {"functools.total_ordering"}
_ORDERING_METHODS: Final = {"__lt__", "__le__", "__gt__", "__ge__"}
PARTIAL: Final = "functools.partial"
class _MethodInfo(NamedTuple):
is_static: bool
type: CallableType
def functools_total_ordering_maker_callback(
ctx: mypy.plugin.ClassDefContext, auto_attribs_default: bool = False
) -> bool:
"""Add dunder methods to classes decorated with functools.total_ordering."""
comparison_methods = _analyze_class(ctx)
if not comparison_methods:
ctx.api.fail(
'No ordering operation defined when using "functools.total_ordering": < > <= >=',
ctx.reason,
)
return True
# prefer __lt__ to __le__ to __gt__ to __ge__
root = max(comparison_methods, key=lambda k: (comparison_methods[k] is None, k))
root_method = comparison_methods[root]
if not root_method:
# None of the defined comparison methods can be analysed
return True
other_type = _find_other_type(root_method)
bool_type = ctx.api.named_type("builtins.bool")
ret_type: Type = bool_type
if root_method.type.ret_type != ctx.api.named_type("builtins.bool"):
proper_ret_type = get_proper_type(root_method.type.ret_type)
if not (
isinstance(proper_ret_type, UnboundType)
and proper_ret_type.name.split(".")[-1] == "bool"
):
ret_type = AnyType(TypeOfAny.implementation_artifact)
for additional_op in _ORDERING_METHODS:
# Either the method is not implemented
# or has an unknown signature that we can now extrapolate.
if not comparison_methods.get(additional_op):
args = [Argument(Var("other", other_type), other_type, None, ARG_POS)]
add_method_to_class(ctx.api, ctx.cls, additional_op, args, ret_type)
return True
def _find_other_type(method: _MethodInfo) -> Type:
"""Find the type of the ``other`` argument in a comparison method."""
first_arg_pos = 0 if method.is_static else 1
cur_pos_arg = 0
other_arg = None
for arg_kind, arg_type in zip(method.type.arg_kinds, method.type.arg_types):
if arg_kind.is_positional():
if cur_pos_arg == first_arg_pos:
other_arg = arg_type
break
cur_pos_arg += 1
elif arg_kind != ARG_STAR2:
other_arg = arg_type
break
if other_arg is None:
return AnyType(TypeOfAny.implementation_artifact)
return other_arg
def _analyze_class(ctx: mypy.plugin.ClassDefContext) -> dict[str, _MethodInfo | None]:
"""Analyze the class body, its parents, and return the comparison methods found."""
# Traverse the MRO and collect ordering methods.
comparison_methods: dict[str, _MethodInfo | None] = {}
# Skip object because total_ordering does not use methods from object
for cls in ctx.cls.info.mro[:-1]:
for name in _ORDERING_METHODS:
if name in cls.names and name not in comparison_methods:
node = cls.names[name].node
if isinstance(node, SYMBOL_FUNCBASE_TYPES) and isinstance(node.type, CallableType):
comparison_methods[name] = _MethodInfo(node.is_static, node.type)
continue
if isinstance(node, Var):
proper_type = get_proper_type(node.type)
if isinstance(proper_type, CallableType):
comparison_methods[name] = _MethodInfo(node.is_staticmethod, proper_type)
continue
comparison_methods[name] = None
return comparison_methods
def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
"""Infer a more precise return type for functools.partial"""
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
return ctx.default_return_type
if len(ctx.arg_types) != 3: # fn, *args, **kwargs
return ctx.default_return_type
if len(ctx.arg_types[0]) != 1:
return ctx.default_return_type
if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded):
# TODO: handle overloads, just fall back to whatever the non-plugin code does
return ctx.default_return_type
return handle_partial_with_callee(ctx, callee=ctx.arg_types[0][0])
def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -> Type:
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
return ctx.default_return_type
if isinstance(callee_proper := get_proper_type(callee), UnionType):
return UnionType.make_union(
[handle_partial_with_callee(ctx, item) for item in callee_proper.items]
)
fn_type = ctx.api.extract_callable_type(callee, ctx=ctx.default_return_type)
if fn_type is None:
return ctx.default_return_type
# We must normalize from the start to have coherent view together with TypeChecker.
fn_type = fn_type.with_unpacked_kwargs().with_normalized_var_args()
last_context = ctx.api.type_context[-1]
if not fn_type.is_type_obj():
# We wrap the return type to get use of a possible type context provided by caller.
# We cannot do this in case of class objects, since otherwise the plugin may get
# falsely triggered when evaluating the constructed call itself.
ret_type: Type = ctx.api.named_generic_type(PARTIAL, [fn_type.ret_type])
wrapped_return = True
else:
ret_type = fn_type.ret_type
# Instead, for class objects we ignore any type context to avoid spurious errors,
# since the type context will be partial[X] etc., not X.
ctx.api.type_context[-1] = None
wrapped_return = False
# Flatten actual to formal mapping, since this is what check_call() expects.
actual_args = []
actual_arg_kinds = []
actual_arg_names = []
actual_types = []
seen_args = set()
for i, param in enumerate(ctx.args[1:], start=1):
for j, a in enumerate(param):
if a in seen_args:
# Same actual arg can map to multiple formals, but we need to include
# each one only once.
continue
# Here we rely on the fact that expressions are essentially immutable, so
# they can be compared by identity.
seen_args.add(a)
actual_args.append(a)
actual_arg_kinds.append(ctx.arg_kinds[i][j])
actual_arg_names.append(ctx.arg_names[i][j])
actual_types.append(ctx.arg_types[i][j])
formal_to_actual = map_actuals_to_formals(
actual_kinds=actual_arg_kinds,
actual_names=actual_arg_names,
formal_kinds=fn_type.arg_kinds,
formal_names=fn_type.arg_names,
actual_arg_type=lambda i: actual_types[i],
)
# We need to remove any type variables that appear only in formals that have
# no actuals, to avoid eagerly binding them in check_call() below.
can_infer_ids = set()
for i, arg_type in enumerate(fn_type.arg_types):
if not formal_to_actual[i]:
continue
can_infer_ids.update({tv.id for tv in get_all_type_vars(arg_type)})
# special_sig="partial" allows omission of args/kwargs typed with ParamSpec
defaulted = fn_type.copy_modified(
arg_kinds=[
(
ArgKind.ARG_OPT
if k == ArgKind.ARG_POS
else (ArgKind.ARG_NAMED_OPT if k == ArgKind.ARG_NAMED else k)
)
for k in fn_type.arg_kinds
],
ret_type=ret_type,
variables=[
tv
for tv in fn_type.variables
# Keep TypeVarTuple/ParamSpec to avoid spurious errors on empty args.
if tv.id in can_infer_ids or not isinstance(tv, TypeVarType)
],
special_sig="partial",
)
if defaulted.line < 0:
# Make up a line number if we don't have one
defaulted.set_line(ctx.default_return_type)
# Create a valid context for various ad-hoc inspections in check_call().
call_expr = CallExpr(
callee=ctx.args[0][0],
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
analyzed=ctx.context.analyzed if isinstance(ctx.context, CallExpr) else None,
)
call_expr.set_line(ctx.context)
_, bound = ctx.api.expr_checker.check_call(
callee=defaulted,
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
context=call_expr,
)
if not wrapped_return:
# Restore previously ignored context.
ctx.api.type_context[-1] = last_context
bound = get_proper_type(bound)
if not isinstance(bound, CallableType):
return ctx.default_return_type
if wrapped_return:
# Reverse the wrapping we did above.
ret_type = get_proper_type(bound.ret_type)
if not isinstance(ret_type, Instance) or ret_type.type.fullname != PARTIAL:
return ctx.default_return_type
bound = bound.copy_modified(ret_type=ret_type.args[0])
partial_kinds = []
partial_types = []
partial_names = []
# We need to fully apply any positional arguments (they cannot be respecified)
# However, keyword arguments can be respecified, so just give them a default
for i, actuals in enumerate(formal_to_actual):
if len(bound.arg_types) == len(fn_type.arg_types):
arg_type = bound.arg_types[i]
if not mypy.checker.is_valid_inferred_type(arg_type, ctx.api.options):
arg_type = fn_type.arg_types[i] # bit of a hack
else:
# TODO: I assume that bound and fn_type have the same arguments. It appears this isn't
# true when PEP 646 things are happening. See testFunctoolsPartialTypeVarTuple
arg_type = fn_type.arg_types[i]
if not actuals or fn_type.arg_kinds[i] in (ArgKind.ARG_STAR, ArgKind.ARG_STAR2):
partial_kinds.append(fn_type.arg_kinds[i])
partial_types.append(arg_type)
partial_names.append(fn_type.arg_names[i])
else:
assert actuals
if any(actual_arg_kinds[j] in (ArgKind.ARG_POS, ArgKind.ARG_STAR) for j in actuals):
# Don't add params for arguments passed positionally
continue
# Add defaulted params for arguments passed via keyword
kind = actual_arg_kinds[actuals[0]]
if kind == ArgKind.ARG_NAMED or kind == ArgKind.ARG_STAR2:
kind = ArgKind.ARG_NAMED_OPT
partial_kinds.append(kind)
partial_types.append(arg_type)
partial_names.append(fn_type.arg_names[i])
ret_type = bound.ret_type
if not mypy.checker.is_valid_inferred_type(ret_type, ctx.api.options):
ret_type = fn_type.ret_type # same kind of hack as above
# Technically, we should set definition to None here, since it will not be recovered
# on warm cache runs in fixup.py. This however may hide some helpful info in error
# messages, so we are keeping it for now. See also issue #20640.
partially_applied = fn_type.copy_modified(
arg_types=partial_types,
arg_kinds=partial_kinds,
arg_names=partial_names,
ret_type=ret_type,
special_sig="partial",
)
# Do not leak typevars from generic functions - they cannot be usable.
# Keep them in the wrapped callable, but avoid `partial[SomeStrayTypeVar]`
erased_ret_type = erase_typevars(ret_type, [tv.id for tv in fn_type.variables])
ret = ctx.api.named_generic_type(PARTIAL, [erased_ret_type])
ret = ret.copy_with_extra_attr("__mypy_partial", partially_applied)
if partially_applied.param_spec():
assert ret.extra_attrs is not None # copy_with_extra_attr above ensures this
attrs = ret.extra_attrs.copy()
if ArgKind.ARG_STAR in actual_arg_kinds:
attrs.immutable.add("__mypy_partial_paramspec_args_bound")
if ArgKind.ARG_STAR2 in actual_arg_kinds:
attrs.immutable.add("__mypy_partial_paramspec_kwargs_bound")
ret.extra_attrs = attrs
return ret
def partial_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
"""Infer a more precise return type for functools.partial.__call__."""
if (
not isinstance(ctx.api, mypy.checker.TypeChecker) # use internals
or not isinstance(ctx.type, Instance)
or ctx.type.type.fullname != PARTIAL
or not ctx.type.extra_attrs
or "__mypy_partial" not in ctx.type.extra_attrs.attrs
):
return ctx.default_return_type
extra_attrs = ctx.type.extra_attrs
partial_type = get_proper_type(extra_attrs.attrs["__mypy_partial"])
if len(ctx.arg_types) != 2: # *args, **kwargs
return ctx.default_return_type
# See comments for similar actual to formal code above
actual_args = []
actual_arg_kinds = []
actual_arg_names = []
seen_args = set()
for i, param in enumerate(ctx.args):
for j, a in enumerate(param):
if a in seen_args:
continue
seen_args.add(a)
actual_args.append(a)
actual_arg_kinds.append(ctx.arg_kinds[i][j])
actual_arg_names.append(ctx.arg_names[i][j])
result, _ = ctx.api.expr_checker.check_call(
callee=partial_type,
args=actual_args,
arg_kinds=actual_arg_kinds,
arg_names=actual_arg_names,
context=ctx.context,
)
if not isinstance(partial_type, CallableType) or partial_type.param_spec() is None:
return result
args_bound = "__mypy_partial_paramspec_args_bound" in extra_attrs.immutable
kwargs_bound = "__mypy_partial_paramspec_kwargs_bound" in extra_attrs.immutable
passed_paramspec_parts = [
arg.node.type
for arg in actual_args
if isinstance(arg, NameExpr)
and isinstance(arg.node, Var)
and isinstance(arg.node.type, ParamSpecType)
]
# ensure *args: P.args
args_passed = any(part.flavor == ParamSpecFlavor.ARGS for part in passed_paramspec_parts)
if not args_bound and not args_passed:
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)
elif args_bound and args_passed:
ctx.api.expr_checker.msg.too_many_arguments(partial_type, ctx.context)
# ensure **kwargs: P.kwargs
kwargs_passed = any(part.flavor == ParamSpecFlavor.KWARGS for part in passed_paramspec_parts)
if not kwargs_bound and not kwargs_passed:
ctx.api.expr_checker.msg.too_few_arguments(partial_type, ctx.context, actual_arg_names)
return result

View file

@ -0,0 +1,178 @@
"""
This plugin is helpful for mypy development itself.
By default, it is not enabled for mypy users.
It also can be used by plugin developers as a part of their CI checks.
It finds missing ``get_proper_type()`` call, which can lead to multiple errors.
"""
from __future__ import annotations
from collections.abc import Callable
from mypy.checker import TypeChecker
from mypy.nodes import TypeInfo
from mypy.plugin import FunctionContext, Plugin
from mypy.subtypes import is_proper_subtype
from mypy.types import (
AnyType,
FunctionLike,
Instance,
NoneTyp,
ProperType,
TupleType,
Type,
UnionType,
get_proper_type,
get_proper_types,
)
class ProperTypePlugin(Plugin):
"""
A plugin to ensure that every type is expanded before doing any special-casing.
This solves the problem that we have hundreds of call sites like:
if isinstance(typ, UnionType):
... # special-case union
But after introducing a new type TypeAliasType (and removing immediate expansion)
all these became dangerous because typ may be e.g. an alias to union.
"""
def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
if fullname == "builtins.isinstance":
return isinstance_proper_hook
if fullname == "mypy.types.get_proper_type":
return proper_type_hook
if fullname == "mypy.types.get_proper_types":
return proper_types_hook
return None
def isinstance_proper_hook(ctx: FunctionContext) -> Type:
if len(ctx.arg_types) != 2 or not ctx.arg_types[1]:
return ctx.default_return_type
right = get_proper_type(ctx.arg_types[1][0])
for arg in ctx.arg_types[0]:
if (
is_improper_type(arg) or isinstance(get_proper_type(arg), AnyType)
) and is_dangerous_target(right):
if is_special_target(right):
return ctx.default_return_type
ctx.api.fail(
"Never apply isinstance() to unexpanded types;"
" use mypy.types.get_proper_type() first",
ctx.context,
)
ctx.api.note( # type: ignore[attr-defined]
"If you pass on the original type"
" after the check, always use its unexpanded version",
ctx.context,
)
return ctx.default_return_type
def is_special_target(right: ProperType) -> bool:
"""Whitelist some special cases for use in isinstance() with improper types."""
if isinstance(right, FunctionLike) and right.is_type_obj():
if right.type_object().fullname == "builtins.tuple":
# Used with Union[Type, Tuple[Type, ...]].
return True
if right.type_object().fullname in (
"mypy.types.Type",
"mypy.types.ProperType",
"mypy.types.TypeAliasType",
):
# Special case: things like assert isinstance(typ, ProperType) are always OK.
return True
if right.type_object().fullname in (
"mypy.types.UnboundType",
"mypy.types.TypeVarLikeType",
"mypy.types.TypeVarType",
"mypy.types.UnpackType",
"mypy.types.TypeVarTupleType",
"mypy.types.ParamSpecType",
"mypy.types.Parameters",
"mypy.types.RawExpressionType",
"mypy.types.EllipsisType",
"mypy.types.StarType",
"mypy.types.TypeList",
"mypy.types.CallableArgument",
"mypy.types.PartialType",
"mypy.types.ErasedType",
"mypy.types.DeletedType",
"mypy.types.RequiredType",
"mypy.types.ReadOnlyType",
"mypy.types.TypeGuardedType",
"mypy.types.PlaceholderType",
):
# Special case: these are not valid targets for a type alias and thus safe.
# TODO: introduce a SyntheticType base to simplify this?
return True
elif isinstance(right, TupleType):
return all(is_special_target(t) for t in get_proper_types(right.items))
return False
def is_improper_type(typ: Type) -> bool:
"""Is this a type that is not a subtype of ProperType?"""
typ = get_proper_type(typ)
if isinstance(typ, Instance):
info = typ.type
return info.has_base("mypy.types.Type") and not info.has_base("mypy.types.ProperType")
if isinstance(typ, UnionType):
return any(is_improper_type(t) for t in typ.items)
return False
def is_dangerous_target(typ: ProperType) -> bool:
"""Is this a dangerous target (right argument) for an isinstance() check?"""
if isinstance(typ, TupleType):
return any(is_dangerous_target(get_proper_type(t)) for t in typ.items)
if isinstance(typ, FunctionLike) and typ.is_type_obj():
return typ.type_object().has_base("mypy.types.Type")
return False
def proper_type_hook(ctx: FunctionContext) -> Type:
"""Check if this get_proper_type() call is not redundant."""
arg_types = ctx.arg_types[0]
if arg_types:
arg_type = get_proper_type(arg_types[0])
proper_type = get_proper_type_instance(ctx)
if is_proper_subtype(arg_type, UnionType.make_union([NoneTyp(), proper_type])):
# Minimize amount of spurious errors from overload machinery.
# TODO: call the hook on the overload as a whole?
if isinstance(arg_type, (UnionType, Instance)):
ctx.api.fail("Redundant call to get_proper_type()", ctx.context)
return ctx.default_return_type
def proper_types_hook(ctx: FunctionContext) -> Type:
"""Check if this get_proper_types() call is not redundant."""
arg_types = ctx.arg_types[0]
if arg_types:
arg_type = arg_types[0]
proper_type = get_proper_type_instance(ctx)
item_type = UnionType.make_union([NoneTyp(), proper_type])
ok_type = ctx.api.named_generic_type("typing.Iterable", [item_type])
if is_proper_subtype(arg_type, ok_type):
ctx.api.fail("Redundant call to get_proper_types()", ctx.context)
return ctx.default_return_type
def get_proper_type_instance(ctx: FunctionContext) -> Instance:
checker = ctx.api
assert isinstance(checker, TypeChecker)
types = checker.modules["mypy.types"]
proper_type_info = types.names["ProperType"]
assert isinstance(proper_type_info.node, TypeInfo)
return Instance(proper_type_info.node, [])
def plugin(version: str) -> type[ProperTypePlugin]:
return ProperTypePlugin

View file

@ -0,0 +1,213 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import NamedTuple, TypeAlias as _TypeAlias, TypeVar
from mypy.messages import format_type
from mypy.nodes import ARG_POS, Argument, Block, ClassDef, Context, SymbolTable, TypeInfo, Var
from mypy.options import Options
from mypy.plugin import CheckerPluginInterface, FunctionContext, MethodContext, MethodSigContext
from mypy.plugins.common import add_method_to_class
from mypy.plugins.constants import SINGLEDISPATCH_REGISTER_RETURN_CLASS
from mypy.subtypes import is_subtype
from mypy.types import (
AnyType,
CallableType,
FunctionLike,
Instance,
NoneType,
Overloaded,
Type,
TypeOfAny,
get_proper_type,
)
class SingledispatchTypeVars(NamedTuple):
return_type: Type
fallback: CallableType
class RegisterCallableInfo(NamedTuple):
register_type: Type
singledispatch_obj: Instance
def get_singledispatch_info(typ: Instance) -> SingledispatchTypeVars | None:
if len(typ.args) == 2:
return SingledispatchTypeVars(*typ.args) # type: ignore[arg-type]
return None
T = TypeVar("T")
def get_first_arg(args: list[list[T]]) -> T | None:
"""Get the element that corresponds to the first argument passed to the function"""
if args and args[0]:
return args[0][0]
return None
def make_fake_register_class_instance(
api: CheckerPluginInterface, type_args: Sequence[Type]
) -> Instance:
defn = ClassDef(SINGLEDISPATCH_REGISTER_RETURN_CLASS, Block([]))
defn.fullname = f"functools.{SINGLEDISPATCH_REGISTER_RETURN_CLASS}"
info = TypeInfo(SymbolTable(), defn, "functools")
obj_type = api.named_generic_type("builtins.object", []).type
info.bases = [Instance(obj_type, [])]
info.mro = [info, obj_type]
defn.info = info
func_arg = Argument(Var("name"), AnyType(TypeOfAny.implementation_artifact), None, ARG_POS)
add_method_to_class(api, defn, "__call__", [func_arg], NoneType())
return Instance(info, type_args)
PluginContext: _TypeAlias = FunctionContext | MethodContext
def fail(ctx: PluginContext, msg: str, context: Context | None) -> None:
"""Emit an error message.
This tries to emit an error message at the location specified by `context`, falling back to the
location specified by `ctx.context`. This is helpful when the only context information about
where you want to put the error message may be None (like it is for `CallableType.definition`)
and falling back to the location of the calling function is fine."""
# TODO: figure out if there is some more reliable way of getting context information, so this
# function isn't necessary
if context is not None:
err_context = context
else:
err_context = ctx.context
ctx.api.fail(msg, err_context)
def create_singledispatch_function_callback(ctx: FunctionContext) -> Type:
"""Called for functools.singledispatch"""
func_type = get_proper_type(get_first_arg(ctx.arg_types))
if isinstance(func_type, CallableType):
if len(func_type.arg_kinds) < 1:
fail(
ctx, "Singledispatch function requires at least one argument", func_type.definition
)
return ctx.default_return_type
elif not func_type.arg_kinds[0].is_positional(star=True):
fail(
ctx,
"First argument to singledispatch function must be a positional argument",
func_type.definition,
)
return ctx.default_return_type
# singledispatch returns an instance of functools._SingleDispatchCallable according to
# typeshed
singledispatch_obj = get_proper_type(ctx.default_return_type)
assert isinstance(singledispatch_obj, Instance)
singledispatch_obj.args += (func_type,)
return ctx.default_return_type
def singledispatch_register_callback(ctx: MethodContext) -> Type:
"""Called for functools._SingleDispatchCallable.register"""
assert isinstance(ctx.type, Instance)
# TODO: check that there's only one argument
first_arg_type = get_proper_type(get_first_arg(ctx.arg_types))
if isinstance(first_arg_type, (CallableType, Overloaded)) and first_arg_type.is_type_obj():
# HACK: We received a class as an argument to register. We need to be able
# to access the function that register is being applied to, and the typeshed definition
# of register has it return a generic Callable, so we create a new
# SingleDispatchRegisterCallable class, define a __call__ method, and then add a
# plugin hook for that.
# is_subtype doesn't work when the right type is Overloaded, so we need the
# actual type
register_type = first_arg_type.items[0].ret_type
type_args = RegisterCallableInfo(register_type, ctx.type)
register_callable = make_fake_register_class_instance(ctx.api, type_args)
return register_callable
elif isinstance(first_arg_type, CallableType):
# TODO: do more checking for registered functions
register_function(ctx, ctx.type, first_arg_type, ctx.api.options)
# The typeshed stubs for register say that the function returned is Callable[..., T], even
# though the function returned is the same as the one passed in. We return the type of the
# function so that mypy can properly type check cases where the registered function is used
# directly (instead of through singledispatch)
return first_arg_type
# fallback in case we don't recognize the arguments
return ctx.default_return_type
def register_function(
ctx: PluginContext,
singledispatch_obj: Instance,
func: Type,
options: Options,
register_arg: Type | None = None,
) -> None:
"""Register a function"""
func = get_proper_type(func)
if not isinstance(func, CallableType):
return
metadata = get_singledispatch_info(singledispatch_obj)
if metadata is None:
# if we never added the fallback to the type variables, we already reported an error, so
# just don't do anything here
return
dispatch_type = get_dispatch_type(func, register_arg)
if dispatch_type is None:
# TODO: report an error here that singledispatch requires at least one argument
# (might want to do the error reporting in get_dispatch_type)
return
fallback = metadata.fallback
fallback_dispatch_type = fallback.arg_types[0]
if not is_subtype(dispatch_type, fallback_dispatch_type):
fail(
ctx,
"Dispatch type {} must be subtype of fallback function first argument {}".format(
format_type(dispatch_type, options), format_type(fallback_dispatch_type, options)
),
func.definition,
)
return
return
def get_dispatch_type(func: CallableType, register_arg: Type | None) -> Type | None:
if register_arg is not None:
return register_arg
if func.arg_types:
return func.arg_types[0]
return None
def call_singledispatch_function_after_register_argument(ctx: MethodContext) -> Type:
"""Called on the function after passing a type to register"""
register_callable = ctx.type
if isinstance(register_callable, Instance):
type_args = RegisterCallableInfo(*register_callable.args) # type: ignore[arg-type]
func = get_first_arg(ctx.arg_types)
if func is not None:
register_function(
ctx, type_args.singledispatch_obj, func, ctx.api.options, type_args.register_type
)
# see call to register_function in the callback for register
return func
return ctx.default_return_type
def call_singledispatch_function_callback(ctx: MethodSigContext) -> FunctionLike:
"""Called for functools._SingleDispatchCallable.__call__"""
if not isinstance(ctx.type, Instance):
return ctx.default_signature
metadata = get_singledispatch_info(ctx.type)
if metadata is None:
return ctx.default_signature
return metadata.fallback