summaryrefslogtreecommitdiff
path: root/net/ipv4/udp_bpf.c
blob: bbe6569c9ad346b4b8afaadaceef507a40047289 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
// SPDX-License-Identifier: GPL-2.0
/* Copyright (c) 2020 Cloudflare Ltd https://cloudflare.com */

#include <linux/skmsg.h>
#include <net/sock.h>
#include <net/udp.h>
#include <net/inet_common.h>

#include "udp_impl.h"

static struct proto *udpv6_prot_saved __read_mostly;

static int sk_udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
			  int noblock, int flags, int *addr_len)
{
#if IS_ENABLED(CONFIG_IPV6)
	if (sk->sk_family == AF_INET6)
		return udpv6_prot_saved->recvmsg(sk, msg, len, noblock, flags,
						 addr_len);
#endif
	return udp_prot.recvmsg(sk, msg, len, noblock, flags, addr_len);
}

static bool udp_sk_has_data(struct sock *sk)
{
	return !skb_queue_empty(&udp_sk(sk)->reader_queue) ||
	       !skb_queue_empty(&sk->sk_receive_queue);
}

static bool psock_has_data(struct sk_psock *psock)
{
	return !skb_queue_empty(&psock->ingress_skb) ||
	       !sk_psock_queue_empty(psock);
}

#define udp_msg_has_data(__sk, __psock)	\
		({ udp_sk_has_data(__sk) || psock_has_data(__psock); })

static int udp_msg_wait_data(struct sock *sk, struct sk_psock *psock,
			     long timeo)
{
	DEFINE_WAIT_FUNC(wait, woken_wake_function);
	int ret = 0;

	if (sk->sk_shutdown & RCV_SHUTDOWN)
		return 1;

	if (!timeo)
		return ret;

	add_wait_queue(sk_sleep(sk), &wait);
	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
	ret = udp_msg_has_data(sk, psock);
	if (!ret) {
		wait_woken(&wait, TASK_INTERRUPTIBLE, timeo);
		ret = udp_msg_has_data(sk, psock);
	}
	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
	remove_wait_queue(sk_sleep(sk), &wait);
	return ret;
}

static int udp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
			   int nonblock, int flags, int *addr_len)
{
	struct sk_psock *psock;
	int copied, ret;

	if (unlikely(flags & MSG_ERRQUEUE))
		return inet_recv_error(sk, msg, len, addr_len);

	psock = sk_psock_get(sk);
	if (unlikely(!psock))
		return sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);

	if (!psock_has_data(psock)) {
		ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
		goto out;
	}

msg_bytes_ready:
	copied = sk_msg_recvmsg(sk, psock, msg, len, flags);
	if (!copied) {
		long timeo;
		int data;

		timeo = sock_rcvtimeo(sk, nonblock);
		data = udp_msg_wait_data(sk, psock, timeo);
		if (data) {
			if (psock_has_data(psock))
				goto msg_bytes_ready;
			ret = sk_udp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
			goto out;
		}
		copied = -EAGAIN;
	}
	ret = copied;
out:
	sk_psock_put(sk, psock);
	return ret;
}

enum {
	UDP_BPF_IPV4,
	UDP_BPF_IPV6,
	UDP_BPF_NUM_PROTS,
};

static DEFINE_SPINLOCK(udpv6_prot_lock);
static struct proto udp_bpf_prots[UDP_BPF_NUM_PROTS];

static void udp_bpf_rebuild_protos(struct proto *prot, const struct proto *base)
{
	*prot        = *base;
	prot->close  = sock_map_close;
	prot->recvmsg = udp_bpf_recvmsg;
	prot->sock_is_readable = sk_msg_is_readable;
}

static void udp_bpf_check_v6_needs_rebuild(struct proto *ops)
{
	if (unlikely(ops != smp_load_acquire(&udpv6_prot_saved))) {
		spin_lock_bh(&udpv6_prot_lock);
		if (likely(ops != udpv6_prot_saved)) {
			udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV6], ops);
			smp_store_release(&udpv6_prot_saved, ops);
		}
		spin_unlock_bh(&udpv6_prot_lock);
	}
}

static int __init udp_bpf_v4_build_proto(void)
{
	udp_bpf_rebuild_protos(&udp_bpf_prots[UDP_BPF_IPV4], &udp_prot);
	return 0;
}
late_initcall(udp_bpf_v4_build_proto);

int udp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore)
{
	int family = sk->sk_family == AF_INET ? UDP_BPF_IPV4 : UDP_BPF_IPV6;

	if (restore) {
		sk->sk_write_space = psock->saved_write_space;
		WRITE_ONCE(sk->sk_prot, psock->sk_proto);
		return 0;
	}

	if (sk->sk_family == AF_INET6)
		udp_bpf_check_v6_needs_rebuild(psock->sk_proto);

	WRITE_ONCE(sk->sk_prot, &udp_bpf_prots[family]);
	return 0;
}
EXPORT_SYMBOL_GPL(udp_bpf_update_proto);