diff --git a/src/ast.odin b/src/ast.odin index e1c737d..4604658 100644 --- a/src/ast.odin +++ b/src/ast.odin @@ -387,6 +387,7 @@ node_create_struct_enum_or_union :: proc( range = range, value = name, } + append(&ret.children, node_create_value(.Identifier, range, name)) for field in fields { append(&ret.children, field) } diff --git a/src/llvm_emitter.odin b/src/llvm_emitter.odin index 7179680..f38d559 100644 --- a/src/llvm_emitter.odin +++ b/src/llvm_emitter.odin @@ -129,6 +129,16 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil case .Pointer: pointer_of := generate_llvm_type_from_node(ctx, mod, builder, type.pointer_to) return LLVMPointerType(pointer_of, 0) + case .Struct: + struct_type := llvm_scope_find_type(&type.struct_type.name) + if struct_type == nil { + append( + &g_message_list, + message_create(.Error, fmt.aprintf("Struct '%s' not found", type.struct_type.name), {}), + ) + return nil + } + return struct_type } panic("LLVM-IR: Invalid type") } @@ -295,11 +305,96 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil type := llvm_scope_find_type(&node.value.([dynamic]u8)) def_value := LLVMBuildLoad2(builder, type, def, "loadtmp") return def_value + } else if node.kind == .StructInitializer { + struct_name := &node.value.([dynamic]u8) + struct_type := llvm_scope_find_type(struct_name) + if struct_type == nil { + append( + &g_message_list, + message_create(.Error, fmt.aprintf("Struct '%s' not found", struct_name), node.range), + ) + return nil + } + + struct_values := [dynamic]LLVMValueRef{} + for &field, i in node.children[:] { + field_value := generate_llvm_expression(ctx, mod, builder, field) + append(&struct_values, field_value) + } + struct_value := LLVMConstNamedStruct(struct_type, raw_data(struct_values[:]), len(struct_values)) + return struct_value } fmt.panicf("FIXME: Implement other node kinds. Got: {}", node.kind) } + generate_llvm_cast :: proc( + ctx: LLVMContextRef, + mod: LLVMModuleRef, + builder: LLVMBuilderRef, + node: ^Node, + ) -> LLVMValueRef { + value := generate_llvm_expression(ctx, mod, builder, node.children[0]) + if value == nil { + return nil + } + if node.return_type.kind == .Integer { + if node.children[0].return_type.kind == .Float { + if node.return_type.bit_size == 32 { + return LLVMBuildFPToSI(builder, value, LLVMInt32TypeInContext(ctx), "casttmp") + } else if node.return_type.bit_size == 64 { + return LLVMBuildFPToSI(builder, value, LLVMInt64TypeInContext(ctx), "casttmp") + } + } else if node.children[0].return_type.kind == .Integer { + if node.return_type.is_signed && node.children[0].return_type.is_signed { + return LLVMBuildSExt( + builder, + value, + LLVMIntTypeInContext(ctx, uint(node.return_type.bit_size)), + "casttmp", + ) + } else if !node.return_type.is_signed && !node.children[0].return_type.is_signed { + return LLVMBuildZExt( + builder, + value, + LLVMIntTypeInContext(ctx, uint(node.return_type.bit_size)), + "casttmp", + ) + } else { + return LLVMBuildIntCast( + builder, + value, + LLVMIntTypeInContext(ctx, uint(node.return_type.bit_size)), + "casttmp", + ) + } + } + } else if node.return_type.kind == .Float { + if node.children[0].return_type.kind == .Integer { + if node.children[0].return_type.bit_size == 32 { + return LLVMBuildSIToFP(builder, value, LLVMFloatTypeInContext(ctx), "casttmp") + } else if node.children[0].return_type.bit_size == 64 { + return LLVMBuildSIToFP(builder, value, LLVMDoubleTypeInContext(ctx), "casttmp") + } + } else if node.children[0].return_type.kind == .Float { + if node.return_type.bit_size == 32 { + if node.children[0].return_type.bit_size == 64 { + return LLVMBuildFPTrunc(builder, value, LLVMFloatTypeInContext(ctx), "casttmp") + } else if node.children[0].return_type.bit_size == 32 { + return value + } + } else if node.return_type.bit_size == 64 { + if node.children[0].return_type.bit_size == 32 { + return LLVMBuildFPExt(builder, value, LLVMDoubleTypeInContext(ctx), "casttmp") + } else if node.children[0].return_type.bit_size == 64 { + return value + } + } + } + } + fmt.panicf("LLVM-C: Unsupported cast: {} to {}", node.children[0].return_type.kind, node.return_type.kind) + } + generate_llvm_expression :: proc( ctx: LLVMContextRef, mod: LLVMModuleRef, @@ -309,13 +404,45 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil if node.kind == .BinaryExpression { return generate_llvm_binary_expression(ctx, mod, builder, node) } - if node.kind == .Integer || node.kind == .Float || node.kind == .FunctionCall || node.kind == .Identifier { + if node.kind == .Cast { + return generate_llvm_cast(ctx, mod, builder, node) + } + if node.kind == .Integer || + node.kind == .Float || + node.kind == .FunctionCall || + node.kind == .Identifier || + node.kind == .StructInitializer { return generate_llvm_value(ctx, mod, builder, node) } if node.kind == .FunctionCall { return generate_llvm_function_call(ctx, mod, builder, node) } - return nil + fmt.panicf("FIXME: Implement other node kinds. Got: {}", node.kind) + } + + generate_llvm_struct_type :: proc( + ctx: LLVMContextRef, + mod: LLVMModuleRef, + builder: LLVMBuilderRef, + node: ^Node, + ) -> LLVMTypeRef { + struct_name := strings.clone_to_cstring(string(node.children[0].value.([dynamic]u8)[:])) + struct_type := LLVMStructCreateNamed(ctx, struct_name) + + if len(node.children) == 1 { + llvm_top_scope().types[get_character_sum_of_dyn_arr(&node.children[0].value.([dynamic]u8))] = struct_type + return struct_type + } + + struct_body := [dynamic]LLVMTypeRef{} + for &field in node.children[1:] { + field_type := generate_llvm_type_from_node(ctx, mod, builder, field.return_type) + append(&struct_body, field_type) + } + + LLVMStructSetBody(struct_type, raw_data(struct_body[:]), len(struct_body), LLVMBool(node.value.(bool))) + llvm_top_scope().types[get_character_sum_of_dyn_arr(&node.children[0].value.([dynamic]u8))] = struct_type + return struct_type } generate_llvm_binary_expression :: proc( @@ -637,6 +764,8 @@ generate_llvm :: proc(ctx: LLVMContextRef, mod: LLVMModuleRef, builder: LLVMBuil case .Ret: value := generate_llvm_expression(ctx, mod, builder, node.children[0]) LLVMBuildRet(builder, value) + case .Struct: + type := generate_llvm_struct_type(ctx, mod, builder, node) case: fmt.panicf("FIXME: Implement other node kinds. Got: {}", node.kind) } diff --git a/src/llvmc.odin b/src/llvmc.odin index 7fb4601..85045b6 100644 --- a/src/llvmc.odin +++ b/src/llvmc.odin @@ -120,6 +120,11 @@ foreign llvmc { LLVMConstInt :: proc(IntTy: LLVMTypeRef, N: u64, SignExtend: LLVMBool) -> LLVMValueRef --- LLVMConstReal :: proc(RealTy: LLVMTypeRef, N: f64) -> LLVMValueRef --- + LLVMConstNamedStruct :: proc(StructTy: LLVMTypeRef, ConstantVals: [^]LLVMValueRef, Count: uint) -> LLVMValueRef --- + + LLVMStructTypeInContext :: proc(C: LLVMContextRef, ElementTypes: [^]LLVMTypeRef, ElementCount: uint, Packed: LLVMBool) -> LLVMTypeRef --- + LLVMStructCreateNamed :: proc(C: LLVMContextRef, Name: cstring) -> LLVMTypeRef --- + LLVMStructSetBody :: proc(StructTy: LLVMTypeRef, ElementTypes: [^]LLVMTypeRef, ElementCount: uint, Packed: LLVMBool) --- LLVMFunctionType :: proc(ReturnType: LLVMTypeRef, ParamTypes: [^]LLVMTypeRef, ParamCount: uint, IsVarArg: LLVMBool) -> LLVMTypeRef --- LLVMAddFunction :: proc(M: LLVMModuleRef, Name: cstring, FunctionType: LLVMTypeRef) -> LLVMValueRef --- @@ -167,6 +172,15 @@ foreign llvmc { LLVMBuildPhi :: proc(Builder: LLVMBuilderRef, Ty: LLVMTypeRef, Name: cstring) -> LLVMValueRef --- + LLVMBuildFPToSI :: proc(Builder: LLVMBuilderRef, Val: LLVMValueRef, DestTy: LLVMTypeRef, Name: cstring) -> LLVMValueRef --- + LLVMBuildSExt :: proc(Builder: LLVMBuilderRef, Val: LLVMValueRef, DestTy: LLVMTypeRef, Name: cstring) -> LLVMValueRef --- + LLVMBuildZExt :: proc(Builder: LLVMBuilderRef, Val: LLVMValueRef, DestTy: LLVMTypeRef, Name: cstring) -> LLVMValueRef --- + LLVMBuildIntCast :: proc(Builder: LLVMBuilderRef, Val: LLVMValueRef, DestTy: LLVMTypeRef, Name: cstring) -> LLVMValueRef --- + LLVMBuildSIToFP :: proc(Builder: LLVMBuilderRef, Val: LLVMValueRef, DestTy: LLVMTypeRef, Name: cstring) -> LLVMValueRef --- + LLVMBuildFPTrunc :: proc(Builder: LLVMBuilderRef, Val: LLVMValueRef, DestTy: LLVMTypeRef, Name: cstring) -> LLVMValueRef --- + LLVMBuildFPExt :: proc(Builder: LLVMBuilderRef, Val: LLVMValueRef, DestTy: LLVMTypeRef, Name: cstring) -> LLVMValueRef --- + + LLVMTypeOf :: proc(Val: LLVMValueRef) -> LLVMTypeRef --- LLVMAddIncoming :: proc(PhiNode: LLVMValueRef, IncomingValues: [^]LLVMValueRef, IncomingBlocks: [^]LLVMBasicBlockRef, Count: uint) --- diff --git a/src/parser.odin b/src/parser.odin index 566108c..fca62d1 100644 --- a/src/parser.odin +++ b/src/parser.odin @@ -62,9 +62,11 @@ parser_parse_block :: proc(parser: ^Parser, end: TokenKind) -> (ret: ^Node) { if accept(parser, .Let) { ret := parser_parse_definitions(parser) expect(parser, .Semicolon) - for stmt in ret { - if stmt != nil { - append(&statements, stmt) + if ret != nil { + for stmt in ret { + if stmt != nil { + append(&statements, stmt) + } } } } else { @@ -138,10 +140,22 @@ parser_parse_struct_definition :: proc(parser: ^Parser) -> ^Node { expect(parser, .Identifier) return nil } + is_packed := false + if parser.tok.kind == .String { + value := parser.tok.value.([dynamic]u8) + if compare_dyn_arr_string(&value, "packed") { + is_packed = true + } else { + panic("TODO, unknown struct attribute") + } + parser_next(parser) + } expect(parser, .OpenBrace) fields := parser_parse_definitions(parser, .CloseBrace) expect(parser, .CloseBrace) - return node_create_struct_enum_or_union(range, .Struct, name, fields) + n := node_create_struct_enum_or_union(range, .Struct, name, fields) + n.value = is_packed + return n } @(private = "file") @@ -269,8 +283,20 @@ parser_parse_definitions :: proc(parser: ^Parser, end := TokenKind.Semicolon) -> type: ^Node = nil are_constants := false uninitialized := false + if parser.tok.kind != .Identifier { + append( + &g_message_list, + message_create(.Error, "Built-in keywords cannot be used for variable names", parser.tok.range), + ) + for parser.tok.kind != end && parser.tok.kind != .EOF { + parser_next(parser) + } + return nil + } + for parser.tok.kind != end && parser.tok.kind != .EOF { names := [dynamic][dynamic]u8{} + for parser.tok.kind == .Identifier { tok := parser.tok if !expect(parser, .Identifier) { diff --git a/src/type.odin b/src/type.odin index c10091a..6405023 100644 --- a/src/type.odin +++ b/src/type.odin @@ -1,86 +1,111 @@ package main TypeKind :: enum { - Integer, - Float, - Pointer, - Array, + Integer, + Float, + Pointer, + Array, + Struct, +} + +StructType :: struct { + name: [dynamic]u8, } Type :: struct { - kind: TypeKind, - bit_size: u8, - is_signed: bool, - pointer_to: ^Type, - array_of: ^Type, - array_size: u64, + kind: TypeKind, + bit_size: u8, + is_signed: bool, + pointer_to: ^Type, + array_of: ^Type, + array_size: u64, + struct_type: ^StructType, } FunctionType :: struct { - name: [dynamic]u8, - return_type: ^Type, - parameter_types: [dynamic]^Type, + name: [dynamic]u8, + return_type: ^Type, + parameter_types: [dynamic]^Type, } -compare_types :: proc(a: ^Type, b: ^Type) -> (ret: bool) { - ret = a != nil && b != nil && a.kind == b.kind && a.bit_size == b.bit_size && a.is_signed == b.is_signed - if ret == false { - return - } +compare_types :: proc(a: ^Type, b: ^Type) -> (ret: bool, cast_required: bool) { + cast_required = false - if a.kind == .Pointer { - ret = compare_types(a.pointer_to, b.pointer_to) - } else if a.kind == .Array { - ret = a.array_size == b.array_size && compare_types(a.array_of, b.array_of) - } - return + if (a.kind == .Integer || a.kind == .Float) && (a.bit_size > b.bit_size) { + ret = true + cast_required = true + return + } + + ret = a != nil && b != nil && a.kind == b.kind && a.bit_size == b.bit_size && a.is_signed == b.is_signed + if ret == false { + return + } + + if a.kind == .Pointer { + ret, _ = compare_types(a.pointer_to, b.pointer_to) + } else if a.kind == .Array { + ret, _ = compare_types(a.array_of, b.array_of) + ret = a.array_size == b.array_size && ret + } + return } compare_function_types :: proc(a: ^FunctionType, b: ^FunctionType) -> (ret: bool) { - ret = a != nil && b != nil && compare_types(a.return_type, b.return_type) - if ret { - for &arg, i in a.parameter_types { - if !compare_types(arg, b.parameter_types[i]) { - ret = false - break - } - } - } - return + ok, cast_ := compare_types(a.return_type, b.return_type) + ret = a != nil && b != nil && (ok && !cast_) + if ret { + for &arg, i in a.parameter_types { + ok, cast_ := compare_types(arg, b.parameter_types[i]) + if !(ok && !cast_) { + ret = false + break + } + } + } + return } function_type_create :: proc() -> (ret: ^FunctionType) { - ret = new(FunctionType) - return + ret = new(FunctionType) + return } type_create_integer :: proc(bit_size: u8, signed: bool) -> (ret: ^Type) { - ret = new(Type) - ret.kind = .Integer - ret.bit_size = bit_size - ret.is_signed = signed - return + ret = new(Type) + ret.kind = .Integer + ret.bit_size = bit_size + ret.is_signed = signed + return } type_create_float :: proc(bit_size: u8) -> (ret: ^Type) { - ret = new(Type) - ret.kind = .Float - ret.bit_size = bit_size - ret.is_signed = true - return + ret = new(Type) + ret.kind = .Float + ret.bit_size = bit_size + ret.is_signed = true + return } type_create_pointer :: proc(to: ^Type) -> (ret: ^Type) { - ret = new(Type) - ret.kind = .Pointer - ret.pointer_to = to - return + ret = new(Type) + ret.kind = .Pointer + ret.pointer_to = to + return } type_create_array :: proc(of: ^Type, size: u64) -> (ret: ^Type) { - ret = new(Type) - ret.kind = .Array - ret.array_of = of - ret.array_size = size - return + ret = new(Type) + ret.kind = .Array + ret.array_of = of + ret.array_size = size + return +} + +type_create_struct :: proc(name: [dynamic]u8) -> (ret: ^Type) { + ret = new(Type) + ret.kind = .Struct + ret.struct_type = new(StructType) + ret.struct_type.name = name + return } diff --git a/src/type_checker.odin b/src/type_checker.odin index d30f136..c595d05 100644 --- a/src/type_checker.odin +++ b/src/type_checker.odin @@ -103,6 +103,10 @@ ast_to_type :: proc(node: ^Node) -> ^Type { } return type_create_float(u8(bit_size)) } else { + res := scope_struct_lookup(value) + if res != nil { + return type_create_struct(value) + } fmt.panicf("Unhandled identifier in ast_to_type: %s", value) } } else if node.kind == .Pointer { @@ -166,6 +170,17 @@ scope_function_lookup :: proc(name: [dynamic]u8) -> ^FunctionType { return nil } +scope_struct_lookup :: proc(name: [dynamic]u8) -> ^Struct { + name_ := name + #reverse for &scope in scope_stack { + struct_, ok := scope.structure_definitions[get_character_sum_of_dyn_arr(&name_)] + if ok { + return struct_ + } + } + return nil +} + scope_function_return_type_lookup :: proc() -> ^Type { #reverse for &scope in scope_stack { if scope.function_return_type != nil { @@ -248,7 +263,13 @@ type_check :: proc(ast: ^Node, parent_ast: ^Node) { for param, i in fn.parameter_types { type_check(ast.children[i + 1], ast) - if !compare_types(param, ast.children[i + 1].return_type) { + ok, cast_required := compare_types(param, ast.children[i + 1].return_type) + if cast_required { + cast_ := node_create_cast({}, ast.children[i + 1], nil) + cast_.return_type = param + ast.children[i + 1] = cast_ + } + if !ok { append( &g_message_list, message_create( @@ -289,7 +310,13 @@ type_check :: proc(ast: ^Node, parent_ast: ^Node) { type_check(ast.children[0], ast) type_check(ast.children[1], ast) - if !compare_types(ast.children[0].return_type, ast.children[1].return_type) { + ok, cast_required := compare_types(ast.children[0].return_type, ast.children[1].return_type) + if cast_required { + cast_ := node_create_cast(ast.children[1].range, ast.children[1], nil) + cast_.return_type = ast.children[0].return_type + ast.children[1] = cast_ + } + if !ok { append( &g_message_list, message_create( @@ -337,7 +364,13 @@ type_check :: proc(ast: ^Node, parent_ast: ^Node) { ) } else { type_check(ast.children[0], ast) - if !compare_types(function_return_type, ast.children[0].return_type) { + ok, cast_required := compare_types(function_return_type, ast.children[0].return_type) + if cast_required { + cast_ := node_create_cast({}, ast.children[0], nil) + cast_.return_type = function_return_type + ast.children[0] = cast_ + } + if !ok { append( &g_message_list, message_create( @@ -351,6 +384,14 @@ type_check :: proc(ast: ^Node, parent_ast: ^Node) { case .Cast: type_check(ast.children[0], ast) type_to := ast_to_type(ast.children[1]) + append( + &g_message_list, + message_create( + .Warning, + fmt.aprintf("Cast to type not checked: {}", ast.children[1].value.([dynamic]u8)), + ast.children[1].range, + ), + ) // FIXME: Check if compatible ast.return_type = type_to case .BitwiseCast: @@ -363,7 +404,13 @@ type_check :: proc(ast: ^Node, parent_ast: ^Node) { if ast.children[1] == nil { ast.return_type = ast.children[2].return_type } - if !compare_types(ast.return_type, ast.children[2].return_type) { + ok, cast_required := compare_types(ast.return_type, ast.children[2].return_type) + if cast_required { + cast_ := node_create_cast({}, ast.children[2], nil) + cast_.return_type = ast.return_type + ast.children[2] = cast_ + } + if !ok { append( &g_message_list, message_create( @@ -486,8 +533,14 @@ type_check :: proc(ast: ^Node, parent_ast: ^Node) { } idx := 0 - for child in ast.children { - if compare_types(child.return_type, struct_.fields[idx].type) == false { + for &child in ast.children { + ok, cast_required := compare_types(child.return_type, struct_.fields[idx].type) + if cast_required { + cast_ := node_create_cast({}, child, {}) + cast_.return_type = struct_.fields[idx].type + child = cast_ + } + if !ok { append( &g_message_list, message_create( @@ -503,6 +556,8 @@ type_check :: proc(ast: ^Node, parent_ast: ^Node) { } idx += 1 } + + ast.return_type = type_create_struct(ast.value.([dynamic]u8)) case: fmt.panicf("Unhandled node kind in type_check: {}", ast.kind) } @@ -561,19 +616,25 @@ find_function_definitions :: proc(ast_: ^Node) -> (ret: [dynamic]^FunctionType) case .Struct: fmt.printf("Struct found\n") struct_ := struct_create() + should_ignore := true for field in ast.children { if field == nil { continue } + if should_ignore { + should_ignore = false + continue + } if field.kind != .VariableDeclaration { fmt.panicf("Expected VariableDeclaration in struct") } struct_field := struct_field_create() struct_field.name = field.children[0].value.([dynamic]u8) struct_field.type = ast_to_type(field.children[1]) + field.return_type = struct_field.type append(&struct_.fields, struct_field) } - scope_stack[len(scope_stack) - 1].structure_definitions[get_character_sum_of_dyn_arr(&ast.value.([dynamic]u8))] = + scope_stack[len(scope_stack) - 1].structure_definitions[get_character_sum_of_dyn_arr(&ast.children[0].value.([dynamic]u8))] = struct_ case: } diff --git a/test_type_checker.cat b/test_type_checker.cat index ca234eb..520eec8 100644 --- a/test_type_checker.cat +++ b/test_type_checker.cat @@ -39,31 +39,33 @@ \} \ \(meow (add 60 9)) -> meow -\ -\fn InitWindow(w h: i32, title: i32) -\fn CloseWindow -\fn ClearBackground(c: i32) -\fn BeginDrawing -\fn EndDrawing -\fn DrawFPS(x y: i32) -\fn DrawRectangle(x y w h c: i32) -\fn DrawCircle(x y: i32, r: f32, c: i32) -\fn WindowShouldClose i32 -\ -\InitWindow 640 480 0 -\for WindowShouldClose == 0 { -\ BeginDrawing -\ ClearBackground 0 -\ DrawFPS 20 20 -\ DrawRectangle 80 80 100 200 4294967295 -\ DrawCircle 90 90 100.0 4278255615 -\ EndDrawing -\} -\CloseWindow -\ -struct MyStruct { - a b c: i32, + +struct Color "packed" { + r g b a: u8, } -.MyStruct{ 1 2 3 } +fn InitWindow(w h: i32, title: i32) +fn CloseWindow +fn ClearBackground(c: Color) +fn BeginDrawing +fn EndDrawing +fn DrawFPS(x y: i32) +fn DrawRectangle(x y w h: i32, c: Color) +fn DrawCircle(x y: i32, r: f32, c: Color) +fn WindowShouldClose i32 + +let white :: .Color{255 255 255 255} +let red :: .Color{255 0 0 255} +let blue :: .Color{0 0 255 255} + +InitWindow 640 480 0 +for WindowShouldClose == 0 { + BeginDrawing + ClearBackground white + DrawFPS 20 20 + \DrawRectangle 80 80 100 200 red + \DrawCircle 90 90 100.0 blue + EndDrawing +} +CloseWindow