Bulletproofing operator overloads using generated data
This commit is contained in:
@@ -271,6 +271,7 @@ struct Ast_Decl: Ast{
|
|||||||
Intern_String unique_name; // For code generation, currently only present on lambdas
|
Intern_String unique_name; // For code generation, currently only present on lambdas
|
||||||
|
|
||||||
U64 operator_overload_arguments_hash;
|
U64 operator_overload_arguments_hash;
|
||||||
|
Operator_Info *overload_op_info;
|
||||||
|
|
||||||
Ast_Scope *scope;
|
Ast_Scope *scope;
|
||||||
Ast_Expr *typespec;
|
Ast_Expr *typespec;
|
||||||
|
|||||||
@@ -14,14 +14,12 @@ print(f'l->interns.last_keyword = keyword_{meta.keywords[-1].lower()}.str;')
|
|||||||
for i in meta.interns:
|
for i in meta.interns:
|
||||||
print(f'intern_{i.lower()} = l->intern("{i}"_s);')
|
print(f'intern_{i.lower()} = l->intern("{i}"_s);')
|
||||||
|
|
||||||
|
index = 0
|
||||||
for i in meta.token_simple_expr:
|
for i in meta.token_simple_expr:
|
||||||
if i[1] != "SPECIAL":
|
if i[1] != "SPECIAL":
|
||||||
print("op_" + meta.pascal_to_snake(i[0]) + f' = l->intern("{i[1]}"_s);')
|
print(f'op_info_table[{index}].op = l->intern("{i[1]}"_s);')
|
||||||
|
index += 1
|
||||||
|
|
||||||
first = "op_" + meta.pascal_to_snake(meta.token_simple_expr[0][0])
|
|
||||||
last = "op_" + meta.pascal_to_snake(meta.token_simple_expr[-1][0])
|
|
||||||
print(f"l->first_op = {first};")
|
|
||||||
print(f"l->last_op = {last};")
|
|
||||||
*/
|
*/
|
||||||
keyword_struct = l->intern("struct"_s);
|
keyword_struct = l->intern("struct"_s);
|
||||||
keyword_union = l->intern("union"_s);
|
keyword_union = l->intern("union"_s);
|
||||||
@@ -48,32 +46,26 @@ intern_strict = l->intern("strict"_s);
|
|||||||
intern_void = l->intern("void"_s);
|
intern_void = l->intern("void"_s);
|
||||||
intern_flag = l->intern("flag"_s);
|
intern_flag = l->intern("flag"_s);
|
||||||
intern_it = l->intern("it"_s);
|
intern_it = l->intern("it"_s);
|
||||||
op_mul = l->intern("*"_s);
|
op_info_table[0].op = l->intern("*"_s);
|
||||||
op_div = l->intern("/"_s);
|
op_info_table[1].op = l->intern("/"_s);
|
||||||
op_mod = l->intern("%"_s);
|
op_info_table[2].op = l->intern("%"_s);
|
||||||
op_left_shift = l->intern("<<"_s);
|
op_info_table[3].op = l->intern("<<"_s);
|
||||||
op_right_shift = l->intern(">>"_s);
|
op_info_table[4].op = l->intern(">>"_s);
|
||||||
op_add = l->intern("+"_s);
|
op_info_table[5].op = l->intern("+"_s);
|
||||||
op_sub = l->intern("-"_s);
|
op_info_table[6].op = l->intern("-"_s);
|
||||||
op_equals = l->intern("=="_s);
|
op_info_table[7].op = l->intern("=="_s);
|
||||||
op_lesser_then_or_equal = l->intern("<="_s);
|
op_info_table[8].op = l->intern("<="_s);
|
||||||
op_greater_then_or_equal = l->intern(">="_s);
|
op_info_table[9].op = l->intern(">="_s);
|
||||||
op_lesser_then = l->intern("<"_s);
|
op_info_table[10].op = l->intern("<"_s);
|
||||||
op_greater_then = l->intern(">"_s);
|
op_info_table[11].op = l->intern(">"_s);
|
||||||
op_not_equals = l->intern("!="_s);
|
op_info_table[12].op = l->intern("!="_s);
|
||||||
op_bit_and = l->intern("&"_s);
|
op_info_table[13].op = l->intern("&"_s);
|
||||||
op_bit_or = l->intern("|"_s);
|
op_info_table[14].op = l->intern("|"_s);
|
||||||
op_bit_xor = l->intern("^"_s);
|
op_info_table[15].op = l->intern("^"_s);
|
||||||
op_and = l->intern("&&"_s);
|
op_info_table[16].op = l->intern("&&"_s);
|
||||||
op_or = l->intern("||"_s);
|
op_info_table[17].op = l->intern("||"_s);
|
||||||
op_neg = l->intern("~"_s);
|
op_info_table[18].op = l->intern("~"_s);
|
||||||
op_not = l->intern("!"_s);
|
op_info_table[19].op = l->intern("!"_s);
|
||||||
op_decrement = l->intern("--"_s);
|
|
||||||
op_increment = l->intern("++"_s);
|
|
||||||
op_post_decrement = l->intern("--"_s);
|
|
||||||
op_post_increment = l->intern("++"_s);
|
|
||||||
l->first_op = op_mul;
|
|
||||||
l->last_op = op_post_increment;
|
|
||||||
/*END*/
|
/*END*/
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -141,9 +141,6 @@ struct Lexer{
|
|||||||
S64 token_iter;
|
S64 token_iter;
|
||||||
U32 token_debug_ids;
|
U32 token_debug_ids;
|
||||||
|
|
||||||
Intern_String first_op;
|
|
||||||
Intern_String last_op ;
|
|
||||||
|
|
||||||
Intern_String intern(String string){
|
Intern_String intern(String string){
|
||||||
return intern_string(&interns, string);
|
return intern_string(&interns, string);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,9 +11,6 @@ global Token null_token; // @todo: memes, why the above is called null?
|
|||||||
for i in meta.keywords: print(f'Intern_String keyword_{i.lower()};')
|
for i in meta.keywords: print(f'Intern_String keyword_{i.lower()};')
|
||||||
for i in meta.interns: print(f'Intern_String intern_{i.lower()};')
|
for i in meta.interns: print(f'Intern_String intern_{i.lower()};')
|
||||||
|
|
||||||
for i in meta.token_simple_expr:
|
|
||||||
if i[1] != "SPECIAL":
|
|
||||||
print("Intern_String op_" + meta.pascal_to_snake(i[0]) + ";")
|
|
||||||
*/
|
*/
|
||||||
Intern_String keyword_struct;
|
Intern_String keyword_struct;
|
||||||
Intern_String keyword_union;
|
Intern_String keyword_union;
|
||||||
@@ -38,30 +35,6 @@ Intern_String intern_strict;
|
|||||||
Intern_String intern_void;
|
Intern_String intern_void;
|
||||||
Intern_String intern_flag;
|
Intern_String intern_flag;
|
||||||
Intern_String intern_it;
|
Intern_String intern_it;
|
||||||
Intern_String op_mul;
|
|
||||||
Intern_String op_div;
|
|
||||||
Intern_String op_mod;
|
|
||||||
Intern_String op_left_shift;
|
|
||||||
Intern_String op_right_shift;
|
|
||||||
Intern_String op_add;
|
|
||||||
Intern_String op_sub;
|
|
||||||
Intern_String op_equals;
|
|
||||||
Intern_String op_lesser_then_or_equal;
|
|
||||||
Intern_String op_greater_then_or_equal;
|
|
||||||
Intern_String op_lesser_then;
|
|
||||||
Intern_String op_greater_then;
|
|
||||||
Intern_String op_not_equals;
|
|
||||||
Intern_String op_bit_and;
|
|
||||||
Intern_String op_bit_or;
|
|
||||||
Intern_String op_bit_xor;
|
|
||||||
Intern_String op_and;
|
|
||||||
Intern_String op_or;
|
|
||||||
Intern_String op_neg;
|
|
||||||
Intern_String op_not;
|
|
||||||
Intern_String op_decrement;
|
|
||||||
Intern_String op_increment;
|
|
||||||
Intern_String op_post_decrement;
|
|
||||||
Intern_String op_post_increment;
|
|
||||||
/*END*/
|
/*END*/
|
||||||
|
|
||||||
//-----------------------------------------------------------------------------
|
//-----------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -72,12 +72,6 @@ lex_is_keyword(Intern_Table *lexer, Intern_String keyword){
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
function B32
|
|
||||||
is_valid_operator_overload(Lexer *lexer, Intern_String op){
|
|
||||||
B32 result = op.str >= lexer->first_op.str && op.str <= lexer->last_op.str;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
function void
|
function void
|
||||||
token_error(Token *t, String error_val){
|
token_error(Token *t, String error_val){
|
||||||
t->kind = TK_Error;
|
t->kind = TK_Error;
|
||||||
|
|||||||
@@ -237,6 +237,7 @@ For modules it's a bit different cause they should be distributed as valid.
|
|||||||
#include "core_compiler.h"
|
#include "core_compiler.h"
|
||||||
#include "core_types.h"
|
#include "core_types.h"
|
||||||
#include "core_globals.cpp"
|
#include "core_globals.cpp"
|
||||||
|
#include "core_generated.cpp"
|
||||||
#include "c3_big_int.cpp"
|
#include "c3_big_int.cpp"
|
||||||
#include "core_lexing.cpp"
|
#include "core_lexing.cpp"
|
||||||
#include "core_ast.cpp"
|
#include "core_ast.cpp"
|
||||||
|
|||||||
@@ -847,15 +847,27 @@ parse_decl(B32 is_global){
|
|||||||
if(!expr->scope){
|
if(!expr->scope){
|
||||||
compiler_error(tname, "Operator overload doesn't have body");
|
compiler_error(tname, "Operator overload doesn't have body");
|
||||||
}
|
}
|
||||||
if(!is_valid_operator_overload(pctx, tname->intern_val)){
|
Operator_Info *op_info = get_operator_info(tname->intern_val);
|
||||||
|
if(!op_info){
|
||||||
compiler_error(tname, "This operator cannot be overloaded");
|
compiler_error(tname, "This operator cannot be overloaded");
|
||||||
}
|
}
|
||||||
// if(is_binary && expr->args.len == 2){
|
|
||||||
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
if(expr->args.len == 1){
|
||||||
|
if(!op_info->valid_unary_expr){
|
||||||
|
compiler_error(tname, "This operator cannot have a unary expression");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if(expr->args.len == 2){
|
||||||
|
if(!op_info->valid_binary_expr){
|
||||||
|
compiler_error(tname, "This operator cannot have a binary expression");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
compiler_error(tname, "Invalid argument count for operator overload, unhandled operator");
|
||||||
|
}
|
||||||
|
|
||||||
result = ast_const(tname, tname->intern_val, expr);
|
result = ast_const(tname, tname->intern_val, expr);
|
||||||
|
result->overload_op_info = op_info;
|
||||||
result->kind = AST_LAMBDA;
|
result->kind = AST_LAMBDA;
|
||||||
result->flags = set_flag(result->flags, AST_OPERATOR_OVERLOAD);
|
result->flags = set_flag(result->flags, AST_OPERATOR_OVERLOAD);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -855,70 +855,15 @@ resolve_name(Ast_Scope *scope, Token *pos, Intern_String name, Search_Flag searc
|
|||||||
return decl;
|
return decl;
|
||||||
}
|
}
|
||||||
|
|
||||||
function Intern_String
|
|
||||||
map_operator_to_intern(Token_Kind op){
|
|
||||||
switch(op){
|
|
||||||
case TK_Add: return op_add; break;
|
|
||||||
case TK_Mul: return op_mul; break;
|
|
||||||
case TK_Div: return op_div; break;
|
|
||||||
case TK_Sub: return op_sub; break;
|
|
||||||
case TK_And: return op_and; break;
|
|
||||||
case TK_BitAnd: return op_bit_and; break;
|
|
||||||
case TK_Or: return op_or; break;
|
|
||||||
case TK_BitOr: return op_bit_or; break;
|
|
||||||
case TK_BitXor: return op_bit_xor; break;
|
|
||||||
case TK_Equals: return op_equals; break;
|
|
||||||
case TK_NotEquals: return op_not_equals; break;
|
|
||||||
case TK_LesserThenOrEqual: return op_lesser_then_or_equal; break;
|
|
||||||
case TK_GreaterThenOrEqual: return op_greater_then_or_equal; break;
|
|
||||||
case TK_LesserThen: return op_lesser_then; break;
|
|
||||||
case TK_GreaterThen: return op_greater_then; break;
|
|
||||||
case TK_LeftShift: return op_left_shift; break;
|
|
||||||
case TK_RightShift: return op_right_shift; break;
|
|
||||||
case TK_Not: return op_not; break;
|
|
||||||
case TK_Neg: return op_neg; break;
|
|
||||||
case TK_Decrement: return op_decrement; break;
|
|
||||||
case TK_Increment: return op_increment; break;
|
|
||||||
default: return {};
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function String
|
|
||||||
map_operator_intern_to_identifier_name(Intern_String op){
|
|
||||||
if(op.str == op_add.str) return "ADD"_s;
|
|
||||||
if(op.str == op_mul.str) return "MUL"_s;
|
|
||||||
if(op.str == op_div.str) return "DIV"_s;
|
|
||||||
if(op.str == op_sub.str) return "SUB"_s;
|
|
||||||
if(op.str == op_and.str) return "AND"_s;
|
|
||||||
if(op.str == op_bit_and.str) return "BITAND"_s;
|
|
||||||
if(op.str == op_or.str) return "OR"_s;
|
|
||||||
if(op.str == op_bit_or.str) return "BITOR"_s;
|
|
||||||
if(op.str == op_bit_xor.str) return "XOR"_s;
|
|
||||||
if(op.str == op_equals.str) return "EQUALS"_s;
|
|
||||||
if(op.str == op_not_equals.str) return "NOT_EQUALS"_s;
|
|
||||||
if(op.str == op_lesser_then_or_equal.str) return "LESSER_THEN_OR_EQUAL"_s;
|
|
||||||
if(op.str == op_greater_then_or_equal.str) return "GREATER_THEN_OR_EQUAL"_s;
|
|
||||||
if(op.str == op_lesser_then.str) return "LESSER_THEN"_s;
|
|
||||||
if(op.str == op_greater_then.str) return "GREATER_THEN"_s;
|
|
||||||
if(op.str == op_left_shift.str) return "LEFT_SHIFT"_s;
|
|
||||||
if(op.str == op_right_shift.str) return "RIGHT_SHIFT"_s;
|
|
||||||
if(op.str == op_not.str) return "NOT"_s;
|
|
||||||
if(op.str == op_neg.str) return "NEG"_s;
|
|
||||||
if(op.str == op_decrement.str) return "DECREMENT"_s;
|
|
||||||
if(op.str == op_increment.str) return "INCREMENT"_s;
|
|
||||||
invalid_codepath;
|
|
||||||
return "INVALID_OPERATOR_OVERLOAD"_s;
|
|
||||||
}
|
|
||||||
|
|
||||||
function Ast_Decl *
|
function Ast_Decl *
|
||||||
resolve_operator_overload(Ast_Scope *scope, Ast_Type *left, Ast_Type *right, Token *pos, Token_Kind op, U64 argument_hash){
|
resolve_operator_overload(Ast_Scope *scope, Ast_Type *left, Ast_Type *right, Token *pos, Token_Kind op, U64 argument_hash){
|
||||||
Intern_String name = map_operator_to_intern(op);
|
Operator_Info *op_info = get_operator_info(op);
|
||||||
if(name.str == 0) return 0;
|
if(op_info == 0) return 0;
|
||||||
|
|
||||||
// Search for all possible candidates in three scopes
|
// Search for all possible candidates in three scopes
|
||||||
// The current module, left type definition module, right type definition module
|
// The current module, left type definition module, right type definition module
|
||||||
Scratch scratch;
|
Scratch scratch;
|
||||||
Scope_Search search = make_scope_search(scratch, scope, name);
|
Scope_Search search = make_scope_search(scratch, scope, op_info->op);
|
||||||
if( left->ast && left->ast->parent_scope) search.scopes.add(left->ast->parent_scope);
|
if( left->ast && left->ast->parent_scope) search.scopes.add(left->ast->parent_scope);
|
||||||
if(right && right->ast && right->ast->parent_scope) search.scopes.add(right->ast->parent_scope);
|
if(right && right->ast && right->ast->parent_scope) search.scopes.add(right->ast->parent_scope);
|
||||||
search.exit_on_find = false;
|
search.exit_on_find = false;
|
||||||
@@ -934,7 +879,7 @@ resolve_operator_overload(Ast_Scope *scope, Ast_Type *left, Ast_Type *right, Tok
|
|||||||
}
|
}
|
||||||
|
|
||||||
if(matching_ops.len > 1){
|
if(matching_ops.len > 1){
|
||||||
compiler_error(pos, "Found multiple matching operator overloads for [%s]", name.str);
|
compiler_error(pos, "Found multiple matching operator overloads for [%Q]", op_info->op);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(matching_ops.len == 1){
|
if(matching_ops.len == 1){
|
||||||
@@ -1976,8 +1921,7 @@ resolve_decl(Ast_Decl *ast){
|
|||||||
}
|
}
|
||||||
|
|
||||||
if(is_flag_set(node->flags, AST_OPERATOR_OVERLOAD)){
|
if(is_flag_set(node->flags, AST_OPERATOR_OVERLOAD)){
|
||||||
String n = map_operator_intern_to_identifier_name(node->name);
|
node->unique_name = pctx->intern(string_fmt(scratch, "CORE_OPERATOR_%Q%d", node->overload_op_info->name, pctx->lambda_ids++));
|
||||||
node->unique_name = pctx->intern(string_fmt(scratch, "CORE_OPERATOR_%Q%d", n, pctx->lambda_ids++));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BREAK();
|
BREAK();
|
||||||
|
|||||||
56
meta.py
56
meta.py
@@ -5,40 +5,46 @@ def pascal_to_snake(v):
|
|||||||
name = snake_case_pattern.sub('_', v).lower()
|
name = snake_case_pattern.sub('_', v).lower()
|
||||||
return name
|
return name
|
||||||
|
|
||||||
|
BINARY_EXPR = 1
|
||||||
|
UNARY_EXPR = 2
|
||||||
|
|
||||||
token_simple_expr = [
|
token_simple_expr = [
|
||||||
["Mul", "*"],
|
["Mul", "*", BINARY_EXPR],
|
||||||
["Div", "/"],
|
["Div", "/", BINARY_EXPR],
|
||||||
["Mod", "%"],
|
["Mod", "%", BINARY_EXPR],
|
||||||
["LeftShift", "<<"],
|
["LeftShift", "<<", BINARY_EXPR],
|
||||||
["RightShift", ">>"],
|
["RightShift", ">>", BINARY_EXPR],
|
||||||
["FirstMul = TK_Mul", "SPECIAL"],
|
["FirstMul = TK_Mul", "SPECIAL"],
|
||||||
["LastMul = TK_RightShift", "SPECIAL"],
|
["LastMul = TK_RightShift", "SPECIAL"],
|
||||||
["Add", "+"],
|
["Add", "+", BINARY_EXPR | UNARY_EXPR],
|
||||||
["Sub", "-"],
|
["Sub", "-", BINARY_EXPR | UNARY_EXPR],
|
||||||
["FirstAdd = TK_Add", "SPECIAL"],
|
["FirstAdd = TK_Add", "SPECIAL"],
|
||||||
["LastAdd = TK_Sub", "SPECIAL"],
|
["LastAdd = TK_Sub", "SPECIAL"],
|
||||||
["Equals", "=="],
|
["Equals", "==", BINARY_EXPR],
|
||||||
["LesserThenOrEqual", "<="],
|
["LesserThenOrEqual", "<=", BINARY_EXPR],
|
||||||
["GreaterThenOrEqual", ">="],
|
["GreaterThenOrEqual", ">=", BINARY_EXPR],
|
||||||
["LesserThen", "<"],
|
["LesserThen", "<", BINARY_EXPR],
|
||||||
["GreaterThen", ">"],
|
["GreaterThen", ">", BINARY_EXPR],
|
||||||
["NotEquals", "!="],
|
["NotEquals", "!=", BINARY_EXPR],
|
||||||
["FirstCompare = TK_Equals", "SPECIAL"],
|
["FirstCompare = TK_Equals", "SPECIAL"],
|
||||||
["LastCompare = TK_NotEquals", "SPECIAL"],
|
["LastCompare = TK_NotEquals", "SPECIAL"],
|
||||||
["BitAnd", "&"],
|
["BitAnd", "&", BINARY_EXPR],
|
||||||
["BitOr", "|"],
|
["BitOr", "|", BINARY_EXPR],
|
||||||
["BitXor", "^"],
|
["BitXor", "^", BINARY_EXPR],
|
||||||
["And", "&&"],
|
["And", "&&", BINARY_EXPR],
|
||||||
["Or", "||"],
|
["Or", "||", BINARY_EXPR],
|
||||||
["FirstLogical = TK_BitAnd", "SPECIAL"],
|
["FirstLogical = TK_BitAnd", "SPECIAL"],
|
||||||
["LastLogical = TK_Or", "SPECIAL"],
|
["LastLogical = TK_Or", "SPECIAL"],
|
||||||
|
|
||||||
["Neg", "~"],
|
["Neg", "~", UNARY_EXPR],
|
||||||
["Not", "!"],
|
["Not", "!", UNARY_EXPR],
|
||||||
["Decrement", "--"],
|
]
|
||||||
["Increment", "++"],
|
|
||||||
["PostDecrement", "--"],
|
token_inc_expr = [
|
||||||
["PostIncrement", "++"],
|
["Decrement", "--", UNARY_EXPR],
|
||||||
|
["Increment", "++", UNARY_EXPR],
|
||||||
|
["PostDecrement", "--", UNARY_EXPR],
|
||||||
|
["PostIncrement", "++", UNARY_EXPR],
|
||||||
]
|
]
|
||||||
|
|
||||||
token_assign_expr = [
|
token_assign_expr = [
|
||||||
@@ -89,7 +95,7 @@ token_rest = [
|
|||||||
["Keyword", "[Keyword]"],
|
["Keyword", "[Keyword]"],
|
||||||
]
|
]
|
||||||
|
|
||||||
token_kinds = token_simple_expr + token_assign_expr + token_rest
|
token_kinds = token_simple_expr + token_inc_expr + token_assign_expr + token_rest
|
||||||
|
|
||||||
keywords = [
|
keywords = [
|
||||||
"struct",
|
"struct",
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import sys
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
files = ["core_compiler.cpp", "core_compiler.h", "core_globals.cpp", "core_lexing.cpp"]
|
files = ["core_compiler.cpp", "core_compiler.h", "core_globals.cpp", "core_lexing.cpp", "core_generated.cpp"]
|
||||||
|
|
||||||
for file_to_modify in files:
|
for file_to_modify in files:
|
||||||
fd = open(file_to_modify, "r+")
|
fd = open(file_to_modify, "r+")
|
||||||
|
|||||||
Reference in New Issue
Block a user