/* libhw/host_net.c - <libhw/generic/net.h> implementation for hosted glibc
 *
 * Copyright (C) 2024-2025  Luke T. Shumaker <lukeshu@lukeshu.com>
 * SPDX-License-Identifier: AGPL-3.0-or-later
 */

#define _GNU_SOURCE     /* for pthread_sigqueue(3gnu) */
/* misc */
#include <errno.h>      /* for errno, EAGAIN, EINVAL */
#include <error.h>      /* for error(3gnu) */
#include <stdlib.h>     /* for abs(), shutdown(), SHUT_RD, SHUT_WR, SHUT_RDWR */
#include <unistd.h>     /* for read(), write() */
/* net */
#include <arpa/inet.h>  /* for htons(3p) */
#include <netinet/in.h> /* for struct sockaddr_in */
#include <sys/socket.h> /* for struct sockaddr{,_storage}, SOCK_*, SOL_*, SO_*, socket(), setsockopt(), bind(), listen(), accept() */
/* async */
#include <pthread.h>    /* for pthread_* */
#include <signal.h>     /* for siginfo_t, struct sigaction, enum sigval, sigaction(), SA_SIGINFO */

#include <libcr/coroutine.h>
#include <libmisc/assert.h>
#include <libmisc/macro.h>
#include <libmisc/vcall.h>

#include <libhw/generic/alarmclock.h>

#define IMPLEMENTATION_FOR_LIBHW_HOST_NET_H YES
#include <libhw/host_net.h>

#include "host_util.h" /* for host_sigrt_alloc(), ns_to_host_us_time() */

/* common *********************************************************************/

static int hostnet_sig_io = 0;

static void hostnet_handle_sig_io(int LM_UNUSED(sig), siginfo_t *info, void *LM_UNUSED(ucontext)) {
	cr_unpause_from_intrhandler((cid_t)info->si_value.sival_int);
}

static void hostnet_init(void) {
	struct sigaction action = {0};

	if (hostnet_sig_io)
		return;

	hostnet_sig_io = host_sigrt_alloc();

	action.sa_flags = SA_SIGINFO;
	action.sa_sigaction = hostnet_handle_sig_io;
	if (sigaction(hostnet_sig_io, &action, NULL) < 0)
		error(1, errno, "sigaction");
}

#define WAKE_COROUTINE(args) do {                               \
		int r;                                          \
		union sigval val = {0};                         \
		val.sival_int = (int)((args)->cr_coroutine);    \
		do {                                            \
			r = pthread_sigqueue((args)->cr_thread, \
			                     hostnet_sig_io,    \
			                     val);              \
			assert(r == 0 || r == EAGAIN);          \
		} while (r == EAGAIN);                          \
	} while (0)

static inline bool RUN_PTHREAD(void *(*fn)(void *), void *args) {
	pthread_t thread;
	bool saved = cr_save_and_disable_interrupts();
	if (pthread_create(&thread, NULL, fn, args))
		return true;
	cr_pause_and_yield();
	cr_restore_interrupts(saved);
	if (pthread_join(thread, NULL))
		return true;
	return false;
}

enum hostnet_timeout_op {
	OP_NONE,
	OP_SEND,
	OP_RECV,
};

static inline ssize_t hostnet_map_negerrno(ssize_t v, enum hostnet_timeout_op op) {
	if (v >= 0)
		return v;
	switch (v) {
	case -EHOSTUNREACH:
		return -NET_EARP_TIMEOUT;
	case -ETIMEDOUT:
		switch (op) {
		case OP_NONE:
			assert_notreached("impossible ETIMEDOUT");
		case OP_SEND:
			return -NET_EACK_TIMEOUT;
		case OP_RECV:
			return -NET_ERECV_TIMEOUT;
		}
		assert_notreached("invalid timeout op");
	case -EBADF:
		return -NET_ECLOSED;
	case -EMSGSIZE:
		return -NET_EMSGSIZE;
	default:
		return -NET_EOTHER;
	}
}

/* TCP init() ( AKA socket(3) + listen(3) )************************************/

