diff --git a/core_typechecking.cpp b/core_typechecking.cpp index d15594a..17fc196 100644 --- a/core_typechecking.cpp +++ b/core_typechecking.cpp @@ -792,19 +792,20 @@ function void scope_search(Scope_Search *search){ // Climb the scope tree and search each scope in module and also // search in implicitly imported scopes - Ast_Scope *scope = search->scope; - for(Ast_Scope *it = scope; it; it=it->parent_scope){ - inside_scope_search(search, it, 0); + For_Named(search->scopes, scope){ + for(Ast_Scope *it = scope; it; it=it->parent_scope){ + inside_scope_search(search, it, 0); - if(search->exit_on_find && search->results.len){ - return; - } + if(search->exit_on_find && search->results.len){ + return; + } - if(search->search_only_current_scope){ - return; + if(search->search_only_current_scope){ + return; + } } + inside_scope_search(search, scope->module, 0); } - inside_scope_search(search, scope->module, 0); } function Scope_Search @@ -813,7 +814,8 @@ make_scope_search(Arena *arena, Ast_Scope *scope, Intern_String name){ result.results.allocator = arena; result.name = name; result.scope_visit_id = ++pctx->scope_visit_id; - result.scope = scope; + result.scopes = {arena}; + result.scopes.add(scope); result.exit_on_find = true; return result; } @@ -909,13 +911,16 @@ map_operator_intern_to_identifier_name(Intern_String op){ } function Ast_Decl * -resolve_operator_overload(Ast_Scope *scope, Token *pos, Token_Kind op, U64 argument_hash){ +resolve_operator_overload(Ast_Scope *scope, Ast_Type *left, Ast_Type *right, Token *pos, Token_Kind op, U64 argument_hash){ Intern_String name = map_operator_to_intern(op); if(name.str == 0) return 0; - // Search for all possible candidates + // Search for all possible candidates in three scopes + // The current module, left type definition module, right type definition module Scratch scratch; Scope_Search search = make_scope_search(scratch, scope, name); + if( left->ast && left->ast->parent_scope) search.scopes.add(left->ast->parent_scope); + if(right && right->ast && right->ast->parent_scope) search.scopes.add(right->ast->parent_scope); search.exit_on_find = false; scope_search(&search); @@ -1639,7 +1644,7 @@ resolve_expr(Ast_Expr *ast, Resolve_Flag flags, Ast_Type *compound_context){ // Try finding a operator overload if(!is_const){ U64 hash = calculate_hash_for_arguments(left.type, right.type); - Ast_Decl *operator_overload = resolve_operator_overload(node->parent_scope, node->pos, node->op, hash); + Ast_Decl *operator_overload = resolve_operator_overload(node->parent_scope, left.type, right.type, node->pos, node->op, hash); if(operator_overload){ proceed_to_default_operator_handler = false; if(operator_overload->lambda->ret.len != 1){ @@ -1706,7 +1711,7 @@ resolve_expr(Ast_Expr *ast, Resolve_Flag flags, Ast_Type *compound_context){ B32 proceed_to_default_operator_handler = true; if(!value.is_const){ U64 hash = calculate_hash_for_arguments(value.type); - Ast_Decl *operator_overload = resolve_operator_overload(node->parent_scope, node->pos, node->op, hash); + Ast_Decl *operator_overload = resolve_operator_overload(node->parent_scope, value.type, 0, node->pos, node->op, hash); if(operator_overload){ proceed_to_default_operator_handler = false; if(operator_overload->lambda->ret.len != 1){ diff --git a/core_typechecking.h b/core_typechecking.h index fbbd7ff..2b6a972 100644 --- a/core_typechecking.h +++ b/core_typechecking.h @@ -10,7 +10,7 @@ struct Operand{ struct Scope_Search { Array results; Intern_String name; - Ast_Scope *scope; + Array scopes; bool exit_on_find; bool search_only_current_scope; diff --git a/examples/operator_overloading.kl b/examples/operator_overloading.kl index 8999d2f..7145b3f 100644 --- a/examples/operator_overloading.kl +++ b/examples/operator_overloading.kl @@ -10,6 +10,8 @@ Vec3 :: struct;; x: F32; y: F32; z: F32 main :: (): int a := Vec3{1,1,1} b := Vec3{2,3,4} + + // The expressions are replaced with the defined lambdas c := a + b Assert(c.x == 3 && c.y == 4 && c.z == 5) d := -c