revisit hash_table, add interning

This commit is contained in:
Krzosa Karol
2025-01-20 10:24:03 +01:00
parent 63cc0d92f1
commit f863f847dc
7 changed files with 95 additions and 82 deletions

View File

@@ -16,6 +16,8 @@ struct thread_ctx_t {
gb_thread thread_ctx_t tcx = {
.log = {
.break_on_fatal = true,
.break_on_error = true,
.break_on_warning = true,
.log_proc = default_log_proc,
}
};

View File

@@ -1,47 +1,12 @@
typedef struct ht_key_value_t ht_key_value_t;
struct ht_key_value_t {
union {
s8_t key_string;
u64 key_u64;
void *key_ptr;
};
union {
s8_t value_string;
void *value_ptr;
u64 value_u64;
};
};
typedef struct ht_node_t ht_node_t;
struct ht_node_t {
ht_node_t *next;
ht_key_value_t kv;
};
typedef struct ht_bucket_t ht_bucket_t;
struct ht_bucket_t {
ht_node_t *first;
ht_node_t *last;
};
typedef struct ht_dict_t ht_dict_t;
struct ht_dict_t {
ma_arena_t *arena;
i32 item_count;
i32 bucket_count;
ht_bucket_t *buckets;
};
ht_dict_t *ht_create(ma_arena_t *arena, i32 size) {
ht_dict_t *result = ma_push_type(arena, ht_dict_t);
fn ht_t *ht_create(ma_arena_t *arena, i32 size) {
ht_t *result = ma_push_type(arena, ht_t);
result->buckets = ma_push_array(arena, ht_bucket_t, size);
result->bucket_count = size;
result->arena = arena;
return result;
}
ht_node_t *ht_insert_kv(ht_dict_t *ht, u64 hash, ht_key_value_t kv) {
fn ht_node_t *ht_insert_kv(ht_t *ht, u64 hash, ht_key_value_t kv) {
ht_node_t *result = ma_push_type(ht->arena, ht_node_t);
result->kv = kv;
i64 idx = hash % ht->bucket_count;
@@ -50,21 +15,26 @@ ht_node_t *ht_insert_kv(ht_dict_t *ht, u64 hash, ht_key_value_t kv) {
return result;
}
ht_node_t *ht_insert_u64_u64(ht_dict_t *ht, u64 key, u64 value) {
fn ht_node_t *ht_insert_u64(ht_t *ht, u64 key, u64 value) {
u64 hash = hash_data(s8_struct(key));
return ht_insert_kv(ht, hash, (ht_key_value_t){.key_u64 = key, .value_u64 = value});
}
ht_node_t *ht_insert_ptr(ht_dict_t *ht, void *key, void *value) {
return ht_insert_u64_u64(ht, (u64)key, (u64)value);
fn ht_node_t *ht_insert_ptr(ht_t *ht, void *key, void *value) {
return ht_insert_u64(ht, (u64)key, (u64)value);
}
ht_node_t *ht_insert_string(ht_dict_t *ht, s8_t key, s8_t value) {
fn ht_node_t *ht_insert_string(ht_t *ht, s8_t key, s8_t value) {
u64 hash = hash_data(key);
return ht_insert_kv(ht, hash, (ht_key_value_t){.key_string = key, .value_string = value});
}
ht_node_t *ht_search_u64(ht_dict_t *ht, u64 key) {
fn ht_node_t *ht_insert_string_ptr(ht_t *ht, s8_t key, void *value) {
u64 hash = hash_data(key);
return ht_insert_kv(ht, hash, (ht_key_value_t){.key_string = key, .value_ptr = value});
}
fn ht_node_t *ht_search_u64_ex(ht_t *ht, u64 key) {
u64 hash = hash_data(s8_struct(key));
i64 idx = hash % ht->bucket_count;
ht_bucket_t *bucket = ht->buckets + idx;
@@ -76,11 +46,19 @@ ht_node_t *ht_search_u64(ht_dict_t *ht, u64 key) {
return NULL;
}
ht_node_t *ht_search_ptr(ht_dict_t *ht, void *key_ptr) {
return ht_search_u64(ht, (u64)key_ptr);
fn ht_node_t *ht_search_ptr_ex(ht_t *ht, void *key) {
u64 hash = hash_data(s8_struct(key));
i64 idx = hash % ht->bucket_count;
ht_bucket_t *bucket = ht->buckets + idx;
for (ht_node_t *it = bucket->first; it; it = it->next) {
if (it->kv.key_ptr == key) {
return it;
}
}
return NULL;
}
ht_node_t *ht_search_string(ht_dict_t *ht, s8_t key) {
fn ht_node_t *ht_search_string_ex(ht_t *ht, s8_t key) {
u64 hash = hash_data(key);
i64 idx = hash % ht->bucket_count;
ht_bucket_t *bucket = ht->buckets + idx;
@@ -90,4 +68,45 @@ ht_node_t *ht_search_string(ht_dict_t *ht, s8_t key) {
}
}
return NULL;
}
}
fn u64 *ht_search_u64(ht_t *ht, u64 key) {
ht_node_t *node = ht_search_u64_ex(ht, key);
if (node) return &node->kv.value_u64;
return NULL;
}
fn void **ht_search_ptr(ht_t *ht, void *key) {
ht_node_t *node = ht_search_ptr_ex(ht, key);
if (node) return &node->kv.value_ptr;
return NULL;
}
fn s8_t *ht_search_string(ht_t *ht, s8_t key) {
ht_node_t *node = ht_search_string_ex(ht, key);
if (node) return &node->kv.value_string;
return NULL;
}
fn void **ht_search_string_ptr(ht_t *ht, s8_t key) {
ht_node_t *node = ht_search_string_ex(ht, key);
if (node) return &node->kv.value_ptr;
return NULL;
}
///////////////////////////////
// string interning
fn s8i_t *intern_string(ht_t *ht, s8_t string) {
s8_t *item = ht_search_string(ht, string);
if (!item) {
string = s8_copy(ht->arena, string);
ht_node_t *node = ht_insert_string(ht, string, string);
item = &node->kv.value_string;
}
return item;
}
fn s8i_t *internf(ht_t *ht, char *str, ...) {
S8_FMT(ht->arena, str, string);
return intern_string(ht, string);
}

View File

@@ -12,4 +12,5 @@
#include "core_type_info.h"
#include "core_intrin.h"
#include "core_platform.h"
#include "core_hash_table.h"
#include "core_ctx.h"

View File

@@ -99,41 +99,54 @@ void test_s8(void) {
void test_hash_table(void) {
ma_temp_t scratch = ma_begin_scratch();
{
ht_dict_t *ht = ht_create(scratch.arena, 16);
ht_t *ht = ht_create(scratch.arena, 16);
for (u64 i = 0; i < 128; i += 1) {
ht_insert_u64_u64(ht, i, i);
ht_node_t *node = ht_search_u64(ht, i);
ht_insert_u64(ht, i, i);
ht_node_t *node = ht_search_u64_ex(ht, i);
assert(node->kv.value_u64 == i);
assert(node->kv.key_u64 == i);
}
for (u64 i = 0; i < 128; i += 1) {
ht_node_t *node = ht_search_u64(ht, i);
ht_node_t *node = ht_search_u64_ex(ht, i);
assert(node->kv.value_u64 == i);
assert(node->kv.key_u64 == i);
}
ht_node_t *node = ht_search_u64(ht, 1111);
ht_node_t *node = ht_search_u64_ex(ht, 1111);
assert(node == NULL);
}
{
ht_dict_t *ht = ht_create(scratch.arena, 16);
ht_t *ht = ht_create(scratch.arena, 16);
for (i32 i = 0; i < 128; i += 1) {
s8_t s = s8_printf(scratch.arena, "%d", i);
ht_insert_string(ht, s, s);
ht_node_t *node = ht_search_string(ht, s);
ht_node_t *node = ht_search_string_ex(ht, s);
assert(s8_are_equal(node->kv.value_string, s));
assert(s8_are_equal(node->kv.key_string, s));
}
ht_node_t *node = ht_search_string(ht, s8_lit("memes"));
ht_node_t *node = ht_search_string_ex(ht, s8_lit("memes"));
assert(node == NULL);
}
ma_end_scratch(scratch);
}
void test_intern_table(void) {
ma_temp_t scratch = ma_begin_scratch();
ht_t *ht = ht_create(scratch.arena, 4);
assert(internf(ht, "asd") == internf(ht, "asd"));
assert(internf(ht, "asdf") != internf(ht, "asd"));
assert(internf(ht, "asdf") == internf(ht, "asdf"));
assert(internf(ht, "123asdf") == internf(ht, "123asdf"));
assert(internf(ht, "123asdf") != internf(ht, "133asdf"));
assert(internf(ht, "") == internf(ht, ""));
assert(internf(ht, "") != internf(ht, "a"));
ma_end_scratch(scratch);
}
#include <stdio.h>
int main(int argc, char **argv) {
@@ -149,6 +162,7 @@ int main(int argc, char **argv) {
test_s8();
test_hash_table();
test_intern_table();
printf("all done!\n");
}