speedcat/src/type_checker.odin
Slendi 5af9845f93 Add command line arguments and make time printing nicer
This patch adds the following arguments:

1. `--dump-ast` or `-d`: This dumps the Abstract Syntax Tree
2. `--dont-emit-llvm` or `-L`: This skips the LLVM generation step only
   parsing and type checking. Useful for debugging.

Besides this, formatting of the time each compiler step took is also
improved and easier to read now.

Signed-off-by: Slendi <slendi@socopon.com>
2024-05-04 13:27:25 +02:00

830 lines
22 KiB
Odin

package main
import "core:fmt"
import "core:strconv"
StructField :: struct {
name: [dynamic]u8,
type: ^Type,
default_value: ^Node,
}
struct_field_create :: proc() -> ^StructField {
return new(StructField)
}
Struct :: struct {
fields: [dynamic]^StructField,
bit_size: u64,
}
struct_create :: proc() -> ^Struct {
s := new(Struct)
s.fields = [dynamic]^StructField{}
s.bit_size = 0
return s
}
struct_find_field :: proc(s: ^Struct, name: [dynamic]u8) -> ^StructField {
name_ := name
for field in s.fields {
if compare_dyn_arrs(&field.name, &name_) {
return field
}
}
return nil
}
Scope :: struct {
function_definitions: map[int]^FunctionType, // A map to nodes which are the function definitions
variable_definitions: map[int]^Type, // A map to types
variable_mutability_definitions: map[int]bool, // A map to a variable's mutability
function_return_type: ^Type,
structure_definitions: map[int]^Struct,
enum_definitions: map[int]^EnumValue,
}
find_struct :: proc(name: [dynamic]u8) -> ^Struct {
name_ := name
#reverse for &scope in scope_stack {
struct_, ok := scope.structure_definitions[get_character_sum_of_dyn_arr(&name_)]
if ok {
return struct_
}
}
return nil
}
@(private = "file")
infer_type :: proc(parent: ^Node, child: ^Node) {
if child.return_type == nil {
#partial switch child.kind {
case .Integer:
child.return_type = type_create_integer(32, true)
case .Float:
child.return_type = type_create_float(32)
case .String:
child.return_type = type_create_array(type_create_integer(8, false), 0)
case .Character:
child.return_type = type_create_integer(32, false)
}
} else {
if parent != nil {
parent.return_type = child.return_type
}
}
}
@(private = "file")
is_number :: proc(node: ^Node) -> bool {
return node.kind == .Integer || node.kind == .Float
}
@(private = "file")
ast_to_type :: proc(node: ^Node) -> ^Type {
if node == nil {
return type_create_integer(0, false)
}
if node.kind == .Identifier {
value := node.value.([dynamic]u8)
if value[0] == 'u' {
bit_size, ok := strconv.parse_u64_of_base(string(value[1:]), 10)
if !ok {
fmt.panicf("Failed to parse integer: %s", value)
}
return type_create_integer(u8(bit_size), false)
} else if value[0] == 'i' {
bit_size, ok := strconv.parse_u64_of_base(string(value[1:]), 10)
if !ok {
fmt.panicf("Failed to parse integer: %s", value)
}
return type_create_integer(u8(bit_size), true)
} else if value[0] == 'f' {
bit_size, ok := strconv.parse_u64_of_base(string(value[1:]), 10)
if !ok {
fmt.panicf("Failed to parse integer: %s", value)
}
return type_create_float(u8(bit_size))
} else {
res := scope_struct_lookup(value)
if res != nil {
return type_create_struct(value, res.bit_size)
}
append(&g_message_list, message_create(.Error, fmt.aprintf("Unknown type: %s", value), node.range))
return nil
}
} else if node.kind == .Pointer {
return type_create_pointer(ast_to_type(node.children[0]))
} else if node.kind == .Array {
return type_create_array(ast_to_type(node.children[0]), node.value.(u64))
} else {
fmt.panicf("Unhandled node kind in ast_to_type: {}", node.kind)
}
}
scope_stack := [dynamic]Scope{}
scope_enter :: proc() {
append(&scope_stack, Scope{})
scope_stack[len(scope_stack) - 1].function_definitions = make(map[int]^FunctionType)
scope_stack[len(scope_stack) - 1].variable_definitions = make(map[int]^Type)
scope_stack[len(scope_stack) - 1].variable_mutability_definitions = make(map[int]bool)
scope_stack[len(scope_stack) - 1].function_return_type = nil
}
scope_leave :: proc() {
if len(scope_stack) == 0 {
fmt.panicf("Tried to leave scope when there are no scopes")
}
delete(scope_stack[len(scope_stack) - 1].function_definitions)
delete(scope_stack[len(scope_stack) - 1].variable_definitions)
pop(&scope_stack)
}
scope_variable_lookup :: proc(name: [dynamic]u8) -> ^Type {
name_ := name
#reverse for &scope in scope_stack {
type, ok := scope.variable_definitions[get_character_sum_of_dyn_arr(&name_)]
if ok {
return type
}
}
return nil
}
scope_variable_lookup_mutable :: proc(name: [dynamic]u8) -> bool {
name_ := name
#reverse for &scope in scope_stack {
type, ok := scope.variable_mutability_definitions[get_character_sum_of_dyn_arr(&name_)]
if ok {
return type
}
}
return false
}
scope_function_lookup :: proc(name: [dynamic]u8) -> ^FunctionType {
name_ := name
#reverse for &scope in scope_stack {
type, ok := scope.function_definitions[get_character_sum_of_dyn_arr(&name_)]
if ok {
return type
}
}
return nil
}
scope_struct_lookup :: proc(name: [dynamic]u8) -> ^Struct {
name_ := name
#reverse for &scope in scope_stack {
struct_, ok := scope.structure_definitions[get_character_sum_of_dyn_arr(&name_)]
if ok {
return struct_
}
}
return nil
}
scope_function_return_type_lookup :: proc() -> ^Type {
#reverse for &scope in scope_stack {
if scope.function_return_type != nil {
return scope.function_return_type
}
}
return nil
}
type_check_function_call :: proc(ast: ^Node, parent_ast: ^Node, must_be_function := true) -> ^FunctionType {
name: [dynamic]u8
if ast.kind == .FunctionCall {
name = ast.children[0].value.([dynamic]u8)
} else {
name = ast.value.([dynamic]u8)
}
fn := scope_function_lookup(name)
if fn == nil {
if must_be_function {
append(&g_message_list, message_create(.Error, fmt.aprintf("Undefined function: %s", name), ast.range))
}
return nil
}
return fn
}
type_check :: proc(ast: ^Node, parent_ast: ^Node) {
in_extern := false
if ast == nil {
return
}
#partial switch (ast.kind) {
case .Integer:
fallthrough
case .Float:
fallthrough
case .String:
infer_type(parent_ast, ast)
case .Block:
scope_enter()
functions := find_function_definitions(ast)
for fn, i in functions {
scope_stack[len(scope_stack) - 1].function_definitions[get_character_sum_of_dyn_arr(&fn.name)] = fn
}
for child in ast.children {
type_check(child, ast)
}
scope_leave()
case .FieldAccess:
lhs := ast.children[0]
rhs := ast.children[1]
// FIXME: Add support for nesting
struct_: ^Struct
if lhs.kind != .FieldAccess {
struct_var := scope_variable_lookup(lhs.value.([dynamic]u8))
if struct_var == nil {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf("Cannot find struct of name: `%s`", lhs.value.([dynamic]u8)),
lhs.range,
),
)
break
}
struct_ = scope_struct_lookup(struct_var.struct_type.name)
if struct_ == nil {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf("Cannot find struct of type name: `%s`", lhs.value.([dynamic]u8)),
lhs.range,
),
)
break
}
}
if rhs.kind != .Identifier {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf("Field access rhs is not an identifier or field access: {}", rhs.kind),
rhs.range,
),
)
break
}
if lhs.kind == .FieldAccess {
type_check(lhs, ast)
if lhs.return_type.kind != .Struct {
append(&g_message_list, message_create(.Error, fmt.aprintf("LHS is not a Struct type"), lhs.range))
}
struct_ = scope_struct_lookup(lhs.return_type.struct_type.name)
if struct_ == nil {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf("Cannot find struct of type name: `%s`", lhs.value.([dynamic]u8)),
lhs.range,
),
)
break
}
} else if lhs.kind != .Identifier {
append(
&g_message_list,
message_create(.Error, fmt.aprintf("Field access lhs is not an identifier or FieldAccess"), lhs.range),
)
break
}
struct_index: u64 = 0
found_field := false
for &field, i in struct_.fields {
if compare_dyn_arrs(&field.name, &rhs.value.([dynamic]u8)) {
ast.return_type = field.type
found_field = true
struct_index = u64(i)
break
}
}
if !found_field {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf("Cannot find field of name: `%s`", rhs.value.([dynamic]u8)),
rhs.range,
),
)
break
}
ast.return_type.struct_index = struct_index
case .FunctionCall:
if ast.children[0].kind == .FieldAccess {
// FIXME: This is some temporary shitfuckery, check if a function is part
// of a struct or namespace first, then do this shit
type_check(ast.children[0], ast)
child := ast.children[0]^
free(ast.children[0])
clear(&ast.children)
ast^ = child
return
}
type := scope_variable_lookup(ast.children[0].value.([dynamic]u8))
if type != nil {
name := ast.children[0].value.([dynamic]u8)
free(ast.children[0])
clear(&ast.children)
ast.return_type = type
ast.kind = .Identifier
ast.value = name
type_check(ast, parent_ast)
} else {
fn := type_check_function_call(ast, parent_ast)
if fn != nil {
if len(fn.parameter_types) != len(ast.children) - 1 {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf(
"Function call parameter count mismatch for function `%s`: {} and {}",
fn.name,
len(fn.parameter_types),
len(ast.children) - 1,
),
ast.range,
),
)
break
}
for param, i in fn.parameter_types {
type_check(ast.children[i + 1], ast)
ok, cast_required := compare_types(param, ast.children[i + 1].return_type)
if cast_required {
cast_ := node_create_cast({}, ast.children[i + 1], nil)
cast_.return_type = param
ast.children[i + 1] = cast_
}
if !ok {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf(
"Type mismatch in function call for `%s`: Wanted {}, got {}",
fn.name,
type_to_string(param),
type_to_string(ast.children[i + 1].return_type),
),
ast.range,
),
)
}
}
ast.return_type = fn.return_type
}
}
case .Identifier:
type := scope_variable_lookup(ast.value.([dynamic]u8))
if type == nil {
fn := type_check_function_call(ast, parent_ast, false)
if fn == nil {
append(&g_message_list, message_create(.Warning, "Variable name treated as string", ast.range))
ast.kind = .String
append(&ast.value.([dynamic]u8), 0)
type_check(ast, parent_ast)
} else {
ast.kind = .FunctionCall
append(&ast.children, node_create_value(.Identifier, ast.range, ast.value))
ast.return_type = fn.return_type
ast.value = nil
}
} else {
ast.return_type = type
}
case .BinaryExpression:
type_check(ast.children[0], ast)
type_check(ast.children[1], ast)
ok, cast_required := compare_types(ast.children[0].return_type, ast.children[1].return_type)
if cast_required {
cast_ := node_create_cast(ast.children[1].range, ast.children[1], nil)
cast_.return_type = ast.children[0].return_type
ast.children[1] = cast_
}
if !ok {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf(
"Type mismatch: {} and {}",
type_to_string(ast.children[0].return_type),
type_to_string(ast.children[1].return_type),
),
ast.range,
),
)
}
ast.return_type = ast.children[1].return_type
if ast.value_token_kind == .Assign {
if ast.children[0].kind != .Identifier && ast.children[0].kind != .FieldAccess {
append(&g_message_list, message_create(.Error, fmt.aprintf("LHS of assignment is invalid"), ast.range))
}
if !scope_variable_lookup_mutable(ast.children[0].value.([dynamic]u8)) {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf("Variable is not mutable: %s", ast.children[0].value.([dynamic]u8)),
ast.range,
),
)
}
} else if ast.value_token_kind == .Equals ||
ast.value_token_kind == .NotEquals ||
ast.value_token_kind == .GreaterThan ||
ast.value_token_kind == .GreaterThanOrEqual ||
ast.value_token_kind == .LessThan ||
ast.value_token_kind == .LessThanOrEqual {
ast.return_type = type_create_integer(1, true)
}
// FIXME: Verify that the operation is possible
case .UnaryExpression:
// FIXME: Verify that the operation is possible
type_check(ast.children[0], ast)
append(&g_message_list, message_create(.FIXME, fmt.aprintf("Check type in unary expression"), ast.range))
ast.return_type = ast.children[0].return_type
if ast.value_token_kind == .Increment || ast.value_token_kind == .Decrement {
if ast.value.(bool) {
ast^ = ast.children[0]^
append(&g_message_list, message_create(.FIXME, fmt.aprintf("Implement postfix inc/dec"), ast.range))
} else {
ast.kind = .BinaryExpression
var := ast.children[0]
op: ^Node
if ast.value_token_kind == .Increment {
op = node_create_binary(.Add, ast.range, var, node_create_value(.Integer, ast.range, 1))
} else {
op = node_create_binary(.Subtract, ast.range, var, node_create_value(.Integer, ast.range, 1))
}
append(&ast.children, op)
type_check(ast.children[1], ast)
ast.value_token_kind = .Assign
}
}
case .Ret:
function_return_type := scope_function_return_type_lookup()
if function_return_type == nil {
append(
&g_message_list,
message_create(.Error, fmt.aprintf("Return statement outside of function"), ast.range),
)
} else {
if function_return_type.kind == .Integer && function_return_type.bit_size == 0 && ast.children[0] == nil {
break
}
type_check(ast.children[0], ast)
ok, cast_required := compare_types(function_return_type, ast.children[0].return_type)
if cast_required {
cast_ := node_create_cast({}, ast.children[0], nil)
cast_.return_type = function_return_type
ast.children[0] = cast_
}
if !ok {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf(
"Type mismatch: {} and {}",
type_to_string(function_return_type),
type_to_string(ast.children[0].return_type),
),
ast.range,
),
)
}
}
case .Cast:
type_check(ast.children[0], ast)
type_to := ast_to_type(ast.children[1])
if ast.children[0].return_type.kind == .Struct || type_to.kind == .Struct {
append(&g_message_list, message_create(.Error, "Cannot cast to/from Struct type.", ast.range))
} else {
// FIXME: Check if compatible
append(
&g_message_list,
message_create(
.FIXME,
fmt.aprintf("Cast to type not checked: %s.", ast.children[1].value.([dynamic]u8)),
ast.range,
),
)
}
ast.return_type = type_to
case .BitwiseCast:
type_check(ast.children[0], ast)
// FIXME: Check if they are both the same bit size
append(
&g_message_list,
message_create(.FIXME, fmt.aprintf("BitwiseCast bit size check not implemented."), ast.range),
)
ast.return_type = ast_to_type(ast.children[1])
case .VariableDeclaration:
name_sum := get_character_sum_of_dyn_arr(&ast.children[0].value.([dynamic]u8))
if name_sum in scope_stack[len(scope_stack) - 1].variable_definitions {
append(
&g_message_list,
message_create(.Error, "A variable is already declared with the same name", ast.range),
)
return
}
if ast.children[2] != nil {
type_check(ast.children[2], ast)
if ast.children[1] == nil {
ast.return_type = ast.children[2].return_type
}
ok, cast_required := compare_types(ast.return_type, ast.children[2].return_type)
if cast_required {
cast_ := node_create_cast({}, ast.children[2], nil)
cast_.return_type = ast.return_type
ast.children[2] = cast_
}
if !ok {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf("Type mismatch: {} and {}", ast.return_type, ast.children[2].return_type),
ast.range,
),
)
return
}
} else {
ast.return_type = ast_to_type(ast.children[1])
}
scope_stack[len(scope_stack) - 1].variable_definitions[name_sum] = ast.return_type
scope_stack[len(scope_stack) - 1].variable_mutability_definitions[name_sum] = !ast.value.(bool)
case .If:
type_check(ast.children[0], ast)
if ast.children[0].return_type == nil || ast.children[0].return_type.kind != .Integer {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf("If condition must be a signed/unsigned integer"),
ast.children[0].range,
),
)
break
}
type_check(ast.children[1], ast)
if len(ast.children) == 3 {
type_check(ast.children[2], ast)
} else {
append(&ast.children, node_create_block(ast.range, {}))
}
case .ExternFunction:
in_extern = true
fallthrough
case .Function:
scope_enter()
ast.return_type = ast_to_type(ast.children[0])
scope_stack[len(scope_stack) - 1].function_return_type = ast.return_type
for child, i in ast.children {
if in_extern == false {
if i < 2 {
continue
}
} else {
if i < 1 {
continue
}
}
type_check(child, ast)
scope_stack[len(scope_stack) - 1].variable_definitions[get_character_sum_of_dyn_arr(&child.children[0].value.([dynamic]u8))] =
child.return_type
scope_stack[len(scope_stack) - 1].variable_mutability_definitions[get_character_sum_of_dyn_arr(&child.children[0].value.([dynamic]u8))] =
true
}
if in_extern == false {
type_check(ast.children[1], ast)
if ast.return_type.kind == .Integer && ast.return_type.bit_size == 0 {
append(&ast.children[1].children, node_create_ret(ast.children[1].range, nil))
}
}
scope_leave()
case .For:
scope_enter()
for child, i in ast.children {
if child == nil {
continue
}
if i == 1 {
type_check(child, ast)
should_error := false
if child.return_type == nil {
should_error = true
} else if child.return_type.kind != .Integer {
should_error = true
}
if should_error {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf("For condition must be a signed/unsigned integer"),
child.range,
),
)
break
}
} else {
type_check(child, ast)
}
}
scope_leave()
case .Struct: // Nothing
case .Enum: // Nothing
case .StructInitializer:
for child in ast.children {
type_check(child, ast)
}
struct_ := find_struct(ast.value.([dynamic]u8))
if struct_ == nil {
append(
&g_message_list,
message_create(.Error, fmt.aprintf("Undefined struct: %s", ast.value.([dynamic]u8)), ast.range),
)
break
}
if len(ast.children) != len(struct_.fields) {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf(
"Struct initializer field count mismatch: Wanted {}, got {}",
len(struct_.fields),
len(ast.children),
),
ast.range,
),
)
break
}
idx := 0
for &child in ast.children {
ok, cast_required := compare_types(child.return_type, struct_.fields[idx].type)
if cast_required {
cast_ := node_create_cast({}, child, {})
cast_.return_type = struct_.fields[idx].type
child = cast_
}
if !ok {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf(
"Type mismatch in struct initializer: Wanted {}, got {}",
type_to_string(struct_.fields[idx].type),
type_to_string(child.return_type),
),
child.range,
),
)
}
idx += 1
}
ast.return_type = type_create_struct(ast.value.([dynamic]u8), struct_.bit_size)
case:
fmt.panicf("Unhandled node kind in type_check: {}", ast.kind)
}
}
find_function_definitions :: proc(ast_: ^Node) -> (ret: [dynamic]^FunctionType) {
if ast_.kind != .Block {
return
}
for ast in ast_.children {
if ast == nil {
continue
}
is_extern := false
#partial switch (ast.kind) {
case .ExternFunction:
is_extern = true
fallthrough
case .Function:
for fn in ret {
if compare_dyn_arrs(&fn.name, &ast.value.([dynamic]u8)) {
append(
&g_message_list,
message_create(
.Error,
fmt.aprintf("Function already defined: {}", ast.value.([dynamic]u8)),
ast.range,
),
)
continue
}
}
fn := function_type_create()
fn.name = ast.value.([dynamic]u8)
return_type: ^Type
if ast.children[0] == nil {
return_type = type_create_integer(0, false)
} else {
return_type = ast_to_type(ast.children[0])
}
fn.return_type = return_type
for decl, i in ast.children {
if is_extern == false {
if i < 2 {
continue
}
} else {
if i < 1 {
continue
}
}
type := ast_to_type(decl.children[1])
append(&fn.parameter_types, type)
}
append(&ret, fn)
case .Struct:
struct_ := struct_create()
should_ignore := true
for field in ast.children {
if field == nil {
continue
}
if should_ignore {
should_ignore = false
continue
}
if field.kind != .VariableDeclaration {
fmt.panicf("Expected VariableDeclaration in struct")
}
struct_field := struct_field_create()
struct_field.name = field.children[0].value.([dynamic]u8)
struct_field.type = ast_to_type(field.children[1])
struct_.bit_size += type_get_bit_size(struct_field.type)
field.return_type = struct_field.type
append(&struct_.fields, struct_field)
}
scope_stack[len(scope_stack) - 1].structure_definitions[get_character_sum_of_dyn_arr(&ast.children[0].value.([dynamic]u8))] =
struct_
case .Enum:
if ast.enum_value.type.kind != .Identifier {
append(
&g_message_list,
message_create(.Error, "The type of this struct must be a scalar", ast.enum_value.type.range),
)
}
type_ptr := ast_to_type(ast.enum_value.type)
if type_ptr.kind != .Integer {
append(
&g_message_list,
message_create(.Error, "The type of this struct must be an integer", ast.enum_value.type.range),
)
}
scope_stack[len(scope_stack) - 1].enum_definitions[get_character_sum_of_dyn_arr(&ast.value.([dynamic]u8))] =
&ast.enum_value
case:
}
}
return
}