static implements_net_stream_conn *hostnet_tcplist_accept(implements_net_stream_listener *);
static int                         hostnet_tcplist_close(implements_net_stream_listener *);
static void    hostnet_tcp_set_read_deadline(implements_net_stream_conn *conn, uint64_t ts_ns);
static ssize_t hostnet_tcp_read(implements_net_stream_conn *conn, void *buf, size_t count);
static ssize_t hostnet_tcp_write(implements_net_stream_conn *conn, void *buf, size_t count);
static int     hostnet_tcp_close(implements_net_stream_conn *conn, bool rd, bool wr);

static struct net_stream_listener_vtable hostnet_tcp_listener_vtable = {
	.accept = hostnet_tcplist_accept,
	.close  = hostnet_tcplist_close,
};

static struct net_stream_conn_vtable hostnet_tcp_conn_vtable = {
	.set_read_deadline = hostnet_tcp_set_read_deadline,
	.read              = hostnet_tcp_read,
	.write             = hostnet_tcp_write,
	.close             = hostnet_tcp_close,
};

void hostnet_tcp_listener_init(struct hostnet_tcp_listener *self, uint16_t port) {
	int listenerfd;
	union {
		struct sockaddr_in in;
		struct sockaddr gen;
	} addr = { 0 };

	hostnet_init();

	addr.in.sin_family = AF_INET;
	addr.in.sin_port = htons(port);
	listenerfd = socket(AF_INET, SOCK_STREAM, 0);
	if (listenerfd < 0)
		error(1, errno, "socket");
	if (setsockopt(listenerfd, SOL_SOCKET, SO_REUSEADDR, &(int){1}, sizeof(int)) < 0)
		error(1, errno, "setsockopt(fd=%d, SO_REUSEADDR=1)", listenerfd);
	if (setsockopt(listenerfd, SOL_SOCKET, SO_REUSEPORT, &(int){1}, sizeof(int)) < 0)
		error(1, errno, "setsockopt(fd=%d, SO_REUSEPORT=1)", listenerfd);
	if (bind(listenerfd, &addr.gen, sizeof addr) < 0)
		error(1, errno, "bind(fd=%d)", listenerfd);
	if (listen(listenerfd, 0) < 0)
		error(1, errno, "listen(fd=%d)", listenerfd);

	self->vtable = &hostnet_tcp_listener_vtable;
	self->fd = listenerfd;
}

/* TCP listener accept() ******************************************************/

struct hostnet_pthread_accept_args {
	pthread_t        cr_thread;
	cid_t            cr_coroutine;

	int              listenerfd;

	int             *ret_connfd;
};

static void *hostnet_pthread_accept(void *_args) {
	struct hostnet_pthread_accept_args *args = _args;
	*(args->ret_connfd) = accept(args->listenerfd, NULL, NULL);
	if (*(args->ret_connfd) < 0)
		*(args->ret_connfd) = -errno;
	WAKE_COROUTINE(args);
	return NULL;
}

static implements_net_stream_conn *hostnet_tcplist_accept(implements_net_stream_listener *_listener) {
	struct hostnet_tcp_listener *listener =
		VCALL_SELF(struct hostnet_tcp_listener, implements_net_stream_listener, _listener);
	assert(listener);

	int ret_connfd;
	struct hostnet_pthread_accept_args args = {
		.cr_thread    = pthread_self(),
		.cr_coroutine = cr_getcid(),
		.listenerfd   = listener->fd,
		.ret_connfd   = &ret_connfd,
	};
	if (RUN_PTHREAD(hostnet_pthread_accept, &args))
		return NULL;

	listener->active_conn.vtable = &hostnet_tcp_conn_vtable;
	listener->active_conn.fd = ret_connfd;
	listener->active_conn.read_deadline_ns = 0;
	return &listener->active_conn;
}

/* TCP listener close() *******************************************************/

static int hostnet_tcplist_close(implements_net_stream_listener *_listener) {
	struct hostnet_tcp_listener *listener =
		VCALL_SELF(struct hostnet_tcp_listener, implements_net_stream_listener, _listener);
	assert(listener);

	return hostnet_map_negerrno(close(listener->fd) ? -errno : 0, OP_NONE);
}

