summaryrefslogtreecommitdiff
path: root/tools/testing/selftests/net/lib/py/utils.py
blob: 34470d65d871a3f4f05f3ba1a9d2527774942652 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# SPDX-License-Identifier: GPL-2.0

import errno
import json as _json
import os
import random
import re
import select
import socket
import subprocess
import time


class CmdExitFailure(Exception):
    def __init__(self, msg, cmd_obj):
        super().__init__(msg)
        self.cmd = cmd_obj


def fd_read_timeout(fd, timeout):
    rlist, _, _ = select.select([fd], [], [], timeout)
    if rlist:
        return os.read(fd, 1024)
    else:
        raise TimeoutError("Timeout waiting for fd read")


class cmd:
    """
    Execute a command on local or remote host.

    Use bkg() instead to run a command in the background.
    """
    def __init__(self, comm, shell=True, fail=True, ns=None, background=False,
                 host=None, timeout=5, ksft_wait=None):
        if ns:
            comm = f'ip netns exec {ns} ' + comm

        self.stdout = None
        self.stderr = None
        self.ret = None
        self.ksft_term_fd = None

        self.comm = comm
        if host:
            self.proc = host.cmd(comm)
        else:
            # ksft_wait lets us wait for the background process to fully start,
            # we pass an FD to the child process, and wait for it to write back.
            # Similarly term_fd tells child it's time to exit.
            pass_fds = ()
            env = os.environ.copy()
            if ksft_wait is not None:
                rfd, ready_fd = os.pipe()
                wait_fd, self.ksft_term_fd = os.pipe()
                pass_fds = (ready_fd, wait_fd, )
                env["KSFT_READY_FD"] = str(ready_fd)
                env["KSFT_WAIT_FD"]  = str(wait_fd)

            self.proc = subprocess.Popen(comm, shell=shell, stdout=subprocess.PIPE,
                                         stderr=subprocess.PIPE, pass_fds=pass_fds,
                                         env=env)
            if ksft_wait is not None:
                os.close(ready_fd)
                os.close(wait_fd)
                msg = fd_read_timeout(rfd, ksft_wait)
                os.close(rfd)
                if not msg:
                    raise Exception("Did not receive ready message")
        if not background:
            self.process(terminate=False, fail=fail, timeout=timeout)

    def process(self, terminate=True, fail=None, timeout=5):
        if fail is None:
            fail = not terminate

        if self.ksft_term_fd:
            os.write(self.ksft_term_fd, b"1")
        if terminate:
            self.proc.terminate()
        stdout, stderr = self.proc.communicate(timeout)
        self.stdout = stdout.decode("utf-8")
        self.stderr = stderr.decode("utf-8")
        self.proc.stdout.close()
        self.proc.stderr.close()
        self.ret = self.proc.returncode

        if self.proc.returncode != 0 and fail:
            if len(stderr) > 0 and stderr[-1] == "\n":
                stderr = stderr[:-1]
            raise CmdExitFailure("Command failed: %s\nSTDOUT: %s\nSTDERR: %s" %
                                 (self.proc.args, stdout, stderr), self)


class bkg(cmd):
    """
    Run a command in the background.

    Examples usage:

    Run a command on remote host, and wait for it to finish.
    This is usually paired with wait_port_listen() to make sure
    the command has initialized:

        with bkg("socat ...", exit_wait=True, host=cfg.remote) as nc:
            ...

    Run a command and expect it to let us know that it's ready
    by writing to a special file descriptor passed via KSFT_READY_FD.
    Command will be terminated when we exit the context manager:

        with bkg("my_binary", ksft_wait=5):
    """
    def __init__(self, comm, shell=True, fail=None, ns=None, host=None,
                 exit_wait=False, ksft_wait=None):
        super().__init__(comm, background=True,
                         shell=shell, fail=fail, ns=ns, host=host,
                         ksft_wait=ksft_wait)
        self.terminate = not exit_wait and not ksft_wait
        self.check_fail = fail

        if shell and self.terminate:
            print("# Warning: combining shell and terminate is risky!")
            print("#          SIGTERM may not reach the child on zsh/ksh!")

    def __enter__(self):
        return self

    def __exit__(self, ex_type, ex_value, ex_tb):
        return self.process(terminate=self.terminate, fail=self.check_fail)


global_defer_queue = []


class defer:
    def __init__(self, func, *args, **kwargs):
        global global_defer_queue

        if not callable(func):
            raise Exception("defer created with un-callable object, did you call the function instead of passing its name?")

        self.func = func
        self.args = args
        self.kwargs = kwargs

        self._queue =  global_defer_queue
        self._queue.append(self)

    def __enter__(self):
        return self

    def __exit__(self, ex_type, ex_value, ex_tb):
        return self.exec()

    def exec_only(self):
        self.func(*self.args, **self.kwargs)

    def cancel(self):
        self._queue.remove(self)

    def exec(self):
        self.cancel()
        self.exec_only()


def tool(name, args, json=None, ns=None, host=None):
    cmd_str = name + ' '
    if json:
        cmd_str += '--json '
    cmd_str += args
    cmd_obj = cmd(cmd_str, ns=ns, host=host)
    if json:
        return _json.loads(cmd_obj.stdout)
    return cmd_obj


def ip(args, json=None, ns=None, host=None):
    if ns:
        args = f'-netns {ns} ' + args
    return tool('ip', args, json=json, host=host)


def ethtool(args, json=None, ns=None, host=None):
    return tool('ethtool', args, json=json, ns=ns, host=host)


def rand_port(type=socket.SOCK_STREAM):
    """
    Get a random unprivileged port.
    """
    with socket.socket(socket.AF_INET6, type) as s:
        s.bind(("", 0))
        return s.getsockname()[1]


def wait_port_listen(port, proto="tcp", ns=None, host=None, sleep=0.005, deadline=5):
    end = time.monotonic() + deadline

    pattern = f":{port:04X} .* "
    if proto == "tcp": # for tcp protocol additionally check the socket state
        pattern += "0A"
    pattern = re.compile(pattern)

    while True:
        data = cmd(f'cat /proc/net/{proto}*', ns=ns, host=host, shell=True).stdout
        for row in data.split("\n"):
            if pattern.search(row):
                return
        if time.monotonic() > end:
            raise Exception("Waiting for port listen timed out")
        time.sleep(sleep)