public class TrafficShapingFilter extends IoFilterAdapter {
protected static Logger log = LoggerFactory.getLogger(TrafficShapingFilter.class);
private final AttributeKey STATE = new AttributeKey(getClass(), "state");
//定时器服务,用来创建定时任务的定时器服务类
private final ScheduledExecutorService scheduledExecutor;
//用来估算对象占据JVM内存的大小的类,计算出一个对象实际占用JVM内存的大小
private final MessageSizeEstimator messageSizeEstimator;
private volatile int maxReadThroughput;
private volatile int maxWriteThroughput;
private volatile int poolSize = 1;
/**
* 构造方法
* @param maxReadThroughput 最大读取字节大小(单位:秒)
* @param maxWriteThroughput 最大写出的字节大小(单位:秒)
*/
public TrafficShapingFilter(int maxReadThroughput, int maxWriteThroughput) {
this(null, null, maxReadThroughput, maxWriteThroughput);
}
public TrafficShapingFilter(ScheduledExecutorService scheduledExecutor, int maxReadThroughput,
int maxWriteThroughput) {
this(scheduledExecutor, null, maxReadThroughput, maxWriteThroughput);
}
public TrafficShapingFilter(ScheduledExecutorService scheduledExecutor, MessageSizeEstimator messageSizeEstimator,
int maxReadThroughput, int maxWriteThroughput) {
log.debug("ctor - executor: {} estimator: {} max read: {} max write: {}", new Object[] { scheduledExecutor,
messageSizeEstimator, maxReadThroughput, maxWriteThroughput });
if (scheduledExecutor == null) {
//实例化一个定时器对象,线程池的默认数量是1.
scheduledExecutor = new ScheduledThreadPoolExecutor(poolSize);
//throw new NullPointerException("scheduledExecutor");
}
if (messageSizeEstimator == null) {
//实例化默认的估算消息大小的类
messageSizeEstimator = new DefaultMessageSizeEstimator() {
@Override
public int estimateSize(Object message) {
if (message instanceof IoBuffer) {
return ((IoBuffer) message).remaining();
}
return super.estimateSize(message);
}
};
}
this.scheduledExecutor = scheduledExecutor;
this.messageSizeEstimator = messageSizeEstimator;
//设置最大读取字节长度(单位:秒)
setMaxReadThroughput(maxReadThroughput);
//设置最大写出字节长度(单位:秒)
setMaxWriteThroughput(maxWriteThroughput);
}
public ScheduledExecutorService getScheduledExecutor() {
return scheduledExecutor;
}
public MessageSizeEstimator getMessageSizeEstimator() {
return messageSizeEstimator;
}
public int getMaxReadThroughput() {
return maxReadThroughput;
}
public void setMaxReadThroughput(int maxReadThroughput) {
if (maxReadThroughput < 0) {
maxReadThroughput = 0;
}
this.maxReadThroughput = maxReadThroughput;
}
public int getMaxWriteThroughput() {
return maxWriteThroughput;
}
public void setMaxWriteThroughput(int maxWriteThroughput) {
if (maxWriteThroughput < 0) {
maxWriteThroughput = 0;
}
this.maxWriteThroughput = maxWriteThroughput;
}
public int getPoolSize() {
return poolSize;
}
public void setPoolSize(int poolSize) {
if (poolSize < 1) {
poolSize = 1;
}
this.poolSize = poolSize;
}
@Override
public void onPreAdd(IoFilterChain parent, String name, NextFilter nextFilter) throws Exception {
if (parent.contains(this)) {
throw new IllegalArgumentException(
"You can't add the same filter instance more than once. Create another instance and add it.");
}
//给每一个session添加一个属性 STATE 属性,关联一个State对象。
parent.getSession().setAttribute(STATE, new State());
//调节会话sessiion读取buffer大小
adjustReadBufferSize(parent.getSession());
}
@Override
public void onPostRemove(IoFilterChain parent, String name, NextFilter nextFilter) throws Exception {
//写完毕,关闭会话移除关联对象State
parent.getSession().removeAttribute(STATE);
}
@Override
public void messageReceived(NextFilter nextFilter, final IoSession session, Object message) throws Exception {
int maxReadThroughput = this.maxReadThroughput;
//process the request if our max is greater than zero
if (maxReadThroughput > 0) {
final State state = (State) session.getAttribute(STATE);
long currentTime = System.currentTimeMillis();
long suspendTime = 0;
boolean firstRead = false;
synchronized (state) {
//估算当前已经读取字节数组总量
state.readBytes += messageSizeEstimator.estimateSize(message);
//如果读取回话挂起,suspendedRead=true
if (!state.suspendedRead) {
if (state.readStartTime == 0) {//表示第一次读取会话的数据
firstRead = true;
//设置会话开始读取数据的时间
state.readStartTime = currentTime - 1000;
}
//估算当前平均每秒读取字节流量大小
long throughput = (state.readBytes * 1000 / (currentTime - state.readStartTime));
if (throughput >= maxReadThroughput) {//如果平均读取流量大于设置的数值,
//计算需要挂起的时间,((state.readBytes/maxReadThroughput)*1000)计算如果按照规定流量读取数据需要多少秒,然后再减去已经读取的时间差,计算出需要挂起的时间
suspendTime = Math.max(0, (state.readBytes * 1000 / maxReadThroughput)
- (firstRead ? 0 : currentTime - state.readStartTime));
state.readBytes = 0;
state.readStartTime = 0;
state.suspendedRead = suspendTime != 0;
//调整会话状态,设置会话不在读取数据,利用定时器,挂起会话,
adjustReadBufferSize(session);
}
}
}
if (suspendTime != 0) {
session.suspendRead();
scheduledExecutor.schedule(new Runnable() {
public void run() {
synchronized (state) {
state.suspendedRead = false;
}
session.resumeRead();
}
}, suspendTime, TimeUnit.MILLISECONDS);
}
}
nextFilter.messageReceived(session, message);
}
/**
*
* 调整session默认设置最大读取字节数组长度
* @param session
*/
private void adjustReadBufferSize(IoSession session) {
int maxReadThroughput = this.maxReadThroughput;
if (maxReadThroughput == 0) {//如果不限制读取流量返回
return;
}
IoSessionConfig config = session.getConfig();
if (config.getReadBufferSize() > maxReadThroughput) {
config.setReadBufferSize(maxReadThroughput);
}
if (config.getMaxReadBufferSize() > maxReadThroughput) {
config.setMaxReadBufferSize(maxReadThroughput);
}
}
@Override
public void messageSent(NextFilter nextFilter, final IoSession session, WriteRequest writeRequest) throws Exception {
//得到当前系统设置最大写出字节
int maxWriteThroughput = this.maxWriteThroughput;
//process the request if our max is greater than zero
if (maxWriteThroughput > 0) {
final State state = (State) session.getAttribute(STATE);
//得到系统时间
long currentTime = System.currentTimeMillis();
//挂起时间长度0
long suspendTime = 0;
boolean firstWrite = false;
synchronized (state) {
state.writtenBytes += messageSizeEstimator.estimateSize(writeRequest.getMessage());
if (!state.suspendedWrite) {
if (state.writeStartTime == 0) {
firstWrite = true;
//初始化写时间
state.writeStartTime = currentTime - 1000;
}
//计算平均写字节数组流量
long throughput = (state.writtenBytes * 1000 / (currentTime - state.writeStartTime));
if (throughput >= maxWriteThroughput) {//写流量超出系统设置,会话挂起操作
//计算会话需要挂起的时间
suspendTime = Math.max(0, state.writtenBytes * 1000 / maxWriteThroughput
- (firstWrite ? 0 : currentTime - state.writeStartTime));
state.writtenBytes = 0;
state.writeStartTime = 0;
state.suspendedWrite = suspendTime != 0;
}
}
}
if (suspendTime != 0) {
log.trace("Suspending write");
//挂起会话
session.suspendWrite();
//定时器执行定时挂起操作
scheduledExecutor.schedule(new Runnable() {
public void run() {
synchronized (state) {
state.suspendedWrite = false;
}
//挂起唤醒会话
session.resumeWrite();
log.trace("Resuming write");
}
}, suspendTime, TimeUnit.MILLISECONDS);
}
}
//串模式,执行下一个过滤器
nextFilter.messageSent(session, writeRequest);
}
/**
* 状态标志
* @author Administrator
*
*/
private static class State {
/**
* 开始读取数据的时间
*/
private long readStartTime;
/**
* 开始写数据的时间
*/
private long writeStartTime;
/**
* 读数据是否已经被挂起,true:挂起,false:未挂起
*/
private boolean suspendedRead;
/**
* 是否写操作被挂起,true:被挂起,false:未被挂起
*/
private boolean suspendedWrite;
/**
* 该会话总共被读取数据字节长度
*/
private long readBytes;
/**
* 总共被写的数据的长度
*/
private long writtenBytes;
}
}