summaryrefslogtreecommitdiff
path: root/kernel/bpf/tnum.c
diff options
context:
space:
mode:
Diffstat (limited to 'kernel/bpf/tnum.c')
-rw-r--r--kernel/bpf/tnum.c132
1 files changed, 119 insertions, 13 deletions
diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c
index 9dbc31b25e3d..ec9c310cf5d7 100644
--- a/kernel/bpf/tnum.c
+++ b/kernel/bpf/tnum.c
@@ -8,6 +8,7 @@
*/
#include <linux/kernel.h>
#include <linux/tnum.h>
+#include <linux/swab.h>
#define TNUM(_v, _m) (struct tnum){.value = _v, .mask = _m}
/* A completely unknown value */
@@ -83,6 +84,11 @@ struct tnum tnum_sub(struct tnum a, struct tnum b)
return TNUM(dv & ~mu, mu);
}
+struct tnum tnum_neg(struct tnum a)
+{
+ return tnum_sub(TNUM(0, 0), a);
+}
+
struct tnum tnum_and(struct tnum a, struct tnum b)
{
u64 alpha, beta, v;
@@ -111,31 +117,55 @@ struct tnum tnum_xor(struct tnum a, struct tnum b)
return TNUM(v & ~mu, mu);
}
-/* Generate partial products by multiplying each bit in the multiplier (tnum a)
- * with the multiplicand (tnum b), and add the partial products after
- * appropriately bit-shifting them. Instead of directly performing tnum addition
- * on the generated partial products, equivalenty, decompose each partial
- * product into two tnums, consisting of the value-sum (acc_v) and the
- * mask-sum (acc_m) and then perform tnum addition on them. The following paper
- * explains the algorithm in more detail: https://arxiv.org/abs/2105.05398.
+/* Perform long multiplication, iterating through the bits in a using rshift:
+ * - if LSB(a) is a known 0, keep current accumulator
+ * - if LSB(a) is a known 1, add b to current accumulator
+ * - if LSB(a) is unknown, take a union of the above cases.
+ *
+ * For example:
+ *
+ * acc_0: acc_1:
+ *
+ * 11 * -> 11 * -> 11 * -> union(0011, 1001) == x0x1
+ * x1 01 11
+ * ------ ------ ------
+ * 11 11 11
+ * xx 00 11
+ * ------ ------ ------
+ * ???? 0011 1001
*/
struct tnum tnum_mul(struct tnum a, struct tnum b)
{
- u64 acc_v = a.value * b.value;
- struct tnum acc_m = TNUM(0, 0);
+ struct tnum acc = TNUM(0, 0);
while (a.value || a.mask) {
/* LSB of tnum a is a certain 1 */
if (a.value & 1)
- acc_m = tnum_add(acc_m, TNUM(0, b.mask));
+ acc = tnum_add(acc, b);
/* LSB of tnum a is uncertain */
- else if (a.mask & 1)
- acc_m = tnum_add(acc_m, TNUM(0, b.value | b.mask));
+ else if (a.mask & 1) {
+ /* acc = tnum_union(acc_0, acc_1), where acc_0 and
+ * acc_1 are partial accumulators for cases
+ * LSB(a) = certain 0 and LSB(a) = certain 1.
+ * acc_0 = acc + 0 * b = acc.
+ * acc_1 = acc + 1 * b = tnum_add(acc, b).
+ */
+
+ acc = tnum_union(acc, tnum_add(acc, b));
+ }
/* Note: no case for LSB is certain 0 */
a = tnum_rshift(a, 1);
b = tnum_lshift(b, 1);
}
- return tnum_add(TNUM(acc_v, 0), acc_m);
+ return acc;
+}
+
+bool tnum_overlap(struct tnum a, struct tnum b)
+{
+ u64 mu;
+
+ mu = ~a.mask & ~b.mask;
+ return (a.value & mu) == (b.value & mu);
}
/* Note that if a and b disagree - i.e. one has a 'known 1' where the other has
@@ -150,6 +180,19 @@ struct tnum tnum_intersect(struct tnum a, struct tnum b)
return TNUM(v & ~mu, mu);
}
+/* Returns a tnum with the uncertainty from both a and b, and in addition, new
+ * uncertainty at any position that a and b disagree. This represents a
+ * superset of the union of the concrete sets of both a and b. Despite the
+ * overapproximation, it is optimal.
+ */
+struct tnum tnum_union(struct tnum a, struct tnum b)
+{
+ u64 v = a.value & b.value;
+ u64 mu = (a.value ^ b.value) | a.mask | b.mask;
+
+ return TNUM(v & ~mu, mu);
+}
+
struct tnum tnum_cast(struct tnum a, u8 size)
{
a.value &= (1ULL << (size * 8)) - 1;
@@ -211,3 +254,66 @@ struct tnum tnum_const_subreg(struct tnum a, u32 value)
{
return tnum_with_subreg(a, tnum_const(value));
}
+
+struct tnum tnum_bswap16(struct tnum a)
+{
+ return TNUM(swab16(a.value & 0xFFFF), swab16(a.mask & 0xFFFF));
+}
+
+struct tnum tnum_bswap32(struct tnum a)
+{
+ return TNUM(swab32(a.value & 0xFFFFFFFF), swab32(a.mask & 0xFFFFFFFF));
+}
+
+struct tnum tnum_bswap64(struct tnum a)
+{
+ return TNUM(swab64(a.value), swab64(a.mask));
+}
+
+/* Given tnum t, and a number z such that tmin <= z < tmax, where tmin
+ * is the smallest member of the t (= t.value) and tmax is the largest
+ * member of t (= t.value | t.mask), returns the smallest member of t
+ * larger than z.
+ *
+ * For example,
+ * t = x11100x0
+ * z = 11110001 (241)
+ * result = 11110010 (242)
+ *
+ * Note: if this function is called with z >= tmax, it just returns
+ * early with tmax; if this function is called with z < tmin, the
+ * algorithm already returns tmin.
+ */
+u64 tnum_step(struct tnum t, u64 z)
+{
+ u64 tmax, d, carry_mask, filled, inc;
+
+ tmax = t.value | t.mask;
+
+ /* if z >= largest member of t, return largest member of t */
+ if (z >= tmax)
+ return tmax;
+
+ /* if z < smallest member of t, return smallest member of t */
+ if (z < t.value)
+ return t.value;
+
+ /*
+ * Let r be the result tnum member, z = t.value + d.
+ * Every tnum member is t.value | s for some submask s of t.mask,
+ * and since t.value & t.mask == 0, t.value | s == t.value + s.
+ * So r > z becomes s > d where d = z - t.value.
+ *
+ * Find the smallest submask s of t.mask greater than d by
+ * "incrementing d within the mask": fill every non-mask
+ * position with 1 (`filled`) so +1 ripples through the gaps,
+ * then keep only mask bits. `carry_mask` additionally fills
+ * positions below the highest non-mask 1 in d, preventing
+ * it from trapping the carry.
+ */
+ d = z - t.value;
+ carry_mask = (1ULL << fls64(d & ~t.mask)) - 1;
+ filled = d | carry_mask | ~t.mask;
+ inc = (filled + 1) & t.mask;
+ return t.value | inc;
+}