Add support for struct field accesses

Signed-off-by: Slendi <slendi@socopon.com>
This commit is contained in:
Slendi 2024-04-09 02:24:45 +03:00
parent a372c4420a
commit 66faf71ec7
8 changed files with 134 additions and 53 deletions

View File

@ -6,4 +6,4 @@ LLVMC=llvm-config
LLVM_LINKER="-lc++ $($LLVMC --libs core --cxxflags --ldflags --system-libs|tr '\n' ' ')"
odin run src -o:none -debug -out:speedcat -extra-linker-flags:"$LLVM_LINKER" -- test_type_checker.cat
clang -I/opt/homebrew/include /opt/homebrew/lib/libraylib.a -lm -framework Cocoa -framework OpenGL -framework IOKit test_type_checker.ll test.c -o raylib
clang -I/opt/local/include -L/opt/local/lib -lraylib -lm -framework Cocoa -framework OpenGL -framework IOKit test_type_checker.ll test.c -o raylib

View File

@ -187,6 +187,9 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil
) {
function_args_type := [dynamic]LLVMTypeRef{}
for arg in fn.children {
if arg == nil {
continue
}
if arg.kind != .VariableDeclaration {
continue
}
@ -214,6 +217,9 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil
// Add function arguments to the scope
arg_index: uint = 0
for arg in fn.children {
if arg == nil {
continue
}
if arg.kind != .VariableDeclaration {
continue
}
@ -230,6 +236,9 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil
// Generate function body
arg_index = 0
for &arg in fn.children {
if arg == nil {
continue
}
if arg.kind != .VariableDeclaration {
continue
}
@ -274,7 +283,6 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil
}
if node.return_type == nil || (node.return_type.kind == .Integer && node.return_type.bit_size == 0) {
fmt.printf("Void function call: %s\n", name)
return LLVMBuildCall2(builder, fn_type, fn_value, raw_data(fn_args[:]), len(fn_args), cstring(""))
}
@ -323,6 +331,30 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil
}
struct_value := LLVMConstNamedStruct(struct_type, raw_data(struct_values[:]), len(struct_values))
return struct_value
} else if node.kind == .FieldAccess {
if node.children[0].kind == .FieldAccess {
append(
&g_message_list,
message_create(.FIXME, fmt.aprintf("Nested field accesses are not implemented."), node.range),
)
return nil
}
def_struct := llvm_scope_find_definition(&node.children[0].value.([dynamic]u8))
type_struct := llvm_scope_find_type(&node.children[0].value.([dynamic]u8))
type_ref := generate_llvm_type_from_node(ctx, mod, builder, node.return_type)
struct_field_data := LLVMBuildAlloca(builder, type_ref, "struct_field")
def := LLVMBuildStructGEP2(
builder,
type_struct,
def_struct,
uint(node.return_type.struct_index),
"struct_ptr",
)
def_value := LLVMBuildLoad2(builder, type_ref, def, "loadtmp")
return def_value
}
fmt.panicf("FIXME: Implement other node kinds. Got: {}", node.kind)
@ -404,6 +436,9 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil
if node.kind == .BinaryExpression {
return generate_llvm_binary_expression(ctx, mod, builder, node)
}
if node.kind == .UnaryExpression {
return generate_llvm_unary_expression(ctx, mod, builder, node)
}
if node.kind == .Cast {
return generate_llvm_cast(ctx, mod, builder, node)
}
@ -411,7 +446,8 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil
node.kind == .Float ||
node.kind == .FunctionCall ||
node.kind == .Identifier ||
node.kind == .StructInitializer {
node.kind == .StructInitializer ||
node.kind == .FieldAccess {
return generate_llvm_value(ctx, mod, builder, node)
}
if node.kind == .FunctionCall {
@ -420,6 +456,23 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil
fmt.panicf("FIXME: Implement other node kinds. Got: {}", node.kind)
}
generate_llvm_unary_expression :: proc(
ctx: LLVMContextRef,
mod: LLVMModuleRef,
builder: LLVMBuilderRef,
node: ^Node,
) -> LLVMValueRef {
op := node.value_token_kind
lhs_node := node.children[0]
lhs := generate_llvm_expression(ctx, mod, builder, lhs_node)
if op == .Subtract {
neg := LLVMBuildNeg(builder, lhs, "tmpneg")
return neg
} else {
fmt.panicf("unsupported unary operation")
}
}
generate_llvm_struct_type :: proc(
ctx: LLVMContextRef,
mod: LLVMModuleRef,
@ -599,7 +652,10 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil
) {
condition_node := node.children[0]
true_node := node.children[1]
false_node := node.children[2]
false_node: ^Node = nil
if len(node.children) > 2 {
false_node = node.children[2]
}
bb := LLVMGetLastBasicBlock(function)
@ -620,22 +676,35 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil
comparison_res := LLVMBuildICmp(builder, .LLVMIntNE, condition, LLVMConstInt(int_32_type, 0, LLVMBool(1)), "")
true_block := LLVMAppendBasicBlockInContext(ctx, function, "")
false_block := LLVMCreateBasicBlockInContext(ctx, "")
false_block: LLVMBasicBlockRef
if false_node != nil {
false_block = LLVMCreateBasicBlockInContext(ctx, "")
}
end_block := LLVMCreateBasicBlockInContext(ctx, "")
cond_br := LLVMBuildCondBr(builder, comparison_res, true_block, false_block)
cond_br: LLVMValueRef
if false_node != nil {
cond_br = LLVMBuildCondBr(builder, comparison_res, true_block, false_block)
} else {
cond_br = LLVMBuildCondBr(builder, comparison_res, true_block, end_block)
}
LLVMPositionBuilderAtEnd(builder, true_block)
generate_llvm_scope(ctx, mod, builder, function, true_node, scope_number, true_block)
LLVMBuildBr(builder, end_block)
true_block = LLVMGetInsertBlock(builder)
LLVMAppendExistingBasicBlock(function, false_block)
if false_node != nil {
LLVMAppendExistingBasicBlock(function, false_block)
LLVMPositionBuilderAtEnd(builder, false_block)
generate_llvm_scope(ctx, mod, builder, function, false_node, scope_number, false_block)
}
LLVMPositionBuilderAtEnd(builder, false_block)
generate_llvm_scope(ctx, mod, builder, function, false_node, scope_number, false_block)
LLVMBuildBr(builder, end_block)
false_block = LLVMGetInsertBlock(builder)
if false_node != nil {
false_block = LLVMGetInsertBlock(builder)
}
LLVMPositionBuilderAtEnd(builder, end_block)
LLVMAppendExistingBasicBlock(function, end_block)
@ -681,7 +750,6 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil
llvm_scope_leave()
LLVMBuildBr(builder, basic_block)
} else if loop_type == LoopType.While {
fmt.println("In while")
condition_block := LLVMAppendBasicBlockInContext(ctx, function, "while_condition")
body_block := LLVMAppendBasicBlockInContext(ctx, function, "while_body")
end_block := LLVMAppendBasicBlockInContext(ctx, function, "while_end")
@ -742,6 +810,8 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil
fallthrough
case .FunctionCall:
fallthrough
case .UnaryExpression:
fallthrough
case .BinaryExpression:
generate_llvm_expression(ctx, mod, builder, node)
case .VariableDeclaration:

View File

@ -170,6 +170,8 @@ foreign llvmc {
LLVMBuildAShr :: proc(Builder: LLVMBuilderRef, LHS: LLVMValueRef, RHS: LLVMValueRef, Name: cstring) -> LLVMValueRef ---
LLVMBuildShr :: proc(Builder: LLVMBuilderRef, LHS: LLVMValueRef, RHS: LLVMValueRef, Name: cstring) -> LLVMValueRef ---
LLVMBuildNeg :: proc(Builder: LLVMBuilderRef, LHS: LLVMValueRef, Name: cstring) -> LLVMValueRef ---
LLVMBuildPhi :: proc(Builder: LLVMBuilderRef, Ty: LLVMTypeRef, Name: cstring) -> LLVMValueRef ---
LLVMBuildFPToSI :: proc(Builder: LLVMBuilderRef, Val: LLVMValueRef, DestTy: LLVMTypeRef, Name: cstring) -> LLVMValueRef ---
@ -180,6 +182,7 @@ foreign llvmc {
LLVMBuildFPTrunc :: proc(Builder: LLVMBuilderRef, Val: LLVMValueRef, DestTy: LLVMTypeRef, Name: cstring) -> LLVMValueRef ---
LLVMBuildFPExt :: proc(Builder: LLVMBuilderRef, Val: LLVMValueRef, DestTy: LLVMTypeRef, Name: cstring) -> LLVMValueRef ---
LLVMBuildStructGEP2 :: proc(Builder: LLVMBuilderRef, Ty: LLVMTypeRef, Pointer: LLVMValueRef, Idx: uint, Name: cstring) -> LLVMValueRef ---
LLVMTypeOf :: proc(Val: LLVMValueRef) -> LLVMTypeRef ---

View File

@ -47,11 +47,9 @@ main :: proc() {
return
}
}
fmt.println("After parse:")
//node_print(ast)
clear(&g_message_list)
type_check(ast, nil)
fmt.println("After type check:")
//node_print(ast)
if len(g_message_list) > 0 {
contains_errors := false
@ -67,7 +65,7 @@ main :: proc() {
}
}
//node_print(ast)
// node_print(ast)
name: string
if handle == os.stdin {

View File

@ -560,6 +560,8 @@ type_check :: proc(ast: ^Node, parent_ast: ^Node) {
type_check(ast.children[1], ast)
if len(ast.children) == 3 {
type_check(ast.children[2], ast)
} else {
append(&ast.children, node_create_block(ast.range, {}))
}
case .ExternFunction:
in_extern = true

46
test.c
View File

@ -1,24 +1,26 @@
#include <stdint.h>
void ClearBackground(uint64_t rgba);
void ClearBackgroundWrap(uint8_t r, uint8_t g, uint8_t b, uint8_t a) {
ClearBackground((uint64_t)r << 24 | (uint64_t)g << 16 | (uint64_t)b << 8 |
(uint64_t)a);
}
void DrawRectangle(uint32_t x, uint32_t y, uint32_t width, uint32_t height,
uint64_t rgba);
void DrawRectangleWrap(uint32_t x, uint32_t y, uint32_t width, uint32_t height,
uint8_t r, uint8_t g, uint8_t b, uint8_t a) {
DrawRectangle(x, y, width, height,
(uint64_t)r << 24 | (uint64_t)g << 16 | (uint64_t)b << 8 |
(uint64_t)a);
}
void DrawCircle(uint32_t x, uint32_t y, float radius, uint64_t rgba);
void DrawCircleWrap(uint32_t x, uint32_t y, float radius, uint8_t r, uint8_t g,
uint8_t b, uint8_t a) {
DrawCircle(x, y, radius,
(uint64_t)r << 24 | (uint64_t)g << 16 | (uint64_t)b << 8 |
(uint64_t)a);
}
// void ClearBackground(uint64_t rgba);
// void ClearBackgroundWrap(uint8_t r, uint8_t g, uint8_t b, uint8_t a) {
// ClearBackground((uint64_t)r << 24 | (uint64_t)g << 16 | (uint64_t)b << 8 |
// (uint64_t)a);
// }
//
// void DrawRectangle(uint32_t x, uint32_t y, uint32_t width, uint32_t height,
// uint64_t rgba);
// void DrawRectangleWrap(uint32_t x, uint32_t y, uint32_t width, uint32_t
// height,
// uint8_t r, uint8_t g, uint8_t b, uint8_t a) {
// DrawRectangle(x, y, width, height,
// (uint64_t)r << 24 | (uint64_t)g << 16 | (uint64_t)b << 8 |
// (uint64_t)a);
// }
//
// void DrawCircle(uint32_t x, uint32_t y, float radius, uint64_t rgba);
// void DrawCircleWrap(uint32_t x, uint32_t y, float radius, uint8_t r, uint8_t
// g,
// uint8_t b, uint8_t a) {
// DrawCircle(x, y, radius,
// (uint64_t)r << 24 | (uint64_t)g << 16 | (uint64_t)b << 8 |
// (uint64_t)a);
// }

View File

@ -41,7 +41,7 @@
\(meow (add 60 9)) -> meow
struct Color {
a b g r: u8,
r g b a: u8,
}
fn ClearBackground(c: u32)
@ -56,34 +56,43 @@ fn EndDrawing
fn DrawFPS(x y: i32)
fn WindowShouldClose i32
1 << 3 | 2
fn ColorToRaylib(c: Color) u32 {
ret c.a as u32 << 24 as u32 | c.b as u32 << 16 as u32 | c.g as u32 << 8 as u32 | c.r
ret (c.a as u32 << 24 as u32) | (c.b as u32 << 16 as u32) | (c.g as u32 << 8 as u32) | c.r
}
fn ClearBackgroundWrap(c: Color) {
fn ClearBackgroundWrap(c: Color) i32 {
ClearBackground (ColorToRaylib c)
ret 0
}
fn DrawRectangleWrap(x y w h: i32, c: Color) {
fn DrawRectangleWrap(x y w h: i32, c: Color) i32 {
DrawRectangle x y w h (ColorToRaylib c)
ret 0
}
fn DrawCircleWrap(x y: i32, r: f32, c: Color) {
fn DrawCircleWrap(x y: i32, r: f32, c: Color) i32 {
DrawCircle x y r (ColorToRaylib c)
ret 0
}
let white :: .Color{255 69 69 69}
let white :: .Color{69 69 69 255}
let red :: .Color{255 0 0 255}
let blue :: .Color{255 0 255 255}
let blue :: .Color{0 0 255 255}
let x := 0
let dir := 1
fn GetScreenHeight i32
InitWindow 640 480 0
SetTargetFPS 30
SetTargetFPS 60
for WindowShouldClose == 0 {
x = x + 1
x = x + dir
if x > GetScreenHeight - 200 {
dir = -1
} elif x < 0 {
dir = 1
}
BeginDrawing
ClearBackgroundWrap white

View File

@ -1,8 +1,5 @@
struct A {
a b: i32,
let inst := 123
for {
inst = inst + 1
}
let inst := .A{ 1 2 }
inst.a as u32