diff options
-rw-r--r-- | include/net/inet_sock.h | 27 | ||||
-rw-r--r-- | net/xfrm/xfrm_policy.c | 2 |
2 files changed, 25 insertions, 4 deletions
diff --git a/include/net/inet_sock.h b/include/net/inet_sock.h index 2134e6d..625bdf9 100644 --- a/include/net/inet_sock.h +++ b/include/net/inet_sock.h @@ -210,18 +210,37 @@ struct inet_sock { #define IP_CMSG_ORIGDSTADDR BIT(6) #define IP_CMSG_CHECKSUM BIT(7) -/* SYNACK messages might be attached to request sockets. +/** + * sk_to_full_sk - Access to a full socket + * @sk: pointer to a socket + * + * SYNACK messages might be attached to request sockets. * Some places want to reach the listener in this case. */ -static inline struct sock *skb_to_full_sk(const struct sk_buff *skb) +static inline struct sock *sk_to_full_sk(struct sock *sk) { - struct sock *sk = skb->sk; - +#ifdef CONFIG_INET if (sk && sk->sk_state == TCP_NEW_SYN_RECV) sk = inet_reqsk(sk)->rsk_listener; +#endif + return sk; +} + +/* sk_to_full_sk() variant with a const argument */ +static inline const struct sock *sk_const_to_full_sk(const struct sock *sk) +{ +#ifdef CONFIG_INET + if (sk && sk->sk_state == TCP_NEW_SYN_RECV) + sk = ((const struct request_sock *)sk)->rsk_listener; +#endif return sk; } +static inline struct sock *skb_to_full_sk(const struct sk_buff *skb) +{ + return sk_to_full_sk(skb->sk); +} + static inline struct inet_sock *inet_sk(const struct sock *sk) { return (struct inet_sock *)sk; diff --git a/net/xfrm/xfrm_policy.c b/net/xfrm/xfrm_policy.c index 09bfcba..18276f0 100644 --- a/net/xfrm/xfrm_policy.c +++ b/net/xfrm/xfrm_policy.c @@ -2198,6 +2198,7 @@ struct dst_entry *xfrm_lookup(struct net *net, struct dst_entry *dst_orig, xdst = NULL; route = NULL; + sk = sk_const_to_full_sk(sk); if (sk && sk->sk_policy[XFRM_POLICY_OUT]) { num_pols = 1; pols[0] = xfrm_sk_policy_lookup(sk, XFRM_POLICY_OUT, fl); @@ -2477,6 +2478,7 @@ int __xfrm_policy_check(struct sock *sk, int dir, struct sk_buff *skb, } pol = NULL; + sk = sk_to_full_sk(sk); if (sk && sk->sk_policy[dir]) { pol = xfrm_sk_policy_lookup(sk, dir, &fl); if (IS_ERR(pol)) { |