diff options
Diffstat (limited to 'drivers/nvme/target/tcp.c')
-rw-r--r-- | drivers/nvme/target/tcp.c | 93 |
1 files changed, 80 insertions, 13 deletions
diff --git a/drivers/nvme/target/tcp.c b/drivers/nvme/target/tcp.c index 58a4a737be77..c9bca7d61605 100644 --- a/drivers/nvme/target/tcp.c +++ b/drivers/nvme/target/tcp.c @@ -14,6 +14,7 @@ #include <net/sock.h> #include <net/tcp.h> #include <net/tls.h> +#include <net/tls_prot.h> #include <net/handshake.h> #include <linux/inet.h> #include <linux/llist.h> @@ -118,6 +119,7 @@ struct nvmet_tcp_cmd { u32 pdu_len; u32 pdu_recv; int sg_idx; + char recv_cbuf[CMSG_LEN(sizeof(char))]; struct msghdr recv_msg; struct bio_vec *iov; u32 flags; @@ -1121,20 +1123,65 @@ static inline bool nvmet_tcp_pdu_valid(u8 type) return false; } +static int nvmet_tcp_tls_record_ok(struct nvmet_tcp_queue *queue, + struct msghdr *msg, char *cbuf) +{ + struct cmsghdr *cmsg = (struct cmsghdr *)cbuf; + u8 ctype, level, description; + int ret = 0; + + ctype = tls_get_record_type(queue->sock->sk, cmsg); + switch (ctype) { + case 0: + break; + case TLS_RECORD_TYPE_DATA: + break; + case TLS_RECORD_TYPE_ALERT: + tls_alert_recv(queue->sock->sk, msg, &level, &description); + if (level == TLS_ALERT_LEVEL_FATAL) { + pr_err("queue %d: TLS Alert desc %u\n", + queue->idx, description); + ret = -ENOTCONN; + } else { + pr_warn("queue %d: TLS Alert desc %u\n", + queue->idx, description); + ret = -EAGAIN; + } + break; + default: + /* discard this record type */ + pr_err("queue %d: TLS record %d unhandled\n", + queue->idx, ctype); + ret = -EAGAIN; + break; + } + return ret; +} + static int nvmet_tcp_try_recv_pdu(struct nvmet_tcp_queue *queue) { struct nvme_tcp_hdr *hdr = &queue->pdu.cmd.hdr; - int len; + int len, ret; struct kvec iov; + char cbuf[CMSG_LEN(sizeof(char))] = {}; struct msghdr msg = { .msg_flags = MSG_DONTWAIT }; recv: iov.iov_base = (void *)&queue->pdu + queue->offset; iov.iov_len = queue->left; + if (queue->tls_pskid) { + msg.msg_control = cbuf; + msg.msg_controllen = sizeof(cbuf); + } len = kernel_recvmsg(queue->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags); if (unlikely(len < 0)) return len; + if (queue->tls_pskid) { + ret = nvmet_tcp_tls_record_ok(queue, &msg, cbuf); + if (ret < 0) + return ret; + } queue->offset += len; queue->left -= len; @@ -1187,16 +1234,22 @@ static void nvmet_tcp_prep_recv_ddgst(struct nvmet_tcp_cmd *cmd) static int nvmet_tcp_try_recv_data(struct nvmet_tcp_queue *queue) { struct nvmet_tcp_cmd *cmd = queue->cmd; - int ret; + int len, ret; while (msg_data_left(&cmd->recv_msg)) { - ret = sock_recvmsg(cmd->queue->sock, &cmd->recv_msg, + len = sock_recvmsg(cmd->queue->sock, &cmd->recv_msg, cmd->recv_msg.msg_flags); - if (ret <= 0) - return ret; + if (len <= 0) + return len; + if (queue->tls_pskid) { + ret = nvmet_tcp_tls_record_ok(cmd->queue, + &cmd->recv_msg, cmd->recv_cbuf); + if (ret < 0) + return ret; + } - cmd->pdu_recv += ret; - cmd->rbytes_done += ret; + cmd->pdu_recv += len; + cmd->rbytes_done += len; } if (queue->data_digest) { @@ -1214,20 +1267,30 @@ static int nvmet_tcp_try_recv_data(struct nvmet_tcp_queue *queue) static int nvmet_tcp_try_recv_ddgst(struct nvmet_tcp_queue *queue) { struct nvmet_tcp_cmd *cmd = queue->cmd; - int ret; + int ret, len; + char cbuf[CMSG_LEN(sizeof(char))] = {}; struct msghdr msg = { .msg_flags = MSG_DONTWAIT }; struct kvec iov = { .iov_base = (void *)&cmd->recv_ddgst + queue->offset, .iov_len = queue->left }; - ret = kernel_recvmsg(queue->sock, &msg, &iov, 1, + if (queue->tls_pskid) { + msg.msg_control = cbuf; + msg.msg_controllen = sizeof(cbuf); + } + len = kernel_recvmsg(queue->sock, &msg, &iov, 1, iov.iov_len, msg.msg_flags); - if (unlikely(ret < 0)) - return ret; + if (unlikely(len < 0)) + return len; + if (queue->tls_pskid) { + ret = nvmet_tcp_tls_record_ok(queue, &msg, cbuf); + if (ret < 0) + return ret; + } - queue->offset += ret; - queue->left -= ret; + queue->offset += len; + queue->left -= len; if (queue->left) return -EAGAIN; @@ -1407,6 +1470,10 @@ static int nvmet_tcp_alloc_cmd(struct nvmet_tcp_queue *queue, if (!c->r2t_pdu) goto out_free_data; + if (queue->state == NVMET_TCP_Q_TLS_HANDSHAKE) { + c->recv_msg.msg_control = c->recv_cbuf; + c->recv_msg.msg_controllen = sizeof(c->recv_cbuf); + } c->recv_msg.msg_flags = MSG_DONTWAIT | MSG_NOSIGNAL; list_add_tail(&c->entry, &queue->free_list); |