/* $OpenBSD: tls13_handshake_msg.c,v 1.2 2019/11/20 16:21:20 beck Exp $ */ /* * Copyright (c) 2018, 2019 Joel Sing <jsing@openbsd.org> * * Permission to use, copy, modify, and distribute this software for any * purpose with or without fee is hereby granted, provided that the above * copyright notice and this permission notice appear in all copies. * * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ #include "bytestring.h" #include "ssl_locl.h" #include "tls13_internal.h" #define TLS13_HANDSHAKE_MSG_HEADER_LEN 4 #define TLS13_HANDSHAKE_MSG_INITIAL_LEN 256 #define TLS13_HANDSHAKE_MSG_MAX_LEN (256 * 1024) struct tls13_handshake_msg { uint8_t msg_type; uint32_t msg_len; uint8_t *data; size_t data_len; struct tls13_buffer *buf; CBS cbs; CBB cbb; }; struct tls13_handshake_msg * tls13_handshake_msg_new() { struct tls13_handshake_msg *msg = NULL; if ((msg = calloc(1, sizeof(struct tls13_handshake_msg))) == NULL) goto err; if ((msg->buf = tls13_buffer_new(0)) == NULL) goto err; return msg; err: tls13_handshake_msg_free(msg); return NULL; } void tls13_handshake_msg_free(struct tls13_handshake_msg *msg) { if (msg == NULL) return; tls13_buffer_free(msg->buf); CBB_cleanup(&msg->cbb); freezero(msg->data, msg->data_len); freezero(msg, sizeof(struct tls13_handshake_msg)); } void tls13_handshake_msg_data(struct tls13_handshake_msg *msg, CBS *cbs) { CBS_init(cbs, msg->data, msg->data_len); } int tls13_handshake_msg_set_buffer(struct tls13_handshake_msg *msg, CBS *cbs) { return tls13_buffer_set_data(msg->buf, cbs); } uint8_t tls13_handshake_msg_type(struct tls13_handshake_msg *msg) { return msg->msg_type; } int tls13_handshake_msg_content(struct tls13_handshake_msg *msg, CBS *cbs) { tls13_handshake_msg_data(msg, cbs); return CBS_skip(cbs, TLS13_HANDSHAKE_MSG_HEADER_LEN); } int tls13_handshake_msg_start(struct tls13_handshake_msg *msg, CBB *body, uint8_t msg_type) { if (!CBB_init(&msg->cbb, TLS13_HANDSHAKE_MSG_INITIAL_LEN)) return 0; if (!CBB_add_u8(&msg->cbb, msg_type)) return 0; if (!CBB_add_u24_length_prefixed(&msg->cbb, body)) return 0; return 1; } int tls13_handshake_msg_finish(struct tls13_handshake_msg *msg) { if (!CBB_finish(&msg->cbb, &msg->data, &msg->data_len)) return 0; CBS_init(&msg->cbs, msg->data, msg->data_len); return 1; } static ssize_t tls13_handshake_msg_read_cb(void *buf, size_t n, void *cb_arg) { struct tls13_record_layer *rl = cb_arg; return tls13_read_handshake_data(rl, buf, n); } int tls13_handshake_msg_recv(struct tls13_handshake_msg *msg, struct tls13_record_layer *rl) { uint8_t msg_type; uint32_t msg_len; CBS cbs; int ret; if (msg->data != NULL) return TLS13_IO_FAILURE; if (msg->msg_type == 0) { if ((ret = tls13_buffer_extend(msg->buf, TLS13_HANDSHAKE_MSG_HEADER_LEN, tls13_handshake_msg_read_cb, rl)) <= 0) return ret; tls13_buffer_cbs(msg->buf, &cbs); if (!CBS_get_u8(&cbs, &msg_type)) return TLS13_IO_FAILURE; if (!CBS_get_u24(&cbs, &msg_len)) return TLS13_IO_FAILURE; /* XXX - do we want to make this variable on message type? */ if (msg_len > TLS13_HANDSHAKE_MSG_MAX_LEN) return TLS13_IO_FAILURE; msg->msg_type = msg_type; msg->msg_len = msg_len; } if ((ret = tls13_buffer_extend(msg->buf, TLS13_HANDSHAKE_MSG_HEADER_LEN + msg->msg_len, tls13_handshake_msg_read_cb, rl)) <= 0) return ret; if (!tls13_buffer_finish(msg->buf, &msg->data, &msg->data_len)) return TLS13_IO_FAILURE; return TLS13_IO_SUCCESS; } int tls13_handshake_msg_send(struct tls13_handshake_msg *msg, struct tls13_record_layer *rl) { ssize_t ret; if (msg->data == NULL) return TLS13_IO_FAILURE; if (CBS_len(&msg->cbs) == 0) return TLS13_IO_FAILURE; while (CBS_len(&msg->cbs) > 0) { if ((ret = tls13_write_handshake_data(rl, CBS_data(&msg->cbs), CBS_len(&msg->cbs))) <= 0) return ret; if (!CBS_skip(&msg->cbs, ret)) return TLS13_IO_FAILURE; } return TLS13_IO_SUCCESS; }