websocket 应用

websocket-golang 应用

目录结构

websocket 应用_第1张图片

client:客户端

service:服务端

memory:服务端管理客户端长链接信息

dimain:通用结构体定义

websocket_test:测试用例

service

package websocket

import (
	"net/http"
	"time"

	"github.com/gorilla/websocket"
)

/*
 * webSocket 服务端
 */

// 升级 http 请求为 websocket 请求
var (
	upgrade = &websocket.Upgrader{
		CheckOrigin: func(r *http.Request) bool {
			return true
		},
	}
)

type Serve struct {
	conn       *websocket.Conn // ws 链接
	remoteAddr string          // 目标地址
	identity   string          // 身份标识
	updateTime int64           // 更新请求时间

	handleMessage func(message *Message, serve *Serve) // 处理消息
}

// NewService 将 http 请求升级为 websocket 请求
func NewService(w http.ResponseWriter, r *http.Request, handleMessage func(message *Message, serve *Serve)) {
	conn, err := upgrade.Upgrade(w, r, nil)
	if err != nil {
		return
	}

	// notice: 验证 identity 是否存在
	identity := r.Header.Get("identify")

	s := &Serve{
		conn:          conn,
		remoteAddr:    r.RemoteAddr,
		identity:      identity,
		handleMessage: handleMessage,
	}

	// 更新时间
	s.UpdateTime()

	memory := GetClientsManager(nil)

	if tmp, ok := memory.GetClient(identity); ok {
		tmp.CloseAndDelMemory()
	}
	// 添加到 memory 进行管理
	memory.AddClient(s.identity, s)

	go s.readLoop()

	return
}

// readLoop 读取消息
func (s *Serve) readLoop() {
	for {
		var tempMessage = &Message{}
		// 检测到连接是否关闭
		if s.conn == nil {
			return
		}

		// 读取消息
		err := s.conn.ReadJSON(tempMessage)
		if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
			return
		} else if err != nil {
			return
		}

		// 更新时间
		s.UpdateTime()

		// 处理消息
		s.handleMessage(tempMessage, s)
	}
}

// Send 发送消息
func (s *Serve) Send(message interface{}) {
	if err := s.conn.WriteJSON(message); err != nil {
	}
}

// CloseAndDelMemory 关闭连接并从内存中删除
func (s *Serve) CloseAndDelMemory() {
	s.Close()
	GetClientsManager(nil).DelClient(s.identity)
}

// Close 关闭连接
func (s *Serve) Close() {
	err := s.conn.Close()
	if err != nil {
		return
	}
}

// UpdateTime 更新时间
func (s *Serve) UpdateTime() {
	s.updateTime = time.Now().Unix()
}

// GetUpdateTime 获取更新时间
func (s *Serve) GetUpdateTime() int64 {
	return s.updateTime
}

client

package websocket

import (
	"net/http"
	"net/url"

	"github.com/gorilla/websocket"
)

/*
 * websocket 客户端
 */

type Client struct {
	conn       *websocket.Conn
	identify   string
	remoteAddr string

	handleMessage func(*Message, *Client)
}

// NewClient 创建客户端
func NewClient(remoteIp, identify string, handleMessage func(*Message, *Client)) *Client {
	u := url.URL{Scheme: "ws", Host: remoteIp, Path: "/ws"}
	header := http.Header{
		"identify": []string{identify},
	}
	conn, _, err := websocket.DefaultDialer.Dial(u.String(), header)
	if err != nil {
		return nil
	}

	c := &Client{
		conn:          conn,
		identify:      identify,
		remoteAddr:    remoteIp,
		handleMessage: handleMessage,
	}

	// 开启读协程
	go c.readLoop()

	return c
}

// readLoop 读协程
func (cli *Client) readLoop() {
	for {
		var tempMessage = &Message{}
		if cli.conn == nil {
			return
		}
		err := cli.conn.ReadJSON(tempMessage)
		if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
			return
		} else if err != nil {
			return
		}

		// 处理消息
		cli.handleMessage(tempMessage, cli)
	}
}

// Close 关闭连接
func (cli *Client) Close() {
	err := cli.conn.Close()
	if err != nil {
		return
	}
}

// Send 发送消息
func (cli *Client) Send(msg interface{}) {
	if cli.conn == nil {
		return
	}
	err := cli.conn.WriteJSON(msg)
	if err != nil {
		return
	}
	return
}

memory

package websocket

import (
	"sync"
	"time"
)

