Bulletproofing operator overloads using generated data

This commit is contained in:
Krzosa Karol
2022-09-30 13:36:48 +02:00
parent 62faf8a78c
commit 4ca3ab95df
10 changed files with 79 additions and 159 deletions

View File

@@ -271,6 +271,7 @@ struct Ast_Decl: Ast{
Intern_String unique_name; // For code generation, currently only present on lambdas
U64 operator_overload_arguments_hash;
Operator_Info *overload_op_info;
Ast_Scope *scope;
Ast_Expr *typespec;

View File

@@ -14,14 +14,12 @@ print(f'l->interns.last_keyword = keyword_{meta.keywords[-1].lower()}.str;')
for i in meta.interns:
print(f'intern_{i.lower()} = l->intern("{i}"_s);')
index = 0
for i in meta.token_simple_expr:
if i[1] != "SPECIAL":
print("op_" + meta.pascal_to_snake(i[0]) + f' = l->intern("{i[1]}"_s);')
print(f'op_info_table[{index}].op = l->intern("{i[1]}"_s);')
index += 1
first = "op_" + meta.pascal_to_snake(meta.token_simple_expr[0][0])
last = "op_" + meta.pascal_to_snake(meta.token_simple_expr[-1][0])
print(f"l->first_op = {first};")
print(f"l->last_op = {last};")
*/
keyword_struct = l->intern("struct"_s);
keyword_union = l->intern("union"_s);
@@ -48,32 +46,26 @@ intern_strict = l->intern("strict"_s);
intern_void = l->intern("void"_s);
intern_flag = l->intern("flag"_s);
intern_it = l->intern("it"_s);
op_mul = l->intern("*"_s);
op_div = l->intern("/"_s);
op_mod = l->intern("%"_s);
op_left_shift = l->intern("<<"_s);
op_right_shift = l->intern(">>"_s);
op_add = l->intern("+"_s);
op_sub = l->intern("-"_s);
op_equals = l->intern("=="_s);
op_lesser_then_or_equal = l->intern("<="_s);
op_greater_then_or_equal = l->intern(">="_s);
op_lesser_then = l->intern("<"_s);
op_greater_then = l->intern(">"_s);
op_not_equals = l->intern("!="_s);
op_bit_and = l->intern("&"_s);
op_bit_or = l->intern("|"_s);
op_bit_xor = l->intern("^"_s);
op_and = l->intern("&&"_s);
op_or = l->intern("||"_s);
op_neg = l->intern("~"_s);
op_not = l->intern("!"_s);
op_decrement = l->intern("--"_s);
op_increment = l->intern("++"_s);
op_post_decrement = l->intern("--"_s);
op_post_increment = l->intern("++"_s);
l->first_op = op_mul;
l->last_op = op_post_increment;
op_info_table[0].op = l->intern("*"_s);
op_info_table[1].op = l->intern("/"_s);
op_info_table[2].op = l->intern("%"_s);
op_info_table[3].op = l->intern("<<"_s);
op_info_table[4].op = l->intern(">>"_s);
op_info_table[5].op = l->intern("+"_s);
op_info_table[6].op = l->intern("-"_s);
op_info_table[7].op = l->intern("=="_s);
op_info_table[8].op = l->intern("<="_s);
op_info_table[9].op = l->intern(">="_s);
op_info_table[10].op = l->intern("<"_s);
op_info_table[11].op = l->intern(">"_s);
op_info_table[12].op = l->intern("!="_s);
op_info_table[13].op = l->intern("&"_s);
op_info_table[14].op = l->intern("|"_s);
op_info_table[15].op = l->intern("^"_s);
op_info_table[16].op = l->intern("&&"_s);
op_info_table[17].op = l->intern("||"_s);
op_info_table[18].op = l->intern("~"_s);
op_info_table[19].op = l->intern("!"_s);
/*END*/
}
@@ -333,4 +325,4 @@ compile_file(String filename, U32 compile_flags = COMPILE_NULL){
}
destroy_compiler();
}
}

