diff options
Diffstat (limited to 'sys/rpc')
-rw-r--r-- | sys/rpc/auth.h | 45 | ||||
-rw-r--r-- | sys/rpc/auth_none.c | 20 | ||||
-rw-r--r-- | sys/rpc/auth_unix.c | 34 | ||||
-rw-r--r-- | sys/rpc/clnt.h | 70 | ||||
-rw-r--r-- | sys/rpc/clnt_dg.c | 310 | ||||
-rw-r--r-- | sys/rpc/clnt_rc.c | 126 | ||||
-rw-r--r-- | sys/rpc/clnt_vc.c | 239 | ||||
-rw-r--r-- | sys/rpc/replay.c | 248 | ||||
-rw-r--r-- | sys/rpc/replay.h | 85 | ||||
-rw-r--r-- | sys/rpc/rpc_com.h | 1 | ||||
-rw-r--r-- | sys/rpc/rpc_generic.c | 134 | ||||
-rw-r--r-- | sys/rpc/rpc_msg.h | 2 | ||||
-rw-r--r-- | sys/rpc/rpc_prot.c | 81 | ||||
-rw-r--r-- | sys/rpc/rpcsec_gss.h | 189 | ||||
-rw-r--r-- | sys/rpc/rpcsec_gss/rpcsec_gss.c | 1064 | ||||
-rw-r--r-- | sys/rpc/rpcsec_gss/rpcsec_gss_conf.c | 163 | ||||
-rw-r--r-- | sys/rpc/rpcsec_gss/rpcsec_gss_int.h | 94 | ||||
-rw-r--r-- | sys/rpc/rpcsec_gss/rpcsec_gss_misc.c | 53 | ||||
-rw-r--r-- | sys/rpc/rpcsec_gss/rpcsec_gss_prot.c | 359 | ||||
-rw-r--r-- | sys/rpc/rpcsec_gss/svc_rpcsec_gss.c | 1485 | ||||
-rw-r--r-- | sys/rpc/svc.c | 1048 | ||||
-rw-r--r-- | sys/rpc/svc.h | 244 | ||||
-rw-r--r-- | sys/rpc/svc_auth.c | 75 | ||||
-rw-r--r-- | sys/rpc/svc_auth.h | 24 | ||||
-rw-r--r-- | sys/rpc/svc_auth_unix.c | 3 | ||||
-rw-r--r-- | sys/rpc/svc_dg.c | 157 | ||||
-rw-r--r-- | sys/rpc/svc_generic.c | 93 | ||||
-rw-r--r-- | sys/rpc/svc_vc.c | 247 | ||||
-rw-r--r-- | sys/rpc/xdr.h | 2 |
29 files changed, 5957 insertions, 738 deletions
diff --git a/sys/rpc/auth.h b/sys/rpc/auth.h index b919559..6be08b6 100644 --- a/sys/rpc/auth.h +++ b/sys/rpc/auth.h @@ -132,7 +132,7 @@ enum auth_stat { * failed locally */ AUTH_INVALIDRESP=6, /* bogus response verifier */ - AUTH_FAILED=7 /* some unknown reason */ + AUTH_FAILED=7, /* some unknown reason */ #ifdef KERBEROS /* * kerberos errors @@ -142,8 +142,14 @@ enum auth_stat { AUTH_TIMEEXPIRE = 9, /* time of credential expired */ AUTH_TKT_FILE = 10, /* something wrong with ticket file */ AUTH_DECODE = 11, /* can't decode authenticator */ - AUTH_NET_ADDR = 12 /* wrong net address in ticket */ + AUTH_NET_ADDR = 12, /* wrong net address in ticket */ #endif /* KERBEROS */ + /* + * RPCSEC_GSS errors + */ + RPCSEC_GSS_CREDPROBLEM = 13, + RPCSEC_GSS_CTXPROBLEM = 14, + RPCSEC_GSS_NODISPATCH = 0x8000000 }; union des_block { @@ -171,6 +177,7 @@ struct opaque_auth { /* * Auth handle, interface to client side authenticators. */ +struct rpc_err; typedef struct __auth { struct opaque_auth ah_cred; struct opaque_auth ah_verf; @@ -178,10 +185,11 @@ typedef struct __auth { struct auth_ops { void (*ah_nextverf) (struct __auth *); /* nextverf & serialize */ - int (*ah_marshal) (struct __auth *, XDR *); + int (*ah_marshal) (struct __auth *, uint32_t, XDR *, + struct mbuf *); /* validate verifier */ - int (*ah_validate) (struct __auth *, - struct opaque_auth *); + int (*ah_validate) (struct __auth *, uint32_t, + struct opaque_auth *, struct mbuf **); /* refresh credentials */ int (*ah_refresh) (struct __auth *, void *); /* destroy this structure */ @@ -201,29 +209,18 @@ typedef struct __auth { */ #define AUTH_NEXTVERF(auth) \ ((*((auth)->ah_ops->ah_nextverf))(auth)) -#define auth_nextverf(auth) \ - ((*((auth)->ah_ops->ah_nextverf))(auth)) -#define AUTH_MARSHALL(auth, xdrs) \ - ((*((auth)->ah_ops->ah_marshal))(auth, xdrs)) -#define auth_marshall(auth, xdrs) \ - ((*((auth)->ah_ops->ah_marshal))(auth, xdrs)) +#define AUTH_MARSHALL(auth, xid, xdrs, args) \ + ((*((auth)->ah_ops->ah_marshal))(auth, xid, xdrs, args)) -#define AUTH_VALIDATE(auth, verfp) \ - ((*((auth)->ah_ops->ah_validate))((auth), verfp)) -#define auth_validate(auth, verfp) \ - ((*((auth)->ah_ops->ah_validate))((auth), verfp)) +#define AUTH_VALIDATE(auth, xid, verfp, resultsp) \ + ((*((auth)->ah_ops->ah_validate))((auth), xid, verfp, resultsp)) #define AUTH_REFRESH(auth, msg) \ ((*((auth)->ah_ops->ah_refresh))(auth, msg)) -#define auth_refresh(auth, msg) \ - ((*((auth)->ah_ops->ah_refresh))(auth, msg)) #define AUTH_DESTROY(auth) \ ((*((auth)->ah_ops->ah_destroy))(auth)) -#define auth_destroy(auth) \ - ((*((auth)->ah_ops->ah_destroy))(auth)) - __BEGIN_DECLS extern struct opaque_auth _null_auth; @@ -357,5 +354,13 @@ __END_DECLS #define AUTH_DH 3 /* for Diffie-Hellman mechanism */ #define AUTH_DES AUTH_DH /* for backward compatibility */ #define AUTH_KERB 4 /* kerberos style */ +#define RPCSEC_GSS 6 /* RPCSEC_GSS */ + +/* + * Pseudo auth flavors for RPCSEC_GSS. + */ +#define RPCSEC_GSS_KRB5 390003 +#define RPCSEC_GSS_KRB5I 390004 +#define RPCSEC_GSS_KRB5P 390005 #endif /* !_RPC_AUTH_H */ diff --git a/sys/rpc/auth_none.c b/sys/rpc/auth_none.c index 8530437..a256b83 100644 --- a/sys/rpc/auth_none.c +++ b/sys/rpc/auth_none.c @@ -54,6 +54,7 @@ __FBSDID("$FreeBSD$"); #include <rpc/types.h> #include <rpc/xdr.h> #include <rpc/auth.h> +#include <rpc/clnt.h> #define MAX_MARSHAL_SIZE 20 @@ -61,9 +62,10 @@ __FBSDID("$FreeBSD$"); * Authenticator operations routines */ -static bool_t authnone_marshal (AUTH *, XDR *); +static bool_t authnone_marshal (AUTH *, uint32_t, XDR *, struct mbuf *); static void authnone_verf (AUTH *); -static bool_t authnone_validate (AUTH *, struct opaque_auth *); +static bool_t authnone_validate (AUTH *, uint32_t, struct opaque_auth *, + struct mbuf **); static bool_t authnone_refresh (AUTH *, void *); static void authnone_destroy (AUTH *); @@ -72,7 +74,7 @@ static struct auth_ops authnone_ops = { .ah_marshal = authnone_marshal, .ah_validate = authnone_validate, .ah_refresh = authnone_refresh, - .ah_destroy = authnone_destroy + .ah_destroy = authnone_destroy, }; struct authnone_private { @@ -109,13 +111,18 @@ authnone_create() /*ARGSUSED*/ static bool_t -authnone_marshal(AUTH *client, XDR *xdrs) +authnone_marshal(AUTH *client, uint32_t xid, XDR *xdrs, struct mbuf *args) { struct authnone_private *ap = &authnone_private; KASSERT(xdrs != NULL, ("authnone_marshal: xdrs is null")); - return (xdrs->x_ops->x_putbytes(xdrs, ap->mclient, ap->mcnt)); + if (!XDR_PUTBYTES(xdrs, ap->mclient, ap->mcnt)) + return (FALSE); + + xdrmbuf_append(xdrs, args); + + return (TRUE); } /* All these unused parameters are required to keep ANSI-C from grumbling */ @@ -127,7 +134,8 @@ authnone_verf(AUTH *client) /*ARGSUSED*/ static bool_t -authnone_validate(AUTH *client, struct opaque_auth *opaque) +authnone_validate(AUTH *client, uint32_t xid, struct opaque_auth *opaque, + struct mbuf **mrepp) { return (TRUE); diff --git a/sys/rpc/auth_unix.c b/sys/rpc/auth_unix.c index e30e59e..bd4be34 100644 --- a/sys/rpc/auth_unix.c +++ b/sys/rpc/auth_unix.c @@ -62,13 +62,15 @@ __FBSDID("$FreeBSD$"); #include <rpc/types.h> #include <rpc/xdr.h> #include <rpc/auth.h> +#include <rpc/clnt.h> #include <rpc/rpc_com.h> /* auth_unix.c */ static void authunix_nextverf (AUTH *); -static bool_t authunix_marshal (AUTH *, XDR *); -static bool_t authunix_validate (AUTH *, struct opaque_auth *); +static bool_t authunix_marshal (AUTH *, uint32_t, XDR *, struct mbuf *); +static bool_t authunix_validate (AUTH *, uint32_t, struct opaque_auth *, + struct mbuf **); static bool_t authunix_refresh (AUTH *, void *); static void authunix_destroy (AUTH *); static void marshal_new_auth (AUTH *); @@ -78,7 +80,7 @@ static struct auth_ops authunix_ops = { .ah_marshal = authunix_marshal, .ah_validate = authunix_validate, .ah_refresh = authunix_refresh, - .ah_destroy = authunix_destroy + .ah_destroy = authunix_destroy, }; /* @@ -246,23 +248,32 @@ authunix_nextverf(AUTH *auth) } static bool_t -authunix_marshal(AUTH *auth, XDR *xdrs) +authunix_marshal(AUTH *auth, uint32_t xid, XDR *xdrs, struct mbuf *args) { struct audata *au; au = AUTH_PRIVATE(auth); - return (XDR_PUTBYTES(xdrs, au->au_marshed, au->au_mpos)); + if (!XDR_PUTBYTES(xdrs, au->au_marshed, au->au_mpos)) + return (FALSE); + + xdrmbuf_append(xdrs, args); + + return (TRUE); } static bool_t -authunix_validate(AUTH *auth, struct opaque_auth *verf) +authunix_validate(AUTH *auth, uint32_t xid, struct opaque_auth *verf, + struct mbuf **mrepp) { struct audata *au; - XDR xdrs; + XDR txdrs; + + if (!verf) + return (TRUE); if (verf->oa_flavor == AUTH_SHORT) { au = AUTH_PRIVATE(auth); - xdrmem_create(&xdrs, verf->oa_base, verf->oa_length, + xdrmem_create(&txdrs, verf->oa_base, verf->oa_length, XDR_DECODE); if (au->au_shcred.oa_base != NULL) { @@ -270,16 +281,17 @@ authunix_validate(AUTH *auth, struct opaque_auth *verf) au->au_shcred.oa_length); au->au_shcred.oa_base = NULL; } - if (xdr_opaque_auth(&xdrs, &au->au_shcred)) { + if (xdr_opaque_auth(&txdrs, &au->au_shcred)) { auth->ah_cred = au->au_shcred; } else { - xdrs.x_op = XDR_FREE; - (void)xdr_opaque_auth(&xdrs, &au->au_shcred); + txdrs.x_op = XDR_FREE; + (void)xdr_opaque_auth(&txdrs, &au->au_shcred); au->au_shcred.oa_base = NULL; auth->ah_cred = au->au_origcred; } marshal_new_auth(auth); } + return (TRUE); } diff --git a/sys/rpc/clnt.h b/sys/rpc/clnt.h index 03e3112..74d5813 100644 --- a/sys/rpc/clnt.h +++ b/sys/rpc/clnt.h @@ -118,6 +118,15 @@ struct rpc_err { typedef void rpc_feedback(int cmd, int procnum, void *); /* + * Timers used for the pseudo-transport protocol when using datagrams + */ +struct rpc_timers { + u_short rt_srtt; /* smoothed round-trip time */ + u_short rt_deviate; /* estimated deviation */ + u_long rt_rtxcur; /* current (backed-off) rto */ +}; + +/* * A structure used with CLNT_CALL_EXT to pass extra information used * while processing an RPC call. */ @@ -125,6 +134,8 @@ struct rpc_callextra { AUTH *rc_auth; /* auth handle to use for this call */ rpc_feedback *rc_feedback; /* callback for retransmits etc. */ void *rc_feedback_arg; /* argument for callback */ + struct rpc_timers *rc_timers; /* optional RTT timers */ + struct rpc_err rc_err; /* detailed call status */ }; #endif @@ -140,8 +151,8 @@ typedef struct __rpc_client { struct clnt_ops { /* call remote procedure */ enum clnt_stat (*cl_call)(struct __rpc_client *, - struct rpc_callextra *, rpcproc_t, xdrproc_t, void *, - xdrproc_t, void *, struct timeval); + struct rpc_callextra *, rpcproc_t, + struct mbuf *, struct mbuf **, struct timeval); /* abort a call */ void (*cl_abort)(struct __rpc_client *); /* get specific error code */ @@ -150,6 +161,8 @@ typedef struct __rpc_client { /* frees results */ bool_t (*cl_freeres)(struct __rpc_client *, xdrproc_t, void *); + /* close the connection and terminate pending RPCs */ + void (*cl_close)(struct __rpc_client *); /* destroy this structure */ void (*cl_destroy)(struct __rpc_client *); /* the ioctl() of rpc */ @@ -183,15 +196,6 @@ typedef struct __rpc_client { char *cl_tp; /* device name */ } CLIENT; -/* - * Timers used for the pseudo-transport protocol when using datagrams - */ -struct rpc_timers { - u_short rt_srtt; /* smoothed round-trip time */ - u_short rt_deviate; /* estimated deviation */ - u_long rt_rtxcur; /* current (backed-off) rto */ -}; - /* * Feedback values used for possible congestion and rate control */ @@ -222,6 +226,32 @@ struct rpc_timers { CLNT_DESTROY(rh) /* + * void + * CLNT_CLOSE(rh); + * CLIENT *rh; + */ +#define CLNT_CLOSE(rh) ((*(rh)->cl_ops->cl_close)(rh)) + +enum clnt_stat clnt_call_private(CLIENT *, struct rpc_callextra *, rpcproc_t, + xdrproc_t, void *, xdrproc_t, void *, struct timeval); + +/* + * enum clnt_stat + * CLNT_CALL_MBUF(rh, ext, proc, mreq, mrepp, timeout) + * CLIENT *rh; + * struct rpc_callextra *ext; + * rpcproc_t proc; + * struct mbuf *mreq; + * struct mbuf **mrepp; + * struct timeval timeout; + * + * Call arguments in mreq which is consumed by the call (even if there + * is an error). Results returned in *mrepp. + */ +#define CLNT_CALL_MBUF(rh, ext, proc, mreq, mrepp, secs) \ + ((*(rh)->cl_ops->cl_call)(rh, ext, proc, mreq, mrepp, secs)) + +/* * enum clnt_stat * CLNT_CALL_EXT(rh, ext, proc, xargs, argsp, xres, resp, timeout) * CLIENT *rh; @@ -234,8 +264,8 @@ struct rpc_timers { * struct timeval timeout; */ #define CLNT_CALL_EXT(rh, ext, proc, xargs, argsp, xres, resp, secs) \ - ((*(rh)->cl_ops->cl_call)(rh, ext, proc, xargs, \ - argsp, xres, resp, secs)) + clnt_call_private(rh, ext, proc, xargs, \ + argsp, xres, resp, secs) #endif /* @@ -250,12 +280,12 @@ struct rpc_timers { * struct timeval timeout; */ #ifdef _KERNEL -#define CLNT_CALL(rh, proc, xargs, argsp, xres, resp, secs) \ - ((*(rh)->cl_ops->cl_call)(rh, NULL, proc, xargs, \ - argsp, xres, resp, secs)) -#define clnt_call(rh, proc, xargs, argsp, xres, resp, secs) \ - ((*(rh)->cl_ops->cl_call)(rh, NULL, proc, xargs, \ - argsp, xres, resp, secs)) +#define CLNT_CALL(rh, proc, xargs, argsp, xres, resp, secs) \ + clnt_call_private(rh, NULL, proc, xargs, \ + argsp, xres, resp, secs) +#define clnt_call(rh, proc, xargs, argsp, xres, resp, secs) \ + clnt_call_private(rh, NULL, proc, xargs, \ + argsp, xres, resp, secs) #else #define CLNT_CALL(rh, proc, xargs, argsp, xres, resp, secs) \ ((*(rh)->cl_ops->cl_call)(rh, proc, xargs, \ @@ -340,6 +370,8 @@ struct rpc_timers { #define CLGET_INTERRUPTIBLE 24 /* set interruptible flag */ #define CLSET_RETRIES 25 /* set retry count for reconnect */ #define CLGET_RETRIES 26 /* get retry count for reconnect */ +#define CLSET_PRIVPORT 27 /* set privileged source port flag */ +#define CLGET_PRIVPORT 28 /* get privileged source port flag */ #endif diff --git a/sys/rpc/clnt_dg.c b/sys/rpc/clnt_dg.c index f14e1d6..e6d101d 100644 --- a/sys/rpc/clnt_dg.c +++ b/sys/rpc/clnt_dg.c @@ -72,11 +72,12 @@ __FBSDID("$FreeBSD$"); static bool_t time_not_ok(struct timeval *); static enum clnt_stat clnt_dg_call(CLIENT *, struct rpc_callextra *, - rpcproc_t, xdrproc_t, void *, xdrproc_t, void *, struct timeval); + rpcproc_t, struct mbuf *, struct mbuf **, struct timeval); static void clnt_dg_geterr(CLIENT *, struct rpc_err *); static bool_t clnt_dg_freeres(CLIENT *, xdrproc_t, void *); static void clnt_dg_abort(CLIENT *); static bool_t clnt_dg_control(CLIENT *, u_int, void *); +static void clnt_dg_close(CLIENT *); static void clnt_dg_destroy(CLIENT *); static void clnt_dg_soupcall(struct socket *so, void *arg, int waitflag); @@ -85,6 +86,7 @@ static struct clnt_ops clnt_dg_ops = { .cl_abort = clnt_dg_abort, .cl_geterr = clnt_dg_geterr, .cl_freeres = clnt_dg_freeres, + .cl_close = clnt_dg_close, .cl_destroy = clnt_dg_destroy, .cl_control = clnt_dg_control }; @@ -102,6 +104,7 @@ struct cu_request { uint32_t cr_xid; /* XID of request */ struct mbuf *cr_mrep; /* reply received by upcall */ int cr_error; /* any error from upcall */ + char cr_verf[MAX_AUTH_BYTES]; /* reply verf */ }; TAILQ_HEAD(cu_request_list, cu_request); @@ -120,7 +123,6 @@ struct cu_socket { struct mtx cs_lock; int cs_refs; /* Count of clients */ struct cu_request_list cs_pending; /* Requests awaiting replies */ - }; /* @@ -128,7 +130,8 @@ struct cu_socket { */ struct cu_data { int cu_threads; /* # threads in clnt_vc_call */ - bool_t cu_closing; /* TRUE if we are destroying */ + bool_t cu_closing; /* TRUE if we are closing */ + bool_t cu_closed; /* TRUE if we are closed */ struct socket *cu_socket; /* connection socket */ bool_t cu_closeit; /* opened by library */ struct sockaddr_storage cu_raddr; /* remote address */ @@ -146,8 +149,14 @@ struct cu_data { int cu_connected; /* Have done connect(). */ const char *cu_waitchan; int cu_waitflag; + int cu_cwnd; /* congestion window */ + int cu_sent; /* number of in-flight RPCs */ + bool_t cu_cwnd_wait; }; +#define CWNDSCALE 256 +#define MAXCWND (32 * CWNDSCALE) + /* * Connection less client creation returns with client handle parameters. * Default options are set, which the user can change using clnt_control(). @@ -211,6 +220,7 @@ clnt_dg_create( cu = mem_alloc(sizeof (*cu)); cu->cu_threads = 0; cu->cu_closing = FALSE; + cu->cu_closed = FALSE; (void) memcpy(&cu->cu_raddr, svcaddr, (size_t)svcaddr->sa_len); cu->cu_rlen = svcaddr->sa_len; /* Other values can also be set through clnt_control() */ @@ -225,6 +235,9 @@ clnt_dg_create( cu->cu_connected = FALSE; cu->cu_waitchan = "rpcrecv"; cu->cu_waitflag = 0; + cu->cu_cwnd = MAXCWND / 2; + cu->cu_sent = 0; + cu->cu_cwnd_wait = FALSE; (void) getmicrotime(&now); cu->cu_xid = __RPC_GETXID(&now); call_msg.rm_xid = cu->cu_xid; @@ -304,15 +317,16 @@ clnt_dg_call( CLIENT *cl, /* client handle */ struct rpc_callextra *ext, /* call metadata */ rpcproc_t proc, /* procedure number */ - xdrproc_t xargs, /* xdr routine for args */ - void *argsp, /* pointer to args */ - xdrproc_t xresults, /* xdr routine for results */ - void *resultsp, /* pointer to results */ + struct mbuf *args, /* pointer to args */ + struct mbuf **resultsp, /* pointer to results */ struct timeval utimeout) /* seconds to wait before giving up */ { struct cu_data *cu = (struct cu_data *)cl->cl_private; struct cu_socket *cs = (struct cu_socket *) cu->cu_socket->so_upcallarg; + struct rpc_timers *rt; AUTH *auth; + struct rpc_err *errp; + enum clnt_stat stat; XDR xdrs; struct rpc_msg reply_msg; bool_t ok; @@ -321,11 +335,11 @@ clnt_dg_call( struct timeval *tvp; int timeout; int retransmit_time; - int next_sendtime, starttime, time_waited, tv; + int next_sendtime, starttime, rtt, time_waited, tv = 0; struct sockaddr *sa; socklen_t salen; - uint32_t xid; - struct mbuf *mreq = NULL; + uint32_t xid = 0; + struct mbuf *mreq = NULL, *results; struct cu_request *cr; int error; @@ -333,17 +347,20 @@ clnt_dg_call( mtx_lock(&cs->cs_lock); - if (cu->cu_closing) { + if (cu->cu_closing || cu->cu_closed) { mtx_unlock(&cs->cs_lock); free(cr, M_RPC); return (RPC_CANTSEND); } cu->cu_threads++; - if (ext) + if (ext) { auth = ext->rc_auth; - else + errp = &ext->rc_err; + } else { auth = cl->cl_auth; + errp = &cu->cu_error; + } cr->cr_client = cl; cr->cr_mrep = NULL; @@ -365,8 +382,8 @@ clnt_dg_call( (struct sockaddr *)&cu->cu_raddr, curthread); mtx_lock(&cs->cs_lock); if (error) { - cu->cu_error.re_errno = error; - cu->cu_error.re_status = RPC_CANTSEND; + errp->re_errno = error; + errp->re_status = stat = RPC_CANTSEND; goto out; } cu->cu_connected = 1; @@ -380,7 +397,15 @@ clnt_dg_call( } time_waited = 0; retrans = 0; - retransmit_time = next_sendtime = tvtohz(&cu->cu_wait); + if (ext && ext->rc_timers) { + rt = ext->rc_timers; + if (!rt->rt_rtxcur) + rt->rt_rtxcur = tvtohz(&cu->cu_wait); + retransmit_time = next_sendtime = rt->rt_rtxcur; + } else { + rt = NULL; + retransmit_time = next_sendtime = tvtohz(&cu->cu_wait); + } starttime = ticks; @@ -394,9 +419,9 @@ send_again: mtx_unlock(&cs->cs_lock); MGETHDR(mreq, M_WAIT, MT_DATA); - MCLGET(mreq, M_WAIT); - mreq->m_len = 0; - m_append(mreq, cu->cu_mcalllen, cu->cu_mcallc); + KASSERT(cu->cu_mcalllen <= MHLEN, ("RPC header too big")); + bcopy(cu->cu_mcallc, mreq->m_data, cu->cu_mcalllen); + mreq->m_len = cu->cu_mcalllen; /* * The XID is the first thing in the request. @@ -405,20 +430,36 @@ send_again: xdrmbuf_create(&xdrs, mreq, XDR_ENCODE); - if (cu->cu_async == TRUE && xargs == NULL) + if (cu->cu_async == TRUE && args == NULL) goto get_reply; if ((! XDR_PUTINT32(&xdrs, &proc)) || - (! AUTH_MARSHALL(auth, &xdrs)) || - (! (*xargs)(&xdrs, argsp))) { - cu->cu_error.re_status = RPC_CANTENCODEARGS; + (! AUTH_MARSHALL(auth, xid, &xdrs, + m_copym(args, 0, M_COPYALL, M_WAITOK)))) { + errp->re_status = stat = RPC_CANTENCODEARGS; mtx_lock(&cs->cs_lock); goto out; } - m_fixhdr(mreq); + mreq->m_pkthdr.len = m_length(mreq, NULL); cr->cr_xid = xid; mtx_lock(&cs->cs_lock); + + /* + * Try to get a place in the congestion window. + */ + while (cu->cu_sent >= cu->cu_cwnd) { + cu->cu_cwnd_wait = TRUE; + error = msleep(&cu->cu_cwnd_wait, &cs->cs_lock, + cu->cu_waitflag, "rpccwnd", 0); + if (error) { + errp->re_errno = error; + errp->re_status = stat = RPC_CANTSEND; + goto out; + } + } + cu->cu_sent += CWNDSCALE; + TAILQ_INSERT_TAIL(&cs->cs_pending, cr, cr_link); mtx_unlock(&cs->cs_lock); @@ -433,15 +474,22 @@ send_again: * some clock time to spare while the packets are in flight. * (We assume that this is actually only executed once.) */ - reply_msg.acpted_rply.ar_verf = _null_auth; - reply_msg.acpted_rply.ar_results.where = resultsp; - reply_msg.acpted_rply.ar_results.proc = xresults; + reply_msg.acpted_rply.ar_verf.oa_flavor = AUTH_NULL; + reply_msg.acpted_rply.ar_verf.oa_base = cr->cr_verf; + reply_msg.acpted_rply.ar_verf.oa_length = 0; + reply_msg.acpted_rply.ar_results.where = NULL; + reply_msg.acpted_rply.ar_results.proc = (xdrproc_t)xdr_void; mtx_lock(&cs->cs_lock); if (error) { TAILQ_REMOVE(&cs->cs_pending, cr, cr_link); - cu->cu_error.re_errno = error; - cu->cu_error.re_status = RPC_CANTSEND; + errp->re_errno = error; + errp->re_status = stat = RPC_CANTSEND; + cu->cu_sent -= CWNDSCALE; + if (cu->cu_cwnd_wait) { + cu->cu_cwnd_wait = FALSE; + wakeup(&cu->cu_cwnd_wait); + } goto out; } @@ -451,12 +499,22 @@ send_again: */ if (cr->cr_error) { TAILQ_REMOVE(&cs->cs_pending, cr, cr_link); - cu->cu_error.re_errno = cr->cr_error; - cu->cu_error.re_status = RPC_CANTRECV; + errp->re_errno = cr->cr_error; + errp->re_status = stat = RPC_CANTRECV; + cu->cu_sent -= CWNDSCALE; + if (cu->cu_cwnd_wait) { + cu->cu_cwnd_wait = FALSE; + wakeup(&cu->cu_cwnd_wait); + } goto out; } if (cr->cr_mrep) { TAILQ_REMOVE(&cs->cs_pending, cr, cr_link); + cu->cu_sent -= CWNDSCALE; + if (cu->cu_cwnd_wait) { + cu->cu_cwnd_wait = FALSE; + wakeup(&cu->cu_cwnd_wait); + } goto got_reply; } @@ -465,7 +523,12 @@ send_again: */ if (timeout == 0) { TAILQ_REMOVE(&cs->cs_pending, cr, cr_link); - cu->cu_error.re_status = RPC_TIMEDOUT; + errp->re_status = stat = RPC_TIMEDOUT; + cu->cu_sent -= CWNDSCALE; + if (cu->cu_cwnd_wait) { + cu->cu_cwnd_wait = FALSE; + wakeup(&cu->cu_cwnd_wait); + } goto out; } @@ -479,7 +542,7 @@ get_reply: tv -= time_waited; if (tv > 0) { - if (cu->cu_closing) + if (cu->cu_closing || cu->cu_closed) error = 0; else error = msleep(cr, &cs->cs_lock, @@ -489,6 +552,11 @@ get_reply: } TAILQ_REMOVE(&cs->cs_pending, cr, cr_link); + cu->cu_sent -= CWNDSCALE; + if (cu->cu_cwnd_wait) { + cu->cu_cwnd_wait = FALSE; + wakeup(&cu->cu_cwnd_wait); + } if (!error) { /* @@ -497,10 +565,52 @@ get_reply: * otherwise we have a reply. */ if (cr->cr_error) { - cu->cu_error.re_errno = cr->cr_error; - cu->cu_error.re_status = RPC_CANTRECV; + errp->re_errno = cr->cr_error; + errp->re_status = stat = RPC_CANTRECV; goto out; } + + cu->cu_cwnd += (CWNDSCALE * CWNDSCALE + + cu->cu_cwnd / 2) / cu->cu_cwnd; + if (cu->cu_cwnd > MAXCWND) + cu->cu_cwnd = MAXCWND; + + if (rt) { + /* + * Add one to the time since a tick + * count of N means that the actual + * time taken was somewhere between N + * and N+1. + */ + rtt = ticks - starttime + 1; + + /* + * Update our estimate of the round + * trip time using roughly the + * algorithm described in RFC + * 2988. Given an RTT sample R: + * + * RTTVAR = (1-beta) * RTTVAR + beta * |SRTT-R| + * SRTT = (1-alpha) * SRTT + alpha * R + * + * where alpha = 0.125 and beta = 0.25. + * + * The initial retransmit timeout is + * SRTT + 4*RTTVAR and doubles on each + * retransmision. + */ + if (rt->rt_srtt == 0) { + rt->rt_srtt = rtt; + rt->rt_deviate = rtt / 2; + } else { + int32_t error = rtt - rt->rt_srtt; + rt->rt_srtt += error / 8; + error = abs(error) - rt->rt_deviate; + rt->rt_deviate += error / 4; + } + rt->rt_rtxcur = rt->rt_srtt + 4*rt->rt_deviate; + } + break; } @@ -510,11 +620,11 @@ get_reply: * re-send the request. */ if (error != EWOULDBLOCK) { - cu->cu_error.re_errno = error; + errp->re_errno = error; if (error == EINTR) - cu->cu_error.re_status = RPC_INTR; + errp->re_status = stat = RPC_INTR; else - cu->cu_error.re_status = RPC_CANTRECV; + errp->re_status = stat = RPC_CANTRECV; goto out; } @@ -522,13 +632,16 @@ get_reply: /* Check for timeout. */ if (time_waited > timeout) { - cu->cu_error.re_errno = EWOULDBLOCK; - cu->cu_error.re_status = RPC_TIMEDOUT; + errp->re_errno = EWOULDBLOCK; + errp->re_status = stat = RPC_TIMEDOUT; goto out; } /* Retransmit if necessary. */ if (time_waited >= next_sendtime) { + cu->cu_cwnd /= 2; + if (cu->cu_cwnd < CWNDSCALE) + cu->cu_cwnd = CWNDSCALE; if (ext && ext->rc_feedback) { mtx_unlock(&cs->cs_lock); if (retrans == 0) @@ -539,9 +652,9 @@ get_reply: proc, ext->rc_feedback_arg); mtx_lock(&cs->cs_lock); } - if (cu->cu_closing) { - cu->cu_error.re_errno = ESHUTDOWN; - cu->cu_error.re_status = RPC_CANTRECV; + if (cu->cu_closing || cu->cu_closed) { + errp->re_errno = ESHUTDOWN; + errp->re_status = stat = RPC_CANTRECV; goto out; } retrans++; @@ -566,47 +679,72 @@ got_reply: xdrmbuf_create(&xdrs, cr->cr_mrep, XDR_DECODE); ok = xdr_replymsg(&xdrs, &reply_msg); - XDR_DESTROY(&xdrs); cr->cr_mrep = NULL; - mtx_lock(&cs->cs_lock); - if (ok) { if ((reply_msg.rm_reply.rp_stat == MSG_ACCEPTED) && - (reply_msg.acpted_rply.ar_stat == SUCCESS)) - cu->cu_error.re_status = RPC_SUCCESS; + (reply_msg.acpted_rply.ar_stat == SUCCESS)) + errp->re_status = stat = RPC_SUCCESS; else - _seterr_reply(&reply_msg, &(cu->cu_error)); - - if (cu->cu_error.re_status == RPC_SUCCESS) { - if (! AUTH_VALIDATE(cl->cl_auth, - &reply_msg.acpted_rply.ar_verf)) { - cu->cu_error.re_status = RPC_AUTHERROR; - cu->cu_error.re_why = AUTH_INVALIDRESP; - } - if (reply_msg.acpted_rply.ar_verf.oa_base != NULL) { - xdrs.x_op = XDR_FREE; - (void) xdr_opaque_auth(&xdrs, - &(reply_msg.acpted_rply.ar_verf)); + stat = _seterr_reply(&reply_msg, &(cu->cu_error)); + + if (errp->re_status == RPC_SUCCESS) { + results = xdrmbuf_getall(&xdrs); + if (! AUTH_VALIDATE(auth, xid, + &reply_msg.acpted_rply.ar_verf, + &results)) { + errp->re_status = stat = RPC_AUTHERROR; + errp->re_why = AUTH_INVALIDRESP; + if (retrans && + auth->ah_cred.oa_flavor == RPCSEC_GSS) { + /* + * If we retransmitted, its + * possible that we will + * receive a reply for one of + * the earlier transmissions + * (which will use an older + * RPCSEC_GSS sequence + * number). In this case, just + * go back and listen for a + * new reply. We could keep a + * record of all the seq + * numbers we have transmitted + * so far so that we could + * accept a reply for any of + * them here. + */ + XDR_DESTROY(&xdrs); + mtx_lock(&cs->cs_lock); + TAILQ_INSERT_TAIL(&cs->cs_pending, + cr, cr_link); + cr->cr_mrep = NULL; + goto get_reply; + } + } else { + *resultsp = results; } } /* end successful completion */ /* * If unsuccesful AND error is an authentication error * then refresh credentials and try again, else break */ - else if (cu->cu_error.re_status == RPC_AUTHERROR) + else if (stat == RPC_AUTHERROR) /* maybe our credentials need to be refreshed ... */ if (nrefreshes > 0 && - AUTH_REFRESH(cl->cl_auth, &reply_msg)) { + AUTH_REFRESH(auth, &reply_msg)) { nrefreshes--; + XDR_DESTROY(&xdrs); + mtx_lock(&cs->cs_lock); goto call_again; } /* end of unsuccessful completion */ } /* end of valid reply message */ else { - cu->cu_error.re_status = RPC_CANTDECODERES; + errp->re_status = stat = RPC_CANTDECODERES; } + XDR_DESTROY(&xdrs); + mtx_lock(&cs->cs_lock); out: mtx_assert(&cs->cs_lock, MA_OWNED); @@ -621,9 +759,12 @@ out: mtx_unlock(&cs->cs_lock); + if (auth && stat != RPC_SUCCESS) + AUTH_VALIDATE(auth, xid, NULL, NULL); + free(cr, M_RPC); - return (cu->cu_error.re_status); + return (stat); } static void @@ -759,7 +900,7 @@ clnt_dg_control(CLIENT *cl, u_int request, void *info) cu->cu_connect = *(int *)info; break; case CLSET_WAITCHAN: - cu->cu_waitchan = *(const char **)info; + cu->cu_waitchan = (const char *)info; break; case CLGET_WAITCHAN: *(const char **) info = cu->cu_waitchan; @@ -785,16 +926,27 @@ clnt_dg_control(CLIENT *cl, u_int request, void *info) } static void -clnt_dg_destroy(CLIENT *cl) +clnt_dg_close(CLIENT *cl) { struct cu_data *cu = (struct cu_data *)cl->cl_private; struct cu_socket *cs = (struct cu_socket *) cu->cu_socket->so_upcallarg; struct cu_request *cr; - struct socket *so = NULL; - bool_t lastsocketref; mtx_lock(&cs->cs_lock); + if (cu->cu_closed) { + mtx_unlock(&cs->cs_lock); + return; + } + + if (cu->cu_closing) { + while (cu->cu_closing) + msleep(cu, &cs->cs_lock, 0, "rpcclose", 0); + KASSERT(cu->cu_closed, ("client should be closed")); + mtx_unlock(&cs->cs_lock); + return; + } + /* * Abort any pending requests and wait until everyone * has finished with clnt_vc_call. @@ -811,6 +963,25 @@ clnt_dg_destroy(CLIENT *cl) while (cu->cu_threads) msleep(cu, &cs->cs_lock, 0, "rpcclose", 0); + cu->cu_closing = FALSE; + cu->cu_closed = TRUE; + + mtx_unlock(&cs->cs_lock); + wakeup(cu); +} + +static void +clnt_dg_destroy(CLIENT *cl) +{ + struct cu_data *cu = (struct cu_data *)cl->cl_private; + struct cu_socket *cs = (struct cu_socket *) cu->cu_socket->so_upcallarg; + struct socket *so = NULL; + bool_t lastsocketref; + + clnt_dg_close(cl); + + mtx_lock(&cs->cs_lock); + cs->cs_refs--; if (cs->cs_refs == 0) { mtx_destroy(&cs->cs_lock); @@ -894,7 +1065,8 @@ clnt_dg_soupcall(struct socket *so, void *arg, int waitflag) /* * The XID is in the first uint32_t of the reply. */ - m = m_pullup(m, sizeof(xid)); + if (m->m_len < sizeof(xid)) + m = m_pullup(m, sizeof(xid)); if (!m) /* * Should never happen. diff --git a/sys/rpc/clnt_rc.c b/sys/rpc/clnt_rc.c index f0ad673..8d7bfd6 100644 --- a/sys/rpc/clnt_rc.c +++ b/sys/rpc/clnt_rc.c @@ -30,6 +30,7 @@ __FBSDID("$FreeBSD$"); #include <sys/param.h> #include <sys/systm.h> +#include <sys/kernel.h> #include <sys/limits.h> #include <sys/lock.h> #include <sys/malloc.h> @@ -46,11 +47,12 @@ __FBSDID("$FreeBSD$"); #include <rpc/rpc_com.h> static enum clnt_stat clnt_reconnect_call(CLIENT *, struct rpc_callextra *, - rpcproc_t, xdrproc_t, void *, xdrproc_t, void *, struct timeval); + rpcproc_t, struct mbuf *, struct mbuf **, struct timeval); static void clnt_reconnect_geterr(CLIENT *, struct rpc_err *); static bool_t clnt_reconnect_freeres(CLIENT *, xdrproc_t, void *); static void clnt_reconnect_abort(CLIENT *); static bool_t clnt_reconnect_control(CLIENT *, u_int, void *); +static void clnt_reconnect_close(CLIENT *); static void clnt_reconnect_destroy(CLIENT *); static struct clnt_ops clnt_reconnect_ops = { @@ -58,10 +60,13 @@ static struct clnt_ops clnt_reconnect_ops = { .cl_abort = clnt_reconnect_abort, .cl_geterr = clnt_reconnect_geterr, .cl_freeres = clnt_reconnect_freeres, + .cl_close = clnt_reconnect_close, .cl_destroy = clnt_reconnect_destroy, .cl_control = clnt_reconnect_control }; +static int fake_wchan; + struct rc_data { struct mtx rc_lock; struct sockaddr_storage rc_addr; /* server address */ @@ -73,10 +78,14 @@ struct rc_data { struct timeval rc_timeout; struct timeval rc_retry; int rc_retries; - const char *rc_waitchan; + int rc_privport; + char *rc_waitchan; int rc_intr; int rc_connecting; + int rc_closed; + struct ucred *rc_ucred; CLIENT* rc_client; /* underlying RPC client */ + struct rpc_err rc_err; }; CLIENT * @@ -110,9 +119,12 @@ clnt_reconnect_create( rc->rc_retry.tv_sec = 3; rc->rc_retry.tv_usec = 0; rc->rc_retries = INT_MAX; + rc->rc_privport = FALSE; rc->rc_waitchan = "rpcrecv"; rc->rc_intr = 0; rc->rc_connecting = FALSE; + rc->rc_closed = FALSE; + rc->rc_ucred = crdup(curthread->td_ucred); rc->rc_client = NULL; cl->cl_refs = 1; @@ -127,16 +139,22 @@ clnt_reconnect_create( static enum clnt_stat clnt_reconnect_connect(CLIENT *cl) { + struct thread *td = curthread; struct rc_data *rc = (struct rc_data *)cl->cl_private; struct socket *so; enum clnt_stat stat; int error; int one = 1; + struct ucred *oldcred; mtx_lock(&rc->rc_lock); again: + if (rc->rc_closed) { + mtx_unlock(&rc->rc_lock); + return (RPC_CANTSEND); + } if (rc->rc_connecting) { - while (!rc->rc_client) { + while (!rc->rc_closed && !rc->rc_client) { error = msleep(rc, &rc->rc_lock, rc->rc_intr ? PCATCH : 0, "rpcrecon", 0); if (error) { @@ -163,7 +181,11 @@ again: rpc_createerr.cf_error.re_errno = 0; goto out; } + if (rc->rc_privport) + bindresvport(so, NULL); + oldcred = td->td_ucred; + td->td_ucred = rc->rc_ucred; if (rc->rc_nconf->nc_semantics == NC_TPI_CLTS) rc->rc_client = clnt_dg_create(so, (struct sockaddr *) &rc->rc_addr, rc->rc_prog, rc->rc_vers, @@ -172,8 +194,11 @@ again: rc->rc_client = clnt_vc_create(so, (struct sockaddr *) &rc->rc_addr, rc->rc_prog, rc->rc_vers, rc->rc_sendsz, rc->rc_recvsz); + td->td_ucred = oldcred; if (!rc->rc_client) { + soclose(so); + rc->rc_err = rpc_createerr.cf_error; stat = rpc_createerr.cf_stat; goto out; } @@ -182,12 +207,19 @@ again: CLNT_CONTROL(rc->rc_client, CLSET_CONNECT, &one); CLNT_CONTROL(rc->rc_client, CLSET_TIMEOUT, &rc->rc_timeout); CLNT_CONTROL(rc->rc_client, CLSET_RETRY_TIMEOUT, &rc->rc_retry); - CLNT_CONTROL(rc->rc_client, CLSET_WAITCHAN, &rc->rc_waitchan); + CLNT_CONTROL(rc->rc_client, CLSET_WAITCHAN, rc->rc_waitchan); CLNT_CONTROL(rc->rc_client, CLSET_INTERRUPTIBLE, &rc->rc_intr); stat = RPC_SUCCESS; out: mtx_lock(&rc->rc_lock); + if (rc->rc_closed) { + if (rc->rc_client) { + CLNT_CLOSE(rc->rc_client); + CLNT_RELEASE(rc->rc_client); + rc->rc_client = NULL; + } + } rc->rc_connecting = FALSE; wakeup(rc); mtx_unlock(&rc->rc_lock); @@ -200,11 +232,9 @@ clnt_reconnect_call( CLIENT *cl, /* client handle */ struct rpc_callextra *ext, /* call metadata */ rpcproc_t proc, /* procedure number */ - xdrproc_t xargs, /* xdr routine for args */ - void *argsp, /* pointer to args */ - xdrproc_t xresults, /* xdr routine for results */ - void *resultsp, /* pointer to results */ - struct timeval utimeout) /* seconds to wait before giving up */ + struct mbuf *args, /* pointer to args */ + struct mbuf **resultsp, /* pointer to results */ + struct timeval utimeout) { struct rc_data *rc = (struct rc_data *)cl->cl_private; CLIENT *client; @@ -213,18 +243,40 @@ clnt_reconnect_call( tries = 0; do { + if (rc->rc_closed) { + return (RPC_CANTSEND); + } + if (!rc->rc_client) { stat = clnt_reconnect_connect(cl); + if (stat == RPC_SYSTEMERROR) { + (void) tsleep(&fake_wchan, 0, + "rpccon", hz); + tries++; + if (tries >= rc->rc_retries) + return (stat); + continue; + } if (stat != RPC_SUCCESS) return (stat); } mtx_lock(&rc->rc_lock); + if (!rc->rc_client) { + mtx_unlock(&rc->rc_lock); + stat = RPC_FAILED; + continue; + } CLNT_ACQUIRE(rc->rc_client); client = rc->rc_client; mtx_unlock(&rc->rc_lock); - stat = CLNT_CALL_EXT(client, ext, proc, xargs, argsp, - xresults, resultsp, utimeout); + stat = CLNT_CALL_MBUF(client, ext, proc, args, + resultsp, utimeout); + + if (stat != RPC_SUCCESS) { + if (!ext) + CLNT_GETERR(client, &rc->rc_err); + } CLNT_RELEASE(client); if (stat == RPC_TIMEDOUT) { @@ -241,10 +293,8 @@ clnt_reconnect_call( } } - if (stat == RPC_INTR) - break; - - if (stat != RPC_SUCCESS) { + if (stat == RPC_TIMEDOUT || stat == RPC_CANTSEND + || stat == RPC_CANTRECV) { tries++; if (tries >= rc->rc_retries) break; @@ -263,9 +313,14 @@ clnt_reconnect_call( rc->rc_client = NULL; } mtx_unlock(&rc->rc_lock); + } else { + break; } } while (stat != RPC_SUCCESS); + KASSERT(stat != RPC_SUCCESS || *resultsp, + ("RPC_SUCCESS without reply")); + return (stat); } @@ -274,10 +329,7 @@ clnt_reconnect_geterr(CLIENT *cl, struct rpc_err *errp) { struct rc_data *rc = (struct rc_data *)cl->cl_private; - if (rc->rc_client) - CLNT_GETERR(rc->rc_client, errp); - else - memset(errp, 0, sizeof(*errp)); + *errp = rc->rc_err; } static bool_t @@ -344,7 +396,7 @@ clnt_reconnect_control(CLIENT *cl, u_int request, void *info) break; case CLSET_WAITCHAN: - rc->rc_waitchan = *(const char **)info; + rc->rc_waitchan = (char *)info; if (rc->rc_client) CLNT_CONTROL(rc->rc_client, request, info); break; @@ -371,6 +423,14 @@ clnt_reconnect_control(CLIENT *cl, u_int request, void *info) *(int *) info = rc->rc_retries; break; + case CLSET_PRIVPORT: + rc->rc_privport = *(int *) info; + break; + + case CLGET_PRIVPORT: + *(int *) info = rc->rc_privport; + break; + default: return (FALSE); } @@ -379,12 +439,38 @@ clnt_reconnect_control(CLIENT *cl, u_int request, void *info) } static void +clnt_reconnect_close(CLIENT *cl) +{ + struct rc_data *rc = (struct rc_data *)cl->cl_private; + CLIENT *client; + + mtx_lock(&rc->rc_lock); + + if (rc->rc_closed) { + mtx_unlock(&rc->rc_lock); + return; + } + + rc->rc_closed = TRUE; + client = rc->rc_client; + rc->rc_client = NULL; + + mtx_unlock(&rc->rc_lock); + + if (client) { + CLNT_CLOSE(client); + CLNT_RELEASE(client); + } +} + +static void clnt_reconnect_destroy(CLIENT *cl) { struct rc_data *rc = (struct rc_data *)cl->cl_private; if (rc->rc_client) CLNT_DESTROY(rc->rc_client); + crfree(rc->rc_ucred); mtx_destroy(&rc->rc_lock); mem_free(rc, sizeof(*rc)); mem_free(cl, sizeof (CLIENT)); diff --git a/sys/rpc/clnt_vc.c b/sys/rpc/clnt_vc.c index cb09352..11fc201 100644 --- a/sys/rpc/clnt_vc.c +++ b/sys/rpc/clnt_vc.c @@ -64,11 +64,13 @@ __FBSDID("$FreeBSD$"); #include <sys/mutex.h> #include <sys/pcpu.h> #include <sys/proc.h> +#include <sys/protosw.h> #include <sys/socket.h> #include <sys/socketvar.h> #include <sys/syslog.h> #include <sys/time.h> #include <sys/uio.h> +#include <netinet/tcp.h> #include <rpc/rpc.h> #include <rpc/rpc_com.h> @@ -81,11 +83,12 @@ struct cmessage { }; static enum clnt_stat clnt_vc_call(CLIENT *, struct rpc_callextra *, - rpcproc_t, xdrproc_t, void *, xdrproc_t, void *, struct timeval); + rpcproc_t, struct mbuf *, struct mbuf **, struct timeval); static void clnt_vc_geterr(CLIENT *, struct rpc_err *); static bool_t clnt_vc_freeres(CLIENT *, xdrproc_t, void *); static void clnt_vc_abort(CLIENT *); static bool_t clnt_vc_control(CLIENT *, u_int, void *); +static void clnt_vc_close(CLIENT *); static void clnt_vc_destroy(CLIENT *); static bool_t time_not_ok(struct timeval *); static void clnt_vc_soupcall(struct socket *so, void *arg, int waitflag); @@ -95,6 +98,7 @@ static struct clnt_ops clnt_vc_ops = { .cl_abort = clnt_vc_abort, .cl_geterr = clnt_vc_geterr, .cl_freeres = clnt_vc_freeres, + .cl_close = clnt_vc_close, .cl_destroy = clnt_vc_destroy, .cl_control = clnt_vc_control }; @@ -109,6 +113,7 @@ struct ct_request { uint32_t cr_xid; /* XID of request */ struct mbuf *cr_mrep; /* reply received by upcall */ int cr_error; /* any error from upcall */ + char cr_verf[MAX_AUTH_BYTES]; /* reply verf */ }; TAILQ_HEAD(ct_request_list, ct_request); @@ -116,7 +121,8 @@ TAILQ_HEAD(ct_request_list, ct_request); struct ct_data { struct mtx ct_lock; int ct_threads; /* number of threads in clnt_vc_call */ - bool_t ct_closing; /* TRUE if we are destroying client */ + bool_t ct_closing; /* TRUE if we are closing */ + bool_t ct_closed; /* TRUE if we are closed */ struct socket *ct_socket; /* connection socket */ bool_t ct_closeit; /* close it on destroy */ struct timeval ct_wait; /* wait interval in milliseconds */ @@ -165,7 +171,8 @@ clnt_vc_create( static uint32_t disrupt; struct __rpc_sockinfo si; XDR xdrs; - int error, interrupted; + int error, interrupted, one = 1; + struct sockopt sopt; if (disrupt == 0) disrupt = (uint32_t)(long)raddr; @@ -176,6 +183,7 @@ clnt_vc_create( mtx_init(&ct->ct_lock, "ct->ct_lock", NULL, MTX_DEF); ct->ct_threads = 0; ct->ct_closing = FALSE; + ct->ct_closed = FALSE; if ((so->so_state & (SS_ISCONNECTED|SS_ISCONFIRMING)) == 0) { error = soconnect(so, raddr, curthread); @@ -208,6 +216,26 @@ clnt_vc_create( if (!__rpc_socket2sockinfo(so, &si)) goto err; + if (so->so_proto->pr_flags & PR_CONNREQUIRED) { + bzero(&sopt, sizeof(sopt)); + sopt.sopt_dir = SOPT_SET; + sopt.sopt_level = SOL_SOCKET; + sopt.sopt_name = SO_KEEPALIVE; + sopt.sopt_val = &one; + sopt.sopt_valsize = sizeof(one); + sosetopt(so, &sopt); + } + + if (so->so_proto->pr_protocol == IPPROTO_TCP) { + bzero(&sopt, sizeof(sopt)); + sopt.sopt_dir = SOPT_SET; + sopt.sopt_level = IPPROTO_TCP; + sopt.sopt_name = TCP_NODELAY; + sopt.sopt_val = &one; + sopt.sopt_valsize = sizeof(one); + sosetopt(so, &sopt); + } + ct->ct_closeit = FALSE; /* @@ -255,6 +283,7 @@ clnt_vc_create( cl->cl_auth = authnone_create(); sendsz = __rpc_get_t_size(si.si_af, si.si_proto, (int)sendsz); recvsz = __rpc_get_t_size(si.si_af, si.si_proto, (int)recvsz); + soreserve(ct->ct_socket, sendsz, recvsz); SOCKBUF_LOCK(&ct->ct_socket->so_rcv); ct->ct_socket->so_upcallarg = ct; @@ -280,24 +309,24 @@ err: static enum clnt_stat clnt_vc_call( - CLIENT *cl, - struct rpc_callextra *ext, - rpcproc_t proc, - xdrproc_t xdr_args, - void *args_ptr, - xdrproc_t xdr_results, - void *results_ptr, - struct timeval utimeout) + CLIENT *cl, /* client handle */ + struct rpc_callextra *ext, /* call metadata */ + rpcproc_t proc, /* procedure number */ + struct mbuf *args, /* pointer to args */ + struct mbuf **resultsp, /* pointer to results */ + struct timeval utimeout) { struct ct_data *ct = (struct ct_data *) cl->cl_private; AUTH *auth; + struct rpc_err *errp; + enum clnt_stat stat; XDR xdrs; struct rpc_msg reply_msg; bool_t ok; int nrefreshes = 2; /* number of times to refresh cred */ struct timeval timeout; uint32_t xid; - struct mbuf *mreq = NULL; + struct mbuf *mreq = NULL, *results; struct ct_request *cr; int error; @@ -305,17 +334,20 @@ clnt_vc_call( mtx_lock(&ct->ct_lock); - if (ct->ct_closing) { + if (ct->ct_closing || ct->ct_closed) { mtx_unlock(&ct->ct_lock); free(cr, M_RPC); return (RPC_CANTSEND); } ct->ct_threads++; - if (ext) + if (ext) { auth = ext->rc_auth; - else + errp = &ext->rc_err; + } else { auth = cl->cl_auth; + errp = &ct->ct_error; + } cr->cr_mrep = NULL; cr->cr_error = 0; @@ -338,10 +370,11 @@ call_again: * Leave space to pre-pend the record mark. */ MGETHDR(mreq, M_WAIT, MT_DATA); - MCLGET(mreq, M_WAIT); - mreq->m_len = 0; mreq->m_data += sizeof(uint32_t); - m_append(mreq, ct->ct_mpos, ct->ct_mcallc); + KASSERT(ct->ct_mpos + sizeof(uint32_t) <= MHLEN, + ("RPC header too big")); + bcopy(ct->ct_mcallc, mreq->m_data, ct->ct_mpos); + mreq->m_len = ct->ct_mpos; /* * The XID is the first thing in the request. @@ -350,17 +383,16 @@ call_again: xdrmbuf_create(&xdrs, mreq, XDR_ENCODE); - ct->ct_error.re_status = RPC_SUCCESS; + errp->re_status = stat = RPC_SUCCESS; if ((! XDR_PUTINT32(&xdrs, &proc)) || - (! AUTH_MARSHALL(auth, &xdrs)) || - (! (*xdr_args)(&xdrs, args_ptr))) { - if (ct->ct_error.re_status == RPC_SUCCESS) - ct->ct_error.re_status = RPC_CANTENCODEARGS; + (! AUTH_MARSHALL(auth, xid, &xdrs, + m_copym(args, 0, M_COPYALL, M_WAITOK)))) { + errp->re_status = stat = RPC_CANTENCODEARGS; mtx_lock(&ct->ct_lock); goto out; } - m_fixhdr(mreq); + mreq->m_pkthdr.len = m_length(mreq, NULL); /* * Prepend a record marker containing the packet length. @@ -379,16 +411,27 @@ call_again: */ error = sosend(ct->ct_socket, NULL, NULL, mreq, NULL, 0, curthread); mreq = NULL; + if (error == EMSGSIZE) { + SOCKBUF_LOCK(&ct->ct_socket->so_snd); + sbwait(&ct->ct_socket->so_snd); + SOCKBUF_UNLOCK(&ct->ct_socket->so_snd); + AUTH_VALIDATE(auth, xid, NULL, NULL); + mtx_lock(&ct->ct_lock); + TAILQ_REMOVE(&ct->ct_pending, cr, cr_link); + goto call_again; + } - reply_msg.acpted_rply.ar_verf = _null_auth; - reply_msg.acpted_rply.ar_results.where = results_ptr; - reply_msg.acpted_rply.ar_results.proc = xdr_results; + reply_msg.acpted_rply.ar_verf.oa_flavor = AUTH_NULL; + reply_msg.acpted_rply.ar_verf.oa_base = cr->cr_verf; + reply_msg.acpted_rply.ar_verf.oa_length = 0; + reply_msg.acpted_rply.ar_results.where = NULL; + reply_msg.acpted_rply.ar_results.proc = (xdrproc_t)xdr_void; mtx_lock(&ct->ct_lock); if (error) { TAILQ_REMOVE(&ct->ct_pending, cr, cr_link); - ct->ct_error.re_errno = error; - ct->ct_error.re_status = RPC_CANTSEND; + errp->re_errno = error; + errp->re_status = stat = RPC_CANTSEND; goto out; } @@ -399,8 +442,8 @@ call_again: */ if (cr->cr_error) { TAILQ_REMOVE(&ct->ct_pending, cr, cr_link); - ct->ct_error.re_errno = cr->cr_error; - ct->ct_error.re_status = RPC_CANTRECV; + errp->re_errno = cr->cr_error; + errp->re_status = stat = RPC_CANTRECV; goto out; } if (cr->cr_mrep) { @@ -413,7 +456,7 @@ call_again: */ if (timeout.tv_sec == 0 && timeout.tv_usec == 0) { TAILQ_REMOVE(&ct->ct_pending, cr, cr_link); - ct->ct_error.re_status = RPC_TIMEDOUT; + errp->re_status = stat = RPC_TIMEDOUT; goto out; } @@ -428,17 +471,18 @@ call_again: * on the list. Turn the error code into an * appropriate client status. */ - ct->ct_error.re_errno = error; + errp->re_errno = error; switch (error) { case EINTR: - ct->ct_error.re_status = RPC_INTR; + stat = RPC_INTR; break; case EWOULDBLOCK: - ct->ct_error.re_status = RPC_TIMEDOUT; + stat = RPC_TIMEDOUT; break; default: - ct->ct_error.re_status = RPC_CANTRECV; + stat = RPC_CANTRECV; } + errp->re_status = stat; goto out; } else { /* @@ -447,8 +491,8 @@ call_again: * otherwise we have a reply. */ if (cr->cr_error) { - ct->ct_error.re_errno = cr->cr_error; - ct->ct_error.re_status = RPC_CANTRECV; + errp->re_errno = cr->cr_error; + errp->re_status = stat = RPC_CANTRECV; goto out; } } @@ -460,51 +504,59 @@ got_reply: */ mtx_unlock(&ct->ct_lock); + if (ext && ext->rc_feedback) + ext->rc_feedback(FEEDBACK_OK, proc, ext->rc_feedback_arg); + xdrmbuf_create(&xdrs, cr->cr_mrep, XDR_DECODE); ok = xdr_replymsg(&xdrs, &reply_msg); - XDR_DESTROY(&xdrs); cr->cr_mrep = NULL; - mtx_lock(&ct->ct_lock); - if (ok) { if ((reply_msg.rm_reply.rp_stat == MSG_ACCEPTED) && - (reply_msg.acpted_rply.ar_stat == SUCCESS)) - ct->ct_error.re_status = RPC_SUCCESS; + (reply_msg.acpted_rply.ar_stat == SUCCESS)) + errp->re_status = stat = RPC_SUCCESS; else - _seterr_reply(&reply_msg, &(ct->ct_error)); - - if (ct->ct_error.re_status == RPC_SUCCESS) { - if (! AUTH_VALIDATE(cl->cl_auth, - &reply_msg.acpted_rply.ar_verf)) { - ct->ct_error.re_status = RPC_AUTHERROR; - ct->ct_error.re_why = AUTH_INVALIDRESP; - } - if (reply_msg.acpted_rply.ar_verf.oa_base != NULL) { - xdrs.x_op = XDR_FREE; - (void) xdr_opaque_auth(&xdrs, - &(reply_msg.acpted_rply.ar_verf)); + stat = _seterr_reply(&reply_msg, errp); + + if (stat == RPC_SUCCESS) { + results = xdrmbuf_getall(&xdrs); + if (!AUTH_VALIDATE(auth, xid, + &reply_msg.acpted_rply.ar_verf, + &results)) { + errp->re_status = stat = RPC_AUTHERROR; + errp->re_why = AUTH_INVALIDRESP; + } else { + KASSERT(results, + ("auth validated but no result")); + *resultsp = results; } } /* end successful completion */ /* * If unsuccesful AND error is an authentication error * then refresh credentials and try again, else break */ - else if (ct->ct_error.re_status == RPC_AUTHERROR) + else if (stat == RPC_AUTHERROR) /* maybe our credentials need to be refreshed ... */ if (nrefreshes > 0 && - AUTH_REFRESH(cl->cl_auth, &reply_msg)) { + AUTH_REFRESH(auth, &reply_msg)) { nrefreshes--; + XDR_DESTROY(&xdrs); + mtx_lock(&ct->ct_lock); goto call_again; } /* end of unsuccessful completion */ } /* end of valid reply message */ else { - ct->ct_error.re_status = RPC_CANTDECODERES; + errp->re_status = stat = RPC_CANTDECODERES; } + XDR_DESTROY(&xdrs); + mtx_lock(&ct->ct_lock); out: mtx_assert(&ct->ct_lock, MA_OWNED); + KASSERT(stat != RPC_SUCCESS || *resultsp, + ("RPC_SUCCESS without reply")); + if (mreq) m_freem(mreq); if (cr->cr_mrep) @@ -516,9 +568,12 @@ out: mtx_unlock(&ct->ct_lock); + if (auth && stat != RPC_SUCCESS) + AUTH_VALIDATE(auth, xid, NULL, NULL); + free(cr, M_RPC); - return (ct->ct_error.re_status); + return (stat); } static void @@ -642,7 +697,7 @@ clnt_vc_control(CLIENT *cl, u_int request, void *info) break; case CLSET_WAITCHAN: - ct->ct_waitchan = *(const char **)info; + ct->ct_waitchan = (const char *)info; break; case CLGET_WAITCHAN: @@ -673,14 +728,26 @@ clnt_vc_control(CLIENT *cl, u_int request, void *info) } static void -clnt_vc_destroy(CLIENT *cl) +clnt_vc_close(CLIENT *cl) { struct ct_data *ct = (struct ct_data *) cl->cl_private; struct ct_request *cr; - struct socket *so = NULL; mtx_lock(&ct->ct_lock); + if (ct->ct_closed) { + mtx_unlock(&ct->ct_lock); + return; + } + + if (ct->ct_closing) { + while (ct->ct_closing) + msleep(ct, &ct->ct_lock, 0, "rpcclose", 0); + KASSERT(ct->ct_closed, ("client should be closed")); + mtx_unlock(&ct->ct_lock); + return; + } + if (ct->ct_socket) { SOCKBUF_LOCK(&ct->ct_socket->so_rcv); ct->ct_socket->so_upcallarg = NULL; @@ -701,7 +768,25 @@ clnt_vc_destroy(CLIENT *cl) while (ct->ct_threads) msleep(ct, &ct->ct_lock, 0, "rpcclose", 0); + } + + ct->ct_closing = FALSE; + ct->ct_closed = TRUE; + mtx_unlock(&ct->ct_lock); + wakeup(ct); +} +static void +clnt_vc_destroy(CLIENT *cl) +{ + struct ct_data *ct = (struct ct_data *) cl->cl_private; + struct socket *so = NULL; + + clnt_vc_close(cl); + + mtx_lock(&ct->ct_lock); + + if (ct->ct_socket) { if (ct->ct_closeit) { so = ct->ct_socket; } @@ -738,6 +823,7 @@ clnt_vc_soupcall(struct socket *so, void *arg, int waitflag) struct ct_request *cr; int error, rcvflag, foundreq; uint32_t xid, header; + bool_t do_read; uio.uio_td = curthread; do { @@ -746,7 +832,6 @@ clnt_vc_soupcall(struct socket *so, void *arg, int waitflag) * record mark. */ if (ct->ct_record_resid == 0) { - bool_t do_read; /* * Make sure there is either a whole record @@ -795,7 +880,7 @@ clnt_vc_soupcall(struct socket *so, void *arg, int waitflag) mtx_unlock(&ct->ct_lock); break; } - memcpy(&header, mtod(m, uint32_t *), sizeof(uint32_t)); + bcopy(mtod(m, uint32_t *), &header, sizeof(uint32_t)); header = ntohl(header); ct->ct_record = NULL; ct->ct_record_resid = header & 0x7fffffff; @@ -803,6 +888,21 @@ clnt_vc_soupcall(struct socket *so, void *arg, int waitflag) m_freem(m); } else { /* + * Wait until the socket has the whole record + * buffered. + */ + do_read = FALSE; + SOCKBUF_LOCK(&so->so_rcv); + if (so->so_rcv.sb_cc >= ct->ct_record_resid + || (so->so_rcv.sb_state & SBS_CANTRCVMORE) + || so->so_error) + do_read = TRUE; + SOCKBUF_UNLOCK(&so->so_rcv); + + if (!do_read) + return; + + /* * We have the record mark. Read as much as * the socket has buffered up to the end of * this record. @@ -839,13 +939,14 @@ clnt_vc_soupcall(struct socket *so, void *arg, int waitflag) * The XID is in the first uint32_t of * the reply. */ - ct->ct_record = - m_pullup(ct->ct_record, sizeof(xid)); + if (ct->ct_record->m_len < sizeof(xid)) + ct->ct_record = + m_pullup(ct->ct_record, + sizeof(xid)); if (!ct->ct_record) break; - memcpy(&xid, - mtod(ct->ct_record, uint32_t *), - sizeof(uint32_t)); + bcopy(mtod(ct->ct_record, uint32_t *), + &xid, sizeof(uint32_t)); xid = ntohl(xid); mtx_lock(&ct->ct_lock); diff --git a/sys/rpc/replay.c b/sys/rpc/replay.c new file mode 100644 index 0000000..d82fc20 --- /dev/null +++ b/sys/rpc/replay.c @@ -0,0 +1,248 @@ +/*- + * Copyright (c) 2008 Isilon Inc http://www.isilon.com/ + * Authors: Doug Rabson <dfr@rabson.org> + * Developed with Red Inc: Alfred Perlstein <alfred@freebsd.org> + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ + +#include <sys/cdefs.h> +__FBSDID("$FreeBSD$"); + +#include <sys/param.h> +#include <sys/hash.h> +#include <sys/kernel.h> +#include <sys/lock.h> +#include <sys/mbuf.h> +#include <sys/mutex.h> +#include <sys/queue.h> + +#include <rpc/rpc.h> +#include <rpc/replay.h> + +struct replay_cache_entry { + int rce_hash; + struct rpc_msg rce_msg; + struct sockaddr_storage rce_addr; + struct rpc_msg rce_repmsg; + struct mbuf *rce_repbody; + + TAILQ_ENTRY(replay_cache_entry) rce_link; + TAILQ_ENTRY(replay_cache_entry) rce_alllink; +}; +TAILQ_HEAD(replay_cache_list, replay_cache_entry); + +static struct replay_cache_entry * + replay_alloc(struct replay_cache *rc, struct rpc_msg *msg, + struct sockaddr *addr, int h); +static void replay_free(struct replay_cache *rc, + struct replay_cache_entry *rce); +static void replay_prune(struct replay_cache *rc); + +#define REPLAY_HASH_SIZE 256 +#define REPLAY_MAX 1024 + +struct replay_cache { + struct replay_cache_list rc_cache[REPLAY_HASH_SIZE]; + struct replay_cache_list rc_all; + struct mtx rc_lock; + int rc_count; + size_t rc_size; + size_t rc_maxsize; +}; + +struct replay_cache * +replay_newcache(size_t maxsize) +{ + struct replay_cache *rc; + int i; + + rc = malloc(sizeof(*rc), M_RPC, M_WAITOK|M_ZERO); + for (i = 0; i < REPLAY_HASH_SIZE; i++) + TAILQ_INIT(&rc->rc_cache[i]); + TAILQ_INIT(&rc->rc_all); + mtx_init(&rc->rc_lock, "rc_lock", NULL, MTX_DEF); + rc->rc_maxsize = maxsize; + + return (rc); +} + +void +replay_setsize(struct replay_cache *rc, size_t newmaxsize) +{ + + rc->rc_maxsize = newmaxsize; + replay_prune(rc); +} + +void +replay_freecache(struct replay_cache *rc) +{ + + mtx_lock(&rc->rc_lock); + while (TAILQ_FIRST(&rc->rc_all)) + replay_free(rc, TAILQ_FIRST(&rc->rc_all)); + mtx_destroy(&rc->rc_lock); + free(rc, M_RPC); +} + +static struct replay_cache_entry * +replay_alloc(struct replay_cache *rc, + struct rpc_msg *msg, struct sockaddr *addr, int h) +{ + struct replay_cache_entry *rce; + + rc->rc_count++; + rce = malloc(sizeof(*rce), M_RPC, M_NOWAIT|M_ZERO); + rce->rce_hash = h; + rce->rce_msg = *msg; + bcopy(addr, &rce->rce_addr, addr->sa_len); + + TAILQ_INSERT_HEAD(&rc->rc_cache[h], rce, rce_link); + TAILQ_INSERT_HEAD(&rc->rc_all, rce, rce_alllink); + + return (rce); +} + +static void +replay_free(struct replay_cache *rc, struct replay_cache_entry *rce) +{ + + rc->rc_count--; + TAILQ_REMOVE(&rc->rc_cache[rce->rce_hash], rce, rce_link); + TAILQ_REMOVE(&rc->rc_all, rce, rce_alllink); + if (rce->rce_repbody) { + rc->rc_size -= m_length(rce->rce_repbody, NULL); + m_freem(rce->rce_repbody); + } + free(rce, M_RPC); +} + +static void +replay_prune(struct replay_cache *rc) +{ + struct replay_cache_entry *rce; + bool_t freed_one; + + if (rc->rc_count >= REPLAY_MAX || rc->rc_size > rc->rc_maxsize) { + freed_one = FALSE; + do { + /* + * Try to free an entry. Don't free in-progress entries + */ + TAILQ_FOREACH_REVERSE(rce, &rc->rc_all, + replay_cache_list, rce_alllink) { + if (rce->rce_repmsg.rm_xid) { + replay_free(rc, rce); + freed_one = TRUE; + break; + } + } + } while (freed_one + && (rc->rc_count >= REPLAY_MAX + || rc->rc_size > rc->rc_maxsize)); + } +} + +enum replay_state +replay_find(struct replay_cache *rc, struct rpc_msg *msg, + struct sockaddr *addr, struct rpc_msg *repmsg, struct mbuf **mp) +{ + int h = HASHSTEP(HASHINIT, msg->rm_xid) % REPLAY_HASH_SIZE; + struct replay_cache_entry *rce; + + mtx_lock(&rc->rc_lock); + TAILQ_FOREACH(rce, &rc->rc_cache[h], rce_link) { + if (rce->rce_msg.rm_xid == msg->rm_xid + && rce->rce_msg.rm_call.cb_prog == msg->rm_call.cb_prog + && rce->rce_msg.rm_call.cb_vers == msg->rm_call.cb_vers + && rce->rce_msg.rm_call.cb_proc == msg->rm_call.cb_proc + && rce->rce_addr.ss_len == addr->sa_len + && bcmp(&rce->rce_addr, addr, addr->sa_len) == 0) { + if (rce->rce_repmsg.rm_xid) { + /* + * We have a reply for this + * message. Copy it and return. Keep + * replay_all LRU sorted + */ + TAILQ_REMOVE(&rc->rc_all, rce, rce_alllink); + TAILQ_INSERT_HEAD(&rc->rc_all, rce, + rce_alllink); + *repmsg = rce->rce_repmsg; + if (rce->rce_repbody) { + *mp = m_copym(rce->rce_repbody, + 0, M_COPYALL, M_NOWAIT); + mtx_unlock(&rc->rc_lock); + if (!*mp) + return (RS_ERROR); + } else { + mtx_unlock(&rc->rc_lock); + } + return (RS_DONE); + } else { + mtx_unlock(&rc->rc_lock); + return (RS_INPROGRESS); + } + } + } + + replay_prune(rc); + + rce = replay_alloc(rc, msg, addr, h); + + mtx_unlock(&rc->rc_lock); + + if (!rce) + return (RS_ERROR); + else + return (RS_NEW); +} + +void +replay_setreply(struct replay_cache *rc, + struct rpc_msg *repmsg, struct sockaddr *addr, struct mbuf *m) +{ + int h = HASHSTEP(HASHINIT, repmsg->rm_xid) % REPLAY_HASH_SIZE; + struct replay_cache_entry *rce; + + /* + * Copy the reply before the lock so we can sleep. + */ + if (m) + m = m_copym(m, 0, M_COPYALL, M_WAITOK); + + mtx_lock(&rc->rc_lock); + TAILQ_FOREACH(rce, &rc->rc_cache[h], rce_link) { + if (rce->rce_msg.rm_xid == repmsg->rm_xid + && rce->rce_addr.ss_len == addr->sa_len + && bcmp(&rce->rce_addr, addr, addr->sa_len) == 0) { + break; + } + } + if (rce) { + rce->rce_repmsg = *repmsg; + rce->rce_repbody = m; + if (m) + rc->rc_size += m_length(m, NULL); + } + mtx_unlock(&rc->rc_lock); +} diff --git a/sys/rpc/replay.h b/sys/rpc/replay.h new file mode 100644 index 0000000..0ef7bf3 --- /dev/null +++ b/sys/rpc/replay.h @@ -0,0 +1,85 @@ +/*- + * Copyright (c) 2008 Isilon Inc http://www.isilon.com/ + * Authors: Doug Rabson <dfr@rabson.org> + * Developed with Red Inc: Alfred Perlstein <alfred@freebsd.org> + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + * + * $FreeBSD$ + */ + +#ifndef _RPC_REPLAY_H +#define _RPC_REPLAY_H + +enum replay_state { + RS_NEW, /* new request - caller should execute */ + RS_DONE, /* request was executed and reply sent */ + RS_INPROGRESS, /* request is being executed now */ + RS_ERROR /* allocation or other failure */ +}; + +struct replay_cache; + +/* + * Create a new replay cache. + */ +struct replay_cache *replay_newcache(size_t); + +/* + * Set the replay cache size. + */ +void replay_setsize(struct replay_cache *, size_t); + +/* + * Free a replay cache. Caller must ensure that no cache entries are + * in-progress. + */ +void replay_freecache(struct replay_cache *rc); + +/* + * Check a replay cache for a message from a given address. + * + * If this is a new request, RS_NEW is returned. Caller should call + * replay_setreply with the results of the request. + * + * If this is a request which is currently executing + * (i.e. replay_setreply hasn't been called for it yet), RS_INPROGRESS + * is returned. The caller should silently drop the request. + * + * If a reply to this message already exists, *repmsg and *mp are set + * to point at the reply and, RS_DONE is returned. The caller should + * re-send this reply. + * + * If the attempt to update the replay cache or copy a replay failed + * for some reason (typically memory shortage), RS_ERROR is returned. + */ +enum replay_state replay_find(struct replay_cache *rc, + struct rpc_msg *msg, struct sockaddr *addr, + struct rpc_msg *repmsg, struct mbuf **mp); + +/* + * Call this after executing a request to record the reply. + */ +void replay_setreply(struct replay_cache *rc, + struct rpc_msg *repmsg, struct sockaddr *addr, struct mbuf *m); + +#endif /* !_RPC_REPLAY_H */ diff --git a/sys/rpc/rpc_com.h b/sys/rpc/rpc_com.h index ad9cc68..e50e513 100644 --- a/sys/rpc/rpc_com.h +++ b/sys/rpc/rpc_com.h @@ -115,6 +115,7 @@ extern const char *__rpc_inet_ntop(int af, const void * __restrict src, char * __restrict dst, socklen_t size); extern int __rpc_inet_pton(int af, const char * __restrict src, void * __restrict dst); +extern int bindresvport(struct socket *so, struct sockaddr *sa); struct xucred; struct __rpc_xdr; diff --git a/sys/rpc/rpc_generic.c b/sys/rpc/rpc_generic.c index ee8ee8a..d9100b3 100644 --- a/sys/rpc/rpc_generic.c +++ b/sys/rpc/rpc_generic.c @@ -46,6 +46,7 @@ __FBSDID("$FreeBSD$"); #include <sys/param.h> #include <sys/kernel.h> #include <sys/malloc.h> +#include <sys/mbuf.h> #include <sys/module.h> #include <sys/proc.h> #include <sys/protosw.h> @@ -722,6 +723,139 @@ __rpc_sockisbound(struct socket *so) } /* + * Implement XDR-style API for RPC call. + */ +enum clnt_stat +clnt_call_private( + CLIENT *cl, /* client handle */ + struct rpc_callextra *ext, /* call metadata */ + rpcproc_t proc, /* procedure number */ + xdrproc_t xargs, /* xdr routine for args */ + void *argsp, /* pointer to args */ + xdrproc_t xresults, /* xdr routine for results */ + void *resultsp, /* pointer to results */ + struct timeval utimeout) /* seconds to wait before giving up */ +{ + XDR xdrs; + struct mbuf *mreq; + struct mbuf *mrep; + enum clnt_stat stat; + + MGET(mreq, M_WAIT, MT_DATA); + MCLGET(mreq, M_WAIT); + mreq->m_len = 0; + + xdrmbuf_create(&xdrs, mreq, XDR_ENCODE); + if (!xargs(&xdrs, argsp)) { + m_freem(mreq); + return (RPC_CANTENCODEARGS); + } + XDR_DESTROY(&xdrs); + + stat = CLNT_CALL_MBUF(cl, ext, proc, mreq, &mrep, utimeout); + m_freem(mreq); + + if (stat == RPC_SUCCESS) { + xdrmbuf_create(&xdrs, mrep, XDR_DECODE); + if (!xresults(&xdrs, resultsp)) { + XDR_DESTROY(&xdrs); + return (RPC_CANTDECODERES); + } + XDR_DESTROY(&xdrs); + } + + return (stat); +} + +/* + * Bind a socket to a privileged IP port + */ +int +bindresvport(struct socket *so, struct sockaddr *sa) +{ + int old, error, af; + bool_t freesa = FALSE; + struct sockaddr_in *sin; +#ifdef INET6 + struct sockaddr_in6 *sin6; +#endif + struct sockopt opt; + int proto, portrange, portlow; + u_int16_t *portp; + socklen_t salen; + + if (sa == NULL) { + error = so->so_proto->pr_usrreqs->pru_sockaddr(so, &sa); + if (error) + return (error); + freesa = TRUE; + af = sa->sa_family; + salen = sa->sa_len; + memset(sa, 0, sa->sa_len); + } else { + af = sa->sa_family; + salen = sa->sa_len; + } + + switch (af) { + case AF_INET: + proto = IPPROTO_IP; + portrange = IP_PORTRANGE; + portlow = IP_PORTRANGE_LOW; + sin = (struct sockaddr_in *)sa; + portp = &sin->sin_port; + break; +#ifdef INET6 + case AF_INET6: + proto = IPPROTO_IPV6; + portrange = IPV6_PORTRANGE; + portlow = IPV6_PORTRANGE_LOW; + sin6 = (struct sockaddr_in6 *)sa; + portp = &sin6->sin6_port; + break; +#endif + default: + return (EPFNOSUPPORT); + } + + sa->sa_family = af; + sa->sa_len = salen; + + if (*portp == 0) { + bzero(&opt, sizeof(opt)); + opt.sopt_dir = SOPT_GET; + opt.sopt_level = proto; + opt.sopt_name = portrange; + opt.sopt_val = &old; + opt.sopt_valsize = sizeof(old); + error = sogetopt(so, &opt); + if (error) + goto out; + + opt.sopt_dir = SOPT_SET; + opt.sopt_val = &portlow; + error = sosetopt(so, &opt); + if (error) + goto out; + } + + error = sobind(so, sa, curthread); + + if (*portp == 0) { + if (error) { + opt.sopt_dir = SOPT_SET; + opt.sopt_val = &old; + sosetopt(so, &opt); + } + } +out: + if (freesa) + free(sa, M_SONAME); + + return (error); +} + +/* * Kernel module glue */ static int diff --git a/sys/rpc/rpc_msg.h b/sys/rpc/rpc_msg.h index 707250a..ff2a6d8 100644 --- a/sys/rpc/rpc_msg.h +++ b/sys/rpc/rpc_msg.h @@ -208,7 +208,7 @@ extern bool_t xdr_rejected_reply(XDR *, struct rejected_reply *); * struct rpc_msg *msg; * struct rpc_err *error; */ -extern void _seterr_reply(struct rpc_msg *, struct rpc_err *); +extern enum clnt_stat _seterr_reply(struct rpc_msg *, struct rpc_err *); __END_DECLS #endif /* !_RPC_RPC_MSG_H */ diff --git a/sys/rpc/rpc_prot.c b/sys/rpc/rpc_prot.c index 16f602f..294c4e3 100644 --- a/sys/rpc/rpc_prot.c +++ b/sys/rpc/rpc_prot.c @@ -64,8 +64,8 @@ MALLOC_DEFINE(M_RPC, "rpc", "Remote Procedure Call"); #define assert(exp) KASSERT(exp, ("bad arguments")) -static void accepted(enum accept_stat, struct rpc_err *); -static void rejected(enum reject_stat, struct rpc_err *); +static enum clnt_stat accepted(enum accept_stat, struct rpc_err *); +static enum clnt_stat rejected(enum reject_stat, struct rpc_err *); /* * * * * * * * * * * * * * XDR Authentication * * * * * * * * * * * */ @@ -111,7 +111,11 @@ xdr_accepted_reply(XDR *xdrs, struct accepted_reply *ar) switch (ar->ar_stat) { case SUCCESS: - return ((*(ar->ar_results.proc))(xdrs, ar->ar_results.where)); + if (ar->ar_results.proc != (xdrproc_t) xdr_void) + return ((*(ar->ar_results.proc))(xdrs, + ar->ar_results.where)); + else + return (TRUE); case PROG_MISMATCH: if (! xdr_uint32_t(xdrs, &(ar->ar_vers.low))) @@ -171,12 +175,34 @@ static const struct xdr_discrim reply_dscrm[3] = { bool_t xdr_replymsg(XDR *xdrs, struct rpc_msg *rmsg) { + int32_t *buf; enum msg_type *prm_direction; enum reply_stat *prp_stat; assert(xdrs != NULL); assert(rmsg != NULL); + if (xdrs->x_op == XDR_DECODE) { + buf = XDR_INLINE(xdrs, 3 * BYTES_PER_XDR_UNIT); + if (buf != NULL) { + rmsg->rm_xid = IXDR_GET_UINT32(buf); + rmsg->rm_direction = IXDR_GET_ENUM(buf, enum msg_type); + if (rmsg->rm_direction != REPLY) { + return (FALSE); + } + rmsg->rm_reply.rp_stat = + IXDR_GET_ENUM(buf, enum reply_stat); + if (rmsg->rm_reply.rp_stat == MSG_ACCEPTED) + return (xdr_accepted_reply(xdrs, + &rmsg->acpted_rply)); + else if (rmsg->rm_reply.rp_stat == MSG_DENIED) + return (xdr_rejected_reply(xdrs, + &rmsg->rjcted_rply)); + else + return (FALSE); + } + } + prm_direction = &rmsg->rm_direction; prp_stat = &rmsg->rm_reply.rp_stat; @@ -220,7 +246,7 @@ xdr_callhdr(XDR *xdrs, struct rpc_msg *cmsg) /* ************************** Client utility routine ************* */ -static void +static enum clnt_stat accepted(enum accept_stat acpt_stat, struct rpc_err *error) { @@ -230,36 +256,32 @@ accepted(enum accept_stat acpt_stat, struct rpc_err *error) case PROG_UNAVAIL: error->re_status = RPC_PROGUNAVAIL; - return; + return (RPC_PROGUNAVAIL); case PROG_MISMATCH: error->re_status = RPC_PROGVERSMISMATCH; - return; + return (RPC_PROGVERSMISMATCH); case PROC_UNAVAIL: - error->re_status = RPC_PROCUNAVAIL; - return; + return (RPC_PROCUNAVAIL); case GARBAGE_ARGS: - error->re_status = RPC_CANTDECODEARGS; - return; + return (RPC_CANTDECODEARGS); case SYSTEM_ERR: - error->re_status = RPC_SYSTEMERROR; - return; + return (RPC_SYSTEMERROR); case SUCCESS: - error->re_status = RPC_SUCCESS; - return; + return (RPC_SUCCESS); } /* NOTREACHED */ /* something's wrong, but we don't know what ... */ - error->re_status = RPC_FAILED; error->re_lb.s1 = (int32_t)MSG_ACCEPTED; error->re_lb.s2 = (int32_t)acpt_stat; + return (RPC_FAILED); } -static void +static enum clnt_stat rejected(enum reject_stat rjct_stat, struct rpc_err *error) { @@ -267,26 +289,25 @@ rejected(enum reject_stat rjct_stat, struct rpc_err *error) switch (rjct_stat) { case RPC_MISMATCH: - error->re_status = RPC_VERSMISMATCH; - return; + return (RPC_VERSMISMATCH); case AUTH_ERROR: - error->re_status = RPC_AUTHERROR; - return; + return (RPC_AUTHERROR); } /* something's wrong, but we don't know what ... */ /* NOTREACHED */ - error->re_status = RPC_FAILED; error->re_lb.s1 = (int32_t)MSG_DENIED; error->re_lb.s2 = (int32_t)rjct_stat; + return (RPC_FAILED); } /* * given a reply message, fills in the error */ -void +enum clnt_stat _seterr_reply(struct rpc_msg *msg, struct rpc_err *error) { + enum clnt_stat stat; assert(msg != NULL); assert(error != NULL); @@ -296,22 +317,24 @@ _seterr_reply(struct rpc_msg *msg, struct rpc_err *error) case MSG_ACCEPTED: if (msg->acpted_rply.ar_stat == SUCCESS) { - error->re_status = RPC_SUCCESS; - return; + stat = RPC_SUCCESS; + return (stat); } - accepted(msg->acpted_rply.ar_stat, error); + stat = accepted(msg->acpted_rply.ar_stat, error); break; case MSG_DENIED: - rejected(msg->rjcted_rply.rj_stat, error); + stat = rejected(msg->rjcted_rply.rj_stat, error); break; default: - error->re_status = RPC_FAILED; + stat = RPC_FAILED; error->re_lb.s1 = (int32_t)(msg->rm_reply.rp_stat); break; } - switch (error->re_status) { + error->re_status = stat; + + switch (stat) { case RPC_VERSMISMATCH: error->re_vers.low = msg->rjcted_rply.rj_vers.low; @@ -345,4 +368,6 @@ _seterr_reply(struct rpc_msg *msg, struct rpc_err *error) default: break; } + + return (stat); } diff --git a/sys/rpc/rpcsec_gss.h b/sys/rpc/rpcsec_gss.h new file mode 100644 index 0000000..563205c --- /dev/null +++ b/sys/rpc/rpcsec_gss.h @@ -0,0 +1,189 @@ +/*- + * Copyright (c) 2008 Doug Rabson + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + * + * $FreeBSD$ + */ + +#ifndef _RPCSEC_GSS_H +#define _RPCSEC_GSS_H + +#include <kgssapi/gssapi.h> + +#ifndef MAX_GSS_MECH +#define MAX_GSS_MECH 64 +#endif + +/* + * Define the types of security service required for rpc_gss_seccreate(). + */ +typedef enum { + rpc_gss_svc_default = 0, + rpc_gss_svc_none = 1, + rpc_gss_svc_integrity = 2, + rpc_gss_svc_privacy = 3 +} rpc_gss_service_t; + +/* + * Structure containing options for rpc_gss_seccreate(). + */ +typedef struct { + int req_flags; /* GSS request bits */ + int time_req; /* requested credential lifetime */ + gss_cred_id_t my_cred; /* GSS credential */ + gss_channel_bindings_t input_channel_bindings; +} rpc_gss_options_req_t; + +/* + * Structure containing options returned by rpc_gss_seccreate(). + */ +typedef struct { + int major_status; + int minor_status; + u_int rpcsec_version; + int ret_flags; + int time_req; + gss_ctx_id_t gss_context; + char actual_mechanism[MAX_GSS_MECH]; +} rpc_gss_options_ret_t; + +/* + * Client principal type. Used as an argument to + * rpc_gss_get_principal_name(). Also referenced by the + * rpc_gss_rawcred_t structure. + */ +typedef struct { + int len; + char name[1]; +} *rpc_gss_principal_t; + +/* + * Structure for raw credentials used by rpc_gss_getcred() and + * rpc_gss_set_callback(). + */ +typedef struct { + u_int version; /* RPC version number */ + const char *mechanism; /* security mechanism */ + const char *qop; /* quality of protection */ + rpc_gss_principal_t client_principal; /* client name */ + const char *svc_principal; /* server name */ + rpc_gss_service_t service; /* service type */ +} rpc_gss_rawcred_t; + +/* + * Unix credentials derived from raw credentials. Returned by + * rpc_gss_getcred(). + */ +typedef struct { + uid_t uid; /* user ID */ + gid_t gid; /* group ID */ + short gidlen; + gid_t *gidlist; /* list of groups */ +} rpc_gss_ucred_t; + +/* + * Structure used to enforce a particular QOP and service. + */ +typedef struct { + bool_t locked; + rpc_gss_rawcred_t *raw_cred; +} rpc_gss_lock_t; + +/* + * Callback structure used by rpc_gss_set_callback(). + */ +typedef struct { + u_int program; /* RPC program number */ + u_int version; /* RPC version number */ + /* user defined callback */ + bool_t (*callback)(struct svc_req *req, + gss_cred_id_t deleg, + gss_ctx_id_t gss_context, + rpc_gss_lock_t *lock, + void **cookie); +} rpc_gss_callback_t; + +/* + * Structure used to return error information by rpc_gss_get_error() + */ +typedef struct { + int rpc_gss_error; + int system_error; /* same as errno */ +} rpc_gss_error_t; + +/* + * Values for rpc_gss_error + */ +#define RPC_GSS_ER_SUCCESS 0 /* no error */ +#define RPC_GSS_ER_SYSTEMERROR 1 /* system error */ + +__BEGIN_DECLS + +#ifdef _KERNEL +AUTH *rpc_gss_secfind(CLIENT *clnt, struct ucred *cred, + const char *principal, gss_OID mech_oid, rpc_gss_service_t service); +void rpc_gss_secpurge(CLIENT *clnt); +#endif +AUTH *rpc_gss_seccreate(CLIENT *clnt, struct ucred *cred, + const char *principal, const char *mechanism, rpc_gss_service_t service, + const char *qop, rpc_gss_options_req_t *options_req, + rpc_gss_options_ret_t *options_ret); +bool_t rpc_gss_set_defaults(AUTH *auth, rpc_gss_service_t service, + const char *qop); +int rpc_gss_max_data_length(AUTH *handle, int max_tp_unit_len); +void rpc_gss_get_error(rpc_gss_error_t *error); + +bool_t rpc_gss_mech_to_oid(const char *mech, gss_OID *oid_ret); +bool_t rpc_gss_oid_to_mech(gss_OID oid, const char **mech_ret); +bool_t rpc_gss_qop_to_num(const char *qop, const char *mech, u_int *num_ret); +const char **rpc_gss_get_mechanisms(void); +const char **rpc_gss_get_mech_info(const char *mech, rpc_gss_service_t *service); +bool_t rpc_gss_get_versions(u_int *vers_hi, u_int *vers_lo); +bool_t rpc_gss_is_installed(const char *mech); + +bool_t rpc_gss_set_svc_name(const char *principal, const char *mechanism, + u_int req_time, u_int program, u_int version); +void rpc_gss_clear_svc_name(u_int program, u_int version); +bool_t rpc_gss_getcred(struct svc_req *req, rpc_gss_rawcred_t **rcred, + rpc_gss_ucred_t **ucred, void **cookie); +bool_t rpc_gss_set_callback(rpc_gss_callback_t *cb); +void rpc_gss_clear_callback(rpc_gss_callback_t *cb); +bool_t rpc_gss_get_principal_name(rpc_gss_principal_t *principal, + const char *mech, const char *name, const char *node, const char *domain); +int rpc_gss_svc_max_data_length(struct svc_req *req, int max_tp_unit_len); + +/* + * Internal interface from the RPC implementation. + */ +#ifndef _KERNEL +bool_t __rpc_gss_wrap(AUTH *auth, void *header, size_t headerlen, + XDR* xdrs, xdrproc_t xdr_args, void *args_ptr); +bool_t __rpc_gss_unwrap(AUTH *auth, XDR* xdrs, xdrproc_t xdr_args, + void *args_ptr); +#endif +bool_t __rpc_gss_set_error(int rpc_gss_error, int system_error); + +__END_DECLS + +#endif /* !_RPCSEC_GSS_H */ diff --git a/sys/rpc/rpcsec_gss/rpcsec_gss.c b/sys/rpc/rpcsec_gss/rpcsec_gss.c new file mode 100644 index 0000000..790804d --- /dev/null +++ b/sys/rpc/rpcsec_gss/rpcsec_gss.c @@ -0,0 +1,1064 @@ +/*- + * Copyright (c) 2008 Doug Rabson + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ +/* + auth_gss.c + + RPCSEC_GSS client routines. + + Copyright (c) 2000 The Regents of the University of Michigan. + All rights reserved. + + Copyright (c) 2000 Dug Song <dugsong@UMICH.EDU>. + All rights reserved, all wrongs reversed. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + 3. Neither the name of the University nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED + WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + $Id: auth_gss.c,v 1.32 2002/01/15 15:43:00 andros Exp $ +*/ + +#include <sys/cdefs.h> +__FBSDID("$FreeBSD$"); + +#include <sys/param.h> +#include <sys/systm.h> +#include <sys/hash.h> +#include <sys/kernel.h> +#include <sys/kobj.h> +#include <sys/lock.h> +#include <sys/malloc.h> +#include <sys/mbuf.h> +#include <sys/mutex.h> +#include <sys/proc.h> +#include <sys/refcount.h> +#include <sys/sx.h> +#include <sys/ucred.h> + +#include <rpc/rpc.h> +#include <rpc/rpcsec_gss.h> + +#include "rpcsec_gss_int.h" + +static void rpc_gss_nextverf(AUTH*); +static bool_t rpc_gss_marshal(AUTH *, uint32_t, XDR *, struct mbuf *); +static bool_t rpc_gss_init(AUTH *auth, rpc_gss_options_ret_t *options_ret); +static bool_t rpc_gss_refresh(AUTH *, void *); +static bool_t rpc_gss_validate(AUTH *, uint32_t, struct opaque_auth *, + struct mbuf **); +static void rpc_gss_destroy(AUTH *); +static void rpc_gss_destroy_context(AUTH *, bool_t); + +static struct auth_ops rpc_gss_ops = { + rpc_gss_nextverf, + rpc_gss_marshal, + rpc_gss_validate, + rpc_gss_refresh, + rpc_gss_destroy, +}; + +enum rpcsec_gss_state { + RPCSEC_GSS_START, + RPCSEC_GSS_CONTEXT, + RPCSEC_GSS_ESTABLISHED, + RPCSEC_GSS_DESTROYING +}; + +struct rpc_pending_request { + uint32_t pr_xid; /* XID of rpc */ + uint32_t pr_seq; /* matching GSS seq */ + LIST_ENTRY(rpc_pending_request) pr_link; +}; +LIST_HEAD(rpc_pending_request_list, rpc_pending_request); + +struct rpc_gss_data { + volatile u_int gd_refs; /* number of current users */ + struct mtx gd_lock; + uint32_t gd_hash; + AUTH *gd_auth; /* link back to AUTH */ + struct ucred *gd_ucred; /* matching local cred */ + char *gd_principal; /* server principal name */ + rpc_gss_options_req_t gd_options; /* GSS context options */ + enum rpcsec_gss_state gd_state; /* connection state */ + gss_buffer_desc gd_verf; /* save GSS_S_COMPLETE + * NULL RPC verfier to + * process at end of + * context negotiation */ + CLIENT *gd_clnt; /* client handle */ + gss_OID gd_mech; /* mechanism to use */ + gss_qop_t gd_qop; /* quality of protection */ + gss_ctx_id_t gd_ctx; /* context id */ + struct rpc_gss_cred gd_cred; /* client credentials */ + uint32_t gd_seq; /* next sequence number */ + u_int gd_win; /* sequence window */ + struct rpc_pending_request_list gd_reqs; + TAILQ_ENTRY(rpc_gss_data) gd_link; + TAILQ_ENTRY(rpc_gss_data) gd_alllink; +}; +TAILQ_HEAD(rpc_gss_data_list, rpc_gss_data); + +#define AUTH_PRIVATE(auth) ((struct rpc_gss_data *)auth->ah_private) + +static struct timeval AUTH_TIMEOUT = { 25, 0 }; + +#define RPC_GSS_HASH_SIZE 11 +#define RPC_GSS_MAX 256 +static struct rpc_gss_data_list rpc_gss_cache[RPC_GSS_HASH_SIZE]; +static struct rpc_gss_data_list rpc_gss_all; +static struct sx rpc_gss_lock; +static int rpc_gss_count; + +static AUTH *rpc_gss_seccreate_int(CLIENT *, struct ucred *, const char *, + gss_OID, rpc_gss_service_t, u_int, rpc_gss_options_req_t *, + rpc_gss_options_ret_t *); + +static void +rpc_gss_hashinit(void *dummy) +{ + int i; + + for (i = 0; i < RPC_GSS_HASH_SIZE; i++) + TAILQ_INIT(&rpc_gss_cache[i]); + TAILQ_INIT(&rpc_gss_all); + sx_init(&rpc_gss_lock, "rpc_gss_lock"); +} +SYSINIT(rpc_gss_hashinit, SI_SUB_KMEM, SI_ORDER_ANY, rpc_gss_hashinit, NULL); + +static uint32_t +rpc_gss_hash(const char *principal, gss_OID mech, + struct ucred *cred, rpc_gss_service_t service) +{ + uint32_t h; + + h = HASHSTEP(HASHINIT, cred->cr_uid); + h = hash32_str(principal, h); + h = hash32_buf(mech->elements, mech->length, h); + h = HASHSTEP(h, (int) service); + + return (h % RPC_GSS_HASH_SIZE); +} + +/* + * Simplified interface to create a security association for the + * current thread's * ucred. + */ +AUTH * +rpc_gss_secfind(CLIENT *clnt, struct ucred *cred, const char *principal, + gss_OID mech_oid, rpc_gss_service_t service) +{ + uint32_t h, th; + AUTH *auth; + struct rpc_gss_data *gd, *tgd; + + if (rpc_gss_count > RPC_GSS_MAX) { + while (rpc_gss_count > RPC_GSS_MAX) { + sx_xlock(&rpc_gss_lock); + tgd = TAILQ_FIRST(&rpc_gss_all); + th = tgd->gd_hash; + TAILQ_REMOVE(&rpc_gss_cache[th], tgd, gd_link); + TAILQ_REMOVE(&rpc_gss_all, tgd, gd_alllink); + rpc_gss_count--; + sx_xunlock(&rpc_gss_lock); + AUTH_DESTROY(tgd->gd_auth); + } + } + + /* + * See if we already have an AUTH which matches. + */ + h = rpc_gss_hash(principal, mech_oid, cred, service); + +again: + sx_slock(&rpc_gss_lock); + TAILQ_FOREACH(gd, &rpc_gss_cache[h], gd_link) { + if (gd->gd_ucred->cr_uid == cred->cr_uid + && !strcmp(gd->gd_principal, principal) + && gd->gd_mech == mech_oid + && gd->gd_cred.gc_svc == service) { + refcount_acquire(&gd->gd_refs); + if (sx_try_upgrade(&rpc_gss_lock)) { + /* + * Keep rpc_gss_all LRU sorted. + */ + TAILQ_REMOVE(&rpc_gss_all, gd, gd_alllink); + TAILQ_INSERT_TAIL(&rpc_gss_all, gd, + gd_alllink); + sx_xunlock(&rpc_gss_lock); + } else { + sx_sunlock(&rpc_gss_lock); + } + return (gd->gd_auth); + } + } + sx_sunlock(&rpc_gss_lock); + + /* + * We missed in the cache - create a new association. + */ + auth = rpc_gss_seccreate_int(clnt, cred, principal, mech_oid, service, + GSS_C_QOP_DEFAULT, NULL, NULL); + if (!auth) + return (NULL); + + gd = AUTH_PRIVATE(auth); + gd->gd_hash = h; + + sx_xlock(&rpc_gss_lock); + TAILQ_FOREACH(tgd, &rpc_gss_cache[h], gd_link) { + if (tgd->gd_ucred->cr_uid == cred->cr_uid + && !strcmp(tgd->gd_principal, principal) + && tgd->gd_mech == mech_oid + && tgd->gd_cred.gc_svc == service) { + /* + * We lost a race to create the AUTH that + * matches this cred. + */ + sx_xunlock(&rpc_gss_lock); + AUTH_DESTROY(auth); + goto again; + } + } + + rpc_gss_count++; + TAILQ_INSERT_TAIL(&rpc_gss_cache[h], gd, gd_link); + TAILQ_INSERT_TAIL(&rpc_gss_all, gd, gd_alllink); + refcount_acquire(&gd->gd_refs); /* one for the cache, one for user */ + sx_xunlock(&rpc_gss_lock); + + return (auth); +} + +void +rpc_gss_secpurge(CLIENT *clnt) +{ + uint32_t h; + struct rpc_gss_data *gd, *tgd; + + TAILQ_FOREACH_SAFE(gd, &rpc_gss_all, gd_alllink, tgd) { + if (gd->gd_clnt == clnt) { + sx_xlock(&rpc_gss_lock); + h = gd->gd_hash; + TAILQ_REMOVE(&rpc_gss_cache[h], gd, gd_link); + TAILQ_REMOVE(&rpc_gss_all, gd, gd_alllink); + rpc_gss_count--; + sx_xunlock(&rpc_gss_lock); + AUTH_DESTROY(gd->gd_auth); + } + } +} + +AUTH * +rpc_gss_seccreate(CLIENT *clnt, struct ucred *cred, const char *principal, + const char *mechanism, rpc_gss_service_t service, const char *qop, + rpc_gss_options_req_t *options_req, rpc_gss_options_ret_t *options_ret) +{ + gss_OID oid; + u_int qop_num; + + /* + * Bail out now if we don't know this mechanism. + */ + if (!rpc_gss_mech_to_oid(mechanism, &oid)) + return (NULL); + + if (qop) { + if (!rpc_gss_qop_to_num(qop, mechanism, &qop_num)) + return (NULL); + } else { + qop_num = GSS_C_QOP_DEFAULT; + } + + return (rpc_gss_seccreate_int(clnt, cred, principal, oid, service, + qop_num, options_req, options_ret)); +} + +static AUTH * +rpc_gss_seccreate_int(CLIENT *clnt, struct ucred *cred, const char *principal, + gss_OID mech_oid, rpc_gss_service_t service, u_int qop_num, + rpc_gss_options_req_t *options_req, rpc_gss_options_ret_t *options_ret) +{ + AUTH *auth; + rpc_gss_options_ret_t options; + struct rpc_gss_data *gd; + + /* + * If the caller doesn't want the options, point at local + * storage to simplify the code below. + */ + if (!options_ret) + options_ret = &options; + + /* + * Default service is integrity. + */ + if (service == rpc_gss_svc_default) + service = rpc_gss_svc_integrity; + + memset(options_ret, 0, sizeof(*options_ret)); + + rpc_gss_log_debug("in rpc_gss_seccreate()"); + + memset(&rpc_createerr, 0, sizeof(rpc_createerr)); + + auth = mem_alloc(sizeof(*auth)); + if (auth == NULL) { + rpc_createerr.cf_stat = RPC_SYSTEMERROR; + rpc_createerr.cf_error.re_errno = ENOMEM; + return (NULL); + } + gd = mem_alloc(sizeof(*gd)); + if (gd == NULL) { + rpc_createerr.cf_stat = RPC_SYSTEMERROR; + rpc_createerr.cf_error.re_errno = ENOMEM; + mem_free(auth, sizeof(*auth)); + return (NULL); + } + + auth->ah_ops = &rpc_gss_ops; + auth->ah_private = (caddr_t) gd; + auth->ah_cred.oa_flavor = RPCSEC_GSS; + + refcount_init(&gd->gd_refs, 1); + mtx_init(&gd->gd_lock, "gd->gd_lock", NULL, MTX_DEF); + gd->gd_auth = auth; + gd->gd_ucred = crdup(cred); + gd->gd_principal = strdup(principal, M_RPC); + + + if (options_req) { + gd->gd_options = *options_req; + } else { + gd->gd_options.req_flags = GSS_C_MUTUAL_FLAG; + gd->gd_options.time_req = 0; + gd->gd_options.my_cred = GSS_C_NO_CREDENTIAL; + gd->gd_options.input_channel_bindings = NULL; + } + CLNT_ACQUIRE(clnt); + gd->gd_clnt = clnt; + gd->gd_ctx = GSS_C_NO_CONTEXT; + gd->gd_mech = mech_oid; + gd->gd_qop = qop_num; + + gd->gd_cred.gc_version = RPCSEC_GSS_VERSION; + gd->gd_cred.gc_proc = RPCSEC_GSS_INIT; + gd->gd_cred.gc_seq = 0; + gd->gd_cred.gc_svc = service; + LIST_INIT(&gd->gd_reqs); + + if (!rpc_gss_init(auth, options_ret)) { + goto bad; + } + + return (auth); + + bad: + AUTH_DESTROY(auth); + return (NULL); +} + +bool_t +rpc_gss_set_defaults(AUTH *auth, rpc_gss_service_t service, const char *qop) +{ + struct rpc_gss_data *gd; + u_int qop_num; + const char *mechanism; + + gd = AUTH_PRIVATE(auth); + if (!rpc_gss_oid_to_mech(gd->gd_mech, &mechanism)) { + return (FALSE); + } + + if (qop) { + if (!rpc_gss_qop_to_num(qop, mechanism, &qop_num)) { + return (FALSE); + } + } else { + qop_num = GSS_C_QOP_DEFAULT; + } + + gd->gd_cred.gc_svc = service; + gd->gd_qop = qop_num; + return (TRUE); +} + +static void +rpc_gss_purge_xid(struct rpc_gss_data *gd, uint32_t xid) +{ + struct rpc_pending_request *pr, *npr; + struct rpc_pending_request_list reqs; + + LIST_INIT(&reqs); + mtx_lock(&gd->gd_lock); + LIST_FOREACH_SAFE(pr, &gd->gd_reqs, pr_link, npr) { + if (pr->pr_xid == xid) { + LIST_REMOVE(pr, pr_link); + LIST_INSERT_HEAD(&reqs, pr, pr_link); + } + } + + mtx_unlock(&gd->gd_lock); + + LIST_FOREACH_SAFE(pr, &reqs, pr_link, npr) { + mem_free(pr, sizeof(*pr)); + } +} + +static uint32_t +rpc_gss_alloc_seq(struct rpc_gss_data *gd) +{ + uint32_t seq; + + mtx_lock(&gd->gd_lock); + seq = gd->gd_seq; + gd->gd_seq++; + mtx_unlock(&gd->gd_lock); + + return (seq); +} + +static void +rpc_gss_nextverf(__unused AUTH *auth) +{ + + /* not used */ +} + +static bool_t +rpc_gss_marshal(AUTH *auth, uint32_t xid, XDR *xdrs, struct mbuf *args) +{ + struct rpc_gss_data *gd; + struct rpc_pending_request *pr; + uint32_t seq; + XDR tmpxdrs; + struct rpc_gss_cred gsscred; + char credbuf[MAX_AUTH_BYTES]; + struct opaque_auth creds, verf; + gss_buffer_desc rpcbuf, checksum; + OM_uint32 maj_stat, min_stat; + bool_t xdr_stat; + + rpc_gss_log_debug("in rpc_gss_marshal()"); + + gd = AUTH_PRIVATE(auth); + + gsscred = gd->gd_cred; + seq = rpc_gss_alloc_seq(gd); + gsscred.gc_seq = seq; + + xdrmem_create(&tmpxdrs, credbuf, sizeof(credbuf), XDR_ENCODE); + if (!xdr_rpc_gss_cred(&tmpxdrs, &gsscred)) { + XDR_DESTROY(&tmpxdrs); + _rpc_gss_set_error(RPC_GSS_ER_SYSTEMERROR, ENOMEM); + return (FALSE); + } + creds.oa_flavor = RPCSEC_GSS; + creds.oa_base = credbuf; + creds.oa_length = XDR_GETPOS(&tmpxdrs); + XDR_DESTROY(&tmpxdrs); + + xdr_opaque_auth(xdrs, &creds); + + if (gd->gd_cred.gc_proc == RPCSEC_GSS_INIT || + gd->gd_cred.gc_proc == RPCSEC_GSS_CONTINUE_INIT) { + if (!xdr_opaque_auth(xdrs, &_null_auth)) { + _rpc_gss_set_error(RPC_GSS_ER_SYSTEMERROR, ENOMEM); + return (FALSE); + } + xdrmbuf_append(xdrs, args); + return (TRUE); + } else { + /* + * Keep track of this XID + seq pair so that we can do + * the matching gss_verify_mic in AUTH_VALIDATE. + */ + pr = mem_alloc(sizeof(struct rpc_pending_request)); + mtx_lock(&gd->gd_lock); + pr->pr_xid = xid; + pr->pr_seq = seq; + LIST_INSERT_HEAD(&gd->gd_reqs, pr, pr_link); + mtx_unlock(&gd->gd_lock); + + /* + * Checksum serialized RPC header, up to and including + * credential. For the in-kernel environment, we + * assume that our XDR stream is on a contiguous + * memory buffer (e.g. an mbuf). + */ + rpcbuf.length = XDR_GETPOS(xdrs); + XDR_SETPOS(xdrs, 0); + rpcbuf.value = XDR_INLINE(xdrs, rpcbuf.length); + + maj_stat = gss_get_mic(&min_stat, gd->gd_ctx, gd->gd_qop, + &rpcbuf, &checksum); + + if (maj_stat != GSS_S_COMPLETE) { + rpc_gss_log_status("gss_get_mic", gd->gd_mech, + maj_stat, min_stat); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) { + rpc_gss_destroy_context(auth, TRUE); + } + _rpc_gss_set_error(RPC_GSS_ER_SYSTEMERROR, EPERM); + return (FALSE); + } + + verf.oa_flavor = RPCSEC_GSS; + verf.oa_base = checksum.value; + verf.oa_length = checksum.length; + + xdr_stat = xdr_opaque_auth(xdrs, &verf); + gss_release_buffer(&min_stat, &checksum); + if (!xdr_stat) { + _rpc_gss_set_error(RPC_GSS_ER_SYSTEMERROR, ENOMEM); + return (FALSE); + } + if (gd->gd_state != RPCSEC_GSS_ESTABLISHED || + gd->gd_cred.gc_svc == rpc_gss_svc_none) { + xdrmbuf_append(xdrs, args); + return (TRUE); + } else { + if (!xdr_rpc_gss_wrap_data(&args, + gd->gd_ctx, gd->gd_qop, gd->gd_cred.gc_svc, + seq)) + return (FALSE); + xdrmbuf_append(xdrs, args); + return (TRUE); + } + } + + return (TRUE); +} + +static bool_t +rpc_gss_validate(AUTH *auth, uint32_t xid, struct opaque_auth *verf, + struct mbuf **resultsp) +{ + struct rpc_gss_data *gd; + struct rpc_pending_request *pr, *npr; + struct rpc_pending_request_list reqs; + gss_qop_t qop_state; + uint32_t num, seq; + gss_buffer_desc signbuf, checksum; + OM_uint32 maj_stat, min_stat; + + rpc_gss_log_debug("in rpc_gss_validate()"); + + gd = AUTH_PRIVATE(auth); + + /* + * The client will call us with a NULL verf when it gives up + * on an XID. + */ + if (!verf) { + rpc_gss_purge_xid(gd, xid); + return (TRUE); + } + + if (gd->gd_state == RPCSEC_GSS_CONTEXT) { + /* + * Save the on the wire verifier to validate last INIT + * phase packet after decode if the major status is + * GSS_S_COMPLETE. + */ + if (gd->gd_verf.value) + xdr_free((xdrproc_t) xdr_gss_buffer_desc, + (char *) &gd->gd_verf); + gd->gd_verf.value = mem_alloc(verf->oa_length); + if (gd->gd_verf.value == NULL) { + printf("gss_validate: out of memory\n"); + _rpc_gss_set_error(RPC_GSS_ER_SYSTEMERROR, ENOMEM); + m_freem(*resultsp); + *resultsp = NULL; + return (FALSE); + } + memcpy(gd->gd_verf.value, verf->oa_base, verf->oa_length); + gd->gd_verf.length = verf->oa_length; + + return (TRUE); + } + + /* + * We need to check the verifier against all the requests + * we've send for this XID - for unreliable protocols, we + * retransmit with the same XID but different sequence + * number. We temporarily take this set of requests out of the + * list so that we can work through the list without having to + * hold the lock. + */ + mtx_lock(&gd->gd_lock); + LIST_INIT(&reqs); + LIST_FOREACH_SAFE(pr, &gd->gd_reqs, pr_link, npr) { + if (pr->pr_xid == xid) { + LIST_REMOVE(pr, pr_link); + LIST_INSERT_HEAD(&reqs, pr, pr_link); + } + } + mtx_unlock(&gd->gd_lock); + LIST_FOREACH(pr, &reqs, pr_link) { + if (pr->pr_xid == xid) { + seq = pr->pr_seq; + num = htonl(seq); + signbuf.value = # + signbuf.length = sizeof(num); + + checksum.value = verf->oa_base; + checksum.length = verf->oa_length; + + maj_stat = gss_verify_mic(&min_stat, gd->gd_ctx, + &signbuf, &checksum, &qop_state); + if (maj_stat != GSS_S_COMPLETE + || qop_state != gd->gd_qop) { + continue; + } + if (maj_stat == GSS_S_CONTEXT_EXPIRED) { + rpc_gss_destroy_context(auth, TRUE); + break; + } + //rpc_gss_purge_reqs(gd, seq); + LIST_FOREACH_SAFE(pr, &reqs, pr_link, npr) + mem_free(pr, sizeof(*pr)); + + if (gd->gd_cred.gc_svc == rpc_gss_svc_none) { + return (TRUE); + } else { + if (!xdr_rpc_gss_unwrap_data(resultsp, + gd->gd_ctx, gd->gd_qop, + gd->gd_cred.gc_svc, seq)) { + return (FALSE); + } + } + return (TRUE); + } + } + + /* + * We didn't match - put back any entries for this XID so that + * a future call to validate can retry. + */ + mtx_lock(&gd->gd_lock); + LIST_FOREACH_SAFE(pr, &reqs, pr_link, npr) { + LIST_REMOVE(pr, pr_link); + LIST_INSERT_HEAD(&gd->gd_reqs, pr, pr_link); + } + mtx_unlock(&gd->gd_lock); + + /* + * Nothing matches - give up. + */ + _rpc_gss_set_error(RPC_GSS_ER_SYSTEMERROR, EPERM); + m_freem(*resultsp); + *resultsp = NULL; + return (FALSE); +} + +static bool_t +rpc_gss_init(AUTH *auth, rpc_gss_options_ret_t *options_ret) +{ + struct thread *td = curthread; + struct ucred *crsave; + struct rpc_gss_data *gd; + struct rpc_gss_init_res gr; + gss_buffer_desc principal_desc; + gss_buffer_desc *recv_tokenp, recv_token, send_token; + gss_name_t name; + OM_uint32 maj_stat, min_stat, call_stat; + const char *mech; + struct rpc_callextra ext; + + rpc_gss_log_debug("in rpc_gss_refresh()"); + + gd = AUTH_PRIVATE(auth); + + mtx_lock(&gd->gd_lock); + /* + * If the context isn't in START state, someone else is + * refreshing - we wait till they are done. If they fail, they + * will put the state back to START and we can try (most + * likely to also fail). + */ + while (gd->gd_state != RPCSEC_GSS_START + && gd->gd_state != RPCSEC_GSS_ESTABLISHED) { + msleep(gd, &gd->gd_lock, 0, "gssstate", 0); + } + if (gd->gd_state == RPCSEC_GSS_ESTABLISHED) { + mtx_unlock(&gd->gd_lock); + return (TRUE); + } + gd->gd_state = RPCSEC_GSS_CONTEXT; + mtx_unlock(&gd->gd_lock); + + principal_desc.value = (void *)gd->gd_principal; + principal_desc.length = strlen(gd->gd_principal); + maj_stat = gss_import_name(&min_stat, &principal_desc, + GSS_C_NT_HOSTBASED_SERVICE, &name); + if (maj_stat != GSS_S_COMPLETE) { + options_ret->major_status = maj_stat; + options_ret->minor_status = min_stat; + goto out; + } + + /* GSS context establishment loop. */ + gd->gd_cred.gc_proc = RPCSEC_GSS_INIT; + gd->gd_cred.gc_seq = 0; + + memset(&recv_token, 0, sizeof(recv_token)); + memset(&gr, 0, sizeof(gr)); + memset(options_ret, 0, sizeof(*options_ret)); + options_ret->major_status = GSS_S_FAILURE; + recv_tokenp = GSS_C_NO_BUFFER; + + for (;;) { + crsave = td->td_ucred; + td->td_ucred = gd->gd_ucred; + maj_stat = gss_init_sec_context(&min_stat, + gd->gd_options.my_cred, + &gd->gd_ctx, + name, + gd->gd_mech, + gd->gd_options.req_flags, + gd->gd_options.time_req, + gd->gd_options.input_channel_bindings, + recv_tokenp, + &gd->gd_mech, /* used mech */ + &send_token, + &options_ret->ret_flags, + &options_ret->time_req); + td->td_ucred = crsave; + + /* + * Free the token which we got from the server (if + * any). Remember that this was allocated by XDR, not + * GSS-API. + */ + if (recv_tokenp != GSS_C_NO_BUFFER) { + xdr_free((xdrproc_t) xdr_gss_buffer_desc, + (char *) &recv_token); + recv_tokenp = GSS_C_NO_BUFFER; + } + if (gd->gd_mech && rpc_gss_oid_to_mech(gd->gd_mech, &mech)) { + strlcpy(options_ret->actual_mechanism, + mech, + sizeof(options_ret->actual_mechanism)); + } + if (maj_stat != GSS_S_COMPLETE && + maj_stat != GSS_S_CONTINUE_NEEDED) { + rpc_gss_log_status("gss_init_sec_context", gd->gd_mech, + maj_stat, min_stat); + options_ret->major_status = maj_stat; + options_ret->minor_status = min_stat; + break; + } + if (send_token.length != 0) { + memset(&gr, 0, sizeof(gr)); + + bzero(&ext, sizeof(ext)); + ext.rc_auth = auth; + call_stat = CLNT_CALL_EXT(gd->gd_clnt, &ext, NULLPROC, + (xdrproc_t)xdr_gss_buffer_desc, + &send_token, + (xdrproc_t)xdr_rpc_gss_init_res, + (caddr_t)&gr, AUTH_TIMEOUT); + + gss_release_buffer(&min_stat, &send_token); + + if (call_stat != RPC_SUCCESS) + break; + + if (gr.gr_major != GSS_S_COMPLETE && + gr.gr_major != GSS_S_CONTINUE_NEEDED) { + rpc_gss_log_status("server reply", gd->gd_mech, + gr.gr_major, gr.gr_minor); + options_ret->major_status = gr.gr_major; + options_ret->minor_status = gr.gr_minor; + break; + } + + /* + * Save the server's gr_handle value, freeing + * what we have already (remember that this + * was allocated by XDR, not GSS-API). + */ + if (gr.gr_handle.length != 0) { + xdr_free((xdrproc_t) xdr_gss_buffer_desc, + (char *) &gd->gd_cred.gc_handle); + gd->gd_cred.gc_handle = gr.gr_handle; + } + + /* + * Save the server's token as well. + */ + if (gr.gr_token.length != 0) { + recv_token = gr.gr_token; + recv_tokenp = &recv_token; + } + + /* + * Since we have copied out all the bits of gr + * which XDR allocated for us, we don't need + * to free it. + */ + gd->gd_cred.gc_proc = RPCSEC_GSS_CONTINUE_INIT; + } + + if (maj_stat == GSS_S_COMPLETE) { + gss_buffer_desc bufin; + u_int seq, qop_state = 0; + + /* + * gss header verifier, + * usually checked in gss_validate + */ + seq = htonl(gr.gr_win); + bufin.value = (unsigned char *)&seq; + bufin.length = sizeof(seq); + + maj_stat = gss_verify_mic(&min_stat, gd->gd_ctx, + &bufin, &gd->gd_verf, &qop_state); + + if (maj_stat != GSS_S_COMPLETE || + qop_state != gd->gd_qop) { + rpc_gss_log_status("gss_verify_mic", gd->gd_mech, + maj_stat, min_stat); + if (maj_stat == GSS_S_CONTEXT_EXPIRED) { + rpc_gss_destroy_context(auth, TRUE); + } + _rpc_gss_set_error(RPC_GSS_ER_SYSTEMERROR, + EPERM); + options_ret->major_status = maj_stat; + options_ret->minor_status = min_stat; + break; + } + + options_ret->major_status = GSS_S_COMPLETE; + options_ret->minor_status = 0; + options_ret->rpcsec_version = gd->gd_cred.gc_version; + options_ret->gss_context = gd->gd_ctx; + + gd->gd_cred.gc_proc = RPCSEC_GSS_DATA; + gd->gd_seq = 1; + gd->gd_win = gr.gr_win; + break; + } + } + + gss_release_name(&min_stat, &name); + xdr_free((xdrproc_t) xdr_gss_buffer_desc, + (char *) &gd->gd_verf); + +out: + /* End context negotiation loop. */ + if (gd->gd_cred.gc_proc != RPCSEC_GSS_DATA) { + rpc_createerr.cf_stat = RPC_AUTHERROR; + _rpc_gss_set_error(RPC_GSS_ER_SYSTEMERROR, EPERM); + if (gd->gd_ctx) { + gss_delete_sec_context(&min_stat, &gd->gd_ctx, + GSS_C_NO_BUFFER); + } + mtx_lock(&gd->gd_lock); + gd->gd_state = RPCSEC_GSS_START; + wakeup(gd); + mtx_unlock(&gd->gd_lock); + return (FALSE); + } + + mtx_lock(&gd->gd_lock); + gd->gd_state = RPCSEC_GSS_ESTABLISHED; + wakeup(gd); + mtx_unlock(&gd->gd_lock); + + return (TRUE); +} + +static bool_t +rpc_gss_refresh(AUTH *auth, void *msg) +{ + struct rpc_msg *reply = (struct rpc_msg *) msg; + rpc_gss_options_ret_t options; + + /* + * If the error was RPCSEC_GSS_CREDPROBLEM of + * RPCSEC_GSS_CTXPROBLEM we start again from scratch. All + * other errors are fatal. + */ + if (reply->rm_reply.rp_stat == MSG_DENIED + && reply->rm_reply.rp_rjct.rj_stat == AUTH_ERROR + && (reply->rm_reply.rp_rjct.rj_why == RPCSEC_GSS_CREDPROBLEM + || reply->rm_reply.rp_rjct.rj_why == RPCSEC_GSS_CTXPROBLEM)) { + rpc_gss_destroy_context(auth, FALSE); + memset(&options, 0, sizeof(options)); + return (rpc_gss_init(auth, &options)); + } + + return (FALSE); +} + +static void +rpc_gss_destroy_context(AUTH *auth, bool_t send_destroy) +{ + struct rpc_gss_data *gd; + struct rpc_pending_request *pr; + OM_uint32 min_stat; + struct rpc_callextra ext; + + rpc_gss_log_debug("in rpc_gss_destroy_context()"); + + gd = AUTH_PRIVATE(auth); + + mtx_lock(&gd->gd_lock); + /* + * If the context isn't in ESTABISHED state, someone else is + * destroying/refreshing - we wait till they are done. + */ + if (gd->gd_state != RPCSEC_GSS_ESTABLISHED) { + while (gd->gd_state != RPCSEC_GSS_START + && gd->gd_state != RPCSEC_GSS_ESTABLISHED) + msleep(gd, &gd->gd_lock, 0, "gssstate", 0); + mtx_unlock(&gd->gd_lock); + return; + } + gd->gd_state = RPCSEC_GSS_DESTROYING; + mtx_unlock(&gd->gd_lock); + + if (send_destroy) { + gd->gd_cred.gc_proc = RPCSEC_GSS_DESTROY; + bzero(&ext, sizeof(ext)); + ext.rc_auth = auth; + CLNT_CALL_EXT(gd->gd_clnt, &ext, NULLPROC, + (xdrproc_t)xdr_void, NULL, + (xdrproc_t)xdr_void, NULL, AUTH_TIMEOUT); + } + + while ((pr = LIST_FIRST(&gd->gd_reqs)) != NULL) { + LIST_REMOVE(pr, pr_link); + mem_free(pr, sizeof(*pr)); + } + + /* + * Free the context token. Remember that this was + * allocated by XDR, not GSS-API. + */ + xdr_free((xdrproc_t) xdr_gss_buffer_desc, + (char *) &gd->gd_cred.gc_handle); + gd->gd_cred.gc_handle.length = 0; + + if (gd->gd_ctx != GSS_C_NO_CONTEXT) + gss_delete_sec_context(&min_stat, &gd->gd_ctx, NULL); + + mtx_lock(&gd->gd_lock); + gd->gd_state = RPCSEC_GSS_START; + wakeup(gd); + mtx_unlock(&gd->gd_lock); +} + +static void +rpc_gss_destroy(AUTH *auth) +{ + struct rpc_gss_data *gd; + + rpc_gss_log_debug("in rpc_gss_destroy()"); + + gd = AUTH_PRIVATE(auth); + + if (!refcount_release(&gd->gd_refs)) + return; + + rpc_gss_destroy_context(auth, TRUE); + + CLNT_RELEASE(gd->gd_clnt); + crfree(gd->gd_ucred); + free(gd->gd_principal, M_RPC); + if (gd->gd_verf.value) + xdr_free((xdrproc_t) xdr_gss_buffer_desc, + (char *) &gd->gd_verf); + mtx_destroy(&gd->gd_lock); + + mem_free(gd, sizeof(*gd)); + mem_free(auth, sizeof(*auth)); +} + +int +rpc_gss_max_data_length(AUTH *auth, int max_tp_unit_len) +{ + struct rpc_gss_data *gd; + int want_conf; + OM_uint32 max; + OM_uint32 maj_stat, min_stat; + int result; + + gd = AUTH_PRIVATE(auth); + + switch (gd->gd_cred.gc_svc) { + case rpc_gss_svc_none: + return (max_tp_unit_len); + break; + + case rpc_gss_svc_default: + case rpc_gss_svc_integrity: + want_conf = FALSE; + break; + + case rpc_gss_svc_privacy: + want_conf = TRUE; + break; + + default: + return (0); + } + + maj_stat = gss_wrap_size_limit(&min_stat, gd->gd_ctx, want_conf, + gd->gd_qop, max_tp_unit_len, &max); + + if (maj_stat == GSS_S_COMPLETE) { + result = (int) max; + if (result < 0) + result = 0; + return (result); + } else { + rpc_gss_log_status("gss_wrap_size_limit", gd->gd_mech, + maj_stat, min_stat); + return (0); + } +} diff --git a/sys/rpc/rpcsec_gss/rpcsec_gss_conf.c b/sys/rpc/rpcsec_gss/rpcsec_gss_conf.c new file mode 100644 index 0000000..b5e99d4 --- /dev/null +++ b/sys/rpc/rpcsec_gss/rpcsec_gss_conf.c @@ -0,0 +1,163 @@ +/*- + * Copyright (c) 2008 Doug Rabson + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ + +#include <sys/cdefs.h> +__FBSDID("$FreeBSD$"); + +#include <sys/param.h> +#include <sys/systm.h> +#include <sys/kobj.h> +#include <sys/lock.h> +#include <sys/malloc.h> +#include <sys/mutex.h> + +#include <rpc/rpc.h> +#include <rpc/rpcsec_gss.h> + +#include "rpcsec_gss_int.h" + +bool_t +rpc_gss_mech_to_oid(const char *mech, gss_OID *oid_ret) +{ + gss_OID oid = kgss_find_mech_by_name(mech); + + if (oid) { + *oid_ret = oid; + return (TRUE); + } + _rpc_gss_set_error(RPC_GSS_ER_SYSTEMERROR, ENOENT); + return (FALSE); +} + +bool_t +rpc_gss_oid_to_mech(gss_OID oid, const char **mech_ret) +{ + const char *name = kgss_find_mech_by_oid(oid); + + if (name) { + *mech_ret = name; + return (TRUE); + } + _rpc_gss_set_error(RPC_GSS_ER_SYSTEMERROR, ENOENT); + return (FALSE); +} + +bool_t +rpc_gss_qop_to_num(const char *qop, const char *mech, u_int *num_ret) +{ + + if (!strcmp(qop, "default")) { + *num_ret = GSS_C_QOP_DEFAULT; + return (TRUE); + } + _rpc_gss_set_error(RPC_GSS_ER_SYSTEMERROR, ENOENT); + return (FALSE); +} + +const char * +_rpc_gss_num_to_qop(const char *mech, u_int num) +{ + + if (num == GSS_C_QOP_DEFAULT) + return "default"; + + return (NULL); +} + +const char ** +rpc_gss_get_mechanisms(void) +{ + static const char **mech_names = NULL; + struct kgss_mech *km; + int count; + + if (mech_names) + return (mech_names); + + count = 0; + LIST_FOREACH(km, &kgss_mechs, km_link) { + count++; + } + count++; + + mech_names = malloc(count * sizeof(const char *), M_RPC, M_WAITOK); + count = 0; + LIST_FOREACH(km, &kgss_mechs, km_link) { + mech_names[count++] = km->km_mech_name; + } + mech_names[count++] = NULL; + + return (mech_names); +} + +#if 0 +const char ** +rpc_gss_get_mech_info(const char *mech, rpc_gss_service_t *service) +{ + struct mech_info *info; + + _rpc_gss_load_mech(); + _rpc_gss_load_qop(); + SLIST_FOREACH(info, &mechs, link) { + if (!strcmp(mech, info->name)) { + /* + * I'm not sure what to do with service + * here. The Solaris manpages are not clear on + * the subject and the OpenSolaris code just + * sets it to rpc_gss_svc_privacy + * unconditionally with a comment noting that + * it is bogus. + */ + *service = rpc_gss_svc_privacy; + return info->qops; + } + } + + _rpc_gss_set_error(RPC_GSS_ER_SYSTEMERROR, ENOENT); + return (NULL); +} +#endif + +bool_t +rpc_gss_get_versions(u_int *vers_hi, u_int *vers_lo) +{ + + *vers_hi = 1; + *vers_lo = 1; + return (TRUE); +} + +bool_t +rpc_gss_is_installed(const char *mech) +{ + gss_OID oid = kgss_find_mech_by_name(mech); + + if (oid) + return (TRUE); + else + return (FALSE); +} + diff --git a/sys/rpc/rpcsec_gss/rpcsec_gss_int.h b/sys/rpc/rpcsec_gss/rpcsec_gss_int.h new file mode 100644 index 0000000..4f38828 --- /dev/null +++ b/sys/rpc/rpcsec_gss/rpcsec_gss_int.h @@ -0,0 +1,94 @@ +/* + rpcsec_gss.h + + Copyright (c) 2000 The Regents of the University of Michigan. + All rights reserved. + + Copyright (c) 2000 Dug Song <dugsong@UMICH.EDU>. + All rights reserved, all wrongs reversed. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + 3. Neither the name of the University nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED + WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + $Id: auth_gss.h,v 1.12 2001/04/30 19:44:47 andros Exp $ +*/ +/* $FreeBSD$ */ + +#ifndef _RPCSEC_GSS_INT_H +#define _RPCSEC_GSS_INT_H + +#include <kgssapi/gssapi_impl.h> + +/* RPCSEC_GSS control procedures. */ +typedef enum { + RPCSEC_GSS_DATA = 0, + RPCSEC_GSS_INIT = 1, + RPCSEC_GSS_CONTINUE_INIT = 2, + RPCSEC_GSS_DESTROY = 3 +} rpc_gss_proc_t; + +#define RPCSEC_GSS_VERSION 1 + +/* Credentials. */ +struct rpc_gss_cred { + u_int gc_version; /* version */ + rpc_gss_proc_t gc_proc; /* control procedure */ + u_int gc_seq; /* sequence number */ + rpc_gss_service_t gc_svc; /* service */ + gss_buffer_desc gc_handle; /* handle to server-side context */ +}; + +/* Context creation response. */ +struct rpc_gss_init_res { + gss_buffer_desc gr_handle; /* handle to server-side context */ + u_int gr_major; /* major status */ + u_int gr_minor; /* minor status */ + u_int gr_win; /* sequence window */ + gss_buffer_desc gr_token; /* token */ +}; + +/* Maximum sequence number value. */ +#define MAXSEQ 0x80000000 + +/* Prototypes. */ +__BEGIN_DECLS + +bool_t xdr_rpc_gss_cred(XDR *xdrs, struct rpc_gss_cred *p); +bool_t xdr_rpc_gss_init_res(XDR *xdrs, struct rpc_gss_init_res *p); +bool_t xdr_rpc_gss_wrap_data(struct mbuf **argsp, + gss_ctx_id_t ctx, gss_qop_t qop, rpc_gss_service_t svc, + u_int seq); +bool_t xdr_rpc_gss_unwrap_data(struct mbuf **resultsp, + gss_ctx_id_t ctx, gss_qop_t qop, rpc_gss_service_t svc, u_int seq); +const char *_rpc_gss_num_to_qop(const char *mech, u_int num); +void _rpc_gss_set_error(int rpc_gss_error, int system_error); + +void rpc_gss_log_debug(const char *fmt, ...); +void rpc_gss_log_status(const char *m, gss_OID mech, OM_uint32 major, + OM_uint32 minor); + +__END_DECLS + +#endif /* !_RPCSEC_GSS_INT_H */ diff --git a/sys/rpc/rpcsec_gss/rpcsec_gss_misc.c b/sys/rpc/rpcsec_gss/rpcsec_gss_misc.c new file mode 100644 index 0000000..5c8bf91 --- /dev/null +++ b/sys/rpc/rpcsec_gss/rpcsec_gss_misc.c @@ -0,0 +1,53 @@ +/*- + * Copyright (c) 2008 Doug Rabson + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ + +#include <sys/cdefs.h> +__FBSDID("$FreeBSD$"); + +#include <sys/param.h> +#include <sys/kobj.h> +#include <sys/malloc.h> +#include <rpc/rpc.h> +#include <rpc/rpcsec_gss.h> + +#include "rpcsec_gss_int.h" + +static rpc_gss_error_t _rpc_gss_error; + +void +_rpc_gss_set_error(int rpc_gss_error, int system_error) +{ + + _rpc_gss_error.rpc_gss_error = rpc_gss_error; + _rpc_gss_error.system_error = system_error; +} + +void +rpc_gss_get_error(rpc_gss_error_t *error) +{ + + *error = _rpc_gss_error; +} diff --git a/sys/rpc/rpcsec_gss/rpcsec_gss_prot.c b/sys/rpc/rpcsec_gss/rpcsec_gss_prot.c new file mode 100644 index 0000000..0654a6e --- /dev/null +++ b/sys/rpc/rpcsec_gss/rpcsec_gss_prot.c @@ -0,0 +1,359 @@ +/* + rpcsec_gss_prot.c + + Copyright (c) 2000 The Regents of the University of Michigan. + All rights reserved. + + Copyright (c) 2000 Dug Song <dugsong@UMICH.EDU>. + All rights reserved, all wrongs reversed. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + 3. Neither the name of the University nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED + WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + $Id: authgss_prot.c,v 1.18 2000/09/01 04:14:03 dugsong Exp $ +*/ + +#include <sys/cdefs.h> +__FBSDID("$FreeBSD$"); + +#include <sys/param.h> +#include <sys/systm.h> +#include <sys/kobj.h> +#include <sys/lock.h> +#include <sys/malloc.h> +#include <sys/mbuf.h> +#include <sys/mutex.h> + +#include <rpc/rpc.h> +#include <rpc/rpcsec_gss.h> + +#include "rpcsec_gss_int.h" + +#define MAX_GSS_SIZE 10240 /* XXX */ + +#if 0 /* use the one from kgssapi */ +bool_t +xdr_gss_buffer_desc(XDR *xdrs, gss_buffer_desc *p) +{ + char *val; + u_int len; + bool_t ret; + + val = p->value; + len = p->length; + ret = xdr_bytes(xdrs, &val, &len, MAX_GSS_SIZE); + p->value = val; + p->length = len; + + return (ret); +} +#endif + +bool_t +xdr_rpc_gss_cred(XDR *xdrs, struct rpc_gss_cred *p) +{ + enum_t proc, svc; + bool_t ret; + + proc = p->gc_proc; + svc = p->gc_svc; + ret = (xdr_u_int(xdrs, &p->gc_version) && + xdr_enum(xdrs, &proc) && + xdr_u_int(xdrs, &p->gc_seq) && + xdr_enum(xdrs, &svc) && + xdr_gss_buffer_desc(xdrs, &p->gc_handle)); + p->gc_proc = proc; + p->gc_svc = svc; + + return (ret); +} + +bool_t +xdr_rpc_gss_init_res(XDR *xdrs, struct rpc_gss_init_res *p) +{ + + return (xdr_gss_buffer_desc(xdrs, &p->gr_handle) && + xdr_u_int(xdrs, &p->gr_major) && + xdr_u_int(xdrs, &p->gr_minor) && + xdr_u_int(xdrs, &p->gr_win) && + xdr_gss_buffer_desc(xdrs, &p->gr_token)); +} + +static void +put_uint32(struct mbuf **mp, uint32_t v) +{ + struct mbuf *m = *mp; + uint32_t n; + + M_PREPEND(m, sizeof(uint32_t), M_WAIT); + n = htonl(v); + bcopy(&n, mtod(m, uint32_t *), sizeof(uint32_t)); + *mp = m; +} + +bool_t +xdr_rpc_gss_wrap_data(struct mbuf **argsp, + gss_ctx_id_t ctx, gss_qop_t qop, + rpc_gss_service_t svc, u_int seq) +{ + struct mbuf *args, *mic; + OM_uint32 maj_stat, min_stat; + int conf_state; + u_int len; + static char zpad[4]; + + args = *argsp; + + /* + * Prepend the sequence number before calling gss_get_mic or gss_wrap. + */ + put_uint32(&args, seq); + len = m_length(args, NULL); + + if (svc == rpc_gss_svc_integrity) { + /* Checksum rpc_gss_data_t. */ + maj_stat = gss_get_mic_mbuf(&min_stat, ctx, qop, args, &mic); + if (maj_stat != GSS_S_COMPLETE) { + rpc_gss_log_debug("gss_get_mic failed"); + m_freem(args); + return (FALSE); + } + + /* + * Marshal databody_integ. Note that since args is + * already RPC encoded, there will be no padding. + */ + put_uint32(&args, len); + + /* + * Marshal checksum. This is likely to need padding. + */ + len = m_length(mic, NULL); + put_uint32(&mic, len); + if (len != RNDUP(len)) { + m_append(mic, RNDUP(len) - len, zpad); + } + + /* + * Concatenate databody_integ with checksum. + */ + m_cat(args, mic); + } else if (svc == rpc_gss_svc_privacy) { + /* Encrypt rpc_gss_data_t. */ + maj_stat = gss_wrap_mbuf(&min_stat, ctx, TRUE, qop, + &args, &conf_state); + if (maj_stat != GSS_S_COMPLETE) { + rpc_gss_log_status("gss_wrap", NULL, + maj_stat, min_stat); + return (FALSE); + } + + /* + * Marshal databody_priv and deal with RPC padding. + */ + len = m_length(args, NULL); + put_uint32(&args, len); + if (len != RNDUP(len)) { + m_append(args, RNDUP(len) - len, zpad); + } + } + *argsp = args; + return (TRUE); +} + +static uint32_t +get_uint32(struct mbuf **mp) +{ + struct mbuf *m = *mp; + uint32_t n; + + if (m->m_len < sizeof(uint32_t)) { + m = m_pullup(m, sizeof(uint32_t)); + if (!m) { + *mp = NULL; + return (0); + } + } + bcopy(mtod(m, uint32_t *), &n, sizeof(uint32_t)); + m_adj(m, sizeof(uint32_t)); + *mp = m; + return (ntohl(n)); +} + +static void +m_trim(struct mbuf *m, int len) +{ + struct mbuf *n; + int off; + + n = m_getptr(m, len, &off); + if (n) { + n->m_len = off; + if (n->m_next) { + m_freem(n->m_next); + n->m_next = NULL; + } + } +} + +bool_t +xdr_rpc_gss_unwrap_data(struct mbuf **resultsp, + gss_ctx_id_t ctx, gss_qop_t qop, + rpc_gss_service_t svc, u_int seq) +{ + struct mbuf *results, *message, *mic; + uint32_t len, cklen; + OM_uint32 maj_stat, min_stat; + u_int seq_num, conf_state, qop_state; + + results = *resultsp; + *resultsp = NULL; + + message = NULL; + if (svc == rpc_gss_svc_integrity) { + /* + * Extract the seq+message part. Remember that there + * may be extra RPC padding in the checksum. The + * message part is RPC encoded already so no + * padding. + */ + len = get_uint32(&results); + message = results; + results = m_split(results, len, M_WAIT); + if (!results) { + m_freem(message); + return (FALSE); + } + + /* + * Extract the MIC and make it contiguous. + */ + cklen = get_uint32(&results); + KASSERT(cklen <= MHLEN, ("unexpected large GSS-API checksum")); + mic = results; + if (cklen > mic->m_len) + mic = m_pullup(mic, cklen); + if (cklen != RNDUP(cklen)) + m_trim(mic, cklen); + + /* Verify checksum and QOP. */ + maj_stat = gss_verify_mic_mbuf(&min_stat, ctx, + message, mic, &qop_state); + m_freem(mic); + + if (maj_stat != GSS_S_COMPLETE || qop_state != qop) { + m_freem(message); + rpc_gss_log_status("gss_verify_mic", NULL, + maj_stat, min_stat); + return (FALSE); + } + } else if (svc == rpc_gss_svc_privacy) { + /* Decode databody_priv. */ + len = get_uint32(&results); + + /* Decrypt databody. */ + message = results; + if (len != RNDUP(len)) + m_trim(message, len); + maj_stat = gss_unwrap_mbuf(&min_stat, ctx, &message, + &conf_state, &qop_state); + + /* Verify encryption and QOP. */ + if (maj_stat != GSS_S_COMPLETE) { + rpc_gss_log_status("gss_unwrap", NULL, + maj_stat, min_stat); + return (FALSE); + } + if (qop_state != qop || conf_state != TRUE) { + m_freem(results); + return (FALSE); + } + } + + /* Decode rpc_gss_data_t (sequence number + arguments). */ + seq_num = get_uint32(&message); + + /* Verify sequence number. */ + if (seq_num != seq) { + rpc_gss_log_debug("wrong sequence number in databody"); + m_freem(message); + return (FALSE); + } + + *resultsp = message; + return (TRUE); +} + +#ifdef DEBUG +#include <ctype.h> + +void +rpc_gss_log_debug(const char *fmt, ...) +{ + va_list ap; + + va_start(ap, fmt); + fprintf(stderr, "rpcsec_gss: "); + vfprintf(stderr, fmt, ap); + fprintf(stderr, "\n"); + va_end(ap); +} + +void +rpc_gss_log_status(const char *m, gss_OID mech, OM_uint32 maj_stat, OM_uint32 min_stat) +{ + OM_uint32 min; + gss_buffer_desc msg; + int msg_ctx = 0; + + fprintf(stderr, "rpcsec_gss: %s: ", m); + + gss_display_status(&min, maj_stat, GSS_C_GSS_CODE, GSS_C_NULL_OID, + &msg_ctx, &msg); + printf("%s - ", (char *)msg.value); + gss_release_buffer(&min, &msg); + + gss_display_status(&min, min_stat, GSS_C_MECH_CODE, mech, + &msg_ctx, &msg); + printf("%s\n", (char *)msg.value); + gss_release_buffer(&min, &msg); +} + +#else + +void +rpc_gss_log_debug(__unused const char *fmt, ...) +{ +} + +void +rpc_gss_log_status(__unused const char *m, __unused gss_OID mech, + __unused OM_uint32 maj_stat, __unused OM_uint32 min_stat) +{ +} + +#endif + + diff --git a/sys/rpc/rpcsec_gss/svc_rpcsec_gss.c b/sys/rpc/rpcsec_gss/svc_rpcsec_gss.c new file mode 100644 index 0000000..e2469fd --- /dev/null +++ b/sys/rpc/rpcsec_gss/svc_rpcsec_gss.c @@ -0,0 +1,1485 @@ +/*- + * Copyright (c) 2008 Doug Rabson + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ +/* + svc_rpcsec_gss.c + + Copyright (c) 2000 The Regents of the University of Michigan. + All rights reserved. + + Copyright (c) 2000 Dug Song <dugsong@UMICH.EDU>. + All rights reserved, all wrongs reversed. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions + are met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + 3. Neither the name of the University nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED + WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR + BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF + LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING + NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + $Id: svc_auth_gss.c,v 1.27 2002/01/15 15:43:00 andros Exp $ + */ + +#include <sys/cdefs.h> +__FBSDID("$FreeBSD$"); + +#include <sys/param.h> +#include <sys/systm.h> +#include <sys/kernel.h> +#include <sys/kobj.h> +#include <sys/lock.h> +#include <sys/malloc.h> +#include <sys/mbuf.h> +#include <sys/mutex.h> +#include <sys/sx.h> +#include <sys/ucred.h> + +#include <rpc/rpc.h> +#include <rpc/rpcsec_gss.h> + +#include "rpcsec_gss_int.h" + +static bool_t svc_rpc_gss_wrap(SVCAUTH *, struct mbuf **); +static bool_t svc_rpc_gss_unwrap(SVCAUTH *, struct mbuf **); +static void svc_rpc_gss_release(SVCAUTH *); +static enum auth_stat svc_rpc_gss(struct svc_req *, struct rpc_msg *); +static int rpc_gss_svc_getcred(struct svc_req *, struct ucred **, int *); + +static struct svc_auth_ops svc_auth_gss_ops = { + svc_rpc_gss_wrap, + svc_rpc_gss_unwrap, + svc_rpc_gss_release, +}; + +struct sx svc_rpc_gss_lock; + +struct svc_rpc_gss_callback { + SLIST_ENTRY(svc_rpc_gss_callback) cb_link; + rpc_gss_callback_t cb_callback; +}; +static SLIST_HEAD(svc_rpc_gss_callback_list, svc_rpc_gss_callback) + svc_rpc_gss_callbacks = SLIST_HEAD_INITIALIZER(&svc_rpc_gss_callbacks); + +struct svc_rpc_gss_svc_name { + SLIST_ENTRY(svc_rpc_gss_svc_name) sn_link; + char *sn_principal; + gss_OID sn_mech; + u_int sn_req_time; + gss_cred_id_t sn_cred; + u_int sn_program; + u_int sn_version; +}; +static SLIST_HEAD(svc_rpc_gss_svc_name_list, svc_rpc_gss_svc_name) + svc_rpc_gss_svc_names = SLIST_HEAD_INITIALIZER(&svc_rpc_gss_svc_names); + +enum svc_rpc_gss_client_state { + CLIENT_NEW, /* still authenticating */ + CLIENT_ESTABLISHED, /* context established */ + CLIENT_STALE /* garbage to collect */ +}; + +#define SVC_RPC_GSS_SEQWINDOW 128 + +struct svc_rpc_gss_clientid { + uint32_t ci_hostid; + uint32_t ci_boottime; + uint32_t ci_id; +}; + +struct svc_rpc_gss_client { + TAILQ_ENTRY(svc_rpc_gss_client) cl_link; + TAILQ_ENTRY(svc_rpc_gss_client) cl_alllink; + volatile u_int cl_refs; + struct sx cl_lock; + struct svc_rpc_gss_clientid cl_id; + time_t cl_expiration; /* when to gc */ + enum svc_rpc_gss_client_state cl_state; /* client state */ + bool_t cl_locked; /* fixed service+qop */ + gss_ctx_id_t cl_ctx; /* context id */ + gss_cred_id_t cl_creds; /* delegated creds */ + gss_name_t cl_cname; /* client name */ + struct svc_rpc_gss_svc_name *cl_sname; /* server name used */ + rpc_gss_rawcred_t cl_rawcred; /* raw credentials */ + rpc_gss_ucred_t cl_ucred; /* unix-style credentials */ + struct ucred *cl_cred; /* kernel-style credentials */ + int cl_rpcflavor; /* RPC pseudo sec flavor */ + bool_t cl_done_callback; /* TRUE after call */ + void *cl_cookie; /* user cookie from callback */ + gid_t cl_gid_storage[NGROUPS]; + gss_OID cl_mech; /* mechanism */ + gss_qop_t cl_qop; /* quality of protection */ + uint32_t cl_seqlast; /* sequence window origin */ + uint32_t cl_seqmask[SVC_RPC_GSS_SEQWINDOW/32]; /* bitmask of seqnums */ +}; +TAILQ_HEAD(svc_rpc_gss_client_list, svc_rpc_gss_client); + +/* + * This structure holds enough information to unwrap arguments or wrap + * results for a given request. We use the rq_clntcred area for this + * (which is a per-request buffer). + */ +struct svc_rpc_gss_cookedcred { + struct svc_rpc_gss_client *cc_client; + rpc_gss_service_t cc_service; + uint32_t cc_seq; +}; + +#define CLIENT_HASH_SIZE 256 +#define CLIENT_MAX 128 +struct svc_rpc_gss_client_list svc_rpc_gss_client_hash[CLIENT_HASH_SIZE]; +struct svc_rpc_gss_client_list svc_rpc_gss_clients; +static size_t svc_rpc_gss_client_count; +static uint32_t svc_rpc_gss_next_clientid = 1; + +static void +svc_rpc_gss_init(void *arg) +{ + int i; + + for (i = 0; i < CLIENT_HASH_SIZE; i++) + TAILQ_INIT(&svc_rpc_gss_client_hash[i]); + TAILQ_INIT(&svc_rpc_gss_clients); + svc_auth_reg(RPCSEC_GSS, svc_rpc_gss, rpc_gss_svc_getcred); + sx_init(&svc_rpc_gss_lock, "gsslock"); +} +SYSINIT(svc_rpc_gss_init, SI_SUB_KMEM, SI_ORDER_ANY, svc_rpc_gss_init, NULL); + +bool_t +rpc_gss_set_callback(rpc_gss_callback_t *cb) +{ + struct svc_rpc_gss_callback *scb; + + scb = mem_alloc(sizeof(struct svc_rpc_gss_callback)); + if (!scb) { + _rpc_gss_set_error(RPC_GSS_ER_SYSTEMERROR, ENOMEM); + return (FALSE); + } + scb->cb_callback = *cb; + sx_xlock(&svc_rpc_gss_lock); + SLIST_INSERT_HEAD(&svc_rpc_gss_callbacks, scb, cb_link); + sx_xunlock(&svc_rpc_gss_lock); + + return (TRUE); +} + +void +rpc_gss_clear_callback(rpc_gss_callback_t *cb) +{ + struct svc_rpc_gss_callback *scb; + + sx_xlock(&svc_rpc_gss_lock); + SLIST_FOREACH(scb, &svc_rpc_gss_callbacks, cb_link) { + if (scb->cb_callback.program == cb->program + && scb->cb_callback.version == cb->version + && scb->cb_callback.callback == cb->callback) { + SLIST_REMOVE(&svc_rpc_gss_callbacks, scb, + svc_rpc_gss_callback, cb_link); + sx_xunlock(&svc_rpc_gss_lock); + mem_free(scb, sizeof(*scb)); + return; + } + } + sx_xunlock(&svc_rpc_gss_lock); +} + +static bool_t +rpc_gss_acquire_svc_cred(struct svc_rpc_gss_svc_name *sname) +{ + OM_uint32 maj_stat, min_stat; + gss_buffer_desc namebuf; + gss_name_t name; + gss_OID_set_desc oid_set; + + oid_set.count = 1; + oid_set.elements = sname->sn_mech; + + namebuf.value = (void *) sname->sn_principal; + namebuf.length = strlen(sname->sn_principal); + + maj_stat = gss_import_name(&min_stat, &namebuf, + GSS_C_NT_HOSTBASED_SERVICE, &name); + if (maj_stat != GSS_S_COMPLETE) + return (FALSE); + + if (sname->sn_cred != GSS_C_NO_CREDENTIAL) + gss_release_cred(&min_stat, &sname->sn_cred); + + maj_stat = gss_acquire_cred(&min_stat, name, + sname->sn_req_time, &oid_set, GSS_C_ACCEPT, &sname->sn_cred, + NULL, NULL); + if (maj_stat != GSS_S_COMPLETE) { + gss_release_name(&min_stat, &name); + return (FALSE); + } + gss_release_name(&min_stat, &name); + + return (TRUE); +} + +bool_t +rpc_gss_set_svc_name(const char *principal, const char *mechanism, + u_int req_time, u_int program, u_int version) +{ + struct svc_rpc_gss_svc_name *sname; + gss_OID mech_oid; + + if (!rpc_gss_mech_to_oid(mechanism, &mech_oid)) + return (FALSE); + + sname = mem_alloc(sizeof(*sname)); + if (!sname) + return (FALSE); + sname->sn_principal = strdup(principal, M_RPC); + sname->sn_mech = mech_oid; + sname->sn_req_time = req_time; + sname->sn_cred = GSS_C_NO_CREDENTIAL; + sname->sn_program = program; + sname->sn_version = version; + + if (!rpc_gss_acquire_svc_cred(sname)) { + free(sname->sn_principal, M_RPC); + mem_free(sname, sizeof(*sname)); + return (FALSE); + } + + sx_xlock(&svc_rpc_gss_lock); + SLIST_INSERT_HEAD(&svc_rpc_gss_svc_names, sname, sn_link); + sx_xunlock(&svc_rpc_gss_lock); + + return (TRUE); +} + +void +rpc_gss_clear_svc_name(u_int program, u_int version) +{ + OM_uint32 min_stat; + struct svc_rpc_gss_svc_name *sname; + + sx_xlock(&svc_rpc_gss_lock); + SLIST_FOREACH(sname, &svc_rpc_gss_svc_names, sn_link) { + if (sname->sn_program == program + && sname->sn_version == version) { + SLIST_REMOVE(&svc_rpc_gss_svc_names, sname, + svc_rpc_gss_svc_name, sn_link); + sx_xunlock(&svc_rpc_gss_lock); + gss_release_cred(&min_stat, &sname->sn_cred); + free(sname->sn_principal, M_RPC); + mem_free(sname, sizeof(*sname)); + return; + } + } + sx_xunlock(&svc_rpc_gss_lock); +} + +bool_t +rpc_gss_get_principal_name(rpc_gss_principal_t *principal, + const char *mech, const char *name, const char *node, const char *domain) +{ + OM_uint32 maj_stat, min_stat; + gss_OID mech_oid; + size_t namelen; + gss_buffer_desc buf; + gss_name_t gss_name, gss_mech_name; + rpc_gss_principal_t result; + + if (!rpc_gss_mech_to_oid(mech, &mech_oid)) + return (FALSE); + + /* + * Construct a gss_buffer containing the full name formatted + * as "name/node@domain" where node and domain are optional. + */ + namelen = strlen(name); + if (node) { + namelen += strlen(node) + 1; + } + if (domain) { + namelen += strlen(domain) + 1; + } + + buf.value = mem_alloc(namelen); + buf.length = namelen; + strcpy((char *) buf.value, name); + if (node) { + strcat((char *) buf.value, "/"); + strcat((char *) buf.value, node); + } + if (domain) { + strcat((char *) buf.value, "@"); + strcat((char *) buf.value, domain); + } + + /* + * Convert that to a gss_name_t and then convert that to a + * mechanism name in the selected mechanism. + */ + maj_stat = gss_import_name(&min_stat, &buf, + GSS_C_NT_USER_NAME, &gss_name); + mem_free(buf.value, buf.length); + if (maj_stat != GSS_S_COMPLETE) { + rpc_gss_log_status("gss_import_name", mech_oid, maj_stat, min_stat); + return (FALSE); + } + maj_stat = gss_canonicalize_name(&min_stat, gss_name, mech_oid, + &gss_mech_name); + if (maj_stat != GSS_S_COMPLETE) { + rpc_gss_log_status("gss_canonicalize_name", mech_oid, maj_stat, + min_stat); + gss_release_name(&min_stat, &gss_name); + return (FALSE); + } + gss_release_name(&min_stat, &gss_name); + + /* + * Export the mechanism name and use that to construct the + * rpc_gss_principal_t result. + */ + maj_stat = gss_export_name(&min_stat, gss_mech_name, &buf); + if (maj_stat != GSS_S_COMPLETE) { + rpc_gss_log_status("gss_export_name", mech_oid, maj_stat, min_stat); + gss_release_name(&min_stat, &gss_mech_name); + return (FALSE); + } + gss_release_name(&min_stat, &gss_mech_name); + + result = mem_alloc(sizeof(int) + buf.length); + if (!result) { + gss_release_buffer(&min_stat, &buf); + return (FALSE); + } + result->len = buf.length; + memcpy(result->name, buf.value, buf.length); + gss_release_buffer(&min_stat, &buf); + + *principal = result; + return (TRUE); +} + +bool_t +rpc_gss_getcred(struct svc_req *req, rpc_gss_rawcred_t **rcred, + rpc_gss_ucred_t **ucred, void **cookie) +{ + struct svc_rpc_gss_cookedcred *cc; + struct svc_rpc_gss_client *client; + + if (req->rq_cred.oa_flavor != RPCSEC_GSS) + return (FALSE); + + cc = req->rq_clntcred; + client = cc->cc_client; + if (rcred) + *rcred = &client->cl_rawcred; + if (ucred) + *ucred = &client->cl_ucred; + if (cookie) + *cookie = client->cl_cookie; + return (TRUE); +} + +/* + * This simpler interface is used by svc_getcred to copy the cred data + * into a kernel cred structure. + */ +static int +rpc_gss_svc_getcred(struct svc_req *req, struct ucred **crp, int *flavorp) +{ + struct ucred *cr; + struct svc_rpc_gss_cookedcred *cc; + struct svc_rpc_gss_client *client; + rpc_gss_ucred_t *uc; + int i; + + if (req->rq_cred.oa_flavor != RPCSEC_GSS) + return (FALSE); + + cc = req->rq_clntcred; + client = cc->cc_client; + + if (flavorp) + *flavorp = client->cl_rpcflavor; + + if (client->cl_cred) { + *crp = crhold(client->cl_cred); + return (TRUE); + } + + uc = &client->cl_ucred; + cr = client->cl_cred = crget(); + cr->cr_uid = cr->cr_ruid = cr->cr_svuid = uc->uid; + cr->cr_rgid = cr->cr_svgid = uc->gid; + cr->cr_ngroups = uc->gidlen; + if (cr->cr_ngroups > NGROUPS) + cr->cr_ngroups = NGROUPS; + for (i = 0; i < cr->cr_ngroups; i++) + cr->cr_groups[i] = uc->gidlist[i]; + *crp = crhold(cr); + + return (TRUE); +} + +int +rpc_gss_svc_max_data_length(struct svc_req *req, int max_tp_unit_len) +{ + struct svc_rpc_gss_cookedcred *cc = req->rq_clntcred; + struct svc_rpc_gss_client *client = cc->cc_client; + int want_conf; + OM_uint32 max; + OM_uint32 maj_stat, min_stat; + int result; + + switch (client->cl_rawcred.service) { + case rpc_gss_svc_none: + return (max_tp_unit_len); + break; + + case rpc_gss_svc_default: + case rpc_gss_svc_integrity: + want_conf = FALSE; + break; + + case rpc_gss_svc_privacy: + want_conf = TRUE; + break; + + default: + return (0); + } + + maj_stat = gss_wrap_size_limit(&min_stat, client->cl_ctx, want_conf, + client->cl_qop, max_tp_unit_len, &max); + + if (maj_stat == GSS_S_COMPLETE) { + result = (int) max; + if (result < 0) + result = 0; + return (result); + } else { + rpc_gss_log_status("gss_wrap_size_limit", client->cl_mech, + maj_stat, min_stat); + return (0); + } +} + +static struct svc_rpc_gss_client * +svc_rpc_gss_find_client(struct svc_rpc_gss_clientid *id) +{ + struct svc_rpc_gss_client *client; + struct svc_rpc_gss_client_list *list; + + rpc_gss_log_debug("in svc_rpc_gss_find_client(%d)", id->ci_id); + + if (id->ci_hostid != hostid || id->ci_boottime != boottime.tv_sec) + return (NULL); + + list = &svc_rpc_gss_client_hash[id->ci_id % CLIENT_HASH_SIZE]; + sx_xlock(&svc_rpc_gss_lock); + TAILQ_FOREACH(client, list, cl_link) { + if (client->cl_id.ci_id == id->ci_id) { + /* + * Move this client to the front of the LRU + * list. + */ + TAILQ_REMOVE(&svc_rpc_gss_clients, client, cl_alllink); + TAILQ_INSERT_HEAD(&svc_rpc_gss_clients, client, + cl_alllink); + refcount_acquire(&client->cl_refs); + break; + } + } + sx_xunlock(&svc_rpc_gss_lock); + + return (client); +} + +static struct svc_rpc_gss_client * +svc_rpc_gss_create_client(void) +{ + struct svc_rpc_gss_client *client; + struct svc_rpc_gss_client_list *list; + + rpc_gss_log_debug("in svc_rpc_gss_create_client()"); + + client = mem_alloc(sizeof(struct svc_rpc_gss_client)); + memset(client, 0, sizeof(struct svc_rpc_gss_client)); + refcount_init(&client->cl_refs, 1); + sx_init(&client->cl_lock, "GSS-client"); + client->cl_id.ci_hostid = hostid; + client->cl_id.ci_boottime = boottime.tv_sec; + client->cl_id.ci_id = svc_rpc_gss_next_clientid++; + list = &svc_rpc_gss_client_hash[client->cl_id.ci_id % CLIENT_HASH_SIZE]; + sx_xlock(&svc_rpc_gss_lock); + TAILQ_INSERT_HEAD(list, client, cl_link); + TAILQ_INSERT_HEAD(&svc_rpc_gss_clients, client, cl_alllink); + svc_rpc_gss_client_count++; + sx_xunlock(&svc_rpc_gss_lock); + + /* + * Start the client off with a short expiration time. We will + * try to get a saner value from the client creds later. + */ + client->cl_state = CLIENT_NEW; + client->cl_locked = FALSE; + client->cl_expiration = time_uptime + 5*60; + + return (client); +} + +static void +svc_rpc_gss_destroy_client(struct svc_rpc_gss_client *client) +{ + OM_uint32 min_stat; + + rpc_gss_log_debug("in svc_rpc_gss_destroy_client()"); + + if (client->cl_ctx) + gss_delete_sec_context(&min_stat, + &client->cl_ctx, GSS_C_NO_BUFFER); + + if (client->cl_cname) + gss_release_name(&min_stat, &client->cl_cname); + + if (client->cl_rawcred.client_principal) + mem_free(client->cl_rawcred.client_principal, + sizeof(*client->cl_rawcred.client_principal) + + client->cl_rawcred.client_principal->len); + + if (client->cl_cred) + crfree(client->cl_cred); + + sx_destroy(&client->cl_lock); + mem_free(client, sizeof(*client)); +} + +/* + * Drop a reference to a client and free it if that was the last reference. + */ +static void +svc_rpc_gss_release_client(struct svc_rpc_gss_client *client) +{ + + if (!refcount_release(&client->cl_refs)) + return; + svc_rpc_gss_destroy_client(client); +} + +/* + * Remove a client from our global lists and free it if we can. + */ +static void +svc_rpc_gss_forget_client(struct svc_rpc_gss_client *client) +{ + struct svc_rpc_gss_client_list *list; + + list = &svc_rpc_gss_client_hash[client->cl_id.ci_id % CLIENT_HASH_SIZE]; + sx_xlock(&svc_rpc_gss_lock); + TAILQ_REMOVE(list, client, cl_link); + TAILQ_REMOVE(&svc_rpc_gss_clients, client, cl_alllink); + svc_rpc_gss_client_count--; + sx_xunlock(&svc_rpc_gss_lock); + svc_rpc_gss_release_client(client); +} + +static void +svc_rpc_gss_timeout_clients(void) +{ + struct svc_rpc_gss_client *client; + struct svc_rpc_gss_client *nclient; + time_t now = time_uptime; + + rpc_gss_log_debug("in svc_rpc_gss_timeout_clients()"); + + /* + * First enforce the max client limit. We keep + * svc_rpc_gss_clients in LRU order. + */ + while (svc_rpc_gss_client_count > CLIENT_MAX) + svc_rpc_gss_forget_client(TAILQ_LAST(&svc_rpc_gss_clients, + svc_rpc_gss_client_list)); + TAILQ_FOREACH_SAFE(client, &svc_rpc_gss_clients, cl_alllink, nclient) { + if (client->cl_state == CLIENT_STALE + || now > client->cl_expiration) { + rpc_gss_log_debug("expiring client %p", client); + svc_rpc_gss_forget_client(client); + } + } +} + +#ifdef DEBUG +/* + * OID<->string routines. These are uuuuugly. + */ +static OM_uint32 +gss_oid_to_str(OM_uint32 *minor_status, gss_OID oid, gss_buffer_t oid_str) +{ + char numstr[128]; + unsigned long number; + int numshift; + size_t string_length; + size_t i; + unsigned char *cp; + char *bp; + + /* Decoded according to krb5/gssapi_krb5.c */ + + /* First determine the size of the string */ + string_length = 0; + number = 0; + numshift = 0; + cp = (unsigned char *) oid->elements; + number = (unsigned long) cp[0]; + sprintf(numstr, "%ld ", number/40); + string_length += strlen(numstr); + sprintf(numstr, "%ld ", number%40); + string_length += strlen(numstr); + for (i=1; i<oid->length; i++) { + if ( (size_t) (numshift+7) < (sizeof(unsigned long)*8)) { + number = (number << 7) | (cp[i] & 0x7f); + numshift += 7; + } + else { + *minor_status = 0; + return(GSS_S_FAILURE); + } + if ((cp[i] & 0x80) == 0) { + sprintf(numstr, "%ld ", number); + string_length += strlen(numstr); + number = 0; + numshift = 0; + } + } + /* + * If we get here, we've calculated the length of "n n n ... n ". Add 4 + * here for "{ " and "}\0". + */ + string_length += 4; + if ((bp = (char *) mem_alloc(string_length))) { + strcpy(bp, "{ "); + number = (unsigned long) cp[0]; + sprintf(numstr, "%ld ", number/40); + strcat(bp, numstr); + sprintf(numstr, "%ld ", number%40); + strcat(bp, numstr); + number = 0; + cp = (unsigned char *) oid->elements; + for (i=1; i<oid->length; i++) { + number = (number << 7) | (cp[i] & 0x7f); + if ((cp[i] & 0x80) == 0) { + sprintf(numstr, "%ld ", number); + strcat(bp, numstr); + number = 0; + } + } + strcat(bp, "}"); + oid_str->length = strlen(bp)+1; + oid_str->value = (void *) bp; + *minor_status = 0; + return(GSS_S_COMPLETE); + } + *minor_status = 0; + return(GSS_S_FAILURE); +} +#endif + +static void +svc_rpc_gss_build_ucred(struct svc_rpc_gss_client *client, + const gss_name_t name) +{ + OM_uint32 maj_stat, min_stat; + rpc_gss_ucred_t *uc = &client->cl_ucred; + int numgroups; + + uc->uid = 65534; + uc->gid = 65534; + uc->gidlist = client->cl_gid_storage; + + numgroups = NGROUPS; + maj_stat = gss_pname_to_unix_cred(&min_stat, name, client->cl_mech, + &uc->uid, &uc->gid, &numgroups, &uc->gidlist[0]); + if (GSS_ERROR(maj_stat)) + uc->gidlen = 0; + else + uc->gidlen = numgroups; +} + +static void +svc_rpc_gss_set_flavor(struct svc_rpc_gss_client *client) +{ + static gss_OID_desc krb5_mech_oid = + {9, (void *) "\x2a\x86\x48\x86\xf7\x12\x01\x02\x02" }; + + /* + * Attempt to translate mech type and service into a + * 'pseudo flavor'. Hardwire in krb5 support for now. + */ + if (kgss_oid_equal(client->cl_mech, &krb5_mech_oid)) { + switch (client->cl_rawcred.service) { + case rpc_gss_svc_default: + case rpc_gss_svc_none: + client->cl_rpcflavor = RPCSEC_GSS_KRB5; + break; + case rpc_gss_svc_integrity: + client->cl_rpcflavor = RPCSEC_GSS_KRB5I; + break; + case rpc_gss_svc_privacy: + client->cl_rpcflavor = RPCSEC_GSS_KRB5P; + break; + } + } else { + client->cl_rpcflavor = RPCSEC_GSS; + } +} + +static bool_t +svc_rpc_gss_accept_sec_context(struct svc_rpc_gss_client *client, + struct svc_req *rqst, + struct rpc_gss_init_res *gr, + struct rpc_gss_cred *gc) +{ + gss_buffer_desc recv_tok; + gss_OID mech; + OM_uint32 maj_stat = 0, min_stat = 0, ret_flags; + OM_uint32 cred_lifetime; + struct svc_rpc_gss_svc_name *sname; + + rpc_gss_log_debug("in svc_rpc_gss_accept_context()"); + + /* Deserialize arguments. */ + memset(&recv_tok, 0, sizeof(recv_tok)); + + if (!svc_getargs(rqst, + (xdrproc_t) xdr_gss_buffer_desc, + (caddr_t) &recv_tok)) { + client->cl_state = CLIENT_STALE; + return (FALSE); + } + + /* + * First time round, try all the server names we have until + * one matches. Afterwards, stick with that one. + */ + sx_xlock(&svc_rpc_gss_lock); + if (!client->cl_sname) { + SLIST_FOREACH(sname, &svc_rpc_gss_svc_names, sn_link) { + if (sname->sn_program == rqst->rq_prog + && sname->sn_version == rqst->rq_vers) { + retry: + gr->gr_major = gss_accept_sec_context( + &gr->gr_minor, + &client->cl_ctx, + sname->sn_cred, + &recv_tok, + GSS_C_NO_CHANNEL_BINDINGS, + &client->cl_cname, + &mech, + &gr->gr_token, + &ret_flags, + &cred_lifetime, + &client->cl_creds); + if (gr->gr_major == + GSS_S_CREDENTIALS_EXPIRED) { + /* + * Either our creds really did + * expire or gssd was + * restarted. + */ + if (rpc_gss_acquire_svc_cred(sname)) + goto retry; + } + client->cl_sname = sname; + break; + } + } + if (!sname) { + xdr_free((xdrproc_t) xdr_gss_buffer_desc, + (char *) &recv_tok); + sx_xunlock(&svc_rpc_gss_lock); + return (FALSE); + } + } else { + gr->gr_major = gss_accept_sec_context( + &gr->gr_minor, + &client->cl_ctx, + client->cl_sname->sn_cred, + &recv_tok, + GSS_C_NO_CHANNEL_BINDINGS, + &client->cl_cname, + &mech, + &gr->gr_token, + &ret_flags, + &cred_lifetime, + NULL); + } + sx_xunlock(&svc_rpc_gss_lock); + + xdr_free((xdrproc_t) xdr_gss_buffer_desc, (char *) &recv_tok); + + /* + * If we get an error from gss_accept_sec_context, send the + * reply anyway so that the client gets a chance to see what + * is wrong. + */ + if (gr->gr_major != GSS_S_COMPLETE && + gr->gr_major != GSS_S_CONTINUE_NEEDED) { + rpc_gss_log_status("accept_sec_context", client->cl_mech, + gr->gr_major, gr->gr_minor); + client->cl_state = CLIENT_STALE; + return (TRUE); + } + + gr->gr_handle.value = &client->cl_id; + gr->gr_handle.length = sizeof(client->cl_id); + gr->gr_win = SVC_RPC_GSS_SEQWINDOW; + + /* Save client info. */ + client->cl_mech = mech; + client->cl_qop = GSS_C_QOP_DEFAULT; + client->cl_done_callback = FALSE; + + if (gr->gr_major == GSS_S_COMPLETE) { + gss_buffer_desc export_name; + + /* + * Change client expiration time to be near when the + * client creds expire (or 24 hours if we can't figure + * that out). + */ + if (cred_lifetime == GSS_C_INDEFINITE) + cred_lifetime = time_uptime + 24*60*60; + + client->cl_expiration = time_uptime + cred_lifetime; + + /* + * Fill in cred details in the rawcred structure. + */ + client->cl_rawcred.version = RPCSEC_GSS_VERSION; + rpc_gss_oid_to_mech(mech, &client->cl_rawcred.mechanism); + maj_stat = gss_export_name(&min_stat, client->cl_cname, + &export_name); + if (maj_stat != GSS_S_COMPLETE) { + rpc_gss_log_status("gss_export_name", client->cl_mech, + maj_stat, min_stat); + return (FALSE); + } + client->cl_rawcred.client_principal = + mem_alloc(sizeof(*client->cl_rawcred.client_principal) + + export_name.length); + client->cl_rawcred.client_principal->len = export_name.length; + memcpy(client->cl_rawcred.client_principal->name, + export_name.value, export_name.length); + gss_release_buffer(&min_stat, &export_name); + client->cl_rawcred.svc_principal = + client->cl_sname->sn_principal; + client->cl_rawcred.service = gc->gc_svc; + + /* + * Use gss_pname_to_uid to map to unix creds. For + * kerberos5, this uses krb5_aname_to_localname. + */ + svc_rpc_gss_build_ucred(client, client->cl_cname); + svc_rpc_gss_set_flavor(client); + gss_release_name(&min_stat, &client->cl_cname); + +#ifdef DEBUG + { + gss_buffer_desc mechname; + + gss_oid_to_str(&min_stat, mech, &mechname); + + rpc_gss_log_debug("accepted context for %s with " + "<mech %.*s, qop %d, svc %d>", + client->cl_rawcred.client_principal->name, + mechname.length, (char *)mechname.value, + client->cl_qop, client->rawcred.service); + + gss_release_buffer(&min_stat, &mechname); + } +#endif /* DEBUG */ + } + return (TRUE); +} + +static bool_t +svc_rpc_gss_validate(struct svc_rpc_gss_client *client, struct rpc_msg *msg, + gss_qop_t *qop) +{ + struct opaque_auth *oa; + gss_buffer_desc rpcbuf, checksum; + OM_uint32 maj_stat, min_stat; + gss_qop_t qop_state; + int32_t rpchdr[128 / sizeof(int32_t)]; + int32_t *buf; + + rpc_gss_log_debug("in svc_rpc_gss_validate()"); + + memset(rpchdr, 0, sizeof(rpchdr)); + + /* Reconstruct RPC header for signing (from xdr_callmsg). */ + buf = rpchdr; + IXDR_PUT_LONG(buf, msg->rm_xid); + IXDR_PUT_ENUM(buf, msg->rm_direction); + IXDR_PUT_LONG(buf, msg->rm_call.cb_rpcvers); + IXDR_PUT_LONG(buf, msg->rm_call.cb_prog); + IXDR_PUT_LONG(buf, msg->rm_call.cb_vers); + IXDR_PUT_LONG(buf, msg->rm_call.cb_proc); + oa = &msg->rm_call.cb_cred; + IXDR_PUT_ENUM(buf, oa->oa_flavor); + IXDR_PUT_LONG(buf, oa->oa_length); + if (oa->oa_length) { + memcpy((caddr_t)buf, oa->oa_base, oa->oa_length); + buf += RNDUP(oa->oa_length) / sizeof(int32_t); + } + rpcbuf.value = rpchdr; + rpcbuf.length = (u_char *)buf - (u_char *)rpchdr; + + checksum.value = msg->rm_call.cb_verf.oa_base; + checksum.length = msg->rm_call.cb_verf.oa_length; + + maj_stat = gss_verify_mic(&min_stat, client->cl_ctx, &rpcbuf, &checksum, + &qop_state); + + if (maj_stat != GSS_S_COMPLETE) { + rpc_gss_log_status("gss_verify_mic", client->cl_mech, + maj_stat, min_stat); + client->cl_state = CLIENT_STALE; + return (FALSE); + } + + *qop = qop_state; + return (TRUE); +} + +static bool_t +svc_rpc_gss_nextverf(struct svc_rpc_gss_client *client, + struct svc_req *rqst, u_int seq) +{ + gss_buffer_desc signbuf; + gss_buffer_desc mic; + OM_uint32 maj_stat, min_stat; + uint32_t nseq; + + rpc_gss_log_debug("in svc_rpc_gss_nextverf()"); + + nseq = htonl(seq); + signbuf.value = &nseq; + signbuf.length = sizeof(nseq); + + maj_stat = gss_get_mic(&min_stat, client->cl_ctx, client->cl_qop, + &signbuf, &mic); + + if (maj_stat != GSS_S_COMPLETE) { + rpc_gss_log_status("gss_get_mic", client->cl_mech, maj_stat, min_stat); + client->cl_state = CLIENT_STALE; + return (FALSE); + } + + KASSERT(mic.length <= MAX_AUTH_BYTES, + ("MIC too large for RPCSEC_GSS")); + + rqst->rq_verf.oa_flavor = RPCSEC_GSS; + rqst->rq_verf.oa_length = mic.length; + bcopy(mic.value, rqst->rq_verf.oa_base, mic.length); + + gss_release_buffer(&min_stat, &mic); + + return (TRUE); +} + +static bool_t +svc_rpc_gss_callback(struct svc_rpc_gss_client *client, struct svc_req *rqst) +{ + struct svc_rpc_gss_callback *scb; + rpc_gss_lock_t lock; + void *cookie; + bool_t cb_res; + bool_t result; + + /* + * See if we have a callback for this guy. + */ + result = TRUE; + SLIST_FOREACH(scb, &svc_rpc_gss_callbacks, cb_link) { + if (scb->cb_callback.program == rqst->rq_prog + && scb->cb_callback.version == rqst->rq_vers) { + /* + * This one matches. Call the callback and see + * if it wants to veto or something. + */ + lock.locked = FALSE; + lock.raw_cred = &client->cl_rawcred; + cb_res = scb->cb_callback.callback(rqst, + client->cl_creds, + client->cl_ctx, + &lock, + &cookie); + + if (!cb_res) { + client->cl_state = CLIENT_STALE; + result = FALSE; + break; + } + + /* + * The callback accepted the connection - it + * is responsible for freeing client->cl_creds + * now. + */ + client->cl_creds = GSS_C_NO_CREDENTIAL; + client->cl_locked = lock.locked; + client->cl_cookie = cookie; + return (TRUE); + } + } + + /* + * Either no callback exists for this program/version or one + * of the callbacks rejected the connection. We just need to + * clean up the delegated client creds, if any. + */ + if (client->cl_creds) { + OM_uint32 min_ver; + gss_release_cred(&min_ver, &client->cl_creds); + } + return (result); +} + +static bool_t +svc_rpc_gss_check_replay(struct svc_rpc_gss_client *client, uint32_t seq) +{ + u_int32_t offset; + int word, bit; + bool_t result; + + sx_xlock(&client->cl_lock); + if (seq <= client->cl_seqlast) { + /* + * The request sequence number is less than + * the largest we have seen so far. If it is + * outside the window or if we have seen a + * request with this sequence before, silently + * discard it. + */ + offset = client->cl_seqlast - seq; + if (offset >= SVC_RPC_GSS_SEQWINDOW) { + result = FALSE; + goto out; + } + word = offset / 32; + bit = offset % 32; + if (client->cl_seqmask[word] & (1 << bit)) { + result = FALSE; + goto out; + } + } + + result = TRUE; +out: + sx_xunlock(&client->cl_lock); + return (result); +} + +static void +svc_rpc_gss_update_seq(struct svc_rpc_gss_client *client, uint32_t seq) +{ + int offset, i, word, bit; + uint32_t carry, newcarry; + + sx_xlock(&client->cl_lock); + if (seq > client->cl_seqlast) { + /* + * This request has a sequence number greater + * than any we have seen so far. Advance the + * seq window and set bit zero of the window + * (which corresponds to the new sequence + * number) + */ + offset = seq - client->cl_seqlast; + while (offset > 32) { + for (i = (SVC_RPC_GSS_SEQWINDOW / 32) - 1; + i > 0; i--) { + client->cl_seqmask[i] = client->cl_seqmask[i-1]; + } + client->cl_seqmask[0] = 0; + offset -= 32; + } + carry = 0; + for (i = 0; i < SVC_RPC_GSS_SEQWINDOW / 32; i++) { + newcarry = client->cl_seqmask[i] >> (32 - offset); + client->cl_seqmask[i] = + (client->cl_seqmask[i] << offset) | carry; + carry = newcarry; + } + client->cl_seqmask[0] |= 1; + client->cl_seqlast = seq; + } else { + offset = client->cl_seqlast - seq; + word = offset / 32; + bit = offset % 32; + client->cl_seqmask[word] |= (1 << bit); + } + sx_xunlock(&client->cl_lock); +} + +enum auth_stat +svc_rpc_gss(struct svc_req *rqst, struct rpc_msg *msg) + +{ + OM_uint32 min_stat; + XDR xdrs; + struct svc_rpc_gss_cookedcred *cc; + struct svc_rpc_gss_client *client; + struct rpc_gss_cred gc; + struct rpc_gss_init_res gr; + gss_qop_t qop; + int call_stat; + enum auth_stat result; + + rpc_gss_log_debug("in svc_rpc_gss()"); + + /* Garbage collect old clients. */ + svc_rpc_gss_timeout_clients(); + + /* Initialize reply. */ + rqst->rq_verf = _null_auth; + + /* Deserialize client credentials. */ + if (rqst->rq_cred.oa_length <= 0) + return (AUTH_BADCRED); + + memset(&gc, 0, sizeof(gc)); + + xdrmem_create(&xdrs, rqst->rq_cred.oa_base, + rqst->rq_cred.oa_length, XDR_DECODE); + + if (!xdr_rpc_gss_cred(&xdrs, &gc)) { + XDR_DESTROY(&xdrs); + return (AUTH_BADCRED); + } + XDR_DESTROY(&xdrs); + + client = NULL; + + /* Check version. */ + if (gc.gc_version != RPCSEC_GSS_VERSION) { + result = AUTH_BADCRED; + goto out; + } + + /* Check the proc and find the client (or create it) */ + if (gc.gc_proc == RPCSEC_GSS_INIT) { + if (gc.gc_handle.length != 0) { + result = AUTH_BADCRED; + goto out; + } + client = svc_rpc_gss_create_client(); + refcount_acquire(&client->cl_refs); + } else { + struct svc_rpc_gss_clientid *p; + if (gc.gc_handle.length != sizeof(*p)) { + result = AUTH_BADCRED; + goto out; + } + p = gc.gc_handle.value; + client = svc_rpc_gss_find_client(p); + if (!client) { + /* + * Can't find the client - we may have + * destroyed it - tell the other side to + * re-authenticate. + */ + result = RPCSEC_GSS_CREDPROBLEM; + goto out; + } + } + cc = rqst->rq_clntcred; + cc->cc_client = client; + cc->cc_service = gc.gc_svc; + cc->cc_seq = gc.gc_seq; + + /* + * The service and sequence number must be ignored for + * RPCSEC_GSS_INIT and RPCSEC_GSS_CONTINUE_INIT. + */ + if (gc.gc_proc != RPCSEC_GSS_INIT + && gc.gc_proc != RPCSEC_GSS_CONTINUE_INIT) { + /* + * Check for sequence number overflow. + */ + if (gc.gc_seq >= MAXSEQ) { + result = RPCSEC_GSS_CTXPROBLEM; + goto out; + } + + /* + * Check for valid service. + */ + if (gc.gc_svc != rpc_gss_svc_none && + gc.gc_svc != rpc_gss_svc_integrity && + gc.gc_svc != rpc_gss_svc_privacy) { + result = AUTH_BADCRED; + goto out; + } + } + + /* Handle RPCSEC_GSS control procedure. */ + switch (gc.gc_proc) { + + case RPCSEC_GSS_INIT: + case RPCSEC_GSS_CONTINUE_INIT: + if (rqst->rq_proc != NULLPROC) { + result = AUTH_REJECTEDCRED; + break; + } + + memset(&gr, 0, sizeof(gr)); + if (!svc_rpc_gss_accept_sec_context(client, rqst, &gr, &gc)) { + result = AUTH_REJECTEDCRED; + break; + } + + if (gr.gr_major == GSS_S_COMPLETE) { + /* + * We borrow the space for the call verf to + * pack our reply verf. + */ + rqst->rq_verf = msg->rm_call.cb_verf; + if (!svc_rpc_gss_nextverf(client, rqst, gr.gr_win)) { + result = AUTH_REJECTEDCRED; + break; + } + } else { + rqst->rq_verf = _null_auth; + } + + call_stat = svc_sendreply(rqst, + (xdrproc_t) xdr_rpc_gss_init_res, + (caddr_t) &gr); + + gss_release_buffer(&min_stat, &gr.gr_token); + + if (!call_stat) { + result = AUTH_FAILED; + break; + } + + if (gr.gr_major == GSS_S_COMPLETE) + client->cl_state = CLIENT_ESTABLISHED; + + result = RPCSEC_GSS_NODISPATCH; + break; + + case RPCSEC_GSS_DATA: + case RPCSEC_GSS_DESTROY: + if (!svc_rpc_gss_check_replay(client, gc.gc_seq)) { + result = RPCSEC_GSS_NODISPATCH; + break; + } + + if (!svc_rpc_gss_validate(client, msg, &qop)) { + result = RPCSEC_GSS_CREDPROBLEM; + break; + } + + /* + * We borrow the space for the call verf to pack our + * reply verf. + */ + rqst->rq_verf = msg->rm_call.cb_verf; + if (!svc_rpc_gss_nextverf(client, rqst, gc.gc_seq)) { + result = RPCSEC_GSS_CTXPROBLEM; + break; + } + + svc_rpc_gss_update_seq(client, gc.gc_seq); + + /* + * Change the SVCAUTH ops on the request to point at + * our own code so that we can unwrap the arguments + * and wrap the result. The caller will re-set this on + * every request to point to a set of null wrap/unwrap + * methods. Acquire an extra reference to the client + * which will be released by svc_rpc_gss_release() + * after the request has finished processing. + */ + refcount_acquire(&client->cl_refs); + rqst->rq_auth.svc_ah_ops = &svc_auth_gss_ops; + rqst->rq_auth.svc_ah_private = cc; + + if (gc.gc_proc == RPCSEC_GSS_DATA) { + /* + * We might be ready to do a callback to the server to + * see if it wants to accept/reject the connection. + */ + sx_xlock(&client->cl_lock); + if (!client->cl_done_callback) { + client->cl_done_callback = TRUE; + client->cl_qop = qop; + client->cl_rawcred.qop = _rpc_gss_num_to_qop( + client->cl_rawcred.mechanism, qop); + if (!svc_rpc_gss_callback(client, rqst)) { + result = AUTH_REJECTEDCRED; + sx_xunlock(&client->cl_lock); + break; + } + } + sx_xunlock(&client->cl_lock); + + /* + * If the server has locked this client to a + * particular service+qop pair, enforce that + * restriction now. + */ + if (client->cl_locked) { + if (client->cl_rawcred.service != gc.gc_svc) { + result = AUTH_FAILED; + break; + } else if (client->cl_qop != qop) { + result = AUTH_BADVERF; + break; + } + } + + /* + * If the qop changed, look up the new qop + * name for rawcred. + */ + if (client->cl_qop != qop) { + client->cl_qop = qop; + client->cl_rawcred.qop = _rpc_gss_num_to_qop( + client->cl_rawcred.mechanism, qop); + } + + /* + * Make sure we use the right service value + * for unwrap/wrap. + */ + if (client->cl_rawcred.service != gc.gc_svc) { + client->cl_rawcred.service = gc.gc_svc; + svc_rpc_gss_set_flavor(client); + } + + result = AUTH_OK; + } else { + if (rqst->rq_proc != NULLPROC) { + result = AUTH_REJECTEDCRED; + break; + } + + call_stat = svc_sendreply(rqst, + (xdrproc_t) xdr_void, (caddr_t) NULL); + + if (!call_stat) { + result = AUTH_FAILED; + break; + } + + svc_rpc_gss_forget_client(client); + + result = RPCSEC_GSS_NODISPATCH; + break; + } + break; + + default: + result = AUTH_BADCRED; + break; + } +out: + if (client) + svc_rpc_gss_release_client(client); + + xdr_free((xdrproc_t) xdr_rpc_gss_cred, (char *) &gc); + return (result); +} + +static bool_t +svc_rpc_gss_wrap(SVCAUTH *auth, struct mbuf **mp) +{ + struct svc_rpc_gss_cookedcred *cc; + struct svc_rpc_gss_client *client; + + rpc_gss_log_debug("in svc_rpc_gss_wrap()"); + + cc = (struct svc_rpc_gss_cookedcred *) auth->svc_ah_private; + client = cc->cc_client; + if (client->cl_state != CLIENT_ESTABLISHED + || cc->cc_service == rpc_gss_svc_none) { + return (TRUE); + } + + return (xdr_rpc_gss_wrap_data(mp, + client->cl_ctx, client->cl_qop, + cc->cc_service, cc->cc_seq)); +} + +static bool_t +svc_rpc_gss_unwrap(SVCAUTH *auth, struct mbuf **mp) +{ + struct svc_rpc_gss_cookedcred *cc; + struct svc_rpc_gss_client *client; + + rpc_gss_log_debug("in svc_rpc_gss_unwrap()"); + + cc = (struct svc_rpc_gss_cookedcred *) auth->svc_ah_private; + client = cc->cc_client; + if (client->cl_state != CLIENT_ESTABLISHED + || cc->cc_service == rpc_gss_svc_none) { + return (TRUE); + } + + return (xdr_rpc_gss_unwrap_data(mp, + client->cl_ctx, client->cl_qop, + cc->cc_service, cc->cc_seq)); +} + +static void +svc_rpc_gss_release(SVCAUTH *auth) +{ + struct svc_rpc_gss_cookedcred *cc; + struct svc_rpc_gss_client *client; + + rpc_gss_log_debug("in svc_rpc_gss_release()"); + + cc = (struct svc_rpc_gss_cookedcred *) auth->svc_ah_private; + client = cc->cc_client; + svc_rpc_gss_release_client(client); +} diff --git a/sys/rpc/svc.c b/sys/rpc/svc.c index d6d6d78..8af9e80 100644 --- a/sys/rpc/svc.c +++ b/sys/rpc/svc.c @@ -49,37 +49,105 @@ __FBSDID("$FreeBSD$"); #include <sys/param.h> #include <sys/lock.h> #include <sys/kernel.h> +#include <sys/kthread.h> #include <sys/malloc.h> +#include <sys/mbuf.h> #include <sys/mutex.h> +#include <sys/proc.h> #include <sys/queue.h> +#include <sys/socketvar.h> #include <sys/systm.h> #include <sys/ucred.h> #include <rpc/rpc.h> #include <rpc/rpcb_clnt.h> +#include <rpc/replay.h> #include <rpc/rpc_com.h> #define SVC_VERSQUIET 0x0001 /* keep quiet about vers mismatch */ -#define version_keepquiet(xp) ((u_long)(xp)->xp_p3 & SVC_VERSQUIET) +#define version_keepquiet(xp) (SVC_EXT(xp)->xp_flags & SVC_VERSQUIET) static struct svc_callout *svc_find(SVCPOOL *pool, rpcprog_t, rpcvers_t, char *); -static void __xprt_do_unregister (SVCXPRT *xprt, bool_t dolock); +static void svc_new_thread(SVCPOOL *pool); +static void xprt_unregister_locked(SVCXPRT *xprt); /* *************** SVCXPRT related stuff **************** */ +static int svcpool_minthread_sysctl(SYSCTL_HANDLER_ARGS); +static int svcpool_maxthread_sysctl(SYSCTL_HANDLER_ARGS); + SVCPOOL* -svcpool_create(void) +svcpool_create(const char *name, struct sysctl_oid_list *sysctl_base) { SVCPOOL *pool; pool = malloc(sizeof(SVCPOOL), M_RPC, M_WAITOK|M_ZERO); mtx_init(&pool->sp_lock, "sp_lock", NULL, MTX_DEF); + pool->sp_name = name; + pool->sp_state = SVCPOOL_INIT; + pool->sp_proc = NULL; TAILQ_INIT(&pool->sp_xlist); TAILQ_INIT(&pool->sp_active); TAILQ_INIT(&pool->sp_callouts); + LIST_INIT(&pool->sp_threads); + LIST_INIT(&pool->sp_idlethreads); + pool->sp_minthreads = 1; + pool->sp_maxthreads = 1; + pool->sp_threadcount = 0; + + /* + * Don't use more than a quarter of mbuf clusters or more than + * 45Mb buffering requests. + */ + pool->sp_space_high = nmbclusters * MCLBYTES / 4; + if (pool->sp_space_high > 45 << 20) + pool->sp_space_high = 45 << 20; + pool->sp_space_low = 2 * pool->sp_space_high / 3; + + sysctl_ctx_init(&pool->sp_sysctl); + if (sysctl_base) { + SYSCTL_ADD_PROC(&pool->sp_sysctl, sysctl_base, OID_AUTO, + "minthreads", CTLTYPE_INT | CTLFLAG_RW, + pool, 0, svcpool_minthread_sysctl, "I", ""); + SYSCTL_ADD_PROC(&pool->sp_sysctl, sysctl_base, OID_AUTO, + "maxthreads", CTLTYPE_INT | CTLFLAG_RW, + pool, 0, svcpool_maxthread_sysctl, "I", ""); + SYSCTL_ADD_INT(&pool->sp_sysctl, sysctl_base, OID_AUTO, + "threads", CTLFLAG_RD, &pool->sp_threadcount, 0, ""); + + SYSCTL_ADD_UINT(&pool->sp_sysctl, sysctl_base, OID_AUTO, + "request_space_used", CTLFLAG_RD, + &pool->sp_space_used, 0, + "Space in parsed but not handled requests."); + + SYSCTL_ADD_UINT(&pool->sp_sysctl, sysctl_base, OID_AUTO, + "request_space_used_highest", CTLFLAG_RD, + &pool->sp_space_used_highest, 0, + "Highest space used since reboot."); + + SYSCTL_ADD_UINT(&pool->sp_sysctl, sysctl_base, OID_AUTO, + "request_space_high", CTLFLAG_RW, + &pool->sp_space_high, 0, + "Maximum space in parsed but not handled requests."); + + SYSCTL_ADD_UINT(&pool->sp_sysctl, sysctl_base, OID_AUTO, + "request_space_low", CTLFLAG_RW, + &pool->sp_space_low, 0, + "Low water mark for request space."); + + SYSCTL_ADD_UINT(&pool->sp_sysctl, sysctl_base, OID_AUTO, + "request_space_throttled", CTLFLAG_RD, + &pool->sp_space_throttled, 0, + "Whether nfs requests are currently throttled"); + + SYSCTL_ADD_UINT(&pool->sp_sysctl, sysctl_base, OID_AUTO, + "request_space_throttle_count", CTLFLAG_RD, + &pool->sp_space_throttle_count, 0, + "Count of times throttling based on request space has occurred"); + } return pool; } @@ -87,16 +155,17 @@ svcpool_create(void) void svcpool_destroy(SVCPOOL *pool) { - SVCXPRT *xprt; + SVCXPRT *xprt, *nxprt; struct svc_callout *s; + struct svcxprt_list cleanup; + TAILQ_INIT(&cleanup); mtx_lock(&pool->sp_lock); while (TAILQ_FIRST(&pool->sp_xlist)) { xprt = TAILQ_FIRST(&pool->sp_xlist); - mtx_unlock(&pool->sp_lock); - SVC_DESTROY(xprt); - mtx_lock(&pool->sp_lock); + xprt_unregister_locked(xprt); + TAILQ_INSERT_TAIL(&cleanup, xprt, xp_link); } while (TAILQ_FIRST(&pool->sp_callouts)) { @@ -107,9 +176,97 @@ svcpool_destroy(SVCPOOL *pool) } mtx_destroy(&pool->sp_lock); + + TAILQ_FOREACH_SAFE(xprt, &cleanup, xp_link, nxprt) { + SVC_RELEASE(xprt); + } + + if (pool->sp_rcache) + replay_freecache(pool->sp_rcache); + + sysctl_ctx_free(&pool->sp_sysctl); free(pool, M_RPC); } +static bool_t +svcpool_active(SVCPOOL *pool) +{ + enum svcpool_state state = pool->sp_state; + + if (state == SVCPOOL_INIT || state == SVCPOOL_CLOSING) + return (FALSE); + return (TRUE); +} + +/* + * Sysctl handler to set the minimum thread count on a pool + */ +static int +svcpool_minthread_sysctl(SYSCTL_HANDLER_ARGS) +{ + SVCPOOL *pool; + int newminthreads, error, n; + + pool = oidp->oid_arg1; + newminthreads = pool->sp_minthreads; + error = sysctl_handle_int(oidp, &newminthreads, 0, req); + if (error == 0 && newminthreads != pool->sp_minthreads) { + if (newminthreads > pool->sp_maxthreads) + return (EINVAL); + mtx_lock(&pool->sp_lock); + if (newminthreads > pool->sp_minthreads + && svcpool_active(pool)) { + /* + * If the pool is running and we are + * increasing, create some more threads now. + */ + n = newminthreads - pool->sp_threadcount; + if (n > 0) { + mtx_unlock(&pool->sp_lock); + while (n--) + svc_new_thread(pool); + mtx_lock(&pool->sp_lock); + } + } + pool->sp_minthreads = newminthreads; + mtx_unlock(&pool->sp_lock); + } + return (error); +} + +/* + * Sysctl handler to set the maximum thread count on a pool + */ +static int +svcpool_maxthread_sysctl(SYSCTL_HANDLER_ARGS) +{ + SVCPOOL *pool; + SVCTHREAD *st; + int newmaxthreads, error; + + pool = oidp->oid_arg1; + newmaxthreads = pool->sp_maxthreads; + error = sysctl_handle_int(oidp, &newmaxthreads, 0, req); + if (error == 0 && newmaxthreads != pool->sp_maxthreads) { + if (newmaxthreads < pool->sp_minthreads) + return (EINVAL); + mtx_lock(&pool->sp_lock); + if (newmaxthreads < pool->sp_maxthreads + && svcpool_active(pool)) { + /* + * If the pool is running and we are + * decreasing, wake up some idle threads to + * encourage them to exit. + */ + LIST_FOREACH(st, &pool->sp_idlethreads, st_ilink) + cv_signal(&st->st_cond); + } + pool->sp_maxthreads = newmaxthreads; + mtx_unlock(&pool->sp_lock); + } + return (error); +} + /* * Activate a transport handle. */ @@ -125,40 +282,70 @@ xprt_register(SVCXPRT *xprt) mtx_unlock(&pool->sp_lock); } -void -xprt_unregister(SVCXPRT *xprt) -{ - __xprt_do_unregister(xprt, TRUE); -} - -void -__xprt_unregister_unlocked(SVCXPRT *xprt) -{ - __xprt_do_unregister(xprt, FALSE); -} - /* - * De-activate a transport handle. + * De-activate a transport handle. Note: the locked version doesn't + * release the transport - caller must do that after dropping the pool + * lock. */ static void -__xprt_do_unregister(SVCXPRT *xprt, bool_t dolock) +xprt_unregister_locked(SVCXPRT *xprt) { SVCPOOL *pool = xprt->xp_pool; - //__svc_generic_cleanup(xprt); - - if (dolock) - mtx_lock(&pool->sp_lock); - if (xprt->xp_active) { TAILQ_REMOVE(&pool->sp_active, xprt, xp_alink); xprt->xp_active = FALSE; } TAILQ_REMOVE(&pool->sp_xlist, xprt, xp_link); xprt->xp_registered = FALSE; +} - if (dolock) - mtx_unlock(&pool->sp_lock); +void +xprt_unregister(SVCXPRT *xprt) +{ + SVCPOOL *pool = xprt->xp_pool; + + mtx_lock(&pool->sp_lock); + xprt_unregister_locked(xprt); + mtx_unlock(&pool->sp_lock); + + SVC_RELEASE(xprt); +} + +static void +xprt_assignthread(SVCXPRT *xprt) +{ + SVCPOOL *pool = xprt->xp_pool; + SVCTHREAD *st; + + /* + * Attempt to assign a service thread to this + * transport. + */ + LIST_FOREACH(st, &pool->sp_idlethreads, st_ilink) { + if (st->st_xprt == NULL && STAILQ_EMPTY(&st->st_reqs)) + break; + } + if (st) { + SVC_ACQUIRE(xprt); + xprt->xp_thread = st; + st->st_xprt = xprt; + cv_signal(&st->st_cond); + } else { + /* + * See if we can create a new thread. The + * actual thread creation happens in + * svc_run_internal because our locking state + * is poorly defined (we are typically called + * from a socket upcall). Don't create more + * than one thread per second. + */ + if (pool->sp_state == SVCPOOL_ACTIVE + && pool->sp_lastcreatetime < time_uptime + && pool->sp_threadcount < pool->sp_maxthreads) { + pool->sp_state = SVCPOOL_THREADWANTED; + } + } } void @@ -166,30 +353,42 @@ xprt_active(SVCXPRT *xprt) { SVCPOOL *pool = xprt->xp_pool; + if (!xprt->xp_registered) { + /* + * Race with xprt_unregister - we lose. + */ + return; + } + mtx_lock(&pool->sp_lock); if (!xprt->xp_active) { TAILQ_INSERT_TAIL(&pool->sp_active, xprt, xp_alink); xprt->xp_active = TRUE; + xprt_assignthread(xprt); } - wakeup(&pool->sp_active); mtx_unlock(&pool->sp_lock); } void -xprt_inactive(SVCXPRT *xprt) +xprt_inactive_locked(SVCXPRT *xprt) { SVCPOOL *pool = xprt->xp_pool; - mtx_lock(&pool->sp_lock); - if (xprt->xp_active) { TAILQ_REMOVE(&pool->sp_active, xprt, xp_alink); xprt->xp_active = FALSE; } - wakeup(&pool->sp_active); +} + +void +xprt_inactive(SVCXPRT *xprt) +{ + SVCPOOL *pool = xprt->xp_pool; + mtx_lock(&pool->sp_lock); + xprt_inactive_locked(xprt); mtx_unlock(&pool->sp_lock); } @@ -253,9 +452,11 @@ rpcb_it: if (nconf) { bool_t dummy; struct netconfig tnc; + struct netbuf nb; tnc = *nconf; - dummy = rpcb_set(prog, vers, &tnc, - &((SVCXPRT *) xprt)->xp_ltaddr); + nb.buf = &xprt->xp_ltaddr; + nb.len = xprt->xp_ltaddr.ss_len; + dummy = rpcb_set(prog, vers, &tnc, &nb); return (dummy); } return (TRUE); @@ -305,270 +506,809 @@ svc_find(SVCPOOL *pool, rpcprog_t prog, rpcvers_t vers, char *netid) /* ******************* REPLY GENERATION ROUTINES ************ */ +static bool_t +svc_sendreply_common(struct svc_req *rqstp, struct rpc_msg *rply, + struct mbuf *body) +{ + SVCXPRT *xprt = rqstp->rq_xprt; + bool_t ok; + + if (rqstp->rq_args) { + m_freem(rqstp->rq_args); + rqstp->rq_args = NULL; + } + + if (xprt->xp_pool->sp_rcache) + replay_setreply(xprt->xp_pool->sp_rcache, + rply, svc_getrpccaller(rqstp), body); + + if (!SVCAUTH_WRAP(&rqstp->rq_auth, &body)) + return (FALSE); + + ok = SVC_REPLY(xprt, rply, rqstp->rq_addr, body); + if (rqstp->rq_addr) { + free(rqstp->rq_addr, M_SONAME); + rqstp->rq_addr = NULL; + } + + return (ok); +} + /* * Send a reply to an rpc request */ bool_t -svc_sendreply(SVCXPRT *xprt, xdrproc_t xdr_results, void * xdr_location) +svc_sendreply(struct svc_req *rqstp, xdrproc_t xdr_results, void * xdr_location) { struct rpc_msg rply; + struct mbuf *m; + XDR xdrs; + bool_t ok; + rply.rm_xid = rqstp->rq_xid; rply.rm_direction = REPLY; rply.rm_reply.rp_stat = MSG_ACCEPTED; - rply.acpted_rply.ar_verf = xprt->xp_verf; + rply.acpted_rply.ar_verf = rqstp->rq_verf; rply.acpted_rply.ar_stat = SUCCESS; - rply.acpted_rply.ar_results.where = xdr_location; - rply.acpted_rply.ar_results.proc = xdr_results; + rply.acpted_rply.ar_results.where = NULL; + rply.acpted_rply.ar_results.proc = (xdrproc_t) xdr_void; + + MGET(m, M_WAIT, MT_DATA); + MCLGET(m, M_WAIT); + m->m_len = 0; + xdrmbuf_create(&xdrs, m, XDR_ENCODE); + ok = xdr_results(&xdrs, xdr_location); + XDR_DESTROY(&xdrs); + + if (ok) { + return (svc_sendreply_common(rqstp, &rply, m)); + } else { + m_freem(m); + return (FALSE); + } +} - return (SVC_REPLY(xprt, &rply)); +bool_t +svc_sendreply_mbuf(struct svc_req *rqstp, struct mbuf *m) +{ + struct rpc_msg rply; + + rply.rm_xid = rqstp->rq_xid; + rply.rm_direction = REPLY; + rply.rm_reply.rp_stat = MSG_ACCEPTED; + rply.acpted_rply.ar_verf = rqstp->rq_verf; + rply.acpted_rply.ar_stat = SUCCESS; + rply.acpted_rply.ar_results.where = NULL; + rply.acpted_rply.ar_results.proc = (xdrproc_t) xdr_void; + + return (svc_sendreply_common(rqstp, &rply, m)); } /* * No procedure error reply */ void -svcerr_noproc(SVCXPRT *xprt) +svcerr_noproc(struct svc_req *rqstp) { + SVCXPRT *xprt = rqstp->rq_xprt; struct rpc_msg rply; + rply.rm_xid = rqstp->rq_xid; rply.rm_direction = REPLY; rply.rm_reply.rp_stat = MSG_ACCEPTED; - rply.acpted_rply.ar_verf = xprt->xp_verf; + rply.acpted_rply.ar_verf = rqstp->rq_verf; rply.acpted_rply.ar_stat = PROC_UNAVAIL; - SVC_REPLY(xprt, &rply); + if (xprt->xp_pool->sp_rcache) + replay_setreply(xprt->xp_pool->sp_rcache, + &rply, svc_getrpccaller(rqstp), NULL); + + svc_sendreply_common(rqstp, &rply, NULL); } /* * Can't decode args error reply */ void -svcerr_decode(SVCXPRT *xprt) +svcerr_decode(struct svc_req *rqstp) { + SVCXPRT *xprt = rqstp->rq_xprt; struct rpc_msg rply; + rply.rm_xid = rqstp->rq_xid; rply.rm_direction = REPLY; rply.rm_reply.rp_stat = MSG_ACCEPTED; - rply.acpted_rply.ar_verf = xprt->xp_verf; + rply.acpted_rply.ar_verf = rqstp->rq_verf; rply.acpted_rply.ar_stat = GARBAGE_ARGS; - SVC_REPLY(xprt, &rply); + if (xprt->xp_pool->sp_rcache) + replay_setreply(xprt->xp_pool->sp_rcache, + &rply, (struct sockaddr *) &xprt->xp_rtaddr, NULL); + + svc_sendreply_common(rqstp, &rply, NULL); } /* * Some system error */ void -svcerr_systemerr(SVCXPRT *xprt) +svcerr_systemerr(struct svc_req *rqstp) { + SVCXPRT *xprt = rqstp->rq_xprt; struct rpc_msg rply; + rply.rm_xid = rqstp->rq_xid; rply.rm_direction = REPLY; rply.rm_reply.rp_stat = MSG_ACCEPTED; - rply.acpted_rply.ar_verf = xprt->xp_verf; + rply.acpted_rply.ar_verf = rqstp->rq_verf; rply.acpted_rply.ar_stat = SYSTEM_ERR; - SVC_REPLY(xprt, &rply); + if (xprt->xp_pool->sp_rcache) + replay_setreply(xprt->xp_pool->sp_rcache, + &rply, svc_getrpccaller(rqstp), NULL); + + svc_sendreply_common(rqstp, &rply, NULL); } /* * Authentication error reply */ void -svcerr_auth(SVCXPRT *xprt, enum auth_stat why) +svcerr_auth(struct svc_req *rqstp, enum auth_stat why) { + SVCXPRT *xprt = rqstp->rq_xprt; struct rpc_msg rply; + rply.rm_xid = rqstp->rq_xid; rply.rm_direction = REPLY; rply.rm_reply.rp_stat = MSG_DENIED; rply.rjcted_rply.rj_stat = AUTH_ERROR; rply.rjcted_rply.rj_why = why; - SVC_REPLY(xprt, &rply); + if (xprt->xp_pool->sp_rcache) + replay_setreply(xprt->xp_pool->sp_rcache, + &rply, svc_getrpccaller(rqstp), NULL); + + svc_sendreply_common(rqstp, &rply, NULL); } /* * Auth too weak error reply */ void -svcerr_weakauth(SVCXPRT *xprt) +svcerr_weakauth(struct svc_req *rqstp) { - svcerr_auth(xprt, AUTH_TOOWEAK); + svcerr_auth(rqstp, AUTH_TOOWEAK); } /* * Program unavailable error reply */ void -svcerr_noprog(SVCXPRT *xprt) +svcerr_noprog(struct svc_req *rqstp) { + SVCXPRT *xprt = rqstp->rq_xprt; struct rpc_msg rply; + rply.rm_xid = rqstp->rq_xid; rply.rm_direction = REPLY; rply.rm_reply.rp_stat = MSG_ACCEPTED; - rply.acpted_rply.ar_verf = xprt->xp_verf; + rply.acpted_rply.ar_verf = rqstp->rq_verf; rply.acpted_rply.ar_stat = PROG_UNAVAIL; - SVC_REPLY(xprt, &rply); + if (xprt->xp_pool->sp_rcache) + replay_setreply(xprt->xp_pool->sp_rcache, + &rply, svc_getrpccaller(rqstp), NULL); + + svc_sendreply_common(rqstp, &rply, NULL); } /* * Program version mismatch error reply */ void -svcerr_progvers(SVCXPRT *xprt, rpcvers_t low_vers, rpcvers_t high_vers) +svcerr_progvers(struct svc_req *rqstp, rpcvers_t low_vers, rpcvers_t high_vers) { + SVCXPRT *xprt = rqstp->rq_xprt; struct rpc_msg rply; + rply.rm_xid = rqstp->rq_xid; rply.rm_direction = REPLY; rply.rm_reply.rp_stat = MSG_ACCEPTED; - rply.acpted_rply.ar_verf = xprt->xp_verf; + rply.acpted_rply.ar_verf = rqstp->rq_verf; rply.acpted_rply.ar_stat = PROG_MISMATCH; rply.acpted_rply.ar_vers.low = (uint32_t)low_vers; rply.acpted_rply.ar_vers.high = (uint32_t)high_vers; - SVC_REPLY(xprt, &rply); + if (xprt->xp_pool->sp_rcache) + replay_setreply(xprt->xp_pool->sp_rcache, + &rply, svc_getrpccaller(rqstp), NULL); + + svc_sendreply_common(rqstp, &rply, NULL); } -/* ******************* SERVER INPUT STUFF ******************* */ +/* + * Allocate a new server transport structure. All fields are + * initialized to zero and xp_p3 is initialized to point at an + * extension structure to hold various flags and authentication + * parameters. + */ +SVCXPRT * +svc_xprt_alloc() +{ + SVCXPRT *xprt; + SVCXPRT_EXT *ext; + + xprt = mem_alloc(sizeof(SVCXPRT)); + memset(xprt, 0, sizeof(SVCXPRT)); + ext = mem_alloc(sizeof(SVCXPRT_EXT)); + memset(ext, 0, sizeof(SVCXPRT_EXT)); + xprt->xp_p3 = ext; + refcount_init(&xprt->xp_refs, 1); + + return (xprt); +} /* - * Get server side input from some transport. - * - * Statement of authentication parameters management: - * This function owns and manages all authentication parameters, specifically - * the "raw" parameters (msg.rm_call.cb_cred and msg.rm_call.cb_verf) and - * the "cooked" credentials (rqst->rq_clntcred). - * In-kernel, we represent non-trivial cooked creds with struct ucred. - * In all events, all three parameters are freed upon exit from this routine. - * The storage is trivially management on the call stack in user land, but - * is mallocated in kernel land. + * Free a server transport structure. */ +void +svc_xprt_free(xprt) + SVCXPRT *xprt; +{ -static void -svc_getreq(SVCXPRT *xprt) + mem_free(xprt->xp_p3, sizeof(SVCXPRT_EXT)); + mem_free(xprt, sizeof(SVCXPRT)); +} + +/* ******************* SERVER INPUT STUFF ******************* */ + +/* + * Read RPC requests from a transport and queue them to be + * executed. We handle authentication and replay cache replies here. + * Actually dispatching the RPC is deferred till svc_executereq. + */ +static enum xprt_stat +svc_getreq(SVCXPRT *xprt, struct svc_req **rqstp_ret) { SVCPOOL *pool = xprt->xp_pool; - struct svc_req r; + struct svc_req *r; struct rpc_msg msg; - int prog_found; - rpcvers_t low_vers; - rpcvers_t high_vers; + struct mbuf *args; enum xprt_stat stat; - char cred_area[2*MAX_AUTH_BYTES + sizeof(struct xucred)]; - - msg.rm_call.cb_cred.oa_base = cred_area; - msg.rm_call.cb_verf.oa_base = &cred_area[MAX_AUTH_BYTES]; - r.rq_clntcred = &cred_area[2*MAX_AUTH_BYTES]; /* now receive msgs from xprtprt (support batch calls) */ - do { - if (SVC_RECV(xprt, &msg)) { - - /* now find the exported program and call it */ - struct svc_callout *s; - enum auth_stat why; - - r.rq_xprt = xprt; - r.rq_prog = msg.rm_call.cb_prog; - r.rq_vers = msg.rm_call.cb_vers; - r.rq_proc = msg.rm_call.cb_proc; - r.rq_cred = msg.rm_call.cb_cred; - /* first authenticate the message */ - if ((why = _authenticate(&r, &msg)) != AUTH_OK) { - svcerr_auth(xprt, why); + r = malloc(sizeof(*r), M_RPC, M_WAITOK|M_ZERO); + + msg.rm_call.cb_cred.oa_base = r->rq_credarea; + msg.rm_call.cb_verf.oa_base = &r->rq_credarea[MAX_AUTH_BYTES]; + r->rq_clntcred = &r->rq_credarea[2*MAX_AUTH_BYTES]; + if (SVC_RECV(xprt, &msg, &r->rq_addr, &args)) { + enum auth_stat why; + + /* + * Handle replays and authenticate before queuing the + * request to be executed. + */ + SVC_ACQUIRE(xprt); + r->rq_xprt = xprt; + if (pool->sp_rcache) { + struct rpc_msg repmsg; + struct mbuf *repbody; + enum replay_state rs; + rs = replay_find(pool->sp_rcache, &msg, + svc_getrpccaller(r), &repmsg, &repbody); + switch (rs) { + case RS_NEW: + break; + case RS_DONE: + SVC_REPLY(xprt, &repmsg, r->rq_addr, + repbody); + if (r->rq_addr) { + free(r->rq_addr, M_SONAME); + r->rq_addr = NULL; + } + goto call_done; + + default: goto call_done; } - /* now match message with a registered service*/ - prog_found = FALSE; - low_vers = (rpcvers_t) -1L; - high_vers = (rpcvers_t) 0L; - TAILQ_FOREACH(s, &pool->sp_callouts, sc_link) { - if (s->sc_prog == r.rq_prog) { - if (s->sc_vers == r.rq_vers) { - (*s->sc_dispatch)(&r, xprt); - goto call_done; - } /* found correct version */ - prog_found = TRUE; - if (s->sc_vers < low_vers) - low_vers = s->sc_vers; - if (s->sc_vers > high_vers) - high_vers = s->sc_vers; - } /* found correct program */ - } + } + + r->rq_xid = msg.rm_xid; + r->rq_prog = msg.rm_call.cb_prog; + r->rq_vers = msg.rm_call.cb_vers; + r->rq_proc = msg.rm_call.cb_proc; + r->rq_size = sizeof(*r) + m_length(args, NULL); + r->rq_args = args; + if ((why = _authenticate(r, &msg)) != AUTH_OK) { /* - * if we got here, the program or version - * is not served ... + * RPCSEC_GSS uses this return code + * for requests that form part of its + * context establishment protocol and + * should not be dispatched to the + * application. */ - if (prog_found) - svcerr_progvers(xprt, low_vers, high_vers); - else - svcerr_noprog(xprt); - /* Fall through to ... */ + if (why != RPCSEC_GSS_NODISPATCH) + svcerr_auth(r, why); + goto call_done; } + + if (!SVCAUTH_UNWRAP(&r->rq_auth, &r->rq_args)) { + svcerr_decode(r); + goto call_done; + } + /* - * Check if the xprt has been disconnected in a - * recursive call in the service dispatch routine. - * If so, then break. + * Everything checks out, return request to caller. */ - mtx_lock(&pool->sp_lock); - if (!xprt->xp_registered) { - mtx_unlock(&pool->sp_lock); - break; - } - mtx_unlock(&pool->sp_lock); + *rqstp_ret = r; + r = NULL; + } call_done: - if ((stat = SVC_STAT(xprt)) == XPRT_DIED) { - SVC_DESTROY(xprt); - break; + if (r) { + svc_freereq(r); + r = NULL; + } + if ((stat = SVC_STAT(xprt)) == XPRT_DIED) { + xprt_unregister(xprt); + } + + return (stat); +} + +static void +svc_executereq(struct svc_req *rqstp) +{ + SVCXPRT *xprt = rqstp->rq_xprt; + SVCPOOL *pool = xprt->xp_pool; + int prog_found; + rpcvers_t low_vers; + rpcvers_t high_vers; + struct svc_callout *s; + + /* now match message with a registered service*/ + prog_found = FALSE; + low_vers = (rpcvers_t) -1L; + high_vers = (rpcvers_t) 0L; + TAILQ_FOREACH(s, &pool->sp_callouts, sc_link) { + if (s->sc_prog == rqstp->rq_prog) { + if (s->sc_vers == rqstp->rq_vers) { + /* + * We hand ownership of r to the + * dispatch method - they must call + * svc_freereq. + */ + (*s->sc_dispatch)(rqstp, xprt); + return; + } /* found correct version */ + prog_found = TRUE; + if (s->sc_vers < low_vers) + low_vers = s->sc_vers; + if (s->sc_vers > high_vers) + high_vers = s->sc_vers; + } /* found correct program */ + } + + /* + * if we got here, the program or version + * is not served ... + */ + if (prog_found) + svcerr_progvers(rqstp, low_vers, high_vers); + else + svcerr_noprog(rqstp); + + svc_freereq(rqstp); +} + +static void +svc_checkidle(SVCPOOL *pool) +{ + SVCXPRT *xprt, *nxprt; + time_t timo; + struct svcxprt_list cleanup; + + TAILQ_INIT(&cleanup); + TAILQ_FOREACH_SAFE(xprt, &pool->sp_xlist, xp_link, nxprt) { + /* + * Only some transports have idle timers. Don't time + * something out which is just waking up. + */ + if (!xprt->xp_idletimeout || xprt->xp_thread) + continue; + + timo = xprt->xp_lastactive + xprt->xp_idletimeout; + if (time_uptime > timo) { + xprt_unregister_locked(xprt); + TAILQ_INSERT_TAIL(&cleanup, xprt, xp_link); } - } while (stat == XPRT_MOREREQS); + } + + mtx_unlock(&pool->sp_lock); + TAILQ_FOREACH_SAFE(xprt, &cleanup, xp_link, nxprt) { + SVC_RELEASE(xprt); + } + mtx_lock(&pool->sp_lock); + } -void -svc_run(SVCPOOL *pool) +static void +svc_assign_waiting_sockets(SVCPOOL *pool) +{ + SVCXPRT *xprt; + + TAILQ_FOREACH(xprt, &pool->sp_active, xp_alink) { + if (!xprt->xp_thread) { + xprt_assignthread(xprt); + } + } +} + +static bool_t +svc_request_space_available(SVCPOOL *pool) +{ + + mtx_assert(&pool->sp_lock, MA_OWNED); + + if (pool->sp_space_throttled) { + /* + * Below the low-water yet? If so, assign any waiting sockets. + */ + if (pool->sp_space_used < pool->sp_space_low) { + pool->sp_space_throttled = FALSE; + svc_assign_waiting_sockets(pool); + return TRUE; + } + + return FALSE; + } else { + if (pool->sp_space_used + >= pool->sp_space_high) { + pool->sp_space_throttled = TRUE; + pool->sp_space_throttle_count++; + return FALSE; + } + + return TRUE; + } +} + +static void +svc_run_internal(SVCPOOL *pool, bool_t ismaster) { + SVCTHREAD *st, *stpref; SVCXPRT *xprt; + enum xprt_stat stat; + struct svc_req *rqstp; int error; + st = mem_alloc(sizeof(*st)); + st->st_xprt = NULL; + STAILQ_INIT(&st->st_reqs); + cv_init(&st->st_cond, "rpcsvc"); + mtx_lock(&pool->sp_lock); + LIST_INSERT_HEAD(&pool->sp_threads, st, st_link); - pool->sp_exited = FALSE; + /* + * If we are a new thread which was spawned to cope with + * increased load, set the state back to SVCPOOL_ACTIVE. + */ + if (pool->sp_state == SVCPOOL_THREADSTARTING) + pool->sp_state = SVCPOOL_ACTIVE; - while (!pool->sp_exited) { - xprt = TAILQ_FIRST(&pool->sp_active); - if (!xprt) { - error = msleep(&pool->sp_active, &pool->sp_lock, PCATCH, - "rpcsvc", 0); - if (error) + while (pool->sp_state != SVCPOOL_CLOSING) { + /* + * Check for idle transports once per second. + */ + if (time_uptime > pool->sp_lastidlecheck) { + pool->sp_lastidlecheck = time_uptime; + svc_checkidle(pool); + } + + xprt = st->st_xprt; + if (!xprt && STAILQ_EMPTY(&st->st_reqs)) { + /* + * Enforce maxthreads count. + */ + if (pool->sp_threadcount > pool->sp_maxthreads) + break; + + /* + * Before sleeping, see if we can find an + * active transport which isn't being serviced + * by a thread. + */ + if (svc_request_space_available(pool)) { + TAILQ_FOREACH(xprt, &pool->sp_active, + xp_alink) { + if (!xprt->xp_thread) { + SVC_ACQUIRE(xprt); + xprt->xp_thread = st; + st->st_xprt = xprt; + break; + } + } + } + if (st->st_xprt) + continue; + + LIST_INSERT_HEAD(&pool->sp_idlethreads, st, st_ilink); + error = cv_timedwait_sig(&st->st_cond, &pool->sp_lock, + 5 * hz); + LIST_REMOVE(st, st_ilink); + + /* + * Reduce worker thread count when idle. + */ + if (error == EWOULDBLOCK) { + if (!ismaster + && (pool->sp_threadcount + > pool->sp_minthreads) + && !st->st_xprt + && STAILQ_EMPTY(&st->st_reqs)) + break; + } + if (error == EWOULDBLOCK) + continue; + if (error) { + if (pool->sp_state != SVCPOOL_CLOSING) { + mtx_unlock(&pool->sp_lock); + svc_exit(pool); + mtx_lock(&pool->sp_lock); + } break; + } + + if (pool->sp_state == SVCPOOL_THREADWANTED) { + pool->sp_state = SVCPOOL_THREADSTARTING; + pool->sp_lastcreatetime = time_uptime; + mtx_unlock(&pool->sp_lock); + svc_new_thread(pool); + mtx_lock(&pool->sp_lock); + } continue; } + if (xprt) { + /* + * Drain the transport socket and queue up any + * RPCs. + */ + xprt->xp_lastactive = time_uptime; + stat = XPRT_IDLE; + do { + if (!svc_request_space_available(pool)) + break; + rqstp = NULL; + mtx_unlock(&pool->sp_lock); + stat = svc_getreq(xprt, &rqstp); + mtx_lock(&pool->sp_lock); + if (rqstp) { + /* + * See if the application has + * a preference for some other + * thread. + */ + stpref = st; + if (pool->sp_assign) + stpref = pool->sp_assign(st, + rqstp); + + pool->sp_space_used += + rqstp->rq_size; + if (pool->sp_space_used + > pool->sp_space_used_highest) + pool->sp_space_used_highest = + pool->sp_space_used; + rqstp->rq_thread = stpref; + STAILQ_INSERT_TAIL(&stpref->st_reqs, + rqstp, rq_link); + stpref->st_reqcount++; + + /* + * If we assigned the request + * to another thread, make + * sure its awake and continue + * reading from the + * socket. Otherwise, try to + * find some other thread to + * read from the socket and + * execute the request + * immediately. + */ + if (stpref != st) { + cv_signal(&stpref->st_cond); + continue; + } else { + break; + } + } + } while (stat == XPRT_MOREREQS + && pool->sp_state != SVCPOOL_CLOSING); + + /* + * Move this transport to the end of the + * active list to ensure fairness when + * multiple transports are active. If this was + * the last queued request, svc_getreq will + * end up calling xprt_inactive to remove from + * the active list. + */ + xprt->xp_thread = NULL; + st->st_xprt = NULL; + if (xprt->xp_active) { + xprt_assignthread(xprt); + TAILQ_REMOVE(&pool->sp_active, xprt, xp_alink); + TAILQ_INSERT_TAIL(&pool->sp_active, xprt, + xp_alink); + } + mtx_unlock(&pool->sp_lock); + SVC_RELEASE(xprt); + mtx_lock(&pool->sp_lock); + } + /* - * Move this transport to the end to ensure fairness - * when multiple transports are active. If this was - * the last queued request, svc_getreq will end up - * calling xprt_inactive to remove from the active - * list. + * Execute what we have queued. */ - TAILQ_REMOVE(&pool->sp_active, xprt, xp_alink); - TAILQ_INSERT_TAIL(&pool->sp_active, xprt, xp_alink); + while ((rqstp = STAILQ_FIRST(&st->st_reqs)) != NULL) { + size_t sz = rqstp->rq_size; + mtx_unlock(&pool->sp_lock); + svc_executereq(rqstp); + mtx_lock(&pool->sp_lock); + pool->sp_space_used -= sz; + } + } - mtx_unlock(&pool->sp_lock); - svc_getreq(xprt); - mtx_lock(&pool->sp_lock); + if (st->st_xprt) { + xprt = st->st_xprt; + st->st_xprt = NULL; + SVC_RELEASE(xprt); + } + + KASSERT(STAILQ_EMPTY(&st->st_reqs), ("stray reqs on exit")); + LIST_REMOVE(st, st_link); + pool->sp_threadcount--; + + mtx_unlock(&pool->sp_lock); + + cv_destroy(&st->st_cond); + mem_free(st, sizeof(*st)); + + if (!ismaster) + wakeup(pool); +} + +static void +svc_thread_start(void *arg) +{ + + svc_run_internal((SVCPOOL *) arg, FALSE); + kthread_exit(); +} + +static void +svc_new_thread(SVCPOOL *pool) +{ + struct thread *td; + + pool->sp_threadcount++; + kthread_add(svc_thread_start, pool, + pool->sp_proc, &td, 0, 0, + "%s: service", pool->sp_name); +} + +void +svc_run(SVCPOOL *pool) +{ + int i; + struct proc *p; + struct thread *td; + + p = curproc; + td = curthread; + snprintf(td->td_name, sizeof(td->td_name), + "%s: master", pool->sp_name); + pool->sp_state = SVCPOOL_ACTIVE; + pool->sp_proc = p; + pool->sp_lastcreatetime = time_uptime; + pool->sp_threadcount = 1; + + for (i = 1; i < pool->sp_minthreads; i++) { + svc_new_thread(pool); } + svc_run_internal(pool, TRUE); + + mtx_lock(&pool->sp_lock); + while (pool->sp_threadcount > 0) + msleep(pool, &pool->sp_lock, 0, "svcexit", 0); mtx_unlock(&pool->sp_lock); } void svc_exit(SVCPOOL *pool) { + SVCTHREAD *st; + mtx_lock(&pool->sp_lock); - pool->sp_exited = TRUE; - wakeup(&pool->sp_active); + + pool->sp_state = SVCPOOL_CLOSING; + LIST_FOREACH(st, &pool->sp_idlethreads, st_ilink) + cv_signal(&st->st_cond); + mtx_unlock(&pool->sp_lock); } + +bool_t +svc_getargs(struct svc_req *rqstp, xdrproc_t xargs, void *args) +{ + struct mbuf *m; + XDR xdrs; + bool_t stat; + + m = rqstp->rq_args; + rqstp->rq_args = NULL; + + xdrmbuf_create(&xdrs, m, XDR_DECODE); + stat = xargs(&xdrs, args); + XDR_DESTROY(&xdrs); + + return (stat); +} + +bool_t +svc_freeargs(struct svc_req *rqstp, xdrproc_t xargs, void *args) +{ + XDR xdrs; + + if (rqstp->rq_addr) { + free(rqstp->rq_addr, M_SONAME); + rqstp->rq_addr = NULL; + } + + xdrs.x_op = XDR_FREE; + return (xargs(&xdrs, args)); +} + +void +svc_freereq(struct svc_req *rqstp) +{ + SVCTHREAD *st; + SVCXPRT *xprt; + SVCPOOL *pool; + + st = rqstp->rq_thread; + xprt = rqstp->rq_xprt; + if (xprt) + pool = xprt->xp_pool; + else + pool = NULL; + if (st) { + mtx_lock(&pool->sp_lock); + KASSERT(rqstp == STAILQ_FIRST(&st->st_reqs), + ("Freeing request out of order")); + STAILQ_REMOVE_HEAD(&st->st_reqs, rq_link); + st->st_reqcount--; + if (pool->sp_done) + pool->sp_done(st, rqstp); + mtx_unlock(&pool->sp_lock); + } + + if (rqstp->rq_auth.svc_ah_ops) + SVCAUTH_RELEASE(&rqstp->rq_auth); + + if (rqstp->rq_xprt) { + SVC_RELEASE(rqstp->rq_xprt); + } + + if (rqstp->rq_addr) + free(rqstp->rq_addr, M_SONAME); + + if (rqstp->rq_args) + m_freem(rqstp->rq_args); + + free(rqstp, M_RPC); +} diff --git a/sys/rpc/svc.h b/sys/rpc/svc.h index 21c7491..eac9bc0 100644 --- a/sys/rpc/svc.h +++ b/sys/rpc/svc.h @@ -47,6 +47,9 @@ #include <sys/queue.h> #include <sys/_lock.h> #include <sys/_mutex.h> +#include <sys/_sx.h> +#include <sys/condvar.h> +#include <sys/sysctl.h> #endif /* @@ -92,8 +95,23 @@ enum xprt_stat { }; struct __rpc_svcxprt; +struct mbuf; struct xp_ops { +#ifdef _KERNEL + /* receive incoming requests */ + bool_t (*xp_recv)(struct __rpc_svcxprt *, struct rpc_msg *, + struct sockaddr **, struct mbuf **); + /* get transport status */ + enum xprt_stat (*xp_stat)(struct __rpc_svcxprt *); + /* send reply */ + bool_t (*xp_reply)(struct __rpc_svcxprt *, struct rpc_msg *, + struct sockaddr *, struct mbuf *); + /* destroy this struct */ + void (*xp_destroy)(struct __rpc_svcxprt *); + /* catch-all function */ + bool_t (*xp_control)(struct __rpc_svcxprt *, const u_int, void *); +#else /* receive incoming requests */ bool_t (*xp_recv)(struct __rpc_svcxprt *, struct rpc_msg *); /* get transport status */ @@ -106,9 +124,6 @@ struct xp_ops { bool_t (*xp_freeargs)(struct __rpc_svcxprt *, xdrproc_t, void *); /* destroy this struct */ void (*xp_destroy)(struct __rpc_svcxprt *); -#ifdef _KERNEL - /* catch-all function */ - bool_t (*xp_control)(struct __rpc_svcxprt *, const u_int, void *); #endif }; @@ -121,32 +136,35 @@ struct xp_ops2 { #ifdef _KERNEL struct __rpc_svcpool; +struct __rpc_svcthread; #endif /* - * Server side transport handle + * Server side transport handle. In the kernel, transports have a + * reference count which tracks the number of currently assigned + * worker threads plus one for the service pool's reference. */ typedef struct __rpc_svcxprt { #ifdef _KERNEL - struct mtx xp_lock; + volatile u_int xp_refs; + struct sx xp_lock; struct __rpc_svcpool *xp_pool; /* owning pool (see below) */ TAILQ_ENTRY(__rpc_svcxprt) xp_link; TAILQ_ENTRY(__rpc_svcxprt) xp_alink; bool_t xp_registered; /* xprt_register has been called */ bool_t xp_active; /* xprt_active has been called */ + struct __rpc_svcthread *xp_thread; /* assigned service thread */ struct socket* xp_socket; const struct xp_ops *xp_ops; char *xp_netid; /* network token */ - struct netbuf xp_ltaddr; /* local transport address */ - struct netbuf xp_rtaddr; /* remote transport address */ - struct opaque_auth xp_verf; /* raw response verifier */ - uint32_t xp_xid; /* current transaction ID */ - XDR xp_xdrreq; /* xdr stream for decoding request */ - XDR xp_xdrrep; /* xdr stream for encoding reply */ + struct sockaddr_storage xp_ltaddr; /* local transport address */ + struct sockaddr_storage xp_rtaddr; /* remote transport address */ void *xp_p1; /* private: for use by svc ops */ void *xp_p2; /* private: for use by svc ops */ void *xp_p3; /* private: for use by svc lib */ int xp_type; /* transport type */ + int xp_idletimeout; /* idle time before closing */ + time_t xp_lastactive; /* time of last RPC */ #else int xp_fd; u_short xp_port; /* associated port number */ @@ -167,6 +185,33 @@ typedef struct __rpc_svcxprt { #endif } SVCXPRT; +/* + * Interface to server-side authentication flavors. + */ +typedef struct __rpc_svcauth { + struct svc_auth_ops { +#ifdef _KERNEL + int (*svc_ah_wrap)(struct __rpc_svcauth *, struct mbuf **); + int (*svc_ah_unwrap)(struct __rpc_svcauth *, struct mbuf **); + void (*svc_ah_release)(struct __rpc_svcauth *); +#else + int (*svc_ah_wrap)(struct __rpc_svcauth *, XDR *, + xdrproc_t, caddr_t); + int (*svc_ah_unwrap)(struct __rpc_svcauth *, XDR *, + xdrproc_t, caddr_t); +#endif + } *svc_ah_ops; + void *svc_ah_private; +} SVCAUTH; + +/* + * Server transport extensions (accessed via xp_p3). + */ +typedef struct __rpc_svcxprt_ext { + int xp_flags; /* versquiet */ + SVCAUTH xp_auth; /* interface to auth methods */ +} SVCXPRT_EXT; + #ifdef _KERNEL /* @@ -184,6 +229,61 @@ struct svc_callout { }; TAILQ_HEAD(svc_callout_list, svc_callout); +struct __rpc_svcthread; + +/* + * Service request + */ +struct svc_req { + STAILQ_ENTRY(svc_req) rq_link; /* list of requests for a thread */ + struct __rpc_svcthread *rq_thread; /* thread which is to execute this */ + uint32_t rq_xid; /* RPC transaction ID */ + uint32_t rq_prog; /* service program number */ + uint32_t rq_vers; /* service protocol version */ + uint32_t rq_proc; /* the desired procedure */ + size_t rq_size; /* space used by request */ + struct mbuf *rq_args; /* XDR-encoded procedure arguments */ + struct opaque_auth rq_cred; /* raw creds from the wire */ + struct opaque_auth rq_verf; /* verifier for the reply */ + void *rq_clntcred; /* read only cooked cred */ + SVCAUTH rq_auth; /* interface to auth methods */ + SVCXPRT *rq_xprt; /* associated transport */ + struct sockaddr *rq_addr; /* reply address or NULL if connected */ + void *rq_p1; /* application workspace */ + int rq_p2; /* application workspace */ + uint64_t rq_p3; /* application workspace */ + char rq_credarea[3*MAX_AUTH_BYTES]; +}; +STAILQ_HEAD(svc_reqlist, svc_req); + +#define svc_getrpccaller(rq) \ + ((rq)->rq_addr ? (rq)->rq_addr : \ + (struct sockaddr *) &(rq)->rq_xprt->xp_rtaddr) + +/* + * This structure is used to manage a thread which is executing + * requests from a service pool. A service thread is in one of three + * states: + * + * SVCTHREAD_SLEEPING waiting for a request to process + * SVCTHREAD_ACTIVE processing a request + * SVCTHREAD_EXITING exiting after finishing current request + * + * Threads which have no work to process sleep on the pool's sp_active + * list. When a transport becomes active, it is assigned a service + * thread to read and execute pending RPCs. + */ +typedef struct __rpc_svcthread { + SVCXPRT *st_xprt; /* transport we are processing */ + struct svc_reqlist st_reqs; /* RPC requests to execute */ + int st_reqcount; /* number of queued reqs */ + struct cv st_cond; /* sleeping for work */ + LIST_ENTRY(__rpc_svcthread) st_link; /* all threads list */ + LIST_ENTRY(__rpc_svcthread) st_ilink; /* idle threads list */ + LIST_ENTRY(__rpc_svcthread) st_alink; /* application thread list */ +} SVCTHREAD; +LIST_HEAD(svcthread_list, __rpc_svcthread); + /* * In the kernel, we can't use global variables to store lists of * transports etc. since otherwise we could not have two unrelated RPC @@ -197,15 +297,55 @@ TAILQ_HEAD(svc_callout_list, svc_callout); * server. */ TAILQ_HEAD(svcxprt_list, __rpc_svcxprt); +enum svcpool_state { + SVCPOOL_INIT, /* svc_run not called yet */ + SVCPOOL_ACTIVE, /* normal running state */ + SVCPOOL_THREADWANTED, /* new service thread requested */ + SVCPOOL_THREADSTARTING, /* new service thread started */ + SVCPOOL_CLOSING /* svc_exit called */ +}; +typedef SVCTHREAD *pool_assign_fn(SVCTHREAD *, struct svc_req *); +typedef void pool_done_fn(SVCTHREAD *, struct svc_req *); typedef struct __rpc_svcpool { struct mtx sp_lock; /* protect the transport lists */ + const char *sp_name; /* pool name (e.g. "nfsd", "NLM" */ + enum svcpool_state sp_state; /* current pool state */ + struct proc *sp_proc; /* process which is in svc_run */ struct svcxprt_list sp_xlist; /* all transports in the pool */ struct svcxprt_list sp_active; /* transports needing service */ struct svc_callout_list sp_callouts; /* (prog,vers)->dispatch list */ - bool_t sp_exited; /* true if shutting down */ + struct svcthread_list sp_threads; /* service threads */ + struct svcthread_list sp_idlethreads; /* idle service threads */ + int sp_minthreads; /* minimum service thread count */ + int sp_maxthreads; /* maximum service thread count */ + int sp_threadcount; /* current service thread count */ + time_t sp_lastcreatetime; /* when we last started a thread */ + time_t sp_lastidlecheck; /* when we last checked idle transports */ + + /* + * Hooks to allow an application to control request to thread + * placement. + */ + pool_assign_fn *sp_assign; + pool_done_fn *sp_done; + + /* + * These variables are used to put an upper bound on the + * amount of memory used by RPC requests which are queued + * waiting for execution. + */ + unsigned int sp_space_low; + unsigned int sp_space_high; + unsigned int sp_space_used; + unsigned int sp_space_used_highest; + bool_t sp_space_throttled; + int sp_space_throttle_count; + + struct replay_cache *sp_rcache; /* optional replay cache */ + struct sysctl_ctx_list sp_sysctl; } SVCPOOL; -#endif +#else /* * Service request @@ -224,6 +364,8 @@ struct svc_req { */ #define svc_getrpccaller(x) (&(x)->xp_rtaddr) +#endif + /* * Operations defined on an SVCXPRT handle * @@ -232,6 +374,32 @@ struct svc_req { * xdrproc_t xargs; * void * argsp; */ +#ifdef _KERNEL + +#define SVC_ACQUIRE(xprt) \ + refcount_acquire(&(xprt)->xp_refs) + +#define SVC_RELEASE(xprt) \ + if (refcount_release(&(xprt)->xp_refs)) \ + SVC_DESTROY(xprt) + +#define SVC_RECV(xprt, msg, addr, args) \ + (*(xprt)->xp_ops->xp_recv)((xprt), (msg), (addr), (args)) + +#define SVC_STAT(xprt) \ + (*(xprt)->xp_ops->xp_stat)(xprt) + +#define SVC_REPLY(xprt, msg, addr, m) \ + (*(xprt)->xp_ops->xp_reply) ((xprt), (msg), (addr), (m)) + +#define SVC_DESTROY(xprt) \ + (*(xprt)->xp_ops->xp_destroy)(xprt) + +#define SVC_CONTROL(xprt, rq, in) \ + (*(xprt)->xp_ops->xp_control)((xprt), (rq), (in)) + +#else + #define SVC_RECV(xprt, msg) \ (*(xprt)->xp_ops->xp_recv)((xprt), (msg)) #define svc_recv(xprt, msg) \ @@ -262,12 +430,32 @@ struct svc_req { #define svc_destroy(xprt) \ (*(xprt)->xp_ops->xp_destroy)(xprt) -#ifdef _KERNEL -#define SVC_CONTROL(xprt, rq, in) \ - (*(xprt)->xp_ops->xp_control)((xprt), (rq), (in)) -#else #define SVC_CONTROL(xprt, rq, in) \ (*(xprt)->xp_ops2->xp_control)((xprt), (rq), (in)) + +#endif + +#define SVC_EXT(xprt) \ + ((SVCXPRT_EXT *) xprt->xp_p3) + +#define SVC_AUTH(xprt) \ + (SVC_EXT(xprt)->xp_auth) + +/* + * Operations defined on an SVCAUTH handle + */ +#ifdef _KERNEL +#define SVCAUTH_WRAP(auth, mp) \ + ((auth)->svc_ah_ops->svc_ah_wrap(auth, mp)) +#define SVCAUTH_UNWRAP(auth, mp) \ + ((auth)->svc_ah_ops->svc_ah_unwrap(auth, mp)) +#define SVCAUTH_RELEASE(auth) \ + ((auth)->svc_ah_ops->svc_ah_release(auth)) +#else +#define SVCAUTH_WRAP(auth, xdrs, xfunc, xwhere) \ + ((auth)->svc_ah_ops->svc_ah_wrap(auth, xdrs, xfunc, xwhere)) +#define SVCAUTH_UNWRAP(auth, xdrs, xfunc, xwhere) \ + ((auth)->svc_ah_ops->svc_ah_unwrap(auth, xdrs, xfunc, xwhere)) #endif /* @@ -332,6 +520,7 @@ __END_DECLS __BEGIN_DECLS extern void xprt_active(SVCXPRT *); extern void xprt_inactive(SVCXPRT *); +extern void xprt_inactive_locked(SVCXPRT *); __END_DECLS #endif @@ -363,6 +552,17 @@ __END_DECLS */ __BEGIN_DECLS +#ifdef _KERNEL +extern bool_t svc_sendreply(struct svc_req *, xdrproc_t, void *); +extern bool_t svc_sendreply_mbuf(struct svc_req *, struct mbuf *); +extern void svcerr_decode(struct svc_req *); +extern void svcerr_weakauth(struct svc_req *); +extern void svcerr_noproc(struct svc_req *); +extern void svcerr_progvers(struct svc_req *, rpcvers_t, rpcvers_t); +extern void svcerr_auth(struct svc_req *, enum auth_stat); +extern void svcerr_noprog(struct svc_req *); +extern void svcerr_systemerr(struct svc_req *); +#else extern bool_t svc_sendreply(SVCXPRT *, xdrproc_t, void *); extern void svcerr_decode(SVCXPRT *); extern void svcerr_weakauth(SVCXPRT *); @@ -371,6 +571,7 @@ extern void svcerr_progvers(SVCXPRT *, rpcvers_t, rpcvers_t); extern void svcerr_auth(SVCXPRT *, enum auth_stat); extern void svcerr_noprog(SVCXPRT *); extern void svcerr_systemerr(SVCXPRT *); +#endif extern int rpc_reg(rpcprog_t, rpcvers_t, rpcproc_t, char *(*)(char *), xdrproc_t, xdrproc_t, char *); @@ -410,6 +611,8 @@ extern void rpctest_service(void); __END_DECLS __BEGIN_DECLS +extern SVCXPRT *svc_xprt_alloc(void); +extern void svc_xprt_free(SVCXPRT *); #ifndef _KERNEL extern void svc_getreq(int); extern void svc_getreqset(fd_set *); @@ -421,6 +624,10 @@ extern void svc_exit(void); #else extern void svc_run(SVCPOOL *); extern void svc_exit(SVCPOOL *); +extern bool_t svc_getargs(struct svc_req *, xdrproc_t, void *); +extern bool_t svc_freeargs(struct svc_req *, xdrproc_t, void *); +extern void svc_freereq(struct svc_req *); + #endif __END_DECLS @@ -441,7 +648,8 @@ __BEGIN_DECLS /* * Create a new service pool. */ -extern SVCPOOL* svcpool_create(void); +extern SVCPOOL* svcpool_create(const char *name, + struct sysctl_oid_list *sysctl_base); /* * Destroy a service pool, including all registered transports. diff --git a/sys/rpc/svc_auth.c b/sys/rpc/svc_auth.c index 22d4e61..6d5a79b 100644 --- a/sys/rpc/svc_auth.c +++ b/sys/rpc/svc_auth.c @@ -52,6 +52,13 @@ __FBSDID("$FreeBSD$"); #include <rpc/rpc.h> +static enum auth_stat (*_svcauth_rpcsec_gss)(struct svc_req *, + struct rpc_msg *) = NULL; +static int (*_svcauth_rpcsec_gss_getcred)(struct svc_req *, + struct ucred **, int *); + +static struct svc_auth_ops svc_auth_null_ops; + /* * The call rpc message, msg has been obtained from the wire. The msg contains * the raw form of credentials and verifiers. authenticate returns AUTH_OK @@ -77,8 +84,8 @@ _authenticate(struct svc_req *rqst, struct rpc_msg *msg) enum auth_stat dummy; rqst->rq_cred = msg->rm_call.cb_cred; - rqst->rq_xprt->xp_verf.oa_flavor = _null_auth.oa_flavor; - rqst->rq_xprt->xp_verf.oa_length = 0; + rqst->rq_auth.svc_ah_ops = &svc_auth_null_ops; + rqst->rq_auth.svc_ah_private = NULL; cred_flavor = rqst->rq_cred.oa_flavor; switch (cred_flavor) { case AUTH_NULL: @@ -90,6 +97,11 @@ _authenticate(struct svc_req *rqst, struct rpc_msg *msg) case AUTH_SHORT: dummy = _svcauth_short(rqst, msg); return (dummy); + case RPCSEC_GSS: + if (!_svcauth_rpcsec_gss) + return (AUTH_REJECTEDCRED); + dummy = _svcauth_rpcsec_gss(rqst, msg); + return (dummy); default: break; } @@ -97,21 +109,65 @@ _authenticate(struct svc_req *rqst, struct rpc_msg *msg) return (AUTH_REJECTEDCRED); } +/* + * A set of null auth methods used by any authentication protocols + * that don't need to inspect or modify the message body. + */ +static bool_t +svcauth_null_wrap(SVCAUTH *auth, struct mbuf **mp) +{ + + return (TRUE); +} + +static bool_t +svcauth_null_unwrap(SVCAUTH *auth, struct mbuf **mp) +{ + + return (TRUE); +} + +static void +svcauth_null_release(SVCAUTH *auth) +{ + +} + +static struct svc_auth_ops svc_auth_null_ops = { + svcauth_null_wrap, + svcauth_null_unwrap, + svcauth_null_release, +}; + /*ARGSUSED*/ enum auth_stat _svcauth_null(struct svc_req *rqst, struct rpc_msg *msg) { + + rqst->rq_verf = _null_auth; return (AUTH_OK); } int -svc_getcred(struct svc_req *rqst, struct ucred *cr, int *flavorp) +svc_auth_reg(int flavor, + enum auth_stat (*svcauth)(struct svc_req *, struct rpc_msg *), + int (*getcred)(struct svc_req *, struct ucred **, int *)) { + + if (flavor == RPCSEC_GSS) { + _svcauth_rpcsec_gss = svcauth; + _svcauth_rpcsec_gss_getcred = getcred; + } + return (TRUE); +} + +int +svc_getcred(struct svc_req *rqst, struct ucred **crp, int *flavorp) +{ + struct ucred *cr = NULL; int flavor, i; struct xucred *xcr; - KASSERT(!crshared(cr), ("svc_getcred with shared cred")); - flavor = rqst->rq_cred.oa_flavor; if (flavorp) *flavorp = flavor; @@ -119,13 +175,20 @@ svc_getcred(struct svc_req *rqst, struct ucred *cr, int *flavorp) switch (flavor) { case AUTH_UNIX: xcr = (struct xucred *) rqst->rq_clntcred; + cr = crget(); cr->cr_uid = cr->cr_ruid = cr->cr_svuid = xcr->cr_uid; cr->cr_ngroups = xcr->cr_ngroups; for (i = 0; i < xcr->cr_ngroups; i++) cr->cr_groups[i] = xcr->cr_groups[i]; - cr->cr_rgid = cr->cr_groups[0]; + cr->cr_rgid = cr->cr_svgid = cr->cr_groups[0]; + *crp = cr; return (TRUE); + case RPCSEC_GSS: + if (!_svcauth_rpcsec_gss_getcred) + return (FALSE); + return (_svcauth_rpcsec_gss_getcred(rqst, crp, flavorp)); + default: return (FALSE); } diff --git a/sys/rpc/svc_auth.h b/sys/rpc/svc_auth.h index 26c191a..9e23876 100644 --- a/sys/rpc/svc_auth.h +++ b/sys/rpc/svc_auth.h @@ -47,19 +47,31 @@ */ __BEGIN_DECLS extern enum auth_stat _authenticate(struct svc_req *, struct rpc_msg *); +#ifdef _KERNEL +extern int svc_auth_reg(int, + enum auth_stat (*)(struct svc_req *, struct rpc_msg *), + int (*)(struct svc_req *, struct ucred **, int *)); +#else +extern int svc_auth_reg(int, enum auth_stat (*)(struct svc_req *, + struct rpc_msg *)); +#endif -extern int svc_getcred(struct svc_req *, struct ucred *, int *); + +extern int svc_getcred(struct svc_req *, struct ucred **, int *); /* * struct svc_req *req; -- RPC request - * struct ucred *cr -- Kernel cred to modify + * struct ucred **crp -- Kernel cred to modify * int *flavorp -- Return RPC auth flavor * * Retrieve unix creds corresponding to an RPC request, if * possible. The auth flavor (AUTH_NONE or AUTH_UNIX) is returned in - * *flavorp. If the flavor is AUTH_UNIX the caller's ucred structure - * will be modified to reflect the values from the request. Return's - * non-zero if credentials were retrieved form the request, otherwise - * zero. + * *flavorp. If the flavor is AUTH_UNIX the caller's ucred pointer + * will be modified to point at a ucred structure which reflects the + * values from the request. The caller should call crfree on this + * pointer. + * + * Return's non-zero if credentials were retrieved from the request, + * otherwise zero. */ __END_DECLS diff --git a/sys/rpc/svc_auth_unix.c b/sys/rpc/svc_auth_unix.c index 9c6cdd7..0c11a4a 100644 --- a/sys/rpc/svc_auth_unix.c +++ b/sys/rpc/svc_auth_unix.c @@ -120,8 +120,7 @@ _svcauth_unix(struct svc_req *rqst, struct rpc_msg *msg) goto done; } - rqst->rq_xprt->xp_verf.oa_flavor = AUTH_NULL; - rqst->rq_xprt->xp_verf.oa_length = 0; + rqst->rq_verf = _null_auth; stat = AUTH_OK; done: XDR_DESTROY(&xdrs); diff --git a/sys/rpc/svc_dg.c b/sys/rpc/svc_dg.c index 666b952..72721b0 100644 --- a/sys/rpc/svc_dg.c +++ b/sys/rpc/svc_dg.c @@ -53,6 +53,7 @@ __FBSDID("$FreeBSD$"); #include <sys/queue.h> #include <sys/socket.h> #include <sys/socketvar.h> +#include <sys/sx.h> #include <sys/systm.h> #include <sys/uio.h> @@ -61,10 +62,10 @@ __FBSDID("$FreeBSD$"); #include <rpc/rpc_com.h> static enum xprt_stat svc_dg_stat(SVCXPRT *); -static bool_t svc_dg_recv(SVCXPRT *, struct rpc_msg *); -static bool_t svc_dg_reply(SVCXPRT *, struct rpc_msg *); -static bool_t svc_dg_getargs(SVCXPRT *, xdrproc_t, void *); -static bool_t svc_dg_freeargs(SVCXPRT *, xdrproc_t, void *); +static bool_t svc_dg_recv(SVCXPRT *, struct rpc_msg *, + struct sockaddr **, struct mbuf **); +static bool_t svc_dg_reply(SVCXPRT *, struct rpc_msg *, + struct sockaddr *, struct mbuf *); static void svc_dg_destroy(SVCXPRT *); static bool_t svc_dg_control(SVCXPRT *, const u_int, void *); static void svc_dg_soupcall(struct socket *so, void *arg, int waitflag); @@ -72,9 +73,7 @@ static void svc_dg_soupcall(struct socket *so, void *arg, int waitflag); static struct xp_ops svc_dg_ops = { .xp_recv = svc_dg_recv, .xp_stat = svc_dg_stat, - .xp_getargs = svc_dg_getargs, .xp_reply = svc_dg_reply, - .xp_freeargs = svc_dg_freeargs, .xp_destroy = svc_dg_destroy, .xp_control = svc_dg_control, }; @@ -116,9 +115,8 @@ svc_dg_create(SVCPOOL *pool, struct socket *so, size_t sendsize, return (NULL); } - xprt = mem_alloc(sizeof (SVCXPRT)); - memset(xprt, 0, sizeof (SVCXPRT)); - mtx_init(&xprt->xp_lock, "xprt->xp_lock", NULL, MTX_DEF); + xprt = svc_xprt_alloc(); + sx_init(&xprt->xp_lock, "xprt->xp_lock"); xprt->xp_pool = pool; xprt->xp_socket = so; xprt->xp_p1 = NULL; @@ -129,16 +127,9 @@ svc_dg_create(SVCPOOL *pool, struct socket *so, size_t sendsize, if (error) goto freedata; - xprt->xp_ltaddr.buf = mem_alloc(sizeof (struct sockaddr_storage)); - xprt->xp_ltaddr.maxlen = sizeof (struct sockaddr_storage); - xprt->xp_ltaddr.len = sa->sa_len; - memcpy(xprt->xp_ltaddr.buf, sa, sa->sa_len); + memcpy(&xprt->xp_ltaddr, sa, sa->sa_len); free(sa, M_SONAME); - xprt->xp_rtaddr.buf = mem_alloc(sizeof (struct sockaddr_storage)); - xprt->xp_rtaddr.maxlen = sizeof (struct sockaddr_storage); - xprt->xp_rtaddr.len = 0; - xprt_register(xprt); SOCKBUF_LOCK(&so->so_rcv); @@ -151,7 +142,7 @@ svc_dg_create(SVCPOOL *pool, struct socket *so, size_t sendsize, freedata: (void) printf(svc_dg_str, __no_mem_str); if (xprt) { - (void) mem_free(xprt, sizeof (SVCXPRT)); + svc_xprt_free(xprt); } return (NULL); } @@ -161,34 +152,34 @@ static enum xprt_stat svc_dg_stat(SVCXPRT *xprt) { + if (soreadable(xprt->xp_socket)) + return (XPRT_MOREREQS); + return (XPRT_IDLE); } static bool_t -svc_dg_recv(SVCXPRT *xprt, struct rpc_msg *msg) +svc_dg_recv(SVCXPRT *xprt, struct rpc_msg *msg, + struct sockaddr **addrp, struct mbuf **mp) { struct uio uio; struct sockaddr *raddr; struct mbuf *mreq; + XDR xdrs; int error, rcvflag; /* + * Serialise access to the socket. + */ + sx_xlock(&xprt->xp_lock); + + /* * The socket upcall calls xprt_active() which will eventually * cause the server to call us here. We attempt to read a * packet from the socket and process it. If the read fails, * we have drained all pending requests so we call * xprt_inactive(). - * - * The lock protects us in the case where a new packet arrives - * on the socket after our call to soreceive fails with - * EWOULDBLOCK - the call to xprt_active() in the upcall will - * happen only after our call to xprt_inactive() which ensures - * that we will remain active. It might be possible to use - * SOCKBUF_LOCK for this - its not clear to me what locks are - * held during the upcall. */ - mtx_lock(&xprt->xp_lock); - uio.uio_resid = 1000000000; uio.uio_td = curthread; mreq = NULL; @@ -196,8 +187,19 @@ svc_dg_recv(SVCXPRT *xprt, struct rpc_msg *msg) error = soreceive(xprt->xp_socket, &raddr, &uio, &mreq, NULL, &rcvflag); if (error == EWOULDBLOCK) { - xprt_inactive(xprt); - mtx_unlock(&xprt->xp_lock); + /* + * We must re-test for readability after taking the + * lock to protect us in the case where a new packet + * arrives on the socket after our call to soreceive + * fails with EWOULDBLOCK. The pool lock protects us + * from racing the upcall after our soreadable() call + * returns false. + */ + mtx_lock(&xprt->xp_pool->sp_lock); + if (!soreadable(xprt->xp_socket)) + xprt_inactive_locked(xprt); + mtx_unlock(&xprt->xp_pool->sp_lock); + sx_xunlock(&xprt->xp_lock); return (FALSE); } @@ -208,45 +210,52 @@ svc_dg_recv(SVCXPRT *xprt, struct rpc_msg *msg) xprt->xp_socket->so_rcv.sb_flags &= ~SB_UPCALL; SOCKBUF_UNLOCK(&xprt->xp_socket->so_rcv); xprt_inactive(xprt); - mtx_unlock(&xprt->xp_lock); + sx_xunlock(&xprt->xp_lock); return (FALSE); } - mtx_unlock(&xprt->xp_lock); - - KASSERT(raddr->sa_len < xprt->xp_rtaddr.maxlen, - ("Unexpected remote address length")); - memcpy(xprt->xp_rtaddr.buf, raddr, raddr->sa_len); - xprt->xp_rtaddr.len = raddr->sa_len; - free(raddr, M_SONAME); + sx_xunlock(&xprt->xp_lock); - xdrmbuf_create(&xprt->xp_xdrreq, mreq, XDR_DECODE); - if (! xdr_callmsg(&xprt->xp_xdrreq, msg)) { - XDR_DESTROY(&xprt->xp_xdrreq); + xdrmbuf_create(&xdrs, mreq, XDR_DECODE); + if (! xdr_callmsg(&xdrs, msg)) { + XDR_DESTROY(&xdrs); return (FALSE); } - xprt->xp_xid = msg->rm_xid; + + *addrp = raddr; + *mp = xdrmbuf_getall(&xdrs); + XDR_DESTROY(&xdrs); return (TRUE); } static bool_t -svc_dg_reply(SVCXPRT *xprt, struct rpc_msg *msg) +svc_dg_reply(SVCXPRT *xprt, struct rpc_msg *msg, + struct sockaddr *addr, struct mbuf *m) { + XDR xdrs; struct mbuf *mrep; - bool_t stat = FALSE; + bool_t stat = TRUE; int error; MGETHDR(mrep, M_WAIT, MT_DATA); - MCLGET(mrep, M_WAIT); mrep->m_len = 0; - xdrmbuf_create(&xprt->xp_xdrrep, mrep, XDR_ENCODE); - msg->rm_xid = xprt->xp_xid; - if (xdr_replymsg(&xprt->xp_xdrrep, msg)) { + xdrmbuf_create(&xdrs, mrep, XDR_ENCODE); + + if (msg->rm_reply.rp_stat == MSG_ACCEPTED && + msg->rm_reply.rp_acpt.ar_stat == SUCCESS) { + if (!xdr_replymsg(&xdrs, msg)) + stat = FALSE; + else + xdrmbuf_append(&xdrs, m); + } else { + stat = xdr_replymsg(&xdrs, msg); + } + + if (stat) { m_fixhdr(mrep); - error = sosend(xprt->xp_socket, - (struct sockaddr *) xprt->xp_rtaddr.buf, NULL, mrep, NULL, + error = sosend(xprt->xp_socket, addr, NULL, mrep, NULL, 0, curthread); if (!error) { stat = TRUE; @@ -255,61 +264,29 @@ svc_dg_reply(SVCXPRT *xprt, struct rpc_msg *msg) m_freem(mrep); } - /* - * This frees the request mbuf chain as well. The reply mbuf - * chain was consumed by sosend. - */ - XDR_DESTROY(&xprt->xp_xdrreq); - XDR_DESTROY(&xprt->xp_xdrrep); + XDR_DESTROY(&xdrs); xprt->xp_p2 = NULL; return (stat); } -static bool_t -svc_dg_getargs(SVCXPRT *xprt, xdrproc_t xdr_args, void *args_ptr) -{ - - return (xdr_args(&xprt->xp_xdrreq, args_ptr)); -} - -static bool_t -svc_dg_freeargs(SVCXPRT *xprt, xdrproc_t xdr_args, void *args_ptr) -{ - XDR xdrs; - - /* - * Free the request mbuf here - this allows us to handle - * protocols where not all requests have replies - * (i.e. NLM). Note that xdrmbuf_destroy handles being called - * twice correctly - the mbuf will only be freed once. - */ - XDR_DESTROY(&xprt->xp_xdrreq); - - xdrs.x_op = XDR_FREE; - return (xdr_args(&xdrs, args_ptr)); -} - static void svc_dg_destroy(SVCXPRT *xprt) { + SOCKBUF_LOCK(&xprt->xp_socket->so_rcv); xprt->xp_socket->so_upcallarg = NULL; xprt->xp_socket->so_upcall = NULL; xprt->xp_socket->so_rcv.sb_flags &= ~SB_UPCALL; SOCKBUF_UNLOCK(&xprt->xp_socket->so_rcv); - xprt_unregister(xprt); - - mtx_destroy(&xprt->xp_lock); + sx_destroy(&xprt->xp_lock); if (xprt->xp_socket) (void)soclose(xprt->xp_socket); - if (xprt->xp_rtaddr.buf) - (void) mem_free(xprt->xp_rtaddr.buf, xprt->xp_rtaddr.maxlen); - if (xprt->xp_ltaddr.buf) - (void) mem_free(xprt->xp_ltaddr.buf, xprt->xp_ltaddr.maxlen); - (void) mem_free(xprt, sizeof (SVCXPRT)); + if (xprt->xp_netid) + (void) mem_free(xprt->xp_netid, strlen(xprt->xp_netid) + 1); + svc_xprt_free(xprt); } static bool_t @@ -328,7 +305,5 @@ svc_dg_soupcall(struct socket *so, void *arg, int waitflag) { SVCXPRT *xprt = (SVCXPRT *) arg; - mtx_lock(&xprt->xp_lock); xprt_active(xprt); - mtx_unlock(&xprt->xp_lock); } diff --git a/sys/rpc/svc_generic.c b/sys/rpc/svc_generic.c index 1f9b2e2..790b4ba 100644 --- a/sys/rpc/svc_generic.c +++ b/sys/rpc/svc_generic.c @@ -178,102 +178,13 @@ svc_tp_create( "svc_tp_create: Could not register prog %u vers %u on %s\n", (unsigned)prognum, (unsigned)versnum, nconf->nc_netid); - SVC_DESTROY(xprt); + xprt_unregister(xprt); return (NULL); } return (xprt); } /* - * Bind a socket to a privileged IP port - */ -int bindresvport(struct socket *so, struct sockaddr *sa); -int -bindresvport(struct socket *so, struct sockaddr *sa) -{ - int old, error, af; - bool_t freesa = FALSE; - struct sockaddr_in *sin; -#ifdef INET6 - struct sockaddr_in6 *sin6; -#endif - struct sockopt opt; - int proto, portrange, portlow; - u_int16_t *portp; - socklen_t salen; - - if (sa == NULL) { - error = so->so_proto->pr_usrreqs->pru_sockaddr(so, &sa); - if (error) - return (error); - freesa = TRUE; - af = sa->sa_family; - salen = sa->sa_len; - memset(sa, 0, sa->sa_len); - } else { - af = sa->sa_family; - salen = sa->sa_len; - } - - switch (af) { - case AF_INET: - proto = IPPROTO_IP; - portrange = IP_PORTRANGE; - portlow = IP_PORTRANGE_LOW; - sin = (struct sockaddr_in *)sa; - portp = &sin->sin_port; - break; -#ifdef INET6 - case AF_INET6: - proto = IPPROTO_IPV6; - portrange = IPV6_PORTRANGE; - portlow = IPV6_PORTRANGE_LOW; - sin6 = (struct sockaddr_in6 *)sa; - portp = &sin6->sin6_port; - break; -#endif - default: - return (EPFNOSUPPORT); - } - - sa->sa_family = af; - sa->sa_len = salen; - - if (*portp == 0) { - bzero(&opt, sizeof(opt)); - opt.sopt_dir = SOPT_GET; - opt.sopt_level = proto; - opt.sopt_name = portrange; - opt.sopt_val = &old; - opt.sopt_valsize = sizeof(old); - error = sogetopt(so, &opt); - if (error) - goto out; - - opt.sopt_dir = SOPT_SET; - opt.sopt_val = &portlow; - error = sosetopt(so, &opt); - if (error) - goto out; - } - - error = sobind(so, sa, curthread); - - if (*portp == 0) { - if (error) { - opt.sopt_dir = SOPT_SET; - opt.sopt_val = &old; - sosetopt(so, &opt); - } - } -out: - if (freesa) - free(sa, M_SONAME); - - return (error); -} - -/* * If so is NULL, then it opens a socket for the given transport * provider (nconf cannot be NULL then). If the t_state is T_UNBND and * bindaddr is NON-NULL, it performs a t_bind using the bindaddr. For @@ -401,7 +312,7 @@ freedata: if (xprt) { if (!madeso) /* so that svc_destroy doesnt close fd */ xprt->xp_socket = NULL; - SVC_DESTROY(xprt); + xprt_unregister(xprt); } return (NULL); } diff --git a/sys/rpc/svc_vc.c b/sys/rpc/svc_vc.c index 47530da..e3f0350 100644 --- a/sys/rpc/svc_vc.c +++ b/sys/rpc/svc_vc.c @@ -54,6 +54,7 @@ __FBSDID("$FreeBSD$"); #include <sys/queue.h> #include <sys/socket.h> #include <sys/socketvar.h> +#include <sys/sx.h> #include <sys/systm.h> #include <sys/uio.h> #include <netinet/tcp.h> @@ -62,16 +63,17 @@ __FBSDID("$FreeBSD$"); #include <rpc/rpc_com.h> -static bool_t svc_vc_rendezvous_recv(SVCXPRT *, struct rpc_msg *); +static bool_t svc_vc_rendezvous_recv(SVCXPRT *, struct rpc_msg *, + struct sockaddr **, struct mbuf **); static enum xprt_stat svc_vc_rendezvous_stat(SVCXPRT *); static void svc_vc_rendezvous_destroy(SVCXPRT *); static bool_t svc_vc_null(void); static void svc_vc_destroy(SVCXPRT *); static enum xprt_stat svc_vc_stat(SVCXPRT *); -static bool_t svc_vc_recv(SVCXPRT *, struct rpc_msg *); -static bool_t svc_vc_getargs(SVCXPRT *, xdrproc_t, void *); -static bool_t svc_vc_freeargs(SVCXPRT *, xdrproc_t, void *); -static bool_t svc_vc_reply(SVCXPRT *, struct rpc_msg *); +static bool_t svc_vc_recv(SVCXPRT *, struct rpc_msg *, + struct sockaddr **, struct mbuf **); +static bool_t svc_vc_reply(SVCXPRT *, struct rpc_msg *, + struct sockaddr *, struct mbuf *); static bool_t svc_vc_control(SVCXPRT *xprt, const u_int rq, void *in); static bool_t svc_vc_rendezvous_control (SVCXPRT *xprt, const u_int rq, void *in); @@ -83,9 +85,8 @@ static void svc_vc_soupcall(struct socket *so, void *arg, int waitflag); static struct xp_ops svc_vc_rendezvous_ops = { .xp_recv = svc_vc_rendezvous_recv, .xp_stat = svc_vc_rendezvous_stat, - .xp_getargs = (bool_t (*)(SVCXPRT *, xdrproc_t, void *))svc_vc_null, - .xp_reply = (bool_t (*)(SVCXPRT *, struct rpc_msg *))svc_vc_null, - .xp_freeargs = (bool_t (*)(SVCXPRT *, xdrproc_t, void *))svc_vc_null, + .xp_reply = (bool_t (*)(SVCXPRT *, struct rpc_msg *, + struct sockaddr *, struct mbuf *))svc_vc_null, .xp_destroy = svc_vc_rendezvous_destroy, .xp_control = svc_vc_rendezvous_control }; @@ -93,9 +94,7 @@ static struct xp_ops svc_vc_rendezvous_ops = { static struct xp_ops svc_vc_ops = { .xp_recv = svc_vc_recv, .xp_stat = svc_vc_stat, - .xp_getargs = svc_vc_getargs, .xp_reply = svc_vc_reply, - .xp_freeargs = svc_vc_freeargs, .xp_destroy = svc_vc_destroy, .xp_control = svc_vc_control }; @@ -141,28 +140,21 @@ svc_vc_create(SVCPOOL *pool, struct socket *so, size_t sendsize, return (xprt); } - xprt = mem_alloc(sizeof(SVCXPRT)); - mtx_init(&xprt->xp_lock, "xprt->xp_lock", NULL, MTX_DEF); + xprt = svc_xprt_alloc(); + sx_init(&xprt->xp_lock, "xprt->xp_lock"); xprt->xp_pool = pool; xprt->xp_socket = so; xprt->xp_p1 = NULL; xprt->xp_p2 = NULL; - xprt->xp_p3 = NULL; - xprt->xp_verf = _null_auth; xprt->xp_ops = &svc_vc_rendezvous_ops; error = so->so_proto->pr_usrreqs->pru_sockaddr(so, &sa); if (error) goto cleanup_svc_vc_create; - xprt->xp_ltaddr.buf = mem_alloc(sizeof (struct sockaddr_storage)); - xprt->xp_ltaddr.maxlen = sizeof (struct sockaddr_storage); - xprt->xp_ltaddr.len = sa->sa_len; - memcpy(xprt->xp_ltaddr.buf, sa, sa->sa_len); + memcpy(&xprt->xp_ltaddr, sa, sa->sa_len); free(sa, M_SONAME); - xprt->xp_rtaddr.maxlen = 0; - xprt_register(xprt); solisten(so, SOMAXCONN, curthread); @@ -176,7 +168,7 @@ svc_vc_create(SVCPOOL *pool, struct socket *so, size_t sendsize, return (xprt); cleanup_svc_vc_create: if (xprt) - mem_free(xprt, sizeof(*xprt)); + svc_xprt_free(xprt); return (NULL); } @@ -218,29 +210,27 @@ svc_vc_create_conn(SVCPOOL *pool, struct socket *so, struct sockaddr *raddr) cd = mem_alloc(sizeof(*cd)); cd->strm_stat = XPRT_IDLE; - xprt = mem_alloc(sizeof(SVCXPRT)); - mtx_init(&xprt->xp_lock, "xprt->xp_lock", NULL, MTX_DEF); + xprt = svc_xprt_alloc(); + sx_init(&xprt->xp_lock, "xprt->xp_lock"); xprt->xp_pool = pool; xprt->xp_socket = so; xprt->xp_p1 = cd; xprt->xp_p2 = NULL; - xprt->xp_p3 = NULL; - xprt->xp_verf = _null_auth; xprt->xp_ops = &svc_vc_ops; - xprt->xp_rtaddr.buf = mem_alloc(sizeof (struct sockaddr_storage)); - xprt->xp_rtaddr.maxlen = sizeof (struct sockaddr_storage); - xprt->xp_rtaddr.len = raddr->sa_len; - memcpy(xprt->xp_rtaddr.buf, raddr, raddr->sa_len); + /* + * See http://www.connectathon.org/talks96/nfstcp.pdf - client + * has a 5 minute timer, server has a 6 minute timer. + */ + xprt->xp_idletimeout = 6 * 60; + + memcpy(&xprt->xp_rtaddr, raddr, raddr->sa_len); error = so->so_proto->pr_usrreqs->pru_sockaddr(so, &sa); if (error) goto cleanup_svc_vc_create; - xprt->xp_ltaddr.buf = mem_alloc(sizeof (struct sockaddr_storage)); - xprt->xp_ltaddr.maxlen = sizeof (struct sockaddr_storage); - xprt->xp_ltaddr.len = sa->sa_len; - memcpy(xprt->xp_ltaddr.buf, sa, sa->sa_len); + memcpy(&xprt->xp_ltaddr, sa, sa->sa_len); free(sa, M_SONAME); xprt_register(xprt); @@ -255,19 +245,13 @@ svc_vc_create_conn(SVCPOOL *pool, struct socket *so, struct sockaddr *raddr) * Throw the transport into the active list in case it already * has some data buffered. */ - mtx_lock(&xprt->xp_lock); + sx_xlock(&xprt->xp_lock); xprt_active(xprt); - mtx_unlock(&xprt->xp_lock); + sx_xunlock(&xprt->xp_lock); return (xprt); cleanup_svc_vc_create: if (xprt) { - if (xprt->xp_ltaddr.buf) - mem_free(xprt->xp_ltaddr.buf, - sizeof(struct sockaddr_storage)); - if (xprt->xp_rtaddr.buf) - mem_free(xprt->xp_rtaddr.buf, - sizeof(struct sockaddr_storage)); mem_free(xprt, sizeof(*xprt)); } if (cd) @@ -335,7 +319,8 @@ done: /*ARGSUSED*/ static bool_t -svc_vc_rendezvous_recv(SVCXPRT *xprt, struct rpc_msg *msg) +svc_vc_rendezvous_recv(SVCXPRT *xprt, struct rpc_msg *msg, + struct sockaddr **addrp, struct mbuf **mp) { struct socket *so = NULL; struct sockaddr *sa = NULL; @@ -347,22 +332,27 @@ svc_vc_rendezvous_recv(SVCXPRT *xprt, struct rpc_msg *msg) * connection from the socket and turn it into a new * transport. If the accept fails, we have drained all pending * connections so we call xprt_inactive(). - * - * The lock protects us in the case where a new connection arrives - * on the socket after our call to accept fails with - * EWOULDBLOCK - the call to xprt_active() in the upcall will - * happen only after our call to xprt_inactive() which ensures - * that we will remain active. It might be possible to use - * SOCKBUF_LOCK for this - its not clear to me what locks are - * held during the upcall. */ - mtx_lock(&xprt->xp_lock); + sx_xlock(&xprt->xp_lock); error = svc_vc_accept(xprt->xp_socket, &so); if (error == EWOULDBLOCK) { - xprt_inactive(xprt); - mtx_unlock(&xprt->xp_lock); + /* + * We must re-test for new connections after taking + * the lock to protect us in the case where a new + * connection arrives after our call to accept fails + * with EWOULDBLOCK. The pool lock protects us from + * racing the upcall after our TAILQ_EMPTY() call + * returns false. + */ + ACCEPT_LOCK(); + mtx_lock(&xprt->xp_pool->sp_lock); + if (TAILQ_EMPTY(&xprt->xp_socket->so_comp)) + xprt_inactive_locked(xprt); + mtx_unlock(&xprt->xp_pool->sp_lock); + ACCEPT_UNLOCK(); + sx_xunlock(&xprt->xp_lock); return (FALSE); } @@ -373,11 +363,11 @@ svc_vc_rendezvous_recv(SVCXPRT *xprt, struct rpc_msg *msg) xprt->xp_socket->so_rcv.sb_flags &= ~SB_UPCALL; SOCKBUF_UNLOCK(&xprt->xp_socket->so_rcv); xprt_inactive(xprt); - mtx_unlock(&xprt->xp_lock); + sx_xunlock(&xprt->xp_lock); return (FALSE); } - mtx_unlock(&xprt->xp_lock); + sx_xunlock(&xprt->xp_lock); sa = 0; error = soaccept(so, &sa); @@ -420,18 +410,13 @@ svc_vc_destroy_common(SVCXPRT *xprt) xprt->xp_socket->so_rcv.sb_flags &= ~SB_UPCALL; SOCKBUF_UNLOCK(&xprt->xp_socket->so_rcv); - xprt_unregister(xprt); - - mtx_destroy(&xprt->xp_lock); + sx_destroy(&xprt->xp_lock); if (xprt->xp_socket) (void)soclose(xprt->xp_socket); - if (xprt->xp_rtaddr.buf) - (void) mem_free(xprt->xp_rtaddr.buf, xprt->xp_rtaddr.maxlen); - if (xprt->xp_ltaddr.buf) - (void) mem_free(xprt->xp_ltaddr.buf, xprt->xp_ltaddr.maxlen); - (void) mem_free(xprt, sizeof (SVCXPRT)); - + if (xprt->xp_netid) + (void) mem_free(xprt->xp_netid, strlen(xprt->xp_netid) + 1); + svc_xprt_free(xprt); } static void @@ -483,32 +468,48 @@ svc_vc_stat(SVCXPRT *xprt) /* * Return XPRT_MOREREQS if we have buffered data and we are - * mid-record or if we have enough data for a record marker. + * mid-record or if we have enough data for a record + * marker. Since this is only a hint, we read mpending and + * resid outside the lock. We do need to take the lock if we + * have to traverse the mbuf chain. */ if (cd->mpending) { if (cd->resid) return (XPRT_MOREREQS); n = 0; + sx_xlock(&xprt->xp_lock); m = cd->mpending; while (m && n < sizeof(uint32_t)) { n += m->m_len; m = m->m_next; } + sx_xunlock(&xprt->xp_lock); if (n >= sizeof(uint32_t)) return (XPRT_MOREREQS); } + if (soreadable(xprt->xp_socket)) + return (XPRT_MOREREQS); + return (XPRT_IDLE); } static bool_t -svc_vc_recv(SVCXPRT *xprt, struct rpc_msg *msg) +svc_vc_recv(SVCXPRT *xprt, struct rpc_msg *msg, + struct sockaddr **addrp, struct mbuf **mp) { struct cf_conn *cd = (struct cf_conn *) xprt->xp_p1; struct uio uio; struct mbuf *m; + XDR xdrs; int error, rcvflag; + /* + * Serialise access to the socket and our own record parsing + * state. + */ + sx_xlock(&xprt->xp_lock); + for (;;) { /* * If we have an mbuf chain in cd->mpending, try to parse a @@ -539,7 +540,9 @@ svc_vc_recv(SVCXPRT *xprt, struct rpc_msg *msg) } if (n < sizeof(uint32_t)) goto readmore; - cd->mpending = m_pullup(cd->mpending, sizeof(uint32_t)); + if (cd->mpending->m_len < sizeof(uint32_t)) + cd->mpending = m_pullup(cd->mpending, + sizeof(uint32_t)); memcpy(&header, mtod(cd->mpending, uint32_t *), sizeof(header)); header = ntohl(header); @@ -557,8 +560,12 @@ svc_vc_recv(SVCXPRT *xprt, struct rpc_msg *msg) */ while (cd->mpending && cd->resid) { m = cd->mpending; - cd->mpending = m_split(cd->mpending, cd->resid, - M_WAIT); + if (cd->mpending->m_next + || cd->mpending->m_len > cd->resid) + cd->mpending = m_split(cd->mpending, + cd->resid, M_WAIT); + else + cd->mpending = NULL; if (cd->mreq) m_last(cd->mreq)->m_next = m; else @@ -582,13 +589,18 @@ svc_vc_recv(SVCXPRT *xprt, struct rpc_msg *msg) * Success - we have a complete record in * cd->mreq. */ - xdrmbuf_create(&xprt->xp_xdrreq, cd->mreq, XDR_DECODE); + xdrmbuf_create(&xdrs, cd->mreq, XDR_DECODE); cd->mreq = NULL; - if (! xdr_callmsg(&xprt->xp_xdrreq, msg)) { - XDR_DESTROY(&xprt->xp_xdrreq); + sx_xunlock(&xprt->xp_lock); + + if (! xdr_callmsg(&xdrs, msg)) { + XDR_DESTROY(&xdrs); return (FALSE); } - xprt->xp_xid = msg->rm_xid; + + *addrp = NULL; + *mp = xdrmbuf_getall(&xdrs); + XDR_DESTROY(&xdrs); return (TRUE); } @@ -602,17 +614,7 @@ svc_vc_recv(SVCXPRT *xprt, struct rpc_msg *msg) * the result in cd->mpending. If the read fails, * we have drained both cd->mpending and the socket so * we can call xprt_inactive(). - * - * The lock protects us in the case where a new packet arrives - * on the socket after our call to soreceive fails with - * EWOULDBLOCK - the call to xprt_active() in the upcall will - * happen only after our call to xprt_inactive() which ensures - * that we will remain active. It might be possible to use - * SOCKBUF_LOCK for this - its not clear to me what locks are - * held during the upcall. */ - mtx_lock(&xprt->xp_lock); - uio.uio_resid = 1000000000; uio.uio_td = curthread; m = NULL; @@ -621,8 +623,20 @@ svc_vc_recv(SVCXPRT *xprt, struct rpc_msg *msg) &rcvflag); if (error == EWOULDBLOCK) { - xprt_inactive(xprt); - mtx_unlock(&xprt->xp_lock); + /* + * We must re-test for readability after + * taking the lock to protect us in the case + * where a new packet arrives on the socket + * after our call to soreceive fails with + * EWOULDBLOCK. The pool lock protects us from + * racing the upcall after our soreadable() + * call returns false. + */ + mtx_lock(&xprt->xp_pool->sp_lock); + if (!soreadable(xprt->xp_socket)) + xprt_inactive_locked(xprt); + mtx_unlock(&xprt->xp_pool->sp_lock); + sx_xunlock(&xprt->xp_lock); return (FALSE); } @@ -634,7 +648,7 @@ svc_vc_recv(SVCXPRT *xprt, struct rpc_msg *msg) SOCKBUF_UNLOCK(&xprt->xp_socket->so_rcv); xprt_inactive(xprt); cd->strm_stat = XPRT_DIED; - mtx_unlock(&xprt->xp_lock); + sx_xunlock(&xprt->xp_lock); return (FALSE); } @@ -642,8 +656,9 @@ svc_vc_recv(SVCXPRT *xprt, struct rpc_msg *msg) /* * EOF - the other end has closed the socket. */ + xprt_inactive(xprt); cd->strm_stat = XPRT_DIED; - mtx_unlock(&xprt->xp_lock); + sx_xunlock(&xprt->xp_lock); return (FALSE); } @@ -651,53 +666,38 @@ svc_vc_recv(SVCXPRT *xprt, struct rpc_msg *msg) m_last(cd->mpending)->m_next = m; else cd->mpending = m; - - mtx_unlock(&xprt->xp_lock); } } static bool_t -svc_vc_getargs(SVCXPRT *xprt, xdrproc_t xdr_args, void *args_ptr) -{ - - return (xdr_args(&xprt->xp_xdrreq, args_ptr)); -} - -static bool_t -svc_vc_freeargs(SVCXPRT *xprt, xdrproc_t xdr_args, void *args_ptr) +svc_vc_reply(SVCXPRT *xprt, struct rpc_msg *msg, + struct sockaddr *addr, struct mbuf *m) { XDR xdrs; - - /* - * Free the request mbuf here - this allows us to handle - * protocols where not all requests have replies - * (i.e. NLM). Note that xdrmbuf_destroy handles being called - * twice correctly - the mbuf will only be freed once. - */ - XDR_DESTROY(&xprt->xp_xdrreq); - - xdrs.x_op = XDR_FREE; - return (xdr_args(&xdrs, args_ptr)); -} - -static bool_t -svc_vc_reply(SVCXPRT *xprt, struct rpc_msg *msg) -{ struct mbuf *mrep; - bool_t stat = FALSE; + bool_t stat = TRUE; int error; /* * Leave space for record mark. */ MGETHDR(mrep, M_WAIT, MT_DATA); - MCLGET(mrep, M_WAIT); mrep->m_len = 0; mrep->m_data += sizeof(uint32_t); - xdrmbuf_create(&xprt->xp_xdrrep, mrep, XDR_ENCODE); - msg->rm_xid = xprt->xp_xid; - if (xdr_replymsg(&xprt->xp_xdrrep, msg)) { + xdrmbuf_create(&xdrs, mrep, XDR_ENCODE); + + if (msg->rm_reply.rp_stat == MSG_ACCEPTED && + msg->rm_reply.rp_acpt.ar_stat == SUCCESS) { + if (!xdr_replymsg(&xdrs, msg)) + stat = FALSE; + else + xdrmbuf_append(&xdrs, m); + } else { + stat = xdr_replymsg(&xdrs, msg); + } + + if (stat) { m_fixhdr(mrep); /* @@ -716,12 +716,7 @@ svc_vc_reply(SVCXPRT *xprt, struct rpc_msg *msg) m_freem(mrep); } - /* - * This frees the request mbuf chain as well. The reply mbuf - * chain was consumed by sosend. - */ - XDR_DESTROY(&xprt->xp_xdrreq); - XDR_DESTROY(&xprt->xp_xdrrep); + XDR_DESTROY(&xdrs); xprt->xp_p2 = NULL; return (stat); @@ -739,9 +734,7 @@ svc_vc_soupcall(struct socket *so, void *arg, int waitflag) { SVCXPRT *xprt = (SVCXPRT *) arg; - mtx_lock(&xprt->xp_lock); xprt_active(xprt); - mtx_unlock(&xprt->xp_lock); } #if 0 @@ -757,7 +750,7 @@ __rpc_get_local_uid(SVCXPRT *transp, uid_t *uid) { struct sockaddr *sa; sock = transp->xp_fd; - sa = (struct sockaddr *)transp->xp_rtaddr.buf; + sa = (struct sockaddr *)transp->xp_rtaddr; if (sa->sa_family == AF_LOCAL) { ret = getpeereid(sock, &euid, &egid); if (ret == 0) diff --git a/sys/rpc/xdr.h b/sys/rpc/xdr.h index bebd448..947bf4f 100644 --- a/sys/rpc/xdr.h +++ b/sys/rpc/xdr.h @@ -348,6 +348,8 @@ extern void xdrmem_create(XDR *, char *, u_int, enum xdr_op); /* XDR using mbufs */ struct mbuf; extern void xdrmbuf_create(XDR *, struct mbuf *, enum xdr_op); +extern void xdrmbuf_append(XDR *, struct mbuf *); +extern struct mbuf * xdrmbuf_getall(XDR *); /* XDR pseudo records for tcp */ extern void xdrrec_create(XDR *, u_int, u_int, void *, |