summaryrefslogtreecommitdiff
path: root/net/tls/tls_sw.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/tls/tls_sw.c')
-rw-r--r--net/tls/tls_sw.c140
1 files changed, 106 insertions, 34 deletions
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 7bcc9b4408a2..914d4e1516a3 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -1314,6 +1314,10 @@ tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock,
int ret = 0;
long timeo;
+ /* a rekey is pending, let userspace deal with it */
+ if (unlikely(ctx->key_update_pending))
+ return -EKEYEXPIRED;
+
timeo = sock_rcvtimeo(sk, nonblock);
while (!tls_strp_msg_ready(ctx)) {
@@ -1720,6 +1724,36 @@ tls_decrypt_device(struct sock *sk, struct msghdr *msg,
return 1;
}
+static int tls_check_pending_rekey(struct sock *sk, struct tls_context *ctx,
+ struct sk_buff *skb)
+{
+ const struct strp_msg *rxm = strp_msg(skb);
+ const struct tls_msg *tlm = tls_msg(skb);
+ char hs_type;
+ int err;
+
+ if (likely(tlm->control != TLS_RECORD_TYPE_HANDSHAKE))
+ return 0;
+
+ if (rxm->full_len < 1)
+ return 0;
+
+ err = skb_copy_bits(skb, rxm->offset, &hs_type, 1);
+ if (err < 0) {
+ DEBUG_NET_WARN_ON_ONCE(1);
+ return err;
+ }
+
+ if (hs_type == TLS_HANDSHAKE_KEYUPDATE) {
+ struct tls_sw_context_rx *rx_ctx = ctx->priv_ctx_rx;
+
+ WRITE_ONCE(rx_ctx->key_update_pending, true);
+ TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXREKEYRECEIVED);
+ }
+
+ return 0;
+}
+
static int tls_rx_one_record(struct sock *sk, struct msghdr *msg,
struct tls_decrypt_arg *darg)
{
@@ -1739,7 +1773,7 @@ static int tls_rx_one_record(struct sock *sk, struct msghdr *msg,
rxm->full_len -= prot->overhead_size;
tls_advance_record_sn(sk, prot, &tls_ctx->rx);
- return 0;
+ return tls_check_pending_rekey(sk, tls_ctx, darg->skb);
}
int decrypt_skb(struct sock *sk, struct scatterlist *sgout)
@@ -2684,12 +2718,22 @@ int init_prot_info(struct tls_prot_info *prot,
return 0;
}
-int tls_set_sw_offload(struct sock *sk, int tx)
+static void tls_finish_key_update(struct sock *sk, struct tls_context *tls_ctx)
+{
+ struct tls_sw_context_rx *ctx = tls_ctx->priv_ctx_rx;
+
+ WRITE_ONCE(ctx->key_update_pending, false);
+ /* wake-up pre-existing poll() */
+ ctx->saved_data_ready(sk);
+}
+
+int tls_set_sw_offload(struct sock *sk, int tx,
+ struct tls_crypto_info *new_crypto_info)
{
+ struct tls_crypto_info *crypto_info, *src_crypto_info;
struct tls_sw_context_tx *sw_ctx_tx = NULL;
struct tls_sw_context_rx *sw_ctx_rx = NULL;
const struct tls_cipher_desc *cipher_desc;
- struct tls_crypto_info *crypto_info;
char *iv, *rec_seq, *key, *salt;
struct cipher_context *cctx;
struct tls_prot_info *prot;
@@ -2701,44 +2745,47 @@ int tls_set_sw_offload(struct sock *sk, int tx)
ctx = tls_get_ctx(sk);
prot = &ctx->prot_info;
- if (tx) {
- ctx->priv_ctx_tx = init_ctx_tx(ctx, sk);
- if (!ctx->priv_ctx_tx)
- return -ENOMEM;
+ /* new_crypto_info != NULL means rekey */
+ if (!new_crypto_info) {
+ if (tx) {
+ ctx->priv_ctx_tx = init_ctx_tx(ctx, sk);
+ if (!ctx->priv_ctx_tx)
+ return -ENOMEM;
+ } else {
+ ctx->priv_ctx_rx = init_ctx_rx(ctx);
+ if (!ctx->priv_ctx_rx)
+ return -ENOMEM;
+ }
+ }
+ if (tx) {
sw_ctx_tx = ctx->priv_ctx_tx;
crypto_info = &ctx->crypto_send.info;
cctx = &ctx->tx;
aead = &sw_ctx_tx->aead_send;
} else {
- ctx->priv_ctx_rx = init_ctx_rx(ctx);
- if (!ctx->priv_ctx_rx)
- return -ENOMEM;
-
sw_ctx_rx = ctx->priv_ctx_rx;
crypto_info = &ctx->crypto_recv.info;
cctx = &ctx->rx;
aead = &sw_ctx_rx->aead_recv;
}
- cipher_desc = get_cipher_desc(crypto_info->cipher_type);
+ src_crypto_info = new_crypto_info ?: crypto_info;
+
+ cipher_desc = get_cipher_desc(src_crypto_info->cipher_type);
if (!cipher_desc) {
rc = -EINVAL;
goto free_priv;
}
- rc = init_prot_info(prot, crypto_info, cipher_desc);
+ rc = init_prot_info(prot, src_crypto_info, cipher_desc);
if (rc)
goto free_priv;
- iv = crypto_info_iv(crypto_info, cipher_desc);
- key = crypto_info_key(crypto_info, cipher_desc);
- salt = crypto_info_salt(crypto_info, cipher_desc);
- rec_seq = crypto_info_rec_seq(crypto_info, cipher_desc);
-
- memcpy(cctx->iv, salt, cipher_desc->salt);
- memcpy(cctx->iv + cipher_desc->salt, iv, cipher_desc->iv);
- memcpy(cctx->rec_seq, rec_seq, cipher_desc->rec_seq);
+ iv = crypto_info_iv(src_crypto_info, cipher_desc);
+ key = crypto_info_key(src_crypto_info, cipher_desc);
+ salt = crypto_info_salt(src_crypto_info, cipher_desc);
+ rec_seq = crypto_info_rec_seq(src_crypto_info, cipher_desc);
if (!*aead) {
*aead = crypto_alloc_aead(cipher_desc->cipher_name, 0, 0);
@@ -2751,20 +2798,30 @@ int tls_set_sw_offload(struct sock *sk, int tx)
ctx->push_pending_record = tls_sw_push_pending_record;
+ /* setkey is the last operation that could fail during a
+ * rekey. if it succeeds, we can start modifying the
+ * context.
+ */
rc = crypto_aead_setkey(*aead, key, cipher_desc->key);
- if (rc)
- goto free_aead;
+ if (rc) {
+ if (new_crypto_info)
+ goto out;
+ else
+ goto free_aead;
+ }
- rc = crypto_aead_setauthsize(*aead, prot->tag_size);
- if (rc)
- goto free_aead;
+ if (!new_crypto_info) {
+ rc = crypto_aead_setauthsize(*aead, prot->tag_size);
+ if (rc)
+ goto free_aead;
+ }
- if (sw_ctx_rx) {
+ if (!tx && !new_crypto_info) {
tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
tls_update_rx_zc_capable(ctx);
sw_ctx_rx->async_capable =
- crypto_info->version != TLS_1_3_VERSION &&
+ src_crypto_info->version != TLS_1_3_VERSION &&
!!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC);
rc = tls_strp_init(&sw_ctx_rx->strp, sk);
@@ -2772,18 +2829,33 @@ int tls_set_sw_offload(struct sock *sk, int tx)
goto free_aead;
}
+ memcpy(cctx->iv, salt, cipher_desc->salt);
+ memcpy(cctx->iv + cipher_desc->salt, iv, cipher_desc->iv);
+ memcpy(cctx->rec_seq, rec_seq, cipher_desc->rec_seq);
+
+ if (new_crypto_info) {
+ unsafe_memcpy(crypto_info, new_crypto_info,
+ cipher_desc->crypto_info,
+ /* size was checked in do_tls_setsockopt_conf */);
+ memzero_explicit(new_crypto_info, cipher_desc->crypto_info);
+ if (!tx)
+ tls_finish_key_update(sk, ctx);
+ }
+
goto out;
free_aead:
crypto_free_aead(*aead);
*aead = NULL;
free_priv:
- if (tx) {
- kfree(ctx->priv_ctx_tx);
- ctx->priv_ctx_tx = NULL;
- } else {
- kfree(ctx->priv_ctx_rx);
- ctx->priv_ctx_rx = NULL;
+ if (!new_crypto_info) {
+ if (tx) {
+ kfree(ctx->priv_ctx_tx);
+ ctx->priv_ctx_tx = NULL;
+ } else {
+ kfree(ctx->priv_ctx_rx);
+ ctx->priv_ctx_rx = NULL;
+ }
}
out:
return rc;