#include #include #include #include #include #include #include #include #include #include "transport.h" #include "serial.h" struct Transport_Conn { int fd; pthread_mutex_t write_mutex; uint32_t max_payload; Transport_Frame_Cb on_frame; Transport_Disconnect_Cb on_disconnect; void *userdata; struct Transport_Server *server; /* NULL for outbound connections */ }; struct Transport_Server { int listen_fd; uint16_t bound_port; /* actual port after bind */ struct Transport_Server_Config config; pthread_t accept_thread; pthread_mutex_t count_mutex; int connection_count; atomic_int running; }; static int write_exact(int fd, const uint8_t *buf, size_t n) { while (n > 0) { ssize_t r = write(fd, buf, n); if (r <= 0) { return -1; } buf += r; n -= (size_t)r; } return 0; } static int read_exact(int fd, uint8_t *buf, size_t n) { while (n > 0) { ssize_t r = read(fd, buf, n); if (r <= 0) { return -1; } buf += r; n -= (size_t)r; } return 0; } static void *conn_read_thread_fn(void *arg) { struct Transport_Conn *conn = arg; uint8_t header_buf[TRANSPORT_FRAME_HEADER_SIZE]; while (1) { if (read_exact(conn->fd, header_buf, TRANSPORT_FRAME_HEADER_SIZE) != 0) { break; } struct Transport_Frame frame; frame.message_type = get_u16(header_buf, 0); frame.payload_length = get_u32(header_buf, 2); if (frame.payload_length > conn->max_payload) { break; } if (frame.payload_length > 0) { frame.payload = malloc(frame.payload_length); if (!frame.payload) { break; } if (read_exact(conn->fd, frame.payload, frame.payload_length) != 0) { free(frame.payload); break; } } else { frame.payload = NULL; } conn->on_frame(conn, &frame, conn->userdata); } if (conn->on_disconnect) { conn->on_disconnect(conn, conn->userdata); } if (conn->server) { pthread_mutex_lock(&conn->server->count_mutex); conn->server->connection_count--; pthread_mutex_unlock(&conn->server->count_mutex); } close(conn->fd); pthread_mutex_destroy(&conn->write_mutex); free(conn); return NULL; } static struct Transport_Conn *conn_create(int fd, uint32_t max_payload, Transport_Frame_Cb on_frame, Transport_Disconnect_Cb on_disconnect, void *userdata, struct Transport_Server *server) { struct Transport_Conn *conn = malloc(sizeof(*conn)); if (!conn) { return NULL; } conn->fd = fd; conn->max_payload = max_payload; conn->on_frame = on_frame; conn->on_disconnect = on_disconnect; conn->userdata = userdata; conn->server = server; pthread_mutex_init(&conn->write_mutex, NULL); return conn; } static int conn_start_thread(struct Transport_Conn *conn) { pthread_t thread; pthread_attr_t attr; pthread_attr_init(&attr); pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED); int r = pthread_create(&thread, &attr, conn_read_thread_fn, conn); pthread_attr_destroy(&attr); return r; } static void *accept_thread_fn(void *arg) { struct Transport_Server *server = arg; while (atomic_load(&server->running)) { int fd = accept(server->listen_fd, NULL, NULL); if (fd < 0) { if (!atomic_load(&server->running)) { break; } continue; } pthread_mutex_lock(&server->count_mutex); if (server->connection_count >= server->config.max_connections) { pthread_mutex_unlock(&server->count_mutex); close(fd); continue; } server->connection_count++; pthread_mutex_unlock(&server->count_mutex); struct Transport_Conn *conn = conn_create(fd, server->config.max_payload, server->config.on_frame, server->config.on_disconnect, server->config.userdata, server); if (!conn) { pthread_mutex_lock(&server->count_mutex); server->connection_count--; pthread_mutex_unlock(&server->count_mutex); close(fd); continue; } if (server->config.on_connect) { server->config.on_connect(conn, server->config.userdata); } if (conn_start_thread(conn) != 0) { pthread_mutex_lock(&server->count_mutex); server->connection_count--; pthread_mutex_unlock(&server->count_mutex); close(conn->fd); pthread_mutex_destroy(&conn->write_mutex); free(conn); } } return NULL; } struct App_Error transport_server_create(struct Transport_Server **out, struct Transport_Server_Config *config) { struct Transport_Server *server = malloc(sizeof(*server)); if (!server) { return APP_SYSCALL_ERROR(); } server->config = *config; server->connection_count = 0; server->listen_fd = -1; atomic_init(&server->running, 0); pthread_mutex_init(&server->count_mutex, NULL); *out = server; return APP_OK; } struct App_Error transport_server_start(struct Transport_Server *server) { int fd = socket(AF_INET, SOCK_STREAM | SOCK_CLOEXEC, 0); if (fd < 0) { return APP_SYSCALL_ERROR(); } int opt = 1; setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); struct sockaddr_in addr = {0}; addr.sin_family = AF_INET; addr.sin_addr.s_addr = htonl(INADDR_ANY); addr.sin_port = htons(server->config.port); if (bind(fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { close(fd); return APP_SYSCALL_ERROR(); } /* Read back the actual port (matters when config.port == 0) */ struct sockaddr_in bound = {0}; socklen_t bound_len = sizeof(bound); if (getsockname(fd, (struct sockaddr *)&bound, &bound_len) == 0) { server->bound_port = ntohs(bound.sin_port); } else { server->bound_port = server->config.port; } if (listen(fd, SOMAXCONN) < 0) { close(fd); return APP_SYSCALL_ERROR(); } server->listen_fd = fd; atomic_store(&server->running, 1); if (pthread_create(&server->accept_thread, NULL, accept_thread_fn, server) != 0) { atomic_store(&server->running, 0); close(fd); server->listen_fd = -1; return APP_SYSCALL_ERROR(); } return APP_OK; } void transport_server_destroy(struct Transport_Server *server) { atomic_store(&server->running, 0); close(server->listen_fd); pthread_join(server->accept_thread, NULL); pthread_mutex_destroy(&server->count_mutex); free(server); } uint16_t transport_server_get_port(const struct Transport_Server *server) { return server->bound_port; } struct App_Error transport_connect(struct Transport_Conn **out, const char *host, uint16_t port, uint32_t max_payload, Transport_Frame_Cb on_frame, Transport_Disconnect_Cb on_disconnect, void *userdata) { char port_str[8]; snprintf(port_str, sizeof(port_str), "%u", port); struct addrinfo hints = {0}; hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_STREAM; struct addrinfo *res; if (getaddrinfo(host, port_str, &hints, &res) != 0) { return APP_SYSCALL_ERROR(); } int fd = -1; for (struct addrinfo *ai = res; ai; ai = ai->ai_next) { fd = socket(ai->ai_family, ai->ai_socktype | SOCK_CLOEXEC, ai->ai_protocol); if (fd < 0) { continue; } if (connect(fd, ai->ai_addr, ai->ai_addrlen) == 0) { break; } close(fd); fd = -1; } freeaddrinfo(res); if (fd < 0) { return APP_SYSCALL_ERROR(); } struct Transport_Conn *conn = conn_create(fd, max_payload, on_frame, on_disconnect, userdata, NULL); if (!conn) { close(fd); return APP_SYSCALL_ERROR(); } if (conn_start_thread(conn) != 0) { close(fd); pthread_mutex_destroy(&conn->write_mutex); free(conn); return APP_SYSCALL_ERROR(); } *out = conn; return APP_OK; } struct App_Error transport_send_frame(struct Transport_Conn *conn, uint16_t message_type, const uint8_t *payload, uint32_t length) { uint8_t header[TRANSPORT_FRAME_HEADER_SIZE]; put_u16(header, 0, message_type); put_u32(header, 2, length); pthread_mutex_lock(&conn->write_mutex); int ok = write_exact(conn->fd, header, TRANSPORT_FRAME_HEADER_SIZE); if (ok == 0 && length > 0) { ok = write_exact(conn->fd, payload, length); } pthread_mutex_unlock(&conn->write_mutex); if (ok != 0) { return APP_SYSCALL_ERROR(); } return APP_OK; } void transport_conn_close(struct Transport_Conn *conn) { close(conn->fd); }