Operator overloads for unary expressions

This commit is contained in:
Krzosa Karol
2022-09-29 17:36:15 +02:00
parent 9ee22abbd2
commit 37489b2730
6 changed files with 61 additions and 22 deletions

View File

@@ -386,11 +386,19 @@ gen_expr(Ast_Expr *ast, Ast_Type *type_of_var){
} }
CASE(UNARY, Unary){ CASE(UNARY, Unary){
if(node->resolved_operator_overload){
gen("%Q(", node->resolved_operator_overload->unique_name);
gen_expr(node->expr);
gen(")");
}
else {
gen("("); gen("(");
if(node->op != TK_PostIncrement && node->op != TK_PostDecrement) gen("%s", name(node->op)); if(node->op != TK_PostIncrement && node->op != TK_PostDecrement) gen("%s", name(node->op));
gen_expr(node->expr); gen_expr(node->expr);
if(node->op == TK_PostIncrement || node->op == TK_PostDecrement) gen("%s", name(node->op)); if(node->op == TK_PostIncrement || node->op == TK_PostDecrement) gen("%s", name(node->op));
gen(")"); gen(")");
}
BREAK(); BREAK();
} }

View File

@@ -5,6 +5,7 @@ struct Ast_File;
struct Ast_Module; struct Ast_Module;
struct Ast_Type; struct Ast_Type;
struct Ast; struct Ast;
struct Ast_Expr;
enum Token_Kind{ enum Token_Kind{
TK_End, TK_End,
@@ -95,7 +96,6 @@ enum Token_Kind{
TK_Pointer = TK_Mul, TK_Pointer = TK_Mul,
TK_Dereference = TK_BitAnd, TK_Dereference = TK_BitAnd,
OPEN_SCOPE = 128, OPEN_SCOPE = 128,
CLOSE_SCOPE, CLOSE_SCOPE,
SAME_SCOPE, SAME_SCOPE,
@@ -122,7 +122,6 @@ struct Token{
S32 line; S32 line;
U8 *line_begin; U8 *line_begin;
}; };
global Token null_token;
struct Lex_Stream{ struct Lex_Stream{
String stream; String stream;
@@ -198,3 +197,4 @@ function Ast_Module *ast_module(Token *pos, Intern_String filename);
function void insert_builtin_types_into_scope(Ast_Scope *p); function void insert_builtin_types_into_scope(Ast_Scope *p);
function void insert_into_scope(Ast_Scope *scope, Ast_Decl *decl); function void insert_into_scope(Ast_Scope *scope, Ast_Decl *decl);
function Ast_Type *type_incomplete(Ast *ast); function Ast_Type *type_incomplete(Ast *ast);
function Ast_Expr *parse_expr(S64 minbp = 0);

View File

@@ -5,11 +5,12 @@ Allocator *bigint_allocator;
global S64 bigint_allocation_count; global S64 bigint_allocation_count;
global Token token_null = {SAME_SCOPE}; global Token token_null = {SAME_SCOPE};
global Token null_token; // @todo: memes, why the above is called null?
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
// Interns / keywords // Interns / keywords
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
Intern_String keyword_struct; Intern_String keyword_struct; // first
Intern_String keyword_union; Intern_String keyword_union;
Intern_String keyword_return; Intern_String keyword_return;
Intern_String keyword_if; Intern_String keyword_if;
@@ -23,7 +24,7 @@ Intern_String keyword_switch;
Intern_String keyword_break; Intern_String keyword_break;
Intern_String keyword_elif; Intern_String keyword_elif;
Intern_String keyword_assert; Intern_String keyword_assert;
Intern_String keyword_enum; Intern_String keyword_enum; // last
Intern_String intern_sizeof; Intern_String intern_sizeof;
Intern_String intern_alignof; Intern_String intern_alignof;
@@ -34,7 +35,7 @@ Intern_String intern_it;
Intern_String intern_strict; Intern_String intern_strict;
Intern_String intern_flag; Intern_String intern_flag;
Intern_String op_add; Intern_String op_add; // first
Intern_String op_mul; Intern_String op_mul;
Intern_String op_div; Intern_String op_div;
Intern_String op_sub; Intern_String op_sub;
@@ -51,10 +52,11 @@ Intern_String op_lesser_then;
Intern_String op_greater_then; Intern_String op_greater_then;
Intern_String op_left_shift; Intern_String op_left_shift;
Intern_String op_right_shift; Intern_String op_right_shift;
Intern_String op_not; Intern_String op_not;
Intern_String op_neg; Intern_String op_neg;
Intern_String op_decrement; Intern_String op_decrement;
Intern_String op_increment; Intern_String op_increment; // last
//----------------------------------------------------------------------------- //-----------------------------------------------------------------------------
// Type globals // Type globals

View File

@@ -165,8 +165,6 @@ token_expect(Token_Kind kind){
return 0; return 0;
} }
function Ast_Expr *parse_expr(S64 minbp = 0);
function Ast_Expr * function Ast_Expr *
parse_init_stmt(Ast_Expr *expr){ parse_init_stmt(Ast_Expr *expr){
Token *token = token_get(); Token *token = token_get();
@@ -852,6 +850,10 @@ parse_decl(B32 is_global){
if(!is_valid_operator_overload(pctx, tname->intern_val)){ if(!is_valid_operator_overload(pctx, tname->intern_val)){
compiler_error(tname, "This operator cannot be overloaded"); compiler_error(tname, "This operator cannot be overloaded");
} }
// if(is_binary && expr->args.len == 2){
// }
result = ast_const(tname, tname->intern_val, expr); result = ast_const(tname, tname->intern_val, expr);
result->kind = AST_LAMBDA; result->kind = AST_LAMBDA;

View File

@@ -1701,12 +1701,34 @@ resolve_expr(Ast_Expr *ast, Resolve_Flag flags, Ast_Type *compound_context){
return operand_lvalue(node->resolved_type); return operand_lvalue(node->resolved_type);
} }
else{ else{
// Try finding a operator overload
B32 proceed_to_default_operator_handler = true;
if(!value.is_const){
U64 hash = calculate_hash_for_arguments(value.type);
Ast_Decl *operator_overload = resolve_operator_overload(node->parent_scope, node->pos, node->op, hash);
if(operator_overload){
proceed_to_default_operator_handler = false;
if(operator_overload->lambda->ret.len != 1){
compiler_error(operator_overload->pos, "Operator overload is required to have exactly 1 return value");
}
node->resolved_type = operator_overload->type->func.ret;
node->resolved_operator_overload = operator_overload;
}
}
if(proceed_to_default_operator_handler){
eval_unary(node->pos, node->op, &value); eval_unary(node->pos, node->op, &value);
node->resolved_type = value.type;
}
if(value.is_const){ if(value.is_const){
rewrite_into_const(node, Ast_Unary, value.value); rewrite_into_const(node, Ast_Unary, value.value);
return operand_const_rvalue(value.value); return operand_const_rvalue(value.value);
} }
return operand_rvalue(value.value.type);
return operand_rvalue(node->resolved_type);
} }
BREAK(); BREAK();

View File

@@ -1,14 +1,19 @@
Vec3 :: struct;; x: F32; y: F32; z: F32 Vec3 :: struct;; x: F32; y: F32; z: F32
"+" :: (a: Vec3, b: Vec3): Vec3 ;; return Vec3{a.x+b.x, a.y+b.y, a.z+b.z}
// We can define operator overloads for arbitrary types
// these are just regular lambdas/functions
"+" :: (a: Vec3, b: Vec3): Vec3 ;; return {a.x+b.x, a.y+b.y, a.z+b.z}
"-" :: (a: Vec3, b: Vec3): Vec3 ;; return {a.x-b.x, a.y-b.y, a.z-b.z}
"-" :: (a: Vec3): Vec3 ;; return {-a.x, -a.y, -a.z}
main :: (): int main :: (): int
a := Vec3{1,1,1} a := Vec3{1,1,1}
b := Vec3{2,3,4} b := Vec3{2,3,4}
c := a + b c := a + b
Assert(c.x == 3) Assert(c.x == 3 && c.y == 4 && c.z == 5)
Assert(c.y == 4) d := -c
Assert(c.z == 5) Assert(d.x == -3 && d.y == -4 && d.z == -5)
e := c - d
Assert(e.x == 6 && e.y == 8 && e.z == 10)
return 0 return 0