Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions jac/jaclang/compiler/type_system/operations.jac
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@ def get_type_of_binary_operation(
}
(magic, rmagic) = BINARY_OPERATOR_MAP[expr.op.name];
return (
evaluator.get_type_of_magic_method_call(left_type, magic)
or evaluator.get_type_of_magic_method_call(right_type, rmagic)
evaluator.get_type_of_magic_method_call(
left_type, magic, [expr.right], expr
)
or evaluator.get_type_of_magic_method_call(
right_type, rmagic, [expr.left], expr
)
or jtypes.UnknownType()
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl TypeEvaluator.validate_call_args(
caller_type = self.get_type_of_expression(expr.target);
# 1. Call to a function.
if isinstance(caller_type, types.FunctionType) {
arg_param_match = self.match_args_to_params(expr, caller_type);
arg_param_match = self.match_args_to_params(expr.params, expr, caller_type);
if not arg_param_match.argument_errors {
self.validate_arg_types(arg_param_match);
}
Expand All @@ -51,7 +51,7 @@ impl TypeEvaluator.validate_call_args(
if isinstance(caller_type, types.OverloadedType) {
for overload in caller_type.overloads {
arg_param_match = self.match_args_to_params(
expr, overload, checking_overload=True
expr.params, expr, overload, checking_overload=True
);
if not arg_param_match.argument_errors {
if self.validate_arg_types(arg_param_match, checking_overload=True) {
Expand Down Expand Up @@ -83,14 +83,18 @@ impl TypeEvaluator.validate_call_args(
if init_method := self._lookup_class_member(caller_type, "__init__") {
init_fn_type = self.get_type_of_symbol(init_method.symbol);
if isinstance(init_fn_type, types.FunctionType) {
arg_param_match = self.match_args_to_params(expr, init_fn_type);
arg_param_match = self.match_args_to_params(
expr.params, expr, init_fn_type
);
if not arg_param_match.argument_errors {
self.validate_arg_types(arg_param_match);
}
}
} elif caller_type.is_data_class() {
init_fn_type = self._create_dataclass_init_method(caller_type);
arg_param_match = self.match_args_to_params(expr, init_fn_type);
arg_param_match = self.match_args_to_params(
expr.params, expr, init_fn_type
);
if not arg_param_match.argument_errors {
self.validate_arg_types(arg_param_match);
}
Expand All @@ -99,8 +103,9 @@ impl TypeEvaluator.validate_call_args(
}
# 5. Call to a callable object (__call__).
if caller_type.is_class_instance() {
# TODO: validate args.
magic_call_ret = self.get_type_of_magic_method_call(caller_type, "__call__");
magic_call_ret = self.get_type_of_magic_method_call(
caller_type, "__call__", expr.params, expr,
);
if magic_call_ret {
return magic_call_ret;
}
Expand All @@ -117,7 +122,8 @@ This logic is based on PEP 3102: https://www.python.org/dev/peps/pep-3102/
"""
impl TypeEvaluator.match_args_to_params(
self: TypeEvaluator,
expr: uni.FuncCall,
arg_nodes: list[uni.Expr | uni.KWPair],
node_for_error: uni.Expr | None,
func_type: types.FunctionType,
checking_overload: bool = False,
) -> MatchArgsToParamsResult {
Expand All @@ -144,7 +150,7 @@ impl TypeEvaluator.match_args_to_params(
# We match positional with | We match named arguments with
# tracked parameter index. | param name lookup.
#
for arg in expr.params {
for arg in arg_nodes {
try {
if isinstance(arg, uni.KWPair) {
# Match parameter based on name lookup.
Expand All @@ -165,9 +171,9 @@ impl TypeEvaluator.match_args_to_params(
if unmatched_params := param_tracker.get_unmatched_required_params() {
argument_errors = True;
names = ", ".join([f"'{p.name}'" for p in unmatched_params]);
if not checking_overload {
if not checking_overload and node_for_error {
self.add_diagnostic(
expr,
node_for_error,
f"Not all required parameters were provided in the function call: {names}",
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ impl TypeEvaluator._get_type_of_expression_core(
return base_type.specialize_generics(type_args);
}
# Regular <expr>[<expr>] case, we need to call __getitem__()
return self.get_type_of_magic_method_call(base_type, "__getitem__")
return self.get_type_of_magic_method_call(
base_type, "__getitem__", [expr.right], expr
)
or types.UnknownType();
}

Expand Down Expand Up @@ -320,7 +322,10 @@ impl TypeEvaluator._get_type_of_symbol(
# iter_type = self.get_type_of_magic_method_call(collection_type, "__iter__");
#
iter_type = self.get_type_of_magic_method_call(
collection_type, "__getitem__"
collection_type,
"__getitem__",
[node_.parent.collection],
node_.parent
);
return iter_type or types.UnknownType();
}
Expand Down Expand Up @@ -399,7 +404,11 @@ impl TypeEvaluator.is_sub_class(

"""Return the effective return type of a magic method call."""
impl TypeEvaluator.get_type_of_magic_method_call(
self: TypeEvaluator, obj_type: TypeBase, method_name: str
self: TypeEvaluator,
obj_type: TypeBase,
method_name: str,
arg_nodes: list[uni.Expr],
node_for_error: uni.Expr | None = None,
) -> TypeBase | None {
if obj_type.category == types.TypeCategory.Class {
# TODO: getTypeOfBoundMember() <-- Implement this if needed, for the simple case
Expand All @@ -410,8 +419,39 @@ impl TypeEvaluator.get_type_of_magic_method_call(
assert isinstance(obj_type, types.ClassType); # <-- To make typecheck happy.
if member := self._lookup_class_member(obj_type, method_name) {
member_ty = self.get_type_of_symbol(member.symbol);
# This list contains all functions including the overloads.
overloaded_fns: list[types.FunctionType] = [];
if isinstance(member_ty, types.FunctionType) {
return member_ty.specialize(obj_type).return_type;
member_ty = member_ty.specialize(obj_type);
overloaded_fns.append(member_ty);
for sym in member.overloads {
overload_ty = self.get_type_of_symbol(sym);
if isinstance(overload_ty, types.FunctionType) {
overload_ty = overload_ty.specialize(obj_type);
overloaded_fns.append(overload_ty);
}
}
# Validate arguments for each overloaded function.
for fn in overloaded_fns {
arg_param_match = self.match_args_to_params(
arg_nodes, node_for_error, fn, checking_overload=True
);
if not arg_param_match.argument_errors {
if self.validate_arg_types(
arg_param_match, checking_overload=True
) {
return fn.return_type or types.UnknownType();
}
}
}
# If we reached here, none of the overloads matched.
if node_for_error {
self.add_diagnostic(
node_for_error,
f'No matching overload found for method "{method_name}" with the given arguments',
);
}
return types.UnknownType();
}
# If we reached here, magic method is not a function.
# 1. recursively check __call__() on the type, TODO
Expand Down
9 changes: 7 additions & 2 deletions jac/jaclang/compiler/type_system/type_evaluator.jac
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ class TypeEvaluator {

# TODO: This should take an argument list as parameter.
def get_type_of_magic_method_call(
self: TypeEvaluator, obj_type: TypeBase, method_name: str
self: TypeEvaluator,
obj_type: TypeBase,
method_name: str,
arg_nodes: list[uni.Expr],
node_for_error: uni.Expr | None = None,
) -> TypeBase | None;

def is_sub_class(
Expand Down Expand Up @@ -222,7 +226,8 @@ class TypeEvaluator {

def match_args_to_params(
self: TypeEvaluator,
expr: uni.FuncCall,
arg_nodes: list[uni.Expr | uni.KWPair],
node_for_error: uni.Expr | None,
func_type: types.FunctionType,
checking_overload: bool = False,
) -> MatchArgsToParamsResult;
Expand Down
40 changes: 40 additions & 0 deletions jac/tests/compiler/passes/main/fixtures/overload_test.jac
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import from typing { Any, overload }

obj Foo {

@overload
def do_something(value: int) -> str {}

@overload
def do_something(value: list[int]) -> str {}

@overload
def __add__(other: int) -> bool {}

@overload
def __add__(other: list[int]) -> bool {}

}

with entry {

# @overload
# def __pow__(self, value: int, mod: None = None, /) -> float: ...
#
# return type must be Any as `float | complex` causes too many false-positive errors
# @overload
# def __pow__(self, value: float, mod: None = None, /) -> Any: ...
x: float = 3.14 ** 42; # <-- Ok
y: Any = 3.14 ** 2.71; # <-- Ok


foo: Foo = Foo();
foo.do_something(10); # <-- Ok
foo.do_something([1, 2, 3]); # <-- Ok
foo.do_something("hello"); # <-- Error

foo + 42; # <-- Ok
foo + [1, 2, 3]; # <-- Ok
foo + "hello"; # <-- Error

}
21 changes: 21 additions & 0 deletions jac/tests/compiler/passes/main/test_checker_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,3 +916,24 @@ def test_dict_pop(fixture_path: Callable[[str], str]) -> None:
""",
program.errors_had[2].pretty_print(),
)


def test_overload(fixture_path: Callable[[str], str]) -> None:
program = JacProgram()
mod = program.compile(fixture_path("overload_test.jac"))
TypeCheckPass(ir_in=mod, prog=program)
assert len(program.errors_had) == 2
_assert_error_pretty_found(
"""
foo.do_something("hello"); # <-- Error
^^^^^^^^^^^^^^^^^^^^^^^^^
""",
program.errors_had[0].pretty_print(),
)
_assert_error_pretty_found(
"""
foo + "hello"; # <-- Error
^^^^^^^^^^^^^^^
""",
program.errors_had[1].pretty_print(),
)
Loading