struct my_ext {
struct orig_ext;
char info[0];
};
...
enum idx_idx{
ROUTE,
SOCKET,
AND_SO_ON,
IDX_IDX_NUM
};
然后定义一个数组来标识真正的索引:
int idx[IDX_IDX_NUM];
定义一个bitmap来表示slot的使用情况即可,具体的做法可以看代码,一目则了然。
修改include/net/netfilter/nf_conntrack_extend.h:
--- nf_conntrack_extend.h.orig 2014-03-29 12:55:26.000000000 +0800
+++ nf_conntrack_extend.h 2015-01-15 17:28:39.000000000 +0800
@@ -3,13 +3,17 @@
#include
+#define NFCT_EXT_EXT
+
enum nf_ct_ext_id
{
NF_CT_EXT_HELPER,
NF_CT_EXT_NAT,
NF_CT_EXT_ACCT,
NF_CT_EXT_ECACHE,
- NF_CT_EXT_NEW,
+#ifdef NFCT_EXT_EXT
+ NF_CT_EXT_EXT,
+#endif
NF_CT_EXT_NUM,
};
@@ -17,13 +21,21 @@
#define NF_CT_EXT_NAT_TYPE struct nf_conn_nat
#define NF_CT_EXT_ACCT_TYPE struct nf_conn_counter
#define NF_CT_EXT_ECACHE_TYPE struct nf_conntrack_ecache
-#define NF_CT_EXT_NEW_TYPE struct nf_conntrack_new
+#ifdef NFCT_EXT_EXT
+#define NF_CT_EXT_EXT_TYPE struct nf_conntrack_ext
+#endif
/* Extensions: optional stuff which isn't permanently in struct. */
struct nf_ct_ext {
struct rcu_head rcu;
+#ifdef NFCT_EXT_EXT
+ /* 内存不再是个事儿 */
+ u16 offset[NF_CT_EXT_NUM];
+ u16 len;
+#else
u8 offset[NF_CT_EXT_NUM];
u8 len;
+#endif
char data[0];
};
@@ -80,10 +92,18 @@
unsigned int flags;
/* Length and min alignment. */
+#ifdef NFCT_EXT_EXT
+ /* 内存不再是个事儿 */
+ u16 len;
+ u16 align;
+ /* initial size of nf_ct_ext. */
+ u16 alloc_size;
+#else
u8 len;
u8 align;
/* initial size of nf_ct_ext. */
u8 alloc_size;
+#endif
};
int nf_ct_extend_register(struct nf_ct_ext_type *type);
增加include/net/netfilter/nf_conntrack_ext.h:
/*
* (C) 2015 marywangran
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 as
* published by the Free Software Foundation.
*/
#ifndef _NF_CONNTRACK_EXT_H
#define _NF_CONNTRACK_EXT_H
#include
#include
#include
#include
#include
#define MAX_EXT_SLOTS 8
#define BITINT 1
struct nf_conntrack_ext {
/* 必须有一个数组用于自省或者反射 */
int bits_idx[MAX_EXT_SLOTS];
int bits[BITINT];
char *slot[MAX_EXT_SLOTS];
};
int nf_ct_exts_add(const struct nf_conn *ct, void *ext);
void *nf_ct_exts_get(const struct nf_conn *ct, int idx);
void nf_ct_exts_remove(const struct nf_conn *ct, int idx);
struct nf_conntrack_ext *nf_conn_exts_find(const struct nf_conn *ct);
struct nf_conntrack_ext *nf_conn_exts_add(struct nf_conn *ct, gfp_t gfp);
extern int nf_conntrack_exts_init();
extern void nf_conntrack_exts_fini();
#endif /* _NF_CONNTRACK_EXT_H */
修改net/netfilter/nf_conntrack_core.c:
--- nf_conntrack_core.c.orig 2014-03-29 13:00:17.000000000 +0800
+++ nf_conntrack_core.c 2015-01-15 17:01:28.000000000 +0800
@@ -42,6 +42,10 @@
#include
#include
#include
+#ifdef NFCT_EXT_EXT
+/* 引入extend的extend头文件 */
+#include
+#endif
#include
#include
@@ -644,8 +648,11 @@
}
nf_ct_acct_ext_add(ct, GFP_ATOMIC);
-
nf_ct_ecache_ext_add(ct, GFP_ATOMIC);
+#ifdef NFCT_EXT_EXT
+ /* 在创建conntrack的时候初始化extend的extend */
+ nf_conn_exts_add(ct, GFP_ATOMIC);
+#endif
spin_lock_bh(&nf_conntrack_lock);
exp = nf_ct_find_expectation(net, tuple);
@@ -1130,6 +1137,10 @@
nf_ct_free_hashtable(net->ct.hash, net->ct.hash_vmalloc,
net->ct.htable_size);
+#ifdef NFCT_EXT_EXT
+ /* 析构extend的extend */
+ nf_conntrack_exts_fini();
+#endif
nf_conntrack_ecache_fini(net);
nf_conntrack_acct_fini(net);
nf_conntrack_expect_fini(net);
@@ -1344,9 +1355,19 @@
ret = nf_conntrack_ecache_init(net);
if (ret < 0)
goto err_ecache;
+#ifdef NFCT_EXT_EXT
+ /* 注册extend的extend */
+ ret = nf_conntrack_exts_init();
+ if (ret < 0)
+ goto err_exts;
+#endif
return 0;
+#ifdef NFCT_EXT_EXT
+err_exts:
+ nf_conntrack_ecache_fini(net);
+#endif
err_ecache:
nf_conntrack_acct_fini(net);
err_acct:
增加net/netfilter/nf_conntrack_ext.c:
/* conntrack扩展的扩展实现文件. */
/*
* conntrack扩展的扩展实现文件.
* 技术核心:
* 1.位图
* 2.索引的索引数组(外部维护的一个‘蓝图’)
* (C) 2015 marywangran
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License version 2 as
* published by the Free Software Foundation.
*/
#include
#include
#include
/* 这个spin lock应该和每一个ext绑定而不是全局的! */
static DEFINE_SPINLOCK(nfct_ext_lock);
static struct nf_ct_ext_type ext_extend __read_mostly = {
.len = sizeof(struct nf_conntrack_ext),
.align = __alignof__(struct nf_conntrack_ext),
.id = NF_CT_EXT_EXT,
.flags = NF_CT_EXT_F_PREALLOC,
};
/*
* 增加一个数据到extend的extend
* 注意:需要自己在外部维护一个关于索引的索引的数组
**/
int nf_ct_exts_add(const struct nf_conn *ct, void *ext)
{
int ret_idx = -1;
struct nf_conntrack_ext *exts = NULL;
if (!ext) {
goto out;
}
exts = nf_conn_exts_find(ct);
if (!exts) {
goto out;
}
spin_lock(&nfct_ext_lock);
ret_idx = find_first_zero_bit(exts->bits, MAX_EXT_SLOTS);
if (ret_idx > MAX_EXT_SLOTS) {
ret_idx = -1;
spin_unlock(&nfct_ext_lock);
goto out;
}
if (exts->slot[ret_idx]) {
ret_idx = -1;
spin_unlock(&nfct_ext_lock);
goto out;
}
set_bit(ret_idx, exts->bits);
exts->slot[ret_idx] = (char *)ext;
spin_unlock(&nfct_ext_lock);
out:
return ret_idx;
};
EXPORT_SYMBOL(nf_ct_exts_add);
/*
* 根据ID的index获取保存在conntrack上的数据
**/
void *nf_ct_exts_get(const struct nf_conn *ct, int idx)
{
char *ret = NULL;
struct nf_conntrack_ext *exts;
if (idx > MAX_EXT_SLOTS || idx < 0) {
goto out;
}
exts = nf_conn_exts_find(ct);
if (!exts) {
goto out;
}
spin_lock(&nfct_ext_lock);
if (! test_bit(idx, exts->bits)) {
spin_unlock(&nfct_ext_lock);
goto out;
}
ret = exts->slot[idx];
spin_unlock(&nfct_ext_lock);
out:
return (void *)ret;
}
EXPORT_SYMBOL(nf_ct_exts_get);
/*
* 根据ID的index删除保存在conntrack上的数据
**/
void nf_ct_exts_remove(const struct nf_conn *ct, int idx)
{
struct nf_conntrack_ext *exts;
if (idx > MAX_EXT_SLOTS || idx < 0) {
goto out;
}
exts = nf_conn_exts_find(ct);
if (!exts) {
goto out;
}
spin_lock(&nfct_ext_lock);
if (! test_bit(idx, exts->bits)) {
spin_unlock(&nfct_ext_lock);
goto out;
}
clear_bit(idx, exts->bits);
exts->slot[idx] = NULL;
spin_unlock(&nfct_ext_lock);
out:
return;
};
EXPORT_SYMBOL(nf_ct_exts_remove);
struct nf_conntrack_ext *nf_conn_exts_find(const struct nf_conn *ct)
{
return nf_ct_ext_find(ct, NF_CT_EXT_EXT);
}
EXPORT_SYMBOL(nf_conn_exts_find);
struct nf_conntrack_ext *nf_conn_exts_add(struct nf_conn *ct, gfp_t gfp)
{
struct nf_conntrack_ext *exts;
exts = nf_ct_ext_add(ct, NF_CT_EXT_EXT, gfp);
if (!exts) {
printk("failed to add extensions area");
return NULL;
}
/* 初始化 */
{
int i;
for (i = 0; i < MAX_EXT_SLOTS; i++) {
exts->bits_idx[i] = -1;
exts->slot[i] = NULL;
}
}
return exts;
}
EXPORT_SYMBOL(nf_conn_exts_add);
int nf_conntrack_exts_init()
{
int ret;
ret = nf_ct_extend_register(&ext_extend);
if (ret < 0) {
printk("nf_conntrack_ext: Unable to register extension\n");
goto out;
}
printk("nf_conntrack_ext: register extension OK\n");
return 0;
out:
return ret;
}
void nf_conntrack_exts_fini()
{
nf_ct_extend_unregister(&ext_extend);
}
测试程序nf_conntrack_private_data_auto_save_restore.c:
#include
#include
#include
#include
MODULE_AUTHOR("marywangran");
MODULE_LICENSE("GPL");
/*
* 必须定义一个用于自省的数组索引
* 否则就会陷入“数据-元数据-元元数据-元元元数据...”的无限自指怪圈!
* 这也是AI所面临的问题:自我意识是根本:being知道某件事,并且being知道“being知道某件事”,
* 并且being知道“being知道‘being知道某件事’”...
*/
enum ext_idx_idx {
CONN_ORIG_ROUTE,
CONN_REPLY_ROUTE,
CONN_SOCK,
CONN_AND_SO_ON,
NUM
};
static inline void
nf_ext_put_sock(struct sock *sk)
{
if ((sk->sk_protocol == IPPROTO_TCP) && (sk->sk_state == TCP_TIME_WAIT)){
inet_twsk_put(inet_twsk(sk));
} else {
sock_put(sk);
}
}
static void
nf_ext_destructor(struct sk_buff *skb)
{
struct sock *sk = skb->sk;
skb->sk = NULL;
skb->destructor = NULL;
if (sk) {
nf_ext_put_sock(sk);
}
}
/* 缓存socket的HOOK函数 */
static unsigned int ipv4_conntrack_save_sock (unsigned int hooknum,
struct sk_buff *skb,
const struct net_device *in,
const struct net_device *out,
int (*okfn)(struct sk_buff *))
{
struct nf_conn *ct;
enum ip_conntrack_info ctinfo;
struct nf_conntrack_ext *exts;
ct = nf_ct_get(skb, &ctinfo);
if (!ct || ct == &nf_conntrack_untracked) {
goto out;
}
if ((ip_hdr(skb)->protocol != IPPROTO_UDP) &&
(ip_hdr(skb)->protocol != IPPROTO_TCP)) {
goto out;
}
exts = nf_conn_exts_find(ct);
if (exts) {
/* 缓存socket,注意,只有INPUT的恢复缓存socket才有比较大的意义 */
if (exts->bits_idx[CONN_SOCK] == -1) {
if (skb->sk == NULL){
goto out;
}
if ((ip_hdr(skb)->protocol == IPPROTO_TCP) && skb->sk->sk_state != TCP_ESTABLISHED) {
goto out;
}
exts->bits_idx[CONN_SOCK] = nf_ct_exts_add(ct, skb->sk);
}
}
out:
return NF_ACCEPT;
}
/* 缓存路由项的HOOK函数 */
static unsigned int ipv4_conntrack_save_dst (unsigned int hooknum,
struct sk_buff *skb,
const struct net_device *in,
const struct net_device *out,
int (*okfn)(struct sk_buff *))
{
struct nf_conn *ct;
enum ip_conntrack_info ctinfo;
struct nf_conntrack_ext *exts;
ct = nf_ct_get(skb, &ctinfo);
if (!ct || ct == &nf_conntrack_untracked) {
goto out;
}
exts = nf_conn_exts_find(ct);
if (exts) {
/* 缓存路由。注意,有两个方向。IP无方向,两个方向路由都要缓存 */
int dir = CTINFO2DIR(ctinfo);
int idx = (dir == IP_CT_DIR_ORIGINAL)?CONN_ORIG_ROUTE:CONN_REPLY_ROUTE;
if (exts->bits_idx[idx] == -1) {
struct dst_entry *dst = skb_dst(skb);
if (dst) {
dst_hold(dst);
exts->bits_idx[idx] = nf_ct_exts_add(ct, dst);
}
}
}
out:
return NF_ACCEPT;
}
/* 获取缓存socket的HOOK函数 */
static unsigned int ipv4_conntrack_restore_sock (unsigned int hooknum,
struct sk_buff *skb,
const struct net_device *in,
const struct net_device *out,
int (*okfn)(struct sk_buff *))
{
struct nf_conn *ct;
enum ip_conntrack_info ctinfo;
struct nf_conntrack_ext *exts;
ct = nf_ct_get(skb, &ctinfo);
if (!ct || ct == &nf_conntrack_untracked){
goto out;
}
if ((ip_hdr(skb)->protocol != IPPROTO_UDP) &&
(ip_hdr(skb)->protocol != IPPROTO_TCP)) {
goto out;
}
exts = nf_conn_exts_find(ct);
if (exts) {
/* 获取缓存的socket */
if (exts->bits_idx[CONN_SOCK] != -1) {
struct sock *sk = (struct sock *)nf_ct_exts_get(ct, exts->bits_idx[CONN_SOCK]);
if (sk) {
if ((ip_hdr(skb)->protocol == IPPROTO_TCP) && sk->sk_state != TCP_ESTABLISHED) {
goto out;
}
if (unlikely(!atomic_inc_not_zero(&sk->sk_refcnt))) {
goto out;
}
skb_orphan(skb);
skb->sk = sk;
/* 曾经在上面atomic inc了引用计数,等到转交给下任owner的时候,一定要put */
skb->destructor = nf_ext_destructor;
}
}
}
out:
return NF_ACCEPT;
}
/* 获取缓存路由项的HOOK函数 */
static unsigned int ipv4_conntrack_restore_dst (unsigned int hooknum,
struct sk_buff *skb,
const struct net_device *in,
const struct net_device *out,
int (*okfn)(struct sk_buff *))
{
struct nf_conn *ct;
enum ip_conntrack_info ctinfo;
struct nf_conntrack_ext *exts;
ct = nf_ct_get(skb, &ctinfo);
if (!ct || ct == &nf_conntrack_untracked) {
goto out;
}
exts = nf_conn_exts_find(ct);
if (exts) {
/* 获取缓存的路由 */
int dir = CTINFO2DIR(ctinfo);
int idx = (dir == IP_CT_DIR_ORIGINAL)?CONN_ORIG_ROUTE:CONN_REPLY_ROUTE;
if (exts->bits_idx[idx] != -1) {
struct dst_entry *dst = (struct dst_entry *)nf_ct_exts_get(ct, exts->bits_idx[idx]);
if (dst) {
dst_hold(dst);
skb_dst_set(skb, dst);
}
}
}
out:
return NF_ACCEPT;
}
/*
* 总体图景:
* OUTPUT:缓存socket
* INPUT:恢复socket
*
* POSTROUTING|INPUT:缓存路由
* PREROUTING:恢复路由
*/
static struct nf_hook_ops ipv4_conn_cache_ops[] __read_mostly = {
{
.hook = ipv4_conntrack_save_dst,
.owner = THIS_MODULE,
.pf = NFPROTO_IPV4,
.hooknum = NF_INET_POST_ROUTING,
.priority = NF_IP_PRI_CONNTRACK + 1,
},
{
.hook = ipv4_conntrack_save_sock,
.owner = THIS_MODULE,
.pf = NFPROTO_IPV4,
.hooknum = NF_INET_LOCAL_OUT,
.priority = NF_IP_PRI_CONNTRACK + 1,
},
{
.hook = ipv4_conntrack_save_dst,
.owner = THIS_MODULE,
.pf = NFPROTO_IPV4,
.hooknum = NF_INET_LOCAL_IN,
.priority = NF_IP_PRI_CONNTRACK + 1,
},
{
.hook = ipv4_conntrack_restore_sock,
.owner = THIS_MODULE,
.pf = NFPROTO_IPV4,
.hooknum = NF_INET_LOCAL_IN,
.priority = NF_IP_PRI_CONNTRACK + 2,
},
{
.hook = ipv4_conntrack_restore_dst,
.owner = THIS_MODULE,
.pf = NFPROTO_IPV4,
.hooknum = NF_INET_PRE_ROUTING,
.priority = NF_IP_PRI_CONNTRACK + 1,
},
};
static int __init cache_dst_and_sock_demo_init(void)
{
int ret;
ret = nf_register_hooks(ipv4_conn_cache_ops, ARRAY_SIZE(ipv4_conn_cache_ops));
if (ret) {
goto out;;
}
return 0;
out:
return ret;
}
static void __exit cache_dst_and_sock_demo_fini(void)
{
nf_unregister_hooks(ipv4_conn_cache_ops, ARRAY_SIZE(ipv4_conn_cache_ops));
}
module_init(cache_dst_and_sock_demo_init);
module_exit(cache_dst_and_sock_demo_fini);
在测试程序中,我缓存了路由项以及到达本机数据包的socket,这样仅仅查询到conntrack就可以直接将路由和socket取出来了,取值的过程由于存在索引数组和索引的索引数组,因此就是数组下标寻址,不再需要查询。