From cd25d4648fdd5f53f76f460b7f57015bdc89bb56 Mon Sep 17 00:00:00 2001 From: hsu Date: Mon, 10 Jun 2002 20:05:46 +0000 Subject: Lock up inpcb. Submitted by: Jennifer Yang --- sys/netinet/tcp_usrreq.c | 198 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 161 insertions(+), 37 deletions(-) (limited to 'sys/netinet/tcp_usrreq.c') diff --git a/sys/netinet/tcp_usrreq.c b/sys/netinet/tcp_usrreq.c index e1f4c1a..f5e75d1 100644 --- a/sys/netinet/tcp_usrreq.c +++ b/sys/netinet/tcp_usrreq.c @@ -40,6 +40,7 @@ #include #include +#include #include #include #include @@ -120,11 +121,13 @@ tcp_usr_attach(struct socket *so, int proto, struct thread *td) { int s = splnet(); int error; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp = 0; TCPDEBUG0; + INP_INFO_WLOCK(&tcbinfo); TCPDEBUG1(); + inp = sotoinpcb(so); if (inp) { error = EISCONN; goto out; @@ -136,9 +139,15 @@ tcp_usr_attach(struct socket *so, int proto, struct thread *td) if ((so->so_options & SO_LINGER) && so->so_linger == 0) so->so_linger = TCP_LINGERTIME; - tp = sototcpcb(so); + + inp = sotoinpcb(so); + INP_LOCK(inp); + tp = intotcpcb(inp); out: TCPDEBUG2(PRU_ATTACH); + if (tp) + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&tcbinfo); splx(s); return error; } @@ -155,35 +164,68 @@ tcp_usr_detach(struct socket *so) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; TCPDEBUG0; + INP_INFO_WLOCK(&tcbinfo); + inp = sotoinpcb(so); if (inp == 0) { + INP_INFO_WUNLOCK(&tcbinfo); splx(s); return EINVAL; /* XXX */ } + INP_LOCK(inp); tp = intotcpcb(inp); TCPDEBUG1(); tp = tcp_disconnect(tp); TCPDEBUG2(PRU_DETACH); + if (tp) + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&tcbinfo); splx(s); return error; } -#define COMMON_START() TCPDEBUG0; \ - do { \ - if (inp == 0) { \ - splx(s); \ - return EINVAL; \ - } \ - tp = intotcpcb(inp); \ - TCPDEBUG1(); \ - } while(0) - -#define COMMON_END(req) out: TCPDEBUG2(req); splx(s); return error; goto out - +#define INI_NOLOCK 0 +#define INI_READ 1 +#define INI_WRITE 2 + +#define COMMON_START() \ + TCPDEBUG0; \ + do { \ + if (inirw == INI_READ) \ + INP_INFO_RLOCK(&tcbinfo); \ + else if (inirw == INI_WRITE) \ + INP_INFO_WLOCK(&tcbinfo); \ + inp = sotoinpcb(so); \ + if (inp == 0) { \ + if (inirw == INI_READ) \ + INP_INFO_RUNLOCK(&tcbinfo); \ + else if (inirw == INI_WRITE) \ + INP_INFO_WUNLOCK(&tcbinfo); \ + splx(s); \ + return EINVAL; \ + } \ + INP_LOCK(inp); \ + if (inirw == INI_READ) \ + INP_INFO_RUNLOCK(&tcbinfo); \ + tp = intotcpcb(inp); \ + TCPDEBUG1(); \ +} while(0) + +#define COMMON_END(req) \ +out: TCPDEBUG2(req); \ + do { \ + if (tp) \ + INP_UNLOCK(inp); \ + if (inirw == INI_WRITE) \ + INP_INFO_WUNLOCK(&tcbinfo); \ + splx(s); \ + return error; \ + goto out; \ +} while(0) /* * Give the socket an address. @@ -193,9 +235,10 @@ tcp_usr_bind(struct socket *so, struct sockaddr *nam, struct thread *td) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; struct sockaddr_in *sinp; + const int inirw = INI_READ; COMMON_START(); @@ -213,7 +256,6 @@ tcp_usr_bind(struct socket *so, struct sockaddr *nam, struct thread *td) if (error) goto out; COMMON_END(PRU_BIND); - } #ifdef INET6 @@ -222,9 +264,10 @@ tcp6_usr_bind(struct socket *so, struct sockaddr *nam, struct thread *td) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; struct sockaddr_in6 *sin6p; + const int inirw = INI_READ; COMMON_START(); @@ -268,8 +311,9 @@ tcp_usr_listen(struct socket *so, struct thread *td) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_READ; COMMON_START(); if (inp->inp_lport == 0) @@ -285,8 +329,9 @@ tcp6_usr_listen(struct socket *so, struct thread *td) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_READ; COMMON_START(); if (inp->inp_lport == 0) { @@ -314,9 +359,10 @@ tcp_usr_connect(struct socket *so, struct sockaddr *nam, struct thread *td) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; struct sockaddr_in *sinp; + const int inirw = INI_WRITE; COMMON_START(); @@ -345,9 +391,10 @@ tcp6_usr_connect(struct socket *so, struct sockaddr *nam, struct thread *td) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; struct sockaddr_in6 *sin6p; + const int inirw = INI_WRITE; COMMON_START(); @@ -402,8 +449,9 @@ tcp_usr_disconnect(struct socket *so) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_WRITE; COMMON_START(); tp = tcp_disconnect(tp); @@ -418,23 +466,49 @@ tcp_usr_disconnect(struct socket *so) static int tcp_usr_accept(struct socket *so, struct sockaddr **nam) { - int s = splnet(); + int s; int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp = NULL; struct tcpcb *tp = NULL; + struct sockaddr_in *sin; + const int inirw = INI_READ; TCPDEBUG0; if (so->so_state & SS_ISDISCONNECTED) { error = ECONNABORTED; goto out; } - if (inp == 0) { + + /* + * Do the malloc first in case it blocks. + */ + MALLOC(sin, struct sockaddr_in *, sizeof *sin, M_SONAME, + M_WAITOK | M_ZERO); + sin->sin_family = AF_INET; + sin->sin_len = sizeof(*sin); + + s = splnet(); + INP_INFO_RLOCK(&tcbinfo); + inp = sotoinpcb(so); + if (!inp) { + INP_INFO_RUNLOCK(&tcbinfo); splx(s); + free(sin, M_SONAME); return (EINVAL); } + INP_LOCK(inp); + INP_INFO_RUNLOCK(&tcbinfo); tp = intotcpcb(inp); TCPDEBUG1(); - in_setpeeraddr(so, nam); + + /* + * We inline in_setpeeraddr here, because we have already done + * the locking and the malloc. + */ + sin->sin_port = inp->inp_fport; + sin->sin_addr = inp->inp_faddr; + *nam = (struct sockaddr *)sin; + COMMON_END(PRU_ACCEPT); } @@ -442,26 +516,56 @@ tcp_usr_accept(struct socket *so, struct sockaddr **nam) static int tcp6_usr_accept(struct socket *so, struct sockaddr **nam) { - int s = splnet(); + int s; + struct inpcb *inp = NULL; int error = 0; - struct inpcb *inp = sotoinpcb(so); struct tcpcb *tp = NULL; + const int inirw = INI_READ; TCPDEBUG0; if (so->so_state & SS_ISDISCONNECTED) { error = ECONNABORTED; goto out; } + + s = splnet(); + INP_INFO_RLOCK(&tcbinfo); + inp = sotoinpcb(so); if (inp == 0) { + INP_INFO_RUNLOCK(&tcbinfo); splx(s); return (EINVAL); } + INP_LOCK(inp); + INP_INFO_RUNLOCK(&tcbinfo); tp = intotcpcb(inp); TCPDEBUG1(); in6_mapped_peeraddr(so, nam); COMMON_END(PRU_ACCEPT); } #endif /* INET6 */ + +/* + * This is the wrapper function for in_setsockaddr. We just pass down + * the pcbinfo for in_setsockaddr to lock. We don't want to do the locking + * here because in_setsockaddr will call malloc and can block. + */ +static int +tcp_sockaddr(struct socket *so, struct sockaddr **nam) +{ + return (in_setsockaddr(so, nam, &tcbinfo)); +} + +/* + * This is the wrapper function for in_setpeeraddr. We just pass down + * the pcbinfo for in_setpeeraddr to lock. + */ +static int +tcp_peeraddr(struct socket *so, struct sockaddr **nam) +{ + return (in_setpeeraddr(so, nam, &tcbinfo)); +} + /* * Mark the connection as being incapable of further output. */ @@ -470,8 +574,9 @@ tcp_usr_shutdown(struct socket *so) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_WRITE; COMMON_START(); socantsendmore(so); @@ -489,8 +594,9 @@ tcp_usr_rcvd(struct socket *so, int flags) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_READ; COMMON_START(); tcp_output(tp); @@ -510,13 +616,22 @@ tcp_usr_send(struct socket *so, int flags, struct mbuf *m, { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_WRITE; #ifdef INET6 int isipv6; #endif TCPDEBUG0; + /* + * Need write lock here because this function might call + * tcp_connect or tcp_usrclosed. + * We really want to have to this function upgrade from read lock + * to write lock. XXX + */ + INP_INFO_WLOCK(&tcbinfo); + inp = sotoinpcb(so); if (inp == NULL) { /* * OOPS! we lost a race, the TCP session got reset after @@ -532,6 +647,7 @@ tcp_usr_send(struct socket *so, int flags, struct mbuf *m, TCPDEBUG1(); goto out; } + INP_LOCK(inp); #ifdef INET6 isipv6 = nam && nam->sa_family == AF_INET6; #endif /* INET6 */ @@ -548,7 +664,7 @@ tcp_usr_send(struct socket *so, int flags, struct mbuf *m, } m_freem(control); /* empty control, just free it */ } - if(!(flags & PRUS_OOB)) { + if (!(flags & PRUS_OOB)) { sbappend(&so->so_snd, m); if (nam && tp->t_state < TCPS_SYN_SENT) { /* @@ -634,8 +750,9 @@ tcp_usr_abort(struct socket *so) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_WRITE; COMMON_START(); tp = tcp_drop(tp, ECONNABORTED); @@ -650,8 +767,9 @@ tcp_usr_rcvoob(struct socket *so, struct mbuf *m, int flags) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_READ; COMMON_START(); if ((so->so_oobmark == 0 && @@ -676,9 +794,9 @@ tcp_usr_rcvoob(struct socket *so, struct mbuf *m, int flags) struct pr_usrreqs tcp_usrreqs = { tcp_usr_abort, tcp_usr_accept, tcp_usr_attach, tcp_usr_bind, tcp_usr_connect, pru_connect2_notsupp, in_control, tcp_usr_detach, - tcp_usr_disconnect, tcp_usr_listen, in_setpeeraddr, tcp_usr_rcvd, + tcp_usr_disconnect, tcp_usr_listen, tcp_peeraddr, tcp_usr_rcvd, tcp_usr_rcvoob, tcp_usr_send, pru_sense_null, tcp_usr_shutdown, - in_setsockaddr, sosend, soreceive, sopoll + tcp_sockaddr, sosend, soreceive, sopoll }; #ifdef INET6 @@ -888,11 +1006,15 @@ tcp_ctloutput(so, sopt) error = 0; s = splnet(); /* XXX */ + INP_INFO_RLOCK(&tcbinfo); inp = sotoinpcb(so); if (inp == NULL) { + INP_INFO_RUNLOCK(&tcbinfo); splx(s); return (ECONNRESET); } + INP_LOCK(inp); + INP_INFO_RUNLOCK(&tcbinfo); if (sopt->sopt_level != IPPROTO_TCP) { #ifdef INET6 if (INP_CHECK_SOCKAF(so, AF_INET6)) @@ -900,6 +1022,7 @@ tcp_ctloutput(so, sopt) else #endif /* INET6 */ error = ip_ctloutput(so, sopt); + INP_UNLOCK(inp); splx(s); return (error); } @@ -987,6 +1110,7 @@ tcp_ctloutput(so, sopt) error = sooptcopyout(sopt, &optval, sizeof optval); break; } + INP_UNLOCK(inp); splx(s); return (error); } -- cgit v1.1