基于 WebSocket + MongoDB + MySQL + Redis
MySQL
用来存储用户基本信息MongoDB
用来存放用户聊天信息Redis
用来存储处理过期信息
WebSocket
是应用层第七层上的一个应用层协议,它必须依赖 HTTP 协议进行一次握手。
握手成功后,数据就直接从TCP
通道传输,与HTTP
无关了。即:WebSocket
分为握手和数据传输阶段。
即进行了 HTTP 握手 + 双工的 TCP 连接。
全双工通信
的协议。https://goproxy.cn,direct
创建main.go
文件
创建管理依赖包文件
go mod init IM
创建文件夹
ini
驱动go get gopkg.in/ini.v1
redis
驱动go get github.com/go-redis/redis
go get github.com/jinzhu/gorm/dialects/mysql
gorm
go get github.com/jinzhu/gorm
gin
框架go get github.com/gin-gonic/gin
MongoDB
驱动go get go.mongodb.org/mongo-driver/mongo
go get go.mongodb.org/mongo-driver/mongo/options
go get github.com/sirupsen/logrus
导入websocket
go get github.com/gorilla/websocket
创建 conf.go
文件
导入MongoDB
驱动
go get go.mongodb.org/mongo-driver/mongo
go get go.mongodb.org/mongo-driver/mongo/options
导入ini
驱动
go get gopkg.in/ini.v1
conf.go
文件内容:
package conf
import (
"IM/model"
"context"
"fmt"
logging "github.com/sirupsen/logrus"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"gopkg.in/ini.v1"
"strings"
)
var (
MongoDBClient *mongo.Client
AppMode string
HttpPort string
Db string
DbHost string
DbPort string
DbUser string
DbPassword string
DbName string
MongoDBPort string
MongoDBHost string
MongoDBName string
MongoDBPassword string
)
func Init() {
//从本地读取环境
file, err := ini.Load("./conf/config.ini")
if err != nil {
fmt.Println("加载ini文件失败", err)
}
LoadServer(file)
LoadMySQL(file)
LoadMongoDB(file)
MongoDB() //MongoDB连接
path := strings.Join([]string{DbUser, ":", DbPassword, "@tcp(", DbHost, ":", DbPort, ")/", DbName, "?charset=utf8mb4&parseTime=true"}, "")
model.Database(path) //数据库连接
}
// MongoDB连接
func MongoDB() {
clientOptions := options.Client().ApplyURI("mongodb://" + MongoDBHost + ":" + MongoDBPort)
var err error
MongoDBClient, err = mongo.Connect(context.TODO(), clientOptions)
if err != nil {
logging.Info(err)
panic(err)
}
logging.Info("MongoDB 连接成功")
}
func LoadServer(file *ini.File) {
AppMode = file.Section("service").Key("AppMode").String()
HttpPort = file.Section("service").Key("HttpPort").String()
}
func LoadMySQL(file *ini.File) {
Db = file.Section("mysql").Key("Db").String()
DbHost = file.Section("mysql").Key("DbHost").String()
DbPort = file.Section("mysql").Key("DbPort").String()
DbUser = file.Section("mysql").Key("DbUser").String()
DbPassword = file.Section("mysql").Key("DbPassword").String()
DbName = file.Section("mysql").Key("DbName").String()
}
func LoadMongoDB(file *ini.File) {
MongoDBPort = file.Section("MongoDB").Key("MongoDBPort").String()
MongoDBHost = file.Section("MongoDB").Key("MongoDBHost").String()
MongoDBName = file.Section("MongoDB").Key("MongoDBName").String()
MongoDBPassword = file.Section("MongoDB").Key("MongoDBPassword").String()
创建 config.ini
文件
#debug开发模式, release生产模式
[service]
AppMode=debug
HttpPort=:3000
[mysql]
Db=mysql
DbHost=127.0.0.1
DbPort=3306
DbUser=root
DbPassword=123456
DbName=IM
[redis]
RedisDb=redis
RedisHost=127.0.0.1
RedisPort=6379
RedisPassword=123456
RedisDbName=2
[MongoDB]
MongoDBPort=27017
MongoDBHost=localhost
MongoDBName=userV1
MongoDBPassword=root
创建 common.go
文件
导入redis
驱动
go get github.com/go-redis/redis
导入日志包
go get github.com/sirupsen/logrus
common.go
文件内容:
package cache
import (
"fmt"
"github.com/go-redis/redis"
logging "github.com/sirupsen/logrus"
"gopkg.in/ini.v1"
"strconv"
)
var (
RedisClient *redis.Client
RedisDb string
RedisHost string
RedisPort string
RedisPassword string
RedisDbName string
)
func init() {
file, err := ini.Load("./conf/config.ini") //加载配置信息文件
if err != nil {
fmt.Println("加载redis ini文件失败", err)
}
LoadRedis(file) //读取配置信息文件内容
Redis() //连接redis
}
// redis加载
func LoadRedis(file *ini.File) {
RedisDb = file.Section("redis").Key("RedisDb").String()
RedisHost = file.Section("redis").Key("RedisHost").String()
RedisPort = file.Section("redis").Key("RedisPort").String()
RedisPassword = file.Section("redis").Key("RedisPassword").String()
RedisDbName = file.Section("redis").Key("RedisDbName").String()
}
// redis连接
func Redis() {
db, _ := strconv.ParseUint(RedisDbName, 10, 64)
client := redis.NewClient(&redis.Options{
Addr: fmt.Sprintf("%s:%s", RedisHost, RedisPort),
DB: int(db),
Password: RedisPassword,
})
_, err := client.Ping().Result()
if err != nil {
logging.Info(err)
panic(err)
}
RedisClient = client
}
创建 init.go
文件
导入数据库驱动
go get github.com/jinzhu/gorm/dialects/mysql
导入gorm
go get github.com/jinzhu/gorm
导入gin
框架
go get github.com/gin-gonic/gin
init.go
文件内容:
package model
import (
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
"time"
)
var DB *gorm.DB
func Database(connstring string) {
db, err := gorm.Open("mysql", connstring)
if err != nil {
panic("mysql数据库连接错误")
}
db.LogMode(true)
//如果是发行版就不用输出日志
if gin.Mode() == "release" {
db.LogMode(false)
}
db.SingularTable(true) //表名不加s,user
db.DB().SetMaxIdleConns(20) //设置连接池
db.DB().SetMaxOpenConns(100) //最大连接数
db.DB().SetConnMaxLifetime(time.Second * 30) //连接时间
DB = db
}
创建 common.go
文件
common.go
文件内容:
package serializer
/*
错误信息序列化
*/
// Response 基础序列化器
type Response struct {
Status int `json:"status"`
Data interface{} `json:"data"`
Msg string `json:"msg"`
Error string `json:"error"`
}
创建 user.go
文件
user.go
文件内容:
package model
import (
"github.com/jinzhu/gorm"
"golang.org/x/crypto/bcrypt"
)
type User struct {
gorm.Model
UserName string `gorm:"unique"`
PassWord string
}
const (
PassWordCost = 12 //密码加密难度
)
// SetPassWord 设置密码
func (user *User) SetPassWord(password string) error {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), PassWordCost)
if err != nil {
return err
}
user.PassWord = string(bytes)
return nil
}
// CheckPassWord 校验密码
func (user *User) CheckPassWord(password string) bool {
err := bcrypt.CompareHashAndPassword([]byte(user.PassWord), []byte(password))
return err == nil
}
创建 migration.go
文件
migration.go
文件内容:
package model
// 迁移
func migration() {
DB.Set("gorm:table_options", "charset=utf8mb4").AutoMigrate(&User{})
}
在model层init.go
最后加migration()
package model
import (
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql"
"time"
)
var DB *gorm.DB
func Database(connstring string) {
db, err := gorm.Open("mysql", connstring)
if err != nil {
panic("mysql数据库连接错误")
}
db.LogMode(true)
//如果是发行版就不用输出日志
if gin.Mode() == "release" {
db.LogMode(false)
}
db.SingularTable(true) //表名不加s,user
db.DB().SetMaxIdleConns(20) //设置连接池
db.DB().SetMaxOpenConns(100) //最大连接数
db.DB().SetConnMaxLifetime(time.Second * 30) //连接时间
DB = db
//迁移
migration()
}
创建 user.go
文件
user.go
文件内容:
package service
import (
"IM/model"
"IM/serializer"
)
type UserRegisterService struct {
UserName string `json:"user_name" form:"user_name"`
PassWord string `json:"password" form:"password"`
}
// 用户注册
func (service *UserRegisterService) Register() serializer.Response {
var user model.User
count := 0
model.DB.Model(&model.User{}).Where("user_name=?", service.UserName).First(&user).Count(&count)
if count != 0 {
return serializer.Response{
Status: 400,
Msg: "用户名已经存在了",
}
}
user = model.User{
UserName: service.UserName,
}
//密码加密
if err := user.SetPassWord(service.PassWord); err != nil {
return serializer.Response{
Status: 500,
Msg: "加密出错了",
}
}
model.DB.Create(&user)
return serializer.Response{
Status: 200,
Msg: "创建成功",
}
}
创建 common.go
文件
common.go
文件内容:
package api
/*
返回错误信息
*/
import (
"IM/serializer"
"encoding/json"
"fmt"
"github.com/go-playground/validator/v10"
)
// 返回错误信息 ErrorResponse
func ErrorResponse(err error) serializer.Response {
if _, ok := err.(validator.ValidationErrors); ok {
return serializer.Response{
Status: 400,
Msg: "错误参数",
Error: fmt.Sprint(err),
}
}
if _, ok := err.(*json.UnmarshalTypeError); ok {
return serializer.Response{
Status: 400,
Msg: "JSON类型不匹配",
Error: fmt.Sprint(err),
}
}
return serializer.Response{
Status: 400,
Msg: "参数错误",
Error: fmt.Sprint(err),
}
}
创建 user.go
文件
user.go
文件内容:
package api
import (
"IM/service"
"github.com/gin-gonic/gin"
logging "github.com/sirupsen/logrus"
)
// 用户注册
func UserRegister(c *gin.Context) {
var userRegisterService service.UserRegisterService
if err := c.ShouldBind(&userRegisterService); err == nil {
res := userRegisterService.Register()
c.JSON(200, res)
} else {
c.JSON(400, ErrorResponse(err))
logging.Info(err)
}
}
创建 router.go
文件
router.go
文件内容:
package router
import (
"IM/api"
"github.com/gin-gonic/gin"
)
func NewRouter() *gin.Engine {
r := gin.Default()
//Recovery 中间件会恢复(recovers) 任何恐慌(panics)
//如果存在恐慌中间件将会写入500
//因为当你程序里有些异常情况你没考虑到的时候,程序就退出了,服务就停止了
//Logger日志
r.Use(gin.Recovery(), gin.Logger())
v1 := r.Group("/")
{
//测试是否成功
v1.GET("ping", func(c *gin.Context) {
c.JSON(200, "成功")
})
//用户注册
v1.POST("user/register", api.UserRegister)
}
return r
}
package main
import (
"IM/conf"
"IM/router"
)
func main() {
//测试初始化
conf.Init()
//启动路由
r := router.NewRouter()
_ = r.Run(conf.HttpPort)
}
导入websocket
go get github.com/gorilla/websocket
创建 ws.go
文件
ws.go
文件内容:
package service
import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
)
const month = 60 * 60 * 24 * 30 //一个月30天
// 发送消息的结构体
type SendMsg struct {
Type int `json:"type"`
Content string `json:"content"`
}
// 回复消息的结构体
type ReplyMsg struct {
From string `json:"from"`
Code int `json:"code"`
Content string `json:"content"`
}
// 用户结构体
type Client struct {
ID string //用户id
SendID string //接收id
Socket *websocket.Conn //Socket连接
Send chan []byte //发送的信息
}
// 广播类(包括广播内容和源用户)
type Broadcast struct {
Client *Client
Message []byte
Type int
}
// 用户管理
type ClientManager struct {
Clients map[string]*Client
Broadcast chan *Broadcast //广播
Reply chan *Client
Register chan *Client //已注册
Unregister chan *Client //未注册
}
// 信息转JSON(包括:发送者、接收者、内容)
type Message struct {
Sender string `json:"sender,omitempty"` //发送者
Recipient string `json:"recipient,omitempty"` //接收者
Content string `json:"content,omitempty"` //内容
}
// 初始化一个全局管理Manager
var Manager = ClientManager{
Clients: make(map[string]*Client), // 参与连接的用户,出于性能的考虑,需要设置最大连接数
Broadcast: make(chan *Broadcast),
Register: make(chan *Client),
Reply: make(chan *Client),
Unregister: make(chan *Client),
}
func CreateID(uid, toUid string) string {
return uid + "->" + toUid //1->2
}
func Handler(c *gin.Context) {
uid := c.Query("uid")
toUid := c.Query("toUid")
conn, err := (&websocket.Upgrader{
//跨域
CheckOrigin: func(r *http.Request) bool {
return true
}}).Upgrade(c.Writer, c.Request, nil) //升级ws协议
if err != nil {
http.NotFound(c.Writer, c.Request)
return
}
//创建一个用户实例
client := &Client{
ID: CreateID(uid, toUid), //发送方 1发送给2
SendID: CreateID(toUid, uid), //接收方 2接收到1
Socket: conn, //Socket连接
Send: make(chan []byte), //发送的信息
}
//用户注册到用户管理上
Manager.Register <- client
go client.Read()
go client.Write()
}
// 读操作
func (c *Client) Read() {
}
// 写操作
func (c *Client) Write() {
}
在router.go
文件中添加:
//升级WebSocket协议
v1.GET("ws", service.Handler)
完整内容:
package router
import (
"IM/api"
"IM/service"
"github.com/gin-gonic/gin"
)
func NewRouter() *gin.Engine {
r := gin.Default()
//Recovery 中间件会恢复(recovers) 任何恐慌(panics)
//如果存在恐慌中间件将会写入500
//因为当你程序里有些异常情况你没考虑到的时候,程序就退出了,服务就停止了
//Logger日志
r.Use(gin.Recovery(), gin.Logger())
v1 := r.Group("/")
{
//测试是否成功
v1.GET("ping", func(c *gin.Context) {
c.JSON(200, "成功")
})
//用户注册
v1.POST("user/register", api.UserRegister)
//升级WebSocket协议
v1.GET("ws", service.Handler)
}
return r
}
创建 code.go
文件
code.go
文件内容:
package e
const (
SUCCESS = 200
UpdatePasswordSuccess = 201 //密码成功
NotExistInentifier = 202 //未绑定
ERROR = 500 //失败
InvalidParams = 400 //请求参数错误
ErrorDatabase = 40001 //数据库操作错误
WebsocketSuccessMessage = 50001 //解析content内容信息
WebsocketSuccess = 50002 //请求历史纪录操作成功
WebsocketEnd = 50003 //请求没有更多历史记录
WebsocketOnlineReply = 50004 //在线应答
WebsocketOfflineReply = 50005 //离线回答
WebsocketLimit = 50006 //请求受到限制
)
创建 msg.go
文件
msg.go
文件内容:
package e
var MsgFlags = map[int]string{
SUCCESS: "ok",
UpdatePasswordSuccess: "修改密码成功",
NotExistInentifier: "该第三方账号未绑定",
ERROR: "失败",
InvalidParams: "请求参数错误",
ErrorDatabase: "数据库操作出错,请重试",
WebsocketSuccessMessage: "解析content内容信息",
WebsocketSuccess: "发送信息,请求历史纪录操作成功",
WebsocketEnd: "请求历史纪录,但没有更多记录了",
WebsocketOnlineReply: "针对回复信息在线应答成功",
WebsocketOfflineReply: "针对回复信息离线回答成功",
WebsocketLimit: "请求受到限制",
}
// GetMsg 获取状态码对应信息
func GetMsg(code int) string {
msg, ok := MsgFlags[code]
if ok {
return msg
}
return MsgFlags[ERROR]
}
ws.go
中的Read()
操作:
// 读操作
func (c *Client) Read() {
//结束时关闭Socket
defer func() {
//用户结构体变成未注册状态
Manager.Unregister <- c
//关闭Socket
_ = c.Socket.Close()
}()
for {
c.Socket.PongHandler()
sendMsg := new(SendMsg)
//序列化
//如果传过来是String类型,用这个接收: c.Socket.ReadMessage()
err := c.Socket.ReadJSON(&sendMsg)
if err != nil {
fmt.Println("数据格式不正确", err)
Manager.Unregister <- c
_ = c.Socket.Close()
break
}
if sendMsg.Type == 1 { // 设置1为发送消息
r1, _ := cache.RedisClient.Get(c.ID).Result() //1->2 查看缓存里有没有发送方id
r2, _ := cache.RedisClient.Get(c.SendID).Result() //2->1 查看缓存里有没有接收方id
if r1 > "3" && r2 == "" { //1给2发消息,发了三条,但是2没有回,或者没有看到,就停止1发送。防止骚扰
replyMsg := ReplyMsg{
Code: e.WebsocketLimit,
Content: e.GetMsg(e.WebsocketLimit),
}
msg, _ := json.Marshal(replyMsg) //序列化
_ = c.Socket.WriteMessage(websocket.TextMessage, msg)
continue
} else {
//存储到redis中
cache.RedisClient.Incr(c.ID)
_, _ = cache.RedisClient.Expire(c.ID, time.Hour*24*30*3).Result() //防止过快“分手”,建立连接三个月过期
}
log.Println(c.ID, "发送消息", sendMsg.Content)
//广播出去
Manager.Broadcast <- &Broadcast{
Client: c,
Message: []byte(sendMsg.Content), //发送过来的消息
}
}
}
}
ws.go
中的Write()
操作:
// 写操作
func (c *Client) Write() {
defer func() {
_ = c.Socket.Close()
}()
for {
select {
case message, ok := <-c.Send:
if !ok {
_ = c.Socket.WriteMessage(websocket.CloseMessage, []byte{})
return
}
log.Println(c.ID, "接受消息:", string(message))
replyMsg := ReplyMsg{
Code: e.WebsocketSuccessMessage,
Content: fmt.Sprintf("%s", string(message)),
}
msg, _ := json.Marshal(replyMsg)
_ = c.Socket.WriteMessage(websocket.TextMessage, msg)
}
}
}
全部:
package service
import (
"IM/cache"
"IM/pkg/e"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"log"
"net/http"
"time"
)
const month = 60 * 60 * 24 * 30 //一个月30天
// 发送消息的结构体
type SendMsg struct {
Type int `json:"type"`
Content string `json:"content"`
}
// 回复消息的结构体
type ReplyMsg struct {
From string `json:"from"`
Code int `json:"code"`
Content string `json:"content"`
}
// 用户结构体
type Client struct {
ID string //用户id
SendID string //接收id
Socket *websocket.Conn //Socket连接
Send chan []byte //发送的信息
}
// 广播类(包括广播内容和源用户)
type Broadcast struct {
Client *Client
Message []byte
Type int
}
// 用户管理
type ClientManager struct {
Clients map[string]*Client
Broadcast chan *Broadcast //广播
Reply chan *Client
Register chan *Client //已注册
Unregister chan *Client //未注册
}
// 信息转JSON(包括:发送者、接收者、内容)
type Message struct {
Sender string `json:"sender,omitempty"` //发送者
Recipient string `json:"recipient,omitempty"` //接收者
Content string `json:"content,omitempty"` //内容
}
// 初始化一个全局管理Manager
var Manager = ClientManager{
Clients: make(map[string]*Client), // 参与连接的用户,出于性能的考虑,需要设置最大连接数
Broadcast: make(chan *Broadcast),
Register: make(chan *Client),
Reply: make(chan *Client),
Unregister: make(chan *Client),
}
func CreateID(uid, toUid string) string {
return uid + "->" + toUid //1->2
}
func Handler(c *gin.Context) {
uid := c.Query("uid")
toUid := c.Query("toUid")
conn, err := (&websocket.Upgrader{
//跨域
CheckOrigin: func(r *http.Request) bool {
return true
}}).Upgrade(c.Writer, c.Request, nil) //升级ws协议
if err != nil {
http.NotFound(c.Writer, c.Request)
return
}
//创建一个用户实例
client := &Client{
ID: CreateID(uid, toUid), //发送方 1发送给2
SendID: CreateID(toUid, uid), //接收方 2接收到1
Socket: conn, //Socket连接
Send: make(chan []byte), //发送的信息
}
//用户注册到用户管理上
Manager.Register <- client
go client.Read()
go client.Write()
}
// 读操作
func (c *Client) Read() {
//结束时关闭Socket
defer func() {
//用户结构体变成未注册状态
Manager.Unregister <- c
//关闭Socket
_ = c.Socket.Close()
}()
for {
c.Socket.PongHandler()
sendMsg := new(SendMsg)
//序列化
//如果传过来是String类型,用这个接收: c.Socket.ReadMessage()
err := c.Socket.ReadJSON(&sendMsg)
if err != nil {
fmt.Println("数据格式不正确", err)
Manager.Unregister <- c
_ = c.Socket.Close()
break
}
if sendMsg.Type == 1 { // 设置1为发送消息
r1, _ := cache.RedisClient.Get(c.ID).Result() //1->2 查看缓存里有没有发送方id
r2, _ := cache.RedisClient.Get(c.SendID).Result() //2->1 查看缓存里有没有接收方id
if r1 > "3" && r2 == "" { //1给2发消息,发了三条,但是2没有回,或者没有看到,就停止1发送。防止骚扰
replyMsg := ReplyMsg{
Code: e.WebsocketLimit,
Content: e.GetMsg(e.WebsocketLimit),
}
msg, _ := json.Marshal(replyMsg) //序列化
_ = c.Socket.WriteMessage(websocket.TextMessage, msg)
continue
} else {
//存储到redis中
cache.RedisClient.Incr(c.ID)
_, _ = cache.RedisClient.Expire(c.ID, time.Hour*24*30*3).Result() //防止过快“分手”,建立连接三个月过期
}
log.Println(c.ID, "发送消息", sendMsg.Content)
//广播出去
Manager.Broadcast <- &Broadcast{
Client: c,
Message: []byte(sendMsg.Content), //发送过来的消息
}
}
}
}
// 写操作
func (c *Client) Write() {
defer func() {
_ = c.Socket.Close()
}()
for {
select {
case message, ok := <-c.Send:
if !ok {
_ = c.Socket.WriteMessage(websocket.CloseMessage, []byte{})
return
}
log.Println(c.ID, "接受消息:", string(message))
replyMsg := ReplyMsg{
Code: e.WebsocketSuccessMessage,
Content: fmt.Sprintf("%s", string(message)),
}
msg, _ := json.Marshal(replyMsg)
_ = c.Socket.WriteMessage(websocket.TextMessage, msg)
}
}
}
创建 start.go
文件
start.go
文件内容:
package service
import (
"IM/pkg/e"
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
)
func (manager *ClientManager) Start() {
for {
fmt.Println("<---监听管道通信--->")
select {
case conn := <-Manager.Register: // 建立连接
fmt.Printf("建立新连接: %v", conn.ID)
Manager.Clients[conn.ID] = conn //把连接放到用户管理上
replyMsg := ReplyMsg{
Code: e.WebsocketSuccess,
Content: "已连接至服务器",
}
msg, _ := json.Marshal(replyMsg)
_ = conn.Socket.WriteMessage(websocket.TextMessage, msg)
}
}
}
添加go service.Manager.Start()
package main
import (
"IM/conf"
"IM/router"
"IM/service"
)
func main() {
//测试初始化
conf.Init()
//监听管道
go service.Manager.Start()
//启动路由
r := router.NewRouter()
_ = r.Run(conf.HttpPort)
}
创建ws文件夹
创建 trainer.go
文件
trainer.go
文件内容:
package ws
// 插入进MongoDB的数据类型
type Trainer struct {
Content string `bson:"content"` // 内容
StartTime int64 `bson:"startTime"` // 创建时间
EndTime int64 `bson:"endTime"` // 过期时间
Read uint `bson:"read"` // 已读
}
创建 find.go
文件
find.go
文件内容:
package service
import (
"IM/conf"
"IM/model/ws"
"context"
"time"
)
func InsertMsg(database, id string, content string, read uint, expire int64) error {
//插入到mongoDB中
collection := conf.MongoDBClient.Database(database).Collection(id) //没有这个id集合的话,创建这个id集合
comment := ws.Trainer{
Content: content,
StartTime: time.Now().Unix(),
EndTime: time.Now().Unix() + expire,
Read: read,
}
_, err := collection.InsertOne(context.TODO(), comment)
return err
}
添加断开连接和广播功能
package service
import (
"IM/conf"
"IM/pkg/e"
"encoding/json"
"fmt"
"github.com/gorilla/websocket"
)
func (manager *ClientManager) Start() {
for {
fmt.Println("<---监听管道通信--->")
select {
case conn := <-Manager.Register: // 建立连接
fmt.Printf("建立新连接: %v", conn.ID)
Manager.Clients[conn.ID] = conn //把连接放到用户管理上
replyMsg := &ReplyMsg{
Code: e.WebsocketSuccess,
Content: "已连接至服务器",
}
msg, _ := json.Marshal(replyMsg)
_ = conn.Socket.WriteMessage(websocket.TextMessage, msg)
case conn := <-Manager.Unregister: //断开连接
fmt.Printf("连接失败%s", conn.ID)
if _, ok := Manager.Clients[conn.ID]; ok {
replyMsg := &ReplyMsg{
Code: e.WebsocketEnd,
Content: "连接中断",
}
msg, _ := json.Marshal(replyMsg)
_ = conn.Socket.WriteMessage(websocket.TextMessage, msg)
close(conn.Send)
delete(Manager.Clients, conn.ID)
}
case broadcast := <-Manager.Broadcast: //1->2
message := broadcast.Message
sendId := broadcast.Client.SendID //2->1
flag := false //默认对方是不在线的
for id, conn := range Manager.Clients {
if id != sendId {
continue
}
select {
case conn.Send <- message:
flag = true
default:
close(conn.Send)
delete(Manager.Clients, conn.ID)
}
}
id := broadcast.Client.ID //1->2
if flag {
fmt.Println("对方在线")
replyMsg := &ReplyMsg{
Code: e.WebsocketOnlineReply,
Content: "对方在线应答",
}
msg, _ := json.Marshal(replyMsg)
_ = broadcast.Client.Socket.WriteMessage(websocket.TextMessage, msg)
/*
把消息插入到MongoDB中:
1代表已读(只要用户在线就判断已读)
int64(3*month):过期时间
*/
err := InsertMsg(conf.MongoDBName, id, string(message), 1, int64(3*month))
if err != nil {
fmt.Println("插入一条消息失败", err)
}
} else {
fmt.Println("对方不在线")
replyMsg := &ReplyMsg{
Code: e.WebsocketOfflineReply,
Content: "对方不在线应答",
}
msg, err := json.Marshal(replyMsg)
_ = broadcast.Client.Socket.WriteMessage(websocket.TextMessage, msg)
err = InsertMsg(conf.MongoDBName, id, string(message), 0, int64(3*month))
if err != nil {
fmt.Println("插入一条消息失败", err)
}
}
}
}
}
trainer.go
文件内容:
package ws
// 插入进MongoDB的数据类型
type Trainer struct {
Content string `bson:"content"` // 内容
StartTime int64 `bson:"startTime"` // 创建时间
EndTime int64 `bson:"endTime"` // 过期时间
Read uint `bson:"read"` // 已读
}
type Result struct {
StartTime int64
Msg string
Content interface{}
From string
}
find.go
文件内容:
package service
import (
"IM/conf"
"IM/model/ws"
"context"
"fmt"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/options"
"sort"
"time"
)
// 排序用的结构体
type SendSortMsg struct {
Content string `json:"content"`
Read uint `json:"read"`
CreateAt int64 `json:"create_at"`
}
// 插入数据到mongoDB中
func InsertMsg(database, id string, content string, read uint, expire int64) error {
//插入到mongoDB中
collection := conf.MongoDBClient.Database(database).Collection(id) //没有这个id集合的话,创建这个id集合
comment := ws.Trainer{
Content: content,
StartTime: time.Now().Unix(),
EndTime: time.Now().Unix() + expire,
Read: read,
}
_, err := collection.InsertOne(context.TODO(), comment)
return err
}
// 获取历史消息
func FindMany(database string, sendId string, id string, time int64, pageSize int) (results []ws.Result, err error) {
var resultsMe []ws.Trainer //id
var resultsYou []ws.Trainer //sendId
sendIdCollection := conf.MongoDBClient.Database(database).Collection(sendId)
idCollection := conf.MongoDBClient.Database(database).Collection(id)
sendIdTimeCursor, err := sendIdCollection.Find(context.TODO(),
//顺序执行
bson.D{},
//限制大小
options.Find().SetLimit(int64(pageSize)))
idTimeCursor, err := idCollection.Find(context.TODO(),
//顺序执行
bson.D{},
//限制大小
options.Find().SetLimit(int64(pageSize)))
//sort.Slice(results, func(i, j int) bool { return results[i].StartTime < results[j].StartTime })
err = idTimeCursor.All(context.TODO(), &resultsMe) // Id 发给对面的
err = sendIdTimeCursor.All(context.TODO(), &resultsYou) // sendId 对面发过来的
results, _ = AppendAndSort(resultsMe, resultsYou)
return
}
func AppendAndSort(resultsMe, resultsYou []ws.Trainer) (results []ws.Result, err error) {
for _, r := range resultsMe {
sendSort := SendSortMsg{ //构造返回的msg
Content: r.Content,
Read: r.Read,
CreateAt: r.StartTime,
}
result := ws.Result{ //构造返回所有的内容,包括传送者
StartTime: r.StartTime,
Msg: fmt.Sprintf("%v", sendSort),
From: "me",
}
results = append(results, result)
}
for _, r := range resultsYou {
sendSort := SendSortMsg{
Content: r.Content,
Read: r.Read,
CreateAt: r.StartTime,
}
result := ws.Result{
StartTime: r.StartTime,
Msg: fmt.Sprintf("%v", sendSort),
From: "you",
}
results = append(results, result)
}
// 进行排序
sort.Slice(results, func(i, j int) bool { return results[i].StartTime < results[j].StartTime })
return results, nil
}
在读操作里面写历史消息
ws.go
文件增加内容:
else if sendMsg.Type == 2 { //拉取历史消息
timeT, err := strconv.Atoi(sendMsg.Content) // string转int64
if err != nil {
timeT = 999999999
}
results, _ := FindMany(conf.MongoDBName, c.SendID, c.ID, int64(timeT), 10) //获取10条历史消息
//大于10条消息
if len(results) > 10 {
results = results[:10]
} else if len(results) == 0 { //0条信息
replyMsg := ReplyMsg{
Code: e.WebsocketEnd,
Content: "到底了",
}
msg, _ := json.Marshal(replyMsg)
_ = c.Socket.WriteMessage(websocket.TextMessage, msg)
continue
}
//如果是1到10条消息时
for _, result := range results {
replyMsg := ReplyMsg{
From: result.From,
Content: result.Msg,
}
msg, _ := json.Marshal(replyMsg)
_ = c.Socket.WriteMessage(websocket.TextMessage, msg)
}
}
ws.go
文件完整内容:
package service
import (
"IM/cache"
"IM/conf"
"IM/pkg/e"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"log"
"net/http"
"strconv"
"time"
)
const month = 60 * 60 * 24 * 30 //一个月30天
// 发送消息的结构体
type SendMsg struct {
Type int `json:"type"`
Content string `json:"content"`
}
// 回复消息的结构体
type ReplyMsg struct {
From string `json:"from"`
Code int `json:"code"`
Content string `json:"content"`
}
// 用户结构体
type Client struct {
ID string //用户id
SendID string //接收id
Socket *websocket.Conn //Socket连接
Send chan []byte //发送的信息
}
// 广播类(包括广播内容和源用户)
type Broadcast struct {
Client *Client
Message []byte
Type int
}
// 用户管理
type ClientManager struct {
Clients map[string]*Client
Broadcast chan *Broadcast //广播
Reply chan *Client
Register chan *Client //已注册
Unregister chan *Client //未注册
}
// 信息转JSON(包括:发送者、接收者、内容)
type Message struct {
Sender string `json:"sender,omitempty"` //发送者
Recipient string `json:"recipient,omitempty"` //接收者
Content string `json:"content,omitempty"` //内容
}
// 初始化一个全局管理Manager
var Manager = ClientManager{
Clients: make(map[string]*Client), // 参与连接的用户,出于性能的考虑,需要设置最大连接数
Broadcast: make(chan *Broadcast),
Register: make(chan *Client),
Reply: make(chan *Client),
Unregister: make(chan *Client),
}
func CreateID(uid, toUid string) string {
return uid + "->" + toUid //1->2
}
func Handler(c *gin.Context) {
uid := c.Query("uid")
toUid := c.Query("toUid")
conn, err := (&websocket.Upgrader{
//跨域
CheckOrigin: func(r *http.Request) bool {
return true
}}).Upgrade(c.Writer, c.Request, nil) //升级ws协议
if err != nil {
http.NotFound(c.Writer, c.Request)
return
}
//创建一个用户实例
client := &Client{
ID: CreateID(uid, toUid), //发送方 1发送给2
SendID: CreateID(toUid, uid), //接收方 2接收到1
Socket: conn, //Socket连接
Send: make(chan []byte), //发送的信息
}
//用户注册到用户管理上
Manager.Register <- client
go client.Read()
go client.Write()
}
// 读操作
func (c *Client) Read() {
//结束时关闭Socket
defer func() {
//用户结构体变成未注册状态
Manager.Unregister <- c
//关闭Socket
_ = c.Socket.Close()
}()
for {
c.Socket.PongHandler()
sendMsg := new(SendMsg)
//序列化
//如果传过来是String类型,用这个接收: c.Socket.ReadMessage()
err := c.Socket.ReadJSON(&sendMsg)
if err != nil {
fmt.Println("数据格式不正确", err)
Manager.Unregister <- c
_ = c.Socket.Close()
break
}
if sendMsg.Type == 1 { // 设置1为发送消息
r1, _ := cache.RedisClient.Get(c.ID).Result() //1->2 查看缓存里有没有发送方id
r2, _ := cache.RedisClient.Get(c.SendID).Result() //2->1 查看缓存里有没有接收方id
if r1 > "3" && r2 == "" { //1给2发消息,发了三条,但是2没有回,或者没有看到,就停止1发送。防止骚扰
replyMsg := ReplyMsg{
Code: e.WebsocketLimit,
Content: e.GetMsg(e.WebsocketLimit),
}
msg, _ := json.Marshal(replyMsg) //序列化
_ = c.Socket.WriteMessage(websocket.TextMessage, msg)
continue
} else {
//存储到redis中
cache.RedisClient.Incr(c.ID)
_, _ = cache.RedisClient.Expire(c.ID, time.Hour*24*30*3).Result() //防止过快“分手”,建立连接三个月过期
}
log.Println(c.ID, "发送消息", sendMsg.Content)
//广播出去
Manager.Broadcast <- &Broadcast{
Client: c,
Message: []byte(sendMsg.Content), //发送过来的消息
}
} else if sendMsg.Type == 2 { //拉取历史消息
timeT, err := strconv.Atoi(sendMsg.Content) // string转int64
if err != nil {
timeT = 999999999
}
results, _ := FindMany(conf.MongoDBName, c.SendID, c.ID, int64(timeT), 10) //获取10条历史消息
//大于10条消息
if len(results) > 10 {
results = results[:10]
} else if len(results) == 0 { //0条信息
replyMsg := ReplyMsg{
Code: e.WebsocketEnd,
Content: "到底了",
}
msg, _ := json.Marshal(replyMsg)
_ = c.Socket.WriteMessage(websocket.TextMessage, msg)
continue
}
//如果是1到10条消息时
for _, result := range results {
replyMsg := ReplyMsg{
From: result.From,
Content: result.Msg,
}
msg, _ := json.Marshal(replyMsg)
_ = c.Socket.WriteMessage(websocket.TextMessage, msg)
}
}
}
}
// 写操作
func (c *Client) Write() {
defer func() {
_ = c.Socket.Close()
}()
for {
select {
case message, ok := <-c.Send:
if !ok {
_ = c.Socket.WriteMessage(websocket.CloseMessage, []byte{})
return
}
log.Println(c.ID, "接受消息:", string(message))
replyMsg := ReplyMsg{
Code: e.WebsocketSuccessMessage,
Content: fmt.Sprintf("%s", string(message)),
}
msg, _ := json.Marshal(replyMsg)
_ = c.Socket.WriteMessage(websocket.TextMessage, msg)
}
}
}