Skip to content

Commit 38c99fb

Browse files
committed
feat: Add support for operator overloading and magic methods in type system
1 parent 65eb910 commit 38c99fb

File tree

7 files changed

+149
-18
lines changed

7 files changed

+149
-18
lines changed

jac/jaclang/compiler/type_system/operations.jac

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,12 @@ def get_type_of_binary_operation(
6767
}
6868
(magic, rmagic) = BINARY_OPERATOR_MAP[expr.op.name];
6969
return (
70-
evaluator.get_type_of_magic_method_call(left_type, magic)
71-
or evaluator.get_type_of_magic_method_call(right_type, rmagic)
70+
evaluator.get_type_of_magic_method_call(
71+
left_type, magic, [expr.right], expr
72+
)
73+
or evaluator.get_type_of_magic_method_call(
74+
right_type, rmagic, [expr.left], expr
75+
)
7276
or jtypes.UnknownType()
7377
);
7478
}

jac/jaclang/compiler/type_system/type_evaluator.impl/parameter_type_check.impl.jac

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ impl TypeEvaluator.validate_call_args(
4141
caller_type = self.get_type_of_expression(expr.target);
4242
# 1. Call to a function.
4343
if isinstance(caller_type, types.FunctionType) {
44-
arg_param_match = self.match_args_to_params(expr, caller_type);
44+
arg_param_match = self.match_args_to_params(expr.params, expr, caller_type);
4545
if not arg_param_match.argument_errors {
4646
self.validate_arg_types(arg_param_match);
4747
}
@@ -51,7 +51,7 @@ impl TypeEvaluator.validate_call_args(
5151
if isinstance(caller_type, types.OverloadedType) {
5252
for overload in caller_type.overloads {
5353
arg_param_match = self.match_args_to_params(
54-
expr, overload, checking_overload=True
54+
expr.params, expr, overload, checking_overload=True
5555
);
5656
if not arg_param_match.argument_errors {
5757
if self.validate_arg_types(arg_param_match, checking_overload=True) {
@@ -83,14 +83,18 @@ impl TypeEvaluator.validate_call_args(
8383
if init_method := self._lookup_class_member(caller_type, "__init__") {
8484
init_fn_type = self.get_type_of_symbol(init_method.symbol);
8585
if isinstance(init_fn_type, types.FunctionType) {
86-
arg_param_match = self.match_args_to_params(expr, init_fn_type);
86+
arg_param_match = self.match_args_to_params(
87+
expr.params, expr, init_fn_type
88+
);
8789
if not arg_param_match.argument_errors {
8890
self.validate_arg_types(arg_param_match);
8991
}
9092
}
9193
} elif caller_type.is_data_class() {
9294
init_fn_type = self._create_dataclass_init_method(caller_type);
93-
arg_param_match = self.match_args_to_params(expr, init_fn_type);
95+
arg_param_match = self.match_args_to_params(
96+
expr.params, expr, init_fn_type
97+
);
9498
if not arg_param_match.argument_errors {
9599
self.validate_arg_types(arg_param_match);
96100
}
@@ -99,8 +103,9 @@ impl TypeEvaluator.validate_call_args(
99103
}
100104
# 5. Call to a callable object (__call__).
101105
if caller_type.is_class_instance() {
102-
# TODO: validate args.
103-
magic_call_ret = self.get_type_of_magic_method_call(caller_type, "__call__");
106+
magic_call_ret = self.get_type_of_magic_method_call(
107+
caller_type, "__call__", expr.params, expr,
108+
);
104109
if magic_call_ret {
105110
return magic_call_ret;
106111
}
@@ -117,7 +122,8 @@ This logic is based on PEP 3102: https://www.python.org/dev/peps/pep-3102/
117122
"""
118123
impl TypeEvaluator.match_args_to_params(
119124
self: TypeEvaluator,
120-
expr: uni.FuncCall,
125+
arg_nodes: list[uni.Expr | uni.KWPair],
126+
node_for_error: uni.Expr | None,
121127
func_type: types.FunctionType,
122128
checking_overload: bool = False,
123129
) -> MatchArgsToParamsResult {
@@ -144,7 +150,7 @@ impl TypeEvaluator.match_args_to_params(
144150
# We match positional with | We match named arguments with
145151
# tracked parameter index. | param name lookup.
146152
#
147-
for arg in expr.params {
153+
for arg in arg_nodes {
148154
try {
149155
if isinstance(arg, uni.KWPair) {
150156
# Match parameter based on name lookup.
@@ -165,9 +171,9 @@ impl TypeEvaluator.match_args_to_params(
165171
if unmatched_params := param_tracker.get_unmatched_required_params() {
166172
argument_errors = True;
167173
names = ", ".join([f"'{p.name}'" for p in unmatched_params]);
168-
if not checking_overload {
174+
if not checking_overload and node_for_error {
169175
self.add_diagnostic(
170-
expr,
176+
node_for_error,
171177
f"Not all required parameters were provided in the function call: {names}",
172178
);
173179
}

jac/jaclang/compiler/type_system/type_evaluator.impl/type_evaluator.impl.jac

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ impl TypeEvaluator._get_type_of_expression_core(
159159
return base_type.specialize_generics(type_args);
160160
}
161161
# Regular <expr>[<expr>] case, we need to call __getitem__()
162-
return self.get_type_of_magic_method_call(base_type, "__getitem__")
162+
return self.get_type_of_magic_method_call(
163+
base_type, "__getitem__", [expr.right], expr
164+
)
163165
or types.UnknownType();
164166
}
165167

@@ -320,7 +322,10 @@ impl TypeEvaluator._get_type_of_symbol(
320322
# iter_type = self.get_type_of_magic_method_call(collection_type, "__iter__");
321323
#
322324
iter_type = self.get_type_of_magic_method_call(
323-
collection_type, "__getitem__"
325+
collection_type,
326+
"__getitem__",
327+
[node_.parent.collection],
328+
node_.parent
324329
);
325330
return iter_type or types.UnknownType();
326331
}
@@ -399,7 +404,11 @@ impl TypeEvaluator.is_sub_class(
399404

400405
"""Return the effective return type of a magic method call."""
401406
impl TypeEvaluator.get_type_of_magic_method_call(
402-
self: TypeEvaluator, obj_type: TypeBase, method_name: str
407+
self: TypeEvaluator,
408+
obj_type: TypeBase,
409+
method_name: str,
410+
arg_nodes: list[uni.Expr],
411+
node_for_error: uni.Expr | None = None,
403412
) -> TypeBase | None {
404413
if obj_type.category == types.TypeCategory.Class {
405414
# TODO: getTypeOfBoundMember() <-- Implement this if needed, for the simple case
@@ -410,8 +419,39 @@ impl TypeEvaluator.get_type_of_magic_method_call(
410419
assert isinstance(obj_type, types.ClassType); # <-- To make typecheck happy.
411420
if member := self._lookup_class_member(obj_type, method_name) {
412421
member_ty = self.get_type_of_symbol(member.symbol);
422+
# This list contains all functions including the overloads.
423+
overloaded_fns: list[types.FunctionType] = [];
413424
if isinstance(member_ty, types.FunctionType) {
414-
return member_ty.specialize(obj_type).return_type;
425+
member_ty = member_ty.specialize(obj_type);
426+
overloaded_fns.append(member_ty);
427+
for sym in member.overloads {
428+
overload_ty = self.get_type_of_symbol(sym);
429+
if isinstance(overload_ty, types.FunctionType) {
430+
overload_ty = overload_ty.specialize(obj_type);
431+
overloaded_fns.append(overload_ty);
432+
}
433+
}
434+
# Validate arguments for each overloaded function.
435+
for fn in overloaded_fns {
436+
arg_param_match = self.match_args_to_params(
437+
arg_nodes, node_for_error, fn, checking_overload=True
438+
);
439+
if not arg_param_match.argument_errors {
440+
if self.validate_arg_types(
441+
arg_param_match, checking_overload=True
442+
) {
443+
return fn.return_type or types.UnknownType();
444+
}
445+
}
446+
}
447+
# If we reached here, none of the overloads matched.
448+
if node_for_error {
449+
self.add_diagnostic(
450+
node_for_error,
451+
f'No matching overload found for method "{method_name}" with the given arguments',
452+
);
453+
}
454+
return types.UnknownType();
415455
}
416456
# If we reached here, magic method is not a function.
417457
# 1. recursively check __call__() on the type, TODO

jac/jaclang/compiler/type_system/type_evaluator.jac

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,11 @@ class TypeEvaluator {
165165

166166
# TODO: This should take an argument list as parameter.
167167
def get_type_of_magic_method_call(
168-
self: TypeEvaluator, obj_type: TypeBase, method_name: str
168+
self: TypeEvaluator,
169+
obj_type: TypeBase,
170+
method_name: str,
171+
arg_nodes: list[uni.Expr],
172+
node_for_error: uni.Expr | None = None,
169173
) -> TypeBase | None;
170174

171175
def is_sub_class(
@@ -222,7 +226,8 @@ class TypeEvaluator {
222226

223227
def match_args_to_params(
224228
self: TypeEvaluator,
225-
expr: uni.FuncCall,
229+
arg_nodes: list[uni.Expr | uni.KWPair],
230+
node_for_error: uni.Expr | None,
226231
func_type: types.FunctionType,
227232
checking_overload: bool = False,
228233
) -> MatchArgsToParamsResult;
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import from typing { Any, overload }
2+
3+
obj Foo {
4+
5+
@overload
6+
def do_something(value: int) -> str {}
7+
8+
@overload
9+
def do_something(value: list[int]) -> str {}
10+
11+
def do_something(value: Any) -> str {
12+
return "Value is: " + (value |> str);
13+
}
14+
}
15+
16+
with entry {
17+
18+
# @overload
19+
# def __pow__(self, value: int, mod: None = None, /) -> float: ...
20+
#
21+
# return type must be Any as `float | complex` causes too many false-positive errors
22+
# @overload
23+
# def __pow__(self, value: float, mod: None = None, /) -> Any: ...
24+
x: float = 3.14 ** 42; # <-- Ok
25+
y: Any = 3.14 ** 2.71; # <-- Ok
26+
27+
28+
foo: Foo = Foo();
29+
foo.do_something(10); # <-- Ok
30+
foo.do_something([1, 2, 3]); # <-- Ok
31+
foo.do_something("hello"); # <-- Error
32+
33+
}

jac/tests/compiler/passes/main/test_checker_pass.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,3 +916,17 @@ def test_dict_pop(fixture_path: Callable[[str], str]) -> None:
916916
""",
917917
program.errors_had[2].pretty_print(),
918918
)
919+
920+
921+
def test_overload(fixture_path: Callable[[str], str]) -> None:
922+
program = JacProgram()
923+
mod = program.compile(fixture_path("overload_test.jac"))
924+
TypeCheckPass(ir_in=mod, prog=program)
925+
assert len(program.errors_had) == 1
926+
_assert_error_pretty_found(
927+
"""
928+
foo.do_something("hello"); # <-- Error
929+
^^^^^^^^^^^^^^^^^^^^^^^^^
930+
""",
931+
program.errors_had[0].pretty_print(),
932+
)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import from typing { Any, overload }
2+
3+
obj Foo {
4+
5+
@overload
6+
def do_something(value: int) -> str {}
7+
8+
@overload
9+
def do_something(value: list[int]) -> str {}
10+
}
11+
12+
with entry {
13+
14+
# @overload
15+
# def __pow__(self, value: int, mod: None = None, /) -> float: ...
16+
#
17+
# return type must be Any as `float | complex` causes too many false-positive errors
18+
# @overload
19+
# def __pow__(self, value: float, mod: None = None, /) -> Any: ...
20+
x: float = 3.14 ** 42; # <-- Ok
21+
y: Any = 3.14 ** 2.71; # <-- Ok
22+
23+
24+
foo: Foo = Foo();
25+
foo.do_something(10); # <-- Ok
26+
foo.do_something([1, 2, 3]); # <-- Ok
27+
foo.do_something("hello"); # <-- Error
28+
29+
}

0 commit comments

Comments
 (0)