/* TCP read() *****************************************************************/

static void hostnet_tcp_set_read_deadline(implements_net_stream_conn *_conn, uint64_t ts_ns) {
	struct _hostnet_tcp_conn *conn =
		VCALL_SELF(struct _hostnet_tcp_conn, implements_net_stream_conn, _conn);
	assert(conn);

	conn->read_deadline_ns = ts_ns;
}

struct hostnet_pthread_read_args {
	pthread_t        cr_thread;
	cid_t            cr_coroutine;

	int              connfd;
	struct timeval   timeout;
	void            *buf;
	size_t           count;

	ssize_t         *ret;
};

static void *hostnet_pthread_read(void *_args) {
	struct hostnet_pthread_read_args *args = _args;

	*(args->ret) = setsockopt(args->connfd, SOL_SOCKET, SO_RCVTIMEO,
	                          &args->timeout, sizeof(args->timeout));
	if (*(args->ret) < 0)
		goto end;

	*(args->ret) = read(args->connfd, args->buf, args->count);
	if (*(args->ret) < 0)
		goto end;

 end:
	if (*(args->ret) < 0)
		*(args->ret) = hostnet_map_negerrno(-errno, OP_SEND);
	WAKE_COROUTINE(args);
	return NULL;
}

static ssize_t hostnet_tcp_read(implements_net_stream_conn *_conn, void *buf, size_t count) {
	struct _hostnet_tcp_conn *conn =
		VCALL_SELF(struct _hostnet_tcp_conn, implements_net_stream_conn, _conn);
	assert(conn);

	ssize_t ret;
	struct hostnet_pthread_read_args args = {
		.cr_thread    = pthread_self(),
		.cr_coroutine = cr_getcid(),

		.connfd       = conn->fd,
		.buf          = buf,
		.count        = count,

		.ret          = &ret,
	};
	if (conn->read_deadline_ns) {
		uint64_t now_ns = VCALL(bootclock, get_time_ns);
		if (conn->read_deadline_ns < now_ns)
			return -NET_ERECV_TIMEOUT;
		args.timeout = ns_to_host_us_time(conn->read_deadline_ns-now_ns);
	} else {
		args.timeout = (host_us_time_t){0};
	}

	if (RUN_PTHREAD(hostnet_pthread_read, &args))
		return -NET_ETHREAD;
	return ret;
}

/* TCP write() ****************************************************************/

struct hostnet_pthread_write_args {
	pthread_t        cr_thread;
	cid_t            cr_coroutine;

	int              connfd;
	void            *buf;
	size_t           count;

	ssize_t         *ret;
};

static void *hostnet_pthread_write(void *_args) {
	struct hostnet_pthread_write_args *args = _args;
	size_t done = 0;
	while (done < args->count) {
		ssize_t r = write(args->connfd, args->buf, args->count);
		if (r < 0) {
			hostnet_map_negerrno(-errno, OP_RECV);
			break;
		}
		done += r;
	}
	if (done == args->count)
		*(args->ret) = done;
	WAKE_COROUTINE(args);
	return NULL;
}

static ssize_t hostnet_tcp_write(implements_net_stream_conn *_conn, void *buf, size_t count) {
	struct _hostnet_tcp_conn *conn =
		VCALL_SELF(struct _hostnet_tcp_conn, implements_net_stream_conn, _conn);
	assert(conn);

	ssize_t ret;
	struct hostnet_pthread_write_args args = {
		.cr_thread    = pthread_self(),
		.cr_coroutine = cr_getcid(),

		.connfd       = conn->fd,
		.buf          = buf,
		.count        = count,

		.ret          = &ret,
	};
	if (RUN_PTHREAD(hostnet_pthread_write, &args))
		return -NET_ETHREAD;
	return ret;
}

/* TCP close() ****************************************************************/

