From 7ccc3e32bae4f7978aa604d0bcb0ef5fd95846f0 Mon Sep 17 00:00:00 2001 From: night0721 <night@night0721.xyz> Date: Tue, 21 Jan 2025 03:25:17 +0000 Subject: [PATCH] Return statements --- include/ast.h | 4 ++- include/env.h | 2 +- src/env.c | 3 ++ src/interpreter.c | 78 ++++++++++++++++++++++++++++++++----------- src/parser.c | 84 ++++++++++++++++++++++++++++++++--------------- 5 files changed, 122 insertions(+), 49 deletions(-) diff --git a/include/ast.h b/include/ast.h index be33c53..8a90486 100644 --- a/include/ast.h +++ b/include/ast.h @@ -51,6 +51,7 @@ typedef enum { STMT_PRINT, STMT_VAR, STMT_WHILE, + STMT_RETURN, } stmt_type_t; typedef enum { @@ -96,8 +97,9 @@ typedef struct ht_t ht_t; struct fn_t { fn_type_t type; int arity; + ht_t *env; stmt_t *stmt; - value_t *(*call)(stmt_t *stmt, val_array_t *arguments, ht_t *env); + value_t *(*call)(struct fn_t *stmt, val_array_t *arguments, ht_t *env); }; struct expr_t { diff --git a/include/env.h b/include/env.h index 25bcaa3..7063564 100644 --- a/include/env.h +++ b/include/env.h @@ -10,7 +10,7 @@ typedef struct ht_t { struct ht_t *enclosing; } ht_t; -#define DEFAULT_HT_SIZE 50 +#define DEFAULT_HT_SIZE 500 ht_t *ht_init(ht_t *env); void ht_add(ht_t *ht, char *name, value_t *value); diff --git a/src/env.c b/src/env.c index 9af008f..e7d957c 100644 --- a/src/env.c +++ b/src/env.c @@ -57,6 +57,9 @@ void ht_add(ht_t *ht, char *name, value_t *value) value_t *ht_get(ht_t *ht, token_t *name, int check_enclosing) { + if (!ht) { + return NULL; + } unsigned int idx = hash(name->value) % DEFAULT_HT_SIZE; /* Linear probing to search for the key */ for (int i = 0; i < DEFAULT_HT_SIZE; i++) { diff --git a/src/interpreter.c b/src/interpreter.c index cdb8fea..ac1dad4 100644 --- a/src/interpreter.c +++ b/src/interpreter.c @@ -10,7 +10,12 @@ #include "lexer.h" #include "parser.h" -void evaluate_statement(stmt_t *stmt, ht_t *env); +typedef struct { + int has_returned; + value_t *value; +} return_state_t; + +void evaluate_statement(stmt_t *stmt, ht_t *env, return_state_t *state); void free_val(value_t *value) { @@ -30,6 +35,9 @@ value_t *visit_literal(expr_t *expr) { value_t *val = malloc(sizeof(value_t)); memcpy(val, expr->as.literal.value, sizeof(value_t)); + if (val->type == VAL_STRING) { + val->as.string = strdup(expr->as.literal.value->as.string); + } return val; } @@ -312,7 +320,7 @@ value_t *visit_call(expr_t *expr, ht_t *env) snprintf(err, 512, "Expected %d arguments but got %d.", callee->as.function->arity, arguments->length); runtime_error(err, expr->line); } - value_t *res = callee->as.function->call(callee->as.function->stmt, arguments, env); + value_t *res = callee->as.function->call(callee->as.function, arguments, callee->as.function->env); free_vals(arguments); free_val(callee); return res; @@ -350,6 +358,11 @@ value_t *evaluate(expr_t *expr, ht_t *env) void print_value(value_t *value) { + if (!value) { + printf("nil\n"); + return; + } + switch (value->type) { case VAL_BOOL: printf("%s\n", value->as.boolean == 1 ? "true" : "false"); @@ -384,23 +397,23 @@ void print_value(value_t *value) } } -void evaluate_statements(stmt_array_t *array, ht_t *env) +void evaluate_statements(stmt_array_t *array, ht_t *env, return_state_t *state) { for (int i = 0; i < array->length; i++) { - evaluate_statement(array->statements[i], env); + evaluate_statement(array->statements[i], env, state); } } -void evaluate_block(stmt_array_t *array, ht_t *cur_env, ht_t *scope_env) +void evaluate_block(stmt_array_t *array, ht_t *cur_env, ht_t *scope_env, return_state_t *state) { ht_t *previous = cur_env; cur_env = scope_env; - evaluate_statements(array, cur_env); - ht_free(scope_env); + evaluate_statements(array, cur_env, state); +/* ht_free(scope_env); */ cur_env = previous; } -value_t *_clock(stmt_t *stmt, val_array_t *arguments, ht_t *env) +value_t *_clock(fn_t *fn, val_array_t *arguments, ht_t *env) { value_t *val = malloc(sizeof(value_t)); val->type = VAL_NUMBER; @@ -408,26 +421,34 @@ value_t *_clock(stmt_t *stmt, val_array_t *arguments, ht_t *env) return val; } -value_t *_call(stmt_t *stmt, val_array_t *arguments, ht_t *env) +value_t *_call(fn_t *fn, val_array_t *arguments, ht_t *env) { - ht_t *fn_env = ht_init(env); - for (int i = 0; i < stmt->as.function.params->length; i++) { - ht_add(fn_env, stmt->as.function.params->tokens[i].value, arguments->arguments[i]); + ht_t *fn_env = ht_init(fn->env); + for (int i = 0; i < fn->stmt->as.function.params->length; i++) { + ht_add(fn_env, fn->stmt->as.function.params->tokens[i].value, arguments->arguments[i]); } - evaluate_block(stmt->as.function.body->as.block.statements, env, fn_env); + return_state_t state = { 0, NULL }; + evaluate_block(fn->stmt->as.function.body->as.block.statements, env, fn_env, &state); + +/* ht_free(fn_env); */ + if (state.has_returned) { + return state.value; + } return NULL; } -void evaluate_statement(stmt_t *stmt, ht_t *env) +void evaluate_statement(stmt_t *stmt, ht_t *env, return_state_t *state) { + if (state && state->has_returned) + return; switch (stmt->type) { case STMT_IF: if (is_truthy(evaluate(stmt->as._if.condition, env))) { - evaluate_statement(stmt->as._if.then_branch, env); + evaluate_statement(stmt->as._if.then_branch, env, state); } else if (stmt->as._if.else_branch) { - evaluate_statement(stmt->as._if.else_branch, env); + evaluate_statement(stmt->as._if.else_branch, env, state); } break; @@ -454,15 +475,20 @@ void evaluate_statement(stmt_t *stmt, ht_t *env) break; } - case STMT_BLOCK: - evaluate_block(stmt->as.block.statements, env, ht_init(env)); + case STMT_BLOCK:; + ht_t *cp_env = ht_init(env); + evaluate_block(stmt->as.block.statements, env, cp_env, state); +/* ht_free(cp_env); */ break; case STMT_WHILE:; value_t *cond = evaluate(stmt->as._while.condition, env); while (is_truthy(cond)) { - evaluate_statement(stmt->as._while.body, env); + evaluate_statement(stmt->as._while.body, env, state); free_val(cond); + if (state->has_returned) { + return; + } cond = evaluate(stmt->as._while.condition, env); } free_val(cond); @@ -472,6 +498,7 @@ void evaluate_statement(stmt_t *stmt, ht_t *env) fn_t *fn = malloc(sizeof(fn_t)); fn->type = FN_CUSTOM; fn->arity = stmt->as.function.params->length; + fn->env = env; fn->stmt = stmt; fn->call = _call; @@ -482,6 +509,14 @@ void evaluate_statement(stmt_t *stmt, ht_t *env) free_val(fn_val); free(fn); break; + + case STMT_RETURN:; + value_t *value = NULL; + if (stmt->as._return.value) { + value = evaluate(stmt->as._return.value, env); + } + state->has_returned = 1; + state->value = value; default: break; @@ -502,7 +537,10 @@ void interpret(stmt_array_t *array) clock_fn->as.function = fn; ht_add(env, "clock", clock_fn); - evaluate_statements(array, env); + + return_state_t state = { 0, NULL }; + + evaluate_statements(array, env, &state); ht_free(env); free(clock_fn); free(fn); diff --git a/src/parser.c b/src/parser.c index 321e291..a0d7c3f 100644 --- a/src/parser.c +++ b/src/parser.c @@ -29,6 +29,7 @@ void error(token_t *token, char *message) fprintf(stderr, "[line %d] at '%s': %s\n", token->line, token->value, message); } errno = 65; + exit(65); synchronize(); } @@ -330,34 +331,42 @@ void free_statement(stmt_t *stmt) if (!stmt) { return; } - if (stmt->type == STMT_PRINT) { - free_expr(stmt->as.print.expression); - free(stmt); - } else if (stmt->type == STMT_EXPR) { - free_expr(stmt->as.expr.expression); - free(stmt); - } else if (stmt->type == STMT_VAR) { - free(stmt->as.variable.name.value); - free_expr(stmt->as.variable.initializer); - free(stmt); - } else if (stmt->type == STMT_BLOCK) { - free_statements(stmt->as.block.statements); - free(stmt); - } else if (stmt->type == STMT_IF) { - free_expr(stmt->as._if.condition); - free_statement(stmt->as._if.then_branch); - free_statement(stmt->as._if.else_branch); - free(stmt); - } else if (stmt->type == STMT_WHILE) { - free_expr(stmt->as._while.condition); - free_statement(stmt->as._while.body); - free(stmt); - } else if (stmt->type == STMT_FUN) { - free(stmt->as.function.name.value); - free_array(stmt->as.function.params); - free_statement(stmt->as.function.body); - free(stmt); + switch (stmt->type) { + case STMT_PRINT: + free_expr(stmt->as.print.expression); + break; + case STMT_EXPR: + free_expr(stmt->as.expr.expression); + break; + case STMT_VAR: + free(stmt->as.variable.name.value); + free_expr(stmt->as.variable.initializer); + break; + case STMT_BLOCK: + free_statements(stmt->as.block.statements); + break; + case STMT_IF: + free_expr(stmt->as._if.condition); + free_statement(stmt->as._if.then_branch); + free_statement(stmt->as._if.else_branch); + break; + case STMT_WHILE: + free_expr(stmt->as._while.condition); + free_statement(stmt->as._while.body); + break; + case STMT_FUN: + free(stmt->as.function.name.value); + free_array(stmt->as.function.params); + free_statement(stmt->as.function.body); + break; + case STMT_RETURN: + free(stmt->as._return.keyword.value); + free_expr(stmt->as._return.value); + break; + default: + break; } + free(stmt); } void free_statements(stmt_array_t *array) @@ -467,6 +476,24 @@ stmt_t *print_stmt(void) return stmt; } +stmt_t *return_stmt(void) +{ + token_t *keyword = previous(); + expr_t *value = NULL; + if (!check(TOKEN_SEMICOLON)) { + value = expression(); + } + + consume(TOKEN_SEMICOLON, "Expect ';' after return value."); + stmt_t *stmt = malloc(sizeof(stmt_t)); + stmt->type = STMT_RETURN; + stmt->as._return.keyword.type = keyword->type; + stmt->as._return.keyword.value = strdup(keyword->value); + stmt->as._return.keyword.line = keyword->line; + stmt->as._return.value = value; + return stmt; +} + stmt_t *while_stmt(void) { consume(TOKEN_LEFT_PAREN, "Expect '(' after 'while'."); @@ -522,6 +549,9 @@ stmt_t *statement(void) if (match(TOKEN_PRINT)) { return print_stmt(); } + if (match(TOKEN_RETURN)) { + return return_stmt(); + } if (match(TOKEN_WHILE)) { return while_stmt(); }