Merge git://git.kernel.org/pub/scm/linux/kernel/git/pablo/nf-next
[muen/linux.git] / net / netfilter / nf_conntrack_proto_tcp.c
index 247b89784a6fb41141bb20cc4d4e5987b33e17a7..1bcf9984d45e8601646cb2b99dc5f3113a5c8b0a 100644 (file)
@@ -717,35 +717,26 @@ static const u8 tcp_valid_flags[(TCPHDR_FIN|TCPHDR_SYN|TCPHDR_RST|TCPHDR_ACK|
        [TCPHDR_ACK|TCPHDR_URG]                 = 1,
 };
 
-static void tcp_error_log(const struct sk_buff *skb, struct net *net,
-                         u8 pf, const char *msg)
+static void tcp_error_log(const struct sk_buff *skb,
+                         const struct nf_hook_state *state,
+                         const char *msg)
 {
-       nf_l4proto_log_invalid(skb, net, pf, IPPROTO_TCP, "%s", msg);
+       nf_l4proto_log_invalid(skb, state->net, state->pf, IPPROTO_TCP, "%s", msg);
 }
 
 /* Protect conntrack agaist broken packets. Code taken from ipt_unclean.c.  */
-static int tcp_error(struct net *net, struct nf_conn *tmpl,
-                    struct sk_buff *skb,
-                    unsigned int dataoff,
-                    u_int8_t pf,
-                    unsigned int hooknum)
+static bool tcp_error(const struct tcphdr *th,
+                     struct sk_buff *skb,
+                     unsigned int dataoff,
+                     const struct nf_hook_state *state)
 {
-       const struct tcphdr *th;
-       struct tcphdr _tcph;
        unsigned int tcplen = skb->len - dataoff;
-       u_int8_t tcpflags;
-
-       /* Smaller that minimal TCP header? */
-       th = skb_header_pointer(skb, dataoff, sizeof(_tcph), &_tcph);
-       if (th == NULL) {
-               tcp_error_log(skb, net, pf, "short packet");
-               return -NF_ACCEPT;
-       }
+       u8 tcpflags;
 
        /* Not whole TCP header or malformed packet */
        if (th->doff*4 < sizeof(struct tcphdr) || tcplen < th->doff*4) {
-               tcp_error_log(skb, net, pf, "truncated packet");
-               return -NF_ACCEPT;
+               tcp_error_log(skb, state, "truncated packet");
+               return true;
        }
 
        /* Checksum invalid? Ignore.
@@ -753,27 +744,101 @@ static int tcp_error(struct net *net, struct nf_conn *tmpl,
         * because the checksum is assumed to be correct.
         */
        /* FIXME: Source route IP option packets --RR */
-       if (net->ct.sysctl_checksum && hooknum == NF_INET_PRE_ROUTING &&
-           nf_checksum(skb, hooknum, dataoff, IPPROTO_TCP, pf)) {
-               tcp_error_log(skb, net, pf, "bad checksum");
-               return -NF_ACCEPT;
+       if (state->net->ct.sysctl_checksum &&
+           state->hook == NF_INET_PRE_ROUTING &&
+           nf_checksum(skb, state->hook, dataoff, IPPROTO_TCP, state->pf)) {
+               tcp_error_log(skb, state, "bad checksum");
+               return true;
        }
 
        /* Check TCP flags. */
        tcpflags = (tcp_flag_byte(th) & ~(TCPHDR_ECE|TCPHDR_CWR|TCPHDR_PSH));
        if (!tcp_valid_flags[tcpflags]) {
-               tcp_error_log(skb, net, pf, "invalid tcp flag combination");
-               return -NF_ACCEPT;
+               tcp_error_log(skb, state, "invalid tcp flag combination");
+               return true;
        }
 
-       return NF_ACCEPT;
+       return false;
+}
+
+static noinline bool tcp_new(struct nf_conn *ct, const struct sk_buff *skb,
+                            unsigned int dataoff,
+                            const struct tcphdr *th)
+{
+       enum tcp_conntrack new_state;
+       struct net *net = nf_ct_net(ct);
+       const struct nf_tcp_net *tn = tcp_pernet(net);
+       const struct ip_ct_tcp_state *sender = &ct->proto.tcp.seen[0];
+       const struct ip_ct_tcp_state *receiver = &ct->proto.tcp.seen[1];
+
+       /* Don't need lock here: this conntrack not in circulation yet */
+       new_state = tcp_conntracks[0][get_conntrack_index(th)][TCP_CONNTRACK_NONE];
+
+       /* Invalid: delete conntrack */
+       if (new_state >= TCP_CONNTRACK_MAX) {
+               pr_debug("nf_ct_tcp: invalid new deleting.\n");
+               return false;
+       }
+
+       if (new_state == TCP_CONNTRACK_SYN_SENT) {
+               memset(&ct->proto.tcp, 0, sizeof(ct->proto.tcp));
+               /* SYN packet */
+               ct->proto.tcp.seen[0].td_end =
+                       segment_seq_plus_len(ntohl(th->seq), skb->len,
+                                            dataoff, th);
+               ct->proto.tcp.seen[0].td_maxwin = ntohs(th->window);
+               if (ct->proto.tcp.seen[0].td_maxwin == 0)
+                       ct->proto.tcp.seen[0].td_maxwin = 1;
+               ct->proto.tcp.seen[0].td_maxend =
+                       ct->proto.tcp.seen[0].td_end;
+
+               tcp_options(skb, dataoff, th, &ct->proto.tcp.seen[0]);
+       } else if (tn->tcp_loose == 0) {
+               /* Don't try to pick up connections. */
+               return false;
+       } else {
+               memset(&ct->proto.tcp, 0, sizeof(ct->proto.tcp));
+               /*
+                * We are in the middle of a connection,
+                * its history is lost for us.
+                * Let's try to use the data from the packet.
+                */
+               ct->proto.tcp.seen[0].td_end =
+                       segment_seq_plus_len(ntohl(th->seq), skb->len,
+                                            dataoff, th);
+               ct->proto.tcp.seen[0].td_maxwin = ntohs(th->window);
+               if (ct->proto.tcp.seen[0].td_maxwin == 0)
+                       ct->proto.tcp.seen[0].td_maxwin = 1;
+               ct->proto.tcp.seen[0].td_maxend =
+                       ct->proto.tcp.seen[0].td_end +
+                       ct->proto.tcp.seen[0].td_maxwin;
+
+               /* We assume SACK and liberal window checking to handle
+                * window scaling */
+               ct->proto.tcp.seen[0].flags =
+               ct->proto.tcp.seen[1].flags = IP_CT_TCP_FLAG_SACK_PERM |
+                                             IP_CT_TCP_FLAG_BE_LIBERAL;
+       }
+
+       /* tcp_packet will set them */
+       ct->proto.tcp.last_index = TCP_NONE_SET;
+
+       pr_debug("%s: sender end=%u maxend=%u maxwin=%u scale=%i "
+                "receiver end=%u maxend=%u maxwin=%u scale=%i\n",
+                __func__,
+                sender->td_end, sender->td_maxend, sender->td_maxwin,
+                sender->td_scale,
+                receiver->td_end, receiver->td_maxend, receiver->td_maxwin,
+                receiver->td_scale);
+       return true;
 }
 
 /* Returns verdict for packet, or -1 for invalid. */
 static int tcp_packet(struct nf_conn *ct,
-                     const struct sk_buff *skb,
+                     struct sk_buff *skb,
                      unsigned int dataoff,
-                     enum ip_conntrack_info ctinfo)
+                     enum ip_conntrack_info ctinfo,
+                     const struct nf_hook_state *state)
 {
        struct net *net = nf_ct_net(ct);
        struct nf_tcp_net *tn = tcp_pernet(net);
@@ -786,7 +851,14 @@ static int tcp_packet(struct nf_conn *ct,
        unsigned long timeout;
 
        th = skb_header_pointer(skb, dataoff, sizeof(_tcph), &_tcph);
-       BUG_ON(th == NULL);
+       if (th == NULL)
+               return -NF_ACCEPT;
+
+       if (tcp_error(th, skb, dataoff, state))
+               return -NF_ACCEPT;
+
+       if (!nf_ct_is_confirmed(ct) && !tcp_new(ct, skb, dataoff, th))
+               return -NF_ACCEPT;
 
        spin_lock_bh(&ct->lock);
        old_state = ct->proto.tcp.state;
@@ -1067,82 +1139,6 @@ static int tcp_packet(struct nf_conn *ct,
        return NF_ACCEPT;
 }
 
-/* Called when a new connection for this protocol found. */
-static bool tcp_new(struct nf_conn *ct, const struct sk_buff *skb,
-                   unsigned int dataoff)
-{
-       enum tcp_conntrack new_state;
-       const struct tcphdr *th;
-       struct tcphdr _tcph;
-       struct net *net = nf_ct_net(ct);
-       struct nf_tcp_net *tn = tcp_pernet(net);
-       const struct ip_ct_tcp_state *sender = &ct->proto.tcp.seen[0];
-       const struct ip_ct_tcp_state *receiver = &ct->proto.tcp.seen[1];
-
-       th = skb_header_pointer(skb, dataoff, sizeof(_tcph), &_tcph);
-       BUG_ON(th == NULL);
-
-       /* Don't need lock here: this conntrack not in circulation yet */
-       new_state = tcp_conntracks[0][get_conntrack_index(th)][TCP_CONNTRACK_NONE];
-
-       /* Invalid: delete conntrack */
-       if (new_state >= TCP_CONNTRACK_MAX) {
-               pr_debug("nf_ct_tcp: invalid new deleting.\n");
-               return false;
-       }
-
-       if (new_state == TCP_CONNTRACK_SYN_SENT) {
-               memset(&ct->proto.tcp, 0, sizeof(ct->proto.tcp));
-               /* SYN packet */
-               ct->proto.tcp.seen[0].td_end =
-                       segment_seq_plus_len(ntohl(th->seq), skb->len,
-                                            dataoff, th);
-               ct->proto.tcp.seen[0].td_maxwin = ntohs(th->window);
-               if (ct->proto.tcp.seen[0].td_maxwin == 0)
-                       ct->proto.tcp.seen[0].td_maxwin = 1;
-               ct->proto.tcp.seen[0].td_maxend =
-                       ct->proto.tcp.seen[0].td_end;
-
-               tcp_options(skb, dataoff, th, &ct->proto.tcp.seen[0]);
-       } else if (tn->tcp_loose == 0) {
-               /* Don't try to pick up connections. */
-               return false;
-       } else {
-               memset(&ct->proto.tcp, 0, sizeof(ct->proto.tcp));
-               /*
-                * We are in the middle of a connection,
-                * its history is lost for us.
-                * Let's try to use the data from the packet.
-                */
-               ct->proto.tcp.seen[0].td_end =
-                       segment_seq_plus_len(ntohl(th->seq), skb->len,
-                                            dataoff, th);
-               ct->proto.tcp.seen[0].td_maxwin = ntohs(th->window);
-               if (ct->proto.tcp.seen[0].td_maxwin == 0)
-                       ct->proto.tcp.seen[0].td_maxwin = 1;
-               ct->proto.tcp.seen[0].td_maxend =
-                       ct->proto.tcp.seen[0].td_end +
-                       ct->proto.tcp.seen[0].td_maxwin;
-
-               /* We assume SACK and liberal window checking to handle
-                * window scaling */
-               ct->proto.tcp.seen[0].flags =
-               ct->proto.tcp.seen[1].flags = IP_CT_TCP_FLAG_SACK_PERM |
-                                             IP_CT_TCP_FLAG_BE_LIBERAL;
-       }
-
-       /* tcp_packet will set them */
-       ct->proto.tcp.last_index = TCP_NONE_SET;
-
-       pr_debug("tcp_new: sender end=%u maxend=%u maxwin=%u scale=%i "
-                "receiver end=%u maxend=%u maxwin=%u scale=%i\n",
-                sender->td_end, sender->td_maxend, sender->td_maxwin,
-                sender->td_scale,
-                receiver->td_end, receiver->td_maxend, receiver->td_maxwin,
-                receiver->td_scale);
-       return true;
-}
-
 static bool tcp_can_early_drop(const struct nf_conn *ct)
 {
        switch (ct->proto.tcp.state) {
@@ -1510,7 +1506,7 @@ static int tcp_kmemdup_sysctl_table(struct nf_proto_net *pn,
        return 0;
 }
 
-static int tcp_init_net(struct net *net, u_int16_t proto)
+static int tcp_init_net(struct net *net)
 {
        struct nf_tcp_net *tn = tcp_pernet(net);
        struct nf_proto_net *pn = &tn->pn;
@@ -1538,16 +1534,13 @@ static struct nf_proto_net *tcp_get_net_proto(struct net *net)
        return &net->ct.nf_ct_proto.tcp.pn;
 }
 
-const struct nf_conntrack_l4proto nf_conntrack_l4proto_tcp4 =
+const struct nf_conntrack_l4proto nf_conntrack_l4proto_tcp =
 {
-       .l3proto                = PF_INET,
        .l4proto                = IPPROTO_TCP,
 #ifdef CONFIG_NF_CONNTRACK_PROCFS
        .print_conntrack        = tcp_print_conntrack,
 #endif
        .packet                 = tcp_packet,
-       .new                    = tcp_new,
-       .error                  = tcp_error,
        .can_early_drop         = tcp_can_early_drop,
 #if IS_ENABLED(CONFIG_NF_CT_NETLINK)
        .to_nlattr              = tcp_to_nlattr,
@@ -1571,39 +1564,3 @@ const struct nf_conntrack_l4proto nf_conntrack_l4proto_tcp4 =
        .init_net               = tcp_init_net,
        .get_net_proto          = tcp_get_net_proto,
 };
-EXPORT_SYMBOL_GPL(nf_conntrack_l4proto_tcp4);
-
-const struct nf_conntrack_l4proto nf_conntrack_l4proto_tcp6 =
-{
-       .l3proto                = PF_INET6,
-       .l4proto                = IPPROTO_TCP,
-#ifdef CONFIG_NF_CONNTRACK_PROCFS
-       .print_conntrack        = tcp_print_conntrack,
-#endif
-       .packet                 = tcp_packet,
-       .new                    = tcp_new,
-       .error                  = tcp_error,
-       .can_early_drop         = tcp_can_early_drop,
-#if IS_ENABLED(CONFIG_NF_CT_NETLINK)
-       .nlattr_size            = TCP_NLATTR_SIZE,
-       .to_nlattr              = tcp_to_nlattr,
-       .from_nlattr            = nlattr_to_tcp,
-       .tuple_to_nlattr        = nf_ct_port_tuple_to_nlattr,
-       .nlattr_to_tuple        = nf_ct_port_nlattr_to_tuple,
-       .nlattr_tuple_size      = tcp_nlattr_tuple_size,
-       .nla_policy             = nf_ct_port_nla_policy,
-#endif
-#ifdef CONFIG_NF_CONNTRACK_TIMEOUT
-       .ctnl_timeout           = {
-               .nlattr_to_obj  = tcp_timeout_nlattr_to_obj,
-               .obj_to_nlattr  = tcp_timeout_obj_to_nlattr,
-               .nlattr_max     = CTA_TIMEOUT_TCP_MAX,
-               .obj_size       = sizeof(unsigned int) *
-                                       TCP_CONNTRACK_TIMEOUT_MAX,
-               .nla_policy     = tcp_timeout_nla_policy,
-       },
-#endif /* CONFIG_NF_CONNTRACK_TIMEOUT */
-       .init_net               = tcp_init_net,
-       .get_net_proto          = tcp_get_net_proto,
-};
-EXPORT_SYMBOL_GPL(nf_conntrack_l4proto_tcp6);