diff options
Diffstat (limited to 'drivers/virt')
-rw-r--r-- | drivers/virt/coco/sev-guest/Kconfig | 4 | ||||
-rw-r--r-- | drivers/virt/coco/sev-guest/sev-guest.c | 175 |
2 files changed, 40 insertions, 139 deletions
diff --git a/drivers/virt/coco/sev-guest/Kconfig b/drivers/virt/coco/sev-guest/Kconfig index 1cffc72c41cb..0b772bd921d8 100644 --- a/drivers/virt/coco/sev-guest/Kconfig +++ b/drivers/virt/coco/sev-guest/Kconfig @@ -2,9 +2,7 @@ config SEV_GUEST tristate "AMD SEV Guest driver" default m depends on AMD_MEM_ENCRYPT - select CRYPTO - select CRYPTO_AEAD2 - select CRYPTO_GCM + select CRYPTO_LIB_AESGCM select TSM_REPORTS help SEV-SNP firmware provides the guest a mechanism to communicate with diff --git a/drivers/virt/coco/sev-guest/sev-guest.c b/drivers/virt/coco/sev-guest/sev-guest.c index 89754b019be2..a33daff516ed 100644 --- a/drivers/virt/coco/sev-guest/sev-guest.c +++ b/drivers/virt/coco/sev-guest/sev-guest.c @@ -17,8 +17,7 @@ #include <linux/set_memory.h> #include <linux/fs.h> #include <linux/tsm.h> -#include <crypto/aead.h> -#include <linux/scatterlist.h> +#include <crypto/gcm.h> #include <linux/psp-sev.h> #include <linux/sockptr.h> #include <linux/cleanup.h> @@ -31,26 +30,18 @@ #include <asm/sev.h> #define DEVICE_NAME "sev-guest" -#define AAD_LEN 48 -#define MSG_HDR_VER 1 #define SNP_REQ_MAX_RETRY_DURATION (60*HZ) #define SNP_REQ_RETRY_DELAY (2*HZ) #define SVSM_MAX_RETRIES 3 -struct snp_guest_crypto { - struct crypto_aead *tfm; - u8 *iv, *authtag; - int iv_len, a_len; -}; - struct snp_guest_dev { struct device *dev; struct miscdevice misc; void *certs_data; - struct snp_guest_crypto *crypto; + struct aesgcm_ctx *ctx; /* request and response are in unencrypted memory */ struct snp_guest_msg *request, *response; @@ -169,132 +160,31 @@ static inline struct snp_guest_dev *to_snp_dev(struct file *file) return container_of(dev, struct snp_guest_dev, misc); } -static struct snp_guest_crypto *init_crypto(struct snp_guest_dev *snp_dev, u8 *key, size_t keylen) +static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen) { - struct snp_guest_crypto *crypto; + struct aesgcm_ctx *ctx; - crypto = kzalloc(sizeof(*crypto), GFP_KERNEL_ACCOUNT); - if (!crypto) + ctx = kzalloc(sizeof(*ctx), GFP_KERNEL_ACCOUNT); + if (!ctx) return NULL; - crypto->tfm = crypto_alloc_aead("gcm(aes)", 0, 0); - if (IS_ERR(crypto->tfm)) - goto e_free; - - if (crypto_aead_setkey(crypto->tfm, key, keylen)) - goto e_free_crypto; - - crypto->iv_len = crypto_aead_ivsize(crypto->tfm); - crypto->iv = kmalloc(crypto->iv_len, GFP_KERNEL_ACCOUNT); - if (!crypto->iv) - goto e_free_crypto; - - if (crypto_aead_authsize(crypto->tfm) > MAX_AUTHTAG_LEN) { - if (crypto_aead_setauthsize(crypto->tfm, MAX_AUTHTAG_LEN)) { - dev_err(snp_dev->dev, "failed to set authsize to %d\n", MAX_AUTHTAG_LEN); - goto e_free_iv; - } + if (aesgcm_expandkey(ctx, key, keylen, AUTHTAG_LEN)) { + pr_err("Crypto context initialization failed\n"); + kfree(ctx); + return NULL; } - crypto->a_len = crypto_aead_authsize(crypto->tfm); - crypto->authtag = kmalloc(crypto->a_len, GFP_KERNEL_ACCOUNT); - if (!crypto->authtag) - goto e_free_iv; - - return crypto; - -e_free_iv: - kfree(crypto->iv); -e_free_crypto: - crypto_free_aead(crypto->tfm); -e_free: - kfree(crypto); - - return NULL; -} - -static void deinit_crypto(struct snp_guest_crypto *crypto) -{ - crypto_free_aead(crypto->tfm); - kfree(crypto->iv); - kfree(crypto->authtag); - kfree(crypto); -} - -static int enc_dec_message(struct snp_guest_crypto *crypto, struct snp_guest_msg *msg, - u8 *src_buf, u8 *dst_buf, size_t len, bool enc) -{ - struct snp_guest_msg_hdr *hdr = &msg->hdr; - struct scatterlist src[3], dst[3]; - DECLARE_CRYPTO_WAIT(wait); - struct aead_request *req; - int ret; - - req = aead_request_alloc(crypto->tfm, GFP_KERNEL); - if (!req) - return -ENOMEM; - - /* - * AEAD memory operations: - * +------ AAD -------+------- DATA -----+---- AUTHTAG----+ - * | msg header | plaintext | hdr->authtag | - * | bytes 30h - 5Fh | or | | - * | | cipher | | - * +------------------+------------------+----------------+ - */ - sg_init_table(src, 3); - sg_set_buf(&src[0], &hdr->algo, AAD_LEN); - sg_set_buf(&src[1], src_buf, hdr->msg_sz); - sg_set_buf(&src[2], hdr->authtag, crypto->a_len); - - sg_init_table(dst, 3); - sg_set_buf(&dst[0], &hdr->algo, AAD_LEN); - sg_set_buf(&dst[1], dst_buf, hdr->msg_sz); - sg_set_buf(&dst[2], hdr->authtag, crypto->a_len); - - aead_request_set_ad(req, AAD_LEN); - aead_request_set_tfm(req, crypto->tfm); - aead_request_set_callback(req, 0, crypto_req_done, &wait); - - aead_request_set_crypt(req, src, dst, len, crypto->iv); - ret = crypto_wait_req(enc ? crypto_aead_encrypt(req) : crypto_aead_decrypt(req), &wait); - - aead_request_free(req); - return ret; -} - -static int __enc_payload(struct snp_guest_dev *snp_dev, struct snp_guest_msg *msg, - void *plaintext, size_t len) -{ - struct snp_guest_crypto *crypto = snp_dev->crypto; - struct snp_guest_msg_hdr *hdr = &msg->hdr; - - memset(crypto->iv, 0, crypto->iv_len); - memcpy(crypto->iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno)); - - return enc_dec_message(crypto, msg, plaintext, msg->payload, len, true); -} - -static int dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_msg *msg, - void *plaintext, size_t len) -{ - struct snp_guest_crypto *crypto = snp_dev->crypto; - struct snp_guest_msg_hdr *hdr = &msg->hdr; - - /* Build IV with response buffer sequence number */ - memset(crypto->iv, 0, crypto->iv_len); - memcpy(crypto->iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno)); - - return enc_dec_message(crypto, msg, msg->payload, plaintext, len, false); + return ctx; } static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, u32 sz) { - struct snp_guest_crypto *crypto = snp_dev->crypto; struct snp_guest_msg *resp_msg = &snp_dev->secret_response; struct snp_guest_msg *req_msg = &snp_dev->secret_request; struct snp_guest_msg_hdr *req_msg_hdr = &req_msg->hdr; struct snp_guest_msg_hdr *resp_msg_hdr = &resp_msg->hdr; + struct aesgcm_ctx *ctx = snp_dev->ctx; + u8 iv[GCM_AES_IV_SIZE] = {}; pr_debug("response [seqno %lld type %d version %d sz %d]\n", resp_msg_hdr->msg_seqno, resp_msg_hdr->msg_type, resp_msg_hdr->msg_version, @@ -316,11 +206,16 @@ static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, * If the message size is greater than our buffer length then return * an error. */ - if (unlikely((resp_msg_hdr->msg_sz + crypto->a_len) > sz)) + if (unlikely((resp_msg_hdr->msg_sz + ctx->authsize) > sz)) return -EBADMSG; /* Decrypt the payload */ - return dec_payload(snp_dev, resp_msg, payload, resp_msg_hdr->msg_sz + crypto->a_len); + memcpy(iv, &resp_msg_hdr->msg_seqno, min(sizeof(iv), sizeof(resp_msg_hdr->msg_seqno))); + if (!aesgcm_decrypt(ctx, payload, resp_msg->payload, resp_msg_hdr->msg_sz, + &resp_msg_hdr->algo, AAD_LEN, iv, resp_msg_hdr->authtag)) + return -EBADMSG; + + return 0; } static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 type, @@ -328,6 +223,8 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 { struct snp_guest_msg *msg = &snp_dev->secret_request; struct snp_guest_msg_hdr *hdr = &msg->hdr; + struct aesgcm_ctx *ctx = snp_dev->ctx; + u8 iv[GCM_AES_IV_SIZE] = {}; memset(msg, 0, sizeof(*msg)); @@ -347,7 +244,14 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 pr_debug("request [seqno %lld type %d version %d sz %d]\n", hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz); - return __enc_payload(snp_dev, msg, payload, sz); + if (WARN_ON((sz + ctx->authsize) > sizeof(msg->payload))) + return -EBADMSG; + + memcpy(iv, &hdr->msg_seqno, min(sizeof(iv), sizeof(hdr->msg_seqno))); + aesgcm_encrypt(ctx, msg->payload, payload, sz, &hdr->algo, AAD_LEN, + iv, hdr->authtag); + + return 0; } static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code, @@ -495,7 +399,6 @@ struct snp_req_resp { static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg) { - struct snp_guest_crypto *crypto = snp_dev->crypto; struct snp_report_req *report_req = &snp_dev->req.report; struct snp_report_resp *report_resp; int rc, resp_len; @@ -513,7 +416,7 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io * response payload. Make sure that it has enough space to cover the * authtag. */ - resp_len = sizeof(report_resp->data) + crypto->a_len; + resp_len = sizeof(report_resp->data) + snp_dev->ctx->authsize; report_resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT); if (!report_resp) return -ENOMEM; @@ -534,7 +437,6 @@ e_free: static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg) { struct snp_derived_key_req *derived_key_req = &snp_dev->req.derived_key; - struct snp_guest_crypto *crypto = snp_dev->crypto; struct snp_derived_key_resp derived_key_resp = {0}; int rc, resp_len; /* Response data is 64 bytes and max authsize for GCM is 16 bytes. */ @@ -550,7 +452,7 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque * response payload. Make sure that it has enough space to cover the * authtag. */ - resp_len = sizeof(derived_key_resp.data) + crypto->a_len; + resp_len = sizeof(derived_key_resp.data) + snp_dev->ctx->authsize; if (sizeof(buf) < resp_len) return -ENOMEM; @@ -579,7 +481,6 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques { struct snp_ext_report_req *report_req = &snp_dev->req.ext_report; - struct snp_guest_crypto *crypto = snp_dev->crypto; struct snp_report_resp *report_resp; int ret, npages = 0, resp_len; sockptr_t certs_address; @@ -622,7 +523,7 @@ cmd: * response payload. Make sure that it has enough space to cover the * authtag. */ - resp_len = sizeof(report_resp->data) + crypto->a_len; + resp_len = sizeof(report_resp->data) + snp_dev->ctx->authsize; report_resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT); if (!report_resp) return -ENOMEM; @@ -1147,8 +1048,8 @@ static int __init sev_guest_probe(struct platform_device *pdev) goto e_free_response; ret = -EIO; - snp_dev->crypto = init_crypto(snp_dev, snp_dev->vmpck, VMPCK_KEY_LEN); - if (!snp_dev->crypto) + snp_dev->ctx = snp_init_crypto(snp_dev->vmpck, VMPCK_KEY_LEN); + if (!snp_dev->ctx) goto e_free_cert_data; misc = &snp_dev->misc; @@ -1174,11 +1075,13 @@ static int __init sev_guest_probe(struct platform_device *pdev) ret = misc_register(misc); if (ret) - goto e_free_cert_data; + goto e_free_ctx; dev_info(dev, "Initialized SEV guest driver (using VMPCK%d communication key)\n", vmpck_id); return 0; +e_free_ctx: + kfree(snp_dev->ctx); e_free_cert_data: free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE); e_free_response: @@ -1197,7 +1100,7 @@ static void __exit sev_guest_remove(struct platform_device *pdev) free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE); free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg)); free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg)); - deinit_crypto(snp_dev->crypto); + kfree(snp_dev->ctx); misc_deregister(&snp_dev->misc); } |