diff options
Diffstat (limited to 'net/tls/tls_device.c')
-rw-r--r-- | net/tls/tls_device.c | 304 |
1 files changed, 268 insertions, 36 deletions
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c index a7a8f8e..292742e 100644 --- a/net/tls/tls_device.c +++ b/net/tls/tls_device.c @@ -52,9 +52,12 @@ static DEFINE_SPINLOCK(tls_device_lock); static void tls_device_free_ctx(struct tls_context *ctx) { - struct tls_offload_context *offload_ctx = tls_offload_ctx(ctx); + if (ctx->tx_conf == TLS_HW) + kfree(tls_offload_ctx_tx(ctx)); + + if (ctx->rx_conf == TLS_HW) + kfree(tls_offload_ctx_rx(ctx)); - kfree(offload_ctx); kfree(ctx); } @@ -71,10 +74,11 @@ static void tls_device_gc_task(struct work_struct *work) list_for_each_entry_safe(ctx, tmp, &gc_list, list) { struct net_device *netdev = ctx->netdev; - if (netdev) { + if (netdev && ctx->tx_conf == TLS_HW) { netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX); dev_put(netdev); + ctx->netdev = NULL; } list_del(&ctx->list); @@ -82,6 +86,22 @@ static void tls_device_gc_task(struct work_struct *work) } } +static void tls_device_attach(struct tls_context *ctx, struct sock *sk, + struct net_device *netdev) +{ + if (sk->sk_destruct != tls_device_sk_destruct) { + refcount_set(&ctx->refcount, 1); + dev_hold(netdev); + ctx->netdev = netdev; + spin_lock_irq(&tls_device_lock); + list_add_tail(&ctx->list, &tls_device_list); + spin_unlock_irq(&tls_device_lock); + + ctx->sk_destruct = sk->sk_destruct; + sk->sk_destruct = tls_device_sk_destruct; + } +} + static void tls_device_queue_ctx_destruction(struct tls_context *ctx) { unsigned long flags; @@ -125,7 +145,7 @@ static void destroy_record(struct tls_record_info *record) kfree(record); } -static void delete_all_records(struct tls_offload_context *offload_ctx) +static void delete_all_records(struct tls_offload_context_tx *offload_ctx) { struct tls_record_info *info, *temp; @@ -141,14 +161,14 @@ static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq) { struct tls_context *tls_ctx = tls_get_ctx(sk); struct tls_record_info *info, *temp; - struct tls_offload_context *ctx; + struct tls_offload_context_tx *ctx; u64 deleted_records = 0; unsigned long flags; if (!tls_ctx) return; - ctx = tls_offload_ctx(tls_ctx); + ctx = tls_offload_ctx_tx(tls_ctx); spin_lock_irqsave(&ctx->lock, flags); info = ctx->retransmit_hint; @@ -179,15 +199,17 @@ static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq) void tls_device_sk_destruct(struct sock *sk) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx); + struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); - if (ctx->open_record) - destroy_record(ctx->open_record); + tls_ctx->sk_destruct(sk); - delete_all_records(ctx); - crypto_free_aead(ctx->aead_send); - ctx->sk_destruct(sk); - clean_acked_data_disable(inet_csk(sk)); + if (tls_ctx->tx_conf == TLS_HW) { + if (ctx->open_record) + destroy_record(ctx->open_record); + delete_all_records(ctx); + crypto_free_aead(ctx->aead_send); + clean_acked_data_disable(inet_csk(sk)); + } if (refcount_dec_and_test(&tls_ctx->refcount)) tls_device_queue_ctx_destruction(tls_ctx); @@ -219,7 +241,7 @@ static void tls_append_frag(struct tls_record_info *record, static int tls_push_record(struct sock *sk, struct tls_context *ctx, - struct tls_offload_context *offload_ctx, + struct tls_offload_context_tx *offload_ctx, struct tls_record_info *record, struct page_frag *pfrag, int flags, @@ -264,7 +286,7 @@ static int tls_push_record(struct sock *sk, return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags); } -static int tls_create_new_record(struct tls_offload_context *offload_ctx, +static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx, struct page_frag *pfrag, size_t prepend_size) { @@ -290,7 +312,7 @@ static int tls_create_new_record(struct tls_offload_context *offload_ctx, } static int tls_do_allocation(struct sock *sk, - struct tls_offload_context *offload_ctx, + struct tls_offload_context_tx *offload_ctx, struct page_frag *pfrag, size_t prepend_size) { @@ -324,7 +346,7 @@ static int tls_push_data(struct sock *sk, unsigned char record_type) { struct tls_context *tls_ctx = tls_get_ctx(sk); - struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx); + struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx); int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST; int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE); struct tls_record_info *record = ctx->open_record; @@ -477,7 +499,7 @@ out: return rc; } -struct tls_record_info *tls_get_record(struct tls_offload_context *context, +struct tls_record_info *tls_get_record(struct tls_offload_context_tx *context, u32 seq, u64 *p_record_sn) { u64 record_sn = context->hint_record_sn; @@ -520,11 +542,123 @@ static int tls_device_push_pending_record(struct sock *sk, int flags) return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA); } +void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct net_device *netdev = tls_ctx->netdev; + struct tls_offload_context_rx *rx_ctx; + u32 is_req_pending; + s64 resync_req; + u32 req_seq; + + if (tls_ctx->rx_conf != TLS_HW) + return; + + rx_ctx = tls_offload_ctx_rx(tls_ctx); + resync_req = atomic64_read(&rx_ctx->resync_req); + req_seq = ntohl(resync_req >> 32) - ((u32)TLS_HEADER_SIZE - 1); + is_req_pending = resync_req; + + if (unlikely(is_req_pending) && req_seq == seq && + atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0)) + netdev->tlsdev_ops->tls_dev_resync_rx(netdev, sk, + seq + TLS_HEADER_SIZE - 1, + rcd_sn); +} + +static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb) +{ + struct strp_msg *rxm = strp_msg(skb); + int err = 0, offset = rxm->offset, copy, nsg; + struct sk_buff *skb_iter, *unused; + struct scatterlist sg[1]; + char *orig_buf, *buf; + + orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + + TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation); + if (!orig_buf) + return -ENOMEM; + buf = orig_buf; + + nsg = skb_cow_data(skb, 0, &unused); + if (unlikely(nsg < 0)) { + err = nsg; + goto free_buf; + } + + sg_init_table(sg, 1); + sg_set_buf(&sg[0], buf, + rxm->full_len + TLS_HEADER_SIZE + + TLS_CIPHER_AES_GCM_128_IV_SIZE); + skb_copy_bits(skb, offset, buf, + TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE); + + /* We are interested only in the decrypted data not the auth */ + err = decrypt_skb(sk, skb, sg); + if (err != -EBADMSG) + goto free_buf; + else + err = 0; + + copy = min_t(int, skb_pagelen(skb) - offset, + rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE); + + if (skb->decrypted) + skb_store_bits(skb, offset, buf, copy); + + offset += copy; + buf += copy; + + skb_walk_frags(skb, skb_iter) { + copy = min_t(int, skb_iter->len, + rxm->full_len - offset + rxm->offset - + TLS_CIPHER_AES_GCM_128_TAG_SIZE); + + if (skb_iter->decrypted) + skb_store_bits(skb_iter, offset, buf, copy); + + offset += copy; + buf += copy; + } + +free_buf: + kfree(orig_buf); + return err; +} + +int tls_device_decrypted(struct sock *sk, struct sk_buff *skb) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx); + int is_decrypted = skb->decrypted; + int is_encrypted = !is_decrypted; + struct sk_buff *skb_iter; + + /* Skip if it is already decrypted */ + if (ctx->sw.decrypted) + return 0; + + /* Check if all the data is decrypted already */ + skb_walk_frags(skb, skb_iter) { + is_decrypted &= skb_iter->decrypted; + is_encrypted &= !skb_iter->decrypted; + } + + ctx->sw.decrypted |= is_decrypted; + + /* Return immedeatly if the record is either entirely plaintext or + * entirely ciphertext. Otherwise handle reencrypt partially decrypted + * record. + */ + return (is_encrypted || is_decrypted) ? 0 : + tls_device_reencrypt(sk, skb); +} + int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) { u16 nonce_size, tag_size, iv_size, rec_seq_size; struct tls_record_info *start_marker_record; - struct tls_offload_context *offload_ctx; + struct tls_offload_context_tx *offload_ctx; struct tls_crypto_info *crypto_info; struct net_device *netdev; char *iv, *rec_seq; @@ -546,7 +680,7 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) goto out; } - offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE, GFP_KERNEL); + offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL); if (!offload_ctx) { rc = -ENOMEM; goto free_marker_record; @@ -582,12 +716,11 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); ctx->tx.rec_seq_size = rec_seq_size; - ctx->tx.rec_seq = kmalloc(rec_seq_size, GFP_KERNEL); + ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); if (!ctx->tx.rec_seq) { rc = -ENOMEM; goto free_iv; } - memcpy(ctx->tx.rec_seq, rec_seq, rec_seq_size); rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info); if (rc) @@ -609,7 +742,6 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) clean_acked_data_enable(inet_csk(sk), &tls_icsk_clean_acked); ctx->push_pending_record = tls_device_push_pending_record; - offload_ctx->sk_destruct = sk->sk_destruct; /* TLS offload is greatly simplified if we don't send * SKBs where only part of the payload needs to be encrypted. @@ -619,8 +751,6 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) if (skb) TCP_SKB_CB(skb)->eor = 1; - refcount_set(&ctx->refcount, 1); - /* We support starting offload on multiple sockets * concurrently, so we only need a read lock here. * This lock must precede get_netdev_for_sock to prevent races between @@ -655,19 +785,14 @@ int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) if (rc) goto release_netdev; - ctx->netdev = netdev; - - spin_lock_irq(&tls_device_lock); - list_add_tail(&ctx->list, &tls_device_list); - spin_unlock_irq(&tls_device_lock); + tls_device_attach(ctx, sk, netdev); - sk->sk_validate_xmit_skb = tls_validate_xmit_skb; /* following this assignment tls_is_sk_tx_device_offloaded * will return true and the context might be accessed * by the netdev's xmit function. */ - smp_store_release(&sk->sk_destruct, - &tls_device_sk_destruct); + smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb); + dev_put(netdev); up_read(&device_offload_lock); goto out; @@ -690,6 +815,105 @@ out: return rc; } +int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx) +{ + struct tls_offload_context_rx *context; + struct net_device *netdev; + int rc = 0; + + /* We support starting offload on multiple sockets + * concurrently, so we only need a read lock here. + * This lock must precede get_netdev_for_sock to prevent races between + * NETDEV_DOWN and setsockopt. + */ + down_read(&device_offload_lock); + netdev = get_netdev_for_sock(sk); + if (!netdev) { + pr_err_ratelimited("%s: netdev not found\n", __func__); + rc = -EINVAL; + goto release_lock; + } + + if (!(netdev->features & NETIF_F_HW_TLS_RX)) { + pr_err_ratelimited("%s: netdev %s with no TLS offload\n", + __func__, netdev->name); + rc = -ENOTSUPP; + goto release_netdev; + } + + /* Avoid offloading if the device is down + * We don't want to offload new flows after + * the NETDEV_DOWN event + */ + if (!(netdev->flags & IFF_UP)) { + rc = -EINVAL; + goto release_netdev; + } + + context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL); + if (!context) { + rc = -ENOMEM; + goto release_netdev; + } + + ctx->priv_ctx_rx = context; + rc = tls_set_sw_offload(sk, ctx, 0); + if (rc) + goto release_ctx; + + rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX, + &ctx->crypto_recv, + tcp_sk(sk)->copied_seq); + if (rc) { + pr_err_ratelimited("%s: The netdev has refused to offload this socket\n", + __func__); + goto free_sw_resources; + } + + tls_device_attach(ctx, sk, netdev); + goto release_netdev; + +free_sw_resources: + tls_sw_free_resources_rx(sk); +release_ctx: + ctx->priv_ctx_rx = NULL; +release_netdev: + dev_put(netdev); +release_lock: + up_read(&device_offload_lock); + return rc; +} + +void tls_device_offload_cleanup_rx(struct sock *sk) +{ + struct tls_context *tls_ctx = tls_get_ctx(sk); + struct net_device *netdev; + + down_read(&device_offload_lock); + netdev = tls_ctx->netdev; + if (!netdev) + goto out; + + if (!(netdev->features & NETIF_F_HW_TLS_RX)) { + pr_err_ratelimited("%s: device is missing NETIF_F_HW_TLS_RX cap\n", + __func__); + goto out; + } + + netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx, + TLS_OFFLOAD_CTX_DIR_RX); + + if (tls_ctx->tx_conf != TLS_HW) { + dev_put(netdev); + tls_ctx->netdev = NULL; + } +out: + up_read(&device_offload_lock); + kfree(tls_ctx->rx.rec_seq); + kfree(tls_ctx->rx.iv); + tls_sw_release_resources_rx(sk); +} + static int tls_device_down(struct net_device *netdev) { struct tls_context *ctx, *tmp; @@ -710,8 +934,12 @@ static int tls_device_down(struct net_device *netdev) spin_unlock_irqrestore(&tls_device_lock, flags); list_for_each_entry_safe(ctx, tmp, &list, list) { - netdev->tlsdev_ops->tls_dev_del(netdev, ctx, - TLS_OFFLOAD_CTX_DIR_TX); + if (ctx->tx_conf == TLS_HW) + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, + TLS_OFFLOAD_CTX_DIR_TX); + if (ctx->rx_conf == TLS_HW) + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, + TLS_OFFLOAD_CTX_DIR_RX); ctx->netdev = NULL; dev_put(netdev); list_del_init(&ctx->list); @@ -732,12 +960,16 @@ static int tls_dev_event(struct notifier_block *this, unsigned long event, { struct net_device *dev = netdev_notifier_info_to_dev(ptr); - if (!(dev->features & NETIF_F_HW_TLS_TX)) + if (!(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX))) return NOTIFY_DONE; switch (event) { case NETDEV_REGISTER: case NETDEV_FEAT_CHANGE: + if ((dev->features & NETIF_F_HW_TLS_RX) && + !dev->tlsdev_ops->tls_dev_resync_rx) + return NOTIFY_BAD; + if (dev->tlsdev_ops && dev->tlsdev_ops->tls_dev_add && dev->tlsdev_ops->tls_dev_del) |