基于连接的每IP限速实现

在《 修改netfilter的limit模块实现基于单个ip的流量监控》中,介绍了一种方式实现针对一个网段每个IP地址的流量控制,如果细化到流,那个就叫做针对每个流的流量控制,我们知道,一个IP地址可以和很多流相关联,针对流的流控限制的不是主机,而是主机上的一个连接,它的约束要比针对IP地址的流控更加小。
然而如何来实现这个呢?实际上在Linux中,几乎所有的流控都可以用TC工具配置出来,然而还有一种方式,那就是使用Netfilter来实现,然后用iptables来配置,这正是体现了Netfilter框架的灵活和强大,当然使用TC也未尝不可,只是TC虽强大,然则功能比较单一,不像Netfilter一样可以扩展到几乎无限制的应用场合。
实现很简单,还是修改limit模块,这次连iptables模块都不用写了,这是改变了iptables对应模块的语义,源代码如下所示:

/* (C) 1999 Jérôme de Vivie 
 * (C) 1999 Hervé Eychenne 
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 2 as
 * published by the Free Software Foundation.
 */

#include 
#include 
#include 
#include 

#include 
#include 

#include 
#include 

//针对每一个方向给一个流量约束
struct xt_limit_priv {
    unsigned long prev[2];
    uint32_t credit[2];
};

MODULE_LICENSE("GPL");
MODULE_AUTHOR("Herve Eychenne ");
MODULE_DESCRIPTION("Xtables: rate-limit match");
MODULE_ALIAS("ipt_limit");
MODULE_ALIAS("ip6t_limit");

/* The algorithm used is the Simple Token Bucket Filter (TBF)
 * see net/sched/sch_tbf.c in the linux source tree
 */

static DEFINE_SPINLOCK(limit_lock);

/* Rusty: This is my (non-mathematically-inclined) understanding of
   this algorithm.  The `average rate' in jiffies becomes your initial
   amount of credit `credit' and the most credit you can ever have
   `credit_cap'.  The `peak rate' becomes the cost of passing the
   test, `cost'.

   `prev' tracks the last packet hit: you gain one credit per jiffy.
   If you get credit balance more than this, the extra credit is
   discarded.  Every time the match passes, you lose `cost' credits;
   if you don't have that many, the test fails.

   See Alexey's formal explanation in net/sched/sch_tbf.c.

   To get the maxmum range, we multiply by this factor (ie. you get N
   credits per jiffy).  We want to allow a rate as low as 1 per day
   (slowest userspace tool allows), which means
   CREDITS_PER_JIFFY*HZ*60*60*24 < 2^32. ie. */
#define MAX_CPJ (0xFFFFFFFF / (HZ*60*60*24))

/* Repeated shift and or gives us all 1s, final shift and add 1 gives
 * us the power of 2 below the theoretical max, so GCC simply does a
 * shift. */
#define _POW2_BELOW2(x) ((x)|((x)>>1))
#define _POW2_BELOW4(x) (_POW2_BELOW2(x)|_POW2_BELOW2((x)>>2))
#define _POW2_BELOW8(x) (_POW2_BELOW4(x)|_POW2_BELOW4((x)>>4))
#define _POW2_BELOW16(x) (_POW2_BELOW8(x)|_POW2_BELOW8((x)>>8))
#define _POW2_BELOW32(x) (_POW2_BELOW16(x)|_POW2_BELOW16((x)>>16))
#define POW2_BELOW32(x) ((_POW2_BELOW32(x)>>1) + 1)

#define CREDITS_PER_JIFFY POW2_BELOW32(MAX_CPJ)

//定义一个conntrack的extension
static struct nf_ct_ext_type limit_extend __read_mostly = {
    .len    = sizeof(struct xt_limit_priv),
    .align    = __alignof__(struct xt_limit_priv),
    .id    = NF_CT_EXT_LIMIT,
};
static u_int32_t user2credits(u_int32_t user);

static bool
limit_mt(const struct sk_buff *skb, const struct xt_match_param *par)
{
    const struct xt_rateinfo *r = par->matchinfo;
    struct nf_conn *ct;
    enum ip_conntrack_info ctinfo;
    struct xt_limit_priv *priv;
    unsigned long now = jiffies;
    int dir = 1;
    

    ct = nf_ct_get(skb, &ctinfo);
    priv = nf_ct_ext_find(ct, NF_CT_EXT_LIMIT);
    if(priv == NULL) {
        if (nf_ct_is_confirmed(ct))
            return false;
        priv = nf_ct_ext_add(ct, NF_CT_EXT_LIMIT, GFP_ATOMIC);
        if (priv == NULL) {
            printk("failed to add LIMIT extension\n");
            return false;
        }
        priv->prev[0] = priv->prev[1] = jiffies;
        priv->credit[0] = priv->credit[1] = user2credits(r->avg * r->burst); /* Credits full. */
    }
    //和DIR相关的元素保存在skb中而不是conntrack中,这样可以最小化锁的开销,因为一个流的数据包的方向是双向的,何时到来并不清楚,如果在conntrack中保存方向,将无法实现两个方向的并行处理。
    dir = CTINFO2DIR(ctinfo) == IP_CT_DIR_ORIGINAL ? 1 : 0;
    
    spin_lock_bh(&limit_lock);
    priv->credit[dir] += (now - xchg(&priv->prev[dir], now)) * CREDITS_PER_JIFFY;
    if (priv->credit[dir] > r->credit_cap)
        priv->credit[dir] = r->credit_cap;

    if (priv->credit[dir] >= r->cost) {

        /* We're not limited. */
        priv->credit[dir] -= skb->len;
        spin_unlock_bh(&limit_lock);
        return true;
    }

    spin_unlock_bh(&limit_lock);
    return false;

}

