diff options
-rw-r--r-- | net/l2tp/l2tp_core.c | 240 | ||||
-rw-r--r-- | net/l2tp/l2tp_core.h | 18 | ||||
-rw-r--r-- | net/l2tp/l2tp_ip.c | 2 | ||||
-rw-r--r-- | net/l2tp/l2tp_ip6.c | 2 |
4 files changed, 188 insertions, 74 deletions
diff --git a/net/l2tp/l2tp_core.c b/net/l2tp/l2tp_core.c index 69f8c9f5cdc7..d6bffdb16466 100644 --- a/net/l2tp/l2tp_core.c +++ b/net/l2tp/l2tp_core.c @@ -107,11 +107,17 @@ struct l2tp_net { /* Lock for write access to l2tp_tunnel_idr */ spinlock_t l2tp_tunnel_idr_lock; struct idr l2tp_tunnel_idr; - struct hlist_head l2tp_session_hlist[L2TP_HASH_SIZE_2]; - /* Lock for write access to l2tp_session_hlist */ - spinlock_t l2tp_session_hlist_lock; + /* Lock for write access to l2tp_v3_session_idr/htable */ + spinlock_t l2tp_session_idr_lock; + struct idr l2tp_v3_session_idr; + struct hlist_head l2tp_v3_session_htable[16]; }; +static inline unsigned long l2tp_v3_session_hashkey(struct sock *sk, u32 session_id) +{ + return ((unsigned long)sk) + session_id; +} + #if IS_ENABLED(CONFIG_IPV6) static bool l2tp_sk_is_v6(struct sock *sk) { @@ -125,17 +131,6 @@ static inline struct l2tp_net *l2tp_pernet(const struct net *net) return net_generic(net, l2tp_net_id); } -/* Session hash global list for L2TPv3. - * The session_id SHOULD be random according to RFC3931, but several - * L2TP implementations use incrementing session_ids. So we do a real - * hash on the session_id, rather than a simple bitmask. - */ -static inline struct hlist_head * -l2tp_session_id_hash_2(struct l2tp_net *pn, u32 session_id) -{ - return &pn->l2tp_session_hlist[hash_32(session_id, L2TP_HASH_BITS_2)]; -} - /* Session hash list. * The session_id SHOULD be random according to RFC2661, but several * L2TP implementations (Cisco and Microsoft) use incrementing @@ -262,26 +257,40 @@ struct l2tp_session *l2tp_tunnel_get_session(struct l2tp_tunnel *tunnel, } EXPORT_SYMBOL_GPL(l2tp_tunnel_get_session); -struct l2tp_session *l2tp_session_get(const struct net *net, u32 session_id) +struct l2tp_session *l2tp_v3_session_get(const struct net *net, struct sock *sk, u32 session_id) { - struct hlist_head *session_list; + const struct l2tp_net *pn = l2tp_pernet(net); struct l2tp_session *session; - session_list = l2tp_session_id_hash_2(l2tp_pernet(net), session_id); - rcu_read_lock_bh(); - hlist_for_each_entry_rcu(session, session_list, global_hlist) - if (session->session_id == session_id) { - l2tp_session_inc_refcount(session); - rcu_read_unlock_bh(); + session = idr_find(&pn->l2tp_v3_session_idr, session_id); + if (session && !hash_hashed(&session->hlist) && + refcount_inc_not_zero(&session->ref_count)) { + rcu_read_unlock_bh(); + return session; + } - return session; + /* If we get here and session is non-NULL, the session_id + * collides with one in another tunnel. If sk is non-NULL, + * find the session matching sk. + */ + if (session && sk) { + unsigned long key = l2tp_v3_session_hashkey(sk, session->session_id); + + hash_for_each_possible_rcu(pn->l2tp_v3_session_htable, session, + hlist, key) { + if (session->tunnel->sock == sk && + refcount_inc_not_zero(&session->ref_count)) { + rcu_read_unlock_bh(); + return session; + } } + } rcu_read_unlock_bh(); return NULL; } -EXPORT_SYMBOL_GPL(l2tp_session_get); +EXPORT_SYMBOL_GPL(l2tp_v3_session_get); struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth) { @@ -313,12 +322,12 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net, const char *ifname) { struct l2tp_net *pn = l2tp_pernet(net); - int hash; + unsigned long session_id, tmp; struct l2tp_session *session; rcu_read_lock_bh(); - for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) { - hlist_for_each_entry_rcu(session, &pn->l2tp_session_hlist[hash], global_hlist) { + idr_for_each_entry_ul(&pn->l2tp_v3_session_idr, session, tmp, session_id) { + if (session) { if (!strcmp(session->ifname, ifname)) { l2tp_session_inc_refcount(session); rcu_read_unlock_bh(); @@ -334,13 +343,106 @@ struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net, } EXPORT_SYMBOL_GPL(l2tp_session_get_by_ifname); +static void l2tp_session_coll_list_add(struct l2tp_session_coll_list *clist, + struct l2tp_session *session) +{ + l2tp_session_inc_refcount(session); + WARN_ON_ONCE(session->coll_list); + session->coll_list = clist; + spin_lock(&clist->lock); + list_add(&session->clist, &clist->list); + spin_unlock(&clist->lock); +} + +static int l2tp_session_collision_add(struct l2tp_net *pn, + struct l2tp_session *session1, + struct l2tp_session *session2) +{ + struct l2tp_session_coll_list *clist; + + lockdep_assert_held(&pn->l2tp_session_idr_lock); + + if (!session2) + return -EEXIST; + + /* If existing session is in IP-encap tunnel, refuse new session */ + if (session2->tunnel->encap == L2TP_ENCAPTYPE_IP) + return -EEXIST; + + clist = session2->coll_list; + if (!clist) { + /* First collision. Allocate list to manage the collided sessions + * and add the existing session to the list. + */ + clist = kmalloc(sizeof(*clist), GFP_ATOMIC); + if (!clist) + return -ENOMEM; + + spin_lock_init(&clist->lock); + INIT_LIST_HEAD(&clist->list); + refcount_set(&clist->ref_count, 1); + l2tp_session_coll_list_add(clist, session2); + } + + /* If existing session isn't already in the session hlist, add it. */ + if (!hash_hashed(&session2->hlist)) + hash_add(pn->l2tp_v3_session_htable, &session2->hlist, + session2->hlist_key); + + /* Add new session to the hlist and collision list */ + hash_add(pn->l2tp_v3_session_htable, &session1->hlist, + session1->hlist_key); + refcount_inc(&clist->ref_count); + l2tp_session_coll_list_add(clist, session1); + + return 0; +} + +static void l2tp_session_collision_del(struct l2tp_net *pn, + struct l2tp_session *session) +{ + struct l2tp_session_coll_list *clist = session->coll_list; + unsigned long session_key = session->session_id; + struct l2tp_session *session2; + + lockdep_assert_held(&pn->l2tp_session_idr_lock); + + hash_del(&session->hlist); + + if (clist) { + /* Remove session from its collision list. If there + * are other sessions with the same ID, replace this + * session's IDR entry with that session, otherwise + * remove the IDR entry. If this is the last session, + * the collision list data is freed. + */ + spin_lock(&clist->lock); + list_del_init(&session->clist); + session2 = list_first_entry_or_null(&clist->list, struct l2tp_session, clist); + if (session2) { + void *old = idr_replace(&pn->l2tp_v3_session_idr, session2, session_key); + + WARN_ON_ONCE(IS_ERR_VALUE(old)); + } else { + void *removed = idr_remove(&pn->l2tp_v3_session_idr, session_key); + + WARN_ON_ONCE(removed != session); + } + session->coll_list = NULL; + spin_unlock(&clist->lock); + if (refcount_dec_and_test(&clist->ref_count)) + kfree(clist); + l2tp_session_dec_refcount(session); + } +} + int l2tp_session_register(struct l2tp_session *session, struct l2tp_tunnel *tunnel) { + struct l2tp_net *pn = l2tp_pernet(tunnel->l2tp_net); struct l2tp_session *session_walk; - struct hlist_head *g_head; struct hlist_head *head; - struct l2tp_net *pn; + u32 session_key; int err; head = l2tp_session_id_hash(tunnel, session->session_id); @@ -358,39 +460,45 @@ int l2tp_session_register(struct l2tp_session *session, } if (tunnel->version == L2TP_HDR_VER_3) { - pn = l2tp_pernet(tunnel->l2tp_net); - g_head = l2tp_session_id_hash_2(pn, session->session_id); - - spin_lock_bh(&pn->l2tp_session_hlist_lock); - + session_key = session->session_id; + spin_lock_bh(&pn->l2tp_session_idr_lock); + err = idr_alloc_u32(&pn->l2tp_v3_session_idr, NULL, + &session_key, session_key, GFP_ATOMIC); /* IP encap expects session IDs to be globally unique, while - * UDP encap doesn't. + * UDP encap doesn't. This isn't per the RFC, which says that + * sessions are identified only by the session ID, but is to + * support existing userspace which depends on it. */ - hlist_for_each_entry(session_walk, g_head, global_hlist) - if (session_walk->session_id == session->session_id && - (session_walk->tunnel->encap == L2TP_ENCAPTYPE_IP || - tunnel->encap == L2TP_ENCAPTYPE_IP)) { - err = -EEXIST; - goto err_tlock_pnlock; - } + if (err == -ENOSPC && tunnel->encap == L2TP_ENCAPTYPE_UDP) { + struct l2tp_session *session2; - l2tp_tunnel_inc_refcount(tunnel); - hlist_add_head_rcu(&session->global_hlist, g_head); - - spin_unlock_bh(&pn->l2tp_session_hlist_lock); - } else { - l2tp_tunnel_inc_refcount(tunnel); + session2 = idr_find(&pn->l2tp_v3_session_idr, + session_key); + err = l2tp_session_collision_add(pn, session, session2); + } + spin_unlock_bh(&pn->l2tp_session_idr_lock); + if (err == -ENOSPC) + err = -EEXIST; } + if (err) + goto err_tlock; + + l2tp_tunnel_inc_refcount(tunnel); + hlist_add_head_rcu(&session->hlist, head); spin_unlock_bh(&tunnel->hlist_lock); + if (tunnel->version == L2TP_HDR_VER_3) { + spin_lock_bh(&pn->l2tp_session_idr_lock); + idr_replace(&pn->l2tp_v3_session_idr, session, session_key); + spin_unlock_bh(&pn->l2tp_session_idr_lock); + } + trace_register_session(session); return 0; -err_tlock_pnlock: - spin_unlock_bh(&pn->l2tp_session_hlist_lock); err_tlock: spin_unlock_bh(&tunnel->hlist_lock); @@ -1218,13 +1326,19 @@ static void l2tp_session_unhash(struct l2tp_session *session) hlist_del_init_rcu(&session->hlist); spin_unlock_bh(&tunnel->hlist_lock); - /* For L2TPv3 we have a per-net hash: remove from there, too */ - if (tunnel->version != L2TP_HDR_VER_2) { + /* For L2TPv3 we have a per-net IDR: remove from there, too */ + if (tunnel->version == L2TP_HDR_VER_3) { struct l2tp_net *pn = l2tp_pernet(tunnel->l2tp_net); - - spin_lock_bh(&pn->l2tp_session_hlist_lock); - hlist_del_init_rcu(&session->global_hlist); - spin_unlock_bh(&pn->l2tp_session_hlist_lock); + struct l2tp_session *removed = session; + + spin_lock_bh(&pn->l2tp_session_idr_lock); + if (hash_hashed(&session->hlist)) + l2tp_session_collision_del(pn, session); + else + removed = idr_remove(&pn->l2tp_v3_session_idr, + session->session_id); + WARN_ON_ONCE(removed && removed != session); + spin_unlock_bh(&pn->l2tp_session_idr_lock); } synchronize_rcu(); @@ -1649,8 +1763,9 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn skb_queue_head_init(&session->reorder_q); + session->hlist_key = l2tp_v3_session_hashkey(tunnel->sock, session->session_id); INIT_HLIST_NODE(&session->hlist); - INIT_HLIST_NODE(&session->global_hlist); + INIT_LIST_HEAD(&session->clist); if (cfg) { session->pwtype = cfg->pw_type; @@ -1683,15 +1798,12 @@ EXPORT_SYMBOL_GPL(l2tp_session_create); static __net_init int l2tp_init_net(struct net *net) { struct l2tp_net *pn = net_generic(net, l2tp_net_id); - int hash; idr_init(&pn->l2tp_tunnel_idr); spin_lock_init(&pn->l2tp_tunnel_idr_lock); - for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) - INIT_HLIST_HEAD(&pn->l2tp_session_hlist[hash]); - - spin_lock_init(&pn->l2tp_session_hlist_lock); + idr_init(&pn->l2tp_v3_session_idr); + spin_lock_init(&pn->l2tp_session_idr_lock); return 0; } @@ -1701,7 +1813,6 @@ static __net_exit void l2tp_exit_net(struct net *net) struct l2tp_net *pn = l2tp_pernet(net); struct l2tp_tunnel *tunnel = NULL; unsigned long tunnel_id, tmp; - int hash; rcu_read_lock_bh(); idr_for_each_entry_ul(&pn->l2tp_tunnel_idr, tunnel, tmp, tunnel_id) { @@ -1714,8 +1825,7 @@ static __net_exit void l2tp_exit_net(struct net *net) flush_workqueue(l2tp_wq); rcu_barrier(); - for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) - WARN_ON_ONCE(!hlist_empty(&pn->l2tp_session_hlist[hash])); + idr_destroy(&pn->l2tp_v3_session_idr); idr_destroy(&pn->l2tp_tunnel_idr); } diff --git a/net/l2tp/l2tp_core.h b/net/l2tp/l2tp_core.h index 54dfba1eb91c..bfccc4ca2644 100644 --- a/net/l2tp/l2tp_core.h +++ b/net/l2tp/l2tp_core.h @@ -23,10 +23,6 @@ #define L2TP_HASH_BITS 4 #define L2TP_HASH_SIZE BIT(L2TP_HASH_BITS) -/* System-wide session hash table size */ -#define L2TP_HASH_BITS_2 8 -#define L2TP_HASH_SIZE_2 BIT(L2TP_HASH_BITS_2) - struct sk_buff; struct l2tp_stats { @@ -61,6 +57,12 @@ struct l2tp_session_cfg { char *ifname; }; +struct l2tp_session_coll_list { + spinlock_t lock; /* for access to list */ + struct list_head list; + refcount_t ref_count; +}; + /* Represents a session (pseudowire) instance. * Tracks runtime state including cookies, dataplane packet sequencing, and IO statistics. * Is linked into a per-tunnel session hashlist; and in the case of an L2TPv3 session into @@ -88,8 +90,11 @@ struct l2tp_session { u32 nr_oos; /* NR of last OOS packet */ int nr_oos_count; /* for OOS recovery */ int nr_oos_count_max; - struct hlist_node hlist; /* hash list node */ refcount_t ref_count; + struct hlist_node hlist; /* per-net session hlist */ + unsigned long hlist_key; /* key for session hlist */ + struct l2tp_session_coll_list *coll_list; /* session collision list */ + struct list_head clist; /* for coll_list */ char name[L2TP_SESSION_NAME_MAX]; /* for logging */ char ifname[IFNAMSIZ]; @@ -102,7 +107,6 @@ struct l2tp_session { int reorder_skip; /* set if skip to next nr */ enum l2tp_pwtype pwtype; struct l2tp_stats stats; - struct hlist_node global_hlist; /* global hash list node */ /* Session receive handler for data packets. * Each pseudowire implementation should implement this callback in order to @@ -226,7 +230,7 @@ struct l2tp_tunnel *l2tp_tunnel_get_nth(const struct net *net, int nth); struct l2tp_session *l2tp_tunnel_get_session(struct l2tp_tunnel *tunnel, u32 session_id); -struct l2tp_session *l2tp_session_get(const struct net *net, u32 session_id); +struct l2tp_session *l2tp_v3_session_get(const struct net *net, struct sock *sk, u32 session_id); struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth); struct l2tp_session *l2tp_session_get_by_ifname(const struct net *net, const char *ifname); diff --git a/net/l2tp/l2tp_ip.c b/net/l2tp/l2tp_ip.c index 19c8cc5289d5..e48aa177d74c 100644 --- a/net/l2tp/l2tp_ip.c +++ b/net/l2tp/l2tp_ip.c @@ -140,7 +140,7 @@ static int l2tp_ip_recv(struct sk_buff *skb) } /* Ok, this is a data packet. Lookup the session. */ - session = l2tp_session_get(net, session_id); + session = l2tp_v3_session_get(net, NULL, session_id); if (!session) goto discard; diff --git a/net/l2tp/l2tp_ip6.c b/net/l2tp/l2tp_ip6.c index 8780ec64f376..d217ff1f229e 100644 --- a/net/l2tp/l2tp_ip6.c +++ b/net/l2tp/l2tp_ip6.c @@ -150,7 +150,7 @@ static int l2tp_ip6_recv(struct sk_buff *skb) } /* Ok, this is a data packet. Lookup the session. */ - session = l2tp_session_get(net, session_id); + session = l2tp_v3_session_get(net, NULL, session_id); if (!session) goto discard; |