zsm

Zen Secure Messaging
git clone https://codeberg.org/night0721/zsm
Log | Files | Refs | README | LICENSE

packet.c (7486B)


      1 #include "packet.h"
      2 
      3 /*
      4  * msg is the error message to print to stderr
      5  * will include error message from function if errno isn't 0
      6  * end program is fatal is 1
      7  */
      8 void error(int fatal, const char *fmt, ...)
      9 {
     10     va_list args;
     11     va_start(args, fmt);
     12 
     13     /* to preserve errno */
     14     int errsv = errno;
     15 
     16     /* Determine the length of the formatted error message */
     17     va_list args_copy;
     18     va_copy(args_copy, args);
     19     size_t error_len = vsnprintf(NULL, 0, fmt, args_copy);
     20     va_end(args_copy);
     21 
     22     /* 7 for [zsm], space and null */
     23     char errorstr[error_len + 1];
     24     vsnprintf(errorstr, error_len + 1, fmt, args);
     25     fprintf(stderr, "[zsm] ");
     26 
     27     if (errsv != 0) {
     28         perror(errorstr);
     29         errno = 0;
     30     } else {
     31         fprintf(stderr, "%s\n", errorstr);
     32     }
     33     
     34     va_end(args);
     35     if (fatal) exit(1);
     36 }
     37 
     38 void *memalloc(size_t size)
     39 {
     40     void *ptr = malloc(size);
     41     if (!ptr) {
     42         error(0, "Error allocating memory"); 
     43         return NULL;
     44     }
     45     return ptr;
     46 }
     47 
     48 void *estrdup(void *str)
     49 {
     50     void *modstr = strdup(str);
     51     if (modstr == NULL) {
     52         error(0, "Error allocating memory");
     53         return NULL;
     54     }
     55     return modstr;
     56 }
     57 
     58 uint8_t *get_public_key(int sockfd)
     59 {
     60     message keyex_msg;
     61     if (recv_packet(&keyex_msg, sockfd) != ZSM_STA_SUCCESS) {
     62         /* We can't do anything if key exchange already failed */
     63         close(sockfd);
     64         return NULL;
     65     } else {
     66         int status = 0;
     67         /* Check to see if the content is actually a key */
     68         if (keyex_msg.type != ZSM_TYP_KEY) {
     69             status = ZSM_STA_INVALID_TYPE;
     70         }
     71         if (keyex_msg.length != PUBLIC_KEY_SIZE) {
     72             status = ZSM_STA_WRONG_KEY_LENGTH;
     73         }
     74         if (status != 0) {
     75             free(keyex_msg.data);
     76             message *error_msg = create_error_packet(status);
     77             send_packet(error_msg, sockfd);
     78             free_packet(error_msg);
     79             close(sockfd);
     80             return NULL;
     81         }
     82     }
     83     /* Obtain public key from packet */
     84     uint8_t *pk = memalloc(PUBLIC_KEY_SIZE * sizeof(char));
     85     memcpy(pk, keyex_msg.data, PUBLIC_KEY_SIZE);
     86     if (pk == NULL) {
     87         free(keyex_msg.data);
     88         /* Fatal, we couldn't complete key exchange */
     89         close(sockfd);
     90         return NULL;
     91     }
     92     free(keyex_msg.data);
     93     return pk;
     94 }
     95 
     96 int send_public_key(int sockfd, uint8_t *pk)
     97 {
     98     /* send_packet requires heap allocated buffer */
     99     uint8_t *pk_dup = memalloc(PUBLIC_KEY_SIZE * sizeof(char));
    100     memcpy(pk_dup, pk, PUBLIC_KEY_SIZE);
    101     if (pk_dup == NULL) {
    102         close(sockfd);
    103         return -1;
    104     }
    105 
    106     /* Sending our public key to client */
    107     /* option???? */
    108     message *keyex = create_packet(1, ZSM_TYP_KEY, PUBLIC_KEY_SIZE, pk_dup);
    109     send_packet(keyex, sockfd);
    110     free_packet(keyex);
    111     return 0;
    112 }
    113 
    114 void print_packet(message *msg)
    115 {
    116     printf("Option: %d\n", msg->option);
    117     printf("Type: %d\n", msg->type);
    118     printf("Length: %lld\n", msg->length);
    119     printf("Data: %s\n\n", msg->data);
    120 }
    121 
    122 /*
    123  * Requires manually free message data
    124  */
    125 int recv_packet(message *msg, int fd)
    126 {
    127     int status = ZSM_STA_SUCCESS;
    128 
    129     /* Read the message components */
    130     if (recv(fd, &msg->option, sizeof(msg->option), 0) < 0 ||
    131         recv(fd, &msg->type, sizeof(msg->type), 0) < 0 ||
    132         recv(fd, &msg->length, sizeof(msg->length), 0) < 0) {
    133         status = ZSM_STA_READING_SOCKET;
    134         error(0, "Error reading from socket");
    135     }
    136     #if DEBUG == 1
    137         printf("==========PACKET RECEIVED==========\n");
    138     #endif
    139     #if DEBUG == 1
    140         printf("Option: %d\n", msg->option);
    141     #endif
    142 
    143     if (msg->type > 0xFF || msg->type < 0x0) {
    144         status = ZSM_STA_INVALID_TYPE;
    145         error(0, "Invalid message type");
    146         goto failure;
    147     }
    148     #if DEBUG == 1
    149         printf("Type: %d\n", msg->type);
    150     #endif
    151 
    152     /* Convert message length from network byte order to host byte order */
    153     if (msg->length > MAX_MESSAGE_LENGTH) {
    154         status = ZSM_STA_TOO_LONG;
    155         error(0, "Message too long: %lld", msg->length);
    156         goto failure;
    157     }
    158     #if DEBUG == 1
    159         printf("Length: %lld\n", msg->length);
    160     #endif
    161 
    162     // Allocate memory for message data
    163     msg->data = memalloc((msg->length + 1) * sizeof(char));
    164     if (msg->data == NULL) {
    165         status = ZSM_STA_MEMORY_ALLOCATION;
    166         goto failure;
    167     }
    168 
    169     /* Read message data from the socket */
    170     size_t bytes_read = 0;
    171     if ((bytes_read = recv(fd, msg->data, msg->length, 0)) < 0) {
    172         status = ZSM_STA_READING_SOCKET;
    173         error(0, "Error reading from socket");
    174         free(msg->data);
    175         goto failure;
    176     }
    177     if (bytes_read != msg->length) {
    178         status = ZSM_STA_INVALID_LENGTH;
    179         error(0, "Invalid message length: bytes_read=%ld != msg->length=%lld", bytes_read, msg->length);
    180         free(msg->data);
    181         goto failure;
    182     }
    183     msg->data[msg->length] = '\0';
    184 
    185     #if DEBUG == 1
    186         printf("Data: %s\n\n", msg->data);
    187     #endif
    188 
    189     return status;
    190 failure:;
    191     message *error_msg = create_error_packet(status);
    192     if (send_packet(error_msg, fd) != ZSM_STA_SUCCESS) {
    193         /* Resend it? */
    194         error(0, "Failed to send error packet to peer. Error status => %d", status);
    195     }
    196     free_packet(error_msg);
    197     return status;
    198 }
    199 
    200 message *create_error_packet(int code)
    201 {
    202     char *err = memalloc(ERROR_LENGTH * sizeof(char));
    203     switch (code) {
    204         case ZSM_STA_INVALID_TYPE:
    205             strcpy(err, "Invalid message type     ");
    206             break;
    207         case ZSM_STA_INVALID_LENGTH:
    208             strcpy(err, "Invalid message length   ");
    209             break;
    210         case ZSM_STA_TOO_LONG:
    211             strcpy(err, "Message too long         ");
    212             break;
    213         case ZSM_STA_READING_SOCKET:
    214             strcpy(err, "Error reading from socket");
    215             break;
    216         case ZSM_STA_WRITING_SOCKET:
    217             strcpy(err, "Error writing to socket  ");
    218             break;
    219         case ZSM_STA_UNKNOWN_USER:
    220             strcpy(err, "Unknwon user             ");
    221             break;
    222         case ZSM_STA_WRONG_KEY_LENGTH:
    223             strcpy(err, "Wrong public key length  ");
    224             break;
    225     }
    226     return create_packet(1, ZSM_TYP_ERROR, ERROR_LENGTH, err);
    227 }
    228 
    229 /*
    230  * Requires heap allocated msg data
    231  */
    232 message *create_packet(uint8_t option, uint8_t type, uint32_t length, char *data)
    233 {
    234     message *msg = memalloc(sizeof(message));
    235     msg->option = option;
    236     msg->type = type;
    237     msg->length = length;
    238     msg->data = data;
    239     return msg;
    240 }
    241 
    242 /*
    243  * Requires heap allocated msg data
    244  */
    245 int send_packet(message *msg, int fd)
    246 {
    247     int status = ZSM_STA_SUCCESS;
    248     uint32_t length = msg->length;
    249     // Send the message back to the client
    250     if (send(fd, &msg->option, sizeof(msg->option), 0) <= 0 ||
    251         send(fd, &msg->type, sizeof(msg->type), 0) <= 0 ||
    252         send(fd, &msg->length, sizeof(msg->length), 0) <= 0 ||
    253         send(fd, msg->data, length, 0) <= 0) {
    254         status = ZSM_STA_WRITING_SOCKET;
    255         error(0, "Error writing to socket");
    256         //free(msg->data);
    257         close(fd); // Close the socket and continue accepting connections
    258     }
    259     #if DEBUG == 1
    260         printf("==========PACKET SENT==========\n");
    261         print_packet(msg);
    262     #endif
    263     return status;
    264 }
    265 
    266 void free_packet(message *msg)
    267 {
    268     if (msg->type != 0x10) {
    269         /* temp solution, dont use stack allocated msg to send to client */
    270         free(msg->data);
    271     }
    272     free(msg);
    273 }