From 4a0f3256347527764f90681c97f151e7ca3594a9 Mon Sep 17 00:00:00 2001
From: night0721 <night@night0721.xyz>
Date: Thu, 16 Jan 2025 00:04:09 +0000
Subject: [PATCH] Variable assign, printing

---
 include/ast.h         |  1 +
 include/env.h         | 21 ++++++++++
 include/interpreter.h |  1 +
 include/stmt.h        |  6 +++
 src/ast.c             | 13 ++++++
 src/env.c             | 93 +++++++++++++++++++++++++++++++++++++++++++
 src/interpreter.c     | 35 ++++++++++++++--
 src/parser.c          | 84 +++++++++++++++++++++++++++++---------
 8 files changed, 232 insertions(+), 22 deletions(-)
 create mode 100644 include/env.h
 create mode 100644 src/env.c

diff --git a/include/ast.h b/include/ast.h
index e853408..5632dca 100644
--- a/include/ast.h
+++ b/include/ast.h
@@ -103,6 +103,7 @@ expr_t *create_unary_expr(token_t *operator, expr_t *right);
 expr_t *create_literal_expr(token_t *token);
 expr_t *create_grouping_expr(expr_t *expression);
 expr_t *create_variable_expr(token_t *name);
+expr_t *create_assign_expr(token_t *name, expr_t *value);
 void print_ast(expr_t *expr);
 
 #endif
diff --git a/include/env.h b/include/env.h
new file mode 100644
index 0000000..2034819
--- /dev/null
+++ b/include/env.h
@@ -0,0 +1,21 @@
+#ifndef ENV_H
+#define ENV_H
+
+#include "ast.h"
+#include "lexer.h"
+
+typedef struct {
+	char *name;
+	value_t value;
+} ht_t;
+
+#define DEFAULT_HT_SIZE 50
+
+ht_t *ht_init(void);
+void ht_add(ht_t *ht, char *name, value_t value);
+value_t *ht_get(ht_t *ht, token_t *name);
+void ht_replace(ht_t *ht, char *name, value_t value);
+void ht_assign(ht_t *ht, token_t *name, value_t value);
+void ht_free(ht_t *ht);
+
+#endif
diff --git a/include/interpreter.h b/include/interpreter.h
index 9d3c03e..c2a3084 100644
--- a/include/interpreter.h
+++ b/include/interpreter.h
@@ -4,6 +4,7 @@
 #include "ast.h"
 #include "stmt.h"
 
+void runtime_error(const char *message, int line);
 value_t evaluate(expr_t *expr);
 void print_value(value_t *value);
 void print_statements(stmt_array_t *array);
diff --git a/include/stmt.h b/include/stmt.h
index 24a6ae6..44607b5 100644
--- a/include/stmt.h
+++ b/include/stmt.h
@@ -6,6 +6,12 @@
 
 #define DEFAULT_STMTS_SIZE 512
 
+/*
+ program        → statement* EOF ;
+ statement      → exprStmt | printStmt ;
+ exprStmt       → expression ";" ;
+ printStmt      → "print" expression ";" ;
+*/
 typedef enum {
 	STMT_BLOCK,
 	STMT_CLASS,
diff --git a/src/ast.c b/src/ast.c
index 56c61b5..426d9bf 100644
--- a/src/ast.c
+++ b/src/ast.c
@@ -93,6 +93,19 @@ expr_t *create_variable_expr(token_t *name)
 	return expr;
 }
 
