.net Core 中间件实现WebSocket通讯

文章目录

  • 前言
  • 一、WebSocket是什么?
  • 二、.net Core 中使用WebSocket
    • 1.创建保存 WebSocket 的类
    • 2.创建管理和操作 WebSocket 的基类
    • 3.创建 WebSocket 的中间件
    • 4.创建 WebSocket 管理子类
    • 5.创建注入扩展
    • 6.配置 Startup.cs
    • 7.测试


前言

最近有用的websocket通讯功能,在此记录一下。


一、WebSocket是什么?

Websocket是一种网络通信协议,是一个在计算机里专门在【两点】之间传输数据的约定和规范。

由于HTTP协议存在一个缺陷,即 通信只能由客户端发起。
但是实际上,如果服务器有连续的状态变化,客户端要获取信息就非常的麻烦,只能采用“轮询”:每隔一段时间,就发出一次询问,了解服务器有么有新的信息。这就会导致效率低下,且浪费资源。(因为没查看一次服务器是否有新消息时,就要建立tcp连接)
WebSocket 就是基于Http协议,或者说借用Http协议来完成一部分握手。
注意: WebSocket 没有同源限制,客户端可以与任意服务器通信,也就不用考虑跨域的问题

二、.net Core 中使用WebSocket

1.创建保存 WebSocket 的类

该类用于保存所有的 WebSocket 连接

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;

namespace WebAPI.Socket
{
    /// 
    ///  WebSocket 管理
    /// 
    public class WebSocketConnectionManager
    {
        /// 
        /// 用户连接池
        /// 
        private ConcurrentDictionary<string, WebSocket> _sockets = new ConcurrentDictionary<string, WebSocket>();
        private ConcurrentDictionary<string, List<string>> _groups = new ConcurrentDictionary<string, List<string>>();

        /// 
        /// 获取指定id的socket
        /// 
        /// 
        /// 
        public WebSocket GetSocketById(string id)
        {
            if (_sockets.TryGetValue(id, out WebSocket socket))
                return socket;
            else
                return null;
        }

        /// 
        /// 获取所有socket
        /// 
        /// 
        public ConcurrentDictionary<string, WebSocket> GetAll()
        {
            return _sockets;
        }

        /// 
        /// 根据 socket 获取其id
        /// 
        /// 
        /// 
        public string GetId(WebSocket socket)
        {
            return _sockets.FirstOrDefault(p => p.Value == socket).Key;
        }

        /// 
        /// 添加socket连接
        /// 
        /// 
        public void AddSocket(WebSocket socket)
        {
            _sockets.TryAdd(CreateConnectionId(), socket);
        }

        /// 
        /// 添加指定id的socket连接
        /// 
        /// 
        /// 
        public void AddSocket(string socketID, WebSocket socket)
        {
            _sockets.TryAdd(socketID, socket);
        }

        /// 
        /// 删除指定 id 的 socket,并关闭连接
        /// 
        /// 
        /// 
        public async Task RemoveSocket(string id)
        {
            if (id == null) return;

            if (_sockets.TryRemove(id, out WebSocket socket))
            {
                if (socket.State != WebSocketState.Open) return;

                await socket.CloseAsync(closeStatus: WebSocketCloseStatus.NormalClosure,
                                        statusDescription: "Closed by the WebSocketManager",
                                        cancellationToken: CancellationToken.None).ConfigureAwait(false);
            }
        }

        /// 
        /// 创建 socket 的 id
        /// 
        /// 
        private string CreateConnectionId()
        {
            return Guid.NewGuid().ToString();
        }

        /// 
        /// 获取socket连接总数量
        /// 
        /// 
        public int GetSocketClientCount()
        {
            return _sockets.Count();
        }

        public List<string> GetAllFromGroup(string GroupID)
        {
            if (_groups.ContainsKey(GroupID))
            {
                return _groups[GroupID];
            }

            return default(List<string>);
        }

        public void AddToGroup(string socketID, string groupID)
        {
            if (_groups.ContainsKey(groupID))
            {
                _groups[groupID].Add(socketID);

                return;
            }

            _groups.TryAdd(groupID, new List<string> { socketID });
        }

