/* srv9p/gnet.c - libmisc/net.h implementation for libcr + GNU libc
 *
 * Copyright (C) 2024  Luke T. Shumaker <lukeshu@lukeshu.com>
 * SPDX-Licence-Identifier: AGPL-3.0-or-later
 */

#define _GNU_SOURCE     /* for pthread_sigqueue(3gnu) */
/* misc */
#include <assert.h>     /* for assert() */
#include <errno.h>      /* for errno, EAGAIN, EINVAL */
#include <error.h>      /* for error(3gnu) */
#include <stdlib.h>     /* for abs(), shutdown(), SHUT_RD, SHUT_WR, SHUT_RDWR */
#include <unistd.h>     /* for read(), write() */
/* net */
#include <arpa/inet.h>  /* for htons(3p) */
#include <netinet/in.h> /* for struct sockaddr_in */
#include <sys/socket.h> /* for struct sockaddr{,_storage}, SOCK_*, SOL_*, SO_*, socket(), setsockopt(), bind(), listen(), accept() */
/* async */
#include <pthread.h>    /* for pthread_* */
#include <signal.h>     /* for siginfo_t, struct sigaction, enum sigval, sigaction(), SIGRTMIN, SIGRTMAX, SA_SIGINFO */

#include <libcr/coroutine.h>
#include <libmisc/vcall.h>

#include "gnet.h"

/* common *********************************************************************/

#define UNUSED(name) /* name __attribute__ ((unused)) */

static int gnet_sig_io = 0;

static void gnet_handle_sig_io(int UNUSED(sig), siginfo_t *info, void *UNUSED(ucontext)) {
	cr_unpause_from_intrhandler((cid_t)info->si_value.sival_int);
}

static void gnet_init(void) {
	struct sigaction action = {0};

	if (gnet_sig_io)
		return;

	gnet_sig_io = SIGRTMIN;
	if (gnet_sig_io > SIGRTMAX)
		error(1, 0, "SIGRTMAX exceeded");

	action.sa_flags = SA_SIGINFO;
	action.sa_sigaction = gnet_handle_sig_io;
	if (sigaction(gnet_sig_io, &action, NULL) < 0)
		error(1, errno, "sigaction");
}

#define WAKE_COROUTINE(args) do {	  \
		int r; \
		union sigval val = {0}; \
		val.sival_int = (int)((args)->cr_coroutine); \
		do { \
			r = pthread_sigqueue((args)->cr_thread, gnet_sig_io, val); \
			assert(r == 0 || r == EAGAIN); \
		} while (r == EAGAIN); \
	} while (0)

static inline bool RUN_PTHREAD(void *(*fn)(void *), void *args) {
	pthread_t thread;
	if (pthread_create(&thread, NULL, fn, args))
		return true;
	cr_pause_and_yield();
	if (pthread_join(thread, NULL))
		return true;
	return false;
}

static inline ssize_t gnet_map_errno(ssize_t v) {
	if (v >= 0)
		return v;
	switch (v) {
	case ETIMEDOUT:
		return NET_ETIMEDOUT;
	default:
		return NET_EOTHER;
	}
}

/* TCP init() ( AKA socket(3) + listen(3) )************************************/

static implements_net_stream_conn *gnet_tcp_accept(implements_net_stream_listener *_listener);
static ssize_t gnet_tcp_read(implements_net_stream_conn *conn, void *buf, size_t count);
static ssize_t gnet_tcp_write(implements_net_stream_conn *conn, void *buf, size_t count);
static int gnet_tcp_close(implements_net_stream_conn *conn, bool rd, bool wr);

static struct net_stream_listener_vtable gnet_tcp_listener_vtable = {
	.accept = gnet_tcp_accept,
};

static struct net_stream_conn_vtable gnet_tcp_conn_vtable = {
	.read  = gnet_tcp_read,
	.write = gnet_tcp_write,
	.close = gnet_tcp_close,
};

