WebSocketMiddleware.cs
using Microsoft.AspNetCore.Http;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;
public class WebSocketMiddleware
{
private readonly RequestDelegate _next;
public WebSocketMiddleware(RequestDelegate next)
{
this._next = next;
}
public async Task Invoke(HttpContext context)
{
if (!context.WebSockets.IsWebSocketRequest)
{
await _next.Invoke(context);
}
else
{
CancellationToken ct = context.RequestAborted;
//客户端连接
var currentSocket = await context.WebSockets.AcceptWebSocketAsync();
await WebSocketHandler.Accept(context, currentSocket, ct);
}
}
}
WebSocketHandler
using Microsoft.AspNetCore.Http;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
public static class WebSocketHandler
{
///
/// 接收大小
///
const int BLOCK_SIZE = 1024;
///
/// 订阅消息
///
public const string SubscribeMsgType = "Subscribe";
///
/// 取消订阅消息
///
public const string UnsubscribeMsgType = "Unsubscribe";
///
/// 已知消息
///
public readonly static Dictionary> KnownMessages =
new Dictionary>{
{ SubscribeMsgType,Subscribe },
{ UnsubscribeMsgType,Unsubscribe }
};
//所有连接的客户端
private readonly static ConcurrentDictionary Clients =
new ConcurrentDictionary();
///
/// 接收客户端俩姐
///
///
public static Task Accept(HttpContext context, WebSocket currentSocket, CancellationToken ct)
{
WebSocketClient client = new WebSocketClient(context.Request.Host.ToString(),
currentSocket);
Program.Logger.Debug($"Accept Client:{client.Host}");
Clients.TryAdd(client.ClientID, client);
return Task.Run(() =>
{
byte[] buffer = new byte[BLOCK_SIZE];
while (currentSocket.State == WebSocketState.Open)
{
WebSocketReceiveResult result;
try
{
//阻塞
result = currentSocket.ReceiveAsync(buffer, ct).Result;
}
catch(AggregateException)
{
break;
}
catch(ThreadAbortException)
{
break;
}
catch(TaskCanceledException)
{
break;
}
if (result == null) continue;
if (result.MessageType == WebSocketMessageType.Close)
{
break;
}
if (result.MessageType == WebSocketMessageType.Text)
{
using (MemoryStream ms = new MemoryStream())
{
ms.Write(buffer, 0, result.Count);
bool flag = true;
while (!result.EndOfMessage)
{
try
{
result = currentSocket.ReceiveAsync(buffer, ct).Result;
}
catch (AggregateException)
{
break;
}
catch (ThreadAbortException)
{
flag = false;
break;
}
catch (TaskCanceledException)
{
flag = false;
break;
}
ms.Write(buffer, 0, result.Count);
}
if (!flag)
break;
ms.Position = 0;
string msg = null;
try
{
byte[] bs = ms.ToArray();
if (bs.Length == 0) continue;
msg = Encoding.UTF8.GetString(bs);
ProcessClientMessage(client, JsonConvert.DeserializeObject(msg));
}
catch
{
Program.Logger.Debug($"{client.Host} Receive:{msg}");
}
}
}
else
{
Program.Logger.Debug($"{client.Host} receive binary");
}
}
CancellationToken cancellationToken = new CancellationToken();
currentSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing", cancellationToken);
Clients.TryRemove(client.ClientID, out client);
Program.Logger.Debug($"Close Client:{client.Host}");
});
}
///
/// 处理客户端信息
///
///
private static void ProcessClientMessage(WebSocketClient client,ReceiveMessage msg)
{
if (client == null || client.Socket.State != WebSocketState.Open
|| msg == null || string.IsNullOrEmpty(msg.MessageType)) return;
Program.Logger.Debug($"{client.Host} {msg.MessageType}");
if(KnownMessages.ContainsKey(msg.MessageType))
{
KnownMessages[msg.MessageType](client,msg);
}
}
///
/// 服务器端发布消息
///
///
public static void PublishMessage(SendMessage message)
{
if (message == null
|| string.IsNullOrEmpty(message.MessageType)) return;
ICollection keys = Clients.Keys;
CancellationToken cancellationToken = new CancellationToken();
string body = Program.Json(message);
byte[] bs = Encoding.UTF8.GetBytes(body);
ArraySegment msg = new ArraySegment(bs);
keys.AsParallel().Count(k => {
WebSocketClient client = null;
Clients.TryGetValue(k, out client);
if(client != null
&& client.Socket.State == WebSocketState.Open
&& client.Messages.Contains(message.MessageType))
{
try
{
client.Socket.SendAsync(msg, WebSocketMessageType.Text, true, cancellationToken);
return true;
}
catch
{
Program.Logger.Error($"SendTo{client.Host} Message:{body} 失败");
}
return false;
}
return false;
});
}
#region Client Messages
///
/// 订阅消息
///
///
///
private static void Subscribe(WebSocketClient client, ReceiveMessage msg)
{
if (client == null || client.Socket.State != WebSocketState.Open
|| msg.Content == null) return;
string[] msgs = null;
JToken token = msg.Content as JToken;
if (token != null)
{
try
{
if (token.Type == JTokenType.String)
{
msgs = new string[] { token.ToString() };
}
else if (token.Type == JTokenType.Array)
{
JArray array = token.Value();
msgs = array.Select(a => a.Value()).ToArray();
}
}
catch
{
Program.Logger.Error("not valid message");
}
}
else
{
//订阅单个消息
string cmd = msg.Content as string;
if (cmd != null)
{
msgs = new string[] { cmd };
}
else
{
//订阅多个消息
msgs = msg.Content as string[];
}
}
if (msgs != null && msgs.Length > 0)
{
foreach (string c in msgs)
{
if (string.IsNullOrEmpty(c) || client.Messages.Contains(c)) continue;
client.Messages.Add(c);
}
}
}
///
/// 取消订阅消息
///
///
///
private static void Unsubscribe(WebSocketClient client, ReceiveMessage msg)
{
if (client == null || client.Socket.State != WebSocketState.Open
|| msg.Content == null) return;
string[] msgs = null;
JToken token = msg.Content as JToken;
if (token != null)
{
try
{
if (token.Type == JTokenType.String)
{
msgs = new string[] { token.ToString() };
}
else if (token.Type == JTokenType.Array)
{
JArray array = token.Value();
msgs = array.Select(a => a.Value()).ToArray();
}
}
catch
{
Program.Logger.Error("not valid message");
}
}
else
{
//订阅单个消息
string cmd = msg.Content as string;
if (cmd != null)
{
msgs = new string[] { cmd };
}
else
{
//订阅多个消息
msgs = msg.Content as string[];
}
}
if (msgs != null && msgs.Length > 0)
{
foreach (string c in msgs)
{
if (string.IsNullOrEmpty(c)) continue;
client.Messages.Remove(c);
}
}
}
#endregion
///
/// 客户端消息
///
public class ReceiveMessage
{
public string MessageType { get; set; }
public object Content { get; set; }
}
///
/// 服务器端消息
///
public class SendMessage
{
public string SenderID { get; set; }
public string MessageType { get; set; }
public object Content { get; set; }
public SendMessage()
{
SenderID = "System";
}
}
public class WebSocketClient
{
///
/// 客户端id
///
public string ClientID { get; }
///
/// ip:port
///
public string Host { get; set; }
public WebSocket Socket { get;}
public List Messages { get;}
public WebSocketClient(string host,
WebSocket socket)
{
ClientID = Guid.NewGuid().ToString();
this.Host = host;
this.Socket = socket;
this.Messages = new List();
}
}
}