package main import "core:fmt" import "core:strconv" StructField :: struct { name: [dynamic]u8, type: ^Type, default_value: ^Node, } struct_field_create :: proc() -> ^StructField { return new(StructField) } Struct :: struct { fields: [dynamic]^StructField, } struct_create :: proc() -> ^Struct { s := new(Struct) s.fields = [dynamic]^StructField{} return s } struct_find_field :: proc(s: ^Struct, name: [dynamic]u8) -> ^StructField { name_ := name for field in s.fields { if compare_dyn_arrs(&field.name, &name_) { return field } } return nil } Scope :: struct { function_definitions: map[int]^FunctionType, // A map to nodes which are the function definitions variable_definitions: map[int]^Type, // A map to types variable_mutability_definitions: map[int]bool, // A map to a variable's mutability function_return_type: ^Type, structure_definitions: map[int]^Struct, enum_definitions: map[int]^EnumValue, } find_struct :: proc(name: [dynamic]u8) -> ^Struct { name_ := name #reverse for &scope in scope_stack { struct_, ok := scope.structure_definitions[get_character_sum_of_dyn_arr(&name_)] if ok { return struct_ } } return nil } @(private = "file") infer_type :: proc(parent: ^Node, child: ^Node) { if child.return_type == nil { #partial switch child.kind { case .Integer: child.return_type = type_create_integer(32, true) case .Float: child.return_type = type_create_float(32) case .String: child.return_type = type_create_array(type_create_integer(8, false), 0) case .Character: child.return_type = type_create_integer(32, false) } } else { if parent != nil { parent.return_type = child.return_type } } } @(private = "file") is_number :: proc(node: ^Node) -> bool { return node.kind == .Integer || node.kind == .Float } @(private = "file") ast_to_type :: proc(node: ^Node) -> ^Type { if node == nil { return type_create_integer(0, false) } if node.kind == .Identifier { value := node.value.([dynamic]u8) if value[0] == 'u' { bit_size, ok := strconv.parse_u64_of_base(string(value[1:]), 10) if !ok { fmt.panicf("Failed to parse integer: %s", value) } return type_create_integer(u8(bit_size), false) } else if value[0] == 'i' { bit_size, ok := strconv.parse_u64_of_base(string(value[1:]), 10) if !ok { fmt.panicf("Failed to parse integer: %s", value) } return type_create_integer(u8(bit_size), true) } else if value[0] == 'f' { bit_size, ok := strconv.parse_u64_of_base(string(value[1:]), 10) if !ok { fmt.panicf("Failed to parse integer: %s", value) } return type_create_float(u8(bit_size)) } else { res := scope_struct_lookup(value) if res != nil { return type_create_struct(value) } append(&g_message_list, message_create(.Error, fmt.aprintf("Unknown type: %s", value), node.range)) return nil } } else if node.kind == .Pointer { return type_create_pointer(ast_to_type(node.children[0])) } else if node.kind == .Array { return type_create_array(ast_to_type(node.children[0]), node.value.(u64)) } else { fmt.panicf("Unhandled node kind in ast_to_type: {}", node.kind) } } scope_stack := [dynamic]Scope{} scope_enter :: proc() { append(&scope_stack, Scope{}) scope_stack[len(scope_stack) - 1].function_definitions = make(map[int]^FunctionType) scope_stack[len(scope_stack) - 1].variable_definitions = make(map[int]^Type) scope_stack[len(scope_stack) - 1].variable_mutability_definitions = make(map[int]bool) scope_stack[len(scope_stack) - 1].function_return_type = nil } scope_leave :: proc() { if len(scope_stack) == 0 { fmt.panicf("Tried to leave scope when there are no scopes") } delete(scope_stack[len(scope_stack) - 1].function_definitions) delete(scope_stack[len(scope_stack) - 1].variable_definitions) pop(&scope_stack) } scope_variable_lookup :: proc(name: [dynamic]u8) -> ^Type { name_ := name #reverse for &scope in scope_stack { type, ok := scope.variable_definitions[get_character_sum_of_dyn_arr(&name_)] if ok { return type } } return nil } scope_variable_lookup_mutable :: proc(name: [dynamic]u8) -> bool { name_ := name #reverse for &scope in scope_stack { type, ok := scope.variable_mutability_definitions[get_character_sum_of_dyn_arr(&name_)] if ok { return type } } return false } scope_function_lookup :: proc(name: [dynamic]u8) -> ^FunctionType { name_ := name #reverse for &scope in scope_stack { type, ok := scope.function_definitions[get_character_sum_of_dyn_arr(&name_)] if ok { return type } } return nil } scope_struct_lookup :: proc(name: [dynamic]u8) -> ^Struct { name_ := name #reverse for &scope in scope_stack { struct_, ok := scope.structure_definitions[get_character_sum_of_dyn_arr(&name_)] if ok { return struct_ } } return nil } scope_function_return_type_lookup :: proc() -> ^Type { #reverse for &scope in scope_stack { if scope.function_return_type != nil { return scope.function_return_type } } return nil } type_check_function_call :: proc(ast: ^Node, parent_ast: ^Node, must_be_function := true) -> ^FunctionType { name: [dynamic]u8 if ast.kind == .FunctionCall { name = ast.children[0].value.([dynamic]u8) } else { name = ast.value.([dynamic]u8) } fn := scope_function_lookup(name) if fn == nil { if must_be_function { append(&g_message_list, message_create(.Error, fmt.aprintf("Undefined function: %s", name), ast.range)) } return nil } return fn } type_check :: proc(ast: ^Node, parent_ast: ^Node) { in_extern := false if ast == nil { return } #partial switch (ast.kind) { case .Integer: fallthrough case .Float: fallthrough case .String: infer_type(parent_ast, ast) case .Block: scope_enter() functions := find_function_definitions(ast) for fn, i in functions { scope_stack[len(scope_stack) - 1].function_definitions[get_character_sum_of_dyn_arr(&fn.name)] = fn } for child in ast.children { type_check(child, ast) } scope_leave() case .FieldAccess: lhs := ast.children[0] rhs := ast.children[1] // FIXME: Add support for nesting struct_: ^Struct if lhs.kind != .FieldAccess { struct_var := scope_variable_lookup(lhs.value.([dynamic]u8)) if struct_var == nil { append( &g_message_list, message_create( .Error, fmt.aprintf("Cannot find struct of name: `%s`", lhs.value.([dynamic]u8)), lhs.range, ), ) break } struct_ = scope_struct_lookup(struct_var.struct_type.name) if struct_ == nil { append( &g_message_list, message_create( .Error, fmt.aprintf("Cannot find struct of type name: `%s`", lhs.value.([dynamic]u8)), lhs.range, ), ) break } } if rhs.kind != .Identifier { append( &g_message_list, message_create( .Error, fmt.aprintf("Field access rhs is not an identifier or field access: {}", rhs.kind), rhs.range, ), ) break } if lhs.kind == .FieldAccess { type_check(lhs, ast) if lhs.return_type.kind != .Struct { append(&g_message_list, message_create(.Error, fmt.aprintf("LHS is not a Struct type"), lhs.range)) } struct_ = scope_struct_lookup(lhs.return_type.struct_type.name) if struct_ == nil { append( &g_message_list, message_create( .Error, fmt.aprintf("Cannot find struct of type name: `%s`", lhs.value.([dynamic]u8)), lhs.range, ), ) break } } else if lhs.kind != .Identifier { append( &g_message_list, message_create(.Error, fmt.aprintf("Field access lhs is not an identifier or FieldAccess"), lhs.range), ) break } struct_index: u64 = 0 found_field := false for &field, i in struct_.fields { if compare_dyn_arrs(&field.name, &rhs.value.([dynamic]u8)) { ast.return_type = field.type found_field = true struct_index = u64(i) break } } if !found_field { append( &g_message_list, message_create( .Error, fmt.aprintf("Cannot find field of name: `%s`", rhs.value.([dynamic]u8)), rhs.range, ), ) break } ast.return_type.struct_index = struct_index case .FunctionCall: if ast.children[0].kind == .FieldAccess { // FIXME: This is some temporary shitfuckery, check if a function is part // of a struct or namespace first, then do this shit type_check(ast.children[0], ast) child := ast.children[0]^ free(ast.children[0]) clear(&ast.children) ast^ = child return } type := scope_variable_lookup(ast.children[0].value.([dynamic]u8)) if type != nil { name := ast.children[0].value.([dynamic]u8) free(ast.children[0]) clear(&ast.children) ast.return_type = type ast.kind = .Identifier ast.value = name type_check(ast, parent_ast) } else { fn := type_check_function_call(ast, parent_ast) if fn != nil { if len(fn.parameter_types) != len(ast.children) - 1 { append( &g_message_list, message_create( .Error, fmt.aprintf( "Function call parameter count mismatch for function `%s`: {} and {}", fn.name, len(fn.parameter_types), len(ast.children) - 1, ), ast.range, ), ) break } for param, i in fn.parameter_types { type_check(ast.children[i + 1], ast) ok, cast_required := compare_types(param, ast.children[i + 1].return_type) if cast_required { cast_ := node_create_cast({}, ast.children[i + 1], nil) cast_.return_type = param ast.children[i + 1] = cast_ } if !ok { append( &g_message_list, message_create( .Error, fmt.aprintf( "Type mismatch in function call for `%s`: Wanted {}, got {}", fn.name, type_to_string(param), type_to_string(ast.children[i + 1].return_type), ), ast.range, ), ) } } ast.return_type = fn.return_type } } case .Identifier: type := scope_variable_lookup(ast.value.([dynamic]u8)) if type == nil { fn := type_check_function_call(ast, parent_ast, false) if fn == nil { append(&g_message_list, message_create(.Warning, "Variable name treated as string", ast.range)) ast.kind = .String append(&ast.value.([dynamic]u8), 0) type_check(ast, parent_ast) } else { ast.kind = .FunctionCall append(&ast.children, node_create_value(.Identifier, ast.range, ast.value)) ast.return_type = fn.return_type ast.value = nil } } else { ast.return_type = type } case .BinaryExpression: type_check(ast.children[0], ast) type_check(ast.children[1], ast) ok, cast_required := compare_types(ast.children[0].return_type, ast.children[1].return_type) if cast_required { cast_ := node_create_cast(ast.children[1].range, ast.children[1], nil) cast_.return_type = ast.children[0].return_type ast.children[1] = cast_ } if !ok { append( &g_message_list, message_create( .Error, fmt.aprintf( "Type mismatch: {} and {}", type_to_string(ast.children[0].return_type), type_to_string(ast.children[1].return_type), ), ast.range, ), ) } ast.return_type = ast.children[1].return_type if ast.value_token_kind == .Assign { if ast.children[0].kind != .Identifier && ast.children[0].kind != .FieldAccess { append(&g_message_list, message_create(.Error, fmt.aprintf("LHS of assignment is invalid"), ast.range)) } if !scope_variable_lookup_mutable(ast.children[0].value.([dynamic]u8)) { append( &g_message_list, message_create( .Error, fmt.aprintf("Variable is not mutable: %s", ast.children[0].value.([dynamic]u8)), ast.range, ), ) } } else if ast.value_token_kind == .Equals || ast.value_token_kind == .NotEquals || ast.value_token_kind == .GreaterThan || ast.value_token_kind == .GreaterThanOrEqual || ast.value_token_kind == .LessThan || ast.value_token_kind == .LessThanOrEqual { ast.return_type = type_create_integer(1, true) } // FIXME: Verify that the operation is possible case .UnaryExpression: // FIXME: Verify that the operation is possible type_check(ast.children[0], ast) append(&g_message_list, message_create(.FIXME, fmt.aprintf("Check type in unary expression"), ast.range)) ast.return_type = ast.children[0].return_type if ast.value_token_kind == .Increment || ast.value_token_kind == .Decrement { if ast.value.(bool) { ast^ = ast.children[0]^ append(&g_message_list, message_create(.FIXME, fmt.aprintf("Implement postfix inc/dec"), ast.range)) } else { ast.kind = .BinaryExpression var := ast.children[0] op: ^Node if ast.value_token_kind == .Increment { op = node_create_binary(.Add, ast.range, var, node_create_value(.Integer, ast.range, 1)) } else { op = node_create_binary(.Subtract, ast.range, var, node_create_value(.Integer, ast.range, 1)) } append(&ast.children, op) type_check(ast.children[1], ast) ast.value_token_kind = .Assign } } case .Ret: function_return_type := scope_function_return_type_lookup() if function_return_type == nil { append( &g_message_list, message_create(.Error, fmt.aprintf("Return statement outside of function"), ast.range), ) } else { if function_return_type.kind == .Integer && function_return_type.bit_size == 0 && ast.children[0] == nil { break } type_check(ast.children[0], ast) ok, cast_required := compare_types(function_return_type, ast.children[0].return_type) if cast_required { cast_ := node_create_cast({}, ast.children[0], nil) cast_.return_type = function_return_type ast.children[0] = cast_ } if !ok { append( &g_message_list, message_create( .Error, fmt.aprintf( "Type mismatch: {} and {}", type_to_string(function_return_type), type_to_string(ast.children[0].return_type), ), ast.range, ), ) } } case .Cast: type_check(ast.children[0], ast) type_to := ast_to_type(ast.children[1]) if ast.children[0].return_type.kind == .Struct || type_to.kind == .Struct { append(&g_message_list, message_create(.Error, "Cannot cast to/from Struct type.", ast.range)) } else { // FIXME: Check if compatible append( &g_message_list, message_create( .FIXME, fmt.aprintf("Cast to type not checked: %s.", ast.children[1].value.([dynamic]u8)), ast.range, ), ) } ast.return_type = type_to case .BitwiseCast: type_check(ast.children[0], ast) // FIXME: Check if they are both the same bit size append( &g_message_list, message_create(.FIXME, fmt.aprintf("BitwiseCast bit size check not implemented."), ast.range), ) ast.return_type = ast_to_type(ast.children[1]) case .VariableDeclaration: name_sum := get_character_sum_of_dyn_arr(&ast.children[0].value.([dynamic]u8)) if name_sum in scope_stack[len(scope_stack) - 1].variable_definitions { append( &g_message_list, message_create(.Error, "A variable is already declared with the same name", ast.range), ) return } if ast.children[2] != nil { type_check(ast.children[2], ast) if ast.children[1] == nil { ast.return_type = ast.children[2].return_type } ok, cast_required := compare_types(ast.return_type, ast.children[2].return_type) if cast_required { cast_ := node_create_cast({}, ast.children[2], nil) cast_.return_type = ast.return_type ast.children[2] = cast_ } if !ok { append( &g_message_list, message_create( .Error, fmt.aprintf("Type mismatch: {} and {}", ast.return_type, ast.children[2].return_type), ast.range, ), ) return } } else { ast.return_type = ast_to_type(ast.children[1]) } scope_stack[len(scope_stack) - 1].variable_definitions[name_sum] = ast.return_type scope_stack[len(scope_stack) - 1].variable_mutability_definitions[name_sum] = !ast.value.(bool) case .If: type_check(ast.children[0], ast) if ast.children[0].return_type == nil || ast.children[0].return_type.kind != .Integer { append( &g_message_list, message_create( .Error, fmt.aprintf("If condition must be a signed/unsigned integer"), ast.children[0].range, ), ) break } type_check(ast.children[1], ast) if len(ast.children) == 3 { type_check(ast.children[2], ast) } else { append(&ast.children, node_create_block(ast.range, {})) } case .ExternFunction: in_extern = true fallthrough case .Function: scope_enter() ast.return_type = ast_to_type(ast.children[0]) scope_stack[len(scope_stack) - 1].function_return_type = ast.return_type for child, i in ast.children { if in_extern == false { if i < 2 { continue } } else { if i < 1 { continue } } type_check(child, ast) scope_stack[len(scope_stack) - 1].variable_definitions[get_character_sum_of_dyn_arr(&child.children[0].value.([dynamic]u8))] = child.return_type scope_stack[len(scope_stack) - 1].variable_mutability_definitions[get_character_sum_of_dyn_arr(&child.children[0].value.([dynamic]u8))] = true } if in_extern == false { type_check(ast.children[1], ast) if ast.return_type.kind == .Integer && ast.return_type.bit_size == 0 { append(&ast.children[1].children, node_create_ret(ast.children[1].range, nil)) } } scope_leave() case .For: scope_enter() for child, i in ast.children { if child == nil { continue } if i == 1 { type_check(child, ast) should_error := false if child.return_type == nil { should_error = true } else if child.return_type.kind != .Integer { should_error = true } if should_error { append( &g_message_list, message_create( .Error, fmt.aprintf("For condition must be a signed/unsigned integer"), child.range, ), ) break } } else { type_check(child, ast) } } scope_leave() case .Struct: // Nothing case .Enum: // Nothing case .StructInitializer: for child in ast.children { type_check(child, ast) } node_print(ast) struct_ := find_struct(ast.value.([dynamic]u8)) if struct_ == nil { append( &g_message_list, message_create(.Error, fmt.aprintf("Undefined struct: %s", ast.value.([dynamic]u8)), ast.range), ) break } if len(ast.children) != len(struct_.fields) { append( &g_message_list, message_create( .Error, fmt.aprintf( "Struct initializer field count mismatch: Wanted {}, got {}", len(struct_.fields), len(ast.children), ), ast.range, ), ) break } idx := 0 for &child in ast.children { fmt.printf("Comp {} and {} (struct f)\n", child.return_type.kind, struct_.fields[idx].type.kind) ok, cast_required := compare_types(child.return_type, struct_.fields[idx].type) if cast_required { cast_ := node_create_cast({}, child, {}) cast_.return_type = struct_.fields[idx].type child = cast_ } if !ok { append( &g_message_list, message_create( .Error, fmt.aprintf( "Type mismatch in struct initializer: Wanted {}, got {}", type_to_string(struct_.fields[idx].type), type_to_string(child.return_type), ), child.range, ), ) } idx += 1 } ast.return_type = type_create_struct(ast.value.([dynamic]u8)) case: fmt.panicf("Unhandled node kind in type_check: {}", ast.kind) } } find_function_definitions :: proc(ast_: ^Node) -> (ret: [dynamic]^FunctionType) { if ast_.kind != .Block { return } for ast in ast_.children { if ast == nil { continue } is_extern := false #partial switch (ast.kind) { case .ExternFunction: is_extern = true fallthrough case .Function: for fn in ret { if compare_dyn_arrs(&fn.name, &ast.value.([dynamic]u8)) { append( &g_message_list, message_create( .Error, fmt.aprintf("Function already defined: {}", ast.value.([dynamic]u8)), ast.range, ), ) continue } } fn := function_type_create() fn.name = ast.value.([dynamic]u8) return_type: ^Type if ast.children[0] == nil { return_type = type_create_integer(0, false) } else { return_type = ast_to_type(ast.children[0]) } fn.return_type = return_type for decl, i in ast.children { if is_extern == false { if i < 2 { continue } } else { if i < 1 { continue } } type := ast_to_type(decl.children[1]) append(&fn.parameter_types, type) } append(&ret, fn) case .Struct: struct_ := struct_create() should_ignore := true for field in ast.children { if field == nil { continue } if should_ignore { should_ignore = false continue } if field.kind != .VariableDeclaration { fmt.panicf("Expected VariableDeclaration in struct") } struct_field := struct_field_create() struct_field.name = field.children[0].value.([dynamic]u8) struct_field.type = ast_to_type(field.children[1]) field.return_type = struct_field.type append(&struct_.fields, struct_field) } scope_stack[len(scope_stack) - 1].structure_definitions[get_character_sum_of_dyn_arr(&ast.children[0].value.([dynamic]u8))] = struct_ case .Enum: if ast.enum_value.type.kind != .Identifier { append( &g_message_list, message_create(.Error, "The type of this struct must be a scalar", ast.enum_value.type.range), ) } type_ptr := ast_to_type(ast.enum_value.type) if type_ptr.kind != .Integer { append( &g_message_list, message_create(.Error, "The type of this struct must be an integer", ast.enum_value.type.range), ) } scope_stack[len(scope_stack) - 1].enum_definitions[get_character_sum_of_dyn_arr(&ast.value.([dynamic]u8))] = &ast.enum_value case: } } return }