summaryrefslogtreecommitdiff
path: root/net/9p/trans_fd.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/9p/trans_fd.c')
-rw-r--r--net/9p/trans_fd.c73
1 files changed, 34 insertions, 39 deletions
diff --git a/net/9p/trans_fd.c b/net/9p/trans_fd.c
index 196060dc6138..339ec4e54778 100644
--- a/net/9p/trans_fd.c
+++ b/net/9p/trans_fd.c
@@ -11,6 +11,7 @@
#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
#include <linux/in.h>
+#include <linux/in6.h>
#include <linux/module.h>
#include <linux/net.h>
#include <linux/ipv6.h>
@@ -191,12 +192,13 @@ static void p9_conn_cancel(struct p9_conn *m, int err)
spin_lock(&m->req_lock);
- if (m->err) {
+ if (READ_ONCE(m->err)) {
spin_unlock(&m->req_lock);
return;
}
- m->err = err;
+ WRITE_ONCE(m->err, err);
+ ASSERT_EXCLUSIVE_WRITER(m->err);
list_for_each_entry_safe(req, rtmp, &m->req_list, req_list) {
list_move(&req->req_list, &cancel_list);
@@ -283,7 +285,7 @@ static void p9_read_work(struct work_struct *work)
m = container_of(work, struct p9_conn, rq);
- if (m->err < 0)
+ if (READ_ONCE(m->err) < 0)
return;
p9_debug(P9_DEBUG_TRANS, "start mux %p pos %zd\n", m, m->rc.offset);
@@ -450,7 +452,7 @@ static void p9_write_work(struct work_struct *work)
m = container_of(work, struct p9_conn, wq);
- if (m->err < 0) {
+ if (READ_ONCE(m->err) < 0) {
clear_bit(Wworksched, &m->wsched);
return;
}
@@ -622,7 +624,7 @@ static void p9_poll_mux(struct p9_conn *m)
__poll_t n;
int err = -ECONNRESET;
- if (m->err < 0)
+ if (READ_ONCE(m->err) < 0)
return;
n = p9_fd_poll(m->client, NULL, &err);
@@ -665,6 +667,7 @@ static void p9_poll_mux(struct p9_conn *m)
static int p9_fd_request(struct p9_client *client, struct p9_req_t *req)
{
__poll_t n;
+ int err;
struct p9_trans_fd *ts = client->trans;
struct p9_conn *m = &ts->conn;
@@ -673,9 +676,10 @@ static int p9_fd_request(struct p9_client *client, struct p9_req_t *req)
spin_lock(&m->req_lock);
- if (m->err < 0) {
+ err = READ_ONCE(m->err);
+ if (err < 0) {
spin_unlock(&m->req_lock);
- return m->err;
+ return err;
}
WRITE_ONCE(req->status, REQ_STATUS_UNSENT);
@@ -954,64 +958,55 @@ static void p9_fd_close(struct p9_client *client)
kfree(ts);
}
-/*
- * stolen from NFS - maybe should be made a generic function?
- */
-static inline int valid_ipaddr4(const char *buf)
-{
- int rc, count, in[4];
-
- rc = sscanf(buf, "%d.%d.%d.%d", &in[0], &in[1], &in[2], &in[3]);
- if (rc != 4)
- return -EINVAL;
- for (count = 0; count < 4; count++) {
- if (in[count] > 255)
- return -EINVAL;
- }
- return 0;
-}
-
static int p9_bind_privport(struct socket *sock)
{
- struct sockaddr_in cl;
+ struct sockaddr_storage stor = { 0 };
int port, err = -EINVAL;
- memset(&cl, 0, sizeof(cl));
- cl.sin_family = AF_INET;
- cl.sin_addr.s_addr = htonl(INADDR_ANY);
+ stor.ss_family = sock->ops->family;
+ if (stor.ss_family == AF_INET)
+ ((struct sockaddr_in *)&stor)->sin_addr.s_addr = htonl(INADDR_ANY);
+ else
+ ((struct sockaddr_in6 *)&stor)->sin6_addr = in6addr_any;
for (port = p9_ipport_resv_max; port >= p9_ipport_resv_min; port--) {
- cl.sin_port = htons((ushort)port);
- err = kernel_bind(sock, (struct sockaddr *)&cl, sizeof(cl));
+ if (stor.ss_family == AF_INET)
+ ((struct sockaddr_in *)&stor)->sin_port = htons((ushort)port);
+ else
+ ((struct sockaddr_in6 *)&stor)->sin6_port = htons((ushort)port);
+ err = kernel_bind(sock, (struct sockaddr *)&stor, sizeof(stor));
if (err != -EADDRINUSE)
break;
}
return err;
}
-
static int
p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args)
{
int err;
+ char port_str[6];
struct socket *csocket;
- struct sockaddr_in sin_server;
+ struct sockaddr_storage stor = { 0 };
struct p9_fd_opts opts;
err = parse_opts(args, &opts);
if (err < 0)
return err;
- if (addr == NULL || valid_ipaddr4(addr) < 0)
+ if (!addr)
return -EINVAL;
+ sprintf(port_str, "%u", opts.port);
+ err = inet_pton_with_scope(current->nsproxy->net_ns, AF_UNSPEC, addr,
+ port_str, &stor);
+ if (err < 0)
+ return err;
+
csocket = NULL;
client->trans_opts.tcp.port = opts.port;
client->trans_opts.tcp.privport = opts.privport;
- sin_server.sin_family = AF_INET;
- sin_server.sin_addr.s_addr = in_aton(addr);
- sin_server.sin_port = htons(opts.port);
- err = __sock_create(current->nsproxy->net_ns, PF_INET,
+ err = __sock_create(current->nsproxy->net_ns, stor.ss_family,
SOCK_STREAM, IPPROTO_TCP, &csocket, 1);
if (err) {
pr_err("%s (%d): problem creating socket\n",
@@ -1030,8 +1025,8 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args)
}
err = READ_ONCE(csocket->ops)->connect(csocket,
- (struct sockaddr *)&sin_server,
- sizeof(struct sockaddr_in), 0);
+ (struct sockaddr *)&stor,
+ sizeof(stor), 0);
if (err < 0) {
pr_err("%s (%d): problem connecting socket to %s\n",
__func__, task_pid_nr(current), addr);