/*
 * 管理客户端长链接
 */

var cliManager = ClientsManager{
	clientsMap: make(map[string]*Serve),
	rwMutex:    sync.RWMutex{},
	maxTime:    30, // 默认设置最长无响应时长 30 秒
	checkTime:  5,  // 默认检测时间 5 秒
}

type ClientsManager struct {
	clientsMap map[string]*Serve
	rwMutex    sync.RWMutex
	once       sync.Once

	maxTime   int64         // 最长无响应时长 单位 秒
	checkTime time.Duration // 检测时间 单位 秒
}

// GetClientsManager 获取客户端管理器
func GetClientsManager(memoryCfg *MemoryConfig) *ClientsManager {
	cliManager.once.Do(func() {
		if memoryCfg != nil {
			// TODO 设置配置,可以使用 opt 设计模式
			cliManager.maxTime = memoryCfg.MaxTime
			cliManager.checkTime = memoryCfg.CheckTime
		}
		go cliManager.isLive()
	})
	return &cliManager
}

// AddClient 添加客户端链接信息
func (m *ClientsManager) AddClient(identify string, client *Serve) {
	m.rwMutex.Lock()
	defer m.rwMutex.Unlock()
	m.clientsMap[identify] = client
}

// GetClient 获取客户端链接信息
func (m *ClientsManager) GetClient(identify string) (client *Serve, ok bool) {
	m.rwMutex.RLock()
	defer m.rwMutex.RUnlock()
	client, ok = m.clientsMap[identify]
	return
}

// DelClient 删除客户端链接信息
func (m *ClientsManager) DelClient(identify string) {
	m.rwMutex.Lock()
	defer m.rwMutex.Unlock()
	delete(m.clientsMap, identify)
}

// 检测客户端链接是否存活
func (m *ClientsManager) isLive() {
	ticker := time.NewTicker(m.checkTime * time.Second)

	for {
		select {
		case <-ticker.C:
			m.checkClients()
		}
	}
}

// 处理未响应的客户端链接
func (m *ClientsManager) checkClients() {
	m.rwMutex.Lock()
	defer m.rwMutex.Unlock()

	for k, v := range m.clientsMap {
		if time.Now().Unix()-v.GetUpdateTime() > m.maxTime {
			v.Close()
			delete(m.clientsMap, k)
		}
	}
}

dimain

package websocket

import "time"

type Message struct {
	Type int         `json:"type,omitempty"` // 事件
	Data interface{} `json:"data,omitempty"` // 消息体
}

type MemoryConfig struct {
	MaxTime   int64
	CheckTime time.Duration
}

websocket_test

package websocket

import (
	"log"
	"net/http"
	"testing"
	"time"

	"github.com/google/uuid"
)

// 服务端
func Test_Websocket(t *testing.T) {
	http.HandleFunc("/ws", WebSocketService)

	err := http.ListenAndServe(":8080", nil)
	if err != nil {
		t.Fatal(err)
	}

	t.Log("end")
}

// 客户端发送消息
func Test_WebsocketClient(t *testing.T) {
	c := NewClient("127.0.0.1:8080", uuid.NewString(), handleMessageByClient)
	for i := 0; i < 20; i++ {
		msg := &Message{
			Type: 1,
			Data: i,
		}
		time.Sleep(time.Second * 5)
		c.Send(msg)
	}
	// time.Sleep(time.Second * 10)

	msg := &Message{
		Type: 2,
		Data: "end",
	}
	c.Send(msg)
}

// 客户端处理消息
func handleMessageByClient(message *Message, cli *Client) {
	log.Println("handleMessage:", message)

	if message.Type == 2 {
		cli.Close()
	}
}

// 服务端处理消息
func handleMessageByServe(message *Message, s *Serve) {
	// 处理消息
	if message.Type == 1 {
		log.Println("handleMessage:", message)

		msg := &Message{
			Type: 1,
			Data: "收到",
		}
		s.Send(msg)
	} else {
		msg := &Message{
			Type: 2,
			Data: "end",
		}
		s.Send(msg)
		s.CloseAndDelMemory()

	}
}

func WebSocketService(w http.ResponseWriter, r *http.Request) {
	// 验证请求是否安全
	identity := r.Header.Get("identify")
	if identity == "" {
		log.Printf("identity: 没有参数")
		return
	}

	// 创建 websocket 服务端
	NewService(w, r, handleMessageByServe)
}

你可能感兴趣的:(websocket,网络协议,网络,golang)