/* lib9p/9p.c - Base 9P protocol utilities for both clients and servers * * Copyright (C) 2024 Luke T. Shumaker * SPDX-License-Identifier: AGPL-3.0-or-later */ #include /* for PRIu{n} */ #include /* for va_* */ #include /* for vsnprintf() */ #include /* for strncpy() */ #include #include "internal.h" /* ctx ************************************************************************/ void lib9p_ctx_clear_error(struct lib9p_ctx *ctx) { assert(ctx); #ifdef CONFIG_9P_ENABLE_9P2000_u ctx->err_num = 0; #endif ctx->err_msg[0] = '\0'; } bool lib9p_ctx_has_error(struct lib9p_ctx *ctx) { assert(ctx); return ctx->err_msg[0]; } int lib9p_error(struct lib9p_ctx *ctx, uint32_t linux_errno, char const *msg) { if (lib9p_ctx_has_error(ctx)) return -1; strncpy(ctx->err_msg, msg, sizeof(ctx->err_msg)); ctx->err_msg[sizeof(ctx->err_msg)-1] = '\0'; #ifdef CONFIG_9P_ENABLE_9P2000_u ctx->err_num = linux_errno; #else (void)(linux_errno); #endif return -1; } int lib9p_errorf(struct lib9p_ctx *ctx, uint32_t linux_errno, char const *fmt, ...) { int n; va_list args; if (lib9p_ctx_has_error(ctx)) return -1; va_start(args, fmt); n = vsnprintf(ctx->err_msg, sizeof(ctx->err_msg), fmt, args); va_end(args); if ((size_t)(n+1) < sizeof(ctx->err_msg)) memset(&ctx->err_msg[n+1], 0, sizeof(ctx->err_msg)-(n+1)); #ifdef CONFIG_9P_ENABLE_9P2000_u ctx->err_num = linux_errno; #else (void)(linux_errno); #endif return -1; } const char *lib9p_msg_type_str(struct lib9p_ctx *ctx, enum lib9p_msg_type typ) { assert(0 <= typ && typ <= 0xFF); return _lib9p_versions[ctx->version].msgs[typ].name; } /* main message functions *****************************************************/ ssize_t lib9p_validate(struct lib9p_ctx *ctx, uint8_t *net_bytes) { /* Inspect the first 5 bytes ourselves. */ struct _validate_ctx subctx = { .ctx = ctx, .net_size = decode_u32le(net_bytes), .net_bytes = net_bytes, .net_offset = 0, .host_extra = 0, }; if (subctx.net_size < 5) return lib9p_error(ctx, LINUX_EBADMSG, "message is impossibly short"); uint8_t typ = net_bytes[4]; struct _table_msg table = _lib9p_versions[ctx->version].msgs[typ]; if (!table.validate) return lib9p_errorf(ctx, LINUX_EOPNOTSUPP, "unknown message type: %s (protocol_version=%s)", lib9p_msg_type_str(ctx, typ), lib9p_version_str(ctx->version)); /* Now use the message-type-specific table to process the whole thing. */ if (table.validate(&subctx)) return -1; assert(subctx.net_offset <= subctx.net_size); if (subctx.net_offset < subctx.net_size) return lib9p_errorf(ctx, LINUX_EBADMSG, "message has %"PRIu32" extra bytes", subctx.net_size - subctx.net_offset); /* Return. */ ssize_t ret; if (__builtin_add_overflow(table.basesize, subctx.host_extra, &ret)) return lib9p_error(ctx, LINUX_EMSGSIZE, "unmarshalled payload overflows SSIZE_MAX"); return ret; } void lib9p_unmarshal(struct lib9p_ctx *ctx, uint8_t *net_bytes, enum lib9p_msg_type *ret_typ, void *ret_body) { struct _unmarshal_ctx subctx = { .ctx = ctx, .net_bytes = net_bytes, .net_offset = 0, }; *ret_typ = net_bytes[4]; struct _table_msg table = _lib9p_versions[ctx->version].msgs[*ret_typ]; subctx.extra = ret_body + table.basesize; table.unmarshal(&subctx, ret_body); } bool lib9p_marshal(struct lib9p_ctx *ctx, enum lib9p_msg_type typ, void *body, uint8_t *ret_bytes) { struct _marshal_ctx subctx = { .ctx = ctx, .net_bytes = ret_bytes, .net_offset = 0, }; struct _table_msg table = _lib9p_versions[ctx->version].msgs[typ]; return table.marshal(&subctx, body); } /* `struct lib9p_stat` helpers ************************************************/ bool lib9p_validate_stat(struct lib9p_ctx *ctx, uint32_t net_size, uint8_t *net_bytes, uint32_t *ret_net_size, ssize_t *ret_host_size) { struct _validate_ctx subctx = { .ctx = ctx, .net_size = net_size, .net_bytes = net_bytes, .net_offset = 0, .host_extra = 0, }; if (_lib9p_validate_stat(&subctx)) return true; if (ret_net_size) *ret_net_size = subctx.net_offset; if (ret_host_size) if (__builtin_add_overflow(sizeof(struct lib9p_stat), subctx.host_extra, ret_host_size)) return lib9p_error(ctx, LINUX_EMSGSIZE, "unmarshalled stat object overflows SSIZE_MAX"); return false; } uint32_t lib9p_unmarshal_stat(struct lib9p_ctx *ctx, uint8_t *net_bytes, struct lib9p_stat *ret_obj, void *ret_extra) { struct _unmarshal_ctx subctx = { .ctx = ctx, .net_bytes = net_bytes, .net_offset = 0, .extra = ret_extra, }; _lib9p_unmarshal_stat(&subctx, ret_obj); return subctx.net_offset; } uint32_t lib9p_marshal_stat(struct lib9p_ctx *ctx, uint32_t max_net_size, struct lib9p_stat *obj, uint8_t *ret_bytes) { struct lib9p_ctx _ctx = *ctx; _ctx.max_msg_size = max_net_size; struct _marshal_ctx subctx = { .ctx = &_ctx, .net_bytes = ret_bytes, .net_offset = 0, }; if (_lib9p_marshal_stat(&subctx, obj)) return 0; return subctx.net_offset; }