From cd25d4648fdd5f53f76f460b7f57015bdc89bb56 Mon Sep 17 00:00:00 2001 From: hsu Date: Mon, 10 Jun 2002 20:05:46 +0000 Subject: Lock up inpcb. Submitted by: Jennifer Yang --- sys/netinet/in.c | 4 +- sys/netinet/in_pcb.c | 60 ++++++++++---- sys/netinet/in_pcb.h | 24 +++++- sys/netinet/ip_divert.c | 25 +++++- sys/netinet/raw_ip.c | 25 +++++- sys/netinet/tcp_input.c | 49 ++++++++++- sys/netinet/tcp_reass.c | 49 ++++++++++- sys/netinet/tcp_subr.c | 64 ++++++++++++--- sys/netinet/tcp_syncache.c | 5 ++ sys/netinet/tcp_timer.c | 51 ++++++++++++ sys/netinet/tcp_timewait.c | 64 ++++++++++++--- sys/netinet/tcp_usrreq.c | 198 ++++++++++++++++++++++++++++++++++++--------- sys/netinet/tcp_var.h | 1 + sys/netinet/udp_usrreq.c | 163 +++++++++++++++++++++++++++++++------ sys/netinet6/in6_pcb.c | 5 +- 15 files changed, 668 insertions(+), 119 deletions(-) (limited to 'sys') diff --git a/sys/netinet/in.c b/sys/netinet/in.c index d27eec9..5f4179d 100644 --- a/sys/netinet/in.c +++ b/sys/netinet/in.c @@ -426,8 +426,8 @@ in_control(so, cmd, data, ifp, td) * from if_detach() */ if (ifaddr_byindex(ifp->if_index) != NULL) { - in_pcbpurgeif0(LIST_FIRST(ripcbinfo.listhead), ifp); - in_pcbpurgeif0(LIST_FIRST(udbinfo.listhead), ifp); + in_pcbpurgeif0(&ripcbinfo, ifp); + in_pcbpurgeif0(&udbinfo, ifp); } error = 0; break; diff --git a/sys/netinet/in_pcb.c b/sys/netinet/in_pcb.c index b4de80a..f45be67 100644 --- a/sys/netinet/in_pcb.c +++ b/sys/netinet/in_pcb.c @@ -144,7 +144,7 @@ in_pcballoc(so, pcbinfo, td) int error; #endif - inp = uma_zalloc(pcbinfo->ipi_zone, M_WAITOK); + inp = uma_zalloc(pcbinfo->ipi_zone, M_NOWAIT); if (inp == NULL) return (ENOBUFS); bzero((caddr_t)inp, sizeof(*inp)); @@ -165,6 +165,7 @@ in_pcballoc(so, pcbinfo, td) LIST_INSERT_HEAD(pcbinfo->listhead, inp, inp_list); pcbinfo->ipi_count++; so->so_pcb = (caddr_t)inp; + INP_LOCK_INIT(inp, "inp"); #ifdef INET6 if (ip6_auto_flowlabel) inp->inp_flags |= IN6P_AUTOFLOWLABEL; @@ -572,23 +573,23 @@ in_pcbdetach(inp) rtfree(inp->inp_route.ro_rt); ip_freemoptions(inp->inp_moptions); inp->inp_vflag = 0; + INP_LOCK_DESTROY(inp); uma_zfree(ipi->ipi_zone, inp); } /* - * The calling convention of in_setsockaddr() and in_setpeeraddr() was - * modified to match the pru_sockaddr() and pru_peeraddr() entry points - * in struct pr_usrreqs, so that protocols can just reference then directly - * without the need for a wrapper function. The socket must have a valid + * The wrapper function will pass down the pcbinfo for this function to lock. + * The socket must have a valid * (i.e., non-nil) PCB, but it should be impossible to get an invalid one * except through a kernel programming error, so it is acceptable to panic * (or in this case trap) if the PCB is invalid. (Actually, we don't trap * because there actually /is/ a programming error somewhere... XXX) */ int -in_setsockaddr(so, nam) +in_setsockaddr(so, nam, pcbinfo) struct socket *so; struct sockaddr **nam; + struct inpcbinfo *pcbinfo; { int s; register struct inpcb *inp; @@ -603,27 +604,36 @@ in_setsockaddr(so, nam) sin->sin_len = sizeof(*sin); s = splnet(); + INP_INFO_RLOCK(pcbinfo); inp = sotoinpcb(so); if (!inp) { + INP_INFO_RUNLOCK(pcbinfo); splx(s); free(sin, M_SONAME); return ECONNRESET; } + INP_LOCK(inp); sin->sin_port = inp->inp_lport; sin->sin_addr = inp->inp_laddr; + INP_UNLOCK(inp); + INP_INFO_RUNLOCK(pcbinfo); splx(s); *nam = (struct sockaddr *)sin; return 0; } +/* + * The wrapper function will pass down the pcbinfo for this function to lock. + */ int -in_setpeeraddr(so, nam) +in_setpeeraddr(so, nam, pcbinfo) struct socket *so; struct sockaddr **nam; + struct inpcbinfo *pcbinfo; { int s; - struct inpcb *inp; + register struct inpcb *inp; register struct sockaddr_in *sin; /* @@ -635,14 +645,19 @@ in_setpeeraddr(so, nam) sin->sin_len = sizeof(*sin); s = splnet(); + INP_INFO_RLOCK(pcbinfo); inp = sotoinpcb(so); if (!inp) { + INP_INFO_RUNLOCK(pcbinfo); splx(s); free(sin, M_SONAME); return ECONNRESET; } + INP_LOCK(inp); sin->sin_port = inp->inp_fport; sin->sin_addr = inp->inp_faddr; + INP_UNLOCK(inp); + INP_INFO_RUNLOCK(pcbinfo); splx(s); *nam = (struct sockaddr *)sin; @@ -650,40 +665,55 @@ in_setpeeraddr(so, nam) } void -in_pcbnotifyall(head, faddr, errno, notify) - struct inpcbhead *head; +in_pcbnotifyall(pcbinfo, faddr, errno, notify) + struct inpcbinfo *pcbinfo; struct in_addr faddr; int errno; void (*notify)(struct inpcb *, int); { struct inpcb *inp, *ninp; + struct inpcbhead *head; int s; s = splnet(); + INP_INFO_RLOCK(pcbinfo); + head = pcbinfo->listhead; for (inp = LIST_FIRST(head); inp != NULL; inp = ninp) { + INP_LOCK(inp); ninp = LIST_NEXT(inp, inp_list); #ifdef INET6 - if ((inp->inp_vflag & INP_IPV4) == 0) + if ((inp->inp_vflag & INP_IPV4) == 0) { + INP_UNLOCK(inp); continue; + } #endif if (inp->inp_faddr.s_addr != faddr.s_addr || - inp->inp_socket == NULL) + inp->inp_socket == NULL) { + INP_UNLOCK(inp); continue; + } (*notify)(inp, errno); + INP_UNLOCK(inp); } + INP_INFO_RUNLOCK(pcbinfo); splx(s); } void -in_pcbpurgeif0(head, ifp) - struct inpcb *head; +in_pcbpurgeif0(pcbinfo, ifp) + struct inpcbinfo *pcbinfo; struct ifnet *ifp; { + struct inpcb *head; struct inpcb *inp; struct ip_moptions *imo; int i, gap; + /* why no splnet here? XXX */ + INP_INFO_RLOCK(pcbinfo); + head = LIST_FIRST(pcbinfo->listhead); for (inp = head; inp != NULL; inp = LIST_NEXT(inp, inp_list)) { + INP_LOCK(inp); imo = inp->inp_moptions; if ((inp->inp_vflag & INP_IPV4) && imo != NULL) { @@ -709,7 +739,9 @@ in_pcbpurgeif0(head, ifp) } imo->imo_num_memberships -= gap; } + INP_UNLOCK(inp); } + INP_INFO_RLOCK(pcbinfo); } /* diff --git a/sys/netinet/in_pcb.h b/sys/netinet/in_pcb.h index b1010d9..fa3f29e 100644 --- a/sys/netinet/in_pcb.h +++ b/sys/netinet/in_pcb.h @@ -174,6 +174,8 @@ struct inpcb { LIST_ENTRY(inpcb) inp_portlist; struct inpcbport *inp_phd; /* head of this list */ inp_gen_t inp_gencnt; /* generation count of this instance */ + struct mtx inp_mtx; + #define in6p_faddr inp_inc.inc6_faddr #define in6p_laddr inp_inc.inc6_laddr #define in6p_route inp_inc.inc6_route @@ -239,8 +241,22 @@ struct inpcbinfo { /* XXX documentation, prefixes */ uma_zone_t ipi_zone; /* zone to allocate pcbs from */ u_int ipi_count; /* number of pcbs in this list */ u_quad_t ipi_gencnt; /* current generation count */ + struct mtx ipi_mtx; }; +#define INP_LOCK_INIT(inp, d) \ + mtx_init(&(inp)->inp_mtx, (d), NULL, MTX_DEF | MTX_RECURSE) +#define INP_LOCK_DESTROY(inp) mtx_destroy(&(inp)->inp_mtx) +#define INP_LOCK(inp) mtx_lock(&(inp)->inp_mtx) +#define INP_UNLOCK(inp) mtx_unlock(&(inp)->inp_mtx) + +#define INP_INFO_LOCK_INIT(ipi, d) \ + mtx_init(&(ipi)->ipi_mtx, (d), NULL, MTX_DEF | MTX_RECURSE) +#define INP_INFO_RLOCK(ipi) mtx_lock(&(ipi)->ipi_mtx) +#define INP_INFO_WLOCK(ipi) mtx_lock(&(ipi)->ipi_mtx) +#define INP_INFO_RUNLOCK(ipi) mtx_unlock(&(ipi)->ipi_mtx) +#define INP_INFO_WUNLOCK(ipi) mtx_unlock(&(ipi)->ipi_mtx) + #define INP_PCBHASH(faddr, lport, fport, mask) \ (((faddr) ^ ((faddr) >> 16) ^ ntohs((lport) ^ (fport))) & (mask)) #define INP_PCBPORTHASH(lport, mask) \ @@ -306,7 +322,7 @@ extern int ipport_lastauto; extern int ipport_hifirstauto; extern int ipport_hilastauto; -void in_pcbpurgeif0(struct inpcb *, struct ifnet *); +void in_pcbpurgeif0(struct inpcbinfo *, struct ifnet *); void in_losing(struct inpcb *); void in_rtchange(struct inpcb *, int); int in_pcballoc(struct socket *, struct inpcbinfo *, struct thread *); @@ -323,11 +339,11 @@ struct inpcb * struct inpcb * in_pcblookup_hash(struct inpcbinfo *, struct in_addr, u_int, struct in_addr, u_int, int, struct ifnet *); -void in_pcbnotifyall(struct inpcbhead *, struct in_addr, +void in_pcbnotifyall(struct inpcbinfo *pcbinfo, struct in_addr, int, void (*)(struct inpcb *, int)); void in_pcbrehash(struct inpcb *); -int in_setpeeraddr(struct socket *so, struct sockaddr **nam); -int in_setsockaddr(struct socket *so, struct sockaddr **nam); +int in_setpeeraddr(struct socket *so, struct sockaddr **nam, struct inpcbinfo *pcbinfo); +int in_setsockaddr(struct socket *so, struct sockaddr **nam, struct inpcbinfo *pcbinfo);; void in_pcbremlists(struct inpcb *inp); int prison_xinpcb(struct thread *td, struct inpcb *inp); #endif /* _KERNEL */ diff --git a/sys/netinet/ip_divert.c b/sys/netinet/ip_divert.c index 8f56d2c..7f8e60b 100644 --- a/sys/netinet/ip_divert.c +++ b/sys/netinet/ip_divert.c @@ -528,6 +528,27 @@ div_pcblist(SYSCTL_HANDLER_ARGS) return error; } +/* + * This is the wrapper function for in_setsockaddr. We just pass down + * the pcbinfo for in_setpeeraddr to lock. + */ +static int +dip_sockaddr(struct socket *so, struct sockaddr **nam) +{ + return (in_setsockaddr(so, nam, &divcbinfo)); +} + +/* + * This is the wrapper function for in_setpeeraddr. We just pass down + * the pcbinfo for in_setpeeraddr to lock. + */ +static int +dip_peeraddr(struct socket *so, struct sockaddr **nam) +{ + return (in_setpeeraddr(so, nam, &divcbinfo)); +} + + SYSCTL_DECL(_net_inet_divert); SYSCTL_PROC(_net_inet_divert, OID_AUTO, pcblist, CTLFLAG_RD, 0, 0, div_pcblist, "S,xinpcb", "List of active divert sockets"); @@ -535,7 +556,7 @@ SYSCTL_PROC(_net_inet_divert, OID_AUTO, pcblist, CTLFLAG_RD, 0, 0, struct pr_usrreqs div_usrreqs = { div_abort, pru_accept_notsupp, div_attach, div_bind, pru_connect_notsupp, pru_connect2_notsupp, in_control, div_detach, - div_disconnect, pru_listen_notsupp, in_setpeeraddr, pru_rcvd_notsupp, + div_disconnect, pru_listen_notsupp, div_peeraddr, pru_rcvd_notsupp, pru_rcvoob_notsupp, div_send, pru_sense_null, div_shutdown, - in_setsockaddr, sosend, soreceive, sopoll + div_sockaddr, sosend, soreceive, sopoll }; diff --git a/sys/netinet/raw_ip.c b/sys/netinet/raw_ip.c index 86915fc..13a84fd 100644 --- a/sys/netinet/raw_ip.c +++ b/sys/netinet/raw_ip.c @@ -673,13 +673,34 @@ rip_pcblist(SYSCTL_HANDLER_ARGS) return error; } +/* + * This is the wrapper function for in_setsockaddr. We just pass down + * the pcbinfo for in_setpeeraddr to lock. + */ +static int +rip_sockaddr(struct socket *so, struct sockaddr **nam) +{ + return (in_setsockaddr(so, nam, &ripcbinfo)); +} + +/* + * This is the wrapper function for in_setpeeraddr. We just pass down + * the pcbinfo for in_setpeeraddr to lock. + */ +static int +rip_peeraddr(struct socket *so, struct sockaddr **nam) +{ + return (in_setpeeraddr(so, nam, &ripcbinfo)); +} + + SYSCTL_PROC(_net_inet_raw, OID_AUTO/*XXX*/, pcblist, CTLFLAG_RD, 0, 0, rip_pcblist, "S,xinpcb", "List of active raw IP sockets"); struct pr_usrreqs rip_usrreqs = { rip_abort, pru_accept_notsupp, rip_attach, rip_bind, rip_connect, pru_connect2_notsupp, in_control, rip_detach, rip_disconnect, - pru_listen_notsupp, in_setpeeraddr, pru_rcvd_notsupp, + pru_listen_notsupp, rip_peeraddr, pru_rcvd_notsupp, pru_rcvoob_notsupp, rip_send, pru_sense_null, rip_shutdown, - in_setsockaddr, sosend, soreceive, sopoll + rip_sockaddr, sosend, soreceive, sopoll }; diff --git a/sys/netinet/tcp_input.c b/sys/netinet/tcp_input.c index b193327..68f9bc6 100644 --- a/sys/netinet/tcp_input.c +++ b/sys/netinet/tcp_input.c @@ -130,6 +130,7 @@ SYSCTL_INT(_net_inet_tcp, OID_AUTO, drop_synfin, CTLFLAG_RW, struct inpcbhead tcb; #define tcb6 tcb /* for KAME src sync over BSD*'s */ struct inpcbinfo tcbinfo; +struct mtx *tcbinfo_mtx; static void tcp_dooptions(struct tcpopt *, u_char *, int, int); static void tcp_pulloutofband(struct socket *, @@ -335,7 +336,7 @@ tcp_input(m, off0) register struct tcphdr *th; register struct ip *ip = NULL; register struct ipovly *ipov; - register struct inpcb *inp; + register struct inpcb *inp = NULL; u_char *optp = NULL; int optlen = 0; int len, tlen, off; @@ -348,6 +349,8 @@ tcp_input(m, off0) struct tcpopt to; /* options in this segment */ struct rmxp_tao *taop; /* pointer to our TAO cache entry */ struct rmxp_tao tao_noncached; /* in case there's no cached entry */ + int headlocked = 0; + #ifdef TCPDEBUG short ostate = 0; #endif @@ -506,6 +509,8 @@ tcp_input(m, off0) /* * Locate pcb for segment. */ + INP_INFO_WLOCK(&tcbinfo); + headlocked = 1; findpcb: #ifdef IPFIREWALL_FORWARD if (ip_fw_fwd_addr != NULL @@ -623,8 +628,10 @@ findpcb: rstreason = BANDLIM_RST_CLOSEDPORT; goto dropwithreset; } + INP_LOCK(inp); tp = intotcpcb(inp); if (tp == 0) { + INP_UNLOCK(inp); rstreason = BANDLIM_RST_CLOSEDPORT; goto dropwithreset; } @@ -695,18 +702,23 @@ findpcb: rstreason = BANDLIM_RST_OPENPORT; goto dropwithreset; } - if (so == NULL) + if (so == NULL) { /* * Could not complete 3-way handshake, * connection is being closed down, and * syncache will free mbuf. */ + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&tcbinfo); return; + } /* * Socket is created in state SYN_RECEIVED. * Continue processing segment. */ + INP_UNLOCK(inp); inp = sotoinpcb(so); + INP_LOCK(inp); tp = intotcpcb(inp); /* * This is what would have happened in @@ -777,6 +789,7 @@ findpcb: if ((ia6 = ip6_getdstifaddr(m)) && (ia6->ia6_flags & IN6_IFF_DEPRECATED)) { + INP_UNLOCK(inp); tp = NULL; rstreason = BANDLIM_RST_OPENPORT; goto dropwithreset; @@ -827,16 +840,22 @@ findpcb: tcp_dooptions(&to, optp, optlen, 1); if (!syncache_add(&inc, &to, th, &so, m)) goto drop; - if (so == NULL) + if (so == NULL) { /* * Entry added to syncache, mbuf used to * send SYN,ACK packet. */ + KASSERT(headlocked, ("headlocked")); + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&tcbinfo); return; + } /* * Segment passed TAO tests. */ + INP_UNLOCK(inp); inp = sotoinpcb(so); + INP_LOCK(inp); tp = intotcpcb(inp); tp->snd_wnd = tiwin; tp->t_starttime = ticks; @@ -959,6 +978,9 @@ after_listen: SEQ_LEQ(th->th_ack, tp->snd_max) && tp->snd_cwnd >= tp->snd_wnd && tp->t_dupacks < tcprexmtthresh) { + KASSERT(headlocked, ("headlocked")); + INP_INFO_WUNLOCK(&tcbinfo); + headlocked = 0; /* * this is a pure ack for outstanding data. */ @@ -1007,11 +1029,15 @@ after_listen: sowwakeup(so); if (so->so_snd.sb_cc) (void) tcp_output(tp); + INP_UNLOCK(inp); return; } } else if (th->th_ack == tp->snd_una && LIST_EMPTY(&tp->t_segq) && tlen <= sbspace(&so->so_rcv)) { + KASSERT(headlocked, ("headlocked")); + INP_INFO_WUNLOCK(&tcbinfo); + headlocked = 0; /* * this is a pure, in-sequence data packet * with nothing on the reassembly queue and @@ -1035,6 +1061,7 @@ after_listen: tp->t_flags |= TF_ACKNOW; tcp_output(tp); } + INP_UNLOCK(inp); return; } } @@ -1983,7 +2010,9 @@ step6: if (SEQ_GT(tp->rcv_nxt, tp->rcv_up)) tp->rcv_up = tp->rcv_nxt; dodata: /* XXX */ - + KASSERT(headlocked, ("headlocked")); + INP_INFO_WUNLOCK(&tcbinfo); + headlocked = 0; /* * Process the segment text, merging it into the TCP sequencing queue, * and arranging for acknowledgment of receipt if necessary. @@ -2121,6 +2150,7 @@ dodata: /* XXX */ */ if (needoutput || (tp->t_flags & TF_ACKNOW)) (void) tcp_output(tp); + INP_UNLOCK(inp); return; dropafterack: @@ -2150,9 +2180,12 @@ dropafterack: tcp_trace(TA_DROP, ostate, tp, (void *)tcp_saveipgen, &tcp_savetcp, 0); #endif + if (headlocked) + INP_INFO_WUNLOCK(&tcbinfo); m_freem(m); tp->t_flags |= TF_ACKNOW; (void) tcp_output(tp); + INP_UNLOCK(inp); return; dropwithreset: @@ -2177,6 +2210,8 @@ dropwithreset: goto drop; /* IPv6 anycast check is done at tcp6_input() */ + if (tp) + INP_UNLOCK(inp); /* * Perform bandwidth limiting. */ @@ -2199,6 +2234,8 @@ dropwithreset: tcp_respond(tp, mtod(m, void *), th, m, th->th_seq+tlen, (tcp_seq)0, TH_RST|TH_ACK); } + if (headlocked) + INP_INFO_WUNLOCK(&tcbinfo); return; drop: @@ -2210,7 +2247,11 @@ drop: tcp_trace(TA_DROP, ostate, tp, (void *)tcp_saveipgen, &tcp_savetcp, 0); #endif + if (tp) + INP_UNLOCK(inp); m_freem(m); + if (headlocked) + INP_INFO_WUNLOCK(&tcbinfo); return; } diff --git a/sys/netinet/tcp_reass.c b/sys/netinet/tcp_reass.c index b193327..68f9bc6 100644 --- a/sys/netinet/tcp_reass.c +++ b/sys/netinet/tcp_reass.c @@ -130,6 +130,7 @@ SYSCTL_INT(_net_inet_tcp, OID_AUTO, drop_synfin, CTLFLAG_RW, struct inpcbhead tcb; #define tcb6 tcb /* for KAME src sync over BSD*'s */ struct inpcbinfo tcbinfo; +struct mtx *tcbinfo_mtx; static void tcp_dooptions(struct tcpopt *, u_char *, int, int); static void tcp_pulloutofband(struct socket *, @@ -335,7 +336,7 @@ tcp_input(m, off0) register struct tcphdr *th; register struct ip *ip = NULL; register struct ipovly *ipov; - register struct inpcb *inp; + register struct inpcb *inp = NULL; u_char *optp = NULL; int optlen = 0; int len, tlen, off; @@ -348,6 +349,8 @@ tcp_input(m, off0) struct tcpopt to; /* options in this segment */ struct rmxp_tao *taop; /* pointer to our TAO cache entry */ struct rmxp_tao tao_noncached; /* in case there's no cached entry */ + int headlocked = 0; + #ifdef TCPDEBUG short ostate = 0; #endif @@ -506,6 +509,8 @@ tcp_input(m, off0) /* * Locate pcb for segment. */ + INP_INFO_WLOCK(&tcbinfo); + headlocked = 1; findpcb: #ifdef IPFIREWALL_FORWARD if (ip_fw_fwd_addr != NULL @@ -623,8 +628,10 @@ findpcb: rstreason = BANDLIM_RST_CLOSEDPORT; goto dropwithreset; } + INP_LOCK(inp); tp = intotcpcb(inp); if (tp == 0) { + INP_UNLOCK(inp); rstreason = BANDLIM_RST_CLOSEDPORT; goto dropwithreset; } @@ -695,18 +702,23 @@ findpcb: rstreason = BANDLIM_RST_OPENPORT; goto dropwithreset; } - if (so == NULL) + if (so == NULL) { /* * Could not complete 3-way handshake, * connection is being closed down, and * syncache will free mbuf. */ + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&tcbinfo); return; + } /* * Socket is created in state SYN_RECEIVED. * Continue processing segment. */ + INP_UNLOCK(inp); inp = sotoinpcb(so); + INP_LOCK(inp); tp = intotcpcb(inp); /* * This is what would have happened in @@ -777,6 +789,7 @@ findpcb: if ((ia6 = ip6_getdstifaddr(m)) && (ia6->ia6_flags & IN6_IFF_DEPRECATED)) { + INP_UNLOCK(inp); tp = NULL; rstreason = BANDLIM_RST_OPENPORT; goto dropwithreset; @@ -827,16 +840,22 @@ findpcb: tcp_dooptions(&to, optp, optlen, 1); if (!syncache_add(&inc, &to, th, &so, m)) goto drop; - if (so == NULL) + if (so == NULL) { /* * Entry added to syncache, mbuf used to * send SYN,ACK packet. */ + KASSERT(headlocked, ("headlocked")); + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&tcbinfo); return; + } /* * Segment passed TAO tests. */ + INP_UNLOCK(inp); inp = sotoinpcb(so); + INP_LOCK(inp); tp = intotcpcb(inp); tp->snd_wnd = tiwin; tp->t_starttime = ticks; @@ -959,6 +978,9 @@ after_listen: SEQ_LEQ(th->th_ack, tp->snd_max) && tp->snd_cwnd >= tp->snd_wnd && tp->t_dupacks < tcprexmtthresh) { + KASSERT(headlocked, ("headlocked")); + INP_INFO_WUNLOCK(&tcbinfo); + headlocked = 0; /* * this is a pure ack for outstanding data. */ @@ -1007,11 +1029,15 @@ after_listen: sowwakeup(so); if (so->so_snd.sb_cc) (void) tcp_output(tp); + INP_UNLOCK(inp); return; } } else if (th->th_ack == tp->snd_una && LIST_EMPTY(&tp->t_segq) && tlen <= sbspace(&so->so_rcv)) { + KASSERT(headlocked, ("headlocked")); + INP_INFO_WUNLOCK(&tcbinfo); + headlocked = 0; /* * this is a pure, in-sequence data packet * with nothing on the reassembly queue and @@ -1035,6 +1061,7 @@ after_listen: tp->t_flags |= TF_ACKNOW; tcp_output(tp); } + INP_UNLOCK(inp); return; } } @@ -1983,7 +2010,9 @@ step6: if (SEQ_GT(tp->rcv_nxt, tp->rcv_up)) tp->rcv_up = tp->rcv_nxt; dodata: /* XXX */ - + KASSERT(headlocked, ("headlocked")); + INP_INFO_WUNLOCK(&tcbinfo); + headlocked = 0; /* * Process the segment text, merging it into the TCP sequencing queue, * and arranging for acknowledgment of receipt if necessary. @@ -2121,6 +2150,7 @@ dodata: /* XXX */ */ if (needoutput || (tp->t_flags & TF_ACKNOW)) (void) tcp_output(tp); + INP_UNLOCK(inp); return; dropafterack: @@ -2150,9 +2180,12 @@ dropafterack: tcp_trace(TA_DROP, ostate, tp, (void *)tcp_saveipgen, &tcp_savetcp, 0); #endif + if (headlocked) + INP_INFO_WUNLOCK(&tcbinfo); m_freem(m); tp->t_flags |= TF_ACKNOW; (void) tcp_output(tp); + INP_UNLOCK(inp); return; dropwithreset: @@ -2177,6 +2210,8 @@ dropwithreset: goto drop; /* IPv6 anycast check is done at tcp6_input() */ + if (tp) + INP_UNLOCK(inp); /* * Perform bandwidth limiting. */ @@ -2199,6 +2234,8 @@ dropwithreset: tcp_respond(tp, mtod(m, void *), th, m, th->th_seq+tlen, (tcp_seq)0, TH_RST|TH_ACK); } + if (headlocked) + INP_INFO_WUNLOCK(&tcbinfo); return; drop: @@ -2210,7 +2247,11 @@ drop: tcp_trace(TA_DROP, ostate, tp, (void *)tcp_saveipgen, &tcp_savetcp, 0); #endif + if (tp) + INP_UNLOCK(inp); m_freem(m); + if (headlocked) + INP_INFO_WUNLOCK(&tcbinfo); return; } diff --git a/sys/netinet/tcp_subr.c b/sys/netinet/tcp_subr.c index 13d8300..143dbff 100644 --- a/sys/netinet/tcp_subr.c +++ b/sys/netinet/tcp_subr.c @@ -197,6 +197,7 @@ tcp_init() tcp_maxpersistidle = TCPTV_KEEP_IDLE; tcp_msl = TCPTV_MSL; + INP_INFO_LOCK_INIT(&tcbinfo, "tcp"); LIST_INIT(&tcb); tcbinfo.listhead = &tcb; TUNABLE_INT_FETCH("net.inet.tcp.tcbhashsize", &hashsize); @@ -748,7 +749,9 @@ tcp_drain() * where we're really low on mbufs, this is potentially * usefull. */ + INP_INFO_RLOCK(&tcbinfo); LIST_FOREACH(inpb, tcbinfo.listhead, inp_list) { + INP_LOCK(inpb); if ((tcpb = intotcpcb(inpb))) { while ((te = LIST_FIRST(&tcpb->t_segq)) != NULL) { @@ -757,7 +760,9 @@ tcp_drain() FREE(te, M_TSEGQ); } } + INP_UNLOCK(inpb); } + INP_INFO_RUNLOCK(&tcbinfo); } } @@ -825,8 +830,10 @@ tcp_pcblist(SYSCTL_HANDLER_ARGS) * OK, now we're committed to doing something. */ s = splnet(); + INP_INFO_RLOCK(&tcbinfo); gencnt = tcbinfo.ipi_gencnt; n = tcbinfo.ipi_count; + INP_INFO_RUNLOCK(&tcbinfo); splx(s); xig.xig_len = sizeof xig; @@ -842,21 +849,26 @@ tcp_pcblist(SYSCTL_HANDLER_ARGS) return ENOMEM; s = splnet(); + INP_INFO_RLOCK(&tcbinfo); for (inp = LIST_FIRST(tcbinfo.listhead), i = 0; inp && i < n; inp = LIST_NEXT(inp, inp_list)) { + INP_LOCK(inp); if (inp->inp_gencnt <= gencnt) { if (cr_canseesocket(req->td->td_ucred, inp->inp_socket)) continue; inp_list[i++] = inp; } + INP_UNLOCK(inp); } + INP_INFO_RUNLOCK(&tcbinfo); splx(s); n = i; error = 0; for (i = 0; i < n; i++) { inp = inp_list[i]; + INP_LOCK(inp); if (inp->inp_gencnt <= gencnt) { struct xtcpcb xt; caddr_t inp_ppcb; @@ -872,6 +884,7 @@ tcp_pcblist(SYSCTL_HANDLER_ARGS) sotoxsocket(inp->inp_socket, &xt.xt_socket); error = SYSCTL_OUT(req, &xt, sizeof xt); } + INP_UNLOCK(inp); } if (!error) { /* @@ -882,9 +895,11 @@ tcp_pcblist(SYSCTL_HANDLER_ARGS) * might be necessary to retry. */ s = splnet(); + INP_INFO_RLOCK(&tcbinfo); xig.xig_gen = tcbinfo.ipi_gencnt; xig.xig_sogen = so_gencnt; xig.xig_count = tcbinfo.ipi_count; + INP_INFO_RUNLOCK(&tcbinfo); splx(s); error = SYSCTL_OUT(req, &xig, sizeof xig); } @@ -910,18 +925,29 @@ tcp_getcred(SYSCTL_HANDLER_ARGS) if (error) return (error); s = splnet(); + INP_INFO_RLOCK(&tcbinfo); inp = in_pcblookup_hash(&tcbinfo, addrs[1].sin_addr, addrs[1].sin_port, addrs[0].sin_addr, addrs[0].sin_port, 0, NULL); - if (inp == NULL || inp->inp_socket == NULL) { + if (inp == NULL) { error = ENOENT; - goto out; + goto outunlocked; + } else { + INP_LOCK(inp); + if (inp->inp_socket == NULL) { + error = ENOENT; + goto out; + } } + error = cr_canseesocket(req->td->td_ucred, inp->inp_socket); if (error) goto out; cru2x(inp->inp_socket->so_cred, &xuc); error = SYSCTL_OUT(req, &xuc, sizeof(struct xucred)); out: + INP_UNLOCK(inp); +outunlocked: + INP_INFO_RUNLOCK(&tcbinfo); splx(s); return (error); } @@ -952,6 +978,7 @@ tcp6_getcred(SYSCTL_HANDLER_ARGS) return (EINVAL); } s = splnet(); + INP_INFO_RLOCK(&tcbinfo); if (mapped == 1) inp = in_pcblookup_hash(&tcbinfo, *(struct in_addr *)&addrs[1].sin6_addr.s6_addr[12], @@ -964,9 +991,15 @@ tcp6_getcred(SYSCTL_HANDLER_ARGS) addrs[1].sin6_port, &addrs[0].sin6_addr, addrs[0].sin6_port, 0, NULL); - if (inp == NULL || inp->inp_socket == NULL) { + if (inp == NULL) { error = ENOENT; - goto out; + goto outunlocked; + } else { + INP_LOCK(inp); + if (inp->inp_socket == NULL) { + error = ENOENT; + goto out; + } } error = cr_canseesocket(req->td->td_ucred, inp->inp_socket); if (error) @@ -974,6 +1007,9 @@ tcp6_getcred(SYSCTL_HANDLER_ARGS) cru2x(inp->inp_socket->so_cred, &xuc); error = SYSCTL_OUT(req, &xuc, sizeof(struct xucred)); out: + INP_UNLOCK(inp); +outunlocked: + INP_INFO_RUNLOCK(&tcbinfo); splx(s); return (error); } @@ -1021,14 +1057,19 @@ tcp_ctlinput(cmd, sa, vip) s = splnet(); th = (struct tcphdr *)((caddr_t)ip + (IP_VHL_HL(ip->ip_vhl) << 2)); + INP_INFO_RLOCK(&tcbinfo); inp = in_pcblookup_hash(&tcbinfo, faddr, th->th_dport, ip->ip_src, th->th_sport, 0, NULL); - if (inp != NULL && inp->inp_socket != NULL) { - icmp_seq = htonl(th->th_seq); - tp = intotcpcb(inp); - if (SEQ_GEQ(icmp_seq, tp->snd_una) && - SEQ_LT(icmp_seq, tp->snd_max)) - (*notify)(inp, inetctlerrmap[cmd]); + if (inp != NULL) { + INP_LOCK(inp); + if (inp->inp_socket != NULL) { + icmp_seq = htonl(th->th_seq); + tp = intotcpcb(inp); + if (SEQ_GEQ(icmp_seq, tp->snd_una) && + SEQ_LT(icmp_seq, tp->snd_max)) + (*notify)(inp, inetctlerrmap[cmd]); + } + INP_UNLOCK(inp); } else { struct in_conninfo inc; @@ -1041,9 +1082,10 @@ tcp_ctlinput(cmd, sa, vip) #endif syncache_unreach(&inc, th); } + INP_INFO_RUNLOCK(&tcbinfo); splx(s); } else - in_pcbnotifyall(&tcb, faddr, inetctlerrmap[cmd], notify); + in_pcbnotifyall(&tcbinfo, faddr, inetctlerrmap[cmd], notify); } #ifdef INET6 diff --git a/sys/netinet/tcp_syncache.c b/sys/netinet/tcp_syncache.c index 771fa82..30a3a93 100644 --- a/sys/netinet/tcp_syncache.c +++ b/sys/netinet/tcp_syncache.c @@ -367,24 +367,29 @@ syncache_timer(xslot) callout_deactivate(&tcp_syncache.tt_timerq[slot]); nsc = TAILQ_FIRST(&tcp_syncache.timerq[slot]); + INP_INFO_RLOCK(&tcbinfo); while (nsc != NULL) { if (ticks < nsc->sc_rxttime) break; sc = nsc; nsc = TAILQ_NEXT(sc, sc_timerq); inp = sc->sc_tp->t_inpcb; + INP_LOCK(inp); if (slot == SYNCACHE_MAXREXMTS || slot >= tcp_syncache.rexmt_limit || inp->inp_gencnt != sc->sc_inp_gencnt) { syncache_drop(sc, NULL); tcpstat.tcps_sc_stale++; + INP_UNLOCK(inp); continue; } (void) syncache_respond(sc, NULL); + INP_UNLOCK(inp); tcpstat.tcps_sc_retransmitted++; TAILQ_REMOVE(&tcp_syncache.timerq[slot], sc, sc_timerq); SYNCACHE_TIMEOUT(sc, slot + 1); } + INP_INFO_RUNLOCK(&tcbinfo); if (nsc != NULL) callout_reset(&tcp_syncache.tt_timerq[slot], nsc->sc_rxttime - ticks, syncache_timer, (void *)(slot)); diff --git a/sys/netinet/tcp_timer.c b/sys/netinet/tcp_timer.c index 087e243..82cf3c5 100644 --- a/sys/netinet/tcp_timer.c +++ b/sys/netinet/tcp_timer.c @@ -160,15 +160,22 @@ static int tcp_totbackoff = 511; /* sum of tcp_backoff[] */ /* * TCP timer processing. */ + void tcp_timer_delack(xtp) void *xtp; { struct tcpcb *tp = xtp; int s; + struct inpcb *inp; s = splnet(); + INP_INFO_RLOCK(&tcbinfo); + inp = tp->t_inpcb; + INP_LOCK(inp); + INP_INFO_RUNLOCK(&tcbinfo); if (callout_pending(tp->tt_delack) || !callout_active(tp->tt_delack)) { + INP_UNLOCK(inp); splx(s); return; } @@ -177,6 +184,7 @@ tcp_timer_delack(xtp) tp->t_flags |= TF_ACKNOW; tcpstat.tcps_delack++; (void) tcp_output(tp); + INP_UNLOCK(inp); splx(s); } @@ -186,13 +194,19 @@ tcp_timer_2msl(xtp) { struct tcpcb *tp = xtp; int s; + struct inpcb *inp; #ifdef TCPDEBUG int ostate; ostate = tp->t_state; #endif s = splnet(); + INP_INFO_WLOCK(&tcbinfo); + inp = tp->t_inpcb; + INP_LOCK(inp); if (callout_pending(tp->tt_2msl) || !callout_active(tp->tt_2msl)) { + INP_UNLOCK(tp->t_inpcb); + INP_INFO_WUNLOCK(&tcbinfo); splx(s); return; } @@ -215,6 +229,9 @@ tcp_timer_2msl(xtp) tcp_trace(TA_USER, ostate, tp, (void *)0, (struct tcphdr *)0, PRU_SLOWTIMO); #endif + if (tp) + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&tcbinfo); splx(s); } @@ -225,13 +242,19 @@ tcp_timer_keep(xtp) struct tcpcb *tp = xtp; struct tcptemp *t_template; int s; + struct inpcb *inp; #ifdef TCPDEBUG int ostate; ostate = tp->t_state; #endif s = splnet(); + INP_INFO_WLOCK(&tcbinfo); + inp = tp->t_inpcb; + INP_LOCK(inp); if (callout_pending(tp->tt_keep) || !callout_active(tp->tt_keep)) { + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&tcbinfo); splx(s); return; } @@ -277,6 +300,8 @@ tcp_timer_keep(xtp) tcp_trace(TA_USER, ostate, tp, (void *)0, (struct tcphdr *)0, PRU_SLOWTIMO); #endif + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&tcbinfo); splx(s); return; @@ -289,6 +314,9 @@ dropit: tcp_trace(TA_USER, ostate, tp, (void *)0, (struct tcphdr *)0, PRU_SLOWTIMO); #endif + if (tp) + INP_UNLOCK(tp->t_inpcb); + INP_INFO_WUNLOCK(&tcbinfo); splx(s); } @@ -298,13 +326,19 @@ tcp_timer_persist(xtp) { struct tcpcb *tp = xtp; int s; + struct inpcb *inp; #ifdef TCPDEBUG int ostate; ostate = tp->t_state; #endif s = splnet(); + INP_INFO_WLOCK(&tcbinfo); + inp = tp->t_inpcb; + INP_LOCK(inp); if (callout_pending(tp->tt_persist) || !callout_active(tp->tt_persist)){ + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&tcbinfo); splx(s); return; } @@ -339,6 +373,9 @@ out: tcp_trace(TA_USER, ostate, tp, (void *)0, (struct tcphdr *)0, PRU_SLOWTIMO); #endif + if (tp) + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&tcbinfo); splx(s); } @@ -349,13 +386,21 @@ tcp_timer_rexmt(xtp) struct tcpcb *tp = xtp; int s; int rexmt; + int headlocked; + struct inpcb *inp; #ifdef TCPDEBUG int ostate; ostate = tp->t_state; #endif s = splnet(); + INP_INFO_WLOCK(&tcbinfo); + headlocked = 1; + inp = tp->t_inpcb; + INP_LOCK(inp); if (callout_pending(tp->tt_rexmt) || !callout_active(tp->tt_rexmt)) { + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&tcbinfo); splx(s); return; } @@ -372,6 +417,8 @@ tcp_timer_rexmt(xtp) tp->t_softerror : ETIMEDOUT); goto out; } + INP_INFO_WUNLOCK(&tcbinfo); + headlocked = 0; if (tp->t_rxtshift == 1) { /* * first retransmit; record ssthresh and cwnd so they can @@ -474,5 +521,9 @@ out: tcp_trace(TA_USER, ostate, tp, (void *)0, (struct tcphdr *)0, PRU_SLOWTIMO); #endif + if (tp) + INP_UNLOCK(inp); + if (headlocked) + INP_INFO_WUNLOCK(&tcbinfo); splx(s); } diff --git a/sys/netinet/tcp_timewait.c b/sys/netinet/tcp_timewait.c index 13d8300..143dbff 100644 --- a/sys/netinet/tcp_timewait.c +++ b/sys/netinet/tcp_timewait.c @@ -197,6 +197,7 @@ tcp_init() tcp_maxpersistidle = TCPTV_KEEP_IDLE; tcp_msl = TCPTV_MSL; + INP_INFO_LOCK_INIT(&tcbinfo, "tcp"); LIST_INIT(&tcb); tcbinfo.listhead = &tcb; TUNABLE_INT_FETCH("net.inet.tcp.tcbhashsize", &hashsize); @@ -748,7 +749,9 @@ tcp_drain() * where we're really low on mbufs, this is potentially * usefull. */ + INP_INFO_RLOCK(&tcbinfo); LIST_FOREACH(inpb, tcbinfo.listhead, inp_list) { + INP_LOCK(inpb); if ((tcpb = intotcpcb(inpb))) { while ((te = LIST_FIRST(&tcpb->t_segq)) != NULL) { @@ -757,7 +760,9 @@ tcp_drain() FREE(te, M_TSEGQ); } } + INP_UNLOCK(inpb); } + INP_INFO_RUNLOCK(&tcbinfo); } } @@ -825,8 +830,10 @@ tcp_pcblist(SYSCTL_HANDLER_ARGS) * OK, now we're committed to doing something. */ s = splnet(); + INP_INFO_RLOCK(&tcbinfo); gencnt = tcbinfo.ipi_gencnt; n = tcbinfo.ipi_count; + INP_INFO_RUNLOCK(&tcbinfo); splx(s); xig.xig_len = sizeof xig; @@ -842,21 +849,26 @@ tcp_pcblist(SYSCTL_HANDLER_ARGS) return ENOMEM; s = splnet(); + INP_INFO_RLOCK(&tcbinfo); for (inp = LIST_FIRST(tcbinfo.listhead), i = 0; inp && i < n; inp = LIST_NEXT(inp, inp_list)) { + INP_LOCK(inp); if (inp->inp_gencnt <= gencnt) { if (cr_canseesocket(req->td->td_ucred, inp->inp_socket)) continue; inp_list[i++] = inp; } + INP_UNLOCK(inp); } + INP_INFO_RUNLOCK(&tcbinfo); splx(s); n = i; error = 0; for (i = 0; i < n; i++) { inp = inp_list[i]; + INP_LOCK(inp); if (inp->inp_gencnt <= gencnt) { struct xtcpcb xt; caddr_t inp_ppcb; @@ -872,6 +884,7 @@ tcp_pcblist(SYSCTL_HANDLER_ARGS) sotoxsocket(inp->inp_socket, &xt.xt_socket); error = SYSCTL_OUT(req, &xt, sizeof xt); } + INP_UNLOCK(inp); } if (!error) { /* @@ -882,9 +895,11 @@ tcp_pcblist(SYSCTL_HANDLER_ARGS) * might be necessary to retry. */ s = splnet(); + INP_INFO_RLOCK(&tcbinfo); xig.xig_gen = tcbinfo.ipi_gencnt; xig.xig_sogen = so_gencnt; xig.xig_count = tcbinfo.ipi_count; + INP_INFO_RUNLOCK(&tcbinfo); splx(s); error = SYSCTL_OUT(req, &xig, sizeof xig); } @@ -910,18 +925,29 @@ tcp_getcred(SYSCTL_HANDLER_ARGS) if (error) return (error); s = splnet(); + INP_INFO_RLOCK(&tcbinfo); inp = in_pcblookup_hash(&tcbinfo, addrs[1].sin_addr, addrs[1].sin_port, addrs[0].sin_addr, addrs[0].sin_port, 0, NULL); - if (inp == NULL || inp->inp_socket == NULL) { + if (inp == NULL) { error = ENOENT; - goto out; + goto outunlocked; + } else { + INP_LOCK(inp); + if (inp->inp_socket == NULL) { + error = ENOENT; + goto out; + } } + error = cr_canseesocket(req->td->td_ucred, inp->inp_socket); if (error) goto out; cru2x(inp->inp_socket->so_cred, &xuc); error = SYSCTL_OUT(req, &xuc, sizeof(struct xucred)); out: + INP_UNLOCK(inp); +outunlocked: + INP_INFO_RUNLOCK(&tcbinfo); splx(s); return (error); } @@ -952,6 +978,7 @@ tcp6_getcred(SYSCTL_HANDLER_ARGS) return (EINVAL); } s = splnet(); + INP_INFO_RLOCK(&tcbinfo); if (mapped == 1) inp = in_pcblookup_hash(&tcbinfo, *(struct in_addr *)&addrs[1].sin6_addr.s6_addr[12], @@ -964,9 +991,15 @@ tcp6_getcred(SYSCTL_HANDLER_ARGS) addrs[1].sin6_port, &addrs[0].sin6_addr, addrs[0].sin6_port, 0, NULL); - if (inp == NULL || inp->inp_socket == NULL) { + if (inp == NULL) { error = ENOENT; - goto out; + goto outunlocked; + } else { + INP_LOCK(inp); + if (inp->inp_socket == NULL) { + error = ENOENT; + goto out; + } } error = cr_canseesocket(req->td->td_ucred, inp->inp_socket); if (error) @@ -974,6 +1007,9 @@ tcp6_getcred(SYSCTL_HANDLER_ARGS) cru2x(inp->inp_socket->so_cred, &xuc); error = SYSCTL_OUT(req, &xuc, sizeof(struct xucred)); out: + INP_UNLOCK(inp); +outunlocked: + INP_INFO_RUNLOCK(&tcbinfo); splx(s); return (error); } @@ -1021,14 +1057,19 @@ tcp_ctlinput(cmd, sa, vip) s = splnet(); th = (struct tcphdr *)((caddr_t)ip + (IP_VHL_HL(ip->ip_vhl) << 2)); + INP_INFO_RLOCK(&tcbinfo); inp = in_pcblookup_hash(&tcbinfo, faddr, th->th_dport, ip->ip_src, th->th_sport, 0, NULL); - if (inp != NULL && inp->inp_socket != NULL) { - icmp_seq = htonl(th->th_seq); - tp = intotcpcb(inp); - if (SEQ_GEQ(icmp_seq, tp->snd_una) && - SEQ_LT(icmp_seq, tp->snd_max)) - (*notify)(inp, inetctlerrmap[cmd]); + if (inp != NULL) { + INP_LOCK(inp); + if (inp->inp_socket != NULL) { + icmp_seq = htonl(th->th_seq); + tp = intotcpcb(inp); + if (SEQ_GEQ(icmp_seq, tp->snd_una) && + SEQ_LT(icmp_seq, tp->snd_max)) + (*notify)(inp, inetctlerrmap[cmd]); + } + INP_UNLOCK(inp); } else { struct in_conninfo inc; @@ -1041,9 +1082,10 @@ tcp_ctlinput(cmd, sa, vip) #endif syncache_unreach(&inc, th); } + INP_INFO_RUNLOCK(&tcbinfo); splx(s); } else - in_pcbnotifyall(&tcb, faddr, inetctlerrmap[cmd], notify); + in_pcbnotifyall(&tcbinfo, faddr, inetctlerrmap[cmd], notify); } #ifdef INET6 diff --git a/sys/netinet/tcp_usrreq.c b/sys/netinet/tcp_usrreq.c index e1f4c1a..f5e75d1 100644 --- a/sys/netinet/tcp_usrreq.c +++ b/sys/netinet/tcp_usrreq.c @@ -40,6 +40,7 @@ #include #include +#include #include #include #include @@ -120,11 +121,13 @@ tcp_usr_attach(struct socket *so, int proto, struct thread *td) { int s = splnet(); int error; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp = 0; TCPDEBUG0; + INP_INFO_WLOCK(&tcbinfo); TCPDEBUG1(); + inp = sotoinpcb(so); if (inp) { error = EISCONN; goto out; @@ -136,9 +139,15 @@ tcp_usr_attach(struct socket *so, int proto, struct thread *td) if ((so->so_options & SO_LINGER) && so->so_linger == 0) so->so_linger = TCP_LINGERTIME; - tp = sototcpcb(so); + + inp = sotoinpcb(so); + INP_LOCK(inp); + tp = intotcpcb(inp); out: TCPDEBUG2(PRU_ATTACH); + if (tp) + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&tcbinfo); splx(s); return error; } @@ -155,35 +164,68 @@ tcp_usr_detach(struct socket *so) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; TCPDEBUG0; + INP_INFO_WLOCK(&tcbinfo); + inp = sotoinpcb(so); if (inp == 0) { + INP_INFO_WUNLOCK(&tcbinfo); splx(s); return EINVAL; /* XXX */ } + INP_LOCK(inp); tp = intotcpcb(inp); TCPDEBUG1(); tp = tcp_disconnect(tp); TCPDEBUG2(PRU_DETACH); + if (tp) + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&tcbinfo); splx(s); return error; } -#define COMMON_START() TCPDEBUG0; \ - do { \ - if (inp == 0) { \ - splx(s); \ - return EINVAL; \ - } \ - tp = intotcpcb(inp); \ - TCPDEBUG1(); \ - } while(0) - -#define COMMON_END(req) out: TCPDEBUG2(req); splx(s); return error; goto out - +#define INI_NOLOCK 0 +#define INI_READ 1 +#define INI_WRITE 2 + +#define COMMON_START() \ + TCPDEBUG0; \ + do { \ + if (inirw == INI_READ) \ + INP_INFO_RLOCK(&tcbinfo); \ + else if (inirw == INI_WRITE) \ + INP_INFO_WLOCK(&tcbinfo); \ + inp = sotoinpcb(so); \ + if (inp == 0) { \ + if (inirw == INI_READ) \ + INP_INFO_RUNLOCK(&tcbinfo); \ + else if (inirw == INI_WRITE) \ + INP_INFO_WUNLOCK(&tcbinfo); \ + splx(s); \ + return EINVAL; \ + } \ + INP_LOCK(inp); \ + if (inirw == INI_READ) \ + INP_INFO_RUNLOCK(&tcbinfo); \ + tp = intotcpcb(inp); \ + TCPDEBUG1(); \ +} while(0) + +#define COMMON_END(req) \ +out: TCPDEBUG2(req); \ + do { \ + if (tp) \ + INP_UNLOCK(inp); \ + if (inirw == INI_WRITE) \ + INP_INFO_WUNLOCK(&tcbinfo); \ + splx(s); \ + return error; \ + goto out; \ +} while(0) /* * Give the socket an address. @@ -193,9 +235,10 @@ tcp_usr_bind(struct socket *so, struct sockaddr *nam, struct thread *td) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; struct sockaddr_in *sinp; + const int inirw = INI_READ; COMMON_START(); @@ -213,7 +256,6 @@ tcp_usr_bind(struct socket *so, struct sockaddr *nam, struct thread *td) if (error) goto out; COMMON_END(PRU_BIND); - } #ifdef INET6 @@ -222,9 +264,10 @@ tcp6_usr_bind(struct socket *so, struct sockaddr *nam, struct thread *td) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; struct sockaddr_in6 *sin6p; + const int inirw = INI_READ; COMMON_START(); @@ -268,8 +311,9 @@ tcp_usr_listen(struct socket *so, struct thread *td) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_READ; COMMON_START(); if (inp->inp_lport == 0) @@ -285,8 +329,9 @@ tcp6_usr_listen(struct socket *so, struct thread *td) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_READ; COMMON_START(); if (inp->inp_lport == 0) { @@ -314,9 +359,10 @@ tcp_usr_connect(struct socket *so, struct sockaddr *nam, struct thread *td) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; struct sockaddr_in *sinp; + const int inirw = INI_WRITE; COMMON_START(); @@ -345,9 +391,10 @@ tcp6_usr_connect(struct socket *so, struct sockaddr *nam, struct thread *td) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; struct sockaddr_in6 *sin6p; + const int inirw = INI_WRITE; COMMON_START(); @@ -402,8 +449,9 @@ tcp_usr_disconnect(struct socket *so) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_WRITE; COMMON_START(); tp = tcp_disconnect(tp); @@ -418,23 +466,49 @@ tcp_usr_disconnect(struct socket *so) static int tcp_usr_accept(struct socket *so, struct sockaddr **nam) { - int s = splnet(); + int s; int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp = NULL; struct tcpcb *tp = NULL; + struct sockaddr_in *sin; + const int inirw = INI_READ; TCPDEBUG0; if (so->so_state & SS_ISDISCONNECTED) { error = ECONNABORTED; goto out; } - if (inp == 0) { + + /* + * Do the malloc first in case it blocks. + */ + MALLOC(sin, struct sockaddr_in *, sizeof *sin, M_SONAME, + M_WAITOK | M_ZERO); + sin->sin_family = AF_INET; + sin->sin_len = sizeof(*sin); + + s = splnet(); + INP_INFO_RLOCK(&tcbinfo); + inp = sotoinpcb(so); + if (!inp) { + INP_INFO_RUNLOCK(&tcbinfo); splx(s); + free(sin, M_SONAME); return (EINVAL); } + INP_LOCK(inp); + INP_INFO_RUNLOCK(&tcbinfo); tp = intotcpcb(inp); TCPDEBUG1(); - in_setpeeraddr(so, nam); + + /* + * We inline in_setpeeraddr here, because we have already done + * the locking and the malloc. + */ + sin->sin_port = inp->inp_fport; + sin->sin_addr = inp->inp_faddr; + *nam = (struct sockaddr *)sin; + COMMON_END(PRU_ACCEPT); } @@ -442,26 +516,56 @@ tcp_usr_accept(struct socket *so, struct sockaddr **nam) static int tcp6_usr_accept(struct socket *so, struct sockaddr **nam) { - int s = splnet(); + int s; + struct inpcb *inp = NULL; int error = 0; - struct inpcb *inp = sotoinpcb(so); struct tcpcb *tp = NULL; + const int inirw = INI_READ; TCPDEBUG0; if (so->so_state & SS_ISDISCONNECTED) { error = ECONNABORTED; goto out; } + + s = splnet(); + INP_INFO_RLOCK(&tcbinfo); + inp = sotoinpcb(so); if (inp == 0) { + INP_INFO_RUNLOCK(&tcbinfo); splx(s); return (EINVAL); } + INP_LOCK(inp); + INP_INFO_RUNLOCK(&tcbinfo); tp = intotcpcb(inp); TCPDEBUG1(); in6_mapped_peeraddr(so, nam); COMMON_END(PRU_ACCEPT); } #endif /* INET6 */ + +/* + * This is the wrapper function for in_setsockaddr. We just pass down + * the pcbinfo for in_setsockaddr to lock. We don't want to do the locking + * here because in_setsockaddr will call malloc and can block. + */ +static int +tcp_sockaddr(struct socket *so, struct sockaddr **nam) +{ + return (in_setsockaddr(so, nam, &tcbinfo)); +} + +/* + * This is the wrapper function for in_setpeeraddr. We just pass down + * the pcbinfo for in_setpeeraddr to lock. + */ +static int +tcp_peeraddr(struct socket *so, struct sockaddr **nam) +{ + return (in_setpeeraddr(so, nam, &tcbinfo)); +} + /* * Mark the connection as being incapable of further output. */ @@ -470,8 +574,9 @@ tcp_usr_shutdown(struct socket *so) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_WRITE; COMMON_START(); socantsendmore(so); @@ -489,8 +594,9 @@ tcp_usr_rcvd(struct socket *so, int flags) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_READ; COMMON_START(); tcp_output(tp); @@ -510,13 +616,22 @@ tcp_usr_send(struct socket *so, int flags, struct mbuf *m, { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_WRITE; #ifdef INET6 int isipv6; #endif TCPDEBUG0; + /* + * Need write lock here because this function might call + * tcp_connect or tcp_usrclosed. + * We really want to have to this function upgrade from read lock + * to write lock. XXX + */ + INP_INFO_WLOCK(&tcbinfo); + inp = sotoinpcb(so); if (inp == NULL) { /* * OOPS! we lost a race, the TCP session got reset after @@ -532,6 +647,7 @@ tcp_usr_send(struct socket *so, int flags, struct mbuf *m, TCPDEBUG1(); goto out; } + INP_LOCK(inp); #ifdef INET6 isipv6 = nam && nam->sa_family == AF_INET6; #endif /* INET6 */ @@ -548,7 +664,7 @@ tcp_usr_send(struct socket *so, int flags, struct mbuf *m, } m_freem(control); /* empty control, just free it */ } - if(!(flags & PRUS_OOB)) { + if (!(flags & PRUS_OOB)) { sbappend(&so->so_snd, m); if (nam && tp->t_state < TCPS_SYN_SENT) { /* @@ -634,8 +750,9 @@ tcp_usr_abort(struct socket *so) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_WRITE; COMMON_START(); tp = tcp_drop(tp, ECONNABORTED); @@ -650,8 +767,9 @@ tcp_usr_rcvoob(struct socket *so, struct mbuf *m, int flags) { int s = splnet(); int error = 0; - struct inpcb *inp = sotoinpcb(so); + struct inpcb *inp; struct tcpcb *tp; + const int inirw = INI_READ; COMMON_START(); if ((so->so_oobmark == 0 && @@ -676,9 +794,9 @@ tcp_usr_rcvoob(struct socket *so, struct mbuf *m, int flags) struct pr_usrreqs tcp_usrreqs = { tcp_usr_abort, tcp_usr_accept, tcp_usr_attach, tcp_usr_bind, tcp_usr_connect, pru_connect2_notsupp, in_control, tcp_usr_detach, - tcp_usr_disconnect, tcp_usr_listen, in_setpeeraddr, tcp_usr_rcvd, + tcp_usr_disconnect, tcp_usr_listen, tcp_peeraddr, tcp_usr_rcvd, tcp_usr_rcvoob, tcp_usr_send, pru_sense_null, tcp_usr_shutdown, - in_setsockaddr, sosend, soreceive, sopoll + tcp_sockaddr, sosend, soreceive, sopoll }; #ifdef INET6 @@ -888,11 +1006,15 @@ tcp_ctloutput(so, sopt) error = 0; s = splnet(); /* XXX */ + INP_INFO_RLOCK(&tcbinfo); inp = sotoinpcb(so); if (inp == NULL) { + INP_INFO_RUNLOCK(&tcbinfo); splx(s); return (ECONNRESET); } + INP_LOCK(inp); + INP_INFO_RUNLOCK(&tcbinfo); if (sopt->sopt_level != IPPROTO_TCP) { #ifdef INET6 if (INP_CHECK_SOCKAF(so, AF_INET6)) @@ -900,6 +1022,7 @@ tcp_ctloutput(so, sopt) else #endif /* INET6 */ error = ip_ctloutput(so, sopt); + INP_UNLOCK(inp); splx(s); return (error); } @@ -987,6 +1110,7 @@ tcp_ctloutput(so, sopt) error = sooptcopyout(sopt, &optval, sizeof optval); break; } + INP_UNLOCK(inp); splx(s); return (error); } diff --git a/sys/netinet/tcp_var.h b/sys/netinet/tcp_var.h index e46a39a..e8eb361 100644 --- a/sys/netinet/tcp_var.h +++ b/sys/netinet/tcp_var.h @@ -38,6 +38,7 @@ #define _NETINET_TCP_VAR_H_ #include /* needed for in_conninfo, inp_gen_t */ +#include /* * Kernel variables for tcp. diff --git a/sys/netinet/udp_usrreq.c b/sys/netinet/udp_usrreq.c index af4769f..b23c60c 100644 --- a/sys/netinet/udp_usrreq.c +++ b/sys/netinet/udp_usrreq.c @@ -142,6 +142,7 @@ static int udp_output(struct inpcb *, struct mbuf *, struct sockaddr *, void udp_init() { + INP_INFO_LOCK_INIT(&udbinfo, "udp"); LIST_INIT(&udb); udbinfo.listhead = &udb; udbinfo.hashbase = hashinit(UDBHASHSIZE, M_PCB, &udbinfo.hashmask); @@ -194,7 +195,7 @@ udp_input(m, off) /* destination port of 0 is illegal, based on RFC768. */ if (uh->uh_dport == 0) - goto bad; + goto badunlocked; /* * Make mbuf data length reflect UDP length. @@ -204,7 +205,7 @@ udp_input(m, off) if (ip->ip_len != len) { if (len > ip->ip_len || len < sizeof(struct udphdr)) { udpstat.udps_badlen++; - goto bad; + goto badunlocked; } m_adj(m, len - ip->ip_len); /* ip->ip_len = len; */ @@ -244,6 +245,8 @@ udp_input(m, off) } else udpstat.udps_nosum++; + INP_INFO_RLOCK(&udbinfo); + if (IN_MULTICAST(ntohl(ip->ip_dst.s_addr)) || in_broadcast(ip->ip_dst, m->m_pkthdr.rcvif)) { struct inpcb *last; @@ -277,22 +280,25 @@ udp_input(m, off) udp_in6.uin6_init_done = udp_ip6.uip6_init_done = 0; #endif LIST_FOREACH(inp, &udb, inp_list) { + INP_LOCK(inp); + if (inp->inp_lport != uh->uh_dport) { + docontinue: + INP_UNLOCK(inp); + continue; + } #ifdef INET6 if ((inp->inp_vflag & INP_IPV4) == 0) - continue; + goto docontinue; #endif - if (inp->inp_lport != uh->uh_dport) - continue; if (inp->inp_laddr.s_addr != INADDR_ANY) { - if (inp->inp_laddr.s_addr != - ip->ip_dst.s_addr) - continue; + if (inp->inp_laddr.s_addr != ip->ip_dst.s_addr) + goto docontinue; } if (inp->inp_faddr.s_addr != INADDR_ANY) { if (inp->inp_faddr.s_addr != ip->ip_src.s_addr || inp->inp_fport != uh->uh_sport) - continue; + goto docontinue; } if (last != NULL) { @@ -309,6 +315,7 @@ udp_input(m, off) udp_append(last, ip, n, iphlen + sizeof(struct udphdr)); + INP_UNLOCK(last); } last = inp; /* @@ -330,15 +337,19 @@ udp_input(m, off) * for a broadcast or multicast datgram.) */ udpstat.udps_noportbcast++; + INP_INFO_RUNLOCK(&udbinfo); goto bad; } #ifdef IPSEC /* check AH/ESP integrity. */ if (ipsec4_in_reject_so(m, last->inp_socket)) { ipsecstat.in_polvio++; + INP_INFO_RUNLOCK(&udbinfo); goto bad; } #endif /*IPSEC*/ + INP_UNLOCK(last); + INP_INFO_RUNLOCK(&udbinfo); udp_append(last, ip, m, iphlen + sizeof(struct udphdr)); return; } @@ -347,6 +358,7 @@ udp_input(m, off) */ inp = in_pcblookup_hash(&udbinfo, ip->ip_src, uh->uh_sport, ip->ip_dst, uh->uh_dport, 1, m->m_pkthdr.rcvif); + INP_INFO_RUNLOCK(&udbinfo); if (inp == NULL) { if (log_in_vain) { char buf[4*sizeof "123"]; @@ -369,15 +381,18 @@ udp_input(m, off) *ip = save_ip; ip->ip_len += iphlen; icmp_error(m, ICMP_UNREACH, ICMP_UNREACH_PORT, 0, 0); + INP_INFO_RUNLOCK(&udbinfo); return; } #ifdef IPSEC if (ipsec4_in_reject_so(m, inp->inp_socket)) { ipsecstat.in_polvio++; + INP_INFO_RUNLOCK(&udbinfo); goto bad; } #endif /*IPSEC*/ + INP_LOCK(inp); /* * Construct sockaddr format source address. * Stuff source address and datagram in user buffer. @@ -412,8 +427,12 @@ udp_input(m, off) goto bad; } sorwakeup(inp->inp_socket); + INP_UNLOCK(inp); return; bad: + if (inp) + INP_UNLOCK(inp); +badunlocked: m_freem(m); if (opts) m_freem(opts); @@ -532,13 +551,20 @@ udp_ctlinput(cmd, sa, vip) if (ip) { s = splnet(); uh = (struct udphdr *)((caddr_t)ip + (ip->ip_hl << 2)); + INP_INFO_RLOCK(&udbinfo); inp = in_pcblookup_hash(&udbinfo, faddr, uh->uh_dport, ip->ip_src, uh->uh_sport, 0, NULL); - if (inp != NULL && inp->inp_socket != NULL) - (*notify)(inp, inetctlerrmap[cmd]); + if (inp != NULL) { + INP_LOCK(inp); + if(inp->inp_socket != NULL) { + (*notify)(inp, inetctlerrmap[cmd]); + } + INP_UNLOCK(inp); + } + INP_INFO_RUNLOCK(&udbinfo); splx(s); } else - in_pcbnotifyall(&udb, faddr, inetctlerrmap[cmd], notify); + in_pcbnotifyall(&udbinfo, faddr, inetctlerrmap[cmd], notify); } static int @@ -584,21 +610,26 @@ udp_pcblist(SYSCTL_HANDLER_ARGS) return ENOMEM; s = splnet(); + INP_INFO_RLOCK(&udbinfo); for (inp = LIST_FIRST(udbinfo.listhead), i = 0; inp && i < n; inp = LIST_NEXT(inp, inp_list)) { + INP_LOCK(inp); if (inp->inp_gencnt <= gencnt) { if (cr_canseesocket(req->td->td_ucred, inp->inp_socket)) continue; inp_list[i++] = inp; } + INP_UNLOCK(inp); } + INP_INFO_RUNLOCK(&udbinfo); splx(s); n = i; error = 0; for (i = 0; i < n; i++) { inp = inp_list[i]; + INP_LOCK(inp); if (inp->inp_gencnt <= gencnt) { struct xinpcb xi; xi.xi_len = sizeof xi; @@ -608,6 +639,7 @@ udp_pcblist(SYSCTL_HANDLER_ARGS) sotoxsocket(inp->inp_socket, &xi.xi_socket); error = SYSCTL_OUT(req, &xi, sizeof xi); } + INP_UNLOCK(inp); } if (!error) { /* @@ -618,9 +650,11 @@ udp_pcblist(SYSCTL_HANDLER_ARGS) * might be necessary to retry. */ s = splnet(); + INP_INFO_RLOCK(&udbinfo); xig.xig_gen = udbinfo.ipi_gencnt; xig.xig_sogen = so_gencnt; xig.xig_count = udbinfo.ipi_count; + INP_INFO_RUNLOCK(&udbinfo); splx(s); error = SYSCTL_OUT(req, &xig, sizeof xig); } @@ -646,6 +680,7 @@ udp_getcred(SYSCTL_HANDLER_ARGS) if (error) return (error); s = splnet(); + INP_INFO_RLOCK(&udbinfo); inp = in_pcblookup_hash(&udbinfo, addrs[1].sin_addr, addrs[1].sin_port, addrs[0].sin_addr, addrs[0].sin_port, 1, NULL); if (inp == NULL || inp->inp_socket == NULL) { @@ -658,6 +693,7 @@ udp_getcred(SYSCTL_HANDLER_ARGS) cru2x(inp->inp_socket->so_cred, &xuc); error = SYSCTL_OUT(req, &xuc, sizeof(struct xucred)); out: + INP_INFO_RUNLOCK(&udbinfo); splx(s); return (error); } @@ -796,12 +832,17 @@ udp_abort(struct socket *so) struct inpcb *inp; int s; + INP_INFO_WLOCK(&udbinfo); inp = sotoinpcb(so); - if (inp == 0) + if (inp == 0) { + INP_INFO_WUNLOCK(&udbinfo); return EINVAL; /* ??? possible? panic instead? */ + } + INP_LOCK(inp); soisdisconnected(so); s = splnet(); in_pcbdetach(inp); + INP_INFO_WUNLOCK(&udbinfo); splx(s); return 0; } @@ -812,13 +853,17 @@ udp_attach(struct socket *so, int proto, struct thread *td) struct inpcb *inp; int s, error; + INP_INFO_WLOCK(&udbinfo); inp = sotoinpcb(so); - if (inp != 0) + if (inp != 0) { + INP_INFO_WUNLOCK(&udbinfo); return EINVAL; - + } error = soreserve(so, udp_sendspace, udp_recvspace); - if (error) + if (error) { + INP_INFO_WUNLOCK(&udbinfo); return error; + } s = splnet(); error = in_pcballoc(so, &udbinfo, td); splx(s); @@ -826,8 +871,11 @@ udp_attach(struct socket *so, int proto, struct thread *td) return error; inp = (struct inpcb *)so->so_pcb; + INP_LOCK(inp); + INP_INFO_WUNLOCK(&udbinfo); inp->inp_vflag |= INP_IPV4; inp->inp_ip_ttl = ip_defttl; + INP_UNLOCK(inp); return 0; } @@ -837,12 +885,18 @@ udp_bind(struct socket *so, struct sockaddr *nam, struct thread *td) struct inpcb *inp; int s, error; + INP_INFO_WLOCK(&udbinfo); inp = sotoinpcb(so); - if (inp == 0) + if (inp == 0) { + INP_INFO_WUNLOCK(&udbinfo); return EINVAL; + } + INP_LOCK(inp); s = splnet(); error = in_pcbbind(inp, nam, td); splx(s); + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&udbinfo); return error; } @@ -853,11 +907,18 @@ udp_connect(struct socket *so, struct sockaddr *nam, struct thread *td) int s, error; struct sockaddr_in *sin; + INP_INFO_WLOCK(&udbinfo); inp = sotoinpcb(so); - if (inp == 0) + if (inp == 0) { + INP_INFO_WUNLOCK(&udbinfo); return EINVAL; - if (inp->inp_faddr.s_addr != INADDR_ANY) + } + INP_LOCK(inp); + if (inp->inp_faddr.s_addr != INADDR_ANY) { + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&udbinfo); return EISCONN; + } s = splnet(); sin = (struct sockaddr_in *)nam; if (td && jailed(td->td_ucred)) @@ -866,6 +927,8 @@ udp_connect(struct socket *so, struct sockaddr *nam, struct thread *td) splx(s); if (error == 0) soisconnected(so); + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&udbinfo); return error; } @@ -875,11 +938,16 @@ udp_detach(struct socket *so) struct inpcb *inp; int s; + INP_INFO_WLOCK(&udbinfo); inp = sotoinpcb(so); - if (inp == 0) + if (inp == 0) { + INP_INFO_WUNLOCK(&udbinfo); return EINVAL; + } + INP_LOCK(inp); s = splnet(); in_pcbdetach(inp); + INP_INFO_WUNLOCK(&udbinfo); splx(s); return 0; } @@ -890,15 +958,24 @@ udp_disconnect(struct socket *so) struct inpcb *inp; int s; + INP_INFO_WLOCK(&udbinfo); inp = sotoinpcb(so); - if (inp == 0) + if (inp == 0) { + INP_INFO_WUNLOCK(&udbinfo); return EINVAL; - if (inp->inp_faddr.s_addr == INADDR_ANY) + } + INP_LOCK(inp); + if (inp->inp_faddr.s_addr == INADDR_ANY) { + INP_INFO_WUNLOCK(&udbinfo); + INP_UNLOCK(inp); return ENOTCONN; + } s = splnet(); in_pcbdisconnect(inp); inp->inp_laddr.s_addr = INADDR_ANY; + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&udbinfo); splx(s); so->so_state &= ~SS_ISCONNECTED; /* XXX */ return 0; @@ -909,13 +986,20 @@ udp_send(struct socket *so, int flags, struct mbuf *m, struct sockaddr *addr, struct mbuf *control, struct thread *td) { struct inpcb *inp; + int ret; + INP_INFO_WLOCK(&udbinfo); inp = sotoinpcb(so); if (inp == 0) { + INP_INFO_WUNLOCK(&udbinfo); m_freem(m); return EINVAL; } - return udp_output(inp, m, addr, control, td); + INP_LOCK(inp); + ret = udp_output(inp, m, addr, control, td); + INP_UNLOCK(inp); + INP_INFO_WUNLOCK(&udbinfo); + return ret; } int @@ -923,17 +1007,44 @@ udp_shutdown(struct socket *so) { struct inpcb *inp; + INP_INFO_RLOCK(&udbinfo); inp = sotoinpcb(so); - if (inp == 0) + if (inp == 0) { + INP_INFO_RUNLOCK(&udbinfo); return EINVAL; + } + INP_LOCK(inp); + INP_INFO_RUNLOCK(&udbinfo); socantsendmore(so); + INP_UNLOCK(inp); return 0; } +/* + * This is the wrapper function for in_setsockaddr. We just pass down + * the pcbinfo for in_setsockaddr to lock. We don't want to do the locking + * here because in_setsockaddr will call malloc and might block. + */ +static int +udp_sockaddr(struct socket *so, struct sockaddr **nam) +{ + return (in_setsockaddr(so, nam, &udbinfo)); +} + +/* + * This is the wrapper function for in_setpeeraddr. We just pass down + * the pcbinfo for in_setpeeraddr to lock. + */ +static int +udp_peeraddr(struct socket *so, struct sockaddr **nam) +{ + return (in_setpeeraddr(so, nam, &udbinfo)); +} + struct pr_usrreqs udp_usrreqs = { udp_abort, pru_accept_notsupp, udp_attach, udp_bind, udp_connect, pru_connect2_notsupp, in_control, udp_detach, udp_disconnect, - pru_listen_notsupp, in_setpeeraddr, pru_rcvd_notsupp, + pru_listen_notsupp, udp_peeraddr, pru_rcvd_notsupp, pru_rcvoob_notsupp, udp_send, pru_sense_null, udp_shutdown, - in_setsockaddr, sosend, soreceive, sopoll + udp_sockaddr, sosend, soreceive, sopoll }; diff --git a/sys/netinet6/in6_pcb.c b/sys/netinet6/in6_pcb.c index 09a5c29..0f18bd1 100644 --- a/sys/netinet6/in6_pcb.c +++ b/sys/netinet6/in6_pcb.c @@ -93,6 +93,7 @@ #include #include #include +#include #include #include #include @@ -719,7 +720,7 @@ in6_mapped_sockaddr(struct socket *so, struct sockaddr **nam) if (inp == NULL) return EINVAL; if (inp->inp_vflag & INP_IPV4) { - error = in_setsockaddr(so, nam); + error = in_setsockaddr(so, nam, &tcbinfo); if (error == 0) in6_sin_2_v4mapsin6_in_sock(nam); } else @@ -738,7 +739,7 @@ in6_mapped_peeraddr(struct socket *so, struct sockaddr **nam) if (inp == NULL) return EINVAL; if (inp->inp_vflag & INP_IPV4) { - error = in_setpeeraddr(so, nam); + error = in_setpeeraddr(so, nam, &tcbinfo); if (error == 0) in6_sin_2_v4mapsin6_in_sock(nam); } else -- cgit v1.1