有一段时间没写博客了,现在因为工作需要学习了go,做一个小Demo,简单的聊天室。
首先需要一个消息传送的protocol,因为是基于TCP协议,传输过程中会有粘包和拆包问题,因此定义一个协议来保证数据完整性。
package message
const (
LoginMesType = "LoginMes"
RegisterMesType = "RegisterMesType"
MesType = "MesType"
)
type Message struct {
Type string `json:"type"` // 消息类型
Data string `json:"data"` // 消息内容
}
type LoginMes struct {
UserId int `json:"userId"` // 用户id
UserPwd string `json:"userPwd"` // 用户密码
UserName string `json:"userName"` // 用户名
}
这个包是用来编码和解码的作用,当发送数据时需要分为2段编码,然后转为字节进行发送。接收数据时,通过反序列化直接就可以获得数据。
package coding
import (
"bufio"
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
)
func Encode(i interface{}) (b []byte, err error) {
//将message进行序列化
data, err := json.Marshal(i)
if err != nil {
fmt.Println("json Marshal error")
return
}
//这个时候data就是要发送的消息
var pkgLen uint32
var pkg = new(bytes.Buffer)
pkgLen = uint32(len(data))
// 写入头
err = binary.Write(pkg, binary.LittleEndian, pkgLen)
// 写入体
err = binary.Write(pkg, binary.LittleEndian, data)
b = pkg.Bytes()
return
}
func Decode(reader *bufio.Reader) (b []byte, err error) {
// 读取消息
preBytes, err := reader.Peek(4)
var length uint32
err = binary.Read(bytes.NewBuffer(preBytes), binary.LittleEndian, &length)
if err != nil {
fmt.Println("读取头4字节失败", err)
return
}
if uint32(reader.Buffered()) < length+4 {
return []byte{}, fmt.Errorf("读取到%d,应为%d", reader.Buffered(), length+4)
}
fmt.Println("body长度为", length)
p := make([]byte, length+4)
_, err = reader.Read(p)
if err != nil {
fmt.Println("读取头数据体失败")
return
}
b = p[4:]
return
}
服务器主要代码,因为是个简单demo,没有考虑很多。
package main
import (
"bufio"
"coding"
"encoding/json"
"fmt"
"message"
"net"
"redisCli"
"strconv"
)
var r *redisCli.RedisCli
var online map[net.Conn]message.LoginMes
func init() {
r, _ = redisCli.NewRedisCli("tcp", "127.0.0.1:6379")
online = make(map[net.Conn]message.LoginMes, 10)
}
func handleLogin(b []byte) (loginMes message.LoginMes) {
_ = json.Unmarshal(b, &loginMes)
//使用redis操作是否存在该账号
str, err := r.GetString(strconv.Itoa(loginMes.UserId) + ":" + loginMes.UserPwd)
if err != nil {
return
}
//if str == "" {
// fmt.Println("登录失败")
// return
//}
loginMes.UserName = str
return
}
func handleRegister(b []byte) {
}
func process(conn net.Conn) {
// 读取客户端发送的信息
reader := bufio.NewReader(conn)
mes := message.Message{}
b, err := coding.Decode(reader)
if err != nil {
fmt.Println("decode失败")
return
}
err = json.Unmarshal(b, &mes)
if err != nil {
fmt.Println("反序列化失败")
return
}
dataBytes := []byte(mes.Data)
switch mes.Type {
case message.LoginMesType: // 登录消息
loginMes := handleLogin(dataBytes)
online[conn] = loginMes
mes := message.Message{
Type: message.MesType,
Data: loginMes.UserName + "加入群聊啦",
}
b, err := coding.Encode(mes)
if err != nil {
fmt.Println("序列化失败")
return
}
for c, _ := range online {
_, _ = c.Write(b)
}
case message.RegisterMesType: // 注册消息
handleRegister(dataBytes)
}
}
func main() {
fmt.Println("服务器从8889端口监听")
listen, err := net.Listen("tcp", "0.0.0.0:8889")
if err != nil {
fmt.Println("net listen error")
return
}
//监听成功等待
for {
conn, err := listen.Accept()
if err != nil {
fmt.Println("net client error")
return
}
//连接成功就保持通讯
go process(conn)
}
}
客户端代码
package main
import (
"bufio"
"coding"
"context"
"fmt"
"message"
"sync"
)
var (
userId int
userPwd string
)
var wg sync.WaitGroup
func main() {
// 接收用户的选择
var key int
// 判断是否还继续显示菜单
var loop = true
for loop {
fmt.Println("----------欢迎登录多人聊天系统----------")
fmt.Println("\t\t\t 1 登录聊天系统")
fmt.Println("\t\t\t 2 注册用户")
fmt.Println("\t\t\t 3 退出系统")
fmt.Println("\t\t\t 请选择(1-3):")
_, _ = fmt.Scanf("%d\n", &key)
switch key {
case 1:
fmt.Println("登录聊天系统")
loop = false
case 2:
fmt.Println("注册用户")
case 3:
fmt.Println("退出系统")
loop = false
default:
fmt.Println("输入有误 请重新输入")
}
}
if key == 1 {
// 说明用户要登录
fmt.Println("请输入用户id")
_, _ = fmt.Scanf("%d\n", &userId)
fmt.Println("请输入用户密码")
_, _ = fmt.Scanf("%s\n", &userPwd)
// 先把登录的函数写到另一个文件中
conn, err := login(userId, userPwd)
if err != nil {
fmt.Println("登录失败了")
} else {
fmt.Println("登录成功了")
// 这里需要监听服务器发来的信息
ctx, cancel := context.WithCancel(context.Background())
wg.Add(2)
// 开启接收端进程
go func(ctx context.Context, cancel context.CancelFunc) {
fmt.Println("开启接收进程")
defer wg.Done()
defer cancel()
for {
select {
case <-ctx.Done():
fmt.Println("退出接收进程")
return
default:
b, err := coding.Decode(bufio.NewReader(conn))
if err != nil {
fmt.Println("退出接收进程, err:", err)
return
}
str := string(b)
fmt.Println("收到消息:", str)
if "exit" == str {
return
}
}
}
}(ctx, cancel)
// 开启发送端进程
go func(ctx context.Context) {
defer wg.Done()
fmt.Println("开启写入进程")
for {
select {
case <-ctx.Done():
fmt.Println("退出写入进程")
return
default:
var str string
_, _ = fmt.Scanf("%s\n", &str)
mes := message.Message{
Type: message.MesType,
Data: str,
}
b, err := coding.Encode(mes)
if err != nil {
fmt.Println("序列化失败")
continue
}
_, _ = conn.Write(b)
}
}
}(ctx)
wg.Wait()
}
} else if key == 2 {
fmt.Println("进行用户注册")
}
}
// 写一个函数,完成登录
func login(userId int, userPwd string) (conn net.Conn, err error) {
// 开始定协议
conn, err = net.Dial("tcp", "localhost:8889")
if err != nil {
fmt.Println("net Dial error")
return
}
// 准备conn发送消息
var mes message.Message
mes.Type = message.LoginMesType
// 创建一个LoginMes 结构体
var loginMes message.LoginMes
loginMes.UserId = userId
loginMes.UserPwd = userPwd
b, err := json.Marshal(loginMes)
if err != nil {
fmt.Println("json Marshal error")
return
}
mes.Data = string(b)
data, err := coding.Encode(mes)
if err != nil {
fmt.Println("data Marshal error")
return
}
_, _ = conn.Write(data)
return
}
效果: