Slightly working prototype, able to relay messages

This commit is contained in:
Night Kaly 2024-09-16 13:11:01 +01:00
parent 2897e18d47
commit 82ee7785f7
Signed by: night0721
GPG key ID: 957D67B8DB7A119B
21 changed files with 1818 additions and 548 deletions

3
.gitignore vendored
View file

@ -1,4 +1,3 @@
zsm
zsmc
bin
*.o
*.tar.gz

View file

@ -3,31 +3,36 @@
CC = cc
VERSION = 1.0
SERVER = zsm
CLIENT = zsmc
SERVER = zmr
CLIENT = zen
TARGET = zsm
MANPAGE = $(TARGET).1
PREFIX ?= /usr/local
BINDIR = $(PREFIX)/bin
MANDIR = $(PREFIX)/share/man/man1
# Flags
LDFLAGS = $(shell pkg-config --libs libsodium libnotify ncurses)
CFLAGS = -O3 -mtune=native -march=native -pipe -g -std=c99 -Wno-pointer-sign -Wpedantic -Wall -D_DEFAULT_SOURCE -D_XOPEN_SOURCE=600 $(shell pkg-config --cflags libsodium libnotify ncurses) -lpthread
LDFLAGS = $(shell pkg-config --libs libsodium libnotify ncurses sqlite3)
CFLAGS = -O3 -mtune=native -march=native -pipe -g -std=c99 -Wno-pointer-sign -pedantic -Wall -D_DEFAULT_SOURCE -D_XOPEN_SOURCE=600 $(shell pkg-config --cflags libsodium libnotify ncurses sqlite3) -lpthread
SERVERSRC = src/server/*.c
CLIENTSRC = src/client/*.c
LIBSRC = lib/*.c
INCLUDE = -Iinclude/
all: $(SERVER) $(CLIENT)
$(SERVER): $(SERVERSRC) $(LIBSRC)
$(CC) $(SERVERSRC) $(LIBSRC) $(INCLUDE) -o $@ $(CFLAGS) $(LDFLAGS)
mkdir -p bin
$(CC) $(SERVERSRC) $(LIBSRC) $(INCLUDE) -o bin/$@ $(CFLAGS) $(LDFLAGS)
$(CLIENT): $(CLIENTSRC) $(LIBSRC)
$(CC) $(CLIENTSRC) $(LIBSRC) $(INCLUDE) -o $@ $(CFLAGS) $(LDFLAGS)
mkdir -p bin
$(CC) $(CLIENTSRC) $(LIBSRC) $(INCLUDE) -o bin/$@ $(CFLAGS) $(LDFLAGS)
dist:
mkdir -p $(TARGET)-$(VERSION)
cp -R README.md $(MANPAGE) $(TARGET) $(TARGET)-$(VERSION)
cp -R README.md $(MANPAGE) $(SERVER) $(CLIENT) $(TARGET)-$(VERSION)
tar -cf $(TARGET)-$(VERSION).tar $(TARGET)-$(VERSION)
gzip $(TARGET)-$(VERSION).tar
rm -rf $(TARGET)-$(VERSION)
@ -35,18 +40,22 @@ dist:
install: $(TARGET)
mkdir -p $(DESTDIR)$(BINDIR)
mkdir -p $(DESTDIR)$(MANDIR)
cp -p $(TARGET) $(DESTDIR)$(BINDIR)/$(TARGET)
chmod 755 $(DESTDIR)$(BINDIR)/$(TARGET)
cp -p bin/$(TARGET) $(DESTDIR)$(BINDIR)/$(SERVER)
chmod 755 $(DESTDIR)$(BINDIR)/$(SERVER)
cp -p bin/$(TARGET) $(DESTDIR)$(BINDIR)/$(CLIENT)
chmod 755 $(DESTDIR)$(BINDIR)/$(CLIENT)
cp -p $(MANPAGE) $(DESTDIR)$(MANDIR)/$(MANPAGE)
chmod 644 $(DESTDIR)$(MANDIR)/$(MANPAGE)
uninstall:
$(RM) $(DESTDIR)$(BINDIR)/$(TARGET)
$(RM) $(DESTDIR)$(BINDIR)/$(SERVER)
$(RM) $(DESTDIR)$(BINDIR)/$(CLIENT)
$(RM) $(DESTDIR)$(MANDIR)/$(MANPAGE)
clean:
$(RM) $(TARGET)
$(RM) $(SERVER) $(CLIENT)
all: $(TARGET)
run:
./bin/zen
.PHONY: all dist install uninstall clean

8
include/client/db.h Normal file
View file

@ -0,0 +1,8 @@
#ifndef DB_H_
#define DB_H_
#include <sqlite3.h>
int sqlite_init();
#endif

12
include/client/ui.h Normal file
View file

@ -0,0 +1,12 @@
#ifndef UI_H_
#define UI_H_
#include <ncurses.h>
void ncurses_init();
void windows_init();
void draw_border(WINDOW *window, bool active);
void add_username(char *username);
void ui();
#endif

27
include/client/user.h Normal file
View file

@ -0,0 +1,27 @@
#ifndef USER_H_
#define USER_H_
#include <stdio.h>
#include <stdbool.h>
#include <stddef.h>
typedef struct user {
char *name;
wchar_t *icon;
int color;
} user;
typedef struct ArrayList {
size_t length;
size_t capacity;
user *items;
} ArrayList;
ArrayList *arraylist_init(size_t capacity);
void arraylist_free(ArrayList *list);
long arraylist_search(ArrayList *list, char *username);
void arraylist_remove(ArrayList *list, long index);
void arraylist_add(ArrayList *list, char *name, wchar_t *icon, int color, bool marked, bool force);
char *get_line(ArrayList *list, long index, bool icons);
#endif

25
include/config.h Normal file
View file

@ -0,0 +1,25 @@
/* Server */
#define DEBUG 0
#define PORT 20247
#define MAX_NAME 32 /* Max username length */
#define DATABASE_NAME "test.db"
#define MAX_MESSAGE_LENGTH 8192
/* Don't touch unless you know what you are doing */
#define MAX_CONNECTION_QUEUE 128
#define MAX_THREADS 8
#define MAX_EVENTS 64 /* Max events can be returned simulataneouly by epoll */
#define MAX_CLIENTS_PER_THREAD 1024
/* Client */
#define DOMAIN "127.0.0.1"
/* UI */
#define PANEL_HEIGHT 1
#define DRAW_PREVIEW 1
#define CLIENT_DATA_DIR "~/.local/share/zsm/zen"
/* Keybindings */
#define DOWN 0x102
#define UP 0x103

18
include/ht.h Normal file
View file

@ -0,0 +1,18 @@
#ifndef HT_H_
#define HT_H_
#define TABLE_SIZE 100
typedef struct client {
int id;
char name[32];
//pthread_t thread;
} client;
unsigned int hash(char *name);
void hashtable_init();
void hashtable_print();
int hashtable_add(client *p);
client *hashtable_search(char *name);
#endif

39
include/key.h Normal file
View file

@ -0,0 +1,39 @@
#ifndef KEY_H_
#define KEY_H_
#include <sodium.h>
#include "config.h"
#define TIME_SIZE sizeof(time_t)
#define SIGN_SIZE crypto_sign_BYTES
#define PK_BIN_SIZE crypto_kx_PUBLICKEYBYTES
#define SK_BIN_SIZE crypto_sign_SECRETKEYBYTES
#define METADATA_SIZE MAX_NAME + TIME_SIZE
#define PK_SIZE PK_BIN_SIZE + METADATA_SIZE + SIGN_SIZE
#define SK_SIZE SK_BIN_SIZE + METADATA_SIZE + SIGN_SIZE
#define SHARED_SIZE crypto_kx_SESSIONKEYBYTES
typedef struct public_key {
uint8_t bin[PK_BIN_SIZE];
uint8_t username[MAX_NAME];
time_t creation;
uint8_t signature[SIGN_SIZE];
} public_key;
typedef struct secret_key {
uint8_t bin[SK_BIN_SIZE];
uint8_t username[MAX_NAME];
time_t creation;
uint8_t signature[SIGN_SIZE];
} secret_key;
typedef struct key_pair {
public_key pk;
secret_key sk;
} key_pair;
key_pair *create_key_pair(char *username);
key_pair *get_key_pair(char *username);
#endif

View file

@ -1,8 +1,9 @@
#ifndef NOTIFICATION_H
#define NOTIFICATION_H
#include <stdint.h>
#include <libnotify/notify.h>
void send_notification(const char *content);
void send_notification(uint8_t *content);
#endif

View file

@ -6,29 +6,34 @@
#include <stdlib.h>
#include <stdint.h>
#include <stdarg.h>
#include <stdbool.h>
#include <unistd.h>
#include <errno.h>
#include <signal.h>
#include <time.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/epoll.h>
#include <arpa/inet.h>
#include <netinet/in.h>
#include <netdb.h>
#include <fcntl.h>
#include <pthread.h>
#include <sodium.h>
#include <libgen.h>
#include <wchar.h>
#include "config.h"
#define DEBUG 1
#define DOMAIN "127.0.0.1"
#define PORT 20247
#define MAX_CONNECTION 5
#define MAX_MESSAGE_LENGTH 8192
#define ERROR_LENGTH 26
#define ZSM_TYP_KEY 0x1
#define ZSM_TYP_SEND_MESSAGE 0x2
#define ZSM_TYP_AUTH 0x1
#define ZSM_TYP_MESSAGE 0x2
#define ZSM_TYP_UPDATE_MESSAGE 0x3
#define ZSM_TYP_DELETE_MESSAGE 0x4
#define ZSM_TYP_PRESENCE 0x5
#define ZSM_TYP_TYPING 0x6
#define ZSM_TYP_ERROR 0x7 /* Error message */
#define ZSM_TYP_B 0x8
#define ZSM_TYP_ERROR 0x5
#define ZSM_TYP_INFO 0x6
#define ZSM_STA_SUCCESS 0x1
#define ZSM_STA_INVALID_TYPE 0x2
@ -38,32 +43,40 @@
#define ZSM_STA_WRITING_SOCKET 0x6
#define ZSM_STA_UNKNOWN_USER 0x7
#define ZSM_STA_MEMORY_ALLOCATION 0x8
#define ZSM_STA_WRONG_KEY_LENGTH 0x9
#define ZSM_STA_ERROR_ENCRYPT 0x9
#define ZSM_STA_ERROR_DECRYPT 0xA
#define ZSM_STA_ERROR_AUTHENTICATE 0xB
#define ZSM_STA_ERROR_INTEGRITY 0xC
#define ZSM_STA_UNAUTHORISED 0xD
#define ZSM_STA_AUTHORISED 0xE
#define PUBLIC_KEY_SIZE crypto_kx_PUBLICKEYBYTES
#define PRIVATE_KEY_SIZE crypto_kx_SECRETKEYBYTES
#define SHARED_KEY_SIZE crypto_kx_SESSIONKEYBYTES
#define ADDRESS_SIZE MAX_NAME + 1 + 255 /* 1 for @, 255 for domain, defined in RFC 5321, Section 4.5.3.1.2 */
#define CHALLENGE_SIZE 32
#define HASH_SIZE crypto_generichash_BYTES
#define NONCE_SIZE crypto_aead_xchacha20poly1305_ietf_NPUBBYTES
#define ADDITIONAL_SIZE crypto_aead_xchacha20poly1305_ietf_ABYTES
typedef struct message {
uint8_t option;
typedef struct packet {
uint8_t status;
uint8_t type;
unsigned long long length;
unsigned char *data;
} message;
uint32_t length;
uint8_t *data;
uint8_t *signature;
} packet;
#include "key.h"
/* Utilities functions */
void error(int fatal, const char *fmt, ...);
void *memalloc(size_t size);
void *estrdup(void *str);
unsigned char *get_public_key(int sockfd);
int send_public_key(int sockfd, unsigned char *pk);
void print_packet(message *msg);
int recv_packet(message *msg, int fd);
message *create_error_packet(int code);
message *create_packet(uint8_t option, uint8_t type, uint32_t length, char *data);
int send_packet(message *msg, int fd);
void free_packet(message *msg);
void print_packet(packet *msg);
int recv_packet(packet *pkt, int fd, uint8_t required_type);
packet *create_packet(uint8_t option, uint8_t type, uint32_t length, uint8_t *data, uint8_t *signature);
int send_packet(packet *msg, int fd);
void free_packet(packet *msg);
int encrypt_packet(int sockfd, key_pair *kp);
packet *verify_packet(packet *pkt, int fd);
uint8_t *encrypt_data(uint8_t *from, uint8_t *to, uint8_t *raw, uint32_t raw_length, uint32_t *length);
uint8_t *decrypt_data(packet *pkt);
int verify_integrity(packet *pkt, public_key *pk);
uint8_t *create_signature(uint8_t *data, uint32_t length, secret_key *sk);
#endif

18
include/util.h Normal file
View file

@ -0,0 +1,18 @@
#ifndef UTIL_H
#define UTIL_H
#define LOG_ERROR 1
#define LOG_INFO 2
#define PATH_MAX 4096
void error(int fatal, const char *fmt, ...);
void *memalloc(size_t size);
void *estrdup(void *str);
int set_nonblocking(int fd);
char *replace_home(char *str);
void mkdir_p(const char *destdir);
void write_log(int type, const char *fmt, ...);
void print_bin(const unsigned char *ptr, size_t length);
#endif

97
lib/key.c Normal file
View file

@ -0,0 +1,97 @@
#include "packet.h"
#include "key.h"
#include "util.h"
key_pair *create_key_pair(char *username)
{
uint8_t cl_pk_bin[PK_BIN_SIZE], cl_sk_bin[SK_BIN_SIZE];
crypto_sign_keypair(cl_pk_bin, cl_sk_bin);
char pk_path[PATH_MAX], sk_path[PATH_MAX];
/* USE DB INSTEAD OF FILES */
sprintf(pk_path, "/home/night/%s_pk", username);
sprintf(sk_path, "/home/night/%s_sk", username);
FILE *pkf = fopen(pk_path, "w+");
FILE *skf = fopen(sk_path, "w+");
uint8_t pk_content[PK_SIZE], sk_content[SK_SIZE], metadata[METADATA_SIZE];
time_t current_time = time(NULL);
uint8_t *username_padded = memalloc(MAX_NAME * sizeof(uint8_t));
strcpy(username_padded, username);
size_t length = strlen(username);
if (length < MAX_NAME) {
/* Pad with null characters up to max length */
memset(username_padded + length, 0, MAX_NAME - length);
}
memcpy(metadata, username_padded, MAX_NAME);
memcpy(metadata + MAX_NAME, &current_time, TIME_SIZE);
uint8_t *hash = memalloc(HASH_SIZE * sizeof(uint8_t));
uint8_t *sign = memalloc(SIGN_SIZE * sizeof(uint8_t));
crypto_generichash(hash, HASH_SIZE, metadata, METADATA_SIZE, NULL, 0);
crypto_sign_detached(sign, NULL, hash, HASH_SIZE, cl_sk_bin);
memcpy(pk_content, cl_pk_bin, PK_BIN_SIZE);
memcpy(pk_content + PK_BIN_SIZE, metadata, METADATA_SIZE);
memcpy(pk_content + PK_BIN_SIZE + METADATA_SIZE, sign, SIGN_SIZE);
memcpy(sk_content, cl_sk_bin, SK_BIN_SIZE);
memcpy(sk_content + SK_BIN_SIZE, metadata, METADATA_SIZE);
memcpy(sk_content + SK_BIN_SIZE + METADATA_SIZE, sign, SIGN_SIZE);
free(hash);
fwrite(pk_content, 1, PK_SIZE, pkf);
fwrite(sk_content, 1, SK_SIZE, skf);
fclose(pkf);
fclose(skf);
key_pair *kp = memalloc(sizeof(key_pair));
memcpy(kp->pk.bin, cl_pk_bin, PK_BIN_SIZE);
memcpy(kp->pk.username, username_padded, MAX_NAME);
kp->pk.creation = current_time;
memcpy(kp->pk.signature, sign, SIGN_SIZE);
memcpy(kp->sk.bin, cl_sk_bin, SK_BIN_SIZE);
memcpy(kp->sk.username, username_padded, MAX_NAME);
kp->sk.creation = current_time;
memcpy(kp->sk.signature, sign, SIGN_SIZE);
free(username_padded);
free(sign);
return kp;
}
key_pair *get_key_pair(char *username)
{
char pk_path[PATH_MAX], sk_path[PATH_MAX];
sprintf(pk_path, "/home/night/%s_pk", username);
sprintf(sk_path, "/home/night/%s_sk", username);
FILE *pkf = fopen(pk_path, "r");
FILE *skf = fopen(sk_path, "r");
if (!pkf || !skf) {
printf("Error opening key files.\n");
return NULL;
}
uint8_t pk_content[PK_SIZE], sk_content[SK_SIZE];
fread(pk_content, 1, PK_SIZE, pkf);
fread(sk_content, 1, SK_SIZE, skf);
fclose(pkf);
fclose(skf);
key_pair *kp = memalloc(sizeof(key_pair));
memcpy(kp->pk.bin, pk_content, PK_BIN_SIZE);
memcpy(kp->pk.username, pk_content + PK_BIN_SIZE, MAX_NAME);
memcpy(&kp->pk.creation, pk_content + PK_BIN_SIZE + MAX_NAME, TIME_SIZE);
memcpy(kp->pk.signature, pk_content + PK_BIN_SIZE + MAX_NAME + TIME_SIZE, SIGN_SIZE);
memcpy(kp->sk.bin, sk_content, SK_BIN_SIZE);
memcpy(kp->sk.username, sk_content + SK_BIN_SIZE, MAX_NAME);
memcpy(&kp->sk.creation, sk_content + SK_BIN_SIZE + MAX_NAME, TIME_SIZE);
memcpy(kp->sk.signature, sk_content + SK_BIN_SIZE + MAX_NAME + TIME_SIZE, SIGN_SIZE);
return kp;
}

View file

@ -1,8 +1,16 @@
#include "notification.h"
#include "util.h"
void send_notification(const char *content)
void send_notification(uint8_t *content)
{
NotifyNotification *noti = notify_notification_new("Client", content, "dialog-information");
notify_notification_show(noti, NULL);
g_object_unref(G_OBJECT(noti));
printf("Content: %s\n", content);
NotifyNotification *notification = notify_notification_new("Client",
(char *) content, "dialog-information");
if (notification == NULL) {
error(0, "Cannot create notification");
}
if (!notify_notification_show(notification, NULL)) {
error(0, "Cannot show notifcation");
}
g_object_unref(G_OBJECT(notification));
}

View file

@ -1,273 +1,436 @@
#include "packet.h"
#include "key.h"
#include "util.h"
/*
* msg is the error message to print to stderr
* will include error message from function if errno isn't 0
* end program is fatal is 1
*/
void error(int fatal, const char *fmt, ...)
void print_packet(packet *pkt)
{
va_list args;
va_start(args, fmt);
/* to preserve errno */
int errsv = errno;
/* Determine the length of the formatted error message */
va_list args_copy;
va_copy(args_copy, args);
size_t error_len = vsnprintf(NULL, 0, fmt, args_copy);
va_end(args_copy);
/* 7 for [zsm], space and null */
char errorstr[error_len + 1];
vsnprintf(errorstr, error_len + 1, fmt, args);
fprintf(stderr, "[zsm] ");
if (errsv != 0) {
perror(errorstr);
errno = 0;
} else {
fprintf(stderr, "%s\n", errorstr);
}
va_end(args);
if (fatal) exit(1);
}
void *memalloc(size_t size)
{
void *ptr = malloc(size);
if (!ptr) {
error(0, "Error allocating memory");
return NULL;
}
return ptr;
}
void *estrdup(void *str)
{
void *modstr = strdup(str);
if (modstr == NULL) {
error(0, "Error allocating memory");
return NULL;
}
return modstr;
}
uint8_t *get_public_key(int sockfd)
{
message keyex_msg;
if (recv_packet(&keyex_msg, sockfd) != ZSM_STA_SUCCESS) {
/* We can't do anything if key exchange already failed */
close(sockfd);
return NULL;
} else {
int status = 0;
/* Check to see if the content is actually a key */
if (keyex_msg.type != ZSM_TYP_KEY) {
status = ZSM_STA_INVALID_TYPE;
}
if (keyex_msg.length != PUBLIC_KEY_SIZE) {
status = ZSM_STA_WRONG_KEY_LENGTH;
}
if (status != 0) {
free(keyex_msg.data);
message *error_msg = create_error_packet(status);
send_packet(error_msg, sockfd);
free_packet(error_msg);
close(sockfd);
return NULL;
}
}
/* Obtain public key from packet */
uint8_t *pk = memalloc(PUBLIC_KEY_SIZE * sizeof(char));
memcpy(pk, keyex_msg.data, PUBLIC_KEY_SIZE);
if (pk == NULL) {
free(keyex_msg.data);
/* Fatal, we couldn't complete key exchange */
close(sockfd);
return NULL;
}
free(keyex_msg.data);
return pk;
}
int send_public_key(int sockfd, uint8_t *pk)
{
/* send_packet requires heap allocated buffer */
uint8_t *pk_dup = memalloc(PUBLIC_KEY_SIZE * sizeof(char));
memcpy(pk_dup, pk, PUBLIC_KEY_SIZE);
if (pk_dup == NULL) {
close(sockfd);
return -1;
}
/* Sending our public key to client */
/* option???? */
message *keyex = create_packet(1, ZSM_TYP_KEY, PUBLIC_KEY_SIZE, pk_dup);
send_packet(keyex, sockfd);
free_packet(keyex);
return 0;
}
void print_packet(message *msg)
{
printf("Option: %d\n", msg->option);
printf("Type: %d\n", msg->type);
printf("Length: %lld\n", msg->length);
printf("Data: %s\n\n", msg->data);
printf("Status: %d\n", pkt->status);
printf("Type: %d\n", pkt->type);
printf("Length: %d\n", pkt->length);
if (pkt->length > 0) {
printf("Data:\n");
for (int i = 0; i < pkt->length; i++) {
printf("%d ", pkt->data[i]);
}
printf("\n");
printf("Signature:\n");
for (int i = 0; i < SIGN_SIZE; i++) {
printf("%d ", pkt->signature[i]);
}
printf("\n");
}
}
/*
* Requires manually free message data
* pkt: packet to fill data in (must be created via create_packet)
* fd: file descriptor to read data from
* required_type: Required packet type to receive, set 0 to not check
*/
int recv_packet(message *msg, int fd)
int recv_packet(packet *pkt, int fd, uint8_t required_type)
{
int status = ZSM_STA_SUCCESS;
/* Read the message components */
if (recv(fd, &msg->option, sizeof(msg->option), 0) < 0 ||
recv(fd, &msg->type, sizeof(msg->type), 0) < 0 ||
recv(fd, &msg->length, sizeof(msg->length), 0) < 0) {
if (recv(fd, &pkt->status, sizeof(pkt->status), 0) < 0 ||
recv(fd, &pkt->type, sizeof(pkt->type), 0) < 0 ||
recv(fd, &pkt->length, sizeof(pkt->length), 0) < 0) {
status = ZSM_STA_READING_SOCKET;
error(0, "Error reading from socket");
}
#if DEBUG == 1
printf("==========PACKET RECEIVED==========\n");
#endif
#if DEBUG == 1
printf("Option: %d\n", msg->option);
printf("==========PACKET RECEIVED==========\n");
printf("Status: %d\n", pkt->status);
#endif
if (msg->type > 0xFF || msg->type < 0x0) {
if (pkt->type > 0xFF || pkt->type < 0x0) {
status = ZSM_STA_INVALID_TYPE;
error(0, "Invalid message type");
goto failure;
}
#if DEBUG == 1
printf("Type: %d\n", msg->type);
printf("Type: %d\n", pkt->type);
#endif
/* Convert message length from network byte order to host byte order */
if (msg->length > MAX_MESSAGE_LENGTH) {
/* Not the same type as wanted to receive */
if (pkt->type != required_type) {
status = ZSM_STA_INVALID_TYPE;
error(0, "Invalid message type");
goto failure;
}
if (pkt->length > MAX_MESSAGE_LENGTH) {
status = ZSM_STA_TOO_LONG;
error(0, "Message too long: %lld", msg->length);
error(0, "Message too long: %d", pkt->length);
goto failure;
}
#if DEBUG == 1
printf("Length: %lld\n", msg->length);
printf("Length: %d\n", pkt->length);
#endif
size_t bytes_read = 0;
// Allocate memory for message data
msg->data = memalloc((msg->length + 1) * sizeof(char));
if (msg->data == NULL) {
status = ZSM_STA_MEMORY_ALLOCATION;
goto failure;
}
/* If packet's length is 0, ignore its data and signature as it is information from server */
if (pkt->type != ZSM_TYP_INFO && pkt->length > 0) {
pkt->data = memalloc((pkt->length + 1) * sizeof(char));
if (pkt->data == NULL) {
status = ZSM_STA_MEMORY_ALLOCATION;
goto failure;
}
/* Read message data from the socket */
size_t bytes_read = 0;
if ((bytes_read = recv(fd, msg->data, msg->length, 0)) < 0) {
status = ZSM_STA_READING_SOCKET;
error(0, "Error reading from socket");
free(msg->data);
goto failure;
}
if (bytes_read != msg->length) {
status = ZSM_STA_INVALID_LENGTH;
error(0, "Invalid message length: bytes_read=%ld != msg->length=%lld", bytes_read, msg->length);
free(msg->data);
goto failure;
}
msg->data[msg->length] = '\0';
/* Read message data from the socket */
if ((bytes_read = recv(fd, pkt->data, pkt->length, 0)) < 0) {
status = ZSM_STA_READING_SOCKET;
error(0, "Error reading from socket");
free(pkt->data);
goto failure;
}
if (bytes_read != pkt->length) {
status = ZSM_STA_INVALID_LENGTH;
error(0, "Invalid message length: bytes_read=%ld != pkt->length=%d", bytes_read, pkt->length);
free(pkt->data);
goto failure;
}
pkt->data[pkt->length] = '\0';
#if DEBUG == 1
printf("Data:\n");
for (int i = 0; i < pkt->length; i++) {
printf("%d ", pkt->data[i]);
}
printf("\n");
#endif
#if DEBUG == 1
printf("Data: %s\n\n", msg->data);
pkt->signature = memalloc((SIGN_SIZE + 1) * sizeof(char));
if (pkt->signature == NULL) {
status = ZSM_STA_MEMORY_ALLOCATION;
goto failure;
}
if ((bytes_read = recv(fd, pkt->signature, SIGN_SIZE, 0)) < 0) {
status = ZSM_STA_READING_SOCKET;
error(0, "Error reading from socket");
free(pkt->data);
goto failure;
}
/* Don't check signature if the packet is emtpy */
if (pkt->length > 0 && bytes_read != SIGN_SIZE) {
status = ZSM_STA_INVALID_LENGTH;
error(0, "Invalid signature length: bytes_read=%ld != SIGN_SIZE(32)", bytes_read);
free(pkt->data);
goto failure;
}
pkt->signature[SIGN_SIZE] = '\0';
#if DEBUG == 1
printf("Signature:\n");
for (int i = 0; i < SIGN_SIZE; i++) {
printf("%d ", pkt->signature[i]);
}
printf("\n");
#endif
}
#if DEBUG == 1
printf("==========END RECEIVING============\n");
#endif
return status;
failure:;
message *error_msg = create_error_packet(status);
if (send_packet(error_msg, fd) != ZSM_STA_SUCCESS) {
packet *error_pkt = create_packet(status, ZSM_TYP_ERROR, 0, NULL,
create_signature(NULL, 0, NULL));
if (send_packet(error_pkt, fd) != ZSM_STA_SUCCESS) {
/* Resend it? */
error(0, "Failed to send error packet to peer. Error status => %d", status);
error(0, "Failed to send error packet. Error status => %d", status);
}
free_packet(error_msg);
free_packet(error_pkt);
return status;
}
message *create_error_packet(int code)
/*
* Creates a packet for receive or send
* Requires heap allocated data
*/
packet *create_packet(uint8_t status, uint8_t type, uint32_t length, uint8_t *data, uint8_t *signature)
{
char *err = memalloc(ERROR_LENGTH * sizeof(char));
switch (code) {
case ZSM_STA_INVALID_TYPE:
strcpy(err, "Invalid message type ");
break;
case ZSM_STA_INVALID_LENGTH:
strcpy(err, "Invalid message length ");
break;
case ZSM_STA_TOO_LONG:
strcpy(err, "Message too long ");
break;
case ZSM_STA_READING_SOCKET:
strcpy(err, "Error reading from socket");
break;
case ZSM_STA_WRITING_SOCKET:
strcpy(err, "Error writing to socket ");
break;
case ZSM_STA_UNKNOWN_USER:
strcpy(err, "Unknwon user ");
break;
case ZSM_STA_WRONG_KEY_LENGTH:
strcpy(err, "Wrong public key length ");
break;
}
return create_packet(1, ZSM_TYP_ERROR, ERROR_LENGTH, err);
packet *pkt = memalloc(sizeof(packet));
pkt->status = status;
pkt->type = type;
pkt->length = length;
pkt->data = data;
pkt->signature = signature;
return pkt;
}
/*
* Requires heap allocated msg data
* Sends packet to fd
* Requires heap allocated data
* Close file descriptor and free data on failure
*/
message *create_packet(uint8_t option, uint8_t type, uint32_t length, char *data)
{
message *msg = memalloc(sizeof(message));
msg->option = option;
msg->type = type;
msg->length = length;
msg->data = data;
return msg;
}
/*
* Requires heap allocated msg data
*/
int send_packet(message *msg, int fd)
int send_packet(packet *pkt, int fd)
{
int status = ZSM_STA_SUCCESS;
uint32_t length = msg->length;
// Send the message back to the client
if (send(fd, &msg->option, sizeof(msg->option), 0) <= 0 ||
send(fd, &msg->type, sizeof(msg->type), 0) <= 0 ||
send(fd, &msg->length, sizeof(msg->length), 0) <= 0 ||
send(fd, msg->data, length, 0) <= 0) {
status = ZSM_STA_WRITING_SOCKET;
error(0, "Error writing to socket");
//free(msg->data);
close(fd); // Close the socket and continue accepting connections
}
uint32_t length = pkt->length;
if (send(fd, &pkt->status, sizeof(pkt->status), 0) <= 0 ||
send(fd, &pkt->type, sizeof(pkt->type), 0) <= 0 ||
send(fd, &pkt->length, sizeof(pkt->length), 0) <= 0)
{
goto failure;
}
if (pkt->type != ZSM_TYP_INFO && pkt->length > 0 && pkt->data != NULL) {
if (send(fd, pkt->data, length, 0) <= 0) goto failure;
if (send(fd, pkt->signature, SIGN_SIZE, 0) <= 0) goto failure;
}
#if DEBUG == 1
printf("==========PACKET SENT==========\n");
print_packet(msg);
printf("==========PACKET SENT============\n");
print_packet(pkt);
printf("==========END SENT===============\n");
#endif
return status;
failure:
/* Or we could resend it? */
status = ZSM_STA_WRITING_SOCKET;
error(0, "Error writing to socket");
free(pkt->data);
close(fd);
return status;
}
void free_packet(message *msg)
/*
* Free allocated memory in packet
*/
void free_packet(packet *pkt)
{
if (msg->type != 0x10) {
/* temp solution, dont use stack allocated msg to send to client */
free(msg->data);
if (pkt->type != ZSM_TYP_AUTH) {
if (pkt->signature != NULL) {
free(pkt->signature);
}
}
free(msg);
free(pkt->data);
free(pkt);
}
/*
* not going to stay
*/
char *getuserinput()
{
printf("Enter message to send: ");
fflush(stdout);
char *line = memalloc(1024);
line[0] = '\0';
size_t length = strlen(line);
while (length <= 1) {
fgets(line, 1024, stdin);
length = strlen(line);
}
length -= 1;
line[length] = '\0';
return line;
}
/*
* not going to stay
*/
char *getrecipient()
{
printf("Enter recipient: ");
fflush(stdout);
char *line = memalloc(32);
line[0] = '\0';
size_t length = strlen(line);
while (length <= 1) {
fgets(line, 1024, stdin);
length = strlen(line);
}
length -= 1;
line[length] = '\0';
return line;
}
int encrypt_packet(int sockfd, key_pair *kp)
{
int status = ZSM_STA_SUCCESS;
char *line = getuserinput();
uint8_t *recipient = getrecipient();
uint32_t data_len;
uint8_t *raw_data = memalloc(8192);
size_t length = strlen(recipient);
size_t length_line = strlen(line);
if (length < MAX_NAME) {
/* Pad with null characters up to max length */
memset(recipient + length, 0, MAX_NAME - length);
}
memcpy(raw_data, kp->pk.username, MAX_NAME);
memcpy(raw_data + MAX_NAME, recipient, MAX_NAME);
memcpy(raw_data + MAX_NAME * 2, line, length_line);
size_t raw_data_size = MAX_NAME * 2 + strlen(line);
uint8_t *data = encrypt_data(kp->pk.username, recipient, raw_data, raw_data_size, &data_len);
uint8_t *signature = create_signature(data, data_len, &kp->sk);
packet *pkt = create_packet(1, ZSM_TYP_MESSAGE, data_len, data, signature);
if ((status = send_packet(pkt, sockfd)) != ZSM_STA_SUCCESS) {
close(sockfd);
return status;
}
free(recipient);
free(line);
free_packet(pkt);
return status;
}
/*
* Wrapper for recv_packet to verify packet
* Reads packet from fd, stores in pkt
* TODO: pkt is unncessary
*/
packet *verify_packet(packet *pkt, int fd)
{
if (recv_packet(pkt, fd, ZSM_TYP_MESSAGE) != ZSM_STA_SUCCESS) {
close(fd);
return NULL;
}
uint8_t from[MAX_NAME], to[MAX_NAME];
memcpy(from, pkt->data, MAX_NAME);
/* TODO: replace with db operations */
key_pair *kp_from = get_key_pair(from);
if (verify_integrity(pkt, &kp_from->pk) != ZSM_STA_SUCCESS) {
free(pkt->data);
free(pkt->signature);
packet *error_pkt = create_packet(ZSM_STA_ERROR_INTEGRITY, ZSM_TYP_ERROR, 0, NULL,
create_signature(NULL, 0, NULL));
send_packet(error_pkt, fd);
free_packet(error_pkt);
return NULL;
}
return pkt;
}
/*
* Encrypt raw with raw_length using to
* length is set to sum length of random bytes and scrambled data
*/
uint8_t *encrypt_data(uint8_t *from, uint8_t *to, uint8_t *raw, uint32_t raw_length, uint32_t *length)
{
key_pair *kp_from = get_key_pair(from);
key_pair *kp_to = get_key_pair(to);
uint8_t shared_key[SHARED_SIZE];
if (crypto_kx_client_session_keys(shared_key, NULL,
kp_from->pk.bin, kp_from->sk.bin, kp_to->pk.bin) != 0) {
/* Suspicious server public key, bail out */
error(0, "Error performing key exchange");
}
uint8_t nonce[NONCE_SIZE];
uint32_t encrypted_len = raw_length + ADDITIONAL_SIZE;
uint8_t encrypted[encrypted_len];
/* Generate random nonce(number used once) */
printf("raw: %s\n", raw);
randombytes_buf(nonce, sizeof(nonce));
crypto_aead_xchacha20poly1305_ietf_encrypt(encrypted, NULL, raw,
raw_length, NULL, 0, NULL, nonce, shared_key);
size_t data_len = MAX_NAME * 2 + NONCE_SIZE + encrypted_len;
*length = data_len;
uint8_t *data = memalloc(data_len * sizeof(uint8_t));
memcpy(data, kp_from->sk.username, MAX_NAME);
memcpy(data + MAX_NAME, kp_to->sk.username, MAX_NAME);
memcpy(data + MAX_NAME * 2, nonce, NONCE_SIZE);
memcpy(data + MAX_NAME * 2 + NONCE_SIZE, encrypted, encrypted_len);
return data;
}
/*
* Should be used by clients
*/
uint8_t *decrypt_data(packet *pkt)
{
size_t encrypted_len = pkt->length - NONCE_SIZE - MAX_NAME * 2;
size_t data_len = encrypted_len - ADDITIONAL_SIZE;
uint8_t nonce[NONCE_SIZE], from[MAX_NAME], to[MAX_NAME], encrypted[encrypted_len];
uint8_t *decrypted = memalloc((data_len + 1) * sizeof(uint8_t));
memcpy(from, pkt->data, MAX_NAME);
memcpy(to, pkt->data + MAX_NAME, MAX_NAME);
memcpy(nonce, pkt->data + MAX_NAME * 2, NONCE_SIZE);
memcpy(encrypted, pkt->data + MAX_NAME * 2 + NONCE_SIZE, encrypted_len);
key_pair *kp_from = get_key_pair(from);
key_pair *kp_to = get_key_pair(to);
uint8_t shared_key[SHARED_SIZE];
if (crypto_kx_client_session_keys(shared_key, NULL,
kp_from->pk.bin, kp_from->sk.bin, kp_to->pk.bin) != 0) {
/* Suspicious server public key, bail out */
error(0, "Error performing key exchange");
}
/* We don't need it anymore */
free(pkt->data);
if (crypto_aead_xchacha20poly1305_ietf_decrypt(decrypted, NULL,
NULL, encrypted,
encrypted_len,
NULL, 0,
nonce, shared_key) != 0) {
free(decrypted);
error(0, "Cannot decrypt message");
return NULL;
} else {
/* Terminate decrypted message so we don't print random bytes */
decrypted[data_len] = '\0';
printf("<%s> to <%s>: %s\n", from, to, decrypted + MAX_NAME * 2);
return decrypted;
}
}
/*
* Verify message integrity + confidentiality
* Verify using public key against hashed message
*/
int verify_integrity(packet *pkt, public_key *pk)
{
uint8_t hash[HASH_SIZE];
/* Hash data to check if matches user provided correct signature */
crypto_generichash(hash, HASH_SIZE,
pkt->data, pkt->length,
NULL, 0);
if (crypto_sign_verify_detached(pkt->signature, hash, HASH_SIZE, pk->bin) != 0) {
/* Not match */
error(0, "Cannot verify message integrity");
return ZSM_STA_ERROR_INTEGRITY;
}
return ZSM_STA_SUCCESS;
}
/*
* Create signature for packet
* When data, secret is null, length is 0, empty siganture is created
*/
uint8_t *create_signature(uint8_t *data, uint32_t length, secret_key *sk)
{
uint8_t *signature = memalloc(SIGN_SIZE * sizeof(uint8_t));
if (data == NULL && length == 0 && sk == NULL) {
/* From server, give fake signature */
memset(signature, 0, SIGN_SIZE * sizeof(uint8_t));
} else {
uint8_t hash[HASH_SIZE];
/* Hash data to check if matches user provided correct signature */
crypto_generichash(hash, HASH_SIZE,
data, length,
NULL, 0);
crypto_sign_detached(signature, NULL, hash, HASH_SIZE, sk->bin);
}
return signature;
}

173
lib/util.c Normal file
View file

@ -0,0 +1,173 @@
#include "config.h"
#include "packet.h"
#include "util.h"
/*
* will include error message from function if errno isn't 0
* end program is fatal is 1
*/
void error(int fatal, const char *fmt, ...)
{
va_list args;
va_start(args, fmt);
/* to preserve errno */
int errsv = errno;
/* Determine the length of the formatted error message */
va_list args_copy;
va_copy(args_copy, args);
size_t error_len = vsnprintf(NULL, 0, fmt, args_copy);
va_end(args_copy);
/* 7 for [zsm], space and null */
char errorstr[error_len + 1];
vsnprintf(errorstr, error_len + 1, fmt, args);
fprintf(stderr, "[zsm] ");
if (errsv != 0) {
perror(errorstr);
errno = 0;
} else {
fprintf(stderr, "%s\n", errorstr);
}
va_end(args);
if (fatal) exit(1);
}
void *memalloc(size_t size)
{
void *ptr = malloc(size);
if (!ptr) {
write_log(LOG_ERROR, "Error allocating memory\n");
return NULL;
}
return ptr;
}
void *estrdup(void *str)
{
void *modstr = strdup(str);
if (modstr == NULL) {
write_log(LOG_ERROR, "Error allocating memory\n");
return NULL;
}
return modstr;
}
/*
* Set socket to non blocking so epoll can use EPOLLET flag
*/
int set_nonblocking(int fd)
{
int flags = fcntl(fd, F_GETFL, 0);
if (flags == -1) {
return -1;
}
return fcntl(fd, F_SETFL, flags | O_NONBLOCK);
}
/*
* Takes heap-allocated str and replace ~ with home path
* Returns heap-allocated newstr
*/
char *replace_home(char *str)
{
char *home = getenv("HOME");
if (home == NULL) {
write_log(LOG_ERROR, "$HOME not defined\n");
return str;
}
char *newstr = memalloc((strlen(str) + strlen(home)) * sizeof(char));
/* replace ~ with home */
snprintf(newstr, strlen(str) + strlen(home), "%s%s", home, str + 1);
free(str);
return newstr;
}
/*
* Recursively create directory by creating each subdirectory
* like mkdir -p
*/
void mkdir_p(const char *destdir)
{
char *path = memalloc(PATH_MAX * sizeof(char));
char dir_path[PATH_MAX] = "";
if (destdir[0] == '~') {
char *home = getenv("HOME");
if (home == NULL) {
write_log(LOG_ERROR, "$HOME not defined\n");
return;
}
/* replace ~ with home */
snprintf(path, PATH_MAX, "%s%s", home, destdir + 1);
} else {
strcpy(path, destdir);
}
/* fix first / not appearing in the string */
if (path[0] == '/')
strcat(dir_path, "/");
char *token = strtok(path, "/");
while (token != NULL) {
strcat(dir_path, token);
strcat(dir_path, "/");
if (mkdir(dir_path, 0755) == -1) {
struct stat st;
if (stat(dir_path, &st) == 0 && S_ISDIR(st.st_mode)) {
/* Directory already exists, continue to the next dir */
token = strtok(NULL, "/");
continue;
}
write_log(LOG_ERROR, "mkdir failed: %s\n", strerror(errno));
free(path);
return;
}
token = strtok(NULL, "/");
}
free(path);
return;
}
void write_log(int type, const char *fmt, ...)
{
va_list args;
va_start(args, fmt);
char *client_data_dir = estrdup(CLIENT_DATA_DIR);
mkdir_p(client_data_dir);
client_data_dir = replace_home(client_data_dir);
char *client_log = memalloc(PATH_MAX);
snprintf(client_log, PATH_MAX, "%s/%s", client_data_dir, "zen.log");
free(client_data_dir);
FILE *log = fopen(client_log, "a");
if (log != NULL) {
time_t now = time(NULL);
struct tm *t = localtime(&now);
/* either info or error */
int type_len = type == LOG_INFO ? 4 : 5;
char logtype[4 + type_len];
snprintf(logtype, 4 + type_len, "[%s] ", type == LOG_INFO ? "INFO" : "ERROR");
char time[21];
strftime(time, 22, "%Y-%m-%d %H:%M:%S ", t);
char details[2 + type_len + 22];
snprintf(details, 2 + type_len + 22, "%s%s", logtype, time);
fprintf(log, details);
vfprintf(log, fmt, args);
}
fclose(log);
va_end(args);
}
void print_bin(const unsigned char *ptr, size_t length)
{
for (size_t i = 0; i < length; i++) {
printf("%02x ", ptr[i]);
}
printf("\n");
}

View file

@ -1,15 +1,93 @@
#include "config.h"
#include "packet.h"
#include <pthread.h>
#include "key.h"
#include "util.h"
#include "client/ui.h"
#include "client/db.h"
uint8_t shared_key[SHARED_KEY_SIZE];
int sockfd;
/*
* Connect to socket server
* Authenticate with server by signing a challenge
*/
int socket_init()
int authenticate_server(key_pair *kp)
{
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
packet server_auth_pkt;
int status;
if ((status = recv_packet(&server_auth_pkt, sockfd, ZSM_TYP_AUTH) != ZSM_STA_SUCCESS)) {
return status;
}
uint8_t *challenge = server_auth_pkt.data;
uint8_t *sig = memalloc(SIGN_SIZE * sizeof(uint8_t));
crypto_sign_detached(sig, NULL, challenge, CHALLENGE_SIZE, kp->sk.bin);
uint8_t *pk_content = memalloc(PK_SIZE);
memcpy(pk_content, kp->pk.bin, PK_BIN_SIZE);
memcpy(pk_content + PK_BIN_SIZE, kp->pk.username, MAX_NAME);
memcpy(pk_content + PK_BIN_SIZE + MAX_NAME, &kp->pk.creation, TIME_SIZE);
memcpy(pk_content + PK_BIN_SIZE + METADATA_SIZE, kp->pk.signature, SIGN_SIZE);
packet *auth_pkt = create_packet(1, ZSM_TYP_AUTH, SIGN_SIZE, pk_content, sig);
if (send_packet(auth_pkt, sockfd) != ZSM_STA_SUCCESS) {
/* fd already closed */
error(0, "Could not authenticate with server");
free(sig);
free_packet(auth_pkt);
return ZSM_STA_ERROR_AUTHENTICATE;
}
free_packet(auth_pkt);
packet response;
status = recv_packet(&response, sockfd, ZSM_TYP_INFO);
return (response.status == ZSM_STA_AUTHORISED ? ZSM_STA_SUCCESS : ZSM_STA_ERROR_AUTHENTICATE);
}
/*
* For sending packets to server
*/
void *send_message(void *arg)
{
key_pair *kp = (key_pair *) arg;
while (1) {
int status = encrypt_packet(sockfd, kp);
if (status != ZSM_STA_SUCCESS) {
error(1, "Error encrypting packet %x", status);
}
}
return NULL;
}
/*
* For receiving packets from server
*/
void *receive_message(void *arg)
{
key_pair *kp = (key_pair *) arg;
while (1) {
packet pkt;
if (verify_packet(&pkt, sockfd) == 0) {
error(0, "Error verifying packet");
}
uint8_t *decrypted = decrypt_data(&pkt);
free(decrypted);
}
return NULL;
}
int main()
{
if (sodium_init() < 0) {
write_log(LOG_ERROR, "Error initializing libsodium\n");
}
//ui();
sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd < 0) {
error(1, "Error on opening socket");
}
@ -19,129 +97,58 @@ int socket_init()
error(1, "No such host %s", DOMAIN);
}
struct sockaddr_in sv_addr;
memset(&sv_addr, 0, sizeof(sv_addr));
sv_addr.sin_family = AF_INET;
sv_addr.sin_port = htons(PORT);
memcpy(&sv_addr.sin_addr.s_addr, server->h_addr, server->h_length);
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_port = htons(PORT);
memcpy(&server_addr.sin_addr.s_addr, server->h_addr, server->h_length);
/* free(server); */
if (connect(sockfd, (struct sockaddr *) &sv_addr, sizeof(sv_addr)) < 0) {
error(1, "Error on connect");
close(sockfd);
return 0;
}
printf("Connected to server at %s\n", DOMAIN);
return sockfd;
}
/*
* Performs key exchange with server
*/
int key_exchange(int sockfd)
{
/* Generate the client's key pair */
uint8_t cl_pk[PUBLIC_KEY_SIZE], cl_sk[PRIVATE_KEY_SIZE];
crypto_kx_keypair(cl_pk, cl_sk);
/* Send our public key */
if (send_public_key(sockfd, cl_pk) < 0) {
return -1;
}
/* Get public key from server */
uint8_t *pk;
if ((pk = get_public_key(sockfd)) == NULL) {
return -1;
}
/* Compute a shared key using the server's public key and our secret key */
if (crypto_kx_client_session_keys(NULL, shared_key, cl_pk, cl_sk, pk) != 0) {
error(1, "Server public key is not acceptable");
free(pk);
close(sockfd);
return -1;
}
free(pk);
return 0;
}
void *sender()
{
while (1) {
printf("Enter message to send to server: ");
fflush(stdout);
char line[1024];
line[0] = '\0';
size_t length = strlen(line);
while (length <= 1) {
fgets(line, sizeof(line), stdin);
length = strlen(line);
}
length -= 1;
line[length] = '\0';
uint8_t nonce[NONCE_SIZE];
uint8_t encrypted[length + ADDITIONAL_SIZE];
unsigned long long encrypted_len;
randombytes_buf(nonce, sizeof(nonce));
crypto_aead_xchacha20poly1305_ietf_encrypt(encrypted, &encrypted_len,
line, length,
NULL, 0, NULL, nonce, shared_key);
size_t payload_t = NONCE_SIZE + encrypted_len;
uint8_t encryptedwithnonce[payload_t];
memcpy(encryptedwithnonce, nonce, NONCE_SIZE);
memcpy(encryptedwithnonce + NONCE_SIZE, encrypted, encrypted_len);
message *msg = create_packet(1, 0x10, payload_t, encryptedwithnonce);
if (send_packet(msg, sockfd) != ZSM_STA_SUCCESS) {
close(sockfd);
}
free_packet(msg);
}
close(sockfd);
}
void *receiver()
{
while (1) {
message servermsg;
if (recv_packet(&servermsg, sockfd) != ZSM_STA_SUCCESS) {
/* free(server); Can't be freed seems */
if (connect(sockfd, (struct sockaddr *) &server_addr, sizeof(server_addr)
) < 0) {
if (errno != EINPROGRESS) {
/* Connection is in progress, shouldn't be treated as error */
error(1, "Error on connect");
close(sockfd);
return 0;
}
free(servermsg.data);
}
return NULL;
}
}
write_log(LOG_INFO, "Connected to server at %s\n", DOMAIN);
int main()
{
if (sodium_init() < 0) {
error(1, "Error initializing libsodium");
}
sockfd = socket_init();
if (key_exchange(sockfd) < 0) {
/* set_nonblocking(sockfd); */
/*
key_pair *kpp = create_key_pair("palanix");
key_pair *kpn = create_key_pair("night");
*/
key_pair *kpp = get_key_pair("palanix");
key_pair *kpn = get_key_pair("night");
if (authenticate_server(kpp) != ZSM_STA_SUCCESS) {
/* Fatal */
error(1, "Error performing key exchange with server");
}
pthread_t recv_worker, send_worker;
if (pthread_create(&recv_worker, NULL, sender, NULL) != 0) {
fprintf(stderr, "Error creating incoming thread\n");
return 1;
error(1, "Error authenticating with server");
} else {
write_log(LOG_INFO, "Authenticated with server\n");
printf("Authenticated as palanix\n");
}
/* Create threads for sending and receiving messages */
pthread_t send_thread, receive_thread;
if (pthread_create(&send_thread, NULL, send_message, kpp) != 0) {
close(sockfd);
error(1, "Failed to create send thread");
exit(EXIT_FAILURE);
}
if (pthread_create(&send_worker, NULL, receiver, NULL) != 0) {
fprintf(stderr, "Error creating outgoing thread\n");
return 1;
if (pthread_create(&receive_thread, NULL, receive_message, kpp) != 0) {
close(sockfd);
error(1, "Failed to create receive thread");
}
// Join threads
pthread_join(recv_worker, NULL);
pthread_join(send_worker, NULL);
/* Wait for threads to finish */
pthread_join(send_thread, NULL);
pthread_join(receive_thread, NULL);
close(sockfd);
return 0;
}

65
src/client/db.c Normal file
View file

@ -0,0 +1,65 @@
#include <stdio.h>
#include <string.h>
#include "config.h"
#include "packet.h"
#include "util.h"
#include "client/ui.h"
#include "client/db.h"
#include "client/user.h"
static int callback(void *NotUsed, int argc, char **argv, char **azColName)
{
char *username = memalloc(32 * sizeof(char));
strcpy(username, argv[0]);
add_username(username);
/*
for(int i = 0; i < argc; i++) {
printf("%s = %s\n", azColName[i], argv[i] ? argv[i] : "NULL");
}
printf("\n");
*/
return 0;
}
int sqlite_init()
{
sqlite3 *db;
char *err_msg = 0;
int rc = sqlite3_open(DATABASE_NAME, &db);
if (rc != SQLITE_OK) {
fprintf(stderr, "Cannot open database: %s\n", sqlite3_errmsg(db));
sqlite3_close(db);
return 1;
}
char *users_statement = "CREATE TABLE IF NOT EXISTS Users(Username TEXT, SecretKey TEXT, Test TEXT);";
char *messages_statement = "CREATE TABLE IF NOT EXISTS Messages(Username TEXT, );";
//"INSERT INTO Users VALUES('night', 'test', '1');";
rc = sqlite3_exec(db, users_statement, 0, 0, &err_msg);
if (rc != SQLITE_OK) {
error(0, "SQL error: %s", err_msg);
sqlite3_free(err_msg);
} else {
/* error(0, "Table created successfully"); */
}
// Select and print all entries
const char* data = "Callback function called";
rc = sqlite3_exec(db, "SELECT * FROM Users", callback, (void*) data, &err_msg);
if (rc != SQLITE_OK ) {
error(0, "SQL error: %s\n", err_msg);
sqlite3_free(err_msg);
}
sqlite3_close(db);
return 0;
}

337
src/client/ui.c Normal file
View file

@ -0,0 +1,337 @@
#include "config.h"
#include "packet.h"
#include "util.h"
#include "client/ui.h"
#include "client/db.h"
#include "client/user.h"
typedef struct windows {
WINDOW *users_border;
WINDOW *users_content;
WINDOW *chat_border;
WINDOW *chat_content;
} windows;
WINDOW *panel;
WINDOW *users_border;
WINDOW *chat_border;
WINDOW *users_content;
WINDOW *chat_content;
ArrayList *users;
ArrayList *marked;
long current_selection = 0;
bool show_icons;
void show_chat();
void ncurses_init()
{
/* check if it is interactive shell */
if (!isatty(STDIN_FILENO)) {
error(1, "No tty detected. zsm requires an interactive shell to run");
}
/* initialize screen, don't print special chars,
* make ctrl + c work, don't show cursor
* enable arrow keys */
initscr();
noecho();
cbreak();
curs_set(0);
keypad(stdscr, TRUE);
/* check terminal has colors */
if (!has_colors()) {
endwin();
error(1, "Color is not supported in your terminal");
} else {
use_default_colors();
start_color();
}
/* colors */
init_pair(1, COLOR_BLACK, -1); /* */
init_pair(2, COLOR_RED, -1); /* */
init_pair(3, COLOR_GREEN, -1); /* */
init_pair(4, COLOR_YELLOW, -1); /* */
init_pair(5, COLOR_BLUE, -1); /* */
init_pair(6, COLOR_MAGENTA, -1); /* */
init_pair(7, COLOR_CYAN, -1); /* */
init_pair(8, COLOR_WHITE, -1); /* */
}
/*
* Draw windows
*/
void windows_init()
{
int users_width = 32;
int chat_width = COLS - 32;
/*------------------------------+
|----border(0)--||--border(2)--||
|| || ||
|| content (1) || content (3) ||
|| (users) || (chat) ||
|| || ||
|---------------||-------------||
+==========panel (4)===========*/
/* lines, cols, y, x */
panel = newwin(PANEL_HEIGHT, COLS, LINES - PANEL_HEIGHT, 0 );
/* draw border around windows */
users_border = newwin(LINES - PANEL_HEIGHT, users_width + 2, 0, 0 );
chat_border = newwin(LINES - PANEL_HEIGHT, chat_width - 2, 0, users_width + 2);
users_content = newwin(LINES - PANEL_HEIGHT - 2, users_width, 1, 1 );
chat_content = newwin(LINES - PANEL_HEIGHT - 2, chat_width - 4, 1, users_width + 3);
refresh();
draw_border(users_border, true);
draw_border(chat_border, false);
scrollok(users_content, true);
scrollok(chat_content, true);
refresh();
}
/*
* Draw the border of the window depending if it's active or not,
*/
void draw_border(WINDOW *window, bool active)
{
int width;
if (window == users_border) {
width = 34;
} else {
width = COLS - 34;
}
/* turn on color depends on active */
if (active) {
wattron(window, COLOR_PAIR(3));
} else {
wattron(window, COLOR_PAIR(5));
}
/* draw top border */
mvwaddch(window, 0, 0, ACS_ULCORNER); /* upper left corner */
mvwhline(window, 0, 1, ACS_HLINE, COLS - 2); /* top horizontal line */
mvwaddch(window, 0, width - 1, ACS_URCORNER); /* upper right corner */
/* draw side border */
mvwvline(window, 1, 0, ACS_VLINE, LINES - 2); /* left vertical line */
mvwvline(window, 1, width - 1, ACS_VLINE, LINES - 2); /* right vertical line */
/* draw bottom border
* make space for the panel */
mvwaddch(window, LINES - PANEL_HEIGHT - 1, 0, ACS_LLCORNER); /* lower left corner */
mvwhline(window, LINES - PANEL_HEIGHT - 1, 1, ACS_HLINE, width - 2); /* bottom horizontal line */
mvwaddch(window, LINES - PANEL_HEIGHT - 1, width - 1, ACS_LRCORNER); /* lower right corner */
/* turn color off after turning it on */
if (active) {
wattroff(window, COLOR_PAIR(3));
} else {
wattroff(window, COLOR_PAIR(5));
}
wrefresh(window); /* Refresh the window to see the colored border and title */
}
/*
* Print line to the panel
*/
void wpprintw(const char *fmt, ...)
{
va_list args;
va_start(args, fmt);
wclear(panel);
vw_printw(panel, fmt, args);
va_end(args);
wrefresh(panel);
}
/*
* Highlight current line by reversing the color
*/
void highlight_current_line()
{
long overflow = 0;
if (current_selection > LINES - 4) {
/* overflown */
overflow = current_selection - (LINES - 4);
}
/* calculate range of files to show */
long range = users->length;
/* not highlight if no files in directory */
if (range == 0 && errno == 0) {
#if DRAW_PREVIEW
wprintw(chat_content, "No users. Start a converstation.");
wrefresh(chat_content);
#endif
return;
}
if (range > LINES - 3) {
/* if there are more files than lines available to display
* shrink range to avaiable lines to display with
* overflow to keep the number of iterations to be constant */
range = LINES - 3 + overflow;
}
wclear(users_content);
long line_count = 0;
for (long i = overflow; i < range; i++) {
if ((overflow == 0 && i == current_selection) || (overflow != 0 && i == current_selection)) {
wattron(users_content, A_REVERSE);
/* check for marked user */
long num_marked = marked->length;
if (num_marked > 0) {
/* Determine length of formatted string */
int m_len = snprintf(NULL, 0, "[%ld] selected", num_marked);
char *selected = memalloc((m_len + 1) * sizeof(char));
snprintf(selected, m_len + 1, "[%ld] selected", num_marked);
wpprintw("(%ld/%ld) %s", current_selection + 1, users->length, selected);
} else {
wpprintw("(%ld/%ld)", current_selection + 1, users->length);
}
}
/* print the actual filename and stats */
char *line = get_line(users, i, show_icons);
int color = users->items[i].color;
/* check is user marked for action */
bool is_marked = arraylist_search(marked, users->items[i].name) != -1;
if (is_marked) {
/* show user is selected */
wattron(users_content, COLOR_PAIR(7));
} else {
/* print the whole directory with default colors */
wattron(users_content, COLOR_PAIR(color));
}
if (overflow > 0)
mvwprintw(users_content, line_count, 0, "%s", line);
else
mvwprintw(users_content, i, 0, "%s", line);
if (is_marked) {
wattroff(users_content, COLOR_PAIR(7));
} else {
wattroff(users_content, COLOR_PAIR(color));
}
wattroff(users_content, A_REVERSE);
//free(line);
line_count++;
}
wrefresh(users_content);
wrefresh(panel);
/* show chat conversation every time cursor changes */
#if DRAW_PREVIEW
show_chat();
#endif
#if DRAW_BORDERS
draw_border_title(preview_border, true);
#endif
wrefresh(chat_content);
}
/*
* Add message to chat window
* user_color is the color defined above at ncurses_init
*/
void add_message(time_t rawtime, char *username, int user_color, char *content)
{
struct tm *timeinfo = localtime(&rawtime);
char buffer[21];
strftime(buffer, sizeof(buffer), "%b %d %Y %H:%M:%S", timeinfo);
wprintw(chat_content, "%s ", buffer);
wattron(chat_content, A_BOLD);
wattron(chat_content, COLOR_PAIR(user_color));
wprintw(chat_content, "<%s> ", username);
wattroff(chat_content, A_BOLD);
wattroff(chat_content, COLOR_PAIR(user_color));
wprintw(chat_content, "%s", content);
}
/*
* Get chat conversation into buffer and show it to chat window
*/
void show_chat()
{
add_message(1725932011, "night", 1, "I go to school by bus.\n");
add_message(1725933011, "night", 2, "I go to school by tram.\n");
add_message(1725934011, "night", 3, "I go to school by train.\n");
add_message(1725935011, "night", 4, "I go to school by car.\n");
add_message(1725936011, "night", 5, "I go to school by run.\n");
add_message(1725937011, "night", 6, "I go to school by bike.\n");
add_message(1725938011, "night", 7, "I go to school by plane.\n");
}
/*
* Require heap allocated username
*/
void add_username(char *username)
{
wchar_t *icon_str = memalloc(2 * sizeof(wchar_t));
wcsncpy(icon_str, L"", 2);
arraylist_add(users, username, icon_str, 7, false, false);
}
/*
* Main loop of user interface
*/
void ui()
{
ncurses_init();
windows_init();
users = arraylist_init(LINES);
marked = arraylist_init(100);
show_icons = true;
sqlite_init();
highlight_current_line();
refresh();
int ch;
while (1) {
if (COLS < 80 || LINES < 24) {
endwin();
error(1, "Terminal size needs to be at least 80x24");
}
ch = getch();
switch (ch) {
case 'q':
goto cleanup;
/* go up by k or up arrow */
case UP:
case 'k':
if (current_selection > 0)
current_selection--;
highlight_current_line();
break;
/* go down by j or down arrow */
case DOWN:
case 'j':
if (current_selection < (users->length - 1))
current_selection++;
highlight_current_line();
break;
}
}
cleanup:
arraylist_free(users);
arraylist_free(marked);
endwin();
return;
}

119
src/client/user.c Normal file
View file

@ -0,0 +1,119 @@
#include "packet.h"
#include "util.h"
#include "client/user.h"
ArrayList *arraylist_init(size_t capacity)
{
ArrayList *list = memalloc(sizeof(ArrayList));
list->length = 0;
list->capacity = capacity;
list->items = memalloc(capacity * sizeof(user));
return list;
}
void arraylist_free(ArrayList *list)
{
for (size_t i = 0; i < list->length; i++) {
if (list->items[i].name != NULL)
free(list->items[i].name);
if (list->items[i].icon != NULL)
free(list->items[i].icon);
}
free(list->items);
free(list);
}
/*
* Check if the user is in the arraylist
*/
long arraylist_search(ArrayList *list, char *username)
{
for (long i = 0; i < list->length; i++) {
if (strcmp(list->items[i].name, username) == 0) {
return i;
}
}
return -1;
}
void arraylist_remove(ArrayList *list, long index)
{
if (index >= list->length)
return;
free(list->items[index].name);
free(list->items[index].icon);
for (long i = index; i < list->length - 1; i++)
list->items[i] = list->items[i + 1];
list->length--;
}
/*
* Force will not remove duplicate marked users, instead it just skip adding
*/
void arraylist_add(ArrayList *list, char *name, wchar_t *icon, int color, bool marked, bool force)
{
user new_user = { name, icon, color };
if (list->capacity != list->length) {
if (marked) {
for (int i = 0; i < list->length; i++) {
if (strcmp(list->items[i].name, new_user.name) == 0) {
if (!force)
arraylist_remove(list, i);
return;
}
}
}
list->items[list->length] = new_user;
} else {
int new_cap = list->capacity * 2;
user *new_items = memalloc(new_cap * sizeof(user));
user *old_items = list->items;
list->capacity = new_cap;
list->items = new_items;
for (int i = 0; i < list->length; i++)
new_items[i] = old_items[i];
free(old_items);
list->items[list->length] = new_user;
}
list->length++;
}
/*
* Construct a formatted line for display
*/
char *get_line(ArrayList *list, long index, bool icons)
{
user seluser = list->items[index];
size_t name_len = strlen(seluser.name);
size_t length;
if (icons) {
length = name_len + 10; /* 8 for icon, 1 for space and 1 for null */
} else {
length = name_len;
}
char *line = memalloc(length * sizeof(char));
line[0] = '\0';
if (icons) {
char *tmp = memalloc(9 * sizeof(char));
snprintf(tmp, 8, "%ls", seluser.icon);
strcat(line, tmp);
strcat(line, " ");
free(tmp);
}
strcat(line, seluser.name);
return line;
}

75
src/server/ht.c Normal file
View file

@ -0,0 +1,75 @@
#include "ht.h"
#include "config.h"
#include "packet.h"
client *hash_table[TABLE_SIZE];
/* Hashes every name with: name and TABLE_SIZE */
unsigned int hash(char *name)
{
int length = strnlen(name, MAX_NAME);
unsigned int hash_value = 0;
for (int i = 0; i < length; i++) {
hash_value += name[i];
hash_value = (hash_value * name[i]) % TABLE_SIZE;
}
return hash_value;
}
void hashtable_init()
{
for (int i = 0; i < TABLE_SIZE; i++)
hash_table[i] = NULL;
}
void hashtable_print()
{
int i = 0;
for (; i < TABLE_SIZE; i++) {
if (hash_table[i] == NULL) {
printf("%i. ---\n", i);
} else {
printf("%i. | Name %s\n", i, hash_table[i]->name);
}
}
}
/* Gets hashed name and tries to store the client struct in that place */
int hashtable_add(client *p)
{
if (p == NULL) return 0;
int index = hash(p->name);
int initial_index = index;
/* linear probing until an empty slot is found */
while (hash_table[index] != NULL) {
index = (index + 1) % TABLE_SIZE; /* move to next item */
/* the hash table is full as no available index back to initial index, cannot fit new item */
if (index == initial_index) return 1;
}
hash_table[index] = p;
return 0;
}
/* Rehashes the name and then looks in this spot, if found returns client */
client *hashtable_search(char *name)
{
int index = hash(name);
int initial_index = index;
/* Linear probing until an empty slot or the desired item is found */
while (hash_table[index] != NULL) {
if (strncmp(hash_table[index]->name, name, MAX_NAME) == 0)
return hash_table[index];
index = (index + 1) % TABLE_SIZE; /* Move to the next slot */
/* back to same item */
if (index == initial_index) break;
}
return NULL;
}

View file

@ -1,79 +1,82 @@
#include "packet.h"
#include "util.h"
#include "config.h"
#include "notification.h"
#include <pthread.h>
socklen_t clilen;
struct sockaddr_in cli_address;
uint8_t shared_key[SHARED_KEY_SIZE];
int clientfd;
typedef struct client_t {
int fd; /* File descriptor for client socket */
uint8_t *shared_key;
char username[MAX_NAME]; /* Username of client */
} client_t;
typedef struct thread_t {
int epoll_fd; /* epoll instance for each thread */
pthread_t thread; /* POSIX thread */
int num_clients; /* Number of active clients in thread */
client_t clients[MAX_CLIENTS_PER_THREAD]; /* Active clients */
} thread_t;
thread_t threads[MAX_THREADS];
int num_thread = 0;
void *thread_worker(void *arg);
int set_nonblocking(int fd);
/*
* Initialise socket server
* Authenticate client before starting communication
*/
int socket_init()
int authenticate_client(int clientfd, uint8_t *username)
{
int serverfd = socket(AF_INET, SOCK_STREAM, 0);
if (serverfd < 0) {
error(1, "Error on opening socket");
/* Create a challenge */
uint8_t *challenge = memalloc(CHALLENGE_SIZE * sizeof(uint8_t));
randombytes_buf(challenge, CHALLENGE_SIZE);
/* Sending fake signature as structure requires it */
uint8_t *fake_sig = create_signature(NULL, 0, NULL);
packet *auth_pkt = create_packet(1, ZSM_TYP_AUTH, CHALLENGE_SIZE,
challenge, fake_sig);
if (send_packet(auth_pkt, clientfd) != ZSM_STA_SUCCESS) {
/* fd already closed */
error(0, "Could not authenticate client");
free(challenge);
free_packet(auth_pkt);
goto failure;
}
free(fake_sig);
packet client_auth_pkt;
int status;
if ((status = recv_packet(&client_auth_pkt, clientfd, ZSM_TYP_AUTH)
!= ZSM_STA_SUCCESS)) {
error(0, "Could not authenticate client");
goto failure;
}
/* Reuse addr(for debug) */
int optval = 1;
if (setsockopt(serverfd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)) < 0) {
uint8_t pk_bin[PK_BIN_SIZE], pk_username[MAX_NAME];
memcpy(pk_bin, client_auth_pkt.data, PK_BIN_SIZE);
memcpy(pk_username, client_auth_pkt.data + PK_BIN_SIZE, MAX_NAME);
error(1, "Error at setting SO_REUSEADDR");
if (crypto_sign_verify_detached(client_auth_pkt.signature, challenge, CHALLENGE_SIZE, pk_bin) != 0) {
error(0, "Incorrect signature, could not authenticate client");
free(client_auth_pkt.data);
goto failure;
} else {
packet *ok_pkt = create_packet(ZSM_STA_AUTHORISED, ZSM_TYP_INFO
, 0, NULL, NULL);
send_packet(ok_pkt, clientfd);
free_packet(ok_pkt);
strcpy(username, pk_username);
return ZSM_STA_SUCCESS;
}
failure:;
packet *error_pkt = create_packet(ZSM_STA_UNAUTHORISED, ZSM_TYP_ERROR,
0, NULL, create_signature(NULL, 0, NULL));
struct sockaddr_in sv_addr;
memset(&sv_addr, 0, sizeof(sv_addr));
sv_addr.sin_family = AF_INET;
sv_addr.sin_addr.s_addr = INADDR_ANY;
sv_addr.sin_port = htons(PORT);
if (bind(serverfd, (struct sockaddr *) &sv_addr, sizeof(sv_addr)) < 0) {
error(1, "Error on bind");
}
if (listen(serverfd, MAX_CONNECTION) < 0) {
error(1, "Error on listen");
}
printf("Listening on port %d\n", PORT);
clilen = sizeof(cli_address);
return serverfd;
}
/*
* Performs key exchange with client
*/
int key_exchange(int clientfd)
{
/* Generate the server's key pair */
uint8_t sv_pk[PUBLIC_KEY_SIZE], sv_sk[PRIVATE_KEY_SIZE];
crypto_kx_keypair(sv_pk, sv_sk);
/* Get public key from client */
uint8_t *pk;
if ((pk = get_public_key(clientfd)) == NULL) {
return -1;
}
/* Send our public key */
if (send_public_key(clientfd, sv_pk) < 0) {
free(pk);
return -1;
}
/* Compute a shared key using the client's public key and our secret key. */
if (crypto_kx_server_session_keys(NULL, shared_key, sv_pk, sv_sk, pk) != 0) {
error(0, "Client public key is not acceptable");
free(pk);
close(clientfd);
return -1;
}
free(pk);
return 0;
send_packet(error_pkt, clientfd);
free_packet(error_pkt);
close(clientfd);
return ZSM_STA_ERROR_AUTHENTICATE;
}
void signal_handler(int signal)
@ -91,98 +94,48 @@ void signal_handler(int signal)
}
}
void *receiver()
/*
* Takes thread_t as argument to use its epoll instance to wait new messages
* Thread worker to relay messages
*/
void *thread_worker(void *arg)
{
int serverfd = socket_init();
clientfd = accept(serverfd, (struct sockaddr *) &cli_address, &clilen);
if (clientfd < 0) {
error(0, "Error on accepting client");
/* Continue accpeting connections */
/* continue; */
}
thread_t *thread = (thread_t *) arg;
struct epoll_event events[MAX_EVENTS];
if (key_exchange(clientfd) < 0) {
error(0, "Error performing key exchange with client");
/* continue; */
}
while (1) {
message msg;
memset(&msg, 0, sizeof(msg));
if (recv_packet(&msg, clientfd) != ZSM_STA_SUCCESS) {
close(clientfd);
break;
/* continue; */
while (1) {
int num_events = epoll_wait(thread->epoll_fd, events, MAX_EVENTS, -1);
if (num_events == -1) {
error(0, "epoll_wait");
}
for (int i = 0; i < num_events; i++) {
client_t *client = (client_t *) events[i].data.ptr;
if (events[i].events & EPOLLIN) {
/* handle message */
packet pkt;
packet *verified_pkt = verify_packet(&pkt, client->fd);
if (verified_pkt == NULL) {
error(0, "Error verifying packet");
}
/* Message relay */
uint8_t to[MAX_NAME];
memcpy(to, verified_pkt->data + MAX_NAME, MAX_NAME);
for (int i = 0; i < MAX_THREADS; i++) {
thread_t thread = threads[i];
for (int j = 0; j < thread.num_clients; j++) {
client_t client = thread.clients[j];
if (strcmp(client.username, to) == 0) {
error(0, "Relaying message to %s\n", client.username);
send_packet(verified_pkt, client.fd);
}
}
}
}
}
size_t encrypted_len = msg.length - NONCE_SIZE;
size_t msg_len = encrypted_len - ADDITIONAL_SIZE;
uint8_t nonce[NONCE_SIZE];
uint8_t encrypted[encrypted_len];
uint8_t decrypted[msg_len + 1];
unsigned long long decrypted_len;
memcpy(nonce, msg.data, NONCE_SIZE);
memcpy(encrypted, msg.data + NONCE_SIZE, encrypted_len);
free(msg.data);
if (crypto_aead_xchacha20poly1305_ietf_decrypt(decrypted, &decrypted_len,
NULL,
encrypted, encrypted_len,
NULL, 0,
nonce, shared_key) != 0) {
error(0, "Cannot decrypt message");
} else {
/* Decrypted message */
decrypted[msg_len] = '\0';
printf("Decrypted: %s\n", decrypted);
send_notification(decrypted);
msg.data = malloc(14);
strcpy(msg.data, "Received data");
msg.length = 14;
send_packet(&msg, clientfd);
free(msg.data);
}
}
close(clientfd);
close(serverfd);
return NULL;
}
void *sender()
{
while (1) {
printf("Enter message to send to client: ");
fflush(stdout);
char line[1024];
line[0] = '\0';
size_t length = strlen(line);
while (length <= 1) {
fgets(line, sizeof(line), stdin);
length = strlen(line);
}
length -= 1;
line[length] = '\0';
uint8_t nonce[NONCE_SIZE];
uint8_t encrypted[length + ADDITIONAL_SIZE];
unsigned long long encrypted_len;
randombytes_buf(nonce, sizeof(nonce));
crypto_aead_xchacha20poly1305_ietf_encrypt(encrypted, &encrypted_len,
line, length,
NULL, 0, NULL, nonce, shared_key);
size_t payload_t = NONCE_SIZE + encrypted_len;
uint8_t encryptedwithnonce[payload_t];
memcpy(encryptedwithnonce, nonce, NONCE_SIZE);
memcpy(encryptedwithnonce + NONCE_SIZE, encrypted, encrypted_len);
message *msg = create_packet(1, 0x10, payload_t, encryptedwithnonce);
if (send_packet(msg, clientfd) != ZSM_STA_SUCCESS) {
close(clientfd);
}
free_packet(msg);
}
close(clientfd);
return NULL;
}
}
int main()
@ -199,20 +152,124 @@ int main()
signal(SIGINT, signal_handler);
signal(SIGTERM, signal_handler);
pthread_t recv_worker, send_worker;
if (pthread_create(&recv_worker, NULL, sender, NULL) != 0) {
fprintf(stderr, "Error creating incoming thread\n");
return 1;
/* Start server and epoll */
int serverfd, clientfd;
/* Create socket */
serverfd = socket(AF_INET, SOCK_STREAM, 0);
if (serverfd < 0) {
error(1, "Error on opening socket");
}
if (pthread_create(&send_worker, NULL, receiver, NULL) != 0) {
fprintf(stderr, "Error creating outgoing thread\n");
return 1;
/* Reuse address (for debug) */
int opt = 1;
if (setsockopt(serverfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
error(1, "Error at setting SO_REUSEADDR");
}
// Join threads
pthread_join(recv_worker, NULL);
pthread_join(send_worker, NULL);
struct sockaddr_in server_addr, client_addr;
socklen_t client_addr_len = sizeof(client_addr);
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = INADDR_ANY;
server_addr.sin_port = htons(PORT);
if (bind(serverfd, (struct sockaddr *) &server_addr
, sizeof(server_addr)) < 0) {
close(serverfd);
error(1, "Error on bind");
}
if (listen(serverfd, MAX_CONNECTION_QUEUE) < 0) {
close(serverfd);
error(1, "Error on listen");
}
/* Creating thread pool */
for (int i = 0; i < MAX_THREADS; i++) {
/* Create epoll instance for each thread */
threads[i].epoll_fd = epoll_create1(0);
if (threads[i].epoll_fd < 0) {
error(1, "Error on creating epoll instance");
}
threads[i].num_clients = 0;
/* Start a new thread and pass thread_t struct to thread */
if (pthread_create(&threads[i].thread, NULL, thread_worker,
&threads[i]) != 0) {
error(1, "Error on creating threads");
} else {
error(0, "Thread %d created", i);
}
}
error(0, "Listening on port %d", PORT);
/* Server loop to accept clients and load balance */
while (1) {
clientfd = accept(serverfd, (struct sockaddr *) &client_addr,
&client_addr_len);
if (clientfd < 0) {
error(0, "Error on accepting client");
continue;
}
/* Assign new client to a thread
* Clients distributed by a rotation(round-robin)
*/
thread_t *thread = &threads[num_thread];
if (thread->num_clients >= MAX_CLIENTS_PER_THREAD) {
error(0, "Thread %d is already full, rejecting connection\n",
num_thread);
close(clientfd);
continue;
}
client_t *client = &thread->clients[thread->num_clients];
uint8_t username[MAX_NAME];
/* User logins, authenticate them */
if (authenticate_client(clientfd, username) != ZSM_STA_SUCCESS) {
error(0, "Error authenticating with client");
continue;
}
/* To use EPOLLET, an nonblocking fd is required */
/*
if (set_nonblocking(clientfd) == -1) {
perror("Failed to set client socket to non-blocking");
close(clientfd);
continue;
}
*/
/* Add the new client to the thread's epoll instance */
struct epoll_event event;
event.data.ptr = client;
event.events = EPOLLIN;
if (epoll_ctl(thread->epoll_fd, EPOLL_CTL_ADD, clientfd, &event) == -1) {
perror("Failed to add client to epoll");
close(clientfd);
continue;
}
/* Assign fd to client in a thread */
client->fd = clientfd;
strcpy(client->username, username);
thread->num_clients++;
/* Rotate num_thread back to start if it is larder than MAX_THREADS */
num_thread = (num_thread + 1) % MAX_THREADS;
}
/* End the thread */
for (int i = 0; i < MAX_THREADS; i++) {
if (pthread_join(threads[i].thread, NULL) != 0) {
error(0, "pthread_join");
}
}
close(serverfd);
return 0;
}