diff options
Diffstat (limited to 'net/tls')
-rw-r--r-- | net/tls/tls_sw.c | 238 |
1 files changed, 142 insertions, 96 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c index 83d67df33f0c..52fbe727d7c1 100644 --- a/net/tls/tls_sw.c +++ b/net/tls/tls_sw.c @@ -48,19 +48,13 @@ static int tls_do_decryption(struct sock *sk, struct scatterlist *sgout, char *iv_recv, size_t data_len, - struct sk_buff *skb, - gfp_t flags) + struct aead_request *aead_req) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - struct aead_request *aead_req; - int ret; - aead_req = aead_request_alloc(ctx->aead_recv, flags); - if (!aead_req) - return -ENOMEM; - + aead_request_set_tfm(aead_req, ctx->aead_recv); aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); aead_request_set_crypt(aead_req, sgin, sgout, data_len + tls_ctx->rx.tag_size, @@ -69,8 +63,6 @@ static int tls_do_decryption(struct sock *sk, crypto_req_done, &ctx->async_wait); ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait); - - aead_request_free(aead_req); return ret; } @@ -657,8 +649,132 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags, return skb; } +/* This function decrypts the input skb into either out_iov or in out_sg + * or in skb buffers itself. The input parameter 'zc' indicates if + * zero-copy mode needs to be tried or not. With zero-copy mode, either + * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are + * NULL, then the decryption happens inside skb buffers itself, i.e. + * zero-copy gets disabled and 'zc' is updated. + */ + +static int decrypt_internal(struct sock *sk, struct sk_buff *skb, + struct iov_iter *out_iov, + struct scatterlist *out_sg, + int *chunk, bool *zc) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); + struct strp_msg *rxm = strp_msg(skb); + int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0; + struct aead_request *aead_req; + struct sk_buff *unused; + u8 *aad, *iv, *mem = NULL; + struct scatterlist *sgin = NULL; + struct scatterlist *sgout = NULL; + const int data_len = rxm->full_len - tls_ctx->rx.overhead_size; + + if (*zc && (out_iov || out_sg)) { + if (out_iov) + n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1; + else + n_sgout = sg_nents(out_sg); + } else { + n_sgout = 0; + *zc = false; + } + + n_sgin = skb_cow_data(skb, 0, &unused); + if (n_sgin < 1) + return -EBADMSG; + + /* Increment to accommodate AAD */ + n_sgin = n_sgin + 1; + + nsg = n_sgin + n_sgout; + + aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); + mem_size = aead_size + (nsg * sizeof(struct scatterlist)); + mem_size = mem_size + TLS_AAD_SPACE_SIZE; + mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv); + + /* Allocate a single block of memory which contains + * aead_req || sgin[] || sgout[] || aad || iv. + * This order achieves correct alignment for aead_req, sgin, sgout. + */ + mem = kmalloc(mem_size, sk->sk_allocation); + if (!mem) + return -ENOMEM; + + /* Segment the allocated memory */ + aead_req = (struct aead_request *)mem; + sgin = (struct scatterlist *)(mem + aead_size); + sgout = sgin + n_sgin; + aad = (u8 *)(sgout + n_sgout); + iv = aad + TLS_AAD_SPACE_SIZE; + + /* Prepare IV */ + err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, + iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, + tls_ctx->rx.iv_size); + if (err < 0) { + kfree(mem); + return err; + } + memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); + + /* Prepare AAD */ + tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size, + tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size, + ctx->control); + + /* Prepare sgin */ + sg_init_table(sgin, n_sgin); + sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE); + err = skb_to_sgvec(skb, &sgin[1], + rxm->offset + tls_ctx->rx.prepend_size, + rxm->full_len - tls_ctx->rx.prepend_size); + if (err < 0) { + kfree(mem); + return err; + } + + if (n_sgout) { + if (out_iov) { + sg_init_table(sgout, n_sgout); + sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE); + + *chunk = 0; + err = zerocopy_from_iter(sk, out_iov, data_len, &pages, + chunk, &sgout[1], + (n_sgout - 1), false); + if (err < 0) + goto fallback_to_reg_recv; + } else if (out_sg) { + memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); + } else { + goto fallback_to_reg_recv; + } + } else { +fallback_to_reg_recv: + sgout = sgin; + pages = 0; + *chunk = 0; + *zc = false; + } + + /* Prepare and submit AEAD request */ + err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req); + + /* Release the pages in case iov was mapped to pages */ + for (; pages > 0; pages--) + put_page(sg_page(&sgout[pages])); + + kfree(mem); + return err; +} + static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, - struct scatterlist *sgout, bool *zc) + struct iov_iter *dest, int *chunk, bool *zc) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); @@ -671,7 +787,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, return err; #endif if (!ctx->decrypted) { - err = decrypt_skb(sk, skb, sgout); + err = decrypt_internal(sk, skb, dest, NULL, chunk, zc); if (err < 0) return err; } else { @@ -690,54 +806,10 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, int decrypt_skb(struct sock *sk, struct sk_buff *skb, struct scatterlist *sgout) { - struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); - char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE]; - struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2]; - struct scatterlist *sgin = &sgin_arr[0]; - struct strp_msg *rxm = strp_msg(skb); - int ret, nsg = ARRAY_SIZE(sgin_arr); - struct sk_buff *unused; - - ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, - iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, - tls_ctx->rx.iv_size); - if (ret < 0) - return ret; - - memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); - if (!sgout) { - nsg = skb_cow_data(skb, 0, &unused) + 1; - sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation); - sgout = sgin; - } - - sg_init_table(sgin, nsg); - sg_set_buf(&sgin[0], ctx->rx_aad_ciphertext, TLS_AAD_SPACE_SIZE); - - nsg = skb_to_sgvec(skb, &sgin[1], - rxm->offset + tls_ctx->rx.prepend_size, - rxm->full_len - tls_ctx->rx.prepend_size); - if (nsg < 0) { - ret = nsg; - goto out; - } - - tls_make_aad(ctx->rx_aad_ciphertext, - rxm->full_len - tls_ctx->rx.overhead_size, - tls_ctx->rx.rec_seq, - tls_ctx->rx.rec_seq_size, - ctx->control); - - ret = tls_do_decryption(sk, sgin, sgout, iv, - rxm->full_len - tls_ctx->rx.overhead_size, - skb, sk->sk_allocation); - -out: - if (sgin != &sgin_arr[0]) - kfree(sgin); + bool zc = true; + int chunk; - return ret; + return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc); } static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, @@ -816,43 +888,17 @@ int tls_sw_recvmsg(struct sock *sk, } if (!ctx->decrypted) { - int page_count; - int to_copy; - - page_count = iov_iter_npages(&msg->msg_iter, - MAX_SKB_FRAGS); - to_copy = rxm->full_len - tls_ctx->rx.overhead_size; - if (!is_kvec && to_copy <= len && page_count < MAX_SKB_FRAGS && - likely(!(flags & MSG_PEEK))) { - struct scatterlist sgin[MAX_SKB_FRAGS + 1]; - int pages = 0; + int to_copy = rxm->full_len - tls_ctx->rx.overhead_size; + if (!is_kvec && to_copy <= len && + likely(!(flags & MSG_PEEK))) zc = true; - sg_init_table(sgin, MAX_SKB_FRAGS + 1); - sg_set_buf(&sgin[0], ctx->rx_aad_plaintext, - TLS_AAD_SPACE_SIZE); - - err = zerocopy_from_iter(sk, &msg->msg_iter, - to_copy, &pages, - &chunk, &sgin[1], - MAX_SKB_FRAGS, false); - if (err < 0) - goto fallback_to_reg_recv; - - err = decrypt_skb_update(sk, skb, sgin, &zc); - for (; pages > 0; pages--) - put_page(sg_page(&sgin[pages])); - if (err < 0) { - tls_err_abort(sk, EBADMSG); - goto recv_end; - } - } else { -fallback_to_reg_recv: - err = decrypt_skb_update(sk, skb, NULL, &zc); - if (err < 0) { - tls_err_abort(sk, EBADMSG); - goto recv_end; - } + + err = decrypt_skb_update(sk, skb, &msg->msg_iter, + &chunk, &zc); + if (err < 0) { + tls_err_abort(sk, EBADMSG); + goto recv_end; } ctx->decrypted = true; } @@ -903,7 +949,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, int err = 0; long timeo; int chunk; - bool zc; + bool zc = false; lock_sock(sk); @@ -920,7 +966,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, } if (!ctx->decrypted) { - err = decrypt_skb_update(sk, skb, NULL, &zc); + err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc); if (err < 0) { tls_err_abort(sk, EBADMSG); |