void gnet_tcp_listener_init(struct gnet_tcp_listener *self, uint16_t port) {
	int listenerfd;
	union {
		struct sockaddr_in in;
		struct sockaddr gen;
	} addr = { 0 };

	gnet_init();

	addr.in.sin_family = AF_INET;
	addr.in.sin_port = htons(port);
	listenerfd = socket(AF_INET, SOCK_STREAM, 0);
	if (listenerfd < 0)
		error(1, errno, "socket");
	if (setsockopt(listenerfd, SOL_SOCKET, SO_REUSEADDR, &(int){1}, sizeof(int)) < 0)
		error(1, errno, "setsockopt(fd=%d, SO_REUSEADDR=1)", listenerfd);
	if (setsockopt(listenerfd, SOL_SOCKET, SO_REUSEPORT, &(int){1}, sizeof(int)) < 0)
		error(1, errno, "setsockopt(fd=%d, SO_REUSEPORT=1)", listenerfd);
	if (bind(listenerfd, &addr.gen, sizeof addr) < 0)
		error(1, errno, "bind(fd=%d)", listenerfd);
	if (listen(listenerfd, 0) < 0)
		error(1, errno, "listen(fd=%d)", listenerfd);

	self->vtable = &gnet_tcp_listener_vtable;
	self->fd = listenerfd;
}

/* TCP accept() ***************************************************************/

struct gnet_pthread_accept_args {
	pthread_t        cr_thread;
	cid_t            cr_coroutine;

	int              listenerfd;

	int             *ret_connfd;
};

static void *gnet_pthread_accept(void *_args) {
	struct gnet_pthread_accept_args *args = _args;
	*(args->ret_connfd) = accept(args->listenerfd, NULL, NULL);
	if (*(args->ret_connfd) < 0)
		*(args->ret_connfd) = -errno;
	WAKE_COROUTINE(args);
	return NULL;
};

static implements_net_stream_conn *gnet_tcp_accept(implements_net_stream_listener *_listener) {
	struct gnet_tcp_listener *listener =
		VCALL_SELF(struct gnet_tcp_listener, implements_net_stream_listener, _listener);
	assert(listener);

	int ret_connfd;
	struct gnet_pthread_accept_args args = {
		.cr_thread    = pthread_self(),
		.cr_coroutine = cr_getcid(),
		.listenerfd   = listener->fd,
		.ret_connfd   = &ret_connfd,
	};
	if (RUN_PTHREAD(gnet_pthread_accept, &args))
		return NULL;

	listener->active_conn.vtable = &gnet_tcp_conn_vtable;
	listener->active_conn.fd = ret_connfd;
	return &listener->active_conn;
}

/* TCP read() *****************************************************************/

struct gnet_pthread_read_args {
	pthread_t        cr_thread;
	cid_t            cr_coroutine;

	int              connfd;
	void            *buf;
	size_t           count;

	ssize_t         *ret;
};

static void *gnet_pthread_read(void *_args) {
	struct gnet_pthread_read_args *args = _args;
	*(args->ret) = read(args->connfd, args->buf, args->count);
	if (*(args->ret) < 0)
		*(args->ret) = gnet_map_errno(-errno);
	WAKE_COROUTINE(args);
	return NULL;
};

static ssize_t gnet_tcp_read(implements_net_stream_conn *_conn, void *buf, size_t count) {
	struct _gnet_tcp_conn *conn =
		VCALL_SELF(struct _gnet_tcp_conn, implements_net_stream_conn, _conn);
	assert(conn);

	ssize_t ret;
	struct gnet_pthread_read_args args = {
		.cr_thread    = pthread_self(),
		.cr_coroutine = cr_getcid(),

		.connfd       = conn->fd,
		.buf          = buf,
		.count        = count,

		.ret          = &ret,
	};
	if (RUN_PTHREAD(gnet_pthread_read, &args))
		return -NET_ETHREAD;
	return ret;
}

/* TCP write() ****************************************************************/

struct gnet_pthread_write_args {
	pthread_t        cr_thread;
	cid_t            cr_coroutine;