/* Precision saver. */
static u_int32_t
user2credits(u_int32_t user)
{
    /* If multiplying would overflow... */
    if (user > 0xFFFFFFFF / (HZ*CREDITS_PER_JIFFY))
        /* Divide first. */
        return (user / XT_LIMIT_SCALE) * HZ * CREDITS_PER_JIFFY;

    return (user * HZ * CREDITS_PER_JIFFY) / XT_LIMIT_SCALE;
}

static bool limit_mt_check(const struct xt_mtchk_param *par)
{
    struct xt_rateinfo *r = par->matchinfo;
    struct xt_limit_priv *priv;

    /* Check for overflow. */
    if (r->burst == 0
        || user2credits(r->avg * r->burst) < user2credits(r->avg)) {
        printk("Overflow in xt_limit, try lower: %u/%u\n",
               r->avg, r->burst);
        return false;
    }

    priv = kmalloc(sizeof(*priv), GFP_KERNEL);
    if (priv == NULL)
        return false;

    /* For SMP, we only want to use one set of state. */
    r->master = priv;
    if (r->cost == 0) {
        /* User avg in seconds * XT_LIMIT_SCALE: convert to jiffies *
           128. */
        priv->prev[0] = priv->prev[1] = jiffies;
        priv->credit[0] = priv->credit[1] =  user2credits(r->avg * r->burst); /* Credits full. */
        r->credit_cap = user2credits(r->avg * r->burst); /* Credits full. */
        r->cost = user2credits(r->avg);
    }
    return true;
}

static void limit_mt_destroy(const struct xt_mtdtor_param *par)
{
    const struct xt_rateinfo *info = par->matchinfo;

    kfree(info->master);
}

#ifdef CONFIG_COMPAT
struct compat_xt_rateinfo {
    u_int32_t avg;
    u_int32_t burst;

    compat_ulong_t prev;
    u_int32_t credit;
    u_int32_t credit_cap, cost;

    u_int32_t master;
};

/* To keep the full "prev" timestamp, the upper 32 bits are stored in the
 * master pointer, which does not need to be preserved. */
static void limit_mt_compat_from_user(void *dst, void *src)
{
    const struct compat_xt_rateinfo *cm = src;
    struct xt_rateinfo m = {
        .avg        = cm->avg,
        .burst        = cm->burst,
        .prev        = cm->prev | (unsigned long)cm->master << 32,
        .credit        = cm->credit,
        .credit_cap    = cm->credit_cap,
        .cost        = cm->cost,
    };
    memcpy(dst, &m, sizeof(m));
}

static int limit_mt_compat_to_user(void __user *dst, void *src)
{
    const struct xt_rateinfo *m = src;
    struct compat_xt_rateinfo cm = {
        .avg        = m->avg,
        .burst        = m->burst,
        .prev        = m->prev,
        .credit        = m->credit,
        .credit_cap    = m->credit_cap,
        .cost        = m->cost,
        .master        = m->prev >> 32,
    };
    return copy_to_user(dst, &cm, sizeof(cm)) ? -EFAULT : 0;
}
#endif /* CONFIG_COMPAT */

static struct xt_match limit_mt_reg __read_mostly = {
    .name             = "limit",
    .revision         = 0,
    .family           = NFPROTO_UNSPEC,
    .match            = limit_mt,
    .checkentry       = limit_mt_check,
    .destroy          = limit_mt_destroy,
    .matchsize        = sizeof(struct xt_rateinfo),
#ifdef CONFIG_COMPAT
    .compatsize       = sizeof(struct compat_xt_rateinfo),
    .compat_from_user = limit_mt_compat_from_user,
    .compat_to_user   = limit_mt_compat_to_user,
#endif
    .me               = THIS_MODULE,
};

static int __init limit_mt_init(void)
{
    int rv = xt_register_match(&limit_mt_reg);
    if (rv < 0) {
        return rv;
    }
    return nf_ct_extend_register(&limit_extend);
}

static void __exit limit_mt_exit(void)
{
    xt_unregister_match(&limit_mt_reg);
}

module_init(limit_mt_init);
module_exit(limit_mt_exit);

在该实现中,重要的是对ip_conntrack的extension的应用,如果每次都在结构体里面增加字段,那种实现太蹩脚了,ip_conntrack在设计之初就有可扩展性,那就是最后由一个ext字段可以供你来增加你自己的数据,类似其它结构体的private字段,类似的也有一个0元素的数组,直接调用add/find接口,无需对核心结构体动手术。
最后,在使用的时候,别忘了单位。原本的limit模块使用包计数,修改后的使用字节计数。

你可能感兴趣的:(基于连接的每IP限速实现)