summaryrefslogtreecommitdiff
path: root/tools/testing/selftests/net/lib/py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/testing/selftests/net/lib/py')
-rw-r--r--tools/testing/selftests/net/lib/py/__init__.py40
-rw-r--r--tools/testing/selftests/net/lib/py/bpf.py68
-rw-r--r--tools/testing/selftests/net/lib/py/ksft.py271
-rw-r--r--tools/testing/selftests/net/lib/py/netns.py18
-rw-r--r--tools/testing/selftests/net/lib/py/nsim.py2
-rw-r--r--tools/testing/selftests/net/lib/py/utils.py250
-rw-r--r--tools/testing/selftests/net/lib/py/ynl.py34
7 files changed, 612 insertions, 71 deletions
diff --git a/tools/testing/selftests/net/lib/py/__init__.py b/tools/testing/selftests/net/lib/py/__init__.py
index 54d8f5eba810..7c81d86a7e97 100644
--- a/tools/testing/selftests/net/lib/py/__init__.py
+++ b/tools/testing/selftests/net/lib/py/__init__.py
@@ -1,9 +1,37 @@
# SPDX-License-Identifier: GPL-2.0
+"""
+Python selftest helpers for netdev.
+"""
+
from .consts import KSRC
-from .ksft import *
-from .netns import NetNS
-from .nsim import *
-from .utils import *
-from .ynl import NlError, YnlFamily, EthtoolFamily, NetdevFamily, RtnlFamily
-from .ynl import NetshaperFamily
+from .ksft import KsftFailEx, KsftSkipEx, KsftXfailEx, ksft_pr, ksft_eq, \
+ ksft_ne, ksft_true, ksft_not_none, ksft_in, ksft_not_in, ksft_is, \
+ ksft_ge, ksft_gt, ksft_lt, ksft_raises, ksft_busy_wait, \
+ ktap_result, ksft_disruptive, ksft_setup, ksft_run, ksft_exit, \
+ ksft_variants, KsftNamedVariant
+from .netns import NetNS, NetNSEnter
+from .nsim import NetdevSim, NetdevSimDev
+from .utils import CmdExitFailure, fd_read_timeout, cmd, bkg, defer, \
+ bpftool, ip, ethtool, bpftrace, rand_port, rand_ports, wait_port_listen, \
+ wait_file, tool
+from .bpf import bpf_map_set, bpf_map_dump, bpf_prog_map_ids
+from .ynl import NlError, NlctrlFamily, YnlFamily, \
+ EthtoolFamily, NetdevFamily, RtnlFamily, RtnlAddrFamily
+from .ynl import NetshaperFamily, DevlinkFamily, PSPFamily, Netlink
+
+__all__ = ["KSRC",
+ "KsftFailEx", "KsftSkipEx", "KsftXfailEx", "ksft_pr", "ksft_eq",
+ "ksft_ne", "ksft_true", "ksft_not_none", "ksft_in", "ksft_not_in",
+ "ksft_is", "ksft_ge", "ksft_gt", "ksft_lt", "ksft_raises",
+ "ksft_busy_wait", "ktap_result", "ksft_disruptive", "ksft_setup",
+ "ksft_run", "ksft_exit", "ksft_variants", "KsftNamedVariant",
+ "NetNS", "NetNSEnter",
+ "CmdExitFailure", "fd_read_timeout", "cmd", "bkg", "defer",
+ "bpftool", "ip", "ethtool", "bpftrace", "rand_port", "rand_ports",
+ "wait_port_listen", "wait_file", "tool",
+ "bpf_map_set", "bpf_map_dump", "bpf_prog_map_ids",
+ "NetdevSim", "NetdevSimDev",
+ "NetshaperFamily", "DevlinkFamily", "PSPFamily", "NlError",
+ "YnlFamily", "EthtoolFamily", "NetdevFamily", "RtnlFamily",
+ "NlctrlFamily", "RtnlAddrFamily", "Netlink"]
diff --git a/tools/testing/selftests/net/lib/py/bpf.py b/tools/testing/selftests/net/lib/py/bpf.py
new file mode 100644
index 000000000000..beb6bf2896a8
--- /dev/null
+++ b/tools/testing/selftests/net/lib/py/bpf.py
@@ -0,0 +1,68 @@
+# SPDX-License-Identifier: GPL-2.0
+
+"""
+BPF helper utilities for kernel selftests.
+
+Provides common operations for interacting with BPF maps and programs
+via bpftool, used by XDP and other BPF-based test files.
+"""
+
+from .utils import bpftool
+
+def _format_hex_bytes(value):
+ """
+ Helper function that converts an integer into a formatted hexadecimal byte string.
+
+ Args:
+ value: An integer representing the number to be converted.
+
+ Returns:
+ A string representing hexadecimal equivalent of value, with bytes separated by spaces.
+ """
+ hex_str = value.to_bytes(4, byteorder='little', signed=True)
+ return ' '.join(f'{byte:02x}' for byte in hex_str)
+
+
+def bpf_map_set(map_name, key, value):
+ """
+ Updates an XDP map with a given key-value pair using bpftool.
+
+ Args:
+ map_name: The name of the XDP map to update.
+ key: The key to update in the map, formatted as a hexadecimal string.
+ value: The value to associate with the key, formatted as a hexadecimal string.
+ """
+ key_formatted = _format_hex_bytes(key)
+ value_formatted = _format_hex_bytes(value)
+ bpftool(
+ f"map update name {map_name} key hex {key_formatted} value hex {value_formatted}"
+ )
+
+def bpf_map_dump(map_id):
+ """Dump all entries of a BPF array map.
+
+ Args:
+ map_id: Numeric map ID (as returned by bpftool prog show).
+
+ Returns:
+ A dict mapping formatted key (int) to formatted value (int).
+ """
+ raw = bpftool(f"map dump id {map_id}", json=True)
+ return {e["formatted"]["key"]: e["formatted"]["value"] for e in raw}
+
+
+def bpf_prog_map_ids(prog_id):
+ """Get the map name-to-ID mapping for a loaded BPF program.
+
+ Args:
+ prog_id: Numeric program ID.
+
+ Returns:
+ A dict mapping map name (str) to map ID (int).
+ """
+ map_ids = bpftool(f"prog show id {prog_id}", json=True)["map_ids"]
+ maps = {}
+ for mid in map_ids:
+ name = bpftool(f"map show id {mid}", json=True)["name"]
+ maps[name] = mid
+ return maps
diff --git a/tools/testing/selftests/net/lib/py/ksft.py b/tools/testing/selftests/net/lib/py/ksft.py
index 3efe005436cd..81287c2daff0 100644
--- a/tools/testing/selftests/net/lib/py/ksft.py
+++ b/tools/testing/selftests/net/lib/py/ksft.py
@@ -1,13 +1,17 @@
# SPDX-License-Identifier: GPL-2.0
-import builtins
+import fnmatch
import functools
+import getopt
import inspect
+import os
+import signal
import sys
import time
import traceback
+from collections import namedtuple
from .consts import KSFT_MAIN_NAME
-from .utils import global_defer_queue
+from . import utils
KSFT_RESULT = None
KSFT_RESULT_ALL = True
@@ -26,8 +30,67 @@ class KsftXfailEx(Exception):
pass
+class KsftTerminate(KeyboardInterrupt):
+ pass
+
+
+class _KsftArgs:
+ def __init__(self):
+ self.list_tests = False
+ self.filters = []
+
+ try:
+ opts, _ = getopt.getopt(sys.argv[1:], 'hlt:T:')
+ except getopt.GetoptError as e:
+ print(e, file=sys.stderr)
+ sys.exit(1)
+
+ for opt, val in opts:
+ if opt == '-h':
+ print(f"Usage: {sys.argv[0]} [-h|-l] [-t|-T name]\n"
+ f"\t-h print help\n"
+ f"\t-l list tests (filtered, if filters were specified)\n"
+ f"\t-t name include test\n"
+ f"\t-T name exclude test",
+ file=sys.stderr)
+ sys.exit(0)
+ elif opt == '-l':
+ self.list_tests = True
+ elif opt == '-t':
+ self.filters.append((True, val))
+ elif opt == '-T':
+ self.filters.append((False, val))
+
+
+@functools.lru_cache()
+def _ksft_supports_color():
+ if os.environ.get("NO_COLOR") is not None:
+ return False
+ if not hasattr(sys.stdout, "isatty") or not sys.stdout.isatty():
+ return False
+ if os.environ.get("TERM") == "dumb":
+ return False
+ return True
+
+
def ksft_pr(*objs, **kwargs):
- print("#", *objs, **kwargs)
+ """
+ Print logs to stdout.
+
+ Behaves like print() but log lines will be prefixed
+ with # to prevent breaking the TAP output formatting.
+
+ Extra arguments (on top of what print() supports):
+ line_pfx - add extra string before each line
+ """
+ sep = kwargs.pop("sep", " ")
+ pfx = kwargs.pop("line_pfx", "")
+ pfx = "#" + (" " + pfx if pfx else "")
+ kwargs["flush"] = True
+
+ text = sep.join(str(obj) for obj in objs)
+ prefixed = f"\n{pfx} ".join(text.split('\n'))
+ print(pfx, prefixed, **kwargs)
def _fail(*args):
@@ -66,11 +129,21 @@ def ksft_true(a, comment=""):
_fail("Check failed", a, "does not eval to True", comment)
+def ksft_not_none(a, comment=""):
+ if a is None:
+ _fail("Check failed", a, "is None", comment)
+
+
def ksft_in(a, b, comment=""):
if a not in b:
_fail("Check failed", a, "not in", b, comment)
+def ksft_not_in(a, b, comment=""):
+ if a in b:
+ _fail("Check failed", a, "in", b, comment)
+
+
def ksft_is(a, b, comment=""):
if a is not b:
_fail("Check failed", a, "is not", b, comment)
@@ -81,6 +154,11 @@ def ksft_ge(a, b, comment=""):
_fail("Check failed", a, "<", b, comment)
+def ksft_gt(a, b, comment=""):
+ if a <= b:
+ _fail("Check failed", a, "<=", b, comment)
+
+
def ksft_lt(a, b, comment=""):
if a >= b:
_fail("Check failed", a, ">=", b, comment)
@@ -115,7 +193,7 @@ def ksft_busy_wait(cond, sleep=0.005, deadline=1, comment=""):
time.sleep(sleep)
-def ktap_result(ok, cnt=1, case="", comment=""):
+def ktap_result(ok, cnt=1, case_name="", comment=""):
global KSFT_RESULT_ALL
KSFT_RESULT_ALL = KSFT_RESULT_ALL and ok
@@ -125,31 +203,46 @@ def ktap_result(ok, cnt=1, case="", comment=""):
res += "ok "
res += str(cnt) + " "
res += KSFT_MAIN_NAME
- if case:
- res += "." + str(case.__name__)
+ if case_name:
+ res += "." + case_name
if comment:
res += " # " + comment
- print(res)
+ if _ksft_supports_color():
+ if comment.startswith(("SKIP", "XFAIL")):
+ color = "\033[33m"
+ elif ok:
+ color = "\033[32m"
+ else:
+ color = "\033[31m"
+ res = color + res + "\033[0m"
+ print(res, flush=True)
+
+
+def _ksft_defer_arm(state):
+ """ Allow or disallow the use of defer() """
+ utils.GLOBAL_DEFER_ARMED = state
def ksft_flush_defer():
global KSFT_RESULT
i = 0
- qlen_start = len(global_defer_queue)
- while global_defer_queue:
+ qlen_start = len(utils.GLOBAL_DEFER_QUEUE)
+ while utils.GLOBAL_DEFER_QUEUE:
i += 1
- entry = global_defer_queue.pop()
+ entry = utils.GLOBAL_DEFER_QUEUE.pop()
try:
entry.exec_only()
- except:
+ except Exception:
ksft_pr(f"Exception while handling defer / cleanup (callback {i} of {qlen_start})!")
- tb = traceback.format_exc()
- for line in tb.strip().split('\n'):
- ksft_pr("Defer Exception|", line)
+ ksft_pr(traceback.format_exc(), line_pfx="Defer Exception|")
KSFT_RESULT = False
+KsftCaseFunction = namedtuple("KsftCaseFunction",
+ ['name', 'original_func', 'variants'])
+
+
def ksft_disruptive(func):
"""
Decorator that marks the test as disruptive (e.g. the test
@@ -160,11 +253,47 @@ def ksft_disruptive(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not KSFT_DISRUPTIVE:
- raise KsftSkipEx(f"marked as disruptive")
+ raise KsftSkipEx("marked as disruptive")
return func(*args, **kwargs)
return wrapper
+class KsftNamedVariant:
+ """ Named string name + argument list tuple for @ksft_variants """
+
+ def __init__(self, name, *params):
+ self.params = params
+ self.name = name or "_".join([str(x) for x in self.params])
+
+
+def ksft_variants(params):
+ """
+ Decorator defining the sets of inputs for a test.
+ The parameters will be included in the name of the resulting sub-case.
+ Parameters can be either single object, tuple or a KsftNamedVariant.
+ The argument can be a list or a generator.
+
+ Example:
+
+ @ksft_variants([
+ (1, "a"),
+ (2, "b"),
+ KsftNamedVariant("three", 3, "c"),
+ ])
+ def my_case(cfg, a, b):
+ pass # ...
+
+ ksft_run(cases=[my_case], args=(cfg, ))
+
+ Will generate cases:
+ my_case.1_a
+ my_case.2_b
+ my_case.three
+ """
+
+ return lambda func: KsftCaseFunction(func.__name__, func, params)
+
+
def ksft_setup(env):
"""
Setup test framework global state from the environment.
@@ -178,7 +307,7 @@ def ksft_setup(env):
return False
try:
return bool(int(value))
- except:
+ except Exception:
raise Exception(f"failed to parse {name}")
if "DISRUPTIVE" in env:
@@ -188,9 +317,42 @@ def ksft_setup(env):
return env
-def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
+def _ksft_intr(signum, frame):
+ # ksft runner.sh sends 2 SIGTERMs in a row on a timeout
+ # if we don't ignore the second one it will stop us from handling cleanup
+ global term_cnt
+ term_cnt += 1
+ if term_cnt == 1:
+ raise KsftTerminate()
+ else:
+ ksft_pr(f"Ignoring SIGTERM (cnt: {term_cnt}), already exiting...")
+
+
+def _ksft_name_matches(name, pattern):
+ if '*' in pattern or '?' in pattern or '[' in pattern:
+ return fnmatch.fnmatchcase(name, pattern)
+ return name == pattern
+
+
+def _ksft_test_enabled(name, filters):
+ has_positive = False
+ for include, pattern in filters:
+ has_positive |= include
+ if _ksft_name_matches(name, pattern):
+ return include
+ return not has_positive
+
+
+def _ksft_generate_test_cases(cases, globs, case_pfx, args, cli_args):
+ """Generate a filtered list of (func, args, name) tuples.
+
+ If -l is given, prints matching test names and exits.
+ """
+
cases = cases or []
+ test_cases = []
+ # If using the globs method find all relevant functions
if globs and case_pfx:
for key, value in globs.items():
if not callable(value):
@@ -200,22 +362,62 @@ def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
cases.append(value)
break
- totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0}
+ for func in cases:
+ if isinstance(func, KsftCaseFunction):
+ # Parametrized test - create case for each param
+ for param in func.variants:
+ if not isinstance(param, KsftNamedVariant):
+ if not isinstance(param, tuple):
+ param = (param, )
+ param = KsftNamedVariant(None, *param)
+
+ test_cases.append((func.original_func,
+ (*args, *param.params),
+ func.name + "." + param.name))
+ else:
+ test_cases.append((func, args, func.__name__))
- print("KTAP version 1")
- print("1.." + str(len(cases)))
+ if cli_args.filters:
+ test_cases = [tc for tc in test_cases
+ if _ksft_test_enabled(tc[2], cli_args.filters)]
+
+ if cli_args.list_tests:
+ for _, _, name in test_cases:
+ print(name)
+ sys.exit(0)
+
+ return test_cases
+
+
+def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
+ cli_args = _KsftArgs()
+ test_cases = _ksft_generate_test_cases(cases, globs, case_pfx, args,
+ cli_args)
+
+ global term_cnt
+ term_cnt = 0
+ prev_sigterm = signal.signal(signal.SIGTERM, _ksft_intr)
+
+ totals = {"pass": 0, "fail": 0, "skip": 0, "xfail": 0}
global KSFT_RESULT
+ if KSFT_RESULT is not None:
+ raise RuntimeError("ksft_run() can't be called multiple times.")
+
+ print("TAP version 13", flush=True)
+ print("1.." + str(len(test_cases)), flush=True)
+
cnt = 0
stop = False
- for case in cases:
+ for func, args, name in test_cases:
KSFT_RESULT = True
cnt += 1
comment = ""
cnt_key = ""
+ _ksft_defer_arm(True)
try:
- case(*args)
+ func(*args)
except KsftSkipEx as e:
comment = "SKIP " + str(e)
cnt_key = 'skip'
@@ -224,25 +426,38 @@ def ksft_run(cases=None, globs=None, case_pfx=None, args=()):
cnt_key = 'xfail'
except BaseException as e:
stop |= isinstance(e, KeyboardInterrupt)
- tb = traceback.format_exc()
- for line in tb.strip().split('\n'):
- ksft_pr("Exception|", line)
+ ksft_pr(traceback.format_exc(), line_pfx="Exception|")
if stop:
- ksft_pr("Stopping tests due to KeyboardInterrupt.")
+ ksft_pr(f"Stopping tests due to {type(e).__name__}.")
KSFT_RESULT = False
cnt_key = 'fail'
+ _ksft_defer_arm(False)
- ksft_flush_defer()
+ try:
+ ksft_flush_defer()
+ except BaseException as e:
+ ksft_pr(traceback.format_exc(), line_pfx="Exception|")
+ if isinstance(e, KeyboardInterrupt):
+ ksft_pr()
+ ksft_pr("WARN: defer() interrupted, cleanup may be incomplete.")
+ ksft_pr(" Attempting to finish cleanup before exiting.")
+ ksft_pr(" Interrupt again to exit immediately.")
+ ksft_pr()
+ stop = True
+ # Flush was interrupted, try to finish the job best we can
+ ksft_flush_defer()
if not cnt_key:
cnt_key = 'pass' if KSFT_RESULT else 'fail'
- ktap_result(KSFT_RESULT, cnt, case, comment=comment)
+ ktap_result(KSFT_RESULT, cnt, name, comment=comment)
totals[cnt_key] += 1
if stop:
break
+ signal.signal(signal.SIGTERM, prev_sigterm)
+
print(
f"# Totals: pass:{totals['pass']} fail:{totals['fail']} xfail:{totals['xfail']} xpass:0 skip:{totals['skip']} error:0"
)
diff --git a/tools/testing/selftests/net/lib/py/netns.py b/tools/testing/selftests/net/lib/py/netns.py
index ecff85f9074f..8e9317044eef 100644
--- a/tools/testing/selftests/net/lib/py/netns.py
+++ b/tools/testing/selftests/net/lib/py/netns.py
@@ -1,9 +1,12 @@
# SPDX-License-Identifier: GPL-2.0
from .utils import ip
+import ctypes
import random
import string
+libc = ctypes.cdll.LoadLibrary('libc.so.6')
+
class NetNS:
def __init__(self, name=None):
@@ -29,3 +32,18 @@ class NetNS:
def __repr__(self):
return f"NetNS({self.name})"
+
+
+class NetNSEnter:
+ def __init__(self, ns_name):
+ self.ns_path = f"/run/netns/{ns_name}"
+
+ def __enter__(self):
+ self.saved = open("/proc/thread-self/ns/net")
+ with open(self.ns_path) as ns_file:
+ libc.setns(ns_file.fileno(), 0)
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ libc.setns(self.saved.fileno(), 0)
+ self.saved.close()
diff --git a/tools/testing/selftests/net/lib/py/nsim.py b/tools/testing/selftests/net/lib/py/nsim.py
index 1a8cbe9acc48..7c640ed64c0b 100644
--- a/tools/testing/selftests/net/lib/py/nsim.py
+++ b/tools/testing/selftests/net/lib/py/nsim.py
@@ -27,7 +27,7 @@ class NetdevSim:
self.port_index = port_index
self.ns = ns
self.dfs_dir = "%s/ports/%u/" % (nsimdev.dfs_dir, port_index)
- ret = ip("-j link show dev %s" % ifname, ns=ns)
+ ret = ip("-d -j link show dev %s" % ifname, ns=ns)
self.dev = json.loads(ret.stdout)[0]
self.ifindex = self.dev["ifindex"]
diff --git a/tools/testing/selftests/net/lib/py/utils.py b/tools/testing/selftests/net/lib/py/utils.py
index 9e3bcddcf3e8..6c44a3d2bbf7 100644
--- a/tools/testing/selftests/net/lib/py/utils.py
+++ b/tools/testing/selftests/net/lib/py/utils.py
@@ -1,80 +1,193 @@
# SPDX-License-Identifier: GPL-2.0
-import errno
import json as _json
-import random
+import os
import re
+import select
import socket
import subprocess
import time
+class CmdInitFailure(Exception):
+ """ Command failed to start. Only raised by bkg(). """
+ def __init__(self, msg, cmd_obj):
+ super().__init__(msg + "\n" + repr(cmd_obj))
+ self.cmd = cmd_obj
+
+
class CmdExitFailure(Exception):
+ """ Command failed (returned non-zero exit code). """
def __init__(self, msg, cmd_obj):
- super().__init__(msg)
+ super().__init__(msg + "\n" + repr(cmd_obj))
self.cmd = cmd_obj
+def fd_read_timeout(fd, timeout):
+ rlist, _, _ = select.select([fd], [], [], timeout)
+ if rlist:
+ return os.read(fd, 1024)
+ raise TimeoutError("Timeout waiting for fd read")
+
+
class cmd:
- def __init__(self, comm, shell=True, fail=True, ns=None, background=False, host=None, timeout=5):
+ """
+ Execute a command on local or remote host.
+
+ @shell defaults to false, and class will try to split @comm into a list
+ if it's a string with spaces.
+
+ Use bkg() instead to run a command in the background.
+ """
+ def __init__(self, comm, shell=None, fail=True, ns=None, background=False,
+ host=None, timeout=5, ksft_ready=None, ksft_wait=None):
if ns:
comm = f'ip netns exec {ns} ' + comm
self.stdout = None
self.stderr = None
self.ret = None
+ self.ksft_term_fd = None
+ self.host = host
self.comm = comm
+
if host:
self.proc = host.cmd(comm)
else:
+ # If user doesn't explicitly request shell try to avoid it.
+ if shell is None and isinstance(comm, str) and ' ' in comm:
+ comm = comm.split()
+
+ # ksft_wait lets us wait for the background process to fully start,
+ # we pass an FD to the child process, and wait for it to write back.
+ # Similarly term_fd tells child it's time to exit.
+ pass_fds = []
+ env = os.environ.copy()
+ if ksft_wait is not None:
+ wait_fd, self.ksft_term_fd = os.pipe()
+ pass_fds.append(wait_fd)
+ env["KSFT_WAIT_FD"] = str(wait_fd)
+ ksft_ready = True # ksft_wait implies ready
+ if ksft_ready is not None:
+ rfd, ready_fd = os.pipe()
+ pass_fds.append(ready_fd)
+ env["KSFT_READY_FD"] = str(ready_fd)
+
self.proc = subprocess.Popen(comm, shell=shell, stdout=subprocess.PIPE,
- stderr=subprocess.PIPE)
+ stderr=subprocess.PIPE, pass_fds=pass_fds,
+ env=env)
+ if ksft_wait is not None:
+ os.close(wait_fd)
+ if ksft_ready is not None:
+ os.close(ready_fd)
+ msg = fd_read_timeout(rfd, ksft_wait)
+ os.close(rfd)
+ if not msg:
+ terminate = self.proc.poll() is None
+ self._process_terminate(terminate=terminate, timeout=1)
+ raise CmdInitFailure("Did not receive ready message", self)
if not background:
self.process(terminate=False, fail=fail, timeout=timeout)
- def process(self, terminate=True, fail=None, timeout=5):
- if fail is None:
- fail = not terminate
-
+ def _process_terminate(self, terminate, timeout):
if terminate:
self.proc.terminate()
- stdout, stderr = self.proc.communicate(timeout)
+ stdout, stderr = self.proc.communicate(timeout=timeout)
self.stdout = stdout.decode("utf-8")
self.stderr = stderr.decode("utf-8")
self.proc.stdout.close()
self.proc.stderr.close()
self.ret = self.proc.returncode
+ return stdout, stderr
+
+ def process(self, terminate=True, fail=None, timeout=5):
+ if fail is None:
+ fail = not terminate
+
+ if self.ksft_term_fd:
+ os.write(self.ksft_term_fd, b"1")
+
+ stdout, stderr = self._process_terminate(terminate=terminate,
+ timeout=timeout)
if self.proc.returncode != 0 and fail:
if len(stderr) > 0 and stderr[-1] == "\n":
stderr = stderr[:-1]
- raise CmdExitFailure("Command failed: %s\nSTDOUT: %s\nSTDERR: %s" %
- (self.proc.args, stdout, stderr), self)
+ raise CmdExitFailure("Command failed", self)
+
+ def __repr__(self):
+ def str_fmt(name, s):
+ name += ': '
+ return (name + s.strip().replace('\n', '\n' + ' ' * len(name)))
+
+ ret = "CMD"
+ if self.host:
+ ret += "[remote]"
+ if self.ret is None:
+ ret += f" (unterminated): {self.comm}\n"
+ elif self.ret == 0:
+ ret += f" (success): {self.comm}\n"
+ else:
+ ret += f": {self.comm}\n"
+ ret += f" EXIT: {self.ret}\n"
+ if self.stdout:
+ ret += str_fmt(" STDOUT", self.stdout) + "\n"
+ if self.stderr:
+ ret += str_fmt(" STDERR", self.stderr) + "\n"
+ return ret.strip()
class bkg(cmd):
- def __init__(self, comm, shell=True, fail=None, ns=None, host=None,
- exit_wait=False):
+ """
+ Run a command in the background.
+
+ Examples usage:
+
+ Run a command on remote host, and wait for it to finish.
+ This is usually paired with wait_port_listen() to make sure
+ the command has initialized:
+
+ with bkg("socat ...", exit_wait=True, host=cfg.remote) as nc:
+ ...
+
+ Run a command and expect it to let us know that it's ready
+ by writing to a special file descriptor passed via KSFT_READY_FD.
+ Command will be terminated when we exit the context manager:
+
+ with bkg("my_binary", ksft_wait=5):
+ """
+ def __init__(self, comm, shell=None, fail=None, ns=None, host=None,
+ exit_wait=False, ksft_ready=None, ksft_wait=None):
super().__init__(comm, background=True,
- shell=shell, fail=fail, ns=ns, host=host)
- self.terminate = not exit_wait
+ shell=shell, fail=fail, ns=ns, host=host,
+ ksft_ready=ksft_ready, ksft_wait=ksft_wait)
+ self.terminate = not exit_wait and not ksft_wait
+ self._exit_wait = exit_wait
self.check_fail = fail
+ if shell and self.terminate:
+ print("# Warning: combining shell and terminate is risky!")
+ print("# SIGTERM may not reach the child on zsh/ksh!")
+
def __enter__(self):
return self
def __exit__(self, ex_type, ex_value, ex_tb):
- return self.process(terminate=self.terminate, fail=self.check_fail)
+ terminate = self.terminate
+ # Force termination on exception, but only if bkg() didn't already exit
+ # since forcing termination silences failures with fail=None
+ if self.proc.poll() is None:
+ terminate = terminate or (self._exit_wait and ex_type is not None)
+ return self.process(terminate=terminate, fail=self.check_fail)
-global_defer_queue = []
+GLOBAL_DEFER_QUEUE = []
+GLOBAL_DEFER_ARMED = False
class defer:
def __init__(self, func, *args, **kwargs):
- global global_defer_queue
-
if not callable(func):
raise Exception("defer created with un-callable object, did you call the function instead of passing its name?")
@@ -82,7 +195,9 @@ class defer:
self.args = args
self.kwargs = kwargs
- self._queue = global_defer_queue
+ if not GLOBAL_DEFER_ARMED:
+ raise Exception("defer queue not armed, did you use defer() outside of a test case?")
+ self._queue = GLOBAL_DEFER_QUEUE
self._queue.append(self)
def __enter__(self):
@@ -113,6 +228,10 @@ def tool(name, args, json=None, ns=None, host=None):
return cmd_obj
+def bpftool(args, json=None, ns=None, host=None):
+ return tool('bpftool', args, json=json, ns=ns, host=host)
+
+
def ip(args, json=None, ns=None, host=None):
if ns:
args = f'-netns {ns} ' + args
@@ -123,20 +242,67 @@ def ethtool(args, json=None, ns=None, host=None):
return tool('ethtool', args, json=json, ns=ns, host=host)
-def rand_port():
+def bpftrace(expr, json=None, ns=None, host=None, timeout=None):
+ """
+ Run bpftrace and return map data (if json=True).
+ The output of bpftrace is inconvenient, so the helper converts
+ to a dict indexed by map name, e.g.:
+ {
+ "@": { ... },
+ "@map2": { ... },
+ }
+ """
+ cmd_arr = ['bpftrace']
+ # Throw in --quiet if json, otherwise the output has two objects
+ if json:
+ cmd_arr += ['-f', 'json', '-q']
+ if timeout:
+ expr += ' interval:s:' + str(timeout) + ' { exit(); }'
+ timeout += 20
+ cmd_arr += ['-e', expr]
+ cmd_obj = cmd(cmd_arr, ns=ns, host=host, shell=False, timeout=timeout)
+ if json:
+ # bpftrace prints objects as lines
+ ret = {}
+ for l in cmd_obj.stdout.split('\n'):
+ if not l.strip():
+ continue
+ one = _json.loads(l)
+ if one.get('type') != 'map':
+ continue
+ for k, v in one["data"].items():
+ if k.startswith('@'):
+ k = k.lstrip('@')
+ ret[k] = v
+ return ret
+ return cmd_obj
+
+
+def rand_port(stype=socket.SOCK_STREAM):
+ """
+ Get a random unprivileged port.
"""
- Get a random unprivileged port, try to make sure it's not already used.
+ return rand_ports(1, stype)[0]
+
+
+def rand_ports(count, stype=socket.SOCK_STREAM):
"""
- for _ in range(1000):
- port = random.randint(10000, 65535)
- try:
- with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
- s.bind(("", port))
- return port
- except OSError as e:
- if e.errno != errno.EADDRINUSE:
- raise
- raise Exception("Can't find any free unprivileged port")
+ Get a unique set of random unprivileged ports.
+ """
+ sockets = []
+ ports = []
+
+ try:
+ for _ in range(count):
+ s = socket.socket(socket.AF_INET6, stype)
+ sockets.append(s)
+ s.bind(("", 0))
+ ports.append(s.getsockname()[1])
+ finally:
+ for s in sockets:
+ s.close()
+
+ return ports
def wait_port_listen(port, proto="tcp", ns=None, host=None, sleep=0.005, deadline=5):
@@ -155,3 +321,21 @@ def wait_port_listen(port, proto="tcp", ns=None, host=None, sleep=0.005, deadlin
if time.monotonic() > end:
raise Exception("Waiting for port listen timed out")
time.sleep(sleep)
+
+
+def wait_file(fname, test_fn, sleep=0.005, deadline=5, encoding='utf-8'):
+ """
+ Wait for file contents on the local system to satisfy a condition.
+ test_fn() should take one argument (file contents) and return whether
+ condition is met.
+ """
+ end = time.monotonic() + deadline
+
+ with open(fname, "r", encoding=encoding) as fp:
+ while True:
+ if test_fn(fp.read()):
+ break
+ fp.seek(0)
+ if time.monotonic() > end:
+ raise TimeoutError("Wait for file contents failed", fname)
+ time.sleep(sleep)
diff --git a/tools/testing/selftests/net/lib/py/ynl.py b/tools/testing/selftests/net/lib/py/ynl.py
index ad1e36baee2a..2e567062aa6c 100644
--- a/tools/testing/selftests/net/lib/py/ynl.py
+++ b/tools/testing/selftests/net/lib/py/ynl.py
@@ -13,20 +13,27 @@ try:
SPEC_PATH = KSFT_DIR / "net/lib/specs"
sys.path.append(tools_full_path.as_posix())
- from net.lib.ynl.pyynl.lib import YnlFamily, NlError
+ from net.lib.ynl.pyynl.lib import YnlFamily, NlError, NlPolicy, Netlink
else:
# Running in tree
tools_full_path = KSRC / "tools"
SPEC_PATH = KSRC / "Documentation/netlink/specs"
sys.path.append(tools_full_path.as_posix())
- from net.ynl.pyynl.lib import YnlFamily, NlError
+ from net.ynl.pyynl.lib import YnlFamily, NlError, NlPolicy, Netlink
except ModuleNotFoundError as e:
ksft_pr("Failed importing `ynl` library from kernel sources")
ksft_pr(str(e))
ktap_result(True, comment="SKIP")
sys.exit(4)
+__all__ = [
+ "NlError", "NlPolicy", "Netlink", "YnlFamily", "SPEC_PATH",
+ "EthtoolFamily", "RtnlFamily", "RtnlAddrFamily",
+ "NetdevFamily", "NetshaperFamily", "NlctrlFamily", "DevlinkFamily",
+ "PSPFamily",
+]
+
#
# Wrapper classes, loading the right specs
# Set schema='' to avoid jsonschema validation, it's slow
@@ -39,9 +46,13 @@ class EthtoolFamily(YnlFamily):
class RtnlFamily(YnlFamily):
def __init__(self, recv_size=0):
- super().__init__((SPEC_PATH / Path('rt_link.yaml')).as_posix(),
+ super().__init__((SPEC_PATH / Path('rt-link.yaml')).as_posix(),
schema='', recv_size=recv_size)
+class RtnlAddrFamily(YnlFamily):
+ def __init__(self, recv_size=0):
+ super().__init__((SPEC_PATH / Path('rt-addr.yaml')).as_posix(),
+ schema='', recv_size=recv_size)
class NetdevFamily(YnlFamily):
def __init__(self, recv_size=0):
@@ -52,3 +63,20 @@ class NetshaperFamily(YnlFamily):
def __init__(self, recv_size=0):
super().__init__((SPEC_PATH / Path('net_shaper.yaml')).as_posix(),
schema='', recv_size=recv_size)
+
+
+class NlctrlFamily(YnlFamily):
+ def __init__(self, recv_size=0):
+ super().__init__((SPEC_PATH / Path('nlctrl.yaml')).as_posix(),
+ schema='', recv_size=recv_size)
+
+
+class DevlinkFamily(YnlFamily):
+ def __init__(self, recv_size=0):
+ super().__init__((SPEC_PATH / Path('devlink.yaml')).as_posix(),
+ schema='', recv_size=recv_size)
+
+class PSPFamily(YnlFamily):
+ def __init__(self, recv_size=0):
+ super().__init__((SPEC_PATH / Path('psp.yaml')).as_posix(),
+ schema='', recv_size=recv_size)