Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions hamilton/htypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,20 @@
from collections.abc import Iterable
from typing import Any, Literal, Protocol, TypeVar, Union

import typing_extensions
import typing_inspect

from hamilton.registry import COLUMN_TYPE, DF_TYPE_AND_COLUMN_TYPES

_TYPE_ALIAS_TYPES = tuple(
t
for t in (
getattr(typing, "TypeAliasType", None),
getattr(typing_extensions, "TypeAliasType", None),
)
if t is not None
)

BASE_ARGS_FOR_GENERICS = (typing.T,)


Expand Down Expand Up @@ -336,6 +346,8 @@ def check_input_type(node_type: type, input_value: Any) -> bool:
:param input_value: Value to check.
:return: True if the input value is of the correct type, False otherwise.
"""
if _TYPE_ALIAS_TYPES and isinstance(node_type, _TYPE_ALIAS_TYPES):
return check_input_type(node_type.__value__, input_value)
if node_type == Any:
return True
# In the case of dict[str, Any] (or equivalent) in python 3.9 +
Expand All @@ -362,17 +374,21 @@ def check_input_type(node_type: type, input_value: Any) -> bool:
node_type
):
return True
# iterable (set, dict) is super class over sequence (list, tuple)
elif (
typing_inspect.is_generic_type(node_type)
and typing_inspect.get_origin(node_type)
in (list, tuple, typing_inspect.get_origin(typing.Sequence))
and isinstance(input_value, (list, tuple, typing_inspect.get_origin(typing.Sequence)))
):
if typing_inspect.get_args(node_type):
elif typing.get_origin(node_type) in (
list,
tuple,
typing_inspect.get_origin(typing.Sequence),
) and isinstance(input_value, (list, tuple, typing_inspect.get_origin(typing.Sequence))):
args = typing.get_args(node_type)
if args:
origin = typing.get_origin(node_type)
if origin is tuple and not (len(args) == 2 and args[1] is Ellipsis):
if not isinstance(input_value, tuple) or len(input_value) != len(args):
return False
return all(check_input_type(t, v) for t, v in zip(args, input_value, strict=True))
# check first value in sequence -- if the type is specified.
for i in input_value: # this handles empty input case, e.g. [] or (), set()
return check_input_type(typing_inspect.get_args(node_type)[0], i)
return check_input_type(args[0], i)
return True
elif (
typing_inspect.is_generic_type(node_type)
Expand Down
35 changes: 35 additions & 0 deletions tests/test_type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

import collections
import sys
import typing
from typing import Annotated, Any, Union

Expand Down Expand Up @@ -313,6 +314,40 @@ def test_check_input_types_subscripted_generics_list_Any():
assert actual is True


def test_check_input_type_parameterized_tuple_match():
assert htypes.check_input_type(tuple[float, float, float, float], (1.0, 2.0, 3.0, 4.0))
assert htypes.check_input_type(tuple[int, str], (1, "a"))


def test_check_input_type_parameterized_tuple_legacy_form():
assert htypes.check_input_type(tuple[float, float], (1.0, 2.0))


def test_check_input_type_parameterized_tuple_wrong_length():
assert htypes.check_input_type(tuple[int, str], (1,)) is False
assert htypes.check_input_type(tuple[int, str], (1, "a", 2)) is False


def test_check_input_type_parameterized_tuple_wrong_element_type():
assert htypes.check_input_type(tuple[int, str], (1, 2)) is False


def test_check_input_type_variable_length_tuple():
assert htypes.check_input_type(tuple[int, ...], (1, 2, 3))
assert htypes.check_input_type(tuple[int, ...], ())


@pytest.mark.skipif(
sys.version_info < (3, 12), reason="PEP 695 `type X = ...` syntax requires Python 3.12+"
)
def test_check_input_type_pep695_type_alias():
namespace: dict = {}
exec("type Bbox = tuple[float, float, float, float]", namespace)
Bbox = namespace["Bbox"]
assert htypes.check_input_type(Bbox, (1.0, 2.0, 3.0, 4.0))
assert htypes.check_input_type(Bbox, (1, 2, 3)) is False


def test_check_instance_with_non_generic_type():
assert check_instance(5, int)
assert not check_instance("5", int)
Expand Down
Loading