struct src_controler {
struct list_head src_list;
int curr; //当前一共有多少了entry
int max; //最多能有多少个entry
};
struct src_entry {
struct list_head list;
__u32 src_addr; //源地址
unsigned long prev; //上次的时间戳
unsigned long passed; //一秒内已经过去了多少数据
};
static int __init xt_limit_init(void)
{
int ret;
src_ctl = kmalloc(sizeof(struct src_controler), GFP_KERNEL); //初始化全局变量
memset(src_ctl, 0, sizeof(struct src_controler));
INIT_LIST_HEAD(&src_ctl->src_list); //初始化全局变量的链表
src_ctl->curr = 0;
src_ctl->max = 1000; //本应该通过模块参数传进来的,这里写死,毕竟是个测试版
ret = xt_register_match(&ipt_limit_reg);
if (ret)
return ret;
ret = xt_register_match(&limit6_reg);
if (ret)
xt_unregister_match(&ipt_limit_reg);
return ret;
}
static void __exit xt_limit_fini(void)
{
xt_unregister_match(&ipt_limit_reg);
xt_unregister_match(&limit6_reg);
//这里应该有一个清理链表的操作,测试版没有实现
}
static int
ipt_limit_match(const struct sk_buff *skb,
const struct net_device *in,
const struct net_device *out,
const struct xt_match *match,
const void *matchinfo,
int offset,
unsigned int protoff,
int *hotdrop)
{
struct xt_rateinfo *r = ((struct xt_rateinfo *)matchinfo)->master;
unsigned long now = jiffies, prev = 0;
struct list_head *lh;
struct src_entry *entry = NULL;
struct src_entry *find_entry;
unsigned long nowa;
struct iphdr *iph = skb->nh.iph;
__u32 this_addr = iph->saddr;
list_for_each(lh, &src_ctl->src_list) { //遍历链表,找到这个ip地址对应的entry
find_entry = list_entry(lh, struct src_entry, list);
if (this_addr == find_entry->src_addr) {
entry = find_entry;
break;
}
}
if (entry) { //如果找到,将其加在头,这样实现了一个简单的lru
prev = entry->prev;
list_del(&entry->list);
list_add(&entry->list, &src_ctl->src_list);
} else { //如果没有找到,看看能否添加
if (src_ctl->curr+1 < src_ctl->max) {
add_entry:
entry = kmalloc(sizeof(struct src_entry), GFP_KERNEL);
memset(entry, 0, sizeof(struct src_entry));
entry->src_addr = this_addr;
prev = entry->prev = now - 1000;
list_add(&entry->list, &src_ctl->src_list);
src_ctl->curr++; //正确做法是atomic_inc
} else { //如果已经满了,那么看看能否删除最后的那个不活动的entry
entry = list_entry(src_ctl->src_list.prev, struct src_entry, list);
if (now-entry->prev > 1000)
goto add_entry;
return 1;
}
}
nowa = entry->passed + skb->len;
if (now-prev < 1000) { //这里的1000其实应该是HZ变量的值,由于懒得引头文件了,直接写死了。如果距上次统计还没有到1秒,则累加数据,不匹配
entry->passed = nowa;
return 0;
} else {
entry->prev = now;
entry->passed = 0;
if (r->burst >= nowa) { //如果到达了1秒,则判断是否超限,如果超限,则匹配,没有超限则重置字段,不匹配
return 0;
} else {
return 1;
}
}
return -1; //不会到达这里
}
优化三:上述实现中,数据单位是字节,这样很不合理,应该是可以配置的才对,比如默认是字节,还可以是k,m,g等等。
优化四:应该实现一个机制,定期清理不活跃的entry,以防止内存占用率过高。
反思:为何在入口位置的流控不实现队列呢?我们还是要想想流控的目的是什么,其一就是避免拥塞-网络的拥塞以及主机上层缓冲区的拥塞,对于接收数据而言,无论如何,流量对到达此地之前的网络的影响已经发生了,对往后的网络的影响还没有发生,因此对于已经发生的影响,没有必要再去进行速率适配了,直接执行动作即可。
如果你真的还需要limit模块完成它本来的功能,那么就别改limit模块了,还是直接写一个为好,这样也更灵活,毕竟我们也就不需要再配置--limit 1/sec去迎合limit的语法了,具体方法参见《编写iptables模块实现不连续IP地址的DNAT-POOL》
修正:
如果同时下载多个局域网内的大文件,会发现上述的match回调函数工作的不是很好,速度并没有被限制住,这是因为我计时统计统计的粒度太粗,一秒统计一次,这一秒中,很多大包将溜过去,因此需要更细粒度的统计,那就是实时的统计,使用数据量/时间间隔这个除式来统计,代码如下:
static int
ipt_limit_match(const struct sk_buff *skb,
const struct net_device *in,
const struct net_device *out,
const struct xt_match *match,
const void *matchinfo,
int offset,
unsigned int protoff,
int *hotdrop)
{
struct xt_rateinfo *r = ((struct xt_rateinfo *)matchinfo)->master;
unsigned long now = jiffies, prev = 0;
struct list_head *lh;
struct src_entry *entry = NULL;
struct src_entry *find_entry;
unsigned long nowa;
unsigned long rate;
struct iphdr *iph = skb->nh.iph;
__u32 this_addr = iph->saddr;
list_for_each(lh, &src_ctl->src_list) {
find_entry = list_entry(lh, struct src_entry, list);
if (this_addr == find_entry->src_addr) {
entry = find_entry;
break;
}
}
if (entry) {
prev = entry->prev;
list_del(&entry->list);
list_add(&entry->list, &src_ctl->src_list);
} else {
if (src_ctl->curr+1 < src_ctl->max) {
add_entry:
entry = kmalloc(sizeof(struct src_entry), GFP_KERNEL);
memset(entry, 0, sizeof(struct src_entry));
entry->src_addr = this_addr;
prev = entry->prev = now - 1000;
list_add(&entry->list, &src_ctl->src_list);
src_ctl->curr++;
} else {
entry = list_entry(src_ctl->src_list.prev, struct src_entry, list);
if (now-entry->prev > 1000)
goto add_entry;
return 1;
}
}
nowa = entry->passed + skb->len;
entry->passed = nowa;
if (now-prev > 0) {
rate = entry->passed/(now-prev);
} else
rate = nowa;
entry->prev = now;
entry->passed = 0;
if (rate > r->burst) {
return 1;
}
return 0;
}