	int              connfd;
	void            *buf;
	size_t           count;

	ssize_t         *ret;
};

static void *gnet_pthread_write(void *_args) {
	struct gnet_pthread_read_args *args = _args;
	size_t done = 0;
	while (done < args->count) {
		ssize_t r = write(args->connfd, args->buf, args->count);
		if (r < 0) {
			gnet_map_errno(-errno);
			break;
		}
		done += r;
	}
	if (done == args->count)
		*(args->ret) = done;
	WAKE_COROUTINE(args);
	return NULL;
};

static ssize_t gnet_tcp_write(implements_net_stream_conn *_conn, void *buf, size_t count) {
	struct _gnet_tcp_conn *conn =
		VCALL_SELF(struct _gnet_tcp_conn, implements_net_stream_conn, _conn);
	assert(conn);

	ssize_t ret;
	struct gnet_pthread_write_args args = {
		.cr_thread    = pthread_self(),
		.cr_coroutine = cr_getcid(),

		.connfd       = conn->fd,
		.buf          = buf,
		.count        = count,

		.ret          = &ret,
	};
	if (RUN_PTHREAD(gnet_pthread_write, &args))
		return -NET_ETHREAD;
	return ret;
}

/* TCP close() ****************************************************************/

static int gnet_tcp_close(implements_net_stream_conn *_conn, bool rd, bool wr) {
	struct _gnet_tcp_conn *conn =
		VCALL_SELF(struct _gnet_tcp_conn, implements_net_stream_conn, _conn);
	assert(conn);

	int how;
	if (rd && wr)
		how = SHUT_RDWR;
	else if (rd && !wr)
		how = SHUT_RD;
	else if (!rd && wr)
		how = SHUT_WR;
	else
		assert(false);
	return gnet_map_errno(shutdown(conn->fd, how) ? -errno : 0);
}

/* UDP init() *****************************************************************/

static ssize_t gnet_udp_sendto(struct net_packet_conn *self, void *buf, size_t len,
                               struct net_ip4_addr addr, uint16_t port);
static ssize_t gnet_udp_recvfrom(struct net_packet_conn *self, void *buf, size_t len,
                                 struct net_ip4_addr *ret_addr, uint16_t *ret_port);
static int     gnet_udp_close(struct net_packet_conn *self);

static struct net_packet_conn_vtable gnet_udp_conn_vtable = {
	.sendto   = gnet_udp_sendto,
	.recvfrom = gnet_udp_recvfrom,
	.close    = gnet_udp_close,
};

void gnet_udp_conn_init(struct gnet_udp_conn *self, uint16_t port) {
	int fd;
	union {
		struct sockaddr_in      in;
		struct sockaddr         gen;
		struct sockaddr_storage stor;
	} addr = { 0 };

	gnet_init();

	addr.in.sin_family = AF_INET;
	addr.in.sin_port = htons(port);
	fd = socket(AF_INET, SOCK_DGRAM, 0);
	if (fd < 0)
		error(1, errno, "socket");
	if (bind(fd, &addr.gen, sizeof addr) < 0)
		error(1, errno, "bind");

	self->vtable = &gnet_udp_conn_vtable;
	self->fd = fd;
}

/* UDP sendto() ***************************************************************/

struct gnet_pthread_sendto_args {
	pthread_t                cr_thread;
	cid_t                    cr_coroutine;

	int                      connfd;
	void                    *buf;
	size_t                   count;
	struct net_ip4_addr      node;
	uint16_t                 port;

	ssize_t                 *ret;
};

static void *gnet_pthread_sendto(void *_args) {
	struct gnet_pthread_sendto_args *args = _args;
	union {
		struct sockaddr_in      in;
		struct sockaddr         gen;
		struct sockaddr_storage stor;
	} addr = { 0 };

	addr.in.sin_family = AF_INET;
	addr.in.sin_addr.s_addr =
		(((uint32_t)args->node.octets[0])<<24) |
		(((uint32_t)args->node.octets[1])<<16) |
		(((uint32_t)args->node.octets[2])<< 8) |
		(((uint32_t)args->node.octets[3])<< 0) ;
	addr.in.sin_port = htons(args->port);
	*(args->ret) = sendto(args->connfd, args->buf, args->count, 0, &addr.gen, sizeof(addr));
	if (*(args->ret) < 0)
		*(args->ret) = gnet_map_errno(-errno);
	WAKE_COROUTINE(args);
	return NULL;
}

