diff options
Diffstat (limited to 'net/ipv4/inet_diag.c')
-rw-r--r-- | net/ipv4/inet_diag.c | 146 |
1 files changed, 78 insertions, 68 deletions
diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c index 46d1e71..570e61f 100644 --- a/net/ipv4/inet_diag.c +++ b/net/ipv4/inet_diag.c @@ -46,9 +46,6 @@ struct inet_diag_entry { u16 userlocks; }; -#define INET_DIAG_PUT(skb, attrtype, attrlen) \ - RTA_DATA(__RTA_PUT(skb, attrtype, attrlen)) - static DEFINE_MUTEX(inet_diag_table_mutex); static const struct inet_diag_handler *inet_diag_lock_handler(int proto) @@ -78,24 +75,22 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, const struct inet_sock *inet = inet_sk(sk); struct inet_diag_msg *r; struct nlmsghdr *nlh; + struct nlattr *attr; void *info = NULL; - struct inet_diag_meminfo *minfo = NULL; - unsigned char *b = skb_tail_pointer(skb); const struct inet_diag_handler *handler; int ext = req->idiag_ext; handler = inet_diag_table[req->sdiag_protocol]; BUG_ON(handler == NULL); - nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r)); - nlh->nlmsg_flags = nlmsg_flags; + nlh = nlmsg_put(skb, pid, seq, unlh->nlmsg_type, sizeof(*r), + nlmsg_flags); + if (!nlh) + return -EMSGSIZE; - r = NLMSG_DATA(nlh); + r = nlmsg_data(nlh); BUG_ON(sk->sk_state == TCP_TIME_WAIT); - if (ext & (1 << (INET_DIAG_MEMINFO - 1))) - minfo = INET_DIAG_PUT(skb, INET_DIAG_MEMINFO, sizeof(*minfo)); - r->idiag_family = sk->sk_family; r->idiag_state = sk->sk_state; r->idiag_timer = 0; @@ -113,7 +108,8 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, * hence this needs to be included regardless of socket family. */ if (ext & (1 << (INET_DIAG_TOS - 1))) - RTA_PUT_U8(skb, INET_DIAG_TOS, inet->tos); + if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0) + goto errout; #if IS_ENABLED(CONFIG_IPV6) if (r->idiag_family == AF_INET6) { @@ -121,24 +117,31 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, *(struct in6_addr *)r->id.idiag_src = np->rcv_saddr; *(struct in6_addr *)r->id.idiag_dst = np->daddr; + if (ext & (1 << (INET_DIAG_TCLASS - 1))) - RTA_PUT_U8(skb, INET_DIAG_TCLASS, np->tclass); + if (nla_put_u8(skb, INET_DIAG_TCLASS, np->tclass) < 0) + goto errout; } #endif r->idiag_uid = sock_i_uid(sk); r->idiag_inode = sock_i_ino(sk); - if (minfo) { - minfo->idiag_rmem = sk_rmem_alloc_get(sk); - minfo->idiag_wmem = sk->sk_wmem_queued; - minfo->idiag_fmem = sk->sk_forward_alloc; - minfo->idiag_tmem = sk_wmem_alloc_get(sk); + if (ext & (1 << (INET_DIAG_MEMINFO - 1))) { + struct inet_diag_meminfo minfo = { + .idiag_rmem = sk_rmem_alloc_get(sk), + .idiag_wmem = sk->sk_wmem_queued, + .idiag_fmem = sk->sk_forward_alloc, + .idiag_tmem = sk_wmem_alloc_get(sk), + }; + + if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0) + goto errout; } if (ext & (1 << (INET_DIAG_SKMEMINFO - 1))) if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO)) - goto rtattr_failure; + goto errout; if (icsk == NULL) { handler->idiag_get_info(sk, r, NULL); @@ -165,16 +168,20 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, } #undef EXPIRES_IN_MS - if (ext & (1 << (INET_DIAG_INFO - 1))) - info = INET_DIAG_PUT(skb, INET_DIAG_INFO, sizeof(struct tcp_info)); - - if ((ext & (1 << (INET_DIAG_CONG - 1))) && icsk->icsk_ca_ops) { - const size_t len = strlen(icsk->icsk_ca_ops->name); + if (ext & (1 << (INET_DIAG_INFO - 1))) { + attr = nla_reserve(skb, INET_DIAG_INFO, + sizeof(struct tcp_info)); + if (!attr) + goto errout; - strcpy(INET_DIAG_PUT(skb, INET_DIAG_CONG, len + 1), - icsk->icsk_ca_ops->name); + info = nla_data(attr); } + if ((ext & (1 << (INET_DIAG_CONG - 1))) && icsk->icsk_ca_ops) + if (nla_put_string(skb, INET_DIAG_CONG, + icsk->icsk_ca_ops->name) < 0) + goto errout; + handler->idiag_get_info(sk, r, info); if (sk->sk_state < TCP_TIME_WAIT && @@ -182,12 +189,10 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk, icsk->icsk_ca_ops->get_info(sk, ext, skb); out: - nlh->nlmsg_len = skb_tail_pointer(skb) - b; - return skb->len; + return nlmsg_end(skb, nlh); -rtattr_failure: -nlmsg_failure: - nlmsg_trim(skb, b); +errout: + nlmsg_cancel(skb, nlh); return -EMSGSIZE; } EXPORT_SYMBOL_GPL(inet_sk_diag_fill); @@ -208,14 +213,15 @@ static int inet_twsk_diag_fill(struct inet_timewait_sock *tw, { long tmo; struct inet_diag_msg *r; - const unsigned char *previous_tail = skb_tail_pointer(skb); - struct nlmsghdr *nlh = NLMSG_PUT(skb, pid, seq, - unlh->nlmsg_type, sizeof(*r)); + struct nlmsghdr *nlh; - r = NLMSG_DATA(nlh); - BUG_ON(tw->tw_state != TCP_TIME_WAIT); + nlh = nlmsg_put(skb, pid, seq, unlh->nlmsg_type, sizeof(*r), + nlmsg_flags); + if (!nlh) + return -EMSGSIZE; - nlh->nlmsg_flags = nlmsg_flags; + r = nlmsg_data(nlh); + BUG_ON(tw->tw_state != TCP_TIME_WAIT); tmo = tw->tw_ttd - jiffies; if (tmo < 0) @@ -245,11 +251,8 @@ static int inet_twsk_diag_fill(struct inet_timewait_sock *tw, *(struct in6_addr *)r->id.idiag_dst = tw6->tw_v6_daddr; } #endif - nlh->nlmsg_len = skb_tail_pointer(skb) - previous_tail; - return skb->len; -nlmsg_failure: - nlmsg_trim(skb, previous_tail); - return -EMSGSIZE; + + return nlmsg_end(skb, nlh); } static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, @@ -269,16 +272,17 @@ int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *in_s int err; struct sock *sk; struct sk_buff *rep; + struct net *net = sock_net(in_skb->sk); err = -EINVAL; if (req->sdiag_family == AF_INET) { - sk = inet_lookup(&init_net, hashinfo, req->id.idiag_dst[0], + sk = inet_lookup(net, hashinfo, req->id.idiag_dst[0], req->id.idiag_dport, req->id.idiag_src[0], req->id.idiag_sport, req->id.idiag_if); } #if IS_ENABLED(CONFIG_IPV6) else if (req->sdiag_family == AF_INET6) { - sk = inet6_lookup(&init_net, hashinfo, + sk = inet6_lookup(net, hashinfo, (struct in6_addr *)req->id.idiag_dst, req->id.idiag_dport, (struct in6_addr *)req->id.idiag_src, @@ -298,23 +302,23 @@ int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *in_s if (err) goto out; - err = -ENOMEM; - rep = alloc_skb(NLMSG_SPACE((sizeof(struct inet_diag_msg) + - sizeof(struct inet_diag_meminfo) + - sizeof(struct tcp_info) + 64)), - GFP_KERNEL); - if (!rep) + rep = nlmsg_new(sizeof(struct inet_diag_msg) + + sizeof(struct inet_diag_meminfo) + + sizeof(struct tcp_info) + 64, GFP_KERNEL); + if (!rep) { + err = -ENOMEM; goto out; + } err = sk_diag_fill(sk, rep, req, NETLINK_CB(in_skb).pid, nlh->nlmsg_seq, 0, nlh); if (err < 0) { WARN_ON(err == -EMSGSIZE); - kfree_skb(rep); + nlmsg_free(rep); goto out; } - err = netlink_unicast(sock_diag_nlsk, rep, NETLINK_CB(in_skb).pid, + err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).pid, MSG_DONTWAIT); if (err > 0) err = 0; @@ -592,15 +596,16 @@ static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk, { const struct inet_request_sock *ireq = inet_rsk(req); struct inet_sock *inet = inet_sk(sk); - unsigned char *b = skb_tail_pointer(skb); struct inet_diag_msg *r; struct nlmsghdr *nlh; long tmo; - nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r)); - nlh->nlmsg_flags = NLM_F_MULTI; - r = NLMSG_DATA(nlh); + nlh = nlmsg_put(skb, pid, seq, unlh->nlmsg_type, sizeof(*r), + NLM_F_MULTI); + if (!nlh) + return -EMSGSIZE; + r = nlmsg_data(nlh); r->idiag_family = sk->sk_family; r->idiag_state = TCP_SYN_RECV; r->idiag_timer = 1; @@ -628,13 +633,8 @@ static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk, *(struct in6_addr *)r->id.idiag_dst = inet6_rsk(req)->rmt_addr; } #endif - nlh->nlmsg_len = skb_tail_pointer(skb) - b; - - return skb->len; -nlmsg_failure: - nlmsg_trim(skb, b); - return -1; + return nlmsg_end(skb, nlh); } static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk, @@ -725,6 +725,7 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb, { int i, num; int s_i, s_num; + struct net *net = sock_net(skb->sk); s_i = cb->args[1]; s_num = num = cb->args[2]; @@ -744,6 +745,9 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb, sk_nulls_for_each(sk, node, &ilb->head) { struct inet_sock *inet = inet_sk(sk); + if (!net_eq(sock_net(sk), net)) + continue; + if (num < s_num) { num++; continue; @@ -814,6 +818,8 @@ skip_listen_ht: sk_nulls_for_each(sk, node, &head->chain) { struct inet_sock *inet = inet_sk(sk); + if (!net_eq(sock_net(sk), net)) + continue; if (num < s_num) goto next_normal; if (!(r->idiag_states & (1 << sk->sk_state))) @@ -840,6 +846,8 @@ next_normal: inet_twsk_for_each(tw, node, &head->twchain) { + if (!net_eq(twsk_net(tw), net)) + continue; if (num < s_num) goto next_dying; @@ -892,7 +900,7 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb) if (nlmsg_attrlen(cb->nlh, hdrlen)) bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE); - return __inet_diag_dump(skb, cb, (struct inet_diag_req_v2 *)NLMSG_DATA(cb->nlh), bc); + return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh), bc); } static inline int inet_diag_type2proto(int type) @@ -909,7 +917,7 @@ static inline int inet_diag_type2proto(int type) static int inet_diag_dump_compat(struct sk_buff *skb, struct netlink_callback *cb) { - struct inet_diag_req *rc = NLMSG_DATA(cb->nlh); + struct inet_diag_req *rc = nlmsg_data(cb->nlh); struct inet_diag_req_v2 req; struct nlattr *bc = NULL; int hdrlen = sizeof(struct inet_diag_req); @@ -929,7 +937,7 @@ static int inet_diag_dump_compat(struct sk_buff *skb, struct netlink_callback *c static int inet_diag_get_exact_compat(struct sk_buff *in_skb, const struct nlmsghdr *nlh) { - struct inet_diag_req *rc = NLMSG_DATA(nlh); + struct inet_diag_req *rc = nlmsg_data(nlh); struct inet_diag_req_v2 req; req.sdiag_family = rc->idiag_family; @@ -944,6 +952,7 @@ static int inet_diag_get_exact_compat(struct sk_buff *in_skb, static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh) { int hdrlen = sizeof(struct inet_diag_req); + struct net *net = sock_net(skb->sk); if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX || nlmsg_len(nlh) < hdrlen) @@ -964,7 +973,7 @@ static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh) struct netlink_dump_control c = { .dump = inet_diag_dump_compat, }; - return netlink_dump_start(sock_diag_nlsk, skb, nlh, &c); + return netlink_dump_start(net->diag_nlsk, skb, nlh, &c); } } @@ -974,6 +983,7 @@ static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh) static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h) { int hdrlen = sizeof(struct inet_diag_req_v2); + struct net *net = sock_net(skb->sk); if (nlmsg_len(h) < hdrlen) return -EINVAL; @@ -992,11 +1002,11 @@ static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h) struct netlink_dump_control c = { .dump = inet_diag_dump, }; - return netlink_dump_start(sock_diag_nlsk, skb, h, &c); + return netlink_dump_start(net->diag_nlsk, skb, h, &c); } } - return inet_diag_get_exact(skb, h, (struct inet_diag_req_v2 *)NLMSG_DATA(h)); + return inet_diag_get_exact(skb, h, nlmsg_data(h)); } static const struct sock_diag_handler inet_diag_handler = { |