[PATCH 11/13] net: Track socket refcounts in skb_steal_sock()
Kleber Souza
kleber.souza at canonical.com
Thu Sep 3 10:57:48 UTC 2020
On 31.08.20 06:03, Khalid Elmously wrote:
> From: Joe Stringer <joe at wand.net.nz>
This patch is missing:
BugLink: https://bugs.launchpad.net/bugs/1887740
>
> [ upstream commit 71489e21d720a09388b565d60ef87ae993c10528 ]
>
> Refactor the UDP/TCP handlers slightly to allow skb_steal_sock() to make
> the determination of whether the socket is reference counted in the case
> where it is prefetched by earlier logic such as early_demux.
>
> Signed-off-by: Joe Stringer <joe at wand.net.nz>
> Signed-off-by: Alexei Starovoitov <ast at kernel.org>
> Acked-by: Martin KaFai Lau <kafai at fb.com>
> Link: https://lore.kernel.org/bpf/20200329225342.16317-3-joe@wand.net.nz
> Signed-off-by: Daniel Borkmann <daniel at iogearbox.net>
> Signed-off-by: Khalid Elmously <khalid.elmously at canonical.com>
> ---
> include/net/inet6_hashtables.h | 3 +--
> include/net/inet_hashtables.h | 3 +--
> include/net/sock.h | 10 +++++++++-
> net/ipv4/udp.c | 6 ++++--
> net/ipv6/udp.c | 9 ++++++---
> 5 files changed, 21 insertions(+), 10 deletions(-)
>
> diff --git a/include/net/inet6_hashtables.h b/include/net/inet6_hashtables.h
> index fe96bf247aac..81b965953036 100644
> --- a/include/net/inet6_hashtables.h
> +++ b/include/net/inet6_hashtables.h
> @@ -85,9 +85,8 @@ static inline struct sock *__inet6_lookup_skb(struct inet_hashinfo *hashinfo,
> int iif, int sdif,
> bool *refcounted)
> {
> - struct sock *sk = skb_steal_sock(skb);
> + struct sock *sk = skb_steal_sock(skb, refcounted);
>
> - *refcounted = true;
> if (sk)
> return sk;
>
> diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
> index d0019d3395cf..ad64ba6a057f 100644
> --- a/include/net/inet_hashtables.h
> +++ b/include/net/inet_hashtables.h
> @@ -379,10 +379,9 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
> const int sdif,
> bool *refcounted)
> {
> - struct sock *sk = skb_steal_sock(skb);
> + struct sock *sk = skb_steal_sock(skb, refcounted);
> const struct iphdr *iph = ip_hdr(skb);
>
> - *refcounted = true;
> if (sk)
> return sk;
>
> diff --git a/include/net/sock.h b/include/net/sock.h
> index b754050401d8..6cb1f0efa01b 100644
> --- a/include/net/sock.h
> +++ b/include/net/sock.h
> @@ -2492,15 +2492,23 @@ skb_sk_is_prefetched(struct sk_buff *skb)
> #endif /* CONFIG_INET */
> }
>
> -static inline struct sock *skb_steal_sock(struct sk_buff *skb)
> +/**
> + * skb_steal_sock
> + * @skb to steal the socket from
> + * @refcounted is set to true if the socket is reference-counted
> + */
> +static inline struct sock *
> +skb_steal_sock(struct sk_buff *skb, bool *refcounted)
> {
> if (skb->sk) {
> struct sock *sk = skb->sk;
>
> + *refcounted = true;
> skb->destructor = NULL;
> skb->sk = NULL;
> return sk;
> }
> + *refcounted = false;
> return NULL;
> }
>
> diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
> index f3b7cb725c1b..b7b01f721310 100644
> --- a/net/ipv4/udp.c
> +++ b/net/ipv4/udp.c
> @@ -2286,6 +2286,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
> struct rtable *rt = skb_rtable(skb);
> __be32 saddr, daddr;
> struct net *net = dev_net(skb->dev);
> + bool refcounted;
>
> /*
> * Validate the packet.
> @@ -2311,7 +2312,7 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
> if (udp4_csum_init(skb, uh, proto))
> goto csum_error;
>
> - sk = skb_steal_sock(skb);
> + sk = skb_steal_sock(skb, &refcounted);
> if (sk) {
> struct dst_entry *dst = skb_dst(skb);
> int ret;
> @@ -2320,7 +2321,8 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
> udp_sk_rx_dst_set(sk, dst);
>
> ret = udp_unicast_rcv_skb(sk, skb, uh);
> - sock_put(sk);
> + if (refcounted)
> + sock_put(sk);
> return ret;
> }
>
> diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
> index 9fec580c968e..3d34e00124ff 100644
> --- a/net/ipv6/udp.c
> +++ b/net/ipv6/udp.c
> @@ -844,6 +844,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
> struct net *net = dev_net(skb->dev);
> struct udphdr *uh;
> struct sock *sk;
> + bool refcounted;
> u32 ulen = 0;
>
> if (!pskb_may_pull(skb, sizeof(struct udphdr)))
> @@ -880,7 +881,7 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
> goto csum_error;
>
> /* Check if the socket is already available, e.g. due to early demux */
> - sk = skb_steal_sock(skb);
> + sk = skb_steal_sock(skb, &refcounted);
> if (sk) {
> struct dst_entry *dst = skb_dst(skb);
> int ret;
> @@ -889,12 +890,14 @@ int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
> udp6_sk_rx_dst_set(sk, dst);
>
> if (!uh->check && !udp_sk(sk)->no_check6_rx) {
> - sock_put(sk);
> + if (refcounted)
> + sock_put(sk);
> goto report_csum_error;
> }
>
> ret = udp6_unicast_rcv_skb(sk, skb, uh);
> - sock_put(sk);
> + if (refcounted)
> + sock_put(sk);
> return ret;
> }
>
>
More information about the kernel-team
mailing list