summaryrefslogtreecommitdiff
path: root/tools/testing/selftests/bpf/prog_tests/lwt_helpers.h
blob: e9190574e79f39fe50a2d2d512883fb94d3987c9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
/* SPDX-License-Identifier: GPL-2.0 */

#ifndef __LWT_HELPERS_H
#define __LWT_HELPERS_H

#include <time.h>
#include <net/if.h>
#include <linux/if_tun.h>
#include <linux/icmp.h>

#include "test_progs.h"

#define log_err(MSG, ...) \
	fprintf(stderr, "(%s:%d: errno: %s) " MSG "\n", \
		__FILE__, __LINE__, strerror(errno), ##__VA_ARGS__)

#define RUN_TEST(name)                                                        \
	({                                                                    \
		if (test__start_subtest(#name))                               \
			if (ASSERT_OK(netns_create(), "netns_create")) {      \
				struct nstoken *token = open_netns(NETNS);    \
				if (ASSERT_OK_PTR(token, "setns")) {          \
					test_ ## name();                      \
					close_netns(token);                   \
				}                                             \
				netns_delete();                               \
			}                                                     \
	})

#define NETNS "ns_lwt"

static inline int netns_create(void)
{
	return system("ip netns add " NETNS);
}

static inline int netns_delete(void)
{
	return system("ip netns del " NETNS ">/dev/null 2>&1");
}

static int open_tuntap(const char *dev_name, bool need_mac)
{
	int err = 0;
	struct ifreq ifr;
	int fd = open("/dev/net/tun", O_RDWR);

	if (!ASSERT_GT(fd, 0, "open(/dev/net/tun)"))
		return -1;

	ifr.ifr_flags = IFF_NO_PI | (need_mac ? IFF_TAP : IFF_TUN);
	strncpy(ifr.ifr_name, dev_name, IFNAMSIZ - 1);
	ifr.ifr_name[IFNAMSIZ - 1] = '\0';

	err = ioctl(fd, TUNSETIFF, &ifr);
	if (!ASSERT_OK(err, "ioctl(TUNSETIFF)")) {
		close(fd);
		return -1;
	}

	err = fcntl(fd, F_SETFL, O_NONBLOCK);
	if (!ASSERT_OK(err, "fcntl(O_NONBLOCK)")) {
		close(fd);
		return -1;
	}

	return fd;
}

#define ICMP_PAYLOAD_SIZE     100

/* Match an ICMP packet with payload len ICMP_PAYLOAD_SIZE */
static int __expect_icmp_ipv4(char *buf, ssize_t len)
{
	struct iphdr *ip = (struct iphdr *)buf;
	struct icmphdr *icmp = (struct icmphdr *)(ip + 1);
	ssize_t min_header_len = sizeof(*ip) + sizeof(*icmp);

	if (len < min_header_len)
		return -1;

	if (ip->protocol != IPPROTO_ICMP)
		return -1;

	if (icmp->type != ICMP_ECHO)
		return -1;

	return len == ICMP_PAYLOAD_SIZE + min_header_len;
}

typedef int (*filter_t) (char *, ssize_t);

/* wait_for_packet - wait for a packet that matches the filter
 *
 * @fd: tun fd/packet socket to read packet
 * @filter: filter function, returning 1 if matches
 * @timeout: timeout to wait for the packet
 *
 * Returns 1 if a matching packet is read, 0 if timeout expired, -1 on error.
 */
static int wait_for_packet(int fd, filter_t filter, struct timeval *timeout)
{
	char buf[4096];
	int max_retry = 5; /* in case we read some spurious packets */
	fd_set fds;

	FD_ZERO(&fds);
	while (max_retry--) {
		/* Linux modifies timeout arg... So make a copy */
		struct timeval copied_timeout = *timeout;
		ssize_t ret = -1;

		FD_SET(fd, &fds);

		ret = select(1 + fd, &fds, NULL, NULL, &copied_timeout);
		if (ret <= 0) {
			if (errno == EINTR)
				continue;
			else if (errno == EAGAIN || ret == 0)
				return 0;

			log_err("select failed");
			return -1;
		}

		ret = read(fd, buf, sizeof(buf));

		if (ret <= 0) {
			log_err("read(dev): %ld", ret);
			return -1;
		}

		if (filter && filter(buf, ret) > 0)
			return 1;
	}

	return 0;
}

#endif /* __LWT_HELPERS_H */