业务处理类继承了ChannelInboundHandlerAdapter
类
通过重载userEventTriggered
方法,可以实现心跳超时的设置
代码如下:
public class ServerHandler extends ChannelInboundHandlerAdapter {
private static final Logger log = LoggerFactory.getLogger(ServerHandler.class);
private ChannelCache channelCache = SpringUtil.getBean(ChannelCache.class);
private static ConcurrentHashMap<ChannelId, Integer> channelIdleTime = new ConcurrentHashMap<ChannelId, Integer>();
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
Message message = (Message) msg;
Result result = new Result();
// 非登录接口,验证是否已登录过
if (message.getModule() != 1) {
if (channelCache.getChannel(ctx.channel()) == null) {
result = new Result(0, "need auth");
ctx.writeAndFlush(result);
return;
}
}
channelCache.addChannel(ctx.channel(), message.getUid());
result = MyAnnotionUtil.process(ctx, message);
log.info("result: " + result.toString());
ctx.writeAndFlush(result);
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
SimpleDateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
IdleStateEvent e = (IdleStateEvent) evt;
if (e.state() == IdleState.READER_IDLE) {
log.warn("---READER_IDLE---" + dateFormat.format(new Date()));
ChannelId channelId = ctx.channel().id();
Integer times = channelIdleTime.get(channelId);
if (times == null) {
channelIdleTime.put(channelId, 1);
} else {
int num = times.intValue() + 1;
if (num >= Const.TIME_OUT_NUM) {
log.error("--- TIME OUT ---");
channelIdleTime.remove(channelId);
channelCache.removeChannel(ctx.channel());
ctx.close();
} else {
channelIdleTime.put(channelId, num);
}
}
}
} else {
super.userEventTriggered(ctx, evt);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
log.error("exceptionCaught:" + cause.getMessage());
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
log.info("====== channelInactive ======");
channelCache.removeChannel(ctx.channel());
ctx.close();
log.info("====== Channel close ======");
}
}
由于ServerHandler
类不是由spring管理的,而是通过new的方式,通过ServerChannelInitializer
进行配置的,所以自定义了一个SpringUtil
工具类,来获取由spring管理的bean
@Component
public class SpringUtil implements ApplicationContextAware {
private static ApplicationContext applicationContext = null;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
if (SpringUtil.applicationContext == null) {
SpringUtil.applicationContext = applicationContext;
}
}
public static ApplicationContext getApplicationContext() {
return applicationContext;
}
public static void setAppCtx(ApplicationContext webAppCtx) {
if (webAppCtx != null) {
applicationContext = webAppCtx;
}
}
/**
* 拿到ApplicationContext对象实例后就可以手动获取Bean的注入实例对象
*/
public static <T> T getBean(Class<T> clazz) {
return getApplicationContext().getBean(clazz);
}
public static <T> T getBean(String name, Class<T> clazz) throws ClassNotFoundException {
return getApplicationContext().getBean(name, clazz);
}
public static final Object getBean(String beanName) {
return getApplicationContext().getBean(beanName);
}
public static final Object getBean(String beanName, String className) throws ClassNotFoundException {
Class<?> clz = Class.forName(className);
return getApplicationContext().getBean(beanName, clz.getClass());
}
public static boolean containsBean(String name) {
return getApplicationContext().containsBean(name);
}
public static boolean isSingleton(String name) throws NoSuchBeanDefinitionException {
return getApplicationContext().isSingleton(name);
}
public static Class<?> getType(String name) throws NoSuchBeanDefinitionException {
return getApplicationContext().getType(name);
}
public static String[] getAliases(String name) throws NoSuchBeanDefinitionException {
return getApplicationContext().getAliases(name);
}
}
MyAnnotionUtil
工具类,是用注解的方式,将数据传给相应的类去处理:
@SuppressWarnings({ "rawtypes", "unchecked" })
@Component
@DependsOn("springUtil")
public class MyAnnotionUtil {
// 控制器类对象
private static Map<Integer, Object> controllerClasses = new HashMap<>();
static {
System.out.println("MyAnnotionUtil static");
// 拿到基础包之后,去得到所有Controller类
List<Class> clzes = ClassUtil.parseAllController("cn.ybt.netty.handler");
// 迭代所有全限定名
for (Class clz : clzes) {
// 判断是否有自定义注解
if (clz.isAnnotationPresent(Module.class)) {
try {
// 获取自定义注解
Module annotation = (Module) clz.getAnnotation(Module.class);
// 获取value值
int value = annotation.module();
Object obj = controllerClasses.get(value);
if (obj == null) {
// obj = applicationContext.getBean(clz);
// obj = clz.newInstance();
obj = SpringUtil.getBean(clz);
controllerClasses.put(value, obj);
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
}
public static Result process(ChannelHandlerContext ctx, Message msg) {
try {
Object obj = controllerClasses.get(msg.getModule());
if (obj == null) {
return new Result(0, "fail", msg.toString());
}
Method method = obj.getClass().getMethod("process", ChannelHandlerContext.class, Message.class);
Object message = method.invoke(obj, ctx, msg);
return (Result) message;
} catch (Exception e) {
return new Result(0, "fail", msg.toString());
}
}
}
ClassUtil
是为了获取到所有handler路径下的类
@SuppressWarnings("rawtypes")
public class ClassUtil {
public static List<Class> parseAllController(String basePackage) {
List<Class> clzes = new ArrayList<>();
String path = basePackage.replace(".", "/");
// 获取此包在磁盘的位置
URL url = Thread.currentThread().getContextClassLoader().getResource(path);
File file = new File(url.getPath());
getClass(file, clzes, basePackage);
return clzes;
}
private static void getClass(File file, List<Class> clzes, String packAgeName) {
// 文件存在
if (file.exists()) {
// 是文件
if (file.isFile()) {
try {
String className = null;
if (packAgeName.contains(".class")) {
className = packAgeName.replace(".class", "");
} else {
className = (packAgeName + "." + file.getName()).replace(".class", "");
}
clzes.add(Class.forName(className));
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
// 是目录
if (file.isDirectory()) {
File[] files = file.listFiles();
for (File f : files) {
String packAge = packAgeName + "." + f.getName();
getClass(f, clzes, packAge);
}
}
}
}
}
Module
注解很简单,定义了一个参数,表示模块号
@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface Module {
// 模块号
int module();
}
为了使所有handler都有process
方法,且参数为ChannelHandlerContext ctx
和Message message
,定义了一个抽象类:
public abstract class BaseHandler {
@Autowired
protected MyRedisService jxRedisService;
@Autowired
protected LoginRedisService loginRedisService;
public abstract Result process(ChannelHandlerContext ctx, Message message);
}
两个实现类的例子,一个是返回心跳成功,一个是返回发送的数据:
@Module(module = 0)
@Component
public class HeartbeatHandler extends BaseHandler {
private static final Logger log = LoggerFactory.getLogger(HeartbeatHandler.class);
@Override
public Result process(ChannelHandlerContext ctx, Message message) {
log.info("heartbeat...");
return new Result(1, "heartbeat success...");
}
}
@Module(module = 99)
@Component
public class EchoHandler extends BaseHandler {
@Override
public Result process(ChannelHandlerContext ctx, Message message) {
return new Result(1, "success...", JSON.toJSONString(message));
}
}
MyRedisService
用来处理用户和通道的关系
LoginRedisService
只负责读取用户登录用的token
public class MyRedisService {
@Resource(name = "myRedisTemplate")
private RedisTemplate<String, Object> redisTemplate;
@Value(value = "${redis.cache.key.prefix}")
private String userKeyPrefix;
@Value(value = "${redis.cache.expireSeconds}")
private long expireSeconds;
public String getUserKeyPrefix() {
return userKeyPrefix;
}
public long getExpireSeconds() {
return expireSeconds;
}
public String flushDb(){
return redisTemplate.execute(new RedisCallback<String>() {
@Override
public String doInRedis(RedisConnection connection) throws DataAccessException {
connection.flushDb();
return "ok";
}
});
}
/**
* 指定缓存失效时间
*
* @param key 键
* @param time 时间(秒)
* @return
*/
public boolean expire(String key, long time) {
try {
if (time > 0) {
redisTemplate.expire(key, time, TimeUnit.SECONDS);
}
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
/**
* 根据key 获取过期时间
*
* @param key 键 不能为null
* @return 时间(秒) 返回0代表为永久有效
*/
public long getExpire(String key) {
return redisTemplate.getExpire(key, TimeUnit.SECONDS);
}
/**
* 判断key是否存在
*
* @param key 键
* @return true 存在 false不存在
*/
public boolean hasKey(String key) {
try {
return redisTemplate.hasKey(key);
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
/**
* 删除缓存
*
* @param key 可以传一个值 或多个
*/
@SuppressWarnings("unchecked")
public void del(String... key) {
if (key != null && key.length > 0) {
if (key.length == 1) {
redisTemplate.delete(key[0]);
} else {
redisTemplate.delete(CollectionUtils.arrayToList(key));
}
}
}
// ============================String=============================
/**
* 普通缓存获取
*
* @param key 键
* @return 值
*/
public Object get(String key) {
return key == null ? null : redisTemplate.opsForValue().get(key);
}
/**
* 普通缓存放入
*
* @param key 键
* @param value 值
* @return true成功 false失败
*/
public boolean set(String key, Object value) {
try {
redisTemplate.opsForValue().set(key, value);
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
/**
* 普通缓存放入并设置时间
*
* @param key 键
* @param value 值
* @param time 时间(秒) time要大于0 如果time小于等于0 将设置无限期
* @return true成功 false 失败
*/
public boolean set(String key, Object value, long time) {
try {
if (time > 0) {
redisTemplate.opsForValue().set(key, value, time, TimeUnit.SECONDS);
} else {
set(key, value);
}
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
/**
* 递增
*
* @param key 键
* @param by 要增加几(大于0)
* @return
*/
public long incr(String key, long delta) {
if (delta < 0) {
throw new RuntimeException("递增因子必须大于0");
}
return redisTemplate.opsForValue().increment(key, delta);
}
/**
* 递减
*
* @param key 键
* @param by 要减少几(小于0)
* @return
*/
public long decr(String key, long delta) {
if (delta < 0) {
throw new RuntimeException("递减因子必须大于0");
}
return redisTemplate.opsForValue().increment(key, -delta);
}
// ================================Map=================================
/**
* HashGet
*
* @param key 键 不能为null
* @param item 项 不能为null
* @return 值
*/
public Object hget(String key, String item) {
return redisTemplate.opsForHash().get(key, item);
}
/**
* 获取hashKey对应的所有键值
*
* @param key 键
* @return 对应的多个键值
*/
public Map<Object, Object> hmget(String key) {
return redisTemplate.opsForHash().entries(key);
}
/**
* HashSet
*
* @param key 键
* @param map 对应多个键值
* @return true 成功 false 失败
*/
public boolean hmset(String key, Map<String, Object> map) {
try {
redisTemplate.opsForHash().putAll(key, map);
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
/**
* HashSet 并设置时间
*
* @param key 键
* @param map 对应多个键值
* @param time 时间(秒)
* @return true成功 false失败
*/
public boolean hmset(String key, Map<String, Object> map, long time) {
try {
redisTemplate.opsForHash().putAll(key, map);
if (time > 0) {
expire(key, time);
}
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
/**
* 向一张hash表中放入数据,如果不存在将创建
*
* @param key 键
* @param item 项
* @param value 值
* @return true 成功 false失败
*/
public boolean hset(String key, String item, Object value) {
try {
redisTemplate.opsForHash().put(key, item, value);
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
/**
* 向一张hash表中放入数据,如果不存在将创建
*
* @param key 键
* @param item 项
* @param value 值
* @param time 时间(秒) 注意:如果已存在的hash表有时间,这里将会替换原有的时间
* @return true 成功 false失败
*/
public boolean hset(String key, String item, Object value, long time) {
try {
redisTemplate.opsForHash().put(key, item, value);
if (time > 0) {
expire(key, time);
}
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
/**
* 删除hash表中的值
*
* @param key 键 不能为null
* @param item 项 可以使多个 不能为null
*/
public void hdel(String key, Object... item) {
redisTemplate.opsForHash().delete(key, item);
}
/**
* 判断hash表中是否有该项的值
*
* @param key 键 不能为null
* @param item 项 不能为null
* @return true 存在 false不存在
*/
public boolean hHasKey(String key, String item) {
return redisTemplate.opsForHash().hasKey(key, item);
}
/**
* hash递增 如果不存在,就会创建一个 并把新增后的值返回
*
* @param key 键
* @param item 项
* @param by 要增加几(大于0)
* @return
*/
public double hincr(String key, String item, double by) {
return redisTemplate.opsForHash().increment(key, item, by);
}
/**
* hash递减
*
* @param key 键
* @param item 项
* @param by 要减少记(小于0)
* @return
*/
public double hdecr(String key, String item, double by) {
return redisTemplate.opsForHash().increment(key, item, -by);
}
// ============================set=============================
/**
* 根据key获取Set中的所有值
*
* @param key 键
* @return
*/
public Set<Object> sGet(String key) {
try {
return redisTemplate.opsForSet().members(key);
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
/**
* 根据value从一个set中查询,是否存在
*
* @param key 键
* @param value 值
* @return true 存在 false不存在
*/
public boolean sHasKey(String key, Object value) {
try {
return redisTemplate.opsForSet().isMember(key, value);
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
/**
* 将数据放入set缓存
*
* @param key 键
* @param values 值 可以是多个
* @return 成功个数
*/
public long sSet(String key, Object... values) {
try {
return redisTemplate.opsForSet().add(key, values);
} catch (Exception e) {
e.printStackTrace();
return 0;
}
}
/**
* 将set数据放入缓存
*
* @param key 键
* @param time 时间(秒)
* @param values 值 可以是多个
* @return 成功个数
*/
public long sSetAndTime(String key, long time, Object... values) {
try {
Long count = redisTemplate.opsForSet().add(key, values);
if (time > 0)
expire(key, time);
return count;
} catch (Exception e) {
e.printStackTrace();
return 0;
}
}
/**
* 获取set缓存的长度
*
* @param key 键
* @return
*/
public long sGetSetSize(String key) {
try {
return redisTemplate.opsForSet().size(key);
} catch (Exception e) {
e.printStackTrace();
return 0;
}
}
/**
* 移除值为value的
*
* @param key 键
* @param values 值 可以是多个
* @return 移除的个数
*/
public long setRemove(String key, Object... values) {
try {
Long count = redisTemplate.opsForSet().remove(key, values);
return count;
} catch (Exception e) {
e.printStackTrace();
return 0;
}
}
// ===============================list=================================
/**
* 获取list缓存的内容
*
* @param key 键
* @param start 开始
* @param end 结束 0 到 -1代表所有值
* @return
*/
public List<Object> lGet(String key, long start, long end) {
try {
return redisTemplate.opsForList().range(key, start, end);
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
/**
* 获取list缓存的长度
*
* @param key 键
* @return
*/
public long lGetListSize(String key) {
try {
return redisTemplate.opsForList().size(key);
} catch (Exception e) {
e.printStackTrace();
return 0;
}
}
/**
* 通过索引 获取list中的值
*
* @param key 键
* @param index 索引 index>=0时, 0 表头,1 第二个元素,依次类推;index<0时,-1,表尾,-2倒数第二个元素,依次类推
* @return
*/
public Object lGetIndex(String key, long index) {
try {
return redisTemplate.opsForList().index(key, index);
} catch (Exception e) {
e.printStackTrace();
return null;
}
}
/**
* 将list放入缓存
*
* @param key 键
* @param value 值
* @param time 时间(秒)
* @return
*/
public boolean lSet(String key, Object value) {
try {
redisTemplate.opsForList().rightPush(key, value);
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
/**
* 将list放入缓存
*
* @param key 键
* @param value 值
* @param time 时间(秒)
* @return
*/
public boolean lSet(String key, Object value, long time) {
try {
redisTemplate.opsForList().rightPush(key, value);
if (time > 0)
expire(key, time);
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
/**
* 将list放入缓存
*
* @param key 键
* @param value 值
* @param time 时间(秒)
* @return
*/
public boolean lSet(String key, List<Object> value) {
try {
redisTemplate.opsForList().rightPushAll(key, value);
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
/**
* 将list放入缓存
*
* @param key 键
* @param value 值
* @param time 时间(秒)
* @return
*/
public boolean lSet(String key, List<Object> value, long time) {
try {
redisTemplate.opsForList().rightPushAll(key, value);
if (time > 0)
expire(key, time);
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
/**
* 根据索引修改list中的某条数据
*
* @param key 键
* @param index 索引
* @param value 值
* @return
*/
public boolean lUpdateIndex(String key, long index, Object value) {
try {
redisTemplate.opsForList().set(key, index, value);
return true;
} catch (Exception e) {
e.printStackTrace();
return false;
}
}
/**
* 移除N个值为value
*
* @param key 键
* @param count 移除多少个
* @param value 值
* @return 移除的个数
*/
public long lRemove(String key, long count, Object value) {
try {
Long remove = redisTemplate.opsForList().remove(key, count, value);
return remove;
} catch (Exception e) {
e.printStackTrace();
return 0;
}
}
}
@Configuration
public class LoginRedisService {
@Resource(name = "loginRedisTemplate")
private RedisTemplate<String, Object> redisTemplate;
@Value("${redis.login.prefix}")
String tokenPre;
// ============================String=============================
/**
* 获取登录用户的token
*
* @param key 键
* @return 值
*/
public Object getToken(int uid) {
String key = tokenPre + uid;
return key == null ? null : redisTemplate.opsForValue().get(key);
}
}
@WebListener
public class NettyRunServletContextListener implements ServletContextListener {
private static final Logger logger = LoggerFactory.getLogger(NettyRunServletContextListener.class);
@Value("${netty.port}")
private int port;
@Value("${netty.url}")
private String url;
@Autowired
private NettyConfig nettyConfig;
@Override
public void contextDestroyed(ServletContextEvent sce) {
System.out.println("====== springboot netty destroy ======");
nettyConfig.destroy();
System.out.println("---test contextDestroyed method---");
}
@Override
public void contextInitialized(ServletContextEvent sce) {
WebApplicationContextUtils.getRequiredWebApplicationContext(sce.getServletContext())
.getAutowireCapableBeanFactory().autowireBean(this);
try {
InetSocketAddress address = new InetSocketAddress(url, port);
ChannelFuture future = nettyConfig.run(address);
logger.info("====== springboot netty start ======");
Runtime.getRuntime().addShutdownHook(new Thread() {
@Override
public void run() {
logger.info("---nettyConfig destroy---");
nettyConfig.destroy();
}
});
future.channel().closeFuture().syncUninterruptibly();
} catch (Exception e) {
logger.error("---springboot netty server start error : ", e.getMessage() + "---");
}
}
}
遗留一个问题,目前这种方式,在使用tomcat的shutdown.sh
指令时,无法正常关闭tomcat,会提示
Using CATALINA_BASE: /usr/local/tomcat
Using CATALINA_HOME: /usr/local/tomcat
Using CATALINA_TMPDIR: /usr/local/tomcat/temp
Using JRE_HOME: /usr/local/jdk1.8.0_66
Using CLASSPATH: /usr/local/tomcat/bin/bootstrap.jar:/usr/local/tomcat/bin/tomcat-juli.jar
Jul 12, 2019 6:51:22 PM org.apache.catalina.startup.Catalina stopServer
SEVERE: Could not contact [localhost:[8080]]. Tomcat may not be running.
Jul 12, 2019 6:51:22 PM org.apache.catalina.startup.Catalina stopServer
SEVERE: Catalina.stop:
java.net.ConnectException: Connection refused
at java.net.PlainSocketImpl.socketConnect(Native Method)
at java.net.AbstractPlainSocketImpl.doConnect(AbstractPlainSocketImpl.java:350)
at java.net.AbstractPlainSocketImpl.connectToAddress(AbstractPlainSocketImpl.java:206)
at java.net.AbstractPlainSocketImpl.connect(AbstractPlainSocketImpl.java:188)
at java.net.SocksSocketImpl.connect(SocksSocketImpl.java:392)
at java.net.Socket.connect(Socket.java:589)
at java.net.Socket.connect(Socket.java:538)
at java.net.Socket.(Socket.java:434)
at java.net.Socket.(Socket.java:211)
at org.apache.catalina.startup.Catalina.stopServer(Catalina.java:504)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:497)
at org.apache.catalina.startup.Bootstrap.stopServer(Bootstrap.java:406)
at org.apache.catalina.startup.Bootstrap.main(Bootstrap.java:498)
可能是由于netty服务端口被占用,无法关闭
使用kill指令时,查看日志:
[2019-07-12 18:58:34.615][INFO ][][ org.my.netty.NettyRunServletContextListener$1.run(NettyRunServletContextListener.java:51)
] ==> ---nettyConfig destroy---
[2019-07-12 18:58:34.617][INFO ][][ org.my.netty.config.NettyConfig.destroy(NettyConfig.java:60)
] ==> Shutdown Netty Server...
[2019-07-12 18:58:34.669][INFO ][][ org.springframework.scheduling.concurrent.ExecutorConfigurationSupport.shutdown(ExecutorConfigurationSupport.java:208)
] ==> Shutting down ExecutorService 'applicationTaskExecutor'
[2019-07-12 18:58:34.790][INFO ][][ io.lettuce.core.EpollProvider.(EpollProvider.java:64)
] ==> Starting with epoll library
[2019-07-12 18:58:34.794][INFO ][][ io.lettuce.core.KqueueProvider.(KqueueProvider.java:70)
] ==> Starting without optional kqueue library
[2019-07-12 18:58:34.951][INFO ][][ org.my.netty.config.NettyConfig.destroy(NettyConfig.java:67)
] ==> Shutdown Netty Server Success!
contextDestroyed
方法并没有执行,只是执行了一个钩子方法,需要研究下怎么正常关闭netty
github: https://github.com/gavinL93/springboot_netty