diff --git a/.gitignore b/.gitignore index 4f8f440..0b1fb0c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,3 @@ -zsm -zsmc +bin *.o *.tar.gz diff --git a/Makefile b/Makefile index 89efcad..130729e 100644 --- a/Makefile +++ b/Makefile @@ -3,31 +3,36 @@ CC = cc VERSION = 1.0 -SERVER = zsm -CLIENT = zsmc +SERVER = zmr +CLIENT = zen +TARGET = zsm MANPAGE = $(TARGET).1 PREFIX ?= /usr/local BINDIR = $(PREFIX)/bin MANDIR = $(PREFIX)/share/man/man1 # Flags -LDFLAGS = $(shell pkg-config --libs libsodium libnotify ncurses) -CFLAGS = -O3 -mtune=native -march=native -pipe -g -std=c99 -Wno-pointer-sign -Wpedantic -Wall -D_DEFAULT_SOURCE -D_XOPEN_SOURCE=600 $(shell pkg-config --cflags libsodium libnotify ncurses) -lpthread +LDFLAGS = $(shell pkg-config --libs libsodium libnotify ncurses sqlite3) +CFLAGS = -O3 -mtune=native -march=native -pipe -g -std=c99 -Wno-pointer-sign -pedantic -Wall -D_DEFAULT_SOURCE -D_XOPEN_SOURCE=600 $(shell pkg-config --cflags libsodium libnotify ncurses sqlite3) -lpthread SERVERSRC = src/server/*.c CLIENTSRC = src/client/*.c LIBSRC = lib/*.c INCLUDE = -Iinclude/ +all: $(SERVER) $(CLIENT) + $(SERVER): $(SERVERSRC) $(LIBSRC) - $(CC) $(SERVERSRC) $(LIBSRC) $(INCLUDE) -o $@ $(CFLAGS) $(LDFLAGS) + mkdir -p bin + $(CC) $(SERVERSRC) $(LIBSRC) $(INCLUDE) -o bin/$@ $(CFLAGS) $(LDFLAGS) $(CLIENT): $(CLIENTSRC) $(LIBSRC) - $(CC) $(CLIENTSRC) $(LIBSRC) $(INCLUDE) -o $@ $(CFLAGS) $(LDFLAGS) + mkdir -p bin + $(CC) $(CLIENTSRC) $(LIBSRC) $(INCLUDE) -o bin/$@ $(CFLAGS) $(LDFLAGS) dist: mkdir -p $(TARGET)-$(VERSION) - cp -R README.md $(MANPAGE) $(TARGET) $(TARGET)-$(VERSION) + cp -R README.md $(MANPAGE) $(SERVER) $(CLIENT) $(TARGET)-$(VERSION) tar -cf $(TARGET)-$(VERSION).tar $(TARGET)-$(VERSION) gzip $(TARGET)-$(VERSION).tar rm -rf $(TARGET)-$(VERSION) @@ -35,18 +40,22 @@ dist: install: $(TARGET) mkdir -p $(DESTDIR)$(BINDIR) mkdir -p $(DESTDIR)$(MANDIR) - cp -p $(TARGET) $(DESTDIR)$(BINDIR)/$(TARGET) - chmod 755 $(DESTDIR)$(BINDIR)/$(TARGET) + cp -p bin/$(TARGET) $(DESTDIR)$(BINDIR)/$(SERVER) + chmod 755 $(DESTDIR)$(BINDIR)/$(SERVER) + cp -p bin/$(TARGET) $(DESTDIR)$(BINDIR)/$(CLIENT) + chmod 755 $(DESTDIR)$(BINDIR)/$(CLIENT) cp -p $(MANPAGE) $(DESTDIR)$(MANDIR)/$(MANPAGE) chmod 644 $(DESTDIR)$(MANDIR)/$(MANPAGE) uninstall: - $(RM) $(DESTDIR)$(BINDIR)/$(TARGET) + $(RM) $(DESTDIR)$(BINDIR)/$(SERVER) + $(RM) $(DESTDIR)$(BINDIR)/$(CLIENT) $(RM) $(DESTDIR)$(MANDIR)/$(MANPAGE) clean: - $(RM) $(TARGET) + $(RM) $(SERVER) $(CLIENT) -all: $(TARGET) +run: + ./bin/zen .PHONY: all dist install uninstall clean diff --git a/include/client/db.h b/include/client/db.h new file mode 100644 index 0000000..3d201ad --- /dev/null +++ b/include/client/db.h @@ -0,0 +1,8 @@ +#ifndef DB_H_ +#define DB_H_ + +#include + +int sqlite_init(); + +#endif diff --git a/include/client/ui.h b/include/client/ui.h new file mode 100644 index 0000000..37f37e8 --- /dev/null +++ b/include/client/ui.h @@ -0,0 +1,12 @@ +#ifndef UI_H_ +#define UI_H_ + +#include + +void ncurses_init(); +void windows_init(); +void draw_border(WINDOW *window, bool active); +void add_username(char *username); +void ui(); + +#endif diff --git a/include/client/user.h b/include/client/user.h new file mode 100644 index 0000000..967baa4 --- /dev/null +++ b/include/client/user.h @@ -0,0 +1,27 @@ +#ifndef USER_H_ +#define USER_H_ + +#include +#include +#include + +typedef struct user { + char *name; + wchar_t *icon; + int color; +} user; + +typedef struct ArrayList { + size_t length; + size_t capacity; + user *items; +} ArrayList; + +ArrayList *arraylist_init(size_t capacity); +void arraylist_free(ArrayList *list); +long arraylist_search(ArrayList *list, char *username); +void arraylist_remove(ArrayList *list, long index); +void arraylist_add(ArrayList *list, char *name, wchar_t *icon, int color, bool marked, bool force); +char *get_line(ArrayList *list, long index, bool icons); + +#endif diff --git a/include/config.h b/include/config.h new file mode 100644 index 0000000..6cf23fb --- /dev/null +++ b/include/config.h @@ -0,0 +1,25 @@ +/* Server */ +#define DEBUG 0 +#define PORT 20247 +#define MAX_NAME 32 /* Max username length */ +#define DATABASE_NAME "test.db" +#define MAX_MESSAGE_LENGTH 8192 + +/* Don't touch unless you know what you are doing */ +#define MAX_CONNECTION_QUEUE 128 +#define MAX_THREADS 8 +#define MAX_EVENTS 64 /* Max events can be returned simulataneouly by epoll */ +#define MAX_CLIENTS_PER_THREAD 1024 + +/* Client */ +#define DOMAIN "127.0.0.1" + +/* UI */ +#define PANEL_HEIGHT 1 +#define DRAW_PREVIEW 1 + +#define CLIENT_DATA_DIR "~/.local/share/zsm/zen" + +/* Keybindings */ +#define DOWN 0x102 +#define UP 0x103 diff --git a/include/ht.h b/include/ht.h new file mode 100644 index 0000000..5eaad5f --- /dev/null +++ b/include/ht.h @@ -0,0 +1,18 @@ +#ifndef HT_H_ +#define HT_H_ + +#define TABLE_SIZE 100 + +typedef struct client { + int id; + char name[32]; + //pthread_t thread; +} client; + +unsigned int hash(char *name); +void hashtable_init(); +void hashtable_print(); +int hashtable_add(client *p); +client *hashtable_search(char *name); + +#endif diff --git a/include/key.h b/include/key.h new file mode 100644 index 0000000..d736c23 --- /dev/null +++ b/include/key.h @@ -0,0 +1,39 @@ +#ifndef KEY_H_ +#define KEY_H_ + +#include + +#include "config.h" + +#define TIME_SIZE sizeof(time_t) +#define SIGN_SIZE crypto_sign_BYTES +#define PK_BIN_SIZE crypto_kx_PUBLICKEYBYTES +#define SK_BIN_SIZE crypto_sign_SECRETKEYBYTES +#define METADATA_SIZE MAX_NAME + TIME_SIZE +#define PK_SIZE PK_BIN_SIZE + METADATA_SIZE + SIGN_SIZE +#define SK_SIZE SK_BIN_SIZE + METADATA_SIZE + SIGN_SIZE +#define SHARED_SIZE crypto_kx_SESSIONKEYBYTES + +typedef struct public_key { + uint8_t bin[PK_BIN_SIZE]; + uint8_t username[MAX_NAME]; + time_t creation; + uint8_t signature[SIGN_SIZE]; +} public_key; + +typedef struct secret_key { + uint8_t bin[SK_BIN_SIZE]; + uint8_t username[MAX_NAME]; + time_t creation; + uint8_t signature[SIGN_SIZE]; +} secret_key; + +typedef struct key_pair { + public_key pk; + secret_key sk; +} key_pair; + +key_pair *create_key_pair(char *username); +key_pair *get_key_pair(char *username); + +#endif diff --git a/include/notification.h b/include/notification.h index 23e9a54..8a81a50 100644 --- a/include/notification.h +++ b/include/notification.h @@ -1,8 +1,9 @@ #ifndef NOTIFICATION_H #define NOTIFICATION_H +#include #include -void send_notification(const char *content); +void send_notification(uint8_t *content); #endif diff --git a/include/packet.h b/include/packet.h index 2748b2a..eba22ac 100644 --- a/include/packet.h +++ b/include/packet.h @@ -6,29 +6,34 @@ #include #include #include +#include #include #include #include +#include #include +#include +#include +#include #include +#include #include +#include +#include #include +#include +#include + +#include "config.h" -#define DEBUG 1 -#define DOMAIN "127.0.0.1" -#define PORT 20247 -#define MAX_CONNECTION 5 -#define MAX_MESSAGE_LENGTH 8192 #define ERROR_LENGTH 26 -#define ZSM_TYP_KEY 0x1 -#define ZSM_TYP_SEND_MESSAGE 0x2 +#define ZSM_TYP_AUTH 0x1 +#define ZSM_TYP_MESSAGE 0x2 #define ZSM_TYP_UPDATE_MESSAGE 0x3 #define ZSM_TYP_DELETE_MESSAGE 0x4 -#define ZSM_TYP_PRESENCE 0x5 -#define ZSM_TYP_TYPING 0x6 -#define ZSM_TYP_ERROR 0x7 /* Error message */ -#define ZSM_TYP_B 0x8 +#define ZSM_TYP_ERROR 0x5 +#define ZSM_TYP_INFO 0x6 #define ZSM_STA_SUCCESS 0x1 #define ZSM_STA_INVALID_TYPE 0x2 @@ -38,32 +43,40 @@ #define ZSM_STA_WRITING_SOCKET 0x6 #define ZSM_STA_UNKNOWN_USER 0x7 #define ZSM_STA_MEMORY_ALLOCATION 0x8 -#define ZSM_STA_WRONG_KEY_LENGTH 0x9 +#define ZSM_STA_ERROR_ENCRYPT 0x9 +#define ZSM_STA_ERROR_DECRYPT 0xA +#define ZSM_STA_ERROR_AUTHENTICATE 0xB +#define ZSM_STA_ERROR_INTEGRITY 0xC +#define ZSM_STA_UNAUTHORISED 0xD +#define ZSM_STA_AUTHORISED 0xE -#define PUBLIC_KEY_SIZE crypto_kx_PUBLICKEYBYTES -#define PRIVATE_KEY_SIZE crypto_kx_SECRETKEYBYTES -#define SHARED_KEY_SIZE crypto_kx_SESSIONKEYBYTES +#define ADDRESS_SIZE MAX_NAME + 1 + 255 /* 1 for @, 255 for domain, defined in RFC 5321, Section 4.5.3.1.2 */ +#define CHALLENGE_SIZE 32 +#define HASH_SIZE crypto_generichash_BYTES #define NONCE_SIZE crypto_aead_xchacha20poly1305_ietf_NPUBBYTES #define ADDITIONAL_SIZE crypto_aead_xchacha20poly1305_ietf_ABYTES -typedef struct message { - uint8_t option; +typedef struct packet { + uint8_t status; uint8_t type; - unsigned long long length; - unsigned char *data; -} message; + uint32_t length; + uint8_t *data; + uint8_t *signature; +} packet; + +#include "key.h" /* Utilities functions */ -void error(int fatal, const char *fmt, ...); -void *memalloc(size_t size); -void *estrdup(void *str); -unsigned char *get_public_key(int sockfd); -int send_public_key(int sockfd, unsigned char *pk); -void print_packet(message *msg); -int recv_packet(message *msg, int fd); -message *create_error_packet(int code); -message *create_packet(uint8_t option, uint8_t type, uint32_t length, char *data); -int send_packet(message *msg, int fd); -void free_packet(message *msg); +void print_packet(packet *msg); +int recv_packet(packet *pkt, int fd, uint8_t required_type); +packet *create_packet(uint8_t option, uint8_t type, uint32_t length, uint8_t *data, uint8_t *signature); +int send_packet(packet *msg, int fd); +void free_packet(packet *msg); +int encrypt_packet(int sockfd, key_pair *kp); +packet *verify_packet(packet *pkt, int fd); +uint8_t *encrypt_data(uint8_t *from, uint8_t *to, uint8_t *raw, uint32_t raw_length, uint32_t *length); +uint8_t *decrypt_data(packet *pkt); +int verify_integrity(packet *pkt, public_key *pk); +uint8_t *create_signature(uint8_t *data, uint32_t length, secret_key *sk); #endif diff --git a/include/util.h b/include/util.h new file mode 100644 index 0000000..c4defb2 --- /dev/null +++ b/include/util.h @@ -0,0 +1,18 @@ +#ifndef UTIL_H +#define UTIL_H + +#define LOG_ERROR 1 +#define LOG_INFO 2 + +#define PATH_MAX 4096 + +void error(int fatal, const char *fmt, ...); +void *memalloc(size_t size); +void *estrdup(void *str); +int set_nonblocking(int fd); +char *replace_home(char *str); +void mkdir_p(const char *destdir); +void write_log(int type, const char *fmt, ...); +void print_bin(const unsigned char *ptr, size_t length); + +#endif diff --git a/lib/key.c b/lib/key.c new file mode 100644 index 0000000..e14fe4c --- /dev/null +++ b/lib/key.c @@ -0,0 +1,97 @@ +#include "packet.h" +#include "key.h" +#include "util.h" + +key_pair *create_key_pair(char *username) +{ + uint8_t cl_pk_bin[PK_BIN_SIZE], cl_sk_bin[SK_BIN_SIZE]; + crypto_sign_keypair(cl_pk_bin, cl_sk_bin); + char pk_path[PATH_MAX], sk_path[PATH_MAX]; + + /* USE DB INSTEAD OF FILES */ + sprintf(pk_path, "/home/night/%s_pk", username); + sprintf(sk_path, "/home/night/%s_sk", username); + FILE *pkf = fopen(pk_path, "w+"); + FILE *skf = fopen(sk_path, "w+"); + + uint8_t pk_content[PK_SIZE], sk_content[SK_SIZE], metadata[METADATA_SIZE]; + time_t current_time = time(NULL); + + uint8_t *username_padded = memalloc(MAX_NAME * sizeof(uint8_t)); + strcpy(username_padded, username); + size_t length = strlen(username); + if (length < MAX_NAME) { + /* Pad with null characters up to max length */ + memset(username_padded + length, 0, MAX_NAME - length); + } + memcpy(metadata, username_padded, MAX_NAME); + memcpy(metadata + MAX_NAME, ¤t_time, TIME_SIZE); + uint8_t *hash = memalloc(HASH_SIZE * sizeof(uint8_t)); + uint8_t *sign = memalloc(SIGN_SIZE * sizeof(uint8_t)); + crypto_generichash(hash, HASH_SIZE, metadata, METADATA_SIZE, NULL, 0); + crypto_sign_detached(sign, NULL, hash, HASH_SIZE, cl_sk_bin); + memcpy(pk_content, cl_pk_bin, PK_BIN_SIZE); + memcpy(pk_content + PK_BIN_SIZE, metadata, METADATA_SIZE); + memcpy(pk_content + PK_BIN_SIZE + METADATA_SIZE, sign, SIGN_SIZE); + memcpy(sk_content, cl_sk_bin, SK_BIN_SIZE); + memcpy(sk_content + SK_BIN_SIZE, metadata, METADATA_SIZE); + memcpy(sk_content + SK_BIN_SIZE + METADATA_SIZE, sign, SIGN_SIZE); + free(hash); + + fwrite(pk_content, 1, PK_SIZE, pkf); + fwrite(sk_content, 1, SK_SIZE, skf); + + fclose(pkf); + fclose(skf); + + key_pair *kp = memalloc(sizeof(key_pair)); + memcpy(kp->pk.bin, cl_pk_bin, PK_BIN_SIZE); + memcpy(kp->pk.username, username_padded, MAX_NAME); + kp->pk.creation = current_time; + memcpy(kp->pk.signature, sign, SIGN_SIZE); + + memcpy(kp->sk.bin, cl_sk_bin, SK_BIN_SIZE); + memcpy(kp->sk.username, username_padded, MAX_NAME); + kp->sk.creation = current_time; + memcpy(kp->sk.signature, sign, SIGN_SIZE); + + free(username_padded); + free(sign); + return kp; +} + +key_pair *get_key_pair(char *username) +{ + char pk_path[PATH_MAX], sk_path[PATH_MAX]; + sprintf(pk_path, "/home/night/%s_pk", username); + sprintf(sk_path, "/home/night/%s_sk", username); + + FILE *pkf = fopen(pk_path, "r"); + FILE *skf = fopen(sk_path, "r"); + + if (!pkf || !skf) { + printf("Error opening key files.\n"); + return NULL; + } + + uint8_t pk_content[PK_SIZE], sk_content[SK_SIZE]; + fread(pk_content, 1, PK_SIZE, pkf); + fread(sk_content, 1, SK_SIZE, skf); + + fclose(pkf); + fclose(skf); + + key_pair *kp = memalloc(sizeof(key_pair)); + + memcpy(kp->pk.bin, pk_content, PK_BIN_SIZE); + memcpy(kp->pk.username, pk_content + PK_BIN_SIZE, MAX_NAME); + memcpy(&kp->pk.creation, pk_content + PK_BIN_SIZE + MAX_NAME, TIME_SIZE); + memcpy(kp->pk.signature, pk_content + PK_BIN_SIZE + MAX_NAME + TIME_SIZE, SIGN_SIZE); + + memcpy(kp->sk.bin, sk_content, SK_BIN_SIZE); + memcpy(kp->sk.username, sk_content + SK_BIN_SIZE, MAX_NAME); + memcpy(&kp->sk.creation, sk_content + SK_BIN_SIZE + MAX_NAME, TIME_SIZE); + memcpy(kp->sk.signature, sk_content + SK_BIN_SIZE + MAX_NAME + TIME_SIZE, SIGN_SIZE); + + return kp; +} diff --git a/lib/notification.c b/lib/notification.c index a906f38..840757a 100644 --- a/lib/notification.c +++ b/lib/notification.c @@ -1,8 +1,16 @@ #include "notification.h" +#include "util.h" -void send_notification(const char *content) +void send_notification(uint8_t *content) { - NotifyNotification *noti = notify_notification_new("Client", content, "dialog-information"); - notify_notification_show(noti, NULL); - g_object_unref(G_OBJECT(noti)); + printf("Content: %s\n", content); + NotifyNotification *notification = notify_notification_new("Client", + (char *) content, "dialog-information"); + if (notification == NULL) { + error(0, "Cannot create notification"); + } + if (!notify_notification_show(notification, NULL)) { + error(0, "Cannot show notifcation"); + } + g_object_unref(G_OBJECT(notification)); } diff --git a/lib/packet.c b/lib/packet.c index 686e8f1..21fec35 100644 --- a/lib/packet.c +++ b/lib/packet.c @@ -1,273 +1,436 @@ #include "packet.h" +#include "key.h" +#include "util.h" -/* - * msg is the error message to print to stderr - * will include error message from function if errno isn't 0 - * end program is fatal is 1 - */ -void error(int fatal, const char *fmt, ...) +void print_packet(packet *pkt) { - va_list args; - va_start(args, fmt); - - /* to preserve errno */ - int errsv = errno; - - /* Determine the length of the formatted error message */ - va_list args_copy; - va_copy(args_copy, args); - size_t error_len = vsnprintf(NULL, 0, fmt, args_copy); - va_end(args_copy); - - /* 7 for [zsm], space and null */ - char errorstr[error_len + 1]; - vsnprintf(errorstr, error_len + 1, fmt, args); - fprintf(stderr, "[zsm] "); - - if (errsv != 0) { - perror(errorstr); - errno = 0; - } else { - fprintf(stderr, "%s\n", errorstr); - } - - va_end(args); - if (fatal) exit(1); -} - -void *memalloc(size_t size) -{ - void *ptr = malloc(size); - if (!ptr) { - error(0, "Error allocating memory"); - return NULL; - } - return ptr; -} - -void *estrdup(void *str) -{ - void *modstr = strdup(str); - if (modstr == NULL) { - error(0, "Error allocating memory"); - return NULL; - } - return modstr; -} - -uint8_t *get_public_key(int sockfd) -{ - message keyex_msg; - if (recv_packet(&keyex_msg, sockfd) != ZSM_STA_SUCCESS) { - /* We can't do anything if key exchange already failed */ - close(sockfd); - return NULL; - } else { - int status = 0; - /* Check to see if the content is actually a key */ - if (keyex_msg.type != ZSM_TYP_KEY) { - status = ZSM_STA_INVALID_TYPE; - } - if (keyex_msg.length != PUBLIC_KEY_SIZE) { - status = ZSM_STA_WRONG_KEY_LENGTH; - } - if (status != 0) { - free(keyex_msg.data); - message *error_msg = create_error_packet(status); - send_packet(error_msg, sockfd); - free_packet(error_msg); - close(sockfd); - return NULL; - } - } - /* Obtain public key from packet */ - uint8_t *pk = memalloc(PUBLIC_KEY_SIZE * sizeof(char)); - memcpy(pk, keyex_msg.data, PUBLIC_KEY_SIZE); - if (pk == NULL) { - free(keyex_msg.data); - /* Fatal, we couldn't complete key exchange */ - close(sockfd); - return NULL; - } - free(keyex_msg.data); - return pk; -} - -int send_public_key(int sockfd, uint8_t *pk) -{ - /* send_packet requires heap allocated buffer */ - uint8_t *pk_dup = memalloc(PUBLIC_KEY_SIZE * sizeof(char)); - memcpy(pk_dup, pk, PUBLIC_KEY_SIZE); - if (pk_dup == NULL) { - close(sockfd); - return -1; - } - - /* Sending our public key to client */ - /* option???? */ - message *keyex = create_packet(1, ZSM_TYP_KEY, PUBLIC_KEY_SIZE, pk_dup); - send_packet(keyex, sockfd); - free_packet(keyex); - return 0; -} - -void print_packet(message *msg) -{ - printf("Option: %d\n", msg->option); - printf("Type: %d\n", msg->type); - printf("Length: %lld\n", msg->length); - printf("Data: %s\n\n", msg->data); + printf("Status: %d\n", pkt->status); + printf("Type: %d\n", pkt->type); + printf("Length: %d\n", pkt->length); + if (pkt->length > 0) { + printf("Data:\n"); + for (int i = 0; i < pkt->length; i++) { + printf("%d ", pkt->data[i]); + } + printf("\n"); + printf("Signature:\n"); + for (int i = 0; i < SIGN_SIZE; i++) { + printf("%d ", pkt->signature[i]); + } + printf("\n"); + } } /* * Requires manually free message data + * pkt: packet to fill data in (must be created via create_packet) + * fd: file descriptor to read data from + * required_type: Required packet type to receive, set 0 to not check */ -int recv_packet(message *msg, int fd) +int recv_packet(packet *pkt, int fd, uint8_t required_type) { int status = ZSM_STA_SUCCESS; /* Read the message components */ - if (recv(fd, &msg->option, sizeof(msg->option), 0) < 0 || - recv(fd, &msg->type, sizeof(msg->type), 0) < 0 || - recv(fd, &msg->length, sizeof(msg->length), 0) < 0) { + if (recv(fd, &pkt->status, sizeof(pkt->status), 0) < 0 || + recv(fd, &pkt->type, sizeof(pkt->type), 0) < 0 || + recv(fd, &pkt->length, sizeof(pkt->length), 0) < 0) { status = ZSM_STA_READING_SOCKET; error(0, "Error reading from socket"); } #if DEBUG == 1 - printf("==========PACKET RECEIVED==========\n"); - #endif - #if DEBUG == 1 - printf("Option: %d\n", msg->option); + printf("==========PACKET RECEIVED==========\n"); + printf("Status: %d\n", pkt->status); #endif - if (msg->type > 0xFF || msg->type < 0x0) { + if (pkt->type > 0xFF || pkt->type < 0x0) { status = ZSM_STA_INVALID_TYPE; error(0, "Invalid message type"); goto failure; } #if DEBUG == 1 - printf("Type: %d\n", msg->type); + printf("Type: %d\n", pkt->type); #endif - /* Convert message length from network byte order to host byte order */ - if (msg->length > MAX_MESSAGE_LENGTH) { + /* Not the same type as wanted to receive */ + if (pkt->type != required_type) { + status = ZSM_STA_INVALID_TYPE; + error(0, "Invalid message type"); + goto failure; + } + + if (pkt->length > MAX_MESSAGE_LENGTH) { status = ZSM_STA_TOO_LONG; - error(0, "Message too long: %lld", msg->length); + error(0, "Message too long: %d", pkt->length); goto failure; } #if DEBUG == 1 - printf("Length: %lld\n", msg->length); + printf("Length: %d\n", pkt->length); #endif + size_t bytes_read = 0; - // Allocate memory for message data - msg->data = memalloc((msg->length + 1) * sizeof(char)); - if (msg->data == NULL) { - status = ZSM_STA_MEMORY_ALLOCATION; - goto failure; - } + /* If packet's length is 0, ignore its data and signature as it is information from server */ + if (pkt->type != ZSM_TYP_INFO && pkt->length > 0) { + pkt->data = memalloc((pkt->length + 1) * sizeof(char)); + if (pkt->data == NULL) { + status = ZSM_STA_MEMORY_ALLOCATION; + goto failure; + } - /* Read message data from the socket */ - size_t bytes_read = 0; - if ((bytes_read = recv(fd, msg->data, msg->length, 0)) < 0) { - status = ZSM_STA_READING_SOCKET; - error(0, "Error reading from socket"); - free(msg->data); - goto failure; - } - if (bytes_read != msg->length) { - status = ZSM_STA_INVALID_LENGTH; - error(0, "Invalid message length: bytes_read=%ld != msg->length=%lld", bytes_read, msg->length); - free(msg->data); - goto failure; - } - msg->data[msg->length] = '\0'; + /* Read message data from the socket */ + if ((bytes_read = recv(fd, pkt->data, pkt->length, 0)) < 0) { + status = ZSM_STA_READING_SOCKET; + error(0, "Error reading from socket"); + free(pkt->data); + goto failure; + } + if (bytes_read != pkt->length) { + status = ZSM_STA_INVALID_LENGTH; + error(0, "Invalid message length: bytes_read=%ld != pkt->length=%d", bytes_read, pkt->length); + free(pkt->data); + goto failure; + } + pkt->data[pkt->length] = '\0'; + #if DEBUG == 1 + printf("Data:\n"); + for (int i = 0; i < pkt->length; i++) { + printf("%d ", pkt->data[i]); + } + printf("\n"); + #endif - #if DEBUG == 1 - printf("Data: %s\n\n", msg->data); + pkt->signature = memalloc((SIGN_SIZE + 1) * sizeof(char)); + if (pkt->signature == NULL) { + status = ZSM_STA_MEMORY_ALLOCATION; + goto failure; + } + + if ((bytes_read = recv(fd, pkt->signature, SIGN_SIZE, 0)) < 0) { + status = ZSM_STA_READING_SOCKET; + error(0, "Error reading from socket"); + free(pkt->data); + goto failure; + } + /* Don't check signature if the packet is emtpy */ + if (pkt->length > 0 && bytes_read != SIGN_SIZE) { + status = ZSM_STA_INVALID_LENGTH; + error(0, "Invalid signature length: bytes_read=%ld != SIGN_SIZE(32)", bytes_read); + free(pkt->data); + goto failure; + } + pkt->signature[SIGN_SIZE] = '\0'; + + #if DEBUG == 1 + printf("Signature:\n"); + for (int i = 0; i < SIGN_SIZE; i++) { + printf("%d ", pkt->signature[i]); + } + printf("\n"); + #endif + } + #if DEBUG == 1 + printf("==========END RECEIVING============\n"); #endif return status; + failure:; - message *error_msg = create_error_packet(status); - if (send_packet(error_msg, fd) != ZSM_STA_SUCCESS) { + packet *error_pkt = create_packet(status, ZSM_TYP_ERROR, 0, NULL, + create_signature(NULL, 0, NULL)); + + if (send_packet(error_pkt, fd) != ZSM_STA_SUCCESS) { /* Resend it? */ - error(0, "Failed to send error packet to peer. Error status => %d", status); + error(0, "Failed to send error packet. Error status => %d", status); } - free_packet(error_msg); + free_packet(error_pkt); return status; } -message *create_error_packet(int code) +/* + * Creates a packet for receive or send + * Requires heap allocated data + */ +packet *create_packet(uint8_t status, uint8_t type, uint32_t length, uint8_t *data, uint8_t *signature) { - char *err = memalloc(ERROR_LENGTH * sizeof(char)); - switch (code) { - case ZSM_STA_INVALID_TYPE: - strcpy(err, "Invalid message type "); - break; - case ZSM_STA_INVALID_LENGTH: - strcpy(err, "Invalid message length "); - break; - case ZSM_STA_TOO_LONG: - strcpy(err, "Message too long "); - break; - case ZSM_STA_READING_SOCKET: - strcpy(err, "Error reading from socket"); - break; - case ZSM_STA_WRITING_SOCKET: - strcpy(err, "Error writing to socket "); - break; - case ZSM_STA_UNKNOWN_USER: - strcpy(err, "Unknwon user "); - break; - case ZSM_STA_WRONG_KEY_LENGTH: - strcpy(err, "Wrong public key length "); - break; - } - return create_packet(1, ZSM_TYP_ERROR, ERROR_LENGTH, err); + packet *pkt = memalloc(sizeof(packet)); + pkt->status = status; + pkt->type = type; + pkt->length = length; + pkt->data = data; + pkt->signature = signature; + return pkt; } /* - * Requires heap allocated msg data + * Sends packet to fd + * Requires heap allocated data + * Close file descriptor and free data on failure */ -message *create_packet(uint8_t option, uint8_t type, uint32_t length, char *data) -{ - message *msg = memalloc(sizeof(message)); - msg->option = option; - msg->type = type; - msg->length = length; - msg->data = data; - return msg; -} - -/* - * Requires heap allocated msg data - */ -int send_packet(message *msg, int fd) +int send_packet(packet *pkt, int fd) { int status = ZSM_STA_SUCCESS; - uint32_t length = msg->length; - // Send the message back to the client - if (send(fd, &msg->option, sizeof(msg->option), 0) <= 0 || - send(fd, &msg->type, sizeof(msg->type), 0) <= 0 || - send(fd, &msg->length, sizeof(msg->length), 0) <= 0 || - send(fd, msg->data, length, 0) <= 0) { - status = ZSM_STA_WRITING_SOCKET; - error(0, "Error writing to socket"); - //free(msg->data); - close(fd); // Close the socket and continue accepting connections - } + uint32_t length = pkt->length; + if (send(fd, &pkt->status, sizeof(pkt->status), 0) <= 0 || + send(fd, &pkt->type, sizeof(pkt->type), 0) <= 0 || + send(fd, &pkt->length, sizeof(pkt->length), 0) <= 0) + { + goto failure; + } + if (pkt->type != ZSM_TYP_INFO && pkt->length > 0 && pkt->data != NULL) { + if (send(fd, pkt->data, length, 0) <= 0) goto failure; + if (send(fd, pkt->signature, SIGN_SIZE, 0) <= 0) goto failure; + } + #if DEBUG == 1 - printf("==========PACKET SENT==========\n"); - print_packet(msg); + printf("==========PACKET SENT============\n"); + print_packet(pkt); + printf("==========END SENT===============\n"); #endif return status; + +failure: + /* Or we could resend it? */ + status = ZSM_STA_WRITING_SOCKET; + error(0, "Error writing to socket"); + free(pkt->data); + close(fd); + return status; } -void free_packet(message *msg) +/* + * Free allocated memory in packet + */ +void free_packet(packet *pkt) { - if (msg->type != 0x10) { - /* temp solution, dont use stack allocated msg to send to client */ - free(msg->data); + if (pkt->type != ZSM_TYP_AUTH) { + if (pkt->signature != NULL) { + free(pkt->signature); + } } - free(msg); + free(pkt->data); + free(pkt); +} + +/* + * not going to stay + */ +char *getuserinput() +{ + printf("Enter message to send: "); + fflush(stdout); + char *line = memalloc(1024); + line[0] = '\0'; + size_t length = strlen(line); + while (length <= 1) { + fgets(line, 1024, stdin); + length = strlen(line); + } + length -= 1; + line[length] = '\0'; + return line; +} + +/* + * not going to stay + */ +char *getrecipient() +{ + printf("Enter recipient: "); + fflush(stdout); + char *line = memalloc(32); + line[0] = '\0'; + size_t length = strlen(line); + while (length <= 1) { + fgets(line, 1024, stdin); + length = strlen(line); + } + length -= 1; + line[length] = '\0'; + return line; +} + + +int encrypt_packet(int sockfd, key_pair *kp) +{ + int status = ZSM_STA_SUCCESS; + char *line = getuserinput(); + uint8_t *recipient = getrecipient(); + uint32_t data_len; + uint8_t *raw_data = memalloc(8192); + size_t length = strlen(recipient); + size_t length_line = strlen(line); + if (length < MAX_NAME) { + /* Pad with null characters up to max length */ + memset(recipient + length, 0, MAX_NAME - length); + } + memcpy(raw_data, kp->pk.username, MAX_NAME); + memcpy(raw_data + MAX_NAME, recipient, MAX_NAME); + memcpy(raw_data + MAX_NAME * 2, line, length_line); + size_t raw_data_size = MAX_NAME * 2 + strlen(line); + + uint8_t *data = encrypt_data(kp->pk.username, recipient, raw_data, raw_data_size, &data_len); + uint8_t *signature = create_signature(data, data_len, &kp->sk); + packet *pkt = create_packet(1, ZSM_TYP_MESSAGE, data_len, data, signature); + if ((status = send_packet(pkt, sockfd)) != ZSM_STA_SUCCESS) { + close(sockfd); + return status; + } + free(recipient); + free(line); + free_packet(pkt); + return status; +} + +/* + * Wrapper for recv_packet to verify packet + * Reads packet from fd, stores in pkt + * TODO: pkt is unncessary + */ +packet *verify_packet(packet *pkt, int fd) +{ + if (recv_packet(pkt, fd, ZSM_TYP_MESSAGE) != ZSM_STA_SUCCESS) { + close(fd); + return NULL; + } + + uint8_t from[MAX_NAME], to[MAX_NAME]; + memcpy(from, pkt->data, MAX_NAME); + /* TODO: replace with db operations */ + key_pair *kp_from = get_key_pair(from); + + if (verify_integrity(pkt, &kp_from->pk) != ZSM_STA_SUCCESS) { + free(pkt->data); + free(pkt->signature); + packet *error_pkt = create_packet(ZSM_STA_ERROR_INTEGRITY, ZSM_TYP_ERROR, 0, NULL, + create_signature(NULL, 0, NULL)); + send_packet(error_pkt, fd); + free_packet(error_pkt); + return NULL; + } + return pkt; +} + +/* + * Encrypt raw with raw_length using to + * length is set to sum length of random bytes and scrambled data + */ +uint8_t *encrypt_data(uint8_t *from, uint8_t *to, uint8_t *raw, uint32_t raw_length, uint32_t *length) +{ + key_pair *kp_from = get_key_pair(from); + key_pair *kp_to = get_key_pair(to); + + uint8_t shared_key[SHARED_SIZE]; + if (crypto_kx_client_session_keys(shared_key, NULL, + kp_from->pk.bin, kp_from->sk.bin, kp_to->pk.bin) != 0) { + /* Suspicious server public key, bail out */ + error(0, "Error performing key exchange"); + } + + uint8_t nonce[NONCE_SIZE]; + uint32_t encrypted_len = raw_length + ADDITIONAL_SIZE; + uint8_t encrypted[encrypted_len]; + + /* Generate random nonce(number used once) */ + printf("raw: %s\n", raw); + randombytes_buf(nonce, sizeof(nonce)); + crypto_aead_xchacha20poly1305_ietf_encrypt(encrypted, NULL, raw, + raw_length, NULL, 0, NULL, nonce, shared_key); + size_t data_len = MAX_NAME * 2 + NONCE_SIZE + encrypted_len; + *length = data_len; + uint8_t *data = memalloc(data_len * sizeof(uint8_t)); + memcpy(data, kp_from->sk.username, MAX_NAME); + memcpy(data + MAX_NAME, kp_to->sk.username, MAX_NAME); + memcpy(data + MAX_NAME * 2, nonce, NONCE_SIZE); + memcpy(data + MAX_NAME * 2 + NONCE_SIZE, encrypted, encrypted_len); + + return data; +} + +/* + * Should be used by clients + */ +uint8_t *decrypt_data(packet *pkt) +{ + size_t encrypted_len = pkt->length - NONCE_SIZE - MAX_NAME * 2; + size_t data_len = encrypted_len - ADDITIONAL_SIZE; + uint8_t nonce[NONCE_SIZE], from[MAX_NAME], to[MAX_NAME], encrypted[encrypted_len]; + uint8_t *decrypted = memalloc((data_len + 1) * sizeof(uint8_t)); + memcpy(from, pkt->data, MAX_NAME); + memcpy(to, pkt->data + MAX_NAME, MAX_NAME); + memcpy(nonce, pkt->data + MAX_NAME * 2, NONCE_SIZE); + memcpy(encrypted, pkt->data + MAX_NAME * 2 + NONCE_SIZE, encrypted_len); + + key_pair *kp_from = get_key_pair(from); + key_pair *kp_to = get_key_pair(to); + + uint8_t shared_key[SHARED_SIZE]; + if (crypto_kx_client_session_keys(shared_key, NULL, + kp_from->pk.bin, kp_from->sk.bin, kp_to->pk.bin) != 0) { + /* Suspicious server public key, bail out */ + error(0, "Error performing key exchange"); + } + + /* We don't need it anymore */ + free(pkt->data); + if (crypto_aead_xchacha20poly1305_ietf_decrypt(decrypted, NULL, + NULL, encrypted, + encrypted_len, + NULL, 0, + nonce, shared_key) != 0) { + free(decrypted); + error(0, "Cannot decrypt message"); + return NULL; + } else { + /* Terminate decrypted message so we don't print random bytes */ + decrypted[data_len] = '\0'; + printf("<%s> to <%s>: %s\n", from, to, decrypted + MAX_NAME * 2); + return decrypted; + } +} + + +/* + * Verify message integrity + confidentiality + * Verify using public key against hashed message + */ +int verify_integrity(packet *pkt, public_key *pk) +{ + uint8_t hash[HASH_SIZE]; + + /* Hash data to check if matches user provided correct signature */ + + crypto_generichash(hash, HASH_SIZE, + pkt->data, pkt->length, + NULL, 0); + + if (crypto_sign_verify_detached(pkt->signature, hash, HASH_SIZE, pk->bin) != 0) { + /* Not match */ + error(0, "Cannot verify message integrity"); + return ZSM_STA_ERROR_INTEGRITY; + } + + return ZSM_STA_SUCCESS; +} + +/* + * Create signature for packet + * When data, secret is null, length is 0, empty siganture is created + */ +uint8_t *create_signature(uint8_t *data, uint32_t length, secret_key *sk) +{ + uint8_t *signature = memalloc(SIGN_SIZE * sizeof(uint8_t)); + if (data == NULL && length == 0 && sk == NULL) { + /* From server, give fake signature */ + memset(signature, 0, SIGN_SIZE * sizeof(uint8_t)); + } else { + uint8_t hash[HASH_SIZE]; + + /* Hash data to check if matches user provided correct signature */ + crypto_generichash(hash, HASH_SIZE, + data, length, + NULL, 0); + crypto_sign_detached(signature, NULL, hash, HASH_SIZE, sk->bin); + } + + return signature; } diff --git a/lib/util.c b/lib/util.c new file mode 100644 index 0000000..a4e563a --- /dev/null +++ b/lib/util.c @@ -0,0 +1,173 @@ +#include "config.h" +#include "packet.h" +#include "util.h" + +/* + * will include error message from function if errno isn't 0 + * end program is fatal is 1 + */ +void error(int fatal, const char *fmt, ...) +{ + va_list args; + va_start(args, fmt); + + /* to preserve errno */ + int errsv = errno; + + /* Determine the length of the formatted error message */ + va_list args_copy; + va_copy(args_copy, args); + size_t error_len = vsnprintf(NULL, 0, fmt, args_copy); + va_end(args_copy); + + /* 7 for [zsm], space and null */ + char errorstr[error_len + 1]; + vsnprintf(errorstr, error_len + 1, fmt, args); + fprintf(stderr, "[zsm] "); + + if (errsv != 0) { + perror(errorstr); + errno = 0; + } else { + fprintf(stderr, "%s\n", errorstr); + } + + va_end(args); + if (fatal) exit(1); +} + +void *memalloc(size_t size) +{ + void *ptr = malloc(size); + if (!ptr) { + write_log(LOG_ERROR, "Error allocating memory\n"); + return NULL; + } + return ptr; +} + +void *estrdup(void *str) +{ + void *modstr = strdup(str); + if (modstr == NULL) { + write_log(LOG_ERROR, "Error allocating memory\n"); + return NULL; + } + return modstr; +} + +/* + * Set socket to non blocking so epoll can use EPOLLET flag + */ +int set_nonblocking(int fd) +{ + int flags = fcntl(fd, F_GETFL, 0); + if (flags == -1) { + return -1; + } + return fcntl(fd, F_SETFL, flags | O_NONBLOCK); +} + +/* + * Takes heap-allocated str and replace ~ with home path + * Returns heap-allocated newstr + */ +char *replace_home(char *str) +{ + char *home = getenv("HOME"); + if (home == NULL) { + write_log(LOG_ERROR, "$HOME not defined\n"); + return str; + } + char *newstr = memalloc((strlen(str) + strlen(home)) * sizeof(char)); + /* replace ~ with home */ + snprintf(newstr, strlen(str) + strlen(home), "%s%s", home, str + 1); + free(str); + return newstr; +} + +/* + * Recursively create directory by creating each subdirectory + * like mkdir -p + */ +void mkdir_p(const char *destdir) +{ + char *path = memalloc(PATH_MAX * sizeof(char)); + char dir_path[PATH_MAX] = ""; + + if (destdir[0] == '~') { + char *home = getenv("HOME"); + if (home == NULL) { + write_log(LOG_ERROR, "$HOME not defined\n"); + return; + } + /* replace ~ with home */ + snprintf(path, PATH_MAX, "%s%s", home, destdir + 1); + } else { + strcpy(path, destdir); + } + + /* fix first / not appearing in the string */ + if (path[0] == '/') + strcat(dir_path, "/"); + + char *token = strtok(path, "/"); + while (token != NULL) { + strcat(dir_path, token); + strcat(dir_path, "/"); + + if (mkdir(dir_path, 0755) == -1) { + struct stat st; + if (stat(dir_path, &st) == 0 && S_ISDIR(st.st_mode)) { + /* Directory already exists, continue to the next dir */ + token = strtok(NULL, "/"); + continue; + } + + write_log(LOG_ERROR, "mkdir failed: %s\n", strerror(errno)); + free(path); + return; + } + token = strtok(NULL, "/"); + } + + free(path); + return; +} + +void write_log(int type, const char *fmt, ...) +{ + va_list args; + va_start(args, fmt); + char *client_data_dir = estrdup(CLIENT_DATA_DIR); + mkdir_p(client_data_dir); + client_data_dir = replace_home(client_data_dir); + char *client_log = memalloc(PATH_MAX); + snprintf(client_log, PATH_MAX, "%s/%s", client_data_dir, "zen.log"); + free(client_data_dir); + FILE *log = fopen(client_log, "a"); + if (log != NULL) { + time_t now = time(NULL); + struct tm *t = localtime(&now); + /* either info or error */ + int type_len = type == LOG_INFO ? 4 : 5; + char logtype[4 + type_len]; + snprintf(logtype, 4 + type_len, "[%s] ", type == LOG_INFO ? "INFO" : "ERROR"); + char time[21]; + strftime(time, 22, "%Y-%m-%d %H:%M:%S ", t); + char details[2 + type_len + 22]; + snprintf(details, 2 + type_len + 22, "%s%s", logtype, time); + fprintf(log, details); + vfprintf(log, fmt, args); + } + fclose(log); + va_end(args); +} + +void print_bin(const unsigned char *ptr, size_t length) +{ + for (size_t i = 0; i < length; i++) { + printf("%02x ", ptr[i]); + } + printf("\n"); +} diff --git a/src/client/client.c b/src/client/client.c index 332f924..25b6037 100644 --- a/src/client/client.c +++ b/src/client/client.c @@ -1,15 +1,93 @@ +#include "config.h" #include "packet.h" -#include +#include "key.h" +#include "util.h" +#include "client/ui.h" +#include "client/db.h" -uint8_t shared_key[SHARED_KEY_SIZE]; int sockfd; /* - * Connect to socket server + * Authenticate with server by signing a challenge */ -int socket_init() +int authenticate_server(key_pair *kp) { - int sockfd = socket(AF_INET, SOCK_STREAM, 0); + packet server_auth_pkt; + int status; + if ((status = recv_packet(&server_auth_pkt, sockfd, ZSM_TYP_AUTH) != ZSM_STA_SUCCESS)) { + return status; + } + uint8_t *challenge = server_auth_pkt.data; + + uint8_t *sig = memalloc(SIGN_SIZE * sizeof(uint8_t)); + crypto_sign_detached(sig, NULL, challenge, CHALLENGE_SIZE, kp->sk.bin); + + uint8_t *pk_content = memalloc(PK_SIZE); + memcpy(pk_content, kp->pk.bin, PK_BIN_SIZE); + memcpy(pk_content + PK_BIN_SIZE, kp->pk.username, MAX_NAME); + memcpy(pk_content + PK_BIN_SIZE + MAX_NAME, &kp->pk.creation, TIME_SIZE); + memcpy(pk_content + PK_BIN_SIZE + METADATA_SIZE, kp->pk.signature, SIGN_SIZE); + + packet *auth_pkt = create_packet(1, ZSM_TYP_AUTH, SIGN_SIZE, pk_content, sig); + if (send_packet(auth_pkt, sockfd) != ZSM_STA_SUCCESS) { + /* fd already closed */ + error(0, "Could not authenticate with server"); + free(sig); + free_packet(auth_pkt); + return ZSM_STA_ERROR_AUTHENTICATE; + } + + free_packet(auth_pkt); + packet response; + status = recv_packet(&response, sockfd, ZSM_TYP_INFO); + return (response.status == ZSM_STA_AUTHORISED ? ZSM_STA_SUCCESS : ZSM_STA_ERROR_AUTHENTICATE); +} + +/* + * For sending packets to server + */ +void *send_message(void *arg) +{ + key_pair *kp = (key_pair *) arg; + + while (1) { + int status = encrypt_packet(sockfd, kp); + if (status != ZSM_STA_SUCCESS) { + error(1, "Error encrypting packet %x", status); + } + } + + return NULL; +} + +/* + * For receiving packets from server + */ +void *receive_message(void *arg) +{ + key_pair *kp = (key_pair *) arg; + + while (1) { + packet pkt; + + if (verify_packet(&pkt, sockfd) == 0) { + error(0, "Error verifying packet"); + } + uint8_t *decrypted = decrypt_data(&pkt); + free(decrypted); + } + + return NULL; +} + +int main() +{ + if (sodium_init() < 0) { + write_log(LOG_ERROR, "Error initializing libsodium\n"); + } + //ui(); + + sockfd = socket(AF_INET, SOCK_STREAM, 0); if (sockfd < 0) { error(1, "Error on opening socket"); } @@ -19,129 +97,58 @@ int socket_init() error(1, "No such host %s", DOMAIN); } - struct sockaddr_in sv_addr; - memset(&sv_addr, 0, sizeof(sv_addr)); - sv_addr.sin_family = AF_INET; - sv_addr.sin_port = htons(PORT); - memcpy(&sv_addr.sin_addr.s_addr, server->h_addr, server->h_length); + struct sockaddr_in server_addr; + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(PORT); + memcpy(&server_addr.sin_addr.s_addr, server->h_addr, server->h_length); -/* free(server); */ - if (connect(sockfd, (struct sockaddr *) &sv_addr, sizeof(sv_addr)) < 0) { - error(1, "Error on connect"); - close(sockfd); - return 0; - } - printf("Connected to server at %s\n", DOMAIN); - return sockfd; -} - -/* - * Performs key exchange with server - */ -int key_exchange(int sockfd) -{ - /* Generate the client's key pair */ - uint8_t cl_pk[PUBLIC_KEY_SIZE], cl_sk[PRIVATE_KEY_SIZE]; - crypto_kx_keypair(cl_pk, cl_sk); - - /* Send our public key */ - if (send_public_key(sockfd, cl_pk) < 0) { - return -1; - } - - /* Get public key from server */ - uint8_t *pk; - if ((pk = get_public_key(sockfd)) == NULL) { - return -1; - } - - /* Compute a shared key using the server's public key and our secret key */ - if (crypto_kx_client_session_keys(NULL, shared_key, cl_pk, cl_sk, pk) != 0) { - error(1, "Server public key is not acceptable"); - free(pk); - close(sockfd); - return -1; - } - free(pk); - return 0; -} - -void *sender() -{ - while (1) { - printf("Enter message to send to server: "); - fflush(stdout); - char line[1024]; - line[0] = '\0'; - size_t length = strlen(line); - while (length <= 1) { - fgets(line, sizeof(line), stdin); - length = strlen(line); - } - length -= 1; - line[length] = '\0'; - - uint8_t nonce[NONCE_SIZE]; - uint8_t encrypted[length + ADDITIONAL_SIZE]; - unsigned long long encrypted_len; - - randombytes_buf(nonce, sizeof(nonce)); - crypto_aead_xchacha20poly1305_ietf_encrypt(encrypted, &encrypted_len, - line, length, - NULL, 0, NULL, nonce, shared_key); - size_t payload_t = NONCE_SIZE + encrypted_len; - uint8_t encryptedwithnonce[payload_t]; - memcpy(encryptedwithnonce, nonce, NONCE_SIZE); - memcpy(encryptedwithnonce + NONCE_SIZE, encrypted, encrypted_len); - - message *msg = create_packet(1, 0x10, payload_t, encryptedwithnonce); - if (send_packet(msg, sockfd) != ZSM_STA_SUCCESS) { - close(sockfd); - } - free_packet(msg); - } - close(sockfd); - -} - -void *receiver() -{ - while (1) { - message servermsg; - if (recv_packet(&servermsg, sockfd) != ZSM_STA_SUCCESS) { +/* free(server); Can't be freed seems */ + if (connect(sockfd, (struct sockaddr *) &server_addr, sizeof(server_addr) + ) < 0) { + if (errno != EINPROGRESS) { + /* Connection is in progress, shouldn't be treated as error */ + error(1, "Error on connect"); close(sockfd); return 0; } - free(servermsg.data); - } - return NULL; -} + } + write_log(LOG_INFO, "Connected to server at %s\n", DOMAIN); -int main() -{ - if (sodium_init() < 0) { - error(1, "Error initializing libsodium"); - } - sockfd = socket_init(); - if (key_exchange(sockfd) < 0) { +/* set_nonblocking(sockfd); */ + /* + key_pair *kpp = create_key_pair("palanix"); + key_pair *kpn = create_key_pair("night"); +*/ + key_pair *kpp = get_key_pair("palanix"); + key_pair *kpn = get_key_pair("night"); + + if (authenticate_server(kpp) != ZSM_STA_SUCCESS) { /* Fatal */ - error(1, "Error performing key exchange with server"); - } - pthread_t recv_worker, send_worker; - if (pthread_create(&recv_worker, NULL, sender, NULL) != 0) { - fprintf(stderr, "Error creating incoming thread\n"); - return 1; + error(1, "Error authenticating with server"); + } else { + write_log(LOG_INFO, "Authenticated with server\n"); + printf("Authenticated as palanix\n"); + } + + /* Create threads for sending and receiving messages */ + pthread_t send_thread, receive_thread; + + if (pthread_create(&send_thread, NULL, send_message, kpp) != 0) { + close(sockfd); + error(1, "Failed to create send thread"); + exit(EXIT_FAILURE); } - if (pthread_create(&send_worker, NULL, receiver, NULL) != 0) { - fprintf(stderr, "Error creating outgoing thread\n"); - return 1; + if (pthread_create(&receive_thread, NULL, receive_message, kpp) != 0) { + close(sockfd); + error(1, "Failed to create receive thread"); } - // Join threads - pthread_join(recv_worker, NULL); - pthread_join(send_worker, NULL); + /* Wait for threads to finish */ + pthread_join(send_thread, NULL); + pthread_join(receive_thread, NULL); + close(sockfd); return 0; } - diff --git a/src/client/db.c b/src/client/db.c new file mode 100644 index 0000000..1226208 --- /dev/null +++ b/src/client/db.c @@ -0,0 +1,65 @@ +#include +#include + +#include "config.h" +#include "packet.h" +#include "util.h" +#include "client/ui.h" +#include "client/db.h" +#include "client/user.h" + +static int callback(void *NotUsed, int argc, char **argv, char **azColName) +{ + char *username = memalloc(32 * sizeof(char)); + strcpy(username, argv[0]); + add_username(username); + + /* + for(int i = 0; i < argc; i++) { + printf("%s = %s\n", azColName[i], argv[i] ? argv[i] : "NULL"); + } + printf("\n"); + */ + return 0; +} + +int sqlite_init() +{ + sqlite3 *db; + char *err_msg = 0; + + int rc = sqlite3_open(DATABASE_NAME, &db); + + if (rc != SQLITE_OK) { + fprintf(stderr, "Cannot open database: %s\n", sqlite3_errmsg(db)); + sqlite3_close(db); + return 1; + } + + char *users_statement = "CREATE TABLE IF NOT EXISTS Users(Username TEXT, SecretKey TEXT, Test TEXT);"; + char *messages_statement = "CREATE TABLE IF NOT EXISTS Messages(Username TEXT, );"; + //"INSERT INTO Users VALUES('night', 'test', '1');"; + + rc = sqlite3_exec(db, users_statement, 0, 0, &err_msg); + + if (rc != SQLITE_OK) { + error(0, "SQL error: %s", err_msg); + sqlite3_free(err_msg); + } else { +/* error(0, "Table created successfully"); */ + } + + // Select and print all entries + const char* data = "Callback function called"; + rc = sqlite3_exec(db, "SELECT * FROM Users", callback, (void*) data, &err_msg); + + if (rc != SQLITE_OK ) { + error(0, "SQL error: %s\n", err_msg); + sqlite3_free(err_msg); + } + + sqlite3_close(db); + + return 0; +} + diff --git a/src/client/ui.c b/src/client/ui.c new file mode 100644 index 0000000..08c76dd --- /dev/null +++ b/src/client/ui.c @@ -0,0 +1,337 @@ +#include "config.h" +#include "packet.h" +#include "util.h" +#include "client/ui.h" +#include "client/db.h" +#include "client/user.h" + +typedef struct windows { + WINDOW *users_border; + WINDOW *users_content; + WINDOW *chat_border; + WINDOW *chat_content; +} windows; + +WINDOW *panel; +WINDOW *users_border; +WINDOW *chat_border; + +WINDOW *users_content; +WINDOW *chat_content; + +ArrayList *users; +ArrayList *marked; +long current_selection = 0; + +bool show_icons; + +void show_chat(); + +void ncurses_init() +{ + /* check if it is interactive shell */ + if (!isatty(STDIN_FILENO)) { + error(1, "No tty detected. zsm requires an interactive shell to run"); + } + + /* initialize screen, don't print special chars, + * make ctrl + c work, don't show cursor + * enable arrow keys */ + initscr(); + noecho(); + cbreak(); + curs_set(0); + keypad(stdscr, TRUE); + /* check terminal has colors */ + if (!has_colors()) { + endwin(); + error(1, "Color is not supported in your terminal"); + } else { + use_default_colors(); + start_color(); + } + /* colors */ + init_pair(1, COLOR_BLACK, -1); /* */ + init_pair(2, COLOR_RED, -1); /* */ + init_pair(3, COLOR_GREEN, -1); /* */ + init_pair(4, COLOR_YELLOW, -1); /* */ + init_pair(5, COLOR_BLUE, -1); /* */ + init_pair(6, COLOR_MAGENTA, -1); /* */ + init_pair(7, COLOR_CYAN, -1); /* */ + init_pair(8, COLOR_WHITE, -1); /* */ +} + +/* + * Draw windows + */ +void windows_init() +{ + int users_width = 32; + int chat_width = COLS - 32; + + /*------------------------------+ + |----border(0)--||--border(2)--|| + || || || + || content (1) || content (3) || + || (users) || (chat) || + || || || + |---------------||-------------|| + +==========panel (4)===========*/ + + /* lines, cols, y, x */ + panel = newwin(PANEL_HEIGHT, COLS, LINES - PANEL_HEIGHT, 0 ); + /* draw border around windows */ + users_border = newwin(LINES - PANEL_HEIGHT, users_width + 2, 0, 0 ); + chat_border = newwin(LINES - PANEL_HEIGHT, chat_width - 2, 0, users_width + 2); + + users_content = newwin(LINES - PANEL_HEIGHT - 2, users_width, 1, 1 ); + chat_content = newwin(LINES - PANEL_HEIGHT - 2, chat_width - 4, 1, users_width + 3); + + refresh(); + draw_border(users_border, true); + draw_border(chat_border, false); + + scrollok(users_content, true); + scrollok(chat_content, true); + refresh(); +} + +/* + * Draw the border of the window depending if it's active or not, + */ +void draw_border(WINDOW *window, bool active) +{ + int width; + if (window == users_border) { + width = 34; + } else { + width = COLS - 34; + } + + /* turn on color depends on active */ + if (active) { + wattron(window, COLOR_PAIR(3)); + } else { + wattron(window, COLOR_PAIR(5)); + } + /* draw top border */ + mvwaddch(window, 0, 0, ACS_ULCORNER); /* upper left corner */ + mvwhline(window, 0, 1, ACS_HLINE, COLS - 2); /* top horizontal line */ + mvwaddch(window, 0, width - 1, ACS_URCORNER); /* upper right corner */ + + /* draw side border */ + mvwvline(window, 1, 0, ACS_VLINE, LINES - 2); /* left vertical line */ + mvwvline(window, 1, width - 1, ACS_VLINE, LINES - 2); /* right vertical line */ + + /* draw bottom border + * make space for the panel */ + mvwaddch(window, LINES - PANEL_HEIGHT - 1, 0, ACS_LLCORNER); /* lower left corner */ + mvwhline(window, LINES - PANEL_HEIGHT - 1, 1, ACS_HLINE, width - 2); /* bottom horizontal line */ + mvwaddch(window, LINES - PANEL_HEIGHT - 1, width - 1, ACS_LRCORNER); /* lower right corner */ + + /* turn color off after turning it on */ + if (active) { + wattroff(window, COLOR_PAIR(3)); + } else { + wattroff(window, COLOR_PAIR(5)); + } + wrefresh(window); /* Refresh the window to see the colored border and title */ +} + +/* + * Print line to the panel + */ +void wpprintw(const char *fmt, ...) +{ + va_list args; + va_start(args, fmt); + wclear(panel); + vw_printw(panel, fmt, args); + va_end(args); + wrefresh(panel); +} + +/* + * Highlight current line by reversing the color + */ +void highlight_current_line() +{ + long overflow = 0; + if (current_selection > LINES - 4) { + /* overflown */ + overflow = current_selection - (LINES - 4); + } + + /* calculate range of files to show */ + long range = users->length; + /* not highlight if no files in directory */ + if (range == 0 && errno == 0) { + #if DRAW_PREVIEW + wprintw(chat_content, "No users. Start a converstation."); + wrefresh(chat_content); + #endif + return; + } + + if (range > LINES - 3) { + /* if there are more files than lines available to display + * shrink range to avaiable lines to display with + * overflow to keep the number of iterations to be constant */ + range = LINES - 3 + overflow; + } + + wclear(users_content); + long line_count = 0; + for (long i = overflow; i < range; i++) { + if ((overflow == 0 && i == current_selection) || (overflow != 0 && i == current_selection)) { + wattron(users_content, A_REVERSE); + + /* check for marked user */ + long num_marked = marked->length; + if (num_marked > 0) { + /* Determine length of formatted string */ + int m_len = snprintf(NULL, 0, "[%ld] selected", num_marked); + char *selected = memalloc((m_len + 1) * sizeof(char)); + + snprintf(selected, m_len + 1, "[%ld] selected", num_marked); + wpprintw("(%ld/%ld) %s", current_selection + 1, users->length, selected); + } else { + wpprintw("(%ld/%ld)", current_selection + 1, users->length); + } + } + /* print the actual filename and stats */ + char *line = get_line(users, i, show_icons); + int color = users->items[i].color; + /* check is user marked for action */ + bool is_marked = arraylist_search(marked, users->items[i].name) != -1; + if (is_marked) { + /* show user is selected */ + wattron(users_content, COLOR_PAIR(7)); + } else { + /* print the whole directory with default colors */ + wattron(users_content, COLOR_PAIR(color)); + } + + if (overflow > 0) + mvwprintw(users_content, line_count, 0, "%s", line); + else + mvwprintw(users_content, i, 0, "%s", line); + + if (is_marked) { + wattroff(users_content, COLOR_PAIR(7)); + } else { + wattroff(users_content, COLOR_PAIR(color)); + } + + wattroff(users_content, A_REVERSE); + //free(line); + line_count++; + } + + wrefresh(users_content); + wrefresh(panel); + /* show chat conversation every time cursor changes */ + #if DRAW_PREVIEW + show_chat(); + #endif + #if DRAW_BORDERS + draw_border_title(preview_border, true); + #endif + wrefresh(chat_content); +} + +/* + * Add message to chat window + * user_color is the color defined above at ncurses_init + */ + +void add_message(time_t rawtime, char *username, int user_color, char *content) +{ + struct tm *timeinfo = localtime(&rawtime); + char buffer[21]; + strftime(buffer, sizeof(buffer), "%b %d %Y %H:%M:%S", timeinfo); + + wprintw(chat_content, "%s ", buffer); + wattron(chat_content, A_BOLD); + wattron(chat_content, COLOR_PAIR(user_color)); + + wprintw(chat_content, "<%s> ", username); + + wattroff(chat_content, A_BOLD); + wattroff(chat_content, COLOR_PAIR(user_color)); + wprintw(chat_content, "%s", content); +} + +/* + * Get chat conversation into buffer and show it to chat window + */ +void show_chat() +{ + add_message(1725932011, "night", 1, "I go to school by bus.\n"); + add_message(1725933011, "night", 2, "I go to school by tram.\n"); + add_message(1725934011, "night", 3, "I go to school by train.\n"); + add_message(1725935011, "night", 4, "I go to school by car.\n"); + add_message(1725936011, "night", 5, "I go to school by run.\n"); + add_message(1725937011, "night", 6, "I go to school by bike.\n"); + add_message(1725938011, "night", 7, "I go to school by plane.\n"); +} +/* + * Require heap allocated username + */ +void add_username(char *username) +{ + wchar_t *icon_str = memalloc(2 * sizeof(wchar_t)); + wcsncpy(icon_str, L"", 2); + + arraylist_add(users, username, icon_str, 7, false, false); +} + +/* + * Main loop of user interface + */ +void ui() +{ + ncurses_init(); + windows_init(); + users = arraylist_init(LINES); + marked = arraylist_init(100); + show_icons = true; + sqlite_init(); + highlight_current_line(); + refresh(); + int ch; + while (1) { + if (COLS < 80 || LINES < 24) { + endwin(); + error(1, "Terminal size needs to be at least 80x24"); + } + ch = getch(); + switch (ch) { + case 'q': + goto cleanup; + + /* go up by k or up arrow */ + case UP: + case 'k': + if (current_selection > 0) + current_selection--; + + highlight_current_line(); + break; + + /* go down by j or down arrow */ + case DOWN: + case 'j': + if (current_selection < (users->length - 1)) + current_selection++; + + highlight_current_line(); + break; + } + } +cleanup: + arraylist_free(users); + arraylist_free(marked); + endwin(); + return; +} diff --git a/src/client/user.c b/src/client/user.c new file mode 100644 index 0000000..7d000f9 --- /dev/null +++ b/src/client/user.c @@ -0,0 +1,119 @@ +#include "packet.h" +#include "util.h" +#include "client/user.h" + +ArrayList *arraylist_init(size_t capacity) +{ + ArrayList *list = memalloc(sizeof(ArrayList)); + list->length = 0; + list->capacity = capacity; + list->items = memalloc(capacity * sizeof(user)); + + return list; +} + +void arraylist_free(ArrayList *list) +{ + for (size_t i = 0; i < list->length; i++) { + if (list->items[i].name != NULL) + free(list->items[i].name); + if (list->items[i].icon != NULL) + free(list->items[i].icon); + } + + free(list->items); + free(list); +} + +/* + * Check if the user is in the arraylist + */ +long arraylist_search(ArrayList *list, char *username) +{ + for (long i = 0; i < list->length; i++) { + if (strcmp(list->items[i].name, username) == 0) { + return i; + } + } + return -1; +} + +void arraylist_remove(ArrayList *list, long index) +{ + if (index >= list->length) + return; + + free(list->items[index].name); + free(list->items[index].icon); + + for (long i = index; i < list->length - 1; i++) + list->items[i] = list->items[i + 1]; + + list->length--; +} + +/* + * Force will not remove duplicate marked users, instead it just skip adding + */ +void arraylist_add(ArrayList *list, char *name, wchar_t *icon, int color, bool marked, bool force) +{ + user new_user = { name, icon, color }; + + if (list->capacity != list->length) { + if (marked) { + for (int i = 0; i < list->length; i++) { + if (strcmp(list->items[i].name, new_user.name) == 0) { + if (!force) + arraylist_remove(list, i); + return; + } + } + } + list->items[list->length] = new_user; + } else { + int new_cap = list->capacity * 2; + user *new_items = memalloc(new_cap * sizeof(user)); + user *old_items = list->items; + list->capacity = new_cap; + list->items = new_items; + + for (int i = 0; i < list->length; i++) + new_items[i] = old_items[i]; + + free(old_items); + list->items[list->length] = new_user; + } + list->length++; +} + +/* + * Construct a formatted line for display + */ +char *get_line(ArrayList *list, long index, bool icons) +{ + user seluser = list->items[index]; + + size_t name_len = strlen(seluser.name); + size_t length; + + if (icons) { + length = name_len + 10; /* 8 for icon, 1 for space and 1 for null */ + } else { + length = name_len; + } + + char *line = memalloc(length * sizeof(char)); + line[0] = '\0'; + + if (icons) { + char *tmp = memalloc(9 * sizeof(char)); + snprintf(tmp, 8, "%ls", seluser.icon); + + strcat(line, tmp); + strcat(line, " "); + free(tmp); + } + strcat(line, seluser.name); + + return line; +} diff --git a/src/server/ht.c b/src/server/ht.c new file mode 100644 index 0000000..c6634a4 --- /dev/null +++ b/src/server/ht.c @@ -0,0 +1,75 @@ +#include "ht.h" +#include "config.h" +#include "packet.h" + +client *hash_table[TABLE_SIZE]; + +/* Hashes every name with: name and TABLE_SIZE */ +unsigned int hash(char *name) +{ + int length = strnlen(name, MAX_NAME); + unsigned int hash_value = 0; + + for (int i = 0; i < length; i++) { + hash_value += name[i]; + hash_value = (hash_value * name[i]) % TABLE_SIZE; + } + + return hash_value; +} + +void hashtable_init() +{ + for (int i = 0; i < TABLE_SIZE; i++) + hash_table[i] = NULL; +} + +void hashtable_print() +{ + int i = 0; + + for (; i < TABLE_SIZE; i++) { + if (hash_table[i] == NULL) { + printf("%i. ---\n", i); + } else { + printf("%i. | Name %s\n", i, hash_table[i]->name); + } + } +} + +/* Gets hashed name and tries to store the client struct in that place */ +int hashtable_add(client *p) +{ + if (p == NULL) return 0; + + int index = hash(p->name); + int initial_index = index; + /* linear probing until an empty slot is found */ + while (hash_table[index] != NULL) { + index = (index + 1) % TABLE_SIZE; /* move to next item */ + /* the hash table is full as no available index back to initial index, cannot fit new item */ + if (index == initial_index) return 1; + } + + hash_table[index] = p; + return 0; +} + +/* Rehashes the name and then looks in this spot, if found returns client */ +client *hashtable_search(char *name) +{ + int index = hash(name); + int initial_index = index; + + /* Linear probing until an empty slot or the desired item is found */ + while (hash_table[index] != NULL) { + if (strncmp(hash_table[index]->name, name, MAX_NAME) == 0) + return hash_table[index]; + + index = (index + 1) % TABLE_SIZE; /* Move to the next slot */ + /* back to same item */ + if (index == initial_index) break; + } + + return NULL; +} diff --git a/src/server/server.c b/src/server/server.c index bafc2e0..1117bc2 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -1,79 +1,82 @@ #include "packet.h" +#include "util.h" +#include "config.h" #include "notification.h" -#include -socklen_t clilen; -struct sockaddr_in cli_address; -uint8_t shared_key[SHARED_KEY_SIZE]; -int clientfd; +typedef struct client_t { + int fd; /* File descriptor for client socket */ + uint8_t *shared_key; + char username[MAX_NAME]; /* Username of client */ +} client_t; + +typedef struct thread_t { + int epoll_fd; /* epoll instance for each thread */ + pthread_t thread; /* POSIX thread */ + int num_clients; /* Number of active clients in thread */ + client_t clients[MAX_CLIENTS_PER_THREAD]; /* Active clients */ +} thread_t; + +thread_t threads[MAX_THREADS]; +int num_thread = 0; + +void *thread_worker(void *arg); +int set_nonblocking(int fd); /* - * Initialise socket server + * Authenticate client before starting communication */ -int socket_init() +int authenticate_client(int clientfd, uint8_t *username) { - int serverfd = socket(AF_INET, SOCK_STREAM, 0); - if (serverfd < 0) { - error(1, "Error on opening socket"); + /* Create a challenge */ + uint8_t *challenge = memalloc(CHALLENGE_SIZE * sizeof(uint8_t)); + randombytes_buf(challenge, CHALLENGE_SIZE); + + /* Sending fake signature as structure requires it */ + uint8_t *fake_sig = create_signature(NULL, 0, NULL); + + packet *auth_pkt = create_packet(1, ZSM_TYP_AUTH, CHALLENGE_SIZE, + challenge, fake_sig); + if (send_packet(auth_pkt, clientfd) != ZSM_STA_SUCCESS) { + /* fd already closed */ + error(0, "Could not authenticate client"); + free(challenge); + free_packet(auth_pkt); + goto failure; + } + free(fake_sig); + + packet client_auth_pkt; + int status; + if ((status = recv_packet(&client_auth_pkt, clientfd, ZSM_TYP_AUTH) + != ZSM_STA_SUCCESS)) { + error(0, "Could not authenticate client"); + goto failure; } - /* Reuse addr(for debug) */ - int optval = 1; - if (setsockopt(serverfd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)) < 0) { + uint8_t pk_bin[PK_BIN_SIZE], pk_username[MAX_NAME]; + memcpy(pk_bin, client_auth_pkt.data, PK_BIN_SIZE); + memcpy(pk_username, client_auth_pkt.data + PK_BIN_SIZE, MAX_NAME); - error(1, "Error at setting SO_REUSEADDR"); + if (crypto_sign_verify_detached(client_auth_pkt.signature, challenge, CHALLENGE_SIZE, pk_bin) != 0) { + error(0, "Incorrect signature, could not authenticate client"); + free(client_auth_pkt.data); + goto failure; + } else { + packet *ok_pkt = create_packet(ZSM_STA_AUTHORISED, ZSM_TYP_INFO + , 0, NULL, NULL); + send_packet(ok_pkt, clientfd); + free_packet(ok_pkt); + strcpy(username, pk_username); + return ZSM_STA_SUCCESS; } +failure:; + packet *error_pkt = create_packet(ZSM_STA_UNAUTHORISED, ZSM_TYP_ERROR, + 0, NULL, create_signature(NULL, 0, NULL)); - struct sockaddr_in sv_addr; - memset(&sv_addr, 0, sizeof(sv_addr)); - sv_addr.sin_family = AF_INET; - sv_addr.sin_addr.s_addr = INADDR_ANY; - sv_addr.sin_port = htons(PORT); - - if (bind(serverfd, (struct sockaddr *) &sv_addr, sizeof(sv_addr)) < 0) { - error(1, "Error on bind"); - } - - if (listen(serverfd, MAX_CONNECTION) < 0) { - error(1, "Error on listen"); - } - - printf("Listening on port %d\n", PORT); - clilen = sizeof(cli_address); - return serverfd; -} - -/* - * Performs key exchange with client - */ -int key_exchange(int clientfd) -{ - /* Generate the server's key pair */ - uint8_t sv_pk[PUBLIC_KEY_SIZE], sv_sk[PRIVATE_KEY_SIZE]; - crypto_kx_keypair(sv_pk, sv_sk); - - /* Get public key from client */ - uint8_t *pk; - if ((pk = get_public_key(clientfd)) == NULL) { - return -1; - } - - /* Send our public key */ - if (send_public_key(clientfd, sv_pk) < 0) { - free(pk); - return -1; - } - - /* Compute a shared key using the client's public key and our secret key. */ - if (crypto_kx_server_session_keys(NULL, shared_key, sv_pk, sv_sk, pk) != 0) { - error(0, "Client public key is not acceptable"); - free(pk); - close(clientfd); - return -1; - } - - free(pk); - return 0; + send_packet(error_pkt, clientfd); + free_packet(error_pkt); + close(clientfd); + return ZSM_STA_ERROR_AUTHENTICATE; } void signal_handler(int signal) @@ -91,98 +94,48 @@ void signal_handler(int signal) } } -void *receiver() +/* + * Takes thread_t as argument to use its epoll instance to wait new messages + * Thread worker to relay messages + */ +void *thread_worker(void *arg) { - int serverfd = socket_init(); - clientfd = accept(serverfd, (struct sockaddr *) &cli_address, &clilen); - if (clientfd < 0) { - error(0, "Error on accepting client"); - /* Continue accpeting connections */ -/* continue; */ - } + thread_t *thread = (thread_t *) arg; + struct epoll_event events[MAX_EVENTS]; + + while (1) { + int num_events = epoll_wait(thread->epoll_fd, events, MAX_EVENTS, -1); + if (num_events == -1) { + error(0, "epoll_wait"); + } + for (int i = 0; i < num_events; i++) { + client_t *client = (client_t *) events[i].data.ptr; - if (key_exchange(clientfd) < 0) { - error(0, "Error performing key exchange with client"); -/* continue; */ - } - while (1) { - message msg; - memset(&msg, 0, sizeof(msg)); - if (recv_packet(&msg, clientfd) != ZSM_STA_SUCCESS) { - close(clientfd); - break; -/* continue; */ + if (events[i].events & EPOLLIN) { + /* handle message */ + packet pkt; + packet *verified_pkt = verify_packet(&pkt, client->fd); + if (verified_pkt == NULL) { + error(0, "Error verifying packet"); + } + + /* Message relay */ + uint8_t to[MAX_NAME]; + memcpy(to, verified_pkt->data + MAX_NAME, MAX_NAME); + + for (int i = 0; i < MAX_THREADS; i++) { + thread_t thread = threads[i]; + for (int j = 0; j < thread.num_clients; j++) { + client_t client = thread.clients[j]; + if (strcmp(client.username, to) == 0) { + error(0, "Relaying message to %s\n", client.username); + send_packet(verified_pkt, client.fd); + } + } + } + } } - - size_t encrypted_len = msg.length - NONCE_SIZE; - size_t msg_len = encrypted_len - ADDITIONAL_SIZE; - uint8_t nonce[NONCE_SIZE]; - uint8_t encrypted[encrypted_len]; - uint8_t decrypted[msg_len + 1]; - unsigned long long decrypted_len; - memcpy(nonce, msg.data, NONCE_SIZE); - memcpy(encrypted, msg.data + NONCE_SIZE, encrypted_len); - - free(msg.data); - if (crypto_aead_xchacha20poly1305_ietf_decrypt(decrypted, &decrypted_len, - NULL, - encrypted, encrypted_len, - NULL, 0, - nonce, shared_key) != 0) { - error(0, "Cannot decrypt message"); - } else { - /* Decrypted message */ - decrypted[msg_len] = '\0'; - printf("Decrypted: %s\n", decrypted); - send_notification(decrypted); - msg.data = malloc(14); - strcpy(msg.data, "Received data"); - msg.length = 14; - send_packet(&msg, clientfd); - free(msg.data); - } - } - close(clientfd); - close(serverfd); - return NULL; -} - -void *sender() -{ - while (1) { - printf("Enter message to send to client: "); - fflush(stdout); - char line[1024]; - line[0] = '\0'; - size_t length = strlen(line); - while (length <= 1) { - fgets(line, sizeof(line), stdin); - length = strlen(line); - } - length -= 1; - line[length] = '\0'; - - uint8_t nonce[NONCE_SIZE]; - uint8_t encrypted[length + ADDITIONAL_SIZE]; - unsigned long long encrypted_len; - - randombytes_buf(nonce, sizeof(nonce)); - crypto_aead_xchacha20poly1305_ietf_encrypt(encrypted, &encrypted_len, - line, length, - NULL, 0, NULL, nonce, shared_key); - size_t payload_t = NONCE_SIZE + encrypted_len; - uint8_t encryptedwithnonce[payload_t]; - memcpy(encryptedwithnonce, nonce, NONCE_SIZE); - memcpy(encryptedwithnonce + NONCE_SIZE, encrypted, encrypted_len); - - message *msg = create_packet(1, 0x10, payload_t, encryptedwithnonce); - if (send_packet(msg, clientfd) != ZSM_STA_SUCCESS) { - close(clientfd); - } - free_packet(msg); - } - close(clientfd); - return NULL; + } } int main() @@ -198,21 +151,125 @@ int main() signal(SIGABRT, signal_handler); signal(SIGINT, signal_handler); signal(SIGTERM, signal_handler); - - pthread_t recv_worker, send_worker; - if (pthread_create(&recv_worker, NULL, sender, NULL) != 0) { - fprintf(stderr, "Error creating incoming thread\n"); - return 1; + + /* Start server and epoll */ + int serverfd, clientfd; + + /* Create socket */ + serverfd = socket(AF_INET, SOCK_STREAM, 0); + if (serverfd < 0) { + error(1, "Error on opening socket"); } - if (pthread_create(&send_worker, NULL, receiver, NULL) != 0) { - fprintf(stderr, "Error creating outgoing thread\n"); - return 1; + /* Reuse address (for debug) */ + int opt = 1; + if (setsockopt(serverfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) { + error(1, "Error at setting SO_REUSEADDR"); } - // Join threads - pthread_join(recv_worker, NULL); - pthread_join(send_worker, NULL); + struct sockaddr_in server_addr, client_addr; + socklen_t client_addr_len = sizeof(client_addr); + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = INADDR_ANY; + server_addr.sin_port = htons(PORT); + if (bind(serverfd, (struct sockaddr *) &server_addr + , sizeof(server_addr)) < 0) { + close(serverfd); + error(1, "Error on bind"); + } + + if (listen(serverfd, MAX_CONNECTION_QUEUE) < 0) { + close(serverfd); + error(1, "Error on listen"); + } + + + /* Creating thread pool */ + for (int i = 0; i < MAX_THREADS; i++) { + /* Create epoll instance for each thread */ + threads[i].epoll_fd = epoll_create1(0); + if (threads[i].epoll_fd < 0) { + error(1, "Error on creating epoll instance"); + } + threads[i].num_clients = 0; + + /* Start a new thread and pass thread_t struct to thread */ + if (pthread_create(&threads[i].thread, NULL, thread_worker, + &threads[i]) != 0) { + error(1, "Error on creating threads"); + } else { + error(0, "Thread %d created", i); + } + } + + error(0, "Listening on port %d", PORT); + + /* Server loop to accept clients and load balance */ + while (1) { + clientfd = accept(serverfd, (struct sockaddr *) &client_addr, + &client_addr_len); + if (clientfd < 0) { + error(0, "Error on accepting client"); + continue; + } + + /* Assign new client to a thread + * Clients distributed by a rotation(round-robin) + */ + thread_t *thread = &threads[num_thread]; + if (thread->num_clients >= MAX_CLIENTS_PER_THREAD) { + error(0, "Thread %d is already full, rejecting connection\n", + num_thread); + close(clientfd); + continue; + } + + client_t *client = &thread->clients[thread->num_clients]; + + uint8_t username[MAX_NAME]; + /* User logins, authenticate them */ + if (authenticate_client(clientfd, username) != ZSM_STA_SUCCESS) { + error(0, "Error authenticating with client"); + continue; + } + + /* To use EPOLLET, an nonblocking fd is required */ + /* + if (set_nonblocking(clientfd) == -1) { + perror("Failed to set client socket to non-blocking"); + close(clientfd); + continue; + } + */ + + /* Add the new client to the thread's epoll instance */ + struct epoll_event event; + event.data.ptr = client; + event.events = EPOLLIN; + + if (epoll_ctl(thread->epoll_fd, EPOLL_CTL_ADD, clientfd, &event) == -1) { + perror("Failed to add client to epoll"); + close(clientfd); + continue; + } + + /* Assign fd to client in a thread */ + client->fd = clientfd; + strcpy(client->username, username); + thread->num_clients++; + + /* Rotate num_thread back to start if it is larder than MAX_THREADS */ + num_thread = (num_thread + 1) % MAX_THREADS; + } + + /* End the thread */ + for (int i = 0; i < MAX_THREADS; i++) { + if (pthread_join(threads[i].thread, NULL) != 0) { + error(0, "pthread_join"); + } + } + close(serverfd); return 0; }