diff options
Diffstat (limited to 'net/xfrm')
-rw-r--r-- | net/xfrm/xfrm_algo.c | 16 | ||||
-rw-r--r-- | net/xfrm/xfrm_input.c | 2 | ||||
-rw-r--r-- | net/xfrm/xfrm_ipcomp.c | 16 | ||||
-rw-r--r-- | net/xfrm/xfrm_policy.c | 120 | ||||
-rw-r--r-- | net/xfrm/xfrm_proc.c | 6 | ||||
-rw-r--r-- | net/xfrm/xfrm_state.c | 106 | ||||
-rw-r--r-- | net/xfrm/xfrm_sysctl.c | 4 | ||||
-rw-r--r-- | net/xfrm/xfrm_user.c | 125 |
8 files changed, 290 insertions, 105 deletions
diff --git a/net/xfrm/xfrm_algo.c b/net/xfrm/xfrm_algo.c index 743c0134a6a9..8b4d6e3246e5 100644 --- a/net/xfrm/xfrm_algo.c +++ b/net/xfrm/xfrm_algo.c @@ -125,6 +125,22 @@ static struct xfrm_algo_desc aead_list[] = { .sadb_alg_maxbits = 256 } }, +{ + .name = "rfc4543(gcm(aes))", + + .uinfo = { + .aead = { + .icv_truncbits = 128, + } + }, + + .desc = { + .sadb_alg_id = SADB_X_EALG_NULL_AES_GMAC, + .sadb_alg_ivlen = 8, + .sadb_alg_minbits = 128, + .sadb_alg_maxbits = 256 + } +}, }; static struct xfrm_algo_desc aalg_list[] = { diff --git a/net/xfrm/xfrm_input.c b/net/xfrm/xfrm_input.c index e0009c17d809..45f1c98d4fce 100644 --- a/net/xfrm/xfrm_input.c +++ b/net/xfrm/xfrm_input.c @@ -152,7 +152,7 @@ int xfrm_input(struct sk_buff *skb, int nexthdr, __be32 spi, int encap_type) goto drop; } - x = xfrm_state_lookup(net, daddr, spi, nexthdr, family); + x = xfrm_state_lookup(net, skb->mark, daddr, spi, nexthdr, family); if (x == NULL) { XFRM_INC_STATS(net, LINUX_MIB_XFRMINNOSTATES); xfrm_audit_state_notfound(skb, family, spi, seq); diff --git a/net/xfrm/xfrm_ipcomp.c b/net/xfrm/xfrm_ipcomp.c index 42cd18391f46..0fc5ff66d1fa 100644 --- a/net/xfrm/xfrm_ipcomp.c +++ b/net/xfrm/xfrm_ipcomp.c @@ -30,12 +30,12 @@ struct ipcomp_tfms { struct list_head list; - struct crypto_comp **tfms; + struct crypto_comp * __percpu *tfms; int users; }; static DEFINE_MUTEX(ipcomp_resource_mutex); -static void **ipcomp_scratches; +static void * __percpu *ipcomp_scratches; static int ipcomp_scratch_users; static LIST_HEAD(ipcomp_tfms_list); @@ -200,7 +200,7 @@ EXPORT_SYMBOL_GPL(ipcomp_output); static void ipcomp_free_scratches(void) { int i; - void **scratches; + void * __percpu *scratches; if (--ipcomp_scratch_users) return; @@ -215,10 +215,10 @@ static void ipcomp_free_scratches(void) free_percpu(scratches); } -static void **ipcomp_alloc_scratches(void) +static void * __percpu *ipcomp_alloc_scratches(void) { int i; - void **scratches; + void * __percpu *scratches; if (ipcomp_scratch_users++) return ipcomp_scratches; @@ -239,7 +239,7 @@ static void **ipcomp_alloc_scratches(void) return scratches; } -static void ipcomp_free_tfms(struct crypto_comp **tfms) +static void ipcomp_free_tfms(struct crypto_comp * __percpu *tfms) { struct ipcomp_tfms *pos; int cpu; @@ -267,10 +267,10 @@ static void ipcomp_free_tfms(struct crypto_comp **tfms) free_percpu(tfms); } -static struct crypto_comp **ipcomp_alloc_tfms(const char *alg_name) +static struct crypto_comp * __percpu *ipcomp_alloc_tfms(const char *alg_name) { struct ipcomp_tfms *pos; - struct crypto_comp **tfms; + struct crypto_comp * __percpu *tfms; int cpu; /* This can be any valid CPU ID so we don't need locking. */ diff --git a/net/xfrm/xfrm_policy.c b/net/xfrm/xfrm_policy.c index cb81ca35b0d6..34a5ef8316e7 100644 --- a/net/xfrm/xfrm_policy.c +++ b/net/xfrm/xfrm_policy.c @@ -469,16 +469,16 @@ static inline int xfrm_byidx_should_resize(struct net *net, int total) return 0; } -void xfrm_spd_getinfo(struct xfrmk_spdinfo *si) +void xfrm_spd_getinfo(struct net *net, struct xfrmk_spdinfo *si) { read_lock_bh(&xfrm_policy_lock); - si->incnt = init_net.xfrm.policy_count[XFRM_POLICY_IN]; - si->outcnt = init_net.xfrm.policy_count[XFRM_POLICY_OUT]; - si->fwdcnt = init_net.xfrm.policy_count[XFRM_POLICY_FWD]; - si->inscnt = init_net.xfrm.policy_count[XFRM_POLICY_IN+XFRM_POLICY_MAX]; - si->outscnt = init_net.xfrm.policy_count[XFRM_POLICY_OUT+XFRM_POLICY_MAX]; - si->fwdscnt = init_net.xfrm.policy_count[XFRM_POLICY_FWD+XFRM_POLICY_MAX]; - si->spdhcnt = init_net.xfrm.policy_idx_hmask; + si->incnt = net->xfrm.policy_count[XFRM_POLICY_IN]; + si->outcnt = net->xfrm.policy_count[XFRM_POLICY_OUT]; + si->fwdcnt = net->xfrm.policy_count[XFRM_POLICY_FWD]; + si->inscnt = net->xfrm.policy_count[XFRM_POLICY_IN+XFRM_POLICY_MAX]; + si->outscnt = net->xfrm.policy_count[XFRM_POLICY_OUT+XFRM_POLICY_MAX]; + si->fwdscnt = net->xfrm.policy_count[XFRM_POLICY_FWD+XFRM_POLICY_MAX]; + si->spdhcnt = net->xfrm.policy_idx_hmask; si->spdhmcnt = xfrm_policy_hashmax; read_unlock_bh(&xfrm_policy_lock); } @@ -556,6 +556,7 @@ int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl) struct hlist_head *chain; struct hlist_node *entry, *newpos; struct dst_entry *gc_list; + u32 mark = policy->mark.v & policy->mark.m; write_lock_bh(&xfrm_policy_lock); chain = policy_hash_bysel(net, &policy->selector, policy->family, dir); @@ -564,6 +565,7 @@ int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl) hlist_for_each_entry(pol, entry, chain, bydst) { if (pol->type == policy->type && !selector_cmp(&pol->selector, &policy->selector) && + (mark & pol->mark.m) == pol->mark.v && xfrm_sec_ctx_match(pol->security, policy->security) && !WARN_ON(delpol)) { if (excl) { @@ -635,8 +637,8 @@ int xfrm_policy_insert(int dir, struct xfrm_policy *policy, int excl) } EXPORT_SYMBOL(xfrm_policy_insert); -struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u8 type, int dir, - struct xfrm_selector *sel, +struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u32 mark, u8 type, + int dir, struct xfrm_selector *sel, struct xfrm_sec_ctx *ctx, int delete, int *err) { @@ -650,6 +652,7 @@ struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u8 type, int dir, ret = NULL; hlist_for_each_entry(pol, entry, chain, bydst) { if (pol->type == type && + (mark & pol->mark.m) == pol->mark.v && !selector_cmp(sel, &pol->selector) && xfrm_sec_ctx_match(ctx, pol->security)) { xfrm_pol_hold(pol); @@ -676,8 +679,8 @@ struct xfrm_policy *xfrm_policy_bysel_ctx(struct net *net, u8 type, int dir, } EXPORT_SYMBOL(xfrm_policy_bysel_ctx); -struct xfrm_policy *xfrm_policy_byid(struct net *net, u8 type, int dir, u32 id, - int delete, int *err) +struct xfrm_policy *xfrm_policy_byid(struct net *net, u32 mark, u8 type, + int dir, u32 id, int delete, int *err) { struct xfrm_policy *pol, *ret; struct hlist_head *chain; @@ -692,7 +695,8 @@ struct xfrm_policy *xfrm_policy_byid(struct net *net, u8 type, int dir, u32 id, chain = net->xfrm.policy_byidx + idx_hash(net, id); ret = NULL; hlist_for_each_entry(pol, entry, chain, byidx) { - if (pol->type == type && pol->index == id) { + if (pol->type == type && pol->index == id && + (mark & pol->mark.m) == pol->mark.v) { xfrm_pol_hold(pol); if (delete) { *err = security_xfrm_policy_delete( @@ -771,7 +775,8 @@ xfrm_policy_flush_secctx_check(struct net *net, u8 type, struct xfrm_audit *audi int xfrm_policy_flush(struct net *net, u8 type, struct xfrm_audit *audit_info) { - int dir, err = 0; + int dir, err = 0, cnt = 0; + struct xfrm_policy *dp; write_lock_bh(&xfrm_policy_lock); @@ -789,8 +794,10 @@ int xfrm_policy_flush(struct net *net, u8 type, struct xfrm_audit *audit_info) &net->xfrm.policy_inexact[dir], bydst) { if (pol->type != type) continue; - __xfrm_policy_unlink(pol, dir); + dp = __xfrm_policy_unlink(pol, dir); write_unlock_bh(&xfrm_policy_lock); + if (dp) + cnt++; xfrm_audit_policy_delete(pol, 1, audit_info->loginuid, audit_info->sessionid, @@ -809,8 +816,10 @@ int xfrm_policy_flush(struct net *net, u8 type, struct xfrm_audit *audit_info) bydst) { if (pol->type != type) continue; - __xfrm_policy_unlink(pol, dir); + dp = __xfrm_policy_unlink(pol, dir); write_unlock_bh(&xfrm_policy_lock); + if (dp) + cnt++; xfrm_audit_policy_delete(pol, 1, audit_info->loginuid, @@ -824,6 +833,8 @@ int xfrm_policy_flush(struct net *net, u8 type, struct xfrm_audit *audit_info) } } + if (!cnt) + err = -ESRCH; atomic_inc(&flow_cache_genid); out: write_unlock_bh(&xfrm_policy_lock); @@ -909,6 +920,7 @@ static int xfrm_policy_match(struct xfrm_policy *pol, struct flowi *fl, int match, ret = -ESRCH; if (pol->family != family || + (fl->mark & pol->mark.m) != pol->mark.v || pol->type != type) return ret; @@ -1033,6 +1045,10 @@ static struct xfrm_policy *xfrm_sk_policy_lookup(struct sock *sk, int dir, struc int err = 0; if (match) { + if ((sk->sk_mark & pol->mark.m) != pol->mark.v) { + pol = NULL; + goto out; + } err = security_xfrm_policy_lookup(pol->security, fl->secid, policy_to_flow_dir(dir)); @@ -1045,6 +1061,7 @@ static struct xfrm_policy *xfrm_sk_policy_lookup(struct sock *sk, int dir, struc } else pol = NULL; } +out: read_unlock_bh(&xfrm_policy_lock); return pol; } @@ -1137,6 +1154,7 @@ static struct xfrm_policy *clone_policy(struct xfrm_policy *old, int dir) } newp->lft = old->lft; newp->curlft = old->curlft; + newp->mark = old->mark; newp->action = old->action; newp->flags = old->flags; newp->xfrm_nr = old->xfrm_nr; @@ -1309,15 +1327,28 @@ static inline int xfrm_get_tos(struct flowi *fl, int family) return tos; } -static inline struct xfrm_dst *xfrm_alloc_dst(int family) +static inline struct xfrm_dst *xfrm_alloc_dst(struct net *net, int family) { struct xfrm_policy_afinfo *afinfo = xfrm_policy_get_afinfo(family); + struct dst_ops *dst_ops; struct xfrm_dst *xdst; if (!afinfo) return ERR_PTR(-EINVAL); - xdst = dst_alloc(afinfo->dst_ops) ?: ERR_PTR(-ENOBUFS); + switch (family) { + case AF_INET: + dst_ops = &net->xfrm.xfrm4_dst_ops; + break; +#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) + case AF_INET6: + dst_ops = &net->xfrm.xfrm6_dst_ops; + break; +#endif + default: + BUG(); + } + xdst = dst_alloc(dst_ops) ?: ERR_PTR(-ENOBUFS); xfrm_policy_put_afinfo(afinfo); @@ -1366,6 +1397,7 @@ static struct dst_entry *xfrm_bundle_create(struct xfrm_policy *policy, struct flowi *fl, struct dst_entry *dst) { + struct net *net = xp_net(policy); unsigned long now = jiffies; struct net_device *dev; struct dst_entry *dst_prev = NULL; @@ -1389,7 +1421,7 @@ static struct dst_entry *xfrm_bundle_create(struct xfrm_policy *policy, dst_hold(dst); for (; i < nx; i++) { - struct xfrm_dst *xdst = xfrm_alloc_dst(family); + struct xfrm_dst *xdst = xfrm_alloc_dst(net, family); struct dst_entry *dst1 = &xdst->u.dst; err = PTR_ERR(xdst); @@ -1445,7 +1477,7 @@ static struct dst_entry *xfrm_bundle_create(struct xfrm_policy *policy, if (!dev) goto free_dst; - /* Copy neighbout for reachability confirmation */ + /* Copy neighbour for reachability confirmation */ dst0->neighbour = neigh_clone(dst->neighbour); xfrm_init_path((struct xfrm_dst *)dst0, dst, nfheader_len); @@ -2031,8 +2063,7 @@ int __xfrm_route_forward(struct sk_buff *skb, unsigned short family) int res; if (xfrm_decode_session(skb, &fl, family) < 0) { - /* XXX: we should have something like FWDHDRERROR here. */ - XFRM_INC_STATS(net, LINUX_MIB_XFRMINHDRERROR); + XFRM_INC_STATS(net, LINUX_MIB_XFRMFWDHDRERROR); return 0; } @@ -2279,6 +2310,7 @@ EXPORT_SYMBOL(xfrm_bundle_ok); int xfrm_policy_register_afinfo(struct xfrm_policy_afinfo *afinfo) { + struct net *net; int err = 0; if (unlikely(afinfo == NULL)) return -EINVAL; @@ -2302,6 +2334,27 @@ int xfrm_policy_register_afinfo(struct xfrm_policy_afinfo *afinfo) xfrm_policy_afinfo[afinfo->family] = afinfo; } write_unlock_bh(&xfrm_policy_afinfo_lock); + + rtnl_lock(); + for_each_net(net) { + struct dst_ops *xfrm_dst_ops; + + switch (afinfo->family) { + case AF_INET: + xfrm_dst_ops = &net->xfrm.xfrm4_dst_ops; + break; +#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) + case AF_INET6: + xfrm_dst_ops = &net->xfrm.xfrm6_dst_ops; + break; +#endif + default: + BUG(); + } + *xfrm_dst_ops = *afinfo->dst_ops; + } + rtnl_unlock(); + return err; } EXPORT_SYMBOL(xfrm_policy_register_afinfo); @@ -2332,6 +2385,22 @@ int xfrm_policy_unregister_afinfo(struct xfrm_policy_afinfo *afinfo) } EXPORT_SYMBOL(xfrm_policy_unregister_afinfo); +static void __net_init xfrm_dst_ops_init(struct net *net) +{ + struct xfrm_policy_afinfo *afinfo; + + read_lock_bh(&xfrm_policy_afinfo_lock); + afinfo = xfrm_policy_afinfo[AF_INET]; + if (afinfo) + net->xfrm.xfrm4_dst_ops = *afinfo->dst_ops; +#if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE) + afinfo = xfrm_policy_afinfo[AF_INET6]; + if (afinfo) + net->xfrm.xfrm6_dst_ops = *afinfo->dst_ops; +#endif + read_unlock_bh(&xfrm_policy_afinfo_lock); +} + static struct xfrm_policy_afinfo *xfrm_policy_get_afinfo(unsigned short family) { struct xfrm_policy_afinfo *afinfo; @@ -2369,19 +2438,19 @@ static int __net_init xfrm_statistics_init(struct net *net) { int rv; - if (snmp_mib_init((void **)net->mib.xfrm_statistics, + if (snmp_mib_init((void __percpu **)net->mib.xfrm_statistics, sizeof(struct linux_xfrm_mib)) < 0) return -ENOMEM; rv = xfrm_proc_init(net); if (rv < 0) - snmp_mib_free((void **)net->mib.xfrm_statistics); + snmp_mib_free((void __percpu **)net->mib.xfrm_statistics); return rv; } static void xfrm_statistics_fini(struct net *net) { xfrm_proc_fini(net); - snmp_mib_free((void **)net->mib.xfrm_statistics); + snmp_mib_free((void __percpu **)net->mib.xfrm_statistics); } #else static int __net_init xfrm_statistics_init(struct net *net) @@ -2494,6 +2563,7 @@ static int __net_init xfrm_net_init(struct net *net) rv = xfrm_policy_init(net); if (rv < 0) goto out_policy; + xfrm_dst_ops_init(net); rv = xfrm_sysctl_init(net); if (rv < 0) goto out_sysctl; diff --git a/net/xfrm/xfrm_proc.c b/net/xfrm/xfrm_proc.c index fef8db553e8d..58d9ae005597 100644 --- a/net/xfrm/xfrm_proc.c +++ b/net/xfrm/xfrm_proc.c @@ -15,7 +15,7 @@ #include <net/snmp.h> #include <net/xfrm.h> -static struct snmp_mib xfrm_mib_list[] = { +static const struct snmp_mib xfrm_mib_list[] = { SNMP_MIB_ITEM("XfrmInError", LINUX_MIB_XFRMINERROR), SNMP_MIB_ITEM("XfrmInBufferError", LINUX_MIB_XFRMINBUFFERERROR), SNMP_MIB_ITEM("XfrmInHdrError", LINUX_MIB_XFRMINHDRERROR), @@ -41,6 +41,7 @@ static struct snmp_mib xfrm_mib_list[] = { SNMP_MIB_ITEM("XfrmOutPolBlock", LINUX_MIB_XFRMOUTPOLBLOCK), SNMP_MIB_ITEM("XfrmOutPolDead", LINUX_MIB_XFRMOUTPOLDEAD), SNMP_MIB_ITEM("XfrmOutPolError", LINUX_MIB_XFRMOUTPOLERROR), + SNMP_MIB_ITEM("XfrmFwdHdrError", LINUX_MIB_XFRMFWDHDRERROR), SNMP_MIB_SENTINEL }; @@ -50,7 +51,8 @@ static int xfrm_statistics_seq_show(struct seq_file *seq, void *v) int i; for (i=0; xfrm_mib_list[i].name; i++) seq_printf(seq, "%-24s\t%lu\n", xfrm_mib_list[i].name, - snmp_fold_field((void **)net->mib.xfrm_statistics, + snmp_fold_field((void __percpu **) + net->mib.xfrm_statistics, xfrm_mib_list[i].entry)); return 0; } diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c index d847f1a52b44..17d5b96f2fc8 100644 --- a/net/xfrm/xfrm_state.c +++ b/net/xfrm/xfrm_state.c @@ -603,13 +603,14 @@ xfrm_state_flush_secctx_check(struct net *net, u8 proto, struct xfrm_audit *audi int xfrm_state_flush(struct net *net, u8 proto, struct xfrm_audit *audit_info) { - int i, err = 0; + int i, err = 0, cnt = 0; spin_lock_bh(&xfrm_state_lock); err = xfrm_state_flush_secctx_check(net, proto, audit_info); if (err) goto out; + err = -ESRCH; for (i = 0; i <= net->xfrm.state_hmask; i++) { struct hlist_node *entry; struct xfrm_state *x; @@ -626,13 +627,16 @@ restart: audit_info->sessionid, audit_info->secid); xfrm_state_put(x); + if (!err) + cnt++; spin_lock_bh(&xfrm_state_lock); goto restart; } } } - err = 0; + if (cnt) + err = 0; out: spin_unlock_bh(&xfrm_state_lock); @@ -641,11 +645,11 @@ out: } EXPORT_SYMBOL(xfrm_state_flush); -void xfrm_sad_getinfo(struct xfrmk_sadinfo *si) +void xfrm_sad_getinfo(struct net *net, struct xfrmk_sadinfo *si) { spin_lock_bh(&xfrm_state_lock); - si->sadcnt = init_net.xfrm.state_num; - si->sadhcnt = init_net.xfrm.state_hmask; + si->sadcnt = net->xfrm.state_num; + si->sadhcnt = net->xfrm.state_hmask; si->sadhmcnt = xfrm_state_hashmax; spin_unlock_bh(&xfrm_state_lock); } @@ -665,7 +669,7 @@ xfrm_init_tempsel(struct xfrm_state *x, struct flowi *fl, return 0; } -static struct xfrm_state *__xfrm_state_lookup(struct net *net, xfrm_address_t *daddr, __be32 spi, u8 proto, unsigned short family) +static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark, xfrm_address_t *daddr, __be32 spi, u8 proto, unsigned short family) { unsigned int h = xfrm_spi_hash(net, daddr, spi, proto, family); struct xfrm_state *x; @@ -678,6 +682,8 @@ static struct xfrm_state *__xfrm_state_lookup(struct net *net, xfrm_address_t *d xfrm_addr_cmp(&x->id.daddr, daddr, family)) continue; + if ((mark & x->mark.m) != x->mark.v) + continue; xfrm_state_hold(x); return x; } @@ -685,7 +691,7 @@ static struct xfrm_state *__xfrm_state_lookup(struct net *net, xfrm_address_t *d return NULL; } -static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, xfrm_address_t *daddr, xfrm_address_t *saddr, u8 proto, unsigned short family) +static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, u32 mark, xfrm_address_t *daddr, xfrm_address_t *saddr, u8 proto, unsigned short family) { unsigned int h = xfrm_src_hash(net, daddr, saddr, family); struct xfrm_state *x; @@ -698,6 +704,8 @@ static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, xfrm_addre xfrm_addr_cmp(&x->props.saddr, saddr, family)) continue; + if ((mark & x->mark.m) != x->mark.v) + continue; xfrm_state_hold(x); return x; } @@ -709,12 +717,14 @@ static inline struct xfrm_state * __xfrm_state_locate(struct xfrm_state *x, int use_spi, int family) { struct net *net = xs_net(x); + u32 mark = x->mark.v & x->mark.m; if (use_spi) - return __xfrm_state_lookup(net, &x->id.daddr, x->id.spi, - x->id.proto, family); + return __xfrm_state_lookup(net, mark, &x->id.daddr, + x->id.spi, x->id.proto, family); else - return __xfrm_state_lookup_byaddr(net, &x->id.daddr, + return __xfrm_state_lookup_byaddr(net, mark, + &x->id.daddr, &x->props.saddr, x->id.proto, family); } @@ -779,6 +789,7 @@ xfrm_state_find(xfrm_address_t *daddr, xfrm_address_t *saddr, int acquire_in_progress = 0; int error = 0; struct xfrm_state *best = NULL; + u32 mark = pol->mark.v & pol->mark.m; to_put = NULL; @@ -787,6 +798,7 @@ xfrm_state_find(xfrm_address_t *daddr, xfrm_address_t *saddr, hlist_for_each_entry(x, entry, net->xfrm.state_bydst+h, bydst) { if (x->props.family == family && x->props.reqid == tmpl->reqid && + (mark & x->mark.m) == x->mark.v && !(x->props.flags & XFRM_STATE_WILDRECV) && xfrm_state_addr_check(x, daddr, saddr, family) && tmpl->mode == x->props.mode && @@ -802,6 +814,7 @@ xfrm_state_find(xfrm_address_t *daddr, xfrm_address_t *saddr, hlist_for_each_entry(x, entry, net->xfrm.state_bydst+h_wildcard, bydst) { if (x->props.family == family && x->props.reqid == tmpl->reqid && + (mark & x->mark.m) == x->mark.v && !(x->props.flags & XFRM_STATE_WILDRECV) && xfrm_state_addr_check(x, daddr, saddr, family) && tmpl->mode == x->props.mode && @@ -815,7 +828,7 @@ found: x = best; if (!x && !error && !acquire_in_progress) { if (tmpl->id.spi && - (x0 = __xfrm_state_lookup(net, daddr, tmpl->id.spi, + (x0 = __xfrm_state_lookup(net, mark, daddr, tmpl->id.spi, tmpl->id.proto, family)) != NULL) { to_put = x0; error = -EEXIST; @@ -829,6 +842,7 @@ found: /* Initialize temporary selector matching only * to current session. */ xfrm_init_tempsel(x, fl, tmpl, daddr, saddr, family); + memcpy(&x->mark, &pol->mark, sizeof(x->mark)); error = security_xfrm_state_alloc_acquire(x, pol->security, fl->secid); if (error) { @@ -871,7 +885,7 @@ out: } struct xfrm_state * -xfrm_stateonly_find(struct net *net, +xfrm_stateonly_find(struct net *net, u32 mark, xfrm_address_t *daddr, xfrm_address_t *saddr, unsigned short family, u8 mode, u8 proto, u32 reqid) { @@ -884,6 +898,7 @@ xfrm_stateonly_find(struct net *net, hlist_for_each_entry(x, entry, net->xfrm.state_bydst+h, bydst) { if (x->props.family == family && x->props.reqid == reqid && + (mark & x->mark.m) == x->mark.v && !(x->props.flags & XFRM_STATE_WILDRECV) && xfrm_state_addr_check(x, daddr, saddr, family) && mode == x->props.mode && @@ -946,11 +961,13 @@ static void __xfrm_state_bump_genids(struct xfrm_state *xnew) struct xfrm_state *x; struct hlist_node *entry; unsigned int h; + u32 mark = xnew->mark.v & xnew->mark.m; h = xfrm_dst_hash(net, &xnew->id.daddr, &xnew->props.saddr, reqid, family); hlist_for_each_entry(x, entry, net->xfrm.state_bydst+h, bydst) { if (x->props.family == family && x->props.reqid == reqid && + (mark & x->mark.m) == x->mark.v && !xfrm_addr_cmp(&x->id.daddr, &xnew->id.daddr, family) && !xfrm_addr_cmp(&x->props.saddr, &xnew->props.saddr, family)) x->genid = xfrm_state_genid; @@ -967,11 +984,12 @@ void xfrm_state_insert(struct xfrm_state *x) EXPORT_SYMBOL(xfrm_state_insert); /* xfrm_state_lock is held */ -static struct xfrm_state *__find_acq_core(struct net *net, unsigned short family, u8 mode, u32 reqid, u8 proto, xfrm_address_t *daddr, xfrm_address_t *saddr, int create) +static struct xfrm_state *__find_acq_core(struct net *net, struct xfrm_mark *m, unsigned short family, u8 mode, u32 reqid, u8 proto, xfrm_address_t *daddr, xfrm_address_t *saddr, int create) { unsigned int h = xfrm_dst_hash(net, daddr, saddr, reqid, family); struct hlist_node *entry; struct xfrm_state *x; + u32 mark = m->v & m->m; hlist_for_each_entry(x, entry, net->xfrm.state_bydst+h, bydst) { if (x->props.reqid != reqid || @@ -980,6 +998,7 @@ static struct xfrm_state *__find_acq_core(struct net *net, unsigned short family x->km.state != XFRM_STATE_ACQ || x->id.spi != 0 || x->id.proto != proto || + (mark & x->mark.m) != x->mark.v || xfrm_addr_cmp(&x->id.daddr, daddr, family) || xfrm_addr_cmp(&x->props.saddr, saddr, family)) continue; @@ -1022,6 +1041,8 @@ static struct xfrm_state *__find_acq_core(struct net *net, unsigned short family x->props.family = family; x->props.mode = mode; x->props.reqid = reqid; + x->mark.v = m->v; + x->mark.m = m->m; x->lft.hard_add_expires_seconds = net->xfrm.sysctl_acq_expires; xfrm_state_hold(x); tasklet_hrtimer_start(&x->mtimer, ktime_set(net->xfrm.sysctl_acq_expires, 0), HRTIMER_MODE_REL); @@ -1038,7 +1059,7 @@ static struct xfrm_state *__find_acq_core(struct net *net, unsigned short family return x; } -static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 seq); +static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq); int xfrm_state_add(struct xfrm_state *x) { @@ -1046,6 +1067,7 @@ int xfrm_state_add(struct xfrm_state *x) struct xfrm_state *x1, *to_put; int family; int err; + u32 mark = x->mark.v & x->mark.m; int use_spi = xfrm_id_proto_match(x->id.proto, IPSEC_PROTO_ANY); family = x->props.family; @@ -1063,7 +1085,7 @@ int xfrm_state_add(struct xfrm_state *x) } if (use_spi && x->km.seq) { - x1 = __xfrm_find_acq_byseq(net, x->km.seq); + x1 = __xfrm_find_acq_byseq(net, mark, x->km.seq); if (x1 && ((x1->id.proto != x->id.proto) || xfrm_addr_cmp(&x1->id.daddr, &x->id.daddr, family))) { to_put = x1; @@ -1072,8 +1094,8 @@ int xfrm_state_add(struct xfrm_state *x) } if (use_spi && !x1) - x1 = __find_acq_core(net, family, x->props.mode, x->props.reqid, - x->id.proto, + x1 = __find_acq_core(net, &x->mark, family, x->props.mode, + x->props.reqid, x->id.proto, &x->id.daddr, &x->props.saddr, 0); __xfrm_state_bump_genids(x); @@ -1102,7 +1124,7 @@ static struct xfrm_state *xfrm_state_clone(struct xfrm_state *orig, int *errp) int err = -ENOMEM; struct xfrm_state *x = xfrm_state_alloc(net); if (!x) - goto error; + goto out; memcpy(&x->id, &orig->id, sizeof(x->id)); memcpy(&x->sel, &orig->sel, sizeof(x->sel)); @@ -1147,6 +1169,8 @@ static struct xfrm_state *xfrm_state_clone(struct xfrm_state *orig, int *errp) goto error; } + memcpy(&x->mark, &orig->mark, sizeof(x->mark)); + err = xfrm_init_state(x); if (err) goto error; @@ -1160,16 +1184,10 @@ static struct xfrm_state *xfrm_state_clone(struct xfrm_state *orig, int *errp) return x; error: + xfrm_state_put(x); +out: if (errp) *errp = err; - if (x) { - kfree(x->aalg); - kfree(x->ealg); - kfree(x->calg); - kfree(x->encap); - kfree(x->coaddr); - } - kfree(x); return NULL; } @@ -1344,41 +1362,41 @@ int xfrm_state_check_expire(struct xfrm_state *x) EXPORT_SYMBOL(xfrm_state_check_expire); struct xfrm_state * -xfrm_state_lookup(struct net *net, xfrm_address_t *daddr, __be32 spi, u8 proto, - unsigned short family) +xfrm_state_lookup(struct net *net, u32 mark, xfrm_address_t *daddr, __be32 spi, + u8 proto, unsigned short family) { struct xfrm_state *x; spin_lock_bh(&xfrm_state_lock); - x = __xfrm_state_lookup(net, daddr, spi, proto, family); + x = __xfrm_state_lookup(net, mark, daddr, spi, proto, family); spin_unlock_bh(&xfrm_state_lock); return x; } EXPORT_SYMBOL(xfrm_state_lookup); struct xfrm_state * -xfrm_state_lookup_byaddr(struct net *net, +xfrm_state_lookup_byaddr(struct net *net, u32 mark, xfrm_address_t *daddr, xfrm_address_t *saddr, u8 proto, unsigned short family) { struct xfrm_state *x; spin_lock_bh(&xfrm_state_lock); - x = __xfrm_state_lookup_byaddr(net, daddr, saddr, proto, family); + x = __xfrm_state_lookup_byaddr(net, mark, daddr, saddr, proto, family); spin_unlock_bh(&xfrm_state_lock); return x; } EXPORT_SYMBOL(xfrm_state_lookup_byaddr); struct xfrm_state * -xfrm_find_acq(struct net *net, u8 mode, u32 reqid, u8 proto, +xfrm_find_acq(struct net *net, struct xfrm_mark *mark, u8 mode, u32 reqid, u8 proto, xfrm_address_t *daddr, xfrm_address_t *saddr, int create, unsigned short family) { struct xfrm_state *x; spin_lock_bh(&xfrm_state_lock); - x = __find_acq_core(net, family, mode, reqid, proto, daddr, saddr, create); + x = __find_acq_core(net, mark, family, mode, reqid, proto, daddr, saddr, create); spin_unlock_bh(&xfrm_state_lock); return x; @@ -1425,7 +1443,7 @@ EXPORT_SYMBOL(xfrm_state_sort); /* Silly enough, but I'm lazy to build resolution list */ -static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 seq) +static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq) { int i; @@ -1435,6 +1453,7 @@ static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 seq) hlist_for_each_entry(x, entry, net->xfrm.state_bydst+i, bydst) { if (x->km.seq == seq && + (mark & x->mark.m) == x->mark.v && x->km.state == XFRM_STATE_ACQ) { xfrm_state_hold(x); return x; @@ -1444,12 +1463,12 @@ static struct xfrm_state *__xfrm_find_acq_byseq(struct net *net, u32 seq) return NULL; } -struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 seq) +struct xfrm_state *xfrm_find_acq_byseq(struct net *net, u32 mark, u32 seq) { struct xfrm_state *x; spin_lock_bh(&xfrm_state_lock); - x = __xfrm_find_acq_byseq(net, seq); + x = __xfrm_find_acq_byseq(net, mark, seq); spin_unlock_bh(&xfrm_state_lock); return x; } @@ -1458,12 +1477,12 @@ EXPORT_SYMBOL(xfrm_find_acq_byseq); u32 xfrm_get_acqseq(void) { u32 res; - static u32 acqseq; - static DEFINE_SPINLOCK(acqseq_lock); + static atomic_t acqseq; + + do { + res = atomic_inc_return(&acqseq); + } while (!res); - spin_lock_bh(&acqseq_lock); - res = (++acqseq ? : ++acqseq); - spin_unlock_bh(&acqseq_lock); return res; } EXPORT_SYMBOL(xfrm_get_acqseq); @@ -1476,6 +1495,7 @@ int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high) int err = -ENOENT; __be32 minspi = htonl(low); __be32 maxspi = htonl(high); + u32 mark = x->mark.v & x->mark.m; spin_lock_bh(&x->lock); if (x->km.state == XFRM_STATE_DEAD) @@ -1488,7 +1508,7 @@ int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high) err = -ENOENT; if (minspi == maxspi) { - x0 = xfrm_state_lookup(net, &x->id.daddr, minspi, x->id.proto, x->props.family); + x0 = xfrm_state_lookup(net, mark, &x->id.daddr, minspi, x->id.proto, x->props.family); if (x0) { xfrm_state_put(x0); goto unlock; @@ -1498,7 +1518,7 @@ int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high) u32 spi = 0; for (h=0; h<high-low+1; h++) { spi = low + net_random()%(high-low+1); - x0 = xfrm_state_lookup(net, &x->id.daddr, htonl(spi), x->id.proto, x->props.family); + x0 = xfrm_state_lookup(net, mark, &x->id.daddr, htonl(spi), x->id.proto, x->props.family); if (x0 == NULL) { x->id.spi = htonl(spi); break; diff --git a/net/xfrm/xfrm_sysctl.c b/net/xfrm/xfrm_sysctl.c index 2e221f2cad7e..2c4d6cdcba49 100644 --- a/net/xfrm/xfrm_sysctl.c +++ b/net/xfrm/xfrm_sysctl.c @@ -2,7 +2,7 @@ #include <net/net_namespace.h> #include <net/xfrm.h> -static void __xfrm_sysctl_init(struct net *net) +static void __net_init __xfrm_sysctl_init(struct net *net) { net->xfrm.sysctl_aevent_etime = XFRM_AE_ETIME; net->xfrm.sysctl_aevent_rseqth = XFRM_AE_SEQT_SIZE; @@ -64,7 +64,7 @@ out_kmemdup: return -ENOMEM; } -void xfrm_sysctl_fini(struct net *net) +void __net_exit xfrm_sysctl_fini(struct net *net) { struct ctl_table *table; diff --git a/net/xfrm/xfrm_user.c b/net/xfrm/xfrm_user.c index 1ada6186933c..6106b72826d3 100644 --- a/net/xfrm/xfrm_user.c +++ b/net/xfrm/xfrm_user.c @@ -446,6 +446,8 @@ static struct xfrm_state *xfrm_state_construct(struct net *net, goto error; } + xfrm_mark_get(attrs, &x->mark); + err = xfrm_init_state(x); if (err) goto error; @@ -526,11 +528,13 @@ static struct xfrm_state *xfrm_user_state_lookup(struct net *net, int *errp) { struct xfrm_state *x = NULL; + struct xfrm_mark m; int err; + u32 mark = xfrm_mark_get(attrs, &m); if (xfrm_id_proto_match(p->proto, IPSEC_PROTO_ANY)) { err = -ESRCH; - x = xfrm_state_lookup(net, &p->daddr, p->spi, p->proto, p->family); + x = xfrm_state_lookup(net, mark, &p->daddr, p->spi, p->proto, p->family); } else { xfrm_address_t *saddr = NULL; @@ -541,7 +545,8 @@ static struct xfrm_state *xfrm_user_state_lookup(struct net *net, } err = -ESRCH; - x = xfrm_state_lookup_byaddr(net, &p->daddr, saddr, + x = xfrm_state_lookup_byaddr(net, mark, + &p->daddr, saddr, p->proto, p->family); } @@ -683,6 +688,9 @@ static int copy_to_user_state_extra(struct xfrm_state *x, if (x->encap) NLA_PUT(skb, XFRMA_ENCAP, sizeof(*x->encap), x->encap); + if (xfrm_mark_put(skb, &x->mark)) + goto nla_put_failure; + if (x->security && copy_sec_ctx(x->security, skb) < 0) goto nla_put_failure; @@ -781,7 +789,8 @@ static inline size_t xfrm_spdinfo_msgsize(void) + nla_total_size(sizeof(struct xfrmu_spdhinfo)); } -static int build_spdinfo(struct sk_buff *skb, u32 pid, u32 seq, u32 flags) +static int build_spdinfo(struct sk_buff *skb, struct net *net, + u32 pid, u32 seq, u32 flags) { struct xfrmk_spdinfo si; struct xfrmu_spdinfo spc; @@ -795,7 +804,7 @@ static int build_spdinfo(struct sk_buff *skb, u32 pid, u32 seq, u32 flags) f = nlmsg_data(nlh); *f = flags; - xfrm_spd_getinfo(&si); + xfrm_spd_getinfo(net, &si); spc.incnt = si.incnt; spc.outcnt = si.outcnt; spc.fwdcnt = si.fwdcnt; @@ -828,7 +837,7 @@ static int xfrm_get_spdinfo(struct sk_buff *skb, struct nlmsghdr *nlh, if (r_skb == NULL) return -ENOMEM; - if (build_spdinfo(r_skb, spid, seq, *flags) < 0) + if (build_spdinfo(r_skb, net, spid, seq, *flags) < 0) BUG(); return nlmsg_unicast(net->xfrm.nlsk, r_skb, spid); @@ -841,7 +850,8 @@ static inline size_t xfrm_sadinfo_msgsize(void) + nla_total_size(4); /* XFRMA_SAD_CNT */ } -static int build_sadinfo(struct sk_buff *skb, u32 pid, u32 seq, u32 flags) +static int build_sadinfo(struct sk_buff *skb, struct net *net, + u32 pid, u32 seq, u32 flags) { struct xfrmk_sadinfo si; struct xfrmu_sadhinfo sh; @@ -854,7 +864,7 @@ static int build_sadinfo(struct sk_buff *skb, u32 pid, u32 seq, u32 flags) f = nlmsg_data(nlh); *f = flags; - xfrm_sad_getinfo(&si); + xfrm_sad_getinfo(net, &si); sh.sadhmcnt = si.sadhmcnt; sh.sadhcnt = si.sadhcnt; @@ -882,7 +892,7 @@ static int xfrm_get_sadinfo(struct sk_buff *skb, struct nlmsghdr *nlh, if (r_skb == NULL) return -ENOMEM; - if (build_sadinfo(r_skb, spid, seq, *flags) < 0) + if (build_sadinfo(r_skb, net, spid, seq, *flags) < 0) BUG(); return nlmsg_unicast(net->xfrm.nlsk, r_skb, spid); @@ -945,6 +955,8 @@ static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh, xfrm_address_t *daddr; int family; int err; + u32 mark; + struct xfrm_mark m; p = nlmsg_data(nlh); err = verify_userspi_info(p); @@ -955,8 +967,10 @@ static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh, daddr = &p->info.id.daddr; x = NULL; + + mark = xfrm_mark_get(attrs, &m); if (p->info.seq) { - x = xfrm_find_acq_byseq(net, p->info.seq); + x = xfrm_find_acq_byseq(net, mark, p->info.seq); if (x && xfrm_addr_cmp(&x->id.daddr, daddr, family)) { xfrm_state_put(x); x = NULL; @@ -964,7 +978,7 @@ static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh, } if (!x) - x = xfrm_find_acq(net, p->info.mode, p->info.reqid, + x = xfrm_find_acq(net, &m, p->info.mode, p->info.reqid, p->info.id.proto, daddr, &p->info.saddr, 1, family); @@ -1218,6 +1232,8 @@ static struct xfrm_policy *xfrm_policy_construct(struct net *net, struct xfrm_us if (err) goto error; + xfrm_mark_get(attrs, &xp->mark); + return xp; error: *errp = err; @@ -1364,10 +1380,13 @@ static int dump_one_policy(struct xfrm_policy *xp, int dir, int count, void *ptr goto nlmsg_failure; if (copy_to_user_policy_type(xp->type, skb) < 0) goto nlmsg_failure; + if (xfrm_mark_put(skb, &xp->mark)) + goto nla_put_failure; nlmsg_end(skb, nlh); return 0; +nla_put_failure: nlmsg_failure: nlmsg_cancel(skb, nlh); return -EMSGSIZE; @@ -1439,6 +1458,8 @@ static int xfrm_get_policy(struct sk_buff *skb, struct nlmsghdr *nlh, int err; struct km_event c; int delete; + struct xfrm_mark m; + u32 mark = xfrm_mark_get(attrs, &m); p = nlmsg_data(nlh); delete = nlh->nlmsg_type == XFRM_MSG_DELPOLICY; @@ -1452,7 +1473,7 @@ static int xfrm_get_policy(struct sk_buff *skb, struct nlmsghdr *nlh, return err; if (p->index) - xp = xfrm_policy_byid(net, type, p->dir, p->index, delete, &err); + xp = xfrm_policy_byid(net, mark, type, p->dir, p->index, delete, &err); else { struct nlattr *rt = attrs[XFRMA_SEC_CTX]; struct xfrm_sec_ctx *ctx; @@ -1469,8 +1490,8 @@ static int xfrm_get_policy(struct sk_buff *skb, struct nlmsghdr *nlh, if (err) return err; } - xp = xfrm_policy_bysel_ctx(net, type, p->dir, &p->sel, ctx, - delete, &err); + xp = xfrm_policy_bysel_ctx(net, mark, type, p->dir, &p->sel, + ctx, delete, &err); security_xfrm_policy_free(ctx); } if (xp == NULL) @@ -1522,8 +1543,11 @@ static int xfrm_flush_sa(struct sk_buff *skb, struct nlmsghdr *nlh, audit_info.sessionid = NETLINK_CB(skb).sessionid; audit_info.secid = NETLINK_CB(skb).sid; err = xfrm_state_flush(net, p->proto, &audit_info); - if (err) + if (err) { + if (err == -ESRCH) /* empty table */ + return 0; return err; + } c.data.proto = p->proto; c.event = nlh->nlmsg_type; c.seq = nlh->nlmsg_seq; @@ -1539,6 +1563,7 @@ static inline size_t xfrm_aevent_msgsize(void) return NLMSG_ALIGN(sizeof(struct xfrm_aevent_id)) + nla_total_size(sizeof(struct xfrm_replay_state)) + nla_total_size(sizeof(struct xfrm_lifetime_cur)) + + nla_total_size(sizeof(struct xfrm_mark)) + nla_total_size(4) /* XFRM_AE_RTHR */ + nla_total_size(4); /* XFRM_AE_ETHR */ } @@ -1571,6 +1596,9 @@ static int build_aevent(struct sk_buff *skb, struct xfrm_state *x, struct km_eve NLA_PUT_U32(skb, XFRMA_ETIMER_THRESH, x->replay_maxage * 10 / HZ); + if (xfrm_mark_put(skb, &x->mark)) + goto nla_put_failure; + return nlmsg_end(skb, nlh); nla_put_failure: @@ -1586,6 +1614,8 @@ static int xfrm_get_ae(struct sk_buff *skb, struct nlmsghdr *nlh, struct sk_buff *r_skb; int err; struct km_event c; + u32 mark; + struct xfrm_mark m; struct xfrm_aevent_id *p = nlmsg_data(nlh); struct xfrm_usersa_id *id = &p->sa_id; @@ -1593,7 +1623,9 @@ static int xfrm_get_ae(struct sk_buff *skb, struct nlmsghdr *nlh, if (r_skb == NULL) return -ENOMEM; - x = xfrm_state_lookup(net, &id->daddr, id->spi, id->proto, id->family); + mark = xfrm_mark_get(attrs, &m); + + x = xfrm_state_lookup(net, mark, &id->daddr, id->spi, id->proto, id->family); if (x == NULL) { kfree_skb(r_skb); return -ESRCH; @@ -1624,6 +1656,8 @@ static int xfrm_new_ae(struct sk_buff *skb, struct nlmsghdr *nlh, struct xfrm_state *x; struct km_event c; int err = - EINVAL; + u32 mark = 0; + struct xfrm_mark m; struct xfrm_aevent_id *p = nlmsg_data(nlh); struct nlattr *rp = attrs[XFRMA_REPLAY_VAL]; struct nlattr *lt = attrs[XFRMA_LTIME_VAL]; @@ -1635,7 +1669,9 @@ static int xfrm_new_ae(struct sk_buff *skb, struct nlmsghdr *nlh, if (!(nlh->nlmsg_flags&NLM_F_REPLACE)) return err; - x = xfrm_state_lookup(net, &p->sa_id.daddr, p->sa_id.spi, p->sa_id.proto, p->sa_id.family); + mark = xfrm_mark_get(attrs, &m); + + x = xfrm_state_lookup(net, mark, &p->sa_id.daddr, p->sa_id.spi, p->sa_id.proto, p->sa_id.family); if (x == NULL) return -ESRCH; @@ -1674,8 +1710,12 @@ static int xfrm_flush_policy(struct sk_buff *skb, struct nlmsghdr *nlh, audit_info.sessionid = NETLINK_CB(skb).sessionid; audit_info.secid = NETLINK_CB(skb).sid; err = xfrm_policy_flush(net, type, &audit_info); - if (err) + if (err) { + if (err == -ESRCH) /* empty table */ + return 0; return err; + } + c.data.type = type; c.event = nlh->nlmsg_type; c.seq = nlh->nlmsg_seq; @@ -1694,13 +1734,15 @@ static int xfrm_add_pol_expire(struct sk_buff *skb, struct nlmsghdr *nlh, struct xfrm_userpolicy_info *p = &up->pol; u8 type = XFRM_POLICY_TYPE_MAIN; int err = -ENOENT; + struct xfrm_mark m; + u32 mark = xfrm_mark_get(attrs, &m); err = copy_from_user_policy_type(&type, attrs); if (err) return err; if (p->index) - xp = xfrm_policy_byid(net, type, p->dir, p->index, 0, &err); + xp = xfrm_policy_byid(net, mark, type, p->dir, p->index, 0, &err); else { struct nlattr *rt = attrs[XFRMA_SEC_CTX]; struct xfrm_sec_ctx *ctx; @@ -1717,7 +1759,8 @@ static int xfrm_add_pol_expire(struct sk_buff *skb, struct nlmsghdr *nlh, if (err) return err; } - xp = xfrm_policy_bysel_ctx(net, type, p->dir, &p->sel, ctx, 0, &err); + xp = xfrm_policy_bysel_ctx(net, mark, type, p->dir, + &p->sel, ctx, 0, &err); security_xfrm_policy_free(ctx); } if (xp == NULL) @@ -1757,8 +1800,10 @@ static int xfrm_add_sa_expire(struct sk_buff *skb, struct nlmsghdr *nlh, int err; struct xfrm_user_expire *ue = nlmsg_data(nlh); struct xfrm_usersa_info *p = &ue->state; + struct xfrm_mark m; + u32 mark = xfrm_mark_get(attrs, &m);; - x = xfrm_state_lookup(net, &p->id.daddr, p->id.spi, p->id.proto, p->family); + x = xfrm_state_lookup(net, mark, &p->id.daddr, p->id.spi, p->id.proto, p->family); err = -ENOENT; if (x == NULL) @@ -1792,6 +1837,7 @@ static int xfrm_add_acquire(struct sk_buff *skb, struct nlmsghdr *nlh, struct xfrm_user_tmpl *ut; int i; struct nlattr *rt = attrs[XFRMA_TMPL]; + struct xfrm_mark mark; struct xfrm_user_acquire *ua = nlmsg_data(nlh); struct xfrm_state *x = xfrm_state_alloc(net); @@ -1800,6 +1846,8 @@ static int xfrm_add_acquire(struct sk_buff *skb, struct nlmsghdr *nlh, if (!x) goto nomem; + xfrm_mark_get(attrs, &mark); + err = verify_newpolicy_info(&ua->policy); if (err) goto bad_policy; @@ -1812,7 +1860,8 @@ static int xfrm_add_acquire(struct sk_buff *skb, struct nlmsghdr *nlh, memcpy(&x->id, &ua->id, sizeof(ua->id)); memcpy(&x->props.saddr, &ua->saddr, sizeof(ua->saddr)); memcpy(&x->sel, &ua->sel, sizeof(ua->sel)); - + xp->mark.m = x->mark.m = mark.m; + xp->mark.v = x->mark.v = mark.v; ut = nla_data(rt); /* extract the templates and for each call km_key */ for (i = 0; i < xp->xfrm_nr; i++, ut++) { @@ -2052,6 +2101,10 @@ static const int xfrm_msg_min[XFRM_NR_MSGTYPES] = { #undef XMSGSIZE static const struct nla_policy xfrma_policy[XFRMA_MAX+1] = { + [XFRMA_SA] = { .len = sizeof(struct xfrm_usersa_info)}, + [XFRMA_POLICY] = { .len = sizeof(struct xfrm_userpolicy_info)}, + [XFRMA_LASTUSED] = { .type = NLA_U64}, + [XFRMA_ALG_AUTH_TRUNC] = { .len = sizeof(struct xfrm_algo_auth)}, [XFRMA_ALG_AEAD] = { .len = sizeof(struct xfrm_algo_aead) }, [XFRMA_ALG_AUTH] = { .len = sizeof(struct xfrm_algo) }, [XFRMA_ALG_CRYPT] = { .len = sizeof(struct xfrm_algo) }, @@ -2068,6 +2121,7 @@ static const struct nla_policy xfrma_policy[XFRMA_MAX+1] = { [XFRMA_POLICY_TYPE] = { .len = sizeof(struct xfrm_userpolicy_type)}, [XFRMA_MIGRATE] = { .len = sizeof(struct xfrm_user_migrate) }, [XFRMA_KMADDRESS] = { .len = sizeof(struct xfrm_user_kmaddress) }, + [XFRMA_MARK] = { .len = sizeof(struct xfrm_mark) }, }; static struct xfrm_link { @@ -2147,7 +2201,8 @@ static void xfrm_netlink_rcv(struct sk_buff *skb) static inline size_t xfrm_expire_msgsize(void) { - return NLMSG_ALIGN(sizeof(struct xfrm_user_expire)); + return NLMSG_ALIGN(sizeof(struct xfrm_user_expire)) + + nla_total_size(sizeof(struct xfrm_mark)); } static int build_expire(struct sk_buff *skb, struct xfrm_state *x, struct km_event *c) @@ -2163,7 +2218,13 @@ static int build_expire(struct sk_buff *skb, struct xfrm_state *x, struct km_eve copy_to_user_state(x, &ue->state); ue->hard = (c->data.hard != 0) ? 1 : 0; + if (xfrm_mark_put(skb, &x->mark)) + goto nla_put_failure; + return nlmsg_end(skb, nlh); + +nla_put_failure: + return -EMSGSIZE; } static int xfrm_exp_state_notify(struct xfrm_state *x, struct km_event *c) @@ -2175,8 +2236,10 @@ static int xfrm_exp_state_notify(struct xfrm_state *x, struct km_event *c) if (skb == NULL) return -ENOMEM; - if (build_expire(skb, x, c) < 0) - BUG(); + if (build_expire(skb, x, c) < 0) { + kfree_skb(skb); + return -EMSGSIZE; + } return nlmsg_multicast(net->xfrm.nlsk, skb, 0, XFRMNLGRP_EXPIRE, GFP_ATOMIC); } @@ -2264,6 +2327,7 @@ static int xfrm_notify_sa(struct xfrm_state *x, struct km_event *c) if (c->event == XFRM_MSG_DELSA) { len += nla_total_size(headlen); headlen = sizeof(*id); + len += nla_total_size(sizeof(struct xfrm_mark)); } len += NLMSG_ALIGN(headlen); @@ -2334,6 +2398,7 @@ static inline size_t xfrm_acquire_msgsize(struct xfrm_state *x, { return NLMSG_ALIGN(sizeof(struct xfrm_user_acquire)) + nla_total_size(sizeof(struct xfrm_user_tmpl) * xp->xfrm_nr) + + nla_total_size(sizeof(struct xfrm_mark)) + nla_total_size(xfrm_user_sec_ctx_size(x->security)) + userpolicy_type_attrsize(); } @@ -2366,9 +2431,12 @@ static int build_acquire(struct sk_buff *skb, struct xfrm_state *x, goto nlmsg_failure; if (copy_to_user_policy_type(xp->type, skb) < 0) goto nlmsg_failure; + if (xfrm_mark_put(skb, &xp->mark)) + goto nla_put_failure; return nlmsg_end(skb, nlh); +nla_put_failure: nlmsg_failure: nlmsg_cancel(skb, nlh); return -EMSGSIZE; @@ -2455,6 +2523,7 @@ static inline size_t xfrm_polexpire_msgsize(struct xfrm_policy *xp) return NLMSG_ALIGN(sizeof(struct xfrm_user_polexpire)) + nla_total_size(sizeof(struct xfrm_user_tmpl) * xp->xfrm_nr) + nla_total_size(xfrm_user_sec_ctx_size(xp->security)) + + nla_total_size(sizeof(struct xfrm_mark)) + userpolicy_type_attrsize(); } @@ -2477,10 +2546,13 @@ static int build_polexpire(struct sk_buff *skb, struct xfrm_policy *xp, goto nlmsg_failure; if (copy_to_user_policy_type(xp->type, skb) < 0) goto nlmsg_failure; + if (xfrm_mark_put(skb, &xp->mark)) + goto nla_put_failure; upe->hard = !!hard; return nlmsg_end(skb, nlh); +nla_put_failure: nlmsg_failure: nlmsg_cancel(skb, nlh); return -EMSGSIZE; @@ -2517,6 +2589,7 @@ static int xfrm_notify_policy(struct xfrm_policy *xp, int dir, struct km_event * headlen = sizeof(*id); } len += userpolicy_type_attrsize(); + len += nla_total_size(sizeof(struct xfrm_mark)); len += NLMSG_ALIGN(headlen); skb = nlmsg_new(len, GFP_ATOMIC); @@ -2552,10 +2625,14 @@ static int xfrm_notify_policy(struct xfrm_policy *xp, int dir, struct km_event * if (copy_to_user_policy_type(xp->type, skb) < 0) goto nlmsg_failure; + if (xfrm_mark_put(skb, &xp->mark)) + goto nla_put_failure; + nlmsg_end(skb, nlh); return nlmsg_multicast(net->xfrm.nlsk, skb, 0, XFRMNLGRP_POLICY, GFP_ATOMIC); +nla_put_failure: nlmsg_failure: kfree_skb(skb); return -1; |