summaryrefslogtreecommitdiff
path: root/tools/testing/selftests/net/lib
diff options
context:
space:
mode:
Diffstat (limited to 'tools/testing/selftests/net/lib')
-rw-r--r--tools/testing/selftests/net/lib/.gitignore2
-rw-r--r--tools/testing/selftests/net/lib/Makefile16
-rw-r--r--tools/testing/selftests/net/lib/csum.c2
-rw-r--r--tools/testing/selftests/net/lib/gro.c1837
-rw-r--r--tools/testing/selftests/net/lib/ksft.h58
-rwxr-xr-xtools/testing/selftests/net/lib/ksft_setup_loopback.sh111
-rw-r--r--tools/testing/selftests/net/lib/py/__init__.py40
-rw-r--r--tools/testing/selftests/net/lib/py/bpf.py68
-rw-r--r--tools/testing/selftests/net/lib/py/ksft.py271
-rw-r--r--tools/testing/selftests/net/lib/py/netns.py18
-rw-r--r--tools/testing/selftests/net/lib/py/nsim.py2
-rw-r--r--tools/testing/selftests/net/lib/py/utils.py250
-rw-r--r--tools/testing/selftests/net/lib/py/ynl.py34
-rw-r--r--tools/testing/selftests/net/lib/sh/defer.sh20
-rw-r--r--tools/testing/selftests/net/lib/xdp_dummy.bpf.c19
-rw-r--r--tools/testing/selftests/net/lib/xdp_helper.c131
-rw-r--r--tools/testing/selftests/net/lib/xdp_metadata.bpf.c163
-rw-r--r--tools/testing/selftests/net/lib/xdp_native.bpf.c685
18 files changed, 3650 insertions, 77 deletions
diff --git a/tools/testing/selftests/net/lib/.gitignore b/tools/testing/selftests/net/lib/.gitignore
index 1ebc6187f421..6cd2b762af5d 100644
--- a/tools/testing/selftests/net/lib/.gitignore
+++ b/tools/testing/selftests/net/lib/.gitignore
@@ -1,2 +1,4 @@
# SPDX-License-Identifier: GPL-2.0-only
csum
+gro
+xdp_helper
diff --git a/tools/testing/selftests/net/lib/Makefile b/tools/testing/selftests/net/lib/Makefile
index bc6b6762baf3..ff83603397d0 100644
--- a/tools/testing/selftests/net/lib/Makefile
+++ b/tools/testing/selftests/net/lib/Makefile
@@ -5,11 +5,21 @@ CFLAGS += -I../../../../../usr/include/ $(KHDR_INCLUDES)
# Additional include paths needed by kselftest.h
CFLAGS += -I../../
-TEST_FILES := ../../../../../Documentation/netlink/specs
-TEST_FILES += ../../../../net/ynl
+TEST_FILES := \
+ ../../../../net/ynl \
+ ../../../../../Documentation/netlink/specs \
+ ksft_setup_loopback.sh \
+# end of TEST_FILES
-TEST_GEN_FILES += csum
+TEST_GEN_FILES := \
+ $(patsubst %.c,%.o,$(wildcard *.bpf.c)) \
+ csum \
+ gro \
+ xdp_helper \
+# end of TEST_GEN_FILES
TEST_INCLUDES := $(wildcard py/*.py sh/*.sh)
include ../../lib.mk
+
+include ../bpf.mk
diff --git a/tools/testing/selftests/net/lib/csum.c b/tools/testing/selftests/net/lib/csum.c
index 27437590eeb5..e28884ce3ab3 100644
--- a/tools/testing/selftests/net/lib/csum.c
+++ b/tools/testing/selftests/net/lib/csum.c
@@ -707,7 +707,7 @@ static uint32_t recv_get_packet_csum_status(struct msghdr *msg)
cm->cmsg_level, cm->cmsg_type);
if (cm->cmsg_len != CMSG_LEN(sizeof(struct tpacket_auxdata)))
- error(1, 0, "cmsg: len=%lu expected=%lu",
+ error(1, 0, "cmsg: len=%zu expected=%zu",
cm->cmsg_len, CMSG_LEN(sizeof(struct tpacket_auxdata)));
aux = (void *)CMSG_DATA(cm);
diff --git a/tools/testing/selftests/net/lib/gro.c b/tools/testing/selftests/net/lib/gro.c
new file mode 100644
index 000000000000..11b16ae5f0e8
--- /dev/null
+++ b/tools/testing/selftests/net/lib/gro.c
@@ -0,0 +1,1837 @@
+// SPDX-License-Identifier: GPL-2.0
+/*
+ * This testsuite provides conformance testing for GRO coalescing.
+ *
+ * Test cases:
+ *
+ * data_*:
+ * Data packets of the same size and same header setup with correct
+ * sequence numbers coalesce. The one exception being the last data
+ * packet coalesced: it can be smaller than the rest and coalesced
+ * as long as it is in the same flow.
+ * - data_same: same size packets coalesce
+ * - data_lrg_sml: large then small coalesces
+ * - data_lrg_1byte: large then 1 byte coalesces (Ethernet padding)
+ * - data_sml_lrg: small then large doesn't coalesce
+ * - data_burst: two bursts of two, separated by 100ms
+ *
+ * ack:
+ * Pure ACK does not coalesce.
+ *
+ * flags_*:
+ * No packets with PSH, SYN, URG, RST, CWR set will be coalesced.
+ * - flags_psh, flags_syn, flags_rst, flags_urg, flags_cwr
+ *
+ * tcp_*:
+ * Packets with incorrect checksum, non-consecutive seqno and
+ * different TCP header options shouldn't coalesce. Nit: given that
+ * some extension headers have paddings, such as timestamp, headers
+ * that are padded differently would not be coalesced.
+ * - tcp_csum: incorrect checksum
+ * - tcp_seq: non-consecutive sequence numbers
+ * - tcp_ts: different timestamps
+ * - tcp_opt: different TCP options
+ *
+ * ip_*:
+ * Packets with different (ECN, TTL, TOS) header, IP options or
+ * IP fragments shouldn't coalesce.
+ * - ip_ecn, ip_tos: shared between IPv4/IPv6
+ * - ip_csum: IPv4 only, bad IP header checksum
+ * - ip_ttl, ip_opt, ip_frag4: IPv4 only
+ * - ip_id_df*: IPv4 IP ID field coalescing tests
+ * - ip_frag6, ip_v6ext_*: IPv6 only
+ *
+ * large_*:
+ * Packets larger than GRO_MAX_SIZE packets shouldn't coalesce.
+ * - large_max: exceeding max size
+ * - large_rem: remainder handling
+ *
+ * single, capacity:
+ * Boring cases used to test coalescing machinery itself and stats
+ * more than protocol behavior.
+ *
+ * MSS is defined as 4096 - header because if it is too small
+ * (i.e. 1500 MTU - header), it will result in many packets,
+ * increasing the "large" test case's flakiness. This is because
+ * due to time sensitivity in the coalescing window, the receiver
+ * may not coalesce all of the packets.
+ *
+ * Note the timing issue applies to all of the test cases, so some
+ * flakiness is to be expected.
+ *
+ */
+
+#define _GNU_SOURCE
+
+#include <arpa/inet.h>
+#include <errno.h>
+#include <error.h>
+#include <getopt.h>
+#include <linux/filter.h>
+#include <linux/if_packet.h>
+#include <linux/ipv6.h>
+#include <linux/net_tstamp.h>
+#include <net/ethernet.h>
+#include <net/if.h>
+#include <netinet/in.h>
+#include <netinet/ip.h>
+#include <netinet/ip6.h>
+#include <netinet/tcp.h>
+#include <stdbool.h>
+#include <stddef.h>
+#include <stdio.h>
+#include <stdarg.h>
+#include <string.h>
+#include <time.h>
+#include <unistd.h>
+
+#include "kselftest.h"
+#include "ksft.h"
+
+#define DPORT 8000
+#define SPORT 1500
+#define PAYLOAD_LEN 100
+#define NUM_PACKETS 4
+#define START_SEQ 100
+#define START_ACK 100
+#define ETH_P_NONE 0
+#define ASSUMED_MTU 4096
+#define MAX_MSS (ASSUMED_MTU - sizeof(struct iphdr) - sizeof(struct tcphdr))
+#define MAX_HDR_LEN \
+ (ETH_HLEN + sizeof(struct ipv6hdr) * 2 + sizeof(struct tcphdr))
+#define MAX_LARGE_PKT_CNT ((IP_MAXPACKET - (MAX_HDR_LEN - ETH_HLEN)) / \
+ (ASSUMED_MTU - (MAX_HDR_LEN - ETH_HLEN)))
+#define MIN_EXTHDR_SIZE 8
+#define EXT_PAYLOAD_1 "\x00\x00\x00\x00\x00\x00"
+#define EXT_PAYLOAD_2 "\x11\x11\x11\x11\x11\x11"
+
+#define ipv6_optlen(p) (((p)->hdrlen+1) << 3) /* calculate IPv6 extension header len */
+#define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))
+
+enum flush_id_case {
+ FLUSH_ID_DF1_INC,
+ FLUSH_ID_DF1_FIXED,
+ FLUSH_ID_DF0_INC,
+ FLUSH_ID_DF0_FIXED,
+ FLUSH_ID_DF1_INC_FIXED,
+ FLUSH_ID_DF1_FIXED_INC,
+};
+
+static const char *addr6_src = "fdaa::2";
+static const char *addr6_dst = "fdaa::1";
+static const char *addr4_src = "192.168.1.200";
+static const char *addr4_dst = "192.168.1.100";
+static int proto = -1;
+static uint8_t src_mac[ETH_ALEN], dst_mac[ETH_ALEN];
+static char *testname = "data";
+static char *ifname = "eth0";
+static char *smac = "aa:00:00:00:00:02";
+static char *dmac = "aa:00:00:00:00:01";
+static bool verbose;
+static bool tx_socket = true;
+static int tcp_offset = -1;
+static int total_hdr_len = -1;
+static int ethhdr_proto = -1;
+static bool ipip;
+static bool ip6ip6;
+static uint64_t txtime_ns;
+static int num_flows = 4;
+static bool order_check;
+
+#define CAPACITY_PAYLOAD_LEN 200
+
+#define TXTIME_DELAY_MS 5
+
+/* Max TCP payload that GRO will coalesce. The outer header overhead
+ * varies by encapsulation, reducing the effective max payload.
+ */
+static int max_payload(void)
+{
+ return IP_MAXPACKET - (total_hdr_len - ETH_HLEN);
+}
+
+static int calc_mss(void)
+{
+ return ASSUMED_MTU - (total_hdr_len - ETH_HLEN);
+}
+
+static int num_large_pkt(void)
+{
+ return max_payload() / calc_mss();
+}
+
+static void vlog(const char *fmt, ...)
+{
+ va_list args;
+
+ if (verbose) {
+ va_start(args, fmt);
+ vfprintf(stderr, fmt, args);
+ va_end(args);
+ }
+}
+
+static void setup_sock_filter(int fd)
+{
+ const int dport_off = tcp_offset + offsetof(struct tcphdr, dest);
+ const int ethproto_off = offsetof(struct ethhdr, h_proto);
+ int optlen = 0;
+ int ipproto_off, opt_ipproto_off;
+
+ if (proto == PF_INET)
+ ipproto_off = tcp_offset - sizeof(struct iphdr) +
+ offsetof(struct iphdr, protocol);
+ else
+ ipproto_off = tcp_offset - sizeof(struct ipv6hdr) +
+ offsetof(struct ipv6hdr, nexthdr);
+
+ /* Overridden later if exthdrs are used: */
+ opt_ipproto_off = ipproto_off;
+
+ if (strcmp(testname, "ip_opt") == 0) {
+ optlen = sizeof(struct ip_timestamp);
+ } else if (strcmp(testname, "ip_frag6") == 0 ||
+ strcmp(testname, "ip_v6ext_same") == 0 ||
+ strcmp(testname, "ip_v6ext_diff") == 0) {
+ BUILD_BUG_ON(sizeof(struct ip6_hbh) > MIN_EXTHDR_SIZE);
+ BUILD_BUG_ON(sizeof(struct ip6_dest) > MIN_EXTHDR_SIZE);
+ BUILD_BUG_ON(sizeof(struct ip6_frag) > MIN_EXTHDR_SIZE);
+
+ /* same size for HBH and Fragment extension header types */
+ optlen = MIN_EXTHDR_SIZE;
+ opt_ipproto_off = ETH_HLEN + sizeof(struct ipv6hdr)
+ + offsetof(struct ip6_ext, ip6e_nxt);
+ }
+
+ /* this filter validates the following:
+ * - packet is IPv4/IPv6 according to the running test.
+ * - packet is TCP. Also handles the case of one extension header and then TCP.
+ * - checks the packet tcp dport equals to DPORT. Also handles the case of one
+ * extension header and then TCP.
+ */
+ struct sock_filter filter[] = {
+ BPF_STMT(BPF_LD + BPF_H + BPF_ABS, ethproto_off),
+ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, ntohs(ethhdr_proto), 0, 9),
+ BPF_STMT(BPF_LD + BPF_B + BPF_ABS, ipproto_off),
+ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, IPPROTO_TCP, 2, 0),
+ BPF_STMT(BPF_LD + BPF_B + BPF_ABS, opt_ipproto_off),
+ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, IPPROTO_TCP, 0, 5),
+ BPF_STMT(BPF_LD + BPF_H + BPF_ABS, dport_off),
+ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, DPORT, 2, 0),
+ BPF_STMT(BPF_LD + BPF_H + BPF_ABS, dport_off + optlen),
+ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, DPORT, 0, 1),
+ BPF_STMT(BPF_RET + BPF_K, 0xFFFFFFFF),
+ BPF_STMT(BPF_RET + BPF_K, 0),
+ };
+
+ struct sock_fprog bpf = {
+ .len = ARRAY_SIZE(filter),
+ .filter = filter,
+ };
+
+ if (setsockopt(fd, SOL_SOCKET, SO_ATTACH_FILTER, &bpf, sizeof(bpf)) < 0)
+ error(1, errno, "error setting filter");
+}
+
+static uint32_t checksum_nofold(void *data, size_t len, uint32_t sum)
+{
+ uint16_t *words = data;
+ int i;
+
+ for (i = 0; i < len / 2; i++)
+ sum += words[i];
+ if (len & 1)
+ sum += ((char *)data)[len - 1];
+ return sum;
+}
+
+static uint16_t checksum_fold(void *data, size_t len, uint32_t sum)
+{
+ sum = checksum_nofold(data, len, sum);
+ while (sum > 0xFFFF)
+ sum = (sum & 0xFFFF) + (sum >> 16);
+ return ~sum;
+}
+
+static uint16_t tcp_checksum(void *buf, int payload_len)
+{
+ struct pseudo_header6 {
+ struct in6_addr saddr;
+ struct in6_addr daddr;
+ uint16_t protocol;
+ uint16_t payload_len;
+ } ph6;
+ struct pseudo_header4 {
+ struct in_addr saddr;
+ struct in_addr daddr;
+ uint16_t protocol;
+ uint16_t payload_len;
+ } ph4;
+ uint32_t sum = 0;
+
+ if (proto == PF_INET6) {
+ if (inet_pton(AF_INET6, addr6_src, &ph6.saddr) != 1)
+ error(1, errno, "inet_pton6 source ip pseudo");
+ if (inet_pton(AF_INET6, addr6_dst, &ph6.daddr) != 1)
+ error(1, errno, "inet_pton6 dest ip pseudo");
+ ph6.protocol = htons(IPPROTO_TCP);
+ ph6.payload_len = htons(sizeof(struct tcphdr) + payload_len);
+
+ sum = checksum_nofold(&ph6, sizeof(ph6), 0);
+ } else if (proto == PF_INET) {
+ if (inet_pton(AF_INET, addr4_src, &ph4.saddr) != 1)
+ error(1, errno, "inet_pton source ip pseudo");
+ if (inet_pton(AF_INET, addr4_dst, &ph4.daddr) != 1)
+ error(1, errno, "inet_pton dest ip pseudo");
+ ph4.protocol = htons(IPPROTO_TCP);
+ ph4.payload_len = htons(sizeof(struct tcphdr) + payload_len);
+
+ sum = checksum_nofold(&ph4, sizeof(ph4), 0);
+ }
+
+ return checksum_fold(buf, sizeof(struct tcphdr) + payload_len, sum);
+}
+
+static void read_MAC(uint8_t *mac_addr, char *mac)
+{
+ if (sscanf(mac, "%hhx:%hhx:%hhx:%hhx:%hhx:%hhx",
+ &mac_addr[0], &mac_addr[1], &mac_addr[2],
+ &mac_addr[3], &mac_addr[4], &mac_addr[5]) != 6)
+ error(1, 0, "sscanf");
+}
+
+static void fill_datalinklayer(void *buf)
+{
+ struct ethhdr *eth = buf;
+
+ memcpy(eth->h_dest, dst_mac, ETH_ALEN);
+ memcpy(eth->h_source, src_mac, ETH_ALEN);
+ eth->h_proto = ethhdr_proto;
+}
+
+static void fill_networklayer(void *buf, int payload_len, int protocol)
+{
+ struct ipv6hdr *ip6h = buf;
+ struct iphdr *iph = buf;
+
+ if (proto == PF_INET6) {
+ memset(ip6h, 0, sizeof(*ip6h));
+
+ ip6h->version = 6;
+ ip6h->payload_len = htons(sizeof(struct tcphdr) + payload_len);
+ ip6h->nexthdr = protocol;
+ ip6h->hop_limit = 8;
+ if (inet_pton(AF_INET6, addr6_src, &ip6h->saddr) != 1)
+ error(1, errno, "inet_pton source ip6");
+ if (inet_pton(AF_INET6, addr6_dst, &ip6h->daddr) != 1)
+ error(1, errno, "inet_pton dest ip6");
+ } else if (proto == PF_INET) {
+ memset(iph, 0, sizeof(*iph));
+
+ iph->version = 4;
+ iph->ihl = 5;
+ iph->ttl = 8;
+ iph->protocol = protocol;
+ iph->tot_len = htons(sizeof(struct tcphdr) +
+ payload_len + sizeof(struct iphdr));
+ iph->frag_off = htons(0x4000); /* DF = 1, MF = 0 */
+ if (inet_pton(AF_INET, addr4_src, &iph->saddr) != 1)
+ error(1, errno, "inet_pton source ip");
+ if (inet_pton(AF_INET, addr4_dst, &iph->daddr) != 1)
+ error(1, errno, "inet_pton dest ip");
+ iph->check = checksum_fold(buf, sizeof(struct iphdr), 0);
+ }
+}
+
+static void fill_transportlayer(void *buf, int seq_offset, int ack_offset,
+ int payload_len, int fin)
+{
+ struct tcphdr *tcph = buf;
+
+ memset(tcph, 0, sizeof(*tcph));
+
+ tcph->source = htons(SPORT);
+ tcph->dest = htons(DPORT);
+ tcph->seq = ntohl(START_SEQ + seq_offset);
+ tcph->ack_seq = ntohl(START_ACK + ack_offset);
+ tcph->ack = 1;
+ tcph->fin = fin;
+ tcph->doff = 5;
+ tcph->window = htons(TCP_MAXWIN);
+ tcph->urg_ptr = 0;
+ tcph->check = tcp_checksum(tcph, payload_len);
+}
+
+static void write_packet(int fd, char *buf, int len, struct sockaddr_ll *daddr)
+{
+ char control[CMSG_SPACE(sizeof(uint64_t))];
+ struct msghdr msg = {};
+ struct iovec iov = {};
+ struct cmsghdr *cm;
+ int ret = -1;
+
+ iov.iov_base = buf;
+ iov.iov_len = len;
+
+ msg.msg_iov = &iov;
+ msg.msg_iovlen = 1;
+ msg.msg_name = daddr;
+ msg.msg_namelen = sizeof(*daddr);
+
+ if (txtime_ns) {
+ memset(control, 0, sizeof(control));
+ msg.msg_control = control;
+ msg.msg_controllen = sizeof(control);
+
+ cm = CMSG_FIRSTHDR(&msg);
+ cm->cmsg_level = SOL_SOCKET;
+ cm->cmsg_type = SCM_TXTIME;
+ cm->cmsg_len = CMSG_LEN(sizeof(uint64_t));
+ memcpy(CMSG_DATA(cm), &txtime_ns, sizeof(txtime_ns));
+ }
+
+ ret = sendmsg(fd, &msg, 0);
+ if (ret == -1)
+ error(1, errno, "sendmsg failure");
+ if (ret != len)
+ error(1, 0, "sendmsg wrong length: %d vs %d", ret, len);
+}
+
+static void create_packet(void *buf, int seq_offset, int ack_offset,
+ int payload_len, int fin)
+{
+ int ip_hdr_len = (proto == PF_INET) ?
+ sizeof(struct iphdr) : sizeof(struct ipv6hdr);
+ int inner_ip_off = tcp_offset - ip_hdr_len;
+
+ memset(buf, 0, total_hdr_len);
+ memset(buf + total_hdr_len, 'a', payload_len);
+
+ fill_transportlayer(buf + tcp_offset, seq_offset, ack_offset,
+ payload_len, fin);
+
+ fill_networklayer(buf + inner_ip_off, payload_len, IPPROTO_TCP);
+ if (inner_ip_off > ETH_HLEN) {
+ int encap_proto = (proto == PF_INET) ?
+ IPPROTO_IPIP : IPPROTO_IPV6;
+
+ fill_networklayer(buf + ETH_HLEN,
+ payload_len + ip_hdr_len, encap_proto);
+ }
+
+ fill_datalinklayer(buf);
+}
+
+static void create_capacity_packet(void *buf, int flow_id, int pkt_idx, int psh)
+{
+ int seq_offset = pkt_idx * CAPACITY_PAYLOAD_LEN;
+ struct tcphdr *tcph;
+
+ create_packet(buf, seq_offset, 0, CAPACITY_PAYLOAD_LEN, 0);
+
+ /* Customize for this flow id */
+ memset(buf + total_hdr_len, 'a' + flow_id, CAPACITY_PAYLOAD_LEN);
+
+ tcph = buf + tcp_offset;
+ tcph->source = htons(SPORT + flow_id);
+ tcph->psh = psh;
+ tcph->check = 0;
+ tcph->check = tcp_checksum(tcph, CAPACITY_PAYLOAD_LEN);
+}
+
+/* Send a capacity test, 2 packets per flow, all first packets then all second:
+ * A1 B1 C1 D1 ... A2 B2 C2 D2 ...
+ */
+static void send_capacity(int fd, struct sockaddr_ll *daddr)
+{
+ static char buf[MAX_HDR_LEN + CAPACITY_PAYLOAD_LEN];
+ int pkt_size = total_hdr_len + CAPACITY_PAYLOAD_LEN;
+ int i;
+
+ /* Send first packet of each flow (no PSH) */
+ for (i = 0; i < num_flows; i++) {
+ create_capacity_packet(buf, i, 0, 0);
+ write_packet(fd, buf, pkt_size, daddr);
+ }
+
+ /* Send second packet of each flow (with PSH to flush) */
+ for (i = 0; i < num_flows; i++) {
+ create_capacity_packet(buf, i, 1, 1);
+ write_packet(fd, buf, pkt_size, daddr);
+ }
+}
+
+#ifndef TH_CWR
+#define TH_CWR 0x80
+#endif
+static void set_flags(struct tcphdr *tcph, int payload_len, int psh, int syn,
+ int rst, int urg, int cwr)
+{
+ tcph->psh = psh;
+ tcph->syn = syn;
+ tcph->rst = rst;
+ tcph->urg = urg;
+ if (cwr)
+ tcph->th_flags |= TH_CWR;
+ else
+ tcph->th_flags &= ~TH_CWR;
+ tcph->check = 0;
+ tcph->check = tcp_checksum(tcph, payload_len);
+}
+
+/* send extra flags of the (NUM_PACKETS / 2) and (NUM_PACKETS / 2 - 1)
+ * pkts, not first and not last pkt
+ */
+static void send_flags(int fd, struct sockaddr_ll *daddr, int psh, int syn,
+ int rst, int urg, int cwr)
+{
+ static char flag_buf[2][MAX_HDR_LEN + PAYLOAD_LEN];
+ static char buf[MAX_HDR_LEN + PAYLOAD_LEN];
+ int payload_len, pkt_size, i;
+ struct tcphdr *tcph;
+ int flag[2];
+
+ payload_len = PAYLOAD_LEN * (psh || cwr);
+ pkt_size = total_hdr_len + payload_len;
+ flag[0] = NUM_PACKETS / 2;
+ flag[1] = NUM_PACKETS / 2 - 1;
+
+ /* Create and configure packets with flags
+ */
+ for (i = 0; i < 2; i++) {
+ if (flag[i] > 0) {
+ create_packet(flag_buf[i], flag[i] * payload_len, 0,
+ payload_len, 0);
+ tcph = (struct tcphdr *)(flag_buf[i] + tcp_offset);
+ set_flags(tcph, payload_len, psh, syn, rst, urg, cwr);
+ }
+ }
+
+ for (i = 0; i < NUM_PACKETS + 1; i++) {
+ if (i == flag[0]) {
+ write_packet(fd, flag_buf[0], pkt_size, daddr);
+ continue;
+ } else if (i == flag[1] && cwr) {
+ write_packet(fd, flag_buf[1], pkt_size, daddr);
+ continue;
+ }
+ create_packet(buf, i * PAYLOAD_LEN, 0, PAYLOAD_LEN, 0);
+ write_packet(fd, buf, total_hdr_len + PAYLOAD_LEN, daddr);
+ }
+}
+
+/* Test for data of same length, smaller than previous
+ * and of different lengths
+ */
+static void send_data_pkts(int fd, struct sockaddr_ll *daddr,
+ int payload_len1, int payload_len2)
+{
+ static char buf[ETH_HLEN + IP_MAXPACKET];
+
+ create_packet(buf, 0, 0, payload_len1, 0);
+ write_packet(fd, buf, total_hdr_len + payload_len1, daddr);
+ create_packet(buf, payload_len1, 0, payload_len2, 0);
+ write_packet(fd, buf, total_hdr_len + payload_len2, daddr);
+}
+
+/* If incoming segments make tracked segment length exceed
+ * legal IP datagram length, do not coalesce
+ */
+static void send_large(int fd, struct sockaddr_ll *daddr, int remainder)
+{
+ static char pkts[MAX_LARGE_PKT_CNT][MAX_HDR_LEN + MAX_MSS];
+ static char new_seg[MAX_HDR_LEN + MAX_MSS];
+ static char last[MAX_HDR_LEN + MAX_MSS];
+ const int num_pkt = num_large_pkt();
+ const int mss = calc_mss();
+ int i;
+
+ for (i = 0; i < num_pkt; i++)
+ create_packet(pkts[i], i * mss, 0, mss, 0);
+ create_packet(last, num_pkt * mss, 0, remainder, 0);
+ create_packet(new_seg, (num_pkt + 1) * mss, 0, remainder, 0);
+
+ for (i = 0; i < num_pkt; i++)
+ write_packet(fd, pkts[i], total_hdr_len + mss, daddr);
+ write_packet(fd, last, total_hdr_len + remainder, daddr);
+ write_packet(fd, new_seg, total_hdr_len + remainder, daddr);
+}
+
+/* Pure acks and dup acks don't coalesce */
+static void send_ack(int fd, struct sockaddr_ll *daddr)
+{
+ static char buf[MAX_HDR_LEN];
+
+ create_packet(buf, 0, 0, 0, 0);
+ write_packet(fd, buf, total_hdr_len, daddr);
+ write_packet(fd, buf, total_hdr_len, daddr);
+ create_packet(buf, 0, 1, 0, 0);
+ write_packet(fd, buf, total_hdr_len, daddr);
+}
+
+static void recompute_packet(char *buf, char *no_ext, int extlen)
+{
+ struct tcphdr *tcphdr = (struct tcphdr *)(buf + tcp_offset);
+ int off;
+
+ memmove(buf, no_ext, total_hdr_len);
+ memmove(buf + total_hdr_len + extlen,
+ no_ext + total_hdr_len, PAYLOAD_LEN);
+
+ tcphdr->doff = tcphdr->doff + (extlen / 4);
+ tcphdr->check = 0;
+ tcphdr->check = tcp_checksum(tcphdr, PAYLOAD_LEN + extlen);
+ if (proto == PF_INET) {
+ for (off = ETH_HLEN; off < tcp_offset;
+ off += sizeof(struct iphdr)) {
+ struct iphdr *iph = (struct iphdr *)(buf + off);
+
+ iph->tot_len = htons(ntohs(iph->tot_len) + extlen);
+ iph->check = 0;
+ iph->check = checksum_fold(iph, sizeof(struct iphdr), 0);
+ }
+ } else {
+ for (off = ETH_HLEN; off < tcp_offset;
+ off += sizeof(struct ipv6hdr)) {
+ struct ipv6hdr *ip6h = (struct ipv6hdr *)(buf + off);
+
+ ip6h->payload_len =
+ htons(ntohs(ip6h->payload_len) + extlen);
+ }
+ }
+}
+
+static void tcp_write_options(char *buf, int kind, int ts)
+{
+ struct tcp_option_ts {
+ uint8_t kind;
+ uint8_t len;
+ uint32_t tsval;
+ uint32_t tsecr;
+ } *opt_ts = (void *)buf;
+ struct tcp_option_window {
+ uint8_t kind;
+ uint8_t len;
+ uint8_t shift;
+ } *opt_window = (void *)buf;
+
+ switch (kind) {
+ case TCPOPT_NOP:
+ buf[0] = TCPOPT_NOP;
+ break;
+ case TCPOPT_WINDOW:
+ memset(opt_window, 0, sizeof(struct tcp_option_window));
+ opt_window->kind = TCPOPT_WINDOW;
+ opt_window->len = TCPOLEN_WINDOW;
+ opt_window->shift = 0;
+ break;
+ case TCPOPT_TIMESTAMP:
+ memset(opt_ts, 0, sizeof(struct tcp_option_ts));
+ opt_ts->kind = TCPOPT_TIMESTAMP;
+ opt_ts->len = TCPOLEN_TIMESTAMP;
+ opt_ts->tsval = ts;
+ opt_ts->tsecr = 0;
+ break;
+ default:
+ error(1, 0, "unimplemented TCP option");
+ break;
+ }
+}
+
+/* TCP with options is always a permutation of {TS, NOP, NOP}.
+ * Implement different orders to verify coalescing stops.
+ */
+static void add_standard_tcp_options(char *buf, char *no_ext, int ts, int order)
+{
+ switch (order) {
+ case 0:
+ tcp_write_options(buf + total_hdr_len, TCPOPT_NOP, 0);
+ tcp_write_options(buf + total_hdr_len + 1, TCPOPT_NOP, 0);
+ tcp_write_options(buf + total_hdr_len + 2 /* two NOP opts */,
+ TCPOPT_TIMESTAMP, ts);
+ break;
+ case 1:
+ tcp_write_options(buf + total_hdr_len, TCPOPT_NOP, 0);
+ tcp_write_options(buf + total_hdr_len + 1,
+ TCPOPT_TIMESTAMP, ts);
+ tcp_write_options(buf + total_hdr_len + 1 + TCPOLEN_TIMESTAMP,
+ TCPOPT_NOP, 0);
+ break;
+ case 2:
+ tcp_write_options(buf + total_hdr_len, TCPOPT_TIMESTAMP, ts);
+ tcp_write_options(buf + total_hdr_len + TCPOLEN_TIMESTAMP + 1,
+ TCPOPT_NOP, 0);
+ tcp_write_options(buf + total_hdr_len + TCPOLEN_TIMESTAMP + 2,
+ TCPOPT_NOP, 0);
+ break;
+ default:
+ error(1, 0, "unknown order");
+ break;
+ }
+ recompute_packet(buf, no_ext, TCPOLEN_TSTAMP_APPA);
+}
+
+/* Packets with invalid checksum don't coalesce. */
+static void send_changed_checksum(int fd, struct sockaddr_ll *daddr)
+{
+ static char buf[MAX_HDR_LEN + PAYLOAD_LEN];
+ struct tcphdr *tcph = (struct tcphdr *)(buf + tcp_offset);
+ int pkt_size = total_hdr_len + PAYLOAD_LEN;
+
+ create_packet(buf, 0, 0, PAYLOAD_LEN, 0);
+ write_packet(fd, buf, pkt_size, daddr);
+
+ create_packet(buf, PAYLOAD_LEN, 0, PAYLOAD_LEN, 0);
+ tcph->check = tcph->check - 1;
+ write_packet(fd, buf, pkt_size, daddr);
+}
+
+/* Packets with incorrect IPv4 header checksum don't coalesce. */
+static void send_changed_ip_checksum(int fd, struct sockaddr_ll *daddr)
+{
+ static char buf[MAX_HDR_LEN + PAYLOAD_LEN];
+ struct iphdr *iph = (struct iphdr *)(buf + ETH_HLEN);
+ int pkt_size = total_hdr_len + PAYLOAD_LEN;
+
+ create_packet(buf, 0, 0, PAYLOAD_LEN, 0);
+ write_packet(fd, buf, pkt_size, daddr);
+
+ create_packet(buf, PAYLOAD_LEN, 0, PAYLOAD_LEN, 0);
+ iph->check = iph->check - 1;
+ write_packet(fd, buf, pkt_size, daddr);
+
+ create_packet(buf, PAYLOAD_LEN * 2, 0, PAYLOAD_LEN, 0);
+ write_packet(fd, buf, pkt_size, daddr);
+}
+
+ /* Packets with non-consecutive sequence number don't coalesce.*/
+static void send_changed_seq(int fd, struct sockaddr_ll *daddr)
+{
+ static char buf[MAX_HDR_LEN + PAYLOAD_LEN];
+ struct tcphdr *tcph = (struct tcphdr *)(buf + tcp_offset);
+ int pkt_size = total_hdr_len + PAYLOAD_LEN;
+
+ create_packet(buf, 0, 0, PAYLOAD_LEN, 0);
+ write_packet(fd, buf, pkt_size, daddr);
+
+ create_packet(buf, PAYLOAD_LEN, 0, PAYLOAD_LEN, 0);
+ tcph->seq = ntohl(htonl(tcph->seq) + 1);
+ tcph->check = 0;
+ tcph->check = tcp_checksum(tcph, PAYLOAD_LEN);
+ write_packet(fd, buf, pkt_size, daddr);
+}
+
+ /* Packet with different timestamp option or different timestamps
+ * don't coalesce.
+ */
+static void send_changed_ts(int fd, struct sockaddr_ll *daddr)
+{
+ static char buf[MAX_HDR_LEN + PAYLOAD_LEN];
+ static char extpkt[sizeof(buf) + TCPOLEN_TSTAMP_APPA];
+ int pkt_size = total_hdr_len + PAYLOAD_LEN + TCPOLEN_TSTAMP_APPA;
+
+ create_packet(buf, 0, 0, PAYLOAD_LEN, 0);
+ add_standard_tcp_options(extpkt, buf, 0, 0);
+ write_packet(fd, extpkt, pkt_size, daddr);
+
+ create_packet(buf, PAYLOAD_LEN, 0, PAYLOAD_LEN, 0);
+ add_standard_tcp_options(extpkt, buf, 0, 0);
+ write_packet(fd, extpkt, pkt_size, daddr);
+
+ create_packet(buf, PAYLOAD_LEN * 2, 0, PAYLOAD_LEN, 0);
+ add_standard_tcp_options(extpkt, buf, 100, 0);
+ write_packet(fd, extpkt, pkt_size, daddr);
+
+ create_packet(buf, PAYLOAD_LEN * 3, 0, PAYLOAD_LEN, 0);
+ add_standard_tcp_options(extpkt, buf, 100, 1);
+ write_packet(fd, extpkt, pkt_size, daddr);
+
+ create_packet(buf, PAYLOAD_LEN * 4, 0, PAYLOAD_LEN, 0);
+ add_standard_tcp_options(extpkt, buf, 100, 2);
+ write_packet(fd, extpkt, pkt_size, daddr);
+}
+
+/* Packet with different tcp options don't coalesce. */
+static void send_diff_opt(int fd, struct sockaddr_ll *daddr)
+{
+ static char buf[MAX_HDR_LEN + PAYLOAD_LEN];
+ static char extpkt1[sizeof(buf) + TCPOLEN_TSTAMP_APPA];
+ static char extpkt2[sizeof(buf) + TCPOLEN_MAXSEG];
+ int extpkt1_size = total_hdr_len + PAYLOAD_LEN + TCPOLEN_TSTAMP_APPA;
+ int extpkt2_size = total_hdr_len + PAYLOAD_LEN + TCPOLEN_MAXSEG;
+
+ create_packet(buf, 0, 0, PAYLOAD_LEN, 0);
+ add_standard_tcp_options(extpkt1, buf, 0, 0);
+ write_packet(fd, extpkt1, extpkt1_size, daddr);
+
+ create_packet(buf, PAYLOAD_LEN, 0, PAYLOAD_LEN, 0);
+ add_standard_tcp_options(extpkt1, buf, 0, 0);
+ write_packet(fd, extpkt1, extpkt1_size, daddr);
+
+ create_packet(buf, PAYLOAD_LEN * 2, 0, PAYLOAD_LEN, 0);
+ tcp_write_options(extpkt2 + MAX_HDR_LEN, TCPOPT_NOP, 0);
+ tcp_write_options(extpkt2 + MAX_HDR_LEN + 1, TCPOPT_WINDOW, 0);
+ recompute_packet(extpkt2, buf, TCPOLEN_WINDOW + 1);
+ write_packet(fd, extpkt2, extpkt2_size, daddr);
+}
+
+static void add_ipv4_ts_option(void *buf, void *optpkt)
+{
+ struct ip_timestamp *ts = (struct ip_timestamp *)(optpkt + tcp_offset);
+ int optlen = sizeof(struct ip_timestamp);
+ struct iphdr *iph;
+
+ if (optlen % 4)
+ error(1, 0, "ipv4 timestamp length is not a multiple of 4B");
+
+ ts->ipt_code = IPOPT_TS;
+ ts->ipt_len = optlen;
+ ts->ipt_ptr = 5;
+ ts->ipt_flg = IPOPT_TS_TSONLY;
+
+ memcpy(optpkt, buf, tcp_offset);
+ memcpy(optpkt + tcp_offset + optlen, buf + tcp_offset,
+ sizeof(struct tcphdr) + PAYLOAD_LEN);
+
+ iph = (struct iphdr *)(optpkt + ETH_HLEN);
+ iph->ihl = 5 + (optlen / 4);
+ iph->tot_len = htons(ntohs(iph->tot_len) + optlen);
+ iph->check = 0;
+ iph->check = checksum_fold(iph, sizeof(struct iphdr) + optlen, 0);
+}
+
+static void add_ipv6_exthdr(void *buf, void *optpkt, __u8 exthdr_type, char *ext_payload)
+{
+ struct ipv6_opt_hdr *exthdr = (struct ipv6_opt_hdr *)(optpkt + tcp_offset);
+ struct ipv6hdr *iph = (struct ipv6hdr *)(optpkt + ETH_HLEN);
+ char *exthdr_payload_start = (char *)(exthdr + 1);
+
+ exthdr->hdrlen = 0;
+ exthdr->nexthdr = IPPROTO_TCP;
+
+ memcpy(exthdr_payload_start, ext_payload, MIN_EXTHDR_SIZE - sizeof(*exthdr));
+
+ memcpy(optpkt, buf, tcp_offset);
+ memcpy(optpkt + tcp_offset + MIN_EXTHDR_SIZE, buf + tcp_offset,
+ sizeof(struct tcphdr) + PAYLOAD_LEN);
+
+ iph->nexthdr = exthdr_type;
+ iph->payload_len = htons(ntohs(iph->payload_len) + MIN_EXTHDR_SIZE);
+}
+
+static void fix_ip4_checksum(struct iphdr *iph)
+{
+ iph->check = 0;
+ iph->check = checksum_fold(iph, sizeof(struct iphdr), 0);
+}
+
+static void send_flush_id_case(int fd, struct sockaddr_ll *daddr,
+ enum flush_id_case tcase)
+{
+ static char buf1[MAX_HDR_LEN + PAYLOAD_LEN];
+ static char buf2[MAX_HDR_LEN + PAYLOAD_LEN];
+ static char buf3[MAX_HDR_LEN + PAYLOAD_LEN];
+ bool send_three = false;
+ struct iphdr *iph1;
+ struct iphdr *iph2;
+ struct iphdr *iph3;
+
+ iph1 = (struct iphdr *)(buf1 + ETH_HLEN);
+ iph2 = (struct iphdr *)(buf2 + ETH_HLEN);
+ iph3 = (struct iphdr *)(buf3 + ETH_HLEN);
+
+ create_packet(buf1, 0, 0, PAYLOAD_LEN, 0);
+ create_packet(buf2, PAYLOAD_LEN, 0, PAYLOAD_LEN, 0);
+ create_packet(buf3, PAYLOAD_LEN * 2, 0, PAYLOAD_LEN, 0);
+
+ switch (tcase) {
+ case FLUSH_ID_DF1_INC: /* DF=1, Incrementing - should coalesce */
+ iph1->frag_off |= htons(IP_DF);
+ iph1->id = htons(8);
+
+ iph2->frag_off |= htons(IP_DF);
+ iph2->id = htons(9);
+ break;
+
+ case FLUSH_ID_DF1_FIXED: /* DF=1, Fixed - should coalesce */
+ iph1->frag_off |= htons(IP_DF);
+ iph1->id = htons(8);
+
+ iph2->frag_off |= htons(IP_DF);
+ iph2->id = htons(8);
+ break;
+
+ case FLUSH_ID_DF0_INC: /* DF=0, Incrementing - should coalesce */
+ iph1->frag_off &= ~htons(IP_DF);
+ iph1->id = htons(8);
+
+ iph2->frag_off &= ~htons(IP_DF);
+ iph2->id = htons(9);
+ break;
+
+ case FLUSH_ID_DF0_FIXED: /* DF=0, Fixed - should coalesce */
+ iph1->frag_off &= ~htons(IP_DF);
+ iph1->id = htons(8);
+
+ iph2->frag_off &= ~htons(IP_DF);
+ iph2->id = htons(8);
+ break;
+
+ case FLUSH_ID_DF1_INC_FIXED: /* DF=1, two packets incrementing, and
+ * one fixed - should coalesce only the
+ * first two packets
+ */
+ iph1->frag_off |= htons(IP_DF);
+ iph1->id = htons(8);
+
+ iph2->frag_off |= htons(IP_DF);
+ iph2->id = htons(9);
+
+ iph3->frag_off |= htons(IP_DF);
+ iph3->id = htons(9);
+ send_three = true;
+ break;
+
+ case FLUSH_ID_DF1_FIXED_INC: /* DF=1, two packets fixed, and one
+ * incrementing - should coalesce only
+ * the first two packets
+ */
+ iph1->frag_off |= htons(IP_DF);
+ iph1->id = htons(8);
+
+ iph2->frag_off |= htons(IP_DF);
+ iph2->id = htons(8);
+
+ iph3->frag_off |= htons(IP_DF);
+ iph3->id = htons(9);
+ send_three = true;
+ break;
+ }
+
+ fix_ip4_checksum(iph1);
+ fix_ip4_checksum(iph2);
+ write_packet(fd, buf1, total_hdr_len + PAYLOAD_LEN, daddr);
+ write_packet(fd, buf2, total_hdr_len + PAYLOAD_LEN, daddr);
+
+ if (send_three) {
+ fix_ip4_checksum(iph3);
+ write_packet(fd, buf3, total_hdr_len + PAYLOAD_LEN, daddr);
+ }
+}
+
+static void send_ipv6_exthdr(int fd, struct sockaddr_ll *daddr, char *ext_data1, char *ext_data2)
+{
+ static char buf[MAX_HDR_LEN + PAYLOAD_LEN];
+ static char exthdr_pck[sizeof(buf) + MIN_EXTHDR_SIZE];
+
+ create_packet(buf, 0, 0, PAYLOAD_LEN, 0);
+ add_ipv6_exthdr(buf, exthdr_pck, IPPROTO_DSTOPTS, ext_data1);
+ write_packet(fd, exthdr_pck, total_hdr_len + PAYLOAD_LEN + MIN_EXTHDR_SIZE, daddr);
+
+ create_packet(buf, PAYLOAD_LEN * 1, 0, PAYLOAD_LEN, 0);
+ add_ipv6_exthdr(buf, exthdr_pck, IPPROTO_DSTOPTS, ext_data2);
+ write_packet(fd, exthdr_pck, total_hdr_len + PAYLOAD_LEN + MIN_EXTHDR_SIZE, daddr);
+}
+
+/* IPv4 options shouldn't coalesce */
+static void send_ip_options(int fd, struct sockaddr_ll *daddr)
+{
+ static char buf[MAX_HDR_LEN + PAYLOAD_LEN];
+ static char optpkt[sizeof(buf) + sizeof(struct ip_timestamp)];
+ int optlen = sizeof(struct ip_timestamp);
+ int pkt_size = total_hdr_len + PAYLOAD_LEN + optlen;
+
+ create_packet(buf, 0, 0, PAYLOAD_LEN, 0);
+ write_packet(fd, buf, total_hdr_len + PAYLOAD_LEN, daddr);
+
+ create_packet(buf, PAYLOAD_LEN * 1, 0, PAYLOAD_LEN, 0);
+ add_ipv4_ts_option(buf, optpkt);
+ write_packet(fd, optpkt, pkt_size, daddr);
+
+ create_packet(buf, PAYLOAD_LEN * 2, 0, PAYLOAD_LEN, 0);
+ write_packet(fd, buf, total_hdr_len + PAYLOAD_LEN, daddr);
+}
+
+/* IPv4 fragments shouldn't coalesce */
+static void send_fragment4(int fd, struct sockaddr_ll *daddr)
+{
+ static char buf[IP_MAXPACKET];
+ struct iphdr *iph = (struct iphdr *)(buf + ETH_HLEN);
+ int pkt_size = total_hdr_len + PAYLOAD_LEN;
+
+ create_packet(buf, 0, 0, PAYLOAD_LEN, 0);
+ write_packet(fd, buf, pkt_size, daddr);
+
+ /* Once fragmented, packet would retain the total_len.
+ * Tcp header is prepared as if rest of data is in follow-up frags,
+ * but follow up frags aren't actually sent.
+ */
+ memset(buf + total_hdr_len, 'a', PAYLOAD_LEN * 2);
+ fill_transportlayer(buf + tcp_offset, PAYLOAD_LEN, 0, PAYLOAD_LEN * 2, 0);
+ fill_networklayer(buf + ETH_HLEN, PAYLOAD_LEN, IPPROTO_TCP);
+ fill_datalinklayer(buf);
+
+ iph->frag_off = htons(0x6000); // DF = 1, MF = 1
+ iph->check = 0;
+ iph->check = checksum_fold(iph, sizeof(struct iphdr), 0);
+ write_packet(fd, buf, pkt_size, daddr);
+}
+
+/* IPv4 packets with different ttl don't coalesce.*/
+static void send_changed_ttl(int fd, struct sockaddr_ll *daddr)
+{
+ int pkt_size = total_hdr_len + PAYLOAD_LEN;
+ static char buf[MAX_HDR_LEN + PAYLOAD_LEN];
+ struct iphdr *iph = (struct iphdr *)(buf + ETH_HLEN);
+
+ create_packet(buf, 0, 0, PAYLOAD_LEN, 0);
+ write_packet(fd, buf, pkt_size, daddr);
+
+ create_packet(buf, PAYLOAD_LEN, 0, PAYLOAD_LEN, 0);
+ iph->ttl = 7;
+ iph->check = 0;
+ iph->check = checksum_fold(iph, sizeof(struct iphdr), 0);
+ write_packet(fd, buf, pkt_size, daddr);
+}
+
+/* Packets with different tos don't coalesce.*/
+static void send_changed_tos(int fd, struct sockaddr_ll *daddr)
+{
+ int pkt_size = total_hdr_len + PAYLOAD_LEN;
+ static char buf[MAX_HDR_LEN + PAYLOAD_LEN];
+ struct iphdr *iph = (struct iphdr *)(buf + ETH_HLEN);
+ struct ipv6hdr *ip6h = (struct ipv6hdr *)(buf + ETH_HLEN);
+
+ create_packet(buf, 0, 0, PAYLOAD_LEN, 0);
+ write_packet(fd, buf, pkt_size, daddr);
+
+ create_packet(buf, PAYLOAD_LEN, 0, PAYLOAD_LEN, 0);
+ if (proto == PF_INET) {
+ iph->tos = 1;
+ iph->check = 0;
+ iph->check = checksum_fold(iph, sizeof(struct iphdr), 0);
+ } else if (proto == PF_INET6) {
+ ip6h->priority = 0xf;
+ }
+ write_packet(fd, buf, pkt_size, daddr);
+}
+
+/* Packets with different ECN don't coalesce.*/
+static void send_changed_ECN(int fd, struct sockaddr_ll *daddr)
+{
+ int pkt_size = total_hdr_len + PAYLOAD_LEN;
+ static char buf[MAX_HDR_LEN + PAYLOAD_LEN];
+ struct iphdr *iph = (struct iphdr *)(buf + ETH_HLEN);
+
+ create_packet(buf, 0, 0, PAYLOAD_LEN, 0);
+ write_packet(fd, buf, pkt_size, daddr);
+
+ create_packet(buf, PAYLOAD_LEN, 0, PAYLOAD_LEN, 0);
+ if (proto == PF_INET) {
+ buf[ETH_HLEN + 1] ^= 0x2; // ECN set to 10
+ iph->check = 0;
+ iph->check = checksum_fold(iph, sizeof(struct iphdr), 0);
+ } else {
+ buf[ETH_HLEN + 1] ^= 0x20; // ECN set to 10
+ }
+ write_packet(fd, buf, pkt_size, daddr);
+}
+
+/* IPv6 fragments and packets with extensions don't coalesce.*/
+static void send_fragment6(int fd, struct sockaddr_ll *daddr)
+{
+ static char buf[MAX_HDR_LEN + PAYLOAD_LEN];
+ static char extpkt[MAX_HDR_LEN + PAYLOAD_LEN +
+ sizeof(struct ip6_frag)];
+ struct ipv6hdr *ip6h = (struct ipv6hdr *)(buf + ETH_HLEN);
+ struct ip6_frag *frag = (void *)(extpkt + tcp_offset);
+ int extlen = sizeof(struct ip6_frag);
+ int bufpkt_len = total_hdr_len + PAYLOAD_LEN;
+ int extpkt_len = bufpkt_len + extlen;
+ int i;
+
+ for (i = 0; i < 2; i++) {
+ create_packet(buf, PAYLOAD_LEN * i, 0, PAYLOAD_LEN, 0);
+ write_packet(fd, buf, bufpkt_len, daddr);
+ }
+ sleep(1);
+ create_packet(buf, PAYLOAD_LEN * 2, 0, PAYLOAD_LEN, 0);
+ memset(extpkt, 0, extpkt_len);
+
+ ip6h->nexthdr = IPPROTO_FRAGMENT;
+ ip6h->payload_len = htons(ntohs(ip6h->payload_len) + extlen);
+ frag->ip6f_nxt = IPPROTO_TCP;
+
+ memcpy(extpkt, buf, tcp_offset);
+ memcpy(extpkt + tcp_offset + extlen, buf + tcp_offset,
+ sizeof(struct tcphdr) + PAYLOAD_LEN);
+ write_packet(fd, extpkt, extpkt_len, daddr);
+
+ create_packet(buf, PAYLOAD_LEN * 3, 0, PAYLOAD_LEN, 0);
+ write_packet(fd, buf, bufpkt_len, daddr);
+}
+
+static void bind_packetsocket(int fd)
+{
+ struct sockaddr_ll daddr = {};
+
+ daddr.sll_family = AF_PACKET;
+ daddr.sll_protocol = ethhdr_proto;
+ daddr.sll_ifindex = if_nametoindex(ifname);
+ if (daddr.sll_ifindex == 0)
+ error(1, errno, "if_nametoindex");
+
+ if (bind(fd, (void *)&daddr, sizeof(daddr)) < 0)
+ error(1, errno, "could not bind socket");
+}
+
+static void set_timeout(int fd)
+{
+ struct timeval timeout;
+
+ timeout.tv_sec = 3;
+ timeout.tv_usec = 0;
+ if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (char *)&timeout,
+ sizeof(timeout)) < 0)
+ error(1, errno, "cannot set timeout, setsockopt failed");
+}
+
+static void set_rcvbuf(int fd)
+{
+ int bufsize = 1 * 1024 * 1024; /* 1 MB */
+
+ if (setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &bufsize, sizeof(bufsize)))
+ error(1, errno, "cannot set rcvbuf size, setsockopt failed");
+}
+
+static void recv_error(int fd, int rcv_errno)
+{
+ struct tpacket_stats stats;
+ socklen_t len;
+
+ len = sizeof(stats);
+ if (getsockopt(fd, SOL_PACKET, PACKET_STATISTICS, &stats, &len))
+ error(1, errno, "can't get stats");
+
+ fprintf(stderr, "Socket stats: packets=%u, drops=%u\n",
+ stats.tp_packets, stats.tp_drops);
+ error(1, rcv_errno, "could not receive");
+}
+
+static void check_recv_pkts(int fd, int *correct_payload,
+ int correct_num_pkts)
+{
+ static char buffer[IP_MAXPACKET + ETH_HLEN + 1];
+ struct iphdr *iph = (struct iphdr *)(buffer + ETH_HLEN);
+ struct ipv6hdr *ip6h = (struct ipv6hdr *)(buffer + ETH_HLEN);
+ struct tcphdr *tcph;
+ bool bad_packet = false;
+ int tcp_ext_len = 0;
+ int ip_ext_len = 0;
+ int pkt_size = -1;
+ int data_len = 0;
+ int num_pkt = 0;
+ int i;
+
+ vlog("Expected {");
+ for (i = 0; i < correct_num_pkts; i++)
+ vlog("%d ", correct_payload[i]);
+ vlog("}, Total %d packets\nReceived {", correct_num_pkts);
+
+ while (1) {
+ ip_ext_len = 0;
+ pkt_size = recv(fd, buffer, IP_MAXPACKET + ETH_HLEN + 1, 0);
+ if (pkt_size < 0)
+ recv_error(fd, errno);
+
+ if (iph->version == 4)
+ ip_ext_len = (iph->ihl - 5) * 4;
+ else if (ip6h->version == 6 && !ip6ip6 &&
+ ip6h->nexthdr != IPPROTO_TCP)
+ ip_ext_len = MIN_EXTHDR_SIZE;
+
+ tcph = (struct tcphdr *)(buffer + tcp_offset + ip_ext_len);
+
+ if (tcph->fin)
+ break;
+
+ tcp_ext_len = (tcph->doff - 5) * 4;
+ data_len = pkt_size - total_hdr_len - tcp_ext_len - ip_ext_len;
+ /* Min ethernet frame payload is 46(ETH_ZLEN - ETH_HLEN) by RFC 802.3.
+ * Ipv4/tcp packets without at least 6 bytes of data will be padded.
+ * Packet sockets are protocol agnostic, and will not trim the padding.
+ */
+ if (pkt_size == ETH_ZLEN && iph->version == 4) {
+ data_len = ntohs(iph->tot_len)
+ - sizeof(struct tcphdr) - sizeof(struct iphdr);
+ }
+ vlog("%d ", data_len);
+ if (data_len != correct_payload[num_pkt]) {
+ vlog("[!=%d]", correct_payload[num_pkt]);
+ bad_packet = true;
+ }
+ num_pkt++;
+ }
+ vlog("}, Total %d packets.\n", num_pkt);
+ if (num_pkt != correct_num_pkts)
+ error(1, 0, "incorrect number of packets");
+ if (bad_packet)
+ error(1, 0, "incorrect packet geometry");
+
+ printf("Test succeeded\n\n");
+}
+
+static void check_capacity_pkts(int fd)
+{
+ static char buffer[IP_MAXPACKET + ETH_HLEN + 1];
+ struct iphdr *iph = (struct iphdr *)(buffer + ETH_HLEN);
+ struct ipv6hdr *ip6h = (struct ipv6hdr *)(buffer + ETH_HLEN);
+ int num_pkt = 0, num_coal = 0, pkt_idx;
+ const char *fail_reason = NULL;
+ int flow_order[num_flows * 2];
+ int coalesced[num_flows];
+ struct tcphdr *tcph;
+ int ip_ext_len = 0;
+ int total_data = 0;
+ int pkt_size = -1;
+ int data_len = 0;
+ int flow_id;
+ int sport;
+
+ memset(coalesced, 0, sizeof(coalesced));
+ memset(flow_order, -1, sizeof(flow_order));
+
+ while (1) {
+ ip_ext_len = 0;
+ pkt_size = recv(fd, buffer, IP_MAXPACKET + ETH_HLEN + 1, 0);
+ if (pkt_size < 0)
+ recv_error(fd, errno);
+
+ if (iph->version == 4)
+ ip_ext_len = (iph->ihl - 5) * 4;
+ else if (ip6h->version == 6 && !ip6ip6 &&
+ ip6h->nexthdr != IPPROTO_TCP)
+ ip_ext_len = MIN_EXTHDR_SIZE;
+
+ tcph = (struct tcphdr *)(buffer + tcp_offset + ip_ext_len);
+
+ if (tcph->fin)
+ break;
+
+ sport = ntohs(tcph->source);
+ flow_id = sport - SPORT;
+
+ if (flow_id < 0 || flow_id >= num_flows) {
+ vlog("Invalid flow_id %d from sport %d\n",
+ flow_id, sport);
+ fail_reason = fail_reason ?: "invalid packet";
+ continue;
+ }
+
+ /* Calculate payload length */
+ if (pkt_size == ETH_ZLEN && iph->version == 4) {
+ data_len = ntohs(iph->tot_len)
+ - sizeof(struct tcphdr) - sizeof(struct iphdr);
+ } else {
+ data_len = pkt_size - total_hdr_len - ip_ext_len;
+ }
+
+ if (num_pkt < num_flows * 2) {
+ flow_order[num_pkt] = flow_id;
+ } else if (num_pkt == num_flows * 2) {
+ vlog("More packets than expected (%d)\n",
+ num_flows * 2);
+ fail_reason = fail_reason ?: "too many packets";
+ }
+ coalesced[flow_id] = data_len;
+
+ if (data_len == CAPACITY_PAYLOAD_LEN * 2) {
+ num_coal++;
+ } else {
+ vlog("Pkt %d: flow %d, sport %d, len %d (expected %d)\n",
+ num_pkt, flow_id, sport, data_len,
+ CAPACITY_PAYLOAD_LEN * 2);
+ fail_reason = fail_reason ?: "not coalesced";
+ }
+
+ num_pkt++;
+ total_data += data_len;
+ }
+
+ /* Check flow ordering. We expect to see all non-coalesced first segs
+ * then interleaved coalesced and non-coalesced second frames.
+ */
+ pkt_idx = 0;
+ for (flow_id = 0; order_check && flow_id < num_flows; flow_id++) {
+ bool coaled = coalesced[flow_id] > CAPACITY_PAYLOAD_LEN;
+
+ if (coaled)
+ continue;
+
+ if (flow_order[pkt_idx] != flow_id) {
+ vlog("Flow order mismatch (non-coalesced) at position %d: expected flow %d, got flow %d\n",
+ pkt_idx, flow_id, flow_order[pkt_idx]);
+ fail_reason = fail_reason ?: "bad packet order (1)";
+ }
+ pkt_idx++;
+ }
+ for (flow_id = 0; order_check && flow_id < num_flows; flow_id++) {
+ bool coaled = coalesced[flow_id] > CAPACITY_PAYLOAD_LEN;
+
+ if (flow_order[pkt_idx] != flow_id) {
+ vlog("Flow order mismatch at position %d: expected flow %d, got flow %d, coalesced: %d\n",
+ pkt_idx, flow_id, flow_order[pkt_idx], coaled);
+ fail_reason = fail_reason ?: "bad packet order (2)";
+ }
+ pkt_idx++;
+ }
+
+ if (!fail_reason) {
+ vlog("All %d flows coalesced correctly\n", num_flows);
+ printf("Test succeeded\n\n");
+ } else {
+ printf("FAILED\n");
+ }
+
+ /* Always print stats for external validation */
+ printf("STATS: received=%d wire=%d coalesced=%d\n",
+ num_pkt, num_pkt + num_coal, num_coal);
+
+ if (fail_reason)
+ error(1, 0, "capacity test failed %s", fail_reason);
+}
+
+static void gro_sender(void)
+{
+ int bufsize = 4 * 1024 * 1024; /* 4 MB */
+ const int fin_delay_us = 100 * 1000;
+ static char fin_pkt[MAX_HDR_LEN];
+ struct sockaddr_ll daddr = {};
+ int txfd = -1;
+
+ txfd = socket(PF_PACKET, SOCK_RAW, IPPROTO_RAW);
+ if (txfd < 0)
+ error(1, errno, "socket creation");
+
+ if (setsockopt(txfd, SOL_SOCKET, SO_SNDBUF, &bufsize, sizeof(bufsize)))
+ error(1, errno, "cannot set sndbuf size, setsockopt failed");
+
+ /* Enable SO_TXTIME unless test case generates more than one flow
+ * SO_TXTIME could result in qdisc layer sorting the packets at sender.
+ */
+ if (strcmp(testname, "single") && strcmp(testname, "capacity")) {
+ struct sock_txtime so_txtime = { .clockid = CLOCK_MONOTONIC, };
+ struct timespec ts;
+
+ if (setsockopt(txfd, SOL_SOCKET, SO_TXTIME,
+ &so_txtime, sizeof(so_txtime)))
+ error(1, errno, "setsockopt SO_TXTIME");
+
+ if (clock_gettime(CLOCK_MONOTONIC, &ts))
+ error(1, errno, "clock_gettime");
+
+ txtime_ns = ts.tv_sec * 1000000000ULL + ts.tv_nsec;
+ txtime_ns += TXTIME_DELAY_MS * 1000000ULL;
+ }
+
+ memset(&daddr, 0, sizeof(daddr));
+ daddr.sll_ifindex = if_nametoindex(ifname);
+ if (daddr.sll_ifindex == 0)
+ error(1, errno, "if_nametoindex");
+ daddr.sll_family = AF_PACKET;
+ memcpy(daddr.sll_addr, dst_mac, ETH_ALEN);
+ daddr.sll_halen = ETH_ALEN;
+ create_packet(fin_pkt, PAYLOAD_LEN * 2, 0, 0, 1);
+
+ /* data sub-tests */
+ if (strcmp(testname, "data_same") == 0) {
+ send_data_pkts(txfd, &daddr, PAYLOAD_LEN, PAYLOAD_LEN);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "data_lrg_sml") == 0) {
+ send_data_pkts(txfd, &daddr, PAYLOAD_LEN, PAYLOAD_LEN / 2);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "data_lrg_1byte") == 0) {
+ send_data_pkts(txfd, &daddr, PAYLOAD_LEN, 1);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "data_sml_lrg") == 0) {
+ send_data_pkts(txfd, &daddr, PAYLOAD_LEN / 2, PAYLOAD_LEN);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "data_burst") == 0) {
+ static char buf[MAX_HDR_LEN + PAYLOAD_LEN];
+
+ create_packet(buf, 0, 0, PAYLOAD_LEN, 0);
+ write_packet(txfd, buf, total_hdr_len + PAYLOAD_LEN, &daddr);
+ create_packet(buf, PAYLOAD_LEN, 0, PAYLOAD_LEN, 0);
+ write_packet(txfd, buf, total_hdr_len + PAYLOAD_LEN, &daddr);
+
+ usleep(100 * 1000); /* 100ms */
+ create_packet(buf, PAYLOAD_LEN * 2, 0, PAYLOAD_LEN, 0);
+ write_packet(txfd, buf, total_hdr_len + PAYLOAD_LEN, &daddr);
+ create_packet(buf, PAYLOAD_LEN * 3, 0, PAYLOAD_LEN, 0);
+ write_packet(txfd, buf, total_hdr_len + PAYLOAD_LEN, &daddr);
+
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+
+ /* ack test */
+ } else if (strcmp(testname, "ack") == 0) {
+ send_ack(txfd, &daddr);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+
+ /* flags sub-tests */
+ } else if (strcmp(testname, "flags_psh") == 0) {
+ send_flags(txfd, &daddr, 1, 0, 0, 0, 0);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "flags_syn") == 0) {
+ send_flags(txfd, &daddr, 0, 1, 0, 0, 0);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "flags_rst") == 0) {
+ send_flags(txfd, &daddr, 0, 0, 1, 0, 0);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "flags_urg") == 0) {
+ send_flags(txfd, &daddr, 0, 0, 0, 1, 0);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "flags_cwr") == 0) {
+ send_flags(txfd, &daddr, 0, 0, 0, 0, 1);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+
+ /* tcp sub-tests */
+ } else if (strcmp(testname, "tcp_csum") == 0) {
+ send_changed_checksum(txfd, &daddr);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "tcp_seq") == 0) {
+ send_changed_seq(txfd, &daddr);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "tcp_ts") == 0) {
+ send_changed_ts(txfd, &daddr);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "tcp_opt") == 0) {
+ send_diff_opt(txfd, &daddr);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+
+ /* ip sub-tests - shared between IPv4 and IPv6 */
+ } else if (strcmp(testname, "ip_ecn") == 0) {
+ send_changed_ECN(txfd, &daddr);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "ip_tos") == 0) {
+ send_changed_tos(txfd, &daddr);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+
+ /* ip sub-tests - IPv4 only */
+ } else if (strcmp(testname, "ip_csum") == 0) {
+ send_changed_ip_checksum(txfd, &daddr);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "ip_ttl") == 0) {
+ send_changed_ttl(txfd, &daddr);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "ip_opt") == 0) {
+ send_ip_options(txfd, &daddr);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "ip_frag4") == 0) {
+ send_fragment4(txfd, &daddr);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "ip_id_df1_inc") == 0) {
+ send_flush_id_case(txfd, &daddr, FLUSH_ID_DF1_INC);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "ip_id_df1_fixed") == 0) {
+ send_flush_id_case(txfd, &daddr, FLUSH_ID_DF1_FIXED);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "ip_id_df0_inc") == 0) {
+ send_flush_id_case(txfd, &daddr, FLUSH_ID_DF0_INC);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "ip_id_df0_fixed") == 0) {
+ send_flush_id_case(txfd, &daddr, FLUSH_ID_DF0_FIXED);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "ip_id_df1_inc_fixed") == 0) {
+ send_flush_id_case(txfd, &daddr, FLUSH_ID_DF1_INC_FIXED);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "ip_id_df1_fixed_inc") == 0) {
+ send_flush_id_case(txfd, &daddr, FLUSH_ID_DF1_FIXED_INC);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+
+ /* ip sub-tests - IPv6 only */
+ } else if (strcmp(testname, "ip_frag6") == 0) {
+ send_fragment6(txfd, &daddr);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "ip_v6ext_same") == 0) {
+ send_ipv6_exthdr(txfd, &daddr, EXT_PAYLOAD_1, EXT_PAYLOAD_1);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "ip_v6ext_diff") == 0) {
+ send_ipv6_exthdr(txfd, &daddr, EXT_PAYLOAD_1, EXT_PAYLOAD_2);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+
+ /* large sub-tests */
+ } else if (strcmp(testname, "large_max") == 0) {
+ int remainder = max_payload() % calc_mss();
+
+ send_large(txfd, &daddr, remainder);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "large_rem") == 0) {
+ int remainder = max_payload() % calc_mss();
+
+ send_large(txfd, &daddr, remainder + 1);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+
+ /* machinery sub-tests */
+ } else if (strcmp(testname, "single") == 0) {
+ static char buf[MAX_HDR_LEN + PAYLOAD_LEN];
+
+ create_packet(buf, 0, 0, PAYLOAD_LEN, 0);
+ write_packet(txfd, buf, total_hdr_len + PAYLOAD_LEN, &daddr);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+ } else if (strcmp(testname, "capacity") == 0) {
+ send_capacity(txfd, &daddr);
+ usleep(fin_delay_us);
+ write_packet(txfd, fin_pkt, total_hdr_len, &daddr);
+
+ } else {
+ error(1, 0, "Unknown testcase: %s", testname);
+ }
+
+ if (close(txfd))
+ error(1, errno, "socket close");
+}
+
+static void gro_receiver(void)
+{
+ static int correct_payload[NUM_PACKETS];
+ int rxfd = -1;
+
+ rxfd = socket(PF_PACKET, SOCK_RAW, htons(ETH_P_NONE));
+ if (rxfd < 0)
+ error(1, 0, "socket creation");
+ setup_sock_filter(rxfd);
+ set_timeout(rxfd);
+ set_rcvbuf(rxfd);
+ bind_packetsocket(rxfd);
+
+ ksft_ready();
+
+ memset(correct_payload, 0, sizeof(correct_payload));
+
+ /* data sub-tests */
+ if (strcmp(testname, "data_same") == 0) {
+ printf("pure data packet of same size: ");
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ check_recv_pkts(rxfd, correct_payload, 1);
+ } else if (strcmp(testname, "data_lrg_sml") == 0) {
+ printf("large data packets followed by a smaller one: ");
+ correct_payload[0] = PAYLOAD_LEN * 1.5;
+ check_recv_pkts(rxfd, correct_payload, 1);
+ } else if (strcmp(testname, "data_lrg_1byte") == 0) {
+ printf("large data packet followed by a 1 byte one: ");
+ correct_payload[0] = PAYLOAD_LEN + 1;
+ check_recv_pkts(rxfd, correct_payload, 1);
+ } else if (strcmp(testname, "data_sml_lrg") == 0) {
+ printf("small data packets followed by a larger one: ");
+ correct_payload[0] = PAYLOAD_LEN / 2;
+ correct_payload[1] = PAYLOAD_LEN;
+ check_recv_pkts(rxfd, correct_payload, 2);
+ } else if (strcmp(testname, "data_burst") == 0) {
+ printf("two bursts of two data packets: ");
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ correct_payload[1] = PAYLOAD_LEN * 2;
+ check_recv_pkts(rxfd, correct_payload, 2);
+
+ /* ack test */
+ } else if (strcmp(testname, "ack") == 0) {
+ printf("duplicate ack and pure ack: ");
+ check_recv_pkts(rxfd, correct_payload, 3);
+
+ /* flags sub-tests */
+ } else if (strcmp(testname, "flags_psh") == 0) {
+ correct_payload[0] = PAYLOAD_LEN * 3;
+ correct_payload[1] = PAYLOAD_LEN * 2;
+ printf("psh flag ends coalescing: ");
+ check_recv_pkts(rxfd, correct_payload, 2);
+ } else if (strcmp(testname, "flags_syn") == 0) {
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ correct_payload[1] = 0;
+ correct_payload[2] = PAYLOAD_LEN * 2;
+ printf("syn flag ends coalescing: ");
+ check_recv_pkts(rxfd, correct_payload, 3);
+ } else if (strcmp(testname, "flags_rst") == 0) {
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ correct_payload[1] = 0;
+ correct_payload[2] = PAYLOAD_LEN * 2;
+ printf("rst flag ends coalescing: ");
+ check_recv_pkts(rxfd, correct_payload, 3);
+ } else if (strcmp(testname, "flags_urg") == 0) {
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ correct_payload[1] = 0;
+ correct_payload[2] = PAYLOAD_LEN * 2;
+ printf("urg flag ends coalescing: ");
+ check_recv_pkts(rxfd, correct_payload, 3);
+ } else if (strcmp(testname, "flags_cwr") == 0) {
+ correct_payload[0] = PAYLOAD_LEN;
+ correct_payload[1] = PAYLOAD_LEN * 2;
+ correct_payload[2] = PAYLOAD_LEN * 2;
+ printf("cwr flag ends coalescing: ");
+ check_recv_pkts(rxfd, correct_payload, 3);
+
+ /* tcp sub-tests */
+ } else if (strcmp(testname, "tcp_csum") == 0) {
+ correct_payload[0] = PAYLOAD_LEN;
+ correct_payload[1] = PAYLOAD_LEN;
+ printf("changed checksum does not coalesce: ");
+ check_recv_pkts(rxfd, correct_payload, 2);
+ } else if (strcmp(testname, "tcp_seq") == 0) {
+ correct_payload[0] = PAYLOAD_LEN;
+ correct_payload[1] = PAYLOAD_LEN;
+ printf("Wrong Seq number doesn't coalesce: ");
+ check_recv_pkts(rxfd, correct_payload, 2);
+ } else if (strcmp(testname, "tcp_ts") == 0) {
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ correct_payload[1] = PAYLOAD_LEN;
+ correct_payload[2] = PAYLOAD_LEN;
+ correct_payload[3] = PAYLOAD_LEN;
+ printf("Different timestamp doesn't coalesce: ");
+ check_recv_pkts(rxfd, correct_payload, 4);
+ } else if (strcmp(testname, "tcp_opt") == 0) {
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ correct_payload[1] = PAYLOAD_LEN;
+ printf("Different options doesn't coalesce: ");
+ check_recv_pkts(rxfd, correct_payload, 2);
+
+ /* ip sub-tests - shared between IPv4 and IPv6 */
+ } else if (strcmp(testname, "ip_ecn") == 0) {
+ correct_payload[0] = PAYLOAD_LEN;
+ correct_payload[1] = PAYLOAD_LEN;
+ printf("different ECN doesn't coalesce: ");
+ check_recv_pkts(rxfd, correct_payload, 2);
+ } else if (strcmp(testname, "ip_tos") == 0) {
+ correct_payload[0] = PAYLOAD_LEN;
+ correct_payload[1] = PAYLOAD_LEN;
+ printf("different tos doesn't coalesce: ");
+ check_recv_pkts(rxfd, correct_payload, 2);
+
+ /* ip sub-tests - IPv4 only */
+ } else if (strcmp(testname, "ip_csum") == 0) {
+ correct_payload[0] = PAYLOAD_LEN;
+ correct_payload[1] = PAYLOAD_LEN;
+ correct_payload[2] = PAYLOAD_LEN;
+ printf("bad ip checksum doesn't coalesce: ");
+ check_recv_pkts(rxfd, correct_payload, 3);
+ } else if (strcmp(testname, "ip_ttl") == 0) {
+ correct_payload[0] = PAYLOAD_LEN;
+ correct_payload[1] = PAYLOAD_LEN;
+ printf("different ttl doesn't coalesce: ");
+ check_recv_pkts(rxfd, correct_payload, 2);
+ } else if (strcmp(testname, "ip_opt") == 0) {
+ correct_payload[0] = PAYLOAD_LEN;
+ correct_payload[1] = PAYLOAD_LEN;
+ correct_payload[2] = PAYLOAD_LEN;
+ printf("ip options doesn't coalesce: ");
+ check_recv_pkts(rxfd, correct_payload, 3);
+ } else if (strcmp(testname, "ip_frag4") == 0) {
+ correct_payload[0] = PAYLOAD_LEN;
+ correct_payload[1] = PAYLOAD_LEN;
+ printf("fragmented ip4 doesn't coalesce: ");
+ check_recv_pkts(rxfd, correct_payload, 2);
+ } else if (strcmp(testname, "ip_id_df1_inc") == 0) {
+ printf("DF=1, Incrementing - should coalesce: ");
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ check_recv_pkts(rxfd, correct_payload, 1);
+ } else if (strcmp(testname, "ip_id_df1_fixed") == 0) {
+ printf("DF=1, Fixed - should coalesce: ");
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ check_recv_pkts(rxfd, correct_payload, 1);
+ } else if (strcmp(testname, "ip_id_df0_inc") == 0) {
+ printf("DF=0, Incrementing - should coalesce: ");
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ check_recv_pkts(rxfd, correct_payload, 1);
+ } else if (strcmp(testname, "ip_id_df0_fixed") == 0) {
+ printf("DF=0, Fixed - should coalesce: ");
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ check_recv_pkts(rxfd, correct_payload, 1);
+ } else if (strcmp(testname, "ip_id_df1_inc_fixed") == 0) {
+ printf("DF=1, 2 Incrementing and one fixed - should coalesce only first 2 packets: ");
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ correct_payload[1] = PAYLOAD_LEN;
+ check_recv_pkts(rxfd, correct_payload, 2);
+ } else if (strcmp(testname, "ip_id_df1_fixed_inc") == 0) {
+ printf("DF=1, 2 Fixed and one incrementing - should coalesce only first 2 packets: ");
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ correct_payload[1] = PAYLOAD_LEN;
+ check_recv_pkts(rxfd, correct_payload, 2);
+
+ /* ip sub-tests - IPv6 only */
+ } else if (strcmp(testname, "ip_frag6") == 0) {
+ /* GRO doesn't check for ipv6 hop limit when flushing.
+ * Hence no corresponding test to the ipv4 case.
+ */
+ printf("fragmented ip6 doesn't coalesce: ");
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ correct_payload[1] = PAYLOAD_LEN;
+ correct_payload[2] = PAYLOAD_LEN;
+ check_recv_pkts(rxfd, correct_payload, 3);
+ } else if (strcmp(testname, "ip_v6ext_same") == 0) {
+ printf("ipv6 with ext header does coalesce: ");
+ correct_payload[0] = PAYLOAD_LEN * 2;
+ check_recv_pkts(rxfd, correct_payload, 1);
+ } else if (strcmp(testname, "ip_v6ext_diff") == 0) {
+ printf("ipv6 with ext header with different payloads doesn't coalesce: ");
+ correct_payload[0] = PAYLOAD_LEN;
+ correct_payload[1] = PAYLOAD_LEN;
+ check_recv_pkts(rxfd, correct_payload, 2);
+
+ /* large sub-tests */
+ } else if (strcmp(testname, "large_max") == 0) {
+ int remainder = max_payload() % calc_mss();
+
+ correct_payload[0] = max_payload();
+ correct_payload[1] = remainder;
+ printf("Shouldn't coalesce if exceed IP max pkt size: ");
+ check_recv_pkts(rxfd, correct_payload, 2);
+ } else if (strcmp(testname, "large_rem") == 0) {
+ int remainder = max_payload() % calc_mss();
+
+ /* last segment sent individually, doesn't start new segment */
+ correct_payload[0] = max_payload() - remainder;
+ correct_payload[1] = remainder + 1;
+ correct_payload[2] = remainder + 1;
+ printf("last segment sent individually: ");
+ check_recv_pkts(rxfd, correct_payload, 3);
+
+ /* machinery sub-tests */
+ } else if (strcmp(testname, "single") == 0) {
+ printf("single data packet: ");
+ correct_payload[0] = PAYLOAD_LEN;
+ check_recv_pkts(rxfd, correct_payload, 1);
+ } else if (strcmp(testname, "capacity") == 0) {
+ check_capacity_pkts(rxfd);
+
+ } else {
+ error(1, 0, "Test case error: unknown testname %s", testname);
+ }
+
+ if (close(rxfd))
+ error(1, 0, "socket close");
+}
+
+static void parse_args(int argc, char **argv)
+{
+ static const struct option opts[] = {
+ { "daddr", required_argument, NULL, 'd' },
+ { "dmac", required_argument, NULL, 'D' },
+ { "iface", required_argument, NULL, 'i' },
+ { "ipv4", no_argument, NULL, '4' },
+ { "ipv6", no_argument, NULL, '6' },
+ { "ipip", no_argument, NULL, 'e' },
+ { "ip6ip6", no_argument, NULL, 'E' },
+ { "num-flows", required_argument, NULL, 'n' },
+ { "rx", no_argument, NULL, 'r' },
+ { "saddr", required_argument, NULL, 's' },
+ { "smac", required_argument, NULL, 'S' },
+ { "test", required_argument, NULL, 't' },
+ { "order-check", no_argument, NULL, 'o' },
+ { "verbose", no_argument, NULL, 'v' },
+ { 0, 0, 0, 0 }
+ };
+ int c;
+
+ while ((c = getopt_long(argc, argv, "46d:D:eEi:n:rs:S:t:ov", opts, NULL)) != -1) {
+ switch (c) {
+ case '4':
+ proto = PF_INET;
+ ethhdr_proto = htons(ETH_P_IP);
+ break;
+ case '6':
+ proto = PF_INET6;
+ ethhdr_proto = htons(ETH_P_IPV6);
+ break;
+ case 'e':
+ ipip = true;
+ proto = PF_INET;
+ ethhdr_proto = htons(ETH_P_IP);
+ break;
+ case 'E':
+ ip6ip6 = true;
+ proto = PF_INET6;
+ ethhdr_proto = htons(ETH_P_IPV6);
+ break;
+ case 'd':
+ addr4_dst = addr6_dst = optarg;
+ break;
+ case 'D':
+ dmac = optarg;
+ break;
+ case 'i':
+ ifname = optarg;
+ break;
+ case 'n':
+ num_flows = atoi(optarg);
+ break;
+ case 'r':
+ tx_socket = false;
+ break;
+ case 's':
+ addr4_src = addr6_src = optarg;
+ break;
+ case 'S':
+ smac = optarg;
+ break;
+ case 't':
+ testname = optarg;
+ break;
+ case 'o':
+ order_check = true;
+ break;
+ case 'v':
+ verbose = true;
+ break;
+ default:
+ error(1, 0, "%s invalid option %c\n", __func__, c);
+ break;
+ }
+ }
+}
+
+int main(int argc, char **argv)
+{
+ parse_args(argc, argv);
+
+ if (ipip) {
+ tcp_offset = ETH_HLEN + sizeof(struct iphdr) * 2;
+ total_hdr_len = tcp_offset + sizeof(struct tcphdr);
+ } else if (ip6ip6) {
+ tcp_offset = ETH_HLEN + sizeof(struct ipv6hdr) * 2;
+ total_hdr_len = tcp_offset + sizeof(struct tcphdr);
+ } else if (proto == PF_INET) {
+ tcp_offset = ETH_HLEN + sizeof(struct iphdr);
+ total_hdr_len = tcp_offset + sizeof(struct tcphdr);
+ } else if (proto == PF_INET6) {
+ tcp_offset = ETH_HLEN + sizeof(struct ipv6hdr);
+ total_hdr_len = tcp_offset + sizeof(struct tcphdr);
+ } else {
+ error(1, 0, "Protocol family is not ipv4 or ipv6");
+ }
+
+ read_MAC(src_mac, smac);
+ read_MAC(dst_mac, dmac);
+
+ if (tx_socket) {
+ gro_sender();
+ } else {
+ /* Only the receiver exit status determines test success. */
+ gro_receiver();
+ fprintf(stderr, "Gro::%s test passed.\n", testname);
+ }
+
+ return 0;
+}
diff --git a/tools/testing/selftests/net/lib/ksft.h b/tools/testing/selftests/net/lib/ksft.h
new file mode 100644
index 000000000000..03912902a6d3
--- /dev/null
+++ b/tools/testing/selftests/net/lib/ksft.h
@@ -0,0 +1,58 @@
+/* SPDX-License-Identifier: GPL-2.0 */
+#if !defined(__NET_KSFT_H__)
+#define __NET_KSFT_H__
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+static inline void ksft_ready(void)
+{
+ const char msg[7] = "ready\n";
+ char *env_str;
+ int fd;
+
+ env_str = getenv("KSFT_READY_FD");
+ if (env_str) {
+ fd = atoi(env_str);
+ if (!fd) {
+ fprintf(stderr, "invalid KSFT_READY_FD = '%s'\n",
+ env_str);
+ return;
+ }
+ } else {
+ fd = STDOUT_FILENO;
+ }
+
+ if (write(fd, msg, sizeof(msg)) < 0)
+ perror("write()");
+ if (fd != STDOUT_FILENO)
+ close(fd);
+}
+
+static inline void ksft_wait(void)
+{
+ char *env_str;
+ char byte;
+ int fd;
+
+ env_str = getenv("KSFT_WAIT_FD");
+ if (env_str) {
+ fd = atoi(env_str);
+ if (!fd) {
+ fprintf(stderr, "invalid KSFT_WAIT_FD = '%s'\n",
+ env_str);
+ return;
+ }
+ } else {
+ /* Not running in KSFT env, wait for input from STDIN instead */
+ fd = STDIN_FILENO;
+ }
+
+ if (read(fd, &byte, sizeof(byte)) < 0)
+ perror("read()");
+ if (fd != STDIN_FILENO)
+ close(fd);
+}
+
+#endif
diff --git a/tools/testing/selftests/net/lib/ksft_setup_loopback.sh b/tools/testing/selftests/net/lib/ksft_setup_loopback.sh
new file mode 100755
index 000000000000..3defbb1919c5
--- /dev/null
+++ b/tools/testing/selftests/net/lib/ksft_setup_loopback.sh
@@ -0,0 +1,111 @@
+#!/bin/bash
+# SPDX-License-Identifier: GPL-2.0
+
+# Setup script for running ksft tests over a real interface in loopback mode.
+# This scripts replaces the historical setup_loopback.sh. It puts
+# a (presumably) real hardware interface into loopback mode, creates macvlan
+# interfaces on top and places them in a network namespace for isolation.
+#
+# NETIF env variable must be exported to indicate the real target device.
+# Note that the test will override NETIF with one of the macvlans, the
+# actual ksft test will only see the macvlans.
+#
+# Example use:
+# export NETIF=eth0
+# ./net/lib/ksft_setup_loopback.sh ./drivers/net/gro.py
+
+if [ -z "$NETIF" ]; then
+ echo "Error: NETIF variable not set"
+ exit 1
+fi
+if ! [ -d "/sys/class/net/$NETIF" ]; then
+ echo "Error: Can't find $NETIF, invalid netdevice"
+ exit 1
+fi
+
+# Save original settings for cleanup
+readonly FLUSH_PATH="/sys/class/net/${NETIF}/gro_flush_timeout"
+readonly IRQ_PATH="/sys/class/net/${NETIF}/napi_defer_hard_irqs"
+FLUSH_TIMEOUT="$(< "${FLUSH_PATH}")"
+readonly FLUSH_TIMEOUT
+HARD_IRQS="$(< "${IRQ_PATH}")"
+readonly HARD_IRQS
+
+SERVER_NS=$(mktemp -u server-XXXXXXXX)
+readonly SERVER_NS
+CLIENT_NS=$(mktemp -u client-XXXXXXXX)
+readonly CLIENT_NS
+readonly SERVER_MAC="aa:00:00:00:00:02"
+readonly CLIENT_MAC="aa:00:00:00:00:01"
+
+# ksft expects addresses to communicate with remote
+export LOCAL_V6=2001:db8:1::1
+export REMOTE_V6=2001:db8:1::2
+
+cleanup() {
+ local exit_code=$?
+
+ echo "Cleaning up..."
+
+ # Remove macvlan interfaces and namespaces
+ ip -netns "${SERVER_NS}" link del dev server 2>/dev/null || true
+ ip netns del "${SERVER_NS}" 2>/dev/null || true
+ ip -netns "${CLIENT_NS}" link del dev client 2>/dev/null || true
+ ip netns del "${CLIENT_NS}" 2>/dev/null || true
+
+ # Disable loopback
+ ethtool -K "${NETIF}" loopback off 2>/dev/null || true
+ sleep 1
+
+ echo "${FLUSH_TIMEOUT}" >"${FLUSH_PATH}"
+ echo "${HARD_IRQS}" >"${IRQ_PATH}"
+
+ exit $exit_code
+}
+
+trap cleanup EXIT INT TERM
+
+# Enable loopback mode
+echo "Enabling loopback on ${NETIF}..."
+ethtool -K "${NETIF}" loopback on || {
+ echo "Failed to enable loopback mode"
+ exit 1
+}
+# The interface may need time to get carrier back, but selftests
+# will wait for carrier, so no need to wait / sleep here.
+
+# Use timer on host to trigger the network stack
+# Also disable device interrupt to not depend on NIC interrupt
+# Reduce test flakiness caused by unexpected interrupts
+echo 100000 >"${FLUSH_PATH}"
+echo 50 >"${IRQ_PATH}"
+
+# Create server namespace with macvlan
+ip netns add "${SERVER_NS}"
+ip link add link "${NETIF}" dev server address "${SERVER_MAC}" type macvlan
+ip link set dev server netns "${SERVER_NS}"
+ip -netns "${SERVER_NS}" link set dev server up
+ip -netns "${SERVER_NS}" addr add $LOCAL_V6/64 dev server
+ip -netns "${SERVER_NS}" link set dev lo up
+
+# Create client namespace with macvlan
+ip netns add "${CLIENT_NS}"
+ip link add link "${NETIF}" dev client address "${CLIENT_MAC}" type macvlan
+ip link set dev client netns "${CLIENT_NS}"
+ip -netns "${CLIENT_NS}" link set dev client up
+ip -netns "${CLIENT_NS}" addr add $REMOTE_V6/64 dev client
+ip -netns "${CLIENT_NS}" link set dev lo up
+
+echo "Setup complete!"
+echo " Device: ${NETIF}"
+echo " Server NS: ${SERVER_NS}"
+echo " Client NS: ${CLIENT_NS}"
+echo ""
+
+# Setup environment variables for tests
+export NETIF=server
+export REMOTE_TYPE=netns
+export REMOTE_ARGS="${CLIENT_NS}"
+
+# Run the command
+ip netns exec "${SERVER_NS}" "$@"
diff --git a/tools/testing/selftests/net/lib/py/__init__.py b/tools/testing/selftests/net/lib/py/__init__.py
index 54d8f5eba810..7c81d86a7e97 100644
--- a/tools/testing/selftests/net/lib/py/__init__.py
+++ b/tools/testing/selftests/net/lib/py/__init__.py
@@ -1,9 +1,37 @@
# SPDX-License-Identifier: GPL-2.0
+"""
+Python selftest helpers for netdev.
+"""
+
from .consts import KSRC
-from .ksft import *
-from .netns import NetNS
-from .nsim import *
-from .utils import *
-from .ynl import NlError, YnlFamily, EthtoolFamily, NetdevFamily, RtnlFamily
-from .ynl import NetshaperFamily
+from .ksft import KsftFailEx, KsftSkipEx, KsftXfailEx, ksft_pr, ksft_eq, \
+ ksft_ne, ksft_true, ksft_not_none, ksft_in, ksft_not_in, ksft_is, \
+ ksft_ge, ksft_gt, ksft_lt, ksft_raises, ksft_busy_wait, \
+ ktap_result, ksft_disruptive, ksft_setup, ksft_run, ksft_exit, \
+ ksft_variants, KsftNamedVariant
+from .netns import NetNS, NetNSEnter
+from .nsim import NetdevSim, NetdevSimDev
+from .utils import CmdExitFailure, fd_read_timeout, cmd, bkg, defer, \
+ bpftool, ip, ethtool, bpftrace, rand_port, rand_ports, wait_port_listen, \
+ wait_file, tool
+from .bpf import bpf_map_set, bpf_map_dump, bpf_prog_map_ids
+from .ynl import NlError, NlctrlFamily, YnlFamily, \
+ EthtoolFamily, NetdevFamily, RtnlFamily, RtnlAddrFamily
+from .ynl import NetshaperFamily, DevlinkFamily, PSPFamily, Netlink
+
+__all__ = ["KSRC",
+ "KsftFailEx", "KsftSkipEx", "KsftXfailEx", "ksft_pr", "ksft_eq",
+ "ksft_ne", "ksft_true", "ksft_not_none", "ksft_in", "ksft_not_in",
+ "ksft_is", "ksft_ge", "ksft_gt", "ksft_lt", "ksft_raises",
+ "ksft_busy_wait", "ktap_result", "ksft_disruptive", "ksft_setup",
+ "ksft_run", "ksft_exit", "ksft_variants", "KsftNamedVariant",
+ "NetNS", "NetNSEnter",
+ "CmdExitFailure", "fd_read_timeout", "cmd", "bkg", "defer",
+ "bpftool", "ip", "ethtool", "bpftrace", "rand_port", "rand_ports",
+ "wait_port_listen", "wait_file", "tool",
+ "bpf_map_set", "bpf_map_dump", "bpf_prog_map_ids",
+ "NetdevSim", "NetdevSimDev",
+ "NetshaperFamily", "DevlinkFamily", "PSPFamily", "NlError",
+ "YnlFamily", "EthtoolFamily", "NetdevFamily", "RtnlFamily",
+ "NlctrlFamily", "RtnlAddrFamily", "Netlink"]
diff --git a/tools/testing/selftests/net/lib/py/bpf.py b/tools/testing/selftests/net/lib/py/bpf.py
new file mode 100644
index 000000000000..beb6bf2896a8
--- /dev/null
+++ b/tools/testing/selftests/net/lib/py/bpf.py
@@ -0,0 +1,68 @@
+# SPDX-License-Identifier: GPL-2.0
+
+"""
+BPF helper utilities for kernel selftests.
+
+Provides common operations for interacting with BPF maps and programs
+via bpftool, used by XDP and other BPF-based test files.
+"""
+
+from .utils import bpftool
+
+def _format_hex_bytes(value):
+ """
+ Helper function that converts an integer into a formatted hexadecimal byte string.
+
+ Args:
+ value: An integer representing the number to be converted.
+
+ Returns:
+ A string representing hexadecimal equivalent of value, with bytes separated by spaces.
+ """
+ hex_str = value.to_bytes(4, byteorder='little', signed=True)
+ return ' '.join(f'{byte:02x}' for byte in hex_str)
+
+
+def bpf_map_set(map_name, key, value):
+ """
+ Updates an XDP map with a given key-value pair using bpftool.
+
+ Args:
+ map_name: The name of the XDP map to update.
+ key: The key to update in the map, formatted as a hexadecimal string.
+ value: The value to associate with the key, formatted as a hexadecimal string.
+ """
+ key_formatted = _format_hex_bytes(key)
+ value_formatted = _format_hex_bytes(value)
+ bpftool(
+ f"map update name {map_name} key hex {key_formatted} value hex {value_formatted}"
+ )
+
+def bpf_map_dump(map_id):
+ """Dump all entries of a BPF array map.
+
+ Args:
+ map_id: Numeric map ID (as returned by bpftool prog show).
+
+ Returns:
+ A dict mapping formatted key (int) to formatted value (int).
+ """
+ raw = bpftool(f"map dump id {map_id}", json=True)
+ return {e["formatted"]["key"]: e["formatted"]["value"] for e in raw}
+
+
+def bpf_prog_map_ids(prog_id):
+ """Get the map name-to-ID mapping for a loaded BPF program.
+
+ Args:
+ prog_id: Numeric program ID.
+
+ Returns:
+ A dict mapping map name (str) to map ID (int).
+ """
+ map_ids = bpftool(f"prog show id {prog_id}", json=True)["map_ids"]
+ maps = {}
+ for mid in map_ids:
+ name = bpftool(f"map show id {mid}", json=True)["name"]
+ maps[name] = mid
+ return maps
diff --git a/tools/testing/selftests/net/lib/py/ksft.py b/tools/testing/selftests/net/lib/py/ksft.py
index 3efe005436cd..81287c2daff0 100644
--- a/tools/testing/selftests/net/lib/py/ksft.py
+++ b/tools/testing/selftests/net/lib/py/ksft.py
@@ -1,13 +1,17 @@
# SPDX-License-Identifier: GPL-2.0
-import builtins
+import fnmatch
import functools
+import getopt
import inspect
+import os
+import signal
import sys
import time
import traceback
+from collections import namedtuple
from .consts import KSFT_MAIN_NAME
-from .utils import global_defer_queue
+from . import utils
KSFT_RESULT = None
KSFT_RESULT_ALL = True
@@ -26,8 +30,67 @@ class KsftXfailEx(Exception):
pass
+class KsftTerminate(KeyboardInterrupt):
+ pass
+
+
+class _KsftArgs:
+ def __init__(self):
+ self.list_tests = False
+ self.filters = []
+
+ try:
+ opts, _ = getopt.getopt(sys.argv[1:], 'hlt:T:')
+ except getopt.GetoptError as e:
+ print(e, file=sys.stderr)
+ sys.exit(1)
+
+ for opt, val in opts:
+ if opt == '-h':
+ print(f"Usage: {sys.argv[0]} [-h|-l] [-t|-T name]\n"
+ f"\t-h print help\n"
+ f"\t-l list tests (filtered, if filters were specified)\n"
+ f"\t-t name include test\n"
+ f"\t-T name exclude test",
+ file=sys.stderr)
+ sys.exit(0)
+ elif opt == '-l':
+ self.list_tests = True
+ elif opt == '-t':
+ self.filters.append((True, val))
+ elif opt == '-T':
+ self.filters.append((False, val))
+
+
+@functools.lru_cache()
+def _ksft_supports_color():
+ if os.environ.get("NO_COLOR") is not None:
+ return False
+ if not hasattr(sys.stdout, "isatty") or not sys.stdout.isatty():
+ return False
+ if os.environ.get("TERM") == "dumb":
+ return False
+ return True
+
+
def ksft_pr(*objs, **kwargs):
- print("#", *objs, **kwargs)
+ """
+ Print logs to stdout.
+
+ Behaves like print() but log lines will be prefixed
+ with # to prevent breaking the TAP output formatting.
+
+ Extra arguments (on top of what print() supports):
+ line_pfx - add extra string before each line
+ """
+ sep = kwargs.pop("sep", " ")
+ pfx = kwargs.pop("line_pfx", "")
+ pfx = "#" + (" " + pfx if pfx else "")
+ kwargs["flush"] = True
+
+ text = sep.join(str(obj) for obj in objs)
+ prefixed = f"\n{pfx} ".join(text.split('\n'))
+ print(pfx, prefixed, **kwargs)
def _fail(*args):
@@ -66,11 +129,21 @@ def ksft_true(a, comment=""):
_fail("Check failed", a, "does not eval to True", comment)
+def ksft_not_none(a, comment=""):
+ if a is None:
+ _fail("Check failed", a, "is None", comment)
+
+
def ksft_in(a, b, comment=""):
if a not in b:
_fail("Check failed", a, "not in", b, comment)
+def ksft_not_in(a, b, comment=""):
+ if a in b:
+ _fail("Check failed", a, "in", b, comment)
+
+
def ksft_is(a, b, comment=""):
if a is not b:
_fail("Check failed", a, "is not", b, comment)
@@ -81,6 +154,11 @@ def ksft_ge(a, b, comment=""):
_fail("Check failed", a, "<", b, comment)
+def ksft_gt(a, b, comment=""):
+ if a <= b:
+ _fail("Check failed", a, "<=", b, comment)
+
+
def ksft_lt(a, b, comment=""):
if a >= b:
_fail("Check failed", a, ">=", b, comment)
@@ -115,7 +193,7 @@ def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""):
time.sleep(sleep)
-def ktap_result(ok, cnt=1, case="", comment=""):
+def ktap_result(ok, cnt=1, case_name="", comment=""):
global KSFT_RESULT_ALL
KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok
@@ -125,31 +203,46 @@ def ktap_result(ok, cnt=1, case="", comment=""):
res += "ok "
res += str(cnt) + " "
res += KSFT_MAIN_NAME
- if case:
- res += "." + str(case.__name__)
+ if case_name:
+ res += "." + case_name
if comment:
res += " # " + comment
- print(res)
+ if _ksft_supports_color():
+ if comment.startswith(("SKIP", "XFAIL")):
+ color = "\033[33m"
+ elif ok:
+ color = "\033[32m"
+ else:
+ color = "\033[31m"
+ res = color + res + "\033[0m"
+ print(res, flush=True)
+
+
+def _ksft_defer_arm(state):
+ """ Allow or disallow the use of defer() """
+ utils.GLOBAL_DEFER_ARMED = state
def ksft_flush_defer():
global KSFT_RESULT
i = 0
- qlen_start = len(global_defer_queue)
- while global_defer_queue:
+ qlen_start = len(utils.GLOBAL_DEFER_QUEUE)
+ while utils.GLOBAL_DEFER_QUEUE:
i += 1
- entry = global_defer_queue.pop()
+ entry = utils.GLOBAL_DEFER_QUEUE.pop()
try:
entry.exec_only()
- except:
+ except Exception:
ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!")
- tb = traceback.format_exc()
- for line in tb.strip().split('\n'):
- ksft_pr("Defer Exception|", line)
+ ksft_pr(traceback.format_exc(), line_pfx="Defer Exception|")
KSFT_RESULT = False
+KsftCaseFunction = namedtuple("KsftCaseFunction",
+ ['name', 'original_func', 'variants'])
+
+
def ksft_disruptive(func):
"""
Decorator that marks the test as disruptive (e.g. the test
@@ -160,11 +253,47 @@ def ksft_disruptive(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not KSFT_DISRUPTIVE:
- raise KsftSkipEx(f"marked as disruptive")
+ raise KsftSkipEx("marked as disruptive")
return func(*args, **kwargs)
return wrapper
+class KsftNamedVariant:
+ """ Named string name + argument list tuple for @ksft_variants """
+
+ def __init__(self, name, *params):
+ self.params = params
+ self.name = name or "_".join([str(x) for x in self.params])
+
+
+def ksft_variants(params):
+ """
+ Decorator defining the sets of inputs for a test.
+ The parameters will be included in the name of the resulting sub-case.
+ Parameters can be either single object, tuple or a KsftNamedVariant.
+ The argument can be a list or a generator.
+
+ Example:
+
+ @ksft_variants([
+ (1, "a"),
+ (2, "b"),
+ KsftNamedVariant("three", 3, "c"),
+ ])
+ def my_case(cfg, a, b):
+ pass # ...
+
+ ksft_run(cases=[my_case], args=(cfg, ))
+
+ Will generate cases:
+ my_case.1_a
+ my_case.2_b
+ my_case.three
+ """
+
+ return lambda func: KsftCaseFunction(func.__name__, func, params)
+
+
def ksft_setup(env):
"""
Setup test framework global state from the environment.
@@ -178,7 +307,7 @@ def ksft_setup(env):
return False
try:
return bool(int(value))
- except:
+ except Exception:
raise Exception(f"failed to parse {name}")
if "DISRUPTIVE" in env:
@@ -188,9 +317,42 @@ def ksft_setup(env):
return env
-def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
+def _ksft_intr(signum, frame):
+ # ksft runner.sh sends 2 SIGTERMs in a row on a timeout
+ # if we don't ignore the second one it will stop us from handling cleanup
+ global term_cnt
+ term_cnt += 1
+ if term_cnt == 1:
+ raise KsftTerminate()
+ else:
+ ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...")
+
+
+def _ksft_name_matches(name, pattern):
+ if '*' in pattern or '?' in pattern or '[' in pattern:
+ return fnmatch.fnmatchcase(name, pattern)
+ return name == pattern
+
+
+def _ksft_test_enabled(name, filters):
+ has_positive = False
+ for include, pattern in filters:
+ has_positive |= include
+ if _ksft_name_matches(name, pattern):
+ return include
+ return not has_positive
+
+
+def _ksft_generate_test_cases(cases, globs, case_pfx, args, cli_args):
+ """Generate a filtered list of (func, args, name) tuples.
+
+ If -l is given, prints matching test names and exits.
+ """
+
cases = cases or []
+ test_cases = []
+ # If using the globs method find all relevant functions
if globs and case_pfx:
for key, value in globs.items():
if not callable(value):
@@ -200,22 +362,62 @@ def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
cases.append(value)
break
- totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0}
+ for func in cases:
+ if isinstance(func, KsftCaseFunction):
+ # Parametrized test - create case for each param
+ for param in func.variants:
+ if not isinstance(param, KsftNamedVariant):
+ if not isinstance(param, tuple):
+ param = (param, )
+ param = KsftNamedVariant(None, *param)
+
+ test_cases.append((func.original_func,
+ (*args, *param.params),
+ func.name + "." + param.name))
+ else:
+ test_cases.append((func, args, func.__name__))
- print("KTAP version 1")
- print("1.." + str(len(cases)))
+ if cli_args.filters:
+ test_cases = [tc for tc in test_cases
+ if _ksft_test_enabled(tc[2], cli_args.filters)]
+
+ if cli_args.list_tests:
+ for _, _, name in test_cases:
+ print(name)
+ sys.exit(0)
+
+ return test_cases
+
+
+def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
+ cli_args = _KsftArgs()
+ test_cases = _ksft_generate_test_cases(cases, globs, case_pfx, args,
+ cli_args)
+
+ global term_cnt
+ term_cnt = 0
+ prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr)
+
+ totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0}
global KSFT_RESULT
+ if KSFT_RESULT is not None:
+ raise RuntimeError("ksft_run() can't be called multiple times.")
+
+ print("TAP version 13", flush=True)
+ print("1.." + str(len(test_cases)), flush=True)
+
cnt = 0
stop = False
- for case in cases:
+ for func, args, name in test_cases:
KSFT_RESULT = True
cnt += 1
comment = ""
cnt_key = ""
+ _ksft_defer_arm(True)
try:
- case(*args)
+ func(*args)
except KsftSkipEx as e:
comment = "SKIP " + str(e)
cnt_key = 'skip'
@@ -224,25 +426,38 @@ def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
cnt_key = 'xfail'
except BaseException as e:
stop |= isinstance(e, KeyboardInterrupt)
- tb = traceback.format_exc()
- for line in tb.strip().split('\n'):
- ksft_pr("Exception|", line)
+ ksft_pr(traceback.format_exc(), line_pfx="Exception|")
if stop:
- ksft_pr("Stopping tests due to KeyboardInterrupt.")
+ ksft_pr(f"Stopping tests due to {type(e).__name__}.")
KSFT_RESULT = False
cnt_key = 'fail'
+ _ksft_defer_arm(False)
- ksft_flush_defer()
+ try:
+ ksft_flush_defer()
+ except BaseException as e:
+ ksft_pr(traceback.format_exc(), line_pfx="Exception|")
+ if isinstance(e, KeyboardInterrupt):
+ ksft_pr()
+ ksft_pr("WARN: defer() interrupted, cleanup may be incomplete.")
+ ksft_pr(" Attempting to finish cleanup before exiting.")
+ ksft_pr(" Interrupt again to exit immediately.")
+ ksft_pr()
+ stop = True
+ # Flush was interrupted, try to finish the job best we can
+ ksft_flush_defer()
if not cnt_key:
cnt_key = 'pass' if KSFT_RESULT else 'fail'
- ktap_result(KSFT_RESULT, cnt, case, comment=comment)
+ ktap_result(KSFT_RESULT, cnt, name, comment=comment)
totals[cnt_key] += 1
if stop:
break
+ signal.signal(signal.SIGTERM, prev_sigterm)
+
print(
f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0"
)
diff --git a/tools/testing/selftests/net/lib/py/netns.py b/tools/testing/selftests/net/lib/py/netns.py
index ecff85f9074f..8e9317044eef 100644
--- a/tools/testing/selftests/net/lib/py/netns.py
+++ b/tools/testing/selftests/net/lib/py/netns.py
@@ -1,9 +1,12 @@
# SPDX-License-Identifier: GPL-2.0
from .utils import ip
+import ctypes
import random
import string
+libc = ctypes.cdll.LoadLibrary('libc.so.6')
+
class NetNS:
def __init__(self, name=None):
@@ -29,3 +32,18 @@ class NetNS:
def __repr__(self):
return f"NetNS({self.name})"
+
+
+class NetNSEnter:
+ def __init__(self, ns_name):
+ self.ns_path = f"/run/netns/{ns_name}"
+
+ def __enter__(self):
+ self.saved = open("/proc/thread-self/ns/net")
+ with open(self.ns_path) as ns_file:
+ libc.setns(ns_file.fileno(), 0)
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ libc.setns(self.saved.fileno(), 0)
+ self.saved.close()
diff --git a/tools/testing/selftests/net/lib/py/nsim.py b/tools/testing/selftests/net/lib/py/nsim.py
index 1a8cbe9acc48..7c640ed64c0b 100644
--- a/tools/testing/selftests/net/lib/py/nsim.py
+++ b/tools/testing/selftests/net/lib/py/nsim.py
@@ -27,7 +27,7 @@ class NetdevSim:
self.port_index = port_index
self.ns = ns
self.dfs_dir = "%s/ports/%u/" % (nsimdev.dfs_dir, port_index)
- ret = ip("-j link show dev %s" % ifname, ns=ns)
+ ret = ip("-d -j link show dev %s" % ifname, ns=ns)
self.dev = json.loads(ret.stdout)[0]
self.ifindex = self.dev["ifindex"]
diff --git a/tools/testing/selftests/net/lib/py/utils.py b/tools/testing/selftests/net/lib/py/utils.py
index 9e3bcddcf3e8..6c44a3d2bbf7 100644
--- a/tools/testing/selftests/net/lib/py/utils.py
+++ b/tools/testing/selftests/net/lib/py/utils.py
@@ -1,80 +1,193 @@
# SPDX-License-Identifier: GPL-2.0
-import errno
import json as _json
-import random
+import os
import re
+import select
import socket
import subprocess
import time
+class CmdInitFailure(Exception):
+ """ Command failed to start. Only raised by bkg(). """
+ def __init__(self, msg, cmd_obj):
+ super().__init__(msg + "\n" + repr(cmd_obj))
+ self.cmd = cmd_obj
+
+
class CmdExitFailure(Exception):
+ """ Command failed (returned non-zero exit code). """
def __init__(self, msg, cmd_obj):
- super().__init__(msg)
+ super().__init__(msg + "\n" + repr(cmd_obj))
self.cmd = cmd_obj
+def fd_read_timeout(fd, timeout):
+ rlist, _, _ = select.select([fd], [], [], timeout)
+ if rlist:
+ return os.read(fd, 1024)
+ raise TimeoutError("Timeout waiting for fd read")
+
+
class cmd:
- def __init__(self, comm, shell=True, fail=True, ns=None, background=False, host=None, timeout=5):
+ """
+ Execute a command on local or remote host.
+
+ @shell defaults to false, and class will try to split @comm into a list
+ if it's a string with spaces.
+
+ Use bkg() instead to run a command in the background.
+ """
+ def __init__(self, comm, shell=None, fail=True, ns=None, background=False,
+ host=None, timeout=5, ksft_ready=None, ksft_wait=None):
if ns:
comm = f'ip netns exec {ns} ' + comm
self.stdout = None
self.stderr = None
self.ret = None
+ self.ksft_term_fd = None
+ self.host = host
self.comm = comm
+
if host:
self.proc = host.cmd(comm)
else:
+ # If user doesn't explicitly request shell try to avoid it.
+ if shell is None and isinstance(comm, str) and ' ' in comm:
+ comm = comm.split()
+
+ # ksft_wait lets us wait for the background process to fully start,
+ # we pass an FD to the child process, and wait for it to write back.
+ # Similarly term_fd tells child it's time to exit.
+ pass_fds = []
+ env = os.environ.copy()
+ if ksft_wait is not None:
+ wait_fd, self.ksft_term_fd = os.pipe()
+ pass_fds.append(wait_fd)
+ env["KSFT_WAIT_FD"] = str(wait_fd)
+ ksft_ready = True # ksft_wait implies ready
+ if ksft_ready is not None:
+ rfd, ready_fd = os.pipe()
+ pass_fds.append(ready_fd)
+ env["KSFT_READY_FD"] = str(ready_fd)
+
self.proc = subprocess.Popen(comm, shell=shell, stdout=subprocess.PIPE,
- stderr=subprocess.PIPE)
+ stderr=subprocess.PIPE, pass_fds=pass_fds,
+ env=env)
+ if ksft_wait is not None:
+ os.close(wait_fd)
+ if ksft_ready is not None:
+ os.close(ready_fd)
+ msg = fd_read_timeout(rfd, ksft_wait)
+ os.close(rfd)
+ if not msg:
+ terminate = self.proc.poll() is None
+ self._process_terminate(terminate=terminate, timeout=1)
+ raise CmdInitFailure("Did not receive ready message", self)
if not background:
self.process(terminate=False, fail=fail, timeout=timeout)
- def process(self, terminate=True, fail=None, timeout=5):
- if fail is None:
- fail = not terminate
-
+ def _process_terminate(self, terminate, timeout):
if terminate:
self.proc.terminate()
- stdout, stderr = self.proc.communicate(timeout)
+ stdout, stderr = self.proc.communicate(timeout=timeout)
self.stdout = stdout.decode("utf-8")
self.stderr = stderr.decode("utf-8")
self.proc.stdout.close()
self.proc.stderr.close()
self.ret = self.proc.returncode
+ return stdout, stderr
+
+ def process(self, terminate=True, fail=None, timeout=5):
+ if fail is None:
+ fail = not terminate
+
+ if self.ksft_term_fd:
+ os.write(self.ksft_term_fd, b"1")
+
+ stdout, stderr = self._process_terminate(terminate=terminate,
+ timeout=timeout)
if self.proc.returncode != 0 and fail:
if len(stderr) > 0 and stderr[-1] == "\n":
stderr = stderr[:-1]
- raise CmdExitFailure("Command failed: %s\nSTDOUT: %s\nSTDERR: %s" %
- (self.proc.args, stdout, stderr), self)
+ raise CmdExitFailure("Command failed", self)
+
+ def __repr__(self):
+ def str_fmt(name, s):
+ name += ': '
+ return (name + s.strip().replace('\n', '\n' + ' ' * len(name)))
+
+ ret = "CMD"
+ if self.host:
+ ret += "[remote]"
+ if self.ret is None:
+ ret += f" (unterminated): {self.comm}\n"
+ elif self.ret == 0:
+ ret += f" (success): {self.comm}\n"
+ else:
+ ret += f": {self.comm}\n"
+ ret += f" EXIT: {self.ret}\n"
+ if self.stdout:
+ ret += str_fmt(" STDOUT", self.stdout) + "\n"
+ if self.stderr:
+ ret += str_fmt(" STDERR", self.stderr) + "\n"
+ return ret.strip()
class bkg(cmd):
- def __init__(self, comm, shell=True, fail=None, ns=None, host=None,
- exit_wait=False):
+ """
+ Run a command in the background.
+
+ Examples usage:
+
+ Run a command on remote host, and wait for it to finish.
+ This is usually paired with wait_port_listen() to make sure
+ the command has initialized:
+
+ with bkg("socat ...", exit_wait=True, host=cfg.remote) as nc:
+ ...
+
+ Run a command and expect it to let us know that it's ready
+ by writing to a special file descriptor passed via KSFT_READY_FD.
+ Command will be terminated when we exit the context manager:
+
+ with bkg("my_binary", ksft_wait=5):
+ """
+ def __init__(self, comm, shell=None, fail=None, ns=None, host=None,
+ exit_wait=False, ksft_ready=None, ksft_wait=None):
super().__init__(comm, background=True,
- shell=shell, fail=fail, ns=ns, host=host)
- self.terminate = not exit_wait
+ shell=shell, fail=fail, ns=ns, host=host,
+ ksft_ready=ksft_ready, ksft_wait=ksft_wait)
+ self.terminate = not exit_wait and not ksft_wait
+ self._exit_wait = exit_wait
self.check_fail = fail
+ if shell and self.terminate:
+ print("# Warning: combining shell and terminate is risky!")
+ print("# SIGTERM may not reach the child on zsh/ksh!")
+
def __enter__(self):
return self
def __exit__(self, ex_type, ex_value, ex_tb):
- return self.process(terminate=self.terminate, fail=self.check_fail)
+ terminate = self.terminate
+ # Force termination on exception, but only if bkg() didn't already exit
+ # since forcing termination silences failures with fail=None
+ if self.proc.poll() is None:
+ terminate = terminate or (self._exit_wait and ex_type is not None)
+ return self.process(terminate=terminate, fail=self.check_fail)
-global_defer_queue = []
+GLOBAL_DEFER_QUEUE = []
+GLOBAL_DEFER_ARMED = False
class defer:
def __init__(self, func, *args, **kwargs):
- global global_defer_queue
-
if not callable(func):
raise Exception("defer created with un-callable object, did you call the function instead of passing its name?")
@@ -82,7 +195,9 @@ class defer:
self.args = args
self.kwargs = kwargs
- self._queue = global_defer_queue
+ if not GLOBAL_DEFER_ARMED:
+ raise Exception("defer queue not armed, did you use defer() outside of a test case?")
+ self._queue = GLOBAL_DEFER_QUEUE
self._queue.append(self)
def __enter__(self):
@@ -113,6 +228,10 @@ def tool(name, args, json=None, ns=None, host=None):
return cmd_obj
+def bpftool(args, json=None, ns=None, host=None):
+ return tool('bpftool', args, json=json, ns=ns, host=host)
+
+
def ip(args, json=None, ns=None, host=None):
if ns:
args = f'-netns {ns} ' + args
@@ -123,20 +242,67 @@ def ethtool(args, json=None, ns=None, host=None):
return tool('ethtool', args, json=json, ns=ns, host=host)
-def rand_port():
+def bpftrace(expr, json=None, ns=None, host=None, timeout=None):
+ """
+ Run bpftrace and return map data (if json=True).
+ The output of bpftrace is inconvenient, so the helper converts
+ to a dict indexed by map name, e.g.:
+ {
+ "@": { ... },
+ "@map2": { ... },
+ }
+ """
+ cmd_arr = ['bpftrace']
+ # Throw in --quiet if json, otherwise the output has two objects
+ if json:
+ cmd_arr += ['-f', 'json', '-q']
+ if timeout:
+ expr += ' interval:s:' + str(timeout) + ' { exit(); }'
+ timeout += 20
+ cmd_arr += ['-e', expr]
+ cmd_obj = cmd(cmd_arr, ns=ns, host=host, shell=False, timeout=timeout)
+ if json:
+ # bpftrace prints objects as lines
+ ret = {}
+ for l in cmd_obj.stdout.split('\n'):
+ if not l.strip():
+ continue
+ one = _json.loads(l)
+ if one.get('type') != 'map':
+ continue
+ for k, v in one["data"].items():
+ if k.startswith('@'):
+ k = k.lstrip('@')
+ ret[k] = v
+ return ret
+ return cmd_obj
+
+
+def rand_port(stype=socket.SOCK_STREAM):
+ """
+ Get a random unprivileged port.
"""
- Get a random unprivileged port, try to make sure it's not already used.
+ return rand_ports(1, stype)[0]
+
+
+def rand_ports(count, stype=socket.SOCK_STREAM):
"""
- for _ in range(1000):
- port = random.randint(10000, 65535)
- try:
- with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
- s.bind(("", port))
- return port
- except OSError as e:
- if e.errno != errno.EADDRINUSE:
- raise
- raise Exception("Can't find any free unprivileged port")
+ Get a unique set of random unprivileged ports.
+ """
+ sockets = []
+ ports = []
+
+ try:
+ for _ in range(count):
+ s = socket.socket(socket.AF_INET6, stype)
+ sockets.append(s)
+ s.bind(("", 0))
+ ports.append(s.getsockname()[1])
+ finally:
+ for s in sockets:
+ s.close()
+
+ return ports
def wait_port_listen(port, proto="tcp", ns=None, host=None, sleep=0.005, deadline=5):
@@ -155,3 +321,21 @@ def wait_port_listen(port, proto="tcp", ns=None, host=None, sleep=0.005, deadlin
if time.monotonic() > end:
raise Exception("Waiting for port listen timed out")
time.sleep(sleep)
+
+
+def wait_file(fname, test_fn, sleep=0.005, deadline=5, encoding='utf-8'):
+ """
+ Wait for file contents on the local system to satisfy a condition.
+ test_fn() should take one argument (file contents) and return whether
+ condition is met.
+ """
+ end = time.monotonic() + deadline
+
+ with open(fname, "r", encoding=encoding) as fp:
+ while True:
+ if test_fn(fp.read()):
+ break
+ fp.seek(0)
+ if time.monotonic() > end:
+ raise TimeoutError("Wait for file contents failed", fname)
+ time.sleep(sleep)
diff --git a/tools/testing/selftests/net/lib/py/ynl.py b/tools/testing/selftests/net/lib/py/ynl.py
index ad1e36baee2a..2e567062aa6c 100644
--- a/tools/testing/selftests/net/lib/py/ynl.py
+++ b/tools/testing/selftests/net/lib/py/ynl.py
@@ -13,20 +13,27 @@ try:
SPEC_PATH = KSFT_DIR / "net/lib/specs"
sys.path.append(tools_full_path.as_posix())
- from net.lib.ynl.pyynl.lib import YnlFamily, NlError
+ from net.lib.ynl.pyynl.lib import YnlFamily, NlError, NlPolicy, Netlink
else:
# Running in tree
tools_full_path = KSRC / "tools"
SPEC_PATH = KSRC / "Documentation/netlink/specs"
sys.path.append(tools_full_path.as_posix())
- from net.ynl.pyynl.lib import YnlFamily, NlError
+ from net.ynl.pyynl.lib import YnlFamily, NlError, NlPolicy, Netlink
except ModuleNotFoundError as e:
ksft_pr("Failed importing `ynl` library from kernel sources")
ksft_pr(str(e))
ktap_result(True, comment="SKIP")
sys.exit(4)
+__all__ = [
+ "NlError", "NlPolicy", "Netlink", "YnlFamily", "SPEC_PATH",
+ "EthtoolFamily", "RtnlFamily", "RtnlAddrFamily",
+ "NetdevFamily", "NetshaperFamily", "NlctrlFamily", "DevlinkFamily",
+ "PSPFamily",
+]
+
#
# Wrapper classes, loading the right specs
# Set schema='' to avoid jsonschema validation, it's slow
@@ -39,9 +46,13 @@ class EthtoolFamily(YnlFamily):
class RtnlFamily(YnlFamily):
def __init__(self, recv_size=0):
- super().__init__((SPEC_PATH / Path('rt_link.yaml')).as_posix(),
+ super().__init__((SPEC_PATH / Path('rt-link.yaml')).as_posix(),
schema='', recv_size=recv_size)
+class RtnlAddrFamily(YnlFamily):
+ def __init__(self, recv_size=0):
+ super().__init__((SPEC_PATH / Path('rt-addr.yaml')).as_posix(),
+ schema='', recv_size=recv_size)
class NetdevFamily(YnlFamily):
def __init__(self, recv_size=0):
@@ -52,3 +63,20 @@ class NetshaperFamily(YnlFamily):
def __init__(self, recv_size=0):
super().__init__((SPEC_PATH / Path('net_shaper.yaml')).as_posix(),
schema='', recv_size=recv_size)
+
+
+class NlctrlFamily(YnlFamily):
+ def __init__(self, recv_size=0):
+ super().__init__((SPEC_PATH / Path('nlctrl.yaml')).as_posix(),
+ schema='', recv_size=recv_size)
+
+
+class DevlinkFamily(YnlFamily):
+ def __init__(self, recv_size=0):
+ super().__init__((SPEC_PATH / Path('devlink.yaml')).as_posix(),
+ schema='', recv_size=recv_size)
+
+class PSPFamily(YnlFamily):
+ def __init__(self, recv_size=0):
+ super().__init__((SPEC_PATH / Path('psp.yaml')).as_posix(),
+ schema='', recv_size=recv_size)
diff --git a/tools/testing/selftests/net/lib/sh/defer.sh b/tools/testing/selftests/net/lib/sh/defer.sh
index 082f5d38321b..47ab78c4d465 100644
--- a/tools/testing/selftests/net/lib/sh/defer.sh
+++ b/tools/testing/selftests/net/lib/sh/defer.sh
@@ -1,6 +1,10 @@
#!/bin/bash
# SPDX-License-Identifier: GPL-2.0
+# Whether to pause and allow debugging when an executed deferred command has a
+# non-zero exit code.
+: "${DEFER_PAUSE_ON_FAIL:=no}"
+
# map[(scope_id,track,cleanup_id) -> cleanup_command]
# track={d=default | p=priority}
declare -A __DEFER__JOBS
@@ -38,8 +42,20 @@ __defer__run()
local track=$1; shift
local defer_ix=$1; shift
local defer_key=$(__defer__defer_key $track $defer_ix)
+ local ret
+
+ eval ${__DEFER__JOBS[$defer_key]}
+ ret=$?
+
+ if [[ "$DEFER_PAUSE_ON_FAIL" == yes && "$ret" -ne 0 ]]; then
+ echo "Deferred command (track $track index $defer_ix):"
+ echo " ${__DEFER__JOBS[$defer_key]}"
+ echo "... ended with an exit status of $ret"
+ echo "Hit enter to continue, 'q' to quit"
+ read a
+ [[ "$a" == q ]] && exit 1
+ fi
- ${__DEFER__JOBS[$defer_key]}
unset __DEFER__JOBS[$defer_key]
}
@@ -49,7 +65,7 @@ __defer__schedule()
local ndefers=$(__defer__ndefers $track)
local ndefers_key=$(__defer__ndefer_key $track)
local defer_key=$(__defer__defer_key $track $ndefers)
- local defer="$@"
+ local defer="${@@Q}"
__DEFER__JOBS[$defer_key]="$defer"
__DEFER__NJOBS[$ndefers_key]=$((ndefers + 1))
diff --git a/tools/testing/selftests/net/lib/xdp_dummy.bpf.c b/tools/testing/selftests/net/lib/xdp_dummy.bpf.c
new file mode 100644
index 000000000000..e73fab3edd9f
--- /dev/null
+++ b/tools/testing/selftests/net/lib/xdp_dummy.bpf.c
@@ -0,0 +1,19 @@
+// SPDX-License-Identifier: GPL-2.0
+
+#define KBUILD_MODNAME "xdp_dummy"
+#include <linux/bpf.h>
+#include <bpf/bpf_helpers.h>
+
+SEC("xdp")
+int xdp_dummy_prog(struct xdp_md *ctx)
+{
+ return XDP_PASS;
+}
+
+SEC("xdp.frags")
+int xdp_dummy_prog_frags(struct xdp_md *ctx)
+{
+ return XDP_PASS;
+}
+
+char _license[] SEC("license") = "GPL";
diff --git a/tools/testing/selftests/net/lib/xdp_helper.c b/tools/testing/selftests/net/lib/xdp_helper.c
new file mode 100644
index 000000000000..eb025a9f35b1
--- /dev/null
+++ b/tools/testing/selftests/net/lib/xdp_helper.c
@@ -0,0 +1,131 @@
+// SPDX-License-Identifier: GPL-2.0
+#include <errno.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <unistd.h>
+#include <sys/mman.h>
+#include <sys/socket.h>
+#include <linux/if_xdp.h>
+#include <linux/if_link.h>
+#include <net/if.h>
+#include <inttypes.h>
+
+#include "ksft.h"
+
+#define UMEM_SZ (1U << 16)
+#define NUM_DESC (UMEM_SZ / 2048)
+
+
+static void print_usage(const char *bin)
+{
+ fprintf(stderr, "Usage: %s ifindex queue_id [-z]\n\n"
+ "where:\n\t-z: force zerocopy mode", bin);
+}
+
+/* this is a simple helper program that creates an XDP socket and does the
+ * minimum necessary to get bind() to succeed.
+ *
+ * this test program is not intended to actually process packets, but could be
+ * extended in the future if that is actually needed.
+ *
+ * it is used by queues.py to ensure the xsk netlinux attribute is set
+ * correctly.
+ */
+int main(int argc, char **argv)
+{
+ struct xdp_umem_reg umem_reg = { 0 };
+ struct sockaddr_xdp sxdp = { 0 };
+ int num_desc = NUM_DESC;
+ void *umem_area;
+ int retry = 0;
+ int ifindex;
+ int sock_fd;
+ int queue;
+
+ if (argc != 3 && argc != 4) {
+ print_usage(argv[0]);
+ return 1;
+ }
+
+ sock_fd = socket(AF_XDP, SOCK_RAW, 0);
+ if (sock_fd < 0) {
+ perror("socket creation failed");
+ /* if the kernel doesn't support AF_XDP, let the test program
+ * know with -1. All other error paths return 1.
+ */
+ if (errno == EAFNOSUPPORT)
+ return -1;
+ return 1;
+ }
+
+ /* "Probing mode", just checking if AF_XDP sockets are supported */
+ if (!strcmp(argv[1], "-") && !strcmp(argv[2], "-")) {
+ printf("AF_XDP support detected\n");
+ close(sock_fd);
+ return 0;
+ }
+
+ ifindex = atoi(argv[1]);
+ queue = atoi(argv[2]);
+
+ umem_area = mmap(NULL, UMEM_SZ, PROT_READ | PROT_WRITE, MAP_PRIVATE |
+ MAP_ANONYMOUS, -1, 0);
+ if (umem_area == MAP_FAILED) {
+ perror("mmap failed");
+ return 1;
+ }
+
+ umem_reg.addr = (uintptr_t)umem_area;
+ umem_reg.len = UMEM_SZ;
+ umem_reg.chunk_size = 2048;
+ umem_reg.headroom = 0;
+
+ setsockopt(sock_fd, SOL_XDP, XDP_UMEM_REG, &umem_reg,
+ sizeof(umem_reg));
+ setsockopt(sock_fd, SOL_XDP, XDP_UMEM_FILL_RING, &num_desc,
+ sizeof(num_desc));
+ setsockopt(sock_fd, SOL_XDP, XDP_UMEM_COMPLETION_RING, &num_desc,
+ sizeof(num_desc));
+ setsockopt(sock_fd, SOL_XDP, XDP_RX_RING, &num_desc, sizeof(num_desc));
+
+ sxdp.sxdp_family = AF_XDP;
+ sxdp.sxdp_ifindex = ifindex;
+ sxdp.sxdp_queue_id = queue;
+ sxdp.sxdp_flags = 0;
+
+ if (argc > 3) {
+ if (!strcmp(argv[3], "-z")) {
+ sxdp.sxdp_flags = XDP_ZEROCOPY;
+ } else {
+ print_usage(argv[0]);
+ return 1;
+ }
+ }
+
+ while (1) {
+ if (bind(sock_fd, (struct sockaddr *)&sxdp, sizeof(sxdp)) == 0)
+ break;
+
+ if (errno == EBUSY && retry < 3) {
+ retry++;
+ sleep(1);
+ continue;
+ } else {
+ perror("bind failed");
+ munmap(umem_area, UMEM_SZ);
+ close(sock_fd);
+ return 1;
+ }
+ }
+
+ ksft_ready();
+ ksft_wait();
+
+ /* parent program will write a byte to stdin when its ready for this
+ * helper to exit
+ */
+
+ close(sock_fd);
+ return 0;
+}
diff --git a/tools/testing/selftests/net/lib/xdp_metadata.bpf.c b/tools/testing/selftests/net/lib/xdp_metadata.bpf.c
new file mode 100644
index 000000000000..f71f59215239
--- /dev/null
+++ b/tools/testing/selftests/net/lib/xdp_metadata.bpf.c
@@ -0,0 +1,163 @@
+// SPDX-License-Identifier: GPL-2.0
+
+#include <stddef.h>
+#include <linux/bpf.h>
+#include <linux/in.h>
+#include <linux/if_ether.h>
+#include <linux/ip.h>
+#include <linux/ipv6.h>
+#include <linux/udp.h>
+#include <linux/tcp.h>
+#include <bpf/bpf_endian.h>
+#include <bpf/bpf_helpers.h>
+
+enum {
+ XDP_PORT = 1,
+ XDP_PROTO = 4,
+} xdp_map_setup_keys;
+
+struct {
+ __uint(type, BPF_MAP_TYPE_ARRAY);
+ __uint(max_entries, 5);
+ __type(key, __u32);
+ __type(value, __s32);
+} map_xdp_setup SEC(".maps");
+
+/* RSS hash results: key 0 = hash, key 1 = hash type,
+ * key 2 = packet count, key 3 = error count.
+ */
+enum {
+ RSS_KEY_HASH = 0,
+ RSS_KEY_TYPE = 1,
+ RSS_KEY_PKT_CNT = 2,
+ RSS_KEY_ERR_CNT = 3,
+};
+
+struct {
+ __uint(type, BPF_MAP_TYPE_ARRAY);
+ __type(key, __u32);
+ __type(value, __u32);
+ __uint(max_entries, 4);
+} map_rss SEC(".maps");
+
+/* Mirror of enum xdp_rss_hash_type from include/net/xdp.h.
+ * Needed because the enum is not part of UAPI headers.
+ */
+enum xdp_rss_hash_type {
+ XDP_RSS_L3_IPV4 = 1U << 0,
+ XDP_RSS_L3_IPV6 = 1U << 1,
+ XDP_RSS_L3_DYNHDR = 1U << 2,
+ XDP_RSS_L4 = 1U << 3,
+ XDP_RSS_L4_TCP = 1U << 4,
+ XDP_RSS_L4_UDP = 1U << 5,
+ XDP_RSS_L4_SCTP = 1U << 6,
+ XDP_RSS_L4_IPSEC = 1U << 7,
+ XDP_RSS_L4_ICMP = 1U << 8,
+};
+
+extern int bpf_xdp_metadata_rx_hash(const struct xdp_md *ctx, __u32 *hash,
+ enum xdp_rss_hash_type *rss_type) __ksym;
+
+static __always_inline __u16 get_dest_port(void *l4, void *data_end,
+ __u8 protocol)
+{
+ if (protocol == IPPROTO_UDP) {
+ struct udphdr *udp = l4;
+
+ if ((void *)(udp + 1) > data_end)
+ return 0;
+ return udp->dest;
+ } else if (protocol == IPPROTO_TCP) {
+ struct tcphdr *tcp = l4;
+
+ if ((void *)(tcp + 1) > data_end)
+ return 0;
+ return tcp->dest;
+ }
+
+ return 0;
+}
+
+SEC("xdp")
+int xdp_rss_hash(struct xdp_md *ctx)
+{
+ void *data_end = (void *)(long)ctx->data_end;
+ void *data = (void *)(long)ctx->data;
+ enum xdp_rss_hash_type rss_type = 0;
+ struct ethhdr *eth = data;
+ __u8 l4_proto = 0;
+ __u32 hash = 0;
+ __u32 key, val;
+ void *l4 = NULL;
+ __u32 *cnt;
+ int ret;
+
+ if ((void *)(eth + 1) > data_end)
+ return XDP_PASS;
+
+ if (eth->h_proto == bpf_htons(ETH_P_IP)) {
+ struct iphdr *iph = (void *)(eth + 1);
+
+ if ((void *)(iph + 1) > data_end)
+ return XDP_PASS;
+ l4_proto = iph->protocol;
+ l4 = (void *)(iph + 1);
+ } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) {
+ struct ipv6hdr *ip6h = (void *)(eth + 1);
+
+ if ((void *)(ip6h + 1) > data_end)
+ return XDP_PASS;
+ l4_proto = ip6h->nexthdr;
+ l4 = (void *)(ip6h + 1);
+ }
+
+ if (!l4)
+ return XDP_PASS;
+
+ /* Filter on the configured protocol (map_xdp_setup key XDP_PROTO).
+ * When set, only process packets matching the requested L4 protocol.
+ */
+ key = XDP_PROTO;
+ __s32 *proto_cfg = bpf_map_lookup_elem(&map_xdp_setup, &key);
+
+ if (proto_cfg && *proto_cfg != 0 && l4_proto != (__u8)*proto_cfg)
+ return XDP_PASS;
+
+ /* Filter on the configured port (map_xdp_setup key XDP_PORT).
+ * Only applies to protocols with ports (UDP, TCP).
+ */
+ key = XDP_PORT;
+ __s32 *port_cfg = bpf_map_lookup_elem(&map_xdp_setup, &key);
+
+ if (port_cfg && *port_cfg != 0) {
+ __u16 dest = get_dest_port(l4, data_end, l4_proto);
+
+ if (!dest || bpf_ntohs(dest) != (__u16)*port_cfg)
+ return XDP_PASS;
+ }
+
+ ret = bpf_xdp_metadata_rx_hash(ctx, &hash, &rss_type);
+ if (ret < 0) {
+ key = RSS_KEY_ERR_CNT;
+ cnt = bpf_map_lookup_elem(&map_rss, &key);
+ if (cnt)
+ __sync_fetch_and_add(cnt, 1);
+ return XDP_PASS;
+ }
+
+ key = RSS_KEY_HASH;
+ bpf_map_update_elem(&map_rss, &key, &hash, BPF_ANY);
+
+ key = RSS_KEY_TYPE;
+ val = (__u32)rss_type;
+ bpf_map_update_elem(&map_rss, &key, &val, BPF_ANY);
+
+ key = RSS_KEY_PKT_CNT;
+ cnt = bpf_map_lookup_elem(&map_rss, &key);
+ if (cnt)
+ __sync_fetch_and_add(cnt, 1);
+
+ return XDP_PASS;
+}
+
+char _license[] SEC("license") = "GPL";
diff --git a/tools/testing/selftests/net/lib/xdp_native.bpf.c b/tools/testing/selftests/net/lib/xdp_native.bpf.c
new file mode 100644
index 000000000000..ded3f896e622
--- /dev/null
+++ b/tools/testing/selftests/net/lib/xdp_native.bpf.c
@@ -0,0 +1,685 @@
+// SPDX-License-Identifier: GPL-2.0
+
+#include <stddef.h>
+#include <linux/bpf.h>
+#include <linux/in.h>
+#include <linux/if_ether.h>
+#include <linux/ip.h>
+#include <linux/ipv6.h>
+#include <linux/udp.h>
+#include <bpf/bpf_endian.h>
+#include <bpf/bpf_helpers.h>
+
+#define MAX_ADJST_OFFSET 256
+#define MAX_PAYLOAD_LEN 5000
+#define MAX_HDR_LEN 64
+
+extern int bpf_xdp_pull_data(struct xdp_md *xdp, __u32 len) __ksym __weak;
+
+enum {
+ XDP_MODE = 0,
+ XDP_PORT = 1,
+ XDP_ADJST_OFFSET = 2,
+ XDP_ADJST_TAG = 3,
+} xdp_map_setup_keys;
+
+enum {
+ XDP_MODE_PASS = 0,
+ XDP_MODE_DROP = 1,
+ XDP_MODE_TX = 2,
+ XDP_MODE_TAIL_ADJST = 3,
+ XDP_MODE_HEAD_ADJST = 4,
+} xdp_map_modes;
+
+enum {
+ STATS_RX = 0,
+ STATS_PASS = 1,
+ STATS_DROP = 2,
+ STATS_TX = 3,
+ STATS_ABORT = 4,
+} xdp_stats;
+
+struct {
+ __uint(type, BPF_MAP_TYPE_ARRAY);
+ __uint(max_entries, 5);
+ __type(key, __u32);
+ __type(value, __s32);
+} map_xdp_setup SEC(".maps");
+
+struct {
+ __uint(type, BPF_MAP_TYPE_ARRAY);
+ __uint(max_entries, 5);
+ __type(key, __u32);
+ __type(value, __u64);
+} map_xdp_stats SEC(".maps");
+
+static __u32 min(__u32 a, __u32 b)
+{
+ return a < b ? a : b;
+}
+
+static void record_stats(struct xdp_md *ctx, __u32 stat_type)
+{
+ __u64 *count;
+
+ count = bpf_map_lookup_elem(&map_xdp_stats, &stat_type);
+
+ if (count)
+ __sync_fetch_and_add(count, 1);
+}
+
+static struct udphdr *filter_udphdr(struct xdp_md *ctx, __u16 port)
+{
+ struct udphdr *udph = NULL;
+ void *data, *data_end;
+ struct ethhdr *eth;
+ int err;
+
+ err = bpf_xdp_pull_data(ctx, sizeof(*eth));
+ if (err)
+ return NULL;
+
+ data_end = (void *)(long)ctx->data_end;
+ data = eth = (void *)(long)ctx->data;
+
+ if (data + sizeof(*eth) > data_end)
+ return NULL;
+
+ if (eth->h_proto == bpf_htons(ETH_P_IP)) {
+ struct iphdr *iph;
+
+ err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*iph) +
+ sizeof(*udph));
+ if (err)
+ return NULL;
+
+ data_end = (void *)(long)ctx->data_end;
+ data = (void *)(long)ctx->data;
+
+ iph = data + sizeof(*eth);
+
+ if (iph + 1 > (struct iphdr *)data_end ||
+ iph->protocol != IPPROTO_UDP)
+ return NULL;
+
+ udph = data + sizeof(*iph) + sizeof(*eth);
+ } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) {
+ struct ipv6hdr *ipv6h;
+
+ err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*ipv6h) +
+ sizeof(*udph));
+ if (err)
+ return NULL;
+
+ data_end = (void *)(long)ctx->data_end;
+ data = (void *)(long)ctx->data;
+
+ ipv6h = data + sizeof(*eth);
+
+ if (ipv6h + 1 > (struct ipv6hdr *)data_end ||
+ ipv6h->nexthdr != IPPROTO_UDP)
+ return NULL;
+
+ udph = data + sizeof(*ipv6h) + sizeof(*eth);
+ } else {
+ return NULL;
+ }
+
+ if (udph + 1 > (struct udphdr *)data_end)
+ return NULL;
+
+ if (udph->dest != bpf_htons(port))
+ return NULL;
+
+ record_stats(ctx, STATS_RX);
+
+ return udph;
+}
+
+static int xdp_mode_pass(struct xdp_md *ctx, __u16 port)
+{
+ struct udphdr *udph = NULL;
+
+ udph = filter_udphdr(ctx, port);
+ if (!udph)
+ return XDP_PASS;
+
+ record_stats(ctx, STATS_PASS);
+
+ return XDP_PASS;
+}
+
+static int xdp_mode_drop_handler(struct xdp_md *ctx, __u16 port)
+{
+ struct udphdr *udph = NULL;
+
+ udph = filter_udphdr(ctx, port);
+ if (!udph)
+ return XDP_PASS;
+
+ record_stats(ctx, STATS_DROP);
+
+ return XDP_DROP;
+}
+
+static void swap_machdr(void *data)
+{
+ struct ethhdr *eth = data;
+ __u8 tmp_mac[ETH_ALEN];
+
+ __builtin_memcpy(tmp_mac, eth->h_source, ETH_ALEN);
+ __builtin_memcpy(eth->h_source, eth->h_dest, ETH_ALEN);
+ __builtin_memcpy(eth->h_dest, tmp_mac, ETH_ALEN);
+}
+
+static int xdp_mode_tx_handler(struct xdp_md *ctx, __u16 port)
+{
+ struct udphdr *udph = NULL;
+ void *data, *data_end;
+ struct ethhdr *eth;
+ int err;
+
+ err = bpf_xdp_pull_data(ctx, sizeof(*eth));
+ if (err)
+ return XDP_PASS;
+
+ data_end = (void *)(long)ctx->data_end;
+ data = eth = (void *)(long)ctx->data;
+
+ if (data + sizeof(*eth) > data_end)
+ return XDP_PASS;
+
+ if (eth->h_proto == bpf_htons(ETH_P_IP)) {
+ struct iphdr *iph;
+ __be32 tmp_ip;
+
+ err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*iph) +
+ sizeof(*udph));
+ if (err)
+ return XDP_PASS;
+
+ data_end = (void *)(long)ctx->data_end;
+ data = (void *)(long)ctx->data;
+
+ iph = data + sizeof(*eth);
+
+ if (iph + 1 > (struct iphdr *)data_end ||
+ iph->protocol != IPPROTO_UDP)
+ return XDP_PASS;
+
+ udph = data + sizeof(*iph) + sizeof(*eth);
+
+ if (udph + 1 > (struct udphdr *)data_end)
+ return XDP_PASS;
+ if (udph->dest != bpf_htons(port))
+ return XDP_PASS;
+
+ record_stats(ctx, STATS_RX);
+ eth = data;
+ swap_machdr((void *)eth);
+
+ tmp_ip = iph->saddr;
+ iph->saddr = iph->daddr;
+ iph->daddr = tmp_ip;
+
+ record_stats(ctx, STATS_TX);
+
+ return XDP_TX;
+
+ } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) {
+ struct in6_addr tmp_ipv6;
+ struct ipv6hdr *ipv6h;
+
+ err = bpf_xdp_pull_data(ctx, sizeof(*eth) + sizeof(*ipv6h) +
+ sizeof(*udph));
+ if (err)
+ return XDP_PASS;
+
+ data_end = (void *)(long)ctx->data_end;
+ data = (void *)(long)ctx->data;
+
+ ipv6h = data + sizeof(*eth);
+
+ if (ipv6h + 1 > (struct ipv6hdr *)data_end ||
+ ipv6h->nexthdr != IPPROTO_UDP)
+ return XDP_PASS;
+
+ udph = data + sizeof(*ipv6h) + sizeof(*eth);
+
+ if (udph + 1 > (struct udphdr *)data_end)
+ return XDP_PASS;
+ if (udph->dest != bpf_htons(port))
+ return XDP_PASS;
+
+ record_stats(ctx, STATS_RX);
+ eth = data;
+ swap_machdr((void *)eth);
+
+ __builtin_memcpy(&tmp_ipv6, &ipv6h->saddr, sizeof(tmp_ipv6));
+ __builtin_memcpy(&ipv6h->saddr, &ipv6h->daddr,
+ sizeof(tmp_ipv6));
+ __builtin_memcpy(&ipv6h->daddr, &tmp_ipv6, sizeof(tmp_ipv6));
+
+ record_stats(ctx, STATS_TX);
+
+ return XDP_TX;
+ }
+
+ return XDP_PASS;
+}
+
+static __always_inline __u16 csum_fold_helper(__u32 csum)
+{
+ csum = (csum & 0xffff) + (csum >> 16);
+ return ~((csum & 0xffff) + (csum >> 16));
+}
+
+static __always_inline __u16 csum_fold_udp_helper(__u32 csum)
+{
+ return csum_fold_helper(csum) ? : 0xffff;
+}
+
+static void *update_pkt(struct xdp_md *ctx, __s16 offset, __u32 *udp_csum)
+{
+ void *data_end = (void *)(long)ctx->data_end;
+ void *data = (void *)(long)ctx->data;
+ struct udphdr *udph = NULL;
+ struct ethhdr *eth = data;
+ __u32 len, len_new;
+
+ if (data + sizeof(*eth) > data_end)
+ return NULL;
+
+ if (eth->h_proto == bpf_htons(ETH_P_IP)) {
+ struct iphdr *iph = data + sizeof(*eth);
+
+ if (iph + 1 > (struct iphdr *)data_end)
+ return NULL;
+
+ udph = (void *)eth + sizeof(*iph) + sizeof(*eth);
+ if (!udph || udph + 1 > (struct udphdr *)data_end)
+ return NULL;
+
+ len = iph->tot_len;
+ len_new = bpf_htons(bpf_ntohs(len) + offset);
+ iph->tot_len = len_new;
+ iph->check = csum_fold_helper(
+ bpf_csum_diff(&len, sizeof(len), &len_new,
+ sizeof(len_new), ~((__u32)iph->check)));
+ } else if (eth->h_proto == bpf_htons(ETH_P_IPV6)) {
+ struct ipv6hdr *ipv6h = data + sizeof(*eth);
+
+ if (ipv6h + 1 > (struct ipv6hdr *)data_end)
+ return NULL;
+
+ udph = (void *)eth + sizeof(*ipv6h) + sizeof(*eth);
+ if (!udph || udph + 1 > (struct udphdr *)data_end)
+ return NULL;
+
+ len = ipv6h->payload_len;
+ len_new = bpf_htons(bpf_ntohs(len) + offset);
+ ipv6h->payload_len = len_new;
+ } else {
+ return NULL;
+ }
+
+ len = udph->len;
+ len_new = bpf_htons(bpf_ntohs(len) + offset);
+
+ *udp_csum = ~((__u32)udph->check);
+ *udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new,
+ sizeof(len_new), *udp_csum);
+ *udp_csum = bpf_csum_diff(&len, sizeof(len), &len_new,
+ sizeof(len_new), *udp_csum);
+
+ udph->len = len_new;
+
+ return udph;
+}
+
+static int xdp_adjst_tail_shrnk_data(struct xdp_md *ctx, __u16 offset,
+ unsigned long hdr_len)
+{
+ char tmp_buff[MAX_ADJST_OFFSET];
+ __u32 buff_pos, udp_csum = 0;
+ struct udphdr *udph = NULL;
+ __u32 buff_len;
+
+ udph = update_pkt(ctx, 0 - offset, &udp_csum);
+ if (!udph)
+ return -1;
+
+ buff_len = bpf_xdp_get_buff_len(ctx);
+
+ offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
+ offset & 0xff;
+ if (offset == 0)
+ return -1;
+
+ /* Make sure we have enough data to avoid eating the header */
+ if (buff_len - offset < hdr_len)
+ return -1;
+
+ buff_pos = buff_len - offset;
+ if (bpf_xdp_load_bytes(ctx, buff_pos, tmp_buff, offset) < 0)
+ return -1;
+
+ udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum);
+ udph->check = (__u16)csum_fold_udp_helper(udp_csum);
+
+ if (bpf_xdp_adjust_tail(ctx, 0 - offset) < 0)
+ return -1;
+
+ return 0;
+}
+
+static int xdp_adjst_tail_grow_data(struct xdp_md *ctx, __u16 offset)
+{
+ char tmp_buff[MAX_ADJST_OFFSET];
+ __u32 buff_pos, udp_csum = 0;
+ __u32 buff_len, hdr_len, key;
+ struct udphdr *udph;
+ __s32 *val;
+ __u8 tag;
+
+ /* Proceed to update the packet headers before attempting to adjuste
+ * the tail. Once the tail is adjusted we lose access to the offset
+ * amount of data at the end of the packet which is crucial to update
+ * the checksum.
+ * Since any failure beyond this would abort the packet, we should
+ * not worry about passing a packet up the stack with wrong headers
+ */
+ udph = update_pkt(ctx, offset, &udp_csum);
+ if (!udph)
+ return -1;
+
+ key = XDP_ADJST_TAG;
+ val = bpf_map_lookup_elem(&map_xdp_setup, &key);
+ if (!val)
+ return -1;
+
+ tag = (__u8)(*val);
+
+ for (int i = 0; i < MAX_ADJST_OFFSET; i++)
+ __builtin_memcpy(&tmp_buff[i], &tag, 1);
+
+ offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
+ offset & 0xff;
+ if (offset == 0)
+ return -1;
+
+ udp_csum = bpf_csum_diff(0, 0, (__be32 *)tmp_buff, offset, udp_csum);
+ udph->check = (__u16)csum_fold_udp_helper(udp_csum);
+
+ buff_len = bpf_xdp_get_buff_len(ctx);
+
+ if (bpf_xdp_adjust_tail(ctx, offset) < 0) {
+ bpf_printk("Failed to adjust tail\n");
+ return -1;
+ }
+
+ if (bpf_xdp_store_bytes(ctx, buff_len, tmp_buff, offset) < 0)
+ return -1;
+
+ return 0;
+}
+
+static int xdp_adjst_tail(struct xdp_md *ctx, __u16 port)
+{
+ struct udphdr *udph = NULL;
+ __s32 *adjust_offset, *val;
+ unsigned long hdr_len;
+ void *offset_ptr;
+ __u32 key;
+ __u8 tag;
+ int ret;
+
+ udph = filter_udphdr(ctx, port);
+ if (!udph)
+ return XDP_PASS;
+
+ hdr_len = (void *)udph - (void *)(long)ctx->data +
+ sizeof(struct udphdr);
+ key = XDP_ADJST_OFFSET;
+ adjust_offset = bpf_map_lookup_elem(&map_xdp_setup, &key);
+ if (!adjust_offset)
+ return XDP_PASS;
+
+ if (*adjust_offset < 0)
+ ret = xdp_adjst_tail_shrnk_data(ctx,
+ (__u16)(0 - *adjust_offset),
+ hdr_len);
+ else
+ ret = xdp_adjst_tail_grow_data(ctx, (__u16)(*adjust_offset));
+ if (ret)
+ goto abort_pkt;
+
+ record_stats(ctx, STATS_PASS);
+ return XDP_PASS;
+
+abort_pkt:
+ record_stats(ctx, STATS_ABORT);
+ return XDP_ABORTED;
+}
+
+static int xdp_adjst_head_shrnk_data(struct xdp_md *ctx, __u64 hdr_len,
+ __u32 offset)
+{
+ char tmp_buff[MAX_ADJST_OFFSET];
+ struct udphdr *udph;
+ void *offset_ptr;
+ __u32 udp_csum = 0;
+
+ /* Update the length information in the IP and UDP headers before
+ * adjusting the headroom. This simplifies accessing the relevant
+ * fields in the IP and UDP headers for fragmented packets. Any
+ * failure beyond this point will result in the packet being aborted,
+ * so we don't need to worry about incorrect length information for
+ * passed packets.
+ */
+ udph = update_pkt(ctx, (__s16)(0 - offset), &udp_csum);
+ if (!udph)
+ return -1;
+
+ offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
+ offset & 0xff;
+ if (offset == 0)
+ return -1;
+
+ if (bpf_xdp_load_bytes(ctx, hdr_len, tmp_buff, offset) < 0)
+ return -1;
+
+ udp_csum = bpf_csum_diff((__be32 *)tmp_buff, offset, 0, 0, udp_csum);
+ udph->check = (__u16)csum_fold_udp_helper(udp_csum);
+
+ if (bpf_xdp_load_bytes(ctx, 0, tmp_buff, MAX_ADJST_OFFSET) < 0)
+ return -1;
+
+ if (bpf_xdp_adjust_head(ctx, offset) < 0)
+ return -1;
+
+ if (offset > MAX_ADJST_OFFSET)
+ return -1;
+
+ if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0)
+ return -1;
+
+ /* Added here to handle clang complain about negative value */
+ hdr_len = hdr_len & 0xff;
+
+ if (hdr_len == 0)
+ return -1;
+
+ if (bpf_xdp_store_bytes(ctx, 0, tmp_buff, hdr_len) < 0)
+ return -1;
+
+ return 0;
+}
+
+static int xdp_adjst_head_grow_data(struct xdp_md *ctx, __u64 hdr_len,
+ __u32 offset)
+{
+ char hdr_buff[MAX_HDR_LEN];
+ char data_buff[MAX_ADJST_OFFSET];
+ void *offset_ptr;
+ __s32 *val;
+ __u32 key;
+ __u8 tag;
+ __u32 udp_csum = 0;
+ struct udphdr *udph;
+
+ udph = update_pkt(ctx, (__s16)(offset), &udp_csum);
+ if (!udph)
+ return -1;
+
+ key = XDP_ADJST_TAG;
+ val = bpf_map_lookup_elem(&map_xdp_setup, &key);
+ if (!val)
+ return -1;
+
+ tag = (__u8)(*val);
+ for (int i = 0; i < MAX_ADJST_OFFSET; i++)
+ __builtin_memcpy(&data_buff[i], &tag, 1);
+
+ offset = (offset & 0x1ff) >= MAX_ADJST_OFFSET ? MAX_ADJST_OFFSET :
+ offset & 0xff;
+ if (offset == 0)
+ return -1;
+
+ udp_csum = bpf_csum_diff(0, 0, (__be32 *)data_buff, offset, udp_csum);
+ udph->check = (__u16)csum_fold_udp_helper(udp_csum);
+
+ if (hdr_len > MAX_ADJST_OFFSET || hdr_len == 0)
+ return -1;
+
+ /* Added here to handle clang complain about negative value */
+ hdr_len = hdr_len & 0xff;
+
+ if (hdr_len == 0)
+ return -1;
+
+ if (bpf_xdp_load_bytes(ctx, 0, hdr_buff, hdr_len) < 0)
+ return -1;
+
+ if (offset > MAX_ADJST_OFFSET)
+ return -1;
+
+ if (bpf_xdp_adjust_head(ctx, 0 - offset) < 0)
+ return -1;
+
+ if (bpf_xdp_store_bytes(ctx, 0, hdr_buff, hdr_len) < 0)
+ return -1;
+
+ if (bpf_xdp_store_bytes(ctx, hdr_len, data_buff, offset) < 0)
+ return -1;
+
+ return 0;
+}
+
+static int xdp_head_adjst(struct xdp_md *ctx, __u16 port)
+{
+ struct udphdr *udph_ptr = NULL;
+ __u32 key, size, hdr_len;
+ __s32 *val;
+ int res;
+
+ /* Filter packets based on UDP port */
+ udph_ptr = filter_udphdr(ctx, port);
+ if (!udph_ptr)
+ return XDP_PASS;
+
+ hdr_len = (void *)udph_ptr - (void *)(long)ctx->data +
+ sizeof(struct udphdr);
+
+ key = XDP_ADJST_OFFSET;
+ val = bpf_map_lookup_elem(&map_xdp_setup, &key);
+ if (!val)
+ return XDP_PASS;
+
+ switch (*val) {
+ case -16:
+ case 16:
+ size = 16;
+ break;
+ case -32:
+ case 32:
+ size = 32;
+ break;
+ case -64:
+ case 64:
+ size = 64;
+ break;
+ case -128:
+ case 128:
+ size = 128;
+ break;
+ case -256:
+ case 256:
+ size = 256;
+ break;
+ default:
+ bpf_printk("Invalid adjustment offset: %d\n", *val);
+ goto abort;
+ }
+
+ if (*val < 0)
+ res = xdp_adjst_head_grow_data(ctx, hdr_len, size);
+ else
+ res = xdp_adjst_head_shrnk_data(ctx, hdr_len, size);
+
+ if (res)
+ goto abort;
+
+ record_stats(ctx, STATS_PASS);
+ return XDP_PASS;
+
+abort:
+ record_stats(ctx, STATS_ABORT);
+ return XDP_ABORTED;
+}
+
+static int xdp_prog_common(struct xdp_md *ctx)
+{
+ __u32 key, *port;
+ __s32 *mode;
+
+ key = XDP_MODE;
+ mode = bpf_map_lookup_elem(&map_xdp_setup, &key);
+ if (!mode)
+ return XDP_PASS;
+
+ key = XDP_PORT;
+ port = bpf_map_lookup_elem(&map_xdp_setup, &key);
+ if (!port)
+ return XDP_PASS;
+
+ switch (*mode) {
+ case XDP_MODE_PASS:
+ return xdp_mode_pass(ctx, (__u16)(*port));
+ case XDP_MODE_DROP:
+ return xdp_mode_drop_handler(ctx, (__u16)(*port));
+ case XDP_MODE_TX:
+ return xdp_mode_tx_handler(ctx, (__u16)(*port));
+ case XDP_MODE_TAIL_ADJST:
+ return xdp_adjst_tail(ctx, (__u16)(*port));
+ case XDP_MODE_HEAD_ADJST:
+ return xdp_head_adjst(ctx, (__u16)(*port));
+ }
+
+ /* Default action is to simple pass */
+ return XDP_PASS;
+}
+
+SEC("xdp")
+int xdp_prog(struct xdp_md *ctx)
+{
+ return xdp_prog_common(ctx);
+}
+
+SEC("xdp.frags")
+int xdp_prog_frags(struct xdp_md *ctx)
+{
+ return xdp_prog_common(ctx);
+}
+
+char _license[] SEC("license") = "GPL";