View File

@@ -141,9 +141,6 @@ struct Lexer{
S64 token_iter;
U32 token_debug_ids;
Intern_String first_op;
Intern_String last_op ;
Intern_String intern(String string){
return intern_string(&interns, string);
}

View File

@@ -11,9 +11,6 @@ global Token null_token; // @todo: memes, why the above is called null?
for i in meta.keywords: print(f'Intern_String keyword_{i.lower()};')
for i in meta.interns: print(f'Intern_String intern_{i.lower()};')
for i in meta.token_simple_expr:
if i[1] != "SPECIAL":
print("Intern_String op_" + meta.pascal_to_snake(i[0]) + ";")
*/
Intern_String keyword_struct;
Intern_String keyword_union;
@@ -38,30 +35,6 @@ Intern_String intern_strict;
Intern_String intern_void;
Intern_String intern_flag;
Intern_String intern_it;
Intern_String op_mul;
Intern_String op_div;
Intern_String op_mod;
Intern_String op_left_shift;
Intern_String op_right_shift;
Intern_String op_add;
Intern_String op_sub;
Intern_String op_equals;
Intern_String op_lesser_then_or_equal;
Intern_String op_greater_then_or_equal;
Intern_String op_lesser_then;
Intern_String op_greater_then;
Intern_String op_not_equals;
Intern_String op_bit_and;
Intern_String op_bit_or;
Intern_String op_bit_xor;
Intern_String op_and;
Intern_String op_or;
Intern_String op_neg;
Intern_String op_not;
Intern_String op_decrement;
Intern_String op_increment;
Intern_String op_post_decrement;
Intern_String op_post_increment;
/*END*/
//-----------------------------------------------------------------------------

View File

@@ -72,12 +72,6 @@ lex_is_keyword(Intern_Table *lexer, Intern_String keyword){
return result;
}
function B32
is_valid_operator_overload(Lexer *lexer, Intern_String op){
B32 result = op.str >= lexer->first_op.str && op.str <= lexer->last_op.str;
return result;
}
function void
token_error(Token *t, String error_val){
t->kind = TK_Error;

View File

@@ -237,6 +237,7 @@ For modules it's a bit different cause they should be distributed as valid.
#include "core_compiler.h"
#include "core_types.h"
#include "core_globals.cpp"
#include "core_generated.cpp"
#include "c3_big_int.cpp"
#include "core_lexing.cpp"
#include "core_ast.cpp"

View File

@@ -847,15 +847,27 @@ parse_decl(B32 is_global){
if(!expr->scope){
compiler_error(tname, "Operator overload doesn't have body");
}
if(!is_valid_operator_overload(pctx, tname->intern_val)){
Operator_Info *op_info = get_operator_info(tname->intern_val);
if(!op_info){
compiler_error(tname, "This operator cannot be overloaded");
}
// if(is_binary && expr->args.len == 2){
// }
if(expr->args.len == 1){
if(!op_info->valid_unary_expr){
compiler_error(tname, "This operator cannot have a unary expression");
}
}
else if(expr->args.len == 2){
if(!op_info->valid_binary_expr){
compiler_error(tname, "This operator cannot have a binary expression");
}
}
else {
compiler_error(tname, "Invalid argument count for operator overload, unhandled operator");
}
result = ast_const(tname, tname->intern_val, expr);
result->overload_op_info = op_info;
result->kind = AST_LAMBDA;
result->flags = set_flag(result->flags, AST_OPERATOR_OVERLOAD);
}

View File

@@ -855,70 +855,15 @@ resolve_name(Ast_Scope *scope, Token *pos, Intern_String name, Search_Flag searc
return decl;
}
function Intern_String
map_operator_to_intern(Token_Kind op){
switch(op){
case TK_Add: return op_add; break;
case TK_Mul: return op_mul; break;
case TK_Div: return op_div; break;
case TK_Sub: return op_sub; break;
case TK_And: return op_and; break;
case TK_BitAnd: return op_bit_and; break;
case TK_Or: return op_or; break;
case TK_BitOr: return op_bit_or; break;
case TK_BitXor: return op_bit_xor; break;
case TK_Equals: return op_equals; break;
case TK_NotEquals: return op_not_equals; break;
case TK_LesserThenOrEqual: return op_lesser_then_or_equal; break;
case TK_GreaterThenOrEqual: return op_greater_then_or_equal; break;
case TK_LesserThen: return op_lesser_then; break;
case TK_GreaterThen: return op_greater_then; break;
case TK_LeftShift: return op_left_shift; break;
case TK_RightShift: return op_right_shift; break;
case TK_Not: return op_not; break;
case TK_Neg: return op_neg; break;
case TK_Decrement: return op_decrement; break;
case TK_Increment: return op_increment; break;
default: return {};
}
}
function String
map_operator_intern_to_identifier_name(Intern_String op){
if(op.str == op_add.str) return "ADD"_s;
if(op.str == op_mul.str) return "MUL"_s;
if(op.str == op_div.str) return "DIV"_s;
if(op.str == op_sub.str) return "SUB"_s;
if(op.str == op_and.str) return "AND"_s;
if(op.str == op_bit_and.str) return "BITAND"_s;
if(op.str == op_or.str) return "OR"_s;
if(op.str == op_bit_or.str) return "BITOR"_s;
if(op.str == op_bit_xor.str) return "XOR"_s;
if(op.str == op_equals.str) return "EQUALS"_s;
if(op.str == op_not_equals.str) return "NOT_EQUALS"_s;
if(op.str == op_lesser_then_or_equal.str) return "LESSER_THEN_OR_EQUAL"_s;
if(op.str == op_greater_then_or_equal.str) return "GREATER_THEN_OR_EQUAL"_s;
if(op.str == op_lesser_then.str) return "LESSER_THEN"_s;
if(op.str == op_greater_then.str) return "GREATER_THEN"_s;
if(op.str == op_left_shift.str) return "LEFT_SHIFT"_s;
if(op.str == op_right_shift.str) return "RIGHT_SHIFT"_s;
if(op.str == op_not.str) return "NOT"_s;
if(op.str == op_neg.str) return "NEG"_s;
if(op.str == op_decrement.str) return "DECREMENT"_s;
if(op.str == op_increment.str) return "INCREMENT"_s;
invalid_codepath;
return "INVALID_OPERATOR_OVERLOAD"_s;
}
function Ast_Decl *
resolve_operator_overload(Ast_Scope *scope, Ast_Type *left, Ast_Type *right, Token *pos, Token_Kind op, U64 argument_hash){
Intern_String name = map_operator_to_intern(op);
if(name.str == 0) return 0;
Operator_Info *op_info = get_operator_info(op);
if(op_info == 0) return 0;
// Search for all possible candidates in three scopes
// The current module, left type definition module, right type definition module
Scratch scratch;
Scope_Search search = make_scope_search(scratch, scope, name);
Scope_Search search = make_scope_search(scratch, scope, op_info->op);
if( left->ast && left->ast->parent_scope) search.scopes.add(left->ast->parent_scope);
if(right && right->ast && right->ast->parent_scope) search.scopes.add(right->ast->parent_scope);
search.exit_on_find = false;
@@ -934,7 +879,7 @@ resolve_operator_overload(Ast_Scope *scope, Ast_Type *left, Ast_Type *right, Tok
}
if(matching_ops.len > 1){
compiler_error(pos, "Found multiple matching operator overloads for [%s]", name.str);
compiler_error(pos, "Found multiple matching operator overloads for [%Q]", op_info->op);
}
if(matching_ops.len == 1){
@@ -1976,8 +1921,7 @@ resolve_decl(Ast_Decl *ast){
}
if(is_flag_set(node->flags, AST_OPERATOR_OVERLOAD)){
String n = map_operator_intern_to_identifier_name(node->name);
node->unique_name = pctx->intern(string_fmt(scratch, "CORE_OPERATOR_%Q%d", n, pctx->lambda_ids++));
node->unique_name = pctx->intern(string_fmt(scratch, "CORE_OPERATOR_%Q%d", node->overload_op_info->name, pctx->lambda_ids++));
}
BREAK();

56
meta.py
View File

@@ -5,40 +5,46 @@ def pascal_to_snake(v):
name = snake_case_pattern.sub('_', v).lower()
return name
BINARY_EXPR = 1
UNARY_EXPR = 2
token_simple_expr = [
["Mul", "*"],
["Div", "/"],
["Mod", "%"],
["LeftShift", "<<"],
["RightShift", ">>"],
["Mul", "*", BINARY_EXPR],
["Div", "/", BINARY_EXPR],
["Mod", "%", BINARY_EXPR],
["LeftShift", "<<", BINARY_EXPR],
["RightShift", ">>", BINARY_EXPR],
["FirstMul = TK_Mul", "SPECIAL"],
["LastMul = TK_RightShift", "SPECIAL"],
["Add", "+"],
["Sub", "-"],
["Add", "+", BINARY_EXPR | UNARY_EXPR],
["Sub", "-", BINARY_EXPR | UNARY_EXPR],
["FirstAdd = TK_Add", "SPECIAL"],
["LastAdd = TK_Sub", "SPECIAL"],
["Equals", "=="],
["LesserThenOrEqual", "<="],
["GreaterThenOrEqual", ">="],
["LesserThen", "<"],
["GreaterThen", ">"],
["NotEquals", "!="],
["Equals", "==", BINARY_EXPR],
["LesserThenOrEqual", "<=", BINARY_EXPR],
["GreaterThenOrEqual", ">=", BINARY_EXPR],
["LesserThen", "<", BINARY_EXPR],
["GreaterThen", ">", BINARY_EXPR],
["NotEquals", "!=", BINARY_EXPR],
["FirstCompare = TK_Equals", "SPECIAL"],
["LastCompare = TK_NotEquals", "SPECIAL"],
["BitAnd", "&"],
["BitOr", "|"],
["BitXor", "^"],
["And", "&&"],
["Or", "||"],
["BitAnd", "&", BINARY_EXPR],
["BitOr", "|", BINARY_EXPR],
["BitXor", "^", BINARY_EXPR],
["And", "&&", BINARY_EXPR],
["Or", "||", BINARY_EXPR],
["FirstLogical = TK_BitAnd", "SPECIAL"],
["LastLogical = TK_Or", "SPECIAL"],
["Neg", "~"],
["Not", "!"],
["Decrement", "--"],
["Increment", "++"],
["PostDecrement", "--"],
["PostIncrement", "++"],
["Neg", "~", UNARY_EXPR],
["Not", "!", UNARY_EXPR],
]
token_inc_expr = [
["Decrement", "--", UNARY_EXPR],
["Increment", "++", UNARY_EXPR],
["PostDecrement", "--", UNARY_EXPR],
["PostIncrement", "++", UNARY_EXPR],
]
token_assign_expr = [
@@ -89,7 +95,7 @@ token_rest = [
["Keyword", "[Keyword]"],
]
token_kinds = token_simple_expr + token_assign_expr + token_rest
token_kinds = token_simple_expr + token_inc_expr + token_assign_expr + token_rest
keywords = [
"struct",

View File

@@ -3,7 +3,7 @@ import sys
import os
files = ["core_compiler.cpp", "core_compiler.h", "core_globals.cpp", "core_lexing.cpp"]
files = ["core_compiler.cpp", "core_compiler.h", "core_globals.cpp", "core_lexing.cpp", "core_generated.cpp"]
for file_to_modify in files:
fd = open(file_to_modify, "r+")