From 9bb7b0dc96c7a189a2e3b214f884d233dffffe5a Mon Sep 17 00:00:00 2001 From: Krzosa Karol Date: Thu, 29 Sep 2022 14:50:48 +0200 Subject: [PATCH] Operator overload found! --- core_ast.cpp | 1 + core_typechecking.cpp | 69 ++++++++++++++++++++++++++------ examples/function_overloading.kl | 26 ++++-------- modules/math.kl | 3 -- 4 files changed, 65 insertions(+), 34 deletions(-) diff --git a/core_ast.cpp b/core_ast.cpp index 1a56955..d9e5847 100644 --- a/core_ast.cpp +++ b/core_ast.cpp @@ -79,6 +79,7 @@ struct Ast{ struct Ast_Type; struct Ast_Expr:Ast{ Ast_Type *resolved_type; + Ast_Decl *resolved_operator_overload; union{ Ast_Type *index_original_type; Ast_Type *cast_after_type; diff --git a/core_typechecking.cpp b/core_typechecking.cpp index 1c7a676..2e40be9 100644 --- a/core_typechecking.cpp +++ b/core_typechecking.cpp @@ -252,16 +252,25 @@ type_array(Ast_Type *base, S64 size){ } inline U64 -calculate_hash_for_arguments(Array args){ +calculate_hash_for_arguments(Ast_Type *a, Ast_Type *b){ U64 result = 13; - For(args) result = hash_mix(result, hash_ptr(it)); + result = hash_mix(result, hash_ptr(a)); + result = hash_mix(result, hash_ptr(b)); + return result; +} + +inline U64 +calculate_hash_for_arguments(Ast_Type *a){ + U64 result = 13; + result = hash_mix(result, hash_ptr(a)); return result; } function Ast_Type * type_lambda(Ast *ast, Array return_vals, Array args){ Ast_Type *ret = type_try_tupling(return_vals, ast); - U64 hash_without_ret = calculate_hash_for_arguments(args); + U64 hash_without_ret = 13; + For(args) hash_without_ret = hash_mix(hash_without_ret, hash_ptr(it)); U64 hash = hash_mix(hash_ptr(ret), hash_without_ret); Ast_Type *result = (Ast_Type *)map_get(&pctx->type_map, hash); @@ -844,8 +853,38 @@ resolve_name(Ast_Scope *scope, Token *pos, Intern_String name, Search_Flag searc return decl; } +function Intern_String +map_operator_to_intern(Token_Kind op){ + switch(op){ + case TK_Add: return op_add; break; + case TK_Mul: return op_mul; break; + case TK_Div: return op_div; break; + case TK_Sub: return op_sub; break; + case TK_And: return op_and; break; + case TK_BitAnd: return op_bitand; break; + case TK_Or: return op_or; break; + case TK_BitOr: return op_bitor; break; + case TK_BitXor: return op_xor; break; + case TK_Equals: return op_equals; break; + case TK_NotEquals: return op_not_equals; break; + case TK_LesserThenOrEqual: return op_lesser_then_or_equal; break; + case TK_GreaterThenOrEqual: return op_greater_then_or_equal; break; + case TK_LesserThen: return op_lesser_then; break; + case TK_GreaterThen: return op_greater_then; break; + case TK_LeftShift: return op_left_shift; break; + case TK_RightShift: return op_right_shift; break; + case TK_Not: return op_not; break; + case TK_Neg: return op_neg; break; + case TK_Decrement: return op_decrement; break; + case TK_Increment: return op_increment; break; + default: return {}; + } +} + function Ast_Decl * -resolve_operator_overload(Ast_Scope *scope, Token *pos, Intern_String name, U64 argument_hash){ +resolve_operator_overload(Ast_Scope *scope, Token *pos, Token_Kind op, U64 argument_hash){ + Intern_String name = map_operator_to_intern(op); + if(name.str == 0) return 0; // Search for all possible candidates Scratch scratch; @@ -853,10 +892,6 @@ resolve_operator_overload(Ast_Scope *scope, Token *pos, Intern_String name, U64 search.exit_on_find = false; scope_search(&search); - if(search.results.len == 0){ - compiler_error(pos, "Failed to find matching operator overload for [%s] No operator of this kind is defined", name.str); - } - // Resolve them until we hit a match Array matching_ops = {scratch}; For(search.results){ @@ -866,15 +901,15 @@ resolve_operator_overload(Ast_Scope *scope, Token *pos, Intern_String name, U64 } } - if(matching_ops.len == 0){ - compiler_error(pos, "Failed to find matching operator overload for [%s]", name.str); - } - if(matching_ops.len > 1){ compiler_error(pos, "Found multiple matching operator overloads for [%s]", name.str); } - return matching_ops.data[0]; + if(matching_ops.len == 1){ + return matching_ops.data[0]; + } + + return 0; } function void @@ -1579,6 +1614,14 @@ resolve_expr(Ast_Expr *ast, Resolve_Flag flags, Ast_Type *compound_context){ return operand_const_rvalue(value); } else { + + + U64 hash = calculate_hash_for_arguments(left.type, right.type); + Ast_Decl *operator_overload = resolve_operator_overload(node->parent_scope, node->pos, node->op, hash); + if(operator_overload){ + __debugbreak(); + } + try_propagating_resolved_type_to_untyped_literals(node->left, value.type); try_propagating_resolved_type_to_untyped_literals(node->right, value.type); return operand_rvalue(value.type); diff --git a/examples/function_overloading.kl b/examples/function_overloading.kl index 1ac997f..0463fc4 100644 --- a/examples/function_overloading.kl +++ b/examples/function_overloading.kl @@ -1,24 +1,14 @@ -/* -@todo: Add function overloading -Current plan: - * allow insert_into_scope to insert multiple lambdas - * change resolve_name and search_for_decl to something - that can seek multiple lambda declarations - resolve them and return a match to hash or type - * change the order of lambda call resolution, probably would have to - hash the arguments first to match the lambda call +Vec3 :: struct;; x: F32; y: F32; z: F32 +"+" :: (a: Vec3, b: Vec3): Vec3 ;; return Vec3{a.x+b.x, a.y+b.y, a.z+b.z} -*/ -// Add :: (a: int, b: int): int ;; return a + b -// Add :: (a: F32, b: F32): F32 ;; return a + b main :: (): int - a: F32 = 3 - b: F32 = 2 - c: int = 4 - d: int = 10 - // e := Add(a, b) - // f := Add(c, d) + a := Vec3{1,1,1} + b := Vec3{2,3,4} + c := a + b + Assert(a.x == 3) + Assert(a.y == 4) + Assert(a.z == 5) return 0 \ No newline at end of file diff --git a/modules/math.kl b/modules/math.kl index 00591b4..afd5619 100644 --- a/modules/math.kl +++ b/modules/math.kl @@ -6,9 +6,6 @@ Vec2I :: struct;; x: S64; y: S64 Vec2 :: struct;; x: F32; y: F32 Vec3 :: struct;; x: F32; y: F32; z: F32 -"+" :: (a: Vec3, b: Vec3): Vec3 - return Vec3_Add(a,b) - Vec3_Cross :: (a: Vec3, b: Vec3): Vec3 result := Vec3{ a.y * b.z - a.z * b.y,