From 7243d2057c2ae8cd7be7c2b53fd5c644d032a6dc Mon Sep 17 00:00:00 2001
From: night0721 <night@night0721.xyz>
Date: Thu, 16 Jan 2025 22:43:45 +0000
Subject: [PATCH] Support scoping

---
 include/ast.h         |  4 +-
 include/env.h         |  7 ++--
 include/interpreter.h |  5 ++-
 include/stmt.h        | 23 ++++++-----
 src/ast.c             |  8 ++--
 src/env.c             | 51 +++++++++++++++++++----
 src/interpreter.c     | 90 ++++++++++++++++++++++++----------------
 src/parser.c          | 95 +++++++++++++++++++++++++++++--------------
 src/rd.c              |  6 ++-
 9 files changed, 190 insertions(+), 99 deletions(-)

diff --git a/include/ast.h b/include/ast.h
index 5632dca..cd2a187 100644
--- a/include/ast.h
+++ b/include/ast.h
@@ -48,7 +48,7 @@ typedef struct expr_t {
 	int line;
 	union {
 		struct {
-			token_t name;
+			struct expr_t *name;
 			struct expr_t *value;
 		} assign;
 		struct {
@@ -103,7 +103,7 @@ expr_t *create_unary_expr(token_t *operator, expr_t *right);
 expr_t *create_literal_expr(token_t *token);
 expr_t *create_grouping_expr(expr_t *expression);
 expr_t *create_variable_expr(token_t *name);
-expr_t *create_assign_expr(token_t *name, expr_t *value);
+expr_t *create_assign_expr(expr_t *name, expr_t *value);
 void print_ast(expr_t *expr);
 
 #endif
diff --git a/include/env.h b/include/env.h
index 2034819..9da5d39 100644
--- a/include/env.h
+++ b/include/env.h
@@ -4,16 +4,17 @@
 #include "ast.h"
 #include "lexer.h"
 
-typedef struct {
+typedef struct ht_t {
 	char *name;
 	value_t value;
+	struct ht_t *enclosing;
 } ht_t;
 
 #define DEFAULT_HT_SIZE 50
 
-ht_t *ht_init(void);
+ht_t *ht_init(ht_t *env);
 void ht_add(ht_t *ht, char *name, value_t value);
-value_t *ht_get(ht_t *ht, token_t *name);
+value_t *ht_get(ht_t *ht, token_t *name, int check_enclosing);
 void ht_replace(ht_t *ht, char *name, value_t value);
 void ht_assign(ht_t *ht, token_t *name, value_t value);
 void ht_free(ht_t *ht);
diff --git a/include/interpreter.h b/include/interpreter.h
index c2a3084..decb592 100644
--- a/include/interpreter.h
+++ b/include/interpreter.h
@@ -2,11 +2,12 @@
 #define INTERPRETER_H
 
 #include "ast.h"
+#include "env.h"
 #include "stmt.h"
 
 void runtime_error(const char *message, int line);
-value_t evaluate(expr_t *expr);
+value_t evaluate(expr_t *expr, ht_t *env);
 void print_value(value_t *value);
-void print_statements(stmt_array_t *array);
+void evaluate_statements(stmt_array_t *array, ht_t *env);
 
 #endif
diff --git a/include/stmt.h b/include/stmt.h
index 44607b5..b32aa7b 100644
--- a/include/stmt.h
+++ b/include/stmt.h
@@ -7,11 +7,10 @@
 #define DEFAULT_STMTS_SIZE 512
 
 /*
- program        → statement* EOF ;
- statement      → exprStmt | printStmt ;
- exprStmt       → expression ";" ;
- printStmt      → "print" expression ";" ;
+ statement      → exprStmt | printStmt | block ;
+ block          → "{" declaration* "}" ;
 */
+
 typedef enum {
 	STMT_BLOCK,
 	STMT_CLASS,
@@ -23,11 +22,19 @@ typedef enum {
 	STMT_WHILE,
 } stmt_type_t;
 
+typedef struct stmt_t stmt_t;
+
+typedef struct {
+	struct stmt_t *statements;
+	int length;
+	int capacity;
+} stmt_array_t;
+
 typedef struct stmt_t {
 	stmt_type_t type;
 	union {
 		struct {
-			struct stmt_t **statements;
+			stmt_array_t *statements;
 		} block;
 		struct {
 			token_t name;
@@ -65,10 +72,4 @@ typedef struct stmt_t {
 	} as;
 } stmt_t;
 
-typedef struct {
-	struct stmt_t *statements;
-	int length;
-	int capacity;
-} stmt_array_t;
-
 #endif 
diff --git a/src/ast.c b/src/ast.c
index 426d9bf..f953528 100644
--- a/src/ast.c
+++ b/src/ast.c
@@ -93,14 +93,12 @@ expr_t *create_variable_expr(token_t *name)
 	return expr;
 }
 
-expr_t *create_assign_expr(token_t *name, expr_t *value)
+expr_t *create_assign_expr(expr_t *name, expr_t *value)
 {
 	expr_t *expr = malloc(sizeof(expr_t));
 	expr->type = EXPR_ASSIGN;
-	expr->line = name->line;
-	expr->as.assign.name.type = name->type;
-	expr->as.assign.name.value = name->value;
-	expr->as.assign.name.line = name->line;
+	expr->line = name->as.variable.name.line;
+	expr->as.assign.name =  name;
 	expr->as.assign.value = value;
 
 	return expr;
diff --git a/src/env.c b/src/env.c
index e59853f..6fdb5a6 100644
--- a/src/env.c
+++ b/src/env.c
@@ -5,13 +5,18 @@
 #include "env.h"
 #include "interpreter.h"
 
-ht_t *ht_init(void)
-{
+ht_t *ht_init(ht_t *env)
+{	
 	ht_t *ht = malloc(sizeof(ht_t) * DEFAULT_HT_SIZE);
 	for (int i = 0; i < DEFAULT_HT_SIZE; i++) {
+		ht[i].value.type = VAL_NIL;
 		ht[i].name = NULL;
 	}
-
+	if (env) {
+		ht->enclosing = env;	
+	} else {
+		ht->enclosing = NULL;
+	}
 	return ht;
 }
 
@@ -30,8 +35,13 @@ void ht_add(ht_t *ht, char *name, value_t value)
 	for (int i = 0; i < DEFAULT_HT_SIZE; i++) {
 		int probe_idx = (idx + i) % DEFAULT_HT_SIZE;
 		if (!ht[probe_idx].name) {
-			ht[probe_idx].name = name;
-			memcpy(&ht[probe_idx].value, &value, sizeof(value));
+			ht[probe_idx].name = strdup(name);
+			ht[probe_idx].value.type = value.type;
+			if (value.type == VAL_STRING) {
+				ht[probe_idx].value.as.string = strdup(value.as.string);
+			} else {
+				ht[probe_idx].value.as = value.as;
+			}
 			return;
 		} else {
 			ht_replace(ht, name, value);
@@ -40,7 +50,7 @@ void ht_add(ht_t *ht, char *name, value_t value)
 	}
 }
 
-value_t *ht_get(ht_t *ht, token_t *name)
+value_t *ht_get(ht_t *ht, token_t *name, int check_enclosing)
 {
 	unsigned int idx = hash(name->value) % DEFAULT_HT_SIZE;
 	/* Linear probing to search for the key */
@@ -49,6 +59,14 @@ value_t *ht_get(ht_t *ht, token_t *name)
 		if (ht[probe_idx].name && !strcmp(ht[probe_idx].name, name->value))
 			return &ht[probe_idx].value;
 	}
+	if (check_enclosing) {
+		if (ht->enclosing) {
+			return ht_get(ht->enclosing, name, 1);
+		}
+	} else {
+		return NULL;
+	}
+
 	char err[512];
 	snprintf(err, 512, "Undefined variable '%s'.", name->value);
 	runtime_error(err, name->line);
@@ -61,10 +79,18 @@ void ht_replace(ht_t *ht, char *name, value_t value)
 
 	for (int i = 0; i < DEFAULT_HT_SIZE; i++) {
 		int probe_idx = (idx + i) % DEFAULT_HT_SIZE;
-		if (!ht[probe_idx].name)
+		if (!ht[probe_idx].name) {
+			ht_replace(ht->enclosing, name, value);
 			break;
+		}
 		if (!strcmp(ht[probe_idx].name, name)) {
-			memcpy(&ht[probe_idx].value, &value, sizeof(value));
+			ht[probe_idx].value.type = value.type;
+			if (value.type == VAL_STRING) {
+				free(ht[probe_idx].value.as.string);
+				ht[probe_idx].value.as.string = strdup(value.as.string);
+			} else {
+				ht[probe_idx].value.as = value.as;
+			}
 			return;
 		}
 	}
@@ -72,10 +98,14 @@ void ht_replace(ht_t *ht, char *name, value_t value)
 
 void ht_assign(ht_t *ht, token_t *name, value_t value)
 {
-	if (ht_get(ht, name)) {
+	if (ht_get(ht, name, 0)) {
 		ht_replace(ht, name->value, value);
 		return;
 	}
+	if (ht->enclosing) {
+		ht_assign(ht->enclosing, name, value);
+		return;
+	}
 	char err[512];
 	snprintf(err, 512, "Undefined variable '%s'.", name->value);
 	runtime_error(err, name->line);
@@ -86,6 +116,9 @@ void ht_free(ht_t *ht)
 	for (int i = 0; i < DEFAULT_HT_SIZE; i++) {
 		if (ht[i].value.type != VAL_NIL) {
 			free(ht[i].name);
+			if (ht[i].value.type == VAL_STRING) {
+				free(ht[i].value.as.string);
+			}
 		}
 	}
 	free(ht);
diff --git a/src/interpreter.c b/src/interpreter.c
index 9a1b086..a13be7c 100644
--- a/src/interpreter.c
+++ b/src/interpreter.c
@@ -7,17 +7,16 @@
 #include "env.h"
 #include "interpreter.h"
 #include "lexer.h"
-
-ht_t *ht;
+#include "parser.h"
 
 value_t visit_literal(expr_t *expr)
 {
 	return expr->as.literal.value;
 }
 
-value_t visit_grouping(expr_t *expr)
+value_t visit_grouping(expr_t *expr, ht_t *env)
 {
-	return evaluate(expr->as.grouping.expression);
+	return evaluate(expr->as.grouping.expression, env);
 }
 
 void runtime_error(const char *message, int line)
@@ -27,11 +26,11 @@ void runtime_error(const char *message, int line)
 	exit(70);
 }
 
-value_t visit_binary(expr_t *expr)
+value_t visit_binary(expr_t *expr, ht_t *env)
 {
 	token_type_t op_type = expr->as.binary.operator.type;
-	value_t right = evaluate(expr->as.binary.right);
-	value_t left = evaluate(expr->as.binary.left);
+	value_t right = evaluate(expr->as.binary.right, env);
+	value_t left = evaluate(expr->as.binary.left, env);
 
 	// Arithmetic
 	if (left.type == VAL_NUMBER && right.type == VAL_NUMBER) {
@@ -163,9 +162,9 @@ int is_truthy(value_t *value)
 	}
 }
 
-value_t visit_unary(expr_t *expr)
+value_t visit_unary(expr_t *expr, ht_t *env)
 {
-	value_t operand = evaluate(expr->as.unary.right);
+	value_t operand = evaluate(expr->as.unary.right, env);
 
 	if (expr->as.unary.operator.type == TOKEN_MINUS) {
 		if (operand.type == VAL_NUMBER) {
@@ -182,9 +181,9 @@ value_t visit_unary(expr_t *expr)
 	return (value_t){.type = VAL_NIL};
 }
 
-value_t visit_variable(expr_t *expr)
+value_t visit_variable(expr_t *expr, ht_t *env)
 {
-	value_t *val = ht_get(ht, &expr->as.variable.name);
+	value_t *val = ht_get(env, &expr->as.variable.name, 1);
 	if (val) {
 		return *val;
 	} else {
@@ -192,14 +191,14 @@ value_t visit_variable(expr_t *expr)
 	}
 }
 
-value_t visit_assign(expr_t *expr)
+value_t visit_assign(expr_t *expr, ht_t *env)
 {
-	value_t value = evaluate(expr->as.assign.value);
-	ht_assign(ht, &expr->as.assign.name, value);
+	value_t value = evaluate(expr->as.assign.value, env);
+	ht_assign(env, &expr->as.assign.name->as.variable.name, value);
     return value;
 }
 
-value_t evaluate(expr_t *expr)
+value_t evaluate(expr_t *expr, ht_t *env)
 {
 	if (!expr) {
 		value_t nil_value = {.type = VAL_NIL };
@@ -209,15 +208,15 @@ value_t evaluate(expr_t *expr)
 		case EXPR_LITERAL:
 			return visit_literal(expr);
 		case EXPR_BINARY:
-			return visit_binary(expr);
+			return visit_binary(expr, env);
 		case EXPR_UNARY:
-			return visit_unary(expr);
+			return visit_unary(expr, env);
 		case EXPR_GROUPING:
-			return visit_grouping(expr);
+			return visit_grouping(expr, env);
 		case EXPR_VARIABLE:
-			return visit_variable(expr);
+			return visit_variable(expr, env);
 		case EXPR_ASSIGN:
-			return visit_assign(expr);
+			return visit_assign(expr, env);
 		default:
 			exit(65);
 			break;
@@ -252,26 +251,47 @@ void print_value(value_t *value)
 	}
 }
 
-void print_statement(stmt_t stmt)
+void evaluate_block(stmt_array_t *array, ht_t *cur_env, ht_t *scope_env)
 {
-	if (stmt.type == STMT_PRINT) {
-		value_t obj = evaluate(stmt.as.print.expression);
-		print_value(&obj);
-	} else if (stmt.type == STMT_EXPR) {
-		evaluate(stmt.as.expr.expression);
-	} else if (stmt.type == STMT_VAR) {
-		value_t value = {.type = VAL_NIL};
-		if (stmt.as.variable.initializer) {
-			value = evaluate(stmt.as.variable.initializer);
-		}
-		ht_add(ht, stmt.as.variable.name.value, value);
+	ht_t *previous = cur_env;
+	cur_env = scope_env;
+	evaluate_statements(array, cur_env);
+	ht_free(scope_env);
+	cur_env = previous;
+}
+
+void evaluate_statement(stmt_t stmt, ht_t *env)
+{
+	switch (stmt.type) {
+		case STMT_PRINT:;
+			value_t obj = evaluate(stmt.as.print.expression, env);
+			print_value(&obj);
+			break;
+
+		case STMT_EXPR:
+			evaluate(stmt.as.expr.expression, env);
+			break;
+
+		case STMT_VAR:;
+			value_t value = {.type = VAL_NIL};
+			if (stmt.as.variable.initializer) {
+				value = evaluate(stmt.as.variable.initializer, env);
+			}
+			ht_add(env, stmt.as.variable.name.value, value);
+			break;
+
+		case STMT_BLOCK:
+			evaluate_block(stmt.as.block.statements, env, ht_init(env));
+			break;
+
+		default:
+			break;
 	}
 }
 
-void print_statements(stmt_array_t *array)
+void evaluate_statements(stmt_array_t *array, ht_t *env)
 {
-	ht = ht_init();
 	for (int i = 0; i < array->length; i++) {
-		print_statement(array->statements[i]);
+		evaluate_statement(array->statements[i], env);
 	}
 }
diff --git a/src/parser.c b/src/parser.c
index f9b2455..87474c3 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -1,5 +1,6 @@
 #include <stdio.h>
 #include <stdlib.h>
+#include <string.h>
 #include <errno.h>
 
 #include "ast.h"
@@ -9,6 +10,7 @@
 int current = 0;
 token_t *tokens;
 expr_t *expression(void);
+stmt_t declaration(void);
 void synchronize(void);
 
 /*
@@ -56,13 +58,16 @@ void free_expr(expr_t *expr)
 			break;
 	
 		case EXPR_VARIABLE:
+			printf("hereR?\n");
 			free(expr->as.variable.name.value);
 			free(expr);
 			break;
 
 		case EXPR_ASSIGN:
-			free(expr->as.assign.name.value);
+			free_expr(expr->as.assign.name);
+			printf("hiiii\n");
 			free_expr(expr->as.assign.value);
+			free(expr);
 			break;
 
 		default:
@@ -212,8 +217,7 @@ expr_t *assignment(void)
 		expr_t *value = assignment();
 
 		if (expr->type == EXPR_VARIABLE) {
-			token_t name = expr->as.variable.name;
-			return create_assign_expr(&name, value);
+			return create_assign_expr(expr, value);
 		}
 		error(equals, "Invalid assignment target.");
 	}
@@ -226,6 +230,40 @@ expr_t *expression(void)
 	return assignment();
 }
 
+void stmt_add(stmt_array_t *array, stmt_t stmt)
+{
+	if (array->length == array->capacity) {
+		array->capacity *= 2;
+		array->statements = realloc(array->statements, array->capacity * sizeof(stmt_t));
+	}
+	array->statements[array->length++] = stmt;
+}
+
+void free_statements(stmt_array_t *array)
+{
+	for (int i = 0; i < array->length; i++) {
+		if (array->statements[i].type == STMT_PRINT) {
+			printf("this should go fifth\n");
+			free_expr(array->statements[i].as.print.expression);
+		}
+		if (array->statements[i].type == STMT_EXPR) {
+			printf("third\n");
+			free_expr(array->statements[i].as.expr.expression);
+		}
+		if (array->statements[i].type == STMT_VAR) {
+			printf("this should go second/forth\n");
+			free(array->statements[i].as.variable.name.value);
+			free_expr(array->statements[i].as.variable.initializer);
+		}
+		if (array->statements[i].type == STMT_BLOCK) {
+			printf("this should go first/third\n");
+			free_statements(array->statements[i].as.block.statements);
+		}
+	}
+	free(array->statements);
+	free(array);
+}
+
 stmt_t print_stmt(void)
 {
 	expr_t *value = expression();
@@ -236,6 +274,25 @@ stmt_t print_stmt(void)
 	};
 }
 
+stmt_t block_stmt(void)
+{
+	stmt_array_t *statements = malloc(sizeof(stmt_array_t));
+	statements->statements = malloc(DEFAULT_STMTS_SIZE * sizeof(stmt_t));
+	statements->length = 0;
+	statements->capacity = DEFAULT_STMTS_SIZE;
+
+
+    while (!check(TOKEN_RIGHT_BRACE) && !end()) {
+		stmt_add(statements, declaration());
+    }
+
+    consume(TOKEN_RIGHT_BRACE, "Expect '}' after block.");
+    return (stmt_t) {
+		.type = STMT_BLOCK,
+		.as.block.statements = statements,
+	};
+}
+
 stmt_t expression_stmt(void)
 {
 	expr_t *expr = expression();
@@ -252,6 +309,10 @@ stmt_t statement(void)
 		advance();
 		return print_stmt();
 	}
+	if (check(TOKEN_LEFT_BRACE)) {
+		advance();
+		return block_stmt();
+	}
 	return expression_stmt();
 }
 
@@ -269,7 +330,7 @@ stmt_t var_declaration(void)
 	return (stmt_t) {
 		.type = STMT_VAR,
 		.as.variable.name.type = name->type,
-		.as.variable.name.value = name->value,
+		.as.variable.name.value = strdup(name->value),
 		.as.variable.name.line = name->line,
 		.as.variable.initializer = initializer,
 	};
@@ -285,32 +346,6 @@ stmt_t declaration(void)
 	return statement();
 }
 
-void stmt_add(stmt_array_t *array, stmt_t stmt)
-{
-	if (array->length == array->capacity) {
-		array->capacity *= 2;
-		array->statements = realloc(array->statements, array->capacity * sizeof(stmt_t));
-	}
-	array->statements[array->length++] = stmt;
-}
-
-void free_statements(stmt_array_t *array)
-{
-	for (int i = 0; i < array->length; i++) {
-		if (array->statements[i].type == STMT_PRINT) {
-			free_expr(array->statements[i].as.print.expression);
-		}
-		if (array->statements[i].type == STMT_EXPR) {
-			free_expr(array->statements[i].as.expr.expression);
-		}
-		if (array->statements[i].type == STMT_VAR) {
-			free_expr(array->statements[i].as.variable.initializer);
-		}
-	}
-	free(array->statements);
-	free(array);
-}
-
 stmt_array_t *parse(token_t *tks)
 {
 	tokens = tks;
diff --git a/src/rd.c b/src/rd.c
index 5798e9e..bf1b214 100644
--- a/src/rd.c
+++ b/src/rd.c
@@ -34,14 +34,16 @@ int main(int argc, char **argv)
 		free_expr(expr);
 	} else if (!strcmp(command, "evaluate")) {
 		expr_t *expr = parse_expr(array->tokens);
-		value_t val = evaluate(expr);
+		value_t val = evaluate(expr, NULL);
 		print_value(&val);
 		free_array(array);
 		free_expr(expr);
 	} else if (!strcmp(command, "run")) {
+		ht_t *env = ht_init(NULL);
 		stmt_array_t *stmts = parse(array->tokens);
 		if (errno != 65) {
-			print_statements(stmts);
+			evaluate_statements(stmts, env);
+			ht_free(env);
 			free_array(array);
 			free_statements(stmts);
 		}