diff options
author | Timothy Pearson <tpearson@raptorengineering.com> | 2017-08-23 14:45:25 -0500 |
---|---|---|
committer | Timothy Pearson <tpearson@raptorengineering.com> | 2017-08-23 14:45:25 -0500 |
commit | fcbb27b0ec6dcbc5a5108cb8fb19eae64593d204 (patch) | |
tree | 22962a4387943edc841c72a4e636a068c66d58fd /net/sunrpc | |
download | ast2050-linux-kernel-fcbb27b0ec6dcbc5a5108cb8fb19eae64593d204.zip ast2050-linux-kernel-fcbb27b0ec6dcbc5a5108cb8fb19eae64593d204.tar.gz |
Initial import of modified Linux 2.6.28 tree
Original upstream URL:
git://git.kernel.org/pub/scm/linux/kernel/git/stable/linux-stable.git | branch linux-2.6.28.y
Diffstat (limited to 'net/sunrpc')
48 files changed, 29630 insertions, 0 deletions
diff --git a/net/sunrpc/Makefile b/net/sunrpc/Makefile new file mode 100644 index 0000000..5369aa3 --- /dev/null +++ b/net/sunrpc/Makefile @@ -0,0 +1,17 @@ +# +# Makefile for Linux kernel SUN RPC +# + + +obj-$(CONFIG_SUNRPC) += sunrpc.o +obj-$(CONFIG_SUNRPC_GSS) += auth_gss/ +obj-$(CONFIG_SUNRPC_XPRT_RDMA) += xprtrdma/ + +sunrpc-y := clnt.o xprt.o socklib.o xprtsock.o sched.o \ + auth.o auth_null.o auth_unix.o auth_generic.o \ + svc.o svcsock.o svcauth.o svcauth_unix.o \ + rpcb_clnt.o timer.o xdr.o \ + sunrpc_syms.o cache.o rpc_pipe.o \ + svc_xprt.o +sunrpc-$(CONFIG_PROC_FS) += stats.o +sunrpc-$(CONFIG_SYSCTL) += sysctl.o diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c new file mode 100644 index 0000000..cb216b2 --- /dev/null +++ b/net/sunrpc/auth.c @@ -0,0 +1,584 @@ +/* + * linux/net/sunrpc/auth.c + * + * Generic RPC client authentication API. + * + * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/types.h> +#include <linux/sched.h> +#include <linux/module.h> +#include <linux/slab.h> +#include <linux/errno.h> +#include <linux/hash.h> +#include <linux/sunrpc/clnt.h> +#include <linux/spinlock.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static DEFINE_SPINLOCK(rpc_authflavor_lock); +static const struct rpc_authops *auth_flavors[RPC_AUTH_MAXFLAVOR] = { + &authnull_ops, /* AUTH_NULL */ + &authunix_ops, /* AUTH_UNIX */ + NULL, /* others can be loadable modules */ +}; + +static LIST_HEAD(cred_unused); +static unsigned long number_cred_unused; + +static u32 +pseudoflavor_to_flavor(u32 flavor) { + if (flavor >= RPC_AUTH_MAXFLAVOR) + return RPC_AUTH_GSS; + return flavor; +} + +int +rpcauth_register(const struct rpc_authops *ops) +{ + rpc_authflavor_t flavor; + int ret = -EPERM; + + if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) + return -EINVAL; + spin_lock(&rpc_authflavor_lock); + if (auth_flavors[flavor] == NULL) { + auth_flavors[flavor] = ops; + ret = 0; + } + spin_unlock(&rpc_authflavor_lock); + return ret; +} +EXPORT_SYMBOL_GPL(rpcauth_register); + +int +rpcauth_unregister(const struct rpc_authops *ops) +{ + rpc_authflavor_t flavor; + int ret = -EPERM; + + if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR) + return -EINVAL; + spin_lock(&rpc_authflavor_lock); + if (auth_flavors[flavor] == ops) { + auth_flavors[flavor] = NULL; + ret = 0; + } + spin_unlock(&rpc_authflavor_lock); + return ret; +} +EXPORT_SYMBOL_GPL(rpcauth_unregister); + +struct rpc_auth * +rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt) +{ + struct rpc_auth *auth; + const struct rpc_authops *ops; + u32 flavor = pseudoflavor_to_flavor(pseudoflavor); + + auth = ERR_PTR(-EINVAL); + if (flavor >= RPC_AUTH_MAXFLAVOR) + goto out; + + if ((ops = auth_flavors[flavor]) == NULL) + request_module("rpc-auth-%u", flavor); + spin_lock(&rpc_authflavor_lock); + ops = auth_flavors[flavor]; + if (ops == NULL || !try_module_get(ops->owner)) { + spin_unlock(&rpc_authflavor_lock); + goto out; + } + spin_unlock(&rpc_authflavor_lock); + auth = ops->create(clnt, pseudoflavor); + module_put(ops->owner); + if (IS_ERR(auth)) + return auth; + if (clnt->cl_auth) + rpcauth_release(clnt->cl_auth); + clnt->cl_auth = auth; + +out: + return auth; +} +EXPORT_SYMBOL_GPL(rpcauth_create); + +void +rpcauth_release(struct rpc_auth *auth) +{ + if (!atomic_dec_and_test(&auth->au_count)) + return; + auth->au_ops->destroy(auth); +} + +static DEFINE_SPINLOCK(rpc_credcache_lock); + +static void +rpcauth_unhash_cred_locked(struct rpc_cred *cred) +{ + hlist_del_rcu(&cred->cr_hash); + smp_mb__before_clear_bit(); + clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags); +} + +static void +rpcauth_unhash_cred(struct rpc_cred *cred) +{ + spinlock_t *cache_lock; + + cache_lock = &cred->cr_auth->au_credcache->lock; + spin_lock(cache_lock); + if (atomic_read(&cred->cr_count) == 0) + rpcauth_unhash_cred_locked(cred); + spin_unlock(cache_lock); +} + +/* + * Initialize RPC credential cache + */ +int +rpcauth_init_credcache(struct rpc_auth *auth) +{ + struct rpc_cred_cache *new; + int i; + + new = kmalloc(sizeof(*new), GFP_KERNEL); + if (!new) + return -ENOMEM; + for (i = 0; i < RPC_CREDCACHE_NR; i++) + INIT_HLIST_HEAD(&new->hashtable[i]); + spin_lock_init(&new->lock); + auth->au_credcache = new; + return 0; +} +EXPORT_SYMBOL_GPL(rpcauth_init_credcache); + +/* + * Destroy a list of credentials + */ +static inline +void rpcauth_destroy_credlist(struct list_head *head) +{ + struct rpc_cred *cred; + + while (!list_empty(head)) { + cred = list_entry(head->next, struct rpc_cred, cr_lru); + list_del_init(&cred->cr_lru); + put_rpccred(cred); + } +} + +/* + * Clear the RPC credential cache, and delete those credentials + * that are not referenced. + */ +void +rpcauth_clear_credcache(struct rpc_cred_cache *cache) +{ + LIST_HEAD(free); + struct hlist_head *head; + struct rpc_cred *cred; + int i; + + spin_lock(&rpc_credcache_lock); + spin_lock(&cache->lock); + for (i = 0; i < RPC_CREDCACHE_NR; i++) { + head = &cache->hashtable[i]; + while (!hlist_empty(head)) { + cred = hlist_entry(head->first, struct rpc_cred, cr_hash); + get_rpccred(cred); + if (!list_empty(&cred->cr_lru)) { + list_del(&cred->cr_lru); + number_cred_unused--; + } + list_add_tail(&cred->cr_lru, &free); + rpcauth_unhash_cred_locked(cred); + } + } + spin_unlock(&cache->lock); + spin_unlock(&rpc_credcache_lock); + rpcauth_destroy_credlist(&free); +} + +/* + * Destroy the RPC credential cache + */ +void +rpcauth_destroy_credcache(struct rpc_auth *auth) +{ + struct rpc_cred_cache *cache = auth->au_credcache; + + if (cache) { + auth->au_credcache = NULL; + rpcauth_clear_credcache(cache); + kfree(cache); + } +} +EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache); + + +#define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ) + +/* + * Remove stale credentials. Avoid sleeping inside the loop. + */ +static int +rpcauth_prune_expired(struct list_head *free, int nr_to_scan) +{ + spinlock_t *cache_lock; + struct rpc_cred *cred, *next; + unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM; + + list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) { + + /* Enforce a 60 second garbage collection moratorium */ + if (time_in_range(cred->cr_expire, expired, jiffies) && + test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) + continue; + + list_del_init(&cred->cr_lru); + number_cred_unused--; + if (atomic_read(&cred->cr_count) != 0) + continue; + + cache_lock = &cred->cr_auth->au_credcache->lock; + spin_lock(cache_lock); + if (atomic_read(&cred->cr_count) == 0) { + get_rpccred(cred); + list_add_tail(&cred->cr_lru, free); + rpcauth_unhash_cred_locked(cred); + nr_to_scan--; + } + spin_unlock(cache_lock); + if (nr_to_scan == 0) + break; + } + return nr_to_scan; +} + +/* + * Run memory cache shrinker. + */ +static int +rpcauth_cache_shrinker(int nr_to_scan, gfp_t gfp_mask) +{ + LIST_HEAD(free); + int res; + + if (list_empty(&cred_unused)) + return 0; + spin_lock(&rpc_credcache_lock); + nr_to_scan = rpcauth_prune_expired(&free, nr_to_scan); + res = (number_cred_unused / 100) * sysctl_vfs_cache_pressure; + spin_unlock(&rpc_credcache_lock); + rpcauth_destroy_credlist(&free); + return res; +} + +/* + * Look up a process' credentials in the authentication cache + */ +struct rpc_cred * +rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, + int flags) +{ + LIST_HEAD(free); + struct rpc_cred_cache *cache = auth->au_credcache; + struct hlist_node *pos; + struct rpc_cred *cred = NULL, + *entry, *new; + unsigned int nr; + + nr = hash_long(acred->uid, RPC_CREDCACHE_HASHBITS); + + rcu_read_lock(); + hlist_for_each_entry_rcu(entry, pos, &cache->hashtable[nr], cr_hash) { + if (!entry->cr_ops->crmatch(acred, entry, flags)) + continue; + spin_lock(&cache->lock); + if (test_bit(RPCAUTH_CRED_HASHED, &entry->cr_flags) == 0) { + spin_unlock(&cache->lock); + continue; + } + cred = get_rpccred(entry); + spin_unlock(&cache->lock); + break; + } + rcu_read_unlock(); + + if (cred != NULL) + goto found; + + new = auth->au_ops->crcreate(auth, acred, flags); + if (IS_ERR(new)) { + cred = new; + goto out; + } + + spin_lock(&cache->lock); + hlist_for_each_entry(entry, pos, &cache->hashtable[nr], cr_hash) { + if (!entry->cr_ops->crmatch(acred, entry, flags)) + continue; + cred = get_rpccred(entry); + break; + } + if (cred == NULL) { + cred = new; + set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags); + hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]); + } else + list_add_tail(&new->cr_lru, &free); + spin_unlock(&cache->lock); +found: + if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) + && cred->cr_ops->cr_init != NULL + && !(flags & RPCAUTH_LOOKUP_NEW)) { + int res = cred->cr_ops->cr_init(auth, cred); + if (res < 0) { + put_rpccred(cred); + cred = ERR_PTR(res); + } + } + rpcauth_destroy_credlist(&free); +out: + return cred; +} +EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache); + +struct rpc_cred * +rpcauth_lookupcred(struct rpc_auth *auth, int flags) +{ + struct auth_cred acred = { + .uid = current->fsuid, + .gid = current->fsgid, + .group_info = current->group_info, + }; + struct rpc_cred *ret; + + dprintk("RPC: looking up %s cred\n", + auth->au_ops->au_name); + get_group_info(acred.group_info); + ret = auth->au_ops->lookup_cred(auth, &acred, flags); + put_group_info(acred.group_info); + return ret; +} + +void +rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred, + struct rpc_auth *auth, const struct rpc_credops *ops) +{ + INIT_HLIST_NODE(&cred->cr_hash); + INIT_LIST_HEAD(&cred->cr_lru); + atomic_set(&cred->cr_count, 1); + cred->cr_auth = auth; + cred->cr_ops = ops; + cred->cr_expire = jiffies; +#ifdef RPC_DEBUG + cred->cr_magic = RPCAUTH_CRED_MAGIC; +#endif + cred->cr_uid = acred->uid; +} +EXPORT_SYMBOL_GPL(rpcauth_init_cred); + +void +rpcauth_generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred) +{ + task->tk_msg.rpc_cred = get_rpccred(cred); + dprintk("RPC: %5u holding %s cred %p\n", task->tk_pid, + cred->cr_auth->au_ops->au_name, cred); +} +EXPORT_SYMBOL_GPL(rpcauth_generic_bind_cred); + +static void +rpcauth_bind_root_cred(struct rpc_task *task) +{ + struct rpc_auth *auth = task->tk_client->cl_auth; + struct auth_cred acred = { + .uid = 0, + .gid = 0, + }; + struct rpc_cred *ret; + + dprintk("RPC: %5u looking up %s cred\n", + task->tk_pid, task->tk_client->cl_auth->au_ops->au_name); + ret = auth->au_ops->lookup_cred(auth, &acred, 0); + if (!IS_ERR(ret)) + task->tk_msg.rpc_cred = ret; + else + task->tk_status = PTR_ERR(ret); +} + +static void +rpcauth_bind_new_cred(struct rpc_task *task) +{ + struct rpc_auth *auth = task->tk_client->cl_auth; + struct rpc_cred *ret; + + dprintk("RPC: %5u looking up %s cred\n", + task->tk_pid, auth->au_ops->au_name); + ret = rpcauth_lookupcred(auth, 0); + if (!IS_ERR(ret)) + task->tk_msg.rpc_cred = ret; + else + task->tk_status = PTR_ERR(ret); +} + +void +rpcauth_bindcred(struct rpc_task *task, struct rpc_cred *cred, int flags) +{ + if (cred != NULL) + cred->cr_ops->crbind(task, cred); + else if (flags & RPC_TASK_ROOTCREDS) + rpcauth_bind_root_cred(task); + else + rpcauth_bind_new_cred(task); +} + +void +put_rpccred(struct rpc_cred *cred) +{ + /* Fast path for unhashed credentials */ + if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) + goto need_lock; + + if (!atomic_dec_and_test(&cred->cr_count)) + return; + goto out_destroy; +need_lock: + if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock)) + return; + if (!list_empty(&cred->cr_lru)) { + number_cred_unused--; + list_del_init(&cred->cr_lru); + } + if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0) + rpcauth_unhash_cred(cred); + if (test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags) != 0) { + cred->cr_expire = jiffies; + list_add_tail(&cred->cr_lru, &cred_unused); + number_cred_unused++; + spin_unlock(&rpc_credcache_lock); + return; + } + spin_unlock(&rpc_credcache_lock); +out_destroy: + cred->cr_ops->crdestroy(cred); +} +EXPORT_SYMBOL_GPL(put_rpccred); + +void +rpcauth_unbindcred(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + + dprintk("RPC: %5u releasing %s cred %p\n", + task->tk_pid, cred->cr_auth->au_ops->au_name, cred); + + put_rpccred(cred); + task->tk_msg.rpc_cred = NULL; +} + +__be32 * +rpcauth_marshcred(struct rpc_task *task, __be32 *p) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + + dprintk("RPC: %5u marshaling %s cred %p\n", + task->tk_pid, cred->cr_auth->au_ops->au_name, cred); + + return cred->cr_ops->crmarshal(task, p); +} + +__be32 * +rpcauth_checkverf(struct rpc_task *task, __be32 *p) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + + dprintk("RPC: %5u validating %s cred %p\n", + task->tk_pid, cred->cr_auth->au_ops->au_name, cred); + + return cred->cr_ops->crvalidate(task, p); +} + +int +rpcauth_wrap_req(struct rpc_task *task, kxdrproc_t encode, void *rqstp, + __be32 *data, void *obj) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + + dprintk("RPC: %5u using %s cred %p to wrap rpc data\n", + task->tk_pid, cred->cr_ops->cr_name, cred); + if (cred->cr_ops->crwrap_req) + return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj); + /* By default, we encode the arguments normally. */ + return rpc_call_xdrproc(encode, rqstp, data, obj); +} + +int +rpcauth_unwrap_resp(struct rpc_task *task, kxdrproc_t decode, void *rqstp, + __be32 *data, void *obj) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + + dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n", + task->tk_pid, cred->cr_ops->cr_name, cred); + if (cred->cr_ops->crunwrap_resp) + return cred->cr_ops->crunwrap_resp(task, decode, rqstp, + data, obj); + /* By default, we decode the arguments normally. */ + return rpc_call_xdrproc(decode, rqstp, data, obj); +} + +int +rpcauth_refreshcred(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + int err; + + dprintk("RPC: %5u refreshing %s cred %p\n", + task->tk_pid, cred->cr_auth->au_ops->au_name, cred); + + err = cred->cr_ops->crrefresh(task); + if (err < 0) + task->tk_status = err; + return err; +} + +void +rpcauth_invalcred(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + + dprintk("RPC: %5u invalidating %s cred %p\n", + task->tk_pid, cred->cr_auth->au_ops->au_name, cred); + if (cred) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); +} + +int +rpcauth_uptodatecred(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + + return cred == NULL || + test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0; +} + +static struct shrinker rpc_cred_shrinker = { + .shrink = rpcauth_cache_shrinker, + .seeks = DEFAULT_SEEKS, +}; + +void __init rpcauth_init_module(void) +{ + rpc_init_authunix(); + rpc_init_generic_auth(); + register_shrinker(&rpc_cred_shrinker); +} + +void __exit rpcauth_remove_module(void) +{ + unregister_shrinker(&rpc_cred_shrinker); +} diff --git a/net/sunrpc/auth_generic.c b/net/sunrpc/auth_generic.c new file mode 100644 index 0000000..4028502 --- /dev/null +++ b/net/sunrpc/auth_generic.c @@ -0,0 +1,193 @@ +/* + * Generic RPC credential + * + * Copyright (C) 2008, Trond Myklebust <Trond.Myklebust@netapp.com> + */ + +#include <linux/err.h> +#include <linux/types.h> +#include <linux/module.h> +#include <linux/sched.h> +#include <linux/sunrpc/auth.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/debug.h> +#include <linux/sunrpc/sched.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +#define RPC_MACHINE_CRED_USERID ((uid_t)0) +#define RPC_MACHINE_CRED_GROUPID ((gid_t)0) + +struct generic_cred { + struct rpc_cred gc_base; + struct auth_cred acred; +}; + +static struct rpc_auth generic_auth; +static struct rpc_cred_cache generic_cred_cache; +static const struct rpc_credops generic_credops; + +/* + * Public call interface + */ +struct rpc_cred *rpc_lookup_cred(void) +{ + return rpcauth_lookupcred(&generic_auth, 0); +} +EXPORT_SYMBOL_GPL(rpc_lookup_cred); + +/* + * Public call interface for looking up machine creds. + */ +struct rpc_cred *rpc_lookup_machine_cred(void) +{ + struct auth_cred acred = { + .uid = RPC_MACHINE_CRED_USERID, + .gid = RPC_MACHINE_CRED_GROUPID, + .machine_cred = 1, + }; + + dprintk("RPC: looking up machine cred\n"); + return generic_auth.au_ops->lookup_cred(&generic_auth, &acred, 0); +} +EXPORT_SYMBOL_GPL(rpc_lookup_machine_cred); + +static void +generic_bind_cred(struct rpc_task *task, struct rpc_cred *cred) +{ + struct rpc_auth *auth = task->tk_client->cl_auth; + struct auth_cred *acred = &container_of(cred, struct generic_cred, gc_base)->acred; + struct rpc_cred *ret; + + ret = auth->au_ops->lookup_cred(auth, acred, 0); + if (!IS_ERR(ret)) + task->tk_msg.rpc_cred = ret; + else + task->tk_status = PTR_ERR(ret); +} + +/* + * Lookup generic creds for current process + */ +static struct rpc_cred * +generic_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +{ + return rpcauth_lookup_credcache(&generic_auth, acred, flags); +} + +static struct rpc_cred * +generic_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +{ + struct generic_cred *gcred; + + gcred = kmalloc(sizeof(*gcred), GFP_KERNEL); + if (gcred == NULL) + return ERR_PTR(-ENOMEM); + + rpcauth_init_cred(&gcred->gc_base, acred, &generic_auth, &generic_credops); + gcred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_UPTODATE; + + gcred->acred.uid = acred->uid; + gcred->acred.gid = acred->gid; + gcred->acred.group_info = acred->group_info; + if (gcred->acred.group_info != NULL) + get_group_info(gcred->acred.group_info); + gcred->acred.machine_cred = acred->machine_cred; + + dprintk("RPC: allocated %s cred %p for uid %d gid %d\n", + gcred->acred.machine_cred ? "machine" : "generic", + gcred, acred->uid, acred->gid); + return &gcred->gc_base; +} + +static void +generic_free_cred(struct rpc_cred *cred) +{ + struct generic_cred *gcred = container_of(cred, struct generic_cred, gc_base); + + dprintk("RPC: generic_free_cred %p\n", gcred); + if (gcred->acred.group_info != NULL) + put_group_info(gcred->acred.group_info); + kfree(gcred); +} + +static void +generic_free_cred_callback(struct rcu_head *head) +{ + struct rpc_cred *cred = container_of(head, struct rpc_cred, cr_rcu); + generic_free_cred(cred); +} + +static void +generic_destroy_cred(struct rpc_cred *cred) +{ + call_rcu(&cred->cr_rcu, generic_free_cred_callback); +} + +/* + * Match credentials against current process creds. + */ +static int +generic_match(struct auth_cred *acred, struct rpc_cred *cred, int flags) +{ + struct generic_cred *gcred = container_of(cred, struct generic_cred, gc_base); + int i; + + if (gcred->acred.uid != acred->uid || + gcred->acred.gid != acred->gid || + gcred->acred.machine_cred != acred->machine_cred) + goto out_nomatch; + + /* Optimisation in the case where pointers are identical... */ + if (gcred->acred.group_info == acred->group_info) + goto out_match; + + /* Slow path... */ + if (gcred->acred.group_info->ngroups != acred->group_info->ngroups) + goto out_nomatch; + for (i = 0; i < gcred->acred.group_info->ngroups; i++) { + if (GROUP_AT(gcred->acred.group_info, i) != + GROUP_AT(acred->group_info, i)) + goto out_nomatch; + } +out_match: + return 1; +out_nomatch: + return 0; +} + +void __init rpc_init_generic_auth(void) +{ + spin_lock_init(&generic_cred_cache.lock); +} + +void __exit rpc_destroy_generic_auth(void) +{ + rpcauth_clear_credcache(&generic_cred_cache); +} + +static struct rpc_cred_cache generic_cred_cache = { + {{ NULL, },}, +}; + +static const struct rpc_authops generic_auth_ops = { + .owner = THIS_MODULE, + .au_name = "Generic", + .lookup_cred = generic_lookup_cred, + .crcreate = generic_create_cred, +}; + +static struct rpc_auth generic_auth = { + .au_ops = &generic_auth_ops, + .au_count = ATOMIC_INIT(0), + .au_credcache = &generic_cred_cache, +}; + +static const struct rpc_credops generic_credops = { + .cr_name = "Generic cred", + .crdestroy = generic_destroy_cred, + .crbind = generic_bind_cred, + .crmatch = generic_match, +}; diff --git a/net/sunrpc/auth_gss/Makefile b/net/sunrpc/auth_gss/Makefile new file mode 100644 index 0000000..4de8bcf --- /dev/null +++ b/net/sunrpc/auth_gss/Makefile @@ -0,0 +1,18 @@ +# +# Makefile for Linux kernel rpcsec_gss implementation +# + +obj-$(CONFIG_SUNRPC_GSS) += auth_rpcgss.o + +auth_rpcgss-objs := auth_gss.o gss_generic_token.o \ + gss_mech_switch.o svcauth_gss.o + +obj-$(CONFIG_RPCSEC_GSS_KRB5) += rpcsec_gss_krb5.o + +rpcsec_gss_krb5-objs := gss_krb5_mech.o gss_krb5_seal.o gss_krb5_unseal.o \ + gss_krb5_seqnum.o gss_krb5_wrap.o gss_krb5_crypto.o + +obj-$(CONFIG_RPCSEC_GSS_SPKM3) += rpcsec_gss_spkm3.o + +rpcsec_gss_spkm3-objs := gss_spkm3_mech.o gss_spkm3_seal.o gss_spkm3_unseal.o \ + gss_spkm3_token.o diff --git a/net/sunrpc/auth_gss/auth_gss.c b/net/sunrpc/auth_gss/auth_gss.c new file mode 100644 index 0000000..853a414 --- /dev/null +++ b/net/sunrpc/auth_gss/auth_gss.c @@ -0,0 +1,1372 @@ +/* + * linux/net/sunrpc/auth_gss/auth_gss.c + * + * RPCSEC_GSS client authentication. + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + * Dug Song <dugsong@monkey.org> + * Andy Adamson <andros@umich.edu> + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the University nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + + +#include <linux/module.h> +#include <linux/init.h> +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/sched.h> +#include <linux/pagemap.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/auth.h> +#include <linux/sunrpc/auth_gss.h> +#include <linux/sunrpc/svcauth_gss.h> +#include <linux/sunrpc/gss_err.h> +#include <linux/workqueue.h> +#include <linux/sunrpc/rpc_pipe_fs.h> +#include <linux/sunrpc/gss_api.h> +#include <asm/uaccess.h> + +static const struct rpc_authops authgss_ops; + +static const struct rpc_credops gss_credops; +static const struct rpc_credops gss_nullops; + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +#define GSS_CRED_SLACK 1024 +/* length of a krb5 verifier (48), plus data added before arguments when + * using integrity (two 4-byte integers): */ +#define GSS_VERF_SLACK 100 + +struct gss_auth { + struct kref kref; + struct rpc_auth rpc_auth; + struct gss_api_mech *mech; + enum rpc_gss_svc service; + struct rpc_clnt *client; + struct dentry *dentry; +}; + +static void gss_free_ctx(struct gss_cl_ctx *); +static struct rpc_pipe_ops gss_upcall_ops; + +static inline struct gss_cl_ctx * +gss_get_ctx(struct gss_cl_ctx *ctx) +{ + atomic_inc(&ctx->count); + return ctx; +} + +static inline void +gss_put_ctx(struct gss_cl_ctx *ctx) +{ + if (atomic_dec_and_test(&ctx->count)) + gss_free_ctx(ctx); +} + +/* gss_cred_set_ctx: + * called by gss_upcall_callback and gss_create_upcall in order + * to set the gss context. The actual exchange of an old context + * and a new one is protected by the inode->i_lock. + */ +static void +gss_cred_set_ctx(struct rpc_cred *cred, struct gss_cl_ctx *ctx) +{ + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + + if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags)) + return; + gss_get_ctx(ctx); + rcu_assign_pointer(gss_cred->gc_ctx, ctx); + set_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + smp_mb__before_clear_bit(); + clear_bit(RPCAUTH_CRED_NEW, &cred->cr_flags); +} + +static const void * +simple_get_bytes(const void *p, const void *end, void *res, size_t len) +{ + const void *q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); + memcpy(res, p, len); + return q; +} + +static inline const void * +simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest) +{ + const void *q; + unsigned int len; + + p = simple_get_bytes(p, end, &len, sizeof(len)); + if (IS_ERR(p)) + return p; + q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); + dest->data = kmemdup(p, len, GFP_NOFS); + if (unlikely(dest->data == NULL)) + return ERR_PTR(-ENOMEM); + dest->len = len; + return q; +} + +static struct gss_cl_ctx * +gss_cred_get_ctx(struct rpc_cred *cred) +{ + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + struct gss_cl_ctx *ctx = NULL; + + rcu_read_lock(); + if (gss_cred->gc_ctx) + ctx = gss_get_ctx(gss_cred->gc_ctx); + rcu_read_unlock(); + return ctx; +} + +static struct gss_cl_ctx * +gss_alloc_context(void) +{ + struct gss_cl_ctx *ctx; + + ctx = kzalloc(sizeof(*ctx), GFP_NOFS); + if (ctx != NULL) { + ctx->gc_proc = RPC_GSS_PROC_DATA; + ctx->gc_seq = 1; /* NetApp 6.4R1 doesn't accept seq. no. 0 */ + spin_lock_init(&ctx->gc_seq_lock); + atomic_set(&ctx->count,1); + } + return ctx; +} + +#define GSSD_MIN_TIMEOUT (60 * 60) +static const void * +gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct gss_api_mech *gm) +{ + const void *q; + unsigned int seclen; + unsigned int timeout; + u32 window_size; + int ret; + + /* First unsigned int gives the lifetime (in seconds) of the cred */ + p = simple_get_bytes(p, end, &timeout, sizeof(timeout)); + if (IS_ERR(p)) + goto err; + if (timeout == 0) + timeout = GSSD_MIN_TIMEOUT; + ctx->gc_expiry = jiffies + (unsigned long)timeout * HZ * 3 / 4; + /* Sequence number window. Determines the maximum number of simultaneous requests */ + p = simple_get_bytes(p, end, &window_size, sizeof(window_size)); + if (IS_ERR(p)) + goto err; + ctx->gc_win = window_size; + /* gssd signals an error by passing ctx->gc_win = 0: */ + if (ctx->gc_win == 0) { + /* in which case, p points to an error code which we ignore */ + p = ERR_PTR(-EACCES); + goto err; + } + /* copy the opaque wire context */ + p = simple_get_netobj(p, end, &ctx->gc_wire_ctx); + if (IS_ERR(p)) + goto err; + /* import the opaque security context */ + p = simple_get_bytes(p, end, &seclen, sizeof(seclen)); + if (IS_ERR(p)) + goto err; + q = (const void *)((const char *)p + seclen); + if (unlikely(q > end || q < p)) { + p = ERR_PTR(-EFAULT); + goto err; + } + ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx); + if (ret < 0) { + p = ERR_PTR(ret); + goto err; + } + return q; +err: + dprintk("RPC: gss_fill_context returning %ld\n", -PTR_ERR(p)); + return p; +} + + +struct gss_upcall_msg { + atomic_t count; + uid_t uid; + struct rpc_pipe_msg msg; + struct list_head list; + struct gss_auth *auth; + struct rpc_wait_queue rpc_waitqueue; + wait_queue_head_t waitqueue; + struct gss_cl_ctx *ctx; +}; + +static void +gss_release_msg(struct gss_upcall_msg *gss_msg) +{ + if (!atomic_dec_and_test(&gss_msg->count)) + return; + BUG_ON(!list_empty(&gss_msg->list)); + if (gss_msg->ctx != NULL) + gss_put_ctx(gss_msg->ctx); + rpc_destroy_wait_queue(&gss_msg->rpc_waitqueue); + kfree(gss_msg); +} + +static struct gss_upcall_msg * +__gss_find_upcall(struct rpc_inode *rpci, uid_t uid) +{ + struct gss_upcall_msg *pos; + list_for_each_entry(pos, &rpci->in_downcall, list) { + if (pos->uid != uid) + continue; + atomic_inc(&pos->count); + dprintk("RPC: gss_find_upcall found msg %p\n", pos); + return pos; + } + dprintk("RPC: gss_find_upcall found nothing\n"); + return NULL; +} + +/* Try to add an upcall to the pipefs queue. + * If an upcall owned by our uid already exists, then we return a reference + * to that upcall instead of adding the new upcall. + */ +static inline struct gss_upcall_msg * +gss_add_msg(struct gss_auth *gss_auth, struct gss_upcall_msg *gss_msg) +{ + struct inode *inode = gss_auth->dentry->d_inode; + struct rpc_inode *rpci = RPC_I(inode); + struct gss_upcall_msg *old; + + spin_lock(&inode->i_lock); + old = __gss_find_upcall(rpci, gss_msg->uid); + if (old == NULL) { + atomic_inc(&gss_msg->count); + list_add(&gss_msg->list, &rpci->in_downcall); + } else + gss_msg = old; + spin_unlock(&inode->i_lock); + return gss_msg; +} + +static void +__gss_unhash_msg(struct gss_upcall_msg *gss_msg) +{ + list_del_init(&gss_msg->list); + rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno); + wake_up_all(&gss_msg->waitqueue); + atomic_dec(&gss_msg->count); +} + +static void +gss_unhash_msg(struct gss_upcall_msg *gss_msg) +{ + struct gss_auth *gss_auth = gss_msg->auth; + struct inode *inode = gss_auth->dentry->d_inode; + + if (list_empty(&gss_msg->list)) + return; + spin_lock(&inode->i_lock); + if (!list_empty(&gss_msg->list)) + __gss_unhash_msg(gss_msg); + spin_unlock(&inode->i_lock); +} + +static void +gss_upcall_callback(struct rpc_task *task) +{ + struct gss_cred *gss_cred = container_of(task->tk_msg.rpc_cred, + struct gss_cred, gc_base); + struct gss_upcall_msg *gss_msg = gss_cred->gc_upcall; + struct inode *inode = gss_msg->auth->dentry->d_inode; + + spin_lock(&inode->i_lock); + if (gss_msg->ctx) + gss_cred_set_ctx(task->tk_msg.rpc_cred, gss_msg->ctx); + else + task->tk_status = gss_msg->msg.errno; + gss_cred->gc_upcall = NULL; + rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno); + spin_unlock(&inode->i_lock); + gss_release_msg(gss_msg); +} + +static inline struct gss_upcall_msg * +gss_alloc_msg(struct gss_auth *gss_auth, uid_t uid) +{ + struct gss_upcall_msg *gss_msg; + + gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS); + if (gss_msg != NULL) { + INIT_LIST_HEAD(&gss_msg->list); + rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq"); + init_waitqueue_head(&gss_msg->waitqueue); + atomic_set(&gss_msg->count, 1); + gss_msg->msg.data = &gss_msg->uid; + gss_msg->msg.len = sizeof(gss_msg->uid); + gss_msg->uid = uid; + gss_msg->auth = gss_auth; + } + return gss_msg; +} + +static struct gss_upcall_msg * +gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cred *cred) +{ + struct gss_cred *gss_cred = container_of(cred, + struct gss_cred, gc_base); + struct gss_upcall_msg *gss_new, *gss_msg; + uid_t uid = cred->cr_uid; + + /* Special case: rpc.gssd assumes that uid == 0 implies machine creds */ + if (gss_cred->gc_machine_cred != 0) + uid = 0; + + gss_new = gss_alloc_msg(gss_auth, uid); + if (gss_new == NULL) + return ERR_PTR(-ENOMEM); + gss_msg = gss_add_msg(gss_auth, gss_new); + if (gss_msg == gss_new) { + int res = rpc_queue_upcall(gss_auth->dentry->d_inode, &gss_new->msg); + if (res) { + gss_unhash_msg(gss_new); + gss_msg = ERR_PTR(res); + } + } else + gss_release_msg(gss_new); + return gss_msg; +} + +static inline int +gss_refresh_upcall(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct gss_auth *gss_auth = container_of(cred->cr_auth, + struct gss_auth, rpc_auth); + struct gss_cred *gss_cred = container_of(cred, + struct gss_cred, gc_base); + struct gss_upcall_msg *gss_msg; + struct inode *inode = gss_auth->dentry->d_inode; + int err = 0; + + dprintk("RPC: %5u gss_refresh_upcall for uid %u\n", task->tk_pid, + cred->cr_uid); + gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred); + if (IS_ERR(gss_msg)) { + err = PTR_ERR(gss_msg); + goto out; + } + spin_lock(&inode->i_lock); + if (gss_cred->gc_upcall != NULL) + rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL); + else if (gss_msg->ctx != NULL) { + gss_cred_set_ctx(task->tk_msg.rpc_cred, gss_msg->ctx); + gss_cred->gc_upcall = NULL; + rpc_wake_up_status(&gss_msg->rpc_waitqueue, gss_msg->msg.errno); + } else if (gss_msg->msg.errno >= 0) { + task->tk_timeout = 0; + gss_cred->gc_upcall = gss_msg; + /* gss_upcall_callback will release the reference to gss_upcall_msg */ + atomic_inc(&gss_msg->count); + rpc_sleep_on(&gss_msg->rpc_waitqueue, task, gss_upcall_callback); + } else + err = gss_msg->msg.errno; + spin_unlock(&inode->i_lock); + gss_release_msg(gss_msg); +out: + dprintk("RPC: %5u gss_refresh_upcall for uid %u result %d\n", + task->tk_pid, cred->cr_uid, err); + return err; +} + +static inline int +gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred) +{ + struct inode *inode = gss_auth->dentry->d_inode; + struct rpc_cred *cred = &gss_cred->gc_base; + struct gss_upcall_msg *gss_msg; + DEFINE_WAIT(wait); + int err = 0; + + dprintk("RPC: gss_upcall for uid %u\n", cred->cr_uid); + gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred); + if (IS_ERR(gss_msg)) { + err = PTR_ERR(gss_msg); + goto out; + } + for (;;) { + prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_INTERRUPTIBLE); + spin_lock(&inode->i_lock); + if (gss_msg->ctx != NULL || gss_msg->msg.errno < 0) { + break; + } + spin_unlock(&inode->i_lock); + if (signalled()) { + err = -ERESTARTSYS; + goto out_intr; + } + schedule(); + } + if (gss_msg->ctx) + gss_cred_set_ctx(cred, gss_msg->ctx); + else + err = gss_msg->msg.errno; + spin_unlock(&inode->i_lock); +out_intr: + finish_wait(&gss_msg->waitqueue, &wait); + gss_release_msg(gss_msg); +out: + dprintk("RPC: gss_create_upcall for uid %u result %d\n", + cred->cr_uid, err); + return err; +} + +static ssize_t +gss_pipe_upcall(struct file *filp, struct rpc_pipe_msg *msg, + char __user *dst, size_t buflen) +{ + char *data = (char *)msg->data + msg->copied; + size_t mlen = min(msg->len, buflen); + unsigned long left; + + left = copy_to_user(dst, data, mlen); + if (left == mlen) { + msg->errno = -EFAULT; + return -EFAULT; + } + + mlen -= left; + msg->copied += mlen; + msg->errno = 0; + return mlen; +} + +#define MSG_BUF_MAXSIZE 1024 + +static ssize_t +gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen) +{ + const void *p, *end; + void *buf; + struct gss_upcall_msg *gss_msg; + struct inode *inode = filp->f_path.dentry->d_inode; + struct gss_cl_ctx *ctx; + uid_t uid; + ssize_t err = -EFBIG; + + if (mlen > MSG_BUF_MAXSIZE) + goto out; + err = -ENOMEM; + buf = kmalloc(mlen, GFP_NOFS); + if (!buf) + goto out; + + err = -EFAULT; + if (copy_from_user(buf, src, mlen)) + goto err; + + end = (const void *)((char *)buf + mlen); + p = simple_get_bytes(buf, end, &uid, sizeof(uid)); + if (IS_ERR(p)) { + err = PTR_ERR(p); + goto err; + } + + err = -ENOMEM; + ctx = gss_alloc_context(); + if (ctx == NULL) + goto err; + + err = -ENOENT; + /* Find a matching upcall */ + spin_lock(&inode->i_lock); + gss_msg = __gss_find_upcall(RPC_I(inode), uid); + if (gss_msg == NULL) { + spin_unlock(&inode->i_lock); + goto err_put_ctx; + } + list_del_init(&gss_msg->list); + spin_unlock(&inode->i_lock); + + p = gss_fill_context(p, end, ctx, gss_msg->auth->mech); + if (IS_ERR(p)) { + err = PTR_ERR(p); + gss_msg->msg.errno = (err == -EAGAIN) ? -EAGAIN : -EACCES; + goto err_release_msg; + } + gss_msg->ctx = gss_get_ctx(ctx); + err = mlen; + +err_release_msg: + spin_lock(&inode->i_lock); + __gss_unhash_msg(gss_msg); + spin_unlock(&inode->i_lock); + gss_release_msg(gss_msg); +err_put_ctx: + gss_put_ctx(ctx); +err: + kfree(buf); +out: + dprintk("RPC: gss_pipe_downcall returning %Zd\n", err); + return err; +} + +static void +gss_pipe_release(struct inode *inode) +{ + struct rpc_inode *rpci = RPC_I(inode); + struct gss_upcall_msg *gss_msg; + + spin_lock(&inode->i_lock); + while (!list_empty(&rpci->in_downcall)) { + + gss_msg = list_entry(rpci->in_downcall.next, + struct gss_upcall_msg, list); + gss_msg->msg.errno = -EPIPE; + atomic_inc(&gss_msg->count); + __gss_unhash_msg(gss_msg); + spin_unlock(&inode->i_lock); + gss_release_msg(gss_msg); + spin_lock(&inode->i_lock); + } + spin_unlock(&inode->i_lock); +} + +static void +gss_pipe_destroy_msg(struct rpc_pipe_msg *msg) +{ + struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg); + static unsigned long ratelimit; + + if (msg->errno < 0) { + dprintk("RPC: gss_pipe_destroy_msg releasing msg %p\n", + gss_msg); + atomic_inc(&gss_msg->count); + gss_unhash_msg(gss_msg); + if (msg->errno == -ETIMEDOUT) { + unsigned long now = jiffies; + if (time_after(now, ratelimit)) { + printk(KERN_WARNING "RPC: AUTH_GSS upcall timed out.\n" + "Please check user daemon is running!\n"); + ratelimit = now + 15*HZ; + } + } + gss_release_msg(gss_msg); + } +} + +/* + * NOTE: we have the opportunity to use different + * parameters based on the input flavor (which must be a pseudoflavor) + */ +static struct rpc_auth * +gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor) +{ + struct gss_auth *gss_auth; + struct rpc_auth * auth; + int err = -ENOMEM; /* XXX? */ + + dprintk("RPC: creating GSS authenticator for client %p\n", clnt); + + if (!try_module_get(THIS_MODULE)) + return ERR_PTR(err); + if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL))) + goto out_dec; + gss_auth->client = clnt; + err = -EINVAL; + gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor); + if (!gss_auth->mech) { + printk(KERN_WARNING "%s: Pseudoflavor %d not found!\n", + __func__, flavor); + goto err_free; + } + gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor); + if (gss_auth->service == 0) + goto err_put_mech; + auth = &gss_auth->rpc_auth; + auth->au_cslack = GSS_CRED_SLACK >> 2; + auth->au_rslack = GSS_VERF_SLACK >> 2; + auth->au_ops = &authgss_ops; + auth->au_flavor = flavor; + atomic_set(&auth->au_count, 1); + kref_init(&gss_auth->kref); + + gss_auth->dentry = rpc_mkpipe(clnt->cl_dentry, gss_auth->mech->gm_name, + clnt, &gss_upcall_ops, RPC_PIPE_WAIT_FOR_OPEN); + if (IS_ERR(gss_auth->dentry)) { + err = PTR_ERR(gss_auth->dentry); + goto err_put_mech; + } + + err = rpcauth_init_credcache(auth); + if (err) + goto err_unlink_pipe; + + return auth; +err_unlink_pipe: + rpc_unlink(gss_auth->dentry); +err_put_mech: + gss_mech_put(gss_auth->mech); +err_free: + kfree(gss_auth); +out_dec: + module_put(THIS_MODULE); + return ERR_PTR(err); +} + +static void +gss_free(struct gss_auth *gss_auth) +{ + rpc_unlink(gss_auth->dentry); + gss_auth->dentry = NULL; + gss_mech_put(gss_auth->mech); + + kfree(gss_auth); + module_put(THIS_MODULE); +} + +static void +gss_free_callback(struct kref *kref) +{ + struct gss_auth *gss_auth = container_of(kref, struct gss_auth, kref); + + gss_free(gss_auth); +} + +static void +gss_destroy(struct rpc_auth *auth) +{ + struct gss_auth *gss_auth; + + dprintk("RPC: destroying GSS authenticator %p flavor %d\n", + auth, auth->au_flavor); + + rpcauth_destroy_credcache(auth); + + gss_auth = container_of(auth, struct gss_auth, rpc_auth); + kref_put(&gss_auth->kref, gss_free_callback); +} + +/* + * gss_destroying_context will cause the RPCSEC_GSS to send a NULL RPC call + * to the server with the GSS control procedure field set to + * RPC_GSS_PROC_DESTROY. This should normally cause the server to release + * all RPCSEC_GSS state associated with that context. + */ +static int +gss_destroying_context(struct rpc_cred *cred) +{ + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth); + struct rpc_task *task; + + if (gss_cred->gc_ctx == NULL || + test_and_clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0) + return 0; + + gss_cred->gc_ctx->gc_proc = RPC_GSS_PROC_DESTROY; + cred->cr_ops = &gss_nullops; + + /* Take a reference to ensure the cred will be destroyed either + * by the RPC call or by the put_rpccred() below */ + get_rpccred(cred); + + task = rpc_call_null(gss_auth->client, cred, RPC_TASK_ASYNC|RPC_TASK_SOFT); + if (!IS_ERR(task)) + rpc_put_task(task); + + put_rpccred(cred); + return 1; +} + +/* gss_destroy_cred (and gss_free_ctx) are used to clean up after failure + * to create a new cred or context, so they check that things have been + * allocated before freeing them. */ +static void +gss_do_free_ctx(struct gss_cl_ctx *ctx) +{ + dprintk("RPC: gss_free_ctx\n"); + + kfree(ctx->gc_wire_ctx.data); + kfree(ctx); +} + +static void +gss_free_ctx_callback(struct rcu_head *head) +{ + struct gss_cl_ctx *ctx = container_of(head, struct gss_cl_ctx, gc_rcu); + gss_do_free_ctx(ctx); +} + +static void +gss_free_ctx(struct gss_cl_ctx *ctx) +{ + struct gss_ctx *gc_gss_ctx; + + gc_gss_ctx = rcu_dereference(ctx->gc_gss_ctx); + rcu_assign_pointer(ctx->gc_gss_ctx, NULL); + call_rcu(&ctx->gc_rcu, gss_free_ctx_callback); + if (gc_gss_ctx) + gss_delete_sec_context(&gc_gss_ctx); +} + +static void +gss_free_cred(struct gss_cred *gss_cred) +{ + dprintk("RPC: gss_free_cred %p\n", gss_cred); + kfree(gss_cred); +} + +static void +gss_free_cred_callback(struct rcu_head *head) +{ + struct gss_cred *gss_cred = container_of(head, struct gss_cred, gc_base.cr_rcu); + gss_free_cred(gss_cred); +} + +static void +gss_destroy_cred(struct rpc_cred *cred) +{ + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base); + struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth); + struct gss_cl_ctx *ctx = gss_cred->gc_ctx; + + if (gss_destroying_context(cred)) + return; + rcu_assign_pointer(gss_cred->gc_ctx, NULL); + call_rcu(&cred->cr_rcu, gss_free_cred_callback); + if (ctx) + gss_put_ctx(ctx); + kref_put(&gss_auth->kref, gss_free_callback); +} + +/* + * Lookup RPCSEC_GSS cred for the current process + */ +static struct rpc_cred * +gss_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +{ + return rpcauth_lookup_credcache(auth, acred, flags); +} + +static struct rpc_cred * +gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +{ + struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth); + struct gss_cred *cred = NULL; + int err = -ENOMEM; + + dprintk("RPC: gss_create_cred for uid %d, flavor %d\n", + acred->uid, auth->au_flavor); + + if (!(cred = kzalloc(sizeof(*cred), GFP_NOFS))) + goto out_err; + + rpcauth_init_cred(&cred->gc_base, acred, auth, &gss_credops); + /* + * Note: in order to force a call to call_refresh(), we deliberately + * fail to flag the credential as RPCAUTH_CRED_UPTODATE. + */ + cred->gc_base.cr_flags = 1UL << RPCAUTH_CRED_NEW; + cred->gc_service = gss_auth->service; + cred->gc_machine_cred = acred->machine_cred; + kref_get(&gss_auth->kref); + return &cred->gc_base; + +out_err: + dprintk("RPC: gss_create_cred failed with error %d\n", err); + return ERR_PTR(err); +} + +static int +gss_cred_init(struct rpc_auth *auth, struct rpc_cred *cred) +{ + struct gss_auth *gss_auth = container_of(auth, struct gss_auth, rpc_auth); + struct gss_cred *gss_cred = container_of(cred,struct gss_cred, gc_base); + int err; + + do { + err = gss_create_upcall(gss_auth, gss_cred); + } while (err == -EAGAIN); + return err; +} + +static int +gss_match(struct auth_cred *acred, struct rpc_cred *rc, int flags) +{ + struct gss_cred *gss_cred = container_of(rc, struct gss_cred, gc_base); + + if (test_bit(RPCAUTH_CRED_NEW, &rc->cr_flags)) + goto out; + /* Don't match with creds that have expired. */ + if (time_after(jiffies, gss_cred->gc_ctx->gc_expiry)) + return 0; + if (!test_bit(RPCAUTH_CRED_UPTODATE, &rc->cr_flags)) + return 0; +out: + if (acred->machine_cred != gss_cred->gc_machine_cred) + return 0; + return (rc->cr_uid == acred->uid); +} + +/* +* Marshal credentials. +* Maybe we should keep a cached credential for performance reasons. +*/ +static __be32 * +gss_marshal(struct rpc_task *task, __be32 *p) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, + gc_base); + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + __be32 *cred_len; + struct rpc_rqst *req = task->tk_rqstp; + u32 maj_stat = 0; + struct xdr_netobj mic; + struct kvec iov; + struct xdr_buf verf_buf; + + dprintk("RPC: %5u gss_marshal\n", task->tk_pid); + + *p++ = htonl(RPC_AUTH_GSS); + cred_len = p++; + + spin_lock(&ctx->gc_seq_lock); + req->rq_seqno = ctx->gc_seq++; + spin_unlock(&ctx->gc_seq_lock); + + *p++ = htonl((u32) RPC_GSS_VERSION); + *p++ = htonl((u32) ctx->gc_proc); + *p++ = htonl((u32) req->rq_seqno); + *p++ = htonl((u32) gss_cred->gc_service); + p = xdr_encode_netobj(p, &ctx->gc_wire_ctx); + *cred_len = htonl((p - (cred_len + 1)) << 2); + + /* We compute the checksum for the verifier over the xdr-encoded bytes + * starting with the xid and ending at the end of the credential: */ + iov.iov_base = xprt_skip_transport_header(task->tk_xprt, + req->rq_snd_buf.head[0].iov_base); + iov.iov_len = (u8 *)p - (u8 *)iov.iov_base; + xdr_buf_from_iov(&iov, &verf_buf); + + /* set verifier flavor*/ + *p++ = htonl(RPC_AUTH_GSS); + + mic.data = (u8 *)(p + 1); + maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) { + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + } else if (maj_stat != 0) { + printk("gss_marshal: gss_get_mic FAILED (%d)\n", maj_stat); + goto out_put_ctx; + } + p = xdr_encode_opaque(p, NULL, mic.len); + gss_put_ctx(ctx); + return p; +out_put_ctx: + gss_put_ctx(ctx); + return NULL; +} + +static int gss_renew_cred(struct rpc_task *task) +{ + struct rpc_cred *oldcred = task->tk_msg.rpc_cred; + struct gss_cred *gss_cred = container_of(oldcred, + struct gss_cred, + gc_base); + struct rpc_auth *auth = oldcred->cr_auth; + struct auth_cred acred = { + .uid = oldcred->cr_uid, + .machine_cred = gss_cred->gc_machine_cred, + }; + struct rpc_cred *new; + + new = gss_lookup_cred(auth, &acred, RPCAUTH_LOOKUP_NEW); + if (IS_ERR(new)) + return PTR_ERR(new); + task->tk_msg.rpc_cred = new; + put_rpccred(oldcred); + return 0; +} + +/* +* Refresh credentials. XXX - finish +*/ +static int +gss_refresh(struct rpc_task *task) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + int ret = 0; + + if (!test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) && + !test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags)) { + ret = gss_renew_cred(task); + if (ret < 0) + goto out; + cred = task->tk_msg.rpc_cred; + } + + if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags)) + ret = gss_refresh_upcall(task); +out: + return ret; +} + +/* Dummy refresh routine: used only when destroying the context */ +static int +gss_refresh_null(struct rpc_task *task) +{ + return -EACCES; +} + +static __be32 * +gss_validate(struct rpc_task *task, __be32 *p) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + __be32 seq; + struct kvec iov; + struct xdr_buf verf_buf; + struct xdr_netobj mic; + u32 flav,len; + u32 maj_stat; + + dprintk("RPC: %5u gss_validate\n", task->tk_pid); + + flav = ntohl(*p++); + if ((len = ntohl(*p++)) > RPC_MAX_AUTH_SIZE) + goto out_bad; + if (flav != RPC_AUTH_GSS) + goto out_bad; + seq = htonl(task->tk_rqstp->rq_seqno); + iov.iov_base = &seq; + iov.iov_len = sizeof(seq); + xdr_buf_from_iov(&iov, &verf_buf); + mic.data = (u8 *)p; + mic.len = len; + + maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + if (maj_stat) { + dprintk("RPC: %5u gss_validate: gss_verify_mic returned " + "error 0x%08x\n", task->tk_pid, maj_stat); + goto out_bad; + } + /* We leave it to unwrap to calculate au_rslack. For now we just + * calculate the length of the verifier: */ + cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2; + gss_put_ctx(ctx); + dprintk("RPC: %5u gss_validate: gss_verify_mic succeeded.\n", + task->tk_pid); + return p + XDR_QUADLEN(len); +out_bad: + gss_put_ctx(ctx); + dprintk("RPC: %5u gss_validate failed.\n", task->tk_pid); + return NULL; +} + +static inline int +gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx, + kxdrproc_t encode, struct rpc_rqst *rqstp, __be32 *p, void *obj) +{ + struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; + struct xdr_buf integ_buf; + __be32 *integ_len = NULL; + struct xdr_netobj mic; + u32 offset; + __be32 *q; + struct kvec *iov; + u32 maj_stat = 0; + int status = -EIO; + + integ_len = p++; + offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; + *p++ = htonl(rqstp->rq_seqno); + + status = rpc_call_xdrproc(encode, rqstp, p, obj); + if (status) + return status; + + if (xdr_buf_subsegment(snd_buf, &integ_buf, + offset, snd_buf->len - offset)) + return status; + *integ_len = htonl(integ_buf.len); + + /* guess whether we're in the head or the tail: */ + if (snd_buf->page_len || snd_buf->tail[0].iov_len) + iov = snd_buf->tail; + else + iov = snd_buf->head; + p = iov->iov_base + iov->iov_len; + mic.data = (u8 *)(p + 1); + + maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic); + status = -EIO; /* XXX? */ + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + else if (maj_stat) + return status; + q = xdr_encode_opaque(p, NULL, mic.len); + + offset = (u8 *)q - (u8 *)p; + iov->iov_len += offset; + snd_buf->len += offset; + return 0; +} + +static void +priv_release_snd_buf(struct rpc_rqst *rqstp) +{ + int i; + + for (i=0; i < rqstp->rq_enc_pages_num; i++) + __free_page(rqstp->rq_enc_pages[i]); + kfree(rqstp->rq_enc_pages); +} + +static int +alloc_enc_pages(struct rpc_rqst *rqstp) +{ + struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; + int first, last, i; + + if (snd_buf->page_len == 0) { + rqstp->rq_enc_pages_num = 0; + return 0; + } + + first = snd_buf->page_base >> PAGE_CACHE_SHIFT; + last = (snd_buf->page_base + snd_buf->page_len - 1) >> PAGE_CACHE_SHIFT; + rqstp->rq_enc_pages_num = last - first + 1 + 1; + rqstp->rq_enc_pages + = kmalloc(rqstp->rq_enc_pages_num * sizeof(struct page *), + GFP_NOFS); + if (!rqstp->rq_enc_pages) + goto out; + for (i=0; i < rqstp->rq_enc_pages_num; i++) { + rqstp->rq_enc_pages[i] = alloc_page(GFP_NOFS); + if (rqstp->rq_enc_pages[i] == NULL) + goto out_free; + } + rqstp->rq_release_snd_buf = priv_release_snd_buf; + return 0; +out_free: + for (i--; i >= 0; i--) { + __free_page(rqstp->rq_enc_pages[i]); + } +out: + return -EAGAIN; +} + +static inline int +gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, + kxdrproc_t encode, struct rpc_rqst *rqstp, __be32 *p, void *obj) +{ + struct xdr_buf *snd_buf = &rqstp->rq_snd_buf; + u32 offset; + u32 maj_stat; + int status; + __be32 *opaque_len; + struct page **inpages; + int first; + int pad; + struct kvec *iov; + char *tmp; + + opaque_len = p++; + offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base; + *p++ = htonl(rqstp->rq_seqno); + + status = rpc_call_xdrproc(encode, rqstp, p, obj); + if (status) + return status; + + status = alloc_enc_pages(rqstp); + if (status) + return status; + first = snd_buf->page_base >> PAGE_CACHE_SHIFT; + inpages = snd_buf->pages + first; + snd_buf->pages = rqstp->rq_enc_pages; + snd_buf->page_base -= first << PAGE_CACHE_SHIFT; + /* Give the tail its own page, in case we need extra space in the + * head when wrapping: */ + if (snd_buf->page_len || snd_buf->tail[0].iov_len) { + tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]); + memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len); + snd_buf->tail[0].iov_base = tmp; + } + maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages); + /* RPC_SLACK_SPACE should prevent this ever happening: */ + BUG_ON(snd_buf->len > snd_buf->buflen); + status = -EIO; + /* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was + * done anyway, so it's safe to put the request on the wire: */ + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + else if (maj_stat) + return status; + + *opaque_len = htonl(snd_buf->len - offset); + /* guess whether we're in the head or the tail: */ + if (snd_buf->page_len || snd_buf->tail[0].iov_len) + iov = snd_buf->tail; + else + iov = snd_buf->head; + p = iov->iov_base + iov->iov_len; + pad = 3 - ((snd_buf->len - offset - 1) & 3); + memset(p, 0, pad); + iov->iov_len += pad; + snd_buf->len += pad; + + return 0; +} + +static int +gss_wrap_req(struct rpc_task *task, + kxdrproc_t encode, void *rqstp, __be32 *p, void *obj) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, + gc_base); + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + int status = -EIO; + + dprintk("RPC: %5u gss_wrap_req\n", task->tk_pid); + if (ctx->gc_proc != RPC_GSS_PROC_DATA) { + /* The spec seems a little ambiguous here, but I think that not + * wrapping context destruction requests makes the most sense. + */ + status = rpc_call_xdrproc(encode, rqstp, p, obj); + goto out; + } + switch (gss_cred->gc_service) { + case RPC_GSS_SVC_NONE: + status = rpc_call_xdrproc(encode, rqstp, p, obj); + break; + case RPC_GSS_SVC_INTEGRITY: + status = gss_wrap_req_integ(cred, ctx, encode, + rqstp, p, obj); + break; + case RPC_GSS_SVC_PRIVACY: + status = gss_wrap_req_priv(cred, ctx, encode, + rqstp, p, obj); + break; + } +out: + gss_put_ctx(ctx); + dprintk("RPC: %5u gss_wrap_req returning %d\n", task->tk_pid, status); + return status; +} + +static inline int +gss_unwrap_resp_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx, + struct rpc_rqst *rqstp, __be32 **p) +{ + struct xdr_buf *rcv_buf = &rqstp->rq_rcv_buf; + struct xdr_buf integ_buf; + struct xdr_netobj mic; + u32 data_offset, mic_offset; + u32 integ_len; + u32 maj_stat; + int status = -EIO; + + integ_len = ntohl(*(*p)++); + if (integ_len & 3) + return status; + data_offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base; + mic_offset = integ_len + data_offset; + if (mic_offset > rcv_buf->len) + return status; + if (ntohl(*(*p)++) != rqstp->rq_seqno) + return status; + + if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset, + mic_offset - data_offset)) + return status; + + if (xdr_buf_read_netobj(rcv_buf, &mic, mic_offset)) + return status; + + maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf, &mic); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + if (maj_stat != GSS_S_COMPLETE) + return status; + return 0; +} + +static inline int +gss_unwrap_resp_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx, + struct rpc_rqst *rqstp, __be32 **p) +{ + struct xdr_buf *rcv_buf = &rqstp->rq_rcv_buf; + u32 offset; + u32 opaque_len; + u32 maj_stat; + int status = -EIO; + + opaque_len = ntohl(*(*p)++); + offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base; + if (offset + opaque_len > rcv_buf->len) + return status; + /* remove padding: */ + rcv_buf->len = offset + opaque_len; + + maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, rcv_buf); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) + clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags); + if (maj_stat != GSS_S_COMPLETE) + return status; + if (ntohl(*(*p)++) != rqstp->rq_seqno) + return status; + + return 0; +} + + +static int +gss_unwrap_resp(struct rpc_task *task, + kxdrproc_t decode, void *rqstp, __be32 *p, void *obj) +{ + struct rpc_cred *cred = task->tk_msg.rpc_cred; + struct gss_cred *gss_cred = container_of(cred, struct gss_cred, + gc_base); + struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred); + __be32 *savedp = p; + struct kvec *head = ((struct rpc_rqst *)rqstp)->rq_rcv_buf.head; + int savedlen = head->iov_len; + int status = -EIO; + + if (ctx->gc_proc != RPC_GSS_PROC_DATA) + goto out_decode; + switch (gss_cred->gc_service) { + case RPC_GSS_SVC_NONE: + break; + case RPC_GSS_SVC_INTEGRITY: + status = gss_unwrap_resp_integ(cred, ctx, rqstp, &p); + if (status) + goto out; + break; + case RPC_GSS_SVC_PRIVACY: + status = gss_unwrap_resp_priv(cred, ctx, rqstp, &p); + if (status) + goto out; + break; + } + /* take into account extra slack for integrity and privacy cases: */ + cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp) + + (savedlen - head->iov_len); +out_decode: + status = rpc_call_xdrproc(decode, rqstp, p, obj); +out: + gss_put_ctx(ctx); + dprintk("RPC: %5u gss_unwrap_resp returning %d\n", task->tk_pid, + status); + return status; +} + +static const struct rpc_authops authgss_ops = { + .owner = THIS_MODULE, + .au_flavor = RPC_AUTH_GSS, + .au_name = "RPCSEC_GSS", + .create = gss_create, + .destroy = gss_destroy, + .lookup_cred = gss_lookup_cred, + .crcreate = gss_create_cred +}; + +static const struct rpc_credops gss_credops = { + .cr_name = "AUTH_GSS", + .crdestroy = gss_destroy_cred, + .cr_init = gss_cred_init, + .crbind = rpcauth_generic_bind_cred, + .crmatch = gss_match, + .crmarshal = gss_marshal, + .crrefresh = gss_refresh, + .crvalidate = gss_validate, + .crwrap_req = gss_wrap_req, + .crunwrap_resp = gss_unwrap_resp, +}; + +static const struct rpc_credops gss_nullops = { + .cr_name = "AUTH_GSS", + .crdestroy = gss_destroy_cred, + .crbind = rpcauth_generic_bind_cred, + .crmatch = gss_match, + .crmarshal = gss_marshal, + .crrefresh = gss_refresh_null, + .crvalidate = gss_validate, + .crwrap_req = gss_wrap_req, + .crunwrap_resp = gss_unwrap_resp, +}; + +static struct rpc_pipe_ops gss_upcall_ops = { + .upcall = gss_pipe_upcall, + .downcall = gss_pipe_downcall, + .destroy_msg = gss_pipe_destroy_msg, + .release_pipe = gss_pipe_release, +}; + +/* + * Initialize RPCSEC_GSS module + */ +static int __init init_rpcsec_gss(void) +{ + int err = 0; + + err = rpcauth_register(&authgss_ops); + if (err) + goto out; + err = gss_svc_init(); + if (err) + goto out_unregister; + return 0; +out_unregister: + rpcauth_unregister(&authgss_ops); +out: + return err; +} + +static void __exit exit_rpcsec_gss(void) +{ + gss_svc_shutdown(); + rpcauth_unregister(&authgss_ops); +} + +MODULE_LICENSE("GPL"); +module_init(init_rpcsec_gss) +module_exit(exit_rpcsec_gss) diff --git a/net/sunrpc/auth_gss/gss_generic_token.c b/net/sunrpc/auth_gss/gss_generic_token.c new file mode 100644 index 0000000..d83b881 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_generic_token.c @@ -0,0 +1,235 @@ +/* + * linux/net/sunrpc/gss_generic_token.c + * + * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/generic/util_token.c + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + */ + +/* + * Copyright 1993 by OpenVision Technologies, Inc. + * + * Permission to use, copy, modify, distribute, and sell this software + * and its documentation for any purpose is hereby granted without fee, + * provided that the above copyright notice appears in all copies and + * that both that copyright notice and this permission notice appear in + * supporting documentation, and that the name of OpenVision not be used + * in advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. OpenVision makes no + * representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, + * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO + * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR + * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF + * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR + * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +#include <linux/types.h> +#include <linux/module.h> +#include <linux/slab.h> +#include <linux/string.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/gss_asn1.h> + + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + + +/* TWRITE_STR from gssapiP_generic.h */ +#define TWRITE_STR(ptr, str, len) \ + memcpy((ptr), (char *) (str), (len)); \ + (ptr) += (len); + +/* XXXX this code currently makes the assumption that a mech oid will + never be longer than 127 bytes. This assumption is not inherent in + the interfaces, so the code can be fixed if the OSI namespace + balloons unexpectedly. */ + +/* Each token looks like this: + +0x60 tag for APPLICATION 0, SEQUENCE + (constructed, definite-length) + <length> possible multiple bytes, need to parse/generate + 0x06 tag for OBJECT IDENTIFIER + <moid_length> compile-time constant string (assume 1 byte) + <moid_bytes> compile-time constant string + <inner_bytes> the ANY containing the application token + bytes 0,1 are the token type + bytes 2,n are the token data + +For the purposes of this abstraction, the token "header" consists of +the sequence tag and length octets, the mech OID DER encoding, and the +first two inner bytes, which indicate the token type. The token +"body" consists of everything else. + +*/ + +static int +der_length_size( int length) +{ + if (length < (1<<7)) + return(1); + else if (length < (1<<8)) + return(2); +#if (SIZEOF_INT == 2) + else + return(3); +#else + else if (length < (1<<16)) + return(3); + else if (length < (1<<24)) + return(4); + else + return(5); +#endif +} + +static void +der_write_length(unsigned char **buf, int length) +{ + if (length < (1<<7)) { + *(*buf)++ = (unsigned char) length; + } else { + *(*buf)++ = (unsigned char) (der_length_size(length)+127); +#if (SIZEOF_INT > 2) + if (length >= (1<<24)) + *(*buf)++ = (unsigned char) (length>>24); + if (length >= (1<<16)) + *(*buf)++ = (unsigned char) ((length>>16)&0xff); +#endif + if (length >= (1<<8)) + *(*buf)++ = (unsigned char) ((length>>8)&0xff); + *(*buf)++ = (unsigned char) (length&0xff); + } +} + +/* returns decoded length, or < 0 on failure. Advances buf and + decrements bufsize */ + +static int +der_read_length(unsigned char **buf, int *bufsize) +{ + unsigned char sf; + int ret; + + if (*bufsize < 1) + return(-1); + sf = *(*buf)++; + (*bufsize)--; + if (sf & 0x80) { + if ((sf &= 0x7f) > ((*bufsize)-1)) + return(-1); + if (sf > SIZEOF_INT) + return (-1); + ret = 0; + for (; sf; sf--) { + ret = (ret<<8) + (*(*buf)++); + (*bufsize)--; + } + } else { + ret = sf; + } + + return(ret); +} + +/* returns the length of a token, given the mech oid and the body size */ + +int +g_token_size(struct xdr_netobj *mech, unsigned int body_size) +{ + /* set body_size to sequence contents size */ + body_size += 2 + (int) mech->len; /* NEED overflow check */ + return(1 + der_length_size(body_size) + body_size); +} + +EXPORT_SYMBOL(g_token_size); + +/* fills in a buffer with the token header. The buffer is assumed to + be the right size. buf is advanced past the token header */ + +void +g_make_token_header(struct xdr_netobj *mech, int body_size, unsigned char **buf) +{ + *(*buf)++ = 0x60; + der_write_length(buf, 2 + mech->len + body_size); + *(*buf)++ = 0x06; + *(*buf)++ = (unsigned char) mech->len; + TWRITE_STR(*buf, mech->data, ((int) mech->len)); +} + +EXPORT_SYMBOL(g_make_token_header); + +/* + * Given a buffer containing a token, reads and verifies the token, + * leaving buf advanced past the token header, and setting body_size + * to the number of remaining bytes. Returns 0 on success, + * G_BAD_TOK_HEADER for a variety of errors, and G_WRONG_MECH if the + * mechanism in the token does not match the mech argument. buf and + * *body_size are left unmodified on error. + */ +u32 +g_verify_token_header(struct xdr_netobj *mech, int *body_size, + unsigned char **buf_in, int toksize) +{ + unsigned char *buf = *buf_in; + int seqsize; + struct xdr_netobj toid; + int ret = 0; + + if ((toksize-=1) < 0) + return(G_BAD_TOK_HEADER); + if (*buf++ != 0x60) + return(G_BAD_TOK_HEADER); + + if ((seqsize = der_read_length(&buf, &toksize)) < 0) + return(G_BAD_TOK_HEADER); + + if (seqsize != toksize) + return(G_BAD_TOK_HEADER); + + if ((toksize-=1) < 0) + return(G_BAD_TOK_HEADER); + if (*buf++ != 0x06) + return(G_BAD_TOK_HEADER); + + if ((toksize-=1) < 0) + return(G_BAD_TOK_HEADER); + toid.len = *buf++; + + if ((toksize-=toid.len) < 0) + return(G_BAD_TOK_HEADER); + toid.data = buf; + buf+=toid.len; + + if (! g_OID_equal(&toid, mech)) + ret = G_WRONG_MECH; + + /* G_WRONG_MECH is not returned immediately because it's more important + to return G_BAD_TOK_HEADER if the token header is in fact bad */ + + if ((toksize-=2) < 0) + return(G_BAD_TOK_HEADER); + + if (ret) + return(ret); + + if (!ret) { + *buf_in = buf; + *body_size = toksize; + } + + return(ret); +} + +EXPORT_SYMBOL(g_verify_token_header); + diff --git a/net/sunrpc/auth_gss/gss_krb5_crypto.c b/net/sunrpc/auth_gss/gss_krb5_crypto.c new file mode 100644 index 0000000..c93fca2 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_crypto.c @@ -0,0 +1,328 @@ +/* + * linux/net/sunrpc/gss_krb5_crypto.c + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + * Bruce Fields <bfields@umich.edu> + */ + +/* + * Copyright (C) 1998 by the FundsXpress, INC. + * + * All rights reserved. + * + * Export of this software from the United States of America may require + * a specific license from the United States Government. It is the + * responsibility of any person or organization contemplating export to + * obtain such a license before exporting. + * + * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and + * distribute this software and its documentation for any purpose and + * without fee is hereby granted, provided that the above copyright + * notice appear in all copies and that both that copyright notice and + * this permission notice appear in supporting documentation, and that + * the name of FundsXpress. not be used in advertising or publicity pertaining + * to distribution of the software without specific, written prior + * permission. FundsXpress makes no representations about the suitability of + * this software for any purpose. It is provided "as is" without express + * or implied warranty. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED + * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE. + */ + +#include <linux/err.h> +#include <linux/types.h> +#include <linux/mm.h> +#include <linux/slab.h> +#include <linux/scatterlist.h> +#include <linux/crypto.h> +#include <linux/highmem.h> +#include <linux/pagemap.h> +#include <linux/sunrpc/gss_krb5.h> +#include <linux/sunrpc/xdr.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +u32 +krb5_encrypt( + struct crypto_blkcipher *tfm, + void * iv, + void * in, + void * out, + int length) +{ + u32 ret = -EINVAL; + struct scatterlist sg[1]; + u8 local_iv[16] = {0}; + struct blkcipher_desc desc = { .tfm = tfm, .info = local_iv }; + + if (length % crypto_blkcipher_blocksize(tfm) != 0) + goto out; + + if (crypto_blkcipher_ivsize(tfm) > 16) { + dprintk("RPC: gss_k5encrypt: tfm iv size too large %d\n", + crypto_blkcipher_ivsize(tfm)); + goto out; + } + + if (iv) + memcpy(local_iv, iv, crypto_blkcipher_ivsize(tfm)); + + memcpy(out, in, length); + sg_init_one(sg, out, length); + + ret = crypto_blkcipher_encrypt_iv(&desc, sg, sg, length); +out: + dprintk("RPC: krb5_encrypt returns %d\n", ret); + return ret; +} + +u32 +krb5_decrypt( + struct crypto_blkcipher *tfm, + void * iv, + void * in, + void * out, + int length) +{ + u32 ret = -EINVAL; + struct scatterlist sg[1]; + u8 local_iv[16] = {0}; + struct blkcipher_desc desc = { .tfm = tfm, .info = local_iv }; + + if (length % crypto_blkcipher_blocksize(tfm) != 0) + goto out; + + if (crypto_blkcipher_ivsize(tfm) > 16) { + dprintk("RPC: gss_k5decrypt: tfm iv size too large %d\n", + crypto_blkcipher_ivsize(tfm)); + goto out; + } + if (iv) + memcpy(local_iv,iv, crypto_blkcipher_ivsize(tfm)); + + memcpy(out, in, length); + sg_init_one(sg, out, length); + + ret = crypto_blkcipher_decrypt_iv(&desc, sg, sg, length); +out: + dprintk("RPC: gss_k5decrypt returns %d\n",ret); + return ret; +} + +static int +checksummer(struct scatterlist *sg, void *data) +{ + struct hash_desc *desc = data; + + return crypto_hash_update(desc, sg, sg->length); +} + +/* checksum the plaintext data and hdrlen bytes of the token header */ +s32 +make_checksum(char *cksumname, char *header, int hdrlen, struct xdr_buf *body, + int body_offset, struct xdr_netobj *cksum) +{ + struct hash_desc desc; /* XXX add to ctx? */ + struct scatterlist sg[1]; + int err; + + desc.tfm = crypto_alloc_hash(cksumname, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(desc.tfm)) + return GSS_S_FAILURE; + cksum->len = crypto_hash_digestsize(desc.tfm); + desc.flags = CRYPTO_TFM_REQ_MAY_SLEEP; + + err = crypto_hash_init(&desc); + if (err) + goto out; + sg_init_one(sg, header, hdrlen); + err = crypto_hash_update(&desc, sg, hdrlen); + if (err) + goto out; + err = xdr_process_buf(body, body_offset, body->len - body_offset, + checksummer, &desc); + if (err) + goto out; + err = crypto_hash_final(&desc, cksum->data); + +out: + crypto_free_hash(desc.tfm); + return err ? GSS_S_FAILURE : 0; +} + +struct encryptor_desc { + u8 iv[8]; /* XXX hard-coded blocksize */ + struct blkcipher_desc desc; + int pos; + struct xdr_buf *outbuf; + struct page **pages; + struct scatterlist infrags[4]; + struct scatterlist outfrags[4]; + int fragno; + int fraglen; +}; + +static int +encryptor(struct scatterlist *sg, void *data) +{ + struct encryptor_desc *desc = data; + struct xdr_buf *outbuf = desc->outbuf; + struct page *in_page; + int thislen = desc->fraglen + sg->length; + int fraglen, ret; + int page_pos; + + /* Worst case is 4 fragments: head, end of page 1, start + * of page 2, tail. Anything more is a bug. */ + BUG_ON(desc->fragno > 3); + + page_pos = desc->pos - outbuf->head[0].iov_len; + if (page_pos >= 0 && page_pos < outbuf->page_len) { + /* pages are not in place: */ + int i = (page_pos + outbuf->page_base) >> PAGE_CACHE_SHIFT; + in_page = desc->pages[i]; + } else { + in_page = sg_page(sg); + } + sg_set_page(&desc->infrags[desc->fragno], in_page, sg->length, + sg->offset); + sg_set_page(&desc->outfrags[desc->fragno], sg_page(sg), sg->length, + sg->offset); + desc->fragno++; + desc->fraglen += sg->length; + desc->pos += sg->length; + + fraglen = thislen & 7; /* XXX hardcoded blocksize */ + thislen -= fraglen; + + if (thislen == 0) + return 0; + + sg_mark_end(&desc->infrags[desc->fragno - 1]); + sg_mark_end(&desc->outfrags[desc->fragno - 1]); + + ret = crypto_blkcipher_encrypt_iv(&desc->desc, desc->outfrags, + desc->infrags, thislen); + if (ret) + return ret; + + sg_init_table(desc->infrags, 4); + sg_init_table(desc->outfrags, 4); + + if (fraglen) { + sg_set_page(&desc->outfrags[0], sg_page(sg), fraglen, + sg->offset + sg->length - fraglen); + desc->infrags[0] = desc->outfrags[0]; + sg_assign_page(&desc->infrags[0], in_page); + desc->fragno = 1; + desc->fraglen = fraglen; + } else { + desc->fragno = 0; + desc->fraglen = 0; + } + return 0; +} + +int +gss_encrypt_xdr_buf(struct crypto_blkcipher *tfm, struct xdr_buf *buf, + int offset, struct page **pages) +{ + int ret; + struct encryptor_desc desc; + + BUG_ON((buf->len - offset) % crypto_blkcipher_blocksize(tfm) != 0); + + memset(desc.iv, 0, sizeof(desc.iv)); + desc.desc.tfm = tfm; + desc.desc.info = desc.iv; + desc.desc.flags = 0; + desc.pos = offset; + desc.outbuf = buf; + desc.pages = pages; + desc.fragno = 0; + desc.fraglen = 0; + + sg_init_table(desc.infrags, 4); + sg_init_table(desc.outfrags, 4); + + ret = xdr_process_buf(buf, offset, buf->len - offset, encryptor, &desc); + return ret; +} + +struct decryptor_desc { + u8 iv[8]; /* XXX hard-coded blocksize */ + struct blkcipher_desc desc; + struct scatterlist frags[4]; + int fragno; + int fraglen; +}; + +static int +decryptor(struct scatterlist *sg, void *data) +{ + struct decryptor_desc *desc = data; + int thislen = desc->fraglen + sg->length; + int fraglen, ret; + + /* Worst case is 4 fragments: head, end of page 1, start + * of page 2, tail. Anything more is a bug. */ + BUG_ON(desc->fragno > 3); + sg_set_page(&desc->frags[desc->fragno], sg_page(sg), sg->length, + sg->offset); + desc->fragno++; + desc->fraglen += sg->length; + + fraglen = thislen & 7; /* XXX hardcoded blocksize */ + thislen -= fraglen; + + if (thislen == 0) + return 0; + + sg_mark_end(&desc->frags[desc->fragno - 1]); + + ret = crypto_blkcipher_decrypt_iv(&desc->desc, desc->frags, + desc->frags, thislen); + if (ret) + return ret; + + sg_init_table(desc->frags, 4); + + if (fraglen) { + sg_set_page(&desc->frags[0], sg_page(sg), fraglen, + sg->offset + sg->length - fraglen); + desc->fragno = 1; + desc->fraglen = fraglen; + } else { + desc->fragno = 0; + desc->fraglen = 0; + } + return 0; +} + +int +gss_decrypt_xdr_buf(struct crypto_blkcipher *tfm, struct xdr_buf *buf, + int offset) +{ + struct decryptor_desc desc; + + /* XXXJBF: */ + BUG_ON((buf->len - offset) % crypto_blkcipher_blocksize(tfm) != 0); + + memset(desc.iv, 0, sizeof(desc.iv)); + desc.desc.tfm = tfm; + desc.desc.info = desc.iv; + desc.desc.flags = 0; + desc.fragno = 0; + desc.fraglen = 0; + + sg_init_table(desc.frags, 4); + + return xdr_process_buf(buf, offset, buf->len - offset, decryptor, &desc); +} diff --git a/net/sunrpc/auth_gss/gss_krb5_mech.c b/net/sunrpc/auth_gss/gss_krb5_mech.c new file mode 100644 index 0000000..ef45eba --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_mech.c @@ -0,0 +1,261 @@ +/* + * linux/net/sunrpc/gss_krb5_mech.c + * + * Copyright (c) 2001 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + * J. Bruce Fields <bfields@umich.edu> + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the University nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include <linux/err.h> +#include <linux/module.h> +#include <linux/init.h> +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/sunrpc/auth.h> +#include <linux/sunrpc/gss_krb5.h> +#include <linux/sunrpc/xdr.h> +#include <linux/crypto.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static const void * +simple_get_bytes(const void *p, const void *end, void *res, int len) +{ + const void *q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); + memcpy(res, p, len); + return q; +} + +static const void * +simple_get_netobj(const void *p, const void *end, struct xdr_netobj *res) +{ + const void *q; + unsigned int len; + + p = simple_get_bytes(p, end, &len, sizeof(len)); + if (IS_ERR(p)) + return p; + q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); + res->data = kmemdup(p, len, GFP_NOFS); + if (unlikely(res->data == NULL)) + return ERR_PTR(-ENOMEM); + res->len = len; + return q; +} + +static inline const void * +get_key(const void *p, const void *end, struct crypto_blkcipher **res) +{ + struct xdr_netobj key; + int alg; + char *alg_name; + + p = simple_get_bytes(p, end, &alg, sizeof(alg)); + if (IS_ERR(p)) + goto out_err; + p = simple_get_netobj(p, end, &key); + if (IS_ERR(p)) + goto out_err; + + switch (alg) { + case ENCTYPE_DES_CBC_RAW: + alg_name = "cbc(des)"; + break; + default: + printk("gss_kerberos_mech: unsupported algorithm %d\n", alg); + goto out_err_free_key; + } + *res = crypto_alloc_blkcipher(alg_name, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(*res)) { + printk("gss_kerberos_mech: unable to initialize crypto algorithm %s\n", alg_name); + *res = NULL; + goto out_err_free_key; + } + if (crypto_blkcipher_setkey(*res, key.data, key.len)) { + printk("gss_kerberos_mech: error setting key for crypto algorithm %s\n", alg_name); + goto out_err_free_tfm; + } + + kfree(key.data); + return p; + +out_err_free_tfm: + crypto_free_blkcipher(*res); +out_err_free_key: + kfree(key.data); + p = ERR_PTR(-EINVAL); +out_err: + return p; +} + +static int +gss_import_sec_context_kerberos(const void *p, + size_t len, + struct gss_ctx *ctx_id) +{ + const void *end = (const void *)((const char *)p + len); + struct krb5_ctx *ctx; + int tmp; + + if (!(ctx = kzalloc(sizeof(*ctx), GFP_NOFS))) + goto out_err; + + p = simple_get_bytes(p, end, &ctx->initiate, sizeof(ctx->initiate)); + if (IS_ERR(p)) + goto out_err_free_ctx; + /* The downcall format was designed before we completely understood + * the uses of the context fields; so it includes some stuff we + * just give some minimal sanity-checking, and some we ignore + * completely (like the next twenty bytes): */ + if (unlikely(p + 20 > end || p + 20 < p)) + goto out_err_free_ctx; + p += 20; + p = simple_get_bytes(p, end, &tmp, sizeof(tmp)); + if (IS_ERR(p)) + goto out_err_free_ctx; + if (tmp != SGN_ALG_DES_MAC_MD5) { + p = ERR_PTR(-ENOSYS); + goto out_err_free_ctx; + } + p = simple_get_bytes(p, end, &tmp, sizeof(tmp)); + if (IS_ERR(p)) + goto out_err_free_ctx; + if (tmp != SEAL_ALG_DES) { + p = ERR_PTR(-ENOSYS); + goto out_err_free_ctx; + } + p = simple_get_bytes(p, end, &ctx->endtime, sizeof(ctx->endtime)); + if (IS_ERR(p)) + goto out_err_free_ctx; + p = simple_get_bytes(p, end, &ctx->seq_send, sizeof(ctx->seq_send)); + if (IS_ERR(p)) + goto out_err_free_ctx; + p = simple_get_netobj(p, end, &ctx->mech_used); + if (IS_ERR(p)) + goto out_err_free_ctx; + p = get_key(p, end, &ctx->enc); + if (IS_ERR(p)) + goto out_err_free_mech; + p = get_key(p, end, &ctx->seq); + if (IS_ERR(p)) + goto out_err_free_key1; + if (p != end) { + p = ERR_PTR(-EFAULT); + goto out_err_free_key2; + } + + ctx_id->internal_ctx_id = ctx; + + dprintk("RPC: Successfully imported new context.\n"); + return 0; + +out_err_free_key2: + crypto_free_blkcipher(ctx->seq); +out_err_free_key1: + crypto_free_blkcipher(ctx->enc); +out_err_free_mech: + kfree(ctx->mech_used.data); +out_err_free_ctx: + kfree(ctx); +out_err: + return PTR_ERR(p); +} + +static void +gss_delete_sec_context_kerberos(void *internal_ctx) { + struct krb5_ctx *kctx = internal_ctx; + + crypto_free_blkcipher(kctx->seq); + crypto_free_blkcipher(kctx->enc); + kfree(kctx->mech_used.data); + kfree(kctx); +} + +static const struct gss_api_ops gss_kerberos_ops = { + .gss_import_sec_context = gss_import_sec_context_kerberos, + .gss_get_mic = gss_get_mic_kerberos, + .gss_verify_mic = gss_verify_mic_kerberos, + .gss_wrap = gss_wrap_kerberos, + .gss_unwrap = gss_unwrap_kerberos, + .gss_delete_sec_context = gss_delete_sec_context_kerberos, +}; + +static struct pf_desc gss_kerberos_pfs[] = { + [0] = { + .pseudoflavor = RPC_AUTH_GSS_KRB5, + .service = RPC_GSS_SVC_NONE, + .name = "krb5", + }, + [1] = { + .pseudoflavor = RPC_AUTH_GSS_KRB5I, + .service = RPC_GSS_SVC_INTEGRITY, + .name = "krb5i", + }, + [2] = { + .pseudoflavor = RPC_AUTH_GSS_KRB5P, + .service = RPC_GSS_SVC_PRIVACY, + .name = "krb5p", + }, +}; + +static struct gss_api_mech gss_kerberos_mech = { + .gm_name = "krb5", + .gm_owner = THIS_MODULE, + .gm_oid = {9, (void *)"\x2a\x86\x48\x86\xf7\x12\x01\x02\x02"}, + .gm_ops = &gss_kerberos_ops, + .gm_pf_num = ARRAY_SIZE(gss_kerberos_pfs), + .gm_pfs = gss_kerberos_pfs, +}; + +static int __init init_kerberos_module(void) +{ + int status; + + status = gss_mech_register(&gss_kerberos_mech); + if (status) + printk("Failed to register kerberos gss mechanism!\n"); + return status; +} + +static void __exit cleanup_kerberos_module(void) +{ + gss_mech_unregister(&gss_kerberos_mech); +} + +MODULE_LICENSE("GPL"); +module_init(init_kerberos_module); +module_exit(cleanup_kerberos_module); diff --git a/net/sunrpc/auth_gss/gss_krb5_seal.c b/net/sunrpc/auth_gss/gss_krb5_seal.c new file mode 100644 index 0000000..b8f42ef --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_seal.c @@ -0,0 +1,123 @@ +/* + * linux/net/sunrpc/gss_krb5_seal.c + * + * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/krb5/k5seal.c + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + * J. Bruce Fields <bfields@umich.edu> + */ + +/* + * Copyright 1993 by OpenVision Technologies, Inc. + * + * Permission to use, copy, modify, distribute, and sell this software + * and its documentation for any purpose is hereby granted without fee, + * provided that the above copyright notice appears in all copies and + * that both that copyright notice and this permission notice appear in + * supporting documentation, and that the name of OpenVision not be used + * in advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. OpenVision makes no + * representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, + * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO + * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR + * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF + * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR + * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +/* + * Copyright (C) 1998 by the FundsXpress, INC. + * + * All rights reserved. + * + * Export of this software from the United States of America may require + * a specific license from the United States Government. It is the + * responsibility of any person or organization contemplating export to + * obtain such a license before exporting. + * + * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and + * distribute this software and its documentation for any purpose and + * without fee is hereby granted, provided that the above copyright + * notice appear in all copies and that both that copyright notice and + * this permission notice appear in supporting documentation, and that + * the name of FundsXpress. not be used in advertising or publicity pertaining + * to distribution of the software without specific, written prior + * permission. FundsXpress makes no representations about the suitability of + * this software for any purpose. It is provided "as is" without express + * or implied warranty. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED + * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE. + */ + +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/jiffies.h> +#include <linux/sunrpc/gss_krb5.h> +#include <linux/random.h> +#include <linux/crypto.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +DEFINE_SPINLOCK(krb5_seq_lock); + +u32 +gss_get_mic_kerberos(struct gss_ctx *gss_ctx, struct xdr_buf *text, + struct xdr_netobj *token) +{ + struct krb5_ctx *ctx = gss_ctx->internal_ctx_id; + char cksumdata[16]; + struct xdr_netobj md5cksum = {.len = 0, .data = cksumdata}; + unsigned char *ptr, *msg_start; + s32 now; + u32 seq_send; + + dprintk("RPC: gss_krb5_seal\n"); + BUG_ON(ctx == NULL); + + now = get_seconds(); + + token->len = g_token_size(&ctx->mech_used, GSS_KRB5_TOK_HDR_LEN + 8); + + ptr = token->data; + g_make_token_header(&ctx->mech_used, GSS_KRB5_TOK_HDR_LEN + 8, &ptr); + + /* ptr now at header described in rfc 1964, section 1.2.1: */ + ptr[0] = (unsigned char) ((KG_TOK_MIC_MSG >> 8) & 0xff); + ptr[1] = (unsigned char) (KG_TOK_MIC_MSG & 0xff); + + msg_start = ptr + GSS_KRB5_TOK_HDR_LEN + 8; + + *(__be16 *)(ptr + 2) = htons(SGN_ALG_DES_MAC_MD5); + memset(ptr + 4, 0xff, 4); + + if (make_checksum("md5", ptr, 8, text, 0, &md5cksum)) + return GSS_S_FAILURE; + + if (krb5_encrypt(ctx->seq, NULL, md5cksum.data, + md5cksum.data, md5cksum.len)) + return GSS_S_FAILURE; + + memcpy(ptr + GSS_KRB5_TOK_HDR_LEN, md5cksum.data + md5cksum.len - 8, 8); + + spin_lock(&krb5_seq_lock); + seq_send = ctx->seq_send++; + spin_unlock(&krb5_seq_lock); + + if (krb5_make_seq_num(ctx->seq, ctx->initiate ? 0 : 0xff, + seq_send, ptr + GSS_KRB5_TOK_HDR_LEN, + ptr + 8)) + return GSS_S_FAILURE; + + return (ctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; +} diff --git a/net/sunrpc/auth_gss/gss_krb5_seqnum.c b/net/sunrpc/auth_gss/gss_krb5_seqnum.c new file mode 100644 index 0000000..f160be6 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_seqnum.c @@ -0,0 +1,88 @@ +/* + * linux/net/sunrpc/gss_krb5_seqnum.c + * + * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/krb5/util_seqnum.c + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + */ + +/* + * Copyright 1993 by OpenVision Technologies, Inc. + * + * Permission to use, copy, modify, distribute, and sell this software + * and its documentation for any purpose is hereby granted without fee, + * provided that the above copyright notice appears in all copies and + * that both that copyright notice and this permission notice appear in + * supporting documentation, and that the name of OpenVision not be used + * in advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. OpenVision makes no + * representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, + * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO + * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR + * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF + * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR + * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/sunrpc/gss_krb5.h> +#include <linux/crypto.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +s32 +krb5_make_seq_num(struct crypto_blkcipher *key, + int direction, + u32 seqnum, + unsigned char *cksum, unsigned char *buf) +{ + unsigned char plain[8]; + + plain[0] = (unsigned char) (seqnum & 0xff); + plain[1] = (unsigned char) ((seqnum >> 8) & 0xff); + plain[2] = (unsigned char) ((seqnum >> 16) & 0xff); + plain[3] = (unsigned char) ((seqnum >> 24) & 0xff); + + plain[4] = direction; + plain[5] = direction; + plain[6] = direction; + plain[7] = direction; + + return krb5_encrypt(key, cksum, plain, buf, 8); +} + +s32 +krb5_get_seq_num(struct crypto_blkcipher *key, + unsigned char *cksum, + unsigned char *buf, + int *direction, u32 *seqnum) +{ + s32 code; + unsigned char plain[8]; + + dprintk("RPC: krb5_get_seq_num:\n"); + + if ((code = krb5_decrypt(key, cksum, buf, plain, 8))) + return code; + + if ((plain[4] != plain[5]) || (plain[4] != plain[6]) + || (plain[4] != plain[7])) + return (s32)KG_BAD_SEQ; + + *direction = plain[4]; + + *seqnum = ((plain[0]) | + (plain[1] << 8) | (plain[2] << 16) | (plain[3] << 24)); + + return (0); +} diff --git a/net/sunrpc/auth_gss/gss_krb5_unseal.c b/net/sunrpc/auth_gss/gss_krb5_unseal.c new file mode 100644 index 0000000..066ec73 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_unseal.c @@ -0,0 +1,138 @@ +/* + * linux/net/sunrpc/gss_krb5_unseal.c + * + * Adapted from MIT Kerberos 5-1.2.1 lib/gssapi/krb5/k5unseal.c + * + * Copyright (c) 2000 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + */ + +/* + * Copyright 1993 by OpenVision Technologies, Inc. + * + * Permission to use, copy, modify, distribute, and sell this software + * and its documentation for any purpose is hereby granted without fee, + * provided that the above copyright notice appears in all copies and + * that both that copyright notice and this permission notice appear in + * supporting documentation, and that the name of OpenVision not be used + * in advertising or publicity pertaining to distribution of the software + * without specific, written prior permission. OpenVision makes no + * representations about the suitability of this software for any + * purpose. It is provided "as is" without express or implied warranty. + * + * OPENVISION DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, + * INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO + * EVENT SHALL OPENVISION BE LIABLE FOR ANY SPECIAL, INDIRECT OR + * CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF + * USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR + * OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + * PERFORMANCE OF THIS SOFTWARE. + */ + +/* + * Copyright (C) 1998 by the FundsXpress, INC. + * + * All rights reserved. + * + * Export of this software from the United States of America may require + * a specific license from the United States Government. It is the + * responsibility of any person or organization contemplating export to + * obtain such a license before exporting. + * + * WITHIN THAT CONSTRAINT, permission to use, copy, modify, and + * distribute this software and its documentation for any purpose and + * without fee is hereby granted, provided that the above copyright + * notice appear in all copies and that both that copyright notice and + * this permission notice appear in supporting documentation, and that + * the name of FundsXpress. not be used in advertising or publicity pertaining + * to distribution of the software without specific, written prior + * permission. FundsXpress makes no representations about the suitability of + * this software for any purpose. It is provided "as is" without express + * or implied warranty. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND WITHOUT ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, WITHOUT LIMITATION, THE IMPLIED + * WARRANTIES OF MERCHANTIBILITY AND FITNESS FOR A PARTICULAR PURPOSE. + */ + +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/jiffies.h> +#include <linux/sunrpc/gss_krb5.h> +#include <linux/crypto.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + + +/* read_token is a mic token, and message_buffer is the data that the mic was + * supposedly taken over. */ + +u32 +gss_verify_mic_kerberos(struct gss_ctx *gss_ctx, + struct xdr_buf *message_buffer, struct xdr_netobj *read_token) +{ + struct krb5_ctx *ctx = gss_ctx->internal_ctx_id; + int signalg; + int sealalg; + char cksumdata[16]; + struct xdr_netobj md5cksum = {.len = 0, .data = cksumdata}; + s32 now; + int direction; + u32 seqnum; + unsigned char *ptr = (unsigned char *)read_token->data; + int bodysize; + + dprintk("RPC: krb5_read_token\n"); + + if (g_verify_token_header(&ctx->mech_used, &bodysize, &ptr, + read_token->len)) + return GSS_S_DEFECTIVE_TOKEN; + + if ((ptr[0] != ((KG_TOK_MIC_MSG >> 8) & 0xff)) || + (ptr[1] != (KG_TOK_MIC_MSG & 0xff))) + return GSS_S_DEFECTIVE_TOKEN; + + /* XXX sanity-check bodysize?? */ + + signalg = ptr[2] + (ptr[3] << 8); + if (signalg != SGN_ALG_DES_MAC_MD5) + return GSS_S_DEFECTIVE_TOKEN; + + sealalg = ptr[4] + (ptr[5] << 8); + if (sealalg != SEAL_ALG_NONE) + return GSS_S_DEFECTIVE_TOKEN; + + if ((ptr[6] != 0xff) || (ptr[7] != 0xff)) + return GSS_S_DEFECTIVE_TOKEN; + + if (make_checksum("md5", ptr, 8, message_buffer, 0, &md5cksum)) + return GSS_S_FAILURE; + + if (krb5_encrypt(ctx->seq, NULL, md5cksum.data, md5cksum.data, 16)) + return GSS_S_FAILURE; + + if (memcmp(md5cksum.data + 8, ptr + GSS_KRB5_TOK_HDR_LEN, 8)) + return GSS_S_BAD_SIG; + + /* it got through unscathed. Make sure the context is unexpired */ + + now = get_seconds(); + + if (now > ctx->endtime) + return GSS_S_CONTEXT_EXPIRED; + + /* do sequencing checks */ + + if (krb5_get_seq_num(ctx->seq, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8, &direction, &seqnum)) + return GSS_S_FAILURE; + + if ((ctx->initiate && direction != 0xff) || + (!ctx->initiate && direction != 0)) + return GSS_S_BAD_SIG; + + return GSS_S_COMPLETE; +} diff --git a/net/sunrpc/auth_gss/gss_krb5_wrap.c b/net/sunrpc/auth_gss/gss_krb5_wrap.c new file mode 100644 index 0000000..ae8e69b --- /dev/null +++ b/net/sunrpc/auth_gss/gss_krb5_wrap.c @@ -0,0 +1,302 @@ +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/jiffies.h> +#include <linux/sunrpc/gss_krb5.h> +#include <linux/random.h> +#include <linux/pagemap.h> +#include <linux/crypto.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static inline int +gss_krb5_padding(int blocksize, int length) +{ + /* Most of the code is block-size independent but currently we + * use only 8: */ + BUG_ON(blocksize != 8); + return 8 - (length & 7); +} + +static inline void +gss_krb5_add_padding(struct xdr_buf *buf, int offset, int blocksize) +{ + int padding = gss_krb5_padding(blocksize, buf->len - offset); + char *p; + struct kvec *iov; + + if (buf->page_len || buf->tail[0].iov_len) + iov = &buf->tail[0]; + else + iov = &buf->head[0]; + p = iov->iov_base + iov->iov_len; + iov->iov_len += padding; + buf->len += padding; + memset(p, padding, padding); +} + +static inline int +gss_krb5_remove_padding(struct xdr_buf *buf, int blocksize) +{ + u8 *ptr; + u8 pad; + size_t len = buf->len; + + if (len <= buf->head[0].iov_len) { + pad = *(u8 *)(buf->head[0].iov_base + len - 1); + if (pad > buf->head[0].iov_len) + return -EINVAL; + buf->head[0].iov_len -= pad; + goto out; + } else + len -= buf->head[0].iov_len; + if (len <= buf->page_len) { + unsigned int last = (buf->page_base + len - 1) + >>PAGE_CACHE_SHIFT; + unsigned int offset = (buf->page_base + len - 1) + & (PAGE_CACHE_SIZE - 1); + ptr = kmap_atomic(buf->pages[last], KM_USER0); + pad = *(ptr + offset); + kunmap_atomic(ptr, KM_USER0); + goto out; + } else + len -= buf->page_len; + BUG_ON(len > buf->tail[0].iov_len); + pad = *(u8 *)(buf->tail[0].iov_base + len - 1); +out: + /* XXX: NOTE: we do not adjust the page lengths--they represent + * a range of data in the real filesystem page cache, and we need + * to know that range so the xdr code can properly place read data. + * However adjusting the head length, as we do above, is harmless. + * In the case of a request that fits into a single page, the server + * also uses length and head length together to determine the original + * start of the request to copy the request for deferal; so it's + * easier on the server if we adjust head and tail length in tandem. + * It's not really a problem that we don't fool with the page and + * tail lengths, though--at worst badly formed xdr might lead the + * server to attempt to parse the padding. + * XXX: Document all these weird requirements for gss mechanism + * wrap/unwrap functions. */ + if (pad > blocksize) + return -EINVAL; + if (buf->len > pad) + buf->len -= pad; + else + return -EINVAL; + return 0; +} + +static void +make_confounder(char *p, u32 conflen) +{ + static u64 i = 0; + u64 *q = (u64 *)p; + + /* rfc1964 claims this should be "random". But all that's really + * necessary is that it be unique. And not even that is necessary in + * our case since our "gssapi" implementation exists only to support + * rpcsec_gss, so we know that the only buffers we will ever encrypt + * already begin with a unique sequence number. Just to hedge my bets + * I'll make a half-hearted attempt at something unique, but ensuring + * uniqueness would mean worrying about atomicity and rollover, and I + * don't care enough. */ + + /* initialize to random value */ + if (i == 0) { + i = random32(); + i = (i << 32) | random32(); + } + + switch (conflen) { + case 16: + *q++ = i++; + /* fall through */ + case 8: + *q++ = i++; + break; + default: + BUG(); + } +} + +/* Assumptions: the head and tail of inbuf are ours to play with. + * The pages, however, may be real pages in the page cache and we replace + * them with scratch pages from **pages before writing to them. */ +/* XXX: obviously the above should be documentation of wrap interface, + * and shouldn't be in this kerberos-specific file. */ + +/* XXX factor out common code with seal/unseal. */ + +u32 +gss_wrap_kerberos(struct gss_ctx *ctx, int offset, + struct xdr_buf *buf, struct page **pages) +{ + struct krb5_ctx *kctx = ctx->internal_ctx_id; + char cksumdata[16]; + struct xdr_netobj md5cksum = {.len = 0, .data = cksumdata}; + int blocksize = 0, plainlen; + unsigned char *ptr, *msg_start; + s32 now; + int headlen; + struct page **tmp_pages; + u32 seq_send; + + dprintk("RPC: gss_wrap_kerberos\n"); + + now = get_seconds(); + + blocksize = crypto_blkcipher_blocksize(kctx->enc); + gss_krb5_add_padding(buf, offset, blocksize); + BUG_ON((buf->len - offset) % blocksize); + plainlen = blocksize + buf->len - offset; + + headlen = g_token_size(&kctx->mech_used, 24 + plainlen) - + (buf->len - offset); + + ptr = buf->head[0].iov_base + offset; + /* shift data to make room for header. */ + /* XXX Would be cleverer to encrypt while copying. */ + /* XXX bounds checking, slack, etc. */ + memmove(ptr + headlen, ptr, buf->head[0].iov_len - offset); + buf->head[0].iov_len += headlen; + buf->len += headlen; + BUG_ON((buf->len - offset - headlen) % blocksize); + + g_make_token_header(&kctx->mech_used, + GSS_KRB5_TOK_HDR_LEN + 8 + plainlen, &ptr); + + + /* ptr now at header described in rfc 1964, section 1.2.1: */ + ptr[0] = (unsigned char) ((KG_TOK_WRAP_MSG >> 8) & 0xff); + ptr[1] = (unsigned char) (KG_TOK_WRAP_MSG & 0xff); + + msg_start = ptr + 24; + + *(__be16 *)(ptr + 2) = htons(SGN_ALG_DES_MAC_MD5); + memset(ptr + 4, 0xff, 4); + *(__be16 *)(ptr + 4) = htons(SEAL_ALG_DES); + + make_confounder(msg_start, blocksize); + + /* XXXJBF: UGH!: */ + tmp_pages = buf->pages; + buf->pages = pages; + if (make_checksum("md5", ptr, 8, buf, + offset + headlen - blocksize, &md5cksum)) + return GSS_S_FAILURE; + buf->pages = tmp_pages; + + if (krb5_encrypt(kctx->seq, NULL, md5cksum.data, + md5cksum.data, md5cksum.len)) + return GSS_S_FAILURE; + memcpy(ptr + GSS_KRB5_TOK_HDR_LEN, md5cksum.data + md5cksum.len - 8, 8); + + spin_lock(&krb5_seq_lock); + seq_send = kctx->seq_send++; + spin_unlock(&krb5_seq_lock); + + /* XXX would probably be more efficient to compute checksum + * and encrypt at the same time: */ + if ((krb5_make_seq_num(kctx->seq, kctx->initiate ? 0 : 0xff, + seq_send, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8))) + return GSS_S_FAILURE; + + if (gss_encrypt_xdr_buf(kctx->enc, buf, offset + headlen - blocksize, + pages)) + return GSS_S_FAILURE; + + return (kctx->endtime < now) ? GSS_S_CONTEXT_EXPIRED : GSS_S_COMPLETE; +} + +u32 +gss_unwrap_kerberos(struct gss_ctx *ctx, int offset, struct xdr_buf *buf) +{ + struct krb5_ctx *kctx = ctx->internal_ctx_id; + int signalg; + int sealalg; + char cksumdata[16]; + struct xdr_netobj md5cksum = {.len = 0, .data = cksumdata}; + s32 now; + int direction; + s32 seqnum; + unsigned char *ptr; + int bodysize; + void *data_start, *orig_start; + int data_len; + int blocksize; + + dprintk("RPC: gss_unwrap_kerberos\n"); + + ptr = (u8 *)buf->head[0].iov_base + offset; + if (g_verify_token_header(&kctx->mech_used, &bodysize, &ptr, + buf->len - offset)) + return GSS_S_DEFECTIVE_TOKEN; + + if ((ptr[0] != ((KG_TOK_WRAP_MSG >> 8) & 0xff)) || + (ptr[1] != (KG_TOK_WRAP_MSG & 0xff))) + return GSS_S_DEFECTIVE_TOKEN; + + /* XXX sanity-check bodysize?? */ + + /* get the sign and seal algorithms */ + + signalg = ptr[2] + (ptr[3] << 8); + if (signalg != SGN_ALG_DES_MAC_MD5) + return GSS_S_DEFECTIVE_TOKEN; + + sealalg = ptr[4] + (ptr[5] << 8); + if (sealalg != SEAL_ALG_DES) + return GSS_S_DEFECTIVE_TOKEN; + + if ((ptr[6] != 0xff) || (ptr[7] != 0xff)) + return GSS_S_DEFECTIVE_TOKEN; + + if (gss_decrypt_xdr_buf(kctx->enc, buf, + ptr + GSS_KRB5_TOK_HDR_LEN + 8 - (unsigned char *)buf->head[0].iov_base)) + return GSS_S_DEFECTIVE_TOKEN; + + if (make_checksum("md5", ptr, 8, buf, + ptr + GSS_KRB5_TOK_HDR_LEN + 8 - (unsigned char *)buf->head[0].iov_base, &md5cksum)) + return GSS_S_FAILURE; + + if (krb5_encrypt(kctx->seq, NULL, md5cksum.data, + md5cksum.data, md5cksum.len)) + return GSS_S_FAILURE; + + if (memcmp(md5cksum.data + 8, ptr + GSS_KRB5_TOK_HDR_LEN, 8)) + return GSS_S_BAD_SIG; + + /* it got through unscathed. Make sure the context is unexpired */ + + now = get_seconds(); + + if (now > kctx->endtime) + return GSS_S_CONTEXT_EXPIRED; + + /* do sequencing checks */ + + if (krb5_get_seq_num(kctx->seq, ptr + GSS_KRB5_TOK_HDR_LEN, ptr + 8, + &direction, &seqnum)) + return GSS_S_BAD_SIG; + + if ((kctx->initiate && direction != 0xff) || + (!kctx->initiate && direction != 0)) + return GSS_S_BAD_SIG; + + /* Copy the data back to the right position. XXX: Would probably be + * better to copy and encrypt at the same time. */ + + blocksize = crypto_blkcipher_blocksize(kctx->enc); + data_start = ptr + GSS_KRB5_TOK_HDR_LEN + 8 + blocksize; + orig_start = buf->head[0].iov_base + offset; + data_len = (buf->head[0].iov_base + buf->head[0].iov_len) - data_start; + memmove(orig_start, data_start, data_len); + buf->head[0].iov_len -= (data_start - orig_start); + buf->len -= (data_start - orig_start); + + if (gss_krb5_remove_padding(buf, blocksize)) + return GSS_S_DEFECTIVE_TOKEN; + + return GSS_S_COMPLETE; +} diff --git a/net/sunrpc/auth_gss/gss_mech_switch.c b/net/sunrpc/auth_gss/gss_mech_switch.c new file mode 100644 index 0000000..bce9d52 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_mech_switch.c @@ -0,0 +1,328 @@ +/* + * linux/net/sunrpc/gss_mech_switch.c + * + * Copyright (c) 2001 The Regents of the University of Michigan. + * All rights reserved. + * + * J. Bruce Fields <bfields@umich.edu> + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the University nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/module.h> +#include <linux/sunrpc/msg_prot.h> +#include <linux/sunrpc/gss_asn1.h> +#include <linux/sunrpc/auth_gss.h> +#include <linux/sunrpc/svcauth_gss.h> +#include <linux/sunrpc/gss_err.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/gss_api.h> +#include <linux/sunrpc/clnt.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static LIST_HEAD(registered_mechs); +static DEFINE_SPINLOCK(registered_mechs_lock); + +static void +gss_mech_free(struct gss_api_mech *gm) +{ + struct pf_desc *pf; + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + pf = &gm->gm_pfs[i]; + kfree(pf->auth_domain_name); + pf->auth_domain_name = NULL; + } +} + +static inline char * +make_auth_domain_name(char *name) +{ + static char *prefix = "gss/"; + char *new; + + new = kmalloc(strlen(name) + strlen(prefix) + 1, GFP_KERNEL); + if (new) { + strcpy(new, prefix); + strcat(new, name); + } + return new; +} + +static int +gss_mech_svc_setup(struct gss_api_mech *gm) +{ + struct pf_desc *pf; + int i, status; + + for (i = 0; i < gm->gm_pf_num; i++) { + pf = &gm->gm_pfs[i]; + pf->auth_domain_name = make_auth_domain_name(pf->name); + status = -ENOMEM; + if (pf->auth_domain_name == NULL) + goto out; + status = svcauth_gss_register_pseudoflavor(pf->pseudoflavor, + pf->auth_domain_name); + if (status) + goto out; + } + return 0; +out: + gss_mech_free(gm); + return status; +} + +int +gss_mech_register(struct gss_api_mech *gm) +{ + int status; + + status = gss_mech_svc_setup(gm); + if (status) + return status; + spin_lock(®istered_mechs_lock); + list_add(&gm->gm_list, ®istered_mechs); + spin_unlock(®istered_mechs_lock); + dprintk("RPC: registered gss mechanism %s\n", gm->gm_name); + return 0; +} + +EXPORT_SYMBOL(gss_mech_register); + +void +gss_mech_unregister(struct gss_api_mech *gm) +{ + spin_lock(®istered_mechs_lock); + list_del(&gm->gm_list); + spin_unlock(®istered_mechs_lock); + dprintk("RPC: unregistered gss mechanism %s\n", gm->gm_name); + gss_mech_free(gm); +} + +EXPORT_SYMBOL(gss_mech_unregister); + +struct gss_api_mech * +gss_mech_get(struct gss_api_mech *gm) +{ + __module_get(gm->gm_owner); + return gm; +} + +EXPORT_SYMBOL(gss_mech_get); + +struct gss_api_mech * +gss_mech_get_by_name(const char *name) +{ + struct gss_api_mech *pos, *gm = NULL; + + spin_lock(®istered_mechs_lock); + list_for_each_entry(pos, ®istered_mechs, gm_list) { + if (0 == strcmp(name, pos->gm_name)) { + if (try_module_get(pos->gm_owner)) + gm = pos; + break; + } + } + spin_unlock(®istered_mechs_lock); + return gm; + +} + +EXPORT_SYMBOL(gss_mech_get_by_name); + +static inline int +mech_supports_pseudoflavor(struct gss_api_mech *gm, u32 pseudoflavor) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].pseudoflavor == pseudoflavor) + return 1; + } + return 0; +} + +struct gss_api_mech * +gss_mech_get_by_pseudoflavor(u32 pseudoflavor) +{ + struct gss_api_mech *pos, *gm = NULL; + + spin_lock(®istered_mechs_lock); + list_for_each_entry(pos, ®istered_mechs, gm_list) { + if (!mech_supports_pseudoflavor(pos, pseudoflavor)) { + module_put(pos->gm_owner); + continue; + } + if (try_module_get(pos->gm_owner)) + gm = pos; + break; + } + spin_unlock(®istered_mechs_lock); + return gm; +} + +EXPORT_SYMBOL(gss_mech_get_by_pseudoflavor); + +u32 +gss_svc_to_pseudoflavor(struct gss_api_mech *gm, u32 service) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].service == service) { + return gm->gm_pfs[i].pseudoflavor; + } + } + return RPC_AUTH_MAXFLAVOR; /* illegal value */ +} +EXPORT_SYMBOL(gss_svc_to_pseudoflavor); + +u32 +gss_pseudoflavor_to_service(struct gss_api_mech *gm, u32 pseudoflavor) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].pseudoflavor == pseudoflavor) + return gm->gm_pfs[i].service; + } + return 0; +} + +EXPORT_SYMBOL(gss_pseudoflavor_to_service); + +char * +gss_service_to_auth_domain_name(struct gss_api_mech *gm, u32 service) +{ + int i; + + for (i = 0; i < gm->gm_pf_num; i++) { + if (gm->gm_pfs[i].service == service) + return gm->gm_pfs[i].auth_domain_name; + } + return NULL; +} + +EXPORT_SYMBOL(gss_service_to_auth_domain_name); + +void +gss_mech_put(struct gss_api_mech * gm) +{ + if (gm) + module_put(gm->gm_owner); +} + +EXPORT_SYMBOL(gss_mech_put); + +/* The mech could probably be determined from the token instead, but it's just + * as easy for now to pass it in. */ +int +gss_import_sec_context(const void *input_token, size_t bufsize, + struct gss_api_mech *mech, + struct gss_ctx **ctx_id) +{ + if (!(*ctx_id = kzalloc(sizeof(**ctx_id), GFP_KERNEL))) + return GSS_S_FAILURE; + (*ctx_id)->mech_type = gss_mech_get(mech); + + return mech->gm_ops + ->gss_import_sec_context(input_token, bufsize, *ctx_id); +} + +/* gss_get_mic: compute a mic over message and return mic_token. */ + +u32 +gss_get_mic(struct gss_ctx *context_handle, + struct xdr_buf *message, + struct xdr_netobj *mic_token) +{ + return context_handle->mech_type->gm_ops + ->gss_get_mic(context_handle, + message, + mic_token); +} + +/* gss_verify_mic: check whether the provided mic_token verifies message. */ + +u32 +gss_verify_mic(struct gss_ctx *context_handle, + struct xdr_buf *message, + struct xdr_netobj *mic_token) +{ + return context_handle->mech_type->gm_ops + ->gss_verify_mic(context_handle, + message, + mic_token); +} + +u32 +gss_wrap(struct gss_ctx *ctx_id, + int offset, + struct xdr_buf *buf, + struct page **inpages) +{ + return ctx_id->mech_type->gm_ops + ->gss_wrap(ctx_id, offset, buf, inpages); +} + +u32 +gss_unwrap(struct gss_ctx *ctx_id, + int offset, + struct xdr_buf *buf) +{ + return ctx_id->mech_type->gm_ops + ->gss_unwrap(ctx_id, offset, buf); +} + + +/* gss_delete_sec_context: free all resources associated with context_handle. + * Note this differs from the RFC 2744-specified prototype in that we don't + * bother returning an output token, since it would never be used anyway. */ + +u32 +gss_delete_sec_context(struct gss_ctx **context_handle) +{ + dprintk("RPC: gss_delete_sec_context deleting %p\n", + *context_handle); + + if (!*context_handle) + return(GSS_S_NO_CONTEXT); + if ((*context_handle)->internal_ctx_id) + (*context_handle)->mech_type->gm_ops + ->gss_delete_sec_context((*context_handle) + ->internal_ctx_id); + gss_mech_put((*context_handle)->mech_type); + kfree(*context_handle); + *context_handle=NULL; + return GSS_S_COMPLETE; +} diff --git a/net/sunrpc/auth_gss/gss_spkm3_mech.c b/net/sunrpc/auth_gss/gss_spkm3_mech.c new file mode 100644 index 0000000..035e1dd --- /dev/null +++ b/net/sunrpc/auth_gss/gss_spkm3_mech.c @@ -0,0 +1,243 @@ +/* + * linux/net/sunrpc/gss_spkm3_mech.c + * + * Copyright (c) 2003 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + * J. Bruce Fields <bfields@umich.edu> + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the University nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include <linux/err.h> +#include <linux/module.h> +#include <linux/init.h> +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/sunrpc/auth.h> +#include <linux/in.h> +#include <linux/sunrpc/svcauth_gss.h> +#include <linux/sunrpc/gss_spkm3.h> +#include <linux/sunrpc/xdr.h> +#include <linux/crypto.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static const void * +simple_get_bytes(const void *p, const void *end, void *res, int len) +{ + const void *q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); + memcpy(res, p, len); + return q; +} + +static const void * +simple_get_netobj(const void *p, const void *end, struct xdr_netobj *res) +{ + const void *q; + unsigned int len; + p = simple_get_bytes(p, end, &len, sizeof(len)); + if (IS_ERR(p)) + return p; + res->len = len; + if (len == 0) { + res->data = NULL; + return p; + } + q = (const void *)((const char *)p + len); + if (unlikely(q > end || q < p)) + return ERR_PTR(-EFAULT); + res->data = kmemdup(p, len, GFP_NOFS); + if (unlikely(res->data == NULL)) + return ERR_PTR(-ENOMEM); + return q; +} + +static int +gss_import_sec_context_spkm3(const void *p, size_t len, + struct gss_ctx *ctx_id) +{ + const void *end = (const void *)((const char *)p + len); + struct spkm3_ctx *ctx; + int version; + + if (!(ctx = kzalloc(sizeof(*ctx), GFP_NOFS))) + goto out_err; + + p = simple_get_bytes(p, end, &version, sizeof(version)); + if (IS_ERR(p)) + goto out_err_free_ctx; + if (version != 1) { + dprintk("RPC: unknown spkm3 token format: " + "obsolete nfs-utils?\n"); + goto out_err_free_ctx; + } + + p = simple_get_netobj(p, end, &ctx->ctx_id); + if (IS_ERR(p)) + goto out_err_free_ctx; + + p = simple_get_bytes(p, end, &ctx->endtime, sizeof(ctx->endtime)); + if (IS_ERR(p)) + goto out_err_free_ctx_id; + + p = simple_get_netobj(p, end, &ctx->mech_used); + if (IS_ERR(p)) + goto out_err_free_ctx_id; + + p = simple_get_bytes(p, end, &ctx->ret_flags, sizeof(ctx->ret_flags)); + if (IS_ERR(p)) + goto out_err_free_mech; + + p = simple_get_netobj(p, end, &ctx->conf_alg); + if (IS_ERR(p)) + goto out_err_free_mech; + + p = simple_get_netobj(p, end, &ctx->derived_conf_key); + if (IS_ERR(p)) + goto out_err_free_conf_alg; + + p = simple_get_netobj(p, end, &ctx->intg_alg); + if (IS_ERR(p)) + goto out_err_free_conf_key; + + p = simple_get_netobj(p, end, &ctx->derived_integ_key); + if (IS_ERR(p)) + goto out_err_free_intg_alg; + + if (p != end) + goto out_err_free_intg_key; + + ctx_id->internal_ctx_id = ctx; + + dprintk("RPC: Successfully imported new spkm context.\n"); + return 0; + +out_err_free_intg_key: + kfree(ctx->derived_integ_key.data); +out_err_free_intg_alg: + kfree(ctx->intg_alg.data); +out_err_free_conf_key: + kfree(ctx->derived_conf_key.data); +out_err_free_conf_alg: + kfree(ctx->conf_alg.data); +out_err_free_mech: + kfree(ctx->mech_used.data); +out_err_free_ctx_id: + kfree(ctx->ctx_id.data); +out_err_free_ctx: + kfree(ctx); +out_err: + return PTR_ERR(p); +} + +static void +gss_delete_sec_context_spkm3(void *internal_ctx) +{ + struct spkm3_ctx *sctx = internal_ctx; + + kfree(sctx->derived_integ_key.data); + kfree(sctx->intg_alg.data); + kfree(sctx->derived_conf_key.data); + kfree(sctx->conf_alg.data); + kfree(sctx->mech_used.data); + kfree(sctx->ctx_id.data); + kfree(sctx); +} + +static u32 +gss_verify_mic_spkm3(struct gss_ctx *ctx, + struct xdr_buf *signbuf, + struct xdr_netobj *checksum) +{ + u32 maj_stat = 0; + struct spkm3_ctx *sctx = ctx->internal_ctx_id; + + maj_stat = spkm3_read_token(sctx, checksum, signbuf, SPKM_MIC_TOK); + + dprintk("RPC: gss_verify_mic_spkm3 returning %d\n", maj_stat); + return maj_stat; +} + +static u32 +gss_get_mic_spkm3(struct gss_ctx *ctx, + struct xdr_buf *message_buffer, + struct xdr_netobj *message_token) +{ + u32 err = 0; + struct spkm3_ctx *sctx = ctx->internal_ctx_id; + + err = spkm3_make_token(sctx, message_buffer, + message_token, SPKM_MIC_TOK); + dprintk("RPC: gss_get_mic_spkm3 returning %d\n", err); + return err; +} + +static const struct gss_api_ops gss_spkm3_ops = { + .gss_import_sec_context = gss_import_sec_context_spkm3, + .gss_get_mic = gss_get_mic_spkm3, + .gss_verify_mic = gss_verify_mic_spkm3, + .gss_delete_sec_context = gss_delete_sec_context_spkm3, +}; + +static struct pf_desc gss_spkm3_pfs[] = { + {RPC_AUTH_GSS_SPKM, RPC_GSS_SVC_NONE, "spkm3"}, + {RPC_AUTH_GSS_SPKMI, RPC_GSS_SVC_INTEGRITY, "spkm3i"}, +}; + +static struct gss_api_mech gss_spkm3_mech = { + .gm_name = "spkm3", + .gm_owner = THIS_MODULE, + .gm_oid = {7, "\053\006\001\005\005\001\003"}, + .gm_ops = &gss_spkm3_ops, + .gm_pf_num = ARRAY_SIZE(gss_spkm3_pfs), + .gm_pfs = gss_spkm3_pfs, +}; + +static int __init init_spkm3_module(void) +{ + int status; + + status = gss_mech_register(&gss_spkm3_mech); + if (status) + printk("Failed to register spkm3 gss mechanism!\n"); + return status; +} + +static void __exit cleanup_spkm3_module(void) +{ + gss_mech_unregister(&gss_spkm3_mech); +} + +MODULE_LICENSE("GPL"); +module_init(init_spkm3_module); +module_exit(cleanup_spkm3_module); diff --git a/net/sunrpc/auth_gss/gss_spkm3_seal.c b/net/sunrpc/auth_gss/gss_spkm3_seal.c new file mode 100644 index 0000000..c832712 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_spkm3_seal.c @@ -0,0 +1,187 @@ +/* + * linux/net/sunrpc/gss_spkm3_seal.c + * + * Copyright (c) 2003 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the University nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/jiffies.h> +#include <linux/sunrpc/gss_spkm3.h> +#include <linux/random.h> +#include <linux/crypto.h> +#include <linux/pagemap.h> +#include <linux/scatterlist.h> +#include <linux/sunrpc/xdr.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +const struct xdr_netobj hmac_md5_oid = { 8, "\x2B\x06\x01\x05\x05\x08\x01\x01"}; +const struct xdr_netobj cast5_cbc_oid = {9, "\x2A\x86\x48\x86\xF6\x7D\x07\x42\x0A"}; + +/* + * spkm3_make_token() + * + * Only SPKM_MIC_TOK with md5 intg-alg is supported + */ + +u32 +spkm3_make_token(struct spkm3_ctx *ctx, + struct xdr_buf * text, struct xdr_netobj * token, + int toktype) +{ + s32 checksum_type; + char tokhdrbuf[25]; + char cksumdata[16]; + struct xdr_netobj md5cksum = {.len = 0, .data = cksumdata}; + struct xdr_netobj mic_hdr = {.len = 0, .data = tokhdrbuf}; + int tokenlen = 0; + unsigned char *ptr; + s32 now; + int ctxelen = 0, ctxzbit = 0; + int md5elen = 0, md5zbit = 0; + + now = jiffies; + + if (ctx->ctx_id.len != 16) { + dprintk("RPC: spkm3_make_token BAD ctx_id.len %d\n", + ctx->ctx_id.len); + goto out_err; + } + + if (!g_OID_equal(&ctx->intg_alg, &hmac_md5_oid)) { + dprintk("RPC: gss_spkm3_seal: unsupported I-ALG " + "algorithm. only support hmac-md5 I-ALG.\n"); + goto out_err; + } else + checksum_type = CKSUMTYPE_HMAC_MD5; + + if (!g_OID_equal(&ctx->conf_alg, &cast5_cbc_oid)) { + dprintk("RPC: gss_spkm3_seal: unsupported C-ALG " + "algorithm\n"); + goto out_err; + } + + if (toktype == SPKM_MIC_TOK) { + /* Calculate checksum over the mic-header */ + asn1_bitstring_len(&ctx->ctx_id, &ctxelen, &ctxzbit); + spkm3_mic_header(&mic_hdr.data, &mic_hdr.len, ctx->ctx_id.data, + ctxelen, ctxzbit); + if (make_spkm3_checksum(checksum_type, &ctx->derived_integ_key, + (char *)mic_hdr.data, mic_hdr.len, + text, 0, &md5cksum)) + goto out_err; + + asn1_bitstring_len(&md5cksum, &md5elen, &md5zbit); + tokenlen = 10 + ctxelen + 1 + md5elen + 1; + + /* Create token header using generic routines */ + token->len = g_token_size(&ctx->mech_used, tokenlen + 2); + + ptr = token->data; + g_make_token_header(&ctx->mech_used, tokenlen + 2, &ptr); + + spkm3_make_mic_token(&ptr, tokenlen, &mic_hdr, &md5cksum, md5elen, md5zbit); + } else if (toktype == SPKM_WRAP_TOK) { /* Not Supported */ + dprintk("RPC: gss_spkm3_seal: SPKM_WRAP_TOK " + "not supported\n"); + goto out_err; + } + + /* XXX need to implement sequence numbers, and ctx->expired */ + + return GSS_S_COMPLETE; +out_err: + token->data = NULL; + token->len = 0; + return GSS_S_FAILURE; +} + +static int +spkm3_checksummer(struct scatterlist *sg, void *data) +{ + struct hash_desc *desc = data; + + return crypto_hash_update(desc, sg, sg->length); +} + +/* checksum the plaintext data and hdrlen bytes of the token header */ +s32 +make_spkm3_checksum(s32 cksumtype, struct xdr_netobj *key, char *header, + unsigned int hdrlen, struct xdr_buf *body, + unsigned int body_offset, struct xdr_netobj *cksum) +{ + char *cksumname; + struct hash_desc desc; /* XXX add to ctx? */ + struct scatterlist sg[1]; + int err; + + switch (cksumtype) { + case CKSUMTYPE_HMAC_MD5: + cksumname = "hmac(md5)"; + break; + default: + dprintk("RPC: spkm3_make_checksum:" + " unsupported checksum %d", cksumtype); + return GSS_S_FAILURE; + } + + if (key->data == NULL || key->len <= 0) return GSS_S_FAILURE; + + desc.tfm = crypto_alloc_hash(cksumname, 0, CRYPTO_ALG_ASYNC); + if (IS_ERR(desc.tfm)) + return GSS_S_FAILURE; + cksum->len = crypto_hash_digestsize(desc.tfm); + desc.flags = CRYPTO_TFM_REQ_MAY_SLEEP; + + err = crypto_hash_setkey(desc.tfm, key->data, key->len); + if (err) + goto out; + + err = crypto_hash_init(&desc); + if (err) + goto out; + + sg_init_one(sg, header, hdrlen); + crypto_hash_update(&desc, sg, sg->length); + + xdr_process_buf(body, body_offset, body->len - body_offset, + spkm3_checksummer, &desc); + crypto_hash_final(&desc, cksum->data); + +out: + crypto_free_hash(desc.tfm); + + return err ? GSS_S_FAILURE : 0; +} diff --git a/net/sunrpc/auth_gss/gss_spkm3_token.c b/net/sunrpc/auth_gss/gss_spkm3_token.c new file mode 100644 index 0000000..3308157 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_spkm3_token.c @@ -0,0 +1,267 @@ +/* + * linux/net/sunrpc/gss_spkm3_token.c + * + * Copyright (c) 2003 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the University nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/jiffies.h> +#include <linux/sunrpc/gss_spkm3.h> +#include <linux/random.h> +#include <linux/crypto.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +/* + * asn1_bitstring_len() + * + * calculate the asn1 bitstring length of the xdr_netobject + */ +void +asn1_bitstring_len(struct xdr_netobj *in, int *enclen, int *zerobits) +{ + int i, zbit = 0,elen = in->len; + char *ptr; + + ptr = &in->data[in->len -1]; + + /* count trailing 0's */ + for(i = in->len; i > 0; i--) { + if (*ptr == 0) { + ptr--; + elen--; + } else + break; + } + + /* count number of 0 bits in final octet */ + ptr = &in->data[elen - 1]; + for(i = 0; i < 8; i++) { + short mask = 0x01; + + if (!((mask << i) & *ptr)) + zbit++; + else + break; + } + *enclen = elen; + *zerobits = zbit; +} + +/* + * decode_asn1_bitstring() + * + * decode a bitstring into a buffer of the expected length. + * enclen = bit string length + * explen = expected length (define in rfc) + */ +int +decode_asn1_bitstring(struct xdr_netobj *out, char *in, int enclen, int explen) +{ + if (!(out->data = kzalloc(explen,GFP_NOFS))) + return 0; + out->len = explen; + memcpy(out->data, in, enclen); + return 1; +} + +/* + * SPKMInnerContextToken choice SPKM_MIC asn1 token layout + * + * contextid is always 16 bytes plain data. max asn1 bitstring len = 17. + * + * tokenlen = pos[0] to end of token (max pos[45] with MD5 cksum) + * + * pos value + * ---------- + * [0] a4 SPKM-MIC tag + * [1] ?? innertoken length (max 44) + * + * + * tok_hdr piece of checksum data starts here + * + * the maximum mic-header len = 9 + 17 = 26 + * mic-header + * ---------- + * [2] 30 SEQUENCE tag + * [3] ?? mic-header length: (max 23) = TokenID + ContextID + * + * TokenID - all fields constant and can be hardcoded + * ------- + * [4] 02 Type 2 + * [5] 02 Length 2 + * [6][7] 01 01 TokenID (SPKM_MIC_TOK) + * + * ContextID - encoded length not constant, calculated + * --------- + * [8] 03 Type 3 + * [9] ?? encoded length + * [10] ?? ctxzbit + * [11] contextid + * + * mic_header piece of checksum data ends here. + * + * int-cksum - encoded length not constant, calculated + * --------- + * [??] 03 Type 3 + * [??] ?? encoded length + * [??] ?? md5zbit + * [??] int-cksum (NID_md5 = 16) + * + * maximum SPKM-MIC innercontext token length = + * 10 + encoded contextid_size(17 max) + 2 + encoded + * cksum_size (17 maxfor NID_md5) = 46 + */ + +/* + * spkm3_mic_header() + * + * Prepare the SPKM_MIC_TOK mic-header for check-sum calculation + * elen: 16 byte context id asn1 bitstring encoded length + */ +void +spkm3_mic_header(unsigned char **hdrbuf, unsigned int *hdrlen, unsigned char *ctxdata, int elen, int zbit) +{ + char *hptr = *hdrbuf; + char *top = *hdrbuf; + + *(u8 *)hptr++ = 0x30; + *(u8 *)hptr++ = elen + 7; /* on the wire header length */ + + /* tokenid */ + *(u8 *)hptr++ = 0x02; + *(u8 *)hptr++ = 0x02; + *(u8 *)hptr++ = 0x01; + *(u8 *)hptr++ = 0x01; + + /* coniextid */ + *(u8 *)hptr++ = 0x03; + *(u8 *)hptr++ = elen + 1; /* add 1 to include zbit */ + *(u8 *)hptr++ = zbit; + memcpy(hptr, ctxdata, elen); + hptr += elen; + *hdrlen = hptr - top; +} + +/* + * spkm3_mic_innercontext_token() + * + * *tokp points to the beginning of the SPKM_MIC token described + * in rfc 2025, section 3.2.1: + * + * toklen is the inner token length + */ +void +spkm3_make_mic_token(unsigned char **tokp, int toklen, struct xdr_netobj *mic_hdr, struct xdr_netobj *md5cksum, int md5elen, int md5zbit) +{ + unsigned char *ict = *tokp; + + *(u8 *)ict++ = 0xa4; + *(u8 *)ict++ = toklen; + memcpy(ict, mic_hdr->data, mic_hdr->len); + ict += mic_hdr->len; + + *(u8 *)ict++ = 0x03; + *(u8 *)ict++ = md5elen + 1; /* add 1 to include zbit */ + *(u8 *)ict++ = md5zbit; + memcpy(ict, md5cksum->data, md5elen); +} + +u32 +spkm3_verify_mic_token(unsigned char **tokp, int *mic_hdrlen, unsigned char **cksum) +{ + struct xdr_netobj spkm3_ctx_id = {.len =0, .data = NULL}; + unsigned char *ptr = *tokp; + int ctxelen; + u32 ret = GSS_S_DEFECTIVE_TOKEN; + + /* spkm3 innercontext token preamble */ + if ((ptr[0] != 0xa4) || (ptr[2] != 0x30)) { + dprintk("RPC: BAD SPKM ictoken preamble\n"); + goto out; + } + + *mic_hdrlen = ptr[3]; + + /* token type */ + if ((ptr[4] != 0x02) || (ptr[5] != 0x02)) { + dprintk("RPC: BAD asn1 SPKM3 token type\n"); + goto out; + } + + /* only support SPKM_MIC_TOK */ + if((ptr[6] != 0x01) || (ptr[7] != 0x01)) { + dprintk("RPC: ERROR unsupported SPKM3 token \n"); + goto out; + } + + /* contextid */ + if (ptr[8] != 0x03) { + dprintk("RPC: BAD SPKM3 asn1 context-id type\n"); + goto out; + } + + ctxelen = ptr[9]; + if (ctxelen > 17) { /* length includes asn1 zbit octet */ + dprintk("RPC: BAD SPKM3 contextid len %d\n", ctxelen); + goto out; + } + + /* ignore ptr[10] */ + + if(!decode_asn1_bitstring(&spkm3_ctx_id, &ptr[11], ctxelen - 1, 16)) + goto out; + + /* + * in the current implementation: the optional int-alg is not present + * so the default int-alg (md5) is used the optional snd-seq field is + * also not present + */ + + if (*mic_hdrlen != 6 + ctxelen) { + dprintk("RPC: BAD SPKM_ MIC_TOK header len %d: we only " + "support default int-alg (should be absent) " + "and do not support snd-seq\n", *mic_hdrlen); + goto out; + } + /* checksum */ + *cksum = (&ptr[10] + ctxelen); /* ctxelen includes ptr[10] */ + + ret = GSS_S_COMPLETE; +out: + kfree(spkm3_ctx_id.data); + return ret; +} + diff --git a/net/sunrpc/auth_gss/gss_spkm3_unseal.c b/net/sunrpc/auth_gss/gss_spkm3_unseal.c new file mode 100644 index 0000000..cc21ee8 --- /dev/null +++ b/net/sunrpc/auth_gss/gss_spkm3_unseal.c @@ -0,0 +1,127 @@ +/* + * linux/net/sunrpc/gss_spkm3_unseal.c + * + * Copyright (c) 2003 The Regents of the University of Michigan. + * All rights reserved. + * + * Andy Adamson <andros@umich.edu> + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * 3. Neither the name of the University nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED + * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/jiffies.h> +#include <linux/sunrpc/gss_spkm3.h> +#include <linux/crypto.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +/* + * spkm3_read_token() + * + * only SPKM_MIC_TOK with md5 intg-alg is supported + */ +u32 +spkm3_read_token(struct spkm3_ctx *ctx, + struct xdr_netobj *read_token, /* checksum */ + struct xdr_buf *message_buffer, /* signbuf */ + int toktype) +{ + s32 checksum_type; + s32 code; + struct xdr_netobj wire_cksum = {.len =0, .data = NULL}; + char cksumdata[16]; + struct xdr_netobj md5cksum = {.len = 0, .data = cksumdata}; + unsigned char *ptr = (unsigned char *)read_token->data; + unsigned char *cksum; + int bodysize, md5elen; + int mic_hdrlen; + u32 ret = GSS_S_DEFECTIVE_TOKEN; + + if (g_verify_token_header((struct xdr_netobj *) &ctx->mech_used, + &bodysize, &ptr, read_token->len)) + goto out; + + /* decode the token */ + + if (toktype != SPKM_MIC_TOK) { + dprintk("RPC: BAD SPKM3 token type: %d\n", toktype); + goto out; + } + + if ((ret = spkm3_verify_mic_token(&ptr, &mic_hdrlen, &cksum))) + goto out; + + if (*cksum++ != 0x03) { + dprintk("RPC: spkm3_read_token BAD checksum type\n"); + goto out; + } + md5elen = *cksum++; + cksum++; /* move past the zbit */ + + if (!decode_asn1_bitstring(&wire_cksum, cksum, md5elen - 1, 16)) + goto out; + + /* HARD CODED FOR MD5 */ + + /* compute the checksum of the message. + * ptr + 2 = start of header piece of checksum + * mic_hdrlen + 2 = length of header piece of checksum + */ + ret = GSS_S_DEFECTIVE_TOKEN; + if (!g_OID_equal(&ctx->intg_alg, &hmac_md5_oid)) { + dprintk("RPC: gss_spkm3_seal: unsupported I-ALG " + "algorithm\n"); + goto out; + } + + checksum_type = CKSUMTYPE_HMAC_MD5; + + code = make_spkm3_checksum(checksum_type, + &ctx->derived_integ_key, ptr + 2, mic_hdrlen + 2, + message_buffer, 0, &md5cksum); + + if (code) + goto out; + + ret = GSS_S_BAD_SIG; + code = memcmp(md5cksum.data, wire_cksum.data, wire_cksum.len); + if (code) { + dprintk("RPC: bad MIC checksum\n"); + goto out; + } + + + /* XXX: need to add expiration and sequencing */ + ret = GSS_S_COMPLETE; +out: + kfree(wire_cksum.data); + return ret; +} diff --git a/net/sunrpc/auth_gss/svcauth_gss.c b/net/sunrpc/auth_gss/svcauth_gss.c new file mode 100644 index 0000000..81ae3d6 --- /dev/null +++ b/net/sunrpc/auth_gss/svcauth_gss.c @@ -0,0 +1,1416 @@ +/* + * Neil Brown <neilb@cse.unsw.edu.au> + * J. Bruce Fields <bfields@umich.edu> + * Andy Adamson <andros@umich.edu> + * Dug Song <dugsong@monkey.org> + * + * RPCSEC_GSS server authentication. + * This implements RPCSEC_GSS as defined in rfc2203 (rpcsec_gss) and rfc2078 + * (gssapi) + * + * The RPCSEC_GSS involves three stages: + * 1/ context creation + * 2/ data exchange + * 3/ context destruction + * + * Context creation is handled largely by upcalls to user-space. + * In particular, GSS_Accept_sec_context is handled by an upcall + * Data exchange is handled entirely within the kernel + * In particular, GSS_GetMIC, GSS_VerifyMIC, GSS_Seal, GSS_Unseal are in-kernel. + * Context destruction is handled in-kernel + * GSS_Delete_sec_context is in-kernel + * + * Context creation is initiated by a RPCSEC_GSS_INIT request arriving. + * The context handle and gss_token are used as a key into the rpcsec_init cache. + * The content of this cache includes some of the outputs of GSS_Accept_sec_context, + * being major_status, minor_status, context_handle, reply_token. + * These are sent back to the client. + * Sequence window management is handled by the kernel. The window size if currently + * a compile time constant. + * + * When user-space is happy that a context is established, it places an entry + * in the rpcsec_context cache. The key for this cache is the context_handle. + * The content includes: + * uid/gidlist - for determining access rights + * mechanism type + * mechanism specific information, such as a key + * + */ + +#include <linux/types.h> +#include <linux/module.h> +#include <linux/pagemap.h> + +#include <linux/sunrpc/auth_gss.h> +#include <linux/sunrpc/gss_err.h> +#include <linux/sunrpc/svcauth.h> +#include <linux/sunrpc/svcauth_gss.h> +#include <linux/sunrpc/cache.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +/* The rpcsec_init cache is used for mapping RPCSEC_GSS_{,CONT_}INIT requests + * into replies. + * + * Key is context handle (\x if empty) and gss_token. + * Content is major_status minor_status (integers) context_handle, reply_token. + * + */ + +static int netobj_equal(struct xdr_netobj *a, struct xdr_netobj *b) +{ + return a->len == b->len && 0 == memcmp(a->data, b->data, a->len); +} + +#define RSI_HASHBITS 6 +#define RSI_HASHMAX (1<<RSI_HASHBITS) +#define RSI_HASHMASK (RSI_HASHMAX-1) + +struct rsi { + struct cache_head h; + struct xdr_netobj in_handle, in_token; + struct xdr_netobj out_handle, out_token; + int major_status, minor_status; +}; + +static struct cache_head *rsi_table[RSI_HASHMAX]; +static struct cache_detail rsi_cache; +static struct rsi *rsi_update(struct rsi *new, struct rsi *old); +static struct rsi *rsi_lookup(struct rsi *item); + +static void rsi_free(struct rsi *rsii) +{ + kfree(rsii->in_handle.data); + kfree(rsii->in_token.data); + kfree(rsii->out_handle.data); + kfree(rsii->out_token.data); +} + +static void rsi_put(struct kref *ref) +{ + struct rsi *rsii = container_of(ref, struct rsi, h.ref); + rsi_free(rsii); + kfree(rsii); +} + +static inline int rsi_hash(struct rsi *item) +{ + return hash_mem(item->in_handle.data, item->in_handle.len, RSI_HASHBITS) + ^ hash_mem(item->in_token.data, item->in_token.len, RSI_HASHBITS); +} + +static int rsi_match(struct cache_head *a, struct cache_head *b) +{ + struct rsi *item = container_of(a, struct rsi, h); + struct rsi *tmp = container_of(b, struct rsi, h); + return netobj_equal(&item->in_handle, &tmp->in_handle) + && netobj_equal(&item->in_token, &tmp->in_token); +} + +static int dup_to_netobj(struct xdr_netobj *dst, char *src, int len) +{ + dst->len = len; + dst->data = (len ? kmemdup(src, len, GFP_KERNEL) : NULL); + if (len && !dst->data) + return -ENOMEM; + return 0; +} + +static inline int dup_netobj(struct xdr_netobj *dst, struct xdr_netobj *src) +{ + return dup_to_netobj(dst, src->data, src->len); +} + +static void rsi_init(struct cache_head *cnew, struct cache_head *citem) +{ + struct rsi *new = container_of(cnew, struct rsi, h); + struct rsi *item = container_of(citem, struct rsi, h); + + new->out_handle.data = NULL; + new->out_handle.len = 0; + new->out_token.data = NULL; + new->out_token.len = 0; + new->in_handle.len = item->in_handle.len; + item->in_handle.len = 0; + new->in_token.len = item->in_token.len; + item->in_token.len = 0; + new->in_handle.data = item->in_handle.data; + item->in_handle.data = NULL; + new->in_token.data = item->in_token.data; + item->in_token.data = NULL; +} + +static void update_rsi(struct cache_head *cnew, struct cache_head *citem) +{ + struct rsi *new = container_of(cnew, struct rsi, h); + struct rsi *item = container_of(citem, struct rsi, h); + + BUG_ON(new->out_handle.data || new->out_token.data); + new->out_handle.len = item->out_handle.len; + item->out_handle.len = 0; + new->out_token.len = item->out_token.len; + item->out_token.len = 0; + new->out_handle.data = item->out_handle.data; + item->out_handle.data = NULL; + new->out_token.data = item->out_token.data; + item->out_token.data = NULL; + + new->major_status = item->major_status; + new->minor_status = item->minor_status; +} + +static struct cache_head *rsi_alloc(void) +{ + struct rsi *rsii = kmalloc(sizeof(*rsii), GFP_KERNEL); + if (rsii) + return &rsii->h; + else + return NULL; +} + +static void rsi_request(struct cache_detail *cd, + struct cache_head *h, + char **bpp, int *blen) +{ + struct rsi *rsii = container_of(h, struct rsi, h); + + qword_addhex(bpp, blen, rsii->in_handle.data, rsii->in_handle.len); + qword_addhex(bpp, blen, rsii->in_token.data, rsii->in_token.len); + (*bpp)[-1] = '\n'; +} + + +static int rsi_parse(struct cache_detail *cd, + char *mesg, int mlen) +{ + /* context token expiry major minor context token */ + char *buf = mesg; + char *ep; + int len; + struct rsi rsii, *rsip = NULL; + time_t expiry; + int status = -EINVAL; + + memset(&rsii, 0, sizeof(rsii)); + /* handle */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) + goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsii.in_handle, buf, len)) + goto out; + + /* token */ + len = qword_get(&mesg, buf, mlen); + status = -EINVAL; + if (len < 0) + goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsii.in_token, buf, len)) + goto out; + + rsip = rsi_lookup(&rsii); + if (!rsip) + goto out; + + rsii.h.flags = 0; + /* expiry */ + expiry = get_expiry(&mesg); + status = -EINVAL; + if (expiry == 0) + goto out; + + /* major/minor */ + len = qword_get(&mesg, buf, mlen); + if (len <= 0) + goto out; + rsii.major_status = simple_strtoul(buf, &ep, 10); + if (*ep) + goto out; + len = qword_get(&mesg, buf, mlen); + if (len <= 0) + goto out; + rsii.minor_status = simple_strtoul(buf, &ep, 10); + if (*ep) + goto out; + + /* out_handle */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) + goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsii.out_handle, buf, len)) + goto out; + + /* out_token */ + len = qword_get(&mesg, buf, mlen); + status = -EINVAL; + if (len < 0) + goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsii.out_token, buf, len)) + goto out; + rsii.h.expiry_time = expiry; + rsip = rsi_update(&rsii, rsip); + status = 0; +out: + rsi_free(&rsii); + if (rsip) + cache_put(&rsip->h, &rsi_cache); + else + status = -ENOMEM; + return status; +} + +static struct cache_detail rsi_cache = { + .owner = THIS_MODULE, + .hash_size = RSI_HASHMAX, + .hash_table = rsi_table, + .name = "auth.rpcsec.init", + .cache_put = rsi_put, + .cache_request = rsi_request, + .cache_parse = rsi_parse, + .match = rsi_match, + .init = rsi_init, + .update = update_rsi, + .alloc = rsi_alloc, +}; + +static struct rsi *rsi_lookup(struct rsi *item) +{ + struct cache_head *ch; + int hash = rsi_hash(item); + + ch = sunrpc_cache_lookup(&rsi_cache, &item->h, hash); + if (ch) + return container_of(ch, struct rsi, h); + else + return NULL; +} + +static struct rsi *rsi_update(struct rsi *new, struct rsi *old) +{ + struct cache_head *ch; + int hash = rsi_hash(new); + + ch = sunrpc_cache_update(&rsi_cache, &new->h, + &old->h, hash); + if (ch) + return container_of(ch, struct rsi, h); + else + return NULL; +} + + +/* + * The rpcsec_context cache is used to store a context that is + * used in data exchange. + * The key is a context handle. The content is: + * uid, gidlist, mechanism, service-set, mech-specific-data + */ + +#define RSC_HASHBITS 10 +#define RSC_HASHMAX (1<<RSC_HASHBITS) +#define RSC_HASHMASK (RSC_HASHMAX-1) + +#define GSS_SEQ_WIN 128 + +struct gss_svc_seq_data { + /* highest seq number seen so far: */ + int sd_max; + /* for i such that sd_max-GSS_SEQ_WIN < i <= sd_max, the i-th bit of + * sd_win is nonzero iff sequence number i has been seen already: */ + unsigned long sd_win[GSS_SEQ_WIN/BITS_PER_LONG]; + spinlock_t sd_lock; +}; + +struct rsc { + struct cache_head h; + struct xdr_netobj handle; + struct svc_cred cred; + struct gss_svc_seq_data seqdata; + struct gss_ctx *mechctx; +}; + +static struct cache_head *rsc_table[RSC_HASHMAX]; +static struct cache_detail rsc_cache; +static struct rsc *rsc_update(struct rsc *new, struct rsc *old); +static struct rsc *rsc_lookup(struct rsc *item); + +static void rsc_free(struct rsc *rsci) +{ + kfree(rsci->handle.data); + if (rsci->mechctx) + gss_delete_sec_context(&rsci->mechctx); + if (rsci->cred.cr_group_info) + put_group_info(rsci->cred.cr_group_info); +} + +static void rsc_put(struct kref *ref) +{ + struct rsc *rsci = container_of(ref, struct rsc, h.ref); + + rsc_free(rsci); + kfree(rsci); +} + +static inline int +rsc_hash(struct rsc *rsci) +{ + return hash_mem(rsci->handle.data, rsci->handle.len, RSC_HASHBITS); +} + +static int +rsc_match(struct cache_head *a, struct cache_head *b) +{ + struct rsc *new = container_of(a, struct rsc, h); + struct rsc *tmp = container_of(b, struct rsc, h); + + return netobj_equal(&new->handle, &tmp->handle); +} + +static void +rsc_init(struct cache_head *cnew, struct cache_head *ctmp) +{ + struct rsc *new = container_of(cnew, struct rsc, h); + struct rsc *tmp = container_of(ctmp, struct rsc, h); + + new->handle.len = tmp->handle.len; + tmp->handle.len = 0; + new->handle.data = tmp->handle.data; + tmp->handle.data = NULL; + new->mechctx = NULL; + new->cred.cr_group_info = NULL; +} + +static void +update_rsc(struct cache_head *cnew, struct cache_head *ctmp) +{ + struct rsc *new = container_of(cnew, struct rsc, h); + struct rsc *tmp = container_of(ctmp, struct rsc, h); + + new->mechctx = tmp->mechctx; + tmp->mechctx = NULL; + memset(&new->seqdata, 0, sizeof(new->seqdata)); + spin_lock_init(&new->seqdata.sd_lock); + new->cred = tmp->cred; + tmp->cred.cr_group_info = NULL; +} + +static struct cache_head * +rsc_alloc(void) +{ + struct rsc *rsci = kmalloc(sizeof(*rsci), GFP_KERNEL); + if (rsci) + return &rsci->h; + else + return NULL; +} + +static int rsc_parse(struct cache_detail *cd, + char *mesg, int mlen) +{ + /* contexthandle expiry [ uid gid N <n gids> mechname ...mechdata... ] */ + char *buf = mesg; + int len, rv; + struct rsc rsci, *rscp = NULL; + time_t expiry; + int status = -EINVAL; + struct gss_api_mech *gm = NULL; + + memset(&rsci, 0, sizeof(rsci)); + /* context handle */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) goto out; + status = -ENOMEM; + if (dup_to_netobj(&rsci.handle, buf, len)) + goto out; + + rsci.h.flags = 0; + /* expiry */ + expiry = get_expiry(&mesg); + status = -EINVAL; + if (expiry == 0) + goto out; + + rscp = rsc_lookup(&rsci); + if (!rscp) + goto out; + + /* uid, or NEGATIVE */ + rv = get_int(&mesg, &rsci.cred.cr_uid); + if (rv == -EINVAL) + goto out; + if (rv == -ENOENT) + set_bit(CACHE_NEGATIVE, &rsci.h.flags); + else { + int N, i; + + /* gid */ + if (get_int(&mesg, &rsci.cred.cr_gid)) + goto out; + + /* number of additional gid's */ + if (get_int(&mesg, &N)) + goto out; + status = -ENOMEM; + rsci.cred.cr_group_info = groups_alloc(N); + if (rsci.cred.cr_group_info == NULL) + goto out; + + /* gid's */ + status = -EINVAL; + for (i=0; i<N; i++) { + gid_t gid; + if (get_int(&mesg, &gid)) + goto out; + GROUP_AT(rsci.cred.cr_group_info, i) = gid; + } + + /* mech name */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) + goto out; + gm = gss_mech_get_by_name(buf); + status = -EOPNOTSUPP; + if (!gm) + goto out; + + status = -EINVAL; + /* mech-specific data: */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) + goto out; + status = gss_import_sec_context(buf, len, gm, &rsci.mechctx); + if (status) + goto out; + } + rsci.h.expiry_time = expiry; + rscp = rsc_update(&rsci, rscp); + status = 0; +out: + gss_mech_put(gm); + rsc_free(&rsci); + if (rscp) + cache_put(&rscp->h, &rsc_cache); + else + status = -ENOMEM; + return status; +} + +static struct cache_detail rsc_cache = { + .owner = THIS_MODULE, + .hash_size = RSC_HASHMAX, + .hash_table = rsc_table, + .name = "auth.rpcsec.context", + .cache_put = rsc_put, + .cache_parse = rsc_parse, + .match = rsc_match, + .init = rsc_init, + .update = update_rsc, + .alloc = rsc_alloc, +}; + +static struct rsc *rsc_lookup(struct rsc *item) +{ + struct cache_head *ch; + int hash = rsc_hash(item); + + ch = sunrpc_cache_lookup(&rsc_cache, &item->h, hash); + if (ch) + return container_of(ch, struct rsc, h); + else + return NULL; +} + +static struct rsc *rsc_update(struct rsc *new, struct rsc *old) +{ + struct cache_head *ch; + int hash = rsc_hash(new); + + ch = sunrpc_cache_update(&rsc_cache, &new->h, + &old->h, hash); + if (ch) + return container_of(ch, struct rsc, h); + else + return NULL; +} + + +static struct rsc * +gss_svc_searchbyctx(struct xdr_netobj *handle) +{ + struct rsc rsci; + struct rsc *found; + + memset(&rsci, 0, sizeof(rsci)); + if (dup_to_netobj(&rsci.handle, handle->data, handle->len)) + return NULL; + found = rsc_lookup(&rsci); + rsc_free(&rsci); + if (!found) + return NULL; + if (cache_check(&rsc_cache, &found->h, NULL)) + return NULL; + return found; +} + +/* Implements sequence number algorithm as specified in RFC 2203. */ +static int +gss_check_seq_num(struct rsc *rsci, int seq_num) +{ + struct gss_svc_seq_data *sd = &rsci->seqdata; + + spin_lock(&sd->sd_lock); + if (seq_num > sd->sd_max) { + if (seq_num >= sd->sd_max + GSS_SEQ_WIN) { + memset(sd->sd_win,0,sizeof(sd->sd_win)); + sd->sd_max = seq_num; + } else while (sd->sd_max < seq_num) { + sd->sd_max++; + __clear_bit(sd->sd_max % GSS_SEQ_WIN, sd->sd_win); + } + __set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win); + goto ok; + } else if (seq_num <= sd->sd_max - GSS_SEQ_WIN) { + goto drop; + } + /* sd_max - GSS_SEQ_WIN < seq_num <= sd_max */ + if (__test_and_set_bit(seq_num % GSS_SEQ_WIN, sd->sd_win)) + goto drop; +ok: + spin_unlock(&sd->sd_lock); + return 1; +drop: + spin_unlock(&sd->sd_lock); + return 0; +} + +static inline u32 round_up_to_quad(u32 i) +{ + return (i + 3 ) & ~3; +} + +static inline int +svc_safe_getnetobj(struct kvec *argv, struct xdr_netobj *o) +{ + int l; + + if (argv->iov_len < 4) + return -1; + o->len = svc_getnl(argv); + l = round_up_to_quad(o->len); + if (argv->iov_len < l) + return -1; + o->data = argv->iov_base; + argv->iov_base += l; + argv->iov_len -= l; + return 0; +} + +static inline int +svc_safe_putnetobj(struct kvec *resv, struct xdr_netobj *o) +{ + u8 *p; + + if (resv->iov_len + 4 > PAGE_SIZE) + return -1; + svc_putnl(resv, o->len); + p = resv->iov_base + resv->iov_len; + resv->iov_len += round_up_to_quad(o->len); + if (resv->iov_len > PAGE_SIZE) + return -1; + memcpy(p, o->data, o->len); + memset(p + o->len, 0, round_up_to_quad(o->len) - o->len); + return 0; +} + +/* + * Verify the checksum on the header and return SVC_OK on success. + * Otherwise, return SVC_DROP (in the case of a bad sequence number) + * or return SVC_DENIED and indicate error in authp. + */ +static int +gss_verify_header(struct svc_rqst *rqstp, struct rsc *rsci, + __be32 *rpcstart, struct rpc_gss_wire_cred *gc, __be32 *authp) +{ + struct gss_ctx *ctx_id = rsci->mechctx; + struct xdr_buf rpchdr; + struct xdr_netobj checksum; + u32 flavor = 0; + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec iov; + + /* data to compute the checksum over: */ + iov.iov_base = rpcstart; + iov.iov_len = (u8 *)argv->iov_base - (u8 *)rpcstart; + xdr_buf_from_iov(&iov, &rpchdr); + + *authp = rpc_autherr_badverf; + if (argv->iov_len < 4) + return SVC_DENIED; + flavor = svc_getnl(argv); + if (flavor != RPC_AUTH_GSS) + return SVC_DENIED; + if (svc_safe_getnetobj(argv, &checksum)) + return SVC_DENIED; + + if (rqstp->rq_deferred) /* skip verification of revisited request */ + return SVC_OK; + if (gss_verify_mic(ctx_id, &rpchdr, &checksum) != GSS_S_COMPLETE) { + *authp = rpcsec_gsserr_credproblem; + return SVC_DENIED; + } + + if (gc->gc_seq > MAXSEQ) { + dprintk("RPC: svcauth_gss: discarding request with " + "large sequence number %d\n", gc->gc_seq); + *authp = rpcsec_gsserr_ctxproblem; + return SVC_DENIED; + } + if (!gss_check_seq_num(rsci, gc->gc_seq)) { + dprintk("RPC: svcauth_gss: discarding request with " + "old sequence number %d\n", gc->gc_seq); + return SVC_DROP; + } + return SVC_OK; +} + +static int +gss_write_null_verf(struct svc_rqst *rqstp) +{ + __be32 *p; + + svc_putnl(rqstp->rq_res.head, RPC_AUTH_NULL); + p = rqstp->rq_res.head->iov_base + rqstp->rq_res.head->iov_len; + /* don't really need to check if head->iov_len > PAGE_SIZE ... */ + *p++ = 0; + if (!xdr_ressize_check(rqstp, p)) + return -1; + return 0; +} + +static int +gss_write_verf(struct svc_rqst *rqstp, struct gss_ctx *ctx_id, u32 seq) +{ + __be32 xdr_seq; + u32 maj_stat; + struct xdr_buf verf_data; + struct xdr_netobj mic; + __be32 *p; + struct kvec iov; + + svc_putnl(rqstp->rq_res.head, RPC_AUTH_GSS); + xdr_seq = htonl(seq); + + iov.iov_base = &xdr_seq; + iov.iov_len = sizeof(xdr_seq); + xdr_buf_from_iov(&iov, &verf_data); + p = rqstp->rq_res.head->iov_base + rqstp->rq_res.head->iov_len; + mic.data = (u8 *)(p + 1); + maj_stat = gss_get_mic(ctx_id, &verf_data, &mic); + if (maj_stat != GSS_S_COMPLETE) + return -1; + *p++ = htonl(mic.len); + memset((u8 *)p + mic.len, 0, round_up_to_quad(mic.len) - mic.len); + p += XDR_QUADLEN(mic.len); + if (!xdr_ressize_check(rqstp, p)) + return -1; + return 0; +} + +struct gss_domain { + struct auth_domain h; + u32 pseudoflavor; +}; + +static struct auth_domain * +find_gss_auth_domain(struct gss_ctx *ctx, u32 svc) +{ + char *name; + + name = gss_service_to_auth_domain_name(ctx->mech_type, svc); + if (!name) + return NULL; + return auth_domain_find(name); +} + +static struct auth_ops svcauthops_gss; + +u32 svcauth_gss_flavor(struct auth_domain *dom) +{ + struct gss_domain *gd = container_of(dom, struct gss_domain, h); + + return gd->pseudoflavor; +} + +EXPORT_SYMBOL(svcauth_gss_flavor); + +int +svcauth_gss_register_pseudoflavor(u32 pseudoflavor, char * name) +{ + struct gss_domain *new; + struct auth_domain *test; + int stat = -ENOMEM; + + new = kmalloc(sizeof(*new), GFP_KERNEL); + if (!new) + goto out; + kref_init(&new->h.ref); + new->h.name = kstrdup(name, GFP_KERNEL); + if (!new->h.name) + goto out_free_dom; + new->h.flavour = &svcauthops_gss; + new->pseudoflavor = pseudoflavor; + + stat = 0; + test = auth_domain_lookup(name, &new->h); + if (test != &new->h) { /* Duplicate registration */ + auth_domain_put(test); + kfree(new->h.name); + goto out_free_dom; + } + return 0; + +out_free_dom: + kfree(new); +out: + return stat; +} + +EXPORT_SYMBOL(svcauth_gss_register_pseudoflavor); + +static inline int +read_u32_from_xdr_buf(struct xdr_buf *buf, int base, u32 *obj) +{ + __be32 raw; + int status; + + status = read_bytes_from_xdr_buf(buf, base, &raw, sizeof(*obj)); + if (status) + return status; + *obj = ntohl(raw); + return 0; +} + +/* It would be nice if this bit of code could be shared with the client. + * Obstacles: + * The client shouldn't malloc(), would have to pass in own memory. + * The server uses base of head iovec as read pointer, while the + * client uses separate pointer. */ +static int +unwrap_integ_data(struct xdr_buf *buf, u32 seq, struct gss_ctx *ctx) +{ + int stat = -EINVAL; + u32 integ_len, maj_stat; + struct xdr_netobj mic; + struct xdr_buf integ_buf; + + integ_len = svc_getnl(&buf->head[0]); + if (integ_len & 3) + return stat; + if (integ_len > buf->len) + return stat; + if (xdr_buf_subsegment(buf, &integ_buf, 0, integ_len)) + BUG(); + /* copy out mic... */ + if (read_u32_from_xdr_buf(buf, integ_len, &mic.len)) + BUG(); + if (mic.len > RPC_MAX_AUTH_SIZE) + return stat; + mic.data = kmalloc(mic.len, GFP_KERNEL); + if (!mic.data) + return stat; + if (read_bytes_from_xdr_buf(buf, integ_len + 4, mic.data, mic.len)) + goto out; + maj_stat = gss_verify_mic(ctx, &integ_buf, &mic); + if (maj_stat != GSS_S_COMPLETE) + goto out; + if (svc_getnl(&buf->head[0]) != seq) + goto out; + stat = 0; +out: + kfree(mic.data); + return stat; +} + +static inline int +total_buf_len(struct xdr_buf *buf) +{ + return buf->head[0].iov_len + buf->page_len + buf->tail[0].iov_len; +} + +static void +fix_priv_head(struct xdr_buf *buf, int pad) +{ + if (buf->page_len == 0) { + /* We need to adjust head and buf->len in tandem in this + * case to make svc_defer() work--it finds the original + * buffer start using buf->len - buf->head[0].iov_len. */ + buf->head[0].iov_len -= pad; + } +} + +static int +unwrap_priv_data(struct svc_rqst *rqstp, struct xdr_buf *buf, u32 seq, struct gss_ctx *ctx) +{ + u32 priv_len, maj_stat; + int pad, saved_len, remaining_len, offset; + + rqstp->rq_splice_ok = 0; + + priv_len = svc_getnl(&buf->head[0]); + if (rqstp->rq_deferred) { + /* Already decrypted last time through! The sequence number + * check at out_seq is unnecessary but harmless: */ + goto out_seq; + } + /* buf->len is the number of bytes from the original start of the + * request to the end, where head[0].iov_len is just the bytes + * not yet read from the head, so these two values are different: */ + remaining_len = total_buf_len(buf); + if (priv_len > remaining_len) + return -EINVAL; + pad = remaining_len - priv_len; + buf->len -= pad; + fix_priv_head(buf, pad); + + /* Maybe it would be better to give gss_unwrap a length parameter: */ + saved_len = buf->len; + buf->len = priv_len; + maj_stat = gss_unwrap(ctx, 0, buf); + pad = priv_len - buf->len; + buf->len = saved_len; + buf->len -= pad; + /* The upper layers assume the buffer is aligned on 4-byte boundaries. + * In the krb5p case, at least, the data ends up offset, so we need to + * move it around. */ + /* XXX: This is very inefficient. It would be better to either do + * this while we encrypt, or maybe in the receive code, if we can peak + * ahead and work out the service and mechanism there. */ + offset = buf->head[0].iov_len % 4; + if (offset) { + buf->buflen = RPCSVC_MAXPAYLOAD; + xdr_shift_buf(buf, offset); + fix_priv_head(buf, pad); + } + if (maj_stat != GSS_S_COMPLETE) + return -EINVAL; +out_seq: + if (svc_getnl(&buf->head[0]) != seq) + return -EINVAL; + return 0; +} + +struct gss_svc_data { + /* decoded gss client cred: */ + struct rpc_gss_wire_cred clcred; + /* save a pointer to the beginning of the encoded verifier, + * for use in encryption/checksumming in svcauth_gss_release: */ + __be32 *verf_start; + struct rsc *rsci; +}; + +static int +svcauth_gss_set_client(struct svc_rqst *rqstp) +{ + struct gss_svc_data *svcdata = rqstp->rq_auth_data; + struct rsc *rsci = svcdata->rsci; + struct rpc_gss_wire_cred *gc = &svcdata->clcred; + int stat; + + /* + * A gss export can be specified either by: + * export *(sec=krb5,rw) + * or by + * export gss/krb5(rw) + * The latter is deprecated; but for backwards compatibility reasons + * the nfsd code will still fall back on trying it if the former + * doesn't work; so we try to make both available to nfsd, below. + */ + rqstp->rq_gssclient = find_gss_auth_domain(rsci->mechctx, gc->gc_svc); + if (rqstp->rq_gssclient == NULL) + return SVC_DENIED; + stat = svcauth_unix_set_client(rqstp); + if (stat == SVC_DROP) + return stat; + return SVC_OK; +} + +static inline int +gss_write_init_verf(struct svc_rqst *rqstp, struct rsi *rsip) +{ + struct rsc *rsci; + int rc; + + if (rsip->major_status != GSS_S_COMPLETE) + return gss_write_null_verf(rqstp); + rsci = gss_svc_searchbyctx(&rsip->out_handle); + if (rsci == NULL) { + rsip->major_status = GSS_S_NO_CONTEXT; + return gss_write_null_verf(rqstp); + } + rc = gss_write_verf(rqstp, rsci->mechctx, GSS_SEQ_WIN); + cache_put(&rsci->h, &rsc_cache); + return rc; +} + +/* + * Having read the cred already and found we're in the context + * initiation case, read the verifier and initiate (or check the results + * of) upcalls to userspace for help with context initiation. If + * the upcall results are available, write the verifier and result. + * Otherwise, drop the request pending an answer to the upcall. + */ +static int svcauth_gss_handle_init(struct svc_rqst *rqstp, + struct rpc_gss_wire_cred *gc, __be32 *authp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + struct xdr_netobj tmpobj; + struct rsi *rsip, rsikey; + int ret; + + /* Read the verifier; should be NULL: */ + *authp = rpc_autherr_badverf; + if (argv->iov_len < 2 * 4) + return SVC_DENIED; + if (svc_getnl(argv) != RPC_AUTH_NULL) + return SVC_DENIED; + if (svc_getnl(argv) != 0) + return SVC_DENIED; + + /* Martial context handle and token for upcall: */ + *authp = rpc_autherr_badcred; + if (gc->gc_proc == RPC_GSS_PROC_INIT && gc->gc_ctx.len != 0) + return SVC_DENIED; + memset(&rsikey, 0, sizeof(rsikey)); + if (dup_netobj(&rsikey.in_handle, &gc->gc_ctx)) + return SVC_DROP; + *authp = rpc_autherr_badverf; + if (svc_safe_getnetobj(argv, &tmpobj)) { + kfree(rsikey.in_handle.data); + return SVC_DENIED; + } + if (dup_netobj(&rsikey.in_token, &tmpobj)) { + kfree(rsikey.in_handle.data); + return SVC_DROP; + } + + /* Perform upcall, or find upcall result: */ + rsip = rsi_lookup(&rsikey); + rsi_free(&rsikey); + if (!rsip) + return SVC_DROP; + switch (cache_check(&rsi_cache, &rsip->h, &rqstp->rq_chandle)) { + case -EAGAIN: + case -ETIMEDOUT: + case -ENOENT: + /* No upcall result: */ + return SVC_DROP; + case 0: + ret = SVC_DROP; + /* Got an answer to the upcall; use it: */ + if (gss_write_init_verf(rqstp, rsip)) + goto out; + if (resv->iov_len + 4 > PAGE_SIZE) + goto out; + svc_putnl(resv, RPC_SUCCESS); + if (svc_safe_putnetobj(resv, &rsip->out_handle)) + goto out; + if (resv->iov_len + 3 * 4 > PAGE_SIZE) + goto out; + svc_putnl(resv, rsip->major_status); + svc_putnl(resv, rsip->minor_status); + svc_putnl(resv, GSS_SEQ_WIN); + if (svc_safe_putnetobj(resv, &rsip->out_token)) + goto out; + } + ret = SVC_COMPLETE; +out: + cache_put(&rsip->h, &rsi_cache); + return ret; +} + +/* + * Accept an rpcsec packet. + * If context establishment, punt to user space + * If data exchange, verify/decrypt + * If context destruction, handle here + * In the context establishment and destruction case we encode + * response here and return SVC_COMPLETE. + */ +static int +svcauth_gss_accept(struct svc_rqst *rqstp, __be32 *authp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + u32 crlen; + struct gss_svc_data *svcdata = rqstp->rq_auth_data; + struct rpc_gss_wire_cred *gc; + struct rsc *rsci = NULL; + __be32 *rpcstart; + __be32 *reject_stat = resv->iov_base + resv->iov_len; + int ret; + + dprintk("RPC: svcauth_gss: argv->iov_len = %zd\n", + argv->iov_len); + + *authp = rpc_autherr_badcred; + if (!svcdata) + svcdata = kmalloc(sizeof(*svcdata), GFP_KERNEL); + if (!svcdata) + goto auth_err; + rqstp->rq_auth_data = svcdata; + svcdata->verf_start = NULL; + svcdata->rsci = NULL; + gc = &svcdata->clcred; + + /* start of rpc packet is 7 u32's back from here: + * xid direction rpcversion prog vers proc flavour + */ + rpcstart = argv->iov_base; + rpcstart -= 7; + + /* credential is: + * version(==1), proc(0,1,2,3), seq, service (1,2,3), handle + * at least 5 u32s, and is preceeded by length, so that makes 6. + */ + + if (argv->iov_len < 5 * 4) + goto auth_err; + crlen = svc_getnl(argv); + if (svc_getnl(argv) != RPC_GSS_VERSION) + goto auth_err; + gc->gc_proc = svc_getnl(argv); + gc->gc_seq = svc_getnl(argv); + gc->gc_svc = svc_getnl(argv); + if (svc_safe_getnetobj(argv, &gc->gc_ctx)) + goto auth_err; + if (crlen != round_up_to_quad(gc->gc_ctx.len) + 5 * 4) + goto auth_err; + + if ((gc->gc_proc != RPC_GSS_PROC_DATA) && (rqstp->rq_proc != 0)) + goto auth_err; + + *authp = rpc_autherr_badverf; + switch (gc->gc_proc) { + case RPC_GSS_PROC_INIT: + case RPC_GSS_PROC_CONTINUE_INIT: + return svcauth_gss_handle_init(rqstp, gc, authp); + case RPC_GSS_PROC_DATA: + case RPC_GSS_PROC_DESTROY: + /* Look up the context, and check the verifier: */ + *authp = rpcsec_gsserr_credproblem; + rsci = gss_svc_searchbyctx(&gc->gc_ctx); + if (!rsci) + goto auth_err; + switch (gss_verify_header(rqstp, rsci, rpcstart, gc, authp)) { + case SVC_OK: + break; + case SVC_DENIED: + goto auth_err; + case SVC_DROP: + goto drop; + } + break; + default: + *authp = rpc_autherr_rejectedcred; + goto auth_err; + } + + /* now act upon the command: */ + switch (gc->gc_proc) { + case RPC_GSS_PROC_DESTROY: + if (gss_write_verf(rqstp, rsci->mechctx, gc->gc_seq)) + goto auth_err; + rsci->h.expiry_time = get_seconds(); + set_bit(CACHE_NEGATIVE, &rsci->h.flags); + if (resv->iov_len + 4 > PAGE_SIZE) + goto drop; + svc_putnl(resv, RPC_SUCCESS); + goto complete; + case RPC_GSS_PROC_DATA: + *authp = rpcsec_gsserr_ctxproblem; + svcdata->verf_start = resv->iov_base + resv->iov_len; + if (gss_write_verf(rqstp, rsci->mechctx, gc->gc_seq)) + goto auth_err; + rqstp->rq_cred = rsci->cred; + get_group_info(rsci->cred.cr_group_info); + *authp = rpc_autherr_badcred; + switch (gc->gc_svc) { + case RPC_GSS_SVC_NONE: + break; + case RPC_GSS_SVC_INTEGRITY: + /* placeholders for length and seq. number: */ + svc_putnl(resv, 0); + svc_putnl(resv, 0); + if (unwrap_integ_data(&rqstp->rq_arg, + gc->gc_seq, rsci->mechctx)) + goto garbage_args; + break; + case RPC_GSS_SVC_PRIVACY: + /* placeholders for length and seq. number: */ + svc_putnl(resv, 0); + svc_putnl(resv, 0); + if (unwrap_priv_data(rqstp, &rqstp->rq_arg, + gc->gc_seq, rsci->mechctx)) + goto garbage_args; + break; + default: + goto auth_err; + } + svcdata->rsci = rsci; + cache_get(&rsci->h); + rqstp->rq_flavor = gss_svc_to_pseudoflavor( + rsci->mechctx->mech_type, gc->gc_svc); + ret = SVC_OK; + goto out; + } +garbage_args: + ret = SVC_GARBAGE; + goto out; +auth_err: + /* Restore write pointer to its original value: */ + xdr_ressize_check(rqstp, reject_stat); + ret = SVC_DENIED; + goto out; +complete: + ret = SVC_COMPLETE; + goto out; +drop: + ret = SVC_DROP; +out: + if (rsci) + cache_put(&rsci->h, &rsc_cache); + return ret; +} + +static __be32 * +svcauth_gss_prepare_to_wrap(struct xdr_buf *resbuf, struct gss_svc_data *gsd) +{ + __be32 *p; + u32 verf_len; + + p = gsd->verf_start; + gsd->verf_start = NULL; + + /* If the reply stat is nonzero, don't wrap: */ + if (*(p-1) != rpc_success) + return NULL; + /* Skip the verifier: */ + p += 1; + verf_len = ntohl(*p++); + p += XDR_QUADLEN(verf_len); + /* move accept_stat to right place: */ + memcpy(p, p + 2, 4); + /* Also don't wrap if the accept stat is nonzero: */ + if (*p != rpc_success) { + resbuf->head[0].iov_len -= 2 * 4; + return NULL; + } + p++; + return p; +} + +static inline int +svcauth_gss_wrap_resp_integ(struct svc_rqst *rqstp) +{ + struct gss_svc_data *gsd = (struct gss_svc_data *)rqstp->rq_auth_data; + struct rpc_gss_wire_cred *gc = &gsd->clcred; + struct xdr_buf *resbuf = &rqstp->rq_res; + struct xdr_buf integ_buf; + struct xdr_netobj mic; + struct kvec *resv; + __be32 *p; + int integ_offset, integ_len; + int stat = -EINVAL; + + p = svcauth_gss_prepare_to_wrap(resbuf, gsd); + if (p == NULL) + goto out; + integ_offset = (u8 *)(p + 1) - (u8 *)resbuf->head[0].iov_base; + integ_len = resbuf->len - integ_offset; + BUG_ON(integ_len % 4); + *p++ = htonl(integ_len); + *p++ = htonl(gc->gc_seq); + if (xdr_buf_subsegment(resbuf, &integ_buf, integ_offset, + integ_len)) + BUG(); + if (resbuf->tail[0].iov_base == NULL) { + if (resbuf->head[0].iov_len + RPC_MAX_AUTH_SIZE > PAGE_SIZE) + goto out_err; + resbuf->tail[0].iov_base = resbuf->head[0].iov_base + + resbuf->head[0].iov_len; + resbuf->tail[0].iov_len = 0; + resv = &resbuf->tail[0]; + } else { + resv = &resbuf->tail[0]; + } + mic.data = (u8 *)resv->iov_base + resv->iov_len + 4; + if (gss_get_mic(gsd->rsci->mechctx, &integ_buf, &mic)) + goto out_err; + svc_putnl(resv, mic.len); + memset(mic.data + mic.len, 0, + round_up_to_quad(mic.len) - mic.len); + resv->iov_len += XDR_QUADLEN(mic.len) << 2; + /* not strictly required: */ + resbuf->len += XDR_QUADLEN(mic.len) << 2; + BUG_ON(resv->iov_len > PAGE_SIZE); +out: + stat = 0; +out_err: + return stat; +} + +static inline int +svcauth_gss_wrap_resp_priv(struct svc_rqst *rqstp) +{ + struct gss_svc_data *gsd = (struct gss_svc_data *)rqstp->rq_auth_data; + struct rpc_gss_wire_cred *gc = &gsd->clcred; + struct xdr_buf *resbuf = &rqstp->rq_res; + struct page **inpages = NULL; + __be32 *p, *len; + int offset; + int pad; + + p = svcauth_gss_prepare_to_wrap(resbuf, gsd); + if (p == NULL) + return 0; + len = p++; + offset = (u8 *)p - (u8 *)resbuf->head[0].iov_base; + *p++ = htonl(gc->gc_seq); + inpages = resbuf->pages; + /* XXX: Would be better to write some xdr helper functions for + * nfs{2,3,4}xdr.c that place the data right, instead of copying: */ + if (resbuf->tail[0].iov_base) { + BUG_ON(resbuf->tail[0].iov_base >= resbuf->head[0].iov_base + + PAGE_SIZE); + BUG_ON(resbuf->tail[0].iov_base < resbuf->head[0].iov_base); + if (resbuf->tail[0].iov_len + resbuf->head[0].iov_len + + 2 * RPC_MAX_AUTH_SIZE > PAGE_SIZE) + return -ENOMEM; + memmove(resbuf->tail[0].iov_base + RPC_MAX_AUTH_SIZE, + resbuf->tail[0].iov_base, + resbuf->tail[0].iov_len); + resbuf->tail[0].iov_base += RPC_MAX_AUTH_SIZE; + } + if (resbuf->tail[0].iov_base == NULL) { + if (resbuf->head[0].iov_len + 2*RPC_MAX_AUTH_SIZE > PAGE_SIZE) + return -ENOMEM; + resbuf->tail[0].iov_base = resbuf->head[0].iov_base + + resbuf->head[0].iov_len + RPC_MAX_AUTH_SIZE; + resbuf->tail[0].iov_len = 0; + } + if (gss_wrap(gsd->rsci->mechctx, offset, resbuf, inpages)) + return -ENOMEM; + *len = htonl(resbuf->len - offset); + pad = 3 - ((resbuf->len - offset - 1)&3); + p = (__be32 *)(resbuf->tail[0].iov_base + resbuf->tail[0].iov_len); + memset(p, 0, pad); + resbuf->tail[0].iov_len += pad; + resbuf->len += pad; + return 0; +} + +static int +svcauth_gss_release(struct svc_rqst *rqstp) +{ + struct gss_svc_data *gsd = (struct gss_svc_data *)rqstp->rq_auth_data; + struct rpc_gss_wire_cred *gc = &gsd->clcred; + struct xdr_buf *resbuf = &rqstp->rq_res; + int stat = -EINVAL; + + if (gc->gc_proc != RPC_GSS_PROC_DATA) + goto out; + /* Release can be called twice, but we only wrap once. */ + if (gsd->verf_start == NULL) + goto out; + /* normally not set till svc_send, but we need it here: */ + /* XXX: what for? Do we mess it up the moment we call svc_putu32 + * or whatever? */ + resbuf->len = total_buf_len(resbuf); + switch (gc->gc_svc) { + case RPC_GSS_SVC_NONE: + break; + case RPC_GSS_SVC_INTEGRITY: + stat = svcauth_gss_wrap_resp_integ(rqstp); + if (stat) + goto out_err; + break; + case RPC_GSS_SVC_PRIVACY: + stat = svcauth_gss_wrap_resp_priv(rqstp); + if (stat) + goto out_err; + break; + default: + goto out_err; + } + +out: + stat = 0; +out_err: + if (rqstp->rq_client) + auth_domain_put(rqstp->rq_client); + rqstp->rq_client = NULL; + if (rqstp->rq_gssclient) + auth_domain_put(rqstp->rq_gssclient); + rqstp->rq_gssclient = NULL; + if (rqstp->rq_cred.cr_group_info) + put_group_info(rqstp->rq_cred.cr_group_info); + rqstp->rq_cred.cr_group_info = NULL; + if (gsd->rsci) + cache_put(&gsd->rsci->h, &rsc_cache); + gsd->rsci = NULL; + + return stat; +} + +static void +svcauth_gss_domain_release(struct auth_domain *dom) +{ + struct gss_domain *gd = container_of(dom, struct gss_domain, h); + + kfree(dom->name); + kfree(gd); +} + +static struct auth_ops svcauthops_gss = { + .name = "rpcsec_gss", + .owner = THIS_MODULE, + .flavour = RPC_AUTH_GSS, + .accept = svcauth_gss_accept, + .release = svcauth_gss_release, + .domain_release = svcauth_gss_domain_release, + .set_client = svcauth_gss_set_client, +}; + +int +gss_svc_init(void) +{ + int rv = svc_auth_register(RPC_AUTH_GSS, &svcauthops_gss); + if (rv) + return rv; + rv = cache_register(&rsc_cache); + if (rv) + goto out1; + rv = cache_register(&rsi_cache); + if (rv) + goto out2; + return 0; +out2: + cache_unregister(&rsc_cache); +out1: + svc_auth_unregister(RPC_AUTH_GSS); + return rv; +} + +void +gss_svc_shutdown(void) +{ + cache_unregister(&rsc_cache); + cache_unregister(&rsi_cache); + svc_auth_unregister(RPC_AUTH_GSS); +} diff --git a/net/sunrpc/auth_null.c b/net/sunrpc/auth_null.c new file mode 100644 index 0000000..c70dd7f --- /dev/null +++ b/net/sunrpc/auth_null.c @@ -0,0 +1,143 @@ +/* + * linux/net/sunrpc/auth_null.c + * + * AUTH_NULL authentication. Really :-) + * + * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/types.h> +#include <linux/module.h> +#include <linux/utsname.h> +#include <linux/sunrpc/clnt.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static struct rpc_auth null_auth; +static struct rpc_cred null_cred; + +static struct rpc_auth * +nul_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor) +{ + atomic_inc(&null_auth.au_count); + return &null_auth; +} + +static void +nul_destroy(struct rpc_auth *auth) +{ +} + +/* + * Lookup NULL creds for current process + */ +static struct rpc_cred * +nul_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +{ + return get_rpccred(&null_cred); +} + +/* + * Destroy cred handle. + */ +static void +nul_destroy_cred(struct rpc_cred *cred) +{ +} + +/* + * Match cred handle against current process + */ +static int +nul_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags) +{ + return 1; +} + +/* + * Marshal credential. + */ +static __be32 * +nul_marshal(struct rpc_task *task, __be32 *p) +{ + *p++ = htonl(RPC_AUTH_NULL); + *p++ = 0; + *p++ = htonl(RPC_AUTH_NULL); + *p++ = 0; + + return p; +} + +/* + * Refresh credential. This is a no-op for AUTH_NULL + */ +static int +nul_refresh(struct rpc_task *task) +{ + set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_msg.rpc_cred->cr_flags); + return 0; +} + +static __be32 * +nul_validate(struct rpc_task *task, __be32 *p) +{ + rpc_authflavor_t flavor; + u32 size; + + flavor = ntohl(*p++); + if (flavor != RPC_AUTH_NULL) { + printk("RPC: bad verf flavor: %u\n", flavor); + return NULL; + } + + size = ntohl(*p++); + if (size != 0) { + printk("RPC: bad verf size: %u\n", size); + return NULL; + } + + return p; +} + +const struct rpc_authops authnull_ops = { + .owner = THIS_MODULE, + .au_flavor = RPC_AUTH_NULL, + .au_name = "NULL", + .create = nul_create, + .destroy = nul_destroy, + .lookup_cred = nul_lookup_cred, +}; + +static +struct rpc_auth null_auth = { + .au_cslack = 4, + .au_rslack = 2, + .au_ops = &authnull_ops, + .au_flavor = RPC_AUTH_NULL, + .au_count = ATOMIC_INIT(0), +}; + +static +const struct rpc_credops null_credops = { + .cr_name = "AUTH_NULL", + .crdestroy = nul_destroy_cred, + .crbind = rpcauth_generic_bind_cred, + .crmatch = nul_match, + .crmarshal = nul_marshal, + .crrefresh = nul_refresh, + .crvalidate = nul_validate, +}; + +static +struct rpc_cred null_cred = { + .cr_lru = LIST_HEAD_INIT(null_cred.cr_lru), + .cr_auth = &null_auth, + .cr_ops = &null_credops, + .cr_count = ATOMIC_INIT(1), + .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE, +#ifdef RPC_DEBUG + .cr_magic = RPCAUTH_CRED_MAGIC, +#endif +}; diff --git a/net/sunrpc/auth_unix.c b/net/sunrpc/auth_unix.c new file mode 100644 index 0000000..46b2647 --- /dev/null +++ b/net/sunrpc/auth_unix.c @@ -0,0 +1,243 @@ +/* + * linux/net/sunrpc/auth_unix.c + * + * UNIX-style authentication; no AUTH_SHORT support + * + * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/types.h> +#include <linux/sched.h> +#include <linux/module.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/auth.h> + +#define NFS_NGROUPS 16 + +struct unx_cred { + struct rpc_cred uc_base; + gid_t uc_gid; + gid_t uc_gids[NFS_NGROUPS]; +}; +#define uc_uid uc_base.cr_uid + +#define UNX_WRITESLACK (21 + (UNX_MAXNODENAME >> 2)) + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_AUTH +#endif + +static struct rpc_auth unix_auth; +static struct rpc_cred_cache unix_cred_cache; +static const struct rpc_credops unix_credops; + +static struct rpc_auth * +unx_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor) +{ + dprintk("RPC: creating UNIX authenticator for client %p\n", + clnt); + atomic_inc(&unix_auth.au_count); + return &unix_auth; +} + +static void +unx_destroy(struct rpc_auth *auth) +{ + dprintk("RPC: destroying UNIX authenticator %p\n", auth); + rpcauth_clear_credcache(auth->au_credcache); +} + +/* + * Lookup AUTH_UNIX creds for current process + */ +static struct rpc_cred * +unx_lookup_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +{ + return rpcauth_lookup_credcache(auth, acred, flags); +} + +static struct rpc_cred * +unx_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags) +{ + struct unx_cred *cred; + unsigned int groups = 0; + unsigned int i; + + dprintk("RPC: allocating UNIX cred for uid %d gid %d\n", + acred->uid, acred->gid); + + if (!(cred = kmalloc(sizeof(*cred), GFP_NOFS))) + return ERR_PTR(-ENOMEM); + + rpcauth_init_cred(&cred->uc_base, acred, auth, &unix_credops); + cred->uc_base.cr_flags = 1UL << RPCAUTH_CRED_UPTODATE; + + if (acred->group_info != NULL) + groups = acred->group_info->ngroups; + if (groups > NFS_NGROUPS) + groups = NFS_NGROUPS; + + cred->uc_gid = acred->gid; + for (i = 0; i < groups; i++) + cred->uc_gids[i] = GROUP_AT(acred->group_info, i); + if (i < NFS_NGROUPS) + cred->uc_gids[i] = NOGROUP; + + return &cred->uc_base; +} + +static void +unx_free_cred(struct unx_cred *unx_cred) +{ + dprintk("RPC: unx_free_cred %p\n", unx_cred); + kfree(unx_cred); +} + +static void +unx_free_cred_callback(struct rcu_head *head) +{ + struct unx_cred *unx_cred = container_of(head, struct unx_cred, uc_base.cr_rcu); + unx_free_cred(unx_cred); +} + +static void +unx_destroy_cred(struct rpc_cred *cred) +{ + call_rcu(&cred->cr_rcu, unx_free_cred_callback); +} + +/* + * Match credentials against current process creds. + * The root_override argument takes care of cases where the caller may + * request root creds (e.g. for NFS swapping). + */ +static int +unx_match(struct auth_cred *acred, struct rpc_cred *rcred, int flags) +{ + struct unx_cred *cred = container_of(rcred, struct unx_cred, uc_base); + unsigned int groups = 0; + unsigned int i; + + + if (cred->uc_uid != acred->uid || cred->uc_gid != acred->gid) + return 0; + + if (acred->group_info != NULL) + groups = acred->group_info->ngroups; + if (groups > NFS_NGROUPS) + groups = NFS_NGROUPS; + for (i = 0; i < groups ; i++) + if (cred->uc_gids[i] != GROUP_AT(acred->group_info, i)) + return 0; + return 1; +} + +/* + * Marshal credentials. + * Maybe we should keep a cached credential for performance reasons. + */ +static __be32 * +unx_marshal(struct rpc_task *task, __be32 *p) +{ + struct rpc_clnt *clnt = task->tk_client; + struct unx_cred *cred = container_of(task->tk_msg.rpc_cred, struct unx_cred, uc_base); + __be32 *base, *hold; + int i; + + *p++ = htonl(RPC_AUTH_UNIX); + base = p++; + *p++ = htonl(jiffies/HZ); + + /* + * Copy the UTS nodename captured when the client was created. + */ + p = xdr_encode_array(p, clnt->cl_nodename, clnt->cl_nodelen); + + *p++ = htonl((u32) cred->uc_uid); + *p++ = htonl((u32) cred->uc_gid); + hold = p++; + for (i = 0; i < 16 && cred->uc_gids[i] != (gid_t) NOGROUP; i++) + *p++ = htonl((u32) cred->uc_gids[i]); + *hold = htonl(p - hold - 1); /* gid array length */ + *base = htonl((p - base - 1) << 2); /* cred length */ + + *p++ = htonl(RPC_AUTH_NULL); + *p++ = htonl(0); + + return p; +} + +/* + * Refresh credentials. This is a no-op for AUTH_UNIX + */ +static int +unx_refresh(struct rpc_task *task) +{ + set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_msg.rpc_cred->cr_flags); + return 0; +} + +static __be32 * +unx_validate(struct rpc_task *task, __be32 *p) +{ + rpc_authflavor_t flavor; + u32 size; + + flavor = ntohl(*p++); + if (flavor != RPC_AUTH_NULL && + flavor != RPC_AUTH_UNIX && + flavor != RPC_AUTH_SHORT) { + printk("RPC: bad verf flavor: %u\n", flavor); + return NULL; + } + + size = ntohl(*p++); + if (size > RPC_MAX_AUTH_SIZE) { + printk("RPC: giant verf size: %u\n", size); + return NULL; + } + task->tk_msg.rpc_cred->cr_auth->au_rslack = (size >> 2) + 2; + p += (size >> 2); + + return p; +} + +void __init rpc_init_authunix(void) +{ + spin_lock_init(&unix_cred_cache.lock); +} + +const struct rpc_authops authunix_ops = { + .owner = THIS_MODULE, + .au_flavor = RPC_AUTH_UNIX, + .au_name = "UNIX", + .create = unx_create, + .destroy = unx_destroy, + .lookup_cred = unx_lookup_cred, + .crcreate = unx_create_cred, +}; + +static +struct rpc_cred_cache unix_cred_cache = { +}; + +static +struct rpc_auth unix_auth = { + .au_cslack = UNX_WRITESLACK, + .au_rslack = 2, /* assume AUTH_NULL verf */ + .au_ops = &authunix_ops, + .au_flavor = RPC_AUTH_UNIX, + .au_count = ATOMIC_INIT(0), + .au_credcache = &unix_cred_cache, +}; + +static +const struct rpc_credops unix_credops = { + .cr_name = "AUTH_UNIX", + .crdestroy = unx_destroy_cred, + .crbind = rpcauth_generic_bind_cred, + .crmatch = unx_match, + .crmarshal = unx_marshal, + .crrefresh = unx_refresh, + .crvalidate = unx_validate, +}; diff --git a/net/sunrpc/cache.c b/net/sunrpc/cache.c new file mode 100644 index 0000000..c996671 --- /dev/null +++ b/net/sunrpc/cache.c @@ -0,0 +1,1318 @@ +/* + * net/sunrpc/cache.c + * + * Generic code for various authentication-related caches + * used by sunrpc clients and servers. + * + * Copyright (C) 2002 Neil Brown <neilb@cse.unsw.edu.au> + * + * Released under terms in GPL version 2. See COPYING. + * + */ + +#include <linux/types.h> +#include <linux/fs.h> +#include <linux/file.h> +#include <linux/slab.h> +#include <linux/signal.h> +#include <linux/sched.h> +#include <linux/kmod.h> +#include <linux/list.h> +#include <linux/module.h> +#include <linux/ctype.h> +#include <asm/uaccess.h> +#include <linux/poll.h> +#include <linux/seq_file.h> +#include <linux/proc_fs.h> +#include <linux/net.h> +#include <linux/workqueue.h> +#include <linux/mutex.h> +#include <asm/ioctls.h> +#include <linux/sunrpc/types.h> +#include <linux/sunrpc/cache.h> +#include <linux/sunrpc/stats.h> + +#define RPCDBG_FACILITY RPCDBG_CACHE + +static int cache_defer_req(struct cache_req *req, struct cache_head *item); +static void cache_revisit_request(struct cache_head *item); + +static void cache_init(struct cache_head *h) +{ + time_t now = get_seconds(); + h->next = NULL; + h->flags = 0; + kref_init(&h->ref); + h->expiry_time = now + CACHE_NEW_EXPIRY; + h->last_refresh = now; +} + +struct cache_head *sunrpc_cache_lookup(struct cache_detail *detail, + struct cache_head *key, int hash) +{ + struct cache_head **head, **hp; + struct cache_head *new = NULL; + + head = &detail->hash_table[hash]; + + read_lock(&detail->hash_lock); + + for (hp=head; *hp != NULL ; hp = &(*hp)->next) { + struct cache_head *tmp = *hp; + if (detail->match(tmp, key)) { + cache_get(tmp); + read_unlock(&detail->hash_lock); + return tmp; + } + } + read_unlock(&detail->hash_lock); + /* Didn't find anything, insert an empty entry */ + + new = detail->alloc(); + if (!new) + return NULL; + /* must fully initialise 'new', else + * we might get lose if we need to + * cache_put it soon. + */ + cache_init(new); + detail->init(new, key); + + write_lock(&detail->hash_lock); + + /* check if entry appeared while we slept */ + for (hp=head; *hp != NULL ; hp = &(*hp)->next) { + struct cache_head *tmp = *hp; + if (detail->match(tmp, key)) { + cache_get(tmp); + write_unlock(&detail->hash_lock); + cache_put(new, detail); + return tmp; + } + } + new->next = *head; + *head = new; + detail->entries++; + cache_get(new); + write_unlock(&detail->hash_lock); + + return new; +} +EXPORT_SYMBOL(sunrpc_cache_lookup); + + +static void queue_loose(struct cache_detail *detail, struct cache_head *ch); + +static int cache_fresh_locked(struct cache_head *head, time_t expiry) +{ + head->expiry_time = expiry; + head->last_refresh = get_seconds(); + return !test_and_set_bit(CACHE_VALID, &head->flags); +} + +static void cache_fresh_unlocked(struct cache_head *head, + struct cache_detail *detail, int new) +{ + if (new) + cache_revisit_request(head); + if (test_and_clear_bit(CACHE_PENDING, &head->flags)) { + cache_revisit_request(head); + queue_loose(detail, head); + } +} + +struct cache_head *sunrpc_cache_update(struct cache_detail *detail, + struct cache_head *new, struct cache_head *old, int hash) +{ + /* The 'old' entry is to be replaced by 'new'. + * If 'old' is not VALID, we update it directly, + * otherwise we need to replace it + */ + struct cache_head **head; + struct cache_head *tmp; + int is_new; + + if (!test_bit(CACHE_VALID, &old->flags)) { + write_lock(&detail->hash_lock); + if (!test_bit(CACHE_VALID, &old->flags)) { + if (test_bit(CACHE_NEGATIVE, &new->flags)) + set_bit(CACHE_NEGATIVE, &old->flags); + else + detail->update(old, new); + is_new = cache_fresh_locked(old, new->expiry_time); + write_unlock(&detail->hash_lock); + cache_fresh_unlocked(old, detail, is_new); + return old; + } + write_unlock(&detail->hash_lock); + } + /* We need to insert a new entry */ + tmp = detail->alloc(); + if (!tmp) { + cache_put(old, detail); + return NULL; + } + cache_init(tmp); + detail->init(tmp, old); + head = &detail->hash_table[hash]; + + write_lock(&detail->hash_lock); + if (test_bit(CACHE_NEGATIVE, &new->flags)) + set_bit(CACHE_NEGATIVE, &tmp->flags); + else + detail->update(tmp, new); + tmp->next = *head; + *head = tmp; + detail->entries++; + cache_get(tmp); + is_new = cache_fresh_locked(tmp, new->expiry_time); + cache_fresh_locked(old, 0); + write_unlock(&detail->hash_lock); + cache_fresh_unlocked(tmp, detail, is_new); + cache_fresh_unlocked(old, detail, 0); + cache_put(old, detail); + return tmp; +} +EXPORT_SYMBOL(sunrpc_cache_update); + +static int cache_make_upcall(struct cache_detail *detail, struct cache_head *h); +/* + * This is the generic cache management routine for all + * the authentication caches. + * It checks the currency of a cache item and will (later) + * initiate an upcall to fill it if needed. + * + * + * Returns 0 if the cache_head can be used, or cache_puts it and returns + * -EAGAIN if upcall is pending, + * -ETIMEDOUT if upcall failed and should be retried, + * -ENOENT if cache entry was negative + */ +int cache_check(struct cache_detail *detail, + struct cache_head *h, struct cache_req *rqstp) +{ + int rv; + long refresh_age, age; + + /* First decide return status as best we can */ + if (!test_bit(CACHE_VALID, &h->flags) || + h->expiry_time < get_seconds()) + rv = -EAGAIN; + else if (detail->flush_time > h->last_refresh) + rv = -EAGAIN; + else { + /* entry is valid */ + if (test_bit(CACHE_NEGATIVE, &h->flags)) + rv = -ENOENT; + else rv = 0; + } + + /* now see if we want to start an upcall */ + refresh_age = (h->expiry_time - h->last_refresh); + age = get_seconds() - h->last_refresh; + + if (rqstp == NULL) { + if (rv == -EAGAIN) + rv = -ENOENT; + } else if (rv == -EAGAIN || age > refresh_age/2) { + dprintk("RPC: Want update, refage=%ld, age=%ld\n", + refresh_age, age); + if (!test_and_set_bit(CACHE_PENDING, &h->flags)) { + switch (cache_make_upcall(detail, h)) { + case -EINVAL: + clear_bit(CACHE_PENDING, &h->flags); + if (rv == -EAGAIN) { + set_bit(CACHE_NEGATIVE, &h->flags); + cache_fresh_unlocked(h, detail, + cache_fresh_locked(h, get_seconds()+CACHE_NEW_EXPIRY)); + rv = -ENOENT; + } + break; + + case -EAGAIN: + clear_bit(CACHE_PENDING, &h->flags); + cache_revisit_request(h); + break; + } + } + } + + if (rv == -EAGAIN) + if (cache_defer_req(rqstp, h) != 0) + rv = -ETIMEDOUT; + + if (rv) + cache_put(h, detail); + return rv; +} +EXPORT_SYMBOL(cache_check); + +/* + * caches need to be periodically cleaned. + * For this we maintain a list of cache_detail and + * a current pointer into that list and into the table + * for that entry. + * + * Each time clean_cache is called it finds the next non-empty entry + * in the current table and walks the list in that entry + * looking for entries that can be removed. + * + * An entry gets removed if: + * - The expiry is before current time + * - The last_refresh time is before the flush_time for that cache + * + * later we might drop old entries with non-NEVER expiry if that table + * is getting 'full' for some definition of 'full' + * + * The question of "how often to scan a table" is an interesting one + * and is answered in part by the use of the "nextcheck" field in the + * cache_detail. + * When a scan of a table begins, the nextcheck field is set to a time + * that is well into the future. + * While scanning, if an expiry time is found that is earlier than the + * current nextcheck time, nextcheck is set to that expiry time. + * If the flush_time is ever set to a time earlier than the nextcheck + * time, the nextcheck time is then set to that flush_time. + * + * A table is then only scanned if the current time is at least + * the nextcheck time. + * + */ + +static LIST_HEAD(cache_list); +static DEFINE_SPINLOCK(cache_list_lock); +static struct cache_detail *current_detail; +static int current_index; + +static const struct file_operations cache_file_operations; +static const struct file_operations content_file_operations; +static const struct file_operations cache_flush_operations; + +static void do_cache_clean(struct work_struct *work); +static DECLARE_DELAYED_WORK(cache_cleaner, do_cache_clean); + +static void remove_cache_proc_entries(struct cache_detail *cd) +{ + if (cd->proc_ent == NULL) + return; + if (cd->flush_ent) + remove_proc_entry("flush", cd->proc_ent); + if (cd->channel_ent) + remove_proc_entry("channel", cd->proc_ent); + if (cd->content_ent) + remove_proc_entry("content", cd->proc_ent); + cd->proc_ent = NULL; + remove_proc_entry(cd->name, proc_net_rpc); +} + +#ifdef CONFIG_PROC_FS +static int create_cache_proc_entries(struct cache_detail *cd) +{ + struct proc_dir_entry *p; + + cd->proc_ent = proc_mkdir(cd->name, proc_net_rpc); + if (cd->proc_ent == NULL) + goto out_nomem; + cd->proc_ent->owner = cd->owner; + cd->channel_ent = cd->content_ent = NULL; + + p = proc_create_data("flush", S_IFREG|S_IRUSR|S_IWUSR, + cd->proc_ent, &cache_flush_operations, cd); + cd->flush_ent = p; + if (p == NULL) + goto out_nomem; + p->owner = cd->owner; + + if (cd->cache_request || cd->cache_parse) { + p = proc_create_data("channel", S_IFREG|S_IRUSR|S_IWUSR, + cd->proc_ent, &cache_file_operations, cd); + cd->channel_ent = p; + if (p == NULL) + goto out_nomem; + p->owner = cd->owner; + } + if (cd->cache_show) { + p = proc_create_data("content", S_IFREG|S_IRUSR|S_IWUSR, + cd->proc_ent, &content_file_operations, cd); + cd->content_ent = p; + if (p == NULL) + goto out_nomem; + p->owner = cd->owner; + } + return 0; +out_nomem: + remove_cache_proc_entries(cd); + return -ENOMEM; +} +#else /* CONFIG_PROC_FS */ +static int create_cache_proc_entries(struct cache_detail *cd) +{ + return 0; +} +#endif + +int cache_register(struct cache_detail *cd) +{ + int ret; + + ret = create_cache_proc_entries(cd); + if (ret) + return ret; + rwlock_init(&cd->hash_lock); + INIT_LIST_HEAD(&cd->queue); + spin_lock(&cache_list_lock); + cd->nextcheck = 0; + cd->entries = 0; + atomic_set(&cd->readers, 0); + cd->last_close = 0; + cd->last_warn = -1; + list_add(&cd->others, &cache_list); + spin_unlock(&cache_list_lock); + + /* start the cleaning process */ + schedule_delayed_work(&cache_cleaner, 0); + return 0; +} +EXPORT_SYMBOL(cache_register); + +void cache_unregister(struct cache_detail *cd) +{ + cache_purge(cd); + spin_lock(&cache_list_lock); + write_lock(&cd->hash_lock); + if (cd->entries || atomic_read(&cd->inuse)) { + write_unlock(&cd->hash_lock); + spin_unlock(&cache_list_lock); + goto out; + } + if (current_detail == cd) + current_detail = NULL; + list_del_init(&cd->others); + write_unlock(&cd->hash_lock); + spin_unlock(&cache_list_lock); + remove_cache_proc_entries(cd); + if (list_empty(&cache_list)) { + /* module must be being unloaded so its safe to kill the worker */ + cancel_delayed_work_sync(&cache_cleaner); + } + return; +out: + printk(KERN_ERR "nfsd: failed to unregister %s cache\n", cd->name); +} +EXPORT_SYMBOL(cache_unregister); + +/* clean cache tries to find something to clean + * and cleans it. + * It returns 1 if it cleaned something, + * 0 if it didn't find anything this time + * -1 if it fell off the end of the list. + */ +static int cache_clean(void) +{ + int rv = 0; + struct list_head *next; + + spin_lock(&cache_list_lock); + + /* find a suitable table if we don't already have one */ + while (current_detail == NULL || + current_index >= current_detail->hash_size) { + if (current_detail) + next = current_detail->others.next; + else + next = cache_list.next; + if (next == &cache_list) { + current_detail = NULL; + spin_unlock(&cache_list_lock); + return -1; + } + current_detail = list_entry(next, struct cache_detail, others); + if (current_detail->nextcheck > get_seconds()) + current_index = current_detail->hash_size; + else { + current_index = 0; + current_detail->nextcheck = get_seconds()+30*60; + } + } + + /* find a non-empty bucket in the table */ + while (current_detail && + current_index < current_detail->hash_size && + current_detail->hash_table[current_index] == NULL) + current_index++; + + /* find a cleanable entry in the bucket and clean it, or set to next bucket */ + + if (current_detail && current_index < current_detail->hash_size) { + struct cache_head *ch, **cp; + struct cache_detail *d; + + write_lock(¤t_detail->hash_lock); + + /* Ok, now to clean this strand */ + + cp = & current_detail->hash_table[current_index]; + ch = *cp; + for (; ch; cp= & ch->next, ch= *cp) { + if (current_detail->nextcheck > ch->expiry_time) + current_detail->nextcheck = ch->expiry_time+1; + if (ch->expiry_time >= get_seconds() + && ch->last_refresh >= current_detail->flush_time + ) + continue; + if (test_and_clear_bit(CACHE_PENDING, &ch->flags)) + queue_loose(current_detail, ch); + + if (atomic_read(&ch->ref.refcount) == 1) + break; + } + if (ch) { + *cp = ch->next; + ch->next = NULL; + current_detail->entries--; + rv = 1; + } + write_unlock(¤t_detail->hash_lock); + d = current_detail; + if (!ch) + current_index ++; + spin_unlock(&cache_list_lock); + if (ch) + cache_put(ch, d); + } else + spin_unlock(&cache_list_lock); + + return rv; +} + +/* + * We want to regularly clean the cache, so we need to schedule some work ... + */ +static void do_cache_clean(struct work_struct *work) +{ + int delay = 5; + if (cache_clean() == -1) + delay = 30*HZ; + + if (list_empty(&cache_list)) + delay = 0; + + if (delay) + schedule_delayed_work(&cache_cleaner, delay); +} + + +/* + * Clean all caches promptly. This just calls cache_clean + * repeatedly until we are sure that every cache has had a chance to + * be fully cleaned + */ +void cache_flush(void) +{ + while (cache_clean() != -1) + cond_resched(); + while (cache_clean() != -1) + cond_resched(); +} +EXPORT_SYMBOL(cache_flush); + +void cache_purge(struct cache_detail *detail) +{ + detail->flush_time = LONG_MAX; + detail->nextcheck = get_seconds(); + cache_flush(); + detail->flush_time = 1; +} +EXPORT_SYMBOL(cache_purge); + + +/* + * Deferral and Revisiting of Requests. + * + * If a cache lookup finds a pending entry, we + * need to defer the request and revisit it later. + * All deferred requests are stored in a hash table, + * indexed by "struct cache_head *". + * As it may be wasteful to store a whole request + * structure, we allow the request to provide a + * deferred form, which must contain a + * 'struct cache_deferred_req' + * This cache_deferred_req contains a method to allow + * it to be revisited when cache info is available + */ + +#define DFR_HASHSIZE (PAGE_SIZE/sizeof(struct list_head)) +#define DFR_HASH(item) ((((long)item)>>4 ^ (((long)item)>>13)) % DFR_HASHSIZE) + +#define DFR_MAX 300 /* ??? */ + +static DEFINE_SPINLOCK(cache_defer_lock); +static LIST_HEAD(cache_defer_list); +static struct list_head cache_defer_hash[DFR_HASHSIZE]; +static int cache_defer_cnt; + +static int cache_defer_req(struct cache_req *req, struct cache_head *item) +{ + struct cache_deferred_req *dreq; + int hash = DFR_HASH(item); + + if (cache_defer_cnt >= DFR_MAX) { + /* too much in the cache, randomly drop this one, + * or continue and drop the oldest below + */ + if (net_random()&1) + return -ETIMEDOUT; + } + dreq = req->defer(req); + if (dreq == NULL) + return -ETIMEDOUT; + + dreq->item = item; + + spin_lock(&cache_defer_lock); + + list_add(&dreq->recent, &cache_defer_list); + + if (cache_defer_hash[hash].next == NULL) + INIT_LIST_HEAD(&cache_defer_hash[hash]); + list_add(&dreq->hash, &cache_defer_hash[hash]); + + /* it is in, now maybe clean up */ + dreq = NULL; + if (++cache_defer_cnt > DFR_MAX) { + dreq = list_entry(cache_defer_list.prev, + struct cache_deferred_req, recent); + list_del(&dreq->recent); + list_del(&dreq->hash); + cache_defer_cnt--; + } + spin_unlock(&cache_defer_lock); + + if (dreq) { + /* there was one too many */ + dreq->revisit(dreq, 1); + } + if (!test_bit(CACHE_PENDING, &item->flags)) { + /* must have just been validated... */ + cache_revisit_request(item); + } + return 0; +} + +static void cache_revisit_request(struct cache_head *item) +{ + struct cache_deferred_req *dreq; + struct list_head pending; + + struct list_head *lp; + int hash = DFR_HASH(item); + + INIT_LIST_HEAD(&pending); + spin_lock(&cache_defer_lock); + + lp = cache_defer_hash[hash].next; + if (lp) { + while (lp != &cache_defer_hash[hash]) { + dreq = list_entry(lp, struct cache_deferred_req, hash); + lp = lp->next; + if (dreq->item == item) { + list_del(&dreq->hash); + list_move(&dreq->recent, &pending); + cache_defer_cnt--; + } + } + } + spin_unlock(&cache_defer_lock); + + while (!list_empty(&pending)) { + dreq = list_entry(pending.next, struct cache_deferred_req, recent); + list_del_init(&dreq->recent); + dreq->revisit(dreq, 0); + } +} + +void cache_clean_deferred(void *owner) +{ + struct cache_deferred_req *dreq, *tmp; + struct list_head pending; + + + INIT_LIST_HEAD(&pending); + spin_lock(&cache_defer_lock); + + list_for_each_entry_safe(dreq, tmp, &cache_defer_list, recent) { + if (dreq->owner == owner) { + list_del(&dreq->hash); + list_move(&dreq->recent, &pending); + cache_defer_cnt--; + } + } + spin_unlock(&cache_defer_lock); + + while (!list_empty(&pending)) { + dreq = list_entry(pending.next, struct cache_deferred_req, recent); + list_del_init(&dreq->recent); + dreq->revisit(dreq, 1); + } +} + +/* + * communicate with user-space + * + * We have a magic /proc file - /proc/sunrpc/<cachename>/channel. + * On read, you get a full request, or block. + * On write, an update request is processed. + * Poll works if anything to read, and always allows write. + * + * Implemented by linked list of requests. Each open file has + * a ->private that also exists in this list. New requests are added + * to the end and may wakeup and preceding readers. + * New readers are added to the head. If, on read, an item is found with + * CACHE_UPCALLING clear, we free it from the list. + * + */ + +static DEFINE_SPINLOCK(queue_lock); +static DEFINE_MUTEX(queue_io_mutex); + +struct cache_queue { + struct list_head list; + int reader; /* if 0, then request */ +}; +struct cache_request { + struct cache_queue q; + struct cache_head *item; + char * buf; + int len; + int readers; +}; +struct cache_reader { + struct cache_queue q; + int offset; /* if non-0, we have a refcnt on next request */ +}; + +static ssize_t +cache_read(struct file *filp, char __user *buf, size_t count, loff_t *ppos) +{ + struct cache_reader *rp = filp->private_data; + struct cache_request *rq; + struct cache_detail *cd = PDE(filp->f_path.dentry->d_inode)->data; + int err; + + if (count == 0) + return 0; + + mutex_lock(&queue_io_mutex); /* protect against multiple concurrent + * readers on this file */ + again: + spin_lock(&queue_lock); + /* need to find next request */ + while (rp->q.list.next != &cd->queue && + list_entry(rp->q.list.next, struct cache_queue, list) + ->reader) { + struct list_head *next = rp->q.list.next; + list_move(&rp->q.list, next); + } + if (rp->q.list.next == &cd->queue) { + spin_unlock(&queue_lock); + mutex_unlock(&queue_io_mutex); + BUG_ON(rp->offset); + return 0; + } + rq = container_of(rp->q.list.next, struct cache_request, q.list); + BUG_ON(rq->q.reader); + if (rp->offset == 0) + rq->readers++; + spin_unlock(&queue_lock); + + if (rp->offset == 0 && !test_bit(CACHE_PENDING, &rq->item->flags)) { + err = -EAGAIN; + spin_lock(&queue_lock); + list_move(&rp->q.list, &rq->q.list); + spin_unlock(&queue_lock); + } else { + if (rp->offset + count > rq->len) + count = rq->len - rp->offset; + err = -EFAULT; + if (copy_to_user(buf, rq->buf + rp->offset, count)) + goto out; + rp->offset += count; + if (rp->offset >= rq->len) { + rp->offset = 0; + spin_lock(&queue_lock); + list_move(&rp->q.list, &rq->q.list); + spin_unlock(&queue_lock); + } + err = 0; + } + out: + if (rp->offset == 0) { + /* need to release rq */ + spin_lock(&queue_lock); + rq->readers--; + if (rq->readers == 0 && + !test_bit(CACHE_PENDING, &rq->item->flags)) { + list_del(&rq->q.list); + spin_unlock(&queue_lock); + cache_put(rq->item, cd); + kfree(rq->buf); + kfree(rq); + } else + spin_unlock(&queue_lock); + } + if (err == -EAGAIN) + goto again; + mutex_unlock(&queue_io_mutex); + return err ? err : count; +} + +static char write_buf[8192]; /* protected by queue_io_mutex */ + +static ssize_t +cache_write(struct file *filp, const char __user *buf, size_t count, + loff_t *ppos) +{ + int err; + struct cache_detail *cd = PDE(filp->f_path.dentry->d_inode)->data; + + if (count == 0) + return 0; + if (count >= sizeof(write_buf)) + return -EINVAL; + + mutex_lock(&queue_io_mutex); + + if (copy_from_user(write_buf, buf, count)) { + mutex_unlock(&queue_io_mutex); + return -EFAULT; + } + write_buf[count] = '\0'; + if (cd->cache_parse) + err = cd->cache_parse(cd, write_buf, count); + else + err = -EINVAL; + + mutex_unlock(&queue_io_mutex); + return err ? err : count; +} + +static DECLARE_WAIT_QUEUE_HEAD(queue_wait); + +static unsigned int +cache_poll(struct file *filp, poll_table *wait) +{ + unsigned int mask; + struct cache_reader *rp = filp->private_data; + struct cache_queue *cq; + struct cache_detail *cd = PDE(filp->f_path.dentry->d_inode)->data; + + poll_wait(filp, &queue_wait, wait); + + /* alway allow write */ + mask = POLL_OUT | POLLWRNORM; + + if (!rp) + return mask; + + spin_lock(&queue_lock); + + for (cq= &rp->q; &cq->list != &cd->queue; + cq = list_entry(cq->list.next, struct cache_queue, list)) + if (!cq->reader) { + mask |= POLLIN | POLLRDNORM; + break; + } + spin_unlock(&queue_lock); + return mask; +} + +static int +cache_ioctl(struct inode *ino, struct file *filp, + unsigned int cmd, unsigned long arg) +{ + int len = 0; + struct cache_reader *rp = filp->private_data; + struct cache_queue *cq; + struct cache_detail *cd = PDE(ino)->data; + + if (cmd != FIONREAD || !rp) + return -EINVAL; + + spin_lock(&queue_lock); + + /* only find the length remaining in current request, + * or the length of the next request + */ + for (cq= &rp->q; &cq->list != &cd->queue; + cq = list_entry(cq->list.next, struct cache_queue, list)) + if (!cq->reader) { + struct cache_request *cr = + container_of(cq, struct cache_request, q); + len = cr->len - rp->offset; + break; + } + spin_unlock(&queue_lock); + + return put_user(len, (int __user *)arg); +} + +static int +cache_open(struct inode *inode, struct file *filp) +{ + struct cache_reader *rp = NULL; + + nonseekable_open(inode, filp); + if (filp->f_mode & FMODE_READ) { + struct cache_detail *cd = PDE(inode)->data; + + rp = kmalloc(sizeof(*rp), GFP_KERNEL); + if (!rp) + return -ENOMEM; + rp->offset = 0; + rp->q.reader = 1; + atomic_inc(&cd->readers); + spin_lock(&queue_lock); + list_add(&rp->q.list, &cd->queue); + spin_unlock(&queue_lock); + } + filp->private_data = rp; + return 0; +} + +static int +cache_release(struct inode *inode, struct file *filp) +{ + struct cache_reader *rp = filp->private_data; + struct cache_detail *cd = PDE(inode)->data; + + if (rp) { + spin_lock(&queue_lock); + if (rp->offset) { + struct cache_queue *cq; + for (cq= &rp->q; &cq->list != &cd->queue; + cq = list_entry(cq->list.next, struct cache_queue, list)) + if (!cq->reader) { + container_of(cq, struct cache_request, q) + ->readers--; + break; + } + rp->offset = 0; + } + list_del(&rp->q.list); + spin_unlock(&queue_lock); + + filp->private_data = NULL; + kfree(rp); + + cd->last_close = get_seconds(); + atomic_dec(&cd->readers); + } + return 0; +} + + + +static const struct file_operations cache_file_operations = { + .owner = THIS_MODULE, + .llseek = no_llseek, + .read = cache_read, + .write = cache_write, + .poll = cache_poll, + .ioctl = cache_ioctl, /* for FIONREAD */ + .open = cache_open, + .release = cache_release, +}; + + +static void queue_loose(struct cache_detail *detail, struct cache_head *ch) +{ + struct cache_queue *cq; + spin_lock(&queue_lock); + list_for_each_entry(cq, &detail->queue, list) + if (!cq->reader) { + struct cache_request *cr = container_of(cq, struct cache_request, q); + if (cr->item != ch) + continue; + if (cr->readers != 0) + continue; + list_del(&cr->q.list); + spin_unlock(&queue_lock); + cache_put(cr->item, detail); + kfree(cr->buf); + kfree(cr); + return; + } + spin_unlock(&queue_lock); +} + +/* + * Support routines for text-based upcalls. + * Fields are separated by spaces. + * Fields are either mangled to quote space tab newline slosh with slosh + * or a hexified with a leading \x + * Record is terminated with newline. + * + */ + +void qword_add(char **bpp, int *lp, char *str) +{ + char *bp = *bpp; + int len = *lp; + char c; + + if (len < 0) return; + + while ((c=*str++) && len) + switch(c) { + case ' ': + case '\t': + case '\n': + case '\\': + if (len >= 4) { + *bp++ = '\\'; + *bp++ = '0' + ((c & 0300)>>6); + *bp++ = '0' + ((c & 0070)>>3); + *bp++ = '0' + ((c & 0007)>>0); + } + len -= 4; + break; + default: + *bp++ = c; + len--; + } + if (c || len <1) len = -1; + else { + *bp++ = ' '; + len--; + } + *bpp = bp; + *lp = len; +} +EXPORT_SYMBOL(qword_add); + +void qword_addhex(char **bpp, int *lp, char *buf, int blen) +{ + char *bp = *bpp; + int len = *lp; + + if (len < 0) return; + + if (len > 2) { + *bp++ = '\\'; + *bp++ = 'x'; + len -= 2; + while (blen && len >= 2) { + unsigned char c = *buf++; + *bp++ = '0' + ((c&0xf0)>>4) + (c>=0xa0)*('a'-'9'-1); + *bp++ = '0' + (c&0x0f) + ((c&0x0f)>=0x0a)*('a'-'9'-1); + len -= 2; + blen--; + } + } + if (blen || len<1) len = -1; + else { + *bp++ = ' '; + len--; + } + *bpp = bp; + *lp = len; +} +EXPORT_SYMBOL(qword_addhex); + +static void warn_no_listener(struct cache_detail *detail) +{ + if (detail->last_warn != detail->last_close) { + detail->last_warn = detail->last_close; + if (detail->warn_no_listener) + detail->warn_no_listener(detail); + } +} + +/* + * register an upcall request to user-space. + * Each request is at most one page long. + */ +static int cache_make_upcall(struct cache_detail *detail, struct cache_head *h) +{ + + char *buf; + struct cache_request *crq; + char *bp; + int len; + + if (detail->cache_request == NULL) + return -EINVAL; + + if (atomic_read(&detail->readers) == 0 && + detail->last_close < get_seconds() - 30) { + warn_no_listener(detail); + return -EINVAL; + } + + buf = kmalloc(PAGE_SIZE, GFP_KERNEL); + if (!buf) + return -EAGAIN; + + crq = kmalloc(sizeof (*crq), GFP_KERNEL); + if (!crq) { + kfree(buf); + return -EAGAIN; + } + + bp = buf; len = PAGE_SIZE; + + detail->cache_request(detail, h, &bp, &len); + + if (len < 0) { + kfree(buf); + kfree(crq); + return -EAGAIN; + } + crq->q.reader = 0; + crq->item = cache_get(h); + crq->buf = buf; + crq->len = PAGE_SIZE - len; + crq->readers = 0; + spin_lock(&queue_lock); + list_add_tail(&crq->q.list, &detail->queue); + spin_unlock(&queue_lock); + wake_up(&queue_wait); + return 0; +} + +/* + * parse a message from user-space and pass it + * to an appropriate cache + * Messages are, like requests, separated into fields by + * spaces and dequotes as \xHEXSTRING or embedded \nnn octal + * + * Message is + * reply cachename expiry key ... content.... + * + * key and content are both parsed by cache + */ + +#define isodigit(c) (isdigit(c) && c <= '7') +int qword_get(char **bpp, char *dest, int bufsize) +{ + /* return bytes copied, or -1 on error */ + char *bp = *bpp; + int len = 0; + + while (*bp == ' ') bp++; + + if (bp[0] == '\\' && bp[1] == 'x') { + /* HEX STRING */ + bp += 2; + while (isxdigit(bp[0]) && isxdigit(bp[1]) && len < bufsize) { + int byte = isdigit(*bp) ? *bp-'0' : toupper(*bp)-'A'+10; + bp++; + byte <<= 4; + byte |= isdigit(*bp) ? *bp-'0' : toupper(*bp)-'A'+10; + *dest++ = byte; + bp++; + len++; + } + } else { + /* text with \nnn octal quoting */ + while (*bp != ' ' && *bp != '\n' && *bp && len < bufsize-1) { + if (*bp == '\\' && + isodigit(bp[1]) && (bp[1] <= '3') && + isodigit(bp[2]) && + isodigit(bp[3])) { + int byte = (*++bp -'0'); + bp++; + byte = (byte << 3) | (*bp++ - '0'); + byte = (byte << 3) | (*bp++ - '0'); + *dest++ = byte; + len++; + } else { + *dest++ = *bp++; + len++; + } + } + } + + if (*bp != ' ' && *bp != '\n' && *bp != '\0') + return -1; + while (*bp == ' ') bp++; + *bpp = bp; + *dest = '\0'; + return len; +} +EXPORT_SYMBOL(qword_get); + + +/* + * support /proc/sunrpc/cache/$CACHENAME/content + * as a seqfile. + * We call ->cache_show passing NULL for the item to + * get a header, then pass each real item in the cache + */ + +struct handle { + struct cache_detail *cd; +}; + +static void *c_start(struct seq_file *m, loff_t *pos) + __acquires(cd->hash_lock) +{ + loff_t n = *pos; + unsigned hash, entry; + struct cache_head *ch; + struct cache_detail *cd = ((struct handle*)m->private)->cd; + + + read_lock(&cd->hash_lock); + if (!n--) + return SEQ_START_TOKEN; + hash = n >> 32; + entry = n & ((1LL<<32) - 1); + + for (ch=cd->hash_table[hash]; ch; ch=ch->next) + if (!entry--) + return ch; + n &= ~((1LL<<32) - 1); + do { + hash++; + n += 1LL<<32; + } while(hash < cd->hash_size && + cd->hash_table[hash]==NULL); + if (hash >= cd->hash_size) + return NULL; + *pos = n+1; + return cd->hash_table[hash]; +} + +static void *c_next(struct seq_file *m, void *p, loff_t *pos) +{ + struct cache_head *ch = p; + int hash = (*pos >> 32); + struct cache_detail *cd = ((struct handle*)m->private)->cd; + + if (p == SEQ_START_TOKEN) + hash = 0; + else if (ch->next == NULL) { + hash++; + *pos += 1LL<<32; + } else { + ++*pos; + return ch->next; + } + *pos &= ~((1LL<<32) - 1); + while (hash < cd->hash_size && + cd->hash_table[hash] == NULL) { + hash++; + *pos += 1LL<<32; + } + if (hash >= cd->hash_size) + return NULL; + ++*pos; + return cd->hash_table[hash]; +} + +static void c_stop(struct seq_file *m, void *p) + __releases(cd->hash_lock) +{ + struct cache_detail *cd = ((struct handle*)m->private)->cd; + read_unlock(&cd->hash_lock); +} + +static int c_show(struct seq_file *m, void *p) +{ + struct cache_head *cp = p; + struct cache_detail *cd = ((struct handle*)m->private)->cd; + + if (p == SEQ_START_TOKEN) + return cd->cache_show(m, cd, NULL); + + ifdebug(CACHE) + seq_printf(m, "# expiry=%ld refcnt=%d flags=%lx\n", + cp->expiry_time, atomic_read(&cp->ref.refcount), cp->flags); + cache_get(cp); + if (cache_check(cd, cp, NULL)) + /* cache_check does a cache_put on failure */ + seq_printf(m, "# "); + else + cache_put(cp, cd); + + return cd->cache_show(m, cd, cp); +} + +static const struct seq_operations cache_content_op = { + .start = c_start, + .next = c_next, + .stop = c_stop, + .show = c_show, +}; + +static int content_open(struct inode *inode, struct file *file) +{ + struct handle *han; + struct cache_detail *cd = PDE(inode)->data; + + han = __seq_open_private(file, &cache_content_op, sizeof(*han)); + if (han == NULL) + return -ENOMEM; + + han->cd = cd; + return 0; +} + +static const struct file_operations content_file_operations = { + .open = content_open, + .read = seq_read, + .llseek = seq_lseek, + .release = seq_release_private, +}; + +static ssize_t read_flush(struct file *file, char __user *buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = PDE(file->f_path.dentry->d_inode)->data; + char tbuf[20]; + unsigned long p = *ppos; + size_t len; + + sprintf(tbuf, "%lu\n", cd->flush_time); + len = strlen(tbuf); + if (p >= len) + return 0; + len -= p; + if (len > count) + len = count; + if (copy_to_user(buf, (void*)(tbuf+p), len)) + return -EFAULT; + *ppos += len; + return len; +} + +static ssize_t write_flush(struct file * file, const char __user * buf, + size_t count, loff_t *ppos) +{ + struct cache_detail *cd = PDE(file->f_path.dentry->d_inode)->data; + char tbuf[20]; + char *ep; + long flushtime; + if (*ppos || count > sizeof(tbuf)-1) + return -EINVAL; + if (copy_from_user(tbuf, buf, count)) + return -EFAULT; + tbuf[count] = 0; + flushtime = simple_strtoul(tbuf, &ep, 0); + if (*ep && *ep != '\n') + return -EINVAL; + + cd->flush_time = flushtime; + cd->nextcheck = get_seconds(); + cache_flush(); + + *ppos += count; + return count; +} + +static const struct file_operations cache_flush_operations = { + .open = nonseekable_open, + .read = read_flush, + .write = write_flush, +}; diff --git a/net/sunrpc/clnt.c b/net/sunrpc/clnt.c new file mode 100644 index 0000000..4895c34 --- /dev/null +++ b/net/sunrpc/clnt.c @@ -0,0 +1,1583 @@ +/* + * linux/net/sunrpc/clnt.c + * + * This file contains the high-level RPC interface. + * It is modeled as a finite state machine to support both synchronous + * and asynchronous requests. + * + * - RPC header generation and argument serialization. + * - Credential refresh. + * - TCP connect handling. + * - Retry of operation when it is suspected the operation failed because + * of uid squashing on the server, or when the credentials were stale + * and need to be refreshed, or when a packet was damaged in transit. + * This may be have to be moved to the VFS layer. + * + * NB: BSD uses a more intelligent approach to guessing when a request + * or reply has been lost by keeping the RTO estimate for each procedure. + * We currently make do with a constant timeout value. + * + * Copyright (C) 1992,1993 Rick Sladkey <jrs@world.std.com> + * Copyright (C) 1995,1996 Olaf Kirch <okir@monad.swb.de> + */ + +#include <asm/system.h> + +#include <linux/module.h> +#include <linux/types.h> +#include <linux/kallsyms.h> +#include <linux/mm.h> +#include <linux/slab.h> +#include <linux/smp_lock.h> +#include <linux/utsname.h> +#include <linux/workqueue.h> +#include <linux/in6.h> + +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/rpc_pipe_fs.h> +#include <linux/sunrpc/metrics.h> + + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_CALL +#endif + +#define dprint_status(t) \ + dprintk("RPC: %5u %s (status %d)\n", t->tk_pid, \ + __func__, t->tk_status) + +/* + * All RPC clients are linked into this list + */ +static LIST_HEAD(all_clients); +static DEFINE_SPINLOCK(rpc_client_lock); + +static DECLARE_WAIT_QUEUE_HEAD(destroy_wait); + + +static void call_start(struct rpc_task *task); +static void call_reserve(struct rpc_task *task); +static void call_reserveresult(struct rpc_task *task); +static void call_allocate(struct rpc_task *task); +static void call_decode(struct rpc_task *task); +static void call_bind(struct rpc_task *task); +static void call_bind_status(struct rpc_task *task); +static void call_transmit(struct rpc_task *task); +static void call_status(struct rpc_task *task); +static void call_transmit_status(struct rpc_task *task); +static void call_refresh(struct rpc_task *task); +static void call_refreshresult(struct rpc_task *task); +static void call_timeout(struct rpc_task *task); +static void call_connect(struct rpc_task *task); +static void call_connect_status(struct rpc_task *task); + +static __be32 *rpc_encode_header(struct rpc_task *task); +static __be32 *rpc_verify_header(struct rpc_task *task); +static int rpc_ping(struct rpc_clnt *clnt, int flags); + +static void rpc_register_client(struct rpc_clnt *clnt) +{ + spin_lock(&rpc_client_lock); + list_add(&clnt->cl_clients, &all_clients); + spin_unlock(&rpc_client_lock); +} + +static void rpc_unregister_client(struct rpc_clnt *clnt) +{ + spin_lock(&rpc_client_lock); + list_del(&clnt->cl_clients); + spin_unlock(&rpc_client_lock); +} + +static int +rpc_setup_pipedir(struct rpc_clnt *clnt, char *dir_name) +{ + static uint32_t clntid; + int error; + + clnt->cl_vfsmnt = ERR_PTR(-ENOENT); + clnt->cl_dentry = ERR_PTR(-ENOENT); + if (dir_name == NULL) + return 0; + + clnt->cl_vfsmnt = rpc_get_mount(); + if (IS_ERR(clnt->cl_vfsmnt)) + return PTR_ERR(clnt->cl_vfsmnt); + + for (;;) { + snprintf(clnt->cl_pathname, sizeof(clnt->cl_pathname), + "%s/clnt%x", dir_name, + (unsigned int)clntid++); + clnt->cl_pathname[sizeof(clnt->cl_pathname) - 1] = '\0'; + clnt->cl_dentry = rpc_mkdir(clnt->cl_pathname, clnt); + if (!IS_ERR(clnt->cl_dentry)) + return 0; + error = PTR_ERR(clnt->cl_dentry); + if (error != -EEXIST) { + printk(KERN_INFO "RPC: Couldn't create pipefs entry %s, error %d\n", + clnt->cl_pathname, error); + rpc_put_mount(); + return error; + } + } +} + +static struct rpc_clnt * rpc_new_client(const struct rpc_create_args *args, struct rpc_xprt *xprt) +{ + struct rpc_program *program = args->program; + struct rpc_version *version; + struct rpc_clnt *clnt = NULL; + struct rpc_auth *auth; + int err; + size_t len; + + /* sanity check the name before trying to print it */ + err = -EINVAL; + len = strlen(args->servername); + if (len > RPC_MAXNETNAMELEN) + goto out_no_rpciod; + len++; + + dprintk("RPC: creating %s client for %s (xprt %p)\n", + program->name, args->servername, xprt); + + err = rpciod_up(); + if (err) + goto out_no_rpciod; + err = -EINVAL; + if (!xprt) + goto out_no_xprt; + + if (args->version >= program->nrvers) + goto out_err; + version = program->version[args->version]; + if (version == NULL) + goto out_err; + + err = -ENOMEM; + clnt = kzalloc(sizeof(*clnt), GFP_KERNEL); + if (!clnt) + goto out_err; + clnt->cl_parent = clnt; + + clnt->cl_server = clnt->cl_inline_name; + if (len > sizeof(clnt->cl_inline_name)) { + char *buf = kmalloc(len, GFP_KERNEL); + if (buf != NULL) + clnt->cl_server = buf; + else + len = sizeof(clnt->cl_inline_name); + } + strlcpy(clnt->cl_server, args->servername, len); + + clnt->cl_xprt = xprt; + clnt->cl_procinfo = version->procs; + clnt->cl_maxproc = version->nrprocs; + clnt->cl_protname = program->name; + clnt->cl_prog = args->prognumber ? : program->number; + clnt->cl_vers = version->number; + clnt->cl_stats = program->stats; + clnt->cl_metrics = rpc_alloc_iostats(clnt); + err = -ENOMEM; + if (clnt->cl_metrics == NULL) + goto out_no_stats; + clnt->cl_program = program; + INIT_LIST_HEAD(&clnt->cl_tasks); + spin_lock_init(&clnt->cl_lock); + + if (!xprt_bound(clnt->cl_xprt)) + clnt->cl_autobind = 1; + + clnt->cl_timeout = xprt->timeout; + if (args->timeout != NULL) { + memcpy(&clnt->cl_timeout_default, args->timeout, + sizeof(clnt->cl_timeout_default)); + clnt->cl_timeout = &clnt->cl_timeout_default; + } + + clnt->cl_rtt = &clnt->cl_rtt_default; + rpc_init_rtt(&clnt->cl_rtt_default, clnt->cl_timeout->to_initval); + + kref_init(&clnt->cl_kref); + + err = rpc_setup_pipedir(clnt, program->pipe_dir_name); + if (err < 0) + goto out_no_path; + + auth = rpcauth_create(args->authflavor, clnt); + if (IS_ERR(auth)) { + printk(KERN_INFO "RPC: Couldn't create auth handle (flavor %u)\n", + args->authflavor); + err = PTR_ERR(auth); + goto out_no_auth; + } + + /* save the nodename */ + clnt->cl_nodelen = strlen(init_utsname()->nodename); + if (clnt->cl_nodelen > UNX_MAXNODENAME) + clnt->cl_nodelen = UNX_MAXNODENAME; + memcpy(clnt->cl_nodename, init_utsname()->nodename, clnt->cl_nodelen); + rpc_register_client(clnt); + return clnt; + +out_no_auth: + if (!IS_ERR(clnt->cl_dentry)) { + rpc_rmdir(clnt->cl_dentry); + rpc_put_mount(); + } +out_no_path: + rpc_free_iostats(clnt->cl_metrics); +out_no_stats: + if (clnt->cl_server != clnt->cl_inline_name) + kfree(clnt->cl_server); + kfree(clnt); +out_err: + xprt_put(xprt); +out_no_xprt: + rpciod_down(); +out_no_rpciod: + return ERR_PTR(err); +} + +/* + * rpc_create - create an RPC client and transport with one call + * @args: rpc_clnt create argument structure + * + * Creates and initializes an RPC transport and an RPC client. + * + * It can ping the server in order to determine if it is up, and to see if + * it supports this program and version. RPC_CLNT_CREATE_NOPING disables + * this behavior so asynchronous tasks can also use rpc_create. + */ +struct rpc_clnt *rpc_create(struct rpc_create_args *args) +{ + struct rpc_xprt *xprt; + struct rpc_clnt *clnt; + struct xprt_create xprtargs = { + .ident = args->protocol, + .srcaddr = args->saddress, + .dstaddr = args->address, + .addrlen = args->addrsize, + }; + char servername[48]; + + /* + * If the caller chooses not to specify a hostname, whip + * up a string representation of the passed-in address. + */ + if (args->servername == NULL) { + servername[0] = '\0'; + switch (args->address->sa_family) { + case AF_INET: { + struct sockaddr_in *sin = + (struct sockaddr_in *)args->address; + snprintf(servername, sizeof(servername), NIPQUAD_FMT, + NIPQUAD(sin->sin_addr.s_addr)); + break; + } + case AF_INET6: { + struct sockaddr_in6 *sin = + (struct sockaddr_in6 *)args->address; + snprintf(servername, sizeof(servername), NIP6_FMT, + NIP6(sin->sin6_addr)); + break; + } + default: + /* caller wants default server name, but + * address family isn't recognized. */ + return ERR_PTR(-EINVAL); + } + args->servername = servername; + } + + xprt = xprt_create_transport(&xprtargs); + if (IS_ERR(xprt)) + return (struct rpc_clnt *)xprt; + + /* + * By default, kernel RPC client connects from a reserved port. + * CAP_NET_BIND_SERVICE will not be set for unprivileged requesters, + * but it is always enabled for rpciod, which handles the connect + * operation. + */ + xprt->resvport = 1; + if (args->flags & RPC_CLNT_CREATE_NONPRIVPORT) + xprt->resvport = 0; + + clnt = rpc_new_client(args, xprt); + if (IS_ERR(clnt)) + return clnt; + + if (!(args->flags & RPC_CLNT_CREATE_NOPING)) { + int err = rpc_ping(clnt, RPC_TASK_SOFT); + if (err != 0) { + rpc_shutdown_client(clnt); + return ERR_PTR(err); + } + } + + clnt->cl_softrtry = 1; + if (args->flags & RPC_CLNT_CREATE_HARDRTRY) + clnt->cl_softrtry = 0; + + if (args->flags & RPC_CLNT_CREATE_AUTOBIND) + clnt->cl_autobind = 1; + if (args->flags & RPC_CLNT_CREATE_DISCRTRY) + clnt->cl_discrtry = 1; + if (!(args->flags & RPC_CLNT_CREATE_QUIET)) + clnt->cl_chatty = 1; + + return clnt; +} +EXPORT_SYMBOL_GPL(rpc_create); + +/* + * This function clones the RPC client structure. It allows us to share the + * same transport while varying parameters such as the authentication + * flavour. + */ +struct rpc_clnt * +rpc_clone_client(struct rpc_clnt *clnt) +{ + struct rpc_clnt *new; + int err = -ENOMEM; + + new = kmemdup(clnt, sizeof(*new), GFP_KERNEL); + if (!new) + goto out_no_clnt; + new->cl_parent = clnt; + /* Turn off autobind on clones */ + new->cl_autobind = 0; + INIT_LIST_HEAD(&new->cl_tasks); + spin_lock_init(&new->cl_lock); + rpc_init_rtt(&new->cl_rtt_default, clnt->cl_timeout->to_initval); + new->cl_metrics = rpc_alloc_iostats(clnt); + if (new->cl_metrics == NULL) + goto out_no_stats; + kref_init(&new->cl_kref); + err = rpc_setup_pipedir(new, clnt->cl_program->pipe_dir_name); + if (err != 0) + goto out_no_path; + if (new->cl_auth) + atomic_inc(&new->cl_auth->au_count); + xprt_get(clnt->cl_xprt); + kref_get(&clnt->cl_kref); + rpc_register_client(new); + rpciod_up(); + return new; +out_no_path: + rpc_free_iostats(new->cl_metrics); +out_no_stats: + kfree(new); +out_no_clnt: + dprintk("RPC: %s: returned error %d\n", __func__, err); + return ERR_PTR(err); +} +EXPORT_SYMBOL_GPL(rpc_clone_client); + +/* + * Properly shut down an RPC client, terminating all outstanding + * requests. + */ +void rpc_shutdown_client(struct rpc_clnt *clnt) +{ + dprintk("RPC: shutting down %s client for %s\n", + clnt->cl_protname, clnt->cl_server); + + while (!list_empty(&clnt->cl_tasks)) { + rpc_killall_tasks(clnt); + wait_event_timeout(destroy_wait, + list_empty(&clnt->cl_tasks), 1*HZ); + } + + rpc_release_client(clnt); +} +EXPORT_SYMBOL_GPL(rpc_shutdown_client); + +/* + * Free an RPC client + */ +static void +rpc_free_client(struct kref *kref) +{ + struct rpc_clnt *clnt = container_of(kref, struct rpc_clnt, cl_kref); + + dprintk("RPC: destroying %s client for %s\n", + clnt->cl_protname, clnt->cl_server); + if (!IS_ERR(clnt->cl_dentry)) { + rpc_rmdir(clnt->cl_dentry); + rpc_put_mount(); + } + if (clnt->cl_parent != clnt) { + rpc_release_client(clnt->cl_parent); + goto out_free; + } + if (clnt->cl_server != clnt->cl_inline_name) + kfree(clnt->cl_server); +out_free: + rpc_unregister_client(clnt); + rpc_free_iostats(clnt->cl_metrics); + clnt->cl_metrics = NULL; + xprt_put(clnt->cl_xprt); + rpciod_down(); + kfree(clnt); +} + +/* + * Free an RPC client + */ +static void +rpc_free_auth(struct kref *kref) +{ + struct rpc_clnt *clnt = container_of(kref, struct rpc_clnt, cl_kref); + + if (clnt->cl_auth == NULL) { + rpc_free_client(kref); + return; + } + + /* + * Note: RPCSEC_GSS may need to send NULL RPC calls in order to + * release remaining GSS contexts. This mechanism ensures + * that it can do so safely. + */ + kref_init(kref); + rpcauth_release(clnt->cl_auth); + clnt->cl_auth = NULL; + kref_put(kref, rpc_free_client); +} + +/* + * Release reference to the RPC client + */ +void +rpc_release_client(struct rpc_clnt *clnt) +{ + dprintk("RPC: rpc_release_client(%p)\n", clnt); + + if (list_empty(&clnt->cl_tasks)) + wake_up(&destroy_wait); + kref_put(&clnt->cl_kref, rpc_free_auth); +} + +/** + * rpc_bind_new_program - bind a new RPC program to an existing client + * @old: old rpc_client + * @program: rpc program to set + * @vers: rpc program version + * + * Clones the rpc client and sets up a new RPC program. This is mainly + * of use for enabling different RPC programs to share the same transport. + * The Sun NFSv2/v3 ACL protocol can do this. + */ +struct rpc_clnt *rpc_bind_new_program(struct rpc_clnt *old, + struct rpc_program *program, + u32 vers) +{ + struct rpc_clnt *clnt; + struct rpc_version *version; + int err; + + BUG_ON(vers >= program->nrvers || !program->version[vers]); + version = program->version[vers]; + clnt = rpc_clone_client(old); + if (IS_ERR(clnt)) + goto out; + clnt->cl_procinfo = version->procs; + clnt->cl_maxproc = version->nrprocs; + clnt->cl_protname = program->name; + clnt->cl_prog = program->number; + clnt->cl_vers = version->number; + clnt->cl_stats = program->stats; + err = rpc_ping(clnt, RPC_TASK_SOFT); + if (err != 0) { + rpc_shutdown_client(clnt); + clnt = ERR_PTR(err); + } +out: + return clnt; +} +EXPORT_SYMBOL_GPL(rpc_bind_new_program); + +/* + * Default callback for async RPC calls + */ +static void +rpc_default_callback(struct rpc_task *task, void *data) +{ +} + +static const struct rpc_call_ops rpc_default_ops = { + .rpc_call_done = rpc_default_callback, +}; + +/** + * rpc_run_task - Allocate a new RPC task, then run rpc_execute against it + * @task_setup_data: pointer to task initialisation data + */ +struct rpc_task *rpc_run_task(const struct rpc_task_setup *task_setup_data) +{ + struct rpc_task *task, *ret; + + task = rpc_new_task(task_setup_data); + if (task == NULL) { + rpc_release_calldata(task_setup_data->callback_ops, + task_setup_data->callback_data); + ret = ERR_PTR(-ENOMEM); + goto out; + } + + if (task->tk_status != 0) { + ret = ERR_PTR(task->tk_status); + rpc_put_task(task); + goto out; + } + atomic_inc(&task->tk_count); + rpc_execute(task); + ret = task; +out: + return ret; +} +EXPORT_SYMBOL_GPL(rpc_run_task); + +/** + * rpc_call_sync - Perform a synchronous RPC call + * @clnt: pointer to RPC client + * @msg: RPC call parameters + * @flags: RPC call flags + */ +int rpc_call_sync(struct rpc_clnt *clnt, const struct rpc_message *msg, int flags) +{ + struct rpc_task *task; + struct rpc_task_setup task_setup_data = { + .rpc_client = clnt, + .rpc_message = msg, + .callback_ops = &rpc_default_ops, + .flags = flags, + }; + int status; + + BUG_ON(flags & RPC_TASK_ASYNC); + + task = rpc_run_task(&task_setup_data); + if (IS_ERR(task)) + return PTR_ERR(task); + status = task->tk_status; + rpc_put_task(task); + return status; +} +EXPORT_SYMBOL_GPL(rpc_call_sync); + +/** + * rpc_call_async - Perform an asynchronous RPC call + * @clnt: pointer to RPC client + * @msg: RPC call parameters + * @flags: RPC call flags + * @tk_ops: RPC call ops + * @data: user call data + */ +int +rpc_call_async(struct rpc_clnt *clnt, const struct rpc_message *msg, int flags, + const struct rpc_call_ops *tk_ops, void *data) +{ + struct rpc_task *task; + struct rpc_task_setup task_setup_data = { + .rpc_client = clnt, + .rpc_message = msg, + .callback_ops = tk_ops, + .callback_data = data, + .flags = flags|RPC_TASK_ASYNC, + }; + + task = rpc_run_task(&task_setup_data); + if (IS_ERR(task)) + return PTR_ERR(task); + rpc_put_task(task); + return 0; +} +EXPORT_SYMBOL_GPL(rpc_call_async); + +void +rpc_call_start(struct rpc_task *task) +{ + task->tk_action = call_start; +} +EXPORT_SYMBOL_GPL(rpc_call_start); + +/** + * rpc_peeraddr - extract remote peer address from clnt's xprt + * @clnt: RPC client structure + * @buf: target buffer + * @bufsize: length of target buffer + * + * Returns the number of bytes that are actually in the stored address. + */ +size_t rpc_peeraddr(struct rpc_clnt *clnt, struct sockaddr *buf, size_t bufsize) +{ + size_t bytes; + struct rpc_xprt *xprt = clnt->cl_xprt; + + bytes = sizeof(xprt->addr); + if (bytes > bufsize) + bytes = bufsize; + memcpy(buf, &clnt->cl_xprt->addr, bytes); + return xprt->addrlen; +} +EXPORT_SYMBOL_GPL(rpc_peeraddr); + +/** + * rpc_peeraddr2str - return remote peer address in printable format + * @clnt: RPC client structure + * @format: address format + * + */ +const char *rpc_peeraddr2str(struct rpc_clnt *clnt, + enum rpc_display_format_t format) +{ + struct rpc_xprt *xprt = clnt->cl_xprt; + + if (xprt->address_strings[format] != NULL) + return xprt->address_strings[format]; + else + return "unprintable"; +} +EXPORT_SYMBOL_GPL(rpc_peeraddr2str); + +void +rpc_setbufsize(struct rpc_clnt *clnt, unsigned int sndsize, unsigned int rcvsize) +{ + struct rpc_xprt *xprt = clnt->cl_xprt; + if (xprt->ops->set_buffer_size) + xprt->ops->set_buffer_size(xprt, sndsize, rcvsize); +} +EXPORT_SYMBOL_GPL(rpc_setbufsize); + +/* + * Return size of largest payload RPC client can support, in bytes + * + * For stream transports, this is one RPC record fragment (see RFC + * 1831), as we don't support multi-record requests yet. For datagram + * transports, this is the size of an IP packet minus the IP, UDP, and + * RPC header sizes. + */ +size_t rpc_max_payload(struct rpc_clnt *clnt) +{ + return clnt->cl_xprt->max_payload; +} +EXPORT_SYMBOL_GPL(rpc_max_payload); + +/** + * rpc_force_rebind - force transport to check that remote port is unchanged + * @clnt: client to rebind + * + */ +void rpc_force_rebind(struct rpc_clnt *clnt) +{ + if (clnt->cl_autobind) + xprt_clear_bound(clnt->cl_xprt); +} +EXPORT_SYMBOL_GPL(rpc_force_rebind); + +/* + * Restart an (async) RPC call. Usually called from within the + * exit handler. + */ +void +rpc_restart_call(struct rpc_task *task) +{ + if (RPC_ASSASSINATED(task)) + return; + + task->tk_action = call_start; +} +EXPORT_SYMBOL_GPL(rpc_restart_call); + +#ifdef RPC_DEBUG +static const char *rpc_proc_name(const struct rpc_task *task) +{ + const struct rpc_procinfo *proc = task->tk_msg.rpc_proc; + + if (proc) { + if (proc->p_name) + return proc->p_name; + else + return "NULL"; + } else + return "no proc"; +} +#endif + +/* + * 0. Initial state + * + * Other FSM states can be visited zero or more times, but + * this state is visited exactly once for each RPC. + */ +static void +call_start(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + + dprintk("RPC: %5u call_start %s%d proc %s (%s)\n", task->tk_pid, + clnt->cl_protname, clnt->cl_vers, + rpc_proc_name(task), + (RPC_IS_ASYNC(task) ? "async" : "sync")); + + /* Increment call count */ + task->tk_msg.rpc_proc->p_count++; + clnt->cl_stats->rpccnt++; + task->tk_action = call_reserve; +} + +/* + * 1. Reserve an RPC call slot + */ +static void +call_reserve(struct rpc_task *task) +{ + dprint_status(task); + + if (!rpcauth_uptodatecred(task)) { + task->tk_action = call_refresh; + return; + } + + task->tk_status = 0; + task->tk_action = call_reserveresult; + xprt_reserve(task); +} + +/* + * 1b. Grok the result of xprt_reserve() + */ +static void +call_reserveresult(struct rpc_task *task) +{ + int status = task->tk_status; + + dprint_status(task); + + /* + * After a call to xprt_reserve(), we must have either + * a request slot or else an error status. + */ + task->tk_status = 0; + if (status >= 0) { + if (task->tk_rqstp) { + task->tk_action = call_allocate; + return; + } + + printk(KERN_ERR "%s: status=%d, but no request slot, exiting\n", + __func__, status); + rpc_exit(task, -EIO); + return; + } + + /* + * Even though there was an error, we may have acquired + * a request slot somehow. Make sure not to leak it. + */ + if (task->tk_rqstp) { + printk(KERN_ERR "%s: status=%d, request allocated anyway\n", + __func__, status); + xprt_release(task); + } + + switch (status) { + case -EAGAIN: /* woken up; retry */ + task->tk_action = call_reserve; + return; + case -EIO: /* probably a shutdown */ + break; + default: + printk(KERN_ERR "%s: unrecognized error %d, exiting\n", + __func__, status); + break; + } + rpc_exit(task, status); +} + +/* + * 2. Allocate the buffer. For details, see sched.c:rpc_malloc. + * (Note: buffer memory is freed in xprt_release). + */ +static void +call_allocate(struct rpc_task *task) +{ + unsigned int slack = task->tk_msg.rpc_cred->cr_auth->au_cslack; + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_procinfo *proc = task->tk_msg.rpc_proc; + + dprint_status(task); + + task->tk_status = 0; + task->tk_action = call_bind; + + if (req->rq_buffer) + return; + + if (proc->p_proc != 0) { + BUG_ON(proc->p_arglen == 0); + if (proc->p_decode != NULL) + BUG_ON(proc->p_replen == 0); + } + + /* + * Calculate the size (in quads) of the RPC call + * and reply headers, and convert both values + * to byte sizes. + */ + req->rq_callsize = RPC_CALLHDRSIZE + (slack << 1) + proc->p_arglen; + req->rq_callsize <<= 2; + req->rq_rcvsize = RPC_REPHDRSIZE + slack + proc->p_replen; + req->rq_rcvsize <<= 2; + + req->rq_buffer = xprt->ops->buf_alloc(task, + req->rq_callsize + req->rq_rcvsize); + if (req->rq_buffer != NULL) + return; + + dprintk("RPC: %5u rpc_buffer allocation failed\n", task->tk_pid); + + if (RPC_IS_ASYNC(task) || !signalled()) { + task->tk_action = call_allocate; + rpc_delay(task, HZ>>4); + return; + } + + rpc_exit(task, -ERESTARTSYS); +} + +static inline int +rpc_task_need_encode(struct rpc_task *task) +{ + return task->tk_rqstp->rq_snd_buf.len == 0; +} + +static inline void +rpc_task_force_reencode(struct rpc_task *task) +{ + task->tk_rqstp->rq_snd_buf.len = 0; +} + +static inline void +rpc_xdr_buf_init(struct xdr_buf *buf, void *start, size_t len) +{ + buf->head[0].iov_base = start; + buf->head[0].iov_len = len; + buf->tail[0].iov_len = 0; + buf->page_len = 0; + buf->flags = 0; + buf->len = 0; + buf->buflen = len; +} + +/* + * 3. Encode arguments of an RPC call + */ +static void +rpc_xdr_encode(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + kxdrproc_t encode; + __be32 *p; + + dprint_status(task); + + rpc_xdr_buf_init(&req->rq_snd_buf, + req->rq_buffer, + req->rq_callsize); + rpc_xdr_buf_init(&req->rq_rcv_buf, + (char *)req->rq_buffer + req->rq_callsize, + req->rq_rcvsize); + + p = rpc_encode_header(task); + if (p == NULL) { + printk(KERN_INFO "RPC: couldn't encode RPC header, exit EIO\n"); + rpc_exit(task, -EIO); + return; + } + + encode = task->tk_msg.rpc_proc->p_encode; + if (encode == NULL) + return; + + task->tk_status = rpcauth_wrap_req(task, encode, req, p, + task->tk_msg.rpc_argp); +} + +/* + * 4. Get the server port number if not yet set + */ +static void +call_bind(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + + dprint_status(task); + + task->tk_action = call_connect; + if (!xprt_bound(xprt)) { + task->tk_action = call_bind_status; + task->tk_timeout = xprt->bind_timeout; + xprt->ops->rpcbind(task); + } +} + +/* + * 4a. Sort out bind result + */ +static void +call_bind_status(struct rpc_task *task) +{ + int status = -EIO; + + if (task->tk_status >= 0) { + dprint_status(task); + task->tk_status = 0; + task->tk_action = call_connect; + return; + } + + switch (task->tk_status) { + case -ENOMEM: + dprintk("RPC: %5u rpcbind out of memory\n", task->tk_pid); + rpc_delay(task, HZ >> 2); + goto retry_timeout; + case -EACCES: + dprintk("RPC: %5u remote rpcbind: RPC program/version " + "unavailable\n", task->tk_pid); + /* fail immediately if this is an RPC ping */ + if (task->tk_msg.rpc_proc->p_proc == 0) { + status = -EOPNOTSUPP; + break; + } + rpc_delay(task, 3*HZ); + goto retry_timeout; + case -ETIMEDOUT: + dprintk("RPC: %5u rpcbind request timed out\n", + task->tk_pid); + goto retry_timeout; + case -EPFNOSUPPORT: + /* server doesn't support any rpcbind version we know of */ + dprintk("RPC: %5u remote rpcbind service unavailable\n", + task->tk_pid); + break; + case -EPROTONOSUPPORT: + dprintk("RPC: %5u remote rpcbind version unavailable, retrying\n", + task->tk_pid); + task->tk_status = 0; + task->tk_action = call_bind; + return; + default: + dprintk("RPC: %5u unrecognized rpcbind error (%d)\n", + task->tk_pid, -task->tk_status); + } + + rpc_exit(task, status); + return; + +retry_timeout: + task->tk_action = call_timeout; +} + +/* + * 4b. Connect to the RPC server + */ +static void +call_connect(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + + dprintk("RPC: %5u call_connect xprt %p %s connected\n", + task->tk_pid, xprt, + (xprt_connected(xprt) ? "is" : "is not")); + + task->tk_action = call_transmit; + if (!xprt_connected(xprt)) { + task->tk_action = call_connect_status; + if (task->tk_status < 0) + return; + xprt_connect(task); + } +} + +/* + * 4c. Sort out connect result + */ +static void +call_connect_status(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + int status = task->tk_status; + + dprint_status(task); + + task->tk_status = 0; + if (status >= 0) { + clnt->cl_stats->netreconn++; + task->tk_action = call_transmit; + return; + } + + /* Something failed: remote service port may have changed */ + rpc_force_rebind(clnt); + + switch (status) { + case -ENOTCONN: + case -EAGAIN: + task->tk_action = call_bind; + if (!RPC_IS_SOFT(task)) + return; + /* if soft mounted, test if we've timed out */ + case -ETIMEDOUT: + task->tk_action = call_timeout; + return; + } + rpc_exit(task, -EIO); +} + +/* + * 5. Transmit the RPC request, and wait for reply + */ +static void +call_transmit(struct rpc_task *task) +{ + dprint_status(task); + + task->tk_action = call_status; + if (task->tk_status < 0) + return; + task->tk_status = xprt_prepare_transmit(task); + if (task->tk_status != 0) + return; + task->tk_action = call_transmit_status; + /* Encode here so that rpcsec_gss can use correct sequence number. */ + if (rpc_task_need_encode(task)) { + BUG_ON(task->tk_rqstp->rq_bytes_sent != 0); + rpc_xdr_encode(task); + /* Did the encode result in an error condition? */ + if (task->tk_status != 0) { + /* Was the error nonfatal? */ + if (task->tk_status == -EAGAIN) + rpc_delay(task, HZ >> 4); + else + rpc_exit(task, task->tk_status); + return; + } + } + xprt_transmit(task); + if (task->tk_status < 0) + return; + /* + * On success, ensure that we call xprt_end_transmit() before sleeping + * in order to allow access to the socket to other RPC requests. + */ + call_transmit_status(task); + if (task->tk_msg.rpc_proc->p_decode != NULL) + return; + task->tk_action = rpc_exit_task; + rpc_wake_up_queued_task(&task->tk_xprt->pending, task); +} + +/* + * 5a. Handle cleanup after a transmission + */ +static void +call_transmit_status(struct rpc_task *task) +{ + task->tk_action = call_status; + /* + * Special case: if we've been waiting on the socket's write_space() + * callback, then don't call xprt_end_transmit(). + */ + if (task->tk_status == -EAGAIN) + return; + xprt_end_transmit(task); + rpc_task_force_reencode(task); +} + +/* + * 6. Sort out the RPC call status + */ +static void +call_status(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + struct rpc_rqst *req = task->tk_rqstp; + int status; + + if (req->rq_received > 0 && !req->rq_bytes_sent) + task->tk_status = req->rq_received; + + dprint_status(task); + + status = task->tk_status; + if (status >= 0) { + task->tk_action = call_decode; + return; + } + + task->tk_status = 0; + switch(status) { + case -EHOSTDOWN: + case -EHOSTUNREACH: + case -ENETUNREACH: + /* + * Delay any retries for 3 seconds, then handle as if it + * were a timeout. + */ + rpc_delay(task, 3*HZ); + case -ETIMEDOUT: + task->tk_action = call_timeout; + if (task->tk_client->cl_discrtry) + xprt_conditional_disconnect(task->tk_xprt, + req->rq_connect_cookie); + break; + case -ECONNREFUSED: + case -ENOTCONN: + rpc_force_rebind(clnt); + task->tk_action = call_bind; + break; + case -EAGAIN: + task->tk_action = call_transmit; + break; + case -EIO: + /* shutdown or soft timeout */ + rpc_exit(task, status); + break; + default: + if (clnt->cl_chatty) + printk("%s: RPC call returned error %d\n", + clnt->cl_protname, -status); + rpc_exit(task, status); + } +} + +/* + * 6a. Handle RPC timeout + * We do not release the request slot, so we keep using the + * same XID for all retransmits. + */ +static void +call_timeout(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + + if (xprt_adjust_timeout(task->tk_rqstp) == 0) { + dprintk("RPC: %5u call_timeout (minor)\n", task->tk_pid); + goto retry; + } + + dprintk("RPC: %5u call_timeout (major)\n", task->tk_pid); + task->tk_timeouts++; + + if (RPC_IS_SOFT(task)) { + if (clnt->cl_chatty) + printk(KERN_NOTICE "%s: server %s not responding, timed out\n", + clnt->cl_protname, clnt->cl_server); + rpc_exit(task, -EIO); + return; + } + + if (!(task->tk_flags & RPC_CALL_MAJORSEEN)) { + task->tk_flags |= RPC_CALL_MAJORSEEN; + if (clnt->cl_chatty) + printk(KERN_NOTICE "%s: server %s not responding, still trying\n", + clnt->cl_protname, clnt->cl_server); + } + rpc_force_rebind(clnt); + /* + * Did our request time out due to an RPCSEC_GSS out-of-sequence + * event? RFC2203 requires the server to drop all such requests. + */ + rpcauth_invalcred(task); + +retry: + clnt->cl_stats->rpcretrans++; + task->tk_action = call_bind; + task->tk_status = 0; +} + +/* + * 7. Decode the RPC reply + */ +static void +call_decode(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + struct rpc_rqst *req = task->tk_rqstp; + kxdrproc_t decode = task->tk_msg.rpc_proc->p_decode; + __be32 *p; + + dprintk("RPC: %5u call_decode (status %d)\n", + task->tk_pid, task->tk_status); + + if (task->tk_flags & RPC_CALL_MAJORSEEN) { + if (clnt->cl_chatty) + printk(KERN_NOTICE "%s: server %s OK\n", + clnt->cl_protname, clnt->cl_server); + task->tk_flags &= ~RPC_CALL_MAJORSEEN; + } + + /* + * Ensure that we see all writes made by xprt_complete_rqst() + * before it changed req->rq_received. + */ + smp_rmb(); + req->rq_rcv_buf.len = req->rq_private_buf.len; + + /* Check that the softirq receive buffer is valid */ + WARN_ON(memcmp(&req->rq_rcv_buf, &req->rq_private_buf, + sizeof(req->rq_rcv_buf)) != 0); + + if (req->rq_rcv_buf.len < 12) { + if (!RPC_IS_SOFT(task)) { + task->tk_action = call_bind; + clnt->cl_stats->rpcretrans++; + goto out_retry; + } + dprintk("RPC: %s: too small RPC reply size (%d bytes)\n", + clnt->cl_protname, task->tk_status); + task->tk_action = call_timeout; + goto out_retry; + } + + p = rpc_verify_header(task); + if (IS_ERR(p)) { + if (p == ERR_PTR(-EAGAIN)) + goto out_retry; + return; + } + + task->tk_action = rpc_exit_task; + + if (decode) { + task->tk_status = rpcauth_unwrap_resp(task, decode, req, p, + task->tk_msg.rpc_resp); + } + dprintk("RPC: %5u call_decode result %d\n", task->tk_pid, + task->tk_status); + return; +out_retry: + task->tk_status = 0; + /* Note: rpc_verify_header() may have freed the RPC slot */ + if (task->tk_rqstp == req) { + req->rq_received = req->rq_rcv_buf.len = 0; + if (task->tk_client->cl_discrtry) + xprt_conditional_disconnect(task->tk_xprt, + req->rq_connect_cookie); + } +} + +/* + * 8. Refresh the credentials if rejected by the server + */ +static void +call_refresh(struct rpc_task *task) +{ + dprint_status(task); + + task->tk_action = call_refreshresult; + task->tk_status = 0; + task->tk_client->cl_stats->rpcauthrefresh++; + rpcauth_refreshcred(task); +} + +/* + * 8a. Process the results of a credential refresh + */ +static void +call_refreshresult(struct rpc_task *task) +{ + int status = task->tk_status; + + dprint_status(task); + + task->tk_status = 0; + task->tk_action = call_reserve; + if (status >= 0 && rpcauth_uptodatecred(task)) + return; + if (status == -EACCES) { + rpc_exit(task, -EACCES); + return; + } + task->tk_action = call_refresh; + if (status != -ETIMEDOUT) + rpc_delay(task, 3*HZ); + return; +} + +static __be32 * +rpc_encode_header(struct rpc_task *task) +{ + struct rpc_clnt *clnt = task->tk_client; + struct rpc_rqst *req = task->tk_rqstp; + __be32 *p = req->rq_svec[0].iov_base; + + /* FIXME: check buffer size? */ + + p = xprt_skip_transport_header(task->tk_xprt, p); + *p++ = req->rq_xid; /* XID */ + *p++ = htonl(RPC_CALL); /* CALL */ + *p++ = htonl(RPC_VERSION); /* RPC version */ + *p++ = htonl(clnt->cl_prog); /* program number */ + *p++ = htonl(clnt->cl_vers); /* program version */ + *p++ = htonl(task->tk_msg.rpc_proc->p_proc); /* procedure */ + p = rpcauth_marshcred(task, p); + req->rq_slen = xdr_adjust_iovec(&req->rq_svec[0], p); + return p; +} + +static __be32 * +rpc_verify_header(struct rpc_task *task) +{ + struct kvec *iov = &task->tk_rqstp->rq_rcv_buf.head[0]; + int len = task->tk_rqstp->rq_rcv_buf.len >> 2; + __be32 *p = iov->iov_base; + u32 n; + int error = -EACCES; + + if ((task->tk_rqstp->rq_rcv_buf.len & 3) != 0) { + /* RFC-1014 says that the representation of XDR data must be a + * multiple of four bytes + * - if it isn't pointer subtraction in the NFS client may give + * undefined results + */ + dprintk("RPC: %5u %s: XDR representation not a multiple of" + " 4 bytes: 0x%x\n", task->tk_pid, __func__, + task->tk_rqstp->rq_rcv_buf.len); + goto out_eio; + } + if ((len -= 3) < 0) + goto out_overflow; + p += 1; /* skip XID */ + + if ((n = ntohl(*p++)) != RPC_REPLY) { + dprintk("RPC: %5u %s: not an RPC reply: %x\n", + task->tk_pid, __func__, n); + goto out_garbage; + } + if ((n = ntohl(*p++)) != RPC_MSG_ACCEPTED) { + if (--len < 0) + goto out_overflow; + switch ((n = ntohl(*p++))) { + case RPC_AUTH_ERROR: + break; + case RPC_MISMATCH: + dprintk("RPC: %5u %s: RPC call version " + "mismatch!\n", + task->tk_pid, __func__); + error = -EPROTONOSUPPORT; + goto out_err; + default: + dprintk("RPC: %5u %s: RPC call rejected, " + "unknown error: %x\n", + task->tk_pid, __func__, n); + goto out_eio; + } + if (--len < 0) + goto out_overflow; + switch ((n = ntohl(*p++))) { + case RPC_AUTH_REJECTEDCRED: + case RPC_AUTH_REJECTEDVERF: + case RPCSEC_GSS_CREDPROBLEM: + case RPCSEC_GSS_CTXPROBLEM: + if (!task->tk_cred_retry) + break; + task->tk_cred_retry--; + dprintk("RPC: %5u %s: retry stale creds\n", + task->tk_pid, __func__); + rpcauth_invalcred(task); + /* Ensure we obtain a new XID! */ + xprt_release(task); + task->tk_action = call_refresh; + goto out_retry; + case RPC_AUTH_BADCRED: + case RPC_AUTH_BADVERF: + /* possibly garbled cred/verf? */ + if (!task->tk_garb_retry) + break; + task->tk_garb_retry--; + dprintk("RPC: %5u %s: retry garbled creds\n", + task->tk_pid, __func__); + task->tk_action = call_bind; + goto out_retry; + case RPC_AUTH_TOOWEAK: + printk(KERN_NOTICE "RPC: server %s requires stronger " + "authentication.\n", task->tk_client->cl_server); + break; + default: + dprintk("RPC: %5u %s: unknown auth error: %x\n", + task->tk_pid, __func__, n); + error = -EIO; + } + dprintk("RPC: %5u %s: call rejected %d\n", + task->tk_pid, __func__, n); + goto out_err; + } + if (!(p = rpcauth_checkverf(task, p))) { + dprintk("RPC: %5u %s: auth check failed\n", + task->tk_pid, __func__); + goto out_garbage; /* bad verifier, retry */ + } + len = p - (__be32 *)iov->iov_base - 1; + if (len < 0) + goto out_overflow; + switch ((n = ntohl(*p++))) { + case RPC_SUCCESS: + return p; + case RPC_PROG_UNAVAIL: + dprintk("RPC: %5u %s: program %u is unsupported by server %s\n", + task->tk_pid, __func__, + (unsigned int)task->tk_client->cl_prog, + task->tk_client->cl_server); + error = -EPFNOSUPPORT; + goto out_err; + case RPC_PROG_MISMATCH: + dprintk("RPC: %5u %s: program %u, version %u unsupported by " + "server %s\n", task->tk_pid, __func__, + (unsigned int)task->tk_client->cl_prog, + (unsigned int)task->tk_client->cl_vers, + task->tk_client->cl_server); + error = -EPROTONOSUPPORT; + goto out_err; + case RPC_PROC_UNAVAIL: + dprintk("RPC: %5u %s: proc %s unsupported by program %u, " + "version %u on server %s\n", + task->tk_pid, __func__, + rpc_proc_name(task), + task->tk_client->cl_prog, + task->tk_client->cl_vers, + task->tk_client->cl_server); + error = -EOPNOTSUPP; + goto out_err; + case RPC_GARBAGE_ARGS: + dprintk("RPC: %5u %s: server saw garbage\n", + task->tk_pid, __func__); + break; /* retry */ + default: + dprintk("RPC: %5u %s: server accept status: %x\n", + task->tk_pid, __func__, n); + /* Also retry */ + } + +out_garbage: + task->tk_client->cl_stats->rpcgarbage++; + if (task->tk_garb_retry) { + task->tk_garb_retry--; + dprintk("RPC: %5u %s: retrying\n", + task->tk_pid, __func__); + task->tk_action = call_bind; +out_retry: + return ERR_PTR(-EAGAIN); + } +out_eio: + error = -EIO; +out_err: + rpc_exit(task, error); + dprintk("RPC: %5u %s: call failed with error %d\n", task->tk_pid, + __func__, error); + return ERR_PTR(error); +out_overflow: + dprintk("RPC: %5u %s: server reply was truncated.\n", task->tk_pid, + __func__); + goto out_garbage; +} + +static int rpcproc_encode_null(void *rqstp, __be32 *data, void *obj) +{ + return 0; +} + +static int rpcproc_decode_null(void *rqstp, __be32 *data, void *obj) +{ + return 0; +} + +static struct rpc_procinfo rpcproc_null = { + .p_encode = rpcproc_encode_null, + .p_decode = rpcproc_decode_null, +}; + +static int rpc_ping(struct rpc_clnt *clnt, int flags) +{ + struct rpc_message msg = { + .rpc_proc = &rpcproc_null, + }; + int err; + msg.rpc_cred = authnull_ops.lookup_cred(NULL, NULL, 0); + err = rpc_call_sync(clnt, &msg, flags); + put_rpccred(msg.rpc_cred); + return err; +} + +struct rpc_task *rpc_call_null(struct rpc_clnt *clnt, struct rpc_cred *cred, int flags) +{ + struct rpc_message msg = { + .rpc_proc = &rpcproc_null, + .rpc_cred = cred, + }; + struct rpc_task_setup task_setup_data = { + .rpc_client = clnt, + .rpc_message = &msg, + .callback_ops = &rpc_default_ops, + .flags = flags, + }; + return rpc_run_task(&task_setup_data); +} +EXPORT_SYMBOL_GPL(rpc_call_null); + +#ifdef RPC_DEBUG +static void rpc_show_header(void) +{ + printk(KERN_INFO "-pid- flgs status -client- --rqstp- " + "-timeout ---ops--\n"); +} + +static void rpc_show_task(const struct rpc_clnt *clnt, + const struct rpc_task *task) +{ + const char *rpc_waitq = "none"; + char *p, action[KSYM_SYMBOL_LEN]; + + if (RPC_IS_QUEUED(task)) + rpc_waitq = rpc_qname(task->tk_waitqueue); + + /* map tk_action pointer to a function name; then trim off + * the "+0x0 [sunrpc]" */ + sprint_symbol(action, (unsigned long)task->tk_action); + p = strchr(action, '+'); + if (p) + *p = '\0'; + + printk(KERN_INFO "%5u %04x %6d %8p %8p %8ld %8p %sv%u %s a:%s q:%s\n", + task->tk_pid, task->tk_flags, task->tk_status, + clnt, task->tk_rqstp, task->tk_timeout, task->tk_ops, + clnt->cl_protname, clnt->cl_vers, rpc_proc_name(task), + action, rpc_waitq); +} + +void rpc_show_tasks(void) +{ + struct rpc_clnt *clnt; + struct rpc_task *task; + int header = 0; + + spin_lock(&rpc_client_lock); + list_for_each_entry(clnt, &all_clients, cl_clients) { + spin_lock(&clnt->cl_lock); + list_for_each_entry(task, &clnt->cl_tasks, tk_task) { + if (!header) { + rpc_show_header(); + header++; + } + rpc_show_task(clnt, task); + } + spin_unlock(&clnt->cl_lock); + } + spin_unlock(&rpc_client_lock); +} +#endif diff --git a/net/sunrpc/rpc_pipe.c b/net/sunrpc/rpc_pipe.c new file mode 100644 index 0000000..23a2b8f --- /dev/null +++ b/net/sunrpc/rpc_pipe.c @@ -0,0 +1,942 @@ +/* + * net/sunrpc/rpc_pipe.c + * + * Userland/kernel interface for rpcauth_gss. + * Code shamelessly plagiarized from fs/nfsd/nfsctl.c + * and fs/sysfs/inode.c + * + * Copyright (c) 2002, Trond Myklebust <trond.myklebust@fys.uio.no> + * + */ +#include <linux/module.h> +#include <linux/slab.h> +#include <linux/string.h> +#include <linux/pagemap.h> +#include <linux/mount.h> +#include <linux/namei.h> +#include <linux/fsnotify.h> +#include <linux/kernel.h> + +#include <asm/ioctls.h> +#include <linux/fs.h> +#include <linux/poll.h> +#include <linux/wait.h> +#include <linux/seq_file.h> + +#include <linux/sunrpc/clnt.h> +#include <linux/workqueue.h> +#include <linux/sunrpc/rpc_pipe_fs.h> + +static struct vfsmount *rpc_mount __read_mostly; +static int rpc_mount_count; + +static struct file_system_type rpc_pipe_fs_type; + + +static struct kmem_cache *rpc_inode_cachep __read_mostly; + +#define RPC_UPCALL_TIMEOUT (30*HZ) + +static void rpc_purge_list(struct rpc_inode *rpci, struct list_head *head, + void (*destroy_msg)(struct rpc_pipe_msg *), int err) +{ + struct rpc_pipe_msg *msg; + + if (list_empty(head)) + return; + do { + msg = list_entry(head->next, struct rpc_pipe_msg, list); + list_del(&msg->list); + msg->errno = err; + destroy_msg(msg); + } while (!list_empty(head)); + wake_up(&rpci->waitq); +} + +static void +rpc_timeout_upcall_queue(struct work_struct *work) +{ + LIST_HEAD(free_list); + struct rpc_inode *rpci = + container_of(work, struct rpc_inode, queue_timeout.work); + struct inode *inode = &rpci->vfs_inode; + void (*destroy_msg)(struct rpc_pipe_msg *); + + spin_lock(&inode->i_lock); + if (rpci->ops == NULL) { + spin_unlock(&inode->i_lock); + return; + } + destroy_msg = rpci->ops->destroy_msg; + if (rpci->nreaders == 0) { + list_splice_init(&rpci->pipe, &free_list); + rpci->pipelen = 0; + } + spin_unlock(&inode->i_lock); + rpc_purge_list(rpci, &free_list, destroy_msg, -ETIMEDOUT); +} + +/** + * rpc_queue_upcall + * @inode: inode of upcall pipe on which to queue given message + * @msg: message to queue + * + * Call with an @inode created by rpc_mkpipe() to queue an upcall. + * A userspace process may then later read the upcall by performing a + * read on an open file for this inode. It is up to the caller to + * initialize the fields of @msg (other than @msg->list) appropriately. + */ +int +rpc_queue_upcall(struct inode *inode, struct rpc_pipe_msg *msg) +{ + struct rpc_inode *rpci = RPC_I(inode); + int res = -EPIPE; + + spin_lock(&inode->i_lock); + if (rpci->ops == NULL) + goto out; + if (rpci->nreaders) { + list_add_tail(&msg->list, &rpci->pipe); + rpci->pipelen += msg->len; + res = 0; + } else if (rpci->flags & RPC_PIPE_WAIT_FOR_OPEN) { + if (list_empty(&rpci->pipe)) + queue_delayed_work(rpciod_workqueue, + &rpci->queue_timeout, + RPC_UPCALL_TIMEOUT); + list_add_tail(&msg->list, &rpci->pipe); + rpci->pipelen += msg->len; + res = 0; + } +out: + spin_unlock(&inode->i_lock); + wake_up(&rpci->waitq); + return res; +} +EXPORT_SYMBOL(rpc_queue_upcall); + +static inline void +rpc_inode_setowner(struct inode *inode, void *private) +{ + RPC_I(inode)->private = private; +} + +static void +rpc_close_pipes(struct inode *inode) +{ + struct rpc_inode *rpci = RPC_I(inode); + struct rpc_pipe_ops *ops; + + mutex_lock(&inode->i_mutex); + ops = rpci->ops; + if (ops != NULL) { + LIST_HEAD(free_list); + + spin_lock(&inode->i_lock); + rpci->nreaders = 0; + list_splice_init(&rpci->in_upcall, &free_list); + list_splice_init(&rpci->pipe, &free_list); + rpci->pipelen = 0; + rpci->ops = NULL; + spin_unlock(&inode->i_lock); + rpc_purge_list(rpci, &free_list, ops->destroy_msg, -EPIPE); + rpci->nwriters = 0; + if (ops->release_pipe) + ops->release_pipe(inode); + cancel_delayed_work_sync(&rpci->queue_timeout); + } + rpc_inode_setowner(inode, NULL); + mutex_unlock(&inode->i_mutex); +} + +static struct inode * +rpc_alloc_inode(struct super_block *sb) +{ + struct rpc_inode *rpci; + rpci = (struct rpc_inode *)kmem_cache_alloc(rpc_inode_cachep, GFP_KERNEL); + if (!rpci) + return NULL; + return &rpci->vfs_inode; +} + +static void +rpc_destroy_inode(struct inode *inode) +{ + kmem_cache_free(rpc_inode_cachep, RPC_I(inode)); +} + +static int +rpc_pipe_open(struct inode *inode, struct file *filp) +{ + struct rpc_inode *rpci = RPC_I(inode); + int res = -ENXIO; + + mutex_lock(&inode->i_mutex); + if (rpci->ops != NULL) { + if (filp->f_mode & FMODE_READ) + rpci->nreaders ++; + if (filp->f_mode & FMODE_WRITE) + rpci->nwriters ++; + res = 0; + } + mutex_unlock(&inode->i_mutex); + return res; +} + +static int +rpc_pipe_release(struct inode *inode, struct file *filp) +{ + struct rpc_inode *rpci = RPC_I(inode); + struct rpc_pipe_msg *msg; + + mutex_lock(&inode->i_mutex); + if (rpci->ops == NULL) + goto out; + msg = (struct rpc_pipe_msg *)filp->private_data; + if (msg != NULL) { + spin_lock(&inode->i_lock); + msg->errno = -EAGAIN; + list_del(&msg->list); + spin_unlock(&inode->i_lock); + rpci->ops->destroy_msg(msg); + } + if (filp->f_mode & FMODE_WRITE) + rpci->nwriters --; + if (filp->f_mode & FMODE_READ) { + rpci->nreaders --; + if (rpci->nreaders == 0) { + LIST_HEAD(free_list); + spin_lock(&inode->i_lock); + list_splice_init(&rpci->pipe, &free_list); + rpci->pipelen = 0; + spin_unlock(&inode->i_lock); + rpc_purge_list(rpci, &free_list, + rpci->ops->destroy_msg, -EAGAIN); + } + } + if (rpci->ops->release_pipe) + rpci->ops->release_pipe(inode); +out: + mutex_unlock(&inode->i_mutex); + return 0; +} + +static ssize_t +rpc_pipe_read(struct file *filp, char __user *buf, size_t len, loff_t *offset) +{ + struct inode *inode = filp->f_path.dentry->d_inode; + struct rpc_inode *rpci = RPC_I(inode); + struct rpc_pipe_msg *msg; + int res = 0; + + mutex_lock(&inode->i_mutex); + if (rpci->ops == NULL) { + res = -EPIPE; + goto out_unlock; + } + msg = filp->private_data; + if (msg == NULL) { + spin_lock(&inode->i_lock); + if (!list_empty(&rpci->pipe)) { + msg = list_entry(rpci->pipe.next, + struct rpc_pipe_msg, + list); + list_move(&msg->list, &rpci->in_upcall); + rpci->pipelen -= msg->len; + filp->private_data = msg; + msg->copied = 0; + } + spin_unlock(&inode->i_lock); + if (msg == NULL) + goto out_unlock; + } + /* NOTE: it is up to the callback to update msg->copied */ + res = rpci->ops->upcall(filp, msg, buf, len); + if (res < 0 || msg->len == msg->copied) { + filp->private_data = NULL; + spin_lock(&inode->i_lock); + list_del(&msg->list); + spin_unlock(&inode->i_lock); + rpci->ops->destroy_msg(msg); + } +out_unlock: + mutex_unlock(&inode->i_mutex); + return res; +} + +static ssize_t +rpc_pipe_write(struct file *filp, const char __user *buf, size_t len, loff_t *offset) +{ + struct inode *inode = filp->f_path.dentry->d_inode; + struct rpc_inode *rpci = RPC_I(inode); + int res; + + mutex_lock(&inode->i_mutex); + res = -EPIPE; + if (rpci->ops != NULL) + res = rpci->ops->downcall(filp, buf, len); + mutex_unlock(&inode->i_mutex); + return res; +} + +static unsigned int +rpc_pipe_poll(struct file *filp, struct poll_table_struct *wait) +{ + struct rpc_inode *rpci; + unsigned int mask = 0; + + rpci = RPC_I(filp->f_path.dentry->d_inode); + poll_wait(filp, &rpci->waitq, wait); + + mask = POLLOUT | POLLWRNORM; + if (rpci->ops == NULL) + mask |= POLLERR | POLLHUP; + if (filp->private_data || !list_empty(&rpci->pipe)) + mask |= POLLIN | POLLRDNORM; + return mask; +} + +static int +rpc_pipe_ioctl(struct inode *ino, struct file *filp, + unsigned int cmd, unsigned long arg) +{ + struct rpc_inode *rpci = RPC_I(filp->f_path.dentry->d_inode); + int len; + + switch (cmd) { + case FIONREAD: + if (rpci->ops == NULL) + return -EPIPE; + len = rpci->pipelen; + if (filp->private_data) { + struct rpc_pipe_msg *msg; + msg = (struct rpc_pipe_msg *)filp->private_data; + len += msg->len - msg->copied; + } + return put_user(len, (int __user *)arg); + default: + return -EINVAL; + } +} + +static const struct file_operations rpc_pipe_fops = { + .owner = THIS_MODULE, + .llseek = no_llseek, + .read = rpc_pipe_read, + .write = rpc_pipe_write, + .poll = rpc_pipe_poll, + .ioctl = rpc_pipe_ioctl, + .open = rpc_pipe_open, + .release = rpc_pipe_release, +}; + +static int +rpc_show_info(struct seq_file *m, void *v) +{ + struct rpc_clnt *clnt = m->private; + + seq_printf(m, "RPC server: %s\n", clnt->cl_server); + seq_printf(m, "service: %s (%d) version %d\n", clnt->cl_protname, + clnt->cl_prog, clnt->cl_vers); + seq_printf(m, "address: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_ADDR)); + seq_printf(m, "protocol: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_PROTO)); + seq_printf(m, "port: %s\n", rpc_peeraddr2str(clnt, RPC_DISPLAY_PORT)); + return 0; +} + +static int +rpc_info_open(struct inode *inode, struct file *file) +{ + struct rpc_clnt *clnt; + int ret = single_open(file, rpc_show_info, NULL); + + if (!ret) { + struct seq_file *m = file->private_data; + mutex_lock(&inode->i_mutex); + clnt = RPC_I(inode)->private; + if (clnt) { + kref_get(&clnt->cl_kref); + m->private = clnt; + } else { + single_release(inode, file); + ret = -EINVAL; + } + mutex_unlock(&inode->i_mutex); + } + return ret; +} + +static int +rpc_info_release(struct inode *inode, struct file *file) +{ + struct seq_file *m = file->private_data; + struct rpc_clnt *clnt = (struct rpc_clnt *)m->private; + + if (clnt) + rpc_release_client(clnt); + return single_release(inode, file); +} + +static const struct file_operations rpc_info_operations = { + .owner = THIS_MODULE, + .open = rpc_info_open, + .read = seq_read, + .llseek = seq_lseek, + .release = rpc_info_release, +}; + + +/* + * We have a single directory with 1 node in it. + */ +enum { + RPCAUTH_Root = 1, + RPCAUTH_lockd, + RPCAUTH_mount, + RPCAUTH_nfs, + RPCAUTH_portmap, + RPCAUTH_statd, + RPCAUTH_RootEOF +}; + +/* + * Description of fs contents. + */ +struct rpc_filelist { + char *name; + const struct file_operations *i_fop; + int mode; +}; + +static struct rpc_filelist files[] = { + [RPCAUTH_lockd] = { + .name = "lockd", + .mode = S_IFDIR | S_IRUGO | S_IXUGO, + }, + [RPCAUTH_mount] = { + .name = "mount", + .mode = S_IFDIR | S_IRUGO | S_IXUGO, + }, + [RPCAUTH_nfs] = { + .name = "nfs", + .mode = S_IFDIR | S_IRUGO | S_IXUGO, + }, + [RPCAUTH_portmap] = { + .name = "portmap", + .mode = S_IFDIR | S_IRUGO | S_IXUGO, + }, + [RPCAUTH_statd] = { + .name = "statd", + .mode = S_IFDIR | S_IRUGO | S_IXUGO, + }, +}; + +enum { + RPCAUTH_info = 2, + RPCAUTH_EOF +}; + +static struct rpc_filelist authfiles[] = { + [RPCAUTH_info] = { + .name = "info", + .i_fop = &rpc_info_operations, + .mode = S_IFREG | S_IRUSR, + }, +}; + +struct vfsmount *rpc_get_mount(void) +{ + int err; + + err = simple_pin_fs(&rpc_pipe_fs_type, &rpc_mount, &rpc_mount_count); + if (err != 0) + return ERR_PTR(err); + return rpc_mount; +} + +void rpc_put_mount(void) +{ + simple_release_fs(&rpc_mount, &rpc_mount_count); +} + +static int rpc_delete_dentry(struct dentry *dentry) +{ + return 1; +} + +static struct dentry_operations rpc_dentry_operations = { + .d_delete = rpc_delete_dentry, +}; + +static int +rpc_lookup_parent(char *path, struct nameidata *nd) +{ + struct vfsmount *mnt; + + if (path[0] == '\0') + return -ENOENT; + + mnt = rpc_get_mount(); + if (IS_ERR(mnt)) { + printk(KERN_WARNING "%s: %s failed to mount " + "pseudofilesystem \n", __FILE__, __func__); + return PTR_ERR(mnt); + } + + if (vfs_path_lookup(mnt->mnt_root, mnt, path, LOOKUP_PARENT, nd)) { + printk(KERN_WARNING "%s: %s failed to find path %s\n", + __FILE__, __func__, path); + rpc_put_mount(); + return -ENOENT; + } + return 0; +} + +static void +rpc_release_path(struct nameidata *nd) +{ + path_put(&nd->path); + rpc_put_mount(); +} + +static struct inode * +rpc_get_inode(struct super_block *sb, int mode) +{ + struct inode *inode = new_inode(sb); + if (!inode) + return NULL; + inode->i_mode = mode; + inode->i_uid = inode->i_gid = 0; + inode->i_blocks = 0; + inode->i_atime = inode->i_mtime = inode->i_ctime = CURRENT_TIME; + switch(mode & S_IFMT) { + case S_IFDIR: + inode->i_fop = &simple_dir_operations; + inode->i_op = &simple_dir_inode_operations; + inc_nlink(inode); + default: + break; + } + return inode; +} + +/* + * FIXME: This probably has races. + */ +static void rpc_depopulate(struct dentry *parent, + unsigned long start, unsigned long eof) +{ + struct inode *dir = parent->d_inode; + struct list_head *pos, *next; + struct dentry *dentry, *dvec[10]; + int n = 0; + + mutex_lock_nested(&dir->i_mutex, I_MUTEX_CHILD); +repeat: + spin_lock(&dcache_lock); + list_for_each_safe(pos, next, &parent->d_subdirs) { + dentry = list_entry(pos, struct dentry, d_u.d_child); + if (!dentry->d_inode || + dentry->d_inode->i_ino < start || + dentry->d_inode->i_ino >= eof) + continue; + spin_lock(&dentry->d_lock); + if (!d_unhashed(dentry)) { + dget_locked(dentry); + __d_drop(dentry); + spin_unlock(&dentry->d_lock); + dvec[n++] = dentry; + if (n == ARRAY_SIZE(dvec)) + break; + } else + spin_unlock(&dentry->d_lock); + } + spin_unlock(&dcache_lock); + if (n) { + do { + dentry = dvec[--n]; + if (S_ISREG(dentry->d_inode->i_mode)) + simple_unlink(dir, dentry); + else if (S_ISDIR(dentry->d_inode->i_mode)) + simple_rmdir(dir, dentry); + d_delete(dentry); + dput(dentry); + } while (n); + goto repeat; + } + mutex_unlock(&dir->i_mutex); +} + +static int +rpc_populate(struct dentry *parent, + struct rpc_filelist *files, + int start, int eof) +{ + struct inode *inode, *dir = parent->d_inode; + void *private = RPC_I(dir)->private; + struct dentry *dentry; + int mode, i; + + mutex_lock(&dir->i_mutex); + for (i = start; i < eof; i++) { + dentry = d_alloc_name(parent, files[i].name); + if (!dentry) + goto out_bad; + dentry->d_op = &rpc_dentry_operations; + mode = files[i].mode; + inode = rpc_get_inode(dir->i_sb, mode); + if (!inode) { + dput(dentry); + goto out_bad; + } + inode->i_ino = i; + if (files[i].i_fop) + inode->i_fop = files[i].i_fop; + if (private) + rpc_inode_setowner(inode, private); + if (S_ISDIR(mode)) + inc_nlink(dir); + d_add(dentry, inode); + fsnotify_create(dir, dentry); + } + mutex_unlock(&dir->i_mutex); + return 0; +out_bad: + mutex_unlock(&dir->i_mutex); + printk(KERN_WARNING "%s: %s failed to populate directory %s\n", + __FILE__, __func__, parent->d_name.name); + return -ENOMEM; +} + +static int +__rpc_mkdir(struct inode *dir, struct dentry *dentry) +{ + struct inode *inode; + + inode = rpc_get_inode(dir->i_sb, S_IFDIR | S_IRUGO | S_IXUGO); + if (!inode) + goto out_err; + inode->i_ino = iunique(dir->i_sb, 100); + d_instantiate(dentry, inode); + inc_nlink(dir); + fsnotify_mkdir(dir, dentry); + return 0; +out_err: + printk(KERN_WARNING "%s: %s failed to allocate inode for dentry %s\n", + __FILE__, __func__, dentry->d_name.name); + return -ENOMEM; +} + +static int +__rpc_rmdir(struct inode *dir, struct dentry *dentry) +{ + int error; + error = simple_rmdir(dir, dentry); + if (!error) + d_delete(dentry); + return error; +} + +static struct dentry * +rpc_lookup_create(struct dentry *parent, const char *name, int len, int exclusive) +{ + struct inode *dir = parent->d_inode; + struct dentry *dentry; + + mutex_lock_nested(&dir->i_mutex, I_MUTEX_PARENT); + dentry = lookup_one_len(name, parent, len); + if (IS_ERR(dentry)) + goto out_err; + if (!dentry->d_inode) + dentry->d_op = &rpc_dentry_operations; + else if (exclusive) { + dput(dentry); + dentry = ERR_PTR(-EEXIST); + goto out_err; + } + return dentry; +out_err: + mutex_unlock(&dir->i_mutex); + return dentry; +} + +static struct dentry * +rpc_lookup_negative(char *path, struct nameidata *nd) +{ + struct dentry *dentry; + int error; + + if ((error = rpc_lookup_parent(path, nd)) != 0) + return ERR_PTR(error); + dentry = rpc_lookup_create(nd->path.dentry, nd->last.name, nd->last.len, + 1); + if (IS_ERR(dentry)) + rpc_release_path(nd); + return dentry; +} + +/** + * rpc_mkdir - Create a new directory in rpc_pipefs + * @path: path from the rpc_pipefs root to the new directory + * @rpc_client: rpc client to associate with this directory + * + * This creates a directory at the given @path associated with + * @rpc_clnt, which will contain a file named "info" with some basic + * information about the client, together with any "pipes" that may + * later be created using rpc_mkpipe(). + */ +struct dentry * +rpc_mkdir(char *path, struct rpc_clnt *rpc_client) +{ + struct nameidata nd; + struct dentry *dentry; + struct inode *dir; + int error; + + dentry = rpc_lookup_negative(path, &nd); + if (IS_ERR(dentry)) + return dentry; + dir = nd.path.dentry->d_inode; + if ((error = __rpc_mkdir(dir, dentry)) != 0) + goto err_dput; + RPC_I(dentry->d_inode)->private = rpc_client; + error = rpc_populate(dentry, authfiles, + RPCAUTH_info, RPCAUTH_EOF); + if (error) + goto err_depopulate; + dget(dentry); +out: + mutex_unlock(&dir->i_mutex); + rpc_release_path(&nd); + return dentry; +err_depopulate: + rpc_depopulate(dentry, RPCAUTH_info, RPCAUTH_EOF); + __rpc_rmdir(dir, dentry); +err_dput: + dput(dentry); + printk(KERN_WARNING "%s: %s() failed to create directory %s (errno = %d)\n", + __FILE__, __func__, path, error); + dentry = ERR_PTR(error); + goto out; +} + +/** + * rpc_rmdir - Remove a directory created with rpc_mkdir() + * @dentry: directory to remove + */ +int +rpc_rmdir(struct dentry *dentry) +{ + struct dentry *parent; + struct inode *dir; + int error; + + parent = dget_parent(dentry); + dir = parent->d_inode; + mutex_lock_nested(&dir->i_mutex, I_MUTEX_PARENT); + rpc_depopulate(dentry, RPCAUTH_info, RPCAUTH_EOF); + error = __rpc_rmdir(dir, dentry); + dput(dentry); + mutex_unlock(&dir->i_mutex); + dput(parent); + return error; +} + +/** + * rpc_mkpipe - make an rpc_pipefs file for kernel<->userspace communication + * @parent: dentry of directory to create new "pipe" in + * @name: name of pipe + * @private: private data to associate with the pipe, for the caller's use + * @ops: operations defining the behavior of the pipe: upcall, downcall, + * release_pipe, and destroy_msg. + * @flags: rpc_inode flags + * + * Data is made available for userspace to read by calls to + * rpc_queue_upcall(). The actual reads will result in calls to + * @ops->upcall, which will be called with the file pointer, + * message, and userspace buffer to copy to. + * + * Writes can come at any time, and do not necessarily have to be + * responses to upcalls. They will result in calls to @msg->downcall. + * + * The @private argument passed here will be available to all these methods + * from the file pointer, via RPC_I(file->f_dentry->d_inode)->private. + */ +struct dentry * +rpc_mkpipe(struct dentry *parent, const char *name, void *private, struct rpc_pipe_ops *ops, int flags) +{ + struct dentry *dentry; + struct inode *dir, *inode; + struct rpc_inode *rpci; + + dentry = rpc_lookup_create(parent, name, strlen(name), 0); + if (IS_ERR(dentry)) + return dentry; + dir = parent->d_inode; + if (dentry->d_inode) { + rpci = RPC_I(dentry->d_inode); + if (rpci->private != private || + rpci->ops != ops || + rpci->flags != flags) { + dput (dentry); + dentry = ERR_PTR(-EBUSY); + } + rpci->nkern_readwriters++; + goto out; + } + inode = rpc_get_inode(dir->i_sb, S_IFIFO | S_IRUSR | S_IWUSR); + if (!inode) + goto err_dput; + inode->i_ino = iunique(dir->i_sb, 100); + inode->i_fop = &rpc_pipe_fops; + d_instantiate(dentry, inode); + rpci = RPC_I(inode); + rpci->private = private; + rpci->flags = flags; + rpci->ops = ops; + rpci->nkern_readwriters = 1; + fsnotify_create(dir, dentry); + dget(dentry); +out: + mutex_unlock(&dir->i_mutex); + return dentry; +err_dput: + dput(dentry); + dentry = ERR_PTR(-ENOMEM); + printk(KERN_WARNING "%s: %s() failed to create pipe %s/%s (errno = %d)\n", + __FILE__, __func__, parent->d_name.name, name, + -ENOMEM); + goto out; +} +EXPORT_SYMBOL(rpc_mkpipe); + +/** + * rpc_unlink - remove a pipe + * @dentry: dentry for the pipe, as returned from rpc_mkpipe + * + * After this call, lookups will no longer find the pipe, and any + * attempts to read or write using preexisting opens of the pipe will + * return -EPIPE. + */ +int +rpc_unlink(struct dentry *dentry) +{ + struct dentry *parent; + struct inode *dir; + int error = 0; + + parent = dget_parent(dentry); + dir = parent->d_inode; + mutex_lock_nested(&dir->i_mutex, I_MUTEX_PARENT); + if (--RPC_I(dentry->d_inode)->nkern_readwriters == 0) { + rpc_close_pipes(dentry->d_inode); + error = simple_unlink(dir, dentry); + if (!error) + d_delete(dentry); + } + dput(dentry); + mutex_unlock(&dir->i_mutex); + dput(parent); + return error; +} +EXPORT_SYMBOL(rpc_unlink); + +/* + * populate the filesystem + */ +static struct super_operations s_ops = { + .alloc_inode = rpc_alloc_inode, + .destroy_inode = rpc_destroy_inode, + .statfs = simple_statfs, +}; + +#define RPCAUTH_GSSMAGIC 0x67596969 + +static int +rpc_fill_super(struct super_block *sb, void *data, int silent) +{ + struct inode *inode; + struct dentry *root; + + sb->s_blocksize = PAGE_CACHE_SIZE; + sb->s_blocksize_bits = PAGE_CACHE_SHIFT; + sb->s_magic = RPCAUTH_GSSMAGIC; + sb->s_op = &s_ops; + sb->s_time_gran = 1; + + inode = rpc_get_inode(sb, S_IFDIR | 0755); + if (!inode) + return -ENOMEM; + root = d_alloc_root(inode); + if (!root) { + iput(inode); + return -ENOMEM; + } + if (rpc_populate(root, files, RPCAUTH_Root + 1, RPCAUTH_RootEOF)) + goto out; + sb->s_root = root; + return 0; +out: + d_genocide(root); + dput(root); + return -ENOMEM; +} + +static int +rpc_get_sb(struct file_system_type *fs_type, + int flags, const char *dev_name, void *data, struct vfsmount *mnt) +{ + return get_sb_single(fs_type, flags, data, rpc_fill_super, mnt); +} + +static struct file_system_type rpc_pipe_fs_type = { + .owner = THIS_MODULE, + .name = "rpc_pipefs", + .get_sb = rpc_get_sb, + .kill_sb = kill_litter_super, +}; + +static void +init_once(void *foo) +{ + struct rpc_inode *rpci = (struct rpc_inode *) foo; + + inode_init_once(&rpci->vfs_inode); + rpci->private = NULL; + rpci->nreaders = 0; + rpci->nwriters = 0; + INIT_LIST_HEAD(&rpci->in_upcall); + INIT_LIST_HEAD(&rpci->in_downcall); + INIT_LIST_HEAD(&rpci->pipe); + rpci->pipelen = 0; + init_waitqueue_head(&rpci->waitq); + INIT_DELAYED_WORK(&rpci->queue_timeout, + rpc_timeout_upcall_queue); + rpci->ops = NULL; +} + +int register_rpc_pipefs(void) +{ + int err; + + rpc_inode_cachep = kmem_cache_create("rpc_inode_cache", + sizeof(struct rpc_inode), + 0, (SLAB_HWCACHE_ALIGN|SLAB_RECLAIM_ACCOUNT| + SLAB_MEM_SPREAD), + init_once); + if (!rpc_inode_cachep) + return -ENOMEM; + err = register_filesystem(&rpc_pipe_fs_type); + if (err) { + kmem_cache_destroy(rpc_inode_cachep); + return err; + } + + return 0; +} + +void unregister_rpc_pipefs(void) +{ + kmem_cache_destroy(rpc_inode_cachep); + unregister_filesystem(&rpc_pipe_fs_type); +} diff --git a/net/sunrpc/rpcb_clnt.c b/net/sunrpc/rpcb_clnt.c new file mode 100644 index 0000000..41013dd --- /dev/null +++ b/net/sunrpc/rpcb_clnt.c @@ -0,0 +1,873 @@ +/* + * In-kernel rpcbind client supporting versions 2, 3, and 4 of the rpcbind + * protocol + * + * Based on RFC 1833: "Binding Protocols for ONC RPC Version 2" and + * RFC 3530: "Network File System (NFS) version 4 Protocol" + * + * Original: Gilles Quillard, Bull Open Source, 2005 <gilles.quillard@bull.net> + * Updated: Chuck Lever, Oracle Corporation, 2007 <chuck.lever@oracle.com> + * + * Descended from net/sunrpc/pmap_clnt.c, + * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/module.h> + +#include <linux/types.h> +#include <linux/socket.h> +#include <linux/in.h> +#include <linux/in6.h> +#include <linux/kernel.h> +#include <linux/errno.h> +#include <net/ipv6.h> + +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/xprtsock.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_BIND +#endif + +#define RPCBIND_PROGRAM (100000u) +#define RPCBIND_PORT (111u) + +#define RPCBVERS_2 (2u) +#define RPCBVERS_3 (3u) +#define RPCBVERS_4 (4u) + +enum { + RPCBPROC_NULL, + RPCBPROC_SET, + RPCBPROC_UNSET, + RPCBPROC_GETPORT, + RPCBPROC_GETADDR = 3, /* alias for GETPORT */ + RPCBPROC_DUMP, + RPCBPROC_CALLIT, + RPCBPROC_BCAST = 5, /* alias for CALLIT */ + RPCBPROC_GETTIME, + RPCBPROC_UADDR2TADDR, + RPCBPROC_TADDR2UADDR, + RPCBPROC_GETVERSADDR, + RPCBPROC_INDIRECT, + RPCBPROC_GETADDRLIST, + RPCBPROC_GETSTAT, +}; + +#define RPCB_HIGHPROC_2 RPCBPROC_CALLIT +#define RPCB_HIGHPROC_3 RPCBPROC_TADDR2UADDR +#define RPCB_HIGHPROC_4 RPCBPROC_GETSTAT + +/* + * r_owner + * + * The "owner" is allowed to unset a service in the rpcbind database. + * We always use the following (arbitrary) fixed string. + */ +#define RPCB_OWNER_STRING "rpcb" +#define RPCB_MAXOWNERLEN sizeof(RPCB_OWNER_STRING) + +static void rpcb_getport_done(struct rpc_task *, void *); +static void rpcb_map_release(void *data); +static struct rpc_program rpcb_program; + +struct rpcbind_args { + struct rpc_xprt * r_xprt; + + u32 r_prog; + u32 r_vers; + u32 r_prot; + unsigned short r_port; + const char * r_netid; + const char * r_addr; + const char * r_owner; + + int r_status; +}; + +static struct rpc_procinfo rpcb_procedures2[]; +static struct rpc_procinfo rpcb_procedures3[]; +static struct rpc_procinfo rpcb_procedures4[]; + +struct rpcb_info { + u32 rpc_vers; + struct rpc_procinfo * rpc_proc; +}; + +static struct rpcb_info rpcb_next_version[]; +static struct rpcb_info rpcb_next_version6[]; + +static const struct rpc_call_ops rpcb_getport_ops = { + .rpc_call_done = rpcb_getport_done, + .rpc_release = rpcb_map_release, +}; + +static void rpcb_wake_rpcbind_waiters(struct rpc_xprt *xprt, int status) +{ + xprt_clear_binding(xprt); + rpc_wake_up_status(&xprt->binding, status); +} + +static void rpcb_map_release(void *data) +{ + struct rpcbind_args *map = data; + + rpcb_wake_rpcbind_waiters(map->r_xprt, map->r_status); + xprt_put(map->r_xprt); + kfree(map); +} + +static const struct sockaddr_in rpcb_inaddr_loopback = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_LOOPBACK), + .sin_port = htons(RPCBIND_PORT), +}; + +static const struct sockaddr_in6 rpcb_in6addr_loopback = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_LOOPBACK_INIT, + .sin6_port = htons(RPCBIND_PORT), +}; + +static struct rpc_clnt *rpcb_create_local(struct sockaddr *addr, + size_t addrlen, u32 version) +{ + struct rpc_create_args args = { + .protocol = XPRT_TRANSPORT_UDP, + .address = addr, + .addrsize = addrlen, + .servername = "localhost", + .program = &rpcb_program, + .version = version, + .authflavor = RPC_AUTH_UNIX, + .flags = RPC_CLNT_CREATE_NOPING, + }; + + return rpc_create(&args); +} + +static struct rpc_clnt *rpcb_create(char *hostname, struct sockaddr *srvaddr, + size_t salen, int proto, u32 version) +{ + struct rpc_create_args args = { + .protocol = proto, + .address = srvaddr, + .addrsize = salen, + .servername = hostname, + .program = &rpcb_program, + .version = version, + .authflavor = RPC_AUTH_UNIX, + .flags = (RPC_CLNT_CREATE_NOPING | + RPC_CLNT_CREATE_NONPRIVPORT), + }; + + switch (srvaddr->sa_family) { + case AF_INET: + ((struct sockaddr_in *)srvaddr)->sin_port = htons(RPCBIND_PORT); + break; + case AF_INET6: + ((struct sockaddr_in6 *)srvaddr)->sin6_port = htons(RPCBIND_PORT); + break; + default: + return NULL; + } + + return rpc_create(&args); +} + +static int rpcb_register_call(struct sockaddr *addr, size_t addrlen, + u32 version, struct rpc_message *msg) +{ + struct rpc_clnt *rpcb_clnt; + int result, error = 0; + + msg->rpc_resp = &result; + + rpcb_clnt = rpcb_create_local(addr, addrlen, version); + if (!IS_ERR(rpcb_clnt)) { + error = rpc_call_sync(rpcb_clnt, msg, 0); + rpc_shutdown_client(rpcb_clnt); + } else + error = PTR_ERR(rpcb_clnt); + + if (error < 0) { + printk(KERN_WARNING "RPC: failed to contact local rpcbind " + "server (errno %d).\n", -error); + return error; + } + + if (!result) + return -EACCES; + return 0; +} + +/** + * rpcb_register - set or unset a port registration with the local rpcbind svc + * @prog: RPC program number to bind + * @vers: RPC version number to bind + * @prot: transport protocol to register + * @port: port value to register + * + * Returns zero if the registration request was dispatched successfully + * and the rpcbind daemon returned success. Otherwise, returns an errno + * value that reflects the nature of the error (request could not be + * dispatched, timed out, or rpcbind returned an error). + * + * RPC services invoke this function to advertise their contact + * information via the system's rpcbind daemon. RPC services + * invoke this function once for each [program, version, transport] + * tuple they wish to advertise. + * + * Callers may also unregister RPC services that are no longer + * available by setting the passed-in port to zero. This removes + * all registered transports for [program, version] from the local + * rpcbind database. + * + * This function uses rpcbind protocol version 2 to contact the + * local rpcbind daemon. + * + * Registration works over both AF_INET and AF_INET6, and services + * registered via this function are advertised as available for any + * address. If the local rpcbind daemon is listening on AF_INET6, + * services registered via this function will be advertised on + * IN6ADDR_ANY (ie available for all AF_INET and AF_INET6 + * addresses). + */ +int rpcb_register(u32 prog, u32 vers, int prot, unsigned short port) +{ + struct rpcbind_args map = { + .r_prog = prog, + .r_vers = vers, + .r_prot = prot, + .r_port = port, + }; + struct rpc_message msg = { + .rpc_argp = &map, + }; + + dprintk("RPC: %sregistering (%u, %u, %d, %u) with local " + "rpcbind\n", (port ? "" : "un"), + prog, vers, prot, port); + + msg.rpc_proc = &rpcb_procedures2[RPCBPROC_UNSET]; + if (port) + msg.rpc_proc = &rpcb_procedures2[RPCBPROC_SET]; + + return rpcb_register_call((struct sockaddr *)&rpcb_inaddr_loopback, + sizeof(rpcb_inaddr_loopback), + RPCBVERS_2, &msg); +} + +/* + * Fill in AF_INET family-specific arguments to register + */ +static int rpcb_register_netid4(struct sockaddr_in *address_to_register, + struct rpc_message *msg) +{ + struct rpcbind_args *map = msg->rpc_argp; + unsigned short port = ntohs(address_to_register->sin_port); + char buf[32]; + + /* Construct AF_INET universal address */ + snprintf(buf, sizeof(buf), + NIPQUAD_FMT".%u.%u", + NIPQUAD(address_to_register->sin_addr.s_addr), + port >> 8, port & 0xff); + map->r_addr = buf; + + dprintk("RPC: %sregistering [%u, %u, %s, '%s'] with " + "local rpcbind\n", (port ? "" : "un"), + map->r_prog, map->r_vers, + map->r_addr, map->r_netid); + + msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET]; + if (port) + msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET]; + + return rpcb_register_call((struct sockaddr *)&rpcb_inaddr_loopback, + sizeof(rpcb_inaddr_loopback), + RPCBVERS_4, msg); +} + +/* + * Fill in AF_INET6 family-specific arguments to register + */ +static int rpcb_register_netid6(struct sockaddr_in6 *address_to_register, + struct rpc_message *msg) +{ + struct rpcbind_args *map = msg->rpc_argp; + unsigned short port = ntohs(address_to_register->sin6_port); + char buf[64]; + + /* Construct AF_INET6 universal address */ + if (ipv6_addr_any(&address_to_register->sin6_addr)) + snprintf(buf, sizeof(buf), "::.%u.%u", + port >> 8, port & 0xff); + else + snprintf(buf, sizeof(buf), NIP6_FMT".%u.%u", + NIP6(address_to_register->sin6_addr), + port >> 8, port & 0xff); + map->r_addr = buf; + + dprintk("RPC: %sregistering [%u, %u, %s, '%s'] with " + "local rpcbind\n", (port ? "" : "un"), + map->r_prog, map->r_vers, + map->r_addr, map->r_netid); + + msg->rpc_proc = &rpcb_procedures4[RPCBPROC_UNSET]; + if (port) + msg->rpc_proc = &rpcb_procedures4[RPCBPROC_SET]; + + return rpcb_register_call((struct sockaddr *)&rpcb_in6addr_loopback, + sizeof(rpcb_in6addr_loopback), + RPCBVERS_4, msg); +} + +/** + * rpcb_v4_register - set or unset a port registration with the local rpcbind + * @program: RPC program number of service to (un)register + * @version: RPC version number of service to (un)register + * @address: address family, IP address, and port to (un)register + * @netid: netid of transport protocol to (un)register + * + * Returns zero if the registration request was dispatched successfully + * and the rpcbind daemon returned success. Otherwise, returns an errno + * value that reflects the nature of the error (request could not be + * dispatched, timed out, or rpcbind returned an error). + * + * RPC services invoke this function to advertise their contact + * information via the system's rpcbind daemon. RPC services + * invoke this function once for each [program, version, address, + * netid] tuple they wish to advertise. + * + * Callers may also unregister RPC services that are no longer + * available by setting the port number in the passed-in address + * to zero. Callers pass a netid of "" to unregister all + * transport netids associated with [program, version, address]. + * + * This function uses rpcbind protocol version 4 to contact the + * local rpcbind daemon. The local rpcbind daemon must support + * version 4 of the rpcbind protocol in order for these functions + * to register a service successfully. + * + * Supported netids include "udp" and "tcp" for UDP and TCP over + * IPv4, and "udp6" and "tcp6" for UDP and TCP over IPv6, + * respectively. + * + * The contents of @address determine the address family and the + * port to be registered. The usual practice is to pass INADDR_ANY + * as the raw address, but specifying a non-zero address is also + * supported by this API if the caller wishes to advertise an RPC + * service on a specific network interface. + * + * Note that passing in INADDR_ANY does not create the same service + * registration as IN6ADDR_ANY. The former advertises an RPC + * service on any IPv4 address, but not on IPv6. The latter + * advertises the service on all IPv4 and IPv6 addresses. + */ +int rpcb_v4_register(const u32 program, const u32 version, + const struct sockaddr *address, const char *netid) +{ + struct rpcbind_args map = { + .r_prog = program, + .r_vers = version, + .r_netid = netid, + .r_owner = RPCB_OWNER_STRING, + }; + struct rpc_message msg = { + .rpc_argp = &map, + }; + + switch (address->sa_family) { + case AF_INET: + return rpcb_register_netid4((struct sockaddr_in *)address, + &msg); + case AF_INET6: + return rpcb_register_netid6((struct sockaddr_in6 *)address, + &msg); + } + + return -EAFNOSUPPORT; +} + +/** + * rpcb_getport_sync - obtain the port for an RPC service on a given host + * @sin: address of remote peer + * @prog: RPC program number to bind + * @vers: RPC version number to bind + * @prot: transport protocol to use to make this request + * + * Return value is the requested advertised port number, + * or a negative errno value. + * + * Called from outside the RPC client in a synchronous task context. + * Uses default timeout parameters specified by underlying transport. + * + * XXX: Needs to support IPv6 + */ +int rpcb_getport_sync(struct sockaddr_in *sin, u32 prog, u32 vers, int prot) +{ + struct rpcbind_args map = { + .r_prog = prog, + .r_vers = vers, + .r_prot = prot, + .r_port = 0, + }; + struct rpc_message msg = { + .rpc_proc = &rpcb_procedures2[RPCBPROC_GETPORT], + .rpc_argp = &map, + .rpc_resp = &map.r_port, + }; + struct rpc_clnt *rpcb_clnt; + int status; + + dprintk("RPC: %s(" NIPQUAD_FMT ", %u, %u, %d)\n", + __func__, NIPQUAD(sin->sin_addr.s_addr), prog, vers, prot); + + rpcb_clnt = rpcb_create(NULL, (struct sockaddr *)sin, + sizeof(*sin), prot, RPCBVERS_2); + if (IS_ERR(rpcb_clnt)) + return PTR_ERR(rpcb_clnt); + + status = rpc_call_sync(rpcb_clnt, &msg, 0); + rpc_shutdown_client(rpcb_clnt); + + if (status >= 0) { + if (map.r_port != 0) + return map.r_port; + status = -EACCES; + } + return status; +} +EXPORT_SYMBOL_GPL(rpcb_getport_sync); + +static struct rpc_task *rpcb_call_async(struct rpc_clnt *rpcb_clnt, struct rpcbind_args *map, struct rpc_procinfo *proc) +{ + struct rpc_message msg = { + .rpc_proc = proc, + .rpc_argp = map, + .rpc_resp = &map->r_port, + }; + struct rpc_task_setup task_setup_data = { + .rpc_client = rpcb_clnt, + .rpc_message = &msg, + .callback_ops = &rpcb_getport_ops, + .callback_data = map, + .flags = RPC_TASK_ASYNC, + }; + + return rpc_run_task(&task_setup_data); +} + +/* + * In the case where rpc clients have been cloned, we want to make + * sure that we use the program number/version etc of the actual + * owner of the xprt. To do so, we walk back up the tree of parents + * to find whoever created the transport and/or whoever has the + * autobind flag set. + */ +static struct rpc_clnt *rpcb_find_transport_owner(struct rpc_clnt *clnt) +{ + struct rpc_clnt *parent = clnt->cl_parent; + + while (parent != clnt) { + if (parent->cl_xprt != clnt->cl_xprt) + break; + if (clnt->cl_autobind) + break; + clnt = parent; + parent = parent->cl_parent; + } + return clnt; +} + +/** + * rpcb_getport_async - obtain the port for a given RPC service on a given host + * @task: task that is waiting for portmapper request + * + * This one can be called for an ongoing RPC request, and can be used in + * an async (rpciod) context. + */ +void rpcb_getport_async(struct rpc_task *task) +{ + struct rpc_clnt *clnt; + struct rpc_procinfo *proc; + u32 bind_version; + struct rpc_xprt *xprt; + struct rpc_clnt *rpcb_clnt; + static struct rpcbind_args *map; + struct rpc_task *child; + struct sockaddr_storage addr; + struct sockaddr *sap = (struct sockaddr *)&addr; + size_t salen; + int status; + + clnt = rpcb_find_transport_owner(task->tk_client); + xprt = clnt->cl_xprt; + + dprintk("RPC: %5u %s(%s, %u, %u, %d)\n", + task->tk_pid, __func__, + clnt->cl_server, clnt->cl_prog, clnt->cl_vers, xprt->prot); + + /* Put self on the wait queue to ensure we get notified if + * some other task is already attempting to bind the port */ + rpc_sleep_on(&xprt->binding, task, NULL); + + if (xprt_test_and_set_binding(xprt)) { + dprintk("RPC: %5u %s: waiting for another binder\n", + task->tk_pid, __func__); + return; + } + + /* Someone else may have bound if we slept */ + if (xprt_bound(xprt)) { + status = 0; + dprintk("RPC: %5u %s: already bound\n", + task->tk_pid, __func__); + goto bailout_nofree; + } + + salen = rpc_peeraddr(clnt, sap, sizeof(addr)); + + /* Don't ever use rpcbind v2 for AF_INET6 requests */ + switch (sap->sa_family) { + case AF_INET: + proc = rpcb_next_version[xprt->bind_index].rpc_proc; + bind_version = rpcb_next_version[xprt->bind_index].rpc_vers; + break; + case AF_INET6: + proc = rpcb_next_version6[xprt->bind_index].rpc_proc; + bind_version = rpcb_next_version6[xprt->bind_index].rpc_vers; + break; + default: + status = -EAFNOSUPPORT; + dprintk("RPC: %5u %s: bad address family\n", + task->tk_pid, __func__); + goto bailout_nofree; + } + if (proc == NULL) { + xprt->bind_index = 0; + status = -EPFNOSUPPORT; + dprintk("RPC: %5u %s: no more getport versions available\n", + task->tk_pid, __func__); + goto bailout_nofree; + } + + dprintk("RPC: %5u %s: trying rpcbind version %u\n", + task->tk_pid, __func__, bind_version); + + rpcb_clnt = rpcb_create(clnt->cl_server, sap, salen, xprt->prot, + bind_version); + if (IS_ERR(rpcb_clnt)) { + status = PTR_ERR(rpcb_clnt); + dprintk("RPC: %5u %s: rpcb_create failed, error %ld\n", + task->tk_pid, __func__, PTR_ERR(rpcb_clnt)); + goto bailout_nofree; + } + + map = kzalloc(sizeof(struct rpcbind_args), GFP_ATOMIC); + if (!map) { + status = -ENOMEM; + dprintk("RPC: %5u %s: no memory available\n", + task->tk_pid, __func__); + goto bailout_release_client; + } + map->r_prog = clnt->cl_prog; + map->r_vers = clnt->cl_vers; + map->r_prot = xprt->prot; + map->r_port = 0; + map->r_xprt = xprt_get(xprt); + map->r_netid = rpc_peeraddr2str(clnt, RPC_DISPLAY_NETID); + map->r_addr = rpc_peeraddr2str(rpcb_clnt, RPC_DISPLAY_UNIVERSAL_ADDR); + map->r_owner = RPCB_OWNER_STRING; /* ignored for GETADDR */ + map->r_status = -EIO; + + child = rpcb_call_async(rpcb_clnt, map, proc); + rpc_release_client(rpcb_clnt); + if (IS_ERR(child)) { + /* rpcb_map_release() has freed the arguments */ + dprintk("RPC: %5u %s: rpc_run_task failed\n", + task->tk_pid, __func__); + return; + } + + xprt->stat.bind_count++; + rpc_put_task(child); + return; + +bailout_release_client: + rpc_release_client(rpcb_clnt); +bailout_nofree: + rpcb_wake_rpcbind_waiters(xprt, status); + task->tk_status = status; +} +EXPORT_SYMBOL_GPL(rpcb_getport_async); + +/* + * Rpcbind child task calls this callback via tk_exit. + */ +static void rpcb_getport_done(struct rpc_task *child, void *data) +{ + struct rpcbind_args *map = data; + struct rpc_xprt *xprt = map->r_xprt; + int status = child->tk_status; + + /* Garbage reply: retry with a lesser rpcbind version */ + if (status == -EIO) + status = -EPROTONOSUPPORT; + + /* rpcbind server doesn't support this rpcbind protocol version */ + if (status == -EPROTONOSUPPORT) + xprt->bind_index++; + + if (status < 0) { + /* rpcbind server not available on remote host? */ + xprt->ops->set_port(xprt, 0); + } else if (map->r_port == 0) { + /* Requested RPC service wasn't registered on remote host */ + xprt->ops->set_port(xprt, 0); + status = -EACCES; + } else { + /* Succeeded */ + xprt->ops->set_port(xprt, map->r_port); + xprt_set_bound(xprt); + status = 0; + } + + dprintk("RPC: %5u rpcb_getport_done(status %d, port %u)\n", + child->tk_pid, status, map->r_port); + + map->r_status = status; +} + +/* + * XDR functions for rpcbind + */ + +static int rpcb_encode_mapping(struct rpc_rqst *req, __be32 *p, + struct rpcbind_args *rpcb) +{ + dprintk("RPC: encoding rpcb request (%u, %u, %d, %u)\n", + rpcb->r_prog, rpcb->r_vers, rpcb->r_prot, rpcb->r_port); + *p++ = htonl(rpcb->r_prog); + *p++ = htonl(rpcb->r_vers); + *p++ = htonl(rpcb->r_prot); + *p++ = htonl(rpcb->r_port); + + req->rq_slen = xdr_adjust_iovec(req->rq_svec, p); + return 0; +} + +static int rpcb_decode_getport(struct rpc_rqst *req, __be32 *p, + unsigned short *portp) +{ + *portp = (unsigned short) ntohl(*p++); + dprintk("RPC: rpcb getport result: %u\n", + *portp); + return 0; +} + +static int rpcb_decode_set(struct rpc_rqst *req, __be32 *p, + unsigned int *boolp) +{ + *boolp = (unsigned int) ntohl(*p++); + dprintk("RPC: rpcb set/unset call %s\n", + (*boolp ? "succeeded" : "failed")); + return 0; +} + +static int rpcb_encode_getaddr(struct rpc_rqst *req, __be32 *p, + struct rpcbind_args *rpcb) +{ + dprintk("RPC: encoding rpcb request (%u, %u, %s)\n", + rpcb->r_prog, rpcb->r_vers, rpcb->r_addr); + *p++ = htonl(rpcb->r_prog); + *p++ = htonl(rpcb->r_vers); + + p = xdr_encode_string(p, rpcb->r_netid); + p = xdr_encode_string(p, rpcb->r_addr); + p = xdr_encode_string(p, rpcb->r_owner); + + req->rq_slen = xdr_adjust_iovec(req->rq_svec, p); + + return 0; +} + +static int rpcb_decode_getaddr(struct rpc_rqst *req, __be32 *p, + unsigned short *portp) +{ + char *addr; + u32 addr_len; + int c, i, f, first, val; + + *portp = 0; + addr_len = ntohl(*p++); + + /* + * Simple sanity check. The smallest possible universal + * address is an IPv4 address string containing 11 bytes. + */ + if (addr_len < 11 || addr_len > RPCBIND_MAXUADDRLEN) + goto out_err; + + /* + * Start at the end and walk backwards until the first dot + * is encountered. When the second dot is found, we have + * both parts of the port number. + */ + addr = (char *)p; + val = 0; + first = 1; + f = 1; + for (i = addr_len - 1; i > 0; i--) { + c = addr[i]; + if (c >= '0' && c <= '9') { + val += (c - '0') * f; + f *= 10; + } else if (c == '.') { + if (first) { + *portp = val; + val = first = 0; + f = 1; + } else { + *portp |= (val << 8); + break; + } + } + } + + /* + * Simple sanity check. If we never saw a dot in the reply, + * then this was probably just garbage. + */ + if (first) + goto out_err; + + dprintk("RPC: rpcb_decode_getaddr port=%u\n", *portp); + return 0; + +out_err: + dprintk("RPC: rpcbind server returned malformed reply\n"); + return -EIO; +} + +#define RPCB_program_sz (1u) +#define RPCB_version_sz (1u) +#define RPCB_protocol_sz (1u) +#define RPCB_port_sz (1u) +#define RPCB_boolean_sz (1u) + +#define RPCB_netid_sz (1+XDR_QUADLEN(RPCBIND_MAXNETIDLEN)) +#define RPCB_addr_sz (1+XDR_QUADLEN(RPCBIND_MAXUADDRLEN)) +#define RPCB_ownerstring_sz (1+XDR_QUADLEN(RPCB_MAXOWNERLEN)) + +#define RPCB_mappingargs_sz RPCB_program_sz+RPCB_version_sz+ \ + RPCB_protocol_sz+RPCB_port_sz +#define RPCB_getaddrargs_sz RPCB_program_sz+RPCB_version_sz+ \ + RPCB_netid_sz+RPCB_addr_sz+ \ + RPCB_ownerstring_sz + +#define RPCB_setres_sz RPCB_boolean_sz +#define RPCB_getportres_sz RPCB_port_sz + +/* + * Note that RFC 1833 does not put any size restrictions on the + * address string returned by the remote rpcbind database. + */ +#define RPCB_getaddrres_sz RPCB_addr_sz + +#define PROC(proc, argtype, restype) \ + [RPCBPROC_##proc] = { \ + .p_proc = RPCBPROC_##proc, \ + .p_encode = (kxdrproc_t) rpcb_encode_##argtype, \ + .p_decode = (kxdrproc_t) rpcb_decode_##restype, \ + .p_arglen = RPCB_##argtype##args_sz, \ + .p_replen = RPCB_##restype##res_sz, \ + .p_statidx = RPCBPROC_##proc, \ + .p_timer = 0, \ + .p_name = #proc, \ + } + +/* + * Not all rpcbind procedures described in RFC 1833 are implemented + * since the Linux kernel RPC code requires only these. + */ +static struct rpc_procinfo rpcb_procedures2[] = { + PROC(SET, mapping, set), + PROC(UNSET, mapping, set), + PROC(GETPORT, mapping, getport), +}; + +static struct rpc_procinfo rpcb_procedures3[] = { + PROC(SET, getaddr, set), + PROC(UNSET, getaddr, set), + PROC(GETADDR, getaddr, getaddr), +}; + +static struct rpc_procinfo rpcb_procedures4[] = { + PROC(SET, getaddr, set), + PROC(UNSET, getaddr, set), + PROC(GETADDR, getaddr, getaddr), + PROC(GETVERSADDR, getaddr, getaddr), +}; + +static struct rpcb_info rpcb_next_version[] = { + { + .rpc_vers = RPCBVERS_2, + .rpc_proc = &rpcb_procedures2[RPCBPROC_GETPORT], + }, + { + .rpc_proc = NULL, + }, +}; + +static struct rpcb_info rpcb_next_version6[] = { + { + .rpc_vers = RPCBVERS_4, + .rpc_proc = &rpcb_procedures4[RPCBPROC_GETADDR], + }, + { + .rpc_vers = RPCBVERS_3, + .rpc_proc = &rpcb_procedures3[RPCBPROC_GETADDR], + }, + { + .rpc_proc = NULL, + }, +}; + +static struct rpc_version rpcb_version2 = { + .number = RPCBVERS_2, + .nrprocs = RPCB_HIGHPROC_2, + .procs = rpcb_procedures2 +}; + +static struct rpc_version rpcb_version3 = { + .number = RPCBVERS_3, + .nrprocs = RPCB_HIGHPROC_3, + .procs = rpcb_procedures3 +}; + +static struct rpc_version rpcb_version4 = { + .number = RPCBVERS_4, + .nrprocs = RPCB_HIGHPROC_4, + .procs = rpcb_procedures4 +}; + +static struct rpc_version *rpcb_version[] = { + NULL, + NULL, + &rpcb_version2, + &rpcb_version3, + &rpcb_version4 +}; + +static struct rpc_stat rpcb_stats; + +static struct rpc_program rpcb_program = { + .name = "rpcbind", + .number = RPCBIND_PROGRAM, + .nrvers = ARRAY_SIZE(rpcb_version), + .version = rpcb_version, + .stats = &rpcb_stats, +}; diff --git a/net/sunrpc/sched.c b/net/sunrpc/sched.c new file mode 100644 index 0000000..385f427 --- /dev/null +++ b/net/sunrpc/sched.c @@ -0,0 +1,1035 @@ +/* + * linux/net/sunrpc/sched.c + * + * Scheduling for synchronous and asynchronous RPC requests. + * + * Copyright (C) 1996 Olaf Kirch, <okir@monad.swb.de> + * + * TCP NFS related read + write fixes + * (C) 1999 Dave Airlie, University of Limerick, Ireland <airlied@linux.ie> + */ + +#include <linux/module.h> + +#include <linux/sched.h> +#include <linux/interrupt.h> +#include <linux/slab.h> +#include <linux/mempool.h> +#include <linux/smp.h> +#include <linux/smp_lock.h> +#include <linux/spinlock.h> +#include <linux/mutex.h> + +#include <linux/sunrpc/clnt.h> + +#ifdef RPC_DEBUG +#define RPCDBG_FACILITY RPCDBG_SCHED +#define RPC_TASK_MAGIC_ID 0xf00baa +#endif + +/* + * RPC slabs and memory pools + */ +#define RPC_BUFFER_MAXSIZE (2048) +#define RPC_BUFFER_POOLSIZE (8) +#define RPC_TASK_POOLSIZE (8) +static struct kmem_cache *rpc_task_slabp __read_mostly; +static struct kmem_cache *rpc_buffer_slabp __read_mostly; +static mempool_t *rpc_task_mempool __read_mostly; +static mempool_t *rpc_buffer_mempool __read_mostly; + +static void rpc_async_schedule(struct work_struct *); +static void rpc_release_task(struct rpc_task *task); +static void __rpc_queue_timer_fn(unsigned long ptr); + +/* + * RPC tasks sit here while waiting for conditions to improve. + */ +static struct rpc_wait_queue delay_queue; + +/* + * rpciod-related stuff + */ +struct workqueue_struct *rpciod_workqueue; + +/* + * Disable the timer for a given RPC task. Should be called with + * queue->lock and bh_disabled in order to avoid races within + * rpc_run_timer(). + */ +static void +__rpc_disable_timer(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + if (task->tk_timeout == 0) + return; + dprintk("RPC: %5u disabling timer\n", task->tk_pid); + task->tk_timeout = 0; + list_del(&task->u.tk_wait.timer_list); + if (list_empty(&queue->timer_list.list)) + del_timer(&queue->timer_list.timer); +} + +static void +rpc_set_queue_timer(struct rpc_wait_queue *queue, unsigned long expires) +{ + queue->timer_list.expires = expires; + mod_timer(&queue->timer_list.timer, expires); +} + +/* + * Set up a timer for the current task. + */ +static void +__rpc_add_timer(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + if (!task->tk_timeout) + return; + + dprintk("RPC: %5u setting alarm for %lu ms\n", + task->tk_pid, task->tk_timeout * 1000 / HZ); + + task->u.tk_wait.expires = jiffies + task->tk_timeout; + if (list_empty(&queue->timer_list.list) || time_before(task->u.tk_wait.expires, queue->timer_list.expires)) + rpc_set_queue_timer(queue, task->u.tk_wait.expires); + list_add(&task->u.tk_wait.timer_list, &queue->timer_list.list); +} + +/* + * Add new request to a priority queue. + */ +static void __rpc_add_wait_queue_priority(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + struct list_head *q; + struct rpc_task *t; + + INIT_LIST_HEAD(&task->u.tk_wait.links); + q = &queue->tasks[task->tk_priority]; + if (unlikely(task->tk_priority > queue->maxpriority)) + q = &queue->tasks[queue->maxpriority]; + list_for_each_entry(t, q, u.tk_wait.list) { + if (t->tk_owner == task->tk_owner) { + list_add_tail(&task->u.tk_wait.list, &t->u.tk_wait.links); + return; + } + } + list_add_tail(&task->u.tk_wait.list, q); +} + +/* + * Add new request to wait queue. + * + * Swapper tasks always get inserted at the head of the queue. + * This should avoid many nasty memory deadlocks and hopefully + * improve overall performance. + * Everyone else gets appended to the queue to ensure proper FIFO behavior. + */ +static void __rpc_add_wait_queue(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + BUG_ON (RPC_IS_QUEUED(task)); + + if (RPC_IS_PRIORITY(queue)) + __rpc_add_wait_queue_priority(queue, task); + else if (RPC_IS_SWAPPER(task)) + list_add(&task->u.tk_wait.list, &queue->tasks[0]); + else + list_add_tail(&task->u.tk_wait.list, &queue->tasks[0]); + task->tk_waitqueue = queue; + queue->qlen++; + rpc_set_queued(task); + + dprintk("RPC: %5u added to queue %p \"%s\"\n", + task->tk_pid, queue, rpc_qname(queue)); +} + +/* + * Remove request from a priority queue. + */ +static void __rpc_remove_wait_queue_priority(struct rpc_task *task) +{ + struct rpc_task *t; + + if (!list_empty(&task->u.tk_wait.links)) { + t = list_entry(task->u.tk_wait.links.next, struct rpc_task, u.tk_wait.list); + list_move(&t->u.tk_wait.list, &task->u.tk_wait.list); + list_splice_init(&task->u.tk_wait.links, &t->u.tk_wait.links); + } +} + +/* + * Remove request from queue. + * Note: must be called with spin lock held. + */ +static void __rpc_remove_wait_queue(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + __rpc_disable_timer(queue, task); + if (RPC_IS_PRIORITY(queue)) + __rpc_remove_wait_queue_priority(task); + list_del(&task->u.tk_wait.list); + queue->qlen--; + dprintk("RPC: %5u removed from queue %p \"%s\"\n", + task->tk_pid, queue, rpc_qname(queue)); +} + +static inline void rpc_set_waitqueue_priority(struct rpc_wait_queue *queue, int priority) +{ + queue->priority = priority; + queue->count = 1 << (priority * 2); +} + +static inline void rpc_set_waitqueue_owner(struct rpc_wait_queue *queue, pid_t pid) +{ + queue->owner = pid; + queue->nr = RPC_BATCH_COUNT; +} + +static inline void rpc_reset_waitqueue_priority(struct rpc_wait_queue *queue) +{ + rpc_set_waitqueue_priority(queue, queue->maxpriority); + rpc_set_waitqueue_owner(queue, 0); +} + +static void __rpc_init_priority_wait_queue(struct rpc_wait_queue *queue, const char *qname, unsigned char nr_queues) +{ + int i; + + spin_lock_init(&queue->lock); + for (i = 0; i < ARRAY_SIZE(queue->tasks); i++) + INIT_LIST_HEAD(&queue->tasks[i]); + queue->maxpriority = nr_queues - 1; + rpc_reset_waitqueue_priority(queue); + queue->qlen = 0; + setup_timer(&queue->timer_list.timer, __rpc_queue_timer_fn, (unsigned long)queue); + INIT_LIST_HEAD(&queue->timer_list.list); +#ifdef RPC_DEBUG + queue->name = qname; +#endif +} + +void rpc_init_priority_wait_queue(struct rpc_wait_queue *queue, const char *qname) +{ + __rpc_init_priority_wait_queue(queue, qname, RPC_NR_PRIORITY); +} + +void rpc_init_wait_queue(struct rpc_wait_queue *queue, const char *qname) +{ + __rpc_init_priority_wait_queue(queue, qname, 1); +} +EXPORT_SYMBOL_GPL(rpc_init_wait_queue); + +void rpc_destroy_wait_queue(struct rpc_wait_queue *queue) +{ + del_timer_sync(&queue->timer_list.timer); +} +EXPORT_SYMBOL_GPL(rpc_destroy_wait_queue); + +static int rpc_wait_bit_killable(void *word) +{ + if (fatal_signal_pending(current)) + return -ERESTARTSYS; + schedule(); + return 0; +} + +#ifdef RPC_DEBUG +static void rpc_task_set_debuginfo(struct rpc_task *task) +{ + static atomic_t rpc_pid; + + task->tk_magic = RPC_TASK_MAGIC_ID; + task->tk_pid = atomic_inc_return(&rpc_pid); +} +#else +static inline void rpc_task_set_debuginfo(struct rpc_task *task) +{ +} +#endif + +static void rpc_set_active(struct rpc_task *task) +{ + struct rpc_clnt *clnt; + if (test_and_set_bit(RPC_TASK_ACTIVE, &task->tk_runstate) != 0) + return; + rpc_task_set_debuginfo(task); + /* Add to global list of all tasks */ + clnt = task->tk_client; + if (clnt != NULL) { + spin_lock(&clnt->cl_lock); + list_add_tail(&task->tk_task, &clnt->cl_tasks); + spin_unlock(&clnt->cl_lock); + } +} + +/* + * Mark an RPC call as having completed by clearing the 'active' bit + */ +static void rpc_mark_complete_task(struct rpc_task *task) +{ + smp_mb__before_clear_bit(); + clear_bit(RPC_TASK_ACTIVE, &task->tk_runstate); + smp_mb__after_clear_bit(); + wake_up_bit(&task->tk_runstate, RPC_TASK_ACTIVE); +} + +/* + * Allow callers to wait for completion of an RPC call + */ +int __rpc_wait_for_completion_task(struct rpc_task *task, int (*action)(void *)) +{ + if (action == NULL) + action = rpc_wait_bit_killable; + return wait_on_bit(&task->tk_runstate, RPC_TASK_ACTIVE, + action, TASK_KILLABLE); +} +EXPORT_SYMBOL_GPL(__rpc_wait_for_completion_task); + +/* + * Make an RPC task runnable. + * + * Note: If the task is ASYNC, this must be called with + * the spinlock held to protect the wait queue operation. + */ +static void rpc_make_runnable(struct rpc_task *task) +{ + rpc_clear_queued(task); + if (rpc_test_and_set_running(task)) + return; + /* We might have raced */ + if (RPC_IS_QUEUED(task)) { + rpc_clear_running(task); + return; + } + if (RPC_IS_ASYNC(task)) { + int status; + + INIT_WORK(&task->u.tk_work, rpc_async_schedule); + status = queue_work(rpciod_workqueue, &task->u.tk_work); + if (status < 0) { + printk(KERN_WARNING "RPC: failed to add task to queue: error: %d!\n", status); + task->tk_status = status; + return; + } + } else + wake_up_bit(&task->tk_runstate, RPC_TASK_QUEUED); +} + +/* + * Prepare for sleeping on a wait queue. + * By always appending tasks to the list we ensure FIFO behavior. + * NB: An RPC task will only receive interrupt-driven events as long + * as it's on a wait queue. + */ +static void __rpc_sleep_on(struct rpc_wait_queue *q, struct rpc_task *task, + rpc_action action) +{ + dprintk("RPC: %5u sleep_on(queue \"%s\" time %lu)\n", + task->tk_pid, rpc_qname(q), jiffies); + + if (!RPC_IS_ASYNC(task) && !RPC_IS_ACTIVATED(task)) { + printk(KERN_ERR "RPC: Inactive synchronous task put to sleep!\n"); + return; + } + + __rpc_add_wait_queue(q, task); + + BUG_ON(task->tk_callback != NULL); + task->tk_callback = action; + __rpc_add_timer(q, task); +} + +void rpc_sleep_on(struct rpc_wait_queue *q, struct rpc_task *task, + rpc_action action) +{ + /* Mark the task as being activated if so needed */ + rpc_set_active(task); + + /* + * Protect the queue operations. + */ + spin_lock_bh(&q->lock); + __rpc_sleep_on(q, task, action); + spin_unlock_bh(&q->lock); +} +EXPORT_SYMBOL_GPL(rpc_sleep_on); + +/** + * __rpc_do_wake_up_task - wake up a single rpc_task + * @queue: wait queue + * @task: task to be woken up + * + * Caller must hold queue->lock, and have cleared the task queued flag. + */ +static void __rpc_do_wake_up_task(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + dprintk("RPC: %5u __rpc_wake_up_task (now %lu)\n", + task->tk_pid, jiffies); + +#ifdef RPC_DEBUG + BUG_ON(task->tk_magic != RPC_TASK_MAGIC_ID); +#endif + /* Has the task been executed yet? If not, we cannot wake it up! */ + if (!RPC_IS_ACTIVATED(task)) { + printk(KERN_ERR "RPC: Inactive task (%p) being woken up!\n", task); + return; + } + + __rpc_remove_wait_queue(queue, task); + + rpc_make_runnable(task); + + dprintk("RPC: __rpc_wake_up_task done\n"); +} + +/* + * Wake up a queued task while the queue lock is being held + */ +static void rpc_wake_up_task_queue_locked(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + if (RPC_IS_QUEUED(task) && task->tk_waitqueue == queue) + __rpc_do_wake_up_task(queue, task); +} + +/* + * Wake up a task on a specific queue + */ +void rpc_wake_up_queued_task(struct rpc_wait_queue *queue, struct rpc_task *task) +{ + spin_lock_bh(&queue->lock); + rpc_wake_up_task_queue_locked(queue, task); + spin_unlock_bh(&queue->lock); +} +EXPORT_SYMBOL_GPL(rpc_wake_up_queued_task); + +/* + * Wake up the specified task + */ +static void rpc_wake_up_task(struct rpc_task *task) +{ + rpc_wake_up_queued_task(task->tk_waitqueue, task); +} + +/* + * Wake up the next task on a priority queue. + */ +static struct rpc_task * __rpc_wake_up_next_priority(struct rpc_wait_queue *queue) +{ + struct list_head *q; + struct rpc_task *task; + + /* + * Service a batch of tasks from a single owner. + */ + q = &queue->tasks[queue->priority]; + if (!list_empty(q)) { + task = list_entry(q->next, struct rpc_task, u.tk_wait.list); + if (queue->owner == task->tk_owner) { + if (--queue->nr) + goto out; + list_move_tail(&task->u.tk_wait.list, q); + } + /* + * Check if we need to switch queues. + */ + if (--queue->count) + goto new_owner; + } + + /* + * Service the next queue. + */ + do { + if (q == &queue->tasks[0]) + q = &queue->tasks[queue->maxpriority]; + else + q = q - 1; + if (!list_empty(q)) { + task = list_entry(q->next, struct rpc_task, u.tk_wait.list); + goto new_queue; + } + } while (q != &queue->tasks[queue->priority]); + + rpc_reset_waitqueue_priority(queue); + return NULL; + +new_queue: + rpc_set_waitqueue_priority(queue, (unsigned int)(q - &queue->tasks[0])); +new_owner: + rpc_set_waitqueue_owner(queue, task->tk_owner); +out: + rpc_wake_up_task_queue_locked(queue, task); + return task; +} + +/* + * Wake up the next task on the wait queue. + */ +struct rpc_task * rpc_wake_up_next(struct rpc_wait_queue *queue) +{ + struct rpc_task *task = NULL; + + dprintk("RPC: wake_up_next(%p \"%s\")\n", + queue, rpc_qname(queue)); + spin_lock_bh(&queue->lock); + if (RPC_IS_PRIORITY(queue)) + task = __rpc_wake_up_next_priority(queue); + else { + task_for_first(task, &queue->tasks[0]) + rpc_wake_up_task_queue_locked(queue, task); + } + spin_unlock_bh(&queue->lock); + + return task; +} +EXPORT_SYMBOL_GPL(rpc_wake_up_next); + +/** + * rpc_wake_up - wake up all rpc_tasks + * @queue: rpc_wait_queue on which the tasks are sleeping + * + * Grabs queue->lock + */ +void rpc_wake_up(struct rpc_wait_queue *queue) +{ + struct rpc_task *task, *next; + struct list_head *head; + + spin_lock_bh(&queue->lock); + head = &queue->tasks[queue->maxpriority]; + for (;;) { + list_for_each_entry_safe(task, next, head, u.tk_wait.list) + rpc_wake_up_task_queue_locked(queue, task); + if (head == &queue->tasks[0]) + break; + head--; + } + spin_unlock_bh(&queue->lock); +} +EXPORT_SYMBOL_GPL(rpc_wake_up); + +/** + * rpc_wake_up_status - wake up all rpc_tasks and set their status value. + * @queue: rpc_wait_queue on which the tasks are sleeping + * @status: status value to set + * + * Grabs queue->lock + */ +void rpc_wake_up_status(struct rpc_wait_queue *queue, int status) +{ + struct rpc_task *task, *next; + struct list_head *head; + + spin_lock_bh(&queue->lock); + head = &queue->tasks[queue->maxpriority]; + for (;;) { + list_for_each_entry_safe(task, next, head, u.tk_wait.list) { + task->tk_status = status; + rpc_wake_up_task_queue_locked(queue, task); + } + if (head == &queue->tasks[0]) + break; + head--; + } + spin_unlock_bh(&queue->lock); +} +EXPORT_SYMBOL_GPL(rpc_wake_up_status); + +static void __rpc_queue_timer_fn(unsigned long ptr) +{ + struct rpc_wait_queue *queue = (struct rpc_wait_queue *)ptr; + struct rpc_task *task, *n; + unsigned long expires, now, timeo; + + spin_lock(&queue->lock); + expires = now = jiffies; + list_for_each_entry_safe(task, n, &queue->timer_list.list, u.tk_wait.timer_list) { + timeo = task->u.tk_wait.expires; + if (time_after_eq(now, timeo)) { + dprintk("RPC: %5u timeout\n", task->tk_pid); + task->tk_status = -ETIMEDOUT; + rpc_wake_up_task_queue_locked(queue, task); + continue; + } + if (expires == now || time_after(expires, timeo)) + expires = timeo; + } + if (!list_empty(&queue->timer_list.list)) + rpc_set_queue_timer(queue, expires); + spin_unlock(&queue->lock); +} + +static void __rpc_atrun(struct rpc_task *task) +{ + task->tk_status = 0; +} + +/* + * Run a task at a later time + */ +void rpc_delay(struct rpc_task *task, unsigned long delay) +{ + task->tk_timeout = delay; + rpc_sleep_on(&delay_queue, task, __rpc_atrun); +} +EXPORT_SYMBOL_GPL(rpc_delay); + +/* + * Helper to call task->tk_ops->rpc_call_prepare + */ +static void rpc_prepare_task(struct rpc_task *task) +{ + task->tk_ops->rpc_call_prepare(task, task->tk_calldata); +} + +/* + * Helper that calls task->tk_ops->rpc_call_done if it exists + */ +void rpc_exit_task(struct rpc_task *task) +{ + task->tk_action = NULL; + if (task->tk_ops->rpc_call_done != NULL) { + task->tk_ops->rpc_call_done(task, task->tk_calldata); + if (task->tk_action != NULL) { + WARN_ON(RPC_ASSASSINATED(task)); + /* Always release the RPC slot and buffer memory */ + xprt_release(task); + } + } +} +EXPORT_SYMBOL_GPL(rpc_exit_task); + +void rpc_release_calldata(const struct rpc_call_ops *ops, void *calldata) +{ + if (ops->rpc_release != NULL) + ops->rpc_release(calldata); +} + +/* + * This is the RPC `scheduler' (or rather, the finite state machine). + */ +static void __rpc_execute(struct rpc_task *task) +{ + int status = 0; + + dprintk("RPC: %5u __rpc_execute flags=0x%x\n", + task->tk_pid, task->tk_flags); + + BUG_ON(RPC_IS_QUEUED(task)); + + for (;;) { + + /* + * Execute any pending callback. + */ + if (task->tk_callback) { + void (*save_callback)(struct rpc_task *); + + /* + * We set tk_callback to NULL before calling it, + * in case it sets the tk_callback field itself: + */ + save_callback = task->tk_callback; + task->tk_callback = NULL; + save_callback(task); + } + + /* + * Perform the next FSM step. + * tk_action may be NULL when the task has been killed + * by someone else. + */ + if (!RPC_IS_QUEUED(task)) { + if (task->tk_action == NULL) + break; + task->tk_action(task); + } + + /* + * Lockless check for whether task is sleeping or not. + */ + if (!RPC_IS_QUEUED(task)) + continue; + rpc_clear_running(task); + if (RPC_IS_ASYNC(task)) { + /* Careful! we may have raced... */ + if (RPC_IS_QUEUED(task)) + return; + if (rpc_test_and_set_running(task)) + return; + continue; + } + + /* sync task: sleep here */ + dprintk("RPC: %5u sync task going to sleep\n", task->tk_pid); + status = out_of_line_wait_on_bit(&task->tk_runstate, + RPC_TASK_QUEUED, rpc_wait_bit_killable, + TASK_KILLABLE); + if (status == -ERESTARTSYS) { + /* + * When a sync task receives a signal, it exits with + * -ERESTARTSYS. In order to catch any callbacks that + * clean up after sleeping on some queue, we don't + * break the loop here, but go around once more. + */ + dprintk("RPC: %5u got signal\n", task->tk_pid); + task->tk_flags |= RPC_TASK_KILLED; + rpc_exit(task, -ERESTARTSYS); + rpc_wake_up_task(task); + } + rpc_set_running(task); + dprintk("RPC: %5u sync task resuming\n", task->tk_pid); + } + + dprintk("RPC: %5u return %d, status %d\n", task->tk_pid, status, + task->tk_status); + /* Release all resources associated with the task */ + rpc_release_task(task); +} + +/* + * User-visible entry point to the scheduler. + * + * This may be called recursively if e.g. an async NFS task updates + * the attributes and finds that dirty pages must be flushed. + * NOTE: Upon exit of this function the task is guaranteed to be + * released. In particular note that tk_release() will have + * been called, so your task memory may have been freed. + */ +void rpc_execute(struct rpc_task *task) +{ + rpc_set_active(task); + rpc_set_running(task); + __rpc_execute(task); +} + +static void rpc_async_schedule(struct work_struct *work) +{ + __rpc_execute(container_of(work, struct rpc_task, u.tk_work)); +} + +struct rpc_buffer { + size_t len; + char data[]; +}; + +/** + * rpc_malloc - allocate an RPC buffer + * @task: RPC task that will use this buffer + * @size: requested byte size + * + * To prevent rpciod from hanging, this allocator never sleeps, + * returning NULL if the request cannot be serviced immediately. + * The caller can arrange to sleep in a way that is safe for rpciod. + * + * Most requests are 'small' (under 2KiB) and can be serviced from a + * mempool, ensuring that NFS reads and writes can always proceed, + * and that there is good locality of reference for these buffers. + * + * In order to avoid memory starvation triggering more writebacks of + * NFS requests, we avoid using GFP_KERNEL. + */ +void *rpc_malloc(struct rpc_task *task, size_t size) +{ + struct rpc_buffer *buf; + gfp_t gfp = RPC_IS_SWAPPER(task) ? GFP_ATOMIC : GFP_NOWAIT; + + size += sizeof(struct rpc_buffer); + if (size <= RPC_BUFFER_MAXSIZE) + buf = mempool_alloc(rpc_buffer_mempool, gfp); + else + buf = kmalloc(size, gfp); + + if (!buf) + return NULL; + + buf->len = size; + dprintk("RPC: %5u allocated buffer of size %zu at %p\n", + task->tk_pid, size, buf); + return &buf->data; +} +EXPORT_SYMBOL_GPL(rpc_malloc); + +/** + * rpc_free - free buffer allocated via rpc_malloc + * @buffer: buffer to free + * + */ +void rpc_free(void *buffer) +{ + size_t size; + struct rpc_buffer *buf; + + if (!buffer) + return; + + buf = container_of(buffer, struct rpc_buffer, data); + size = buf->len; + + dprintk("RPC: freeing buffer of size %zu at %p\n", + size, buf); + + if (size <= RPC_BUFFER_MAXSIZE) + mempool_free(buf, rpc_buffer_mempool); + else + kfree(buf); +} +EXPORT_SYMBOL_GPL(rpc_free); + +/* + * Creation and deletion of RPC task structures + */ +static void rpc_init_task(struct rpc_task *task, const struct rpc_task_setup *task_setup_data) +{ + memset(task, 0, sizeof(*task)); + atomic_set(&task->tk_count, 1); + task->tk_flags = task_setup_data->flags; + task->tk_ops = task_setup_data->callback_ops; + task->tk_calldata = task_setup_data->callback_data; + INIT_LIST_HEAD(&task->tk_task); + + /* Initialize retry counters */ + task->tk_garb_retry = 2; + task->tk_cred_retry = 2; + + task->tk_priority = task_setup_data->priority - RPC_PRIORITY_LOW; + task->tk_owner = current->tgid; + + /* Initialize workqueue for async tasks */ + task->tk_workqueue = task_setup_data->workqueue; + + task->tk_client = task_setup_data->rpc_client; + if (task->tk_client != NULL) { + kref_get(&task->tk_client->cl_kref); + if (task->tk_client->cl_softrtry) + task->tk_flags |= RPC_TASK_SOFT; + } + + if (task->tk_ops->rpc_call_prepare != NULL) + task->tk_action = rpc_prepare_task; + + if (task_setup_data->rpc_message != NULL) { + task->tk_msg.rpc_proc = task_setup_data->rpc_message->rpc_proc; + task->tk_msg.rpc_argp = task_setup_data->rpc_message->rpc_argp; + task->tk_msg.rpc_resp = task_setup_data->rpc_message->rpc_resp; + /* Bind the user cred */ + rpcauth_bindcred(task, task_setup_data->rpc_message->rpc_cred, task_setup_data->flags); + if (task->tk_action == NULL) + rpc_call_start(task); + } + + /* starting timestamp */ + task->tk_start = jiffies; + + dprintk("RPC: new task initialized, procpid %u\n", + task_pid_nr(current)); +} + +static struct rpc_task * +rpc_alloc_task(void) +{ + return (struct rpc_task *)mempool_alloc(rpc_task_mempool, GFP_NOFS); +} + +/* + * Create a new task for the specified client. + */ +struct rpc_task *rpc_new_task(const struct rpc_task_setup *setup_data) +{ + struct rpc_task *task = setup_data->task; + unsigned short flags = 0; + + if (task == NULL) { + task = rpc_alloc_task(); + if (task == NULL) + goto out; + flags = RPC_TASK_DYNAMIC; + } + + rpc_init_task(task, setup_data); + + task->tk_flags |= flags; + dprintk("RPC: allocated task %p\n", task); +out: + return task; +} + +static void rpc_free_task(struct rpc_task *task) +{ + const struct rpc_call_ops *tk_ops = task->tk_ops; + void *calldata = task->tk_calldata; + + if (task->tk_flags & RPC_TASK_DYNAMIC) { + dprintk("RPC: %5u freeing task\n", task->tk_pid); + mempool_free(task, rpc_task_mempool); + } + rpc_release_calldata(tk_ops, calldata); +} + +static void rpc_async_release(struct work_struct *work) +{ + rpc_free_task(container_of(work, struct rpc_task, u.tk_work)); +} + +void rpc_put_task(struct rpc_task *task) +{ + if (!atomic_dec_and_test(&task->tk_count)) + return; + /* Release resources */ + if (task->tk_rqstp) + xprt_release(task); + if (task->tk_msg.rpc_cred) + rpcauth_unbindcred(task); + if (task->tk_client) { + rpc_release_client(task->tk_client); + task->tk_client = NULL; + } + if (task->tk_workqueue != NULL) { + INIT_WORK(&task->u.tk_work, rpc_async_release); + queue_work(task->tk_workqueue, &task->u.tk_work); + } else + rpc_free_task(task); +} +EXPORT_SYMBOL_GPL(rpc_put_task); + +static void rpc_release_task(struct rpc_task *task) +{ +#ifdef RPC_DEBUG + BUG_ON(task->tk_magic != RPC_TASK_MAGIC_ID); +#endif + dprintk("RPC: %5u release task\n", task->tk_pid); + + if (!list_empty(&task->tk_task)) { + struct rpc_clnt *clnt = task->tk_client; + /* Remove from client task list */ + spin_lock(&clnt->cl_lock); + list_del(&task->tk_task); + spin_unlock(&clnt->cl_lock); + } + BUG_ON (RPC_IS_QUEUED(task)); + +#ifdef RPC_DEBUG + task->tk_magic = 0; +#endif + /* Wake up anyone who is waiting for task completion */ + rpc_mark_complete_task(task); + + rpc_put_task(task); +} + +/* + * Kill all tasks for the given client. + * XXX: kill their descendants as well? + */ +void rpc_killall_tasks(struct rpc_clnt *clnt) +{ + struct rpc_task *rovr; + + + if (list_empty(&clnt->cl_tasks)) + return; + dprintk("RPC: killing all tasks for client %p\n", clnt); + /* + * Spin lock all_tasks to prevent changes... + */ + spin_lock(&clnt->cl_lock); + list_for_each_entry(rovr, &clnt->cl_tasks, tk_task) { + if (! RPC_IS_ACTIVATED(rovr)) + continue; + if (!(rovr->tk_flags & RPC_TASK_KILLED)) { + rovr->tk_flags |= RPC_TASK_KILLED; + rpc_exit(rovr, -EIO); + rpc_wake_up_task(rovr); + } + } + spin_unlock(&clnt->cl_lock); +} +EXPORT_SYMBOL_GPL(rpc_killall_tasks); + +int rpciod_up(void) +{ + return try_module_get(THIS_MODULE) ? 0 : -EINVAL; +} + +void rpciod_down(void) +{ + module_put(THIS_MODULE); +} + +/* + * Start up the rpciod workqueue. + */ +static int rpciod_start(void) +{ + struct workqueue_struct *wq; + + /* + * Create the rpciod thread and wait for it to start. + */ + dprintk("RPC: creating workqueue rpciod\n"); + wq = create_workqueue("rpciod"); + rpciod_workqueue = wq; + return rpciod_workqueue != NULL; +} + +static void rpciod_stop(void) +{ + struct workqueue_struct *wq = NULL; + + if (rpciod_workqueue == NULL) + return; + dprintk("RPC: destroying workqueue rpciod\n"); + + wq = rpciod_workqueue; + rpciod_workqueue = NULL; + destroy_workqueue(wq); +} + +void +rpc_destroy_mempool(void) +{ + rpciod_stop(); + if (rpc_buffer_mempool) + mempool_destroy(rpc_buffer_mempool); + if (rpc_task_mempool) + mempool_destroy(rpc_task_mempool); + if (rpc_task_slabp) + kmem_cache_destroy(rpc_task_slabp); + if (rpc_buffer_slabp) + kmem_cache_destroy(rpc_buffer_slabp); + rpc_destroy_wait_queue(&delay_queue); +} + +int +rpc_init_mempool(void) +{ + /* + * The following is not strictly a mempool initialisation, + * but there is no harm in doing it here + */ + rpc_init_wait_queue(&delay_queue, "delayq"); + if (!rpciod_start()) + goto err_nomem; + + rpc_task_slabp = kmem_cache_create("rpc_tasks", + sizeof(struct rpc_task), + 0, SLAB_HWCACHE_ALIGN, + NULL); + if (!rpc_task_slabp) + goto err_nomem; + rpc_buffer_slabp = kmem_cache_create("rpc_buffers", + RPC_BUFFER_MAXSIZE, + 0, SLAB_HWCACHE_ALIGN, + NULL); + if (!rpc_buffer_slabp) + goto err_nomem; + rpc_task_mempool = mempool_create_slab_pool(RPC_TASK_POOLSIZE, + rpc_task_slabp); + if (!rpc_task_mempool) + goto err_nomem; + rpc_buffer_mempool = mempool_create_slab_pool(RPC_BUFFER_POOLSIZE, + rpc_buffer_slabp); + if (!rpc_buffer_mempool) + goto err_nomem; + return 0; +err_nomem: + rpc_destroy_mempool(); + return -ENOMEM; +} diff --git a/net/sunrpc/socklib.c b/net/sunrpc/socklib.c new file mode 100644 index 0000000..a661a3a --- /dev/null +++ b/net/sunrpc/socklib.c @@ -0,0 +1,184 @@ +/* + * linux/net/sunrpc/socklib.c + * + * Common socket helper routines for RPC client and server + * + * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/compiler.h> +#include <linux/netdevice.h> +#include <linux/skbuff.h> +#include <linux/types.h> +#include <linux/pagemap.h> +#include <linux/udp.h> +#include <linux/sunrpc/xdr.h> + + +/** + * xdr_skb_read_bits - copy some data bits from skb to internal buffer + * @desc: sk_buff copy helper + * @to: copy destination + * @len: number of bytes to copy + * + * Possibly called several times to iterate over an sk_buff and copy + * data out of it. + */ +size_t xdr_skb_read_bits(struct xdr_skb_reader *desc, void *to, size_t len) +{ + if (len > desc->count) + len = desc->count; + if (unlikely(skb_copy_bits(desc->skb, desc->offset, to, len))) + return 0; + desc->count -= len; + desc->offset += len; + return len; +} +EXPORT_SYMBOL_GPL(xdr_skb_read_bits); + +/** + * xdr_skb_read_and_csum_bits - copy and checksum from skb to buffer + * @desc: sk_buff copy helper + * @to: copy destination + * @len: number of bytes to copy + * + * Same as skb_read_bits, but calculate a checksum at the same time. + */ +static size_t xdr_skb_read_and_csum_bits(struct xdr_skb_reader *desc, void *to, size_t len) +{ + unsigned int pos; + __wsum csum2; + + if (len > desc->count) + len = desc->count; + pos = desc->offset; + csum2 = skb_copy_and_csum_bits(desc->skb, pos, to, len, 0); + desc->csum = csum_block_add(desc->csum, csum2, pos); + desc->count -= len; + desc->offset += len; + return len; +} + +/** + * xdr_partial_copy_from_skb - copy data out of an skb + * @xdr: target XDR buffer + * @base: starting offset + * @desc: sk_buff copy helper + * @copy_actor: virtual method for copying data + * + */ +ssize_t xdr_partial_copy_from_skb(struct xdr_buf *xdr, unsigned int base, struct xdr_skb_reader *desc, xdr_skb_read_actor copy_actor) +{ + struct page **ppage = xdr->pages; + unsigned int len, pglen = xdr->page_len; + ssize_t copied = 0; + size_t ret; + + len = xdr->head[0].iov_len; + if (base < len) { + len -= base; + ret = copy_actor(desc, (char *)xdr->head[0].iov_base + base, len); + copied += ret; + if (ret != len || !desc->count) + goto out; + base = 0; + } else + base -= len; + + if (unlikely(pglen == 0)) + goto copy_tail; + if (unlikely(base >= pglen)) { + base -= pglen; + goto copy_tail; + } + if (base || xdr->page_base) { + pglen -= base; + base += xdr->page_base; + ppage += base >> PAGE_CACHE_SHIFT; + base &= ~PAGE_CACHE_MASK; + } + do { + char *kaddr; + + /* ACL likes to be lazy in allocating pages - ACLs + * are small by default but can get huge. */ + if (unlikely(*ppage == NULL)) { + *ppage = alloc_page(GFP_ATOMIC); + if (unlikely(*ppage == NULL)) { + if (copied == 0) + copied = -ENOMEM; + goto out; + } + } + + len = PAGE_CACHE_SIZE; + kaddr = kmap_atomic(*ppage, KM_SKB_SUNRPC_DATA); + if (base) { + len -= base; + if (pglen < len) + len = pglen; + ret = copy_actor(desc, kaddr + base, len); + base = 0; + } else { + if (pglen < len) + len = pglen; + ret = copy_actor(desc, kaddr, len); + } + flush_dcache_page(*ppage); + kunmap_atomic(kaddr, KM_SKB_SUNRPC_DATA); + copied += ret; + if (ret != len || !desc->count) + goto out; + ppage++; + } while ((pglen -= len) != 0); +copy_tail: + len = xdr->tail[0].iov_len; + if (base < len) + copied += copy_actor(desc, (char *)xdr->tail[0].iov_base + base, len - base); +out: + return copied; +} +EXPORT_SYMBOL_GPL(xdr_partial_copy_from_skb); + +/** + * csum_partial_copy_to_xdr - checksum and copy data + * @xdr: target XDR buffer + * @skb: source skb + * + * We have set things up such that we perform the checksum of the UDP + * packet in parallel with the copies into the RPC client iovec. -DaveM + */ +int csum_partial_copy_to_xdr(struct xdr_buf *xdr, struct sk_buff *skb) +{ + struct xdr_skb_reader desc; + + desc.skb = skb; + desc.offset = sizeof(struct udphdr); + desc.count = skb->len - desc.offset; + + if (skb_csum_unnecessary(skb)) + goto no_checksum; + + desc.csum = csum_partial(skb->data, desc.offset, skb->csum); + if (xdr_partial_copy_from_skb(xdr, 0, &desc, xdr_skb_read_and_csum_bits) < 0) + return -1; + if (desc.offset != skb->len) { + __wsum csum2; + csum2 = skb_checksum(skb, desc.offset, skb->len - desc.offset, 0); + desc.csum = csum_block_add(desc.csum, csum2, desc.offset); + } + if (desc.count) + return -1; + if (csum_fold(desc.csum)) + return -1; + if (unlikely(skb->ip_summed == CHECKSUM_COMPLETE)) + netdev_rx_csum_fault(skb->dev); + return 0; +no_checksum: + if (xdr_partial_copy_from_skb(xdr, 0, &desc, xdr_skb_read_bits) < 0) + return -1; + if (desc.count) + return -1; + return 0; +} +EXPORT_SYMBOL_GPL(csum_partial_copy_to_xdr); diff --git a/net/sunrpc/stats.c b/net/sunrpc/stats.c new file mode 100644 index 0000000..50b049c --- /dev/null +++ b/net/sunrpc/stats.c @@ -0,0 +1,284 @@ +/* + * linux/net/sunrpc/stats.c + * + * procfs-based user access to generic RPC statistics. The stats files + * reside in /proc/net/rpc. + * + * The read routines assume that the buffer passed in is just big enough. + * If you implement an RPC service that has its own stats routine which + * appends the generic RPC stats, make sure you don't exceed the PAGE_SIZE + * limit. + * + * Copyright (C) 1995, 1996, 1997 Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/module.h> + +#include <linux/init.h> +#include <linux/kernel.h> +#include <linux/proc_fs.h> +#include <linux/seq_file.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/metrics.h> +#include <net/net_namespace.h> + +#define RPCDBG_FACILITY RPCDBG_MISC + +struct proc_dir_entry *proc_net_rpc = NULL; + +/* + * Get RPC client stats + */ +static int rpc_proc_show(struct seq_file *seq, void *v) { + const struct rpc_stat *statp = seq->private; + const struct rpc_program *prog = statp->program; + unsigned int i, j; + + seq_printf(seq, + "net %u %u %u %u\n", + statp->netcnt, + statp->netudpcnt, + statp->nettcpcnt, + statp->nettcpconn); + seq_printf(seq, + "rpc %u %u %u\n", + statp->rpccnt, + statp->rpcretrans, + statp->rpcauthrefresh); + + for (i = 0; i < prog->nrvers; i++) { + const struct rpc_version *vers = prog->version[i]; + if (!vers) + continue; + seq_printf(seq, "proc%u %u", + vers->number, vers->nrprocs); + for (j = 0; j < vers->nrprocs; j++) + seq_printf(seq, " %u", + vers->procs[j].p_count); + seq_putc(seq, '\n'); + } + return 0; +} + +static int rpc_proc_open(struct inode *inode, struct file *file) +{ + return single_open(file, rpc_proc_show, PDE(inode)->data); +} + +static const struct file_operations rpc_proc_fops = { + .owner = THIS_MODULE, + .open = rpc_proc_open, + .read = seq_read, + .llseek = seq_lseek, + .release = single_release, +}; + +/* + * Get RPC server stats + */ +void svc_seq_show(struct seq_file *seq, const struct svc_stat *statp) { + const struct svc_program *prog = statp->program; + const struct svc_procedure *proc; + const struct svc_version *vers; + unsigned int i, j; + + seq_printf(seq, + "net %u %u %u %u\n", + statp->netcnt, + statp->netudpcnt, + statp->nettcpcnt, + statp->nettcpconn); + seq_printf(seq, + "rpc %u %u %u %u %u\n", + statp->rpccnt, + statp->rpcbadfmt+statp->rpcbadauth+statp->rpcbadclnt, + statp->rpcbadfmt, + statp->rpcbadauth, + statp->rpcbadclnt); + + for (i = 0; i < prog->pg_nvers; i++) { + if (!(vers = prog->pg_vers[i]) || !(proc = vers->vs_proc)) + continue; + seq_printf(seq, "proc%d %u", i, vers->vs_nproc); + for (j = 0; j < vers->vs_nproc; j++, proc++) + seq_printf(seq, " %u", proc->pc_count); + seq_putc(seq, '\n'); + } +} +EXPORT_SYMBOL(svc_seq_show); + +/** + * rpc_alloc_iostats - allocate an rpc_iostats structure + * @clnt: RPC program, version, and xprt + * + */ +struct rpc_iostats *rpc_alloc_iostats(struct rpc_clnt *clnt) +{ + struct rpc_iostats *new; + new = kcalloc(clnt->cl_maxproc, sizeof(struct rpc_iostats), GFP_KERNEL); + return new; +} +EXPORT_SYMBOL_GPL(rpc_alloc_iostats); + +/** + * rpc_free_iostats - release an rpc_iostats structure + * @stats: doomed rpc_iostats structure + * + */ +void rpc_free_iostats(struct rpc_iostats *stats) +{ + kfree(stats); +} +EXPORT_SYMBOL_GPL(rpc_free_iostats); + +/** + * rpc_count_iostats - tally up per-task stats + * @task: completed rpc_task + * + * Relies on the caller for serialization. + */ +void rpc_count_iostats(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_iostats *stats = task->tk_client->cl_metrics; + struct rpc_iostats *op_metrics; + long rtt, execute, queue; + + if (!stats || !req) + return; + op_metrics = &stats[task->tk_msg.rpc_proc->p_statidx]; + + op_metrics->om_ops++; + op_metrics->om_ntrans += req->rq_ntrans; + op_metrics->om_timeouts += task->tk_timeouts; + + op_metrics->om_bytes_sent += task->tk_bytes_sent; + op_metrics->om_bytes_recv += req->rq_received; + + queue = (long)req->rq_xtime - task->tk_start; + if (queue < 0) + queue = -queue; + op_metrics->om_queue += queue; + + rtt = task->tk_rtt; + if (rtt < 0) + rtt = -rtt; + op_metrics->om_rtt += rtt; + + execute = (long)jiffies - task->tk_start; + if (execute < 0) + execute = -execute; + op_metrics->om_execute += execute; +} + +static void _print_name(struct seq_file *seq, unsigned int op, + struct rpc_procinfo *procs) +{ + if (procs[op].p_name) + seq_printf(seq, "\t%12s: ", procs[op].p_name); + else if (op == 0) + seq_printf(seq, "\t NULL: "); + else + seq_printf(seq, "\t%12u: ", op); +} + +#define MILLISECS_PER_JIFFY (1000 / HZ) + +void rpc_print_iostats(struct seq_file *seq, struct rpc_clnt *clnt) +{ + struct rpc_iostats *stats = clnt->cl_metrics; + struct rpc_xprt *xprt = clnt->cl_xprt; + unsigned int op, maxproc = clnt->cl_maxproc; + + if (!stats) + return; + + seq_printf(seq, "\tRPC iostats version: %s ", RPC_IOSTATS_VERS); + seq_printf(seq, "p/v: %u/%u (%s)\n", + clnt->cl_prog, clnt->cl_vers, clnt->cl_protname); + + if (xprt) + xprt->ops->print_stats(xprt, seq); + + seq_printf(seq, "\tper-op statistics\n"); + for (op = 0; op < maxproc; op++) { + struct rpc_iostats *metrics = &stats[op]; + _print_name(seq, op, clnt->cl_procinfo); + seq_printf(seq, "%lu %lu %lu %Lu %Lu %Lu %Lu %Lu\n", + metrics->om_ops, + metrics->om_ntrans, + metrics->om_timeouts, + metrics->om_bytes_sent, + metrics->om_bytes_recv, + metrics->om_queue * MILLISECS_PER_JIFFY, + metrics->om_rtt * MILLISECS_PER_JIFFY, + metrics->om_execute * MILLISECS_PER_JIFFY); + } +} +EXPORT_SYMBOL_GPL(rpc_print_iostats); + +/* + * Register/unregister RPC proc files + */ +static inline struct proc_dir_entry * +do_register(const char *name, void *data, const struct file_operations *fops) +{ + rpc_proc_init(); + dprintk("RPC: registering /proc/net/rpc/%s\n", name); + + return proc_create_data(name, 0, proc_net_rpc, fops, data); +} + +struct proc_dir_entry * +rpc_proc_register(struct rpc_stat *statp) +{ + return do_register(statp->program->name, statp, &rpc_proc_fops); +} +EXPORT_SYMBOL_GPL(rpc_proc_register); + +void +rpc_proc_unregister(const char *name) +{ + remove_proc_entry(name, proc_net_rpc); +} +EXPORT_SYMBOL_GPL(rpc_proc_unregister); + +struct proc_dir_entry * +svc_proc_register(struct svc_stat *statp, const struct file_operations *fops) +{ + return do_register(statp->program->pg_name, statp, fops); +} +EXPORT_SYMBOL(svc_proc_register); + +void +svc_proc_unregister(const char *name) +{ + remove_proc_entry(name, proc_net_rpc); +} +EXPORT_SYMBOL(svc_proc_unregister); + +void +rpc_proc_init(void) +{ + dprintk("RPC: registering /proc/net/rpc\n"); + if (!proc_net_rpc) { + struct proc_dir_entry *ent; + ent = proc_mkdir("rpc", init_net.proc_net); + if (ent) { + ent->owner = THIS_MODULE; + proc_net_rpc = ent; + } + } +} + +void +rpc_proc_exit(void) +{ + dprintk("RPC: unregistering /proc/net/rpc\n"); + if (proc_net_rpc) { + proc_net_rpc = NULL; + remove_proc_entry("rpc", init_net.proc_net); + } +} + diff --git a/net/sunrpc/sunrpc_syms.c b/net/sunrpc/sunrpc_syms.c new file mode 100644 index 0000000..843629f --- /dev/null +++ b/net/sunrpc/sunrpc_syms.c @@ -0,0 +1,72 @@ +/* + * linux/net/sunrpc/sunrpc_syms.c + * + * Symbols exported by the sunrpc module. + * + * Copyright (C) 1997 Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/module.h> + +#include <linux/types.h> +#include <linux/uio.h> +#include <linux/unistd.h> +#include <linux/init.h> + +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/svc.h> +#include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/auth.h> +#include <linux/workqueue.h> +#include <linux/sunrpc/rpc_pipe_fs.h> +#include <linux/sunrpc/xprtsock.h> + +extern struct cache_detail ip_map_cache, unix_gid_cache; + +static int __init +init_sunrpc(void) +{ + int err = register_rpc_pipefs(); + if (err) + goto out; + err = rpc_init_mempool(); + if (err) { + unregister_rpc_pipefs(); + goto out; + } +#ifdef RPC_DEBUG + rpc_register_sysctl(); +#endif +#ifdef CONFIG_PROC_FS + rpc_proc_init(); +#endif + cache_register(&ip_map_cache); + cache_register(&unix_gid_cache); + svc_init_xprt_sock(); /* svc sock transport */ + init_socket_xprt(); /* clnt sock transport */ + rpcauth_init_module(); +out: + return err; +} + +static void __exit +cleanup_sunrpc(void) +{ + rpcauth_remove_module(); + cleanup_socket_xprt(); + svc_cleanup_xprt_sock(); + unregister_rpc_pipefs(); + rpc_destroy_mempool(); + cache_unregister(&ip_map_cache); + cache_unregister(&unix_gid_cache); +#ifdef RPC_DEBUG + rpc_unregister_sysctl(); +#endif +#ifdef CONFIG_PROC_FS + rpc_proc_exit(); +#endif +} +MODULE_LICENSE("GPL"); +module_init(init_sunrpc); +module_exit(cleanup_sunrpc); diff --git a/net/sunrpc/svc.c b/net/sunrpc/svc.c new file mode 100644 index 0000000..54c98d8 --- /dev/null +++ b/net/sunrpc/svc.c @@ -0,0 +1,1247 @@ +/* + * linux/net/sunrpc/svc.c + * + * High-level RPC service routines + * + * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de> + * + * Multiple threads pools and NUMAisation + * Copyright (c) 2006 Silicon Graphics, Inc. + * by Greg Banks <gnb@melbourne.sgi.com> + */ + +#include <linux/linkage.h> +#include <linux/sched.h> +#include <linux/errno.h> +#include <linux/net.h> +#include <linux/in.h> +#include <linux/mm.h> +#include <linux/interrupt.h> +#include <linux/module.h> +#include <linux/kthread.h> + +#include <linux/sunrpc/types.h> +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/stats.h> +#include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/clnt.h> + +#define RPCDBG_FACILITY RPCDBG_SVCDSP + +static void svc_unregister(const struct svc_serv *serv); + +#define svc_serv_is_pooled(serv) ((serv)->sv_function) + +/* + * Mode for mapping cpus to pools. + */ +enum { + SVC_POOL_AUTO = -1, /* choose one of the others */ + SVC_POOL_GLOBAL, /* no mapping, just a single global pool + * (legacy & UP mode) */ + SVC_POOL_PERCPU, /* one pool per cpu */ + SVC_POOL_PERNODE /* one pool per numa node */ +}; +#define SVC_POOL_DEFAULT SVC_POOL_GLOBAL + +/* + * Structure for mapping cpus to pools and vice versa. + * Setup once during sunrpc initialisation. + */ +static struct svc_pool_map { + int count; /* How many svc_servs use us */ + int mode; /* Note: int not enum to avoid + * warnings about "enumeration value + * not handled in switch" */ + unsigned int npools; + unsigned int *pool_to; /* maps pool id to cpu or node */ + unsigned int *to_pool; /* maps cpu or node to pool id */ +} svc_pool_map = { + .count = 0, + .mode = SVC_POOL_DEFAULT +}; +static DEFINE_MUTEX(svc_pool_map_mutex);/* protects svc_pool_map.count only */ + +static int +param_set_pool_mode(const char *val, struct kernel_param *kp) +{ + int *ip = (int *)kp->arg; + struct svc_pool_map *m = &svc_pool_map; + int err; + + mutex_lock(&svc_pool_map_mutex); + + err = -EBUSY; + if (m->count) + goto out; + + err = 0; + if (!strncmp(val, "auto", 4)) + *ip = SVC_POOL_AUTO; + else if (!strncmp(val, "global", 6)) + *ip = SVC_POOL_GLOBAL; + else if (!strncmp(val, "percpu", 6)) + *ip = SVC_POOL_PERCPU; + else if (!strncmp(val, "pernode", 7)) + *ip = SVC_POOL_PERNODE; + else + err = -EINVAL; + +out: + mutex_unlock(&svc_pool_map_mutex); + return err; +} + +static int +param_get_pool_mode(char *buf, struct kernel_param *kp) +{ + int *ip = (int *)kp->arg; + + switch (*ip) + { + case SVC_POOL_AUTO: + return strlcpy(buf, "auto", 20); + case SVC_POOL_GLOBAL: + return strlcpy(buf, "global", 20); + case SVC_POOL_PERCPU: + return strlcpy(buf, "percpu", 20); + case SVC_POOL_PERNODE: + return strlcpy(buf, "pernode", 20); + default: + return sprintf(buf, "%d", *ip); + } +} + +module_param_call(pool_mode, param_set_pool_mode, param_get_pool_mode, + &svc_pool_map.mode, 0644); + +/* + * Detect best pool mapping mode heuristically, + * according to the machine's topology. + */ +static int +svc_pool_map_choose_mode(void) +{ + unsigned int node; + + if (num_online_nodes() > 1) { + /* + * Actually have multiple NUMA nodes, + * so split pools on NUMA node boundaries + */ + return SVC_POOL_PERNODE; + } + + node = any_online_node(node_online_map); + if (nr_cpus_node(node) > 2) { + /* + * Non-trivial SMP, or CONFIG_NUMA on + * non-NUMA hardware, e.g. with a generic + * x86_64 kernel on Xeons. In this case we + * want to divide the pools on cpu boundaries. + */ + return SVC_POOL_PERCPU; + } + + /* default: one global pool */ + return SVC_POOL_GLOBAL; +} + +/* + * Allocate the to_pool[] and pool_to[] arrays. + * Returns 0 on success or an errno. + */ +static int +svc_pool_map_alloc_arrays(struct svc_pool_map *m, unsigned int maxpools) +{ + m->to_pool = kcalloc(maxpools, sizeof(unsigned int), GFP_KERNEL); + if (!m->to_pool) + goto fail; + m->pool_to = kcalloc(maxpools, sizeof(unsigned int), GFP_KERNEL); + if (!m->pool_to) + goto fail_free; + + return 0; + +fail_free: + kfree(m->to_pool); +fail: + return -ENOMEM; +} + +/* + * Initialise the pool map for SVC_POOL_PERCPU mode. + * Returns number of pools or <0 on error. + */ +static int +svc_pool_map_init_percpu(struct svc_pool_map *m) +{ + unsigned int maxpools = nr_cpu_ids; + unsigned int pidx = 0; + unsigned int cpu; + int err; + + err = svc_pool_map_alloc_arrays(m, maxpools); + if (err) + return err; + + for_each_online_cpu(cpu) { + BUG_ON(pidx > maxpools); + m->to_pool[cpu] = pidx; + m->pool_to[pidx] = cpu; + pidx++; + } + /* cpus brought online later all get mapped to pool0, sorry */ + + return pidx; +}; + + +/* + * Initialise the pool map for SVC_POOL_PERNODE mode. + * Returns number of pools or <0 on error. + */ +static int +svc_pool_map_init_pernode(struct svc_pool_map *m) +{ + unsigned int maxpools = nr_node_ids; + unsigned int pidx = 0; + unsigned int node; + int err; + + err = svc_pool_map_alloc_arrays(m, maxpools); + if (err) + return err; + + for_each_node_with_cpus(node) { + /* some architectures (e.g. SN2) have cpuless nodes */ + BUG_ON(pidx > maxpools); + m->to_pool[node] = pidx; + m->pool_to[pidx] = node; + pidx++; + } + /* nodes brought online later all get mapped to pool0, sorry */ + + return pidx; +} + + +/* + * Add a reference to the global map of cpus to pools (and + * vice versa). Initialise the map if we're the first user. + * Returns the number of pools. + */ +static unsigned int +svc_pool_map_get(void) +{ + struct svc_pool_map *m = &svc_pool_map; + int npools = -1; + + mutex_lock(&svc_pool_map_mutex); + + if (m->count++) { + mutex_unlock(&svc_pool_map_mutex); + return m->npools; + } + + if (m->mode == SVC_POOL_AUTO) + m->mode = svc_pool_map_choose_mode(); + + switch (m->mode) { + case SVC_POOL_PERCPU: + npools = svc_pool_map_init_percpu(m); + break; + case SVC_POOL_PERNODE: + npools = svc_pool_map_init_pernode(m); + break; + } + + if (npools < 0) { + /* default, or memory allocation failure */ + npools = 1; + m->mode = SVC_POOL_GLOBAL; + } + m->npools = npools; + + mutex_unlock(&svc_pool_map_mutex); + return m->npools; +} + + +/* + * Drop a reference to the global map of cpus to pools. + * When the last reference is dropped, the map data is + * freed; this allows the sysadmin to change the pool + * mode using the pool_mode module option without + * rebooting or re-loading sunrpc.ko. + */ +static void +svc_pool_map_put(void) +{ + struct svc_pool_map *m = &svc_pool_map; + + mutex_lock(&svc_pool_map_mutex); + + if (!--m->count) { + m->mode = SVC_POOL_DEFAULT; + kfree(m->to_pool); + kfree(m->pool_to); + m->npools = 0; + } + + mutex_unlock(&svc_pool_map_mutex); +} + + +/* + * Set the given thread's cpus_allowed mask so that it + * will only run on cpus in the given pool. + */ +static inline void +svc_pool_map_set_cpumask(struct task_struct *task, unsigned int pidx) +{ + struct svc_pool_map *m = &svc_pool_map; + unsigned int node = m->pool_to[pidx]; + + /* + * The caller checks for sv_nrpools > 1, which + * implies that we've been initialized. + */ + BUG_ON(m->count == 0); + + switch (m->mode) { + case SVC_POOL_PERCPU: + { + set_cpus_allowed_ptr(task, &cpumask_of_cpu(node)); + break; + } + case SVC_POOL_PERNODE: + { + node_to_cpumask_ptr(nodecpumask, node); + set_cpus_allowed_ptr(task, nodecpumask); + break; + } + } +} + +/* + * Use the mapping mode to choose a pool for a given CPU. + * Used when enqueueing an incoming RPC. Always returns + * a non-NULL pool pointer. + */ +struct svc_pool * +svc_pool_for_cpu(struct svc_serv *serv, int cpu) +{ + struct svc_pool_map *m = &svc_pool_map; + unsigned int pidx = 0; + + /* + * An uninitialised map happens in a pure client when + * lockd is brought up, so silently treat it the + * same as SVC_POOL_GLOBAL. + */ + if (svc_serv_is_pooled(serv)) { + switch (m->mode) { + case SVC_POOL_PERCPU: + pidx = m->to_pool[cpu]; + break; + case SVC_POOL_PERNODE: + pidx = m->to_pool[cpu_to_node(cpu)]; + break; + } + } + return &serv->sv_pools[pidx % serv->sv_nrpools]; +} + + +/* + * Create an RPC service + */ +static struct svc_serv * +__svc_create(struct svc_program *prog, unsigned int bufsize, int npools, + sa_family_t family, void (*shutdown)(struct svc_serv *serv)) +{ + struct svc_serv *serv; + unsigned int vers; + unsigned int xdrsize; + unsigned int i; + + if (!(serv = kzalloc(sizeof(*serv), GFP_KERNEL))) + return NULL; + serv->sv_family = family; + serv->sv_name = prog->pg_name; + serv->sv_program = prog; + serv->sv_nrthreads = 1; + serv->sv_stats = prog->pg_stats; + if (bufsize > RPCSVC_MAXPAYLOAD) + bufsize = RPCSVC_MAXPAYLOAD; + serv->sv_max_payload = bufsize? bufsize : 4096; + serv->sv_max_mesg = roundup(serv->sv_max_payload + PAGE_SIZE, PAGE_SIZE); + serv->sv_shutdown = shutdown; + xdrsize = 0; + while (prog) { + prog->pg_lovers = prog->pg_nvers-1; + for (vers=0; vers<prog->pg_nvers ; vers++) + if (prog->pg_vers[vers]) { + prog->pg_hivers = vers; + if (prog->pg_lovers > vers) + prog->pg_lovers = vers; + if (prog->pg_vers[vers]->vs_xdrsize > xdrsize) + xdrsize = prog->pg_vers[vers]->vs_xdrsize; + } + prog = prog->pg_next; + } + serv->sv_xdrsize = xdrsize; + INIT_LIST_HEAD(&serv->sv_tempsocks); + INIT_LIST_HEAD(&serv->sv_permsocks); + init_timer(&serv->sv_temptimer); + spin_lock_init(&serv->sv_lock); + + serv->sv_nrpools = npools; + serv->sv_pools = + kcalloc(serv->sv_nrpools, sizeof(struct svc_pool), + GFP_KERNEL); + if (!serv->sv_pools) { + kfree(serv); + return NULL; + } + + for (i = 0; i < serv->sv_nrpools; i++) { + struct svc_pool *pool = &serv->sv_pools[i]; + + dprintk("svc: initialising pool %u for %s\n", + i, serv->sv_name); + + pool->sp_id = i; + INIT_LIST_HEAD(&pool->sp_threads); + INIT_LIST_HEAD(&pool->sp_sockets); + INIT_LIST_HEAD(&pool->sp_all_threads); + spin_lock_init(&pool->sp_lock); + } + + /* Remove any stale portmap registrations */ + svc_unregister(serv); + + return serv; +} + +struct svc_serv * +svc_create(struct svc_program *prog, unsigned int bufsize, + sa_family_t family, void (*shutdown)(struct svc_serv *serv)) +{ + return __svc_create(prog, bufsize, /*npools*/1, family, shutdown); +} +EXPORT_SYMBOL(svc_create); + +struct svc_serv * +svc_create_pooled(struct svc_program *prog, unsigned int bufsize, + sa_family_t family, void (*shutdown)(struct svc_serv *serv), + svc_thread_fn func, struct module *mod) +{ + struct svc_serv *serv; + unsigned int npools = svc_pool_map_get(); + + serv = __svc_create(prog, bufsize, npools, family, shutdown); + + if (serv != NULL) { + serv->sv_function = func; + serv->sv_module = mod; + } + + return serv; +} +EXPORT_SYMBOL(svc_create_pooled); + +/* + * Destroy an RPC service. Should be called with appropriate locking to + * protect the sv_nrthreads, sv_permsocks and sv_tempsocks. + */ +void +svc_destroy(struct svc_serv *serv) +{ + dprintk("svc: svc_destroy(%s, %d)\n", + serv->sv_program->pg_name, + serv->sv_nrthreads); + + if (serv->sv_nrthreads) { + if (--(serv->sv_nrthreads) != 0) { + svc_sock_update_bufs(serv); + return; + } + } else + printk("svc_destroy: no threads for serv=%p!\n", serv); + + del_timer_sync(&serv->sv_temptimer); + + svc_close_all(&serv->sv_tempsocks); + + if (serv->sv_shutdown) + serv->sv_shutdown(serv); + + svc_close_all(&serv->sv_permsocks); + + BUG_ON(!list_empty(&serv->sv_permsocks)); + BUG_ON(!list_empty(&serv->sv_tempsocks)); + + cache_clean_deferred(serv); + + if (svc_serv_is_pooled(serv)) + svc_pool_map_put(); + + svc_unregister(serv); + kfree(serv->sv_pools); + kfree(serv); +} +EXPORT_SYMBOL(svc_destroy); + +/* + * Allocate an RPC server's buffer space. + * We allocate pages and place them in rq_argpages. + */ +static int +svc_init_buffer(struct svc_rqst *rqstp, unsigned int size) +{ + unsigned int pages, arghi; + + pages = size / PAGE_SIZE + 1; /* extra page as we hold both request and reply. + * We assume one is at most one page + */ + arghi = 0; + BUG_ON(pages > RPCSVC_MAXPAGES); + while (pages) { + struct page *p = alloc_page(GFP_KERNEL); + if (!p) + break; + rqstp->rq_pages[arghi++] = p; + pages--; + } + return pages == 0; +} + +/* + * Release an RPC server buffer + */ +static void +svc_release_buffer(struct svc_rqst *rqstp) +{ + unsigned int i; + + for (i = 0; i < ARRAY_SIZE(rqstp->rq_pages); i++) + if (rqstp->rq_pages[i]) + put_page(rqstp->rq_pages[i]); +} + +struct svc_rqst * +svc_prepare_thread(struct svc_serv *serv, struct svc_pool *pool) +{ + struct svc_rqst *rqstp; + + rqstp = kzalloc(sizeof(*rqstp), GFP_KERNEL); + if (!rqstp) + goto out_enomem; + + init_waitqueue_head(&rqstp->rq_wait); + + serv->sv_nrthreads++; + spin_lock_bh(&pool->sp_lock); + pool->sp_nrthreads++; + list_add(&rqstp->rq_all, &pool->sp_all_threads); + spin_unlock_bh(&pool->sp_lock); + rqstp->rq_server = serv; + rqstp->rq_pool = pool; + + rqstp->rq_argp = kmalloc(serv->sv_xdrsize, GFP_KERNEL); + if (!rqstp->rq_argp) + goto out_thread; + + rqstp->rq_resp = kmalloc(serv->sv_xdrsize, GFP_KERNEL); + if (!rqstp->rq_resp) + goto out_thread; + + if (!svc_init_buffer(rqstp, serv->sv_max_mesg)) + goto out_thread; + + return rqstp; +out_thread: + svc_exit_thread(rqstp); +out_enomem: + return ERR_PTR(-ENOMEM); +} +EXPORT_SYMBOL(svc_prepare_thread); + +/* + * Choose a pool in which to create a new thread, for svc_set_num_threads + */ +static inline struct svc_pool * +choose_pool(struct svc_serv *serv, struct svc_pool *pool, unsigned int *state) +{ + if (pool != NULL) + return pool; + + return &serv->sv_pools[(*state)++ % serv->sv_nrpools]; +} + +/* + * Choose a thread to kill, for svc_set_num_threads + */ +static inline struct task_struct * +choose_victim(struct svc_serv *serv, struct svc_pool *pool, unsigned int *state) +{ + unsigned int i; + struct task_struct *task = NULL; + + if (pool != NULL) { + spin_lock_bh(&pool->sp_lock); + } else { + /* choose a pool in round-robin fashion */ + for (i = 0; i < serv->sv_nrpools; i++) { + pool = &serv->sv_pools[--(*state) % serv->sv_nrpools]; + spin_lock_bh(&pool->sp_lock); + if (!list_empty(&pool->sp_all_threads)) + goto found_pool; + spin_unlock_bh(&pool->sp_lock); + } + return NULL; + } + +found_pool: + if (!list_empty(&pool->sp_all_threads)) { + struct svc_rqst *rqstp; + + /* + * Remove from the pool->sp_all_threads list + * so we don't try to kill it again. + */ + rqstp = list_entry(pool->sp_all_threads.next, struct svc_rqst, rq_all); + list_del_init(&rqstp->rq_all); + task = rqstp->rq_task; + } + spin_unlock_bh(&pool->sp_lock); + + return task; +} + +/* + * Create or destroy enough new threads to make the number + * of threads the given number. If `pool' is non-NULL, applies + * only to threads in that pool, otherwise round-robins between + * all pools. Must be called with a svc_get() reference and + * the BKL or another lock to protect access to svc_serv fields. + * + * Destroying threads relies on the service threads filling in + * rqstp->rq_task, which only the nfs ones do. Assumes the serv + * has been created using svc_create_pooled(). + * + * Based on code that used to be in nfsd_svc() but tweaked + * to be pool-aware. + */ +int +svc_set_num_threads(struct svc_serv *serv, struct svc_pool *pool, int nrservs) +{ + struct svc_rqst *rqstp; + struct task_struct *task; + struct svc_pool *chosen_pool; + int error = 0; + unsigned int state = serv->sv_nrthreads-1; + + if (pool == NULL) { + /* The -1 assumes caller has done a svc_get() */ + nrservs -= (serv->sv_nrthreads-1); + } else { + spin_lock_bh(&pool->sp_lock); + nrservs -= pool->sp_nrthreads; + spin_unlock_bh(&pool->sp_lock); + } + + /* create new threads */ + while (nrservs > 0) { + nrservs--; + chosen_pool = choose_pool(serv, pool, &state); + + rqstp = svc_prepare_thread(serv, chosen_pool); + if (IS_ERR(rqstp)) { + error = PTR_ERR(rqstp); + break; + } + + __module_get(serv->sv_module); + task = kthread_create(serv->sv_function, rqstp, serv->sv_name); + if (IS_ERR(task)) { + error = PTR_ERR(task); + module_put(serv->sv_module); + svc_exit_thread(rqstp); + break; + } + + rqstp->rq_task = task; + if (serv->sv_nrpools > 1) + svc_pool_map_set_cpumask(task, chosen_pool->sp_id); + + svc_sock_update_bufs(serv); + wake_up_process(task); + } + /* destroy old threads */ + while (nrservs < 0 && + (task = choose_victim(serv, pool, &state)) != NULL) { + send_sig(SIGINT, task, 1); + nrservs++; + } + + return error; +} +EXPORT_SYMBOL(svc_set_num_threads); + +/* + * Called from a server thread as it's exiting. Caller must hold the BKL or + * the "service mutex", whichever is appropriate for the service. + */ +void +svc_exit_thread(struct svc_rqst *rqstp) +{ + struct svc_serv *serv = rqstp->rq_server; + struct svc_pool *pool = rqstp->rq_pool; + + svc_release_buffer(rqstp); + kfree(rqstp->rq_resp); + kfree(rqstp->rq_argp); + kfree(rqstp->rq_auth_data); + + spin_lock_bh(&pool->sp_lock); + pool->sp_nrthreads--; + list_del(&rqstp->rq_all); + spin_unlock_bh(&pool->sp_lock); + + kfree(rqstp); + + /* Release the server */ + if (serv) + svc_destroy(serv); +} +EXPORT_SYMBOL(svc_exit_thread); + +#ifdef CONFIG_SUNRPC_REGISTER_V4 + +/* + * Register an "inet" protocol family netid with the local + * rpcbind daemon via an rpcbind v4 SET request. + * + * No netconfig infrastructure is available in the kernel, so + * we map IP_ protocol numbers to netids by hand. + * + * Returns zero on success; a negative errno value is returned + * if any error occurs. + */ +static int __svc_rpcb_register4(const u32 program, const u32 version, + const unsigned short protocol, + const unsigned short port) +{ + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), + .sin_port = htons(port), + }; + char *netid; + + switch (protocol) { + case IPPROTO_UDP: + netid = RPCBIND_NETID_UDP; + break; + case IPPROTO_TCP: + netid = RPCBIND_NETID_TCP; + break; + default: + return -EPROTONOSUPPORT; + } + + return rpcb_v4_register(program, version, + (struct sockaddr *)&sin, netid); +} + +/* + * Register an "inet6" protocol family netid with the local + * rpcbind daemon via an rpcbind v4 SET request. + * + * No netconfig infrastructure is available in the kernel, so + * we map IP_ protocol numbers to netids by hand. + * + * Returns zero on success; a negative errno value is returned + * if any error occurs. + */ +static int __svc_rpcb_register6(const u32 program, const u32 version, + const unsigned short protocol, + const unsigned short port) +{ + struct sockaddr_in6 sin6 = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, + .sin6_port = htons(port), + }; + char *netid; + + switch (protocol) { + case IPPROTO_UDP: + netid = RPCBIND_NETID_UDP6; + break; + case IPPROTO_TCP: + netid = RPCBIND_NETID_TCP6; + break; + default: + return -EPROTONOSUPPORT; + } + + return rpcb_v4_register(program, version, + (struct sockaddr *)&sin6, netid); +} + +/* + * Register a kernel RPC service via rpcbind version 4. + * + * Returns zero on success; a negative errno value is returned + * if any error occurs. + */ +static int __svc_register(const u32 program, const u32 version, + const sa_family_t family, + const unsigned short protocol, + const unsigned short port) +{ + int error; + + switch (family) { + case AF_INET: + return __svc_rpcb_register4(program, version, + protocol, port); + case AF_INET6: + error = __svc_rpcb_register6(program, version, + protocol, port); + if (error < 0) + return error; + + /* + * Work around bug in some versions of Linux rpcbind + * which don't allow registration of both inet and + * inet6 netids. + * + * Error return ignored for now. + */ + __svc_rpcb_register4(program, version, + protocol, port); + return 0; + } + + return -EAFNOSUPPORT; +} + +#else /* CONFIG_SUNRPC_REGISTER_V4 */ + +/* + * Register a kernel RPC service via rpcbind version 2. + * + * Returns zero on success; a negative errno value is returned + * if any error occurs. + */ +static int __svc_register(const u32 program, const u32 version, + sa_family_t family, + const unsigned short protocol, + const unsigned short port) +{ + if (family != AF_INET) + return -EAFNOSUPPORT; + + return rpcb_register(program, version, protocol, port); +} + +#endif /* CONFIG_SUNRPC_REGISTER_V4 */ + +/** + * svc_register - register an RPC service with the local portmapper + * @serv: svc_serv struct for the service to register + * @proto: transport protocol number to advertise + * @port: port to advertise + * + * Service is registered for any address in serv's address family + */ +int svc_register(const struct svc_serv *serv, const unsigned short proto, + const unsigned short port) +{ + struct svc_program *progp; + unsigned int i; + int error = 0; + + BUG_ON(proto == 0 && port == 0); + + for (progp = serv->sv_program; progp; progp = progp->pg_next) { + for (i = 0; i < progp->pg_nvers; i++) { + if (progp->pg_vers[i] == NULL) + continue; + + dprintk("svc: svc_register(%sv%d, %s, %u, %u)%s\n", + progp->pg_name, + i, + proto == IPPROTO_UDP? "udp" : "tcp", + port, + serv->sv_family, + progp->pg_vers[i]->vs_hidden? + " (but not telling portmap)" : ""); + + if (progp->pg_vers[i]->vs_hidden) + continue; + + error = __svc_register(progp->pg_prog, i, + serv->sv_family, proto, port); + if (error < 0) + break; + } + } + + return error; +} + +#ifdef CONFIG_SUNRPC_REGISTER_V4 + +static void __svc_unregister(const u32 program, const u32 version, + const char *progname) +{ + struct sockaddr_in6 sin6 = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, + .sin6_port = 0, + }; + int error; + + error = rpcb_v4_register(program, version, + (struct sockaddr *)&sin6, ""); + dprintk("svc: %s(%sv%u), error %d\n", + __func__, progname, version, error); +} + +#else /* CONFIG_SUNRPC_REGISTER_V4 */ + +static void __svc_unregister(const u32 program, const u32 version, + const char *progname) +{ + int error; + + error = rpcb_register(program, version, 0, 0); + dprintk("svc: %s(%sv%u), error %d\n", + __func__, progname, version, error); +} + +#endif /* CONFIG_SUNRPC_REGISTER_V4 */ + +/* + * All netids, bind addresses and ports registered for [program, version] + * are removed from the local rpcbind database (if the service is not + * hidden) to make way for a new instance of the service. + * + * The result of unregistration is reported via dprintk for those who want + * verification of the result, but is otherwise not important. + */ +static void svc_unregister(const struct svc_serv *serv) +{ + struct svc_program *progp; + unsigned long flags; + unsigned int i; + + clear_thread_flag(TIF_SIGPENDING); + + for (progp = serv->sv_program; progp; progp = progp->pg_next) { + for (i = 0; i < progp->pg_nvers; i++) { + if (progp->pg_vers[i] == NULL) + continue; + if (progp->pg_vers[i]->vs_hidden) + continue; + + __svc_unregister(progp->pg_prog, i, progp->pg_name); + } + } + + spin_lock_irqsave(¤t->sighand->siglock, flags); + recalc_sigpending(); + spin_unlock_irqrestore(¤t->sighand->siglock, flags); +} + +/* + * Printk the given error with the address of the client that caused it. + */ +static int +__attribute__ ((format (printf, 2, 3))) +svc_printk(struct svc_rqst *rqstp, const char *fmt, ...) +{ + va_list args; + int r; + char buf[RPC_MAX_ADDRBUFLEN]; + + if (!net_ratelimit()) + return 0; + + printk(KERN_WARNING "svc: %s: ", + svc_print_addr(rqstp, buf, sizeof(buf))); + + va_start(args, fmt); + r = vprintk(fmt, args); + va_end(args); + + return r; +} + +/* + * Process the RPC request. + */ +int +svc_process(struct svc_rqst *rqstp) +{ + struct svc_program *progp; + struct svc_version *versp = NULL; /* compiler food */ + struct svc_procedure *procp = NULL; + struct kvec * argv = &rqstp->rq_arg.head[0]; + struct kvec * resv = &rqstp->rq_res.head[0]; + struct svc_serv *serv = rqstp->rq_server; + kxdrproc_t xdr; + __be32 *statp; + u32 dir, prog, vers, proc; + __be32 auth_stat, rpc_stat; + int auth_res; + __be32 *reply_statp; + + rpc_stat = rpc_success; + + if (argv->iov_len < 6*4) + goto err_short_len; + + /* setup response xdr_buf. + * Initially it has just one page + */ + rqstp->rq_resused = 1; + resv->iov_base = page_address(rqstp->rq_respages[0]); + resv->iov_len = 0; + rqstp->rq_res.pages = rqstp->rq_respages + 1; + rqstp->rq_res.len = 0; + rqstp->rq_res.page_base = 0; + rqstp->rq_res.page_len = 0; + rqstp->rq_res.buflen = PAGE_SIZE; + rqstp->rq_res.tail[0].iov_base = NULL; + rqstp->rq_res.tail[0].iov_len = 0; + /* Will be turned off only in gss privacy case: */ + rqstp->rq_splice_ok = 1; + + /* Setup reply header */ + rqstp->rq_xprt->xpt_ops->xpo_prep_reply_hdr(rqstp); + + rqstp->rq_xid = svc_getu32(argv); + svc_putu32(resv, rqstp->rq_xid); + + dir = svc_getnl(argv); + vers = svc_getnl(argv); + + /* First words of reply: */ + svc_putnl(resv, 1); /* REPLY */ + + if (dir != 0) /* direction != CALL */ + goto err_bad_dir; + if (vers != 2) /* RPC version number */ + goto err_bad_rpc; + + /* Save position in case we later decide to reject: */ + reply_statp = resv->iov_base + resv->iov_len; + + svc_putnl(resv, 0); /* ACCEPT */ + + rqstp->rq_prog = prog = svc_getnl(argv); /* program number */ + rqstp->rq_vers = vers = svc_getnl(argv); /* version number */ + rqstp->rq_proc = proc = svc_getnl(argv); /* procedure number */ + + progp = serv->sv_program; + + for (progp = serv->sv_program; progp; progp = progp->pg_next) + if (prog == progp->pg_prog) + break; + + /* + * Decode auth data, and add verifier to reply buffer. + * We do this before anything else in order to get a decent + * auth verifier. + */ + auth_res = svc_authenticate(rqstp, &auth_stat); + /* Also give the program a chance to reject this call: */ + if (auth_res == SVC_OK && progp) { + auth_stat = rpc_autherr_badcred; + auth_res = progp->pg_authenticate(rqstp); + } + switch (auth_res) { + case SVC_OK: + break; + case SVC_GARBAGE: + goto err_garbage; + case SVC_SYSERR: + rpc_stat = rpc_system_err; + goto err_bad; + case SVC_DENIED: + goto err_bad_auth; + case SVC_DROP: + goto dropit; + case SVC_COMPLETE: + goto sendit; + } + + if (progp == NULL) + goto err_bad_prog; + + if (vers >= progp->pg_nvers || + !(versp = progp->pg_vers[vers])) + goto err_bad_vers; + + procp = versp->vs_proc + proc; + if (proc >= versp->vs_nproc || !procp->pc_func) + goto err_bad_proc; + rqstp->rq_server = serv; + rqstp->rq_procinfo = procp; + + /* Syntactic check complete */ + serv->sv_stats->rpccnt++; + + /* Build the reply header. */ + statp = resv->iov_base +resv->iov_len; + svc_putnl(resv, RPC_SUCCESS); + + /* Bump per-procedure stats counter */ + procp->pc_count++; + + /* Initialize storage for argp and resp */ + memset(rqstp->rq_argp, 0, procp->pc_argsize); + memset(rqstp->rq_resp, 0, procp->pc_ressize); + + /* un-reserve some of the out-queue now that we have a + * better idea of reply size + */ + if (procp->pc_xdrressize) + svc_reserve_auth(rqstp, procp->pc_xdrressize<<2); + + /* Call the function that processes the request. */ + if (!versp->vs_dispatch) { + /* Decode arguments */ + xdr = procp->pc_decode; + if (xdr && !xdr(rqstp, argv->iov_base, rqstp->rq_argp)) + goto err_garbage; + + *statp = procp->pc_func(rqstp, rqstp->rq_argp, rqstp->rq_resp); + + /* Encode reply */ + if (*statp == rpc_drop_reply) { + if (procp->pc_release) + procp->pc_release(rqstp, NULL, rqstp->rq_resp); + goto dropit; + } + if (*statp == rpc_success && (xdr = procp->pc_encode) + && !xdr(rqstp, resv->iov_base+resv->iov_len, rqstp->rq_resp)) { + dprintk("svc: failed to encode reply\n"); + /* serv->sv_stats->rpcsystemerr++; */ + *statp = rpc_system_err; + } + } else { + dprintk("svc: calling dispatcher\n"); + if (!versp->vs_dispatch(rqstp, statp)) { + /* Release reply info */ + if (procp->pc_release) + procp->pc_release(rqstp, NULL, rqstp->rq_resp); + goto dropit; + } + } + + /* Check RPC status result */ + if (*statp != rpc_success) + resv->iov_len = ((void*)statp) - resv->iov_base + 4; + + /* Release reply info */ + if (procp->pc_release) + procp->pc_release(rqstp, NULL, rqstp->rq_resp); + + if (procp->pc_encode == NULL) + goto dropit; + + sendit: + if (svc_authorise(rqstp)) + goto dropit; + return svc_send(rqstp); + + dropit: + svc_authorise(rqstp); /* doesn't hurt to call this twice */ + dprintk("svc: svc_process dropit\n"); + svc_drop(rqstp); + return 0; + +err_short_len: + svc_printk(rqstp, "short len %Zd, dropping request\n", + argv->iov_len); + + goto dropit; /* drop request */ + +err_bad_dir: + svc_printk(rqstp, "bad direction %d, dropping request\n", dir); + + serv->sv_stats->rpcbadfmt++; + goto dropit; /* drop request */ + +err_bad_rpc: + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, 1); /* REJECT */ + svc_putnl(resv, 0); /* RPC_MISMATCH */ + svc_putnl(resv, 2); /* Only RPCv2 supported */ + svc_putnl(resv, 2); + goto sendit; + +err_bad_auth: + dprintk("svc: authentication failed (%d)\n", ntohl(auth_stat)); + serv->sv_stats->rpcbadauth++; + /* Restore write pointer to location of accept status: */ + xdr_ressize_check(rqstp, reply_statp); + svc_putnl(resv, 1); /* REJECT */ + svc_putnl(resv, 1); /* AUTH_ERROR */ + svc_putnl(resv, ntohl(auth_stat)); /* status */ + goto sendit; + +err_bad_prog: + dprintk("svc: unknown program %d\n", prog); + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, RPC_PROG_UNAVAIL); + goto sendit; + +err_bad_vers: + svc_printk(rqstp, "unknown version (%d for prog %d, %s)\n", + vers, prog, progp->pg_name); + + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, RPC_PROG_MISMATCH); + svc_putnl(resv, progp->pg_lovers); + svc_putnl(resv, progp->pg_hivers); + goto sendit; + +err_bad_proc: + svc_printk(rqstp, "unknown procedure (%d)\n", proc); + + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, RPC_PROC_UNAVAIL); + goto sendit; + +err_garbage: + svc_printk(rqstp, "failed to decode args\n"); + + rpc_stat = rpc_garbage_args; +err_bad: + serv->sv_stats->rpcbadfmt++; + svc_putnl(resv, ntohl(rpc_stat)); + goto sendit; +} +EXPORT_SYMBOL(svc_process); + +/* + * Return (transport-specific) limit on the rpc payload. + */ +u32 svc_max_payload(const struct svc_rqst *rqstp) +{ + u32 max = rqstp->rq_xprt->xpt_class->xcl_max_payload; + + if (rqstp->rq_server->sv_max_payload < max) + max = rqstp->rq_server->sv_max_payload; + return max; +} +EXPORT_SYMBOL_GPL(svc_max_payload); diff --git a/net/sunrpc/svc_xprt.c b/net/sunrpc/svc_xprt.c new file mode 100644 index 0000000..bf5b5cd --- /dev/null +++ b/net/sunrpc/svc_xprt.c @@ -0,0 +1,1081 @@ +/* + * linux/net/sunrpc/svc_xprt.c + * + * Author: Tom Tucker <tom@opengridcomputing.com> + */ + +#include <linux/sched.h> +#include <linux/errno.h> +#include <linux/freezer.h> +#include <linux/kthread.h> +#include <net/sock.h> +#include <linux/sunrpc/stats.h> +#include <linux/sunrpc/svc_xprt.h> + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt); +static int svc_deferred_recv(struct svc_rqst *rqstp); +static struct cache_deferred_req *svc_defer(struct cache_req *req); +static void svc_age_temp_xprts(unsigned long closure); + +/* apparently the "standard" is that clients close + * idle connections after 5 minutes, servers after + * 6 minutes + * http://www.connectathon.org/talks96/nfstcp.pdf + */ +static int svc_conn_age_period = 6*60; + +/* List of registered transport classes */ +static DEFINE_SPINLOCK(svc_xprt_class_lock); +static LIST_HEAD(svc_xprt_class_list); + +/* SMP locking strategy: + * + * svc_pool->sp_lock protects most of the fields of that pool. + * svc_serv->sv_lock protects sv_tempsocks, sv_permsocks, sv_tmpcnt. + * when both need to be taken (rare), svc_serv->sv_lock is first. + * BKL protects svc_serv->sv_nrthread. + * svc_sock->sk_lock protects the svc_sock->sk_deferred list + * and the ->sk_info_authunix cache. + * + * The XPT_BUSY bit in xprt->xpt_flags prevents a transport being + * enqueued multiply. During normal transport processing this bit + * is set by svc_xprt_enqueue and cleared by svc_xprt_received. + * Providers should not manipulate this bit directly. + * + * Some flags can be set to certain values at any time + * providing that certain rules are followed: + * + * XPT_CONN, XPT_DATA: + * - Can be set or cleared at any time. + * - After a set, svc_xprt_enqueue must be called to enqueue + * the transport for processing. + * - After a clear, the transport must be read/accepted. + * If this succeeds, it must be set again. + * XPT_CLOSE: + * - Can set at any time. It is never cleared. + * XPT_DEAD: + * - Can only be set while XPT_BUSY is held which ensures + * that no other thread will be using the transport or will + * try to set XPT_DEAD. + */ + +int svc_reg_xprt_class(struct svc_xprt_class *xcl) +{ + struct svc_xprt_class *cl; + int res = -EEXIST; + + dprintk("svc: Adding svc transport class '%s'\n", xcl->xcl_name); + + INIT_LIST_HEAD(&xcl->xcl_list); + spin_lock(&svc_xprt_class_lock); + /* Make sure there isn't already a class with the same name */ + list_for_each_entry(cl, &svc_xprt_class_list, xcl_list) { + if (strcmp(xcl->xcl_name, cl->xcl_name) == 0) + goto out; + } + list_add_tail(&xcl->xcl_list, &svc_xprt_class_list); + res = 0; +out: + spin_unlock(&svc_xprt_class_lock); + return res; +} +EXPORT_SYMBOL_GPL(svc_reg_xprt_class); + +void svc_unreg_xprt_class(struct svc_xprt_class *xcl) +{ + dprintk("svc: Removing svc transport class '%s'\n", xcl->xcl_name); + spin_lock(&svc_xprt_class_lock); + list_del_init(&xcl->xcl_list); + spin_unlock(&svc_xprt_class_lock); +} +EXPORT_SYMBOL_GPL(svc_unreg_xprt_class); + +/* + * Format the transport list for printing + */ +int svc_print_xprts(char *buf, int maxlen) +{ + struct list_head *le; + char tmpstr[80]; + int len = 0; + buf[0] = '\0'; + + spin_lock(&svc_xprt_class_lock); + list_for_each(le, &svc_xprt_class_list) { + int slen; + struct svc_xprt_class *xcl = + list_entry(le, struct svc_xprt_class, xcl_list); + + sprintf(tmpstr, "%s %d\n", xcl->xcl_name, xcl->xcl_max_payload); + slen = strlen(tmpstr); + if (len + slen > maxlen) + break; + len += slen; + strcat(buf, tmpstr); + } + spin_unlock(&svc_xprt_class_lock); + + return len; +} + +static void svc_xprt_free(struct kref *kref) +{ + struct svc_xprt *xprt = + container_of(kref, struct svc_xprt, xpt_ref); + struct module *owner = xprt->xpt_class->xcl_owner; + if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags) + && xprt->xpt_auth_cache != NULL) + svcauth_unix_info_release(xprt->xpt_auth_cache); + xprt->xpt_ops->xpo_free(xprt); + module_put(owner); +} + +void svc_xprt_put(struct svc_xprt *xprt) +{ + kref_put(&xprt->xpt_ref, svc_xprt_free); +} +EXPORT_SYMBOL_GPL(svc_xprt_put); + +/* + * Called by transport drivers to initialize the transport independent + * portion of the transport instance. + */ +void svc_xprt_init(struct svc_xprt_class *xcl, struct svc_xprt *xprt, + struct svc_serv *serv) +{ + memset(xprt, 0, sizeof(*xprt)); + xprt->xpt_class = xcl; + xprt->xpt_ops = xcl->xcl_ops; + kref_init(&xprt->xpt_ref); + xprt->xpt_server = serv; + INIT_LIST_HEAD(&xprt->xpt_list); + INIT_LIST_HEAD(&xprt->xpt_ready); + INIT_LIST_HEAD(&xprt->xpt_deferred); + mutex_init(&xprt->xpt_mutex); + spin_lock_init(&xprt->xpt_lock); + set_bit(XPT_BUSY, &xprt->xpt_flags); +} +EXPORT_SYMBOL_GPL(svc_xprt_init); + +static struct svc_xprt *__svc_xpo_create(struct svc_xprt_class *xcl, + struct svc_serv *serv, + unsigned short port, int flags) +{ + struct sockaddr_in sin = { + .sin_family = AF_INET, + .sin_addr.s_addr = htonl(INADDR_ANY), + .sin_port = htons(port), + }; + struct sockaddr_in6 sin6 = { + .sin6_family = AF_INET6, + .sin6_addr = IN6ADDR_ANY_INIT, + .sin6_port = htons(port), + }; + struct sockaddr *sap; + size_t len; + + switch (serv->sv_family) { + case AF_INET: + sap = (struct sockaddr *)&sin; + len = sizeof(sin); + break; + case AF_INET6: + sap = (struct sockaddr *)&sin6; + len = sizeof(sin6); + break; + default: + return ERR_PTR(-EAFNOSUPPORT); + } + + return xcl->xcl_ops->xpo_create(serv, sap, len, flags); +} + +int svc_create_xprt(struct svc_serv *serv, char *xprt_name, unsigned short port, + int flags) +{ + struct svc_xprt_class *xcl; + + dprintk("svc: creating transport %s[%d]\n", xprt_name, port); + spin_lock(&svc_xprt_class_lock); + list_for_each_entry(xcl, &svc_xprt_class_list, xcl_list) { + struct svc_xprt *newxprt; + + if (strcmp(xprt_name, xcl->xcl_name)) + continue; + + if (!try_module_get(xcl->xcl_owner)) + goto err; + + spin_unlock(&svc_xprt_class_lock); + newxprt = __svc_xpo_create(xcl, serv, port, flags); + if (IS_ERR(newxprt)) { + module_put(xcl->xcl_owner); + return PTR_ERR(newxprt); + } + + clear_bit(XPT_TEMP, &newxprt->xpt_flags); + spin_lock_bh(&serv->sv_lock); + list_add(&newxprt->xpt_list, &serv->sv_permsocks); + spin_unlock_bh(&serv->sv_lock); + clear_bit(XPT_BUSY, &newxprt->xpt_flags); + return svc_xprt_local_port(newxprt); + } + err: + spin_unlock(&svc_xprt_class_lock); + dprintk("svc: transport %s not found\n", xprt_name); + return -ENOENT; +} +EXPORT_SYMBOL_GPL(svc_create_xprt); + +/* + * Copy the local and remote xprt addresses to the rqstp structure + */ +void svc_xprt_copy_addrs(struct svc_rqst *rqstp, struct svc_xprt *xprt) +{ + struct sockaddr *sin; + + memcpy(&rqstp->rq_addr, &xprt->xpt_remote, xprt->xpt_remotelen); + rqstp->rq_addrlen = xprt->xpt_remotelen; + + /* + * Destination address in request is needed for binding the + * source address in RPC replies/callbacks later. + */ + sin = (struct sockaddr *)&xprt->xpt_local; + switch (sin->sa_family) { + case AF_INET: + rqstp->rq_daddr.addr = ((struct sockaddr_in *)sin)->sin_addr; + break; + case AF_INET6: + rqstp->rq_daddr.addr6 = ((struct sockaddr_in6 *)sin)->sin6_addr; + break; + } +} +EXPORT_SYMBOL_GPL(svc_xprt_copy_addrs); + +/** + * svc_print_addr - Format rq_addr field for printing + * @rqstp: svc_rqst struct containing address to print + * @buf: target buffer for formatted address + * @len: length of target buffer + * + */ +char *svc_print_addr(struct svc_rqst *rqstp, char *buf, size_t len) +{ + return __svc_print_addr(svc_addr(rqstp), buf, len); +} +EXPORT_SYMBOL_GPL(svc_print_addr); + +/* + * Queue up an idle server thread. Must have pool->sp_lock held. + * Note: this is really a stack rather than a queue, so that we only + * use as many different threads as we need, and the rest don't pollute + * the cache. + */ +static void svc_thread_enqueue(struct svc_pool *pool, struct svc_rqst *rqstp) +{ + list_add(&rqstp->rq_list, &pool->sp_threads); +} + +/* + * Dequeue an nfsd thread. Must have pool->sp_lock held. + */ +static void svc_thread_dequeue(struct svc_pool *pool, struct svc_rqst *rqstp) +{ + list_del(&rqstp->rq_list); +} + +/* + * Queue up a transport with data pending. If there are idle nfsd + * processes, wake 'em up. + * + */ +void svc_xprt_enqueue(struct svc_xprt *xprt) +{ + struct svc_serv *serv = xprt->xpt_server; + struct svc_pool *pool; + struct svc_rqst *rqstp; + int cpu; + + if (!(xprt->xpt_flags & + ((1<<XPT_CONN)|(1<<XPT_DATA)|(1<<XPT_CLOSE)|(1<<XPT_DEFERRED)))) + return; + + cpu = get_cpu(); + pool = svc_pool_for_cpu(xprt->xpt_server, cpu); + put_cpu(); + + spin_lock_bh(&pool->sp_lock); + + if (!list_empty(&pool->sp_threads) && + !list_empty(&pool->sp_sockets)) + printk(KERN_ERR + "svc_xprt_enqueue: " + "threads and transports both waiting??\n"); + + if (test_bit(XPT_DEAD, &xprt->xpt_flags)) { + /* Don't enqueue dead transports */ + dprintk("svc: transport %p is dead, not enqueued\n", xprt); + goto out_unlock; + } + + /* Mark transport as busy. It will remain in this state until + * the provider calls svc_xprt_received. We update XPT_BUSY + * atomically because it also guards against trying to enqueue + * the transport twice. + */ + if (test_and_set_bit(XPT_BUSY, &xprt->xpt_flags)) { + /* Don't enqueue transport while already enqueued */ + dprintk("svc: transport %p busy, not enqueued\n", xprt); + goto out_unlock; + } + BUG_ON(xprt->xpt_pool != NULL); + xprt->xpt_pool = pool; + + /* Handle pending connection */ + if (test_bit(XPT_CONN, &xprt->xpt_flags)) + goto process; + + /* Handle close in-progress */ + if (test_bit(XPT_CLOSE, &xprt->xpt_flags)) + goto process; + + /* Check if we have space to reply to a request */ + if (!xprt->xpt_ops->xpo_has_wspace(xprt)) { + /* Don't enqueue while not enough space for reply */ + dprintk("svc: no write space, transport %p not enqueued\n", + xprt); + xprt->xpt_pool = NULL; + clear_bit(XPT_BUSY, &xprt->xpt_flags); + goto out_unlock; + } + + process: + if (!list_empty(&pool->sp_threads)) { + rqstp = list_entry(pool->sp_threads.next, + struct svc_rqst, + rq_list); + dprintk("svc: transport %p served by daemon %p\n", + xprt, rqstp); + svc_thread_dequeue(pool, rqstp); + if (rqstp->rq_xprt) + printk(KERN_ERR + "svc_xprt_enqueue: server %p, rq_xprt=%p!\n", + rqstp, rqstp->rq_xprt); + rqstp->rq_xprt = xprt; + svc_xprt_get(xprt); + rqstp->rq_reserved = serv->sv_max_mesg; + atomic_add(rqstp->rq_reserved, &xprt->xpt_reserved); + BUG_ON(xprt->xpt_pool != pool); + wake_up(&rqstp->rq_wait); + } else { + dprintk("svc: transport %p put into queue\n", xprt); + list_add_tail(&xprt->xpt_ready, &pool->sp_sockets); + BUG_ON(xprt->xpt_pool != pool); + } + +out_unlock: + spin_unlock_bh(&pool->sp_lock); +} +EXPORT_SYMBOL_GPL(svc_xprt_enqueue); + +/* + * Dequeue the first transport. Must be called with the pool->sp_lock held. + */ +static struct svc_xprt *svc_xprt_dequeue(struct svc_pool *pool) +{ + struct svc_xprt *xprt; + + if (list_empty(&pool->sp_sockets)) + return NULL; + + xprt = list_entry(pool->sp_sockets.next, + struct svc_xprt, xpt_ready); + list_del_init(&xprt->xpt_ready); + + dprintk("svc: transport %p dequeued, inuse=%d\n", + xprt, atomic_read(&xprt->xpt_ref.refcount)); + + return xprt; +} + +/* + * svc_xprt_received conditionally queues the transport for processing + * by another thread. The caller must hold the XPT_BUSY bit and must + * not thereafter touch transport data. + * + * Note: XPT_DATA only gets cleared when a read-attempt finds no (or + * insufficient) data. + */ +void svc_xprt_received(struct svc_xprt *xprt) +{ + BUG_ON(!test_bit(XPT_BUSY, &xprt->xpt_flags)); + xprt->xpt_pool = NULL; + clear_bit(XPT_BUSY, &xprt->xpt_flags); + svc_xprt_enqueue(xprt); +} +EXPORT_SYMBOL_GPL(svc_xprt_received); + +/** + * svc_reserve - change the space reserved for the reply to a request. + * @rqstp: The request in question + * @space: new max space to reserve + * + * Each request reserves some space on the output queue of the transport + * to make sure the reply fits. This function reduces that reserved + * space to be the amount of space used already, plus @space. + * + */ +void svc_reserve(struct svc_rqst *rqstp, int space) +{ + space += rqstp->rq_res.head[0].iov_len; + + if (space < rqstp->rq_reserved) { + struct svc_xprt *xprt = rqstp->rq_xprt; + atomic_sub((rqstp->rq_reserved - space), &xprt->xpt_reserved); + rqstp->rq_reserved = space; + + svc_xprt_enqueue(xprt); + } +} +EXPORT_SYMBOL(svc_reserve); + +static void svc_xprt_release(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + + rqstp->rq_xprt->xpt_ops->xpo_release_rqst(rqstp); + + svc_free_res_pages(rqstp); + rqstp->rq_res.page_len = 0; + rqstp->rq_res.page_base = 0; + + /* Reset response buffer and release + * the reservation. + * But first, check that enough space was reserved + * for the reply, otherwise we have a bug! + */ + if ((rqstp->rq_res.len) > rqstp->rq_reserved) + printk(KERN_ERR "RPC request reserved %d but used %d\n", + rqstp->rq_reserved, + rqstp->rq_res.len); + + rqstp->rq_res.head[0].iov_len = 0; + svc_reserve(rqstp, 0); + rqstp->rq_xprt = NULL; + + svc_xprt_put(xprt); +} + +/* + * External function to wake up a server waiting for data + * This really only makes sense for services like lockd + * which have exactly one thread anyway. + */ +void svc_wake_up(struct svc_serv *serv) +{ + struct svc_rqst *rqstp; + unsigned int i; + struct svc_pool *pool; + + for (i = 0; i < serv->sv_nrpools; i++) { + pool = &serv->sv_pools[i]; + + spin_lock_bh(&pool->sp_lock); + if (!list_empty(&pool->sp_threads)) { + rqstp = list_entry(pool->sp_threads.next, + struct svc_rqst, + rq_list); + dprintk("svc: daemon %p woken up.\n", rqstp); + /* + svc_thread_dequeue(pool, rqstp); + rqstp->rq_xprt = NULL; + */ + wake_up(&rqstp->rq_wait); + } + spin_unlock_bh(&pool->sp_lock); + } +} +EXPORT_SYMBOL(svc_wake_up); + +int svc_port_is_privileged(struct sockaddr *sin) +{ + switch (sin->sa_family) { + case AF_INET: + return ntohs(((struct sockaddr_in *)sin)->sin_port) + < PROT_SOCK; + case AF_INET6: + return ntohs(((struct sockaddr_in6 *)sin)->sin6_port) + < PROT_SOCK; + default: + return 0; + } +} + +/* + * Make sure that we don't have too many active connections. If we + * have, something must be dropped. + * + * There's no point in trying to do random drop here for DoS + * prevention. The NFS clients does 1 reconnect in 15 seconds. An + * attacker can easily beat that. + * + * The only somewhat efficient mechanism would be if drop old + * connections from the same IP first. But right now we don't even + * record the client IP in svc_sock. + */ +static void svc_check_conn_limits(struct svc_serv *serv) +{ + if (serv->sv_tmpcnt > (serv->sv_nrthreads+3)*20) { + struct svc_xprt *xprt = NULL; + spin_lock_bh(&serv->sv_lock); + if (!list_empty(&serv->sv_tempsocks)) { + if (net_ratelimit()) { + /* Try to help the admin */ + printk(KERN_NOTICE "%s: too many open " + "connections, consider increasing the " + "number of nfsd threads\n", + serv->sv_name); + } + /* + * Always select the oldest connection. It's not fair, + * but so is life + */ + xprt = list_entry(serv->sv_tempsocks.prev, + struct svc_xprt, + xpt_list); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_xprt_get(xprt); + } + spin_unlock_bh(&serv->sv_lock); + + if (xprt) { + svc_xprt_enqueue(xprt); + svc_xprt_put(xprt); + } + } +} + +/* + * Receive the next request on any transport. This code is carefully + * organised not to touch any cachelines in the shared svc_serv + * structure, only cachelines in the local svc_pool. + */ +int svc_recv(struct svc_rqst *rqstp, long timeout) +{ + struct svc_xprt *xprt = NULL; + struct svc_serv *serv = rqstp->rq_server; + struct svc_pool *pool = rqstp->rq_pool; + int len, i; + int pages; + struct xdr_buf *arg; + DECLARE_WAITQUEUE(wait, current); + + dprintk("svc: server %p waiting for data (to = %ld)\n", + rqstp, timeout); + + if (rqstp->rq_xprt) + printk(KERN_ERR + "svc_recv: service %p, transport not NULL!\n", + rqstp); + if (waitqueue_active(&rqstp->rq_wait)) + printk(KERN_ERR + "svc_recv: service %p, wait queue active!\n", + rqstp); + + /* now allocate needed pages. If we get a failure, sleep briefly */ + pages = (serv->sv_max_mesg + PAGE_SIZE) / PAGE_SIZE; + for (i = 0; i < pages ; i++) + while (rqstp->rq_pages[i] == NULL) { + struct page *p = alloc_page(GFP_KERNEL); + if (!p) { + set_current_state(TASK_INTERRUPTIBLE); + if (signalled() || kthread_should_stop()) { + set_current_state(TASK_RUNNING); + return -EINTR; + } + schedule_timeout(msecs_to_jiffies(500)); + } + rqstp->rq_pages[i] = p; + } + rqstp->rq_pages[i++] = NULL; /* this might be seen in nfs_read_actor */ + BUG_ON(pages >= RPCSVC_MAXPAGES); + + /* Make arg->head point to first page and arg->pages point to rest */ + arg = &rqstp->rq_arg; + arg->head[0].iov_base = page_address(rqstp->rq_pages[0]); + arg->head[0].iov_len = PAGE_SIZE; + arg->pages = rqstp->rq_pages + 1; + arg->page_base = 0; + /* save at least one page for response */ + arg->page_len = (pages-2)*PAGE_SIZE; + arg->len = (pages-1)*PAGE_SIZE; + arg->tail[0].iov_len = 0; + + try_to_freeze(); + cond_resched(); + if (signalled() || kthread_should_stop()) + return -EINTR; + + spin_lock_bh(&pool->sp_lock); + xprt = svc_xprt_dequeue(pool); + if (xprt) { + rqstp->rq_xprt = xprt; + svc_xprt_get(xprt); + rqstp->rq_reserved = serv->sv_max_mesg; + atomic_add(rqstp->rq_reserved, &xprt->xpt_reserved); + } else { + /* No data pending. Go to sleep */ + svc_thread_enqueue(pool, rqstp); + + /* + * We have to be able to interrupt this wait + * to bring down the daemons ... + */ + set_current_state(TASK_INTERRUPTIBLE); + + /* + * checking kthread_should_stop() here allows us to avoid + * locking and signalling when stopping kthreads that call + * svc_recv. If the thread has already been woken up, then + * we can exit here without sleeping. If not, then it + * it'll be woken up quickly during the schedule_timeout + */ + if (kthread_should_stop()) { + set_current_state(TASK_RUNNING); + spin_unlock_bh(&pool->sp_lock); + return -EINTR; + } + + add_wait_queue(&rqstp->rq_wait, &wait); + spin_unlock_bh(&pool->sp_lock); + + schedule_timeout(timeout); + + try_to_freeze(); + + spin_lock_bh(&pool->sp_lock); + remove_wait_queue(&rqstp->rq_wait, &wait); + + xprt = rqstp->rq_xprt; + if (!xprt) { + svc_thread_dequeue(pool, rqstp); + spin_unlock_bh(&pool->sp_lock); + dprintk("svc: server %p, no data yet\n", rqstp); + if (signalled() || kthread_should_stop()) + return -EINTR; + else + return -EAGAIN; + } + } + spin_unlock_bh(&pool->sp_lock); + + len = 0; + if (test_bit(XPT_CLOSE, &xprt->xpt_flags)) { + dprintk("svc_recv: found XPT_CLOSE\n"); + svc_delete_xprt(xprt); + } else if (test_bit(XPT_LISTENER, &xprt->xpt_flags)) { + struct svc_xprt *newxpt; + newxpt = xprt->xpt_ops->xpo_accept(xprt); + if (newxpt) { + /* + * We know this module_get will succeed because the + * listener holds a reference too + */ + __module_get(newxpt->xpt_class->xcl_owner); + svc_check_conn_limits(xprt->xpt_server); + spin_lock_bh(&serv->sv_lock); + set_bit(XPT_TEMP, &newxpt->xpt_flags); + list_add(&newxpt->xpt_list, &serv->sv_tempsocks); + serv->sv_tmpcnt++; + if (serv->sv_temptimer.function == NULL) { + /* setup timer to age temp transports */ + setup_timer(&serv->sv_temptimer, + svc_age_temp_xprts, + (unsigned long)serv); + mod_timer(&serv->sv_temptimer, + jiffies + svc_conn_age_period * HZ); + } + spin_unlock_bh(&serv->sv_lock); + svc_xprt_received(newxpt); + } + svc_xprt_received(xprt); + } else { + dprintk("svc: server %p, pool %u, transport %p, inuse=%d\n", + rqstp, pool->sp_id, xprt, + atomic_read(&xprt->xpt_ref.refcount)); + rqstp->rq_deferred = svc_deferred_dequeue(xprt); + if (rqstp->rq_deferred) { + svc_xprt_received(xprt); + len = svc_deferred_recv(rqstp); + } else + len = xprt->xpt_ops->xpo_recvfrom(rqstp); + dprintk("svc: got len=%d\n", len); + } + + /* No data, incomplete (TCP) read, or accept() */ + if (len == 0 || len == -EAGAIN) { + rqstp->rq_res.len = 0; + svc_xprt_release(rqstp); + return -EAGAIN; + } + clear_bit(XPT_OLD, &xprt->xpt_flags); + + rqstp->rq_secure = svc_port_is_privileged(svc_addr(rqstp)); + rqstp->rq_chandle.defer = svc_defer; + + if (serv->sv_stats) + serv->sv_stats->netcnt++; + return len; +} +EXPORT_SYMBOL(svc_recv); + +/* + * Drop request + */ +void svc_drop(struct svc_rqst *rqstp) +{ + dprintk("svc: xprt %p dropped request\n", rqstp->rq_xprt); + svc_xprt_release(rqstp); +} +EXPORT_SYMBOL(svc_drop); + +/* + * Return reply to client. + */ +int svc_send(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt; + int len; + struct xdr_buf *xb; + + xprt = rqstp->rq_xprt; + if (!xprt) + return -EFAULT; + + /* release the receive skb before sending the reply */ + rqstp->rq_xprt->xpt_ops->xpo_release_rqst(rqstp); + + /* calculate over-all length */ + xb = &rqstp->rq_res; + xb->len = xb->head[0].iov_len + + xb->page_len + + xb->tail[0].iov_len; + + /* Grab mutex to serialize outgoing data. */ + mutex_lock(&xprt->xpt_mutex); + if (test_bit(XPT_DEAD, &xprt->xpt_flags)) + len = -ENOTCONN; + else + len = xprt->xpt_ops->xpo_sendto(rqstp); + mutex_unlock(&xprt->xpt_mutex); + svc_xprt_release(rqstp); + + if (len == -ECONNREFUSED || len == -ENOTCONN || len == -EAGAIN) + return 0; + return len; +} + +/* + * Timer function to close old temporary transports, using + * a mark-and-sweep algorithm. + */ +static void svc_age_temp_xprts(unsigned long closure) +{ + struct svc_serv *serv = (struct svc_serv *)closure; + struct svc_xprt *xprt; + struct list_head *le, *next; + LIST_HEAD(to_be_aged); + + dprintk("svc_age_temp_xprts\n"); + + if (!spin_trylock_bh(&serv->sv_lock)) { + /* busy, try again 1 sec later */ + dprintk("svc_age_temp_xprts: busy\n"); + mod_timer(&serv->sv_temptimer, jiffies + HZ); + return; + } + + list_for_each_safe(le, next, &serv->sv_tempsocks) { + xprt = list_entry(le, struct svc_xprt, xpt_list); + + /* First time through, just mark it OLD. Second time + * through, close it. */ + if (!test_and_set_bit(XPT_OLD, &xprt->xpt_flags)) + continue; + if (atomic_read(&xprt->xpt_ref.refcount) > 1 + || test_bit(XPT_BUSY, &xprt->xpt_flags)) + continue; + svc_xprt_get(xprt); + list_move(le, &to_be_aged); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + set_bit(XPT_DETACHED, &xprt->xpt_flags); + } + spin_unlock_bh(&serv->sv_lock); + + while (!list_empty(&to_be_aged)) { + le = to_be_aged.next; + /* fiddling the xpt_list node is safe 'cos we're XPT_DETACHED */ + list_del_init(le); + xprt = list_entry(le, struct svc_xprt, xpt_list); + + dprintk("queuing xprt %p for closing\n", xprt); + + /* a thread will dequeue and close it soon */ + svc_xprt_enqueue(xprt); + svc_xprt_put(xprt); + } + + mod_timer(&serv->sv_temptimer, jiffies + svc_conn_age_period * HZ); +} + +/* + * Remove a dead transport + */ +void svc_delete_xprt(struct svc_xprt *xprt) +{ + struct svc_serv *serv = xprt->xpt_server; + + dprintk("svc: svc_delete_xprt(%p)\n", xprt); + xprt->xpt_ops->xpo_detach(xprt); + + spin_lock_bh(&serv->sv_lock); + if (!test_and_set_bit(XPT_DETACHED, &xprt->xpt_flags)) + list_del_init(&xprt->xpt_list); + /* + * We used to delete the transport from whichever list + * it's sk_xprt.xpt_ready node was on, but we don't actually + * need to. This is because the only time we're called + * while still attached to a queue, the queue itself + * is about to be destroyed (in svc_destroy). + */ + if (!test_and_set_bit(XPT_DEAD, &xprt->xpt_flags)) { + BUG_ON(atomic_read(&xprt->xpt_ref.refcount) < 2); + if (test_bit(XPT_TEMP, &xprt->xpt_flags)) + serv->sv_tmpcnt--; + svc_xprt_put(xprt); + } + spin_unlock_bh(&serv->sv_lock); +} + +void svc_close_xprt(struct svc_xprt *xprt) +{ + set_bit(XPT_CLOSE, &xprt->xpt_flags); + if (test_and_set_bit(XPT_BUSY, &xprt->xpt_flags)) + /* someone else will have to effect the close */ + return; + + svc_xprt_get(xprt); + svc_delete_xprt(xprt); + clear_bit(XPT_BUSY, &xprt->xpt_flags); + svc_xprt_put(xprt); +} +EXPORT_SYMBOL_GPL(svc_close_xprt); + +void svc_close_all(struct list_head *xprt_list) +{ + struct svc_xprt *xprt; + struct svc_xprt *tmp; + + list_for_each_entry_safe(xprt, tmp, xprt_list, xpt_list) { + set_bit(XPT_CLOSE, &xprt->xpt_flags); + if (test_bit(XPT_BUSY, &xprt->xpt_flags)) { + /* Waiting to be processed, but no threads left, + * So just remove it from the waiting list + */ + list_del_init(&xprt->xpt_ready); + clear_bit(XPT_BUSY, &xprt->xpt_flags); + } + svc_close_xprt(xprt); + } +} + +/* + * Handle defer and revisit of requests + */ + +static void svc_revisit(struct cache_deferred_req *dreq, int too_many) +{ + struct svc_deferred_req *dr = + container_of(dreq, struct svc_deferred_req, handle); + struct svc_xprt *xprt = dr->xprt; + + if (too_many) { + svc_xprt_put(xprt); + kfree(dr); + return; + } + dprintk("revisit queued\n"); + dr->xprt = NULL; + spin_lock(&xprt->xpt_lock); + list_add(&dr->handle.recent, &xprt->xpt_deferred); + spin_unlock(&xprt->xpt_lock); + set_bit(XPT_DEFERRED, &xprt->xpt_flags); + svc_xprt_enqueue(xprt); + svc_xprt_put(xprt); +} + +/* + * Save the request off for later processing. The request buffer looks + * like this: + * + * <xprt-header><rpc-header><rpc-pagelist><rpc-tail> + * + * This code can only handle requests that consist of an xprt-header + * and rpc-header. + */ +static struct cache_deferred_req *svc_defer(struct cache_req *req) +{ + struct svc_rqst *rqstp = container_of(req, struct svc_rqst, rq_chandle); + struct svc_deferred_req *dr; + + if (rqstp->rq_arg.page_len) + return NULL; /* if more than a page, give up FIXME */ + if (rqstp->rq_deferred) { + dr = rqstp->rq_deferred; + rqstp->rq_deferred = NULL; + } else { + size_t skip; + size_t size; + /* FIXME maybe discard if size too large */ + size = sizeof(struct svc_deferred_req) + rqstp->rq_arg.len; + dr = kmalloc(size, GFP_KERNEL); + if (dr == NULL) + return NULL; + + dr->handle.owner = rqstp->rq_server; + dr->prot = rqstp->rq_prot; + memcpy(&dr->addr, &rqstp->rq_addr, rqstp->rq_addrlen); + dr->addrlen = rqstp->rq_addrlen; + dr->daddr = rqstp->rq_daddr; + dr->argslen = rqstp->rq_arg.len >> 2; + dr->xprt_hlen = rqstp->rq_xprt_hlen; + + /* back up head to the start of the buffer and copy */ + skip = rqstp->rq_arg.len - rqstp->rq_arg.head[0].iov_len; + memcpy(dr->args, rqstp->rq_arg.head[0].iov_base - skip, + dr->argslen << 2); + } + svc_xprt_get(rqstp->rq_xprt); + dr->xprt = rqstp->rq_xprt; + + dr->handle.revisit = svc_revisit; + return &dr->handle; +} + +/* + * recv data from a deferred request into an active one + */ +static int svc_deferred_recv(struct svc_rqst *rqstp) +{ + struct svc_deferred_req *dr = rqstp->rq_deferred; + + /* setup iov_base past transport header */ + rqstp->rq_arg.head[0].iov_base = dr->args + (dr->xprt_hlen>>2); + /* The iov_len does not include the transport header bytes */ + rqstp->rq_arg.head[0].iov_len = (dr->argslen<<2) - dr->xprt_hlen; + rqstp->rq_arg.page_len = 0; + /* The rq_arg.len includes the transport header bytes */ + rqstp->rq_arg.len = dr->argslen<<2; + rqstp->rq_prot = dr->prot; + memcpy(&rqstp->rq_addr, &dr->addr, dr->addrlen); + rqstp->rq_addrlen = dr->addrlen; + /* Save off transport header len in case we get deferred again */ + rqstp->rq_xprt_hlen = dr->xprt_hlen; + rqstp->rq_daddr = dr->daddr; + rqstp->rq_respages = rqstp->rq_pages; + return (dr->argslen<<2) - dr->xprt_hlen; +} + + +static struct svc_deferred_req *svc_deferred_dequeue(struct svc_xprt *xprt) +{ + struct svc_deferred_req *dr = NULL; + + if (!test_bit(XPT_DEFERRED, &xprt->xpt_flags)) + return NULL; + spin_lock(&xprt->xpt_lock); + clear_bit(XPT_DEFERRED, &xprt->xpt_flags); + if (!list_empty(&xprt->xpt_deferred)) { + dr = list_entry(xprt->xpt_deferred.next, + struct svc_deferred_req, + handle.recent); + list_del_init(&dr->handle.recent); + set_bit(XPT_DEFERRED, &xprt->xpt_flags); + } + spin_unlock(&xprt->xpt_lock); + return dr; +} + +/* + * Return the transport instance pointer for the endpoint accepting + * connections/peer traffic from the specified transport class, + * address family and port. + * + * Specifying 0 for the address family or port is effectively a + * wild-card, and will result in matching the first transport in the + * service's list that has a matching class name. + */ +struct svc_xprt *svc_find_xprt(struct svc_serv *serv, char *xcl_name, + int af, int port) +{ + struct svc_xprt *xprt; + struct svc_xprt *found = NULL; + + /* Sanity check the args */ + if (!serv || !xcl_name) + return found; + + spin_lock_bh(&serv->sv_lock); + list_for_each_entry(xprt, &serv->sv_permsocks, xpt_list) { + if (strcmp(xprt->xpt_class->xcl_name, xcl_name)) + continue; + if (af != AF_UNSPEC && af != xprt->xpt_local.ss_family) + continue; + if (port && port != svc_xprt_local_port(xprt)) + continue; + found = xprt; + svc_xprt_get(xprt); + break; + } + spin_unlock_bh(&serv->sv_lock); + return found; +} +EXPORT_SYMBOL_GPL(svc_find_xprt); + +/* + * Format a buffer with a list of the active transports. A zero for + * the buflen parameter disables target buffer overflow checking. + */ +int svc_xprt_names(struct svc_serv *serv, char *buf, int buflen) +{ + struct svc_xprt *xprt; + char xprt_str[64]; + int totlen = 0; + int len; + + /* Sanity check args */ + if (!serv) + return 0; + + spin_lock_bh(&serv->sv_lock); + list_for_each_entry(xprt, &serv->sv_permsocks, xpt_list) { + len = snprintf(xprt_str, sizeof(xprt_str), + "%s %d\n", xprt->xpt_class->xcl_name, + svc_xprt_local_port(xprt)); + /* If the string was truncated, replace with error string */ + if (len >= sizeof(xprt_str)) + strcpy(xprt_str, "name-too-long\n"); + /* Don't overflow buffer */ + len = strlen(xprt_str); + if (buflen && (len + totlen >= buflen)) + break; + strcpy(buf+totlen, xprt_str); + totlen += len; + } + spin_unlock_bh(&serv->sv_lock); + return totlen; +} +EXPORT_SYMBOL_GPL(svc_xprt_names); diff --git a/net/sunrpc/svcauth.c b/net/sunrpc/svcauth.c new file mode 100644 index 0000000..8a73cbb --- /dev/null +++ b/net/sunrpc/svcauth.c @@ -0,0 +1,166 @@ +/* + * linux/net/sunrpc/svcauth.c + * + * The generic interface for RPC authentication on the server side. + * + * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de> + * + * CHANGES + * 19-Apr-2000 Chris Evans - Security fix + */ + +#include <linux/types.h> +#include <linux/module.h> +#include <linux/sunrpc/types.h> +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/svcauth.h> +#include <linux/err.h> +#include <linux/hash.h> + +#define RPCDBG_FACILITY RPCDBG_AUTH + + +/* + * Table of authenticators + */ +extern struct auth_ops svcauth_null; +extern struct auth_ops svcauth_unix; + +static DEFINE_SPINLOCK(authtab_lock); +static struct auth_ops *authtab[RPC_AUTH_MAXFLAVOR] = { + [0] = &svcauth_null, + [1] = &svcauth_unix, +}; + +int +svc_authenticate(struct svc_rqst *rqstp, __be32 *authp) +{ + rpc_authflavor_t flavor; + struct auth_ops *aops; + + *authp = rpc_auth_ok; + + flavor = svc_getnl(&rqstp->rq_arg.head[0]); + + dprintk("svc: svc_authenticate (%d)\n", flavor); + + spin_lock(&authtab_lock); + if (flavor >= RPC_AUTH_MAXFLAVOR || !(aops = authtab[flavor]) + || !try_module_get(aops->owner)) { + spin_unlock(&authtab_lock); + *authp = rpc_autherr_badcred; + return SVC_DENIED; + } + spin_unlock(&authtab_lock); + + rqstp->rq_authop = aops; + return aops->accept(rqstp, authp); +} +EXPORT_SYMBOL(svc_authenticate); + +int svc_set_client(struct svc_rqst *rqstp) +{ + return rqstp->rq_authop->set_client(rqstp); +} +EXPORT_SYMBOL(svc_set_client); + +/* A request, which was authenticated, has now executed. + * Time to finalise the credentials and verifier + * and release and resources + */ +int svc_authorise(struct svc_rqst *rqstp) +{ + struct auth_ops *aops = rqstp->rq_authop; + int rv = 0; + + rqstp->rq_authop = NULL; + + if (aops) { + rv = aops->release(rqstp); + module_put(aops->owner); + } + return rv; +} + +int +svc_auth_register(rpc_authflavor_t flavor, struct auth_ops *aops) +{ + int rv = -EINVAL; + spin_lock(&authtab_lock); + if (flavor < RPC_AUTH_MAXFLAVOR && authtab[flavor] == NULL) { + authtab[flavor] = aops; + rv = 0; + } + spin_unlock(&authtab_lock); + return rv; +} +EXPORT_SYMBOL(svc_auth_register); + +void +svc_auth_unregister(rpc_authflavor_t flavor) +{ + spin_lock(&authtab_lock); + if (flavor < RPC_AUTH_MAXFLAVOR) + authtab[flavor] = NULL; + spin_unlock(&authtab_lock); +} +EXPORT_SYMBOL(svc_auth_unregister); + +/************************************************** + * 'auth_domains' are stored in a hash table indexed by name. + * When the last reference to an 'auth_domain' is dropped, + * the object is unhashed and freed. + * If auth_domain_lookup fails to find an entry, it will return + * it's second argument 'new'. If this is non-null, it will + * have been atomically linked into the table. + */ + +#define DN_HASHBITS 6 +#define DN_HASHMAX (1<<DN_HASHBITS) +#define DN_HASHMASK (DN_HASHMAX-1) + +static struct hlist_head auth_domain_table[DN_HASHMAX]; +static spinlock_t auth_domain_lock = + __SPIN_LOCK_UNLOCKED(auth_domain_lock); + +void auth_domain_put(struct auth_domain *dom) +{ + if (atomic_dec_and_lock(&dom->ref.refcount, &auth_domain_lock)) { + hlist_del(&dom->hash); + dom->flavour->domain_release(dom); + spin_unlock(&auth_domain_lock); + } +} +EXPORT_SYMBOL(auth_domain_put); + +struct auth_domain * +auth_domain_lookup(char *name, struct auth_domain *new) +{ + struct auth_domain *hp; + struct hlist_head *head; + struct hlist_node *np; + + head = &auth_domain_table[hash_str(name, DN_HASHBITS)]; + + spin_lock(&auth_domain_lock); + + hlist_for_each_entry(hp, np, head, hash) { + if (strcmp(hp->name, name)==0) { + kref_get(&hp->ref); + spin_unlock(&auth_domain_lock); + return hp; + } + } + if (new) + hlist_add_head(&new->hash, head); + spin_unlock(&auth_domain_lock); + return new; +} +EXPORT_SYMBOL(auth_domain_lookup); + +struct auth_domain *auth_domain_find(char *name) +{ + return auth_domain_lookup(name, NULL); +} +EXPORT_SYMBOL(auth_domain_find); diff --git a/net/sunrpc/svcauth_unix.c b/net/sunrpc/svcauth_unix.c new file mode 100644 index 0000000..f24800f --- /dev/null +++ b/net/sunrpc/svcauth_unix.c @@ -0,0 +1,875 @@ +#include <linux/types.h> +#include <linux/sched.h> +#include <linux/module.h> +#include <linux/sunrpc/types.h> +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/svcauth.h> +#include <linux/sunrpc/gss_api.h> +#include <linux/err.h> +#include <linux/seq_file.h> +#include <linux/hash.h> +#include <linux/string.h> +#include <net/sock.h> +#include <net/ipv6.h> +#include <linux/kernel.h> +#define RPCDBG_FACILITY RPCDBG_AUTH + + +/* + * AUTHUNIX and AUTHNULL credentials are both handled here. + * AUTHNULL is treated just like AUTHUNIX except that the uid/gid + * are always nobody (-2). i.e. we do the same IP address checks for + * AUTHNULL as for AUTHUNIX, and that is done here. + */ + + +struct unix_domain { + struct auth_domain h; + int addr_changes; + /* other stuff later */ +}; + +extern struct auth_ops svcauth_unix; + +struct auth_domain *unix_domain_find(char *name) +{ + struct auth_domain *rv; + struct unix_domain *new = NULL; + + rv = auth_domain_lookup(name, NULL); + while(1) { + if (rv) { + if (new && rv != &new->h) + auth_domain_put(&new->h); + + if (rv->flavour != &svcauth_unix) { + auth_domain_put(rv); + return NULL; + } + return rv; + } + + new = kmalloc(sizeof(*new), GFP_KERNEL); + if (new == NULL) + return NULL; + kref_init(&new->h.ref); + new->h.name = kstrdup(name, GFP_KERNEL); + if (new->h.name == NULL) { + kfree(new); + return NULL; + } + new->h.flavour = &svcauth_unix; + new->addr_changes = 0; + rv = auth_domain_lookup(name, &new->h); + } +} +EXPORT_SYMBOL(unix_domain_find); + +static void svcauth_unix_domain_release(struct auth_domain *dom) +{ + struct unix_domain *ud = container_of(dom, struct unix_domain, h); + + kfree(dom->name); + kfree(ud); +} + + +/************************************************** + * cache for IP address to unix_domain + * as needed by AUTH_UNIX + */ +#define IP_HASHBITS 8 +#define IP_HASHMAX (1<<IP_HASHBITS) +#define IP_HASHMASK (IP_HASHMAX-1) + +struct ip_map { + struct cache_head h; + char m_class[8]; /* e.g. "nfsd" */ + struct in6_addr m_addr; + struct unix_domain *m_client; + int m_add_change; +}; +static struct cache_head *ip_table[IP_HASHMAX]; + +static void ip_map_put(struct kref *kref) +{ + struct cache_head *item = container_of(kref, struct cache_head, ref); + struct ip_map *im = container_of(item, struct ip_map,h); + + if (test_bit(CACHE_VALID, &item->flags) && + !test_bit(CACHE_NEGATIVE, &item->flags)) + auth_domain_put(&im->m_client->h); + kfree(im); +} + +#if IP_HASHBITS == 8 +/* hash_long on a 64 bit machine is currently REALLY BAD for + * IP addresses in reverse-endian (i.e. on a little-endian machine). + * So use a trivial but reliable hash instead + */ +static inline int hash_ip(__be32 ip) +{ + int hash = (__force u32)ip ^ ((__force u32)ip>>16); + return (hash ^ (hash>>8)) & 0xff; +} +#endif +static inline int hash_ip6(struct in6_addr ip) +{ + return (hash_ip(ip.s6_addr32[0]) ^ + hash_ip(ip.s6_addr32[1]) ^ + hash_ip(ip.s6_addr32[2]) ^ + hash_ip(ip.s6_addr32[3])); +} +static int ip_map_match(struct cache_head *corig, struct cache_head *cnew) +{ + struct ip_map *orig = container_of(corig, struct ip_map, h); + struct ip_map *new = container_of(cnew, struct ip_map, h); + return strcmp(orig->m_class, new->m_class) == 0 + && ipv6_addr_equal(&orig->m_addr, &new->m_addr); +} +static void ip_map_init(struct cache_head *cnew, struct cache_head *citem) +{ + struct ip_map *new = container_of(cnew, struct ip_map, h); + struct ip_map *item = container_of(citem, struct ip_map, h); + + strcpy(new->m_class, item->m_class); + ipv6_addr_copy(&new->m_addr, &item->m_addr); +} +static void update(struct cache_head *cnew, struct cache_head *citem) +{ + struct ip_map *new = container_of(cnew, struct ip_map, h); + struct ip_map *item = container_of(citem, struct ip_map, h); + + kref_get(&item->m_client->h.ref); + new->m_client = item->m_client; + new->m_add_change = item->m_add_change; +} +static struct cache_head *ip_map_alloc(void) +{ + struct ip_map *i = kmalloc(sizeof(*i), GFP_KERNEL); + if (i) + return &i->h; + else + return NULL; +} + +static void ip_map_request(struct cache_detail *cd, + struct cache_head *h, + char **bpp, int *blen) +{ + char text_addr[40]; + struct ip_map *im = container_of(h, struct ip_map, h); + + if (ipv6_addr_v4mapped(&(im->m_addr))) { + snprintf(text_addr, 20, NIPQUAD_FMT, + ntohl(im->m_addr.s6_addr32[3]) >> 24 & 0xff, + ntohl(im->m_addr.s6_addr32[3]) >> 16 & 0xff, + ntohl(im->m_addr.s6_addr32[3]) >> 8 & 0xff, + ntohl(im->m_addr.s6_addr32[3]) >> 0 & 0xff); + } else { + snprintf(text_addr, 40, NIP6_FMT, NIP6(im->m_addr)); + } + qword_add(bpp, blen, im->m_class); + qword_add(bpp, blen, text_addr); + (*bpp)[-1] = '\n'; +} + +static struct ip_map *ip_map_lookup(char *class, struct in6_addr *addr); +static int ip_map_update(struct ip_map *ipm, struct unix_domain *udom, time_t expiry); + +static int ip_map_parse(struct cache_detail *cd, + char *mesg, int mlen) +{ + /* class ipaddress [domainname] */ + /* should be safe just to use the start of the input buffer + * for scratch: */ + char *buf = mesg; + int len; + int b1, b2, b3, b4, b5, b6, b7, b8; + char c; + char class[8]; + struct in6_addr addr; + int err; + + struct ip_map *ipmp; + struct auth_domain *dom; + time_t expiry; + + if (mesg[mlen-1] != '\n') + return -EINVAL; + mesg[mlen-1] = 0; + + /* class */ + len = qword_get(&mesg, class, sizeof(class)); + if (len <= 0) return -EINVAL; + + /* ip address */ + len = qword_get(&mesg, buf, mlen); + if (len <= 0) return -EINVAL; + + if (sscanf(buf, NIPQUAD_FMT "%c", &b1, &b2, &b3, &b4, &c) == 4) { + addr.s6_addr32[0] = 0; + addr.s6_addr32[1] = 0; + addr.s6_addr32[2] = htonl(0xffff); + addr.s6_addr32[3] = + htonl((((((b1<<8)|b2)<<8)|b3)<<8)|b4); + } else if (sscanf(buf, NIP6_FMT "%c", + &b1, &b2, &b3, &b4, &b5, &b6, &b7, &b8, &c) == 8) { + addr.s6_addr16[0] = htons(b1); + addr.s6_addr16[1] = htons(b2); + addr.s6_addr16[2] = htons(b3); + addr.s6_addr16[3] = htons(b4); + addr.s6_addr16[4] = htons(b5); + addr.s6_addr16[5] = htons(b6); + addr.s6_addr16[6] = htons(b7); + addr.s6_addr16[7] = htons(b8); + } else + return -EINVAL; + + expiry = get_expiry(&mesg); + if (expiry ==0) + return -EINVAL; + + /* domainname, or empty for NEGATIVE */ + len = qword_get(&mesg, buf, mlen); + if (len < 0) return -EINVAL; + + if (len) { + dom = unix_domain_find(buf); + if (dom == NULL) + return -ENOENT; + } else + dom = NULL; + + ipmp = ip_map_lookup(class, &addr); + if (ipmp) { + err = ip_map_update(ipmp, + container_of(dom, struct unix_domain, h), + expiry); + } else + err = -ENOMEM; + + if (dom) + auth_domain_put(dom); + + cache_flush(); + return err; +} + +static int ip_map_show(struct seq_file *m, + struct cache_detail *cd, + struct cache_head *h) +{ + struct ip_map *im; + struct in6_addr addr; + char *dom = "-no-domain-"; + + if (h == NULL) { + seq_puts(m, "#class IP domain\n"); + return 0; + } + im = container_of(h, struct ip_map, h); + /* class addr domain */ + ipv6_addr_copy(&addr, &im->m_addr); + + if (test_bit(CACHE_VALID, &h->flags) && + !test_bit(CACHE_NEGATIVE, &h->flags)) + dom = im->m_client->h.name; + + if (ipv6_addr_v4mapped(&addr)) { + seq_printf(m, "%s " NIPQUAD_FMT " %s\n", + im->m_class, + ntohl(addr.s6_addr32[3]) >> 24 & 0xff, + ntohl(addr.s6_addr32[3]) >> 16 & 0xff, + ntohl(addr.s6_addr32[3]) >> 8 & 0xff, + ntohl(addr.s6_addr32[3]) >> 0 & 0xff, + dom); + } else { + seq_printf(m, "%s " NIP6_FMT " %s\n", + im->m_class, NIP6(addr), dom); + } + return 0; +} + + +struct cache_detail ip_map_cache = { + .owner = THIS_MODULE, + .hash_size = IP_HASHMAX, + .hash_table = ip_table, + .name = "auth.unix.ip", + .cache_put = ip_map_put, + .cache_request = ip_map_request, + .cache_parse = ip_map_parse, + .cache_show = ip_map_show, + .match = ip_map_match, + .init = ip_map_init, + .update = update, + .alloc = ip_map_alloc, +}; + +static struct ip_map *ip_map_lookup(char *class, struct in6_addr *addr) +{ + struct ip_map ip; + struct cache_head *ch; + + strcpy(ip.m_class, class); + ipv6_addr_copy(&ip.m_addr, addr); + ch = sunrpc_cache_lookup(&ip_map_cache, &ip.h, + hash_str(class, IP_HASHBITS) ^ + hash_ip6(*addr)); + + if (ch) + return container_of(ch, struct ip_map, h); + else + return NULL; +} + +static int ip_map_update(struct ip_map *ipm, struct unix_domain *udom, time_t expiry) +{ + struct ip_map ip; + struct cache_head *ch; + + ip.m_client = udom; + ip.h.flags = 0; + if (!udom) + set_bit(CACHE_NEGATIVE, &ip.h.flags); + else { + ip.m_add_change = udom->addr_changes; + /* if this is from the legacy set_client system call, + * we need m_add_change to be one higher + */ + if (expiry == NEVER) + ip.m_add_change++; + } + ip.h.expiry_time = expiry; + ch = sunrpc_cache_update(&ip_map_cache, + &ip.h, &ipm->h, + hash_str(ipm->m_class, IP_HASHBITS) ^ + hash_ip6(ipm->m_addr)); + if (!ch) + return -ENOMEM; + cache_put(ch, &ip_map_cache); + return 0; +} + +int auth_unix_add_addr(struct in6_addr *addr, struct auth_domain *dom) +{ + struct unix_domain *udom; + struct ip_map *ipmp; + + if (dom->flavour != &svcauth_unix) + return -EINVAL; + udom = container_of(dom, struct unix_domain, h); + ipmp = ip_map_lookup("nfsd", addr); + + if (ipmp) + return ip_map_update(ipmp, udom, NEVER); + else + return -ENOMEM; +} +EXPORT_SYMBOL(auth_unix_add_addr); + +int auth_unix_forget_old(struct auth_domain *dom) +{ + struct unix_domain *udom; + + if (dom->flavour != &svcauth_unix) + return -EINVAL; + udom = container_of(dom, struct unix_domain, h); + udom->addr_changes++; + return 0; +} +EXPORT_SYMBOL(auth_unix_forget_old); + +struct auth_domain *auth_unix_lookup(struct in6_addr *addr) +{ + struct ip_map *ipm; + struct auth_domain *rv; + + ipm = ip_map_lookup("nfsd", addr); + + if (!ipm) + return NULL; + if (cache_check(&ip_map_cache, &ipm->h, NULL)) + return NULL; + + if ((ipm->m_client->addr_changes - ipm->m_add_change) >0) { + if (test_and_set_bit(CACHE_NEGATIVE, &ipm->h.flags) == 0) + auth_domain_put(&ipm->m_client->h); + rv = NULL; + } else { + rv = &ipm->m_client->h; + kref_get(&rv->ref); + } + cache_put(&ipm->h, &ip_map_cache); + return rv; +} +EXPORT_SYMBOL(auth_unix_lookup); + +void svcauth_unix_purge(void) +{ + cache_purge(&ip_map_cache); +} +EXPORT_SYMBOL(svcauth_unix_purge); + +static inline struct ip_map * +ip_map_cached_get(struct svc_rqst *rqstp) +{ + struct ip_map *ipm = NULL; + struct svc_xprt *xprt = rqstp->rq_xprt; + + if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) { + spin_lock(&xprt->xpt_lock); + ipm = xprt->xpt_auth_cache; + if (ipm != NULL) { + if (!cache_valid(&ipm->h)) { + /* + * The entry has been invalidated since it was + * remembered, e.g. by a second mount from the + * same IP address. + */ + xprt->xpt_auth_cache = NULL; + spin_unlock(&xprt->xpt_lock); + cache_put(&ipm->h, &ip_map_cache); + return NULL; + } + cache_get(&ipm->h); + } + spin_unlock(&xprt->xpt_lock); + } + return ipm; +} + +static inline void +ip_map_cached_put(struct svc_rqst *rqstp, struct ip_map *ipm) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + + if (test_bit(XPT_CACHE_AUTH, &xprt->xpt_flags)) { + spin_lock(&xprt->xpt_lock); + if (xprt->xpt_auth_cache == NULL) { + /* newly cached, keep the reference */ + xprt->xpt_auth_cache = ipm; + ipm = NULL; + } + spin_unlock(&xprt->xpt_lock); + } + if (ipm) + cache_put(&ipm->h, &ip_map_cache); +} + +void +svcauth_unix_info_release(void *info) +{ + struct ip_map *ipm = info; + cache_put(&ipm->h, &ip_map_cache); +} + +/**************************************************************************** + * auth.unix.gid cache + * simple cache to map a UID to a list of GIDs + * because AUTH_UNIX aka AUTH_SYS has a max of 16 + */ +#define GID_HASHBITS 8 +#define GID_HASHMAX (1<<GID_HASHBITS) +#define GID_HASHMASK (GID_HASHMAX - 1) + +struct unix_gid { + struct cache_head h; + uid_t uid; + struct group_info *gi; +}; +static struct cache_head *gid_table[GID_HASHMAX]; + +static void unix_gid_put(struct kref *kref) +{ + struct cache_head *item = container_of(kref, struct cache_head, ref); + struct unix_gid *ug = container_of(item, struct unix_gid, h); + if (test_bit(CACHE_VALID, &item->flags) && + !test_bit(CACHE_NEGATIVE, &item->flags)) + put_group_info(ug->gi); + kfree(ug); +} + +static int unix_gid_match(struct cache_head *corig, struct cache_head *cnew) +{ + struct unix_gid *orig = container_of(corig, struct unix_gid, h); + struct unix_gid *new = container_of(cnew, struct unix_gid, h); + return orig->uid == new->uid; +} +static void unix_gid_init(struct cache_head *cnew, struct cache_head *citem) +{ + struct unix_gid *new = container_of(cnew, struct unix_gid, h); + struct unix_gid *item = container_of(citem, struct unix_gid, h); + new->uid = item->uid; +} +static void unix_gid_update(struct cache_head *cnew, struct cache_head *citem) +{ + struct unix_gid *new = container_of(cnew, struct unix_gid, h); + struct unix_gid *item = container_of(citem, struct unix_gid, h); + + get_group_info(item->gi); + new->gi = item->gi; +} +static struct cache_head *unix_gid_alloc(void) +{ + struct unix_gid *g = kmalloc(sizeof(*g), GFP_KERNEL); + if (g) + return &g->h; + else + return NULL; +} + +static void unix_gid_request(struct cache_detail *cd, + struct cache_head *h, + char **bpp, int *blen) +{ + char tuid[20]; + struct unix_gid *ug = container_of(h, struct unix_gid, h); + + snprintf(tuid, 20, "%u", ug->uid); + qword_add(bpp, blen, tuid); + (*bpp)[-1] = '\n'; +} + +static struct unix_gid *unix_gid_lookup(uid_t uid); +extern struct cache_detail unix_gid_cache; + +static int unix_gid_parse(struct cache_detail *cd, + char *mesg, int mlen) +{ + /* uid expiry Ngid gid0 gid1 ... gidN-1 */ + int uid; + int gids; + int rv; + int i; + int err; + time_t expiry; + struct unix_gid ug, *ugp; + + if (mlen <= 0 || mesg[mlen-1] != '\n') + return -EINVAL; + mesg[mlen-1] = 0; + + rv = get_int(&mesg, &uid); + if (rv) + return -EINVAL; + ug.uid = uid; + + expiry = get_expiry(&mesg); + if (expiry == 0) + return -EINVAL; + + rv = get_int(&mesg, &gids); + if (rv || gids < 0 || gids > 8192) + return -EINVAL; + + ug.gi = groups_alloc(gids); + if (!ug.gi) + return -ENOMEM; + + for (i = 0 ; i < gids ; i++) { + int gid; + rv = get_int(&mesg, &gid); + err = -EINVAL; + if (rv) + goto out; + GROUP_AT(ug.gi, i) = gid; + } + + ugp = unix_gid_lookup(uid); + if (ugp) { + struct cache_head *ch; + ug.h.flags = 0; + ug.h.expiry_time = expiry; + ch = sunrpc_cache_update(&unix_gid_cache, + &ug.h, &ugp->h, + hash_long(uid, GID_HASHBITS)); + if (!ch) + err = -ENOMEM; + else { + err = 0; + cache_put(ch, &unix_gid_cache); + } + } else + err = -ENOMEM; + out: + if (ug.gi) + put_group_info(ug.gi); + return err; +} + +static int unix_gid_show(struct seq_file *m, + struct cache_detail *cd, + struct cache_head *h) +{ + struct unix_gid *ug; + int i; + int glen; + + if (h == NULL) { + seq_puts(m, "#uid cnt: gids...\n"); + return 0; + } + ug = container_of(h, struct unix_gid, h); + if (test_bit(CACHE_VALID, &h->flags) && + !test_bit(CACHE_NEGATIVE, &h->flags)) + glen = ug->gi->ngroups; + else + glen = 0; + + seq_printf(m, "%d %d:", ug->uid, glen); + for (i = 0; i < glen; i++) + seq_printf(m, " %d", GROUP_AT(ug->gi, i)); + seq_printf(m, "\n"); + return 0; +} + +struct cache_detail unix_gid_cache = { + .owner = THIS_MODULE, + .hash_size = GID_HASHMAX, + .hash_table = gid_table, + .name = "auth.unix.gid", + .cache_put = unix_gid_put, + .cache_request = unix_gid_request, + .cache_parse = unix_gid_parse, + .cache_show = unix_gid_show, + .match = unix_gid_match, + .init = unix_gid_init, + .update = unix_gid_update, + .alloc = unix_gid_alloc, +}; + +static struct unix_gid *unix_gid_lookup(uid_t uid) +{ + struct unix_gid ug; + struct cache_head *ch; + + ug.uid = uid; + ch = sunrpc_cache_lookup(&unix_gid_cache, &ug.h, + hash_long(uid, GID_HASHBITS)); + if (ch) + return container_of(ch, struct unix_gid, h); + else + return NULL; +} + +static int unix_gid_find(uid_t uid, struct group_info **gip, + struct svc_rqst *rqstp) +{ + struct unix_gid *ug = unix_gid_lookup(uid); + if (!ug) + return -EAGAIN; + switch (cache_check(&unix_gid_cache, &ug->h, &rqstp->rq_chandle)) { + case -ENOENT: + *gip = NULL; + return 0; + case 0: + *gip = ug->gi; + get_group_info(*gip); + return 0; + default: + return -EAGAIN; + } +} + +int +svcauth_unix_set_client(struct svc_rqst *rqstp) +{ + struct sockaddr_in *sin; + struct sockaddr_in6 *sin6, sin6_storage; + struct ip_map *ipm; + + switch (rqstp->rq_addr.ss_family) { + case AF_INET: + sin = svc_addr_in(rqstp); + sin6 = &sin6_storage; + ipv6_addr_set(&sin6->sin6_addr, 0, 0, + htonl(0x0000FFFF), sin->sin_addr.s_addr); + break; + case AF_INET6: + sin6 = svc_addr_in6(rqstp); + break; + default: + BUG(); + } + + rqstp->rq_client = NULL; + if (rqstp->rq_proc == 0) + return SVC_OK; + + ipm = ip_map_cached_get(rqstp); + if (ipm == NULL) + ipm = ip_map_lookup(rqstp->rq_server->sv_program->pg_class, + &sin6->sin6_addr); + + if (ipm == NULL) + return SVC_DENIED; + + switch (cache_check(&ip_map_cache, &ipm->h, &rqstp->rq_chandle)) { + default: + BUG(); + case -EAGAIN: + case -ETIMEDOUT: + return SVC_DROP; + case -ENOENT: + return SVC_DENIED; + case 0: + rqstp->rq_client = &ipm->m_client->h; + kref_get(&rqstp->rq_client->ref); + ip_map_cached_put(rqstp, ipm); + break; + } + return SVC_OK; +} + +EXPORT_SYMBOL(svcauth_unix_set_client); + +static int +svcauth_null_accept(struct svc_rqst *rqstp, __be32 *authp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + struct svc_cred *cred = &rqstp->rq_cred; + + cred->cr_group_info = NULL; + rqstp->rq_client = NULL; + + if (argv->iov_len < 3*4) + return SVC_GARBAGE; + + if (svc_getu32(argv) != 0) { + dprintk("svc: bad null cred\n"); + *authp = rpc_autherr_badcred; + return SVC_DENIED; + } + if (svc_getu32(argv) != htonl(RPC_AUTH_NULL) || svc_getu32(argv) != 0) { + dprintk("svc: bad null verf\n"); + *authp = rpc_autherr_badverf; + return SVC_DENIED; + } + + /* Signal that mapping to nobody uid/gid is required */ + cred->cr_uid = (uid_t) -1; + cred->cr_gid = (gid_t) -1; + cred->cr_group_info = groups_alloc(0); + if (cred->cr_group_info == NULL) + return SVC_DROP; /* kmalloc failure - client must retry */ + + /* Put NULL verifier */ + svc_putnl(resv, RPC_AUTH_NULL); + svc_putnl(resv, 0); + + rqstp->rq_flavor = RPC_AUTH_NULL; + return SVC_OK; +} + +static int +svcauth_null_release(struct svc_rqst *rqstp) +{ + if (rqstp->rq_client) + auth_domain_put(rqstp->rq_client); + rqstp->rq_client = NULL; + if (rqstp->rq_cred.cr_group_info) + put_group_info(rqstp->rq_cred.cr_group_info); + rqstp->rq_cred.cr_group_info = NULL; + + return 0; /* don't drop */ +} + + +struct auth_ops svcauth_null = { + .name = "null", + .owner = THIS_MODULE, + .flavour = RPC_AUTH_NULL, + .accept = svcauth_null_accept, + .release = svcauth_null_release, + .set_client = svcauth_unix_set_client, +}; + + +static int +svcauth_unix_accept(struct svc_rqst *rqstp, __be32 *authp) +{ + struct kvec *argv = &rqstp->rq_arg.head[0]; + struct kvec *resv = &rqstp->rq_res.head[0]; + struct svc_cred *cred = &rqstp->rq_cred; + u32 slen, i; + int len = argv->iov_len; + + cred->cr_group_info = NULL; + rqstp->rq_client = NULL; + + if ((len -= 3*4) < 0) + return SVC_GARBAGE; + + svc_getu32(argv); /* length */ + svc_getu32(argv); /* time stamp */ + slen = XDR_QUADLEN(svc_getnl(argv)); /* machname length */ + if (slen > 64 || (len -= (slen + 3)*4) < 0) + goto badcred; + argv->iov_base = (void*)((__be32*)argv->iov_base + slen); /* skip machname */ + argv->iov_len -= slen*4; + + cred->cr_uid = svc_getnl(argv); /* uid */ + cred->cr_gid = svc_getnl(argv); /* gid */ + slen = svc_getnl(argv); /* gids length */ + if (slen > 16 || (len -= (slen + 2)*4) < 0) + goto badcred; + if (unix_gid_find(cred->cr_uid, &cred->cr_group_info, rqstp) + == -EAGAIN) + return SVC_DROP; + if (cred->cr_group_info == NULL) { + cred->cr_group_info = groups_alloc(slen); + if (cred->cr_group_info == NULL) + return SVC_DROP; + for (i = 0; i < slen; i++) + GROUP_AT(cred->cr_group_info, i) = svc_getnl(argv); + } else { + for (i = 0; i < slen ; i++) + svc_getnl(argv); + } + if (svc_getu32(argv) != htonl(RPC_AUTH_NULL) || svc_getu32(argv) != 0) { + *authp = rpc_autherr_badverf; + return SVC_DENIED; + } + + /* Put NULL verifier */ + svc_putnl(resv, RPC_AUTH_NULL); + svc_putnl(resv, 0); + + rqstp->rq_flavor = RPC_AUTH_UNIX; + return SVC_OK; + +badcred: + *authp = rpc_autherr_badcred; + return SVC_DENIED; +} + +static int +svcauth_unix_release(struct svc_rqst *rqstp) +{ + /* Verifier (such as it is) is already in place. + */ + if (rqstp->rq_client) + auth_domain_put(rqstp->rq_client); + rqstp->rq_client = NULL; + if (rqstp->rq_cred.cr_group_info) + put_group_info(rqstp->rq_cred.cr_group_info); + rqstp->rq_cred.cr_group_info = NULL; + + return 0; +} + + +struct auth_ops svcauth_unix = { + .name = "unix", + .owner = THIS_MODULE, + .flavour = RPC_AUTH_UNIX, + .accept = svcauth_unix_accept, + .release = svcauth_unix_release, + .domain_release = svcauth_unix_domain_release, + .set_client = svcauth_unix_set_client, +}; + diff --git a/net/sunrpc/svcsock.c b/net/sunrpc/svcsock.c new file mode 100644 index 0000000..a1951dc --- /dev/null +++ b/net/sunrpc/svcsock.c @@ -0,0 +1,1305 @@ +/* + * linux/net/sunrpc/svcsock.c + * + * These are the RPC server socket internals. + * + * The server scheduling algorithm does not always distribute the load + * evenly when servicing a single client. May need to modify the + * svc_xprt_enqueue procedure... + * + * TCP support is largely untested and may be a little slow. The problem + * is that we currently do two separate recvfrom's, one for the 4-byte + * record length, and the second for the actual record. This could possibly + * be improved by always reading a minimum size of around 100 bytes and + * tucking any superfluous bytes away in a temporary store. Still, that + * leaves write requests out in the rain. An alternative may be to peek at + * the first skb in the queue, and if it matches the next TCP sequence + * number, to extract the record marker. Yuck. + * + * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/kernel.h> +#include <linux/sched.h> +#include <linux/errno.h> +#include <linux/fcntl.h> +#include <linux/net.h> +#include <linux/in.h> +#include <linux/inet.h> +#include <linux/udp.h> +#include <linux/tcp.h> +#include <linux/unistd.h> +#include <linux/slab.h> +#include <linux/netdevice.h> +#include <linux/skbuff.h> +#include <linux/file.h> +#include <linux/freezer.h> +#include <net/sock.h> +#include <net/checksum.h> +#include <net/ip.h> +#include <net/ipv6.h> +#include <net/tcp.h> +#include <net/tcp_states.h> +#include <asm/uaccess.h> +#include <asm/ioctls.h> + +#include <linux/sunrpc/types.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/msg_prot.h> +#include <linux/sunrpc/svcsock.h> +#include <linux/sunrpc/stats.h> + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + + +static struct svc_sock *svc_setup_socket(struct svc_serv *, struct socket *, + int *errp, int flags); +static void svc_udp_data_ready(struct sock *, int); +static int svc_udp_recvfrom(struct svc_rqst *); +static int svc_udp_sendto(struct svc_rqst *); +static void svc_sock_detach(struct svc_xprt *); +static void svc_sock_free(struct svc_xprt *); + +static struct svc_xprt *svc_create_socket(struct svc_serv *, int, + struct sockaddr *, int, int); +#ifdef CONFIG_DEBUG_LOCK_ALLOC +static struct lock_class_key svc_key[2]; +static struct lock_class_key svc_slock_key[2]; + +static void svc_reclassify_socket(struct socket *sock) +{ + struct sock *sk = sock->sk; + BUG_ON(sock_owned_by_user(sk)); + switch (sk->sk_family) { + case AF_INET: + sock_lock_init_class_and_name(sk, "slock-AF_INET-NFSD", + &svc_slock_key[0], + "sk_xprt.xpt_lock-AF_INET-NFSD", + &svc_key[0]); + break; + + case AF_INET6: + sock_lock_init_class_and_name(sk, "slock-AF_INET6-NFSD", + &svc_slock_key[1], + "sk_xprt.xpt_lock-AF_INET6-NFSD", + &svc_key[1]); + break; + + default: + BUG(); + } +} +#else +static void svc_reclassify_socket(struct socket *sock) +{ +} +#endif + +/* + * Release an skbuff after use + */ +static void svc_release_skb(struct svc_rqst *rqstp) +{ + struct sk_buff *skb = rqstp->rq_xprt_ctxt; + struct svc_deferred_req *dr = rqstp->rq_deferred; + + if (skb) { + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + rqstp->rq_xprt_ctxt = NULL; + + dprintk("svc: service %p, releasing skb %p\n", rqstp, skb); + skb_free_datagram(svsk->sk_sk, skb); + } + if (dr) { + rqstp->rq_deferred = NULL; + kfree(dr); + } +} + +union svc_pktinfo_u { + struct in_pktinfo pkti; + struct in6_pktinfo pkti6; +}; +#define SVC_PKTINFO_SPACE \ + CMSG_SPACE(sizeof(union svc_pktinfo_u)) + +static void svc_set_cmsg_data(struct svc_rqst *rqstp, struct cmsghdr *cmh) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + switch (svsk->sk_sk->sk_family) { + case AF_INET: { + struct in_pktinfo *pki = CMSG_DATA(cmh); + + cmh->cmsg_level = SOL_IP; + cmh->cmsg_type = IP_PKTINFO; + pki->ipi_ifindex = 0; + pki->ipi_spec_dst.s_addr = rqstp->rq_daddr.addr.s_addr; + cmh->cmsg_len = CMSG_LEN(sizeof(*pki)); + } + break; + + case AF_INET6: { + struct in6_pktinfo *pki = CMSG_DATA(cmh); + + cmh->cmsg_level = SOL_IPV6; + cmh->cmsg_type = IPV6_PKTINFO; + pki->ipi6_ifindex = 0; + ipv6_addr_copy(&pki->ipi6_addr, + &rqstp->rq_daddr.addr6); + cmh->cmsg_len = CMSG_LEN(sizeof(*pki)); + } + break; + } + return; +} + +/* + * Generic sendto routine + */ +static int svc_sendto(struct svc_rqst *rqstp, struct xdr_buf *xdr) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + struct socket *sock = svsk->sk_sock; + int slen; + union { + struct cmsghdr hdr; + long all[SVC_PKTINFO_SPACE / sizeof(long)]; + } buffer; + struct cmsghdr *cmh = &buffer.hdr; + int len = 0; + int result; + int size; + struct page **ppage = xdr->pages; + size_t base = xdr->page_base; + unsigned int pglen = xdr->page_len; + unsigned int flags = MSG_MORE; + RPC_IFDEBUG(char buf[RPC_MAX_ADDRBUFLEN]); + + slen = xdr->len; + + if (rqstp->rq_prot == IPPROTO_UDP) { + struct msghdr msg = { + .msg_name = &rqstp->rq_addr, + .msg_namelen = rqstp->rq_addrlen, + .msg_control = cmh, + .msg_controllen = sizeof(buffer), + .msg_flags = MSG_MORE, + }; + + svc_set_cmsg_data(rqstp, cmh); + + if (sock_sendmsg(sock, &msg, 0) < 0) + goto out; + } + + /* send head */ + if (slen == xdr->head[0].iov_len) + flags = 0; + len = kernel_sendpage(sock, rqstp->rq_respages[0], 0, + xdr->head[0].iov_len, flags); + if (len != xdr->head[0].iov_len) + goto out; + slen -= xdr->head[0].iov_len; + if (slen == 0) + goto out; + + /* send page data */ + size = PAGE_SIZE - base < pglen ? PAGE_SIZE - base : pglen; + while (pglen > 0) { + if (slen == size) + flags = 0; + result = kernel_sendpage(sock, *ppage, base, size, flags); + if (result > 0) + len += result; + if (result != size) + goto out; + slen -= size; + pglen -= size; + size = PAGE_SIZE < pglen ? PAGE_SIZE : pglen; + base = 0; + ppage++; + } + /* send tail */ + if (xdr->tail[0].iov_len) { + result = kernel_sendpage(sock, rqstp->rq_respages[0], + ((unsigned long)xdr->tail[0].iov_base) + & (PAGE_SIZE-1), + xdr->tail[0].iov_len, 0); + + if (result > 0) + len += result; + } +out: + dprintk("svc: socket %p sendto([%p %Zu... ], %d) = %d (addr %s)\n", + svsk, xdr->head[0].iov_base, xdr->head[0].iov_len, + xdr->len, len, svc_print_addr(rqstp, buf, sizeof(buf))); + + return len; +} + +/* + * Report socket names for nfsdfs + */ +static int one_sock_name(char *buf, struct svc_sock *svsk) +{ + int len; + + switch(svsk->sk_sk->sk_family) { + case AF_INET: + len = sprintf(buf, "ipv4 %s %u.%u.%u.%u %d\n", + svsk->sk_sk->sk_protocol==IPPROTO_UDP? + "udp" : "tcp", + NIPQUAD(inet_sk(svsk->sk_sk)->rcv_saddr), + inet_sk(svsk->sk_sk)->num); + break; + default: + len = sprintf(buf, "*unknown-%d*\n", + svsk->sk_sk->sk_family); + } + return len; +} + +int +svc_sock_names(char *buf, struct svc_serv *serv, char *toclose) +{ + struct svc_sock *svsk, *closesk = NULL; + int len = 0; + + if (!serv) + return 0; + spin_lock_bh(&serv->sv_lock); + list_for_each_entry(svsk, &serv->sv_permsocks, sk_xprt.xpt_list) { + int onelen = one_sock_name(buf+len, svsk); + if (toclose && strcmp(toclose, buf+len) == 0) + closesk = svsk; + else + len += onelen; + } + spin_unlock_bh(&serv->sv_lock); + if (closesk) + /* Should unregister with portmap, but you cannot + * unregister just one protocol... + */ + svc_close_xprt(&closesk->sk_xprt); + else if (toclose) + return -ENOENT; + return len; +} +EXPORT_SYMBOL(svc_sock_names); + +/* + * Check input queue length + */ +static int svc_recv_available(struct svc_sock *svsk) +{ + struct socket *sock = svsk->sk_sock; + int avail, err; + + err = kernel_sock_ioctl(sock, TIOCINQ, (unsigned long) &avail); + + return (err >= 0)? avail : err; +} + +/* + * Generic recvfrom routine. + */ +static int svc_recvfrom(struct svc_rqst *rqstp, struct kvec *iov, int nr, + int buflen) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + struct msghdr msg = { + .msg_flags = MSG_DONTWAIT, + }; + int len; + + rqstp->rq_xprt_hlen = 0; + + len = kernel_recvmsg(svsk->sk_sock, &msg, iov, nr, buflen, + msg.msg_flags); + + dprintk("svc: socket %p recvfrom(%p, %Zu) = %d\n", + svsk, iov[0].iov_base, iov[0].iov_len, len); + return len; +} + +/* + * Set socket snd and rcv buffer lengths + */ +static void svc_sock_setbufsize(struct socket *sock, unsigned int snd, + unsigned int rcv) +{ +#if 0 + mm_segment_t oldfs; + oldfs = get_fs(); set_fs(KERNEL_DS); + sock_setsockopt(sock, SOL_SOCKET, SO_SNDBUF, + (char*)&snd, sizeof(snd)); + sock_setsockopt(sock, SOL_SOCKET, SO_RCVBUF, + (char*)&rcv, sizeof(rcv)); +#else + /* sock_setsockopt limits use to sysctl_?mem_max, + * which isn't acceptable. Until that is made conditional + * on not having CAP_SYS_RESOURCE or similar, we go direct... + * DaveM said I could! + */ + lock_sock(sock->sk); + sock->sk->sk_sndbuf = snd * 2; + sock->sk->sk_rcvbuf = rcv * 2; + sock->sk->sk_userlocks |= SOCK_SNDBUF_LOCK|SOCK_RCVBUF_LOCK; + release_sock(sock->sk); +#endif +} +/* + * INET callback when data has been received on the socket. + */ +static void svc_udp_data_ready(struct sock *sk, int count) +{ + struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; + + if (svsk) { + dprintk("svc: socket %p(inet %p), count=%d, busy=%d\n", + svsk, sk, count, + test_bit(XPT_BUSY, &svsk->sk_xprt.xpt_flags)); + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + svc_xprt_enqueue(&svsk->sk_xprt); + } + if (sk->sk_sleep && waitqueue_active(sk->sk_sleep)) + wake_up_interruptible(sk->sk_sleep); +} + +/* + * INET callback when space is newly available on the socket. + */ +static void svc_write_space(struct sock *sk) +{ + struct svc_sock *svsk = (struct svc_sock *)(sk->sk_user_data); + + if (svsk) { + dprintk("svc: socket %p(inet %p), write_space busy=%d\n", + svsk, sk, test_bit(XPT_BUSY, &svsk->sk_xprt.xpt_flags)); + svc_xprt_enqueue(&svsk->sk_xprt); + } + + if (sk->sk_sleep && waitqueue_active(sk->sk_sleep)) { + dprintk("RPC svc_write_space: someone sleeping on %p\n", + svsk); + wake_up_interruptible(sk->sk_sleep); + } +} + +/* + * Copy the UDP datagram's destination address to the rqstp structure. + * The 'destination' address in this case is the address to which the + * peer sent the datagram, i.e. our local address. For multihomed + * hosts, this can change from msg to msg. Note that only the IP + * address changes, the port number should remain the same. + */ +static void svc_udp_get_dest_address(struct svc_rqst *rqstp, + struct cmsghdr *cmh) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + switch (svsk->sk_sk->sk_family) { + case AF_INET: { + struct in_pktinfo *pki = CMSG_DATA(cmh); + rqstp->rq_daddr.addr.s_addr = pki->ipi_spec_dst.s_addr; + break; + } + case AF_INET6: { + struct in6_pktinfo *pki = CMSG_DATA(cmh); + ipv6_addr_copy(&rqstp->rq_daddr.addr6, &pki->ipi6_addr); + break; + } + } +} + +/* + * Receive a datagram from a UDP socket. + */ +static int svc_udp_recvfrom(struct svc_rqst *rqstp) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + struct svc_serv *serv = svsk->sk_xprt.xpt_server; + struct sk_buff *skb; + union { + struct cmsghdr hdr; + long all[SVC_PKTINFO_SPACE / sizeof(long)]; + } buffer; + struct cmsghdr *cmh = &buffer.hdr; + int err, len; + struct msghdr msg = { + .msg_name = svc_addr(rqstp), + .msg_control = cmh, + .msg_controllen = sizeof(buffer), + .msg_flags = MSG_DONTWAIT, + }; + + if (test_and_clear_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags)) + /* udp sockets need large rcvbuf as all pending + * requests are still in that buffer. sndbuf must + * also be large enough that there is enough space + * for one reply per thread. We count all threads + * rather than threads in a particular pool, which + * provides an upper bound on the number of threads + * which will access the socket. + */ + svc_sock_setbufsize(svsk->sk_sock, + (serv->sv_nrthreads+3) * serv->sv_max_mesg, + (serv->sv_nrthreads+3) * serv->sv_max_mesg); + + clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + skb = NULL; + err = kernel_recvmsg(svsk->sk_sock, &msg, NULL, + 0, 0, MSG_PEEK | MSG_DONTWAIT); + if (err >= 0) + skb = skb_recv_datagram(svsk->sk_sk, 0, 1, &err); + + if (skb == NULL) { + if (err != -EAGAIN) { + /* possibly an icmp error */ + dprintk("svc: recvfrom returned error %d\n", -err); + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + } + svc_xprt_received(&svsk->sk_xprt); + return -EAGAIN; + } + len = svc_addr_len(svc_addr(rqstp)); + if (len < 0) + return len; + rqstp->rq_addrlen = len; + if (skb->tstamp.tv64 == 0) { + skb->tstamp = ktime_get_real(); + /* Don't enable netstamp, sunrpc doesn't + need that much accuracy */ + } + svsk->sk_sk->sk_stamp = skb->tstamp; + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); /* there may be more data... */ + + /* + * Maybe more packets - kick another thread ASAP. + */ + svc_xprt_received(&svsk->sk_xprt); + + len = skb->len - sizeof(struct udphdr); + rqstp->rq_arg.len = len; + + rqstp->rq_prot = IPPROTO_UDP; + + if (cmh->cmsg_level != IPPROTO_IP || + cmh->cmsg_type != IP_PKTINFO) { + if (net_ratelimit()) + printk("rpcsvc: received unknown control message:" + "%d/%d\n", + cmh->cmsg_level, cmh->cmsg_type); + skb_free_datagram(svsk->sk_sk, skb); + return 0; + } + svc_udp_get_dest_address(rqstp, cmh); + + if (skb_is_nonlinear(skb)) { + /* we have to copy */ + local_bh_disable(); + if (csum_partial_copy_to_xdr(&rqstp->rq_arg, skb)) { + local_bh_enable(); + /* checksum error */ + skb_free_datagram(svsk->sk_sk, skb); + return 0; + } + local_bh_enable(); + skb_free_datagram(svsk->sk_sk, skb); + } else { + /* we can use it in-place */ + rqstp->rq_arg.head[0].iov_base = skb->data + + sizeof(struct udphdr); + rqstp->rq_arg.head[0].iov_len = len; + if (skb_checksum_complete(skb)) { + skb_free_datagram(svsk->sk_sk, skb); + return 0; + } + rqstp->rq_xprt_ctxt = skb; + } + + rqstp->rq_arg.page_base = 0; + if (len <= rqstp->rq_arg.head[0].iov_len) { + rqstp->rq_arg.head[0].iov_len = len; + rqstp->rq_arg.page_len = 0; + rqstp->rq_respages = rqstp->rq_pages+1; + } else { + rqstp->rq_arg.page_len = len - rqstp->rq_arg.head[0].iov_len; + rqstp->rq_respages = rqstp->rq_pages + 1 + + DIV_ROUND_UP(rqstp->rq_arg.page_len, PAGE_SIZE); + } + + if (serv->sv_stats) + serv->sv_stats->netudpcnt++; + + return len; +} + +static int +svc_udp_sendto(struct svc_rqst *rqstp) +{ + int error; + + error = svc_sendto(rqstp, &rqstp->rq_res); + if (error == -ECONNREFUSED) + /* ICMP error on earlier request. */ + error = svc_sendto(rqstp, &rqstp->rq_res); + + return error; +} + +static void svc_udp_prep_reply_hdr(struct svc_rqst *rqstp) +{ +} + +static int svc_udp_has_wspace(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct svc_serv *serv = xprt->xpt_server; + unsigned long required; + + /* + * Set the SOCK_NOSPACE flag before checking the available + * sock space. + */ + set_bit(SOCK_NOSPACE, &svsk->sk_sock->flags); + required = atomic_read(&svsk->sk_xprt.xpt_reserved) + serv->sv_max_mesg; + if (required*2 > sock_wspace(svsk->sk_sk)) + return 0; + clear_bit(SOCK_NOSPACE, &svsk->sk_sock->flags); + return 1; +} + +static struct svc_xprt *svc_udp_accept(struct svc_xprt *xprt) +{ + BUG(); + return NULL; +} + +static struct svc_xprt *svc_udp_create(struct svc_serv *serv, + struct sockaddr *sa, int salen, + int flags) +{ + return svc_create_socket(serv, IPPROTO_UDP, sa, salen, flags); +} + +static struct svc_xprt_ops svc_udp_ops = { + .xpo_create = svc_udp_create, + .xpo_recvfrom = svc_udp_recvfrom, + .xpo_sendto = svc_udp_sendto, + .xpo_release_rqst = svc_release_skb, + .xpo_detach = svc_sock_detach, + .xpo_free = svc_sock_free, + .xpo_prep_reply_hdr = svc_udp_prep_reply_hdr, + .xpo_has_wspace = svc_udp_has_wspace, + .xpo_accept = svc_udp_accept, +}; + +static struct svc_xprt_class svc_udp_class = { + .xcl_name = "udp", + .xcl_owner = THIS_MODULE, + .xcl_ops = &svc_udp_ops, + .xcl_max_payload = RPCSVC_MAXPAYLOAD_UDP, +}; + +static void svc_udp_init(struct svc_sock *svsk, struct svc_serv *serv) +{ + int one = 1; + mm_segment_t oldfs; + + svc_xprt_init(&svc_udp_class, &svsk->sk_xprt, serv); + clear_bit(XPT_CACHE_AUTH, &svsk->sk_xprt.xpt_flags); + svsk->sk_sk->sk_data_ready = svc_udp_data_ready; + svsk->sk_sk->sk_write_space = svc_write_space; + + /* initialise setting must have enough space to + * receive and respond to one request. + * svc_udp_recvfrom will re-adjust if necessary + */ + svc_sock_setbufsize(svsk->sk_sock, + 3 * svsk->sk_xprt.xpt_server->sv_max_mesg, + 3 * svsk->sk_xprt.xpt_server->sv_max_mesg); + + /* data might have come in before data_ready set up */ + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); + + oldfs = get_fs(); + set_fs(KERNEL_DS); + /* make sure we get destination address info */ + svsk->sk_sock->ops->setsockopt(svsk->sk_sock, IPPROTO_IP, IP_PKTINFO, + (char __user *)&one, sizeof(one)); + set_fs(oldfs); +} + +/* + * A data_ready event on a listening socket means there's a connection + * pending. Do not use state_change as a substitute for it. + */ +static void svc_tcp_listen_data_ready(struct sock *sk, int count_unused) +{ + struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; + + dprintk("svc: socket %p TCP (listen) state change %d\n", + sk, sk->sk_state); + + /* + * This callback may called twice when a new connection + * is established as a child socket inherits everything + * from a parent LISTEN socket. + * 1) data_ready method of the parent socket will be called + * when one of child sockets become ESTABLISHED. + * 2) data_ready method of the child socket may be called + * when it receives data before the socket is accepted. + * In case of 2, we should ignore it silently. + */ + if (sk->sk_state == TCP_LISTEN) { + if (svsk) { + set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); + svc_xprt_enqueue(&svsk->sk_xprt); + } else + printk("svc: socket %p: no user data\n", sk); + } + + if (sk->sk_sleep && waitqueue_active(sk->sk_sleep)) + wake_up_interruptible_all(sk->sk_sleep); +} + +/* + * A state change on a connected socket means it's dying or dead. + */ +static void svc_tcp_state_change(struct sock *sk) +{ + struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; + + dprintk("svc: socket %p TCP (connected) state change %d (svsk %p)\n", + sk, sk->sk_state, sk->sk_user_data); + + if (!svsk) + printk("svc: socket %p: no user data\n", sk); + else { + set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); + svc_xprt_enqueue(&svsk->sk_xprt); + } + if (sk->sk_sleep && waitqueue_active(sk->sk_sleep)) + wake_up_interruptible_all(sk->sk_sleep); +} + +static void svc_tcp_data_ready(struct sock *sk, int count) +{ + struct svc_sock *svsk = (struct svc_sock *)sk->sk_user_data; + + dprintk("svc: socket %p TCP data ready (svsk %p)\n", + sk, sk->sk_user_data); + if (svsk) { + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + svc_xprt_enqueue(&svsk->sk_xprt); + } + if (sk->sk_sleep && waitqueue_active(sk->sk_sleep)) + wake_up_interruptible(sk->sk_sleep); +} + +/* + * Accept a TCP connection + */ +static struct svc_xprt *svc_tcp_accept(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct sockaddr_storage addr; + struct sockaddr *sin = (struct sockaddr *) &addr; + struct svc_serv *serv = svsk->sk_xprt.xpt_server; + struct socket *sock = svsk->sk_sock; + struct socket *newsock; + struct svc_sock *newsvsk; + int err, slen; + RPC_IFDEBUG(char buf[RPC_MAX_ADDRBUFLEN]); + + dprintk("svc: tcp_accept %p sock %p\n", svsk, sock); + if (!sock) + return NULL; + + clear_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); + err = kernel_accept(sock, &newsock, O_NONBLOCK); + if (err < 0) { + if (err == -ENOMEM) + printk(KERN_WARNING "%s: no more sockets!\n", + serv->sv_name); + else if (err != -EAGAIN && net_ratelimit()) + printk(KERN_WARNING "%s: accept failed (err %d)!\n", + serv->sv_name, -err); + return NULL; + } + set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); + + err = kernel_getpeername(newsock, sin, &slen); + if (err < 0) { + if (net_ratelimit()) + printk(KERN_WARNING "%s: peername failed (err %d)!\n", + serv->sv_name, -err); + goto failed; /* aborted connection or whatever */ + } + + /* Ideally, we would want to reject connections from unauthorized + * hosts here, but when we get encryption, the IP of the host won't + * tell us anything. For now just warn about unpriv connections. + */ + if (!svc_port_is_privileged(sin)) { + dprintk(KERN_WARNING + "%s: connect from unprivileged port: %s\n", + serv->sv_name, + __svc_print_addr(sin, buf, sizeof(buf))); + } + dprintk("%s: connect from %s\n", serv->sv_name, + __svc_print_addr(sin, buf, sizeof(buf))); + + /* make sure that a write doesn't block forever when + * low on memory + */ + newsock->sk->sk_sndtimeo = HZ*30; + + if (!(newsvsk = svc_setup_socket(serv, newsock, &err, + (SVC_SOCK_ANONYMOUS | SVC_SOCK_TEMPORARY)))) + goto failed; + svc_xprt_set_remote(&newsvsk->sk_xprt, sin, slen); + err = kernel_getsockname(newsock, sin, &slen); + if (unlikely(err < 0)) { + dprintk("svc_tcp_accept: kernel_getsockname error %d\n", -err); + slen = offsetof(struct sockaddr, sa_data); + } + svc_xprt_set_local(&newsvsk->sk_xprt, sin, slen); + + if (serv->sv_stats) + serv->sv_stats->nettcpconn++; + + return &newsvsk->sk_xprt; + +failed: + sock_release(newsock); + return NULL; +} + +/* + * Receive data from a TCP socket. + */ +static int svc_tcp_recvfrom(struct svc_rqst *rqstp) +{ + struct svc_sock *svsk = + container_of(rqstp->rq_xprt, struct svc_sock, sk_xprt); + struct svc_serv *serv = svsk->sk_xprt.xpt_server; + int len; + struct kvec *vec; + int pnum, vlen; + + dprintk("svc: tcp_recv %p data %d conn %d close %d\n", + svsk, test_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags), + test_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags), + test_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags)); + + if (test_and_clear_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags)) + /* sndbuf needs to have room for one request + * per thread, otherwise we can stall even when the + * network isn't a bottleneck. + * + * We count all threads rather than threads in a + * particular pool, which provides an upper bound + * on the number of threads which will access the socket. + * + * rcvbuf just needs to be able to hold a few requests. + * Normally they will be removed from the queue + * as soon a a complete request arrives. + */ + svc_sock_setbufsize(svsk->sk_sock, + (serv->sv_nrthreads+3) * serv->sv_max_mesg, + 3 * serv->sv_max_mesg); + + clear_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + + /* Receive data. If we haven't got the record length yet, get + * the next four bytes. Otherwise try to gobble up as much as + * possible up to the complete record length. + */ + if (svsk->sk_tcplen < sizeof(rpc_fraghdr)) { + int want = sizeof(rpc_fraghdr) - svsk->sk_tcplen; + struct kvec iov; + + iov.iov_base = ((char *) &svsk->sk_reclen) + svsk->sk_tcplen; + iov.iov_len = want; + if ((len = svc_recvfrom(rqstp, &iov, 1, want)) < 0) + goto error; + svsk->sk_tcplen += len; + + if (len < want) { + dprintk("svc: short recvfrom while reading record " + "length (%d of %d)\n", len, want); + svc_xprt_received(&svsk->sk_xprt); + return -EAGAIN; /* record header not complete */ + } + + svsk->sk_reclen = ntohl(svsk->sk_reclen); + if (!(svsk->sk_reclen & RPC_LAST_STREAM_FRAGMENT)) { + /* FIXME: technically, a record can be fragmented, + * and non-terminal fragments will not have the top + * bit set in the fragment length header. + * But apparently no known nfs clients send fragmented + * records. */ + if (net_ratelimit()) + printk(KERN_NOTICE "RPC: multiple fragments " + "per record not supported\n"); + goto err_delete; + } + svsk->sk_reclen &= RPC_FRAGMENT_SIZE_MASK; + dprintk("svc: TCP record, %d bytes\n", svsk->sk_reclen); + if (svsk->sk_reclen > serv->sv_max_mesg) { + if (net_ratelimit()) + printk(KERN_NOTICE "RPC: " + "fragment too large: 0x%08lx\n", + (unsigned long)svsk->sk_reclen); + goto err_delete; + } + } + + /* Check whether enough data is available */ + len = svc_recv_available(svsk); + if (len < 0) + goto error; + + if (len < svsk->sk_reclen) { + dprintk("svc: incomplete TCP record (%d of %d)\n", + len, svsk->sk_reclen); + svc_xprt_received(&svsk->sk_xprt); + return -EAGAIN; /* record not complete */ + } + len = svsk->sk_reclen; + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + + vec = rqstp->rq_vec; + vec[0] = rqstp->rq_arg.head[0]; + vlen = PAGE_SIZE; + pnum = 1; + while (vlen < len) { + vec[pnum].iov_base = page_address(rqstp->rq_pages[pnum]); + vec[pnum].iov_len = PAGE_SIZE; + pnum++; + vlen += PAGE_SIZE; + } + rqstp->rq_respages = &rqstp->rq_pages[pnum]; + + /* Now receive data */ + len = svc_recvfrom(rqstp, vec, pnum, len); + if (len < 0) + goto error; + + dprintk("svc: TCP complete record (%d bytes)\n", len); + rqstp->rq_arg.len = len; + rqstp->rq_arg.page_base = 0; + if (len <= rqstp->rq_arg.head[0].iov_len) { + rqstp->rq_arg.head[0].iov_len = len; + rqstp->rq_arg.page_len = 0; + } else { + rqstp->rq_arg.page_len = len - rqstp->rq_arg.head[0].iov_len; + } + + rqstp->rq_xprt_ctxt = NULL; + rqstp->rq_prot = IPPROTO_TCP; + + /* Reset TCP read info */ + svsk->sk_reclen = 0; + svsk->sk_tcplen = 0; + + svc_xprt_copy_addrs(rqstp, &svsk->sk_xprt); + svc_xprt_received(&svsk->sk_xprt); + if (serv->sv_stats) + serv->sv_stats->nettcpcnt++; + + return len; + + err_delete: + set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); + return -EAGAIN; + + error: + if (len == -EAGAIN) { + dprintk("RPC: TCP recvfrom got EAGAIN\n"); + svc_xprt_received(&svsk->sk_xprt); + } else { + printk(KERN_NOTICE "%s: recvfrom returned errno %d\n", + svsk->sk_xprt.xpt_server->sv_name, -len); + goto err_delete; + } + + return len; +} + +/* + * Send out data on TCP socket. + */ +static int svc_tcp_sendto(struct svc_rqst *rqstp) +{ + struct xdr_buf *xbufp = &rqstp->rq_res; + int sent; + __be32 reclen; + + /* Set up the first element of the reply kvec. + * Any other kvecs that may be in use have been taken + * care of by the server implementation itself. + */ + reclen = htonl(0x80000000|((xbufp->len ) - 4)); + memcpy(xbufp->head[0].iov_base, &reclen, 4); + + if (test_bit(XPT_DEAD, &rqstp->rq_xprt->xpt_flags)) + return -ENOTCONN; + + sent = svc_sendto(rqstp, &rqstp->rq_res); + if (sent != xbufp->len) { + printk(KERN_NOTICE + "rpc-srv/tcp: %s: %s %d when sending %d bytes " + "- shutting down socket\n", + rqstp->rq_xprt->xpt_server->sv_name, + (sent<0)?"got error":"sent only", + sent, xbufp->len); + set_bit(XPT_CLOSE, &rqstp->rq_xprt->xpt_flags); + svc_xprt_enqueue(rqstp->rq_xprt); + sent = -EAGAIN; + } + return sent; +} + +/* + * Setup response header. TCP has a 4B record length field. + */ +static void svc_tcp_prep_reply_hdr(struct svc_rqst *rqstp) +{ + struct kvec *resv = &rqstp->rq_res.head[0]; + + /* tcp needs a space for the record length... */ + svc_putnl(resv, 0); +} + +static int svc_tcp_has_wspace(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct svc_serv *serv = svsk->sk_xprt.xpt_server; + int required; + int wspace; + + /* + * Set the SOCK_NOSPACE flag before checking the available + * sock space. + */ + set_bit(SOCK_NOSPACE, &svsk->sk_sock->flags); + required = atomic_read(&svsk->sk_xprt.xpt_reserved) + serv->sv_max_mesg; + wspace = sk_stream_wspace(svsk->sk_sk); + + if (wspace < sk_stream_min_wspace(svsk->sk_sk)) + return 0; + if (required * 2 > wspace) + return 0; + + clear_bit(SOCK_NOSPACE, &svsk->sk_sock->flags); + return 1; +} + +static struct svc_xprt *svc_tcp_create(struct svc_serv *serv, + struct sockaddr *sa, int salen, + int flags) +{ + return svc_create_socket(serv, IPPROTO_TCP, sa, salen, flags); +} + +static struct svc_xprt_ops svc_tcp_ops = { + .xpo_create = svc_tcp_create, + .xpo_recvfrom = svc_tcp_recvfrom, + .xpo_sendto = svc_tcp_sendto, + .xpo_release_rqst = svc_release_skb, + .xpo_detach = svc_sock_detach, + .xpo_free = svc_sock_free, + .xpo_prep_reply_hdr = svc_tcp_prep_reply_hdr, + .xpo_has_wspace = svc_tcp_has_wspace, + .xpo_accept = svc_tcp_accept, +}; + +static struct svc_xprt_class svc_tcp_class = { + .xcl_name = "tcp", + .xcl_owner = THIS_MODULE, + .xcl_ops = &svc_tcp_ops, + .xcl_max_payload = RPCSVC_MAXPAYLOAD_TCP, +}; + +void svc_init_xprt_sock(void) +{ + svc_reg_xprt_class(&svc_tcp_class); + svc_reg_xprt_class(&svc_udp_class); +} + +void svc_cleanup_xprt_sock(void) +{ + svc_unreg_xprt_class(&svc_tcp_class); + svc_unreg_xprt_class(&svc_udp_class); +} + +static void svc_tcp_init(struct svc_sock *svsk, struct svc_serv *serv) +{ + struct sock *sk = svsk->sk_sk; + + svc_xprt_init(&svc_tcp_class, &svsk->sk_xprt, serv); + set_bit(XPT_CACHE_AUTH, &svsk->sk_xprt.xpt_flags); + if (sk->sk_state == TCP_LISTEN) { + dprintk("setting up TCP socket for listening\n"); + set_bit(XPT_LISTENER, &svsk->sk_xprt.xpt_flags); + sk->sk_data_ready = svc_tcp_listen_data_ready; + set_bit(XPT_CONN, &svsk->sk_xprt.xpt_flags); + } else { + dprintk("setting up TCP socket for reading\n"); + sk->sk_state_change = svc_tcp_state_change; + sk->sk_data_ready = svc_tcp_data_ready; + sk->sk_write_space = svc_write_space; + + svsk->sk_reclen = 0; + svsk->sk_tcplen = 0; + + tcp_sk(sk)->nonagle |= TCP_NAGLE_OFF; + + /* initialise setting must have enough space to + * receive and respond to one request. + * svc_tcp_recvfrom will re-adjust if necessary + */ + svc_sock_setbufsize(svsk->sk_sock, + 3 * svsk->sk_xprt.xpt_server->sv_max_mesg, + 3 * svsk->sk_xprt.xpt_server->sv_max_mesg); + + set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); + set_bit(XPT_DATA, &svsk->sk_xprt.xpt_flags); + if (sk->sk_state != TCP_ESTABLISHED) + set_bit(XPT_CLOSE, &svsk->sk_xprt.xpt_flags); + } +} + +void svc_sock_update_bufs(struct svc_serv *serv) +{ + /* + * The number of server threads has changed. Update + * rcvbuf and sndbuf accordingly on all sockets + */ + struct list_head *le; + + spin_lock_bh(&serv->sv_lock); + list_for_each(le, &serv->sv_permsocks) { + struct svc_sock *svsk = + list_entry(le, struct svc_sock, sk_xprt.xpt_list); + set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); + } + list_for_each(le, &serv->sv_tempsocks) { + struct svc_sock *svsk = + list_entry(le, struct svc_sock, sk_xprt.xpt_list); + set_bit(XPT_CHNGBUF, &svsk->sk_xprt.xpt_flags); + } + spin_unlock_bh(&serv->sv_lock); +} +EXPORT_SYMBOL(svc_sock_update_bufs); + +/* + * Initialize socket for RPC use and create svc_sock struct + * XXX: May want to setsockopt SO_SNDBUF and SO_RCVBUF. + */ +static struct svc_sock *svc_setup_socket(struct svc_serv *serv, + struct socket *sock, + int *errp, int flags) +{ + struct svc_sock *svsk; + struct sock *inet; + int pmap_register = !(flags & SVC_SOCK_ANONYMOUS); + int val; + + dprintk("svc: svc_setup_socket %p\n", sock); + if (!(svsk = kzalloc(sizeof(*svsk), GFP_KERNEL))) { + *errp = -ENOMEM; + return NULL; + } + + inet = sock->sk; + + /* Register socket with portmapper */ + if (*errp >= 0 && pmap_register) + *errp = svc_register(serv, inet->sk_protocol, + ntohs(inet_sk(inet)->sport)); + + if (*errp < 0) { + kfree(svsk); + return NULL; + } + + inet->sk_user_data = svsk; + svsk->sk_sock = sock; + svsk->sk_sk = inet; + svsk->sk_ostate = inet->sk_state_change; + svsk->sk_odata = inet->sk_data_ready; + svsk->sk_owspace = inet->sk_write_space; + + /* Initialize the socket */ + if (sock->type == SOCK_DGRAM) + svc_udp_init(svsk, serv); + else + svc_tcp_init(svsk, serv); + + /* + * We start one listener per sv_serv. We want AF_INET + * requests to be automatically shunted to our AF_INET6 + * listener using a mapped IPv4 address. Make sure + * no-one starts an equivalent IPv4 listener, which + * would steal our incoming connections. + */ + val = 0; + if (serv->sv_family == AF_INET6) + kernel_setsockopt(sock, SOL_IPV6, IPV6_V6ONLY, + (char *)&val, sizeof(val)); + + dprintk("svc: svc_setup_socket created %p (inet %p)\n", + svsk, svsk->sk_sk); + + return svsk; +} + +int svc_addsock(struct svc_serv *serv, + int fd, + char *name_return) +{ + int err = 0; + struct socket *so = sockfd_lookup(fd, &err); + struct svc_sock *svsk = NULL; + + if (!so) + return err; + if (so->sk->sk_family != AF_INET) + err = -EAFNOSUPPORT; + else if (so->sk->sk_protocol != IPPROTO_TCP && + so->sk->sk_protocol != IPPROTO_UDP) + err = -EPROTONOSUPPORT; + else if (so->state > SS_UNCONNECTED) + err = -EISCONN; + else { + if (!try_module_get(THIS_MODULE)) + err = -ENOENT; + else + svsk = svc_setup_socket(serv, so, &err, + SVC_SOCK_DEFAULTS); + if (svsk) { + struct sockaddr_storage addr; + struct sockaddr *sin = (struct sockaddr *)&addr; + int salen; + if (kernel_getsockname(svsk->sk_sock, sin, &salen) == 0) + svc_xprt_set_local(&svsk->sk_xprt, sin, salen); + clear_bit(XPT_TEMP, &svsk->sk_xprt.xpt_flags); + spin_lock_bh(&serv->sv_lock); + list_add(&svsk->sk_xprt.xpt_list, &serv->sv_permsocks); + spin_unlock_bh(&serv->sv_lock); + svc_xprt_received(&svsk->sk_xprt); + err = 0; + } else + module_put(THIS_MODULE); + } + if (err) { + sockfd_put(so); + return err; + } + return one_sock_name(name_return, svsk); +} +EXPORT_SYMBOL_GPL(svc_addsock); + +/* + * Create socket for RPC service. + */ +static struct svc_xprt *svc_create_socket(struct svc_serv *serv, + int protocol, + struct sockaddr *sin, int len, + int flags) +{ + struct svc_sock *svsk; + struct socket *sock; + int error; + int type; + struct sockaddr_storage addr; + struct sockaddr *newsin = (struct sockaddr *)&addr; + int newlen; + RPC_IFDEBUG(char buf[RPC_MAX_ADDRBUFLEN]); + + dprintk("svc: svc_create_socket(%s, %d, %s)\n", + serv->sv_program->pg_name, protocol, + __svc_print_addr(sin, buf, sizeof(buf))); + + if (protocol != IPPROTO_UDP && protocol != IPPROTO_TCP) { + printk(KERN_WARNING "svc: only UDP and TCP " + "sockets supported\n"); + return ERR_PTR(-EINVAL); + } + type = (protocol == IPPROTO_UDP)? SOCK_DGRAM : SOCK_STREAM; + + error = sock_create_kern(sin->sa_family, type, protocol, &sock); + if (error < 0) + return ERR_PTR(error); + + svc_reclassify_socket(sock); + + if (type == SOCK_STREAM) + sock->sk->sk_reuse = 1; /* allow address reuse */ + error = kernel_bind(sock, sin, len); + if (error < 0) + goto bummer; + + newlen = len; + error = kernel_getsockname(sock, newsin, &newlen); + if (error < 0) + goto bummer; + + if (protocol == IPPROTO_TCP) { + if ((error = kernel_listen(sock, 64)) < 0) + goto bummer; + } + + if ((svsk = svc_setup_socket(serv, sock, &error, flags)) != NULL) { + svc_xprt_set_local(&svsk->sk_xprt, newsin, newlen); + return (struct svc_xprt *)svsk; + } + +bummer: + dprintk("svc: svc_create_socket error = %d\n", -error); + sock_release(sock); + return ERR_PTR(error); +} + +/* + * Detach the svc_sock from the socket so that no + * more callbacks occur. + */ +static void svc_sock_detach(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + struct sock *sk = svsk->sk_sk; + + dprintk("svc: svc_sock_detach(%p)\n", svsk); + + /* put back the old socket callbacks */ + sk->sk_state_change = svsk->sk_ostate; + sk->sk_data_ready = svsk->sk_odata; + sk->sk_write_space = svsk->sk_owspace; +} + +/* + * Free the svc_sock's socket resources and the svc_sock itself. + */ +static void svc_sock_free(struct svc_xprt *xprt) +{ + struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); + dprintk("svc: svc_sock_free(%p)\n", svsk); + + if (svsk->sk_sock->file) + sockfd_put(svsk->sk_sock); + else + sock_release(svsk->sk_sock); + kfree(svsk); +} diff --git a/net/sunrpc/sysctl.c b/net/sunrpc/sysctl.c new file mode 100644 index 0000000..5231f7a --- /dev/null +++ b/net/sunrpc/sysctl.c @@ -0,0 +1,184 @@ +/* + * linux/net/sunrpc/sysctl.c + * + * Sysctl interface to sunrpc module. + * + * I would prefer to register the sunrpc table below sys/net, but that's + * impossible at the moment. + */ + +#include <linux/types.h> +#include <linux/linkage.h> +#include <linux/ctype.h> +#include <linux/fs.h> +#include <linux/sysctl.h> +#include <linux/module.h> + +#include <asm/uaccess.h> +#include <linux/sunrpc/types.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/stats.h> +#include <linux/sunrpc/svc_xprt.h> + +/* + * Declare the debug flags here + */ +unsigned int rpc_debug; +EXPORT_SYMBOL_GPL(rpc_debug); + +unsigned int nfs_debug; +EXPORT_SYMBOL_GPL(nfs_debug); + +unsigned int nfsd_debug; +EXPORT_SYMBOL_GPL(nfsd_debug); + +unsigned int nlm_debug; +EXPORT_SYMBOL_GPL(nlm_debug); + +#ifdef RPC_DEBUG + +static struct ctl_table_header *sunrpc_table_header; +static ctl_table sunrpc_table[]; + +void +rpc_register_sysctl(void) +{ + if (!sunrpc_table_header) + sunrpc_table_header = register_sysctl_table(sunrpc_table); +} + +void +rpc_unregister_sysctl(void) +{ + if (sunrpc_table_header) { + unregister_sysctl_table(sunrpc_table_header); + sunrpc_table_header = NULL; + } +} + +static int proc_do_xprt(ctl_table *table, int write, struct file *file, + void __user *buffer, size_t *lenp, loff_t *ppos) +{ + char tmpbuf[256]; + size_t len; + + if ((*ppos && !write) || !*lenp) { + *lenp = 0; + return 0; + } + len = svc_print_xprts(tmpbuf, sizeof(tmpbuf)); + return simple_read_from_buffer(buffer, *lenp, ppos, tmpbuf, len); +} + +static int +proc_dodebug(ctl_table *table, int write, struct file *file, + void __user *buffer, size_t *lenp, loff_t *ppos) +{ + char tmpbuf[20], c, *s; + char __user *p; + unsigned int value; + size_t left, len; + + if ((*ppos && !write) || !*lenp) { + *lenp = 0; + return 0; + } + + left = *lenp; + + if (write) { + if (!access_ok(VERIFY_READ, buffer, left)) + return -EFAULT; + p = buffer; + while (left && __get_user(c, p) >= 0 && isspace(c)) + left--, p++; + if (!left) + goto done; + + if (left > sizeof(tmpbuf) - 1) + return -EINVAL; + if (copy_from_user(tmpbuf, p, left)) + return -EFAULT; + tmpbuf[left] = '\0'; + + for (s = tmpbuf, value = 0; '0' <= *s && *s <= '9'; s++, left--) + value = 10 * value + (*s - '0'); + if (*s && !isspace(*s)) + return -EINVAL; + while (left && isspace(*s)) + left--, s++; + *(unsigned int *) table->data = value; + /* Display the RPC tasks on writing to rpc_debug */ + if (strcmp(table->procname, "rpc_debug") == 0) + rpc_show_tasks(); + } else { + if (!access_ok(VERIFY_WRITE, buffer, left)) + return -EFAULT; + len = sprintf(tmpbuf, "%d", *(unsigned int *) table->data); + if (len > left) + len = left; + if (__copy_to_user(buffer, tmpbuf, len)) + return -EFAULT; + if ((left -= len) > 0) { + if (put_user('\n', (char __user *)buffer + len)) + return -EFAULT; + left--; + } + } + +done: + *lenp -= left; + *ppos += *lenp; + return 0; +} + + +static ctl_table debug_table[] = { + { + .procname = "rpc_debug", + .data = &rpc_debug, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = &proc_dodebug + }, + { + .procname = "nfs_debug", + .data = &nfs_debug, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = &proc_dodebug + }, + { + .procname = "nfsd_debug", + .data = &nfsd_debug, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = &proc_dodebug + }, + { + .procname = "nlm_debug", + .data = &nlm_debug, + .maxlen = sizeof(int), + .mode = 0644, + .proc_handler = &proc_dodebug + }, + { + .procname = "transports", + .maxlen = 256, + .mode = 0444, + .proc_handler = &proc_do_xprt, + }, + { .ctl_name = 0 } +}; + +static ctl_table sunrpc_table[] = { + { + .ctl_name = CTL_SUNRPC, + .procname = "sunrpc", + .mode = 0555, + .child = debug_table + }, + { .ctl_name = 0 } +}; + +#endif diff --git a/net/sunrpc/timer.c b/net/sunrpc/timer.c new file mode 100644 index 0000000..31becbf --- /dev/null +++ b/net/sunrpc/timer.c @@ -0,0 +1,109 @@ +/* + * linux/net/sunrpc/timer.c + * + * Estimate RPC request round trip time. + * + * Based on packet round-trip and variance estimator algorithms described + * in appendix A of "Congestion Avoidance and Control" by Van Jacobson + * and Michael J. Karels (ACM Computer Communication Review; Proceedings + * of the Sigcomm '88 Symposium in Stanford, CA, August, 1988). + * + * This RTT estimator is used only for RPC over datagram protocols. + * + * Copyright (C) 2002 Trond Myklebust <trond.myklebust@fys.uio.no> + */ + +#include <asm/param.h> + +#include <linux/types.h> +#include <linux/unistd.h> +#include <linux/module.h> + +#include <linux/sunrpc/clnt.h> + +#define RPC_RTO_MAX (60*HZ) +#define RPC_RTO_INIT (HZ/5) +#define RPC_RTO_MIN (HZ/10) + +void +rpc_init_rtt(struct rpc_rtt *rt, unsigned long timeo) +{ + unsigned long init = 0; + unsigned i; + + rt->timeo = timeo; + + if (timeo > RPC_RTO_INIT) + init = (timeo - RPC_RTO_INIT) << 3; + for (i = 0; i < 5; i++) { + rt->srtt[i] = init; + rt->sdrtt[i] = RPC_RTO_INIT; + rt->ntimeouts[i] = 0; + } +} +EXPORT_SYMBOL_GPL(rpc_init_rtt); + +/* + * NB: When computing the smoothed RTT and standard deviation, + * be careful not to produce negative intermediate results. + */ +void +rpc_update_rtt(struct rpc_rtt *rt, unsigned timer, long m) +{ + long *srtt, *sdrtt; + + if (timer-- == 0) + return; + + /* jiffies wrapped; ignore this one */ + if (m < 0) + return; + + if (m == 0) + m = 1L; + + srtt = (long *)&rt->srtt[timer]; + m -= *srtt >> 3; + *srtt += m; + + if (m < 0) + m = -m; + + sdrtt = (long *)&rt->sdrtt[timer]; + m -= *sdrtt >> 2; + *sdrtt += m; + + /* Set lower bound on the variance */ + if (*sdrtt < RPC_RTO_MIN) + *sdrtt = RPC_RTO_MIN; +} +EXPORT_SYMBOL_GPL(rpc_update_rtt); + +/* + * Estimate rto for an nfs rpc sent via. an unreliable datagram. + * Use the mean and mean deviation of rtt for the appropriate type of rpc + * for the frequent rpcs and a default for the others. + * The justification for doing "other" this way is that these rpcs + * happen so infrequently that timer est. would probably be stale. + * Also, since many of these rpcs are + * non-idempotent, a conservative timeout is desired. + * getattr, lookup, + * read, write, commit - A+4D + * other - timeo + */ + +unsigned long +rpc_calc_rto(struct rpc_rtt *rt, unsigned timer) +{ + unsigned long res; + + if (timer-- == 0) + return rt->timeo; + + res = ((rt->srtt[timer] + 7) >> 3) + rt->sdrtt[timer]; + if (res > RPC_RTO_MAX) + res = RPC_RTO_MAX; + + return res; +} +EXPORT_SYMBOL_GPL(rpc_calc_rto); diff --git a/net/sunrpc/xdr.c b/net/sunrpc/xdr.c new file mode 100644 index 0000000..79a55d5 --- /dev/null +++ b/net/sunrpc/xdr.c @@ -0,0 +1,1110 @@ +/* + * linux/net/sunrpc/xdr.c + * + * Generic XDR support. + * + * Copyright (C) 1995, 1996 Olaf Kirch <okir@monad.swb.de> + */ + +#include <linux/module.h> +#include <linux/types.h> +#include <linux/string.h> +#include <linux/kernel.h> +#include <linux/pagemap.h> +#include <linux/errno.h> +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/msg_prot.h> + +/* + * XDR functions for basic NFS types + */ +__be32 * +xdr_encode_netobj(__be32 *p, const struct xdr_netobj *obj) +{ + unsigned int quadlen = XDR_QUADLEN(obj->len); + + p[quadlen] = 0; /* zero trailing bytes */ + *p++ = htonl(obj->len); + memcpy(p, obj->data, obj->len); + return p + XDR_QUADLEN(obj->len); +} +EXPORT_SYMBOL(xdr_encode_netobj); + +__be32 * +xdr_decode_netobj(__be32 *p, struct xdr_netobj *obj) +{ + unsigned int len; + + if ((len = ntohl(*p++)) > XDR_MAX_NETOBJ) + return NULL; + obj->len = len; + obj->data = (u8 *) p; + return p + XDR_QUADLEN(len); +} +EXPORT_SYMBOL(xdr_decode_netobj); + +/** + * xdr_encode_opaque_fixed - Encode fixed length opaque data + * @p: pointer to current position in XDR buffer. + * @ptr: pointer to data to encode (or NULL) + * @nbytes: size of data. + * + * Copy the array of data of length nbytes at ptr to the XDR buffer + * at position p, then align to the next 32-bit boundary by padding + * with zero bytes (see RFC1832). + * Note: if ptr is NULL, only the padding is performed. + * + * Returns the updated current XDR buffer position + * + */ +__be32 *xdr_encode_opaque_fixed(__be32 *p, const void *ptr, unsigned int nbytes) +{ + if (likely(nbytes != 0)) { + unsigned int quadlen = XDR_QUADLEN(nbytes); + unsigned int padding = (quadlen << 2) - nbytes; + + if (ptr != NULL) + memcpy(p, ptr, nbytes); + if (padding != 0) + memset((char *)p + nbytes, 0, padding); + p += quadlen; + } + return p; +} +EXPORT_SYMBOL(xdr_encode_opaque_fixed); + +/** + * xdr_encode_opaque - Encode variable length opaque data + * @p: pointer to current position in XDR buffer. + * @ptr: pointer to data to encode (or NULL) + * @nbytes: size of data. + * + * Returns the updated current XDR buffer position + */ +__be32 *xdr_encode_opaque(__be32 *p, const void *ptr, unsigned int nbytes) +{ + *p++ = htonl(nbytes); + return xdr_encode_opaque_fixed(p, ptr, nbytes); +} +EXPORT_SYMBOL(xdr_encode_opaque); + +__be32 * +xdr_encode_string(__be32 *p, const char *string) +{ + return xdr_encode_array(p, string, strlen(string)); +} +EXPORT_SYMBOL(xdr_encode_string); + +__be32 * +xdr_decode_string_inplace(__be32 *p, char **sp, + unsigned int *lenp, unsigned int maxlen) +{ + u32 len; + + len = ntohl(*p++); + if (len > maxlen) + return NULL; + *lenp = len; + *sp = (char *) p; + return p + XDR_QUADLEN(len); +} +EXPORT_SYMBOL(xdr_decode_string_inplace); + +void +xdr_encode_pages(struct xdr_buf *xdr, struct page **pages, unsigned int base, + unsigned int len) +{ + struct kvec *tail = xdr->tail; + u32 *p; + + xdr->pages = pages; + xdr->page_base = base; + xdr->page_len = len; + + p = (u32 *)xdr->head[0].iov_base + XDR_QUADLEN(xdr->head[0].iov_len); + tail->iov_base = p; + tail->iov_len = 0; + + if (len & 3) { + unsigned int pad = 4 - (len & 3); + + *p = 0; + tail->iov_base = (char *)p + (len & 3); + tail->iov_len = pad; + len += pad; + } + xdr->buflen += len; + xdr->len += len; +} +EXPORT_SYMBOL(xdr_encode_pages); + +void +xdr_inline_pages(struct xdr_buf *xdr, unsigned int offset, + struct page **pages, unsigned int base, unsigned int len) +{ + struct kvec *head = xdr->head; + struct kvec *tail = xdr->tail; + char *buf = (char *)head->iov_base; + unsigned int buflen = head->iov_len; + + head->iov_len = offset; + + xdr->pages = pages; + xdr->page_base = base; + xdr->page_len = len; + + tail->iov_base = buf + offset; + tail->iov_len = buflen - offset; + + xdr->buflen += len; +} +EXPORT_SYMBOL(xdr_inline_pages); + +/* + * Helper routines for doing 'memmove' like operations on a struct xdr_buf + * + * _shift_data_right_pages + * @pages: vector of pages containing both the source and dest memory area. + * @pgto_base: page vector address of destination + * @pgfrom_base: page vector address of source + * @len: number of bytes to copy + * + * Note: the addresses pgto_base and pgfrom_base are both calculated in + * the same way: + * if a memory area starts at byte 'base' in page 'pages[i]', + * then its address is given as (i << PAGE_CACHE_SHIFT) + base + * Also note: pgfrom_base must be < pgto_base, but the memory areas + * they point to may overlap. + */ +static void +_shift_data_right_pages(struct page **pages, size_t pgto_base, + size_t pgfrom_base, size_t len) +{ + struct page **pgfrom, **pgto; + char *vfrom, *vto; + size_t copy; + + BUG_ON(pgto_base <= pgfrom_base); + + pgto_base += len; + pgfrom_base += len; + + pgto = pages + (pgto_base >> PAGE_CACHE_SHIFT); + pgfrom = pages + (pgfrom_base >> PAGE_CACHE_SHIFT); + + pgto_base &= ~PAGE_CACHE_MASK; + pgfrom_base &= ~PAGE_CACHE_MASK; + + do { + /* Are any pointers crossing a page boundary? */ + if (pgto_base == 0) { + pgto_base = PAGE_CACHE_SIZE; + pgto--; + } + if (pgfrom_base == 0) { + pgfrom_base = PAGE_CACHE_SIZE; + pgfrom--; + } + + copy = len; + if (copy > pgto_base) + copy = pgto_base; + if (copy > pgfrom_base) + copy = pgfrom_base; + pgto_base -= copy; + pgfrom_base -= copy; + + vto = kmap_atomic(*pgto, KM_USER0); + vfrom = kmap_atomic(*pgfrom, KM_USER1); + memmove(vto + pgto_base, vfrom + pgfrom_base, copy); + flush_dcache_page(*pgto); + kunmap_atomic(vfrom, KM_USER1); + kunmap_atomic(vto, KM_USER0); + + } while ((len -= copy) != 0); +} + +/* + * _copy_to_pages + * @pages: array of pages + * @pgbase: page vector address of destination + * @p: pointer to source data + * @len: length + * + * Copies data from an arbitrary memory location into an array of pages + * The copy is assumed to be non-overlapping. + */ +static void +_copy_to_pages(struct page **pages, size_t pgbase, const char *p, size_t len) +{ + struct page **pgto; + char *vto; + size_t copy; + + pgto = pages + (pgbase >> PAGE_CACHE_SHIFT); + pgbase &= ~PAGE_CACHE_MASK; + + for (;;) { + copy = PAGE_CACHE_SIZE - pgbase; + if (copy > len) + copy = len; + + vto = kmap_atomic(*pgto, KM_USER0); + memcpy(vto + pgbase, p, copy); + kunmap_atomic(vto, KM_USER0); + + len -= copy; + if (len == 0) + break; + + pgbase += copy; + if (pgbase == PAGE_CACHE_SIZE) { + flush_dcache_page(*pgto); + pgbase = 0; + pgto++; + } + p += copy; + } + flush_dcache_page(*pgto); +} + +/* + * _copy_from_pages + * @p: pointer to destination + * @pages: array of pages + * @pgbase: offset of source data + * @len: length + * + * Copies data into an arbitrary memory location from an array of pages + * The copy is assumed to be non-overlapping. + */ +static void +_copy_from_pages(char *p, struct page **pages, size_t pgbase, size_t len) +{ + struct page **pgfrom; + char *vfrom; + size_t copy; + + pgfrom = pages + (pgbase >> PAGE_CACHE_SHIFT); + pgbase &= ~PAGE_CACHE_MASK; + + do { + copy = PAGE_CACHE_SIZE - pgbase; + if (copy > len) + copy = len; + + vfrom = kmap_atomic(*pgfrom, KM_USER0); + memcpy(p, vfrom + pgbase, copy); + kunmap_atomic(vfrom, KM_USER0); + + pgbase += copy; + if (pgbase == PAGE_CACHE_SIZE) { + pgbase = 0; + pgfrom++; + } + p += copy; + + } while ((len -= copy) != 0); +} + +/* + * xdr_shrink_bufhead + * @buf: xdr_buf + * @len: bytes to remove from buf->head[0] + * + * Shrinks XDR buffer's header kvec buf->head[0] by + * 'len' bytes. The extra data is not lost, but is instead + * moved into the inlined pages and/or the tail. + */ +static void +xdr_shrink_bufhead(struct xdr_buf *buf, size_t len) +{ + struct kvec *head, *tail; + size_t copy, offs; + unsigned int pglen = buf->page_len; + + tail = buf->tail; + head = buf->head; + BUG_ON (len > head->iov_len); + + /* Shift the tail first */ + if (tail->iov_len != 0) { + if (tail->iov_len > len) { + copy = tail->iov_len - len; + memmove((char *)tail->iov_base + len, + tail->iov_base, copy); + } + /* Copy from the inlined pages into the tail */ + copy = len; + if (copy > pglen) + copy = pglen; + offs = len - copy; + if (offs >= tail->iov_len) + copy = 0; + else if (copy > tail->iov_len - offs) + copy = tail->iov_len - offs; + if (copy != 0) + _copy_from_pages((char *)tail->iov_base + offs, + buf->pages, + buf->page_base + pglen + offs - len, + copy); + /* Do we also need to copy data from the head into the tail ? */ + if (len > pglen) { + offs = copy = len - pglen; + if (copy > tail->iov_len) + copy = tail->iov_len; + memcpy(tail->iov_base, + (char *)head->iov_base + + head->iov_len - offs, + copy); + } + } + /* Now handle pages */ + if (pglen != 0) { + if (pglen > len) + _shift_data_right_pages(buf->pages, + buf->page_base + len, + buf->page_base, + pglen - len); + copy = len; + if (len > pglen) + copy = pglen; + _copy_to_pages(buf->pages, buf->page_base, + (char *)head->iov_base + head->iov_len - len, + copy); + } + head->iov_len -= len; + buf->buflen -= len; + /* Have we truncated the message? */ + if (buf->len > buf->buflen) + buf->len = buf->buflen; +} + +/* + * xdr_shrink_pagelen + * @buf: xdr_buf + * @len: bytes to remove from buf->pages + * + * Shrinks XDR buffer's page array buf->pages by + * 'len' bytes. The extra data is not lost, but is instead + * moved into the tail. + */ +static void +xdr_shrink_pagelen(struct xdr_buf *buf, size_t len) +{ + struct kvec *tail; + size_t copy; + char *p; + unsigned int pglen = buf->page_len; + + tail = buf->tail; + BUG_ON (len > pglen); + + /* Shift the tail first */ + if (tail->iov_len != 0) { + p = (char *)tail->iov_base + len; + if (tail->iov_len > len) { + copy = tail->iov_len - len; + memmove(p, tail->iov_base, copy); + } else + buf->buflen -= len; + /* Copy from the inlined pages into the tail */ + copy = len; + if (copy > tail->iov_len) + copy = tail->iov_len; + _copy_from_pages((char *)tail->iov_base, + buf->pages, buf->page_base + pglen - len, + copy); + } + buf->page_len -= len; + buf->buflen -= len; + /* Have we truncated the message? */ + if (buf->len > buf->buflen) + buf->len = buf->buflen; +} + +void +xdr_shift_buf(struct xdr_buf *buf, size_t len) +{ + xdr_shrink_bufhead(buf, len); +} +EXPORT_SYMBOL(xdr_shift_buf); + +/** + * xdr_init_encode - Initialize a struct xdr_stream for sending data. + * @xdr: pointer to xdr_stream struct + * @buf: pointer to XDR buffer in which to encode data + * @p: current pointer inside XDR buffer + * + * Note: at the moment the RPC client only passes the length of our + * scratch buffer in the xdr_buf's header kvec. Previously this + * meant we needed to call xdr_adjust_iovec() after encoding the + * data. With the new scheme, the xdr_stream manages the details + * of the buffer length, and takes care of adjusting the kvec + * length for us. + */ +void xdr_init_encode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p) +{ + struct kvec *iov = buf->head; + int scratch_len = buf->buflen - buf->page_len - buf->tail[0].iov_len; + + BUG_ON(scratch_len < 0); + xdr->buf = buf; + xdr->iov = iov; + xdr->p = (__be32 *)((char *)iov->iov_base + iov->iov_len); + xdr->end = (__be32 *)((char *)iov->iov_base + scratch_len); + BUG_ON(iov->iov_len > scratch_len); + + if (p != xdr->p && p != NULL) { + size_t len; + + BUG_ON(p < xdr->p || p > xdr->end); + len = (char *)p - (char *)xdr->p; + xdr->p = p; + buf->len += len; + iov->iov_len += len; + } +} +EXPORT_SYMBOL(xdr_init_encode); + +/** + * xdr_reserve_space - Reserve buffer space for sending + * @xdr: pointer to xdr_stream + * @nbytes: number of bytes to reserve + * + * Checks that we have enough buffer space to encode 'nbytes' more + * bytes of data. If so, update the total xdr_buf length, and + * adjust the length of the current kvec. + */ +__be32 * xdr_reserve_space(struct xdr_stream *xdr, size_t nbytes) +{ + __be32 *p = xdr->p; + __be32 *q; + + /* align nbytes on the next 32-bit boundary */ + nbytes += 3; + nbytes &= ~3; + q = p + (nbytes >> 2); + if (unlikely(q > xdr->end || q < p)) + return NULL; + xdr->p = q; + xdr->iov->iov_len += nbytes; + xdr->buf->len += nbytes; + return p; +} +EXPORT_SYMBOL(xdr_reserve_space); + +/** + * xdr_write_pages - Insert a list of pages into an XDR buffer for sending + * @xdr: pointer to xdr_stream + * @pages: list of pages + * @base: offset of first byte + * @len: length of data in bytes + * + */ +void xdr_write_pages(struct xdr_stream *xdr, struct page **pages, unsigned int base, + unsigned int len) +{ + struct xdr_buf *buf = xdr->buf; + struct kvec *iov = buf->tail; + buf->pages = pages; + buf->page_base = base; + buf->page_len = len; + + iov->iov_base = (char *)xdr->p; + iov->iov_len = 0; + xdr->iov = iov; + + if (len & 3) { + unsigned int pad = 4 - (len & 3); + + BUG_ON(xdr->p >= xdr->end); + iov->iov_base = (char *)xdr->p + (len & 3); + iov->iov_len += pad; + len += pad; + *xdr->p++ = 0; + } + buf->buflen += len; + buf->len += len; +} +EXPORT_SYMBOL(xdr_write_pages); + +/** + * xdr_init_decode - Initialize an xdr_stream for decoding data. + * @xdr: pointer to xdr_stream struct + * @buf: pointer to XDR buffer from which to decode data + * @p: current pointer inside XDR buffer + */ +void xdr_init_decode(struct xdr_stream *xdr, struct xdr_buf *buf, __be32 *p) +{ + struct kvec *iov = buf->head; + unsigned int len = iov->iov_len; + + if (len > buf->len) + len = buf->len; + xdr->buf = buf; + xdr->iov = iov; + xdr->p = p; + xdr->end = (__be32 *)((char *)iov->iov_base + len); +} +EXPORT_SYMBOL(xdr_init_decode); + +/** + * xdr_inline_decode - Retrieve non-page XDR data to decode + * @xdr: pointer to xdr_stream struct + * @nbytes: number of bytes of data to decode + * + * Check if the input buffer is long enough to enable us to decode + * 'nbytes' more bytes of data starting at the current position. + * If so return the current pointer, then update the current + * pointer position. + */ +__be32 * xdr_inline_decode(struct xdr_stream *xdr, size_t nbytes) +{ + __be32 *p = xdr->p; + __be32 *q = p + XDR_QUADLEN(nbytes); + + if (unlikely(q > xdr->end || q < p)) + return NULL; + xdr->p = q; + return p; +} +EXPORT_SYMBOL(xdr_inline_decode); + +/** + * xdr_read_pages - Ensure page-based XDR data to decode is aligned at current pointer position + * @xdr: pointer to xdr_stream struct + * @len: number of bytes of page data + * + * Moves data beyond the current pointer position from the XDR head[] buffer + * into the page list. Any data that lies beyond current position + "len" + * bytes is moved into the XDR tail[]. + */ +void xdr_read_pages(struct xdr_stream *xdr, unsigned int len) +{ + struct xdr_buf *buf = xdr->buf; + struct kvec *iov; + ssize_t shift; + unsigned int end; + int padding; + + /* Realign pages to current pointer position */ + iov = buf->head; + shift = iov->iov_len + (char *)iov->iov_base - (char *)xdr->p; + if (shift > 0) + xdr_shrink_bufhead(buf, shift); + + /* Truncate page data and move it into the tail */ + if (buf->page_len > len) + xdr_shrink_pagelen(buf, buf->page_len - len); + padding = (XDR_QUADLEN(len) << 2) - len; + xdr->iov = iov = buf->tail; + /* Compute remaining message length. */ + end = iov->iov_len; + shift = buf->buflen - buf->len; + if (shift < end) + end -= shift; + else if (shift > 0) + end = 0; + /* + * Position current pointer at beginning of tail, and + * set remaining message length. + */ + xdr->p = (__be32 *)((char *)iov->iov_base + padding); + xdr->end = (__be32 *)((char *)iov->iov_base + end); +} +EXPORT_SYMBOL(xdr_read_pages); + +/** + * xdr_enter_page - decode data from the XDR page + * @xdr: pointer to xdr_stream struct + * @len: number of bytes of page data + * + * Moves data beyond the current pointer position from the XDR head[] buffer + * into the page list. Any data that lies beyond current position + "len" + * bytes is moved into the XDR tail[]. The current pointer is then + * repositioned at the beginning of the first XDR page. + */ +void xdr_enter_page(struct xdr_stream *xdr, unsigned int len) +{ + char * kaddr = page_address(xdr->buf->pages[0]); + xdr_read_pages(xdr, len); + /* + * Position current pointer at beginning of tail, and + * set remaining message length. + */ + if (len > PAGE_CACHE_SIZE - xdr->buf->page_base) + len = PAGE_CACHE_SIZE - xdr->buf->page_base; + xdr->p = (__be32 *)(kaddr + xdr->buf->page_base); + xdr->end = (__be32 *)((char *)xdr->p + len); +} +EXPORT_SYMBOL(xdr_enter_page); + +static struct kvec empty_iov = {.iov_base = NULL, .iov_len = 0}; + +void +xdr_buf_from_iov(struct kvec *iov, struct xdr_buf *buf) +{ + buf->head[0] = *iov; + buf->tail[0] = empty_iov; + buf->page_len = 0; + buf->buflen = buf->len = iov->iov_len; +} +EXPORT_SYMBOL(xdr_buf_from_iov); + +/* Sets subbuf to the portion of buf of length len beginning base bytes + * from the start of buf. Returns -1 if base of length are out of bounds. */ +int +xdr_buf_subsegment(struct xdr_buf *buf, struct xdr_buf *subbuf, + unsigned int base, unsigned int len) +{ + subbuf->buflen = subbuf->len = len; + if (base < buf->head[0].iov_len) { + subbuf->head[0].iov_base = buf->head[0].iov_base + base; + subbuf->head[0].iov_len = min_t(unsigned int, len, + buf->head[0].iov_len - base); + len -= subbuf->head[0].iov_len; + base = 0; + } else { + subbuf->head[0].iov_base = NULL; + subbuf->head[0].iov_len = 0; + base -= buf->head[0].iov_len; + } + + if (base < buf->page_len) { + subbuf->page_len = min(buf->page_len - base, len); + base += buf->page_base; + subbuf->page_base = base & ~PAGE_CACHE_MASK; + subbuf->pages = &buf->pages[base >> PAGE_CACHE_SHIFT]; + len -= subbuf->page_len; + base = 0; + } else { + base -= buf->page_len; + subbuf->page_len = 0; + } + + if (base < buf->tail[0].iov_len) { + subbuf->tail[0].iov_base = buf->tail[0].iov_base + base; + subbuf->tail[0].iov_len = min_t(unsigned int, len, + buf->tail[0].iov_len - base); + len -= subbuf->tail[0].iov_len; + base = 0; + } else { + subbuf->tail[0].iov_base = NULL; + subbuf->tail[0].iov_len = 0; + base -= buf->tail[0].iov_len; + } + + if (base || len) + return -1; + return 0; +} +EXPORT_SYMBOL(xdr_buf_subsegment); + +static void __read_bytes_from_xdr_buf(struct xdr_buf *subbuf, void *obj, unsigned int len) +{ + unsigned int this_len; + + this_len = min_t(unsigned int, len, subbuf->head[0].iov_len); + memcpy(obj, subbuf->head[0].iov_base, this_len); + len -= this_len; + obj += this_len; + this_len = min_t(unsigned int, len, subbuf->page_len); + if (this_len) + _copy_from_pages(obj, subbuf->pages, subbuf->page_base, this_len); + len -= this_len; + obj += this_len; + this_len = min_t(unsigned int, len, subbuf->tail[0].iov_len); + memcpy(obj, subbuf->tail[0].iov_base, this_len); +} + +/* obj is assumed to point to allocated memory of size at least len: */ +int read_bytes_from_xdr_buf(struct xdr_buf *buf, unsigned int base, void *obj, unsigned int len) +{ + struct xdr_buf subbuf; + int status; + + status = xdr_buf_subsegment(buf, &subbuf, base, len); + if (status != 0) + return status; + __read_bytes_from_xdr_buf(&subbuf, obj, len); + return 0; +} +EXPORT_SYMBOL(read_bytes_from_xdr_buf); + +static void __write_bytes_to_xdr_buf(struct xdr_buf *subbuf, void *obj, unsigned int len) +{ + unsigned int this_len; + + this_len = min_t(unsigned int, len, subbuf->head[0].iov_len); + memcpy(subbuf->head[0].iov_base, obj, this_len); + len -= this_len; + obj += this_len; + this_len = min_t(unsigned int, len, subbuf->page_len); + if (this_len) + _copy_to_pages(subbuf->pages, subbuf->page_base, obj, this_len); + len -= this_len; + obj += this_len; + this_len = min_t(unsigned int, len, subbuf->tail[0].iov_len); + memcpy(subbuf->tail[0].iov_base, obj, this_len); +} + +/* obj is assumed to point to allocated memory of size at least len: */ +int write_bytes_to_xdr_buf(struct xdr_buf *buf, unsigned int base, void *obj, unsigned int len) +{ + struct xdr_buf subbuf; + int status; + + status = xdr_buf_subsegment(buf, &subbuf, base, len); + if (status != 0) + return status; + __write_bytes_to_xdr_buf(&subbuf, obj, len); + return 0; +} + +int +xdr_decode_word(struct xdr_buf *buf, unsigned int base, u32 *obj) +{ + __be32 raw; + int status; + + status = read_bytes_from_xdr_buf(buf, base, &raw, sizeof(*obj)); + if (status) + return status; + *obj = ntohl(raw); + return 0; +} +EXPORT_SYMBOL(xdr_decode_word); + +int +xdr_encode_word(struct xdr_buf *buf, unsigned int base, u32 obj) +{ + __be32 raw = htonl(obj); + + return write_bytes_to_xdr_buf(buf, base, &raw, sizeof(obj)); +} +EXPORT_SYMBOL(xdr_encode_word); + +/* If the netobj starting offset bytes from the start of xdr_buf is contained + * entirely in the head or the tail, set object to point to it; otherwise + * try to find space for it at the end of the tail, copy it there, and + * set obj to point to it. */ +int xdr_buf_read_netobj(struct xdr_buf *buf, struct xdr_netobj *obj, unsigned int offset) +{ + struct xdr_buf subbuf; + + if (xdr_decode_word(buf, offset, &obj->len)) + return -EFAULT; + if (xdr_buf_subsegment(buf, &subbuf, offset + 4, obj->len)) + return -EFAULT; + + /* Is the obj contained entirely in the head? */ + obj->data = subbuf.head[0].iov_base; + if (subbuf.head[0].iov_len == obj->len) + return 0; + /* ..or is the obj contained entirely in the tail? */ + obj->data = subbuf.tail[0].iov_base; + if (subbuf.tail[0].iov_len == obj->len) + return 0; + + /* use end of tail as storage for obj: + * (We don't copy to the beginning because then we'd have + * to worry about doing a potentially overlapping copy. + * This assumes the object is at most half the length of the + * tail.) */ + if (obj->len > buf->buflen - buf->len) + return -ENOMEM; + if (buf->tail[0].iov_len != 0) + obj->data = buf->tail[0].iov_base + buf->tail[0].iov_len; + else + obj->data = buf->head[0].iov_base + buf->head[0].iov_len; + __read_bytes_from_xdr_buf(&subbuf, obj->data, obj->len); + return 0; +} +EXPORT_SYMBOL(xdr_buf_read_netobj); + +/* Returns 0 on success, or else a negative error code. */ +static int +xdr_xcode_array2(struct xdr_buf *buf, unsigned int base, + struct xdr_array2_desc *desc, int encode) +{ + char *elem = NULL, *c; + unsigned int copied = 0, todo, avail_here; + struct page **ppages = NULL; + int err; + + if (encode) { + if (xdr_encode_word(buf, base, desc->array_len) != 0) + return -EINVAL; + } else { + if (xdr_decode_word(buf, base, &desc->array_len) != 0 || + desc->array_len > desc->array_maxlen || + (unsigned long) base + 4 + desc->array_len * + desc->elem_size > buf->len) + return -EINVAL; + } + base += 4; + + if (!desc->xcode) + return 0; + + todo = desc->array_len * desc->elem_size; + + /* process head */ + if (todo && base < buf->head->iov_len) { + c = buf->head->iov_base + base; + avail_here = min_t(unsigned int, todo, + buf->head->iov_len - base); + todo -= avail_here; + + while (avail_here >= desc->elem_size) { + err = desc->xcode(desc, c); + if (err) + goto out; + c += desc->elem_size; + avail_here -= desc->elem_size; + } + if (avail_here) { + if (!elem) { + elem = kmalloc(desc->elem_size, GFP_KERNEL); + err = -ENOMEM; + if (!elem) + goto out; + } + if (encode) { + err = desc->xcode(desc, elem); + if (err) + goto out; + memcpy(c, elem, avail_here); + } else + memcpy(elem, c, avail_here); + copied = avail_here; + } + base = buf->head->iov_len; /* align to start of pages */ + } + + /* process pages array */ + base -= buf->head->iov_len; + if (todo && base < buf->page_len) { + unsigned int avail_page; + + avail_here = min(todo, buf->page_len - base); + todo -= avail_here; + + base += buf->page_base; + ppages = buf->pages + (base >> PAGE_CACHE_SHIFT); + base &= ~PAGE_CACHE_MASK; + avail_page = min_t(unsigned int, PAGE_CACHE_SIZE - base, + avail_here); + c = kmap(*ppages) + base; + + while (avail_here) { + avail_here -= avail_page; + if (copied || avail_page < desc->elem_size) { + unsigned int l = min(avail_page, + desc->elem_size - copied); + if (!elem) { + elem = kmalloc(desc->elem_size, + GFP_KERNEL); + err = -ENOMEM; + if (!elem) + goto out; + } + if (encode) { + if (!copied) { + err = desc->xcode(desc, elem); + if (err) + goto out; + } + memcpy(c, elem + copied, l); + copied += l; + if (copied == desc->elem_size) + copied = 0; + } else { + memcpy(elem + copied, c, l); + copied += l; + if (copied == desc->elem_size) { + err = desc->xcode(desc, elem); + if (err) + goto out; + copied = 0; + } + } + avail_page -= l; + c += l; + } + while (avail_page >= desc->elem_size) { + err = desc->xcode(desc, c); + if (err) + goto out; + c += desc->elem_size; + avail_page -= desc->elem_size; + } + if (avail_page) { + unsigned int l = min(avail_page, + desc->elem_size - copied); + if (!elem) { + elem = kmalloc(desc->elem_size, + GFP_KERNEL); + err = -ENOMEM; + if (!elem) + goto out; + } + if (encode) { + if (!copied) { + err = desc->xcode(desc, elem); + if (err) + goto out; + } + memcpy(c, elem + copied, l); + copied += l; + if (copied == desc->elem_size) + copied = 0; + } else { + memcpy(elem + copied, c, l); + copied += l; + if (copied == desc->elem_size) { + err = desc->xcode(desc, elem); + if (err) + goto out; + copied = 0; + } + } + } + if (avail_here) { + kunmap(*ppages); + ppages++; + c = kmap(*ppages); + } + + avail_page = min(avail_here, + (unsigned int) PAGE_CACHE_SIZE); + } + base = buf->page_len; /* align to start of tail */ + } + + /* process tail */ + base -= buf->page_len; + if (todo) { + c = buf->tail->iov_base + base; + if (copied) { + unsigned int l = desc->elem_size - copied; + + if (encode) + memcpy(c, elem + copied, l); + else { + memcpy(elem + copied, c, l); + err = desc->xcode(desc, elem); + if (err) + goto out; + } + todo -= l; + c += l; + } + while (todo) { + err = desc->xcode(desc, c); + if (err) + goto out; + c += desc->elem_size; + todo -= desc->elem_size; + } + } + err = 0; + +out: + kfree(elem); + if (ppages) + kunmap(*ppages); + return err; +} + +int +xdr_decode_array2(struct xdr_buf *buf, unsigned int base, + struct xdr_array2_desc *desc) +{ + if (base >= buf->len) + return -EINVAL; + + return xdr_xcode_array2(buf, base, desc, 0); +} +EXPORT_SYMBOL(xdr_decode_array2); + +int +xdr_encode_array2(struct xdr_buf *buf, unsigned int base, + struct xdr_array2_desc *desc) +{ + if ((unsigned long) base + 4 + desc->array_len * desc->elem_size > + buf->head->iov_len + buf->page_len + buf->tail->iov_len) + return -EINVAL; + + return xdr_xcode_array2(buf, base, desc, 1); +} +EXPORT_SYMBOL(xdr_encode_array2); + +int +xdr_process_buf(struct xdr_buf *buf, unsigned int offset, unsigned int len, + int (*actor)(struct scatterlist *, void *), void *data) +{ + int i, ret = 0; + unsigned page_len, thislen, page_offset; + struct scatterlist sg[1]; + + sg_init_table(sg, 1); + + if (offset >= buf->head[0].iov_len) { + offset -= buf->head[0].iov_len; + } else { + thislen = buf->head[0].iov_len - offset; + if (thislen > len) + thislen = len; + sg_set_buf(sg, buf->head[0].iov_base + offset, thislen); + ret = actor(sg, data); + if (ret) + goto out; + offset = 0; + len -= thislen; + } + if (len == 0) + goto out; + + if (offset >= buf->page_len) { + offset -= buf->page_len; + } else { + page_len = buf->page_len - offset; + if (page_len > len) + page_len = len; + len -= page_len; + page_offset = (offset + buf->page_base) & (PAGE_CACHE_SIZE - 1); + i = (offset + buf->page_base) >> PAGE_CACHE_SHIFT; + thislen = PAGE_CACHE_SIZE - page_offset; + do { + if (thislen > page_len) + thislen = page_len; + sg_set_page(sg, buf->pages[i], thislen, page_offset); + ret = actor(sg, data); + if (ret) + goto out; + page_len -= thislen; + i++; + page_offset = 0; + thislen = PAGE_CACHE_SIZE; + } while (page_len != 0); + offset = 0; + } + if (len == 0) + goto out; + if (offset < buf->tail[0].iov_len) { + thislen = buf->tail[0].iov_len - offset; + if (thislen > len) + thislen = len; + sg_set_buf(sg, buf->tail[0].iov_base + offset, thislen); + ret = actor(sg, data); + len -= thislen; + } + if (len != 0) + ret = -EINVAL; +out: + return ret; +} +EXPORT_SYMBOL(xdr_process_buf); + diff --git a/net/sunrpc/xprt.c b/net/sunrpc/xprt.c new file mode 100644 index 0000000..29e401b --- /dev/null +++ b/net/sunrpc/xprt.c @@ -0,0 +1,1105 @@ +/* + * linux/net/sunrpc/xprt.c + * + * This is a generic RPC call interface supporting congestion avoidance, + * and asynchronous calls. + * + * The interface works like this: + * + * - When a process places a call, it allocates a request slot if + * one is available. Otherwise, it sleeps on the backlog queue + * (xprt_reserve). + * - Next, the caller puts together the RPC message, stuffs it into + * the request struct, and calls xprt_transmit(). + * - xprt_transmit sends the message and installs the caller on the + * transport's wait list. At the same time, it installs a timer that + * is run after the packet's timeout has expired. + * - When a packet arrives, the data_ready handler walks the list of + * pending requests for that transport. If a matching XID is found, the + * caller is woken up, and the timer removed. + * - When no reply arrives within the timeout interval, the timer is + * fired by the kernel and runs xprt_timer(). It either adjusts the + * timeout values (minor timeout) or wakes up the caller with a status + * of -ETIMEDOUT. + * - When the caller receives a notification from RPC that a reply arrived, + * it should release the RPC slot, and process the reply. + * If the call timed out, it may choose to retry the operation by + * adjusting the initial timeout value, and simply calling rpc_call + * again. + * + * Support for async RPC is done through a set of RPC-specific scheduling + * primitives that `transparently' work for processes as well as async + * tasks that rely on callbacks. + * + * Copyright (C) 1995-1997, Olaf Kirch <okir@monad.swb.de> + * + * Transport switch API copyright (C) 2005, Chuck Lever <cel@netapp.com> + */ + +#include <linux/module.h> + +#include <linux/types.h> +#include <linux/interrupt.h> +#include <linux/workqueue.h> +#include <linux/net.h> + +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/metrics.h> + +/* + * Local variables + */ + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_XPRT +#endif + +/* + * Local functions + */ +static void xprt_request_init(struct rpc_task *, struct rpc_xprt *); +static inline void do_xprt_reserve(struct rpc_task *); +static void xprt_connect_status(struct rpc_task *task); +static int __xprt_get_cong(struct rpc_xprt *, struct rpc_task *); + +static DEFINE_SPINLOCK(xprt_list_lock); +static LIST_HEAD(xprt_list); + +/* + * The transport code maintains an estimate on the maximum number of out- + * standing RPC requests, using a smoothed version of the congestion + * avoidance implemented in 44BSD. This is basically the Van Jacobson + * congestion algorithm: If a retransmit occurs, the congestion window is + * halved; otherwise, it is incremented by 1/cwnd when + * + * - a reply is received and + * - a full number of requests are outstanding and + * - the congestion window hasn't been updated recently. + */ +#define RPC_CWNDSHIFT (8U) +#define RPC_CWNDSCALE (1U << RPC_CWNDSHIFT) +#define RPC_INITCWND RPC_CWNDSCALE +#define RPC_MAXCWND(xprt) ((xprt)->max_reqs << RPC_CWNDSHIFT) + +#define RPCXPRT_CONGESTED(xprt) ((xprt)->cong >= (xprt)->cwnd) + +/** + * xprt_register_transport - register a transport implementation + * @transport: transport to register + * + * If a transport implementation is loaded as a kernel module, it can + * call this interface to make itself known to the RPC client. + * + * Returns: + * 0: transport successfully registered + * -EEXIST: transport already registered + * -EINVAL: transport module being unloaded + */ +int xprt_register_transport(struct xprt_class *transport) +{ + struct xprt_class *t; + int result; + + result = -EEXIST; + spin_lock(&xprt_list_lock); + list_for_each_entry(t, &xprt_list, list) { + /* don't register the same transport class twice */ + if (t->ident == transport->ident) + goto out; + } + + list_add_tail(&transport->list, &xprt_list); + printk(KERN_INFO "RPC: Registered %s transport module.\n", + transport->name); + result = 0; + +out: + spin_unlock(&xprt_list_lock); + return result; +} +EXPORT_SYMBOL_GPL(xprt_register_transport); + +/** + * xprt_unregister_transport - unregister a transport implementation + * @transport: transport to unregister + * + * Returns: + * 0: transport successfully unregistered + * -ENOENT: transport never registered + */ +int xprt_unregister_transport(struct xprt_class *transport) +{ + struct xprt_class *t; + int result; + + result = 0; + spin_lock(&xprt_list_lock); + list_for_each_entry(t, &xprt_list, list) { + if (t == transport) { + printk(KERN_INFO + "RPC: Unregistered %s transport module.\n", + transport->name); + list_del_init(&transport->list); + goto out; + } + } + result = -ENOENT; + +out: + spin_unlock(&xprt_list_lock); + return result; +} +EXPORT_SYMBOL_GPL(xprt_unregister_transport); + +/** + * xprt_reserve_xprt - serialize write access to transports + * @task: task that is requesting access to the transport + * + * This prevents mixing the payload of separate requests, and prevents + * transport connects from colliding with writes. No congestion control + * is provided. + */ +int xprt_reserve_xprt(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_rqst *req = task->tk_rqstp; + + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) { + if (task == xprt->snd_task) + return 1; + if (task == NULL) + return 0; + goto out_sleep; + } + xprt->snd_task = task; + if (req) { + req->rq_bytes_sent = 0; + req->rq_ntrans++; + } + return 1; + +out_sleep: + dprintk("RPC: %5u failed to lock transport %p\n", + task->tk_pid, xprt); + task->tk_timeout = 0; + task->tk_status = -EAGAIN; + if (req && req->rq_ntrans) + rpc_sleep_on(&xprt->resend, task, NULL); + else + rpc_sleep_on(&xprt->sending, task, NULL); + return 0; +} +EXPORT_SYMBOL_GPL(xprt_reserve_xprt); + +static void xprt_clear_locked(struct rpc_xprt *xprt) +{ + xprt->snd_task = NULL; + if (!test_bit(XPRT_CLOSE_WAIT, &xprt->state) || xprt->shutdown) { + smp_mb__before_clear_bit(); + clear_bit(XPRT_LOCKED, &xprt->state); + smp_mb__after_clear_bit(); + } else + queue_work(rpciod_workqueue, &xprt->task_cleanup); +} + +/* + * xprt_reserve_xprt_cong - serialize write access to transports + * @task: task that is requesting access to the transport + * + * Same as xprt_reserve_xprt, but Van Jacobson congestion control is + * integrated into the decision of whether a request is allowed to be + * woken up and given access to the transport. + */ +int xprt_reserve_xprt_cong(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_rqst *req = task->tk_rqstp; + + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) { + if (task == xprt->snd_task) + return 1; + goto out_sleep; + } + if (__xprt_get_cong(xprt, task)) { + xprt->snd_task = task; + if (req) { + req->rq_bytes_sent = 0; + req->rq_ntrans++; + } + return 1; + } + xprt_clear_locked(xprt); +out_sleep: + dprintk("RPC: %5u failed to lock transport %p\n", task->tk_pid, xprt); + task->tk_timeout = 0; + task->tk_status = -EAGAIN; + if (req && req->rq_ntrans) + rpc_sleep_on(&xprt->resend, task, NULL); + else + rpc_sleep_on(&xprt->sending, task, NULL); + return 0; +} +EXPORT_SYMBOL_GPL(xprt_reserve_xprt_cong); + +static inline int xprt_lock_write(struct rpc_xprt *xprt, struct rpc_task *task) +{ + int retval; + + spin_lock_bh(&xprt->transport_lock); + retval = xprt->ops->reserve_xprt(task); + spin_unlock_bh(&xprt->transport_lock); + return retval; +} + +static void __xprt_lock_write_next(struct rpc_xprt *xprt) +{ + struct rpc_task *task; + struct rpc_rqst *req; + + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) + return; + + task = rpc_wake_up_next(&xprt->resend); + if (!task) { + task = rpc_wake_up_next(&xprt->sending); + if (!task) + goto out_unlock; + } + + req = task->tk_rqstp; + xprt->snd_task = task; + if (req) { + req->rq_bytes_sent = 0; + req->rq_ntrans++; + } + return; + +out_unlock: + xprt_clear_locked(xprt); +} + +static void __xprt_lock_write_next_cong(struct rpc_xprt *xprt) +{ + struct rpc_task *task; + + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) + return; + if (RPCXPRT_CONGESTED(xprt)) + goto out_unlock; + task = rpc_wake_up_next(&xprt->resend); + if (!task) { + task = rpc_wake_up_next(&xprt->sending); + if (!task) + goto out_unlock; + } + if (__xprt_get_cong(xprt, task)) { + struct rpc_rqst *req = task->tk_rqstp; + xprt->snd_task = task; + if (req) { + req->rq_bytes_sent = 0; + req->rq_ntrans++; + } + return; + } +out_unlock: + xprt_clear_locked(xprt); +} + +/** + * xprt_release_xprt - allow other requests to use a transport + * @xprt: transport with other tasks potentially waiting + * @task: task that is releasing access to the transport + * + * Note that "task" can be NULL. No congestion control is provided. + */ +void xprt_release_xprt(struct rpc_xprt *xprt, struct rpc_task *task) +{ + if (xprt->snd_task == task) { + xprt_clear_locked(xprt); + __xprt_lock_write_next(xprt); + } +} +EXPORT_SYMBOL_GPL(xprt_release_xprt); + +/** + * xprt_release_xprt_cong - allow other requests to use a transport + * @xprt: transport with other tasks potentially waiting + * @task: task that is releasing access to the transport + * + * Note that "task" can be NULL. Another task is awoken to use the + * transport if the transport's congestion window allows it. + */ +void xprt_release_xprt_cong(struct rpc_xprt *xprt, struct rpc_task *task) +{ + if (xprt->snd_task == task) { + xprt_clear_locked(xprt); + __xprt_lock_write_next_cong(xprt); + } +} +EXPORT_SYMBOL_GPL(xprt_release_xprt_cong); + +static inline void xprt_release_write(struct rpc_xprt *xprt, struct rpc_task *task) +{ + spin_lock_bh(&xprt->transport_lock); + xprt->ops->release_xprt(xprt, task); + spin_unlock_bh(&xprt->transport_lock); +} + +/* + * Van Jacobson congestion avoidance. Check if the congestion window + * overflowed. Put the task to sleep if this is the case. + */ +static int +__xprt_get_cong(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + + if (req->rq_cong) + return 1; + dprintk("RPC: %5u xprt_cwnd_limited cong = %lu cwnd = %lu\n", + task->tk_pid, xprt->cong, xprt->cwnd); + if (RPCXPRT_CONGESTED(xprt)) + return 0; + req->rq_cong = 1; + xprt->cong += RPC_CWNDSCALE; + return 1; +} + +/* + * Adjust the congestion window, and wake up the next task + * that has been sleeping due to congestion + */ +static void +__xprt_put_cong(struct rpc_xprt *xprt, struct rpc_rqst *req) +{ + if (!req->rq_cong) + return; + req->rq_cong = 0; + xprt->cong -= RPC_CWNDSCALE; + __xprt_lock_write_next_cong(xprt); +} + +/** + * xprt_release_rqst_cong - housekeeping when request is complete + * @task: RPC request that recently completed + * + * Useful for transports that require congestion control. + */ +void xprt_release_rqst_cong(struct rpc_task *task) +{ + __xprt_put_cong(task->tk_xprt, task->tk_rqstp); +} +EXPORT_SYMBOL_GPL(xprt_release_rqst_cong); + +/** + * xprt_adjust_cwnd - adjust transport congestion window + * @task: recently completed RPC request used to adjust window + * @result: result code of completed RPC request + * + * We use a time-smoothed congestion estimator to avoid heavy oscillation. + */ +void xprt_adjust_cwnd(struct rpc_task *task, int result) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = task->tk_xprt; + unsigned long cwnd = xprt->cwnd; + + if (result >= 0 && cwnd <= xprt->cong) { + /* The (cwnd >> 1) term makes sure + * the result gets rounded properly. */ + cwnd += (RPC_CWNDSCALE * RPC_CWNDSCALE + (cwnd >> 1)) / cwnd; + if (cwnd > RPC_MAXCWND(xprt)) + cwnd = RPC_MAXCWND(xprt); + __xprt_lock_write_next_cong(xprt); + } else if (result == -ETIMEDOUT) { + cwnd >>= 1; + if (cwnd < RPC_CWNDSCALE) + cwnd = RPC_CWNDSCALE; + } + dprintk("RPC: cong %ld, cwnd was %ld, now %ld\n", + xprt->cong, xprt->cwnd, cwnd); + xprt->cwnd = cwnd; + __xprt_put_cong(xprt, req); +} +EXPORT_SYMBOL_GPL(xprt_adjust_cwnd); + +/** + * xprt_wake_pending_tasks - wake all tasks on a transport's pending queue + * @xprt: transport with waiting tasks + * @status: result code to plant in each task before waking it + * + */ +void xprt_wake_pending_tasks(struct rpc_xprt *xprt, int status) +{ + if (status < 0) + rpc_wake_up_status(&xprt->pending, status); + else + rpc_wake_up(&xprt->pending); +} +EXPORT_SYMBOL_GPL(xprt_wake_pending_tasks); + +/** + * xprt_wait_for_buffer_space - wait for transport output buffer to clear + * @task: task to be put to sleep + * @action: function pointer to be executed after wait + */ +void xprt_wait_for_buffer_space(struct rpc_task *task, rpc_action action) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + task->tk_timeout = req->rq_timeout; + rpc_sleep_on(&xprt->pending, task, action); +} +EXPORT_SYMBOL_GPL(xprt_wait_for_buffer_space); + +/** + * xprt_write_space - wake the task waiting for transport output buffer space + * @xprt: transport with waiting tasks + * + * Can be called in a soft IRQ context, so xprt_write_space never sleeps. + */ +void xprt_write_space(struct rpc_xprt *xprt) +{ + if (unlikely(xprt->shutdown)) + return; + + spin_lock_bh(&xprt->transport_lock); + if (xprt->snd_task) { + dprintk("RPC: write space: waking waiting task on " + "xprt %p\n", xprt); + rpc_wake_up_queued_task(&xprt->pending, xprt->snd_task); + } + spin_unlock_bh(&xprt->transport_lock); +} +EXPORT_SYMBOL_GPL(xprt_write_space); + +/** + * xprt_set_retrans_timeout_def - set a request's retransmit timeout + * @task: task whose timeout is to be set + * + * Set a request's retransmit timeout based on the transport's + * default timeout parameters. Used by transports that don't adjust + * the retransmit timeout based on round-trip time estimation. + */ +void xprt_set_retrans_timeout_def(struct rpc_task *task) +{ + task->tk_timeout = task->tk_rqstp->rq_timeout; +} +EXPORT_SYMBOL_GPL(xprt_set_retrans_timeout_def); + +/* + * xprt_set_retrans_timeout_rtt - set a request's retransmit timeout + * @task: task whose timeout is to be set + * + * Set a request's retransmit timeout using the RTT estimator. + */ +void xprt_set_retrans_timeout_rtt(struct rpc_task *task) +{ + int timer = task->tk_msg.rpc_proc->p_timer; + struct rpc_clnt *clnt = task->tk_client; + struct rpc_rtt *rtt = clnt->cl_rtt; + struct rpc_rqst *req = task->tk_rqstp; + unsigned long max_timeout = clnt->cl_timeout->to_maxval; + + task->tk_timeout = rpc_calc_rto(rtt, timer); + task->tk_timeout <<= rpc_ntimeo(rtt, timer) + req->rq_retries; + if (task->tk_timeout > max_timeout || task->tk_timeout == 0) + task->tk_timeout = max_timeout; +} +EXPORT_SYMBOL_GPL(xprt_set_retrans_timeout_rtt); + +static void xprt_reset_majortimeo(struct rpc_rqst *req) +{ + const struct rpc_timeout *to = req->rq_task->tk_client->cl_timeout; + + req->rq_majortimeo = req->rq_timeout; + if (to->to_exponential) + req->rq_majortimeo <<= to->to_retries; + else + req->rq_majortimeo += to->to_increment * to->to_retries; + if (req->rq_majortimeo > to->to_maxval || req->rq_majortimeo == 0) + req->rq_majortimeo = to->to_maxval; + req->rq_majortimeo += jiffies; +} + +/** + * xprt_adjust_timeout - adjust timeout values for next retransmit + * @req: RPC request containing parameters to use for the adjustment + * + */ +int xprt_adjust_timeout(struct rpc_rqst *req) +{ + struct rpc_xprt *xprt = req->rq_xprt; + const struct rpc_timeout *to = req->rq_task->tk_client->cl_timeout; + int status = 0; + + if (time_before(jiffies, req->rq_majortimeo)) { + if (to->to_exponential) + req->rq_timeout <<= 1; + else + req->rq_timeout += to->to_increment; + if (to->to_maxval && req->rq_timeout >= to->to_maxval) + req->rq_timeout = to->to_maxval; + req->rq_retries++; + } else { + req->rq_timeout = to->to_initval; + req->rq_retries = 0; + xprt_reset_majortimeo(req); + /* Reset the RTT counters == "slow start" */ + spin_lock_bh(&xprt->transport_lock); + rpc_init_rtt(req->rq_task->tk_client->cl_rtt, to->to_initval); + spin_unlock_bh(&xprt->transport_lock); + status = -ETIMEDOUT; + } + + if (req->rq_timeout == 0) { + printk(KERN_WARNING "xprt_adjust_timeout: rq_timeout = 0!\n"); + req->rq_timeout = 5 * HZ; + } + return status; +} + +static void xprt_autoclose(struct work_struct *work) +{ + struct rpc_xprt *xprt = + container_of(work, struct rpc_xprt, task_cleanup); + + xprt->ops->close(xprt); + clear_bit(XPRT_CLOSE_WAIT, &xprt->state); + xprt_release_write(xprt, NULL); +} + +/** + * xprt_disconnect_done - mark a transport as disconnected + * @xprt: transport to flag for disconnect + * + */ +void xprt_disconnect_done(struct rpc_xprt *xprt) +{ + dprintk("RPC: disconnected transport %p\n", xprt); + spin_lock_bh(&xprt->transport_lock); + xprt_clear_connected(xprt); + xprt_wake_pending_tasks(xprt, -ENOTCONN); + spin_unlock_bh(&xprt->transport_lock); +} +EXPORT_SYMBOL_GPL(xprt_disconnect_done); + +/** + * xprt_force_disconnect - force a transport to disconnect + * @xprt: transport to disconnect + * + */ +void xprt_force_disconnect(struct rpc_xprt *xprt) +{ + /* Don't race with the test_bit() in xprt_clear_locked() */ + spin_lock_bh(&xprt->transport_lock); + set_bit(XPRT_CLOSE_WAIT, &xprt->state); + /* Try to schedule an autoclose RPC call */ + if (test_and_set_bit(XPRT_LOCKED, &xprt->state) == 0) + queue_work(rpciod_workqueue, &xprt->task_cleanup); + xprt_wake_pending_tasks(xprt, -ENOTCONN); + spin_unlock_bh(&xprt->transport_lock); +} + +/** + * xprt_conditional_disconnect - force a transport to disconnect + * @xprt: transport to disconnect + * @cookie: 'connection cookie' + * + * This attempts to break the connection if and only if 'cookie' matches + * the current transport 'connection cookie'. It ensures that we don't + * try to break the connection more than once when we need to retransmit + * a batch of RPC requests. + * + */ +void xprt_conditional_disconnect(struct rpc_xprt *xprt, unsigned int cookie) +{ + /* Don't race with the test_bit() in xprt_clear_locked() */ + spin_lock_bh(&xprt->transport_lock); + if (cookie != xprt->connect_cookie) + goto out; + if (test_bit(XPRT_CLOSING, &xprt->state) || !xprt_connected(xprt)) + goto out; + set_bit(XPRT_CLOSE_WAIT, &xprt->state); + /* Try to schedule an autoclose RPC call */ + if (test_and_set_bit(XPRT_LOCKED, &xprt->state) == 0) + queue_work(rpciod_workqueue, &xprt->task_cleanup); + xprt_wake_pending_tasks(xprt, -ENOTCONN); +out: + spin_unlock_bh(&xprt->transport_lock); +} + +static void +xprt_init_autodisconnect(unsigned long data) +{ + struct rpc_xprt *xprt = (struct rpc_xprt *)data; + + spin_lock(&xprt->transport_lock); + if (!list_empty(&xprt->recv) || xprt->shutdown) + goto out_abort; + if (test_and_set_bit(XPRT_LOCKED, &xprt->state)) + goto out_abort; + spin_unlock(&xprt->transport_lock); + if (xprt_connecting(xprt)) + xprt_release_write(xprt, NULL); + else + queue_work(rpciod_workqueue, &xprt->task_cleanup); + return; +out_abort: + spin_unlock(&xprt->transport_lock); +} + +/** + * xprt_connect - schedule a transport connect operation + * @task: RPC task that is requesting the connect + * + */ +void xprt_connect(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + + dprintk("RPC: %5u xprt_connect xprt %p %s connected\n", task->tk_pid, + xprt, (xprt_connected(xprt) ? "is" : "is not")); + + if (!xprt_bound(xprt)) { + task->tk_status = -EIO; + return; + } + if (!xprt_lock_write(xprt, task)) + return; + if (xprt_connected(xprt)) + xprt_release_write(xprt, task); + else { + if (task->tk_rqstp) + task->tk_rqstp->rq_bytes_sent = 0; + + task->tk_timeout = xprt->connect_timeout; + rpc_sleep_on(&xprt->pending, task, xprt_connect_status); + xprt->stat.connect_start = jiffies; + xprt->ops->connect(task); + } + return; +} + +static void xprt_connect_status(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + + if (task->tk_status == 0) { + xprt->stat.connect_count++; + xprt->stat.connect_time += (long)jiffies - xprt->stat.connect_start; + dprintk("RPC: %5u xprt_connect_status: connection established\n", + task->tk_pid); + return; + } + + switch (task->tk_status) { + case -ENOTCONN: + dprintk("RPC: %5u xprt_connect_status: connection broken\n", + task->tk_pid); + break; + case -ETIMEDOUT: + dprintk("RPC: %5u xprt_connect_status: connect attempt timed " + "out\n", task->tk_pid); + break; + default: + dprintk("RPC: %5u xprt_connect_status: error %d connecting to " + "server %s\n", task->tk_pid, -task->tk_status, + task->tk_client->cl_server); + xprt_release_write(xprt, task); + task->tk_status = -EIO; + } +} + +/** + * xprt_lookup_rqst - find an RPC request corresponding to an XID + * @xprt: transport on which the original request was transmitted + * @xid: RPC XID of incoming reply + * + */ +struct rpc_rqst *xprt_lookup_rqst(struct rpc_xprt *xprt, __be32 xid) +{ + struct list_head *pos; + + list_for_each(pos, &xprt->recv) { + struct rpc_rqst *entry = list_entry(pos, struct rpc_rqst, rq_list); + if (entry->rq_xid == xid) + return entry; + } + + dprintk("RPC: xprt_lookup_rqst did not find xid %08x\n", + ntohl(xid)); + xprt->stat.bad_xids++; + return NULL; +} +EXPORT_SYMBOL_GPL(xprt_lookup_rqst); + +/** + * xprt_update_rtt - update an RPC client's RTT state after receiving a reply + * @task: RPC request that recently completed + * + */ +void xprt_update_rtt(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_rtt *rtt = task->tk_client->cl_rtt; + unsigned timer = task->tk_msg.rpc_proc->p_timer; + + if (timer) { + if (req->rq_ntrans == 1) + rpc_update_rtt(rtt, timer, + (long)jiffies - req->rq_xtime); + rpc_set_timeo(rtt, timer, req->rq_ntrans - 1); + } +} +EXPORT_SYMBOL_GPL(xprt_update_rtt); + +/** + * xprt_complete_rqst - called when reply processing is complete + * @task: RPC request that recently completed + * @copied: actual number of bytes received from the transport + * + * Caller holds transport lock. + */ +void xprt_complete_rqst(struct rpc_task *task, int copied) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + dprintk("RPC: %5u xid %08x complete (%d bytes received)\n", + task->tk_pid, ntohl(req->rq_xid), copied); + + xprt->stat.recvs++; + task->tk_rtt = (long)jiffies - req->rq_xtime; + + list_del_init(&req->rq_list); + req->rq_private_buf.len = copied; + /* Ensure all writes are done before we update req->rq_received */ + smp_wmb(); + req->rq_received = copied; + rpc_wake_up_queued_task(&xprt->pending, task); +} +EXPORT_SYMBOL_GPL(xprt_complete_rqst); + +static void xprt_timer(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + + if (task->tk_status != -ETIMEDOUT) + return; + dprintk("RPC: %5u xprt_timer\n", task->tk_pid); + + spin_lock_bh(&xprt->transport_lock); + if (!req->rq_received) { + if (xprt->ops->timer) + xprt->ops->timer(task); + } else + task->tk_status = 0; + spin_unlock_bh(&xprt->transport_lock); +} + +/** + * xprt_prepare_transmit - reserve the transport before sending a request + * @task: RPC task about to send a request + * + */ +int xprt_prepare_transmit(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + int err = 0; + + dprintk("RPC: %5u xprt_prepare_transmit\n", task->tk_pid); + + spin_lock_bh(&xprt->transport_lock); + if (req->rq_received && !req->rq_bytes_sent) { + err = req->rq_received; + goto out_unlock; + } + if (!xprt->ops->reserve_xprt(task)) { + err = -EAGAIN; + goto out_unlock; + } + + if (!xprt_connected(xprt)) { + err = -ENOTCONN; + goto out_unlock; + } +out_unlock: + spin_unlock_bh(&xprt->transport_lock); + return err; +} + +void xprt_end_transmit(struct rpc_task *task) +{ + xprt_release_write(task->tk_xprt, task); +} + +/** + * xprt_transmit - send an RPC request on a transport + * @task: controlling RPC task + * + * We have to copy the iovec because sendmsg fiddles with its contents. + */ +void xprt_transmit(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + int status; + + dprintk("RPC: %5u xprt_transmit(%u)\n", task->tk_pid, req->rq_slen); + + if (!req->rq_received) { + if (list_empty(&req->rq_list)) { + spin_lock_bh(&xprt->transport_lock); + /* Update the softirq receive buffer */ + memcpy(&req->rq_private_buf, &req->rq_rcv_buf, + sizeof(req->rq_private_buf)); + /* Add request to the receive list */ + list_add_tail(&req->rq_list, &xprt->recv); + spin_unlock_bh(&xprt->transport_lock); + xprt_reset_majortimeo(req); + /* Turn off autodisconnect */ + del_singleshot_timer_sync(&xprt->timer); + } + } else if (!req->rq_bytes_sent) + return; + + req->rq_connect_cookie = xprt->connect_cookie; + req->rq_xtime = jiffies; + status = xprt->ops->send_request(task); + if (status == 0) { + dprintk("RPC: %5u xmit complete\n", task->tk_pid); + spin_lock_bh(&xprt->transport_lock); + + xprt->ops->set_retrans_timeout(task); + + xprt->stat.sends++; + xprt->stat.req_u += xprt->stat.sends - xprt->stat.recvs; + xprt->stat.bklog_u += xprt->backlog.qlen; + + /* Don't race with disconnect */ + if (!xprt_connected(xprt)) + task->tk_status = -ENOTCONN; + else if (!req->rq_received) + rpc_sleep_on(&xprt->pending, task, xprt_timer); + spin_unlock_bh(&xprt->transport_lock); + return; + } + + /* Note: at this point, task->tk_sleeping has not yet been set, + * hence there is no danger of the waking up task being put on + * schedq, and being picked up by a parallel run of rpciod(). + */ + task->tk_status = status; + if (status == -ECONNREFUSED) + rpc_sleep_on(&xprt->sending, task, NULL); +} + +static inline void do_xprt_reserve(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + + task->tk_status = 0; + if (task->tk_rqstp) + return; + if (!list_empty(&xprt->free)) { + struct rpc_rqst *req = list_entry(xprt->free.next, struct rpc_rqst, rq_list); + list_del_init(&req->rq_list); + task->tk_rqstp = req; + xprt_request_init(task, xprt); + return; + } + dprintk("RPC: waiting for request slot\n"); + task->tk_status = -EAGAIN; + task->tk_timeout = 0; + rpc_sleep_on(&xprt->backlog, task, NULL); +} + +/** + * xprt_reserve - allocate an RPC request slot + * @task: RPC task requesting a slot allocation + * + * If no more slots are available, place the task on the transport's + * backlog queue. + */ +void xprt_reserve(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + + task->tk_status = -EIO; + spin_lock(&xprt->reserve_lock); + do_xprt_reserve(task); + spin_unlock(&xprt->reserve_lock); +} + +static inline __be32 xprt_alloc_xid(struct rpc_xprt *xprt) +{ + return xprt->xid++; +} + +static inline void xprt_init_xid(struct rpc_xprt *xprt) +{ + xprt->xid = net_random(); +} + +static void xprt_request_init(struct rpc_task *task, struct rpc_xprt *xprt) +{ + struct rpc_rqst *req = task->tk_rqstp; + + req->rq_timeout = task->tk_client->cl_timeout->to_initval; + req->rq_task = task; + req->rq_xprt = xprt; + req->rq_buffer = NULL; + req->rq_xid = xprt_alloc_xid(xprt); + req->rq_release_snd_buf = NULL; + xprt_reset_majortimeo(req); + dprintk("RPC: %5u reserved req %p xid %08x\n", task->tk_pid, + req, ntohl(req->rq_xid)); +} + +/** + * xprt_release - release an RPC request slot + * @task: task which is finished with the slot + * + */ +void xprt_release(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + struct rpc_rqst *req; + + if (!(req = task->tk_rqstp)) + return; + rpc_count_iostats(task); + spin_lock_bh(&xprt->transport_lock); + xprt->ops->release_xprt(xprt, task); + if (xprt->ops->release_request) + xprt->ops->release_request(task); + if (!list_empty(&req->rq_list)) + list_del(&req->rq_list); + xprt->last_used = jiffies; + if (list_empty(&xprt->recv)) + mod_timer(&xprt->timer, + xprt->last_used + xprt->idle_timeout); + spin_unlock_bh(&xprt->transport_lock); + xprt->ops->buf_free(req->rq_buffer); + task->tk_rqstp = NULL; + if (req->rq_release_snd_buf) + req->rq_release_snd_buf(req); + memset(req, 0, sizeof(*req)); /* mark unused */ + + dprintk("RPC: %5u release request %p\n", task->tk_pid, req); + + spin_lock(&xprt->reserve_lock); + list_add(&req->rq_list, &xprt->free); + rpc_wake_up_next(&xprt->backlog); + spin_unlock(&xprt->reserve_lock); +} + +/** + * xprt_create_transport - create an RPC transport + * @args: rpc transport creation arguments + * + */ +struct rpc_xprt *xprt_create_transport(struct xprt_create *args) +{ + struct rpc_xprt *xprt; + struct rpc_rqst *req; + struct xprt_class *t; + + spin_lock(&xprt_list_lock); + list_for_each_entry(t, &xprt_list, list) { + if (t->ident == args->ident) { + spin_unlock(&xprt_list_lock); + goto found; + } + } + spin_unlock(&xprt_list_lock); + printk(KERN_ERR "RPC: transport (%d) not supported\n", args->ident); + return ERR_PTR(-EIO); + +found: + xprt = t->setup(args); + if (IS_ERR(xprt)) { + dprintk("RPC: xprt_create_transport: failed, %ld\n", + -PTR_ERR(xprt)); + return xprt; + } + + kref_init(&xprt->kref); + spin_lock_init(&xprt->transport_lock); + spin_lock_init(&xprt->reserve_lock); + + INIT_LIST_HEAD(&xprt->free); + INIT_LIST_HEAD(&xprt->recv); + INIT_WORK(&xprt->task_cleanup, xprt_autoclose); + setup_timer(&xprt->timer, xprt_init_autodisconnect, + (unsigned long)xprt); + xprt->last_used = jiffies; + xprt->cwnd = RPC_INITCWND; + xprt->bind_index = 0; + + rpc_init_wait_queue(&xprt->binding, "xprt_binding"); + rpc_init_wait_queue(&xprt->pending, "xprt_pending"); + rpc_init_wait_queue(&xprt->sending, "xprt_sending"); + rpc_init_wait_queue(&xprt->resend, "xprt_resend"); + rpc_init_priority_wait_queue(&xprt->backlog, "xprt_backlog"); + + /* initialize free list */ + for (req = &xprt->slot[xprt->max_reqs-1]; req >= &xprt->slot[0]; req--) + list_add(&req->rq_list, &xprt->free); + + xprt_init_xid(xprt); + + dprintk("RPC: created transport %p with %u slots\n", xprt, + xprt->max_reqs); + + return xprt; +} + +/** + * xprt_destroy - destroy an RPC transport, killing off all requests. + * @kref: kref for the transport to destroy + * + */ +static void xprt_destroy(struct kref *kref) +{ + struct rpc_xprt *xprt = container_of(kref, struct rpc_xprt, kref); + + dprintk("RPC: destroying transport %p\n", xprt); + xprt->shutdown = 1; + del_timer_sync(&xprt->timer); + + rpc_destroy_wait_queue(&xprt->binding); + rpc_destroy_wait_queue(&xprt->pending); + rpc_destroy_wait_queue(&xprt->sending); + rpc_destroy_wait_queue(&xprt->resend); + rpc_destroy_wait_queue(&xprt->backlog); + /* + * Tear down transport state and free the rpc_xprt + */ + xprt->ops->destroy(xprt); +} + +/** + * xprt_put - release a reference to an RPC transport. + * @xprt: pointer to the transport + * + */ +void xprt_put(struct rpc_xprt *xprt) +{ + kref_put(&xprt->kref, xprt_destroy); +} + +/** + * xprt_get - return a reference to an RPC transport. + * @xprt: pointer to the transport + * + */ +struct rpc_xprt *xprt_get(struct rpc_xprt *xprt) +{ + kref_get(&xprt->kref); + return xprt; +} diff --git a/net/sunrpc/xprtrdma/Makefile b/net/sunrpc/xprtrdma/Makefile new file mode 100644 index 0000000..5a8f268 --- /dev/null +++ b/net/sunrpc/xprtrdma/Makefile @@ -0,0 +1,8 @@ +obj-$(CONFIG_SUNRPC_XPRT_RDMA) += xprtrdma.o + +xprtrdma-y := transport.o rpc_rdma.o verbs.o + +obj-$(CONFIG_SUNRPC_XPRT_RDMA) += svcrdma.o + +svcrdma-y := svc_rdma.o svc_rdma_transport.o \ + svc_rdma_marshal.o svc_rdma_sendto.o svc_rdma_recvfrom.o diff --git a/net/sunrpc/xprtrdma/rpc_rdma.c b/net/sunrpc/xprtrdma/rpc_rdma.c new file mode 100644 index 0000000..14106d2 --- /dev/null +++ b/net/sunrpc/xprtrdma/rpc_rdma.c @@ -0,0 +1,880 @@ +/* + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * rpc_rdma.c + * + * This file contains the guts of the RPC RDMA protocol, and + * does marshaling/unmarshaling, etc. It is also where interfacing + * to the Linux RPC framework lives. + */ + +#include "xprt_rdma.h" + +#include <linux/highmem.h> + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_TRANS +#endif + +enum rpcrdma_chunktype { + rpcrdma_noch = 0, + rpcrdma_readch, + rpcrdma_areadch, + rpcrdma_writech, + rpcrdma_replych +}; + +#ifdef RPC_DEBUG +static const char transfertypes[][12] = { + "pure inline", /* no chunks */ + " read chunk", /* some argument via rdma read */ + "*read chunk", /* entire request via rdma read */ + "write chunk", /* some result via rdma write */ + "reply chunk" /* entire reply via rdma write */ +}; +#endif + +/* + * Chunk assembly from upper layer xdr_buf. + * + * Prepare the passed-in xdr_buf into representation as RPC/RDMA chunk + * elements. Segments are then coalesced when registered, if possible + * within the selected memreg mode. + * + * Note, this routine is never called if the connection's memory + * registration strategy is 0 (bounce buffers). + */ + +static int +rpcrdma_convert_iovs(struct xdr_buf *xdrbuf, unsigned int pos, + enum rpcrdma_chunktype type, struct rpcrdma_mr_seg *seg, int nsegs) +{ + int len, n = 0, p; + + if (pos == 0 && xdrbuf->head[0].iov_len) { + seg[n].mr_page = NULL; + seg[n].mr_offset = xdrbuf->head[0].iov_base; + seg[n].mr_len = xdrbuf->head[0].iov_len; + ++n; + } + + if (xdrbuf->page_len && (xdrbuf->pages[0] != NULL)) { + if (n == nsegs) + return 0; + seg[n].mr_page = xdrbuf->pages[0]; + seg[n].mr_offset = (void *)(unsigned long) xdrbuf->page_base; + seg[n].mr_len = min_t(u32, + PAGE_SIZE - xdrbuf->page_base, xdrbuf->page_len); + len = xdrbuf->page_len - seg[n].mr_len; + ++n; + p = 1; + while (len > 0) { + if (n == nsegs) + return 0; + seg[n].mr_page = xdrbuf->pages[p]; + seg[n].mr_offset = NULL; + seg[n].mr_len = min_t(u32, PAGE_SIZE, len); + len -= seg[n].mr_len; + ++n; + ++p; + } + } + + if (xdrbuf->tail[0].iov_len) { + /* the rpcrdma protocol allows us to omit any trailing + * xdr pad bytes, saving the server an RDMA operation. */ + if (xdrbuf->tail[0].iov_len < 4 && xprt_rdma_pad_optimize) + return n; + if (n == nsegs) + return 0; + seg[n].mr_page = NULL; + seg[n].mr_offset = xdrbuf->tail[0].iov_base; + seg[n].mr_len = xdrbuf->tail[0].iov_len; + ++n; + } + + return n; +} + +/* + * Create read/write chunk lists, and reply chunks, for RDMA + * + * Assume check against THRESHOLD has been done, and chunks are required. + * Assume only encoding one list entry for read|write chunks. The NFSv3 + * protocol is simple enough to allow this as it only has a single "bulk + * result" in each procedure - complicated NFSv4 COMPOUNDs are not. (The + * RDMA/Sessions NFSv4 proposal addresses this for future v4 revs.) + * + * When used for a single reply chunk (which is a special write + * chunk used for the entire reply, rather than just the data), it + * is used primarily for READDIR and READLINK which would otherwise + * be severely size-limited by a small rdma inline read max. The server + * response will come back as an RDMA Write, followed by a message + * of type RDMA_NOMSG carrying the xid and length. As a result, reply + * chunks do not provide data alignment, however they do not require + * "fixup" (moving the response to the upper layer buffer) either. + * + * Encoding key for single-list chunks (HLOO = Handle32 Length32 Offset64): + * + * Read chunklist (a linked list): + * N elements, position P (same P for all chunks of same arg!): + * 1 - PHLOO - 1 - PHLOO - ... - 1 - PHLOO - 0 + * + * Write chunklist (a list of (one) counted array): + * N elements: + * 1 - N - HLOO - HLOO - ... - HLOO - 0 + * + * Reply chunk (a counted array): + * N elements: + * 1 - N - HLOO - HLOO - ... - HLOO + */ + +static unsigned int +rpcrdma_create_chunks(struct rpc_rqst *rqst, struct xdr_buf *target, + struct rpcrdma_msg *headerp, enum rpcrdma_chunktype type) +{ + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_task->tk_xprt); + int nsegs, nchunks = 0; + unsigned int pos; + struct rpcrdma_mr_seg *seg = req->rl_segments; + struct rpcrdma_read_chunk *cur_rchunk = NULL; + struct rpcrdma_write_array *warray = NULL; + struct rpcrdma_write_chunk *cur_wchunk = NULL; + __be32 *iptr = headerp->rm_body.rm_chunks; + + if (type == rpcrdma_readch || type == rpcrdma_areadch) { + /* a read chunk - server will RDMA Read our memory */ + cur_rchunk = (struct rpcrdma_read_chunk *) iptr; + } else { + /* a write or reply chunk - server will RDMA Write our memory */ + *iptr++ = xdr_zero; /* encode a NULL read chunk list */ + if (type == rpcrdma_replych) + *iptr++ = xdr_zero; /* a NULL write chunk list */ + warray = (struct rpcrdma_write_array *) iptr; + cur_wchunk = (struct rpcrdma_write_chunk *) (warray + 1); + } + + if (type == rpcrdma_replych || type == rpcrdma_areadch) + pos = 0; + else + pos = target->head[0].iov_len; + + nsegs = rpcrdma_convert_iovs(target, pos, type, seg, RPCRDMA_MAX_SEGS); + if (nsegs == 0) + return 0; + + do { + /* bind/register the memory, then build chunk from result. */ + int n = rpcrdma_register_external(seg, nsegs, + cur_wchunk != NULL, r_xprt); + if (n <= 0) + goto out; + if (cur_rchunk) { /* read */ + cur_rchunk->rc_discrim = xdr_one; + /* all read chunks have the same "position" */ + cur_rchunk->rc_position = htonl(pos); + cur_rchunk->rc_target.rs_handle = htonl(seg->mr_rkey); + cur_rchunk->rc_target.rs_length = htonl(seg->mr_len); + xdr_encode_hyper( + (__be32 *)&cur_rchunk->rc_target.rs_offset, + seg->mr_base); + dprintk("RPC: %s: read chunk " + "elem %d@0x%llx:0x%x pos %u (%s)\n", __func__, + seg->mr_len, (unsigned long long)seg->mr_base, + seg->mr_rkey, pos, n < nsegs ? "more" : "last"); + cur_rchunk++; + r_xprt->rx_stats.read_chunk_count++; + } else { /* write/reply */ + cur_wchunk->wc_target.rs_handle = htonl(seg->mr_rkey); + cur_wchunk->wc_target.rs_length = htonl(seg->mr_len); + xdr_encode_hyper( + (__be32 *)&cur_wchunk->wc_target.rs_offset, + seg->mr_base); + dprintk("RPC: %s: %s chunk " + "elem %d@0x%llx:0x%x (%s)\n", __func__, + (type == rpcrdma_replych) ? "reply" : "write", + seg->mr_len, (unsigned long long)seg->mr_base, + seg->mr_rkey, n < nsegs ? "more" : "last"); + cur_wchunk++; + if (type == rpcrdma_replych) + r_xprt->rx_stats.reply_chunk_count++; + else + r_xprt->rx_stats.write_chunk_count++; + r_xprt->rx_stats.total_rdma_request += seg->mr_len; + } + nchunks++; + seg += n; + nsegs -= n; + } while (nsegs); + + /* success. all failures return above */ + req->rl_nchunks = nchunks; + + BUG_ON(nchunks == 0); + + /* + * finish off header. If write, marshal discrim and nchunks. + */ + if (cur_rchunk) { + iptr = (__be32 *) cur_rchunk; + *iptr++ = xdr_zero; /* finish the read chunk list */ + *iptr++ = xdr_zero; /* encode a NULL write chunk list */ + *iptr++ = xdr_zero; /* encode a NULL reply chunk */ + } else { + warray->wc_discrim = xdr_one; + warray->wc_nchunks = htonl(nchunks); + iptr = (__be32 *) cur_wchunk; + if (type == rpcrdma_writech) { + *iptr++ = xdr_zero; /* finish the write chunk list */ + *iptr++ = xdr_zero; /* encode a NULL reply chunk */ + } + } + + /* + * Return header size. + */ + return (unsigned char *)iptr - (unsigned char *)headerp; + +out: + for (pos = 0; nchunks--;) + pos += rpcrdma_deregister_external( + &req->rl_segments[pos], r_xprt, NULL); + return 0; +} + +/* + * Copy write data inline. + * This function is used for "small" requests. Data which is passed + * to RPC via iovecs (or page list) is copied directly into the + * pre-registered memory buffer for this request. For small amounts + * of data, this is efficient. The cutoff value is tunable. + */ +static int +rpcrdma_inline_pullup(struct rpc_rqst *rqst, int pad) +{ + int i, npages, curlen; + int copy_len; + unsigned char *srcp, *destp; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(rqst->rq_xprt); + + destp = rqst->rq_svec[0].iov_base; + curlen = rqst->rq_svec[0].iov_len; + destp += curlen; + /* + * Do optional padding where it makes sense. Alignment of write + * payload can help the server, if our setting is accurate. + */ + pad -= (curlen + 36/*sizeof(struct rpcrdma_msg_padded)*/); + if (pad < 0 || rqst->rq_slen - curlen < RPCRDMA_INLINE_PAD_THRESH) + pad = 0; /* don't pad this request */ + + dprintk("RPC: %s: pad %d destp 0x%p len %d hdrlen %d\n", + __func__, pad, destp, rqst->rq_slen, curlen); + + copy_len = rqst->rq_snd_buf.page_len; + r_xprt->rx_stats.pullup_copy_count += copy_len; + npages = PAGE_ALIGN(rqst->rq_snd_buf.page_base+copy_len) >> PAGE_SHIFT; + for (i = 0; copy_len && i < npages; i++) { + if (i == 0) + curlen = PAGE_SIZE - rqst->rq_snd_buf.page_base; + else + curlen = PAGE_SIZE; + if (curlen > copy_len) + curlen = copy_len; + dprintk("RPC: %s: page %d destp 0x%p len %d curlen %d\n", + __func__, i, destp, copy_len, curlen); + srcp = kmap_atomic(rqst->rq_snd_buf.pages[i], + KM_SKB_SUNRPC_DATA); + if (i == 0) + memcpy(destp, srcp+rqst->rq_snd_buf.page_base, curlen); + else + memcpy(destp, srcp, curlen); + kunmap_atomic(srcp, KM_SKB_SUNRPC_DATA); + rqst->rq_svec[0].iov_len += curlen; + destp += curlen; + copy_len -= curlen; + } + if (rqst->rq_snd_buf.tail[0].iov_len) { + curlen = rqst->rq_snd_buf.tail[0].iov_len; + if (destp != rqst->rq_snd_buf.tail[0].iov_base) { + memcpy(destp, + rqst->rq_snd_buf.tail[0].iov_base, curlen); + r_xprt->rx_stats.pullup_copy_count += curlen; + } + dprintk("RPC: %s: tail destp 0x%p len %d curlen %d\n", + __func__, destp, copy_len, curlen); + rqst->rq_svec[0].iov_len += curlen; + } + /* header now contains entire send message */ + return pad; +} + +/* + * Marshal a request: the primary job of this routine is to choose + * the transfer modes. See comments below. + * + * Uses multiple RDMA IOVs for a request: + * [0] -- RPC RDMA header, which uses memory from the *start* of the + * preregistered buffer that already holds the RPC data in + * its middle. + * [1] -- the RPC header/data, marshaled by RPC and the NFS protocol. + * [2] -- optional padding. + * [3] -- if padded, header only in [1] and data here. + */ + +int +rpcrdma_marshal_req(struct rpc_rqst *rqst) +{ + struct rpc_xprt *xprt = rqst->rq_task->tk_xprt; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + char *base; + size_t hdrlen, rpclen, padlen; + enum rpcrdma_chunktype rtype, wtype; + struct rpcrdma_msg *headerp; + + /* + * rpclen gets amount of data in first buffer, which is the + * pre-registered buffer. + */ + base = rqst->rq_svec[0].iov_base; + rpclen = rqst->rq_svec[0].iov_len; + + /* build RDMA header in private area at front */ + headerp = (struct rpcrdma_msg *) req->rl_base; + /* don't htonl XID, it's already done in request */ + headerp->rm_xid = rqst->rq_xid; + headerp->rm_vers = xdr_one; + headerp->rm_credit = htonl(r_xprt->rx_buf.rb_max_requests); + headerp->rm_type = htonl(RDMA_MSG); + + /* + * Chunks needed for results? + * + * o If the expected result is under the inline threshold, all ops + * return as inline (but see later). + * o Large non-read ops return as a single reply chunk. + * o Large read ops return data as write chunk(s), header as inline. + * + * Note: the NFS code sending down multiple result segments implies + * the op is one of read, readdir[plus], readlink or NFSv4 getacl. + */ + + /* + * This code can handle read chunks, write chunks OR reply + * chunks -- only one type. If the request is too big to fit + * inline, then we will choose read chunks. If the request is + * a READ, then use write chunks to separate the file data + * into pages; otherwise use reply chunks. + */ + if (rqst->rq_rcv_buf.buflen <= RPCRDMA_INLINE_READ_THRESHOLD(rqst)) + wtype = rpcrdma_noch; + else if (rqst->rq_rcv_buf.page_len == 0) + wtype = rpcrdma_replych; + else if (rqst->rq_rcv_buf.flags & XDRBUF_READ) + wtype = rpcrdma_writech; + else + wtype = rpcrdma_replych; + + /* + * Chunks needed for arguments? + * + * o If the total request is under the inline threshold, all ops + * are sent as inline. + * o Large non-write ops are sent with the entire message as a + * single read chunk (protocol 0-position special case). + * o Large write ops transmit data as read chunk(s), header as + * inline. + * + * Note: the NFS code sending down multiple argument segments + * implies the op is a write. + * TBD check NFSv4 setacl + */ + if (rqst->rq_snd_buf.len <= RPCRDMA_INLINE_WRITE_THRESHOLD(rqst)) + rtype = rpcrdma_noch; + else if (rqst->rq_snd_buf.page_len == 0) + rtype = rpcrdma_areadch; + else + rtype = rpcrdma_readch; + + /* The following simplification is not true forever */ + if (rtype != rpcrdma_noch && wtype == rpcrdma_replych) + wtype = rpcrdma_noch; + BUG_ON(rtype != rpcrdma_noch && wtype != rpcrdma_noch); + + if (r_xprt->rx_ia.ri_memreg_strategy == RPCRDMA_BOUNCEBUFFERS && + (rtype != rpcrdma_noch || wtype != rpcrdma_noch)) { + /* forced to "pure inline"? */ + dprintk("RPC: %s: too much data (%d/%d) for inline\n", + __func__, rqst->rq_rcv_buf.len, rqst->rq_snd_buf.len); + return -1; + } + + hdrlen = 28; /*sizeof *headerp;*/ + padlen = 0; + + /* + * Pull up any extra send data into the preregistered buffer. + * When padding is in use and applies to the transfer, insert + * it and change the message type. + */ + if (rtype == rpcrdma_noch) { + + padlen = rpcrdma_inline_pullup(rqst, + RPCRDMA_INLINE_PAD_VALUE(rqst)); + + if (padlen) { + headerp->rm_type = htonl(RDMA_MSGP); + headerp->rm_body.rm_padded.rm_align = + htonl(RPCRDMA_INLINE_PAD_VALUE(rqst)); + headerp->rm_body.rm_padded.rm_thresh = + htonl(RPCRDMA_INLINE_PAD_THRESH); + headerp->rm_body.rm_padded.rm_pempty[0] = xdr_zero; + headerp->rm_body.rm_padded.rm_pempty[1] = xdr_zero; + headerp->rm_body.rm_padded.rm_pempty[2] = xdr_zero; + hdrlen += 2 * sizeof(u32); /* extra words in padhdr */ + BUG_ON(wtype != rpcrdma_noch); + + } else { + headerp->rm_body.rm_nochunks.rm_empty[0] = xdr_zero; + headerp->rm_body.rm_nochunks.rm_empty[1] = xdr_zero; + headerp->rm_body.rm_nochunks.rm_empty[2] = xdr_zero; + /* new length after pullup */ + rpclen = rqst->rq_svec[0].iov_len; + /* + * Currently we try to not actually use read inline. + * Reply chunks have the desirable property that + * they land, packed, directly in the target buffers + * without headers, so they require no fixup. The + * additional RDMA Write op sends the same amount + * of data, streams on-the-wire and adds no overhead + * on receive. Therefore, we request a reply chunk + * for non-writes wherever feasible and efficient. + */ + if (wtype == rpcrdma_noch && + r_xprt->rx_ia.ri_memreg_strategy > RPCRDMA_REGISTER) + wtype = rpcrdma_replych; + } + } + + /* + * Marshal chunks. This routine will return the header length + * consumed by marshaling. + */ + if (rtype != rpcrdma_noch) { + hdrlen = rpcrdma_create_chunks(rqst, + &rqst->rq_snd_buf, headerp, rtype); + wtype = rtype; /* simplify dprintk */ + + } else if (wtype != rpcrdma_noch) { + hdrlen = rpcrdma_create_chunks(rqst, + &rqst->rq_rcv_buf, headerp, wtype); + } + + if (hdrlen == 0) + return -1; + + dprintk("RPC: %s: %s: hdrlen %zd rpclen %zd padlen %zd" + " headerp 0x%p base 0x%p lkey 0x%x\n", + __func__, transfertypes[wtype], hdrlen, rpclen, padlen, + headerp, base, req->rl_iov.lkey); + + /* + * initialize send_iov's - normally only two: rdma chunk header and + * single preregistered RPC header buffer, but if padding is present, + * then use a preregistered (and zeroed) pad buffer between the RPC + * header and any write data. In all non-rdma cases, any following + * data has been copied into the RPC header buffer. + */ + req->rl_send_iov[0].addr = req->rl_iov.addr; + req->rl_send_iov[0].length = hdrlen; + req->rl_send_iov[0].lkey = req->rl_iov.lkey; + + req->rl_send_iov[1].addr = req->rl_iov.addr + (base - req->rl_base); + req->rl_send_iov[1].length = rpclen; + req->rl_send_iov[1].lkey = req->rl_iov.lkey; + + req->rl_niovs = 2; + + if (padlen) { + struct rpcrdma_ep *ep = &r_xprt->rx_ep; + + req->rl_send_iov[2].addr = ep->rep_pad.addr; + req->rl_send_iov[2].length = padlen; + req->rl_send_iov[2].lkey = ep->rep_pad.lkey; + + req->rl_send_iov[3].addr = req->rl_send_iov[1].addr + rpclen; + req->rl_send_iov[3].length = rqst->rq_slen - rpclen; + req->rl_send_iov[3].lkey = req->rl_iov.lkey; + + req->rl_niovs = 4; + } + + return 0; +} + +/* + * Chase down a received write or reply chunklist to get length + * RDMA'd by server. See map at rpcrdma_create_chunks()! :-) + */ +static int +rpcrdma_count_chunks(struct rpcrdma_rep *rep, unsigned int max, int wrchunk, __be32 **iptrp) +{ + unsigned int i, total_len; + struct rpcrdma_write_chunk *cur_wchunk; + + i = ntohl(**iptrp); /* get array count */ + if (i > max) + return -1; + cur_wchunk = (struct rpcrdma_write_chunk *) (*iptrp + 1); + total_len = 0; + while (i--) { + struct rpcrdma_segment *seg = &cur_wchunk->wc_target; + ifdebug(FACILITY) { + u64 off; + xdr_decode_hyper((__be32 *)&seg->rs_offset, &off); + dprintk("RPC: %s: chunk %d@0x%llx:0x%x\n", + __func__, + ntohl(seg->rs_length), + (unsigned long long)off, + ntohl(seg->rs_handle)); + } + total_len += ntohl(seg->rs_length); + ++cur_wchunk; + } + /* check and adjust for properly terminated write chunk */ + if (wrchunk) { + __be32 *w = (__be32 *) cur_wchunk; + if (*w++ != xdr_zero) + return -1; + cur_wchunk = (struct rpcrdma_write_chunk *) w; + } + if ((char *) cur_wchunk > rep->rr_base + rep->rr_len) + return -1; + + *iptrp = (__be32 *) cur_wchunk; + return total_len; +} + +/* + * Scatter inline received data back into provided iov's. + */ +static void +rpcrdma_inline_fixup(struct rpc_rqst *rqst, char *srcp, int copy_len, int pad) +{ + int i, npages, curlen, olen; + char *destp; + + curlen = rqst->rq_rcv_buf.head[0].iov_len; + if (curlen > copy_len) { /* write chunk header fixup */ + curlen = copy_len; + rqst->rq_rcv_buf.head[0].iov_len = curlen; + } + + dprintk("RPC: %s: srcp 0x%p len %d hdrlen %d\n", + __func__, srcp, copy_len, curlen); + + /* Shift pointer for first receive segment only */ + rqst->rq_rcv_buf.head[0].iov_base = srcp; + srcp += curlen; + copy_len -= curlen; + + olen = copy_len; + i = 0; + rpcx_to_rdmax(rqst->rq_xprt)->rx_stats.fixup_copy_count += olen; + if (copy_len && rqst->rq_rcv_buf.page_len) { + npages = PAGE_ALIGN(rqst->rq_rcv_buf.page_base + + rqst->rq_rcv_buf.page_len) >> PAGE_SHIFT; + for (; i < npages; i++) { + if (i == 0) + curlen = PAGE_SIZE - rqst->rq_rcv_buf.page_base; + else + curlen = PAGE_SIZE; + if (curlen > copy_len) + curlen = copy_len; + dprintk("RPC: %s: page %d" + " srcp 0x%p len %d curlen %d\n", + __func__, i, srcp, copy_len, curlen); + destp = kmap_atomic(rqst->rq_rcv_buf.pages[i], + KM_SKB_SUNRPC_DATA); + if (i == 0) + memcpy(destp + rqst->rq_rcv_buf.page_base, + srcp, curlen); + else + memcpy(destp, srcp, curlen); + flush_dcache_page(rqst->rq_rcv_buf.pages[i]); + kunmap_atomic(destp, KM_SKB_SUNRPC_DATA); + srcp += curlen; + copy_len -= curlen; + if (copy_len == 0) + break; + } + rqst->rq_rcv_buf.page_len = olen - copy_len; + } else + rqst->rq_rcv_buf.page_len = 0; + + if (copy_len && rqst->rq_rcv_buf.tail[0].iov_len) { + curlen = copy_len; + if (curlen > rqst->rq_rcv_buf.tail[0].iov_len) + curlen = rqst->rq_rcv_buf.tail[0].iov_len; + if (rqst->rq_rcv_buf.tail[0].iov_base != srcp) + memcpy(rqst->rq_rcv_buf.tail[0].iov_base, srcp, curlen); + dprintk("RPC: %s: tail srcp 0x%p len %d curlen %d\n", + __func__, srcp, copy_len, curlen); + rqst->rq_rcv_buf.tail[0].iov_len = curlen; + copy_len -= curlen; ++i; + } else + rqst->rq_rcv_buf.tail[0].iov_len = 0; + + if (pad) { + /* implicit padding on terminal chunk */ + unsigned char *p = rqst->rq_rcv_buf.tail[0].iov_base; + while (pad--) + p[rqst->rq_rcv_buf.tail[0].iov_len++] = 0; + } + + if (copy_len) + dprintk("RPC: %s: %d bytes in" + " %d extra segments (%d lost)\n", + __func__, olen, i, copy_len); + + /* TBD avoid a warning from call_decode() */ + rqst->rq_private_buf = rqst->rq_rcv_buf; +} + +/* + * This function is called when an async event is posted to + * the connection which changes the connection state. All it + * does at this point is mark the connection up/down, the rpc + * timers do the rest. + */ +void +rpcrdma_conn_func(struct rpcrdma_ep *ep) +{ + struct rpc_xprt *xprt = ep->rep_xprt; + + spin_lock_bh(&xprt->transport_lock); + if (++xprt->connect_cookie == 0) /* maintain a reserved value */ + ++xprt->connect_cookie; + if (ep->rep_connected > 0) { + if (!xprt_test_and_set_connected(xprt)) + xprt_wake_pending_tasks(xprt, 0); + } else { + if (xprt_test_and_clear_connected(xprt)) + xprt_wake_pending_tasks(xprt, -ENOTCONN); + } + spin_unlock_bh(&xprt->transport_lock); +} + +/* + * This function is called when memory window unbind which we are waiting + * for completes. Just use rr_func (zeroed by upcall) to signal completion. + */ +static void +rpcrdma_unbind_func(struct rpcrdma_rep *rep) +{ + wake_up(&rep->rr_unbind); +} + +/* + * Called as a tasklet to do req/reply match and complete a request + * Errors must result in the RPC task either being awakened, or + * allowed to timeout, to discover the errors at that time. + */ +void +rpcrdma_reply_handler(struct rpcrdma_rep *rep) +{ + struct rpcrdma_msg *headerp; + struct rpcrdma_req *req; + struct rpc_rqst *rqst; + struct rpc_xprt *xprt = rep->rr_xprt; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + __be32 *iptr; + int i, rdmalen, status; + + /* Check status. If bad, signal disconnect and return rep to pool */ + if (rep->rr_len == ~0U) { + rpcrdma_recv_buffer_put(rep); + if (r_xprt->rx_ep.rep_connected == 1) { + r_xprt->rx_ep.rep_connected = -EIO; + rpcrdma_conn_func(&r_xprt->rx_ep); + } + return; + } + if (rep->rr_len < 28) { + dprintk("RPC: %s: short/invalid reply\n", __func__); + goto repost; + } + headerp = (struct rpcrdma_msg *) rep->rr_base; + if (headerp->rm_vers != xdr_one) { + dprintk("RPC: %s: invalid version %d\n", + __func__, ntohl(headerp->rm_vers)); + goto repost; + } + + /* Get XID and try for a match. */ + spin_lock(&xprt->transport_lock); + rqst = xprt_lookup_rqst(xprt, headerp->rm_xid); + if (rqst == NULL) { + spin_unlock(&xprt->transport_lock); + dprintk("RPC: %s: reply 0x%p failed " + "to match any request xid 0x%08x len %d\n", + __func__, rep, headerp->rm_xid, rep->rr_len); +repost: + r_xprt->rx_stats.bad_reply_count++; + rep->rr_func = rpcrdma_reply_handler; + if (rpcrdma_ep_post_recv(&r_xprt->rx_ia, &r_xprt->rx_ep, rep)) + rpcrdma_recv_buffer_put(rep); + + return; + } + + /* get request object */ + req = rpcr_to_rdmar(rqst); + + dprintk("RPC: %s: reply 0x%p completes request 0x%p\n" + " RPC request 0x%p xid 0x%08x\n", + __func__, rep, req, rqst, headerp->rm_xid); + + BUG_ON(!req || req->rl_reply); + + /* from here on, the reply is no longer an orphan */ + req->rl_reply = rep; + + /* check for expected message types */ + /* The order of some of these tests is important. */ + switch (headerp->rm_type) { + case htonl(RDMA_MSG): + /* never expect read chunks */ + /* never expect reply chunks (two ways to check) */ + /* never expect write chunks without having offered RDMA */ + if (headerp->rm_body.rm_chunks[0] != xdr_zero || + (headerp->rm_body.rm_chunks[1] == xdr_zero && + headerp->rm_body.rm_chunks[2] != xdr_zero) || + (headerp->rm_body.rm_chunks[1] != xdr_zero && + req->rl_nchunks == 0)) + goto badheader; + if (headerp->rm_body.rm_chunks[1] != xdr_zero) { + /* count any expected write chunks in read reply */ + /* start at write chunk array count */ + iptr = &headerp->rm_body.rm_chunks[2]; + rdmalen = rpcrdma_count_chunks(rep, + req->rl_nchunks, 1, &iptr); + /* check for validity, and no reply chunk after */ + if (rdmalen < 0 || *iptr++ != xdr_zero) + goto badheader; + rep->rr_len -= + ((unsigned char *)iptr - (unsigned char *)headerp); + status = rep->rr_len + rdmalen; + r_xprt->rx_stats.total_rdma_reply += rdmalen; + /* special case - last chunk may omit padding */ + if (rdmalen &= 3) { + rdmalen = 4 - rdmalen; + status += rdmalen; + } + } else { + /* else ordinary inline */ + rdmalen = 0; + iptr = (__be32 *)((unsigned char *)headerp + 28); + rep->rr_len -= 28; /*sizeof *headerp;*/ + status = rep->rr_len; + } + /* Fix up the rpc results for upper layer */ + rpcrdma_inline_fixup(rqst, (char *)iptr, rep->rr_len, rdmalen); + break; + + case htonl(RDMA_NOMSG): + /* never expect read or write chunks, always reply chunks */ + if (headerp->rm_body.rm_chunks[0] != xdr_zero || + headerp->rm_body.rm_chunks[1] != xdr_zero || + headerp->rm_body.rm_chunks[2] != xdr_one || + req->rl_nchunks == 0) + goto badheader; + iptr = (__be32 *)((unsigned char *)headerp + 28); + rdmalen = rpcrdma_count_chunks(rep, req->rl_nchunks, 0, &iptr); + if (rdmalen < 0) + goto badheader; + r_xprt->rx_stats.total_rdma_reply += rdmalen; + /* Reply chunk buffer already is the reply vector - no fixup. */ + status = rdmalen; + break; + +badheader: + default: + dprintk("%s: invalid rpcrdma reply header (type %d):" + " chunks[012] == %d %d %d" + " expected chunks <= %d\n", + __func__, ntohl(headerp->rm_type), + headerp->rm_body.rm_chunks[0], + headerp->rm_body.rm_chunks[1], + headerp->rm_body.rm_chunks[2], + req->rl_nchunks); + status = -EIO; + r_xprt->rx_stats.bad_reply_count++; + break; + } + + /* If using mw bind, start the deregister process now. */ + /* (Note: if mr_free(), cannot perform it here, in tasklet context) */ + if (req->rl_nchunks) switch (r_xprt->rx_ia.ri_memreg_strategy) { + case RPCRDMA_MEMWINDOWS: + for (i = 0; req->rl_nchunks-- > 1;) + i += rpcrdma_deregister_external( + &req->rl_segments[i], r_xprt, NULL); + /* Optionally wait (not here) for unbinds to complete */ + rep->rr_func = rpcrdma_unbind_func; + (void) rpcrdma_deregister_external(&req->rl_segments[i], + r_xprt, rep); + break; + case RPCRDMA_MEMWINDOWS_ASYNC: + for (i = 0; req->rl_nchunks--;) + i += rpcrdma_deregister_external(&req->rl_segments[i], + r_xprt, NULL); + break; + default: + break; + } + + dprintk("RPC: %s: xprt_complete_rqst(0x%p, 0x%p, %d)\n", + __func__, xprt, rqst, status); + xprt_complete_rqst(rqst->rq_task, status); + spin_unlock(&xprt->transport_lock); +} diff --git a/net/sunrpc/xprtrdma/svc_rdma.c b/net/sunrpc/xprtrdma/svc_rdma.c new file mode 100644 index 0000000..8710117 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma.c @@ -0,0 +1,301 @@ +/* + * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Author: Tom Tucker <tom@opengridcomputing.com> + */ +#include <linux/module.h> +#include <linux/init.h> +#include <linux/fs.h> +#include <linux/sysctl.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/svc_rdma.h> + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +/* RPC/RDMA parameters */ +unsigned int svcrdma_ord = RPCRDMA_ORD; +static unsigned int min_ord = 1; +static unsigned int max_ord = 4096; +unsigned int svcrdma_max_requests = RPCRDMA_MAX_REQUESTS; +static unsigned int min_max_requests = 4; +static unsigned int max_max_requests = 16384; +unsigned int svcrdma_max_req_size = RPCRDMA_MAX_REQ_SIZE; +static unsigned int min_max_inline = 4096; +static unsigned int max_max_inline = 65536; + +atomic_t rdma_stat_recv; +atomic_t rdma_stat_read; +atomic_t rdma_stat_write; +atomic_t rdma_stat_sq_starve; +atomic_t rdma_stat_rq_starve; +atomic_t rdma_stat_rq_poll; +atomic_t rdma_stat_rq_prod; +atomic_t rdma_stat_sq_poll; +atomic_t rdma_stat_sq_prod; + +/* Temporary NFS request map and context caches */ +struct kmem_cache *svc_rdma_map_cachep; +struct kmem_cache *svc_rdma_ctxt_cachep; + +/* + * This function implements reading and resetting an atomic_t stat + * variable through read/write to a proc file. Any write to the file + * resets the associated statistic to zero. Any read returns it's + * current value. + */ +static int read_reset_stat(ctl_table *table, int write, + struct file *filp, void __user *buffer, size_t *lenp, + loff_t *ppos) +{ + atomic_t *stat = (atomic_t *)table->data; + + if (!stat) + return -EINVAL; + + if (write) + atomic_set(stat, 0); + else { + char str_buf[32]; + char *data; + int len = snprintf(str_buf, 32, "%d\n", atomic_read(stat)); + if (len >= 32) + return -EFAULT; + len = strlen(str_buf); + if (*ppos > len) { + *lenp = 0; + return 0; + } + data = &str_buf[*ppos]; + len -= *ppos; + if (len > *lenp) + len = *lenp; + if (len && copy_to_user(buffer, str_buf, len)) + return -EFAULT; + *lenp = len; + *ppos += len; + } + return 0; +} + +static struct ctl_table_header *svcrdma_table_header; +static ctl_table svcrdma_parm_table[] = { + { + .procname = "max_requests", + .data = &svcrdma_max_requests, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = &proc_dointvec_minmax, + .strategy = &sysctl_intvec, + .extra1 = &min_max_requests, + .extra2 = &max_max_requests + }, + { + .procname = "max_req_size", + .data = &svcrdma_max_req_size, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = &proc_dointvec_minmax, + .strategy = &sysctl_intvec, + .extra1 = &min_max_inline, + .extra2 = &max_max_inline + }, + { + .procname = "max_outbound_read_requests", + .data = &svcrdma_ord, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = &proc_dointvec_minmax, + .strategy = &sysctl_intvec, + .extra1 = &min_ord, + .extra2 = &max_ord, + }, + + { + .procname = "rdma_stat_read", + .data = &rdma_stat_read, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = &read_reset_stat, + }, + { + .procname = "rdma_stat_recv", + .data = &rdma_stat_recv, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = &read_reset_stat, + }, + { + .procname = "rdma_stat_write", + .data = &rdma_stat_write, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = &read_reset_stat, + }, + { + .procname = "rdma_stat_sq_starve", + .data = &rdma_stat_sq_starve, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = &read_reset_stat, + }, + { + .procname = "rdma_stat_rq_starve", + .data = &rdma_stat_rq_starve, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = &read_reset_stat, + }, + { + .procname = "rdma_stat_rq_poll", + .data = &rdma_stat_rq_poll, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = &read_reset_stat, + }, + { + .procname = "rdma_stat_rq_prod", + .data = &rdma_stat_rq_prod, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = &read_reset_stat, + }, + { + .procname = "rdma_stat_sq_poll", + .data = &rdma_stat_sq_poll, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = &read_reset_stat, + }, + { + .procname = "rdma_stat_sq_prod", + .data = &rdma_stat_sq_prod, + .maxlen = sizeof(atomic_t), + .mode = 0644, + .proc_handler = &read_reset_stat, + }, + { + .ctl_name = 0, + }, +}; + +static ctl_table svcrdma_table[] = { + { + .procname = "svc_rdma", + .mode = 0555, + .child = svcrdma_parm_table + }, + { + .ctl_name = 0, + }, +}; + +static ctl_table svcrdma_root_table[] = { + { + .ctl_name = CTL_SUNRPC, + .procname = "sunrpc", + .mode = 0555, + .child = svcrdma_table + }, + { + .ctl_name = 0, + }, +}; + +void svc_rdma_cleanup(void) +{ + dprintk("SVCRDMA Module Removed, deregister RPC RDMA transport\n"); + flush_scheduled_work(); + if (svcrdma_table_header) { + unregister_sysctl_table(svcrdma_table_header); + svcrdma_table_header = NULL; + } + svc_unreg_xprt_class(&svc_rdma_class); + kmem_cache_destroy(svc_rdma_map_cachep); + kmem_cache_destroy(svc_rdma_ctxt_cachep); +} + +int svc_rdma_init(void) +{ + dprintk("SVCRDMA Module Init, register RPC RDMA transport\n"); + dprintk("\tsvcrdma_ord : %d\n", svcrdma_ord); + dprintk("\tmax_requests : %d\n", svcrdma_max_requests); + dprintk("\tsq_depth : %d\n", + svcrdma_max_requests * RPCRDMA_SQ_DEPTH_MULT); + dprintk("\tmax_inline : %d\n", svcrdma_max_req_size); + if (!svcrdma_table_header) + svcrdma_table_header = + register_sysctl_table(svcrdma_root_table); + + /* Create the temporary map cache */ + svc_rdma_map_cachep = kmem_cache_create("svc_rdma_map_cache", + sizeof(struct svc_rdma_req_map), + 0, + SLAB_HWCACHE_ALIGN, + NULL); + if (!svc_rdma_map_cachep) { + printk(KERN_INFO "Could not allocate map cache.\n"); + goto err0; + } + + /* Create the temporary context cache */ + svc_rdma_ctxt_cachep = + kmem_cache_create("svc_rdma_ctxt_cache", + sizeof(struct svc_rdma_op_ctxt), + 0, + SLAB_HWCACHE_ALIGN, + NULL); + if (!svc_rdma_ctxt_cachep) { + printk(KERN_INFO "Could not allocate WR ctxt cache.\n"); + goto err1; + } + + /* Register RDMA with the SVC transport switch */ + svc_reg_xprt_class(&svc_rdma_class); + return 0; + err1: + kmem_cache_destroy(svc_rdma_map_cachep); + err0: + unregister_sysctl_table(svcrdma_table_header); + return -ENOMEM; +} +MODULE_AUTHOR("Tom Tucker <tom@opengridcomputing.com>"); +MODULE_DESCRIPTION("SVC RDMA Transport"); +MODULE_LICENSE("Dual BSD/GPL"); +module_init(svc_rdma_init); +module_exit(svc_rdma_cleanup); diff --git a/net/sunrpc/xprtrdma/svc_rdma_marshal.c b/net/sunrpc/xprtrdma/svc_rdma_marshal.c new file mode 100644 index 0000000..9530ef2 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_marshal.c @@ -0,0 +1,412 @@ +/* + * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Author: Tom Tucker <tom@opengridcomputing.com> + */ + +#include <linux/sunrpc/xdr.h> +#include <linux/sunrpc/debug.h> +#include <asm/unaligned.h> +#include <linux/sunrpc/rpc_rdma.h> +#include <linux/sunrpc/svc_rdma.h> + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +/* + * Decodes a read chunk list. The expected format is as follows: + * descrim : xdr_one + * position : u32 offset into XDR stream + * handle : u32 RKEY + * . . . + * end-of-list: xdr_zero + */ +static u32 *decode_read_list(u32 *va, u32 *vaend) +{ + struct rpcrdma_read_chunk *ch = (struct rpcrdma_read_chunk *)va; + + while (ch->rc_discrim != xdr_zero) { + u64 ch_offset; + + if (((unsigned long)ch + sizeof(struct rpcrdma_read_chunk)) > + (unsigned long)vaend) { + dprintk("svcrdma: vaend=%p, ch=%p\n", vaend, ch); + return NULL; + } + + ch->rc_discrim = ntohl(ch->rc_discrim); + ch->rc_position = ntohl(ch->rc_position); + ch->rc_target.rs_handle = ntohl(ch->rc_target.rs_handle); + ch->rc_target.rs_length = ntohl(ch->rc_target.rs_length); + va = (u32 *)&ch->rc_target.rs_offset; + xdr_decode_hyper(va, &ch_offset); + put_unaligned(ch_offset, (u64 *)va); + ch++; + } + return (u32 *)&ch->rc_position; +} + +/* + * Determine number of chunks and total bytes in chunk list. The chunk + * list has already been verified to fit within the RPCRDMA header. + */ +void svc_rdma_rcl_chunk_counts(struct rpcrdma_read_chunk *ch, + int *ch_count, int *byte_count) +{ + /* compute the number of bytes represented by read chunks */ + *byte_count = 0; + *ch_count = 0; + for (; ch->rc_discrim != 0; ch++) { + *byte_count = *byte_count + ch->rc_target.rs_length; + *ch_count = *ch_count + 1; + } +} + +/* + * Decodes a write chunk list. The expected format is as follows: + * descrim : xdr_one + * nchunks : <count> + * handle : u32 RKEY ---+ + * length : u32 <len of segment> | + * offset : remove va + <count> + * . . . | + * ---+ + */ +static u32 *decode_write_list(u32 *va, u32 *vaend) +{ + int ch_no; + struct rpcrdma_write_array *ary = + (struct rpcrdma_write_array *)va; + + /* Check for not write-array */ + if (ary->wc_discrim == xdr_zero) + return (u32 *)&ary->wc_nchunks; + + if ((unsigned long)ary + sizeof(struct rpcrdma_write_array) > + (unsigned long)vaend) { + dprintk("svcrdma: ary=%p, vaend=%p\n", ary, vaend); + return NULL; + } + ary->wc_discrim = ntohl(ary->wc_discrim); + ary->wc_nchunks = ntohl(ary->wc_nchunks); + if (((unsigned long)&ary->wc_array[0] + + (sizeof(struct rpcrdma_write_chunk) * ary->wc_nchunks)) > + (unsigned long)vaend) { + dprintk("svcrdma: ary=%p, wc_nchunks=%d, vaend=%p\n", + ary, ary->wc_nchunks, vaend); + return NULL; + } + for (ch_no = 0; ch_no < ary->wc_nchunks; ch_no++) { + u64 ch_offset; + + ary->wc_array[ch_no].wc_target.rs_handle = + ntohl(ary->wc_array[ch_no].wc_target.rs_handle); + ary->wc_array[ch_no].wc_target.rs_length = + ntohl(ary->wc_array[ch_no].wc_target.rs_length); + va = (u32 *)&ary->wc_array[ch_no].wc_target.rs_offset; + xdr_decode_hyper(va, &ch_offset); + put_unaligned(ch_offset, (u64 *)va); + } + + /* + * rs_length is the 2nd 4B field in wc_target and taking its + * address skips the list terminator + */ + return (u32 *)&ary->wc_array[ch_no].wc_target.rs_length; +} + +static u32 *decode_reply_array(u32 *va, u32 *vaend) +{ + int ch_no; + struct rpcrdma_write_array *ary = + (struct rpcrdma_write_array *)va; + + /* Check for no reply-array */ + if (ary->wc_discrim == xdr_zero) + return (u32 *)&ary->wc_nchunks; + + if ((unsigned long)ary + sizeof(struct rpcrdma_write_array) > + (unsigned long)vaend) { + dprintk("svcrdma: ary=%p, vaend=%p\n", ary, vaend); + return NULL; + } + ary->wc_discrim = ntohl(ary->wc_discrim); + ary->wc_nchunks = ntohl(ary->wc_nchunks); + if (((unsigned long)&ary->wc_array[0] + + (sizeof(struct rpcrdma_write_chunk) * ary->wc_nchunks)) > + (unsigned long)vaend) { + dprintk("svcrdma: ary=%p, wc_nchunks=%d, vaend=%p\n", + ary, ary->wc_nchunks, vaend); + return NULL; + } + for (ch_no = 0; ch_no < ary->wc_nchunks; ch_no++) { + u64 ch_offset; + + ary->wc_array[ch_no].wc_target.rs_handle = + ntohl(ary->wc_array[ch_no].wc_target.rs_handle); + ary->wc_array[ch_no].wc_target.rs_length = + ntohl(ary->wc_array[ch_no].wc_target.rs_length); + va = (u32 *)&ary->wc_array[ch_no].wc_target.rs_offset; + xdr_decode_hyper(va, &ch_offset); + put_unaligned(ch_offset, (u64 *)va); + } + + return (u32 *)&ary->wc_array[ch_no]; +} + +int svc_rdma_xdr_decode_req(struct rpcrdma_msg **rdma_req, + struct svc_rqst *rqstp) +{ + struct rpcrdma_msg *rmsgp = NULL; + u32 *va; + u32 *vaend; + u32 hdr_len; + + rmsgp = (struct rpcrdma_msg *)rqstp->rq_arg.head[0].iov_base; + + /* Verify that there's enough bytes for header + something */ + if (rqstp->rq_arg.len <= RPCRDMA_HDRLEN_MIN) { + dprintk("svcrdma: header too short = %d\n", + rqstp->rq_arg.len); + return -EINVAL; + } + + /* Decode the header */ + rmsgp->rm_xid = ntohl(rmsgp->rm_xid); + rmsgp->rm_vers = ntohl(rmsgp->rm_vers); + rmsgp->rm_credit = ntohl(rmsgp->rm_credit); + rmsgp->rm_type = ntohl(rmsgp->rm_type); + + if (rmsgp->rm_vers != RPCRDMA_VERSION) + return -ENOSYS; + + /* Pull in the extra for the padded case and bump our pointer */ + if (rmsgp->rm_type == RDMA_MSGP) { + int hdrlen; + rmsgp->rm_body.rm_padded.rm_align = + ntohl(rmsgp->rm_body.rm_padded.rm_align); + rmsgp->rm_body.rm_padded.rm_thresh = + ntohl(rmsgp->rm_body.rm_padded.rm_thresh); + + va = &rmsgp->rm_body.rm_padded.rm_pempty[4]; + rqstp->rq_arg.head[0].iov_base = va; + hdrlen = (u32)((unsigned long)va - (unsigned long)rmsgp); + rqstp->rq_arg.head[0].iov_len -= hdrlen; + if (hdrlen > rqstp->rq_arg.len) + return -EINVAL; + return hdrlen; + } + + /* The chunk list may contain either a read chunk list or a write + * chunk list and a reply chunk list. + */ + va = &rmsgp->rm_body.rm_chunks[0]; + vaend = (u32 *)((unsigned long)rmsgp + rqstp->rq_arg.len); + va = decode_read_list(va, vaend); + if (!va) + return -EINVAL; + va = decode_write_list(va, vaend); + if (!va) + return -EINVAL; + va = decode_reply_array(va, vaend); + if (!va) + return -EINVAL; + + rqstp->rq_arg.head[0].iov_base = va; + hdr_len = (unsigned long)va - (unsigned long)rmsgp; + rqstp->rq_arg.head[0].iov_len -= hdr_len; + + *rdma_req = rmsgp; + return hdr_len; +} + +int svc_rdma_xdr_decode_deferred_req(struct svc_rqst *rqstp) +{ + struct rpcrdma_msg *rmsgp = NULL; + struct rpcrdma_read_chunk *ch; + struct rpcrdma_write_array *ary; + u32 *va; + u32 hdrlen; + + dprintk("svcrdma: processing deferred RDMA header on rqstp=%p\n", + rqstp); + rmsgp = (struct rpcrdma_msg *)rqstp->rq_arg.head[0].iov_base; + + /* Pull in the extra for the padded case and bump our pointer */ + if (rmsgp->rm_type == RDMA_MSGP) { + va = &rmsgp->rm_body.rm_padded.rm_pempty[4]; + rqstp->rq_arg.head[0].iov_base = va; + hdrlen = (u32)((unsigned long)va - (unsigned long)rmsgp); + rqstp->rq_arg.head[0].iov_len -= hdrlen; + return hdrlen; + } + + /* + * Skip all chunks to find RPC msg. These were previously processed + */ + va = &rmsgp->rm_body.rm_chunks[0]; + + /* Skip read-list */ + for (ch = (struct rpcrdma_read_chunk *)va; + ch->rc_discrim != xdr_zero; ch++); + va = (u32 *)&ch->rc_position; + + /* Skip write-list */ + ary = (struct rpcrdma_write_array *)va; + if (ary->wc_discrim == xdr_zero) + va = (u32 *)&ary->wc_nchunks; + else + /* + * rs_length is the 2nd 4B field in wc_target and taking its + * address skips the list terminator + */ + va = (u32 *)&ary->wc_array[ary->wc_nchunks].wc_target.rs_length; + + /* Skip reply-array */ + ary = (struct rpcrdma_write_array *)va; + if (ary->wc_discrim == xdr_zero) + va = (u32 *)&ary->wc_nchunks; + else + va = (u32 *)&ary->wc_array[ary->wc_nchunks]; + + rqstp->rq_arg.head[0].iov_base = va; + hdrlen = (unsigned long)va - (unsigned long)rmsgp; + rqstp->rq_arg.head[0].iov_len -= hdrlen; + + return hdrlen; +} + +int svc_rdma_xdr_encode_error(struct svcxprt_rdma *xprt, + struct rpcrdma_msg *rmsgp, + enum rpcrdma_errcode err, u32 *va) +{ + u32 *startp = va; + + *va++ = htonl(rmsgp->rm_xid); + *va++ = htonl(rmsgp->rm_vers); + *va++ = htonl(xprt->sc_max_requests); + *va++ = htonl(RDMA_ERROR); + *va++ = htonl(err); + if (err == ERR_VERS) { + *va++ = htonl(RPCRDMA_VERSION); + *va++ = htonl(RPCRDMA_VERSION); + } + + return (int)((unsigned long)va - (unsigned long)startp); +} + +int svc_rdma_xdr_get_reply_hdr_len(struct rpcrdma_msg *rmsgp) +{ + struct rpcrdma_write_array *wr_ary; + + /* There is no read-list in a reply */ + + /* skip write list */ + wr_ary = (struct rpcrdma_write_array *) + &rmsgp->rm_body.rm_chunks[1]; + if (wr_ary->wc_discrim) + wr_ary = (struct rpcrdma_write_array *) + &wr_ary->wc_array[ntohl(wr_ary->wc_nchunks)]. + wc_target.rs_length; + else + wr_ary = (struct rpcrdma_write_array *) + &wr_ary->wc_nchunks; + + /* skip reply array */ + if (wr_ary->wc_discrim) + wr_ary = (struct rpcrdma_write_array *) + &wr_ary->wc_array[ntohl(wr_ary->wc_nchunks)]; + else + wr_ary = (struct rpcrdma_write_array *) + &wr_ary->wc_nchunks; + + return (unsigned long) wr_ary - (unsigned long) rmsgp; +} + +void svc_rdma_xdr_encode_write_list(struct rpcrdma_msg *rmsgp, int chunks) +{ + struct rpcrdma_write_array *ary; + + /* no read-list */ + rmsgp->rm_body.rm_chunks[0] = xdr_zero; + + /* write-array discrim */ + ary = (struct rpcrdma_write_array *) + &rmsgp->rm_body.rm_chunks[1]; + ary->wc_discrim = xdr_one; + ary->wc_nchunks = htonl(chunks); + + /* write-list terminator */ + ary->wc_array[chunks].wc_target.rs_handle = xdr_zero; + + /* reply-array discriminator */ + ary->wc_array[chunks].wc_target.rs_length = xdr_zero; +} + +void svc_rdma_xdr_encode_reply_array(struct rpcrdma_write_array *ary, + int chunks) +{ + ary->wc_discrim = xdr_one; + ary->wc_nchunks = htonl(chunks); +} + +void svc_rdma_xdr_encode_array_chunk(struct rpcrdma_write_array *ary, + int chunk_no, + u32 rs_handle, u64 rs_offset, + u32 write_len) +{ + struct rpcrdma_segment *seg = &ary->wc_array[chunk_no].wc_target; + seg->rs_handle = htonl(rs_handle); + seg->rs_length = htonl(write_len); + xdr_encode_hyper((u32 *) &seg->rs_offset, rs_offset); +} + +void svc_rdma_xdr_encode_reply_header(struct svcxprt_rdma *xprt, + struct rpcrdma_msg *rdma_argp, + struct rpcrdma_msg *rdma_resp, + enum rpcrdma_proc rdma_type) +{ + rdma_resp->rm_xid = htonl(rdma_argp->rm_xid); + rdma_resp->rm_vers = htonl(rdma_argp->rm_vers); + rdma_resp->rm_credit = htonl(xprt->sc_max_requests); + rdma_resp->rm_type = htonl(rdma_type); + + /* Encode <nul> chunks lists */ + rdma_resp->rm_body.rm_chunks[0] = xdr_zero; + rdma_resp->rm_body.rm_chunks[1] = xdr_zero; + rdma_resp->rm_body.rm_chunks[2] = xdr_zero; +} diff --git a/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c new file mode 100644 index 0000000..a475657 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_recvfrom.c @@ -0,0 +1,684 @@ +/* + * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Author: Tom Tucker <tom@opengridcomputing.com> + */ + +#include <linux/sunrpc/debug.h> +#include <linux/sunrpc/rpc_rdma.h> +#include <linux/spinlock.h> +#include <asm/unaligned.h> +#include <rdma/ib_verbs.h> +#include <rdma/rdma_cm.h> +#include <linux/sunrpc/svc_rdma.h> + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +/* + * Replace the pages in the rq_argpages array with the pages from the SGE in + * the RDMA_RECV completion. The SGL should contain full pages up until the + * last one. + */ +static void rdma_build_arg_xdr(struct svc_rqst *rqstp, + struct svc_rdma_op_ctxt *ctxt, + u32 byte_count) +{ + struct page *page; + u32 bc; + int sge_no; + + /* Swap the page in the SGE with the page in argpages */ + page = ctxt->pages[0]; + put_page(rqstp->rq_pages[0]); + rqstp->rq_pages[0] = page; + + /* Set up the XDR head */ + rqstp->rq_arg.head[0].iov_base = page_address(page); + rqstp->rq_arg.head[0].iov_len = min(byte_count, ctxt->sge[0].length); + rqstp->rq_arg.len = byte_count; + rqstp->rq_arg.buflen = byte_count; + + /* Compute bytes past head in the SGL */ + bc = byte_count - rqstp->rq_arg.head[0].iov_len; + + /* If data remains, store it in the pagelist */ + rqstp->rq_arg.page_len = bc; + rqstp->rq_arg.page_base = 0; + rqstp->rq_arg.pages = &rqstp->rq_pages[1]; + sge_no = 1; + while (bc && sge_no < ctxt->count) { + page = ctxt->pages[sge_no]; + put_page(rqstp->rq_pages[sge_no]); + rqstp->rq_pages[sge_no] = page; + bc -= min(bc, ctxt->sge[sge_no].length); + rqstp->rq_arg.buflen += ctxt->sge[sge_no].length; + sge_no++; + } + rqstp->rq_respages = &rqstp->rq_pages[sge_no]; + + /* We should never run out of SGE because the limit is defined to + * support the max allowed RPC data length + */ + BUG_ON(bc && (sge_no == ctxt->count)); + BUG_ON((rqstp->rq_arg.head[0].iov_len + rqstp->rq_arg.page_len) + != byte_count); + BUG_ON(rqstp->rq_arg.len != byte_count); + + /* If not all pages were used from the SGL, free the remaining ones */ + bc = sge_no; + while (sge_no < ctxt->count) { + page = ctxt->pages[sge_no++]; + put_page(page); + } + ctxt->count = bc; + + /* Set up tail */ + rqstp->rq_arg.tail[0].iov_base = NULL; + rqstp->rq_arg.tail[0].iov_len = 0; +} + +/* Encode a read-chunk-list as an array of IB SGE + * + * Assumptions: + * - chunk[0]->position points to pages[0] at an offset of 0 + * - pages[] is not physically or virtually contiguous and consists of + * PAGE_SIZE elements. + * + * Output: + * - sge array pointing into pages[] array. + * - chunk_sge array specifying sge index and count for each + * chunk in the read list + * + */ +static int map_read_chunks(struct svcxprt_rdma *xprt, + struct svc_rqst *rqstp, + struct svc_rdma_op_ctxt *head, + struct rpcrdma_msg *rmsgp, + struct svc_rdma_req_map *rpl_map, + struct svc_rdma_req_map *chl_map, + int ch_count, + int byte_count) +{ + int sge_no; + int sge_bytes; + int page_off; + int page_no; + int ch_bytes; + int ch_no; + struct rpcrdma_read_chunk *ch; + + sge_no = 0; + page_no = 0; + page_off = 0; + ch = (struct rpcrdma_read_chunk *)&rmsgp->rm_body.rm_chunks[0]; + ch_no = 0; + ch_bytes = ch->rc_target.rs_length; + head->arg.head[0] = rqstp->rq_arg.head[0]; + head->arg.tail[0] = rqstp->rq_arg.tail[0]; + head->arg.pages = &head->pages[head->count]; + head->hdr_count = head->count; /* save count of hdr pages */ + head->arg.page_base = 0; + head->arg.page_len = ch_bytes; + head->arg.len = rqstp->rq_arg.len + ch_bytes; + head->arg.buflen = rqstp->rq_arg.buflen + ch_bytes; + head->count++; + chl_map->ch[0].start = 0; + while (byte_count) { + rpl_map->sge[sge_no].iov_base = + page_address(rqstp->rq_arg.pages[page_no]) + page_off; + sge_bytes = min_t(int, PAGE_SIZE-page_off, ch_bytes); + rpl_map->sge[sge_no].iov_len = sge_bytes; + /* + * Don't bump head->count here because the same page + * may be used by multiple SGE. + */ + head->arg.pages[page_no] = rqstp->rq_arg.pages[page_no]; + rqstp->rq_respages = &rqstp->rq_arg.pages[page_no+1]; + + byte_count -= sge_bytes; + ch_bytes -= sge_bytes; + sge_no++; + /* + * If all bytes for this chunk have been mapped to an + * SGE, move to the next SGE + */ + if (ch_bytes == 0) { + chl_map->ch[ch_no].count = + sge_no - chl_map->ch[ch_no].start; + ch_no++; + ch++; + chl_map->ch[ch_no].start = sge_no; + ch_bytes = ch->rc_target.rs_length; + /* If bytes remaining account for next chunk */ + if (byte_count) { + head->arg.page_len += ch_bytes; + head->arg.len += ch_bytes; + head->arg.buflen += ch_bytes; + } + } + /* + * If this SGE consumed all of the page, move to the + * next page + */ + if ((sge_bytes + page_off) == PAGE_SIZE) { + page_no++; + page_off = 0; + /* + * If there are still bytes left to map, bump + * the page count + */ + if (byte_count) + head->count++; + } else + page_off += sge_bytes; + } + BUG_ON(byte_count != 0); + return sge_no; +} + +/* Map a read-chunk-list to an XDR and fast register the page-list. + * + * Assumptions: + * - chunk[0] position points to pages[0] at an offset of 0 + * - pages[] will be made physically contiguous by creating a one-off memory + * region using the fastreg verb. + * - byte_count is # of bytes in read-chunk-list + * - ch_count is # of chunks in read-chunk-list + * + * Output: + * - sge array pointing into pages[] array. + * - chunk_sge array specifying sge index and count for each + * chunk in the read list + */ +static int fast_reg_read_chunks(struct svcxprt_rdma *xprt, + struct svc_rqst *rqstp, + struct svc_rdma_op_ctxt *head, + struct rpcrdma_msg *rmsgp, + struct svc_rdma_req_map *rpl_map, + struct svc_rdma_req_map *chl_map, + int ch_count, + int byte_count) +{ + int page_no; + int ch_no; + u32 offset; + struct rpcrdma_read_chunk *ch; + struct svc_rdma_fastreg_mr *frmr; + int ret = 0; + + frmr = svc_rdma_get_frmr(xprt); + if (IS_ERR(frmr)) + return -ENOMEM; + + head->frmr = frmr; + head->arg.head[0] = rqstp->rq_arg.head[0]; + head->arg.tail[0] = rqstp->rq_arg.tail[0]; + head->arg.pages = &head->pages[head->count]; + head->hdr_count = head->count; /* save count of hdr pages */ + head->arg.page_base = 0; + head->arg.page_len = byte_count; + head->arg.len = rqstp->rq_arg.len + byte_count; + head->arg.buflen = rqstp->rq_arg.buflen + byte_count; + + /* Fast register the page list */ + frmr->kva = page_address(rqstp->rq_arg.pages[0]); + frmr->direction = DMA_FROM_DEVICE; + frmr->access_flags = (IB_ACCESS_LOCAL_WRITE|IB_ACCESS_REMOTE_WRITE); + frmr->map_len = byte_count; + frmr->page_list_len = PAGE_ALIGN(byte_count) >> PAGE_SHIFT; + for (page_no = 0; page_no < frmr->page_list_len; page_no++) { + frmr->page_list->page_list[page_no] = + ib_dma_map_single(xprt->sc_cm_id->device, + page_address(rqstp->rq_arg.pages[page_no]), + PAGE_SIZE, DMA_TO_DEVICE); + if (ib_dma_mapping_error(xprt->sc_cm_id->device, + frmr->page_list->page_list[page_no])) + goto fatal_err; + atomic_inc(&xprt->sc_dma_used); + head->arg.pages[page_no] = rqstp->rq_arg.pages[page_no]; + } + head->count += page_no; + + /* rq_respages points one past arg pages */ + rqstp->rq_respages = &rqstp->rq_arg.pages[page_no]; + + /* Create the reply and chunk maps */ + offset = 0; + ch = (struct rpcrdma_read_chunk *)&rmsgp->rm_body.rm_chunks[0]; + for (ch_no = 0; ch_no < ch_count; ch_no++) { + rpl_map->sge[ch_no].iov_base = frmr->kva + offset; + rpl_map->sge[ch_no].iov_len = ch->rc_target.rs_length; + chl_map->ch[ch_no].count = 1; + chl_map->ch[ch_no].start = ch_no; + offset += ch->rc_target.rs_length; + ch++; + } + + ret = svc_rdma_fastreg(xprt, frmr); + if (ret) + goto fatal_err; + + return ch_no; + + fatal_err: + printk("svcrdma: error fast registering xdr for xprt %p", xprt); + svc_rdma_put_frmr(xprt, frmr); + return -EIO; +} + +static int rdma_set_ctxt_sge(struct svcxprt_rdma *xprt, + struct svc_rdma_op_ctxt *ctxt, + struct svc_rdma_fastreg_mr *frmr, + struct kvec *vec, + u64 *sgl_offset, + int count) +{ + int i; + + ctxt->count = count; + ctxt->direction = DMA_FROM_DEVICE; + for (i = 0; i < count; i++) { + ctxt->sge[i].length = 0; /* in case map fails */ + if (!frmr) { + ctxt->sge[i].addr = + ib_dma_map_single(xprt->sc_cm_id->device, + vec[i].iov_base, + vec[i].iov_len, + DMA_FROM_DEVICE); + if (ib_dma_mapping_error(xprt->sc_cm_id->device, + ctxt->sge[i].addr)) + return -EINVAL; + ctxt->sge[i].lkey = xprt->sc_dma_lkey; + atomic_inc(&xprt->sc_dma_used); + } else { + ctxt->sge[i].addr = (unsigned long)vec[i].iov_base; + ctxt->sge[i].lkey = frmr->mr->lkey; + } + ctxt->sge[i].length = vec[i].iov_len; + *sgl_offset = *sgl_offset + vec[i].iov_len; + } + return 0; +} + +static int rdma_read_max_sge(struct svcxprt_rdma *xprt, int sge_count) +{ + if ((RDMA_TRANSPORT_IWARP == + rdma_node_get_transport(xprt->sc_cm_id-> + device->node_type)) + && sge_count > 1) + return 1; + else + return min_t(int, sge_count, xprt->sc_max_sge); +} + +/* + * Use RDMA_READ to read data from the advertised client buffer into the + * XDR stream starting at rq_arg.head[0].iov_base. + * Each chunk in the array + * contains the following fields: + * discrim - '1', This isn't used for data placement + * position - The xdr stream offset (the same for every chunk) + * handle - RMR for client memory region + * length - data transfer length + * offset - 64 bit tagged offset in remote memory region + * + * On our side, we need to read into a pagelist. The first page immediately + * follows the RPC header. + * + * This function returns: + * 0 - No error and no read-list found. + * + * 1 - Successful read-list processing. The data is not yet in + * the pagelist and therefore the RPC request must be deferred. The + * I/O completion will enqueue the transport again and + * svc_rdma_recvfrom will complete the request. + * + * <0 - Error processing/posting read-list. + * + * NOTE: The ctxt must not be touched after the last WR has been posted + * because the I/O completion processing may occur on another + * processor and free / modify the context. Ne touche pas! + */ +static int rdma_read_xdr(struct svcxprt_rdma *xprt, + struct rpcrdma_msg *rmsgp, + struct svc_rqst *rqstp, + struct svc_rdma_op_ctxt *hdr_ctxt) +{ + struct ib_send_wr read_wr; + struct ib_send_wr inv_wr; + int err = 0; + int ch_no; + int ch_count; + int byte_count; + int sge_count; + u64 sgl_offset; + struct rpcrdma_read_chunk *ch; + struct svc_rdma_op_ctxt *ctxt = NULL; + struct svc_rdma_req_map *rpl_map; + struct svc_rdma_req_map *chl_map; + + /* If no read list is present, return 0 */ + ch = svc_rdma_get_read_chunk(rmsgp); + if (!ch) + return 0; + + /* Allocate temporary reply and chunk maps */ + rpl_map = svc_rdma_get_req_map(); + chl_map = svc_rdma_get_req_map(); + + svc_rdma_rcl_chunk_counts(ch, &ch_count, &byte_count); + if (ch_count > RPCSVC_MAXPAGES) + return -EINVAL; + + if (!xprt->sc_frmr_pg_list_len) + sge_count = map_read_chunks(xprt, rqstp, hdr_ctxt, rmsgp, + rpl_map, chl_map, ch_count, + byte_count); + else + sge_count = fast_reg_read_chunks(xprt, rqstp, hdr_ctxt, rmsgp, + rpl_map, chl_map, ch_count, + byte_count); + if (sge_count < 0) { + err = -EIO; + goto out; + } + + sgl_offset = 0; + ch_no = 0; + + for (ch = (struct rpcrdma_read_chunk *)&rmsgp->rm_body.rm_chunks[0]; + ch->rc_discrim != 0; ch++, ch_no++) { +next_sge: + ctxt = svc_rdma_get_context(xprt); + ctxt->direction = DMA_FROM_DEVICE; + ctxt->frmr = hdr_ctxt->frmr; + ctxt->read_hdr = NULL; + clear_bit(RDMACTXT_F_LAST_CTXT, &ctxt->flags); + clear_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags); + + /* Prepare READ WR */ + memset(&read_wr, 0, sizeof read_wr); + read_wr.wr_id = (unsigned long)ctxt; + read_wr.opcode = IB_WR_RDMA_READ; + ctxt->wr_op = read_wr.opcode; + read_wr.send_flags = IB_SEND_SIGNALED; + read_wr.wr.rdma.rkey = ch->rc_target.rs_handle; + read_wr.wr.rdma.remote_addr = + get_unaligned(&(ch->rc_target.rs_offset)) + + sgl_offset; + read_wr.sg_list = ctxt->sge; + read_wr.num_sge = + rdma_read_max_sge(xprt, chl_map->ch[ch_no].count); + err = rdma_set_ctxt_sge(xprt, ctxt, hdr_ctxt->frmr, + &rpl_map->sge[chl_map->ch[ch_no].start], + &sgl_offset, + read_wr.num_sge); + if (err) { + svc_rdma_unmap_dma(ctxt); + svc_rdma_put_context(ctxt, 0); + goto out; + } + if (((ch+1)->rc_discrim == 0) && + (read_wr.num_sge == chl_map->ch[ch_no].count)) { + /* + * Mark the last RDMA_READ with a bit to + * indicate all RPC data has been fetched from + * the client and the RPC needs to be enqueued. + */ + set_bit(RDMACTXT_F_LAST_CTXT, &ctxt->flags); + if (hdr_ctxt->frmr) { + set_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags); + /* + * Invalidate the local MR used to map the data + * sink. + */ + if (xprt->sc_dev_caps & + SVCRDMA_DEVCAP_READ_W_INV) { + read_wr.opcode = + IB_WR_RDMA_READ_WITH_INV; + ctxt->wr_op = read_wr.opcode; + read_wr.ex.invalidate_rkey = + ctxt->frmr->mr->lkey; + } else { + /* Prepare INVALIDATE WR */ + memset(&inv_wr, 0, sizeof inv_wr); + inv_wr.opcode = IB_WR_LOCAL_INV; + inv_wr.send_flags = IB_SEND_SIGNALED; + inv_wr.ex.invalidate_rkey = + hdr_ctxt->frmr->mr->lkey; + read_wr.next = &inv_wr; + } + } + ctxt->read_hdr = hdr_ctxt; + } + /* Post the read */ + err = svc_rdma_send(xprt, &read_wr); + if (err) { + printk(KERN_ERR "svcrdma: Error %d posting RDMA_READ\n", + err); + set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); + svc_rdma_put_context(ctxt, 0); + goto out; + } + atomic_inc(&rdma_stat_read); + + if (read_wr.num_sge < chl_map->ch[ch_no].count) { + chl_map->ch[ch_no].count -= read_wr.num_sge; + chl_map->ch[ch_no].start += read_wr.num_sge; + goto next_sge; + } + sgl_offset = 0; + err = 1; + } + + out: + svc_rdma_put_req_map(rpl_map); + svc_rdma_put_req_map(chl_map); + + /* Detach arg pages. svc_recv will replenish them */ + for (ch_no = 0; &rqstp->rq_pages[ch_no] < rqstp->rq_respages; ch_no++) + rqstp->rq_pages[ch_no] = NULL; + + /* + * Detach res pages. svc_release must see a resused count of + * zero or it will attempt to put them. + */ + while (rqstp->rq_resused) + rqstp->rq_respages[--rqstp->rq_resused] = NULL; + + return err; +} + +static int rdma_read_complete(struct svc_rqst *rqstp, + struct svc_rdma_op_ctxt *head) +{ + int page_no; + int ret; + + BUG_ON(!head); + + /* Copy RPC pages */ + for (page_no = 0; page_no < head->count; page_no++) { + put_page(rqstp->rq_pages[page_no]); + rqstp->rq_pages[page_no] = head->pages[page_no]; + } + /* Point rq_arg.pages past header */ + rqstp->rq_arg.pages = &rqstp->rq_pages[head->hdr_count]; + rqstp->rq_arg.page_len = head->arg.page_len; + rqstp->rq_arg.page_base = head->arg.page_base; + + /* rq_respages starts after the last arg page */ + rqstp->rq_respages = &rqstp->rq_arg.pages[page_no]; + rqstp->rq_resused = 0; + + /* Rebuild rq_arg head and tail. */ + rqstp->rq_arg.head[0] = head->arg.head[0]; + rqstp->rq_arg.tail[0] = head->arg.tail[0]; + rqstp->rq_arg.len = head->arg.len; + rqstp->rq_arg.buflen = head->arg.buflen; + + /* Free the context */ + svc_rdma_put_context(head, 0); + + /* XXX: What should this be? */ + rqstp->rq_prot = IPPROTO_MAX; + svc_xprt_copy_addrs(rqstp, rqstp->rq_xprt); + + ret = rqstp->rq_arg.head[0].iov_len + + rqstp->rq_arg.page_len + + rqstp->rq_arg.tail[0].iov_len; + dprintk("svcrdma: deferred read ret=%d, rq_arg.len =%d, " + "rq_arg.head[0].iov_base=%p, rq_arg.head[0].iov_len = %zd\n", + ret, rqstp->rq_arg.len, rqstp->rq_arg.head[0].iov_base, + rqstp->rq_arg.head[0].iov_len); + + svc_xprt_received(rqstp->rq_xprt); + return ret; +} + +/* + * Set up the rqstp thread context to point to the RQ buffer. If + * necessary, pull additional data from the client with an RDMA_READ + * request. + */ +int svc_rdma_recvfrom(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + struct svcxprt_rdma *rdma_xprt = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + struct svc_rdma_op_ctxt *ctxt = NULL; + struct rpcrdma_msg *rmsgp; + int ret = 0; + int len; + + dprintk("svcrdma: rqstp=%p\n", rqstp); + + spin_lock_bh(&rdma_xprt->sc_rq_dto_lock); + if (!list_empty(&rdma_xprt->sc_read_complete_q)) { + ctxt = list_entry(rdma_xprt->sc_read_complete_q.next, + struct svc_rdma_op_ctxt, + dto_q); + list_del_init(&ctxt->dto_q); + } + if (ctxt) { + spin_unlock_bh(&rdma_xprt->sc_rq_dto_lock); + return rdma_read_complete(rqstp, ctxt); + } + + if (!list_empty(&rdma_xprt->sc_rq_dto_q)) { + ctxt = list_entry(rdma_xprt->sc_rq_dto_q.next, + struct svc_rdma_op_ctxt, + dto_q); + list_del_init(&ctxt->dto_q); + } else { + atomic_inc(&rdma_stat_rq_starve); + clear_bit(XPT_DATA, &xprt->xpt_flags); + ctxt = NULL; + } + spin_unlock_bh(&rdma_xprt->sc_rq_dto_lock); + if (!ctxt) { + /* This is the EAGAIN path. The svc_recv routine will + * return -EAGAIN, the nfsd thread will go to call into + * svc_recv again and we shouldn't be on the active + * transport list + */ + if (test_bit(XPT_CLOSE, &xprt->xpt_flags)) + goto close_out; + + BUG_ON(ret); + goto out; + } + dprintk("svcrdma: processing ctxt=%p on xprt=%p, rqstp=%p, status=%d\n", + ctxt, rdma_xprt, rqstp, ctxt->wc_status); + BUG_ON(ctxt->wc_status != IB_WC_SUCCESS); + atomic_inc(&rdma_stat_recv); + + /* Build up the XDR from the receive buffers. */ + rdma_build_arg_xdr(rqstp, ctxt, ctxt->byte_len); + + /* Decode the RDMA header. */ + len = svc_rdma_xdr_decode_req(&rmsgp, rqstp); + rqstp->rq_xprt_hlen = len; + + /* If the request is invalid, reply with an error */ + if (len < 0) { + if (len == -ENOSYS) + svc_rdma_send_error(rdma_xprt, rmsgp, ERR_VERS); + goto close_out; + } + + /* Read read-list data. */ + ret = rdma_read_xdr(rdma_xprt, rmsgp, rqstp, ctxt); + if (ret > 0) { + /* read-list posted, defer until data received from client. */ + svc_xprt_received(xprt); + return 0; + } + if (ret < 0) { + /* Post of read-list failed, free context. */ + svc_rdma_put_context(ctxt, 1); + return 0; + } + + ret = rqstp->rq_arg.head[0].iov_len + + rqstp->rq_arg.page_len + + rqstp->rq_arg.tail[0].iov_len; + svc_rdma_put_context(ctxt, 0); + out: + dprintk("svcrdma: ret = %d, rq_arg.len =%d, " + "rq_arg.head[0].iov_base=%p, rq_arg.head[0].iov_len = %zd\n", + ret, rqstp->rq_arg.len, + rqstp->rq_arg.head[0].iov_base, + rqstp->rq_arg.head[0].iov_len); + rqstp->rq_prot = IPPROTO_MAX; + svc_xprt_copy_addrs(rqstp, xprt); + svc_xprt_received(xprt); + return ret; + + close_out: + if (ctxt) + svc_rdma_put_context(ctxt, 1); + dprintk("svcrdma: transport %p is closing\n", xprt); + /* + * Set the close bit and enqueue it. svc_recv will see the + * close bit and call svc_xprt_delete + */ + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_xprt_received(xprt); + return 0; +} diff --git a/net/sunrpc/xprtrdma/svc_rdma_sendto.c b/net/sunrpc/xprtrdma/svc_rdma_sendto.c new file mode 100644 index 0000000..9a7a8e7 --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_sendto.c @@ -0,0 +1,700 @@ +/* + * Copyright (c) 2005-2006 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Author: Tom Tucker <tom@opengridcomputing.com> + */ + +#include <linux/sunrpc/debug.h> +#include <linux/sunrpc/rpc_rdma.h> +#include <linux/spinlock.h> +#include <asm/unaligned.h> +#include <rdma/ib_verbs.h> +#include <rdma/rdma_cm.h> +#include <linux/sunrpc/svc_rdma.h> + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +/* Encode an XDR as an array of IB SGE + * + * Assumptions: + * - head[0] is physically contiguous. + * - tail[0] is physically contiguous. + * - pages[] is not physically or virtually contigous and consists of + * PAGE_SIZE elements. + * + * Output: + * SGE[0] reserved for RCPRDMA header + * SGE[1] data from xdr->head[] + * SGE[2..sge_count-2] data from xdr->pages[] + * SGE[sge_count-1] data from xdr->tail. + * + * The max SGE we need is the length of the XDR / pagesize + one for + * head + one for tail + one for RPCRDMA header. Since RPCSVC_MAXPAGES + * reserves a page for both the request and the reply header, and this + * array is only concerned with the reply we are assured that we have + * on extra page for the RPCRMDA header. + */ +int fast_reg_xdr(struct svcxprt_rdma *xprt, + struct xdr_buf *xdr, + struct svc_rdma_req_map *vec) +{ + int sge_no; + u32 sge_bytes; + u32 page_bytes; + u32 page_off; + int page_no = 0; + u8 *frva; + struct svc_rdma_fastreg_mr *frmr; + + frmr = svc_rdma_get_frmr(xprt); + if (IS_ERR(frmr)) + return -ENOMEM; + vec->frmr = frmr; + + /* Skip the RPCRDMA header */ + sge_no = 1; + + /* Map the head. */ + frva = (void *)((unsigned long)(xdr->head[0].iov_base) & PAGE_MASK); + vec->sge[sge_no].iov_base = xdr->head[0].iov_base; + vec->sge[sge_no].iov_len = xdr->head[0].iov_len; + vec->count = 2; + sge_no++; + + /* Build the FRMR */ + frmr->kva = frva; + frmr->direction = DMA_TO_DEVICE; + frmr->access_flags = 0; + frmr->map_len = PAGE_SIZE; + frmr->page_list_len = 1; + frmr->page_list->page_list[page_no] = + ib_dma_map_single(xprt->sc_cm_id->device, + (void *)xdr->head[0].iov_base, + PAGE_SIZE, DMA_TO_DEVICE); + if (ib_dma_mapping_error(xprt->sc_cm_id->device, + frmr->page_list->page_list[page_no])) + goto fatal_err; + atomic_inc(&xprt->sc_dma_used); + + page_off = xdr->page_base; + page_bytes = xdr->page_len + page_off; + if (!page_bytes) + goto encode_tail; + + /* Map the pages */ + vec->sge[sge_no].iov_base = frva + frmr->map_len + page_off; + vec->sge[sge_no].iov_len = page_bytes; + sge_no++; + while (page_bytes) { + struct page *page; + + page = xdr->pages[page_no++]; + sge_bytes = min_t(u32, page_bytes, (PAGE_SIZE - page_off)); + page_bytes -= sge_bytes; + + frmr->page_list->page_list[page_no] = + ib_dma_map_page(xprt->sc_cm_id->device, page, 0, + PAGE_SIZE, DMA_TO_DEVICE); + if (ib_dma_mapping_error(xprt->sc_cm_id->device, + frmr->page_list->page_list[page_no])) + goto fatal_err; + + atomic_inc(&xprt->sc_dma_used); + page_off = 0; /* reset for next time through loop */ + frmr->map_len += PAGE_SIZE; + frmr->page_list_len++; + } + vec->count++; + + encode_tail: + /* Map tail */ + if (0 == xdr->tail[0].iov_len) + goto done; + + vec->count++; + vec->sge[sge_no].iov_len = xdr->tail[0].iov_len; + + if (((unsigned long)xdr->tail[0].iov_base & PAGE_MASK) == + ((unsigned long)xdr->head[0].iov_base & PAGE_MASK)) { + /* + * If head and tail use the same page, we don't need + * to map it again. + */ + vec->sge[sge_no].iov_base = xdr->tail[0].iov_base; + } else { + void *va; + + /* Map another page for the tail */ + page_off = (unsigned long)xdr->tail[0].iov_base & ~PAGE_MASK; + va = (void *)((unsigned long)xdr->tail[0].iov_base & PAGE_MASK); + vec->sge[sge_no].iov_base = frva + frmr->map_len + page_off; + + frmr->page_list->page_list[page_no] = + ib_dma_map_single(xprt->sc_cm_id->device, va, PAGE_SIZE, + DMA_TO_DEVICE); + if (ib_dma_mapping_error(xprt->sc_cm_id->device, + frmr->page_list->page_list[page_no])) + goto fatal_err; + atomic_inc(&xprt->sc_dma_used); + frmr->map_len += PAGE_SIZE; + frmr->page_list_len++; + } + + done: + if (svc_rdma_fastreg(xprt, frmr)) + goto fatal_err; + + return 0; + + fatal_err: + printk("svcrdma: Error fast registering memory for xprt %p\n", xprt); + svc_rdma_put_frmr(xprt, frmr); + return -EIO; +} + +static int map_xdr(struct svcxprt_rdma *xprt, + struct xdr_buf *xdr, + struct svc_rdma_req_map *vec) +{ + int sge_max = (xdr->len+PAGE_SIZE-1) / PAGE_SIZE + 3; + int sge_no; + u32 sge_bytes; + u32 page_bytes; + u32 page_off; + int page_no; + + BUG_ON(xdr->len != + (xdr->head[0].iov_len + xdr->page_len + xdr->tail[0].iov_len)); + + if (xprt->sc_frmr_pg_list_len) + return fast_reg_xdr(xprt, xdr, vec); + + /* Skip the first sge, this is for the RPCRDMA header */ + sge_no = 1; + + /* Head SGE */ + vec->sge[sge_no].iov_base = xdr->head[0].iov_base; + vec->sge[sge_no].iov_len = xdr->head[0].iov_len; + sge_no++; + + /* pages SGE */ + page_no = 0; + page_bytes = xdr->page_len; + page_off = xdr->page_base; + while (page_bytes) { + vec->sge[sge_no].iov_base = + page_address(xdr->pages[page_no]) + page_off; + sge_bytes = min_t(u32, page_bytes, (PAGE_SIZE - page_off)); + page_bytes -= sge_bytes; + vec->sge[sge_no].iov_len = sge_bytes; + + sge_no++; + page_no++; + page_off = 0; /* reset for next time through loop */ + } + + /* Tail SGE */ + if (xdr->tail[0].iov_len) { + vec->sge[sge_no].iov_base = xdr->tail[0].iov_base; + vec->sge[sge_no].iov_len = xdr->tail[0].iov_len; + sge_no++; + } + + BUG_ON(sge_no > sge_max); + vec->count = sge_no; + return 0; +} + +/* Assumptions: + * - We are using FRMR + * - or - + * - The specified write_len can be represented in sc_max_sge * PAGE_SIZE + */ +static int send_write(struct svcxprt_rdma *xprt, struct svc_rqst *rqstp, + u32 rmr, u64 to, + u32 xdr_off, int write_len, + struct svc_rdma_req_map *vec) +{ + struct ib_send_wr write_wr; + struct ib_sge *sge; + int xdr_sge_no; + int sge_no; + int sge_bytes; + int sge_off; + int bc; + struct svc_rdma_op_ctxt *ctxt; + + BUG_ON(vec->count > RPCSVC_MAXPAGES); + dprintk("svcrdma: RDMA_WRITE rmr=%x, to=%llx, xdr_off=%d, " + "write_len=%d, vec->sge=%p, vec->count=%lu\n", + rmr, (unsigned long long)to, xdr_off, + write_len, vec->sge, vec->count); + + ctxt = svc_rdma_get_context(xprt); + ctxt->direction = DMA_TO_DEVICE; + sge = ctxt->sge; + + /* Find the SGE associated with xdr_off */ + for (bc = xdr_off, xdr_sge_no = 1; bc && xdr_sge_no < vec->count; + xdr_sge_no++) { + if (vec->sge[xdr_sge_no].iov_len > bc) + break; + bc -= vec->sge[xdr_sge_no].iov_len; + } + + sge_off = bc; + bc = write_len; + sge_no = 0; + + /* Copy the remaining SGE */ + while (bc != 0) { + sge_bytes = min_t(size_t, + bc, vec->sge[xdr_sge_no].iov_len-sge_off); + sge[sge_no].length = sge_bytes; + if (!vec->frmr) { + sge[sge_no].addr = + ib_dma_map_single(xprt->sc_cm_id->device, + (void *) + vec->sge[xdr_sge_no].iov_base + sge_off, + sge_bytes, DMA_TO_DEVICE); + if (ib_dma_mapping_error(xprt->sc_cm_id->device, + sge[sge_no].addr)) + goto err; + atomic_inc(&xprt->sc_dma_used); + sge[sge_no].lkey = xprt->sc_dma_lkey; + } else { + sge[sge_no].addr = (unsigned long) + vec->sge[xdr_sge_no].iov_base + sge_off; + sge[sge_no].lkey = vec->frmr->mr->lkey; + } + ctxt->count++; + ctxt->frmr = vec->frmr; + sge_off = 0; + sge_no++; + xdr_sge_no++; + BUG_ON(xdr_sge_no > vec->count); + bc -= sge_bytes; + } + + /* Prepare WRITE WR */ + memset(&write_wr, 0, sizeof write_wr); + ctxt->wr_op = IB_WR_RDMA_WRITE; + write_wr.wr_id = (unsigned long)ctxt; + write_wr.sg_list = &sge[0]; + write_wr.num_sge = sge_no; + write_wr.opcode = IB_WR_RDMA_WRITE; + write_wr.send_flags = IB_SEND_SIGNALED; + write_wr.wr.rdma.rkey = rmr; + write_wr.wr.rdma.remote_addr = to; + + /* Post It */ + atomic_inc(&rdma_stat_write); + if (svc_rdma_send(xprt, &write_wr)) + goto err; + return 0; + err: + svc_rdma_put_context(ctxt, 0); + /* Fatal error, close transport */ + return -EIO; +} + +static int send_write_chunks(struct svcxprt_rdma *xprt, + struct rpcrdma_msg *rdma_argp, + struct rpcrdma_msg *rdma_resp, + struct svc_rqst *rqstp, + struct svc_rdma_req_map *vec) +{ + u32 xfer_len = rqstp->rq_res.page_len + rqstp->rq_res.tail[0].iov_len; + int write_len; + int max_write; + u32 xdr_off; + int chunk_off; + int chunk_no; + struct rpcrdma_write_array *arg_ary; + struct rpcrdma_write_array *res_ary; + int ret; + + arg_ary = svc_rdma_get_write_array(rdma_argp); + if (!arg_ary) + return 0; + res_ary = (struct rpcrdma_write_array *) + &rdma_resp->rm_body.rm_chunks[1]; + + if (vec->frmr) + max_write = vec->frmr->map_len; + else + max_write = xprt->sc_max_sge * PAGE_SIZE; + + /* Write chunks start at the pagelist */ + for (xdr_off = rqstp->rq_res.head[0].iov_len, chunk_no = 0; + xfer_len && chunk_no < arg_ary->wc_nchunks; + chunk_no++) { + struct rpcrdma_segment *arg_ch; + u64 rs_offset; + + arg_ch = &arg_ary->wc_array[chunk_no].wc_target; + write_len = min(xfer_len, arg_ch->rs_length); + + /* Prepare the response chunk given the length actually + * written */ + rs_offset = get_unaligned(&(arg_ch->rs_offset)); + svc_rdma_xdr_encode_array_chunk(res_ary, chunk_no, + arg_ch->rs_handle, + rs_offset, + write_len); + chunk_off = 0; + while (write_len) { + int this_write; + this_write = min(write_len, max_write); + ret = send_write(xprt, rqstp, + arg_ch->rs_handle, + rs_offset + chunk_off, + xdr_off, + this_write, + vec); + if (ret) { + dprintk("svcrdma: RDMA_WRITE failed, ret=%d\n", + ret); + return -EIO; + } + chunk_off += this_write; + xdr_off += this_write; + xfer_len -= this_write; + write_len -= this_write; + } + } + /* Update the req with the number of chunks actually used */ + svc_rdma_xdr_encode_write_list(rdma_resp, chunk_no); + + return rqstp->rq_res.page_len + rqstp->rq_res.tail[0].iov_len; +} + +static int send_reply_chunks(struct svcxprt_rdma *xprt, + struct rpcrdma_msg *rdma_argp, + struct rpcrdma_msg *rdma_resp, + struct svc_rqst *rqstp, + struct svc_rdma_req_map *vec) +{ + u32 xfer_len = rqstp->rq_res.len; + int write_len; + int max_write; + u32 xdr_off; + int chunk_no; + int chunk_off; + struct rpcrdma_segment *ch; + struct rpcrdma_write_array *arg_ary; + struct rpcrdma_write_array *res_ary; + int ret; + + arg_ary = svc_rdma_get_reply_array(rdma_argp); + if (!arg_ary) + return 0; + /* XXX: need to fix when reply lists occur with read-list and or + * write-list */ + res_ary = (struct rpcrdma_write_array *) + &rdma_resp->rm_body.rm_chunks[2]; + + if (vec->frmr) + max_write = vec->frmr->map_len; + else + max_write = xprt->sc_max_sge * PAGE_SIZE; + + /* xdr offset starts at RPC message */ + for (xdr_off = 0, chunk_no = 0; + xfer_len && chunk_no < arg_ary->wc_nchunks; + chunk_no++) { + u64 rs_offset; + ch = &arg_ary->wc_array[chunk_no].wc_target; + write_len = min(xfer_len, ch->rs_length); + + /* Prepare the reply chunk given the length actually + * written */ + rs_offset = get_unaligned(&(ch->rs_offset)); + svc_rdma_xdr_encode_array_chunk(res_ary, chunk_no, + ch->rs_handle, rs_offset, + write_len); + chunk_off = 0; + while (write_len) { + int this_write; + + this_write = min(write_len, max_write); + ret = send_write(xprt, rqstp, + ch->rs_handle, + rs_offset + chunk_off, + xdr_off, + this_write, + vec); + if (ret) { + dprintk("svcrdma: RDMA_WRITE failed, ret=%d\n", + ret); + return -EIO; + } + chunk_off += this_write; + xdr_off += this_write; + xfer_len -= this_write; + write_len -= this_write; + } + } + /* Update the req with the number of chunks actually used */ + svc_rdma_xdr_encode_reply_array(res_ary, chunk_no); + + return rqstp->rq_res.len; +} + +/* This function prepares the portion of the RPCRDMA message to be + * sent in the RDMA_SEND. This function is called after data sent via + * RDMA has already been transmitted. There are three cases: + * - The RPCRDMA header, RPC header, and payload are all sent in a + * single RDMA_SEND. This is the "inline" case. + * - The RPCRDMA header and some portion of the RPC header and data + * are sent via this RDMA_SEND and another portion of the data is + * sent via RDMA. + * - The RPCRDMA header [NOMSG] is sent in this RDMA_SEND and the RPC + * header and data are all transmitted via RDMA. + * In all three cases, this function prepares the RPCRDMA header in + * sge[0], the 'type' parameter indicates the type to place in the + * RPCRDMA header, and the 'byte_count' field indicates how much of + * the XDR to include in this RDMA_SEND. + */ +static int send_reply(struct svcxprt_rdma *rdma, + struct svc_rqst *rqstp, + struct page *page, + struct rpcrdma_msg *rdma_resp, + struct svc_rdma_op_ctxt *ctxt, + struct svc_rdma_req_map *vec, + int byte_count) +{ + struct ib_send_wr send_wr; + struct ib_send_wr inv_wr; + int sge_no; + int sge_bytes; + int page_no; + int ret; + + /* Post a recv buffer to handle another request. */ + ret = svc_rdma_post_recv(rdma); + if (ret) { + printk(KERN_INFO + "svcrdma: could not post a receive buffer, err=%d." + "Closing transport %p.\n", ret, rdma); + set_bit(XPT_CLOSE, &rdma->sc_xprt.xpt_flags); + svc_rdma_put_context(ctxt, 0); + return -ENOTCONN; + } + + /* Prepare the context */ + ctxt->pages[0] = page; + ctxt->count = 1; + ctxt->frmr = vec->frmr; + if (vec->frmr) + set_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags); + else + clear_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags); + + /* Prepare the SGE for the RPCRDMA Header */ + ctxt->sge[0].addr = + ib_dma_map_page(rdma->sc_cm_id->device, + page, 0, PAGE_SIZE, DMA_TO_DEVICE); + if (ib_dma_mapping_error(rdma->sc_cm_id->device, ctxt->sge[0].addr)) + goto err; + atomic_inc(&rdma->sc_dma_used); + + ctxt->direction = DMA_TO_DEVICE; + + ctxt->sge[0].length = svc_rdma_xdr_get_reply_hdr_len(rdma_resp); + ctxt->sge[0].lkey = rdma->sc_dma_lkey; + + /* Determine how many of our SGE are to be transmitted */ + for (sge_no = 1; byte_count && sge_no < vec->count; sge_no++) { + sge_bytes = min_t(size_t, vec->sge[sge_no].iov_len, byte_count); + byte_count -= sge_bytes; + if (!vec->frmr) { + ctxt->sge[sge_no].addr = + ib_dma_map_single(rdma->sc_cm_id->device, + vec->sge[sge_no].iov_base, + sge_bytes, DMA_TO_DEVICE); + if (ib_dma_mapping_error(rdma->sc_cm_id->device, + ctxt->sge[sge_no].addr)) + goto err; + atomic_inc(&rdma->sc_dma_used); + ctxt->sge[sge_no].lkey = rdma->sc_dma_lkey; + } else { + ctxt->sge[sge_no].addr = (unsigned long) + vec->sge[sge_no].iov_base; + ctxt->sge[sge_no].lkey = vec->frmr->mr->lkey; + } + ctxt->sge[sge_no].length = sge_bytes; + } + BUG_ON(byte_count != 0); + + /* Save all respages in the ctxt and remove them from the + * respages array. They are our pages until the I/O + * completes. + */ + for (page_no = 0; page_no < rqstp->rq_resused; page_no++) { + ctxt->pages[page_no+1] = rqstp->rq_respages[page_no]; + ctxt->count++; + rqstp->rq_respages[page_no] = NULL; + /* + * If there are more pages than SGE, terminate SGE + * list so that svc_rdma_unmap_dma doesn't attempt to + * unmap garbage. + */ + if (page_no+1 >= sge_no) + ctxt->sge[page_no+1].length = 0; + } + BUG_ON(sge_no > rdma->sc_max_sge); + BUG_ON(sge_no > ctxt->count); + memset(&send_wr, 0, sizeof send_wr); + ctxt->wr_op = IB_WR_SEND; + send_wr.wr_id = (unsigned long)ctxt; + send_wr.sg_list = ctxt->sge; + send_wr.num_sge = sge_no; + send_wr.opcode = IB_WR_SEND; + send_wr.send_flags = IB_SEND_SIGNALED; + if (vec->frmr) { + /* Prepare INVALIDATE WR */ + memset(&inv_wr, 0, sizeof inv_wr); + inv_wr.opcode = IB_WR_LOCAL_INV; + inv_wr.send_flags = IB_SEND_SIGNALED; + inv_wr.ex.invalidate_rkey = + vec->frmr->mr->lkey; + send_wr.next = &inv_wr; + } + + ret = svc_rdma_send(rdma, &send_wr); + if (ret) + goto err; + + return 0; + + err: + svc_rdma_put_frmr(rdma, vec->frmr); + svc_rdma_put_context(ctxt, 1); + return -EIO; +} + +void svc_rdma_prep_reply_hdr(struct svc_rqst *rqstp) +{ +} + +/* + * Return the start of an xdr buffer. + */ +static void *xdr_start(struct xdr_buf *xdr) +{ + return xdr->head[0].iov_base - + (xdr->len - + xdr->page_len - + xdr->tail[0].iov_len - + xdr->head[0].iov_len); +} + +int svc_rdma_sendto(struct svc_rqst *rqstp) +{ + struct svc_xprt *xprt = rqstp->rq_xprt; + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + struct rpcrdma_msg *rdma_argp; + struct rpcrdma_msg *rdma_resp; + struct rpcrdma_write_array *reply_ary; + enum rpcrdma_proc reply_type; + int ret; + int inline_bytes; + struct page *res_page; + struct svc_rdma_op_ctxt *ctxt; + struct svc_rdma_req_map *vec; + + dprintk("svcrdma: sending response for rqstp=%p\n", rqstp); + + /* Get the RDMA request header. */ + rdma_argp = xdr_start(&rqstp->rq_arg); + + /* Build an req vec for the XDR */ + ctxt = svc_rdma_get_context(rdma); + ctxt->direction = DMA_TO_DEVICE; + vec = svc_rdma_get_req_map(); + ret = map_xdr(rdma, &rqstp->rq_res, vec); + if (ret) + goto err0; + inline_bytes = rqstp->rq_res.len; + + /* Create the RDMA response header */ + res_page = svc_rdma_get_page(); + rdma_resp = page_address(res_page); + reply_ary = svc_rdma_get_reply_array(rdma_argp); + if (reply_ary) + reply_type = RDMA_NOMSG; + else + reply_type = RDMA_MSG; + svc_rdma_xdr_encode_reply_header(rdma, rdma_argp, + rdma_resp, reply_type); + + /* Send any write-chunk data and build resp write-list */ + ret = send_write_chunks(rdma, rdma_argp, rdma_resp, + rqstp, vec); + if (ret < 0) { + printk(KERN_ERR "svcrdma: failed to send write chunks, rc=%d\n", + ret); + goto err1; + } + inline_bytes -= ret; + + /* Send any reply-list data and update resp reply-list */ + ret = send_reply_chunks(rdma, rdma_argp, rdma_resp, + rqstp, vec); + if (ret < 0) { + printk(KERN_ERR "svcrdma: failed to send reply chunks, rc=%d\n", + ret); + goto err1; + } + inline_bytes -= ret; + + ret = send_reply(rdma, rqstp, res_page, rdma_resp, ctxt, vec, + inline_bytes); + svc_rdma_put_req_map(vec); + dprintk("svcrdma: send_reply returns %d\n", ret); + return ret; + + err1: + put_page(res_page); + err0: + svc_rdma_put_req_map(vec); + svc_rdma_put_context(ctxt, 0); + return ret; +} diff --git a/net/sunrpc/xprtrdma/svc_rdma_transport.c b/net/sunrpc/xprtrdma/svc_rdma_transport.c new file mode 100644 index 0000000..6fb493c --- /dev/null +++ b/net/sunrpc/xprtrdma/svc_rdma_transport.c @@ -0,0 +1,1350 @@ +/* + * Copyright (c) 2005-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * Author: Tom Tucker <tom@opengridcomputing.com> + */ + +#include <linux/sunrpc/svc_xprt.h> +#include <linux/sunrpc/debug.h> +#include <linux/sunrpc/rpc_rdma.h> +#include <linux/spinlock.h> +#include <rdma/ib_verbs.h> +#include <rdma/rdma_cm.h> +#include <linux/sunrpc/svc_rdma.h> + +#define RPCDBG_FACILITY RPCDBG_SVCXPRT + +static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, + struct sockaddr *sa, int salen, + int flags); +static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt); +static void svc_rdma_release_rqst(struct svc_rqst *); +static void dto_tasklet_func(unsigned long data); +static void svc_rdma_detach(struct svc_xprt *xprt); +static void svc_rdma_free(struct svc_xprt *xprt); +static int svc_rdma_has_wspace(struct svc_xprt *xprt); +static void rq_cq_reap(struct svcxprt_rdma *xprt); +static void sq_cq_reap(struct svcxprt_rdma *xprt); + +DECLARE_TASKLET(dto_tasklet, dto_tasklet_func, 0UL); +static DEFINE_SPINLOCK(dto_lock); +static LIST_HEAD(dto_xprt_q); + +static struct svc_xprt_ops svc_rdma_ops = { + .xpo_create = svc_rdma_create, + .xpo_recvfrom = svc_rdma_recvfrom, + .xpo_sendto = svc_rdma_sendto, + .xpo_release_rqst = svc_rdma_release_rqst, + .xpo_detach = svc_rdma_detach, + .xpo_free = svc_rdma_free, + .xpo_prep_reply_hdr = svc_rdma_prep_reply_hdr, + .xpo_has_wspace = svc_rdma_has_wspace, + .xpo_accept = svc_rdma_accept, +}; + +struct svc_xprt_class svc_rdma_class = { + .xcl_name = "rdma", + .xcl_owner = THIS_MODULE, + .xcl_ops = &svc_rdma_ops, + .xcl_max_payload = RPCSVC_MAXPAYLOAD_TCP, +}; + +/* WR context cache. Created in svc_rdma.c */ +extern struct kmem_cache *svc_rdma_ctxt_cachep; + +struct svc_rdma_op_ctxt *svc_rdma_get_context(struct svcxprt_rdma *xprt) +{ + struct svc_rdma_op_ctxt *ctxt; + + while (1) { + ctxt = kmem_cache_alloc(svc_rdma_ctxt_cachep, GFP_KERNEL); + if (ctxt) + break; + schedule_timeout_uninterruptible(msecs_to_jiffies(500)); + } + ctxt->xprt = xprt; + INIT_LIST_HEAD(&ctxt->dto_q); + ctxt->count = 0; + ctxt->frmr = NULL; + atomic_inc(&xprt->sc_ctxt_used); + return ctxt; +} + +void svc_rdma_unmap_dma(struct svc_rdma_op_ctxt *ctxt) +{ + struct svcxprt_rdma *xprt = ctxt->xprt; + int i; + for (i = 0; i < ctxt->count && ctxt->sge[i].length; i++) { + /* + * Unmap the DMA addr in the SGE if the lkey matches + * the sc_dma_lkey, otherwise, ignore it since it is + * an FRMR lkey and will be unmapped later when the + * last WR that uses it completes. + */ + if (ctxt->sge[i].lkey == xprt->sc_dma_lkey) { + atomic_dec(&xprt->sc_dma_used); + ib_dma_unmap_single(xprt->sc_cm_id->device, + ctxt->sge[i].addr, + ctxt->sge[i].length, + ctxt->direction); + } + } +} + +void svc_rdma_put_context(struct svc_rdma_op_ctxt *ctxt, int free_pages) +{ + struct svcxprt_rdma *xprt; + int i; + + BUG_ON(!ctxt); + xprt = ctxt->xprt; + if (free_pages) + for (i = 0; i < ctxt->count; i++) + put_page(ctxt->pages[i]); + + kmem_cache_free(svc_rdma_ctxt_cachep, ctxt); + atomic_dec(&xprt->sc_ctxt_used); +} + +/* Temporary NFS request map cache. Created in svc_rdma.c */ +extern struct kmem_cache *svc_rdma_map_cachep; + +/* + * Temporary NFS req mappings are shared across all transport + * instances. These are short lived and should be bounded by the number + * of concurrent server threads * depth of the SQ. + */ +struct svc_rdma_req_map *svc_rdma_get_req_map(void) +{ + struct svc_rdma_req_map *map; + while (1) { + map = kmem_cache_alloc(svc_rdma_map_cachep, GFP_KERNEL); + if (map) + break; + schedule_timeout_uninterruptible(msecs_to_jiffies(500)); + } + map->count = 0; + map->frmr = NULL; + return map; +} + +void svc_rdma_put_req_map(struct svc_rdma_req_map *map) +{ + kmem_cache_free(svc_rdma_map_cachep, map); +} + +/* ib_cq event handler */ +static void cq_event_handler(struct ib_event *event, void *context) +{ + struct svc_xprt *xprt = context; + dprintk("svcrdma: received CQ event id=%d, context=%p\n", + event->event, context); + set_bit(XPT_CLOSE, &xprt->xpt_flags); +} + +/* QP event handler */ +static void qp_event_handler(struct ib_event *event, void *context) +{ + struct svc_xprt *xprt = context; + + switch (event->event) { + /* These are considered benign events */ + case IB_EVENT_PATH_MIG: + case IB_EVENT_COMM_EST: + case IB_EVENT_SQ_DRAINED: + case IB_EVENT_QP_LAST_WQE_REACHED: + dprintk("svcrdma: QP event %d received for QP=%p\n", + event->event, event->element.qp); + break; + /* These are considered fatal events */ + case IB_EVENT_PATH_MIG_ERR: + case IB_EVENT_QP_FATAL: + case IB_EVENT_QP_REQ_ERR: + case IB_EVENT_QP_ACCESS_ERR: + case IB_EVENT_DEVICE_FATAL: + default: + dprintk("svcrdma: QP ERROR event %d received for QP=%p, " + "closing transport\n", + event->event, event->element.qp); + set_bit(XPT_CLOSE, &xprt->xpt_flags); + break; + } +} + +/* + * Data Transfer Operation Tasklet + * + * Walks a list of transports with I/O pending, removing entries as + * they are added to the server's I/O pending list. Two bits indicate + * if SQ, RQ, or both have I/O pending. The dto_lock is an irqsave + * spinlock that serializes access to the transport list with the RQ + * and SQ interrupt handlers. + */ +static void dto_tasklet_func(unsigned long data) +{ + struct svcxprt_rdma *xprt; + unsigned long flags; + + spin_lock_irqsave(&dto_lock, flags); + while (!list_empty(&dto_xprt_q)) { + xprt = list_entry(dto_xprt_q.next, + struct svcxprt_rdma, sc_dto_q); + list_del_init(&xprt->sc_dto_q); + spin_unlock_irqrestore(&dto_lock, flags); + + rq_cq_reap(xprt); + sq_cq_reap(xprt); + + svc_xprt_put(&xprt->sc_xprt); + spin_lock_irqsave(&dto_lock, flags); + } + spin_unlock_irqrestore(&dto_lock, flags); +} + +/* + * Receive Queue Completion Handler + * + * Since an RQ completion handler is called on interrupt context, we + * need to defer the handling of the I/O to a tasklet + */ +static void rq_comp_handler(struct ib_cq *cq, void *cq_context) +{ + struct svcxprt_rdma *xprt = cq_context; + unsigned long flags; + + /* Guard against unconditional flush call for destroyed QP */ + if (atomic_read(&xprt->sc_xprt.xpt_ref.refcount)==0) + return; + + /* + * Set the bit regardless of whether or not it's on the list + * because it may be on the list already due to an SQ + * completion. + */ + set_bit(RDMAXPRT_RQ_PENDING, &xprt->sc_flags); + + /* + * If this transport is not already on the DTO transport queue, + * add it + */ + spin_lock_irqsave(&dto_lock, flags); + if (list_empty(&xprt->sc_dto_q)) { + svc_xprt_get(&xprt->sc_xprt); + list_add_tail(&xprt->sc_dto_q, &dto_xprt_q); + } + spin_unlock_irqrestore(&dto_lock, flags); + + /* Tasklet does all the work to avoid irqsave locks. */ + tasklet_schedule(&dto_tasklet); +} + +/* + * rq_cq_reap - Process the RQ CQ. + * + * Take all completing WC off the CQE and enqueue the associated DTO + * context on the dto_q for the transport. + * + * Note that caller must hold a transport reference. + */ +static void rq_cq_reap(struct svcxprt_rdma *xprt) +{ + int ret; + struct ib_wc wc; + struct svc_rdma_op_ctxt *ctxt = NULL; + + if (!test_and_clear_bit(RDMAXPRT_RQ_PENDING, &xprt->sc_flags)) + return; + + ib_req_notify_cq(xprt->sc_rq_cq, IB_CQ_NEXT_COMP); + atomic_inc(&rdma_stat_rq_poll); + + while ((ret = ib_poll_cq(xprt->sc_rq_cq, 1, &wc)) > 0) { + ctxt = (struct svc_rdma_op_ctxt *)(unsigned long)wc.wr_id; + ctxt->wc_status = wc.status; + ctxt->byte_len = wc.byte_len; + svc_rdma_unmap_dma(ctxt); + if (wc.status != IB_WC_SUCCESS) { + /* Close the transport */ + dprintk("svcrdma: transport closing putting ctxt %p\n", ctxt); + set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); + svc_rdma_put_context(ctxt, 1); + svc_xprt_put(&xprt->sc_xprt); + continue; + } + spin_lock_bh(&xprt->sc_rq_dto_lock); + list_add_tail(&ctxt->dto_q, &xprt->sc_rq_dto_q); + spin_unlock_bh(&xprt->sc_rq_dto_lock); + svc_xprt_put(&xprt->sc_xprt); + } + + if (ctxt) + atomic_inc(&rdma_stat_rq_prod); + + set_bit(XPT_DATA, &xprt->sc_xprt.xpt_flags); + /* + * If data arrived before established event, + * don't enqueue. This defers RPC I/O until the + * RDMA connection is complete. + */ + if (!test_bit(RDMAXPRT_CONN_PENDING, &xprt->sc_flags)) + svc_xprt_enqueue(&xprt->sc_xprt); +} + +/* + * Processs a completion context + */ +static void process_context(struct svcxprt_rdma *xprt, + struct svc_rdma_op_ctxt *ctxt) +{ + svc_rdma_unmap_dma(ctxt); + + switch (ctxt->wr_op) { + case IB_WR_SEND: + if (test_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags)) + svc_rdma_put_frmr(xprt, ctxt->frmr); + svc_rdma_put_context(ctxt, 1); + break; + + case IB_WR_RDMA_WRITE: + svc_rdma_put_context(ctxt, 0); + break; + + case IB_WR_RDMA_READ: + case IB_WR_RDMA_READ_WITH_INV: + if (test_bit(RDMACTXT_F_LAST_CTXT, &ctxt->flags)) { + struct svc_rdma_op_ctxt *read_hdr = ctxt->read_hdr; + BUG_ON(!read_hdr); + if (test_bit(RDMACTXT_F_FAST_UNREG, &ctxt->flags)) + svc_rdma_put_frmr(xprt, ctxt->frmr); + spin_lock_bh(&xprt->sc_rq_dto_lock); + set_bit(XPT_DATA, &xprt->sc_xprt.xpt_flags); + list_add_tail(&read_hdr->dto_q, + &xprt->sc_read_complete_q); + spin_unlock_bh(&xprt->sc_rq_dto_lock); + svc_xprt_enqueue(&xprt->sc_xprt); + } + svc_rdma_put_context(ctxt, 0); + break; + + default: + printk(KERN_ERR "svcrdma: unexpected completion type, " + "opcode=%d\n", + ctxt->wr_op); + break; + } +} + +/* + * Send Queue Completion Handler - potentially called on interrupt context. + * + * Note that caller must hold a transport reference. + */ +static void sq_cq_reap(struct svcxprt_rdma *xprt) +{ + struct svc_rdma_op_ctxt *ctxt = NULL; + struct ib_wc wc; + struct ib_cq *cq = xprt->sc_sq_cq; + int ret; + + if (!test_and_clear_bit(RDMAXPRT_SQ_PENDING, &xprt->sc_flags)) + return; + + ib_req_notify_cq(xprt->sc_sq_cq, IB_CQ_NEXT_COMP); + atomic_inc(&rdma_stat_sq_poll); + while ((ret = ib_poll_cq(cq, 1, &wc)) > 0) { + if (wc.status != IB_WC_SUCCESS) + /* Close the transport */ + set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); + + /* Decrement used SQ WR count */ + atomic_dec(&xprt->sc_sq_count); + wake_up(&xprt->sc_send_wait); + + ctxt = (struct svc_rdma_op_ctxt *)(unsigned long)wc.wr_id; + if (ctxt) + process_context(xprt, ctxt); + + svc_xprt_put(&xprt->sc_xprt); + } + + if (ctxt) + atomic_inc(&rdma_stat_sq_prod); +} + +static void sq_comp_handler(struct ib_cq *cq, void *cq_context) +{ + struct svcxprt_rdma *xprt = cq_context; + unsigned long flags; + + /* Guard against unconditional flush call for destroyed QP */ + if (atomic_read(&xprt->sc_xprt.xpt_ref.refcount)==0) + return; + + /* + * Set the bit regardless of whether or not it's on the list + * because it may be on the list already due to an RQ + * completion. + */ + set_bit(RDMAXPRT_SQ_PENDING, &xprt->sc_flags); + + /* + * If this transport is not already on the DTO transport queue, + * add it + */ + spin_lock_irqsave(&dto_lock, flags); + if (list_empty(&xprt->sc_dto_q)) { + svc_xprt_get(&xprt->sc_xprt); + list_add_tail(&xprt->sc_dto_q, &dto_xprt_q); + } + spin_unlock_irqrestore(&dto_lock, flags); + + /* Tasklet does all the work to avoid irqsave locks. */ + tasklet_schedule(&dto_tasklet); +} + +static struct svcxprt_rdma *rdma_create_xprt(struct svc_serv *serv, + int listener) +{ + struct svcxprt_rdma *cma_xprt = kzalloc(sizeof *cma_xprt, GFP_KERNEL); + + if (!cma_xprt) + return NULL; + svc_xprt_init(&svc_rdma_class, &cma_xprt->sc_xprt, serv); + INIT_LIST_HEAD(&cma_xprt->sc_accept_q); + INIT_LIST_HEAD(&cma_xprt->sc_dto_q); + INIT_LIST_HEAD(&cma_xprt->sc_rq_dto_q); + INIT_LIST_HEAD(&cma_xprt->sc_read_complete_q); + INIT_LIST_HEAD(&cma_xprt->sc_frmr_q); + init_waitqueue_head(&cma_xprt->sc_send_wait); + + spin_lock_init(&cma_xprt->sc_lock); + spin_lock_init(&cma_xprt->sc_rq_dto_lock); + spin_lock_init(&cma_xprt->sc_frmr_q_lock); + + cma_xprt->sc_ord = svcrdma_ord; + + cma_xprt->sc_max_req_size = svcrdma_max_req_size; + cma_xprt->sc_max_requests = svcrdma_max_requests; + cma_xprt->sc_sq_depth = svcrdma_max_requests * RPCRDMA_SQ_DEPTH_MULT; + atomic_set(&cma_xprt->sc_sq_count, 0); + atomic_set(&cma_xprt->sc_ctxt_used, 0); + + if (listener) + set_bit(XPT_LISTENER, &cma_xprt->sc_xprt.xpt_flags); + + return cma_xprt; +} + +struct page *svc_rdma_get_page(void) +{ + struct page *page; + + while ((page = alloc_page(GFP_KERNEL)) == NULL) { + /* If we can't get memory, wait a bit and try again */ + printk(KERN_INFO "svcrdma: out of memory...retrying in 1000 " + "jiffies.\n"); + schedule_timeout_uninterruptible(msecs_to_jiffies(1000)); + } + return page; +} + +int svc_rdma_post_recv(struct svcxprt_rdma *xprt) +{ + struct ib_recv_wr recv_wr, *bad_recv_wr; + struct svc_rdma_op_ctxt *ctxt; + struct page *page; + dma_addr_t pa; + int sge_no; + int buflen; + int ret; + + ctxt = svc_rdma_get_context(xprt); + buflen = 0; + ctxt->direction = DMA_FROM_DEVICE; + for (sge_no = 0; buflen < xprt->sc_max_req_size; sge_no++) { + BUG_ON(sge_no >= xprt->sc_max_sge); + page = svc_rdma_get_page(); + ctxt->pages[sge_no] = page; + pa = ib_dma_map_page(xprt->sc_cm_id->device, + page, 0, PAGE_SIZE, + DMA_FROM_DEVICE); + if (ib_dma_mapping_error(xprt->sc_cm_id->device, pa)) + goto err_put_ctxt; + atomic_inc(&xprt->sc_dma_used); + ctxt->sge[sge_no].addr = pa; + ctxt->sge[sge_no].length = PAGE_SIZE; + ctxt->sge[sge_no].lkey = xprt->sc_dma_lkey; + buflen += PAGE_SIZE; + } + ctxt->count = sge_no; + recv_wr.next = NULL; + recv_wr.sg_list = &ctxt->sge[0]; + recv_wr.num_sge = ctxt->count; + recv_wr.wr_id = (u64)(unsigned long)ctxt; + + svc_xprt_get(&xprt->sc_xprt); + ret = ib_post_recv(xprt->sc_qp, &recv_wr, &bad_recv_wr); + if (ret) { + svc_xprt_put(&xprt->sc_xprt); + svc_rdma_put_context(ctxt, 1); + } + return ret; + + err_put_ctxt: + svc_rdma_put_context(ctxt, 1); + return -ENOMEM; +} + +/* + * This function handles the CONNECT_REQUEST event on a listening + * endpoint. It is passed the cma_id for the _new_ connection. The context in + * this cma_id is inherited from the listening cma_id and is the svc_xprt + * structure for the listening endpoint. + * + * This function creates a new xprt for the new connection and enqueues it on + * the accept queue for the listent xprt. When the listen thread is kicked, it + * will call the recvfrom method on the listen xprt which will accept the new + * connection. + */ +static void handle_connect_req(struct rdma_cm_id *new_cma_id, size_t client_ird) +{ + struct svcxprt_rdma *listen_xprt = new_cma_id->context; + struct svcxprt_rdma *newxprt; + struct sockaddr *sa; + + /* Create a new transport */ + newxprt = rdma_create_xprt(listen_xprt->sc_xprt.xpt_server, 0); + if (!newxprt) { + dprintk("svcrdma: failed to create new transport\n"); + return; + } + newxprt->sc_cm_id = new_cma_id; + new_cma_id->context = newxprt; + dprintk("svcrdma: Creating newxprt=%p, cm_id=%p, listenxprt=%p\n", + newxprt, newxprt->sc_cm_id, listen_xprt); + + /* Save client advertised inbound read limit for use later in accept. */ + newxprt->sc_ord = client_ird; + + /* Set the local and remote addresses in the transport */ + sa = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.dst_addr; + svc_xprt_set_remote(&newxprt->sc_xprt, sa, svc_addr_len(sa)); + sa = (struct sockaddr *)&newxprt->sc_cm_id->route.addr.src_addr; + svc_xprt_set_local(&newxprt->sc_xprt, sa, svc_addr_len(sa)); + + /* + * Enqueue the new transport on the accept queue of the listening + * transport + */ + spin_lock_bh(&listen_xprt->sc_lock); + list_add_tail(&newxprt->sc_accept_q, &listen_xprt->sc_accept_q); + spin_unlock_bh(&listen_xprt->sc_lock); + + /* + * Can't use svc_xprt_received here because we are not on a + * rqstp thread + */ + set_bit(XPT_CONN, &listen_xprt->sc_xprt.xpt_flags); + svc_xprt_enqueue(&listen_xprt->sc_xprt); +} + +/* + * Handles events generated on the listening endpoint. These events will be + * either be incoming connect requests or adapter removal events. + */ +static int rdma_listen_handler(struct rdma_cm_id *cma_id, + struct rdma_cm_event *event) +{ + struct svcxprt_rdma *xprt = cma_id->context; + int ret = 0; + + switch (event->event) { + case RDMA_CM_EVENT_CONNECT_REQUEST: + dprintk("svcrdma: Connect request on cma_id=%p, xprt = %p, " + "event=%d\n", cma_id, cma_id->context, event->event); + handle_connect_req(cma_id, + event->param.conn.initiator_depth); + break; + + case RDMA_CM_EVENT_ESTABLISHED: + /* Accept complete */ + dprintk("svcrdma: Connection completed on LISTEN xprt=%p, " + "cm_id=%p\n", xprt, cma_id); + break; + + case RDMA_CM_EVENT_DEVICE_REMOVAL: + dprintk("svcrdma: Device removal xprt=%p, cm_id=%p\n", + xprt, cma_id); + if (xprt) + set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); + break; + + default: + dprintk("svcrdma: Unexpected event on listening endpoint %p, " + "event=%d\n", cma_id, event->event); + break; + } + + return ret; +} + +static int rdma_cma_handler(struct rdma_cm_id *cma_id, + struct rdma_cm_event *event) +{ + struct svc_xprt *xprt = cma_id->context; + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + switch (event->event) { + case RDMA_CM_EVENT_ESTABLISHED: + /* Accept complete */ + svc_xprt_get(xprt); + dprintk("svcrdma: Connection completed on DTO xprt=%p, " + "cm_id=%p\n", xprt, cma_id); + clear_bit(RDMAXPRT_CONN_PENDING, &rdma->sc_flags); + svc_xprt_enqueue(xprt); + break; + case RDMA_CM_EVENT_DISCONNECTED: + dprintk("svcrdma: Disconnect on DTO xprt=%p, cm_id=%p\n", + xprt, cma_id); + if (xprt) { + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_xprt_enqueue(xprt); + svc_xprt_put(xprt); + } + break; + case RDMA_CM_EVENT_DEVICE_REMOVAL: + dprintk("svcrdma: Device removal cma_id=%p, xprt = %p, " + "event=%d\n", cma_id, xprt, event->event); + if (xprt) { + set_bit(XPT_CLOSE, &xprt->xpt_flags); + svc_xprt_enqueue(xprt); + } + break; + default: + dprintk("svcrdma: Unexpected event on DTO endpoint %p, " + "event=%d\n", cma_id, event->event); + break; + } + return 0; +} + +/* + * Create a listening RDMA service endpoint. + */ +static struct svc_xprt *svc_rdma_create(struct svc_serv *serv, + struct sockaddr *sa, int salen, + int flags) +{ + struct rdma_cm_id *listen_id; + struct svcxprt_rdma *cma_xprt; + struct svc_xprt *xprt; + int ret; + + dprintk("svcrdma: Creating RDMA socket\n"); + + cma_xprt = rdma_create_xprt(serv, 1); + if (!cma_xprt) + return ERR_PTR(-ENOMEM); + xprt = &cma_xprt->sc_xprt; + + listen_id = rdma_create_id(rdma_listen_handler, cma_xprt, RDMA_PS_TCP); + if (IS_ERR(listen_id)) { + ret = PTR_ERR(listen_id); + dprintk("svcrdma: rdma_create_id failed = %d\n", ret); + goto err0; + } + + ret = rdma_bind_addr(listen_id, sa); + if (ret) { + dprintk("svcrdma: rdma_bind_addr failed = %d\n", ret); + goto err1; + } + cma_xprt->sc_cm_id = listen_id; + + ret = rdma_listen(listen_id, RPCRDMA_LISTEN_BACKLOG); + if (ret) { + dprintk("svcrdma: rdma_listen failed = %d\n", ret); + goto err1; + } + + /* + * We need to use the address from the cm_id in case the + * caller specified 0 for the port number. + */ + sa = (struct sockaddr *)&cma_xprt->sc_cm_id->route.addr.src_addr; + svc_xprt_set_local(&cma_xprt->sc_xprt, sa, salen); + + return &cma_xprt->sc_xprt; + + err1: + rdma_destroy_id(listen_id); + err0: + kfree(cma_xprt); + return ERR_PTR(ret); +} + +static struct svc_rdma_fastreg_mr *rdma_alloc_frmr(struct svcxprt_rdma *xprt) +{ + struct ib_mr *mr; + struct ib_fast_reg_page_list *pl; + struct svc_rdma_fastreg_mr *frmr; + + frmr = kmalloc(sizeof(*frmr), GFP_KERNEL); + if (!frmr) + goto err; + + mr = ib_alloc_fast_reg_mr(xprt->sc_pd, RPCSVC_MAXPAGES); + if (!mr) + goto err_free_frmr; + + pl = ib_alloc_fast_reg_page_list(xprt->sc_cm_id->device, + RPCSVC_MAXPAGES); + if (!pl) + goto err_free_mr; + + frmr->mr = mr; + frmr->page_list = pl; + INIT_LIST_HEAD(&frmr->frmr_list); + return frmr; + + err_free_mr: + ib_dereg_mr(mr); + err_free_frmr: + kfree(frmr); + err: + return ERR_PTR(-ENOMEM); +} + +static void rdma_dealloc_frmr_q(struct svcxprt_rdma *xprt) +{ + struct svc_rdma_fastreg_mr *frmr; + + while (!list_empty(&xprt->sc_frmr_q)) { + frmr = list_entry(xprt->sc_frmr_q.next, + struct svc_rdma_fastreg_mr, frmr_list); + list_del_init(&frmr->frmr_list); + ib_dereg_mr(frmr->mr); + ib_free_fast_reg_page_list(frmr->page_list); + kfree(frmr); + } +} + +struct svc_rdma_fastreg_mr *svc_rdma_get_frmr(struct svcxprt_rdma *rdma) +{ + struct svc_rdma_fastreg_mr *frmr = NULL; + + spin_lock_bh(&rdma->sc_frmr_q_lock); + if (!list_empty(&rdma->sc_frmr_q)) { + frmr = list_entry(rdma->sc_frmr_q.next, + struct svc_rdma_fastreg_mr, frmr_list); + list_del_init(&frmr->frmr_list); + frmr->map_len = 0; + frmr->page_list_len = 0; + } + spin_unlock_bh(&rdma->sc_frmr_q_lock); + if (frmr) + return frmr; + + return rdma_alloc_frmr(rdma); +} + +static void frmr_unmap_dma(struct svcxprt_rdma *xprt, + struct svc_rdma_fastreg_mr *frmr) +{ + int page_no; + for (page_no = 0; page_no < frmr->page_list_len; page_no++) { + dma_addr_t addr = frmr->page_list->page_list[page_no]; + if (ib_dma_mapping_error(frmr->mr->device, addr)) + continue; + atomic_dec(&xprt->sc_dma_used); + ib_dma_unmap_single(frmr->mr->device, addr, PAGE_SIZE, + frmr->direction); + } +} + +void svc_rdma_put_frmr(struct svcxprt_rdma *rdma, + struct svc_rdma_fastreg_mr *frmr) +{ + if (frmr) { + frmr_unmap_dma(rdma, frmr); + spin_lock_bh(&rdma->sc_frmr_q_lock); + BUG_ON(!list_empty(&frmr->frmr_list)); + list_add(&frmr->frmr_list, &rdma->sc_frmr_q); + spin_unlock_bh(&rdma->sc_frmr_q_lock); + } +} + +/* + * This is the xpo_recvfrom function for listening endpoints. Its + * purpose is to accept incoming connections. The CMA callback handler + * has already created a new transport and attached it to the new CMA + * ID. + * + * There is a queue of pending connections hung on the listening + * transport. This queue contains the new svc_xprt structure. This + * function takes svc_xprt structures off the accept_q and completes + * the connection. + */ +static struct svc_xprt *svc_rdma_accept(struct svc_xprt *xprt) +{ + struct svcxprt_rdma *listen_rdma; + struct svcxprt_rdma *newxprt = NULL; + struct rdma_conn_param conn_param; + struct ib_qp_init_attr qp_attr; + struct ib_device_attr devattr; + int dma_mr_acc; + int need_dma_mr; + int ret; + int i; + + listen_rdma = container_of(xprt, struct svcxprt_rdma, sc_xprt); + clear_bit(XPT_CONN, &xprt->xpt_flags); + /* Get the next entry off the accept list */ + spin_lock_bh(&listen_rdma->sc_lock); + if (!list_empty(&listen_rdma->sc_accept_q)) { + newxprt = list_entry(listen_rdma->sc_accept_q.next, + struct svcxprt_rdma, sc_accept_q); + list_del_init(&newxprt->sc_accept_q); + } + if (!list_empty(&listen_rdma->sc_accept_q)) + set_bit(XPT_CONN, &listen_rdma->sc_xprt.xpt_flags); + spin_unlock_bh(&listen_rdma->sc_lock); + if (!newxprt) + return NULL; + + dprintk("svcrdma: newxprt from accept queue = %p, cm_id=%p\n", + newxprt, newxprt->sc_cm_id); + + ret = ib_query_device(newxprt->sc_cm_id->device, &devattr); + if (ret) { + dprintk("svcrdma: could not query device attributes on " + "device %p, rc=%d\n", newxprt->sc_cm_id->device, ret); + goto errout; + } + + /* Qualify the transport resource defaults with the + * capabilities of this particular device */ + newxprt->sc_max_sge = min((size_t)devattr.max_sge, + (size_t)RPCSVC_MAXPAGES); + newxprt->sc_max_requests = min((size_t)devattr.max_qp_wr, + (size_t)svcrdma_max_requests); + newxprt->sc_sq_depth = RPCRDMA_SQ_DEPTH_MULT * newxprt->sc_max_requests; + + /* + * Limit ORD based on client limit, local device limit, and + * configured svcrdma limit. + */ + newxprt->sc_ord = min_t(size_t, devattr.max_qp_rd_atom, newxprt->sc_ord); + newxprt->sc_ord = min_t(size_t, svcrdma_ord, newxprt->sc_ord); + + newxprt->sc_pd = ib_alloc_pd(newxprt->sc_cm_id->device); + if (IS_ERR(newxprt->sc_pd)) { + dprintk("svcrdma: error creating PD for connect request\n"); + goto errout; + } + newxprt->sc_sq_cq = ib_create_cq(newxprt->sc_cm_id->device, + sq_comp_handler, + cq_event_handler, + newxprt, + newxprt->sc_sq_depth, + 0); + if (IS_ERR(newxprt->sc_sq_cq)) { + dprintk("svcrdma: error creating SQ CQ for connect request\n"); + goto errout; + } + newxprt->sc_rq_cq = ib_create_cq(newxprt->sc_cm_id->device, + rq_comp_handler, + cq_event_handler, + newxprt, + newxprt->sc_max_requests, + 0); + if (IS_ERR(newxprt->sc_rq_cq)) { + dprintk("svcrdma: error creating RQ CQ for connect request\n"); + goto errout; + } + + memset(&qp_attr, 0, sizeof qp_attr); + qp_attr.event_handler = qp_event_handler; + qp_attr.qp_context = &newxprt->sc_xprt; + qp_attr.cap.max_send_wr = newxprt->sc_sq_depth; + qp_attr.cap.max_recv_wr = newxprt->sc_max_requests; + qp_attr.cap.max_send_sge = newxprt->sc_max_sge; + qp_attr.cap.max_recv_sge = newxprt->sc_max_sge; + qp_attr.sq_sig_type = IB_SIGNAL_REQ_WR; + qp_attr.qp_type = IB_QPT_RC; + qp_attr.send_cq = newxprt->sc_sq_cq; + qp_attr.recv_cq = newxprt->sc_rq_cq; + dprintk("svcrdma: newxprt->sc_cm_id=%p, newxprt->sc_pd=%p\n" + " cm_id->device=%p, sc_pd->device=%p\n" + " cap.max_send_wr = %d\n" + " cap.max_recv_wr = %d\n" + " cap.max_send_sge = %d\n" + " cap.max_recv_sge = %d\n", + newxprt->sc_cm_id, newxprt->sc_pd, + newxprt->sc_cm_id->device, newxprt->sc_pd->device, + qp_attr.cap.max_send_wr, + qp_attr.cap.max_recv_wr, + qp_attr.cap.max_send_sge, + qp_attr.cap.max_recv_sge); + + ret = rdma_create_qp(newxprt->sc_cm_id, newxprt->sc_pd, &qp_attr); + if (ret) { + /* + * XXX: This is a hack. We need a xx_request_qp interface + * that will adjust the qp_attr's with a best-effort + * number + */ + qp_attr.cap.max_send_sge -= 2; + qp_attr.cap.max_recv_sge -= 2; + ret = rdma_create_qp(newxprt->sc_cm_id, newxprt->sc_pd, + &qp_attr); + if (ret) { + dprintk("svcrdma: failed to create QP, ret=%d\n", ret); + goto errout; + } + newxprt->sc_max_sge = qp_attr.cap.max_send_sge; + newxprt->sc_max_sge = qp_attr.cap.max_recv_sge; + newxprt->sc_sq_depth = qp_attr.cap.max_send_wr; + newxprt->sc_max_requests = qp_attr.cap.max_recv_wr; + } + newxprt->sc_qp = newxprt->sc_cm_id->qp; + + /* + * Use the most secure set of MR resources based on the + * transport type and available memory management features in + * the device. Here's the table implemented below: + * + * Fast Global DMA Remote WR + * Reg LKEY MR Access + * Sup'd Sup'd Needed Needed + * + * IWARP N N Y Y + * N Y Y Y + * Y N Y N + * Y Y N - + * + * IB N N Y N + * N Y N - + * Y N Y N + * Y Y N - + * + * NB: iWARP requires remote write access for the data sink + * of an RDMA_READ. IB does not. + */ + if (devattr.device_cap_flags & IB_DEVICE_MEM_MGT_EXTENSIONS) { + newxprt->sc_frmr_pg_list_len = + devattr.max_fast_reg_page_list_len; + newxprt->sc_dev_caps |= SVCRDMA_DEVCAP_FAST_REG; + } + + /* + * Determine if a DMA MR is required and if so, what privs are required + */ + switch (rdma_node_get_transport(newxprt->sc_cm_id->device->node_type)) { + case RDMA_TRANSPORT_IWARP: + newxprt->sc_dev_caps |= SVCRDMA_DEVCAP_READ_W_INV; + if (!(newxprt->sc_dev_caps & SVCRDMA_DEVCAP_FAST_REG)) { + need_dma_mr = 1; + dma_mr_acc = + (IB_ACCESS_LOCAL_WRITE | + IB_ACCESS_REMOTE_WRITE); + } else if (!(devattr.device_cap_flags & IB_DEVICE_LOCAL_DMA_LKEY)) { + need_dma_mr = 1; + dma_mr_acc = IB_ACCESS_LOCAL_WRITE; + } else + need_dma_mr = 0; + break; + case RDMA_TRANSPORT_IB: + if (!(devattr.device_cap_flags & IB_DEVICE_LOCAL_DMA_LKEY)) { + need_dma_mr = 1; + dma_mr_acc = IB_ACCESS_LOCAL_WRITE; + } else + need_dma_mr = 0; + break; + default: + goto errout; + } + + /* Create the DMA MR if needed, otherwise, use the DMA LKEY */ + if (need_dma_mr) { + /* Register all of physical memory */ + newxprt->sc_phys_mr = + ib_get_dma_mr(newxprt->sc_pd, dma_mr_acc); + if (IS_ERR(newxprt->sc_phys_mr)) { + dprintk("svcrdma: Failed to create DMA MR ret=%d\n", + ret); + goto errout; + } + newxprt->sc_dma_lkey = newxprt->sc_phys_mr->lkey; + } else + newxprt->sc_dma_lkey = + newxprt->sc_cm_id->device->local_dma_lkey; + + /* Post receive buffers */ + for (i = 0; i < newxprt->sc_max_requests; i++) { + ret = svc_rdma_post_recv(newxprt); + if (ret) { + dprintk("svcrdma: failure posting receive buffers\n"); + goto errout; + } + } + + /* Swap out the handler */ + newxprt->sc_cm_id->event_handler = rdma_cma_handler; + + /* + * Arm the CQs for the SQ and RQ before accepting so we can't + * miss the first message + */ + ib_req_notify_cq(newxprt->sc_sq_cq, IB_CQ_NEXT_COMP); + ib_req_notify_cq(newxprt->sc_rq_cq, IB_CQ_NEXT_COMP); + + /* Accept Connection */ + set_bit(RDMAXPRT_CONN_PENDING, &newxprt->sc_flags); + memset(&conn_param, 0, sizeof conn_param); + conn_param.responder_resources = 0; + conn_param.initiator_depth = newxprt->sc_ord; + ret = rdma_accept(newxprt->sc_cm_id, &conn_param); + if (ret) { + dprintk("svcrdma: failed to accept new connection, ret=%d\n", + ret); + goto errout; + } + + dprintk("svcrdma: new connection %p accepted with the following " + "attributes:\n" + " local_ip : %d.%d.%d.%d\n" + " local_port : %d\n" + " remote_ip : %d.%d.%d.%d\n" + " remote_port : %d\n" + " max_sge : %d\n" + " sq_depth : %d\n" + " max_requests : %d\n" + " ord : %d\n", + newxprt, + NIPQUAD(((struct sockaddr_in *)&newxprt->sc_cm_id-> + route.addr.src_addr)->sin_addr.s_addr), + ntohs(((struct sockaddr_in *)&newxprt->sc_cm_id-> + route.addr.src_addr)->sin_port), + NIPQUAD(((struct sockaddr_in *)&newxprt->sc_cm_id-> + route.addr.dst_addr)->sin_addr.s_addr), + ntohs(((struct sockaddr_in *)&newxprt->sc_cm_id-> + route.addr.dst_addr)->sin_port), + newxprt->sc_max_sge, + newxprt->sc_sq_depth, + newxprt->sc_max_requests, + newxprt->sc_ord); + + return &newxprt->sc_xprt; + + errout: + dprintk("svcrdma: failure accepting new connection rc=%d.\n", ret); + /* Take a reference in case the DTO handler runs */ + svc_xprt_get(&newxprt->sc_xprt); + if (newxprt->sc_qp && !IS_ERR(newxprt->sc_qp)) + ib_destroy_qp(newxprt->sc_qp); + rdma_destroy_id(newxprt->sc_cm_id); + /* This call to put will destroy the transport */ + svc_xprt_put(&newxprt->sc_xprt); + return NULL; +} + +static void svc_rdma_release_rqst(struct svc_rqst *rqstp) +{ +} + +/* + * When connected, an svc_xprt has at least two references: + * + * - A reference held by the cm_id between the ESTABLISHED and + * DISCONNECTED events. If the remote peer disconnected first, this + * reference could be gone. + * + * - A reference held by the svc_recv code that called this function + * as part of close processing. + * + * At a minimum one references should still be held. + */ +static void svc_rdma_detach(struct svc_xprt *xprt) +{ + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + dprintk("svc: svc_rdma_detach(%p)\n", xprt); + + /* Disconnect and flush posted WQE */ + rdma_disconnect(rdma->sc_cm_id); +} + +static void __svc_rdma_free(struct work_struct *work) +{ + struct svcxprt_rdma *rdma = + container_of(work, struct svcxprt_rdma, sc_work); + dprintk("svcrdma: svc_rdma_free(%p)\n", rdma); + + /* We should only be called from kref_put */ + BUG_ON(atomic_read(&rdma->sc_xprt.xpt_ref.refcount) != 0); + + /* + * Destroy queued, but not processed read completions. Note + * that this cleanup has to be done before destroying the + * cm_id because the device ptr is needed to unmap the dma in + * svc_rdma_put_context. + */ + while (!list_empty(&rdma->sc_read_complete_q)) { + struct svc_rdma_op_ctxt *ctxt; + ctxt = list_entry(rdma->sc_read_complete_q.next, + struct svc_rdma_op_ctxt, + dto_q); + list_del_init(&ctxt->dto_q); + svc_rdma_put_context(ctxt, 1); + } + + /* Destroy queued, but not processed recv completions */ + while (!list_empty(&rdma->sc_rq_dto_q)) { + struct svc_rdma_op_ctxt *ctxt; + ctxt = list_entry(rdma->sc_rq_dto_q.next, + struct svc_rdma_op_ctxt, + dto_q); + list_del_init(&ctxt->dto_q); + svc_rdma_put_context(ctxt, 1); + } + + /* Warn if we leaked a resource or under-referenced */ + WARN_ON(atomic_read(&rdma->sc_ctxt_used) != 0); + WARN_ON(atomic_read(&rdma->sc_dma_used) != 0); + + /* De-allocate fastreg mr */ + rdma_dealloc_frmr_q(rdma); + + /* Destroy the QP if present (not a listener) */ + if (rdma->sc_qp && !IS_ERR(rdma->sc_qp)) + ib_destroy_qp(rdma->sc_qp); + + if (rdma->sc_sq_cq && !IS_ERR(rdma->sc_sq_cq)) + ib_destroy_cq(rdma->sc_sq_cq); + + if (rdma->sc_rq_cq && !IS_ERR(rdma->sc_rq_cq)) + ib_destroy_cq(rdma->sc_rq_cq); + + if (rdma->sc_phys_mr && !IS_ERR(rdma->sc_phys_mr)) + ib_dereg_mr(rdma->sc_phys_mr); + + if (rdma->sc_pd && !IS_ERR(rdma->sc_pd)) + ib_dealloc_pd(rdma->sc_pd); + + /* Destroy the CM ID */ + rdma_destroy_id(rdma->sc_cm_id); + + kfree(rdma); +} + +static void svc_rdma_free(struct svc_xprt *xprt) +{ + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + INIT_WORK(&rdma->sc_work, __svc_rdma_free); + schedule_work(&rdma->sc_work); +} + +static int svc_rdma_has_wspace(struct svc_xprt *xprt) +{ + struct svcxprt_rdma *rdma = + container_of(xprt, struct svcxprt_rdma, sc_xprt); + + /* + * If there are fewer SQ WR available than required to send a + * simple response, return false. + */ + if ((rdma->sc_sq_depth - atomic_read(&rdma->sc_sq_count) < 3)) + return 0; + + /* + * ...or there are already waiters on the SQ, + * return false. + */ + if (waitqueue_active(&rdma->sc_send_wait)) + return 0; + + /* Otherwise return true. */ + return 1; +} + +/* + * Attempt to register the kvec representing the RPC memory with the + * device. + * + * Returns: + * NULL : The device does not support fastreg or there were no more + * fastreg mr. + * frmr : The kvec register request was successfully posted. + * <0 : An error was encountered attempting to register the kvec. + */ +int svc_rdma_fastreg(struct svcxprt_rdma *xprt, + struct svc_rdma_fastreg_mr *frmr) +{ + struct ib_send_wr fastreg_wr; + u8 key; + + /* Bump the key */ + key = (u8)(frmr->mr->lkey & 0x000000FF); + ib_update_fast_reg_key(frmr->mr, ++key); + + /* Prepare FASTREG WR */ + memset(&fastreg_wr, 0, sizeof fastreg_wr); + fastreg_wr.opcode = IB_WR_FAST_REG_MR; + fastreg_wr.send_flags = IB_SEND_SIGNALED; + fastreg_wr.wr.fast_reg.iova_start = (unsigned long)frmr->kva; + fastreg_wr.wr.fast_reg.page_list = frmr->page_list; + fastreg_wr.wr.fast_reg.page_list_len = frmr->page_list_len; + fastreg_wr.wr.fast_reg.page_shift = PAGE_SHIFT; + fastreg_wr.wr.fast_reg.length = frmr->map_len; + fastreg_wr.wr.fast_reg.access_flags = frmr->access_flags; + fastreg_wr.wr.fast_reg.rkey = frmr->mr->lkey; + return svc_rdma_send(xprt, &fastreg_wr); +} + +int svc_rdma_send(struct svcxprt_rdma *xprt, struct ib_send_wr *wr) +{ + struct ib_send_wr *bad_wr, *n_wr; + int wr_count; + int i; + int ret; + + if (test_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags)) + return -ENOTCONN; + + BUG_ON(wr->send_flags != IB_SEND_SIGNALED); + wr_count = 1; + for (n_wr = wr->next; n_wr; n_wr = n_wr->next) + wr_count++; + + /* If the SQ is full, wait until an SQ entry is available */ + while (1) { + spin_lock_bh(&xprt->sc_lock); + if (xprt->sc_sq_depth < atomic_read(&xprt->sc_sq_count) + wr_count) { + spin_unlock_bh(&xprt->sc_lock); + atomic_inc(&rdma_stat_sq_starve); + + /* See if we can opportunistically reap SQ WR to make room */ + sq_cq_reap(xprt); + + /* Wait until SQ WR available if SQ still full */ + wait_event(xprt->sc_send_wait, + atomic_read(&xprt->sc_sq_count) < + xprt->sc_sq_depth); + if (test_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags)) + return 0; + continue; + } + /* Take a transport ref for each WR posted */ + for (i = 0; i < wr_count; i++) + svc_xprt_get(&xprt->sc_xprt); + + /* Bump used SQ WR count and post */ + atomic_add(wr_count, &xprt->sc_sq_count); + ret = ib_post_send(xprt->sc_qp, wr, &bad_wr); + if (ret) { + set_bit(XPT_CLOSE, &xprt->sc_xprt.xpt_flags); + atomic_sub(wr_count, &xprt->sc_sq_count); + for (i = 0; i < wr_count; i ++) + svc_xprt_put(&xprt->sc_xprt); + dprintk("svcrdma: failed to post SQ WR rc=%d, " + "sc_sq_count=%d, sc_sq_depth=%d\n", + ret, atomic_read(&xprt->sc_sq_count), + xprt->sc_sq_depth); + } + spin_unlock_bh(&xprt->sc_lock); + if (ret) + wake_up(&xprt->sc_send_wait); + break; + } + return ret; +} + +void svc_rdma_send_error(struct svcxprt_rdma *xprt, struct rpcrdma_msg *rmsgp, + enum rpcrdma_errcode err) +{ + struct ib_send_wr err_wr; + struct ib_sge sge; + struct page *p; + struct svc_rdma_op_ctxt *ctxt; + u32 *va; + int length; + int ret; + + p = svc_rdma_get_page(); + va = page_address(p); + + /* XDR encode error */ + length = svc_rdma_xdr_encode_error(xprt, rmsgp, err, va); + + /* Prepare SGE for local address */ + sge.addr = ib_dma_map_page(xprt->sc_cm_id->device, + p, 0, PAGE_SIZE, DMA_FROM_DEVICE); + if (ib_dma_mapping_error(xprt->sc_cm_id->device, sge.addr)) { + put_page(p); + return; + } + atomic_inc(&xprt->sc_dma_used); + sge.lkey = xprt->sc_dma_lkey; + sge.length = length; + + ctxt = svc_rdma_get_context(xprt); + ctxt->count = 1; + ctxt->pages[0] = p; + + /* Prepare SEND WR */ + memset(&err_wr, 0, sizeof err_wr); + ctxt->wr_op = IB_WR_SEND; + err_wr.wr_id = (unsigned long)ctxt; + err_wr.sg_list = &sge; + err_wr.num_sge = 1; + err_wr.opcode = IB_WR_SEND; + err_wr.send_flags = IB_SEND_SIGNALED; + + /* Post It */ + ret = svc_rdma_send(xprt, &err_wr); + if (ret) { + dprintk("svcrdma: Error %d posting send for protocol error\n", + ret); + ib_dma_unmap_page(xprt->sc_cm_id->device, + sge.addr, PAGE_SIZE, + DMA_FROM_DEVICE); + svc_rdma_put_context(ctxt, 1); + } +} diff --git a/net/sunrpc/xprtrdma/transport.c b/net/sunrpc/xprtrdma/transport.c new file mode 100644 index 0000000..9839c3d --- /dev/null +++ b/net/sunrpc/xprtrdma/transport.c @@ -0,0 +1,834 @@ +/* + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * transport.c + * + * This file contains the top-level implementation of an RPC RDMA + * transport. + * + * Naming convention: functions beginning with xprt_ are part of the + * transport switch. All others are RPC RDMA internal. + */ + +#include <linux/module.h> +#include <linux/init.h> +#include <linux/seq_file.h> + +#include "xprt_rdma.h" + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_TRANS +#endif + +MODULE_LICENSE("Dual BSD/GPL"); + +MODULE_DESCRIPTION("RPC/RDMA Transport for Linux kernel NFS"); +MODULE_AUTHOR("Network Appliance, Inc."); + +/* + * tunables + */ + +static unsigned int xprt_rdma_slot_table_entries = RPCRDMA_DEF_SLOT_TABLE; +static unsigned int xprt_rdma_max_inline_read = RPCRDMA_DEF_INLINE; +static unsigned int xprt_rdma_max_inline_write = RPCRDMA_DEF_INLINE; +static unsigned int xprt_rdma_inline_write_padding; +static unsigned int xprt_rdma_memreg_strategy = RPCRDMA_FRMR; + int xprt_rdma_pad_optimize = 0; + +#ifdef RPC_DEBUG + +static unsigned int min_slot_table_size = RPCRDMA_MIN_SLOT_TABLE; +static unsigned int max_slot_table_size = RPCRDMA_MAX_SLOT_TABLE; +static unsigned int zero; +static unsigned int max_padding = PAGE_SIZE; +static unsigned int min_memreg = RPCRDMA_BOUNCEBUFFERS; +static unsigned int max_memreg = RPCRDMA_LAST - 1; + +static struct ctl_table_header *sunrpc_table_header; + +static ctl_table xr_tunables_table[] = { + { + .ctl_name = CTL_UNNUMBERED, + .procname = "rdma_slot_table_entries", + .data = &xprt_rdma_slot_table_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = &proc_dointvec_minmax, + .strategy = &sysctl_intvec, + .extra1 = &min_slot_table_size, + .extra2 = &max_slot_table_size + }, + { + .ctl_name = CTL_UNNUMBERED, + .procname = "rdma_max_inline_read", + .data = &xprt_rdma_max_inline_read, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = &proc_dointvec, + .strategy = &sysctl_intvec, + }, + { + .ctl_name = CTL_UNNUMBERED, + .procname = "rdma_max_inline_write", + .data = &xprt_rdma_max_inline_write, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = &proc_dointvec, + .strategy = &sysctl_intvec, + }, + { + .ctl_name = CTL_UNNUMBERED, + .procname = "rdma_inline_write_padding", + .data = &xprt_rdma_inline_write_padding, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = &proc_dointvec_minmax, + .strategy = &sysctl_intvec, + .extra1 = &zero, + .extra2 = &max_padding, + }, + { + .ctl_name = CTL_UNNUMBERED, + .procname = "rdma_memreg_strategy", + .data = &xprt_rdma_memreg_strategy, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = &proc_dointvec_minmax, + .strategy = &sysctl_intvec, + .extra1 = &min_memreg, + .extra2 = &max_memreg, + }, + { + .ctl_name = CTL_UNNUMBERED, + .procname = "rdma_pad_optimize", + .data = &xprt_rdma_pad_optimize, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = &proc_dointvec, + }, + { + .ctl_name = 0, + }, +}; + +static ctl_table sunrpc_table[] = { + { + .ctl_name = CTL_SUNRPC, + .procname = "sunrpc", + .mode = 0555, + .child = xr_tunables_table + }, + { + .ctl_name = 0, + }, +}; + +#endif + +static struct rpc_xprt_ops xprt_rdma_procs; /* forward reference */ + +static void +xprt_rdma_format_addresses(struct rpc_xprt *xprt) +{ + struct sockaddr_in *addr = (struct sockaddr_in *) + &rpcx_to_rdmad(xprt).addr; + char *buf; + + buf = kzalloc(20, GFP_KERNEL); + if (buf) + snprintf(buf, 20, NIPQUAD_FMT, NIPQUAD(addr->sin_addr.s_addr)); + xprt->address_strings[RPC_DISPLAY_ADDR] = buf; + + buf = kzalloc(8, GFP_KERNEL); + if (buf) + snprintf(buf, 8, "%u", ntohs(addr->sin_port)); + xprt->address_strings[RPC_DISPLAY_PORT] = buf; + + xprt->address_strings[RPC_DISPLAY_PROTO] = "rdma"; + + buf = kzalloc(48, GFP_KERNEL); + if (buf) + snprintf(buf, 48, "addr="NIPQUAD_FMT" port=%u proto=%s", + NIPQUAD(addr->sin_addr.s_addr), + ntohs(addr->sin_port), "rdma"); + xprt->address_strings[RPC_DISPLAY_ALL] = buf; + + buf = kzalloc(10, GFP_KERNEL); + if (buf) + snprintf(buf, 10, "%02x%02x%02x%02x", + NIPQUAD(addr->sin_addr.s_addr)); + xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = buf; + + buf = kzalloc(8, GFP_KERNEL); + if (buf) + snprintf(buf, 8, "%4hx", ntohs(addr->sin_port)); + xprt->address_strings[RPC_DISPLAY_HEX_PORT] = buf; + + buf = kzalloc(30, GFP_KERNEL); + if (buf) + snprintf(buf, 30, NIPQUAD_FMT".%u.%u", + NIPQUAD(addr->sin_addr.s_addr), + ntohs(addr->sin_port) >> 8, + ntohs(addr->sin_port) & 0xff); + xprt->address_strings[RPC_DISPLAY_UNIVERSAL_ADDR] = buf; + + /* netid */ + xprt->address_strings[RPC_DISPLAY_NETID] = "rdma"; +} + +static void +xprt_rdma_free_addresses(struct rpc_xprt *xprt) +{ + unsigned int i; + + for (i = 0; i < RPC_DISPLAY_MAX; i++) + switch (i) { + case RPC_DISPLAY_PROTO: + case RPC_DISPLAY_NETID: + continue; + default: + kfree(xprt->address_strings[i]); + } +} + +static void +xprt_rdma_connect_worker(struct work_struct *work) +{ + struct rpcrdma_xprt *r_xprt = + container_of(work, struct rpcrdma_xprt, rdma_connect.work); + struct rpc_xprt *xprt = &r_xprt->xprt; + int rc = 0; + + if (!xprt->shutdown) { + xprt_clear_connected(xprt); + + dprintk("RPC: %s: %sconnect\n", __func__, + r_xprt->rx_ep.rep_connected != 0 ? "re" : ""); + rc = rpcrdma_ep_connect(&r_xprt->rx_ep, &r_xprt->rx_ia); + if (rc) + goto out; + } + goto out_clear; + +out: + xprt_wake_pending_tasks(xprt, rc); + +out_clear: + dprintk("RPC: %s: exit\n", __func__); + xprt_clear_connecting(xprt); +} + +/* + * xprt_rdma_destroy + * + * Destroy the xprt. + * Free all memory associated with the object, including its own. + * NOTE: none of the *destroy methods free memory for their top-level + * objects, even though they may have allocated it (they do free + * private memory). It's up to the caller to handle it. In this + * case (RDMA transport), all structure memory is inlined with the + * struct rpcrdma_xprt. + */ +static void +xprt_rdma_destroy(struct rpc_xprt *xprt) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + int rc; + + dprintk("RPC: %s: called\n", __func__); + + cancel_delayed_work(&r_xprt->rdma_connect); + flush_scheduled_work(); + + xprt_clear_connected(xprt); + + rpcrdma_buffer_destroy(&r_xprt->rx_buf); + rc = rpcrdma_ep_destroy(&r_xprt->rx_ep, &r_xprt->rx_ia); + if (rc) + dprintk("RPC: %s: rpcrdma_ep_destroy returned %i\n", + __func__, rc); + rpcrdma_ia_close(&r_xprt->rx_ia); + + xprt_rdma_free_addresses(xprt); + + kfree(xprt->slot); + xprt->slot = NULL; + kfree(xprt); + + dprintk("RPC: %s: returning\n", __func__); + + module_put(THIS_MODULE); +} + +static const struct rpc_timeout xprt_rdma_default_timeout = { + .to_initval = 60 * HZ, + .to_maxval = 60 * HZ, +}; + +/** + * xprt_setup_rdma - Set up transport to use RDMA + * + * @args: rpc transport arguments + */ +static struct rpc_xprt * +xprt_setup_rdma(struct xprt_create *args) +{ + struct rpcrdma_create_data_internal cdata; + struct rpc_xprt *xprt; + struct rpcrdma_xprt *new_xprt; + struct rpcrdma_ep *new_ep; + struct sockaddr_in *sin; + int rc; + + if (args->addrlen > sizeof(xprt->addr)) { + dprintk("RPC: %s: address too large\n", __func__); + return ERR_PTR(-EBADF); + } + + xprt = kzalloc(sizeof(struct rpcrdma_xprt), GFP_KERNEL); + if (xprt == NULL) { + dprintk("RPC: %s: couldn't allocate rpcrdma_xprt\n", + __func__); + return ERR_PTR(-ENOMEM); + } + + xprt->max_reqs = xprt_rdma_slot_table_entries; + xprt->slot = kcalloc(xprt->max_reqs, + sizeof(struct rpc_rqst), GFP_KERNEL); + if (xprt->slot == NULL) { + dprintk("RPC: %s: couldn't allocate %d slots\n", + __func__, xprt->max_reqs); + kfree(xprt); + return ERR_PTR(-ENOMEM); + } + + /* 60 second timeout, no retries */ + xprt->timeout = &xprt_rdma_default_timeout; + xprt->bind_timeout = (60U * HZ); + xprt->connect_timeout = (60U * HZ); + xprt->reestablish_timeout = (5U * HZ); + xprt->idle_timeout = (5U * 60 * HZ); + + xprt->resvport = 0; /* privileged port not needed */ + xprt->tsh_size = 0; /* RPC-RDMA handles framing */ + xprt->max_payload = RPCRDMA_MAX_DATA_SEGS * PAGE_SIZE; + xprt->ops = &xprt_rdma_procs; + + /* + * Set up RDMA-specific connect data. + */ + + /* Put server RDMA address in local cdata */ + memcpy(&cdata.addr, args->dstaddr, args->addrlen); + + /* Ensure xprt->addr holds valid server TCP (not RDMA) + * address, for any side protocols which peek at it */ + xprt->prot = IPPROTO_TCP; + xprt->addrlen = args->addrlen; + memcpy(&xprt->addr, &cdata.addr, xprt->addrlen); + + sin = (struct sockaddr_in *)&cdata.addr; + if (ntohs(sin->sin_port) != 0) + xprt_set_bound(xprt); + + dprintk("RPC: %s: %u.%u.%u.%u:%u\n", __func__, + NIPQUAD(sin->sin_addr.s_addr), ntohs(sin->sin_port)); + + /* Set max requests */ + cdata.max_requests = xprt->max_reqs; + + /* Set some length limits */ + cdata.rsize = RPCRDMA_MAX_SEGS * PAGE_SIZE; /* RDMA write max */ + cdata.wsize = RPCRDMA_MAX_SEGS * PAGE_SIZE; /* RDMA read max */ + + cdata.inline_wsize = xprt_rdma_max_inline_write; + if (cdata.inline_wsize > cdata.wsize) + cdata.inline_wsize = cdata.wsize; + + cdata.inline_rsize = xprt_rdma_max_inline_read; + if (cdata.inline_rsize > cdata.rsize) + cdata.inline_rsize = cdata.rsize; + + cdata.padding = xprt_rdma_inline_write_padding; + + /* + * Create new transport instance, which includes initialized + * o ia + * o endpoint + * o buffers + */ + + new_xprt = rpcx_to_rdmax(xprt); + + rc = rpcrdma_ia_open(new_xprt, (struct sockaddr *) &cdata.addr, + xprt_rdma_memreg_strategy); + if (rc) + goto out1; + + /* + * initialize and create ep + */ + new_xprt->rx_data = cdata; + new_ep = &new_xprt->rx_ep; + new_ep->rep_remote_addr = cdata.addr; + + rc = rpcrdma_ep_create(&new_xprt->rx_ep, + &new_xprt->rx_ia, &new_xprt->rx_data); + if (rc) + goto out2; + + /* + * Allocate pre-registered send and receive buffers for headers and + * any inline data. Also specify any padding which will be provided + * from a preregistered zero buffer. + */ + rc = rpcrdma_buffer_create(&new_xprt->rx_buf, new_ep, &new_xprt->rx_ia, + &new_xprt->rx_data); + if (rc) + goto out3; + + /* + * Register a callback for connection events. This is necessary because + * connection loss notification is async. We also catch connection loss + * when reaping receives. + */ + INIT_DELAYED_WORK(&new_xprt->rdma_connect, xprt_rdma_connect_worker); + new_ep->rep_func = rpcrdma_conn_func; + new_ep->rep_xprt = xprt; + + xprt_rdma_format_addresses(xprt); + + if (!try_module_get(THIS_MODULE)) + goto out4; + + return xprt; + +out4: + xprt_rdma_free_addresses(xprt); + rc = -EINVAL; +out3: + (void) rpcrdma_ep_destroy(new_ep, &new_xprt->rx_ia); +out2: + rpcrdma_ia_close(&new_xprt->rx_ia); +out1: + kfree(xprt->slot); + kfree(xprt); + return ERR_PTR(rc); +} + +/* + * Close a connection, during shutdown or timeout/reconnect + */ +static void +xprt_rdma_close(struct rpc_xprt *xprt) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + + dprintk("RPC: %s: closing\n", __func__); + if (r_xprt->rx_ep.rep_connected > 0) + xprt->reestablish_timeout = 0; + xprt_disconnect_done(xprt); + (void) rpcrdma_ep_disconnect(&r_xprt->rx_ep, &r_xprt->rx_ia); +} + +static void +xprt_rdma_set_port(struct rpc_xprt *xprt, u16 port) +{ + struct sockaddr_in *sap; + + sap = (struct sockaddr_in *)&xprt->addr; + sap->sin_port = htons(port); + sap = (struct sockaddr_in *)&rpcx_to_rdmad(xprt).addr; + sap->sin_port = htons(port); + dprintk("RPC: %s: %u\n", __func__, port); +} + +static void +xprt_rdma_connect(struct rpc_task *task) +{ + struct rpc_xprt *xprt = (struct rpc_xprt *)task->tk_xprt; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + + if (!xprt_test_and_set_connecting(xprt)) { + if (r_xprt->rx_ep.rep_connected != 0) { + /* Reconnect */ + schedule_delayed_work(&r_xprt->rdma_connect, + xprt->reestablish_timeout); + xprt->reestablish_timeout <<= 1; + if (xprt->reestablish_timeout > (30 * HZ)) + xprt->reestablish_timeout = (30 * HZ); + else if (xprt->reestablish_timeout < (5 * HZ)) + xprt->reestablish_timeout = (5 * HZ); + } else { + schedule_delayed_work(&r_xprt->rdma_connect, 0); + if (!RPC_IS_ASYNC(task)) + flush_scheduled_work(); + } + } +} + +static int +xprt_rdma_reserve_xprt(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + int credits = atomic_read(&r_xprt->rx_buf.rb_credits); + + /* == RPC_CWNDSCALE @ init, but *after* setup */ + if (r_xprt->rx_buf.rb_cwndscale == 0UL) { + r_xprt->rx_buf.rb_cwndscale = xprt->cwnd; + dprintk("RPC: %s: cwndscale %lu\n", __func__, + r_xprt->rx_buf.rb_cwndscale); + BUG_ON(r_xprt->rx_buf.rb_cwndscale <= 0); + } + xprt->cwnd = credits * r_xprt->rx_buf.rb_cwndscale; + return xprt_reserve_xprt_cong(task); +} + +/* + * The RDMA allocate/free functions need the task structure as a place + * to hide the struct rpcrdma_req, which is necessary for the actual send/recv + * sequence. For this reason, the recv buffers are attached to send + * buffers for portions of the RPC. Note that the RPC layer allocates + * both send and receive buffers in the same call. We may register + * the receive buffer portion when using reply chunks. + */ +static void * +xprt_rdma_allocate(struct rpc_task *task, size_t size) +{ + struct rpc_xprt *xprt = task->tk_xprt; + struct rpcrdma_req *req, *nreq; + + req = rpcrdma_buffer_get(&rpcx_to_rdmax(xprt)->rx_buf); + BUG_ON(NULL == req); + + if (size > req->rl_size) { + dprintk("RPC: %s: size %zd too large for buffer[%zd]: " + "prog %d vers %d proc %d\n", + __func__, size, req->rl_size, + task->tk_client->cl_prog, task->tk_client->cl_vers, + task->tk_msg.rpc_proc->p_proc); + /* + * Outgoing length shortage. Our inline write max must have + * been configured to perform direct i/o. + * + * This is therefore a large metadata operation, and the + * allocate call was made on the maximum possible message, + * e.g. containing long filename(s) or symlink data. In + * fact, while these metadata operations *might* carry + * large outgoing payloads, they rarely *do*. However, we + * have to commit to the request here, so reallocate and + * register it now. The data path will never require this + * reallocation. + * + * If the allocation or registration fails, the RPC framework + * will (doggedly) retry. + */ + if (rpcx_to_rdmax(xprt)->rx_ia.ri_memreg_strategy == + RPCRDMA_BOUNCEBUFFERS) { + /* forced to "pure inline" */ + dprintk("RPC: %s: too much data (%zd) for inline " + "(r/w max %d/%d)\n", __func__, size, + rpcx_to_rdmad(xprt).inline_rsize, + rpcx_to_rdmad(xprt).inline_wsize); + size = req->rl_size; + rpc_exit(task, -EIO); /* fail the operation */ + rpcx_to_rdmax(xprt)->rx_stats.failed_marshal_count++; + goto out; + } + if (task->tk_flags & RPC_TASK_SWAPPER) + nreq = kmalloc(sizeof *req + size, GFP_ATOMIC); + else + nreq = kmalloc(sizeof *req + size, GFP_NOFS); + if (nreq == NULL) + goto outfail; + + if (rpcrdma_register_internal(&rpcx_to_rdmax(xprt)->rx_ia, + nreq->rl_base, size + sizeof(struct rpcrdma_req) + - offsetof(struct rpcrdma_req, rl_base), + &nreq->rl_handle, &nreq->rl_iov)) { + kfree(nreq); + goto outfail; + } + rpcx_to_rdmax(xprt)->rx_stats.hardway_register_count += size; + nreq->rl_size = size; + nreq->rl_niovs = 0; + nreq->rl_nchunks = 0; + nreq->rl_buffer = (struct rpcrdma_buffer *)req; + nreq->rl_reply = req->rl_reply; + memcpy(nreq->rl_segments, + req->rl_segments, sizeof nreq->rl_segments); + /* flag the swap with an unused field */ + nreq->rl_iov.length = 0; + req->rl_reply = NULL; + req = nreq; + } + dprintk("RPC: %s: size %zd, request 0x%p\n", __func__, size, req); +out: + req->rl_connect_cookie = 0; /* our reserved value */ + return req->rl_xdr_buf; + +outfail: + rpcrdma_buffer_put(req); + rpcx_to_rdmax(xprt)->rx_stats.failed_marshal_count++; + return NULL; +} + +/* + * This function returns all RDMA resources to the pool. + */ +static void +xprt_rdma_free(void *buffer) +{ + struct rpcrdma_req *req; + struct rpcrdma_xprt *r_xprt; + struct rpcrdma_rep *rep; + int i; + + if (buffer == NULL) + return; + + req = container_of(buffer, struct rpcrdma_req, rl_xdr_buf[0]); + if (req->rl_iov.length == 0) { /* see allocate above */ + r_xprt = container_of(((struct rpcrdma_req *) req->rl_buffer)->rl_buffer, + struct rpcrdma_xprt, rx_buf); + } else + r_xprt = container_of(req->rl_buffer, struct rpcrdma_xprt, rx_buf); + rep = req->rl_reply; + + dprintk("RPC: %s: called on 0x%p%s\n", + __func__, rep, (rep && rep->rr_func) ? " (with waiter)" : ""); + + /* + * Finish the deregistration. When using mw bind, this was + * begun in rpcrdma_reply_handler(). In all other modes, we + * do it here, in thread context. The process is considered + * complete when the rr_func vector becomes NULL - this + * was put in place during rpcrdma_reply_handler() - the wait + * call below will not block if the dereg is "done". If + * interrupted, our framework will clean up. + */ + for (i = 0; req->rl_nchunks;) { + --req->rl_nchunks; + i += rpcrdma_deregister_external( + &req->rl_segments[i], r_xprt, NULL); + } + + if (rep && wait_event_interruptible(rep->rr_unbind, !rep->rr_func)) { + rep->rr_func = NULL; /* abandon the callback */ + req->rl_reply = NULL; + } + + if (req->rl_iov.length == 0) { /* see allocate above */ + struct rpcrdma_req *oreq = (struct rpcrdma_req *)req->rl_buffer; + oreq->rl_reply = req->rl_reply; + (void) rpcrdma_deregister_internal(&r_xprt->rx_ia, + req->rl_handle, + &req->rl_iov); + kfree(req); + req = oreq; + } + + /* Put back request+reply buffers */ + rpcrdma_buffer_put(req); +} + +/* + * send_request invokes the meat of RPC RDMA. It must do the following: + * 1. Marshal the RPC request into an RPC RDMA request, which means + * putting a header in front of data, and creating IOVs for RDMA + * from those in the request. + * 2. In marshaling, detect opportunities for RDMA, and use them. + * 3. Post a recv message to set up asynch completion, then send + * the request (rpcrdma_ep_post). + * 4. No partial sends are possible in the RPC-RDMA protocol (as in UDP). + */ + +static int +xprt_rdma_send_request(struct rpc_task *task) +{ + struct rpc_rqst *rqst = task->tk_rqstp; + struct rpc_xprt *xprt = task->tk_xprt; + struct rpcrdma_req *req = rpcr_to_rdmar(rqst); + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + + /* marshal the send itself */ + if (req->rl_niovs == 0 && rpcrdma_marshal_req(rqst) != 0) { + r_xprt->rx_stats.failed_marshal_count++; + dprintk("RPC: %s: rpcrdma_marshal_req failed\n", + __func__); + return -EIO; + } + + if (req->rl_reply == NULL) /* e.g. reconnection */ + rpcrdma_recv_buffer_get(req); + + if (req->rl_reply) { + req->rl_reply->rr_func = rpcrdma_reply_handler; + /* this need only be done once, but... */ + req->rl_reply->rr_xprt = xprt; + } + + /* Must suppress retransmit to maintain credits */ + if (req->rl_connect_cookie == xprt->connect_cookie) + goto drop_connection; + req->rl_connect_cookie = xprt->connect_cookie; + + if (rpcrdma_ep_post(&r_xprt->rx_ia, &r_xprt->rx_ep, req)) + goto drop_connection; + + task->tk_bytes_sent += rqst->rq_snd_buf.len; + rqst->rq_bytes_sent = 0; + return 0; + +drop_connection: + xprt_disconnect_done(xprt); + return -ENOTCONN; /* implies disconnect */ +} + +static void xprt_rdma_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) +{ + struct rpcrdma_xprt *r_xprt = rpcx_to_rdmax(xprt); + long idle_time = 0; + + if (xprt_connected(xprt)) + idle_time = (long)(jiffies - xprt->last_used) / HZ; + + seq_printf(seq, + "\txprt:\trdma %u %lu %lu %lu %ld %lu %lu %lu %Lu %Lu " + "%lu %lu %lu %Lu %Lu %Lu %Lu %lu %lu %lu\n", + + 0, /* need a local port? */ + xprt->stat.bind_count, + xprt->stat.connect_count, + xprt->stat.connect_time, + idle_time, + xprt->stat.sends, + xprt->stat.recvs, + xprt->stat.bad_xids, + xprt->stat.req_u, + xprt->stat.bklog_u, + + r_xprt->rx_stats.read_chunk_count, + r_xprt->rx_stats.write_chunk_count, + r_xprt->rx_stats.reply_chunk_count, + r_xprt->rx_stats.total_rdma_request, + r_xprt->rx_stats.total_rdma_reply, + r_xprt->rx_stats.pullup_copy_count, + r_xprt->rx_stats.fixup_copy_count, + r_xprt->rx_stats.hardway_register_count, + r_xprt->rx_stats.failed_marshal_count, + r_xprt->rx_stats.bad_reply_count); +} + +/* + * Plumbing for rpc transport switch and kernel module + */ + +static struct rpc_xprt_ops xprt_rdma_procs = { + .reserve_xprt = xprt_rdma_reserve_xprt, + .release_xprt = xprt_release_xprt_cong, /* sunrpc/xprt.c */ + .release_request = xprt_release_rqst_cong, /* ditto */ + .set_retrans_timeout = xprt_set_retrans_timeout_def, /* ditto */ + .rpcbind = rpcb_getport_async, /* sunrpc/rpcb_clnt.c */ + .set_port = xprt_rdma_set_port, + .connect = xprt_rdma_connect, + .buf_alloc = xprt_rdma_allocate, + .buf_free = xprt_rdma_free, + .send_request = xprt_rdma_send_request, + .close = xprt_rdma_close, + .destroy = xprt_rdma_destroy, + .print_stats = xprt_rdma_print_stats +}; + +static struct xprt_class xprt_rdma = { + .list = LIST_HEAD_INIT(xprt_rdma.list), + .name = "rdma", + .owner = THIS_MODULE, + .ident = XPRT_TRANSPORT_RDMA, + .setup = xprt_setup_rdma, +}; + +static void __exit xprt_rdma_cleanup(void) +{ + int rc; + + dprintk(KERN_INFO "RPCRDMA Module Removed, deregister RPC RDMA transport\n"); +#ifdef RPC_DEBUG + if (sunrpc_table_header) { + unregister_sysctl_table(sunrpc_table_header); + sunrpc_table_header = NULL; + } +#endif + rc = xprt_unregister_transport(&xprt_rdma); + if (rc) + dprintk("RPC: %s: xprt_unregister returned %i\n", + __func__, rc); +} + +static int __init xprt_rdma_init(void) +{ + int rc; + + rc = xprt_register_transport(&xprt_rdma); + + if (rc) + return rc; + + dprintk(KERN_INFO "RPCRDMA Module Init, register RPC RDMA transport\n"); + + dprintk(KERN_INFO "Defaults:\n"); + dprintk(KERN_INFO "\tSlots %d\n" + "\tMaxInlineRead %d\n\tMaxInlineWrite %d\n", + xprt_rdma_slot_table_entries, + xprt_rdma_max_inline_read, xprt_rdma_max_inline_write); + dprintk(KERN_INFO "\tPadding %d\n\tMemreg %d\n", + xprt_rdma_inline_write_padding, xprt_rdma_memreg_strategy); + +#ifdef RPC_DEBUG + if (!sunrpc_table_header) + sunrpc_table_header = register_sysctl_table(sunrpc_table); +#endif + return 0; +} + +module_init(xprt_rdma_init); +module_exit(xprt_rdma_cleanup); diff --git a/net/sunrpc/xprtrdma/verbs.c b/net/sunrpc/xprtrdma/verbs.c new file mode 100644 index 0000000..a5fef5e --- /dev/null +++ b/net/sunrpc/xprtrdma/verbs.c @@ -0,0 +1,1901 @@ +/* + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +/* + * verbs.c + * + * Encapsulates the major functions managing: + * o adapters + * o endpoints + * o connections + * o buffer memory + */ + +#include <linux/pci.h> /* for Tavor hack below */ + +#include "xprt_rdma.h" + +/* + * Globals/Macros + */ + +#ifdef RPC_DEBUG +# define RPCDBG_FACILITY RPCDBG_TRANS +#endif + +/* + * internal functions + */ + +/* + * handle replies in tasklet context, using a single, global list + * rdma tasklet function -- just turn around and call the func + * for all replies on the list + */ + +static DEFINE_SPINLOCK(rpcrdma_tk_lock_g); +static LIST_HEAD(rpcrdma_tasklets_g); + +static void +rpcrdma_run_tasklet(unsigned long data) +{ + struct rpcrdma_rep *rep; + void (*func)(struct rpcrdma_rep *); + unsigned long flags; + + data = data; + spin_lock_irqsave(&rpcrdma_tk_lock_g, flags); + while (!list_empty(&rpcrdma_tasklets_g)) { + rep = list_entry(rpcrdma_tasklets_g.next, + struct rpcrdma_rep, rr_list); + list_del(&rep->rr_list); + func = rep->rr_func; + rep->rr_func = NULL; + spin_unlock_irqrestore(&rpcrdma_tk_lock_g, flags); + + if (func) + func(rep); + else + rpcrdma_recv_buffer_put(rep); + + spin_lock_irqsave(&rpcrdma_tk_lock_g, flags); + } + spin_unlock_irqrestore(&rpcrdma_tk_lock_g, flags); +} + +static DECLARE_TASKLET(rpcrdma_tasklet_g, rpcrdma_run_tasklet, 0UL); + +static inline void +rpcrdma_schedule_tasklet(struct rpcrdma_rep *rep) +{ + unsigned long flags; + + spin_lock_irqsave(&rpcrdma_tk_lock_g, flags); + list_add_tail(&rep->rr_list, &rpcrdma_tasklets_g); + spin_unlock_irqrestore(&rpcrdma_tk_lock_g, flags); + tasklet_schedule(&rpcrdma_tasklet_g); +} + +static void +rpcrdma_qp_async_error_upcall(struct ib_event *event, void *context) +{ + struct rpcrdma_ep *ep = context; + + dprintk("RPC: %s: QP error %X on device %s ep %p\n", + __func__, event->event, event->device->name, context); + if (ep->rep_connected == 1) { + ep->rep_connected = -EIO; + ep->rep_func(ep); + wake_up_all(&ep->rep_connect_wait); + } +} + +static void +rpcrdma_cq_async_error_upcall(struct ib_event *event, void *context) +{ + struct rpcrdma_ep *ep = context; + + dprintk("RPC: %s: CQ error %X on device %s ep %p\n", + __func__, event->event, event->device->name, context); + if (ep->rep_connected == 1) { + ep->rep_connected = -EIO; + ep->rep_func(ep); + wake_up_all(&ep->rep_connect_wait); + } +} + +static inline +void rpcrdma_event_process(struct ib_wc *wc) +{ + struct rpcrdma_rep *rep = + (struct rpcrdma_rep *)(unsigned long) wc->wr_id; + + dprintk("RPC: %s: event rep %p status %X opcode %X length %u\n", + __func__, rep, wc->status, wc->opcode, wc->byte_len); + + if (!rep) /* send or bind completion that we don't care about */ + return; + + if (IB_WC_SUCCESS != wc->status) { + dprintk("RPC: %s: %s WC status %X, connection lost\n", + __func__, (wc->opcode & IB_WC_RECV) ? "recv" : "send", + wc->status); + rep->rr_len = ~0U; + rpcrdma_schedule_tasklet(rep); + return; + } + + switch (wc->opcode) { + case IB_WC_RECV: + rep->rr_len = wc->byte_len; + ib_dma_sync_single_for_cpu( + rdmab_to_ia(rep->rr_buffer)->ri_id->device, + rep->rr_iov.addr, rep->rr_len, DMA_FROM_DEVICE); + /* Keep (only) the most recent credits, after check validity */ + if (rep->rr_len >= 16) { + struct rpcrdma_msg *p = + (struct rpcrdma_msg *) rep->rr_base; + unsigned int credits = ntohl(p->rm_credit); + if (credits == 0) { + dprintk("RPC: %s: server" + " dropped credits to 0!\n", __func__); + /* don't deadlock */ + credits = 1; + } else if (credits > rep->rr_buffer->rb_max_requests) { + dprintk("RPC: %s: server" + " over-crediting: %d (%d)\n", + __func__, credits, + rep->rr_buffer->rb_max_requests); + credits = rep->rr_buffer->rb_max_requests; + } + atomic_set(&rep->rr_buffer->rb_credits, credits); + } + /* fall through */ + case IB_WC_BIND_MW: + rpcrdma_schedule_tasklet(rep); + break; + default: + dprintk("RPC: %s: unexpected WC event %X\n", + __func__, wc->opcode); + break; + } +} + +static inline int +rpcrdma_cq_poll(struct ib_cq *cq) +{ + struct ib_wc wc; + int rc; + + for (;;) { + rc = ib_poll_cq(cq, 1, &wc); + if (rc < 0) { + dprintk("RPC: %s: ib_poll_cq failed %i\n", + __func__, rc); + return rc; + } + if (rc == 0) + break; + + rpcrdma_event_process(&wc); + } + + return 0; +} + +/* + * rpcrdma_cq_event_upcall + * + * This upcall handles recv, send, bind and unbind events. + * It is reentrant but processes single events in order to maintain + * ordering of receives to keep server credits. + * + * It is the responsibility of the scheduled tasklet to return + * recv buffers to the pool. NOTE: this affects synchronization of + * connection shutdown. That is, the structures required for + * the completion of the reply handler must remain intact until + * all memory has been reclaimed. + * + * Note that send events are suppressed and do not result in an upcall. + */ +static void +rpcrdma_cq_event_upcall(struct ib_cq *cq, void *context) +{ + int rc; + + rc = rpcrdma_cq_poll(cq); + if (rc) + return; + + rc = ib_req_notify_cq(cq, IB_CQ_NEXT_COMP); + if (rc) { + dprintk("RPC: %s: ib_req_notify_cq failed %i\n", + __func__, rc); + return; + } + + rpcrdma_cq_poll(cq); +} + +#ifdef RPC_DEBUG +static const char * const conn[] = { + "address resolved", + "address error", + "route resolved", + "route error", + "connect request", + "connect response", + "connect error", + "unreachable", + "rejected", + "established", + "disconnected", + "device removal" +}; +#endif + +static int +rpcrdma_conn_upcall(struct rdma_cm_id *id, struct rdma_cm_event *event) +{ + struct rpcrdma_xprt *xprt = id->context; + struct rpcrdma_ia *ia = &xprt->rx_ia; + struct rpcrdma_ep *ep = &xprt->rx_ep; + struct sockaddr_in *addr = (struct sockaddr_in *) &ep->rep_remote_addr; + struct ib_qp_attr attr; + struct ib_qp_init_attr iattr; + int connstate = 0; + + switch (event->event) { + case RDMA_CM_EVENT_ADDR_RESOLVED: + case RDMA_CM_EVENT_ROUTE_RESOLVED: + ia->ri_async_rc = 0; + complete(&ia->ri_done); + break; + case RDMA_CM_EVENT_ADDR_ERROR: + ia->ri_async_rc = -EHOSTUNREACH; + dprintk("RPC: %s: CM address resolution error, ep 0x%p\n", + __func__, ep); + complete(&ia->ri_done); + break; + case RDMA_CM_EVENT_ROUTE_ERROR: + ia->ri_async_rc = -ENETUNREACH; + dprintk("RPC: %s: CM route resolution error, ep 0x%p\n", + __func__, ep); + complete(&ia->ri_done); + break; + case RDMA_CM_EVENT_ESTABLISHED: + connstate = 1; + ib_query_qp(ia->ri_id->qp, &attr, + IB_QP_MAX_QP_RD_ATOMIC | IB_QP_MAX_DEST_RD_ATOMIC, + &iattr); + dprintk("RPC: %s: %d responder resources" + " (%d initiator)\n", + __func__, attr.max_dest_rd_atomic, attr.max_rd_atomic); + goto connected; + case RDMA_CM_EVENT_CONNECT_ERROR: + connstate = -ENOTCONN; + goto connected; + case RDMA_CM_EVENT_UNREACHABLE: + connstate = -ENETDOWN; + goto connected; + case RDMA_CM_EVENT_REJECTED: + connstate = -ECONNREFUSED; + goto connected; + case RDMA_CM_EVENT_DISCONNECTED: + connstate = -ECONNABORTED; + goto connected; + case RDMA_CM_EVENT_DEVICE_REMOVAL: + connstate = -ENODEV; +connected: + dprintk("RPC: %s: %s: %u.%u.%u.%u:%u" + " (ep 0x%p event 0x%x)\n", + __func__, + (event->event <= 11) ? conn[event->event] : + "unknown connection error", + NIPQUAD(addr->sin_addr.s_addr), + ntohs(addr->sin_port), + ep, event->event); + atomic_set(&rpcx_to_rdmax(ep->rep_xprt)->rx_buf.rb_credits, 1); + dprintk("RPC: %s: %sconnected\n", + __func__, connstate > 0 ? "" : "dis"); + ep->rep_connected = connstate; + ep->rep_func(ep); + wake_up_all(&ep->rep_connect_wait); + break; + default: + dprintk("RPC: %s: unexpected CM event %d\n", + __func__, event->event); + break; + } + +#ifdef RPC_DEBUG + if (connstate == 1) { + int ird = attr.max_dest_rd_atomic; + int tird = ep->rep_remote_cma.responder_resources; + printk(KERN_INFO "rpcrdma: connection to %u.%u.%u.%u:%u " + "on %s, memreg %d slots %d ird %d%s\n", + NIPQUAD(addr->sin_addr.s_addr), + ntohs(addr->sin_port), + ia->ri_id->device->name, + ia->ri_memreg_strategy, + xprt->rx_buf.rb_max_requests, + ird, ird < 4 && ird < tird / 2 ? " (low!)" : ""); + } else if (connstate < 0) { + printk(KERN_INFO "rpcrdma: connection to %u.%u.%u.%u:%u " + "closed (%d)\n", + NIPQUAD(addr->sin_addr.s_addr), + ntohs(addr->sin_port), + connstate); + } +#endif + + return 0; +} + +static struct rdma_cm_id * +rpcrdma_create_id(struct rpcrdma_xprt *xprt, + struct rpcrdma_ia *ia, struct sockaddr *addr) +{ + struct rdma_cm_id *id; + int rc; + + init_completion(&ia->ri_done); + + id = rdma_create_id(rpcrdma_conn_upcall, xprt, RDMA_PS_TCP); + if (IS_ERR(id)) { + rc = PTR_ERR(id); + dprintk("RPC: %s: rdma_create_id() failed %i\n", + __func__, rc); + return id; + } + + ia->ri_async_rc = -ETIMEDOUT; + rc = rdma_resolve_addr(id, NULL, addr, RDMA_RESOLVE_TIMEOUT); + if (rc) { + dprintk("RPC: %s: rdma_resolve_addr() failed %i\n", + __func__, rc); + goto out; + } + wait_for_completion_interruptible_timeout(&ia->ri_done, + msecs_to_jiffies(RDMA_RESOLVE_TIMEOUT) + 1); + rc = ia->ri_async_rc; + if (rc) + goto out; + + ia->ri_async_rc = -ETIMEDOUT; + rc = rdma_resolve_route(id, RDMA_RESOLVE_TIMEOUT); + if (rc) { + dprintk("RPC: %s: rdma_resolve_route() failed %i\n", + __func__, rc); + goto out; + } + wait_for_completion_interruptible_timeout(&ia->ri_done, + msecs_to_jiffies(RDMA_RESOLVE_TIMEOUT) + 1); + rc = ia->ri_async_rc; + if (rc) + goto out; + + return id; + +out: + rdma_destroy_id(id); + return ERR_PTR(rc); +} + +/* + * Drain any cq, prior to teardown. + */ +static void +rpcrdma_clean_cq(struct ib_cq *cq) +{ + struct ib_wc wc; + int count = 0; + + while (1 == ib_poll_cq(cq, 1, &wc)) + ++count; + + if (count) + dprintk("RPC: %s: flushed %d events (last 0x%x)\n", + __func__, count, wc.opcode); +} + +/* + * Exported functions. + */ + +/* + * Open and initialize an Interface Adapter. + * o initializes fields of struct rpcrdma_ia, including + * interface and provider attributes and protection zone. + */ +int +rpcrdma_ia_open(struct rpcrdma_xprt *xprt, struct sockaddr *addr, int memreg) +{ + int rc, mem_priv; + struct ib_device_attr devattr; + struct rpcrdma_ia *ia = &xprt->rx_ia; + + ia->ri_id = rpcrdma_create_id(xprt, ia, addr); + if (IS_ERR(ia->ri_id)) { + rc = PTR_ERR(ia->ri_id); + goto out1; + } + + ia->ri_pd = ib_alloc_pd(ia->ri_id->device); + if (IS_ERR(ia->ri_pd)) { + rc = PTR_ERR(ia->ri_pd); + dprintk("RPC: %s: ib_alloc_pd() failed %i\n", + __func__, rc); + goto out2; + } + + /* + * Query the device to determine if the requested memory + * registration strategy is supported. If it isn't, set the + * strategy to a globally supported model. + */ + rc = ib_query_device(ia->ri_id->device, &devattr); + if (rc) { + dprintk("RPC: %s: ib_query_device failed %d\n", + __func__, rc); + goto out2; + } + + if (devattr.device_cap_flags & IB_DEVICE_LOCAL_DMA_LKEY) { + ia->ri_have_dma_lkey = 1; + ia->ri_dma_lkey = ia->ri_id->device->local_dma_lkey; + } + + switch (memreg) { + case RPCRDMA_MEMWINDOWS: + case RPCRDMA_MEMWINDOWS_ASYNC: + if (!(devattr.device_cap_flags & IB_DEVICE_MEM_WINDOW)) { + dprintk("RPC: %s: MEMWINDOWS registration " + "specified but not supported by adapter, " + "using slower RPCRDMA_REGISTER\n", + __func__); + memreg = RPCRDMA_REGISTER; + } + break; + case RPCRDMA_MTHCAFMR: + if (!ia->ri_id->device->alloc_fmr) { +#if RPCRDMA_PERSISTENT_REGISTRATION + dprintk("RPC: %s: MTHCAFMR registration " + "specified but not supported by adapter, " + "using riskier RPCRDMA_ALLPHYSICAL\n", + __func__); + memreg = RPCRDMA_ALLPHYSICAL; +#else + dprintk("RPC: %s: MTHCAFMR registration " + "specified but not supported by adapter, " + "using slower RPCRDMA_REGISTER\n", + __func__); + memreg = RPCRDMA_REGISTER; +#endif + } + break; + case RPCRDMA_FRMR: + /* Requires both frmr reg and local dma lkey */ + if ((devattr.device_cap_flags & + (IB_DEVICE_MEM_MGT_EXTENSIONS|IB_DEVICE_LOCAL_DMA_LKEY)) != + (IB_DEVICE_MEM_MGT_EXTENSIONS|IB_DEVICE_LOCAL_DMA_LKEY)) { +#if RPCRDMA_PERSISTENT_REGISTRATION + dprintk("RPC: %s: FRMR registration " + "specified but not supported by adapter, " + "using riskier RPCRDMA_ALLPHYSICAL\n", + __func__); + memreg = RPCRDMA_ALLPHYSICAL; +#else + dprintk("RPC: %s: FRMR registration " + "specified but not supported by adapter, " + "using slower RPCRDMA_REGISTER\n", + __func__); + memreg = RPCRDMA_REGISTER; +#endif + } + break; + } + + /* + * Optionally obtain an underlying physical identity mapping in + * order to do a memory window-based bind. This base registration + * is protected from remote access - that is enabled only by binding + * for the specific bytes targeted during each RPC operation, and + * revoked after the corresponding completion similar to a storage + * adapter. + */ + switch (memreg) { + case RPCRDMA_BOUNCEBUFFERS: + case RPCRDMA_REGISTER: + case RPCRDMA_FRMR: + break; +#if RPCRDMA_PERSISTENT_REGISTRATION + case RPCRDMA_ALLPHYSICAL: + mem_priv = IB_ACCESS_LOCAL_WRITE | + IB_ACCESS_REMOTE_WRITE | + IB_ACCESS_REMOTE_READ; + goto register_setup; +#endif + case RPCRDMA_MEMWINDOWS_ASYNC: + case RPCRDMA_MEMWINDOWS: + mem_priv = IB_ACCESS_LOCAL_WRITE | + IB_ACCESS_MW_BIND; + goto register_setup; + case RPCRDMA_MTHCAFMR: + if (ia->ri_have_dma_lkey) + break; + mem_priv = IB_ACCESS_LOCAL_WRITE; + register_setup: + ia->ri_bind_mem = ib_get_dma_mr(ia->ri_pd, mem_priv); + if (IS_ERR(ia->ri_bind_mem)) { + printk(KERN_ALERT "%s: ib_get_dma_mr for " + "phys register failed with %lX\n\t" + "Will continue with degraded performance\n", + __func__, PTR_ERR(ia->ri_bind_mem)); + memreg = RPCRDMA_REGISTER; + ia->ri_bind_mem = NULL; + } + break; + default: + printk(KERN_ERR "%s: invalid memory registration mode %d\n", + __func__, memreg); + rc = -EINVAL; + goto out2; + } + dprintk("RPC: %s: memory registration strategy is %d\n", + __func__, memreg); + + /* Else will do memory reg/dereg for each chunk */ + ia->ri_memreg_strategy = memreg; + + return 0; +out2: + rdma_destroy_id(ia->ri_id); + ia->ri_id = NULL; +out1: + return rc; +} + +/* + * Clean up/close an IA. + * o if event handles and PD have been initialized, free them. + * o close the IA + */ +void +rpcrdma_ia_close(struct rpcrdma_ia *ia) +{ + int rc; + + dprintk("RPC: %s: entering\n", __func__); + if (ia->ri_bind_mem != NULL) { + rc = ib_dereg_mr(ia->ri_bind_mem); + dprintk("RPC: %s: ib_dereg_mr returned %i\n", + __func__, rc); + } + if (ia->ri_id != NULL && !IS_ERR(ia->ri_id)) { + if (ia->ri_id->qp) + rdma_destroy_qp(ia->ri_id); + rdma_destroy_id(ia->ri_id); + ia->ri_id = NULL; + } + if (ia->ri_pd != NULL && !IS_ERR(ia->ri_pd)) { + rc = ib_dealloc_pd(ia->ri_pd); + dprintk("RPC: %s: ib_dealloc_pd returned %i\n", + __func__, rc); + } +} + +/* + * Create unconnected endpoint. + */ +int +rpcrdma_ep_create(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia, + struct rpcrdma_create_data_internal *cdata) +{ + struct ib_device_attr devattr; + int rc, err; + + rc = ib_query_device(ia->ri_id->device, &devattr); + if (rc) { + dprintk("RPC: %s: ib_query_device failed %d\n", + __func__, rc); + return rc; + } + + /* check provider's send/recv wr limits */ + if (cdata->max_requests > devattr.max_qp_wr) + cdata->max_requests = devattr.max_qp_wr; + + ep->rep_attr.event_handler = rpcrdma_qp_async_error_upcall; + ep->rep_attr.qp_context = ep; + /* send_cq and recv_cq initialized below */ + ep->rep_attr.srq = NULL; + ep->rep_attr.cap.max_send_wr = cdata->max_requests; + switch (ia->ri_memreg_strategy) { + case RPCRDMA_FRMR: + /* Add room for frmr register and invalidate WRs */ + ep->rep_attr.cap.max_send_wr *= 3; + if (ep->rep_attr.cap.max_send_wr > devattr.max_qp_wr) + return -EINVAL; + break; + case RPCRDMA_MEMWINDOWS_ASYNC: + case RPCRDMA_MEMWINDOWS: + /* Add room for mw_binds+unbinds - overkill! */ + ep->rep_attr.cap.max_send_wr++; + ep->rep_attr.cap.max_send_wr *= (2 * RPCRDMA_MAX_SEGS); + if (ep->rep_attr.cap.max_send_wr > devattr.max_qp_wr) + return -EINVAL; + break; + default: + break; + } + ep->rep_attr.cap.max_recv_wr = cdata->max_requests; + ep->rep_attr.cap.max_send_sge = (cdata->padding ? 4 : 2); + ep->rep_attr.cap.max_recv_sge = 1; + ep->rep_attr.cap.max_inline_data = 0; + ep->rep_attr.sq_sig_type = IB_SIGNAL_REQ_WR; + ep->rep_attr.qp_type = IB_QPT_RC; + ep->rep_attr.port_num = ~0; + + dprintk("RPC: %s: requested max: dtos: send %d recv %d; " + "iovs: send %d recv %d\n", + __func__, + ep->rep_attr.cap.max_send_wr, + ep->rep_attr.cap.max_recv_wr, + ep->rep_attr.cap.max_send_sge, + ep->rep_attr.cap.max_recv_sge); + + /* set trigger for requesting send completion */ + ep->rep_cqinit = ep->rep_attr.cap.max_send_wr/2 /* - 1*/; + switch (ia->ri_memreg_strategy) { + case RPCRDMA_MEMWINDOWS_ASYNC: + case RPCRDMA_MEMWINDOWS: + ep->rep_cqinit -= RPCRDMA_MAX_SEGS; + break; + default: + break; + } + if (ep->rep_cqinit <= 2) + ep->rep_cqinit = 0; + INIT_CQCOUNT(ep); + ep->rep_ia = ia; + init_waitqueue_head(&ep->rep_connect_wait); + + /* + * Create a single cq for receive dto and mw_bind (only ever + * care about unbind, really). Send completions are suppressed. + * Use single threaded tasklet upcalls to maintain ordering. + */ + ep->rep_cq = ib_create_cq(ia->ri_id->device, rpcrdma_cq_event_upcall, + rpcrdma_cq_async_error_upcall, NULL, + ep->rep_attr.cap.max_recv_wr + + ep->rep_attr.cap.max_send_wr + 1, 0); + if (IS_ERR(ep->rep_cq)) { + rc = PTR_ERR(ep->rep_cq); + dprintk("RPC: %s: ib_create_cq failed: %i\n", + __func__, rc); + goto out1; + } + + rc = ib_req_notify_cq(ep->rep_cq, IB_CQ_NEXT_COMP); + if (rc) { + dprintk("RPC: %s: ib_req_notify_cq failed: %i\n", + __func__, rc); + goto out2; + } + + ep->rep_attr.send_cq = ep->rep_cq; + ep->rep_attr.recv_cq = ep->rep_cq; + + /* Initialize cma parameters */ + + /* RPC/RDMA does not use private data */ + ep->rep_remote_cma.private_data = NULL; + ep->rep_remote_cma.private_data_len = 0; + + /* Client offers RDMA Read but does not initiate */ + ep->rep_remote_cma.initiator_depth = 0; + if (ia->ri_memreg_strategy == RPCRDMA_BOUNCEBUFFERS) + ep->rep_remote_cma.responder_resources = 0; + else if (devattr.max_qp_rd_atom > 32) /* arbitrary but <= 255 */ + ep->rep_remote_cma.responder_resources = 32; + else + ep->rep_remote_cma.responder_resources = devattr.max_qp_rd_atom; + + ep->rep_remote_cma.retry_count = 7; + ep->rep_remote_cma.flow_control = 0; + ep->rep_remote_cma.rnr_retry_count = 0; + + return 0; + +out2: + err = ib_destroy_cq(ep->rep_cq); + if (err) + dprintk("RPC: %s: ib_destroy_cq returned %i\n", + __func__, err); +out1: + return rc; +} + +/* + * rpcrdma_ep_destroy + * + * Disconnect and destroy endpoint. After this, the only + * valid operations on the ep are to free it (if dynamically + * allocated) or re-create it. + * + * The caller's error handling must be sure to not leak the endpoint + * if this function fails. + */ +int +rpcrdma_ep_destroy(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) +{ + int rc; + + dprintk("RPC: %s: entering, connected is %d\n", + __func__, ep->rep_connected); + + if (ia->ri_id->qp) { + rc = rpcrdma_ep_disconnect(ep, ia); + if (rc) + dprintk("RPC: %s: rpcrdma_ep_disconnect" + " returned %i\n", __func__, rc); + rdma_destroy_qp(ia->ri_id); + ia->ri_id->qp = NULL; + } + + /* padding - could be done in rpcrdma_buffer_destroy... */ + if (ep->rep_pad_mr) { + rpcrdma_deregister_internal(ia, ep->rep_pad_mr, &ep->rep_pad); + ep->rep_pad_mr = NULL; + } + + rpcrdma_clean_cq(ep->rep_cq); + rc = ib_destroy_cq(ep->rep_cq); + if (rc) + dprintk("RPC: %s: ib_destroy_cq returned %i\n", + __func__, rc); + + return rc; +} + +/* + * Connect unconnected endpoint. + */ +int +rpcrdma_ep_connect(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) +{ + struct rdma_cm_id *id; + int rc = 0; + int retry_count = 0; + + if (ep->rep_connected != 0) { + struct rpcrdma_xprt *xprt; +retry: + rc = rpcrdma_ep_disconnect(ep, ia); + if (rc && rc != -ENOTCONN) + dprintk("RPC: %s: rpcrdma_ep_disconnect" + " status %i\n", __func__, rc); + rpcrdma_clean_cq(ep->rep_cq); + + xprt = container_of(ia, struct rpcrdma_xprt, rx_ia); + id = rpcrdma_create_id(xprt, ia, + (struct sockaddr *)&xprt->rx_data.addr); + if (IS_ERR(id)) { + rc = PTR_ERR(id); + goto out; + } + /* TEMP TEMP TEMP - fail if new device: + * Deregister/remarshal *all* requests! + * Close and recreate adapter, pd, etc! + * Re-determine all attributes still sane! + * More stuff I haven't thought of! + * Rrrgh! + */ + if (ia->ri_id->device != id->device) { + printk("RPC: %s: can't reconnect on " + "different device!\n", __func__); + rdma_destroy_id(id); + rc = -ENETDOWN; + goto out; + } + /* END TEMP */ + rdma_destroy_qp(ia->ri_id); + rdma_destroy_id(ia->ri_id); + ia->ri_id = id; + } + + rc = rdma_create_qp(ia->ri_id, ia->ri_pd, &ep->rep_attr); + if (rc) { + dprintk("RPC: %s: rdma_create_qp failed %i\n", + __func__, rc); + goto out; + } + +/* XXX Tavor device performs badly with 2K MTU! */ +if (strnicmp(ia->ri_id->device->dma_device->bus->name, "pci", 3) == 0) { + struct pci_dev *pcid = to_pci_dev(ia->ri_id->device->dma_device); + if (pcid->device == PCI_DEVICE_ID_MELLANOX_TAVOR && + (pcid->vendor == PCI_VENDOR_ID_MELLANOX || + pcid->vendor == PCI_VENDOR_ID_TOPSPIN)) { + struct ib_qp_attr attr = { + .path_mtu = IB_MTU_1024 + }; + rc = ib_modify_qp(ia->ri_id->qp, &attr, IB_QP_PATH_MTU); + } +} + + ep->rep_connected = 0; + + rc = rdma_connect(ia->ri_id, &ep->rep_remote_cma); + if (rc) { + dprintk("RPC: %s: rdma_connect() failed with %i\n", + __func__, rc); + goto out; + } + + wait_event_interruptible(ep->rep_connect_wait, ep->rep_connected != 0); + + /* + * Check state. A non-peer reject indicates no listener + * (ECONNREFUSED), which may be a transient state. All + * others indicate a transport condition which has already + * undergone a best-effort. + */ + if (ep->rep_connected == -ECONNREFUSED + && ++retry_count <= RDMA_CONNECT_RETRY_MAX) { + dprintk("RPC: %s: non-peer_reject, retry\n", __func__); + goto retry; + } + if (ep->rep_connected <= 0) { + /* Sometimes, the only way to reliably connect to remote + * CMs is to use same nonzero values for ORD and IRD. */ + if (retry_count++ <= RDMA_CONNECT_RETRY_MAX + 1 && + (ep->rep_remote_cma.responder_resources == 0 || + ep->rep_remote_cma.initiator_depth != + ep->rep_remote_cma.responder_resources)) { + if (ep->rep_remote_cma.responder_resources == 0) + ep->rep_remote_cma.responder_resources = 1; + ep->rep_remote_cma.initiator_depth = + ep->rep_remote_cma.responder_resources; + goto retry; + } + rc = ep->rep_connected; + } else { + dprintk("RPC: %s: connected\n", __func__); + } + +out: + if (rc) + ep->rep_connected = rc; + return rc; +} + +/* + * rpcrdma_ep_disconnect + * + * This is separate from destroy to facilitate the ability + * to reconnect without recreating the endpoint. + * + * This call is not reentrant, and must not be made in parallel + * on the same endpoint. + */ +int +rpcrdma_ep_disconnect(struct rpcrdma_ep *ep, struct rpcrdma_ia *ia) +{ + int rc; + + rpcrdma_clean_cq(ep->rep_cq); + rc = rdma_disconnect(ia->ri_id); + if (!rc) { + /* returns without wait if not connected */ + wait_event_interruptible(ep->rep_connect_wait, + ep->rep_connected != 1); + dprintk("RPC: %s: after wait, %sconnected\n", __func__, + (ep->rep_connected == 1) ? "still " : "dis"); + } else { + dprintk("RPC: %s: rdma_disconnect %i\n", __func__, rc); + ep->rep_connected = rc; + } + return rc; +} + +/* + * Initialize buffer memory + */ +int +rpcrdma_buffer_create(struct rpcrdma_buffer *buf, struct rpcrdma_ep *ep, + struct rpcrdma_ia *ia, struct rpcrdma_create_data_internal *cdata) +{ + char *p; + size_t len; + int i, rc; + struct rpcrdma_mw *r; + + buf->rb_max_requests = cdata->max_requests; + spin_lock_init(&buf->rb_lock); + atomic_set(&buf->rb_credits, 1); + + /* Need to allocate: + * 1. arrays for send and recv pointers + * 2. arrays of struct rpcrdma_req to fill in pointers + * 3. array of struct rpcrdma_rep for replies + * 4. padding, if any + * 5. mw's, fmr's or frmr's, if any + * Send/recv buffers in req/rep need to be registered + */ + + len = buf->rb_max_requests * + (sizeof(struct rpcrdma_req *) + sizeof(struct rpcrdma_rep *)); + len += cdata->padding; + switch (ia->ri_memreg_strategy) { + case RPCRDMA_FRMR: + len += buf->rb_max_requests * RPCRDMA_MAX_SEGS * + sizeof(struct rpcrdma_mw); + break; + case RPCRDMA_MTHCAFMR: + /* TBD we are perhaps overallocating here */ + len += (buf->rb_max_requests + 1) * RPCRDMA_MAX_SEGS * + sizeof(struct rpcrdma_mw); + break; + case RPCRDMA_MEMWINDOWS_ASYNC: + case RPCRDMA_MEMWINDOWS: + len += (buf->rb_max_requests + 1) * RPCRDMA_MAX_SEGS * + sizeof(struct rpcrdma_mw); + break; + default: + break; + } + + /* allocate 1, 4 and 5 in one shot */ + p = kzalloc(len, GFP_KERNEL); + if (p == NULL) { + dprintk("RPC: %s: req_t/rep_t/pad kzalloc(%zd) failed\n", + __func__, len); + rc = -ENOMEM; + goto out; + } + buf->rb_pool = p; /* for freeing it later */ + + buf->rb_send_bufs = (struct rpcrdma_req **) p; + p = (char *) &buf->rb_send_bufs[buf->rb_max_requests]; + buf->rb_recv_bufs = (struct rpcrdma_rep **) p; + p = (char *) &buf->rb_recv_bufs[buf->rb_max_requests]; + + /* + * Register the zeroed pad buffer, if any. + */ + if (cdata->padding) { + rc = rpcrdma_register_internal(ia, p, cdata->padding, + &ep->rep_pad_mr, &ep->rep_pad); + if (rc) + goto out; + } + p += cdata->padding; + + /* + * Allocate the fmr's, or mw's for mw_bind chunk registration. + * We "cycle" the mw's in order to minimize rkey reuse, + * and also reduce unbind-to-bind collision. + */ + INIT_LIST_HEAD(&buf->rb_mws); + r = (struct rpcrdma_mw *)p; + switch (ia->ri_memreg_strategy) { + case RPCRDMA_FRMR: + for (i = buf->rb_max_requests * RPCRDMA_MAX_SEGS; i; i--) { + r->r.frmr.fr_mr = ib_alloc_fast_reg_mr(ia->ri_pd, + RPCRDMA_MAX_SEGS); + if (IS_ERR(r->r.frmr.fr_mr)) { + rc = PTR_ERR(r->r.frmr.fr_mr); + dprintk("RPC: %s: ib_alloc_fast_reg_mr" + " failed %i\n", __func__, rc); + goto out; + } + r->r.frmr.fr_pgl = + ib_alloc_fast_reg_page_list(ia->ri_id->device, + RPCRDMA_MAX_SEGS); + if (IS_ERR(r->r.frmr.fr_pgl)) { + rc = PTR_ERR(r->r.frmr.fr_pgl); + dprintk("RPC: %s: " + "ib_alloc_fast_reg_page_list " + "failed %i\n", __func__, rc); + goto out; + } + list_add(&r->mw_list, &buf->rb_mws); + ++r; + } + break; + case RPCRDMA_MTHCAFMR: + /* TBD we are perhaps overallocating here */ + for (i = (buf->rb_max_requests+1) * RPCRDMA_MAX_SEGS; i; i--) { + static struct ib_fmr_attr fa = + { RPCRDMA_MAX_DATA_SEGS, 1, PAGE_SHIFT }; + r->r.fmr = ib_alloc_fmr(ia->ri_pd, + IB_ACCESS_REMOTE_WRITE | IB_ACCESS_REMOTE_READ, + &fa); + if (IS_ERR(r->r.fmr)) { + rc = PTR_ERR(r->r.fmr); + dprintk("RPC: %s: ib_alloc_fmr" + " failed %i\n", __func__, rc); + goto out; + } + list_add(&r->mw_list, &buf->rb_mws); + ++r; + } + break; + case RPCRDMA_MEMWINDOWS_ASYNC: + case RPCRDMA_MEMWINDOWS: + /* Allocate one extra request's worth, for full cycling */ + for (i = (buf->rb_max_requests+1) * RPCRDMA_MAX_SEGS; i; i--) { + r->r.mw = ib_alloc_mw(ia->ri_pd); + if (IS_ERR(r->r.mw)) { + rc = PTR_ERR(r->r.mw); + dprintk("RPC: %s: ib_alloc_mw" + " failed %i\n", __func__, rc); + goto out; + } + list_add(&r->mw_list, &buf->rb_mws); + ++r; + } + break; + default: + break; + } + + /* + * Allocate/init the request/reply buffers. Doing this + * using kmalloc for now -- one for each buf. + */ + for (i = 0; i < buf->rb_max_requests; i++) { + struct rpcrdma_req *req; + struct rpcrdma_rep *rep; + + len = cdata->inline_wsize + sizeof(struct rpcrdma_req); + /* RPC layer requests *double* size + 1K RPC_SLACK_SPACE! */ + /* Typical ~2400b, so rounding up saves work later */ + if (len < 4096) + len = 4096; + req = kmalloc(len, GFP_KERNEL); + if (req == NULL) { + dprintk("RPC: %s: request buffer %d alloc" + " failed\n", __func__, i); + rc = -ENOMEM; + goto out; + } + memset(req, 0, sizeof(struct rpcrdma_req)); + buf->rb_send_bufs[i] = req; + buf->rb_send_bufs[i]->rl_buffer = buf; + + rc = rpcrdma_register_internal(ia, req->rl_base, + len - offsetof(struct rpcrdma_req, rl_base), + &buf->rb_send_bufs[i]->rl_handle, + &buf->rb_send_bufs[i]->rl_iov); + if (rc) + goto out; + + buf->rb_send_bufs[i]->rl_size = len-sizeof(struct rpcrdma_req); + + len = cdata->inline_rsize + sizeof(struct rpcrdma_rep); + rep = kmalloc(len, GFP_KERNEL); + if (rep == NULL) { + dprintk("RPC: %s: reply buffer %d alloc failed\n", + __func__, i); + rc = -ENOMEM; + goto out; + } + memset(rep, 0, sizeof(struct rpcrdma_rep)); + buf->rb_recv_bufs[i] = rep; + buf->rb_recv_bufs[i]->rr_buffer = buf; + init_waitqueue_head(&rep->rr_unbind); + + rc = rpcrdma_register_internal(ia, rep->rr_base, + len - offsetof(struct rpcrdma_rep, rr_base), + &buf->rb_recv_bufs[i]->rr_handle, + &buf->rb_recv_bufs[i]->rr_iov); + if (rc) + goto out; + + } + dprintk("RPC: %s: max_requests %d\n", + __func__, buf->rb_max_requests); + /* done */ + return 0; +out: + rpcrdma_buffer_destroy(buf); + return rc; +} + +/* + * Unregister and destroy buffer memory. Need to deal with + * partial initialization, so it's callable from failed create. + * Must be called before destroying endpoint, as registrations + * reference it. + */ +void +rpcrdma_buffer_destroy(struct rpcrdma_buffer *buf) +{ + int rc, i; + struct rpcrdma_ia *ia = rdmab_to_ia(buf); + struct rpcrdma_mw *r; + + /* clean up in reverse order from create + * 1. recv mr memory (mr free, then kfree) + * 1a. bind mw memory + * 2. send mr memory (mr free, then kfree) + * 3. padding (if any) [moved to rpcrdma_ep_destroy] + * 4. arrays + */ + dprintk("RPC: %s: entering\n", __func__); + + for (i = 0; i < buf->rb_max_requests; i++) { + if (buf->rb_recv_bufs && buf->rb_recv_bufs[i]) { + rpcrdma_deregister_internal(ia, + buf->rb_recv_bufs[i]->rr_handle, + &buf->rb_recv_bufs[i]->rr_iov); + kfree(buf->rb_recv_bufs[i]); + } + if (buf->rb_send_bufs && buf->rb_send_bufs[i]) { + while (!list_empty(&buf->rb_mws)) { + r = list_entry(buf->rb_mws.next, + struct rpcrdma_mw, mw_list); + list_del(&r->mw_list); + switch (ia->ri_memreg_strategy) { + case RPCRDMA_FRMR: + rc = ib_dereg_mr(r->r.frmr.fr_mr); + if (rc) + dprintk("RPC: %s:" + " ib_dereg_mr" + " failed %i\n", + __func__, rc); + ib_free_fast_reg_page_list(r->r.frmr.fr_pgl); + break; + case RPCRDMA_MTHCAFMR: + rc = ib_dealloc_fmr(r->r.fmr); + if (rc) + dprintk("RPC: %s:" + " ib_dealloc_fmr" + " failed %i\n", + __func__, rc); + break; + case RPCRDMA_MEMWINDOWS_ASYNC: + case RPCRDMA_MEMWINDOWS: + rc = ib_dealloc_mw(r->r.mw); + if (rc) + dprintk("RPC: %s:" + " ib_dealloc_mw" + " failed %i\n", + __func__, rc); + break; + default: + break; + } + } + rpcrdma_deregister_internal(ia, + buf->rb_send_bufs[i]->rl_handle, + &buf->rb_send_bufs[i]->rl_iov); + kfree(buf->rb_send_bufs[i]); + } + } + + kfree(buf->rb_pool); +} + +/* + * Get a set of request/reply buffers. + * + * Reply buffer (if needed) is attached to send buffer upon return. + * Rule: + * rb_send_index and rb_recv_index MUST always be pointing to the + * *next* available buffer (non-NULL). They are incremented after + * removing buffers, and decremented *before* returning them. + */ +struct rpcrdma_req * +rpcrdma_buffer_get(struct rpcrdma_buffer *buffers) +{ + struct rpcrdma_req *req; + unsigned long flags; + int i; + struct rpcrdma_mw *r; + + spin_lock_irqsave(&buffers->rb_lock, flags); + if (buffers->rb_send_index == buffers->rb_max_requests) { + spin_unlock_irqrestore(&buffers->rb_lock, flags); + dprintk("RPC: %s: out of request buffers\n", __func__); + return ((struct rpcrdma_req *)NULL); + } + + req = buffers->rb_send_bufs[buffers->rb_send_index]; + if (buffers->rb_send_index < buffers->rb_recv_index) { + dprintk("RPC: %s: %d extra receives outstanding (ok)\n", + __func__, + buffers->rb_recv_index - buffers->rb_send_index); + req->rl_reply = NULL; + } else { + req->rl_reply = buffers->rb_recv_bufs[buffers->rb_recv_index]; + buffers->rb_recv_bufs[buffers->rb_recv_index++] = NULL; + } + buffers->rb_send_bufs[buffers->rb_send_index++] = NULL; + if (!list_empty(&buffers->rb_mws)) { + i = RPCRDMA_MAX_SEGS - 1; + do { + r = list_entry(buffers->rb_mws.next, + struct rpcrdma_mw, mw_list); + list_del(&r->mw_list); + req->rl_segments[i].mr_chunk.rl_mw = r; + } while (--i >= 0); + } + spin_unlock_irqrestore(&buffers->rb_lock, flags); + return req; +} + +/* + * Put request/reply buffers back into pool. + * Pre-decrement counter/array index. + */ +void +rpcrdma_buffer_put(struct rpcrdma_req *req) +{ + struct rpcrdma_buffer *buffers = req->rl_buffer; + struct rpcrdma_ia *ia = rdmab_to_ia(buffers); + int i; + unsigned long flags; + + BUG_ON(req->rl_nchunks != 0); + spin_lock_irqsave(&buffers->rb_lock, flags); + buffers->rb_send_bufs[--buffers->rb_send_index] = req; + req->rl_niovs = 0; + if (req->rl_reply) { + buffers->rb_recv_bufs[--buffers->rb_recv_index] = req->rl_reply; + init_waitqueue_head(&req->rl_reply->rr_unbind); + req->rl_reply->rr_func = NULL; + req->rl_reply = NULL; + } + switch (ia->ri_memreg_strategy) { + case RPCRDMA_FRMR: + case RPCRDMA_MTHCAFMR: + case RPCRDMA_MEMWINDOWS_ASYNC: + case RPCRDMA_MEMWINDOWS: + /* + * Cycle mw's back in reverse order, and "spin" them. + * This delays and scrambles reuse as much as possible. + */ + i = 1; + do { + struct rpcrdma_mw **mw; + mw = &req->rl_segments[i].mr_chunk.rl_mw; + list_add_tail(&(*mw)->mw_list, &buffers->rb_mws); + *mw = NULL; + } while (++i < RPCRDMA_MAX_SEGS); + list_add_tail(&req->rl_segments[0].mr_chunk.rl_mw->mw_list, + &buffers->rb_mws); + req->rl_segments[0].mr_chunk.rl_mw = NULL; + break; + default: + break; + } + spin_unlock_irqrestore(&buffers->rb_lock, flags); +} + +/* + * Recover reply buffers from pool. + * This happens when recovering from error conditions. + * Post-increment counter/array index. + */ +void +rpcrdma_recv_buffer_get(struct rpcrdma_req *req) +{ + struct rpcrdma_buffer *buffers = req->rl_buffer; + unsigned long flags; + + if (req->rl_iov.length == 0) /* special case xprt_rdma_allocate() */ + buffers = ((struct rpcrdma_req *) buffers)->rl_buffer; + spin_lock_irqsave(&buffers->rb_lock, flags); + if (buffers->rb_recv_index < buffers->rb_max_requests) { + req->rl_reply = buffers->rb_recv_bufs[buffers->rb_recv_index]; + buffers->rb_recv_bufs[buffers->rb_recv_index++] = NULL; + } + spin_unlock_irqrestore(&buffers->rb_lock, flags); +} + +/* + * Put reply buffers back into pool when not attached to + * request. This happens in error conditions, and when + * aborting unbinds. Pre-decrement counter/array index. + */ +void +rpcrdma_recv_buffer_put(struct rpcrdma_rep *rep) +{ + struct rpcrdma_buffer *buffers = rep->rr_buffer; + unsigned long flags; + + rep->rr_func = NULL; + spin_lock_irqsave(&buffers->rb_lock, flags); + buffers->rb_recv_bufs[--buffers->rb_recv_index] = rep; + spin_unlock_irqrestore(&buffers->rb_lock, flags); +} + +/* + * Wrappers for internal-use kmalloc memory registration, used by buffer code. + */ + +int +rpcrdma_register_internal(struct rpcrdma_ia *ia, void *va, int len, + struct ib_mr **mrp, struct ib_sge *iov) +{ + struct ib_phys_buf ipb; + struct ib_mr *mr; + int rc; + + /* + * All memory passed here was kmalloc'ed, therefore phys-contiguous. + */ + iov->addr = ib_dma_map_single(ia->ri_id->device, + va, len, DMA_BIDIRECTIONAL); + iov->length = len; + + if (ia->ri_have_dma_lkey) { + *mrp = NULL; + iov->lkey = ia->ri_dma_lkey; + return 0; + } else if (ia->ri_bind_mem != NULL) { + *mrp = NULL; + iov->lkey = ia->ri_bind_mem->lkey; + return 0; + } + + ipb.addr = iov->addr; + ipb.size = iov->length; + mr = ib_reg_phys_mr(ia->ri_pd, &ipb, 1, + IB_ACCESS_LOCAL_WRITE, &iov->addr); + + dprintk("RPC: %s: phys convert: 0x%llx " + "registered 0x%llx length %d\n", + __func__, (unsigned long long)ipb.addr, + (unsigned long long)iov->addr, len); + + if (IS_ERR(mr)) { + *mrp = NULL; + rc = PTR_ERR(mr); + dprintk("RPC: %s: failed with %i\n", __func__, rc); + } else { + *mrp = mr; + iov->lkey = mr->lkey; + rc = 0; + } + + return rc; +} + +int +rpcrdma_deregister_internal(struct rpcrdma_ia *ia, + struct ib_mr *mr, struct ib_sge *iov) +{ + int rc; + + ib_dma_unmap_single(ia->ri_id->device, + iov->addr, iov->length, DMA_BIDIRECTIONAL); + + if (NULL == mr) + return 0; + + rc = ib_dereg_mr(mr); + if (rc) + dprintk("RPC: %s: ib_dereg_mr failed %i\n", __func__, rc); + return rc; +} + +/* + * Wrappers for chunk registration, shared by read/write chunk code. + */ + +static void +rpcrdma_map_one(struct rpcrdma_ia *ia, struct rpcrdma_mr_seg *seg, int writing) +{ + seg->mr_dir = writing ? DMA_FROM_DEVICE : DMA_TO_DEVICE; + seg->mr_dmalen = seg->mr_len; + if (seg->mr_page) + seg->mr_dma = ib_dma_map_page(ia->ri_id->device, + seg->mr_page, offset_in_page(seg->mr_offset), + seg->mr_dmalen, seg->mr_dir); + else + seg->mr_dma = ib_dma_map_single(ia->ri_id->device, + seg->mr_offset, + seg->mr_dmalen, seg->mr_dir); +} + +static void +rpcrdma_unmap_one(struct rpcrdma_ia *ia, struct rpcrdma_mr_seg *seg) +{ + if (seg->mr_page) + ib_dma_unmap_page(ia->ri_id->device, + seg->mr_dma, seg->mr_dmalen, seg->mr_dir); + else + ib_dma_unmap_single(ia->ri_id->device, + seg->mr_dma, seg->mr_dmalen, seg->mr_dir); +} + +static int +rpcrdma_register_frmr_external(struct rpcrdma_mr_seg *seg, + int *nsegs, int writing, struct rpcrdma_ia *ia, + struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_mr_seg *seg1 = seg; + struct ib_send_wr frmr_wr, *bad_wr; + u8 key; + int len, pageoff; + int i, rc; + + pageoff = offset_in_page(seg1->mr_offset); + seg1->mr_offset -= pageoff; /* start of page */ + seg1->mr_len += pageoff; + len = -pageoff; + if (*nsegs > RPCRDMA_MAX_DATA_SEGS) + *nsegs = RPCRDMA_MAX_DATA_SEGS; + for (i = 0; i < *nsegs;) { + rpcrdma_map_one(ia, seg, writing); + seg1->mr_chunk.rl_mw->r.frmr.fr_pgl->page_list[i] = seg->mr_dma; + len += seg->mr_len; + ++seg; + ++i; + /* Check for holes */ + if ((i < *nsegs && offset_in_page(seg->mr_offset)) || + offset_in_page((seg-1)->mr_offset + (seg-1)->mr_len)) + break; + } + dprintk("RPC: %s: Using frmr %p to map %d segments\n", + __func__, seg1->mr_chunk.rl_mw, i); + + /* Bump the key */ + key = (u8)(seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey & 0x000000FF); + ib_update_fast_reg_key(seg1->mr_chunk.rl_mw->r.frmr.fr_mr, ++key); + + /* Prepare FRMR WR */ + memset(&frmr_wr, 0, sizeof frmr_wr); + frmr_wr.opcode = IB_WR_FAST_REG_MR; + frmr_wr.send_flags = 0; /* unsignaled */ + frmr_wr.wr.fast_reg.iova_start = (unsigned long)seg1->mr_dma; + frmr_wr.wr.fast_reg.page_list = seg1->mr_chunk.rl_mw->r.frmr.fr_pgl; + frmr_wr.wr.fast_reg.page_list_len = i; + frmr_wr.wr.fast_reg.page_shift = PAGE_SHIFT; + frmr_wr.wr.fast_reg.length = i << PAGE_SHIFT; + frmr_wr.wr.fast_reg.access_flags = (writing ? + IB_ACCESS_REMOTE_WRITE : IB_ACCESS_REMOTE_READ); + frmr_wr.wr.fast_reg.rkey = seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey; + DECR_CQCOUNT(&r_xprt->rx_ep); + + rc = ib_post_send(ia->ri_id->qp, &frmr_wr, &bad_wr); + + if (rc) { + dprintk("RPC: %s: failed ib_post_send for register," + " status %i\n", __func__, rc); + while (i--) + rpcrdma_unmap_one(ia, --seg); + } else { + seg1->mr_rkey = seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey; + seg1->mr_base = seg1->mr_dma + pageoff; + seg1->mr_nsegs = i; + seg1->mr_len = len; + } + *nsegs = i; + return rc; +} + +static int +rpcrdma_deregister_frmr_external(struct rpcrdma_mr_seg *seg, + struct rpcrdma_ia *ia, struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_mr_seg *seg1 = seg; + struct ib_send_wr invalidate_wr, *bad_wr; + int rc; + + while (seg1->mr_nsegs--) + rpcrdma_unmap_one(ia, seg++); + + memset(&invalidate_wr, 0, sizeof invalidate_wr); + invalidate_wr.opcode = IB_WR_LOCAL_INV; + invalidate_wr.send_flags = 0; /* unsignaled */ + invalidate_wr.ex.invalidate_rkey = seg1->mr_chunk.rl_mw->r.frmr.fr_mr->rkey; + DECR_CQCOUNT(&r_xprt->rx_ep); + + rc = ib_post_send(ia->ri_id->qp, &invalidate_wr, &bad_wr); + if (rc) + dprintk("RPC: %s: failed ib_post_send for invalidate," + " status %i\n", __func__, rc); + return rc; +} + +static int +rpcrdma_register_fmr_external(struct rpcrdma_mr_seg *seg, + int *nsegs, int writing, struct rpcrdma_ia *ia) +{ + struct rpcrdma_mr_seg *seg1 = seg; + u64 physaddrs[RPCRDMA_MAX_DATA_SEGS]; + int len, pageoff, i, rc; + + pageoff = offset_in_page(seg1->mr_offset); + seg1->mr_offset -= pageoff; /* start of page */ + seg1->mr_len += pageoff; + len = -pageoff; + if (*nsegs > RPCRDMA_MAX_DATA_SEGS) + *nsegs = RPCRDMA_MAX_DATA_SEGS; + for (i = 0; i < *nsegs;) { + rpcrdma_map_one(ia, seg, writing); + physaddrs[i] = seg->mr_dma; + len += seg->mr_len; + ++seg; + ++i; + /* Check for holes */ + if ((i < *nsegs && offset_in_page(seg->mr_offset)) || + offset_in_page((seg-1)->mr_offset + (seg-1)->mr_len)) + break; + } + rc = ib_map_phys_fmr(seg1->mr_chunk.rl_mw->r.fmr, + physaddrs, i, seg1->mr_dma); + if (rc) { + dprintk("RPC: %s: failed ib_map_phys_fmr " + "%u@0x%llx+%i (%d)... status %i\n", __func__, + len, (unsigned long long)seg1->mr_dma, + pageoff, i, rc); + while (i--) + rpcrdma_unmap_one(ia, --seg); + } else { + seg1->mr_rkey = seg1->mr_chunk.rl_mw->r.fmr->rkey; + seg1->mr_base = seg1->mr_dma + pageoff; + seg1->mr_nsegs = i; + seg1->mr_len = len; + } + *nsegs = i; + return rc; +} + +static int +rpcrdma_deregister_fmr_external(struct rpcrdma_mr_seg *seg, + struct rpcrdma_ia *ia) +{ + struct rpcrdma_mr_seg *seg1 = seg; + LIST_HEAD(l); + int rc; + + list_add(&seg1->mr_chunk.rl_mw->r.fmr->list, &l); + rc = ib_unmap_fmr(&l); + while (seg1->mr_nsegs--) + rpcrdma_unmap_one(ia, seg++); + if (rc) + dprintk("RPC: %s: failed ib_unmap_fmr," + " status %i\n", __func__, rc); + return rc; +} + +static int +rpcrdma_register_memwin_external(struct rpcrdma_mr_seg *seg, + int *nsegs, int writing, struct rpcrdma_ia *ia, + struct rpcrdma_xprt *r_xprt) +{ + int mem_priv = (writing ? IB_ACCESS_REMOTE_WRITE : + IB_ACCESS_REMOTE_READ); + struct ib_mw_bind param; + int rc; + + *nsegs = 1; + rpcrdma_map_one(ia, seg, writing); + param.mr = ia->ri_bind_mem; + param.wr_id = 0ULL; /* no send cookie */ + param.addr = seg->mr_dma; + param.length = seg->mr_len; + param.send_flags = 0; + param.mw_access_flags = mem_priv; + + DECR_CQCOUNT(&r_xprt->rx_ep); + rc = ib_bind_mw(ia->ri_id->qp, seg->mr_chunk.rl_mw->r.mw, ¶m); + if (rc) { + dprintk("RPC: %s: failed ib_bind_mw " + "%u@0x%llx status %i\n", + __func__, seg->mr_len, + (unsigned long long)seg->mr_dma, rc); + rpcrdma_unmap_one(ia, seg); + } else { + seg->mr_rkey = seg->mr_chunk.rl_mw->r.mw->rkey; + seg->mr_base = param.addr; + seg->mr_nsegs = 1; + } + return rc; +} + +static int +rpcrdma_deregister_memwin_external(struct rpcrdma_mr_seg *seg, + struct rpcrdma_ia *ia, + struct rpcrdma_xprt *r_xprt, void **r) +{ + struct ib_mw_bind param; + LIST_HEAD(l); + int rc; + + BUG_ON(seg->mr_nsegs != 1); + param.mr = ia->ri_bind_mem; + param.addr = 0ULL; /* unbind */ + param.length = 0; + param.mw_access_flags = 0; + if (*r) { + param.wr_id = (u64) (unsigned long) *r; + param.send_flags = IB_SEND_SIGNALED; + INIT_CQCOUNT(&r_xprt->rx_ep); + } else { + param.wr_id = 0ULL; + param.send_flags = 0; + DECR_CQCOUNT(&r_xprt->rx_ep); + } + rc = ib_bind_mw(ia->ri_id->qp, seg->mr_chunk.rl_mw->r.mw, ¶m); + rpcrdma_unmap_one(ia, seg); + if (rc) + dprintk("RPC: %s: failed ib_(un)bind_mw," + " status %i\n", __func__, rc); + else + *r = NULL; /* will upcall on completion */ + return rc; +} + +static int +rpcrdma_register_default_external(struct rpcrdma_mr_seg *seg, + int *nsegs, int writing, struct rpcrdma_ia *ia) +{ + int mem_priv = (writing ? IB_ACCESS_REMOTE_WRITE : + IB_ACCESS_REMOTE_READ); + struct rpcrdma_mr_seg *seg1 = seg; + struct ib_phys_buf ipb[RPCRDMA_MAX_DATA_SEGS]; + int len, i, rc = 0; + + if (*nsegs > RPCRDMA_MAX_DATA_SEGS) + *nsegs = RPCRDMA_MAX_DATA_SEGS; + for (len = 0, i = 0; i < *nsegs;) { + rpcrdma_map_one(ia, seg, writing); + ipb[i].addr = seg->mr_dma; + ipb[i].size = seg->mr_len; + len += seg->mr_len; + ++seg; + ++i; + /* Check for holes */ + if ((i < *nsegs && offset_in_page(seg->mr_offset)) || + offset_in_page((seg-1)->mr_offset+(seg-1)->mr_len)) + break; + } + seg1->mr_base = seg1->mr_dma; + seg1->mr_chunk.rl_mr = ib_reg_phys_mr(ia->ri_pd, + ipb, i, mem_priv, &seg1->mr_base); + if (IS_ERR(seg1->mr_chunk.rl_mr)) { + rc = PTR_ERR(seg1->mr_chunk.rl_mr); + dprintk("RPC: %s: failed ib_reg_phys_mr " + "%u@0x%llx (%d)... status %i\n", + __func__, len, + (unsigned long long)seg1->mr_dma, i, rc); + while (i--) + rpcrdma_unmap_one(ia, --seg); + } else { + seg1->mr_rkey = seg1->mr_chunk.rl_mr->rkey; + seg1->mr_nsegs = i; + seg1->mr_len = len; + } + *nsegs = i; + return rc; +} + +static int +rpcrdma_deregister_default_external(struct rpcrdma_mr_seg *seg, + struct rpcrdma_ia *ia) +{ + struct rpcrdma_mr_seg *seg1 = seg; + int rc; + + rc = ib_dereg_mr(seg1->mr_chunk.rl_mr); + seg1->mr_chunk.rl_mr = NULL; + while (seg1->mr_nsegs--) + rpcrdma_unmap_one(ia, seg++); + if (rc) + dprintk("RPC: %s: failed ib_dereg_mr," + " status %i\n", __func__, rc); + return rc; +} + +int +rpcrdma_register_external(struct rpcrdma_mr_seg *seg, + int nsegs, int writing, struct rpcrdma_xprt *r_xprt) +{ + struct rpcrdma_ia *ia = &r_xprt->rx_ia; + int rc = 0; + + switch (ia->ri_memreg_strategy) { + +#if RPCRDMA_PERSISTENT_REGISTRATION + case RPCRDMA_ALLPHYSICAL: + rpcrdma_map_one(ia, seg, writing); + seg->mr_rkey = ia->ri_bind_mem->rkey; + seg->mr_base = seg->mr_dma; + seg->mr_nsegs = 1; + nsegs = 1; + break; +#endif + + /* Registration using frmr registration */ + case RPCRDMA_FRMR: + rc = rpcrdma_register_frmr_external(seg, &nsegs, writing, ia, r_xprt); + break; + + /* Registration using fmr memory registration */ + case RPCRDMA_MTHCAFMR: + rc = rpcrdma_register_fmr_external(seg, &nsegs, writing, ia); + break; + + /* Registration using memory windows */ + case RPCRDMA_MEMWINDOWS_ASYNC: + case RPCRDMA_MEMWINDOWS: + rc = rpcrdma_register_memwin_external(seg, &nsegs, writing, ia, r_xprt); + break; + + /* Default registration each time */ + default: + rc = rpcrdma_register_default_external(seg, &nsegs, writing, ia); + break; + } + if (rc) + return -1; + + return nsegs; +} + +int +rpcrdma_deregister_external(struct rpcrdma_mr_seg *seg, + struct rpcrdma_xprt *r_xprt, void *r) +{ + struct rpcrdma_ia *ia = &r_xprt->rx_ia; + int nsegs = seg->mr_nsegs, rc; + + switch (ia->ri_memreg_strategy) { + +#if RPCRDMA_PERSISTENT_REGISTRATION + case RPCRDMA_ALLPHYSICAL: + BUG_ON(nsegs != 1); + rpcrdma_unmap_one(ia, seg); + rc = 0; + break; +#endif + + case RPCRDMA_FRMR: + rc = rpcrdma_deregister_frmr_external(seg, ia, r_xprt); + break; + + case RPCRDMA_MTHCAFMR: + rc = rpcrdma_deregister_fmr_external(seg, ia); + break; + + case RPCRDMA_MEMWINDOWS_ASYNC: + case RPCRDMA_MEMWINDOWS: + rc = rpcrdma_deregister_memwin_external(seg, ia, r_xprt, &r); + break; + + default: + rc = rpcrdma_deregister_default_external(seg, ia); + break; + } + if (r) { + struct rpcrdma_rep *rep = r; + void (*func)(struct rpcrdma_rep *) = rep->rr_func; + rep->rr_func = NULL; + func(rep); /* dereg done, callback now */ + } + return nsegs; +} + +/* + * Prepost any receive buffer, then post send. + * + * Receive buffer is donated to hardware, reclaimed upon recv completion. + */ +int +rpcrdma_ep_post(struct rpcrdma_ia *ia, + struct rpcrdma_ep *ep, + struct rpcrdma_req *req) +{ + struct ib_send_wr send_wr, *send_wr_fail; + struct rpcrdma_rep *rep = req->rl_reply; + int rc; + + if (rep) { + rc = rpcrdma_ep_post_recv(ia, ep, rep); + if (rc) + goto out; + req->rl_reply = NULL; + } + + send_wr.next = NULL; + send_wr.wr_id = 0ULL; /* no send cookie */ + send_wr.sg_list = req->rl_send_iov; + send_wr.num_sge = req->rl_niovs; + send_wr.opcode = IB_WR_SEND; + if (send_wr.num_sge == 4) /* no need to sync any pad (constant) */ + ib_dma_sync_single_for_device(ia->ri_id->device, + req->rl_send_iov[3].addr, req->rl_send_iov[3].length, + DMA_TO_DEVICE); + ib_dma_sync_single_for_device(ia->ri_id->device, + req->rl_send_iov[1].addr, req->rl_send_iov[1].length, + DMA_TO_DEVICE); + ib_dma_sync_single_for_device(ia->ri_id->device, + req->rl_send_iov[0].addr, req->rl_send_iov[0].length, + DMA_TO_DEVICE); + + if (DECR_CQCOUNT(ep) > 0) + send_wr.send_flags = 0; + else { /* Provider must take a send completion every now and then */ + INIT_CQCOUNT(ep); + send_wr.send_flags = IB_SEND_SIGNALED; + } + + rc = ib_post_send(ia->ri_id->qp, &send_wr, &send_wr_fail); + if (rc) + dprintk("RPC: %s: ib_post_send returned %i\n", __func__, + rc); +out: + return rc; +} + +/* + * (Re)post a receive buffer. + */ +int +rpcrdma_ep_post_recv(struct rpcrdma_ia *ia, + struct rpcrdma_ep *ep, + struct rpcrdma_rep *rep) +{ + struct ib_recv_wr recv_wr, *recv_wr_fail; + int rc; + + recv_wr.next = NULL; + recv_wr.wr_id = (u64) (unsigned long) rep; + recv_wr.sg_list = &rep->rr_iov; + recv_wr.num_sge = 1; + + ib_dma_sync_single_for_cpu(ia->ri_id->device, + rep->rr_iov.addr, rep->rr_iov.length, DMA_BIDIRECTIONAL); + + DECR_CQCOUNT(ep); + rc = ib_post_recv(ia->ri_id->qp, &recv_wr, &recv_wr_fail); + + if (rc) + dprintk("RPC: %s: ib_post_recv returned %i\n", __func__, + rc); + return rc; +} diff --git a/net/sunrpc/xprtrdma/xprt_rdma.h b/net/sunrpc/xprtrdma/xprt_rdma.h new file mode 100644 index 0000000..c7a7eba --- /dev/null +++ b/net/sunrpc/xprtrdma/xprt_rdma.h @@ -0,0 +1,345 @@ +/* + * Copyright (c) 2003-2007 Network Appliance, Inc. All rights reserved. + * + * This software is available to you under a choice of one of two + * licenses. You may choose to be licensed under the terms of the GNU + * General Public License (GPL) Version 2, available from the file + * COPYING in the main directory of this source tree, or the BSD-type + * license below: + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following + * disclaimer in the documentation and/or other materials provided + * with the distribution. + * + * Neither the name of the Network Appliance, Inc. nor the names of + * its contributors may be used to endorse or promote products + * derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef _LINUX_SUNRPC_XPRT_RDMA_H +#define _LINUX_SUNRPC_XPRT_RDMA_H + +#include <linux/wait.h> /* wait_queue_head_t, etc */ +#include <linux/spinlock.h> /* spinlock_t, etc */ +#include <asm/atomic.h> /* atomic_t, etc */ + +#include <rdma/rdma_cm.h> /* RDMA connection api */ +#include <rdma/ib_verbs.h> /* RDMA verbs api */ + +#include <linux/sunrpc/clnt.h> /* rpc_xprt */ +#include <linux/sunrpc/rpc_rdma.h> /* RPC/RDMA protocol */ +#include <linux/sunrpc/xprtrdma.h> /* xprt parameters */ + +#define RDMA_RESOLVE_TIMEOUT (5000) /* 5 seconds */ +#define RDMA_CONNECT_RETRY_MAX (2) /* retries if no listener backlog */ + +/* + * Interface Adapter -- one per transport instance + */ +struct rpcrdma_ia { + struct rdma_cm_id *ri_id; + struct ib_pd *ri_pd; + struct ib_mr *ri_bind_mem; + u32 ri_dma_lkey; + int ri_have_dma_lkey; + struct completion ri_done; + int ri_async_rc; + enum rpcrdma_memreg ri_memreg_strategy; +}; + +/* + * RDMA Endpoint -- one per transport instance + */ + +struct rpcrdma_ep { + atomic_t rep_cqcount; + int rep_cqinit; + int rep_connected; + struct rpcrdma_ia *rep_ia; + struct ib_cq *rep_cq; + struct ib_qp_init_attr rep_attr; + wait_queue_head_t rep_connect_wait; + struct ib_sge rep_pad; /* holds zeroed pad */ + struct ib_mr *rep_pad_mr; /* holds zeroed pad */ + void (*rep_func)(struct rpcrdma_ep *); + struct rpc_xprt *rep_xprt; /* for rep_func */ + struct rdma_conn_param rep_remote_cma; + struct sockaddr_storage rep_remote_addr; +}; + +#define INIT_CQCOUNT(ep) atomic_set(&(ep)->rep_cqcount, (ep)->rep_cqinit) +#define DECR_CQCOUNT(ep) atomic_sub_return(1, &(ep)->rep_cqcount) + +/* + * struct rpcrdma_rep -- this structure encapsulates state required to recv + * and complete a reply, asychronously. It needs several pieces of + * state: + * o recv buffer (posted to provider) + * o ib_sge (also donated to provider) + * o status of reply (length, success or not) + * o bookkeeping state to get run by tasklet (list, etc) + * + * These are allocated during initialization, per-transport instance; + * however, the tasklet execution list itself is global, as it should + * always be pretty short. + * + * N of these are associated with a transport instance, and stored in + * struct rpcrdma_buffer. N is the max number of outstanding requests. + */ + +/* temporary static scatter/gather max */ +#define RPCRDMA_MAX_DATA_SEGS (8) /* max scatter/gather */ +#define RPCRDMA_MAX_SEGS (RPCRDMA_MAX_DATA_SEGS + 2) /* head+tail = 2 */ +#define MAX_RPCRDMAHDR (\ + /* max supported RPC/RDMA header */ \ + sizeof(struct rpcrdma_msg) + (2 * sizeof(u32)) + \ + (sizeof(struct rpcrdma_read_chunk) * RPCRDMA_MAX_SEGS) + sizeof(u32)) + +struct rpcrdma_buffer; + +struct rpcrdma_rep { + unsigned int rr_len; /* actual received reply length */ + struct rpcrdma_buffer *rr_buffer; /* home base for this structure */ + struct rpc_xprt *rr_xprt; /* needed for request/reply matching */ + void (*rr_func)(struct rpcrdma_rep *);/* called by tasklet in softint */ + struct list_head rr_list; /* tasklet list */ + wait_queue_head_t rr_unbind; /* optional unbind wait */ + struct ib_sge rr_iov; /* for posting */ + struct ib_mr *rr_handle; /* handle for mem in rr_iov */ + char rr_base[MAX_RPCRDMAHDR]; /* minimal inline receive buffer */ +}; + +/* + * struct rpcrdma_req -- structure central to the request/reply sequence. + * + * N of these are associated with a transport instance, and stored in + * struct rpcrdma_buffer. N is the max number of outstanding requests. + * + * It includes pre-registered buffer memory for send AND recv. + * The recv buffer, however, is not owned by this structure, and + * is "donated" to the hardware when a recv is posted. When a + * reply is handled, the recv buffer used is given back to the + * struct rpcrdma_req associated with the request. + * + * In addition to the basic memory, this structure includes an array + * of iovs for send operations. The reason is that the iovs passed to + * ib_post_{send,recv} must not be modified until the work request + * completes. + * + * NOTES: + * o RPCRDMA_MAX_SEGS is the max number of addressible chunk elements we + * marshal. The number needed varies depending on the iov lists that + * are passed to us, the memory registration mode we are in, and if + * physical addressing is used, the layout. + */ + +struct rpcrdma_mr_seg { /* chunk descriptors */ + union { /* chunk memory handles */ + struct ib_mr *rl_mr; /* if registered directly */ + struct rpcrdma_mw { /* if registered from region */ + union { + struct ib_mw *mw; + struct ib_fmr *fmr; + struct { + struct ib_fast_reg_page_list *fr_pgl; + struct ib_mr *fr_mr; + } frmr; + } r; + struct list_head mw_list; + } *rl_mw; + } mr_chunk; + u64 mr_base; /* registration result */ + u32 mr_rkey; /* registration result */ + u32 mr_len; /* length of chunk or segment */ + int mr_nsegs; /* number of segments in chunk or 0 */ + enum dma_data_direction mr_dir; /* segment mapping direction */ + dma_addr_t mr_dma; /* segment mapping address */ + size_t mr_dmalen; /* segment mapping length */ + struct page *mr_page; /* owning page, if any */ + char *mr_offset; /* kva if no page, else offset */ +}; + +struct rpcrdma_req { + size_t rl_size; /* actual length of buffer */ + unsigned int rl_niovs; /* 0, 2 or 4 */ + unsigned int rl_nchunks; /* non-zero if chunks */ + unsigned int rl_connect_cookie; /* retry detection */ + struct rpcrdma_buffer *rl_buffer; /* home base for this structure */ + struct rpcrdma_rep *rl_reply;/* holder for reply buffer */ + struct rpcrdma_mr_seg rl_segments[RPCRDMA_MAX_SEGS];/* chunk segments */ + struct ib_sge rl_send_iov[4]; /* for active requests */ + struct ib_sge rl_iov; /* for posting */ + struct ib_mr *rl_handle; /* handle for mem in rl_iov */ + char rl_base[MAX_RPCRDMAHDR]; /* start of actual buffer */ + __u32 rl_xdr_buf[0]; /* start of returned rpc rq_buffer */ +}; +#define rpcr_to_rdmar(r) \ + container_of((r)->rq_buffer, struct rpcrdma_req, rl_xdr_buf[0]) + +/* + * struct rpcrdma_buffer -- holds list/queue of pre-registered memory for + * inline requests/replies, and client/server credits. + * + * One of these is associated with a transport instance + */ +struct rpcrdma_buffer { + spinlock_t rb_lock; /* protects indexes */ + atomic_t rb_credits; /* most recent server credits */ + unsigned long rb_cwndscale; /* cached framework rpc_cwndscale */ + int rb_max_requests;/* client max requests */ + struct list_head rb_mws; /* optional memory windows/fmrs/frmrs */ + int rb_send_index; + struct rpcrdma_req **rb_send_bufs; + int rb_recv_index; + struct rpcrdma_rep **rb_recv_bufs; + char *rb_pool; +}; +#define rdmab_to_ia(b) (&container_of((b), struct rpcrdma_xprt, rx_buf)->rx_ia) + +/* + * Internal structure for transport instance creation. This + * exists primarily for modularity. + * + * This data should be set with mount options + */ +struct rpcrdma_create_data_internal { + struct sockaddr_storage addr; /* RDMA server address */ + unsigned int max_requests; /* max requests (slots) in flight */ + unsigned int rsize; /* mount rsize - max read hdr+data */ + unsigned int wsize; /* mount wsize - max write hdr+data */ + unsigned int inline_rsize; /* max non-rdma read data payload */ + unsigned int inline_wsize; /* max non-rdma write data payload */ + unsigned int padding; /* non-rdma write header padding */ +}; + +#define RPCRDMA_INLINE_READ_THRESHOLD(rq) \ + (rpcx_to_rdmad(rq->rq_task->tk_xprt).inline_rsize) + +#define RPCRDMA_INLINE_WRITE_THRESHOLD(rq)\ + (rpcx_to_rdmad(rq->rq_task->tk_xprt).inline_wsize) + +#define RPCRDMA_INLINE_PAD_VALUE(rq)\ + rpcx_to_rdmad(rq->rq_task->tk_xprt).padding + +/* + * Statistics for RPCRDMA + */ +struct rpcrdma_stats { + unsigned long read_chunk_count; + unsigned long write_chunk_count; + unsigned long reply_chunk_count; + + unsigned long long total_rdma_request; + unsigned long long total_rdma_reply; + + unsigned long long pullup_copy_count; + unsigned long long fixup_copy_count; + unsigned long hardway_register_count; + unsigned long failed_marshal_count; + unsigned long bad_reply_count; +}; + +/* + * RPCRDMA transport -- encapsulates the structures above for + * integration with RPC. + * + * The contained structures are embedded, not pointers, + * for convenience. This structure need not be visible externally. + * + * It is allocated and initialized during mount, and released + * during unmount. + */ +struct rpcrdma_xprt { + struct rpc_xprt xprt; + struct rpcrdma_ia rx_ia; + struct rpcrdma_ep rx_ep; + struct rpcrdma_buffer rx_buf; + struct rpcrdma_create_data_internal rx_data; + struct delayed_work rdma_connect; + struct rpcrdma_stats rx_stats; +}; + +#define rpcx_to_rdmax(x) container_of(x, struct rpcrdma_xprt, xprt) +#define rpcx_to_rdmad(x) (rpcx_to_rdmax(x)->rx_data) + +/* Setting this to 0 ensures interoperability with early servers. + * Setting this to 1 enhances certain unaligned read/write performance. + * Default is 0, see sysctl entry and rpc_rdma.c rpcrdma_convert_iovs() */ +extern int xprt_rdma_pad_optimize; + +/* + * Interface Adapter calls - xprtrdma/verbs.c + */ +int rpcrdma_ia_open(struct rpcrdma_xprt *, struct sockaddr *, int); +void rpcrdma_ia_close(struct rpcrdma_ia *); + +/* + * Endpoint calls - xprtrdma/verbs.c + */ +int rpcrdma_ep_create(struct rpcrdma_ep *, struct rpcrdma_ia *, + struct rpcrdma_create_data_internal *); +int rpcrdma_ep_destroy(struct rpcrdma_ep *, struct rpcrdma_ia *); +int rpcrdma_ep_connect(struct rpcrdma_ep *, struct rpcrdma_ia *); +int rpcrdma_ep_disconnect(struct rpcrdma_ep *, struct rpcrdma_ia *); + +int rpcrdma_ep_post(struct rpcrdma_ia *, struct rpcrdma_ep *, + struct rpcrdma_req *); +int rpcrdma_ep_post_recv(struct rpcrdma_ia *, struct rpcrdma_ep *, + struct rpcrdma_rep *); + +/* + * Buffer calls - xprtrdma/verbs.c + */ +int rpcrdma_buffer_create(struct rpcrdma_buffer *, struct rpcrdma_ep *, + struct rpcrdma_ia *, + struct rpcrdma_create_data_internal *); +void rpcrdma_buffer_destroy(struct rpcrdma_buffer *); + +struct rpcrdma_req *rpcrdma_buffer_get(struct rpcrdma_buffer *); +void rpcrdma_buffer_put(struct rpcrdma_req *); +void rpcrdma_recv_buffer_get(struct rpcrdma_req *); +void rpcrdma_recv_buffer_put(struct rpcrdma_rep *); + +int rpcrdma_register_internal(struct rpcrdma_ia *, void *, int, + struct ib_mr **, struct ib_sge *); +int rpcrdma_deregister_internal(struct rpcrdma_ia *, + struct ib_mr *, struct ib_sge *); + +int rpcrdma_register_external(struct rpcrdma_mr_seg *, + int, int, struct rpcrdma_xprt *); +int rpcrdma_deregister_external(struct rpcrdma_mr_seg *, + struct rpcrdma_xprt *, void *); + +/* + * RPC/RDMA connection management calls - xprtrdma/rpc_rdma.c + */ +void rpcrdma_conn_func(struct rpcrdma_ep *); +void rpcrdma_reply_handler(struct rpcrdma_rep *); + +/* + * RPC/RDMA protocol calls - xprtrdma/rpc_rdma.c + */ +int rpcrdma_marshal_req(struct rpc_rqst *); + +#endif /* _LINUX_SUNRPC_XPRT_RDMA_H */ diff --git a/net/sunrpc/xprtsock.c b/net/sunrpc/xprtsock.c new file mode 100644 index 0000000..0a50361 --- /dev/null +++ b/net/sunrpc/xprtsock.c @@ -0,0 +1,2129 @@ +/* + * linux/net/sunrpc/xprtsock.c + * + * Client-side transport implementation for sockets. + * + * TCP callback races fixes (C) 1998 Red Hat + * TCP send fixes (C) 1998 Red Hat + * TCP NFS related read + write fixes + * (C) 1999 Dave Airlie, University of Limerick, Ireland <airlied@linux.ie> + * + * Rewrite of larges part of the code in order to stabilize TCP stuff. + * Fix behaviour when socket buffer is full. + * (C) 1999 Trond Myklebust <trond.myklebust@fys.uio.no> + * + * IP socket transport implementation, (C) 2005 Chuck Lever <cel@netapp.com> + * + * IPv6 support contributed by Gilles Quillard, Bull Open Source, 2005. + * <gilles.quillard@bull.net> + */ + +#include <linux/types.h> +#include <linux/slab.h> +#include <linux/module.h> +#include <linux/capability.h> +#include <linux/pagemap.h> +#include <linux/errno.h> +#include <linux/socket.h> +#include <linux/in.h> +#include <linux/net.h> +#include <linux/mm.h> +#include <linux/udp.h> +#include <linux/tcp.h> +#include <linux/sunrpc/clnt.h> +#include <linux/sunrpc/sched.h> +#include <linux/sunrpc/xprtsock.h> +#include <linux/file.h> + +#include <net/sock.h> +#include <net/checksum.h> +#include <net/udp.h> +#include <net/tcp.h> + +/* + * xprtsock tunables + */ +unsigned int xprt_udp_slot_table_entries = RPC_DEF_SLOT_TABLE; +unsigned int xprt_tcp_slot_table_entries = RPC_DEF_SLOT_TABLE; + +unsigned int xprt_min_resvport = RPC_DEF_MIN_RESVPORT; +unsigned int xprt_max_resvport = RPC_DEF_MAX_RESVPORT; + +/* + * We can register our own files under /proc/sys/sunrpc by + * calling register_sysctl_table() again. The files in that + * directory become the union of all files registered there. + * + * We simply need to make sure that we don't collide with + * someone else's file names! + */ + +#ifdef RPC_DEBUG + +static unsigned int min_slot_table_size = RPC_MIN_SLOT_TABLE; +static unsigned int max_slot_table_size = RPC_MAX_SLOT_TABLE; +static unsigned int xprt_min_resvport_limit = RPC_MIN_RESVPORT; +static unsigned int xprt_max_resvport_limit = RPC_MAX_RESVPORT; + +static struct ctl_table_header *sunrpc_table_header; + +/* + * FIXME: changing the UDP slot table size should also resize the UDP + * socket buffers for existing UDP transports + */ +static ctl_table xs_tunables_table[] = { + { + .ctl_name = CTL_SLOTTABLE_UDP, + .procname = "udp_slot_table_entries", + .data = &xprt_udp_slot_table_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = &proc_dointvec_minmax, + .strategy = &sysctl_intvec, + .extra1 = &min_slot_table_size, + .extra2 = &max_slot_table_size + }, + { + .ctl_name = CTL_SLOTTABLE_TCP, + .procname = "tcp_slot_table_entries", + .data = &xprt_tcp_slot_table_entries, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = &proc_dointvec_minmax, + .strategy = &sysctl_intvec, + .extra1 = &min_slot_table_size, + .extra2 = &max_slot_table_size + }, + { + .ctl_name = CTL_MIN_RESVPORT, + .procname = "min_resvport", + .data = &xprt_min_resvport, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = &proc_dointvec_minmax, + .strategy = &sysctl_intvec, + .extra1 = &xprt_min_resvport_limit, + .extra2 = &xprt_max_resvport_limit + }, + { + .ctl_name = CTL_MAX_RESVPORT, + .procname = "max_resvport", + .data = &xprt_max_resvport, + .maxlen = sizeof(unsigned int), + .mode = 0644, + .proc_handler = &proc_dointvec_minmax, + .strategy = &sysctl_intvec, + .extra1 = &xprt_min_resvport_limit, + .extra2 = &xprt_max_resvport_limit + }, + { + .ctl_name = 0, + }, +}; + +static ctl_table sunrpc_table[] = { + { + .ctl_name = CTL_SUNRPC, + .procname = "sunrpc", + .mode = 0555, + .child = xs_tunables_table + }, + { + .ctl_name = 0, + }, +}; + +#endif + +/* + * Time out for an RPC UDP socket connect. UDP socket connects are + * synchronous, but we set a timeout anyway in case of resource + * exhaustion on the local host. + */ +#define XS_UDP_CONN_TO (5U * HZ) + +/* + * Wait duration for an RPC TCP connection to be established. Solaris + * NFS over TCP uses 60 seconds, for example, which is in line with how + * long a server takes to reboot. + */ +#define XS_TCP_CONN_TO (60U * HZ) + +/* + * Wait duration for a reply from the RPC portmapper. + */ +#define XS_BIND_TO (60U * HZ) + +/* + * Delay if a UDP socket connect error occurs. This is most likely some + * kind of resource problem on the local host. + */ +#define XS_UDP_REEST_TO (2U * HZ) + +/* + * The reestablish timeout allows clients to delay for a bit before attempting + * to reconnect to a server that just dropped our connection. + * + * We implement an exponential backoff when trying to reestablish a TCP + * transport connection with the server. Some servers like to drop a TCP + * connection when they are overworked, so we start with a short timeout and + * increase over time if the server is down or not responding. + */ +#define XS_TCP_INIT_REEST_TO (3U * HZ) +#define XS_TCP_MAX_REEST_TO (5U * 60 * HZ) + +/* + * TCP idle timeout; client drops the transport socket if it is idle + * for this long. Note that we also timeout UDP sockets to prevent + * holding port numbers when there is no RPC traffic. + */ +#define XS_IDLE_DISC_TO (5U * 60 * HZ) + +#ifdef RPC_DEBUG +# undef RPC_DEBUG_DATA +# define RPCDBG_FACILITY RPCDBG_TRANS +#endif + +#ifdef RPC_DEBUG_DATA +static void xs_pktdump(char *msg, u32 *packet, unsigned int count) +{ + u8 *buf = (u8 *) packet; + int j; + + dprintk("RPC: %s\n", msg); + for (j = 0; j < count && j < 128; j += 4) { + if (!(j & 31)) { + if (j) + dprintk("\n"); + dprintk("0x%04x ", j); + } + dprintk("%02x%02x%02x%02x ", + buf[j], buf[j+1], buf[j+2], buf[j+3]); + } + dprintk("\n"); +} +#else +static inline void xs_pktdump(char *msg, u32 *packet, unsigned int count) +{ + /* NOP */ +} +#endif + +struct sock_xprt { + struct rpc_xprt xprt; + + /* + * Network layer + */ + struct socket * sock; + struct sock * inet; + + /* + * State of TCP reply receive + */ + __be32 tcp_fraghdr, + tcp_xid; + + u32 tcp_offset, + tcp_reclen; + + unsigned long tcp_copied, + tcp_flags; + + /* + * Connection of transports + */ + struct delayed_work connect_worker; + struct sockaddr_storage addr; + unsigned short port; + + /* + * UDP socket buffer size parameters + */ + size_t rcvsize, + sndsize; + + /* + * Saved socket callback addresses + */ + void (*old_data_ready)(struct sock *, int); + void (*old_state_change)(struct sock *); + void (*old_write_space)(struct sock *); + void (*old_error_report)(struct sock *); +}; + +/* + * TCP receive state flags + */ +#define TCP_RCV_LAST_FRAG (1UL << 0) +#define TCP_RCV_COPY_FRAGHDR (1UL << 1) +#define TCP_RCV_COPY_XID (1UL << 2) +#define TCP_RCV_COPY_DATA (1UL << 3) + +static inline struct sockaddr *xs_addr(struct rpc_xprt *xprt) +{ + return (struct sockaddr *) &xprt->addr; +} + +static inline struct sockaddr_in *xs_addr_in(struct rpc_xprt *xprt) +{ + return (struct sockaddr_in *) &xprt->addr; +} + +static inline struct sockaddr_in6 *xs_addr_in6(struct rpc_xprt *xprt) +{ + return (struct sockaddr_in6 *) &xprt->addr; +} + +static void xs_format_ipv4_peer_addresses(struct rpc_xprt *xprt, + const char *protocol, + const char *netid) +{ + struct sockaddr_in *addr = xs_addr_in(xprt); + char *buf; + + buf = kzalloc(20, GFP_KERNEL); + if (buf) { + snprintf(buf, 20, NIPQUAD_FMT, + NIPQUAD(addr->sin_addr.s_addr)); + } + xprt->address_strings[RPC_DISPLAY_ADDR] = buf; + + buf = kzalloc(8, GFP_KERNEL); + if (buf) { + snprintf(buf, 8, "%u", + ntohs(addr->sin_port)); + } + xprt->address_strings[RPC_DISPLAY_PORT] = buf; + + xprt->address_strings[RPC_DISPLAY_PROTO] = protocol; + + buf = kzalloc(48, GFP_KERNEL); + if (buf) { + snprintf(buf, 48, "addr="NIPQUAD_FMT" port=%u proto=%s", + NIPQUAD(addr->sin_addr.s_addr), + ntohs(addr->sin_port), + protocol); + } + xprt->address_strings[RPC_DISPLAY_ALL] = buf; + + buf = kzalloc(10, GFP_KERNEL); + if (buf) { + snprintf(buf, 10, "%02x%02x%02x%02x", + NIPQUAD(addr->sin_addr.s_addr)); + } + xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = buf; + + buf = kzalloc(8, GFP_KERNEL); + if (buf) { + snprintf(buf, 8, "%4hx", + ntohs(addr->sin_port)); + } + xprt->address_strings[RPC_DISPLAY_HEX_PORT] = buf; + + buf = kzalloc(30, GFP_KERNEL); + if (buf) { + snprintf(buf, 30, NIPQUAD_FMT".%u.%u", + NIPQUAD(addr->sin_addr.s_addr), + ntohs(addr->sin_port) >> 8, + ntohs(addr->sin_port) & 0xff); + } + xprt->address_strings[RPC_DISPLAY_UNIVERSAL_ADDR] = buf; + + xprt->address_strings[RPC_DISPLAY_NETID] = netid; +} + +static void xs_format_ipv6_peer_addresses(struct rpc_xprt *xprt, + const char *protocol, + const char *netid) +{ + struct sockaddr_in6 *addr = xs_addr_in6(xprt); + char *buf; + + buf = kzalloc(40, GFP_KERNEL); + if (buf) { + snprintf(buf, 40, NIP6_FMT, + NIP6(addr->sin6_addr)); + } + xprt->address_strings[RPC_DISPLAY_ADDR] = buf; + + buf = kzalloc(8, GFP_KERNEL); + if (buf) { + snprintf(buf, 8, "%u", + ntohs(addr->sin6_port)); + } + xprt->address_strings[RPC_DISPLAY_PORT] = buf; + + xprt->address_strings[RPC_DISPLAY_PROTO] = protocol; + + buf = kzalloc(64, GFP_KERNEL); + if (buf) { + snprintf(buf, 64, "addr="NIP6_FMT" port=%u proto=%s", + NIP6(addr->sin6_addr), + ntohs(addr->sin6_port), + protocol); + } + xprt->address_strings[RPC_DISPLAY_ALL] = buf; + + buf = kzalloc(36, GFP_KERNEL); + if (buf) { + snprintf(buf, 36, NIP6_SEQFMT, + NIP6(addr->sin6_addr)); + } + xprt->address_strings[RPC_DISPLAY_HEX_ADDR] = buf; + + buf = kzalloc(8, GFP_KERNEL); + if (buf) { + snprintf(buf, 8, "%4hx", + ntohs(addr->sin6_port)); + } + xprt->address_strings[RPC_DISPLAY_HEX_PORT] = buf; + + buf = kzalloc(50, GFP_KERNEL); + if (buf) { + snprintf(buf, 50, NIP6_FMT".%u.%u", + NIP6(addr->sin6_addr), + ntohs(addr->sin6_port) >> 8, + ntohs(addr->sin6_port) & 0xff); + } + xprt->address_strings[RPC_DISPLAY_UNIVERSAL_ADDR] = buf; + + xprt->address_strings[RPC_DISPLAY_NETID] = netid; +} + +static void xs_free_peer_addresses(struct rpc_xprt *xprt) +{ + unsigned int i; + + for (i = 0; i < RPC_DISPLAY_MAX; i++) + switch (i) { + case RPC_DISPLAY_PROTO: + case RPC_DISPLAY_NETID: + continue; + default: + kfree(xprt->address_strings[i]); + } +} + +#define XS_SENDMSG_FLAGS (MSG_DONTWAIT | MSG_NOSIGNAL) + +static int xs_send_kvec(struct socket *sock, struct sockaddr *addr, int addrlen, struct kvec *vec, unsigned int base, int more) +{ + struct msghdr msg = { + .msg_name = addr, + .msg_namelen = addrlen, + .msg_flags = XS_SENDMSG_FLAGS | (more ? MSG_MORE : 0), + }; + struct kvec iov = { + .iov_base = vec->iov_base + base, + .iov_len = vec->iov_len - base, + }; + + if (iov.iov_len != 0) + return kernel_sendmsg(sock, &msg, &iov, 1, iov.iov_len); + return kernel_sendmsg(sock, &msg, NULL, 0, 0); +} + +static int xs_send_pagedata(struct socket *sock, struct xdr_buf *xdr, unsigned int base, int more) +{ + struct page **ppage; + unsigned int remainder; + int err, sent = 0; + + remainder = xdr->page_len - base; + base += xdr->page_base; + ppage = xdr->pages + (base >> PAGE_SHIFT); + base &= ~PAGE_MASK; + for(;;) { + unsigned int len = min_t(unsigned int, PAGE_SIZE - base, remainder); + int flags = XS_SENDMSG_FLAGS; + + remainder -= len; + if (remainder != 0 || more) + flags |= MSG_MORE; + err = sock->ops->sendpage(sock, *ppage, base, len, flags); + if (remainder == 0 || err != len) + break; + sent += err; + ppage++; + base = 0; + } + if (sent == 0) + return err; + if (err > 0) + sent += err; + return sent; +} + +/** + * xs_sendpages - write pages directly to a socket + * @sock: socket to send on + * @addr: UDP only -- address of destination + * @addrlen: UDP only -- length of destination address + * @xdr: buffer containing this request + * @base: starting position in the buffer + * + */ +static int xs_sendpages(struct socket *sock, struct sockaddr *addr, int addrlen, struct xdr_buf *xdr, unsigned int base) +{ + unsigned int remainder = xdr->len - base; + int err, sent = 0; + + if (unlikely(!sock)) + return -ENOTCONN; + + clear_bit(SOCK_ASYNC_NOSPACE, &sock->flags); + if (base != 0) { + addr = NULL; + addrlen = 0; + } + + if (base < xdr->head[0].iov_len || addr != NULL) { + unsigned int len = xdr->head[0].iov_len - base; + remainder -= len; + err = xs_send_kvec(sock, addr, addrlen, &xdr->head[0], base, remainder != 0); + if (remainder == 0 || err != len) + goto out; + sent += err; + base = 0; + } else + base -= xdr->head[0].iov_len; + + if (base < xdr->page_len) { + unsigned int len = xdr->page_len - base; + remainder -= len; + err = xs_send_pagedata(sock, xdr, base, remainder != 0); + if (remainder == 0 || err != len) + goto out; + sent += err; + base = 0; + } else + base -= xdr->page_len; + + if (base >= xdr->tail[0].iov_len) + return sent; + err = xs_send_kvec(sock, NULL, 0, &xdr->tail[0], base, 0); +out: + if (sent == 0) + return err; + if (err > 0) + sent += err; + return sent; +} + +static void xs_nospace_callback(struct rpc_task *task) +{ + struct sock_xprt *transport = container_of(task->tk_rqstp->rq_xprt, struct sock_xprt, xprt); + + transport->inet->sk_write_pending--; + clear_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags); +} + +/** + * xs_nospace - place task on wait queue if transmit was incomplete + * @task: task to put to sleep + * + */ +static void xs_nospace(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + dprintk("RPC: %5u xmit incomplete (%u left of %u)\n", + task->tk_pid, req->rq_slen - req->rq_bytes_sent, + req->rq_slen); + + /* Protect against races with write_space */ + spin_lock_bh(&xprt->transport_lock); + + /* Don't race with disconnect */ + if (xprt_connected(xprt)) { + if (test_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags)) { + /* + * Notify TCP that we're limited by the application + * window size + */ + set_bit(SOCK_NOSPACE, &transport->sock->flags); + transport->inet->sk_write_pending++; + /* ...and wait for more buffer space */ + xprt_wait_for_buffer_space(task, xs_nospace_callback); + } + } else { + clear_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags); + task->tk_status = -ENOTCONN; + } + + spin_unlock_bh(&xprt->transport_lock); +} + +/** + * xs_udp_send_request - write an RPC request to a UDP socket + * @task: address of RPC task that manages the state of an RPC request + * + * Return values: + * 0: The request has been sent + * EAGAIN: The socket was blocked, please call again later to + * complete the request + * ENOTCONN: Caller needs to invoke connect logic then call again + * other: Some other error occured, the request was not sent + */ +static int xs_udp_send_request(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct xdr_buf *xdr = &req->rq_snd_buf; + int status; + + xs_pktdump("packet data:", + req->rq_svec->iov_base, + req->rq_svec->iov_len); + + status = xs_sendpages(transport->sock, + xs_addr(xprt), + xprt->addrlen, xdr, + req->rq_bytes_sent); + + dprintk("RPC: xs_udp_send_request(%u) = %d\n", + xdr->len - req->rq_bytes_sent, status); + + if (status >= 0) { + task->tk_bytes_sent += status; + if (status >= req->rq_slen) + return 0; + /* Still some bytes left; set up for a retry later. */ + status = -EAGAIN; + } + + switch (status) { + case -EAGAIN: + xs_nospace(task); + break; + case -ENETUNREACH: + case -EPIPE: + case -ECONNREFUSED: + /* When the server has died, an ICMP port unreachable message + * prompts ECONNREFUSED. */ + clear_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags); + break; + default: + clear_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags); + dprintk("RPC: sendmsg returned unrecognized error %d\n", + -status); + } + + return status; +} + +/** + * xs_tcp_shutdown - gracefully shut down a TCP socket + * @xprt: transport + * + * Initiates a graceful shutdown of the TCP socket by calling the + * equivalent of shutdown(SHUT_WR); + */ +static void xs_tcp_shutdown(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct socket *sock = transport->sock; + + if (sock != NULL) + kernel_sock_shutdown(sock, SHUT_WR); +} + +static inline void xs_encode_tcp_record_marker(struct xdr_buf *buf) +{ + u32 reclen = buf->len - sizeof(rpc_fraghdr); + rpc_fraghdr *base = buf->head[0].iov_base; + *base = htonl(RPC_LAST_STREAM_FRAGMENT | reclen); +} + +/** + * xs_tcp_send_request - write an RPC request to a TCP socket + * @task: address of RPC task that manages the state of an RPC request + * + * Return values: + * 0: The request has been sent + * EAGAIN: The socket was blocked, please call again later to + * complete the request + * ENOTCONN: Caller needs to invoke connect logic then call again + * other: Some other error occured, the request was not sent + * + * XXX: In the case of soft timeouts, should we eventually give up + * if sendmsg is not able to make progress? + */ +static int xs_tcp_send_request(struct rpc_task *task) +{ + struct rpc_rqst *req = task->tk_rqstp; + struct rpc_xprt *xprt = req->rq_xprt; + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct xdr_buf *xdr = &req->rq_snd_buf; + int status; + + xs_encode_tcp_record_marker(&req->rq_snd_buf); + + xs_pktdump("packet data:", + req->rq_svec->iov_base, + req->rq_svec->iov_len); + + /* Continue transmitting the packet/record. We must be careful + * to cope with writespace callbacks arriving _after_ we have + * called sendmsg(). */ + while (1) { + status = xs_sendpages(transport->sock, + NULL, 0, xdr, req->rq_bytes_sent); + + dprintk("RPC: xs_tcp_send_request(%u) = %d\n", + xdr->len - req->rq_bytes_sent, status); + + if (unlikely(status < 0)) + break; + + /* If we've sent the entire packet, immediately + * reset the count of bytes sent. */ + req->rq_bytes_sent += status; + task->tk_bytes_sent += status; + if (likely(req->rq_bytes_sent >= req->rq_slen)) { + req->rq_bytes_sent = 0; + return 0; + } + + if (status != 0) + continue; + status = -EAGAIN; + break; + } + + switch (status) { + case -EAGAIN: + xs_nospace(task); + break; + case -ECONNRESET: + xs_tcp_shutdown(xprt); + case -ECONNREFUSED: + case -ENOTCONN: + case -EPIPE: + status = -ENOTCONN; + clear_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags); + break; + default: + dprintk("RPC: sendmsg returned unrecognized error %d\n", + -status); + clear_bit(SOCK_ASYNC_NOSPACE, &transport->sock->flags); + xs_tcp_shutdown(xprt); + } + + return status; +} + +/** + * xs_tcp_release_xprt - clean up after a tcp transmission + * @xprt: transport + * @task: rpc task + * + * This cleans up if an error causes us to abort the transmission of a request. + * In this case, the socket may need to be reset in order to avoid confusing + * the server. + */ +static void xs_tcp_release_xprt(struct rpc_xprt *xprt, struct rpc_task *task) +{ + struct rpc_rqst *req; + + if (task != xprt->snd_task) + return; + if (task == NULL) + goto out_release; + req = task->tk_rqstp; + if (req->rq_bytes_sent == 0) + goto out_release; + if (req->rq_bytes_sent == req->rq_snd_buf.len) + goto out_release; + set_bit(XPRT_CLOSE_WAIT, &task->tk_xprt->state); +out_release: + xprt_release_xprt(xprt, task); +} + +static void xs_save_old_callbacks(struct sock_xprt *transport, struct sock *sk) +{ + transport->old_data_ready = sk->sk_data_ready; + transport->old_state_change = sk->sk_state_change; + transport->old_write_space = sk->sk_write_space; + transport->old_error_report = sk->sk_error_report; +} + +static void xs_restore_old_callbacks(struct sock_xprt *transport, struct sock *sk) +{ + sk->sk_data_ready = transport->old_data_ready; + sk->sk_state_change = transport->old_state_change; + sk->sk_write_space = transport->old_write_space; + sk->sk_error_report = transport->old_error_report; +} + +/** + * xs_close - close a socket + * @xprt: transport + * + * This is used when all requests are complete; ie, no DRC state remains + * on the server we want to save. + */ +static void xs_close(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct socket *sock = transport->sock; + struct sock *sk = transport->inet; + + if (!sk) + goto clear_close_wait; + + dprintk("RPC: xs_close xprt %p\n", xprt); + + write_lock_bh(&sk->sk_callback_lock); + transport->inet = NULL; + transport->sock = NULL; + + sk->sk_user_data = NULL; + + xs_restore_old_callbacks(transport, sk); + write_unlock_bh(&sk->sk_callback_lock); + + sk->sk_no_check = 0; + + sock_release(sock); +clear_close_wait: + smp_mb__before_clear_bit(); + clear_bit(XPRT_CLOSE_WAIT, &xprt->state); + clear_bit(XPRT_CLOSING, &xprt->state); + smp_mb__after_clear_bit(); + xprt_disconnect_done(xprt); +} + +/** + * xs_destroy - prepare to shutdown a transport + * @xprt: doomed transport + * + */ +static void xs_destroy(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + dprintk("RPC: xs_destroy xprt %p\n", xprt); + + cancel_rearming_delayed_work(&transport->connect_worker); + + xs_close(xprt); + xs_free_peer_addresses(xprt); + kfree(xprt->slot); + kfree(xprt); + module_put(THIS_MODULE); +} + +static inline struct rpc_xprt *xprt_from_sock(struct sock *sk) +{ + return (struct rpc_xprt *) sk->sk_user_data; +} + +/** + * xs_udp_data_ready - "data ready" callback for UDP sockets + * @sk: socket with data to read + * @len: how much data to read + * + */ +static void xs_udp_data_ready(struct sock *sk, int len) +{ + struct rpc_task *task; + struct rpc_xprt *xprt; + struct rpc_rqst *rovr; + struct sk_buff *skb; + int err, repsize, copied; + u32 _xid; + __be32 *xp; + + read_lock(&sk->sk_callback_lock); + dprintk("RPC: xs_udp_data_ready...\n"); + if (!(xprt = xprt_from_sock(sk))) + goto out; + + if ((skb = skb_recv_datagram(sk, 0, 1, &err)) == NULL) + goto out; + + if (xprt->shutdown) + goto dropit; + + repsize = skb->len - sizeof(struct udphdr); + if (repsize < 4) { + dprintk("RPC: impossible RPC reply size %d!\n", repsize); + goto dropit; + } + + /* Copy the XID from the skb... */ + xp = skb_header_pointer(skb, sizeof(struct udphdr), + sizeof(_xid), &_xid); + if (xp == NULL) + goto dropit; + + /* Look up and lock the request corresponding to the given XID */ + spin_lock(&xprt->transport_lock); + rovr = xprt_lookup_rqst(xprt, *xp); + if (!rovr) + goto out_unlock; + task = rovr->rq_task; + + if ((copied = rovr->rq_private_buf.buflen) > repsize) + copied = repsize; + + /* Suck it into the iovec, verify checksum if not done by hw. */ + if (csum_partial_copy_to_xdr(&rovr->rq_private_buf, skb)) { + UDPX_INC_STATS_BH(sk, UDP_MIB_INERRORS); + goto out_unlock; + } + + UDPX_INC_STATS_BH(sk, UDP_MIB_INDATAGRAMS); + + /* Something worked... */ + dst_confirm(skb->dst); + + xprt_adjust_cwnd(task, copied); + xprt_update_rtt(task); + xprt_complete_rqst(task, copied); + + out_unlock: + spin_unlock(&xprt->transport_lock); + dropit: + skb_free_datagram(sk, skb); + out: + read_unlock(&sk->sk_callback_lock); +} + +static inline void xs_tcp_read_fraghdr(struct rpc_xprt *xprt, struct xdr_skb_reader *desc) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + size_t len, used; + char *p; + + p = ((char *) &transport->tcp_fraghdr) + transport->tcp_offset; + len = sizeof(transport->tcp_fraghdr) - transport->tcp_offset; + used = xdr_skb_read_bits(desc, p, len); + transport->tcp_offset += used; + if (used != len) + return; + + transport->tcp_reclen = ntohl(transport->tcp_fraghdr); + if (transport->tcp_reclen & RPC_LAST_STREAM_FRAGMENT) + transport->tcp_flags |= TCP_RCV_LAST_FRAG; + else + transport->tcp_flags &= ~TCP_RCV_LAST_FRAG; + transport->tcp_reclen &= RPC_FRAGMENT_SIZE_MASK; + + transport->tcp_flags &= ~TCP_RCV_COPY_FRAGHDR; + transport->tcp_offset = 0; + + /* Sanity check of the record length */ + if (unlikely(transport->tcp_reclen < 4)) { + dprintk("RPC: invalid TCP record fragment length\n"); + xprt_force_disconnect(xprt); + return; + } + dprintk("RPC: reading TCP record fragment of length %d\n", + transport->tcp_reclen); +} + +static void xs_tcp_check_fraghdr(struct sock_xprt *transport) +{ + if (transport->tcp_offset == transport->tcp_reclen) { + transport->tcp_flags |= TCP_RCV_COPY_FRAGHDR; + transport->tcp_offset = 0; + if (transport->tcp_flags & TCP_RCV_LAST_FRAG) { + transport->tcp_flags &= ~TCP_RCV_COPY_DATA; + transport->tcp_flags |= TCP_RCV_COPY_XID; + transport->tcp_copied = 0; + } + } +} + +static inline void xs_tcp_read_xid(struct sock_xprt *transport, struct xdr_skb_reader *desc) +{ + size_t len, used; + char *p; + + len = sizeof(transport->tcp_xid) - transport->tcp_offset; + dprintk("RPC: reading XID (%Zu bytes)\n", len); + p = ((char *) &transport->tcp_xid) + transport->tcp_offset; + used = xdr_skb_read_bits(desc, p, len); + transport->tcp_offset += used; + if (used != len) + return; + transport->tcp_flags &= ~TCP_RCV_COPY_XID; + transport->tcp_flags |= TCP_RCV_COPY_DATA; + transport->tcp_copied = 4; + dprintk("RPC: reading reply for XID %08x\n", + ntohl(transport->tcp_xid)); + xs_tcp_check_fraghdr(transport); +} + +static inline void xs_tcp_read_request(struct rpc_xprt *xprt, struct xdr_skb_reader *desc) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct rpc_rqst *req; + struct xdr_buf *rcvbuf; + size_t len; + ssize_t r; + + /* Find and lock the request corresponding to this xid */ + spin_lock(&xprt->transport_lock); + req = xprt_lookup_rqst(xprt, transport->tcp_xid); + if (!req) { + transport->tcp_flags &= ~TCP_RCV_COPY_DATA; + dprintk("RPC: XID %08x request not found!\n", + ntohl(transport->tcp_xid)); + spin_unlock(&xprt->transport_lock); + return; + } + + rcvbuf = &req->rq_private_buf; + len = desc->count; + if (len > transport->tcp_reclen - transport->tcp_offset) { + struct xdr_skb_reader my_desc; + + len = transport->tcp_reclen - transport->tcp_offset; + memcpy(&my_desc, desc, sizeof(my_desc)); + my_desc.count = len; + r = xdr_partial_copy_from_skb(rcvbuf, transport->tcp_copied, + &my_desc, xdr_skb_read_bits); + desc->count -= r; + desc->offset += r; + } else + r = xdr_partial_copy_from_skb(rcvbuf, transport->tcp_copied, + desc, xdr_skb_read_bits); + + if (r > 0) { + transport->tcp_copied += r; + transport->tcp_offset += r; + } + if (r != len) { + /* Error when copying to the receive buffer, + * usually because we weren't able to allocate + * additional buffer pages. All we can do now + * is turn off TCP_RCV_COPY_DATA, so the request + * will not receive any additional updates, + * and time out. + * Any remaining data from this record will + * be discarded. + */ + transport->tcp_flags &= ~TCP_RCV_COPY_DATA; + dprintk("RPC: XID %08x truncated request\n", + ntohl(transport->tcp_xid)); + dprintk("RPC: xprt = %p, tcp_copied = %lu, " + "tcp_offset = %u, tcp_reclen = %u\n", + xprt, transport->tcp_copied, + transport->tcp_offset, transport->tcp_reclen); + goto out; + } + + dprintk("RPC: XID %08x read %Zd bytes\n", + ntohl(transport->tcp_xid), r); + dprintk("RPC: xprt = %p, tcp_copied = %lu, tcp_offset = %u, " + "tcp_reclen = %u\n", xprt, transport->tcp_copied, + transport->tcp_offset, transport->tcp_reclen); + + if (transport->tcp_copied == req->rq_private_buf.buflen) + transport->tcp_flags &= ~TCP_RCV_COPY_DATA; + else if (transport->tcp_offset == transport->tcp_reclen) { + if (transport->tcp_flags & TCP_RCV_LAST_FRAG) + transport->tcp_flags &= ~TCP_RCV_COPY_DATA; + } + +out: + if (!(transport->tcp_flags & TCP_RCV_COPY_DATA)) + xprt_complete_rqst(req->rq_task, transport->tcp_copied); + spin_unlock(&xprt->transport_lock); + xs_tcp_check_fraghdr(transport); +} + +static inline void xs_tcp_read_discard(struct sock_xprt *transport, struct xdr_skb_reader *desc) +{ + size_t len; + + len = transport->tcp_reclen - transport->tcp_offset; + if (len > desc->count) + len = desc->count; + desc->count -= len; + desc->offset += len; + transport->tcp_offset += len; + dprintk("RPC: discarded %Zu bytes\n", len); + xs_tcp_check_fraghdr(transport); +} + +static int xs_tcp_data_recv(read_descriptor_t *rd_desc, struct sk_buff *skb, unsigned int offset, size_t len) +{ + struct rpc_xprt *xprt = rd_desc->arg.data; + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct xdr_skb_reader desc = { + .skb = skb, + .offset = offset, + .count = len, + }; + + dprintk("RPC: xs_tcp_data_recv started\n"); + do { + /* Read in a new fragment marker if necessary */ + /* Can we ever really expect to get completely empty fragments? */ + if (transport->tcp_flags & TCP_RCV_COPY_FRAGHDR) { + xs_tcp_read_fraghdr(xprt, &desc); + continue; + } + /* Read in the xid if necessary */ + if (transport->tcp_flags & TCP_RCV_COPY_XID) { + xs_tcp_read_xid(transport, &desc); + continue; + } + /* Read in the request data */ + if (transport->tcp_flags & TCP_RCV_COPY_DATA) { + xs_tcp_read_request(xprt, &desc); + continue; + } + /* Skip over any trailing bytes on short reads */ + xs_tcp_read_discard(transport, &desc); + } while (desc.count); + dprintk("RPC: xs_tcp_data_recv done\n"); + return len - desc.count; +} + +/** + * xs_tcp_data_ready - "data ready" callback for TCP sockets + * @sk: socket with data to read + * @bytes: how much data to read + * + */ +static void xs_tcp_data_ready(struct sock *sk, int bytes) +{ + struct rpc_xprt *xprt; + read_descriptor_t rd_desc; + int read; + + dprintk("RPC: xs_tcp_data_ready...\n"); + + read_lock(&sk->sk_callback_lock); + if (!(xprt = xprt_from_sock(sk))) + goto out; + if (xprt->shutdown) + goto out; + + /* We use rd_desc to pass struct xprt to xs_tcp_data_recv */ + rd_desc.arg.data = xprt; + do { + rd_desc.count = 65536; + read = tcp_read_sock(sk, &rd_desc, xs_tcp_data_recv); + } while (read > 0); +out: + read_unlock(&sk->sk_callback_lock); +} + +/** + * xs_tcp_state_change - callback to handle TCP socket state changes + * @sk: socket whose state has changed + * + */ +static void xs_tcp_state_change(struct sock *sk) +{ + struct rpc_xprt *xprt; + + read_lock(&sk->sk_callback_lock); + if (!(xprt = xprt_from_sock(sk))) + goto out; + dprintk("RPC: xs_tcp_state_change client %p...\n", xprt); + dprintk("RPC: state %x conn %d dead %d zapped %d\n", + sk->sk_state, xprt_connected(xprt), + sock_flag(sk, SOCK_DEAD), + sock_flag(sk, SOCK_ZAPPED)); + + switch (sk->sk_state) { + case TCP_ESTABLISHED: + spin_lock_bh(&xprt->transport_lock); + if (!xprt_test_and_set_connected(xprt)) { + struct sock_xprt *transport = container_of(xprt, + struct sock_xprt, xprt); + + /* Reset TCP record info */ + transport->tcp_offset = 0; + transport->tcp_reclen = 0; + transport->tcp_copied = 0; + transport->tcp_flags = + TCP_RCV_COPY_FRAGHDR | TCP_RCV_COPY_XID; + + xprt_wake_pending_tasks(xprt, 0); + } + spin_unlock_bh(&xprt->transport_lock); + break; + case TCP_FIN_WAIT1: + /* The client initiated a shutdown of the socket */ + xprt->connect_cookie++; + xprt->reestablish_timeout = 0; + set_bit(XPRT_CLOSING, &xprt->state); + smp_mb__before_clear_bit(); + clear_bit(XPRT_CONNECTED, &xprt->state); + clear_bit(XPRT_CLOSE_WAIT, &xprt->state); + smp_mb__after_clear_bit(); + break; + case TCP_CLOSE_WAIT: + /* The server initiated a shutdown of the socket */ + set_bit(XPRT_CLOSING, &xprt->state); + xprt_force_disconnect(xprt); + case TCP_SYN_SENT: + xprt->connect_cookie++; + case TCP_CLOSING: + /* + * If the server closed down the connection, make sure that + * we back off before reconnecting + */ + if (xprt->reestablish_timeout < XS_TCP_INIT_REEST_TO) + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + break; + case TCP_LAST_ACK: + smp_mb__before_clear_bit(); + clear_bit(XPRT_CONNECTED, &xprt->state); + smp_mb__after_clear_bit(); + break; + case TCP_CLOSE: + smp_mb__before_clear_bit(); + clear_bit(XPRT_CLOSE_WAIT, &xprt->state); + clear_bit(XPRT_CLOSING, &xprt->state); + smp_mb__after_clear_bit(); + /* Mark transport as closed and wake up all pending tasks */ + xprt_disconnect_done(xprt); + } + out: + read_unlock(&sk->sk_callback_lock); +} + +/** + * xs_tcp_error_report - callback mainly for catching RST events + * @sk: socket + */ +static void xs_tcp_error_report(struct sock *sk) +{ + struct rpc_xprt *xprt; + + read_lock(&sk->sk_callback_lock); + if (sk->sk_err != ECONNRESET || sk->sk_state != TCP_ESTABLISHED) + goto out; + if (!(xprt = xprt_from_sock(sk))) + goto out; + dprintk("RPC: %s client %p...\n" + "RPC: error %d\n", + __func__, xprt, sk->sk_err); + + xprt_force_disconnect(xprt); +out: + read_unlock(&sk->sk_callback_lock); +} + +/** + * xs_udp_write_space - callback invoked when socket buffer space + * becomes available + * @sk: socket whose state has changed + * + * Called when more output buffer space is available for this socket. + * We try not to wake our writers until they can make "significant" + * progress, otherwise we'll waste resources thrashing kernel_sendmsg + * with a bunch of small requests. + */ +static void xs_udp_write_space(struct sock *sk) +{ + read_lock(&sk->sk_callback_lock); + + /* from net/core/sock.c:sock_def_write_space */ + if (sock_writeable(sk)) { + struct socket *sock; + struct rpc_xprt *xprt; + + if (unlikely(!(sock = sk->sk_socket))) + goto out; + clear_bit(SOCK_NOSPACE, &sock->flags); + + if (unlikely(!(xprt = xprt_from_sock(sk)))) + goto out; + if (test_and_clear_bit(SOCK_ASYNC_NOSPACE, &sock->flags) == 0) + goto out; + + xprt_write_space(xprt); + } + + out: + read_unlock(&sk->sk_callback_lock); +} + +/** + * xs_tcp_write_space - callback invoked when socket buffer space + * becomes available + * @sk: socket whose state has changed + * + * Called when more output buffer space is available for this socket. + * We try not to wake our writers until they can make "significant" + * progress, otherwise we'll waste resources thrashing kernel_sendmsg + * with a bunch of small requests. + */ +static void xs_tcp_write_space(struct sock *sk) +{ + read_lock(&sk->sk_callback_lock); + + /* from net/core/stream.c:sk_stream_write_space */ + if (sk_stream_wspace(sk) >= sk_stream_min_wspace(sk)) { + struct socket *sock; + struct rpc_xprt *xprt; + + if (unlikely(!(sock = sk->sk_socket))) + goto out; + clear_bit(SOCK_NOSPACE, &sock->flags); + + if (unlikely(!(xprt = xprt_from_sock(sk)))) + goto out; + if (test_and_clear_bit(SOCK_ASYNC_NOSPACE, &sock->flags) == 0) + goto out; + + xprt_write_space(xprt); + } + + out: + read_unlock(&sk->sk_callback_lock); +} + +static void xs_udp_do_set_buffer_size(struct rpc_xprt *xprt) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct sock *sk = transport->inet; + + if (transport->rcvsize) { + sk->sk_userlocks |= SOCK_RCVBUF_LOCK; + sk->sk_rcvbuf = transport->rcvsize * xprt->max_reqs * 2; + } + if (transport->sndsize) { + sk->sk_userlocks |= SOCK_SNDBUF_LOCK; + sk->sk_sndbuf = transport->sndsize * xprt->max_reqs * 2; + sk->sk_write_space(sk); + } +} + +/** + * xs_udp_set_buffer_size - set send and receive limits + * @xprt: generic transport + * @sndsize: requested size of send buffer, in bytes + * @rcvsize: requested size of receive buffer, in bytes + * + * Set socket send and receive buffer size limits. + */ +static void xs_udp_set_buffer_size(struct rpc_xprt *xprt, size_t sndsize, size_t rcvsize) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + transport->sndsize = 0; + if (sndsize) + transport->sndsize = sndsize + 1024; + transport->rcvsize = 0; + if (rcvsize) + transport->rcvsize = rcvsize + 1024; + + xs_udp_do_set_buffer_size(xprt); +} + +/** + * xs_udp_timer - called when a retransmit timeout occurs on a UDP transport + * @task: task that timed out + * + * Adjust the congestion window after a retransmit timeout has occurred. + */ +static void xs_udp_timer(struct rpc_task *task) +{ + xprt_adjust_cwnd(task, -ETIMEDOUT); +} + +static unsigned short xs_get_random_port(void) +{ + unsigned short range = xprt_max_resvport - xprt_min_resvport; + unsigned short rand = (unsigned short) net_random() % range; + return rand + xprt_min_resvport; +} + +/** + * xs_set_port - reset the port number in the remote endpoint address + * @xprt: generic transport + * @port: new port number + * + */ +static void xs_set_port(struct rpc_xprt *xprt, unsigned short port) +{ + struct sockaddr *addr = xs_addr(xprt); + + dprintk("RPC: setting port for xprt %p to %u\n", xprt, port); + + switch (addr->sa_family) { + case AF_INET: + ((struct sockaddr_in *)addr)->sin_port = htons(port); + break; + case AF_INET6: + ((struct sockaddr_in6 *)addr)->sin6_port = htons(port); + break; + default: + BUG(); + } +} + +static unsigned short xs_get_srcport(struct sock_xprt *transport, struct socket *sock) +{ + unsigned short port = transport->port; + + if (port == 0 && transport->xprt.resvport) + port = xs_get_random_port(); + return port; +} + +static unsigned short xs_next_srcport(struct sock_xprt *transport, struct socket *sock, unsigned short port) +{ + if (transport->port != 0) + transport->port = 0; + if (!transport->xprt.resvport) + return 0; + if (port <= xprt_min_resvport || port > xprt_max_resvport) + return xprt_max_resvport; + return --port; +} + +static int xs_bind4(struct sock_xprt *transport, struct socket *sock) +{ + struct sockaddr_in myaddr = { + .sin_family = AF_INET, + }; + struct sockaddr_in *sa; + int err, nloop = 0; + unsigned short port = xs_get_srcport(transport, sock); + unsigned short last; + + sa = (struct sockaddr_in *)&transport->addr; + myaddr.sin_addr = sa->sin_addr; + do { + myaddr.sin_port = htons(port); + err = kernel_bind(sock, (struct sockaddr *) &myaddr, + sizeof(myaddr)); + if (port == 0) + break; + if (err == 0) { + transport->port = port; + break; + } + last = port; + port = xs_next_srcport(transport, sock, port); + if (port > last) + nloop++; + } while (err == -EADDRINUSE && nloop != 2); + dprintk("RPC: %s "NIPQUAD_FMT":%u: %s (%d)\n", + __func__, NIPQUAD(myaddr.sin_addr), + port, err ? "failed" : "ok", err); + return err; +} + +static int xs_bind6(struct sock_xprt *transport, struct socket *sock) +{ + struct sockaddr_in6 myaddr = { + .sin6_family = AF_INET6, + }; + struct sockaddr_in6 *sa; + int err, nloop = 0; + unsigned short port = xs_get_srcport(transport, sock); + unsigned short last; + + sa = (struct sockaddr_in6 *)&transport->addr; + myaddr.sin6_addr = sa->sin6_addr; + do { + myaddr.sin6_port = htons(port); + err = kernel_bind(sock, (struct sockaddr *) &myaddr, + sizeof(myaddr)); + if (port == 0) + break; + if (err == 0) { + transport->port = port; + break; + } + last = port; + port = xs_next_srcport(transport, sock, port); + if (port > last) + nloop++; + } while (err == -EADDRINUSE && nloop != 2); + dprintk("RPC: xs_bind6 "NIP6_FMT":%u: %s (%d)\n", + NIP6(myaddr.sin6_addr), port, err ? "failed" : "ok", err); + return err; +} + +#ifdef CONFIG_DEBUG_LOCK_ALLOC +static struct lock_class_key xs_key[2]; +static struct lock_class_key xs_slock_key[2]; + +static inline void xs_reclassify_socket4(struct socket *sock) +{ + struct sock *sk = sock->sk; + + BUG_ON(sock_owned_by_user(sk)); + sock_lock_init_class_and_name(sk, "slock-AF_INET-RPC", + &xs_slock_key[0], "sk_lock-AF_INET-RPC", &xs_key[0]); +} + +static inline void xs_reclassify_socket6(struct socket *sock) +{ + struct sock *sk = sock->sk; + + BUG_ON(sock_owned_by_user(sk)); + sock_lock_init_class_and_name(sk, "slock-AF_INET6-RPC", + &xs_slock_key[1], "sk_lock-AF_INET6-RPC", &xs_key[1]); +} +#else +static inline void xs_reclassify_socket4(struct socket *sock) +{ +} + +static inline void xs_reclassify_socket6(struct socket *sock) +{ +} +#endif + +static void xs_udp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + if (!transport->inet) { + struct sock *sk = sock->sk; + + write_lock_bh(&sk->sk_callback_lock); + + xs_save_old_callbacks(transport, sk); + + sk->sk_user_data = xprt; + sk->sk_data_ready = xs_udp_data_ready; + sk->sk_write_space = xs_udp_write_space; + sk->sk_no_check = UDP_CSUM_NORCV; + sk->sk_allocation = GFP_ATOMIC; + + xprt_set_connected(xprt); + + /* Reset to new socket */ + transport->sock = sock; + transport->inet = sk; + + write_unlock_bh(&sk->sk_callback_lock); + } + xs_udp_do_set_buffer_size(xprt); +} + +/** + * xs_udp_connect_worker4 - set up a UDP socket + * @work: RPC transport to connect + * + * Invoked by a work queue tasklet. + */ +static void xs_udp_connect_worker4(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, connect_worker.work); + struct rpc_xprt *xprt = &transport->xprt; + struct socket *sock = transport->sock; + int err, status = -EIO; + + if (xprt->shutdown || !xprt_bound(xprt)) + goto out; + + /* Start by resetting any existing state */ + xs_close(xprt); + + if ((err = sock_create_kern(PF_INET, SOCK_DGRAM, IPPROTO_UDP, &sock)) < 0) { + dprintk("RPC: can't create UDP transport socket (%d).\n", -err); + goto out; + } + xs_reclassify_socket4(sock); + + if (xs_bind4(transport, sock)) { + sock_release(sock); + goto out; + } + + dprintk("RPC: worker connecting xprt %p to address: %s\n", + xprt, xprt->address_strings[RPC_DISPLAY_ALL]); + + xs_udp_finish_connecting(xprt, sock); + status = 0; +out: + xprt_wake_pending_tasks(xprt, status); + xprt_clear_connecting(xprt); +} + +/** + * xs_udp_connect_worker6 - set up a UDP socket + * @work: RPC transport to connect + * + * Invoked by a work queue tasklet. + */ +static void xs_udp_connect_worker6(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, connect_worker.work); + struct rpc_xprt *xprt = &transport->xprt; + struct socket *sock = transport->sock; + int err, status = -EIO; + + if (xprt->shutdown || !xprt_bound(xprt)) + goto out; + + /* Start by resetting any existing state */ + xs_close(xprt); + + if ((err = sock_create_kern(PF_INET6, SOCK_DGRAM, IPPROTO_UDP, &sock)) < 0) { + dprintk("RPC: can't create UDP transport socket (%d).\n", -err); + goto out; + } + xs_reclassify_socket6(sock); + + if (xs_bind6(transport, sock) < 0) { + sock_release(sock); + goto out; + } + + dprintk("RPC: worker connecting xprt %p to address: %s\n", + xprt, xprt->address_strings[RPC_DISPLAY_ALL]); + + xs_udp_finish_connecting(xprt, sock); + status = 0; +out: + xprt_wake_pending_tasks(xprt, status); + xprt_clear_connecting(xprt); +} + +/* + * We need to preserve the port number so the reply cache on the server can + * find our cached RPC replies when we get around to reconnecting. + */ +static void xs_tcp_reuse_connection(struct rpc_xprt *xprt) +{ + int result; + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + struct sockaddr any; + + dprintk("RPC: disconnecting xprt %p to reuse port\n", xprt); + + /* + * Disconnect the transport socket by doing a connect operation + * with AF_UNSPEC. This should return immediately... + */ + memset(&any, 0, sizeof(any)); + any.sa_family = AF_UNSPEC; + result = kernel_connect(transport->sock, &any, sizeof(any), 0); + if (result) + dprintk("RPC: AF_UNSPEC connect return code %d\n", + result); +} + +static int xs_tcp_finish_connecting(struct rpc_xprt *xprt, struct socket *sock) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + if (!transport->inet) { + struct sock *sk = sock->sk; + + write_lock_bh(&sk->sk_callback_lock); + + xs_save_old_callbacks(transport, sk); + + sk->sk_user_data = xprt; + sk->sk_data_ready = xs_tcp_data_ready; + sk->sk_state_change = xs_tcp_state_change; + sk->sk_write_space = xs_tcp_write_space; + sk->sk_error_report = xs_tcp_error_report; + sk->sk_allocation = GFP_ATOMIC; + + /* socket options */ + sk->sk_userlocks |= SOCK_BINDPORT_LOCK; + sock_reset_flag(sk, SOCK_LINGER); + tcp_sk(sk)->linger2 = 0; + tcp_sk(sk)->nonagle |= TCP_NAGLE_OFF; + + xprt_clear_connected(xprt); + + /* Reset to new socket */ + transport->sock = sock; + transport->inet = sk; + + write_unlock_bh(&sk->sk_callback_lock); + } + + /* Tell the socket layer to start connecting... */ + xprt->stat.connect_count++; + xprt->stat.connect_start = jiffies; + return kernel_connect(sock, xs_addr(xprt), xprt->addrlen, O_NONBLOCK); +} + +/** + * xs_tcp_connect_worker4 - connect a TCP socket to a remote endpoint + * @work: RPC transport to connect + * + * Invoked by a work queue tasklet. + */ +static void xs_tcp_connect_worker4(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, connect_worker.work); + struct rpc_xprt *xprt = &transport->xprt; + struct socket *sock = transport->sock; + int err, status = -EIO; + + if (xprt->shutdown || !xprt_bound(xprt)) + goto out; + + if (!sock) { + /* start from scratch */ + if ((err = sock_create_kern(PF_INET, SOCK_STREAM, IPPROTO_TCP, &sock)) < 0) { + dprintk("RPC: can't create TCP transport socket (%d).\n", -err); + goto out; + } + xs_reclassify_socket4(sock); + + if (xs_bind4(transport, sock) < 0) { + sock_release(sock); + goto out; + } + } else + /* "close" the socket, preserving the local port */ + xs_tcp_reuse_connection(xprt); + + dprintk("RPC: worker connecting xprt %p to address: %s\n", + xprt, xprt->address_strings[RPC_DISPLAY_ALL]); + + status = xs_tcp_finish_connecting(xprt, sock); + dprintk("RPC: %p connect status %d connected %d sock state %d\n", + xprt, -status, xprt_connected(xprt), + sock->sk->sk_state); + if (status < 0) { + switch (status) { + case -EINPROGRESS: + case -EALREADY: + goto out_clear; + case -ECONNREFUSED: + case -ECONNRESET: + /* retry with existing socket, after a delay */ + break; + default: + /* get rid of existing socket, and retry */ + xs_tcp_shutdown(xprt); + } + } +out: + xprt_wake_pending_tasks(xprt, status); +out_clear: + xprt_clear_connecting(xprt); +} + +/** + * xs_tcp_connect_worker6 - connect a TCP socket to a remote endpoint + * @work: RPC transport to connect + * + * Invoked by a work queue tasklet. + */ +static void xs_tcp_connect_worker6(struct work_struct *work) +{ + struct sock_xprt *transport = + container_of(work, struct sock_xprt, connect_worker.work); + struct rpc_xprt *xprt = &transport->xprt; + struct socket *sock = transport->sock; + int err, status = -EIO; + + if (xprt->shutdown || !xprt_bound(xprt)) + goto out; + + if (!sock) { + /* start from scratch */ + if ((err = sock_create_kern(PF_INET6, SOCK_STREAM, IPPROTO_TCP, &sock)) < 0) { + dprintk("RPC: can't create TCP transport socket (%d).\n", -err); + goto out; + } + xs_reclassify_socket6(sock); + + if (xs_bind6(transport, sock) < 0) { + sock_release(sock); + goto out; + } + } else + /* "close" the socket, preserving the local port */ + xs_tcp_reuse_connection(xprt); + + dprintk("RPC: worker connecting xprt %p to address: %s\n", + xprt, xprt->address_strings[RPC_DISPLAY_ALL]); + + status = xs_tcp_finish_connecting(xprt, sock); + dprintk("RPC: %p connect status %d connected %d sock state %d\n", + xprt, -status, xprt_connected(xprt), sock->sk->sk_state); + if (status < 0) { + switch (status) { + case -EINPROGRESS: + case -EALREADY: + goto out_clear; + case -ECONNREFUSED: + case -ECONNRESET: + /* retry with existing socket, after a delay */ + break; + default: + /* get rid of existing socket, and retry */ + xs_tcp_shutdown(xprt); + } + } +out: + xprt_wake_pending_tasks(xprt, status); +out_clear: + xprt_clear_connecting(xprt); +} + +/** + * xs_connect - connect a socket to a remote endpoint + * @task: address of RPC task that manages state of connect request + * + * TCP: If the remote end dropped the connection, delay reconnecting. + * + * UDP socket connects are synchronous, but we use a work queue anyway + * to guarantee that even unprivileged user processes can set up a + * socket on a privileged port. + * + * If a UDP socket connect fails, the delay behavior here prevents + * retry floods (hard mounts). + */ +static void xs_connect(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + if (xprt_test_and_set_connecting(xprt)) + return; + + if (transport->sock != NULL) { + dprintk("RPC: xs_connect delayed xprt %p for %lu " + "seconds\n", + xprt, xprt->reestablish_timeout / HZ); + queue_delayed_work(rpciod_workqueue, + &transport->connect_worker, + xprt->reestablish_timeout); + xprt->reestablish_timeout <<= 1; + if (xprt->reestablish_timeout > XS_TCP_MAX_REEST_TO) + xprt->reestablish_timeout = XS_TCP_MAX_REEST_TO; + } else { + dprintk("RPC: xs_connect scheduled xprt %p\n", xprt); + queue_delayed_work(rpciod_workqueue, + &transport->connect_worker, 0); + } +} + +static void xs_tcp_connect(struct rpc_task *task) +{ + struct rpc_xprt *xprt = task->tk_xprt; + + /* Initiate graceful shutdown of the socket if not already done */ + if (test_bit(XPRT_CONNECTED, &xprt->state)) + xs_tcp_shutdown(xprt); + /* Exit if we need to wait for socket shutdown to complete */ + if (test_bit(XPRT_CLOSING, &xprt->state)) + return; + xs_connect(task); +} + +/** + * xs_udp_print_stats - display UDP socket-specifc stats + * @xprt: rpc_xprt struct containing statistics + * @seq: output file + * + */ +static void xs_udp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + + seq_printf(seq, "\txprt:\tudp %u %lu %lu %lu %lu %Lu %Lu\n", + transport->port, + xprt->stat.bind_count, + xprt->stat.sends, + xprt->stat.recvs, + xprt->stat.bad_xids, + xprt->stat.req_u, + xprt->stat.bklog_u); +} + +/** + * xs_tcp_print_stats - display TCP socket-specifc stats + * @xprt: rpc_xprt struct containing statistics + * @seq: output file + * + */ +static void xs_tcp_print_stats(struct rpc_xprt *xprt, struct seq_file *seq) +{ + struct sock_xprt *transport = container_of(xprt, struct sock_xprt, xprt); + long idle_time = 0; + + if (xprt_connected(xprt)) + idle_time = (long)(jiffies - xprt->last_used) / HZ; + + seq_printf(seq, "\txprt:\ttcp %u %lu %lu %lu %ld %lu %lu %lu %Lu %Lu\n", + transport->port, + xprt->stat.bind_count, + xprt->stat.connect_count, + xprt->stat.connect_time, + idle_time, + xprt->stat.sends, + xprt->stat.recvs, + xprt->stat.bad_xids, + xprt->stat.req_u, + xprt->stat.bklog_u); +} + +static struct rpc_xprt_ops xs_udp_ops = { + .set_buffer_size = xs_udp_set_buffer_size, + .reserve_xprt = xprt_reserve_xprt_cong, + .release_xprt = xprt_release_xprt_cong, + .rpcbind = rpcb_getport_async, + .set_port = xs_set_port, + .connect = xs_connect, + .buf_alloc = rpc_malloc, + .buf_free = rpc_free, + .send_request = xs_udp_send_request, + .set_retrans_timeout = xprt_set_retrans_timeout_rtt, + .timer = xs_udp_timer, + .release_request = xprt_release_rqst_cong, + .close = xs_close, + .destroy = xs_destroy, + .print_stats = xs_udp_print_stats, +}; + +static struct rpc_xprt_ops xs_tcp_ops = { + .reserve_xprt = xprt_reserve_xprt, + .release_xprt = xs_tcp_release_xprt, + .rpcbind = rpcb_getport_async, + .set_port = xs_set_port, + .connect = xs_tcp_connect, + .buf_alloc = rpc_malloc, + .buf_free = rpc_free, + .send_request = xs_tcp_send_request, + .set_retrans_timeout = xprt_set_retrans_timeout_def, + .close = xs_tcp_shutdown, + .destroy = xs_destroy, + .print_stats = xs_tcp_print_stats, +}; + +static struct rpc_xprt *xs_setup_xprt(struct xprt_create *args, + unsigned int slot_table_size) +{ + struct rpc_xprt *xprt; + struct sock_xprt *new; + + if (args->addrlen > sizeof(xprt->addr)) { + dprintk("RPC: xs_setup_xprt: address too large\n"); + return ERR_PTR(-EBADF); + } + + new = kzalloc(sizeof(*new), GFP_KERNEL); + if (new == NULL) { + dprintk("RPC: xs_setup_xprt: couldn't allocate " + "rpc_xprt\n"); + return ERR_PTR(-ENOMEM); + } + xprt = &new->xprt; + + xprt->max_reqs = slot_table_size; + xprt->slot = kcalloc(xprt->max_reqs, sizeof(struct rpc_rqst), GFP_KERNEL); + if (xprt->slot == NULL) { + kfree(xprt); + dprintk("RPC: xs_setup_xprt: couldn't allocate slot " + "table\n"); + return ERR_PTR(-ENOMEM); + } + + memcpy(&xprt->addr, args->dstaddr, args->addrlen); + xprt->addrlen = args->addrlen; + if (args->srcaddr) + memcpy(&new->addr, args->srcaddr, args->addrlen); + + return xprt; +} + +static const struct rpc_timeout xs_udp_default_timeout = { + .to_initval = 5 * HZ, + .to_maxval = 30 * HZ, + .to_increment = 5 * HZ, + .to_retries = 5, +}; + +/** + * xs_setup_udp - Set up transport to use a UDP socket + * @args: rpc transport creation arguments + * + */ +static struct rpc_xprt *xs_setup_udp(struct xprt_create *args) +{ + struct sockaddr *addr = args->dstaddr; + struct rpc_xprt *xprt; + struct sock_xprt *transport; + + xprt = xs_setup_xprt(args, xprt_udp_slot_table_entries); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = IPPROTO_UDP; + xprt->tsh_size = 0; + /* XXX: header size can vary due to auth type, IPv6, etc. */ + xprt->max_payload = (1U << 16) - (MAX_HEADER << 3); + + xprt->bind_timeout = XS_BIND_TO; + xprt->connect_timeout = XS_UDP_CONN_TO; + xprt->reestablish_timeout = XS_UDP_REEST_TO; + xprt->idle_timeout = XS_IDLE_DISC_TO; + + xprt->ops = &xs_udp_ops; + + xprt->timeout = &xs_udp_default_timeout; + + switch (addr->sa_family) { + case AF_INET: + if (((struct sockaddr_in *)addr)->sin_port != htons(0)) + xprt_set_bound(xprt); + + INIT_DELAYED_WORK(&transport->connect_worker, + xs_udp_connect_worker4); + xs_format_ipv4_peer_addresses(xprt, "udp", RPCBIND_NETID_UDP); + break; + case AF_INET6: + if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0)) + xprt_set_bound(xprt); + + INIT_DELAYED_WORK(&transport->connect_worker, + xs_udp_connect_worker6); + xs_format_ipv6_peer_addresses(xprt, "udp", RPCBIND_NETID_UDP6); + break; + default: + kfree(xprt); + return ERR_PTR(-EAFNOSUPPORT); + } + + dprintk("RPC: set up transport to address %s\n", + xprt->address_strings[RPC_DISPLAY_ALL]); + + if (try_module_get(THIS_MODULE)) + return xprt; + + kfree(xprt->slot); + kfree(xprt); + return ERR_PTR(-EINVAL); +} + +static const struct rpc_timeout xs_tcp_default_timeout = { + .to_initval = 60 * HZ, + .to_maxval = 60 * HZ, + .to_retries = 2, +}; + +/** + * xs_setup_tcp - Set up transport to use a TCP socket + * @args: rpc transport creation arguments + * + */ +static struct rpc_xprt *xs_setup_tcp(struct xprt_create *args) +{ + struct sockaddr *addr = args->dstaddr; + struct rpc_xprt *xprt; + struct sock_xprt *transport; + + xprt = xs_setup_xprt(args, xprt_tcp_slot_table_entries); + if (IS_ERR(xprt)) + return xprt; + transport = container_of(xprt, struct sock_xprt, xprt); + + xprt->prot = IPPROTO_TCP; + xprt->tsh_size = sizeof(rpc_fraghdr) / sizeof(u32); + xprt->max_payload = RPC_MAX_FRAGMENT_SIZE; + + xprt->bind_timeout = XS_BIND_TO; + xprt->connect_timeout = XS_TCP_CONN_TO; + xprt->reestablish_timeout = XS_TCP_INIT_REEST_TO; + xprt->idle_timeout = XS_IDLE_DISC_TO; + + xprt->ops = &xs_tcp_ops; + xprt->timeout = &xs_tcp_default_timeout; + + switch (addr->sa_family) { + case AF_INET: + if (((struct sockaddr_in *)addr)->sin_port != htons(0)) + xprt_set_bound(xprt); + + INIT_DELAYED_WORK(&transport->connect_worker, xs_tcp_connect_worker4); + xs_format_ipv4_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP); + break; + case AF_INET6: + if (((struct sockaddr_in6 *)addr)->sin6_port != htons(0)) + xprt_set_bound(xprt); + + INIT_DELAYED_WORK(&transport->connect_worker, xs_tcp_connect_worker6); + xs_format_ipv6_peer_addresses(xprt, "tcp", RPCBIND_NETID_TCP6); + break; + default: + kfree(xprt); + return ERR_PTR(-EAFNOSUPPORT); + } + + dprintk("RPC: set up transport to address %s\n", + xprt->address_strings[RPC_DISPLAY_ALL]); + + if (try_module_get(THIS_MODULE)) + return xprt; + + kfree(xprt->slot); + kfree(xprt); + return ERR_PTR(-EINVAL); +} + +static struct xprt_class xs_udp_transport = { + .list = LIST_HEAD_INIT(xs_udp_transport.list), + .name = "udp", + .owner = THIS_MODULE, + .ident = IPPROTO_UDP, + .setup = xs_setup_udp, +}; + +static struct xprt_class xs_tcp_transport = { + .list = LIST_HEAD_INIT(xs_tcp_transport.list), + .name = "tcp", + .owner = THIS_MODULE, + .ident = IPPROTO_TCP, + .setup = xs_setup_tcp, +}; + +/** + * init_socket_xprt - set up xprtsock's sysctls, register with RPC client + * + */ +int init_socket_xprt(void) +{ +#ifdef RPC_DEBUG + if (!sunrpc_table_header) + sunrpc_table_header = register_sysctl_table(sunrpc_table); +#endif + + xprt_register_transport(&xs_udp_transport); + xprt_register_transport(&xs_tcp_transport); + + return 0; +} + +/** + * cleanup_socket_xprt - remove xprtsock's sysctls, unregister + * + */ +void cleanup_socket_xprt(void) +{ +#ifdef RPC_DEBUG + if (sunrpc_table_header) { + unregister_sysctl_table(sunrpc_table_header); + sunrpc_table_header = NULL; + } +#endif + + xprt_unregister_transport(&xs_udp_transport); + xprt_unregister_transport(&xs_tcp_transport); +} |