static ssize_t gnet_udp_sendto(struct net_packet_conn *_conn, void *buf, size_t count,
                               struct net_ip4_addr node, uint16_t port) {
	struct gnet_udp_conn *conn =
		VCALL_SELF(struct gnet_udp_conn, implements_net_packet_conn, _conn);
	assert(conn);

	ssize_t ret;
	struct gnet_pthread_sendto_args args = {
		.cr_thread    = pthread_self(),
		.cr_coroutine = cr_getcid(),

		.connfd       = conn->fd,
		.buf          = buf,
		.count        = count,
		.node         = node,
		.port         = port,

		.ret          = &ret,
	};
	if (RUN_PTHREAD(gnet_pthread_sendto, &args))
		return -NET_ETHREAD;
	return ret;
}

/* UDP recvfrom() *************************************************************/

struct gnet_pthread_recvfrom_args {
	pthread_t                cr_thread;
	cid_t                    cr_coroutine;

	int                      connfd;
	void                    *buf;
	size_t                   count;

	ssize_t                 *ret_size;
	struct net_ip4_addr     *ret_node;
	uint16_t                *ret_port;
};

static void *gnet_pthread_recvfrom(void *_args) {
	struct gnet_pthread_recvfrom_args *args = _args;

	union {
		struct sockaddr_in      in;
		struct sockaddr         gen;
		struct sockaddr_storage stor;
	} addr = { 0 };
	socklen_t addr_size;

	*(args->ret_size) = recvfrom(args->connfd, args->buf, args->count, 0, &addr.gen, &addr_size);
	if (*(args->ret_size) < 0)
		*(args->ret_size) = gnet_map_errno(-errno);
	else {
		assert(addr.in.sin_family == AF_INET);
		if (args->ret_node) {
			args->ret_node->octets[0] = (addr.in.sin_addr.s_addr >> 24) & 0xFF;
			args->ret_node->octets[1] = (addr.in.sin_addr.s_addr >> 16) & 0xFF;
			args->ret_node->octets[2] = (addr.in.sin_addr.s_addr >>  8) & 0xFF;
			args->ret_node->octets[3] = (addr.in.sin_addr.s_addr >>  0) & 0xFF;
		}
		if (args->ret_port)
			(*args->ret_port) = ntohs(addr.in.sin_port);
	}
	WAKE_COROUTINE(args);
	return NULL;
}

static ssize_t gnet_udp_recvfrom(struct net_packet_conn *_conn, void *buf, size_t count,
                                 struct net_ip4_addr *ret_node, uint16_t *ret_port) {
	struct gnet_udp_conn *conn =
		VCALL_SELF(struct gnet_udp_conn, implements_net_packet_conn, _conn);
	assert(conn);

	ssize_t ret;
	struct gnet_pthread_recvfrom_args args = {
		.cr_thread    = pthread_self(),
		.cr_coroutine = cr_getcid(),

		.connfd       = conn->fd,
		.buf          = buf,
		.count        = count,

		.ret_size     = &ret,
		.ret_node     = ret_node,
		.ret_port     = ret_port,
	};
	if (RUN_PTHREAD(gnet_pthread_recvfrom, &args))
		return -NET_ETHREAD;
	return ret;
}

/* UDP close() ****************************************************************/

static int gnet_udp_close(struct net_packet_conn *_conn) {
	struct gnet_udp_conn *conn =
		VCALL_SELF(struct gnet_udp_conn, implements_net_packet_conn, _conn);
	assert(conn);

	return gnet_map_errno(close(conn->fd) ? -errno : 0);
}