static int hostnet_tcp_close(implements_net_stream_conn *_conn, bool rd, bool wr) {
	struct _hostnet_tcp_conn *conn =
		VCALL_SELF(struct _hostnet_tcp_conn, implements_net_stream_conn, _conn);
	assert(conn);

	int how;
	if (rd && wr)
		how = SHUT_RDWR;
	else if (rd && !wr)
		how = SHUT_RD;
	else if (!rd && wr)
		how = SHUT_WR;
	else
		assert_notreached("invalid arguments to stream_conn.close()");
	return hostnet_map_negerrno(shutdown(conn->fd, how) ? -errno : 0, OP_NONE);
}

/* UDP init() *****************************************************************/

static void    hostnet_udp_set_read_deadline(implements_net_packet_conn *self,
                                             uint64_t ts_ns);
static ssize_t hostnet_udp_sendto(implements_net_packet_conn *self, void *buf, size_t len,
                                  struct net_ip4_addr addr, uint16_t port);
static ssize_t hostnet_udp_recvfrom(implements_net_packet_conn *self, void *buf, size_t len,
                                    struct net_ip4_addr *ret_addr, uint16_t *ret_port);
static int     hostnet_udp_close(implements_net_packet_conn *self);

static struct net_packet_conn_vtable hostnet_udp_conn_vtable = {
	.set_read_deadline = hostnet_udp_set_read_deadline,
	.sendto            = hostnet_udp_sendto,
	.recvfrom          = hostnet_udp_recvfrom,
	.close             = hostnet_udp_close,
};

void hostnet_udp_conn_init(struct hostnet_udp_conn *self, uint16_t port) {
	int fd;
	union {
		struct sockaddr_in      in;
		struct sockaddr         gen;
		struct sockaddr_storage stor;
	} addr = { 0 };

	hostnet_init();

	addr.in.sin_family = AF_INET;
	addr.in.sin_port = htons(port);
	fd = socket(AF_INET, SOCK_DGRAM, 0);
	if (fd < 0)
		error(1, errno, "socket");
	if (bind(fd, &addr.gen, sizeof addr) < 0)
		error(1, errno, "bind");

	self->vtable = &hostnet_udp_conn_vtable;
	self->fd = fd;
	self->read_deadline_ns = 0;
}

/* UDP sendto() ***************************************************************/

struct hostnet_pthread_sendto_args {
	pthread_t                cr_thread;
	cid_t                    cr_coroutine;

	int                      connfd;
	void                    *buf;
	size_t                   count;
	struct net_ip4_addr      node;
	uint16_t                 port;

	ssize_t                 *ret;
};

static void *hostnet_pthread_sendto(void *_args) {
	struct hostnet_pthread_sendto_args *args = _args;
	union {
		struct sockaddr_in      in;
		struct sockaddr         gen;
		struct sockaddr_storage stor;
	} addr = { 0 };

	addr.in.sin_family = AF_INET;
	addr.in.sin_addr.s_addr =
		(((uint32_t)args->node.octets[0])<<24) |
		(((uint32_t)args->node.octets[1])<<16) |
		(((uint32_t)args->node.octets[2])<< 8) |
		(((uint32_t)args->node.octets[3])<< 0) ;
	addr.in.sin_port = htons(args->port);
	*(args->ret) = sendto(args->connfd, args->buf, args->count, 0, &addr.gen, sizeof(addr));
	if (*(args->ret) < 0)
		*(args->ret) = hostnet_map_negerrno(-errno, OP_SEND);
	WAKE_COROUTINE(args);
	return NULL;
}

static ssize_t hostnet_udp_sendto(implements_net_packet_conn *_conn, void *buf, size_t count,
                                  struct net_ip4_addr node, uint16_t port) {
	struct hostnet_udp_conn *conn =
		VCALL_SELF(struct hostnet_udp_conn, implements_net_packet_conn, _conn);
	assert(conn);

	ssize_t ret;
	struct hostnet_pthread_sendto_args args = {
		.cr_thread    = pthread_self(),
		.cr_coroutine = cr_getcid(),

		.connfd       = conn->fd,
		.buf          = buf,
		.count        = count,
		.node         = node,
		.port         = port,

		.ret          = &ret,
	};
	if (RUN_PTHREAD(hostnet_pthread_sendto, &args))
		return -NET_ETHREAD;
	return ret;
}

