summaryrefslogblamecommitdiff
path: root/fs/nfs_common/nfslocalio.c
blob: f0bff023bb5e75962c6654bf9a7bc3c80c850464 (plain) (tree)



















































































































                                                                  
// SPDX-License-Identifier: GPL-2.0-only
/*
 * Copyright (C) 2024 Mike Snitzer <snitzer@hammerspace.com>
 * Copyright (C) 2024 NeilBrown <neilb@suse.de>
 */

#include <linux/module.h>
#include <linux/rculist.h>
#include <linux/nfslocalio.h>
#include <net/netns/generic.h>

MODULE_LICENSE("GPL");
MODULE_DESCRIPTION("NFS localio protocol bypass support");

static DEFINE_SPINLOCK(nfs_uuid_lock);

/*
 * Global list of nfs_uuid_t instances
 * that is protected by nfs_uuid_lock.
 */
static LIST_HEAD(nfs_uuids);

void nfs_uuid_begin(nfs_uuid_t *nfs_uuid)
{
	nfs_uuid->net = NULL;
	nfs_uuid->dom = NULL;
	uuid_gen(&nfs_uuid->uuid);

	spin_lock(&nfs_uuid_lock);
	list_add_tail_rcu(&nfs_uuid->list, &nfs_uuids);
	spin_unlock(&nfs_uuid_lock);
}
EXPORT_SYMBOL_GPL(nfs_uuid_begin);

void nfs_uuid_end(nfs_uuid_t *nfs_uuid)
{
	if (nfs_uuid->net == NULL) {
		spin_lock(&nfs_uuid_lock);
		list_del_init(&nfs_uuid->list);
		spin_unlock(&nfs_uuid_lock);
	}
}
EXPORT_SYMBOL_GPL(nfs_uuid_end);

static nfs_uuid_t * nfs_uuid_lookup_locked(const uuid_t *uuid)
{
	nfs_uuid_t *nfs_uuid;

	list_for_each_entry(nfs_uuid, &nfs_uuids, list)
		if (uuid_equal(&nfs_uuid->uuid, uuid))
			return nfs_uuid;

	return NULL;
}

static struct module *nfsd_mod;

void nfs_uuid_is_local(const uuid_t *uuid, struct list_head *list,
		       struct net *net, struct auth_domain *dom,
		       struct module *mod)
{
	nfs_uuid_t *nfs_uuid;

	spin_lock(&nfs_uuid_lock);
	nfs_uuid = nfs_uuid_lookup_locked(uuid);
	if (nfs_uuid) {
		kref_get(&dom->ref);
		nfs_uuid->dom = dom;
		/*
		 * We don't hold a ref on the net, but instead put
		 * ourselves on a list so the net pointer can be
		 * invalidated.
		 */
		list_move(&nfs_uuid->list, list);
		nfs_uuid->net = net;

		__module_get(mod);
		nfsd_mod = mod;
	}
	spin_unlock(&nfs_uuid_lock);
}
EXPORT_SYMBOL_GPL(nfs_uuid_is_local);

static void nfs_uuid_put_locked(nfs_uuid_t *nfs_uuid)
{
	if (nfs_uuid->net) {
		module_put(nfsd_mod);
		nfs_uuid->net = NULL;
	}
	if (nfs_uuid->dom) {
		auth_domain_put(nfs_uuid->dom);
		nfs_uuid->dom = NULL;
	}
	list_del_init(&nfs_uuid->list);
}

void nfs_uuid_invalidate_clients(struct list_head *list)
{
	nfs_uuid_t *nfs_uuid, *tmp;

	spin_lock(&nfs_uuid_lock);
	list_for_each_entry_safe(nfs_uuid, tmp, list, list)
		nfs_uuid_put_locked(nfs_uuid);
	spin_unlock(&nfs_uuid_lock);
}
EXPORT_SYMBOL_GPL(nfs_uuid_invalidate_clients);

void nfs_uuid_invalidate_one_client(nfs_uuid_t *nfs_uuid)
{
	if (nfs_uuid->net) {
		spin_lock(&nfs_uuid_lock);
		nfs_uuid_put_locked(nfs_uuid);
		spin_unlock(&nfs_uuid_lock);
	}
}
EXPORT_SYMBOL_GPL(nfs_uuid_invalidate_one_client);