+expr_t *create_assign_expr(token_t *name, expr_t *value)
+{
+	expr_t *expr = malloc(sizeof(expr_t));
+	expr->type = EXPR_ASSIGN;
+	expr->line = name->line;
+	expr->as.assign.name.type = name->type;
+	expr->as.assign.name.value = name->value;
+	expr->as.assign.name.line = name->line;
+	expr->as.assign.value = value;
+
+	return expr;
+}
+
 void print_ast(expr_t *expr)
 {
 	if (!expr)
diff --git a/src/env.c b/src/env.c
new file mode 100644
index 0000000..e59853f
--- /dev/null
+++ b/src/env.c
@@ -0,0 +1,93 @@
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "env.h"
+#include "interpreter.h"
+
+ht_t *ht_init(void)
+{
+	ht_t *ht = malloc(sizeof(ht_t) * DEFAULT_HT_SIZE);
+	for (int i = 0; i < DEFAULT_HT_SIZE; i++) {
+		ht[i].name = NULL;
+	}
+
+	return ht;
+}
+
+unsigned int hash(char *key)
+{
+	unsigned int h = 0;
+	for (; *key; key++)
+		h = 31 * h + *key;
+	return h;
+}
+
+void ht_add(ht_t *ht, char *name, value_t value)
+{
+	unsigned int idx = hash(name) % DEFAULT_HT_SIZE;
+	/* Linear probing for collision resolution */
+	for (int i = 0; i < DEFAULT_HT_SIZE; i++) {
+		int probe_idx = (idx + i) % DEFAULT_HT_SIZE;
+		if (!ht[probe_idx].name) {
+			ht[probe_idx].name = name;
+			memcpy(&ht[probe_idx].value, &value, sizeof(value));
+			return;
+		} else {
+			ht_replace(ht, name, value);
+			return;
+		}
+	}
+}
+
+value_t *ht_get(ht_t *ht, token_t *name)
+{
+	unsigned int idx = hash(name->value) % DEFAULT_HT_SIZE;
+	/* Linear probing to search for the key */
+	for (int i = 0; i < DEFAULT_HT_SIZE; i++) {
+		int probe_idx = (idx + i) % DEFAULT_HT_SIZE;
+		if (ht[probe_idx].name && !strcmp(ht[probe_idx].name, name->value))
+			return &ht[probe_idx].value;
+	}
+	char err[512];
+	snprintf(err, 512, "Undefined variable '%s'.", name->value);
+	runtime_error(err, name->line);
+	return NULL;
+}
+
+void ht_replace(ht_t *ht, char *name, value_t value)
+{
+	unsigned int idx = hash(name) % DEFAULT_HT_SIZE;
+
+	for (int i = 0; i < DEFAULT_HT_SIZE; i++) {
+		int probe_idx = (idx + i) % DEFAULT_HT_SIZE;
+		if (!ht[probe_idx].name)
+			break;
+		if (!strcmp(ht[probe_idx].name, name)) {
+			memcpy(&ht[probe_idx].value, &value, sizeof(value));
+			return;
+		}
+	}
+}
+
+void ht_assign(ht_t *ht, token_t *name, value_t value)
+{
+	if (ht_get(ht, name)) {
+		ht_replace(ht, name->value, value);
+		return;
+	}
+	char err[512];
+	snprintf(err, 512, "Undefined variable '%s'.", name->value);
+	runtime_error(err, name->line);
+}
+
+void ht_free(ht_t *ht)
+{
+	for (int i = 0; i < DEFAULT_HT_SIZE; i++) {
+		if (ht[i].value.type != VAL_NIL) {
+			free(ht[i].name);
+		}
+	}
+	free(ht);
+}
+
diff --git a/src/interpreter.c b/src/interpreter.c
index 2706484..9a1b086 100644
--- a/src/interpreter.c
+++ b/src/interpreter.c
@@ -4,13 +4,11 @@
 #include <errno.h>
 
 #include "ast.h"
+#include "env.h"
 #include "interpreter.h"
 #include "lexer.h"
 
-value_t visit_literal(expr_t *expr);
-value_t visit_grouping(expr_t *expr);
-value_t visit_binary(expr_t *expr);
-value_t visit_unary(expr_t *expr);
+ht_t *ht;
 
 value_t visit_literal(expr_t *expr)
 {
@@ -184,6 +182,23 @@ value_t visit_unary(expr_t *expr)
 	return (value_t){.type = VAL_NIL};
 }
 
+value_t visit_variable(expr_t *expr)
+{
+	value_t *val = ht_get(ht, &expr->as.variable.name);
+	if (val) {
+		return *val;
+	} else {
+		return (value_t) {.type = VAL_NIL};
+	}
+}
+
+value_t visit_assign(expr_t *expr)
+{
+	value_t value = evaluate(expr->as.assign.value);
+	ht_assign(ht, &expr->as.assign.name, value);
+    return value;
+}
+
 value_t evaluate(expr_t *expr)
 {
 	if (!expr) {
@@ -199,6 +214,10 @@ value_t evaluate(expr_t *expr)
 			return visit_unary(expr);
 		case EXPR_GROUPING:
 			return visit_grouping(expr);
+		case EXPR_VARIABLE:
+			return visit_variable(expr);
+		case EXPR_ASSIGN:
+			return visit_assign(expr);
 		default:
 			exit(65);
 			break;
@@ -240,10 +259,18 @@ void print_statement(stmt_t stmt)
 		print_value(&obj);
 	} else if (stmt.type == STMT_EXPR) {
 		evaluate(stmt.as.expr.expression);
+	} else if (stmt.type == STMT_VAR) {
+		value_t value = {.type = VAL_NIL};
+		if (stmt.as.variable.initializer) {
+			value = evaluate(stmt.as.variable.initializer);
+		}
+		ht_add(ht, stmt.as.variable.name.value, value);
 	}
 }
+
 void print_statements(stmt_array_t *array)
 {
+	ht = ht_init();
 	for (int i = 0; i < array->length; i++) {
 		print_statement(array->statements[i]);
 	}
diff --git a/src/parser.c b/src/parser.c
index ab3eedb..f9b2455 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -11,6 +11,9 @@ token_t *tokens;
 expr_t *expression(void);
 void synchronize(void);
 
+/*
+ * Syntax error
+ */
 void error(token_t *token, char *message)
 {
 	if (token->type == TOKEN_EOF) {
@@ -51,6 +54,16 @@ void free_expr(expr_t *expr)
 			}
 			free(expr);
 			break;
+	
+		case EXPR_VARIABLE:
+			free(expr->as.variable.name.value);
+			free(expr);
+			break;
+
+		case EXPR_ASSIGN:
+			free(expr->as.assign.name.value);
+			free_expr(expr->as.assign.value);
+			break;
 
 		default:
 			break;
@@ -80,19 +93,17 @@ void advance(void)
 
 int check(token_type_t type)
 {
-	if (tokens[current].type == type) {
-		advance();
-		return 1;
-	} else {
-		return 0;
-	}
+	return tokens[current].type == type;
 }
 
-token_t *consume(token_type_t type, char *message) {
+token_t *consume(token_type_t type, char *message)
+{
 	if (!check(type)) {
 		error(peek(), message);
 	} else {
-		return peek();
+		token_t *tok = peek();
+		advance();
+		return tok;
 	}
 	return NULL;
 }
@@ -101,14 +112,19 @@ expr_t *primary(void)
 {
 	if (check(TOKEN_FALSE) || check(TOKEN_TRUE) || check(TOKEN_NIL) ||
 			check(TOKEN_NUMBER) || check(TOKEN_STRING)) {
-		return create_literal_expr(previous());
+		token_t *tok = peek();
+		advance();
+		return create_literal_expr(tok);
 	}
 
 	if (check(TOKEN_IDENTIFIER)) {
-		return create_variable_expr(previous());
+		token_t *tok = peek();
+		advance();
+		return create_variable_expr(tok);
 	}
 
 	if (check(TOKEN_LEFT_PAREN)) {
+		advance();
 		expr_t *expr = expression();
 		consume(TOKEN_RIGHT_PAREN, "Expect ')' after expression.");
 		return create_grouping_expr(expr);
@@ -120,7 +136,8 @@ expr_t *primary(void)
 expr_t *unary(void)
 {
 	if (check(TOKEN_BANG) || check(TOKEN_MINUS)) {
-		token_t *operator = previous();
+		token_t *operator = peek();
+		advance();
 		expr_t *right = unary();
 		return create_unary_expr(operator, right);
 	}
@@ -133,7 +150,8 @@ expr_t *factor(void)
 	expr_t *expr = unary();
 
 	while (check(TOKEN_SLASH) || check(TOKEN_STAR)) {
-		token_t *operator = previous();
+		token_t *operator = peek();
+		advance();
 		expr_t *right = unary();
 		expr = create_binary_expr(operator, expr, right);
 	}
@@ -146,7 +164,8 @@ expr_t *term(void)
 	expr_t *expr = factor();
 
 	while (check(TOKEN_MINUS) || check(TOKEN_PLUS)) {
-		token_t *operator = previous();
+		token_t *operator = peek();
+		advance();
 		expr_t *right = factor();
 		expr = create_binary_expr(operator, expr, right);
 	}
@@ -160,7 +179,8 @@ expr_t *comparison(void)
 
 	while (check(TOKEN_GREATER) || check(TOKEN_GREATER_EQUAL) || check(TOKEN_LESS)
 			|| check(TOKEN_LESS_EQUAL)) {
-		token_t *operator = previous();
+		token_t *operator = peek();
+		advance();
 		expr_t *right = term();
 		expr = create_binary_expr(operator, expr, right);
 	}
@@ -173,7 +193,8 @@ expr_t *equality(void)
 	expr_t *expr = comparison();
 
 	while (check(TOKEN_BANG_EQUAL) || check(TOKEN_EQUAL_EQUAL)) {
-		token_t *operator = previous();
+		token_t *operator = peek();
+		advance();
 		expr_t *right = comparison();
 		expr = create_binary_expr(operator, expr, right);
 	}
@@ -181,9 +202,28 @@ expr_t *equality(void)
 	return expr;
 }
 
+expr_t *assignment(void)
+{
+	expr_t *expr = equality();
+
+	if (check(TOKEN_EQUAL)) {
+		token_t *equals = peek();
+		advance();
+		expr_t *value = assignment();
+
+		if (expr->type == EXPR_VARIABLE) {
+			token_t name = expr->as.variable.name;
+			return create_assign_expr(&name, value);
+		}
+		error(equals, "Invalid assignment target.");
+	}
+
+    return expr;
+}
+
 expr_t *expression(void)
 {
-	return equality();
+	return assignment();
 }
 
 stmt_t print_stmt(void)
@@ -208,8 +248,10 @@ stmt_t expression_stmt(void)
 
 stmt_t statement(void)
 {
-	if (check(TOKEN_PRINT))
+	if (check(TOKEN_PRINT)) {
+		advance();
 		return print_stmt();
+	}
 	return expression_stmt();
 }
 
@@ -219,6 +261,7 @@ stmt_t var_declaration(void)
 
 	expr_t *initializer = NULL;
 	if (check(TOKEN_EQUAL)) {
+		advance();
 		initializer = expression();
 	}
 
@@ -234,8 +277,10 @@ stmt_t var_declaration(void)
 
 stmt_t declaration(void)
 {
-	if (check(TOKEN_VAR))
+	if (check(TOKEN_VAR)) {
+		advance();
 		return var_declaration();
+	}
 
 	return statement();
 }
@@ -258,6 +303,9 @@ void free_statements(stmt_array_t *array)
 		if (array->statements[i].type == STMT_EXPR) {
 			free_expr(array->statements[i].as.expr.expression);
 		}
+		if (array->statements[i].type == STMT_VAR) {
+			free_expr(array->statements[i].as.variable.initializer);
+		}
 	}
 	free(array->statements);
 	free(array);