/* UDP recvfrom() *************************************************************/

static void hostnet_udp_set_read_deadline(implements_net_packet_conn *_conn,
                                          uint64_t ts_ns) {
	struct hostnet_udp_conn *conn =
		VCALL_SELF(struct hostnet_udp_conn, implements_net_packet_conn, _conn);
	assert(conn);

	conn->read_deadline_ns = ts_ns;
}

struct hostnet_pthread_recvfrom_args {
	pthread_t                cr_thread;
	cid_t                    cr_coroutine;

	int                      connfd;
	struct timeval           timeout;
	void                    *buf;
	size_t                   count;

	ssize_t                 *ret_size;
	struct net_ip4_addr     *ret_node;
	uint16_t                *ret_port;
};

static void *hostnet_pthread_recvfrom(void *_args) {
	struct hostnet_pthread_recvfrom_args *args = _args;

	union {
		struct sockaddr_in      in;
		struct sockaddr         gen;
		struct sockaddr_storage stor;
	} addr = { 0 };
	socklen_t addr_size;

	*(args->ret_size) = setsockopt(args->connfd, SOL_SOCKET, SO_RCVTIMEO,
	                               &args->timeout, sizeof(args->timeout));
	if (*(args->ret_size) < 0)
		goto end;

	*(args->ret_size) = recvfrom(args->connfd, args->buf, args->count,
	                             MSG_TRUNC, &addr.gen, &addr_size);
	if (*(args->ret_size) < 0)
		goto end;

	assert(addr.in.sin_family == AF_INET);
	if (args->ret_node) {
		args->ret_node->octets[0] = (addr.in.sin_addr.s_addr >> 24) & 0xFF;
		args->ret_node->octets[1] = (addr.in.sin_addr.s_addr >> 16) & 0xFF;
		args->ret_node->octets[2] = (addr.in.sin_addr.s_addr >>  8) & 0xFF;
		args->ret_node->octets[3] = (addr.in.sin_addr.s_addr >>  0) & 0xFF;
	}
	if (args->ret_port) {
		(*args->ret_port) = ntohs(addr.in.sin_port);
	}

 end:
	if (*(args->ret_size) < 0)
		*(args->ret_size) = hostnet_map_negerrno(-errno, OP_RECV);
	WAKE_COROUTINE(args);
	return NULL;
}

static ssize_t hostnet_udp_recvfrom(implements_net_packet_conn *_conn, void *buf, size_t count,
                                    struct net_ip4_addr *ret_node, uint16_t *ret_port) {
	struct hostnet_udp_conn *conn =
		VCALL_SELF(struct hostnet_udp_conn, implements_net_packet_conn, _conn);
	assert(conn);

	ssize_t ret;
	struct hostnet_pthread_recvfrom_args args = {
		.cr_thread    = pthread_self(),
		.cr_coroutine = cr_getcid(),

		.connfd       = conn->fd,
		.buf          = buf,
		.count        = count,

		.ret_size     = &ret,
		.ret_node     = ret_node,
		.ret_port     = ret_port,
	};
	if (conn->read_deadline_ns) {
		uint64_t now_ns = VCALL(bootclock, get_time_ns);
		if (conn->read_deadline_ns < now_ns)
			return -NET_ERECV_TIMEOUT;
		args.timeout = ns_to_host_us_time(conn->read_deadline_ns-now_ns);
	} else {
		args.timeout = (host_us_time_t){0};
	}

	if (RUN_PTHREAD(hostnet_pthread_recvfrom, &args))
		return -NET_ETHREAD;
	return ret;
}

/* UDP close() ****************************************************************/

static int hostnet_udp_close(implements_net_packet_conn *_conn) {
	struct hostnet_udp_conn *conn =
		VCALL_SELF(struct hostnet_udp_conn, implements_net_packet_conn, _conn);
	assert(conn);

	return hostnet_map_negerrno(close(conn->fd) ? -errno : 0, OP_NONE);
}