import cn.hutool.core.bean.BeanUtil;
import cn.hutool.json.JSONUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
/**
* @ClassName WebSocketServer
* @date 2022/07/08 20:18:27
*/
@Component
@ServerEndpoint("/websocket/{bindIp}")
public class WebSocketServer {
private final Logger logger = LoggerFactory.getLogger(this.getClass());
/**
* 在线人数
*/
public static int onlineNumber = 0;
/**
* 以用户名为key,WebSocket为对象保存起来
*/
private static final Map clients = new ConcurrentHashMap();
/**
* 会话
*/
private Session session;
/**
* 主机名
*/
private String bindIp;
/**
* 文件名
*/
private String fileName = null;
/**
* token
*/
private String token;
/**
* 上传路径
*/
private String uploadPath = null;
/**
* 文件大小
*/
private long fileSize = 0L;
private long passedlen = 0L;
private long i = 0;
private FileOutputStream fileOutputStream = null;
/**
* 建立连接
*
* @param bindIp
* @param session
*/
@OnOpen
public void onOpen(@PathParam("bindIp") String bindIp, Session session) {
onlineNumber++;
logger.info("现在来连接的客户端:" + bindIp + "当前在线人数" + onlineNumber);
this.bindIp = bindIp;
this.session = session;
try {
//先给所有人发送通知,说我上线了
//Message message = new Message();
//message.setMsgType(ServerConstant.WEB_SOCKET_TYPE.ONLINE);
//message.setMessage(bindIp);
//sendMessageAll(JSONUtil.toJsonStr(message));
//把自己的信息加入到map当中去
clients.put(bindIp, this);
//给自己发一条消息:告诉自己现在都有谁在线
Message m = new Message();
m.setMsgType(ServerConstant.WEB_SOCKET_TYPE.ONLINEUSER);
//移除掉自己
Set set = clients.keySet();
m.setMsg(set);
sendMessageTo(JSONUtil.toJsonStr(m), bindIp);
} catch (Exception e) {
logger.info(bindIp + "上线的时候通知所有人发生了错误");
}
}
@OnError
public void onError(Session session, Throwable error) {
logger.error("服务端发生了错误", error);
}
/**
* 连接关闭
*/
@OnClose
public void onClose() {
onlineNumber--;
clients.remove(bindIp);
try {
Message message = new Message();
message.setMsgType(ServerConstant.WEB_SOCKET_TYPE.OFFLINE);
message.setMsg(clients.keySet());
message.setFrom(bindIp);
sendMessageAll(JSONUtil.toJsonStr(message));
} catch (IOException e) {
logger.error(bindIp + "下线的时候通知所有人发生了错误", e);
}
logger.info("有连接关闭! 当前在线人数" + onlineNumber);
}
/**
* 收到客户端的消息
*
* @param message 消息
* @param session 会话
*/
@OnMessage
public void onMessage(String message, Session session) {
try {
Message m = JSONUtil.toBean(message, Message.class);
Object textMessage = m.getMsg();
String frombindIp = m.getFrom();
String tobindIp = m.getTo();
ServerConstant.WEB_SOCKET_TYPE msgType = m.getMsgType();
logger.info("来自客户端" + bindIp + "消息,消息类型: " + msgType);
Message msg = new Message();
//心跳检测
if (msgType == ServerConstant.WEB_SOCKET_TYPE.PING) {
msg.setMsgType(ServerConstant.WEB_SOCKET_TYPE.PING);
sendMessageTo(JSONUtil.toJsonStr(msg), frombindIp);
} else if (msgType == ServerConstant.WEB_SOCKET_TYPE.FILE_UPLOAD_START) {
if ("server".equals(tobindIp)) {
Map HashMap = BeanUtil.beanToMap(textMessage);
this.fileName = HashMap.get("fileName").toString();
this.uploadPath = HashMap.get("uploadPath").toString();
this.fileSize = Long.parseLong(HashMap.get("fileSize").toString());
String savePath = uploadPath + File.separatorChar + fileName;
this.fileOutputStream = new FileOutputStream(savePath);
}
} else if (msgType == ServerConstant.WEB_SOCKET_TYPE.FILE_UPLOAD_UPLOADING) {
if ("server".equals(tobindIp)) {
Map HashMap = BeanUtil.beanToMap(textMessage);
byte[] buf = FormatUtil.toByteArray(HashMap.get("byte"));
int off = Integer.parseInt(HashMap.get("off").toString());
int len = Integer.parseInt(HashMap.get("len").toString());
this.fileOutputStream.write(buf, off, len);
this.passedlen += len;
if (this.i < (this.passedlen * 100L / this.fileSize)) {
this.i = this.passedlen * 100L / this.fileSize;
logger.info("文件已经接收: " + this.i + "%");
}
}
} else if (msgType == ServerConstant.WEB_SOCKET_TYPE.FILE_UPLOAD_COMPLETED) {
if ("server".equals(tobindIp)) {
if (this.fileOutputStream != null) {
this.fileOutputStream.close();
this.passedlen = 0L;
this.fileSize = 0L;
this.fileName = null;
this.uploadPath = null;
}
logger.info("文件接收完成");
}
} else {
msg.setMsg(textMessage);
msg.setFrom(frombindIp);
//如果不是发给所有,那么就发给某一个人
msg.setMsgType(ServerConstant.WEB_SOCKET_TYPE.STRING);
if ("All".equals(tobindIp)) {
msg.setTo("All");
sendMessageAll(JSONUtil.toJsonStr(msg));
} else {
msg.setTo(tobindIp);
sendMessageTo(JSONUtil.toJsonStr(msg), tobindIp);
}
}
} catch (Exception e) {
try {
if (fileOutputStream != null) {
fileOutputStream.close();
}
this.passedlen = 0;
this.fileSize = 0;
this.fileName = null;
this.uploadPath = null;
} catch (Exception ex) {
logger.error("发生了错误了", ex);
}
logger.error("发生了错误了", e);
}
}
public void sendMessageTo(String message, String tobindIp) {
for (WebSocketServer item : clients.values()) {
if (item.bindIp.equals(tobindIp)) {
item.session.getAsyncRemote().sendText(message);
break;
}
}
}
public void sendMessageAll(String message) throws IOException {
for (WebSocketServer item : clients.values()) {
item.session.getAsyncRemote().sendText(message);
}
}
public static synchronized int getOnlineCount() {
return onlineNumber;
}
}
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;
@Configuration
public class WebSocketConfig {
@Bean
public ServerEndpointExporter serverEndpointExporter(){
return new ServerEndpointExporter();
}
}
public interface ServerConstant {
/**
* 文件状态enum
*/
public static enum FILE_STATE {
/**
* 文件上传完成
*/
FILE_UPLOAD_COMPLETED,
/**
* 文件上传中
*/
FILE_UPLOAD_UPLOADING,
/**
* 文件上传失败
*/
FILE_UPLOAD_FAILED,
/**
* 文件未找到
*/
FILE_NOT_FIND,
/**
* 文件已存在
*/
FILE_EXSIT,
/**
* 文件准备开始传输
*/
FILE_UPLOAD_START
}
/**
* websocket
*/
public static enum WEB_SOCKET_TYPE{
/**
* 系统消息 由服务器发给客户端
*/
SYSTEM,
/**
* 客户端消息 由客户端发给服务器
*/
USER,
/**
* 上线通知
*/
ONLINE,
/**
* 下线通知
*/
OFFLINE,
/**
* 在线名单
*/
ONLINEUSER,
/**
* 普通消息
*/
STRING,
/**
* 心跳检测
*/
PING,
/**
* 文件准备开始传输
*/
FILE_UPLOAD_START,
/**
* 文件上传完成
*/
FILE_UPLOAD_COMPLETED,
/**
* 文件上传中
*/
FILE_UPLOAD_UPLOADING,
/**
* 文件上传失败
*/
FILE_UPLOAD_FAILED,
/**
* 文件未找到
*/
FILE_NOT_FIND,
/**
* 文件已存在
*/
FILE_EXSIT
}
}
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.ObjectOutputStream;
public class FormatUtil {
/**
* 对象转数组
* @param obj
* @return
*
*/
public static byte[] toByteArray(Object obj) {
byte[] bytes = null;
ByteArrayOutputStream bos = new ByteArrayOutputStream();
try {
ObjectOutputStream oos = new ObjectOutputStream(bos);
oos.writeObject(obj);
oos.flush();
bytes = bos.toByteArray();
oos.close();
bos.close();
} catch (IOException ex) {
ex.printStackTrace();
}
return bytes;
}
}
public class Message {
/**
* 消息内容
*/
private Object msg;
/**
* 消息类型
*/
private ServerConstant.WEB_SOCKET_TYPE msgType;
/**
* 发给谁
*/
private String to;
/**
* 谁发的
*/
private String from;
public Object getMsg() {
return msg;
}
public void setMsg(Object msg) {
this.msg = msg;
}
public ServerConstant.WEB_SOCKET_TYPE getMsgType() {
return msgType;
}
public void setMsgType(ServerConstant.WEB_SOCKET_TYPE msgType) {
this.msgType = msgType;
}
public String getTo() {
return to;
}
public void setTo(String to) {
this.to = to;
}
public String getFrom() {
return from;
}
public void setFrom(String from) {
this.from = from;
}
@Override
public String toString() {
return "Message{" +
"msg=" + msg +
", msgType=" + msgType +
", to='" + to + '\'' +
", from='" + from + '\'' +
'}';
}
}
org.springframework.boot
spring-boot-starter-web
org.springframework.boot
spring-boot-starter
org.springframework.boot
spring-boot-starter-web-services
org.springframework.boot
spring-boot-starter-websocket
org.springframework.boot
spring-boot-starter-tomcat
provided
org.springframework.boot
spring-boot-starter-test
test
com.baomidou
mybatis-plus-core
3.4.2
com.baomidou
mybatis-plus-boot-starter
3.4.2
mysql
mysql-connector-java
runtime
import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.handshake.ServerHandshake;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;
@Slf4j
public class MyWebSocketClient extends WebSocketClient {
/**
* description 客户端连接状态
*/
private boolean isConnect = false;
public void setConnectState(boolean isConnect) {
this.isConnect = isConnect;
}
public boolean getConnectState() {
return this.isConnect;
}
public MyWebSocketClient(URI serverUri) {
super(serverUri);
connect();
}
@Override
public void connect() {
super.connect();
}
@Override
public void onOpen(ServerHandshake serverHandshake) {
log.info("开始连接...");
setConnectState(true);
//for(Iterator it = serverHandshake.iterateHttpFields(); it.hasNext();) {
// String key = it.next();
// System.out.println(key+":"+serverHandshake.getFieldValue(key));
//}
}
@Override
public void onMessage(String s) {
log.info("接收到消息:" + s);
}
/***检测到连接关闭之后,会更新连接状态以及尝试重新连接***/
@Override
public void onClose(int i, String s, boolean b) {
log.info("连接关闭 {}", s);
setConnectState(false);
//rec();
}
/***检测到错误,更新连接状态***/
@Override
public void onError(Exception e) {
log.error("连接错误 ", e);
setConnectState(false);
//rec();
}
/**
* 重连
*/
@Override
public void reconnect() {
super.reconnect();
log.warn("重连中 .................");
}
//public void rec() {
// try{
// if(!getConnectState()){
// int i = 0;
// while(true){
// reconnect();
// if(getConnectState()){
// log.info("连接成功");
// break;
// }
// Thread.sleep(1000);
// }
// }
// }catch (InterruptedException e){
// log.error("",e);
// }
//}
public static void main(String[] args) throws InterruptedException, IOException {
String username = "aaa" + 11;
String uri = "ws://127.0.0.1:10010/websocket/" + username;
MyWebSocketClient myWebSocketClient = new MyWebSocketClient(URI.create(uri));
//必须延时否则可能会无法收到消息
//Thread.sleep(1000);
log.info(uri);
String s = "";
File file = new File("F:\\OS\\cn_windows_server_2019_essentials_x64_dvd_5b386b0b.iso");
//获得要发送文件的长度
long length = file.length();
Message message = new Message();
message.setMsgType(ClientConstant.WEB_SOCKET_TYPE.FILE_UPLOAD_START);
message.setFrom(username);
message.setTo("server");
Map map = new HashMap();
map.put("fileName",file.getName());
map.put("fileSize",length);
map.put("uploadPath", "F:\\");
message.setMsg(map);
s = JSONUtil.toJsonStr(message);
log.info("发送开始: " + s);
myWebSocketClient.send(s);
Message m = new Message();
byte[] buf = new byte[1024];
FileInputStream fileInputStream = new FileInputStream(file.getPath());
int read = 0;
long i = 0;
int passedlen = 0;
while ((read = fileInputStream.read(buf)) != -1) {
m.setMsgType(ClientConstant.WEB_SOCKET_TYPE.FILE_UPLOAD_UPLOADING);
m.setFrom(username);
m.setTo("server");
Map fileMap = new HashMap();
//byte b[], int off, int len
fileMap.put("byte",buf);
fileMap.put("off",0);
fileMap.put("len",read);
m.setMsg(fileMap);
s = JSONUtil.toJsonStr(m);
myWebSocketClient.send(s);
passedlen += read;
if (i < passedlen * 100L / length) {
i = (passedlen * 100L / length);
log.info("已经发送文件: " + i + "%");
}
//不加延时传送大文件会出异常原因未知 小文件可以注释
//Thread.sleep(30);
}
m.setMsgType(ClientConstant.WEB_SOCKET_TYPE.FILE_UPLOAD_COMPLETED);
m.setFrom(username);
m.setTo("server");
m.setMsg("");
s = JSONUtil.toJsonStr(m);
myWebSocketClient.send(s);
log.info("发送完成: " + file.getPath());
//while (true) {
// if(myWebSocketClient.getConnectState()){
// //Message message = new Message();
// //message.setMsgType(ClientConstant.WEB_SOCKET_TYPE.PING);
// //message.setFrom(username);
// //s = JSONUtil.toJsonStr(message);
// //myWebSocketClient.send(s);
//
// }
// log.info(s);
// Thread.sleep(3000);
//}
}
}
org.java-websocket
Java-WebSocket
1.4.0
cn.hutool
hutool-all
5.7.19
ClientConstant 和 Message类和服务端相同这里不再列出