        public void RemoveFromGroup(string socketID, string groupID)
        {
            if (_groups.ContainsKey(groupID))
            {
                _groups[groupID].Remove(socketID);
            }
        }
    }
}

2.创建管理和操作 WebSocket 的基类

该类旨在处理 socket 的连接和断连,以及接收和发送消息,属于基类。

using Newtonsoft.Json;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Net.WebSockets;
using System.Reflection;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace WebAPI.Socket
{
    /// 
    ///  WebSocket 抽象类
    /// 
    public abstract class WebSocketHandler
    {
        protected WebSocketConnectionManager WebSocketConnectionManager { get; set; }
        public WebSocketHandler(WebSocketConnectionManager webSocketConnectionManager)
        {
            WebSocketConnectionManager = webSocketConnectionManager;
        }

        /// 
        /// 根据 stocketId获取对应的WebSocket
        /// 
        /// 
        /// 
        public virtual async Task<WebSocket> GetWebStocket(string socketId)
        {
            return WebSocketConnectionManager.GetSocketById(socketId);
        }

        /// 
        /// 连接一个 socket
        /// 
        /// 
        /// 
        public virtual async Task OnConnected(WebSocket socket)
        {
            WebSocketConnectionManager.AddSocket(socket);
        }

        /// 
        /// 连接一个 socket (指定socketId)
        /// 
        /// 
        /// 
        /// 
        public virtual async Task OnConnected(string socketId, WebSocket socket)
        {
            WebSocketConnectionManager.AddSocket(socketId, socket);
        }
        public virtual async Task OnDisconnected(WebSocket socket)
        {
            var socketId = WebSocketConnectionManager.GetId(socket);
            if (!string.IsNullOrWhiteSpace(socketId))
                await WebSocketConnectionManager.RemoveSocket(socketId).ConfigureAwait(false);
        }

        /// 
        /// 发送消息给指定 socket
        /// 
        /// 
        /// 
        /// 
        public async Task SendMessageAsync(WebSocket socket, Message message)
        {
            if (socket.State != WebSocketState.Open)
                return;
            var serializedMessage = JsonConvert.SerializeObject(message);
            var encodedMessage = Encoding.UTF8.GetBytes(serializedMessage);
            try
            {
                await socket.SendAsync(buffer: new ArraySegment<byte>(array: encodedMessage,
                                                                      offset: 0,
                                                                      count: encodedMessage.Length),
                                       messageType: WebSocketMessageType.Text,
                                       endOfMessage: true,
                                       cancellationToken: CancellationToken.None).ConfigureAwait(false);
            }
            catch (WebSocketException e)
            {
                if (e.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
                {
                    await OnDisconnected(socket);
                }
            }
        }
        /// 
        /// 发送消息给指定id的socket
        /// 
        /// 
        /// 
        /// 
        public async Task SendMessageAsync(string socketId, Message message)
        {
            var socket = WebSocketConnectionManager.GetSocketById(socketId);
            if (socket != null)
                await SendMessageAsync(socket, message).ConfigureAwait(false);
        }

        /// 
        /// 发送消息给多个指定id的socket
        /// 
        /// 
        /// 
        /// 
        public async Task SendMessageAsync(List<string> sockets, Message message)
        {
            foreach (var socket in sockets)
            {
                await SendMessageAsync(socket, message).ConfigureAwait(false);
            }
        }

        /// 
        /// 获取所有 socket 连接
        /// 
        /// 
        public async Task<ConcurrentDictionary<string, WebSocket>> GetAll()
        {
            return WebSocketConnectionManager.GetAll();
        }

        /// 
        /// 给所有 socket 发送消息
        /// 
        /// 
        /// 
        public async Task SendMessageToAllAsync(Message message)
        {
            foreach (var pair in WebSocketConnectionManager.GetAll())
            {
                try
                {
                    if (pair.Value.State == WebSocketState.Open)
                        await SendMessageAsync(pair.Value, message).ConfigureAwait(false);
                }
                catch (WebSocketException e)
                {
                    if (e.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
                    {
                        await OnDisconnected(pair.Value);
                    }
                }
            }
        }
        public async Task SendMessageToGroupAsync(string groupID, Message message)
        {
            var sockets = WebSocketConnectionManager.GetAllFromGroup(groupID);
            if (sockets != null)
            {
                foreach (var socket in sockets)
                {
                    await SendMessageAsync(socket, message);
                }
            }
        }
        public async Task SendMessageToGroupAsync(string groupID, Message message, string except)
        {
            var sockets = WebSocketConnectionManager.GetAllFromGroup(groupID);
            if (sockets != null)
            {
                foreach (var id in sockets)
                {
                    if (id != except)
                        await SendMessageAsync(id, message);
                }
            }
        }
        /// 
        /// 接收消息
        /// 
        /// 
        /// 
        /// 
        /// 
        public virtual async Task ReceiveAsync(WebSocket socket, WebSocketReceiveResult result, Message receivedMessage)
        {
            try
            {
                await SendMessageAsync(socket, receivedMessage).ConfigureAwait(false);
            }
            catch (TargetParameterCountException)
            {
                await SendMessageAsync(socket, new Message() { }).ConfigureAwait(false);
            }
            catch (ArgumentException)
            {
                await SendMessageAsync(socket, new Message() { }).ConfigureAwait(false);
            }
        }
    }
}

3.创建 WebSocket 的中间件

using Microsoft.AspNetCore.Http;
using Newtonsoft.Json;
using System;
using System.IO;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Linq;
using System.Collections.Generic;

namespace WebAPI.Socket
{
    /// 
    /// 自定义 WebSocket 中间件
    /// 
    public class WebSocketManagerMiddleware
    {
        private readonly RequestDelegate _next;
        private WebSocketHandler _webSocketHandler { get; set; }
        /// 
        /// 
        /// 
        /// 
        /// 
        public WebSocketManagerMiddleware(RequestDelegate next,
                                          WebSocketHandler webSocketHandler)
        {
            _next = next;
            _webSocketHandler = webSocketHandler;
        }

        public async Task Invoke(HttpContext context)
        {
            if (!context.WebSockets.IsWebSocketRequest)
            {
                await _next.Invoke(context);
                return;
            }

            //接受 websocket 客户端连接  // 转换当前连接为一个 ws 连接
            var socket = await context.WebSockets.AcceptWebSocketAsync().ConfigureAwait(false);

            Message message = new Message();
            // 获取参数
            message.editId = context.Request.Query["editId"].ToString();
            message.userName = context.Request.Query["userName"].ToString();
            message.methodName = context.Request.Query["methodName"].ToString();
            string socketId = string.Format("{0}_{1}", message.userName, message.editId);

            await _webSocketHandler.OnConnected(socketId, socket).ConfigureAwait(false);
            await MessageHandle(socket, socketId, message);

            await Receive(socket, async (result, serializedMessage) =>
            {
                if (result.MessageType == WebSocketMessageType.Text)
                {
                    try
                    {
                        Message message = JsonConvert.DeserializeObject<Message>(serializedMessage);
                        var socketIds = await MessageHandle(socket, socketId, message);
                        if (message.messageState == MessageState.BeApplied || message.messageState == MessageState.BeDisAgreed || message.messageState == MessageState.BeAgreed)
                        {
                            await _webSocketHandler.SendMessageAsync(socketIds, message).ConfigureAwait(false);
                        }
                        else
                        {
                            await _webSocketHandler.ReceiveAsync(socket, result, message).ConfigureAwait(false);
                        }
                    }
                    catch (Exception)
                    {
                        socket.Abort();
                    }
                    return;
                }
                else if (result.MessageType == WebSocketMessageType.Close)
                {
                    try
                    {
                        await _webSocketHandler.OnDisconnected(socket);
                    }
                    catch (WebSocketException)
                    {
                        throw; //let's not swallow any exception for now
                    }
                    return;
                }
            });
        }

        private async Task Receive(WebSocket socket, Action<WebSocketReceiveResult, string> handleMessage)
        {
            // 判断连接类型,并执行相应操作
            while (socket.State == WebSocketState.Open)
            {
                ArraySegment<Byte> buffer = new ArraySegment<byte>(new Byte[1024 * 4]);
                string message = null;
                WebSocketReceiveResult result = null;
                try
                {
                    using (var ms = new MemoryStream())
                    {
                        do
                        {
                            // 继续接受信息
                            result = await socket.ReceiveAsync(buffer, CancellationToken.None).ConfigureAwait(false);
                            ms.Write(buffer.Array, buffer.Offset, result.Count);
                        }
                        while (!result.EndOfMessage);

                        ms.Seek(0, SeekOrigin.Begin);

                        using (var reader = new StreamReader(ms, Encoding.UTF8))
                        {
                            message = await reader.ReadToEndAsync().ConfigureAwait(false);
                        }
                    }

                    handleMessage(result, message);
                }
                catch (WebSocketException e)
                {
                    if (e.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely)
                    {
                        socket.Abort();
                    }
                }
            }

            await _webSocketHandler.OnDisconnected(socket);
        }

        /// 
        /// 消息处理
        /// 
        /// 
        private async Task<List<string>> MessageHandle(WebSocket socket, string socketId, Message message)
        {
            //获取所有连接
            var socketKeyValues = await _webSocketHandler.GetAll();
            var socketIds = socketKeyValues.Where(x => x.Key.Contains(message.editId) && !x.Key.Equals(socketId)).Select(x => x.Key).ToList();

            if (socketIds != null && socketIds.Count > 0) 
            {
                if (message.messageState == MessageState.Apply)
                {
                    message.messageState = MessageState.BeApplied;
                    message.socketId = socketId;//申请人的websocket连接
                }
                else if (message.messageState == MessageState.DisAgree || message.messageState == MessageState.Agree)
                {
                    message.messageState = message.messageState == MessageState.DisAgree ? MessageState.BeDisAgreed: MessageState.BeAgreed;
                    socketIds = new List<string>() { message.socketId };
                }
                else
                {
                    // 存在数据正在被编辑
                    if (socketIds != null && socketIds.Count() > 0)
                    {
                        string name = null;
                        foreach (var item in socketIds)
                        {
                            name += item.Substring(0, item.IndexOf('_')) + ",";
                        }
                        message.userName = name.TrimEnd(',');
                        message.messageState = MessageState.Editing;
                        await _webSocketHandler.SendMessageAsync(socket, message);
                    }
                }
            }
            else
            {
                message.messageState = MessageState.Enable;
            }


            return socketIds;
        }
    }
}

4.创建 WebSocket 管理子类

可以创建多个,用于个性化设置,主要是上面设置了接收的抽象方法,所以必须要重写 Receive 方法。如果不需要的话,其实把基类的抽象去掉,直接在基类中写也可以。

namespace ACH_Sampleonline.WebAPI.Socket
{
    /// 
    /// 业务逻辑处理
    /// 
    public class BusMessageHandler : WebSocketHandler
    {
        public BusMessageHandler(WebSocketConnectionManager webSocketConnectionManager) : base(webSocketConnectionManager)
        {
        }
    }
}

5.创建注入扩展

直接在 Startup.cs 中写也无不可,但这是好习惯,将每个注入内容单独写到文件。

using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using System.Reflection;

namespace WebAPI.Socket
{
    /// 
    /// 中间件扩展方法
    /// 
    public static class WebSocketManagerExtensions
    {
        /// 
        /// 
        /// 
        /// 
        /// 
        /// 
        public static IServiceCollection AddWebSocketManager(this IServiceCollection services, Assembly assembly = null)
        {
            services.AddTransient<WebSocketConnectionManager>();

            Assembly ass = assembly ?? Assembly.GetEntryAssembly();

            foreach (var type in ass.ExportedTypes)
            {
                if (type.GetTypeInfo().BaseType == typeof(WebSocketHandler))
                {
                    services.AddSingleton(type);
                }
            }

            return services;
        }
        /// 
        /// 通过扩展方法公开中间件
        /// 
        /// 
        /// 
        /// 
        /// 
        public static IApplicationBuilder MapWebSocketManager(this IApplicationBuilder app,
                                                              PathString path,
                                                              WebSocketHandler handler)
        {
            return app.Map(path, (_app) => _app.UseMiddleware<WebSocketManagerMiddleware>(handler));
        }
    }
}

6.配置 Startup.cs

将上面的内容注入到启动项中即可。
在 ConfigureServices 中添加:

services.AddWebSocketManager();

在 Configure 中添加:

// Socket通讯
app.UseWebSockets();
app.MapWebSocketManager("/ws", app.ApplicationServices.GetService<BusMessageHandler>());

即可。如果提示serviceProvider找不到,在 Configure 的参数中添加:

IServiceProvider serviceProvider

7.测试

前端代码样例:

DOCTYPE html>
<html>
<head>
    <meta charset="utf-8" />
    <title>测试title>
head>
<body>
    <div id="message" style="border: solid 1px #333; padding: 4px; width: 550px; overflow: auto; background-color: #404040; height: 300px; margin-bottom: 8px; font-size: 14px;">
    div>
    <input id="text" type="text" onkeypress="enter(event);" style="width: 340px" />
      
    <button id="send" onclick="send();">发送button>
    <button onclick="quit();">停止button>
body>
html>
<script type="text/javascript">
    var ws;
    var msgContainer = document.getElementById('message');
    var text = document.getElementById('text');
    window.onload = function () {
        ws = new WebSocket("ws://localhost:5056/ws?userName=张三&editId=" + "12345678");
        ws.onopen = function (e) {
            var msg = document.createElement('div');
            msg.style.color = '#0f0';
            msg.innerhtml = "server > connection open.";
            msgcontainer.appendChild(msg);
        };
        ws.onmessage = function (e) {
            console.log(e, "onmessage");
            var msg = document.createElement('div');
            msg.style.color = '#0f0';
            msg.innerHTML = e.data;
            msgContainer.appendChild(msg);
            msgContainer.scrollTop = msgContainer.scrollHeight;
        };
        ws.onerror = function (e) {
            console.log(e, "onerror");
            var msg = document.createElement('div');
            msg.style.color = '#0f0';
            msg.innerHTML = 'Server > ' + e.data;
            msgContainer.appendChild(msg);
        };
        ws.onclose = function (e) {
            var msg = document.createElement('div');
            msg.style.color = '#0f0';
            msg.innerHTML = 'Server > connection closed.';
            msgContainer.appendChild(msg);
            ws = null;
        };
    }
    function quit() {
        if (ws) {
            ws.close();
            var msg = document.createElement('div');
            msg.style.color = '#0f0';
            msg.innerHTML = 'Server >start closed.';
            msgContainer.appendChild(msg);
            ws = null;
        }
    }
    function send() {
        ws.send(text.value);
        var htmlValue = "客户端: " + text.value + "  " + getNowTime();
        var msg = document.createElement('div');
        msg.style.color = '#ffff00';
        msg.innerHTML = htmlValue;
        msgContainer.appendChild(msg);
        text.value = "";
        msgContainer.scrollTop = msgContainer.scrollHeight;
    }
    function enter(event) {
        if (event.keyCode == 13) {
            send();
        }
    }
    //获取当前时间
    function getNowTime() {
        var date = new Date();
        //年 getFullYear():四位数字返回年份
        var year = date.getFullYear(); //getFullYear()代替getYear()
        //月 getMonth():0 ~ 11
        var month = date.getMonth() + 1;
        //日 getDate():(1 ~ 31)
        var day = date.getDate();
        //时 getHours():(0 ~ 23)
        var hour = date.getHours();
        //分 getMinutes(): (0 ~ 59)
        var minute = date.getMinutes();
        //秒 getSeconds():(0 ~ 59)
        var second = date.getSeconds();
        var time = year + '/' + addZero(month) + '/' + addZero(day) + ' ' + addZero(hour) + ':' + addZero(minute) + ':' + addZero(second);
        return time;
    }
    //小于10的拼接上0字符串
    function addZero(s) {
        return s < 10 ? ('0' + s) : s;
    }
script>

参考链接:https://blog.csdn.net/wulex/article/details/115548474


你可能感兴趣的:(C#知识整理,websocket,.netcore,中间件)