Return statements

This commit is contained in:
Night Kaly 2025-01-21 03:25:17 +00:00
parent 6d9d68d693
commit 7ccc3e32ba
Signed by: night0721
SSH key fingerprint: SHA256:B/hgVwUoBpx5vdNsXl9w8XwZljA9766uk6T4ubZp5HM
5 changed files with 122 additions and 49 deletions

View file

@ -51,6 +51,7 @@ typedef enum {
STMT_PRINT,
STMT_VAR,
STMT_WHILE,
STMT_RETURN,
} stmt_type_t;
typedef enum {
@ -96,8 +97,9 @@ typedef struct ht_t ht_t;
struct fn_t {
fn_type_t type;
int arity;
ht_t *env;
stmt_t *stmt;
value_t *(*call)(stmt_t *stmt, val_array_t *arguments, ht_t *env);
value_t *(*call)(struct fn_t *stmt, val_array_t *arguments, ht_t *env);
};
struct expr_t {

View file

@ -10,7 +10,7 @@ typedef struct ht_t {
struct ht_t *enclosing;
} ht_t;
#define DEFAULT_HT_SIZE 50
#define DEFAULT_HT_SIZE 500
ht_t *ht_init(ht_t *env);
void ht_add(ht_t *ht, char *name, value_t *value);

View file

@ -57,6 +57,9 @@ void ht_add(ht_t *ht, char *name, value_t *value)
value_t *ht_get(ht_t *ht, token_t *name, int check_enclosing)
{
if (!ht) {
return NULL;
}
unsigned int idx = hash(name->value) % DEFAULT_HT_SIZE;
/* Linear probing to search for the key */
for (int i = 0; i < DEFAULT_HT_SIZE; i++) {

View file

@ -10,7 +10,12 @@
#include "lexer.h"
#include "parser.h"
void evaluate_statement(stmt_t *stmt, ht_t *env);
typedef struct {
int has_returned;
value_t *value;
} return_state_t;
void evaluate_statement(stmt_t *stmt, ht_t *env, return_state_t *state);
void free_val(value_t *value)
{
@ -30,6 +35,9 @@ value_t *visit_literal(expr_t *expr)
{
value_t *val = malloc(sizeof(value_t));
memcpy(val, expr->as.literal.value, sizeof(value_t));
if (val->type == VAL_STRING) {
val->as.string = strdup(expr->as.literal.value->as.string);
}
return val;
}
@ -312,7 +320,7 @@ value_t *visit_call(expr_t *expr, ht_t *env)
snprintf(err, 512, "Expected %d arguments but got %d.", callee->as.function->arity, arguments->length);
runtime_error(err, expr->line);
}
value_t *res = callee->as.function->call(callee->as.function->stmt, arguments, env);
value_t *res = callee->as.function->call(callee->as.function, arguments, callee->as.function->env);
free_vals(arguments);
free_val(callee);
return res;
@ -350,6 +358,11 @@ value_t *evaluate(expr_t *expr, ht_t *env)
void print_value(value_t *value)
{
if (!value) {
printf("nil\n");
return;
}
switch (value->type) {
case VAL_BOOL:
printf("%s\n", value->as.boolean == 1 ? "true" : "false");
@ -384,23 +397,23 @@ void print_value(value_t *value)
}
}
void evaluate_statements(stmt_array_t *array, ht_t *env)
void evaluate_statements(stmt_array_t *array, ht_t *env, return_state_t *state)
{
for (int i = 0; i < array->length; i++) {
evaluate_statement(array->statements[i], env);
evaluate_statement(array->statements[i], env, state);
}
}
void evaluate_block(stmt_array_t *array, ht_t *cur_env, ht_t *scope_env)
void evaluate_block(stmt_array_t *array, ht_t *cur_env, ht_t *scope_env, return_state_t *state)
{
ht_t *previous = cur_env;
cur_env = scope_env;
evaluate_statements(array, cur_env);
ht_free(scope_env);
evaluate_statements(array, cur_env, state);
/* ht_free(scope_env); */
cur_env = previous;
}
value_t *_clock(stmt_t *stmt, val_array_t *arguments, ht_t *env)
value_t *_clock(fn_t *fn, val_array_t *arguments, ht_t *env)
{
value_t *val = malloc(sizeof(value_t));
val->type = VAL_NUMBER;
@ -408,26 +421,34 @@ value_t *_clock(stmt_t *stmt, val_array_t *arguments, ht_t *env)
return val;
}
value_t *_call(stmt_t *stmt, val_array_t *arguments, ht_t *env)
value_t *_call(fn_t *fn, val_array_t *arguments, ht_t *env)
{
ht_t *fn_env = ht_init(env);
for (int i = 0; i < stmt->as.function.params->length; i++) {
ht_add(fn_env, stmt->as.function.params->tokens[i].value, arguments->arguments[i]);
ht_t *fn_env = ht_init(fn->env);
for (int i = 0; i < fn->stmt->as.function.params->length; i++) {
ht_add(fn_env, fn->stmt->as.function.params->tokens[i].value, arguments->arguments[i]);
}
evaluate_block(stmt->as.function.body->as.block.statements, env, fn_env);
return_state_t state = { 0, NULL };
evaluate_block(fn->stmt->as.function.body->as.block.statements, env, fn_env, &state);
/* ht_free(fn_env); */
if (state.has_returned) {
return state.value;
}
return NULL;
}
void evaluate_statement(stmt_t *stmt, ht_t *env)
void evaluate_statement(stmt_t *stmt, ht_t *env, return_state_t *state)
{
if (state && state->has_returned)
return;
switch (stmt->type) {
case STMT_IF:
if (is_truthy(evaluate(stmt->as._if.condition, env))) {
evaluate_statement(stmt->as._if.then_branch, env);
evaluate_statement(stmt->as._if.then_branch, env, state);
} else if (stmt->as._if.else_branch) {
evaluate_statement(stmt->as._if.else_branch, env);
evaluate_statement(stmt->as._if.else_branch, env, state);
}
break;
@ -454,15 +475,20 @@ void evaluate_statement(stmt_t *stmt, ht_t *env)
break;
}
case STMT_BLOCK:
evaluate_block(stmt->as.block.statements, env, ht_init(env));
case STMT_BLOCK:;
ht_t *cp_env = ht_init(env);
evaluate_block(stmt->as.block.statements, env, cp_env, state);
/* ht_free(cp_env); */
break;
case STMT_WHILE:;
value_t *cond = evaluate(stmt->as._while.condition, env);
while (is_truthy(cond)) {
evaluate_statement(stmt->as._while.body, env);
evaluate_statement(stmt->as._while.body, env, state);
free_val(cond);
if (state->has_returned) {
return;
}
cond = evaluate(stmt->as._while.condition, env);
}
free_val(cond);
@ -472,6 +498,7 @@ void evaluate_statement(stmt_t *stmt, ht_t *env)
fn_t *fn = malloc(sizeof(fn_t));
fn->type = FN_CUSTOM;
fn->arity = stmt->as.function.params->length;
fn->env = env;
fn->stmt = stmt;
fn->call = _call;
@ -482,6 +509,14 @@ void evaluate_statement(stmt_t *stmt, ht_t *env)
free_val(fn_val);
free(fn);
break;
case STMT_RETURN:;
value_t *value = NULL;
if (stmt->as._return.value) {
value = evaluate(stmt->as._return.value, env);
}
state->has_returned = 1;
state->value = value;
default:
break;
@ -502,7 +537,10 @@ void interpret(stmt_array_t *array)
clock_fn->as.function = fn;
ht_add(env, "clock", clock_fn);
evaluate_statements(array, env);
return_state_t state = { 0, NULL };
evaluate_statements(array, env, &state);
ht_free(env);
free(clock_fn);
free(fn);

View file

@ -29,6 +29,7 @@ void error(token_t *token, char *message)
fprintf(stderr, "[line %d] at '%s': %s\n", token->line, token->value, message);
}
errno = 65;
exit(65);
synchronize();
}
@ -330,34 +331,42 @@ void free_statement(stmt_t *stmt)
if (!stmt) {
return;
}
if (stmt->type == STMT_PRINT) {
free_expr(stmt->as.print.expression);
free(stmt);
} else if (stmt->type == STMT_EXPR) {
free_expr(stmt->as.expr.expression);
free(stmt);
} else if (stmt->type == STMT_VAR) {
free(stmt->as.variable.name.value);
free_expr(stmt->as.variable.initializer);
free(stmt);
} else if (stmt->type == STMT_BLOCK) {
free_statements(stmt->as.block.statements);
free(stmt);
} else if (stmt->type == STMT_IF) {
free_expr(stmt->as._if.condition);
free_statement(stmt->as._if.then_branch);
free_statement(stmt->as._if.else_branch);
free(stmt);
} else if (stmt->type == STMT_WHILE) {
free_expr(stmt->as._while.condition);
free_statement(stmt->as._while.body);
free(stmt);
} else if (stmt->type == STMT_FUN) {
free(stmt->as.function.name.value);
free_array(stmt->as.function.params);
free_statement(stmt->as.function.body);
free(stmt);
switch (stmt->type) {
case STMT_PRINT:
free_expr(stmt->as.print.expression);
break;
case STMT_EXPR:
free_expr(stmt->as.expr.expression);
break;
case STMT_VAR:
free(stmt->as.variable.name.value);
free_expr(stmt->as.variable.initializer);
break;
case STMT_BLOCK:
free_statements(stmt->as.block.statements);
break;
case STMT_IF:
free_expr(stmt->as._if.condition);
free_statement(stmt->as._if.then_branch);
free_statement(stmt->as._if.else_branch);
break;
case STMT_WHILE:
free_expr(stmt->as._while.condition);
free_statement(stmt->as._while.body);
break;
case STMT_FUN:
free(stmt->as.function.name.value);
free_array(stmt->as.function.params);
free_statement(stmt->as.function.body);
break;
case STMT_RETURN:
free(stmt->as._return.keyword.value);
free_expr(stmt->as._return.value);
break;
default:
break;
}
free(stmt);
}
void free_statements(stmt_array_t *array)
@ -467,6 +476,24 @@ stmt_t *print_stmt(void)
return stmt;
}
stmt_t *return_stmt(void)
{
token_t *keyword = previous();
expr_t *value = NULL;
if (!check(TOKEN_SEMICOLON)) {
value = expression();
}
consume(TOKEN_SEMICOLON, "Expect ';' after return value.");
stmt_t *stmt = malloc(sizeof(stmt_t));
stmt->type = STMT_RETURN;
stmt->as._return.keyword.type = keyword->type;
stmt->as._return.keyword.value = strdup(keyword->value);
stmt->as._return.keyword.line = keyword->line;
stmt->as._return.value = value;
return stmt;
}
stmt_t *while_stmt(void)
{
consume(TOKEN_LEFT_PAREN, "Expect '(' after 'while'.");
@ -522,6 +549,9 @@ stmt_t *statement(void)
if (match(TOKEN_PRINT)) {
return print_stmt();
}
if (match(TOKEN_RETURN)) {
return return_stmt();
}
if (match(TOKEN_WHILE)) {
return while_stmt();
}