diff options
author | Jakub Kicinski <kuba@kernel.org> | 2022-07-14 22:22:30 -0700 |
---|---|---|
committer | David S. Miller <davem@davemloft.net> | 2022-07-18 11:24:11 +0100 |
commit | 541cc48be3b141e8529fef05ad6cedbca83f9e80 (patch) | |
tree | 37bbdfc174258b2baa5342a4d74ffddcc159f1c8 /net/tls/tls_sw.c | |
parent | 8a958732818bc27f7da4d41ecf2c5c99d9aa8b0e (diff) | |
download | lwn-541cc48be3b141e8529fef05ad6cedbca83f9e80.tar.gz lwn-541cc48be3b141e8529fef05ad6cedbca83f9e80.zip |
tls: rx: read the input skb from ctx->recv_pkt
Callers always pass ctx->recv_pkt into decrypt_skb_update(),
and it propagates it to its callees. This may give someone
the false impression that those functions can accept any valid
skb containing a TLS record. That's not the case, the record
sequence number is read from the context, and they can only
take the next record coming out of the strp.
Let the functions get the skb from the context instead of
passing it in. This will also make it cleaner to return
a different skb than ctx->recv_pkt as the decrypted one
later on.
Since we're touching the definition of decrypt_skb_update()
use this as an opportunity to rename it.
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
Diffstat (limited to 'net/tls/tls_sw.c')
-rw-r--r-- | net/tls/tls_sw.c | 37 |
1 files changed, 18 insertions, 19 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 5ef78e75c463..6205ad1a84c7 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -1421,8 +1421,7 @@ out: * NULL, then the decryption happens inside skb buffers itself, i.e. * zero-copy gets disabled and 'darg->zc' is updated. */ -static int tls_decrypt_sg(struct sock *sk, struct sk_buff *skb, - struct iov_iter *out_iov, +static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, struct scatterlist *out_sg, struct tls_decrypt_arg *darg) { @@ -1430,6 +1429,7 @@ static int tls_decrypt_sg(struct sock *sk, struct sk_buff *skb, struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_prot_info *prot = &tls_ctx->prot_info; int n_sgin, n_sgout, aead_size, err, pages = 0; + struct sk_buff *skb = tls_strp_msg(ctx); struct strp_msg *rxm = strp_msg(skb); struct tls_msg *tlm = tls_msg(skb); struct aead_request *aead_req; @@ -1567,14 +1567,14 @@ exit_free: static int tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx, - struct sk_buff *skb, struct tls_decrypt_arg *darg) + struct tls_decrypt_arg *darg) { int err; if (tls_ctx->rx_conf != TLS_HW) return 0; - err = tls_device_decrypted(sk, tls_ctx, skb, strp_msg(skb)); + err = tls_device_decrypted(sk, tls_ctx); if (err <= 0) return err; @@ -1583,22 +1583,22 @@ tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx, return 1; } -static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, - struct iov_iter *dest, - struct tls_decrypt_arg *darg) +static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest, + struct tls_decrypt_arg *darg) { struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); struct tls_prot_info *prot = &tls_ctx->prot_info; - struct strp_msg *rxm = strp_msg(skb); + struct strp_msg *rxm; int pad, err; - err = tls_decrypt_device(sk, tls_ctx, skb, darg); + err = tls_decrypt_device(sk, tls_ctx, darg); if (err < 0) return err; if (err) goto decrypt_done; - err = tls_decrypt_sg(sk, skb, dest, NULL, darg); + err = tls_decrypt_sg(sk, dest, NULL, darg); if (err < 0) { if (err == -EBADMSG) TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); @@ -1613,14 +1613,15 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, if (!darg->tail) TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL); TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY); - return decrypt_skb_update(sk, skb, dest, darg); + return tls_rx_one_record(sk, dest, darg); } decrypt_done: - pad = tls_padding_length(prot, skb, darg); + pad = tls_padding_length(prot, ctx->recv_pkt, darg); if (pad < 0) return pad; + rxm = strp_msg(ctx->recv_pkt); rxm->full_len -= pad; rxm->offset += prot->prepend_size; rxm->full_len -= prot->overhead_size; @@ -1630,12 +1631,11 @@ decrypt_next: return 0; } -int decrypt_skb(struct sock *sk, struct sk_buff *skb, - struct scatterlist *sgout) +int decrypt_skb(struct sock *sk, struct scatterlist *sgout) { struct tls_decrypt_arg darg = { .zc = true, }; - return tls_decrypt_sg(sk, skb, NULL, sgout, &darg); + return tls_decrypt_sg(sk, NULL, sgout, &darg); } static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, @@ -1905,7 +1905,7 @@ int tls_sw_recvmsg(struct sock *sk, else darg.async = false; - err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg); + err = tls_rx_one_record(sk, &msg->msg_iter, &darg); if (err < 0) { tls_err_abort(sk, -EBADMSG); goto recv_end; @@ -2058,14 +2058,13 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, if (err <= 0) goto splice_read_end; - skb = ctx->recv_pkt; - - err = decrypt_skb_update(sk, skb, NULL, &darg); + err = tls_rx_one_record(sk, NULL, &darg); if (err < 0) { tls_err_abort(sk, -EBADMSG); goto splice_read_end; } + skb = ctx->recv_pkt; tls_rx_rec_done(ctx); } |