diff --git a/speedcat.pdb b/speedcat.pdb new file mode 100644 index 0000000..7bf97f6 Binary files /dev/null and b/speedcat.pdb differ diff --git a/test_type_checker.cat b/test_type_checker.cat new file mode 100644 index 0000000..6dfcfcf --- /dev/null +++ b/test_type_checker.cat @@ -0,0 +1,9 @@ +let asdf := 123.0 +let poop :: 12 as f32 + 2.0 * asdf + +\fn name {} +fn name(a b: i32) i32 { } +\fn name() {} +\fn name(param1 param2 param3: i32, param4: u32) u32 { } + +name 123 456 diff --git a/type.odin b/type.odin new file mode 100644 index 0000000..2cb2a17 --- /dev/null +++ b/type.odin @@ -0,0 +1,64 @@ +package main + +TypeKind :: enum { + Integer, + Float, + String, +} + +Type :: struct { + kind: TypeKind, + bit_size: u8, + is_signed: bool, +} + +FunctionType :: struct { + name: [dynamic]u8, + return_type: ^Type, + parameter_types: [dynamic]^Type, +} + +compare_types :: proc(a: ^Type, b: ^Type) -> (ret: bool) { + ret = a != nil && b != nil && a.kind == b.kind && a.bit_size == b.bit_size && a.is_signed == b.is_signed + return +} + +compare_function_types :: proc(a: ^FunctionType, b: ^FunctionType) -> (ret: bool) { + ret = a != nil && b != nil && compare_types(a.return_type, b.return_type) + if ret { + for &arg, i in a.parameter_types { + if !compare_types(arg, b.parameter_types[i]) { + ret = false + break + } + } + } + return +} + +function_type_create :: proc() -> (ret: ^FunctionType) { + ret = new(FunctionType) + return +} + +type_create_integer :: proc(bit_size: u8, signed: bool) -> (ret: ^Type) { + ret = new(Type) + ret.kind = .Integer + ret.bit_size = bit_size + ret.is_signed = signed + return +} + +type_create_float :: proc(bit_size: u8) -> (ret: ^Type) { + ret = new(Type) + ret.kind = .Float + ret.bit_size = bit_size + ret.is_signed = true + return +} + +type_create_string :: proc() -> (ret: ^Type) { + ret = new(Type) + ret.kind = .String + return +} diff --git a/type_checker.odin b/type_checker.odin new file mode 100644 index 0000000..2bba899 --- /dev/null +++ b/type_checker.odin @@ -0,0 +1,294 @@ +package main + +import "core:fmt" +import "core:strconv" + +Scope :: struct { + function_definitions: map[int]^FunctionType, // A map to nodes which are the function definitions + variable_definitions: map[int]^Type, // A map to types + variable_mutability_definitions: map[int]bool, // A map to a variable's mutability +} + +@(private = "file") +infer_type :: proc(parent: ^Node, child: ^Node) { + if child.return_type == nil { + #partial switch child.kind { + case .Integer: child.return_type = type_create_integer(32, true) + case .Float: child.return_type = type_create_float(32) + case .String: child.return_type = type_create_string() + case .Character: child.return_type = type_create_integer(32, false) + } + } else { + if parent != nil { + parent.return_type = child.return_type + } + } +} + +@(private = "file") +is_number :: proc(node: ^Node) -> bool { + return node.kind == .Integer || node.kind == .Float +} + +@(private = "file") +ast_to_type :: proc(node: ^Node) -> ^Type { + if node.kind == .Identifier { + value := node.value.([dynamic]u8) + if value[0] == 'u' { + bit_size, ok := strconv.parse_u64_of_base(string(value[1:]), 10) + if !ok { + fmt.panicf("Failed to parse integer: %s", value) + } + return type_create_integer(u8(bit_size), false) + } else if value[0] == 'i' { + bit_size, ok := strconv.parse_u64_of_base(string(value[1:]), 10) + if !ok { + fmt.panicf("Failed to parse integer: %s", value) + } + return type_create_integer(u8(bit_size), true) + } else if value[0] == 'f' { + bit_size, ok := strconv.parse_u64_of_base(string(value[1:]), 10) + if !ok { + fmt.panicf("Failed to parse integer: %s", value) + } + return type_create_float(u8(bit_size)) + } else { + fmt.panicf("Unhandled identifier in ast_to_type: %s", value) + } + } else { + fmt.panicf("Unhandled node kind in ast_to_type: {}", node.kind) + } +} + +scope_stack := [dynamic]Scope {} + +scope_enter :: proc() { + append(&scope_stack, Scope{}) + scope_stack[len(scope_stack) - 1].function_definitions = make(map[int]^FunctionType) + scope_stack[len(scope_stack) - 1].variable_definitions = make(map[int]^Type) +} + +scope_leave :: proc() { + if len(scope_stack) == 0 { + fmt.panicf("Tried to leave scope when there are no scopes") + } + delete(scope_stack[len(scope_stack) - 1].function_definitions) + delete(scope_stack[len(scope_stack) - 1].variable_definitions) + pop(&scope_stack) +} + +scope_variable_lookup :: proc(name: [dynamic]u8) -> ^Type { + name_ := name + for &scope in scope_stack { + type, ok := scope.variable_definitions[get_character_sum_of_dyn_arr(&name_)] + if ok { + return type + } + } + return nil +} + +scope_function_lookup :: proc(name: [dynamic]u8) -> ^FunctionType { + name_ := name + for &scope in scope_stack { + type, ok := scope.function_definitions[get_character_sum_of_dyn_arr(&name_)] + if ok { + return type + } + } + return nil +} + +type_check_function_call :: proc(ast: ^Node, parent_ast: ^Node, must_be_function := true) -> ^FunctionType { + name : [dynamic]u8 + if ast.kind == .FunctionCall { + name = ast.children[0].value.([dynamic]u8) + } else { + name = ast.value.([dynamic]u8) + } + fn := scope_function_lookup(name) + if fn == nil { + if must_be_function { + append(&g_message_list, + message_create( + .Error, + fmt.aprintf("Undefined function: %s", name), + ast.range, + ), + ) + } + return nil + } + + return fn +} + +type_check :: proc(ast: ^Node, parent_ast: ^Node) { + #partial switch (ast.kind) { + case .Integer: fallthrough + case .Float: fallthrough + case .String: + infer_type(parent_ast, ast) + case .Block: + scope_enter() + functions := find_function_definitions(ast) + for fn, i in functions { + scope_stack[len(scope_stack) - 1].function_definitions[get_character_sum_of_dyn_arr(&fn.name)] = fn + } + for child in ast.children { + type_check(child, ast) + } + scope_leave() + case .FunctionCall: + fn := type_check_function_call(ast, parent_ast) + if fn != nil { + if len(fn.parameter_types) != len(ast.children) - 1 { + append(&g_message_list, + message_create( + .Error, + fmt.aprintf("Function call parameter count mismatch for function `%s`: {} and {}", fn.name, len(fn.parameter_types), len(ast.children) - 1), + ast.range, + ), + ) + break + } + + for param, i in fn.parameter_types { + type_check(ast.children[i + 1], ast) + if !compare_types(param, ast.children[i + 1].return_type) { + append(&g_message_list, + message_create( + .Error, + fmt.aprintf("Type mismatch: {} and {}", param, ast.children[i + 1].return_type), + ast.range, + ), + ) + } + } + } + case .Identifier: + type := scope_variable_lookup(ast.value.([dynamic]u8)) + if type == nil { + fn := type_check_function_call(ast, parent_ast, false) + if fn == nil { + append(&g_message_list, + message_create( + .Error, + fmt.aprintf("Undefined variable: %s", ast.value.([dynamic]u8)), + ast.range, + ), + ) + } else { + ast.return_type = fn.return_type + } + } + ast.return_type = type + case .BinaryExpression: + type_check(ast.children[0], ast) + type_check(ast.children[1], ast) + + if !compare_types(ast.children[0].return_type, ast.children[1].return_type) { + append(&g_message_list, + message_create( + .Error, + fmt.aprintf("Type mismatch: {} and {}", ast.children[0].return_type, ast.children[1].return_type), + ast.range, + ), + ) + } + + if ast.value_token_kind == .Assign { + if !scope_stack[len(scope_stack) - 1].variable_mutability_definitions[get_character_sum_of_dyn_arr(&ast.children[0].value.([dynamic]u8))] { + append(&g_message_list, + message_create( + .Error, + fmt.aprintf("Variable is not mutable: {}", ast.children[0].value.([dynamic]u8)), + ast.range, + ), + ) + } + } + + ast.return_type = ast.children[0].return_type + + // FIXME: Verify that the operation is possible + case .UnaryExpression: + // FIXME: Verify that the operation is possible + type_check(ast.children[0], ast) + case .Cast: + type_check(ast.children[0], ast) + type_to := ast_to_type(ast.children[1]) + // FIXME: Check if compatible + ast.return_type = type_to + case .BitwiseCast: + type_check(ast.children[0], ast) + // FIXME: Check if they are both the same bit size + ast.return_type = ast_to_type(ast.children[1]) + case .VariableDeclaration: + type_check(ast.children[2], ast) + if ast.children[1] == nil { + ast.return_type = ast.children[2].return_type + } + if !compare_types(ast.return_type, ast.children[2].return_type) { + append(&g_message_list, + message_create( + .Error, + fmt.aprintf("Type mismatch: {} and {}", ast.return_type, ast.children[2].return_type), + ast.range, + ), + ) + } + scope_stack[len(scope_stack) - 1].variable_definitions[get_character_sum_of_dyn_arr(&ast.children[0].value.([dynamic]u8))] = ast.return_type + case .Function: + // FIXME: Declare variables from params + type_check(ast.children[1], ast) + case: + fmt.panicf("Unhandled node kind in type_check: {}", ast.kind) + } +} + +find_function_definitions :: proc(ast_: ^Node) -> (ret: [dynamic]^FunctionType) { + if ast_.kind != .Block { + return + } + for ast in ast_.children { + if ast == nil { + continue + } + #partial switch (ast.kind) { + case .Function: + for fn in ret { + if compare_dyn_arrs(&fn.name, &ast.value.([dynamic]u8)) { + append(&g_message_list, + message_create( + .Error, + fmt.aprintf("Function already defined: {}", ast.value.([dynamic]u8)), + ast.range, + ), + ) + continue + } + } + fn := function_type_create() + fn.name = ast.value.([dynamic]u8) + return_type : ^Type + if ast.children[0] == nil { + return_type = type_create_integer(0, false) + } else { + return_type = ast_to_type(ast.children[0]) + } + node_print(ast) + for decl, i in ast.children { + if i < 2 { + continue + } + type := ast_to_type(decl.children[1]) + append(&fn.parameter_types, type) + } + fmt.printf("Added: %s\n", fn.name) + append(&ret, fn) + case: + } + } + return +} diff --git a/util.odin b/util.odin index 48319ed..154b550 100644 --- a/util.odin +++ b/util.odin @@ -12,6 +12,18 @@ compare_dyn_arr_string :: proc(a: ^[dynamic]u8, b: string) -> bool { return true } +compare_dyn_arrs :: proc(a: ^[dynamic]u8, b: ^[dynamic]u8) -> bool { + if len(a) != len(b) { + return false + } + for c, i in a { + if c != b[i] { + return false + } + } + return true +} + get_character_sum_of_dyn_arr :: proc(a: ^[dynamic]u8) -> int { sum := 0 for c in a {