speedcat/type_checker.odin
Slendi 7fe927a683 Add type checking for functions
Signed-off-by: Slendi <slendi@socopon.com>
2024-02-28 15:07:06 +02:00

295 lines
9.8 KiB
Odin

package main
import "core:fmt"
import "core:strconv"
Scope :: struct {
function_definitions: map[int]^FunctionType, // A map to nodes which are the function definitions
variable_definitions: map[int]^Type, // A map to types
variable_mutability_definitions: map[int]bool, // A map to a variable's mutability
}
@(private = "file")
infer_type :: proc(parent: ^Node, child: ^Node) {
if child.return_type == nil {
#partial switch child.kind {
case .Integer: child.return_type = type_create_integer(32, true)
case .Float: child.return_type = type_create_float(32)
case .String: child.return_type = type_create_string()
case .Character: child.return_type = type_create_integer(32, false)
}
} else {
if parent != nil {
parent.return_type = child.return_type
}
}
}
@(private = "file")
is_number :: proc(node: ^Node) -> bool {
return node.kind == .Integer || node.kind == .Float
}
@(private = "file")
ast_to_type :: proc(node: ^Node) -> ^Type {
if node.kind == .Identifier {
value := node.value.([dynamic]u8)
if value[0] == 'u' {
bit_size, ok := strconv.parse_u64_of_base(string(value[1:]), 10)
if !ok {
fmt.panicf("Failed to parse integer: %s", value)
}
return type_create_integer(u8(bit_size), false)
} else if value[0] == 'i' {
bit_size, ok := strconv.parse_u64_of_base(string(value[1:]), 10)
if !ok {
fmt.panicf("Failed to parse integer: %s", value)
}
return type_create_integer(u8(bit_size), true)
} else if value[0] == 'f' {
bit_size, ok := strconv.parse_u64_of_base(string(value[1:]), 10)
if !ok {
fmt.panicf("Failed to parse integer: %s", value)
}
return type_create_float(u8(bit_size))
} else {
fmt.panicf("Unhandled identifier in ast_to_type: %s", value)
}
} else {
fmt.panicf("Unhandled node kind in ast_to_type: {}", node.kind)
}
}
scope_stack := [dynamic]Scope {}
scope_enter :: proc() {
append(&scope_stack, Scope{})
scope_stack[len(scope_stack) - 1].function_definitions = make(map[int]^FunctionType)
scope_stack[len(scope_stack) - 1].variable_definitions = make(map[int]^Type)
}
scope_leave :: proc() {
if len(scope_stack) == 0 {
fmt.panicf("Tried to leave scope when there are no scopes")
}
delete(scope_stack[len(scope_stack) - 1].function_definitions)
delete(scope_stack[len(scope_stack) - 1].variable_definitions)
pop(&scope_stack)
}
scope_variable_lookup :: proc(name: [dynamic]u8) -> ^Type {
name_ := name
for &scope in scope_stack {
type, ok := scope.variable_definitions[get_character_sum_of_dyn_arr(&name_)]
if ok {
return type
}
}
return nil
}
scope_function_lookup :: proc(name: [dynamic]u8) -> ^FunctionType {
name_ := name
for &scope in scope_stack {
type, ok := scope.function_definitions[get_character_sum_of_dyn_arr(&name_)]
if ok {
return type
}
}
return nil
}
type_check_function_call :: proc(ast: ^Node, parent_ast: ^Node, must_be_function := true) -> ^FunctionType {
name : [dynamic]u8
if ast.kind == .FunctionCall {
name = ast.children[0].value.([dynamic]u8)
} else {
name = ast.value.([dynamic]u8)
}
fn := scope_function_lookup(name)
if fn == nil {
if must_be_function {
append(&g_message_list,
message_create(
.Error,
fmt.aprintf("Undefined function: %s", name),
ast.range,
),
)
}
return nil
}
return fn
}
type_check :: proc(ast: ^Node, parent_ast: ^Node) {
#partial switch (ast.kind) {
case .Integer: fallthrough
case .Float: fallthrough
case .String:
infer_type(parent_ast, ast)
case .Block:
scope_enter()
functions := find_function_definitions(ast)
for fn, i in functions {
scope_stack[len(scope_stack) - 1].function_definitions[get_character_sum_of_dyn_arr(&fn.name)] = fn
}
for child in ast.children {
type_check(child, ast)
}
scope_leave()
case .FunctionCall:
fn := type_check_function_call(ast, parent_ast)
if fn != nil {
if len(fn.parameter_types) != len(ast.children) - 1 {
append(&g_message_list,
message_create(
.Error,
fmt.aprintf("Function call parameter count mismatch for function `%s`: {} and {}", fn.name, len(fn.parameter_types), len(ast.children) - 1),
ast.range,
),
)
break
}
for param, i in fn.parameter_types {
type_check(ast.children[i + 1], ast)
if !compare_types(param, ast.children[i + 1].return_type) {
append(&g_message_list,
message_create(
.Error,
fmt.aprintf("Type mismatch: {} and {}", param, ast.children[i + 1].return_type),
ast.range,
),
)
}
}
}
case .Identifier:
type := scope_variable_lookup(ast.value.([dynamic]u8))
if type == nil {
fn := type_check_function_call(ast, parent_ast, false)
if fn == nil {
append(&g_message_list,
message_create(
.Error,
fmt.aprintf("Undefined variable: %s", ast.value.([dynamic]u8)),
ast.range,
),
)
} else {
ast.return_type = fn.return_type
}
}
ast.return_type = type
case .BinaryExpression:
type_check(ast.children[0], ast)
type_check(ast.children[1], ast)
if !compare_types(ast.children[0].return_type, ast.children[1].return_type) {
append(&g_message_list,
message_create(
.Error,
fmt.aprintf("Type mismatch: {} and {}", ast.children[0].return_type, ast.children[1].return_type),
ast.range,
),
)
}
if ast.value_token_kind == .Assign {
if !scope_stack[len(scope_stack) - 1].variable_mutability_definitions[get_character_sum_of_dyn_arr(&ast.children[0].value.([dynamic]u8))] {
append(&g_message_list,
message_create(
.Error,
fmt.aprintf("Variable is not mutable: {}", ast.children[0].value.([dynamic]u8)),
ast.range,
),
)
}
}
ast.return_type = ast.children[0].return_type
// FIXME: Verify that the operation is possible
case .UnaryExpression:
// FIXME: Verify that the operation is possible
type_check(ast.children[0], ast)
case .Cast:
type_check(ast.children[0], ast)
type_to := ast_to_type(ast.children[1])
// FIXME: Check if compatible
ast.return_type = type_to
case .BitwiseCast:
type_check(ast.children[0], ast)
// FIXME: Check if they are both the same bit size
ast.return_type = ast_to_type(ast.children[1])
case .VariableDeclaration:
type_check(ast.children[2], ast)
if ast.children[1] == nil {
ast.return_type = ast.children[2].return_type
}
if !compare_types(ast.return_type, ast.children[2].return_type) {
append(&g_message_list,
message_create(
.Error,
fmt.aprintf("Type mismatch: {} and {}", ast.return_type, ast.children[2].return_type),
ast.range,
),
)
}
scope_stack[len(scope_stack) - 1].variable_definitions[get_character_sum_of_dyn_arr(&ast.children[0].value.([dynamic]u8))] = ast.return_type
case .Function:
// FIXME: Declare variables from params
type_check(ast.children[1], ast)
case:
fmt.panicf("Unhandled node kind in type_check: {}", ast.kind)
}
}
find_function_definitions :: proc(ast_: ^Node) -> (ret: [dynamic]^FunctionType) {
if ast_.kind != .Block {
return
}
for ast in ast_.children {
if ast == nil {
continue
}
#partial switch (ast.kind) {
case .Function:
for fn in ret {
if compare_dyn_arrs(&fn.name, &ast.value.([dynamic]u8)) {
append(&g_message_list,
message_create(
.Error,
fmt.aprintf("Function already defined: {}", ast.value.([dynamic]u8)),
ast.range,
),
)
continue
}
}
fn := function_type_create()
fn.name = ast.value.([dynamic]u8)
return_type : ^Type
if ast.children[0] == nil {
return_type = type_create_integer(0, false)
} else {
return_type = ast_to_type(ast.children[0])
}
node_print(ast)
for decl, i in ast.children {
if i < 2 {
continue
}
type := ast_to_type(decl.children[1])
append(&fn.parameter_types, type)
}
fmt.printf("Added: %s\n", fn.name)
append(&ret, fn)
case:
}
}
return
}