/* srv9p/gnet.c - libmisc/net.h implementation for libcr + GNU libc
 *
 * Copyright (C) 2024  Luke T. Shumaker <lukeshu@lukeshu.com>
 * SPDX-Licence-Identifier: AGPL-3.0-or-later
 */

#define _GNU_SOURCE     /* for pthread_sigqueue(3gnu) */
/* misc */
#include <assert.h>     /* for assert() */
#include <errno.h>      /* for errno, EAGAIN, EINVAL */
#include <error.h>      /* for error(3gnu) */
#include <stdlib.h>     /* for abs(), shutdown(), SHUT_RD, SHUT_WR, SHUT_RDWR */
#include <unistd.h>     /* for read(), write() */
/* net */
#include <arpa/inet.h>  /* for htons(3p) */
#include <netinet/in.h> /* for struct sockaddr_in */
#include <sys/socket.h> /* for struct sockaddr, socket(), SOCK_* flags, setsockopt(), SOL_SOCKET, SO_REUSEADDR, bind(), listen(), accept() */
/* async */
#include <pthread.h>    /* for pthread_* */
#include <signal.h>     /* for siginfo_t, struct sigaction, enum sigval, sigaction(), SIGRTMIN, SIGRTMAX, SA_SIGINFO */

#include <libcr/coroutine.h>
#include <libmisc/vcall.h>

#include "gnet.h"

/* common *********************************************************************/

#define UNUSED(name) /* name __attribute__ ((unused)) */

static int gnet_sig_io = 0;

static void gnet_handle_sig_io(int UNUSED(sig), siginfo_t *info, void *UNUSED(ucontext)) {
	cr_unpause_from_intrhandler((cid_t)info->si_value.sival_int);
}

static void gnet_init(void) {
	struct sigaction action = {0};

	if (gnet_sig_io)
		return;

	gnet_sig_io = SIGRTMIN;
	if (gnet_sig_io > SIGRTMAX)
		error(1, 0, "SIGRTMAX exceeded");

	action.sa_flags = SA_SIGINFO;
	action.sa_sigaction = gnet_handle_sig_io;
	if (sigaction(gnet_sig_io, &action, NULL) < 0)
		error(1, errno, "sigaction");
}

#define WAKE_COROUTINE(args) do {	  \
		int r; \
		union sigval val = {0}; \
		val.sival_int = (int)((args)->cr_coroutine); \
		do { \
			r = pthread_sigqueue((args)->cr_thread, gnet_sig_io, val); \
			assert(r == 0 || r == EAGAIN); \
		} while (r == EAGAIN); \
	} while (0)

static inline bool RUN_PTHREAD(void *(*fn)(void *), void *args) {
	pthread_t thread;
	if (pthread_create(&thread, NULL, fn, args))
		return true;
	cr_pause_and_yield();
	if (pthread_join(thread, NULL))
		return true;
	return false;
}

/* init() ( AKA socket(3) + listen(3) )****************************************/

static implements_net_conn *gnet_accept(implements_net_listener *_listener);
static ssize_t gnet_read(implements_net_conn *conn, void *buf, size_t count);
static ssize_t gnet_write(implements_net_conn *conn, void *buf, size_t count);
static int gnet_close(implements_net_conn *conn, bool rd, bool wr);

static struct net_listener_vtable gnet_listener_vtable = {
	.accept = gnet_accept,
};

static struct net_conn_vtable gnet_conn_vtable = {
	.read  = gnet_read,
	.write = gnet_write,
	.close = gnet_close,
};

void gnet_listener_init(struct gnet_listener *self, uint16_t port) {
	int listenerfd;
	union {
		struct sockaddr_in in;
		struct sockaddr gen;
	} addr = { 0 };

	gnet_init();

	addr.in.sin_family = AF_INET;
	addr.in.sin_port = htons(port);
	listenerfd = socket(AF_INET, SOCK_STREAM, 0);
	if (listenerfd < 0)
		error(1, errno, "socket");
	if (setsockopt(listenerfd, SOL_SOCKET, SO_REUSEADDR, &(int){1}, sizeof(int)) < 0)
		error(1, errno, "setsockopt");
	if (bind(listenerfd, &addr.gen, sizeof addr) < 0)
		error(1, errno, "bind");
	if (listen(listenerfd, 0) < 0)
		error(1, errno, "listen");

	self->vtable = &gnet_listener_vtable;
	self->fd = listenerfd;
}

