"""Tests for the experimental native mypy parser. To run these, you will need to manually install ast_serialize from https://github.com/mypyc/ast_serialize first (see the README for the details). """ from __future__ import annotations import contextlib import os import tempfile import unittest from collections.abc import Iterator from typing import Any from mypy import defaults, nodes from mypy.cache import ( END_TAG, LIST_GEN, LIST_INT, LITERAL_INT, LITERAL_NONE, LITERAL_STR, LOCATION, ) from mypy.config_parser import parse_mypy_comments from mypy.errors import CompileError from mypy.nodes import MypyFile from mypy.options import Options from mypy.test.data import DataDrivenTestCase, DataSuite from mypy.test.helpers import assert_string_arrays_equal from mypy.util import get_mypy_comments # If the experimental ast_serialize module isn't installed, the following import will fail # and we won't run any native parser tests. try: from mypy.nativeparse import native_parse, parse_to_binary_ast has_nativeparse = True except ImportError: has_nativeparse = False class NativeParserSuite(DataSuite): required_out_section = True base_path = "." files = ["native-parser.test"] if has_nativeparse else [] def run_case(self, testcase: DataDrivenTestCase) -> None: test_parser(testcase) class NativeParserImportsSuite(DataSuite): required_out_section = True base_path = "." files = ["native-parser-imports.test"] if has_nativeparse else [] def run_case(self, testcase: DataDrivenTestCase) -> None: test_parser_imports(testcase) def test_parser(testcase: DataDrivenTestCase) -> None: """Perform a single native parser test case. The argument contains the description of the test case. """ options = Options() options.hide_error_codes = True if testcase.file.endswith("python310.test"): options.python_version = (3, 10) elif testcase.file.endswith("python312.test"): options.python_version = (3, 12) elif testcase.file.endswith("python313.test"): options.python_version = (3, 13) elif testcase.file.endswith("python314.test"): options.python_version = (3, 14) else: options.python_version = defaults.PYTHON3_VERSION source = "\n".join(testcase.input) # Apply mypy: comments to options. comments = get_mypy_comments(source) changes, _ = parse_mypy_comments(comments, options) options = options.apply_changes(changes) # Check if we should skip function bodies (when ignoring errors) skip_function_bodies = "# mypy: ignore-errors=True" in source try: with temp_source(source) as fnam: node, errors, type_ignores = native_parse(fnam, options, skip_function_bodies) node.path = "main" a = node.str_with_options(options).split("\n") a = [format_error(err) for err in errors] + a a = [format_ignore(ignore) for ignore in type_ignores] + a except CompileError as e: a = e.messages assert_string_arrays_equal( testcase.output, a, f"Invalid parser output ({testcase.file}, line {testcase.line})" ) def format_error(err: dict[str, Any]) -> str: return f"{err['line']}:{err['column']}: error: {err['message']}" def format_ignore(ignore: tuple[int, list[str]]) -> str: line, codes = ignore if not codes: return f"ignore: {line}" else: return f"ignore: {line} [{', '.join(codes)}]" def test_parser_imports(testcase: DataDrivenTestCase) -> None: """Perform a single native parser imports test case. The argument contains the description of the test case. This test outputs only reachable import information. """ options = Options() options.hide_error_codes = True options.python_version = (3, 10) source = "\n".join(testcase.input) try: with temp_source(source) as fnam: node, errors, type_ignores = native_parse(fnam, options) # Extract and format reachable imports a = format_reachable_imports(node) a = [format_error(err) for err in errors] + a except CompileError as e: a = e.messages assert_string_arrays_equal( testcase.output, a, f"Invalid parser output ({testcase.file}, line {testcase.line})" ) def format_reachable_imports(node: MypyFile) -> list[str]: """Format reachable imports from a MypyFile node. Returns a list of strings representing reachable imports with line numbers and flags. """ from mypy.nodes import Import, ImportAll, ImportFrom output: list[str] = [] # Filter for reachable imports (is_unreachable == False) reachable_imports = [imp for imp in node.imports if not imp.is_unreachable] for imp in reachable_imports: line_num = imp.line # Collect flags (only show when flag is False/not set) flags = [] if not imp.is_top_level: flags.append("not top_level") if imp.is_mypy_only: flags.append("mypy_only") flags_str = " [" + ", ".join(flags) + "]" if flags else "" if isinstance(imp, Import): # Format: line: import foo [as bar] [flags] for module_id, as_id in imp.ids: if as_id: output.append(f"{line_num}: import {module_id} as {as_id}{flags_str}") else: output.append(f"{line_num}: import {module_id}{flags_str}") elif isinstance(imp, ImportFrom): # Format: line: from foo import bar, baz [as b] [flags] # Handle relative imports if imp.relative > 0: prefix = "." * imp.relative if imp.id: module = f"{prefix}{imp.id}" else: module = prefix else: module = imp.id # Group all names together name_parts = [] for name, as_name in imp.names: if as_name: name_parts.append(f"{name} as {as_name}") else: name_parts.append(name) names_str = ", ".join(name_parts) output.append(f"{line_num}: from {module} import {names_str}{flags_str}") elif isinstance(imp, ImportAll): # Format: line: from foo import * [flags] # Handle relative imports if imp.relative > 0: prefix = "." * imp.relative if imp.id: module = f"{prefix}{imp.id}" else: module = prefix else: module = imp.id output.append(f"{line_num}: from {module} import *{flags_str}") return output @unittest.skipUnless(has_nativeparse, "nativeparse not available") class TestNativeParserBinaryFormat(unittest.TestCase): def test_trivial_binary_data(self) -> None: # A quick sanity check to ensure the serialized data looks as expected. Only covers # a few AST nodes. def int_enc(n: int) -> int: return (n + 10) << 1 def locs(start_line: int, start_column: int, end_line: int, end_column: int) -> list[int]: return [ LOCATION, int_enc(start_line), int_enc(start_column), int_enc(end_line - start_line), int_enc(end_column - start_column), ] with temp_source("print('hello')") as fnam: b, _, _, _, _, _ = parse_to_binary_ast(fnam, Options()) assert list(b) == ( [LITERAL_INT, 22, nodes.EXPR_STMT, nodes.CALL_EXPR] + [nodes.NAME_EXPR, LITERAL_STR] + [int_enc(5)] + list(b"print") + locs(1, 0, 1, 5) + [END_TAG, LIST_GEN, 22, nodes.STR_EXPR] + [LITERAL_STR, int_enc(5)] + list(b"hello") + locs(1, 6, 1, 13) + [END_TAG] # arg_kinds: [ARG_POS] + [LIST_INT, 22, int_enc(0)] # arg_names: [None] + [LIST_GEN, 22, LITERAL_NONE] + locs(1, 0, 1, 14) + [END_TAG, END_TAG] ) @contextlib.contextmanager def temp_source(text: str) -> Iterator[str]: with tempfile.TemporaryDirectory() as temp_dir: temp_path = os.path.join(temp_dir, "t.py") with open(temp_path, "w") as f: f.write(text) yield temp_path