Operator overloads for unary expressions

This commit is contained in:
Krzosa Karol
2022-09-29 17:36:15 +02:00
parent 9ee22abbd2
commit 37489b2730
6 changed files with 61 additions and 22 deletions

View File

@@ -386,11 +386,19 @@ gen_expr(Ast_Expr *ast, Ast_Type *type_of_var){
}
CASE(UNARY, Unary){
gen("(");
if(node->op != TK_PostIncrement && node->op != TK_PostDecrement) gen("%s", name(node->op));
gen_expr(node->expr);
if(node->op == TK_PostIncrement || node->op == TK_PostDecrement) gen("%s", name(node->op));
gen(")");
if(node->resolved_operator_overload){
gen("%Q(", node->resolved_operator_overload->unique_name);
gen_expr(node->expr);
gen(")");
}
else {
gen("(");
if(node->op != TK_PostIncrement && node->op != TK_PostDecrement) gen("%s", name(node->op));
gen_expr(node->expr);
if(node->op == TK_PostIncrement || node->op == TK_PostDecrement) gen("%s", name(node->op));
gen(")");
}
BREAK();
}

View File

@@ -5,6 +5,7 @@ struct Ast_File;
struct Ast_Module;
struct Ast_Type;
struct Ast;
struct Ast_Expr;
enum Token_Kind{
TK_End,
@@ -95,7 +96,6 @@ enum Token_Kind{
TK_Pointer = TK_Mul,
TK_Dereference = TK_BitAnd,
OPEN_SCOPE = 128,
CLOSE_SCOPE,
SAME_SCOPE,
@@ -122,7 +122,6 @@ struct Token{
S32 line;
U8 *line_begin;
};
global Token null_token;
struct Lex_Stream{
String stream;
@@ -197,4 +196,5 @@ function String compile_to_c_code();
function Ast_Module *ast_module(Token *pos, Intern_String filename);
function void insert_builtin_types_into_scope(Ast_Scope *p);
function void insert_into_scope(Ast_Scope *scope, Ast_Decl *decl);
function Ast_Type *type_incomplete(Ast *ast);
function Ast_Type *type_incomplete(Ast *ast);
function Ast_Expr *parse_expr(S64 minbp = 0);

View File

@@ -5,11 +5,12 @@ Allocator *bigint_allocator;
global S64 bigint_allocation_count;
global Token token_null = {SAME_SCOPE};
global Token null_token; // @todo: memes, why the above is called null?
//-----------------------------------------------------------------------------
// Interns / keywords
//-----------------------------------------------------------------------------
Intern_String keyword_struct;
Intern_String keyword_struct; // first
Intern_String keyword_union;
Intern_String keyword_return;
Intern_String keyword_if;
@@ -23,7 +24,7 @@ Intern_String keyword_switch;
Intern_String keyword_break;
Intern_String keyword_elif;
Intern_String keyword_assert;
Intern_String keyword_enum;
Intern_String keyword_enum; // last
Intern_String intern_sizeof;
Intern_String intern_alignof;
@@ -34,7 +35,7 @@ Intern_String intern_it;
Intern_String intern_strict;
Intern_String intern_flag;
Intern_String op_add;
Intern_String op_add; // first
Intern_String op_mul;
Intern_String op_div;
Intern_String op_sub;
@@ -51,10 +52,11 @@ Intern_String op_lesser_then;
Intern_String op_greater_then;
Intern_String op_left_shift;
Intern_String op_right_shift;
Intern_String op_not;
Intern_String op_neg;
Intern_String op_decrement;
Intern_String op_increment;
Intern_String op_increment; // last
//-----------------------------------------------------------------------------
// Type globals

View File

@@ -165,8 +165,6 @@ token_expect(Token_Kind kind){
return 0;
}
function Ast_Expr *parse_expr(S64 minbp = 0);
function Ast_Expr *
parse_init_stmt(Ast_Expr *expr){
Token *token = token_get();
@@ -852,6 +850,10 @@ parse_decl(B32 is_global){
if(!is_valid_operator_overload(pctx, tname->intern_val)){
compiler_error(tname, "This operator cannot be overloaded");
}
// if(is_binary && expr->args.len == 2){
// }
result = ast_const(tname, tname->intern_val, expr);
result->kind = AST_LAMBDA;

View File

@@ -1701,12 +1701,34 @@ resolve_expr(Ast_Expr *ast, Resolve_Flag flags, Ast_Type *compound_context){
return operand_lvalue(node->resolved_type);
}
else{
eval_unary(node->pos, node->op, &value);
// Try finding a operator overload
B32 proceed_to_default_operator_handler = true;
if(!value.is_const){
U64 hash = calculate_hash_for_arguments(value.type);
Ast_Decl *operator_overload = resolve_operator_overload(node->parent_scope, node->pos, node->op, hash);
if(operator_overload){
proceed_to_default_operator_handler = false;
if(operator_overload->lambda->ret.len != 1){
compiler_error(operator_overload->pos, "Operator overload is required to have exactly 1 return value");
}
node->resolved_type = operator_overload->type->func.ret;
node->resolved_operator_overload = operator_overload;
}
}
if(proceed_to_default_operator_handler){
eval_unary(node->pos, node->op, &value);
node->resolved_type = value.type;
}
if(value.is_const){
rewrite_into_const(node, Ast_Unary, value.value);
return operand_const_rvalue(value.value);
}
return operand_rvalue(value.value.type);
return operand_rvalue(node->resolved_type);
}
BREAK();

View File

@@ -1,14 +1,19 @@
Vec3 :: struct;; x: F32; y: F32; z: F32
"+" :: (a: Vec3, b: Vec3): Vec3 ;; return Vec3{a.x+b.x, a.y+b.y, a.z+b.z}
Vec3 :: struct;; x: F32; y: F32; z: F32
// We can define operator overloads for arbitrary types
// these are just regular lambdas/functions
"+" :: (a: Vec3, b: Vec3): Vec3 ;; return {a.x+b.x, a.y+b.y, a.z+b.z}
"-" :: (a: Vec3, b: Vec3): Vec3 ;; return {a.x-b.x, a.y-b.y, a.z-b.z}
"-" :: (a: Vec3): Vec3 ;; return {-a.x, -a.y, -a.z}
main :: (): int
a := Vec3{1,1,1}
b := Vec3{2,3,4}
c := a + b
Assert(c.x == 3)
Assert(c.y == 4)
Assert(c.z == 5)
Assert(c.x == 3 && c.y == 4 && c.z == 5)
d := -c
Assert(d.x == -3 && d.y == -4 && d.z == -5)
e := c - d
Assert(e.x == 6 && e.y == 8 && e.z == 10)
return 0