/* accept() *******************************************************************/

struct _pthread_accept_args {
	pthread_t        cr_thread;
	cid_t            cr_coroutine;

	int              listenerfd;

	int             *ret_connfd;
};

static void *_pthread_accept(void *_args) {
	struct _pthread_accept_args *args = _args;
	*(args->ret_connfd) = accept(args->listenerfd, NULL, NULL);
	if (*(args->ret_connfd) < 0)
		*(args->ret_connfd) = -errno;
	WAKE_COROUTINE(args);
	return NULL;
};

static implements_net_conn *gnet_accept(implements_net_listener *_listener) {
	struct gnet_listener *listener = VCALL_SELF(struct gnet_listener, implements_net_listener, _listener);
	assert(listener);

	int ret_connfd;
	struct _pthread_accept_args args = {
		.cr_thread    = pthread_self(),
		.cr_coroutine = cr_getcid(),
		.listenerfd   = listener->fd,
		.ret_connfd   = &ret_connfd,
	};
	if (RUN_PTHREAD(_pthread_accept, &args))
		return NULL;

	listener->active_conn.vtable = &gnet_conn_vtable;
	listener->active_conn.fd = ret_connfd;
	return &listener->active_conn;
}

/* read() *********************************************************************/

struct _pthread_read_args {
	pthread_t        cr_thread;
	cid_t            cr_coroutine;

	int              connfd;
	void            *buf;
	size_t           count;

	ssize_t         *ret;
};

static void *_pthread_read(void *_args) {
	struct _pthread_read_args *args = _args;
	*(args->ret) = read(args->connfd, args->buf, args->count);
	if (*(args->ret) < 0)
		*(args->ret) = -errno;
	WAKE_COROUTINE(args);
	return NULL;
};

static ssize_t gnet_read(implements_net_conn *_conn, void *buf, size_t count) {
	struct _gnet_conn *conn = VCALL_SELF(struct _gnet_conn, implements_net_conn, _conn);
	assert(conn);

	ssize_t ret;
	struct _pthread_read_args args = {
		.cr_thread    = pthread_self(),
		.cr_coroutine = cr_getcid(),

		.connfd       = conn->fd,
		.buf          = buf,
		.count        = count,

		.ret          = &ret,
	};
	if (RUN_PTHREAD(_pthread_read, &args))
		return -1;
	return ret;
}

/* write() ********************************************************************/

struct _pthread_write_args {
	pthread_t        cr_thread;
	cid_t            cr_coroutine;

	int              connfd;
	void            *buf;
	size_t           count;

	ssize_t         *ret;
};

static void *_pthread_write(void *_args) {
	struct _pthread_read_args *args = _args;
	size_t done = 0;
	while (done < args->count) {
		ssize_t r = write(args->connfd, args->buf, args->count);
		if (r < 0) {
			*(args->ret) = -errno;
			break;
		}
		done += r;
	}
	if (done == args->count)
		*(args->ret) = done;
	WAKE_COROUTINE(args);
	return NULL;
};

static ssize_t gnet_write(implements_net_conn *_conn, void *buf, size_t count) {
	struct _gnet_conn *conn = VCALL_SELF(struct _gnet_conn, implements_net_conn, _conn);
	assert(conn);

	ssize_t ret;
	struct _pthread_write_args args = {
		.cr_thread    = pthread_self(),
		.cr_coroutine = cr_getcid(),

		.connfd       = conn->fd,
		.buf          = buf,
		.count        = count,

		.ret          = &ret,
	};
	if (RUN_PTHREAD(_pthread_write, &args))
		return -1;
	return ret;
}

/* close() ********************************************************************/

static int gnet_close(implements_net_conn *_conn, bool rd, bool wr) {
	struct _gnet_conn *conn = VCALL_SELF(struct _gnet_conn, implements_net_conn, _conn);
	assert(conn);

	int how;
	if (rd && wr)
		how = SHUT_RDWR;
	else if (rd && !wr)
		how = SHUT_RD;
	else if (!rd && wr)
		how = SHUT_WR;
	else
		return -EINVAL;
	return shutdown(conn->fd, how) ? -errno : 0;
}