diff --git a/mypy/checker.py b/mypy/checker.py index 59571954e0f7..620db7078585 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -424,6 +424,9 @@ class TypeChecker(NodeVisitor[None], TypeCheckerSharedApi, SplittingVisitor): # Short names of Var nodes whose previous inferred type has been widened via assignment. # NOTE: The names might not be unique, they are only for debugging purposes. widened_vars: list[str] + # Global variables widened inside a function body, to be propagated to + # the module-level binder after the function is type checked (with --allow-redefinition-new). + _globals_widened_in_func: list[tuple[NameExpr, Type]] globals: SymbolTable modules: dict[str, MypyFile] # Nodes that couldn't be checked because some types weren't available. We'll run @@ -488,6 +491,7 @@ def __init__( self.tscope = Scope() self.scope = CheckerScope(tree) self.binder = ConditionalTypeBinder(options) + self.globals_binder = self.binder self.globals = tree.names self.return_types = [] self.dynamic_funcs = [] @@ -496,6 +500,7 @@ def __init__( self.var_decl_frames = {} self.deferred_nodes = [] self.widened_vars = [] + self._globals_widened_in_func = [] self._type_maps = [{}] self.module_refs = set() self.pass_num = 0 @@ -1616,6 +1621,16 @@ def check_func_def( self.return_types.pop() + # Propagate any global variable widenings directly to the + # module-level binder (skipping any intermediate class binders). + if self._globals_widened_in_func: + for lvalue, widened_type in self._globals_widened_in_func: + self.globals_binder.put(lvalue, widened_type) + lit = literal_hash(lvalue) + if lit is not None: + self.globals_binder.declarations[lit] = widened_type + self._globals_widened_in_func = [] + self.binder = old_binder def check_funcdef_item( @@ -4897,6 +4912,7 @@ def check_simple_assignment( not self.refers_to_different_scope(lvalue) and not isinstance(inferred.type, PartialType) and not is_proper_subtype(new_inferred, inferred.type) + and self.can_widen_in_scope(lvalue, inferred.type) ): lvalue_type = make_simplified_union([inferred.type, new_inferred]) # Widen the type to the union of original and new type. @@ -4904,6 +4920,10 @@ def check_simple_assignment( # Skip index variables as they are reset on each loop. self.widened_vars.append(inferred.name) self.set_inferred_type(inferred, lvalue, lvalue_type) + if lvalue.kind == GDEF and self.scope.top_level_function() is not None: + # Widening a global inside a function -- record for + # propagation to the module-level binder afterwards. + self._globals_widened_in_func.append((lvalue, lvalue_type)) self.binder.put(lvalue, rvalue_type) # TODO: A bit hacky, maybe add a binder method that does put and # updates declaration? @@ -4932,7 +4952,7 @@ def refers_to_different_scope(self, name: NameExpr) -> bool: if name.kind == LDEF: # TODO: Consider reference to outer function as a different scope? return False - elif self.scope.top_level_function() is not None: + elif self.scope.top_level_function() is not None and name.kind != GDEF: # A non-local reference from within a function must refer to a different scope return True elif name.kind == GDEF and name.fullname.rpartition(".")[0] != self.tree.fullname: @@ -4940,6 +4960,22 @@ def refers_to_different_scope(self, name: NameExpr) -> bool: return True return False + def can_widen_in_scope(self, name: NameExpr, orig_type: Type) -> bool: + """Can a variable type be widened via assignment in the current scope? + + Globals can only be widened from within a function if the original type + is None (backward compat with partial type handling of `x = None`). + + See test cases testNewRedefineGlobalVariableNoneInit[1-4], for example. + """ + if ( + name.kind == GDEF + and self.scope.top_level_function() is not None + and not isinstance(get_proper_type(orig_type), NoneType) + ): + return False + return True + def check_member_assignment( self, lvalue: MemberExpr, @@ -8337,7 +8373,16 @@ def visit_nonlocal_decl(self, o: NonlocalDecl, /) -> None: return None def visit_global_decl(self, o: GlobalDecl, /) -> None: - return None + if self.options.allow_redefinition_new: + # Add names to binder, since their types could be widened + for name in o.names: + sym = self.globals.get(name) + if sym and isinstance(sym.node, Var) and sym.node.type is not None: + n = NameExpr(name) + n.node = sym.node + n.kind = GDEF + n.fullname = sym.node.fullname + self.binder.assign_type(n, sym.node.type, sym.node.type) class TypeCheckerAsSemanticAnalyzer(SemanticAnalyzerCoreInterface): diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index c08908eb8a3a..3550d6967ead 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -4350,20 +4350,22 @@ def g() -> None: x = "" # E: Incompatible types in assignment (expression has type "str", variable has type "int | None") reveal_type(x) # N: Revealed type is "builtins.int | None" +reveal_type(x) # N: Revealed type is "builtins.int | None" + [case testGlobalVariableNoneInitMultipleFuncsRedefine] # flags: --allow-redefinition-new --local-partial-types -# Widening this is intentionally prohibited (for now). +# Widening None is supported, as a special case x = None def f() -> None: global x reveal_type(x) # N: Revealed type is "None" - x = 1 # E: Incompatible types in assignment (expression has type "int", variable has type "None") - reveal_type(x) # N: Revealed type is "None" + x = 1 + reveal_type(x) # N: Revealed type is "builtins.int" def g() -> None: global x - reveal_type(x) # N: Revealed type is "None" - x = "" # E: Incompatible types in assignment (expression has type "str", variable has type "None") - reveal_type(x) # N: Revealed type is "None" + reveal_type(x) # N: Revealed type is "None | builtins.int" + x = "" # E: Incompatible types in assignment (expression has type "str", variable has type "int | None") + reveal_type(x) # N: Revealed type is "None | builtins.int" diff --git a/test-data/unit/check-redefine2.test b/test-data/unit/check-redefine2.test index d8a7ccbfc4a4..ea357b4628bb 100644 --- a/test-data/unit/check-redefine2.test +++ b/test-data/unit/check-redefine2.test @@ -168,17 +168,76 @@ def f2() -> None: reveal_type(x) # N: Revealed type is "builtins.int | builtins.str" -[case testNewRedefineGlobalVariableNoneInit] +[case testNewRedefineGlobalVariableNoneInit1] # flags: --allow-redefinition-new --local-partial-types x = None def f() -> None: global x reveal_type(x) # N: Revealed type is "None" - x = 1 # E: Incompatible types in assignment (expression has type "int", variable has type "None") - reveal_type(x) # N: Revealed type is "None" + x = 1 + reveal_type(x) # N: Revealed type is "builtins.int" + +reveal_type(x) # N: Revealed type is "None | builtins.int" + +[case testNewRedefineGlobalVariableNoneInit2] +# flags: --allow-redefinition-new --local-partial-types +x = None + +def deco(f): return f + +@deco +def f() -> None: + global x + if int(): + x = 1 + reveal_type(x) # N: Revealed type is "None | builtins.int" + +reveal_type(x) # N: Revealed type is "None | builtins.int" + +[case testNewRedefineGlobalVariableNoneInit3] +# flags: --allow-redefinition-new --local-partial-types +from typing import overload + +x = None + +class C: + @overload + def f(self) -> None: ... + @overload + def f(self, n: int) -> None: ... + + def f(self, n: int = 0) -> None: + global x + x = n + a = [x] + x = None + reveal_type(x) # N: Revealed type is "None" + +reveal_type(x) # N: Revealed type is "None | builtins.int" + +[case testNewRedefineGlobalVariableNoneInit4] +# flags: --allow-redefinition-new --local-partial-types +x = None + +def f() -> None: + def nested() -> None: + global x + x = 1 -reveal_type(x) # N: Revealed type is "None" + nested() + +reveal_type(x) # N: Revealed type is "None | builtins.int" + +[case testNewRedefineGlobalVariableWithUnsupportedType] +# flags: --allow-redefinition-new --local-partial-types +x = 1 + +def f() -> None: + global x + x = "a" # E: Incompatible types in assignment (expression has type "str", variable has type "int") + +reveal_type(x) # N: Revealed type is "builtins.int" [case testNewRedefineParameterTypes] # flags: --allow-redefinition-new --local-partial-types @@ -641,6 +700,7 @@ def f5() -> None: continue x = "" reveal_type(x) # N: Revealed type is "builtins.str" + [case testNewRedefineWhileLoopSimple] # flags: --allow-redefinition-new --local-partial-types def f() -> None: diff --git a/test-data/unit/fine-grained.test b/test-data/unit/fine-grained.test index 671a20b66779..9a9e4fffe715 100644 --- a/test-data/unit/fine-grained.test +++ b/test-data/unit/fine-grained.test @@ -6743,7 +6743,7 @@ class D: class D: y: int [out] -b2.py:7: error: Argument 1 to "f" has incompatible type "D"; expected "P2" (diff) +b2.py:7: error: Argument 1 to "f" has incompatible type "D"; expected "P2" b1.py:7: error: Argument 1 to "f" has incompatible type "D"; expected "P1" [case testProtocolsInvalidateByRemovingBase] @@ -11682,3 +11682,83 @@ def f() -> str: ... [out] == main:4: error: Incompatible return value type (got "str | None", expected "int | None") + +[case testGlobalNoneWidenedInFuncWithRedefinition1] +import m +reveal_type(m.x) + +[file m.py] +# mypy: allow-redefinition-new +import m2 + +x = None + +def foo() -> None: + global x + x = m2.bar() + +[file m2.py] +def bar() -> int: return 0 + +[file m2.py.2] +def bar() -> str: return "a" + +[out] +main:2: note: Revealed type is "None | builtins.int" +== +main:2: note: Revealed type is "None | builtins.str" + +[case testGlobalNoneWidenedInFuncWithRedefinition2] +import m +reveal_type(m.x) + +[file m.py] +# mypy: allow-redefinition-new +import m2 + +x = None + +def deco(f): return f + +class C: + @deco + def foo(self) -> None: + global x + x = m2.bar() + +[file m2.py] +def bar() -> int: return 0 + +[file m2.py.2] +def bar() -> str: return "a" + +[out] +main:2: note: Revealed type is "None | builtins.int" +== +main:2: note: Revealed type is "None | builtins.str" + +[case testGlobalNoneWidenedInFuncWithRedefinition3] +import m +reveal_type(m.x) + +[file m.py] +# mypy: allow-redefinition-new +import m2 + +x = None + +def foo() -> None: + def nested(self) -> None: + global x + x = m2.bar() + +[file m2.py] +def bar() -> int: return 0 + +[file m2.py.2] +def bar() -> str: return "a" + +[out] +main:2: note: Revealed type is "None | builtins.int" +== +main:2: note: Revealed type is "None | builtins.str"