summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--net/tls/tls.h2
-rw-r--r--net/tls/tls_sw.c40
2 files changed, 31 insertions, 11 deletions
diff --git a/net/tls/tls.h b/net/tls/tls.h
index 0e840a0c3437..804c3880d028 100644
--- a/net/tls/tls.h
+++ b/net/tls/tls.h
@@ -70,6 +70,8 @@ struct tls_rec {
char content_type;
struct scatterlist sg_content_type;
+ struct sock *sk;
+
char aad_space[TLS_AAD_SPACE_SIZE];
u8 iv_data[MAX_IV_SIZE];
struct aead_request aead_req;
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 9ed978634125..5b7f67a7d394 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -38,6 +38,7 @@
#include <linux/bug.h>
#include <linux/sched/signal.h>
#include <linux/module.h>
+#include <linux/kernel.h>
#include <linux/splice.h>
#include <crypto/aead.h>
@@ -57,6 +58,7 @@ struct tls_decrypt_arg {
};
struct tls_decrypt_ctx {
+ struct sock *sk;
u8 iv[MAX_IV_SIZE];
u8 aad[TLS_MAX_AAD_SIZE];
u8 tail;
@@ -177,18 +179,25 @@ static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb,
return sub;
}
-static void tls_decrypt_done(struct crypto_async_request *req, int err)
+static void tls_decrypt_done(crypto_completion_data_t *data, int err)
{
- struct aead_request *aead_req = (struct aead_request *)req;
+ struct aead_request *aead_req = crypto_get_completion_data(data);
+ struct crypto_aead *aead = crypto_aead_reqtfm(aead_req);
struct scatterlist *sgout = aead_req->dst;
struct scatterlist *sgin = aead_req->src;
struct tls_sw_context_rx *ctx;
+ struct tls_decrypt_ctx *dctx;
struct tls_context *tls_ctx;
struct scatterlist *sg;
unsigned int pages;
struct sock *sk;
+ int aead_size;
- sk = (struct sock *)req->data;
+ aead_size = sizeof(*aead_req) + crypto_aead_reqsize(aead);
+ aead_size = ALIGN(aead_size, __alignof__(*dctx));
+ dctx = (void *)((u8 *)aead_req + aead_size);
+
+ sk = dctx->sk;
tls_ctx = tls_get_ctx(sk);
ctx = tls_sw_ctx_rx(tls_ctx);
@@ -240,7 +249,7 @@ static int tls_do_decryption(struct sock *sk,
if (darg->async) {
aead_request_set_callback(aead_req,
CRYPTO_TFM_REQ_MAY_BACKLOG,
- tls_decrypt_done, sk);
+ tls_decrypt_done, aead_req);
atomic_inc(&ctx->decrypt_pending);
} else {
aead_request_set_callback(aead_req,
@@ -336,6 +345,8 @@ static struct tls_rec *tls_get_rec(struct sock *sk)
sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
sg_unmark_end(&rec->sg_aead_out[1]);
+ rec->sk = sk;
+
return rec;
}
@@ -417,22 +428,27 @@ tx_err:
return rc;
}
-static void tls_encrypt_done(struct crypto_async_request *req, int err)
+static void tls_encrypt_done(crypto_completion_data_t *data, int err)
{
- struct aead_request *aead_req = (struct aead_request *)req;
- struct sock *sk = req->data;
- struct tls_context *tls_ctx = tls_get_ctx(sk);
- struct tls_prot_info *prot = &tls_ctx->prot_info;
- struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+ struct aead_request *aead_req = crypto_get_completion_data(data);
+ struct tls_sw_context_tx *ctx;
+ struct tls_context *tls_ctx;
+ struct tls_prot_info *prot;
struct scatterlist *sge;
struct sk_msg *msg_en;
struct tls_rec *rec;
bool ready = false;
+ struct sock *sk;
int pending;
rec = container_of(aead_req, struct tls_rec, aead_req);
msg_en = &rec->msg_encrypted;
+ sk = rec->sk;
+ tls_ctx = tls_get_ctx(sk);
+ prot = &tls_ctx->prot_info;
+ ctx = tls_sw_ctx_tx(tls_ctx);
+
sge = sk_msg_elem(msg_en, msg_en->sg.curr);
sge->offset -= prot->prepend_size;
sge->length += prot->prepend_size;
@@ -520,7 +536,7 @@ static int tls_do_encryption(struct sock *sk,
data_len, rec->iv_data);
aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
- tls_encrypt_done, sk);
+ tls_encrypt_done, aead_req);
/* Add the record in tx_list */
list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
@@ -1485,6 +1501,7 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
* Both structs are variable length.
*/
aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
+ aead_size = ALIGN(aead_size, __alignof__(*dctx));
mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout),
sk->sk_allocation);
if (!mem) {
@@ -1495,6 +1512,7 @@ static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov,
/* Segment the allocated memory */
aead_req = (struct aead_request *)mem;
dctx = (struct tls_decrypt_ctx *)(mem + aead_size);
+ dctx->sk = sk;
sgin = &dctx->sg[0];
sgout = &dctx->sg[n_sgin];