diff --git a/core_codegen_c_language.cpp b/core_codegen_c_language.cpp index d95327c..2ef3e3d 100644 --- a/core_codegen_c_language.cpp +++ b/core_codegen_c_language.cpp @@ -386,11 +386,19 @@ gen_expr(Ast_Expr *ast, Ast_Type *type_of_var){ } CASE(UNARY, Unary){ - gen("("); - if(node->op != TK_PostIncrement && node->op != TK_PostDecrement) gen("%s", name(node->op)); - gen_expr(node->expr); - if(node->op == TK_PostIncrement || node->op == TK_PostDecrement) gen("%s", name(node->op)); - gen(")"); + if(node->resolved_operator_overload){ + gen("%Q(", node->resolved_operator_overload->unique_name); + gen_expr(node->expr); + gen(")"); + } + + else { + gen("("); + if(node->op != TK_PostIncrement && node->op != TK_PostDecrement) gen("%s", name(node->op)); + gen_expr(node->expr); + if(node->op == TK_PostIncrement || node->op == TK_PostDecrement) gen("%s", name(node->op)); + gen(")"); + } BREAK(); } diff --git a/core_compiler.h b/core_compiler.h index f148a9d..7461429 100644 --- a/core_compiler.h +++ b/core_compiler.h @@ -5,6 +5,7 @@ struct Ast_File; struct Ast_Module; struct Ast_Type; struct Ast; +struct Ast_Expr; enum Token_Kind{ TK_End, @@ -95,7 +96,6 @@ enum Token_Kind{ TK_Pointer = TK_Mul, TK_Dereference = TK_BitAnd, - OPEN_SCOPE = 128, CLOSE_SCOPE, SAME_SCOPE, @@ -122,7 +122,6 @@ struct Token{ S32 line; U8 *line_begin; }; -global Token null_token; struct Lex_Stream{ String stream; @@ -197,4 +196,5 @@ function String compile_to_c_code(); function Ast_Module *ast_module(Token *pos, Intern_String filename); function void insert_builtin_types_into_scope(Ast_Scope *p); function void insert_into_scope(Ast_Scope *scope, Ast_Decl *decl); -function Ast_Type *type_incomplete(Ast *ast); \ No newline at end of file +function Ast_Type *type_incomplete(Ast *ast); +function Ast_Expr *parse_expr(S64 minbp = 0); \ No newline at end of file diff --git a/core_globals.cpp b/core_globals.cpp index 50610df..f439a46 100644 --- a/core_globals.cpp +++ b/core_globals.cpp @@ -5,11 +5,12 @@ Allocator *bigint_allocator; global S64 bigint_allocation_count; global Token token_null = {SAME_SCOPE}; +global Token null_token; // @todo: memes, why the above is called null? //----------------------------------------------------------------------------- // Interns / keywords //----------------------------------------------------------------------------- -Intern_String keyword_struct; +Intern_String keyword_struct; // first Intern_String keyword_union; Intern_String keyword_return; Intern_String keyword_if; @@ -23,7 +24,7 @@ Intern_String keyword_switch; Intern_String keyword_break; Intern_String keyword_elif; Intern_String keyword_assert; -Intern_String keyword_enum; +Intern_String keyword_enum; // last Intern_String intern_sizeof; Intern_String intern_alignof; @@ -34,7 +35,7 @@ Intern_String intern_it; Intern_String intern_strict; Intern_String intern_flag; -Intern_String op_add; +Intern_String op_add; // first Intern_String op_mul; Intern_String op_div; Intern_String op_sub; @@ -51,10 +52,11 @@ Intern_String op_lesser_then; Intern_String op_greater_then; Intern_String op_left_shift; Intern_String op_right_shift; + Intern_String op_not; Intern_String op_neg; Intern_String op_decrement; -Intern_String op_increment; +Intern_String op_increment; // last //----------------------------------------------------------------------------- // Type globals diff --git a/core_parsing.cpp b/core_parsing.cpp index f9f9bed..321ce2e 100644 --- a/core_parsing.cpp +++ b/core_parsing.cpp @@ -165,8 +165,6 @@ token_expect(Token_Kind kind){ return 0; } -function Ast_Expr *parse_expr(S64 minbp = 0); - function Ast_Expr * parse_init_stmt(Ast_Expr *expr){ Token *token = token_get(); @@ -852,6 +850,10 @@ parse_decl(B32 is_global){ if(!is_valid_operator_overload(pctx, tname->intern_val)){ compiler_error(tname, "This operator cannot be overloaded"); } + // if(is_binary && expr->args.len == 2){ + + // } + result = ast_const(tname, tname->intern_val, expr); result->kind = AST_LAMBDA; diff --git a/core_typechecking.cpp b/core_typechecking.cpp index a829ea1..d15594a 100644 --- a/core_typechecking.cpp +++ b/core_typechecking.cpp @@ -1701,12 +1701,34 @@ resolve_expr(Ast_Expr *ast, Resolve_Flag flags, Ast_Type *compound_context){ return operand_lvalue(node->resolved_type); } else{ - eval_unary(node->pos, node->op, &value); + + // Try finding a operator overload + B32 proceed_to_default_operator_handler = true; + if(!value.is_const){ + U64 hash = calculate_hash_for_arguments(value.type); + Ast_Decl *operator_overload = resolve_operator_overload(node->parent_scope, node->pos, node->op, hash); + if(operator_overload){ + proceed_to_default_operator_handler = false; + if(operator_overload->lambda->ret.len != 1){ + compiler_error(operator_overload->pos, "Operator overload is required to have exactly 1 return value"); + } + + node->resolved_type = operator_overload->type->func.ret; + node->resolved_operator_overload = operator_overload; + } + } + + if(proceed_to_default_operator_handler){ + eval_unary(node->pos, node->op, &value); + node->resolved_type = value.type; + } + if(value.is_const){ rewrite_into_const(node, Ast_Unary, value.value); return operand_const_rvalue(value.value); } - return operand_rvalue(value.value.type); + + return operand_rvalue(node->resolved_type); } BREAK(); diff --git a/examples/operator_overloading.kl b/examples/operator_overloading.kl index 364545d..8999d2f 100644 --- a/examples/operator_overloading.kl +++ b/examples/operator_overloading.kl @@ -1,14 +1,19 @@ -Vec3 :: struct;; x: F32; y: F32; z: F32 -"+" :: (a: Vec3, b: Vec3): Vec3 ;; return Vec3{a.x+b.x, a.y+b.y, a.z+b.z} +Vec3 :: struct;; x: F32; y: F32; z: F32 +// We can define operator overloads for arbitrary types +// these are just regular lambdas/functions +"+" :: (a: Vec3, b: Vec3): Vec3 ;; return {a.x+b.x, a.y+b.y, a.z+b.z} +"-" :: (a: Vec3, b: Vec3): Vec3 ;; return {a.x-b.x, a.y-b.y, a.z-b.z} +"-" :: (a: Vec3): Vec3 ;; return {-a.x, -a.y, -a.z} main :: (): int a := Vec3{1,1,1} b := Vec3{2,3,4} c := a + b - Assert(c.x == 3) - Assert(c.y == 4) - Assert(c.z == 5) - + Assert(c.x == 3 && c.y == 4 && c.z == 5) + d := -c + Assert(d.x == -3 && d.y == -4 && d.z == -5) + e := c - d + Assert(e.x == 6 && e.y == 8 && e.z == 10) return 0 \ No newline at end of file