.net core WebSocket

WebSocketMiddleware.cs

using Microsoft.AspNetCore.Http;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;

public class WebSocketMiddleware
{
	private readonly RequestDelegate _next;

	public WebSocketMiddleware(RequestDelegate next)
	{
		this._next = next;
	}

	public async Task Invoke(HttpContext context)
	{
		if (!context.WebSockets.IsWebSocketRequest)
		{
			await _next.Invoke(context);
			
		}
		else
		{
			CancellationToken ct = context.RequestAborted;
			//客户端连接
			var currentSocket = await context.WebSockets.AcceptWebSocketAsync();

			await WebSocketHandler.Accept(context, currentSocket, ct);

		}

	}

}

WebSocketHandler

using Microsoft.AspNetCore.Http;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

public static class WebSocketHandler
{
	/// 
	/// 接收大小
	/// 
	const int BLOCK_SIZE = 1024;

	/// 
	/// 订阅消息
	/// 
	public const string SubscribeMsgType = "Subscribe";

	/// 
	/// 取消订阅消息
	/// 
	public const string UnsubscribeMsgType = "Unsubscribe";


	/// 
	/// 已知消息
	/// 
	public readonly static Dictionary> KnownMessages =
		new Dictionary>{
			{ SubscribeMsgType,Subscribe },
			{ UnsubscribeMsgType,Unsubscribe }
	};

	//所有连接的客户端
	private readonly static ConcurrentDictionary Clients = 
		new ConcurrentDictionary();

	/// 
	/// 接收客户端俩姐
	/// 
	/// 
	public static Task Accept(HttpContext context, WebSocket currentSocket, CancellationToken ct)
	{
		WebSocketClient client = new WebSocketClient(context.Request.Host.ToString(),
			currentSocket);

		Program.Logger.Debug($"Accept Client:{client.Host}");

		Clients.TryAdd(client.ClientID, client);

		return Task.Run(() =>
		{
			byte[] buffer = new byte[BLOCK_SIZE];

			while (currentSocket.State == WebSocketState.Open)
			{

				WebSocketReceiveResult result;

				try
				{
					//阻塞
					result = currentSocket.ReceiveAsync(buffer, ct).Result;
				}
				catch(AggregateException)
				{
					break;
				}
				catch(ThreadAbortException)
				{
					break;
				}
				catch(TaskCanceledException)
				{
					break;
				}

				if (result == null) continue;

				if (result.MessageType == WebSocketMessageType.Close)
				{
					break;
				}

				if (result.MessageType == WebSocketMessageType.Text)
				{
					using (MemoryStream ms = new MemoryStream())
					{
						ms.Write(buffer, 0, result.Count);
						bool flag = true;
						while (!result.EndOfMessage)
						{
							try
							{
								result = currentSocket.ReceiveAsync(buffer, ct).Result;
							}
							catch (AggregateException)
							{
								break;
							}
							catch (ThreadAbortException)
							{
								flag = false;
								break;
							}
							catch (TaskCanceledException)
							{
								flag = false;
								break;
							}
							ms.Write(buffer, 0, result.Count);
						}

						if (!flag)
							break;

						ms.Position = 0;

						string msg = null;
						try
						{
							byte[] bs = ms.ToArray();
							if (bs.Length == 0) continue;

							msg = Encoding.UTF8.GetString(bs);

							ProcessClientMessage(client, JsonConvert.DeserializeObject(msg));
						}
						catch
						{
							Program.Logger.Debug($"{client.Host} Receive:{msg}");
						}
					}
				}
				else
				{
					Program.Logger.Debug($"{client.Host} receive binary");
				}
			}

			CancellationToken cancellationToken = new CancellationToken();
			currentSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing", cancellationToken);
			Clients.TryRemove(client.ClientID, out client);

			Program.Logger.Debug($"Close Client:{client.Host}");
		});
	}

	/// 
	/// 处理客户端信息
	/// 
	/// 
	private static void ProcessClientMessage(WebSocketClient client,ReceiveMessage msg)
	{
		if (client == null || client.Socket.State != WebSocketState.Open 
			|| msg == null || string.IsNullOrEmpty(msg.MessageType)) return;

		Program.Logger.Debug($"{client.Host} {msg.MessageType}");

		if(KnownMessages.ContainsKey(msg.MessageType))
		{
			KnownMessages[msg.MessageType](client,msg);
		}
	}

	/// 
	/// 服务器端发布消息
	/// 
	/// 
	public static void PublishMessage(SendMessage message)
	{
		if (message == null 
			|| string.IsNullOrEmpty(message.MessageType)) return;

		ICollection keys = Clients.Keys;

		CancellationToken cancellationToken = new CancellationToken();

		string body = Program.Json(message);
		byte[] bs = Encoding.UTF8.GetBytes(body);
		ArraySegment msg = new ArraySegment(bs);

		keys.AsParallel().Count(k => {

			WebSocketClient client = null;
			Clients.TryGetValue(k, out client);

			if(client != null 
			&& client.Socket.State == WebSocketState.Open
			&& client.Messages.Contains(message.MessageType))
			{
				try
				{
					client.Socket.SendAsync(msg, WebSocketMessageType.Text, true, cancellationToken);
					return true;
				}
				catch
				{
					Program.Logger.Error($"SendTo{client.Host} Message:{body} 失败");
				}
				return false;
			}

			return false;
		});
		
	}

	#region Client Messages
	/// 
	/// 订阅消息
	/// 
	/// 
	/// 
	private static void Subscribe(WebSocketClient client, ReceiveMessage msg)
	{
		if (client == null || client.Socket.State != WebSocketState.Open
			|| msg.Content == null) return;

		string[] msgs = null;

		JToken token = msg.Content as JToken;
		if (token != null)
		{
			try
			{
				if (token.Type == JTokenType.String)
				{
					msgs = new string[] { token.ToString() };
				}
				else if (token.Type == JTokenType.Array)
				{
					JArray array = token.Value();
					msgs = array.Select(a => a.Value()).ToArray();
				}
			}
			catch
			{
				Program.Logger.Error("not valid message");
			}
		}
		else
		{
			//订阅单个消息
			string cmd = msg.Content as string;
			if (cmd != null)
			{
				msgs = new string[] { cmd };
			}
			else
			{
				//订阅多个消息
				msgs = msg.Content as string[];
			}
		}

		if (msgs != null && msgs.Length > 0)
		{
			foreach (string c in msgs)
			{
				if (string.IsNullOrEmpty(c) || client.Messages.Contains(c)) continue;

				client.Messages.Add(c);
			}
		}
	}

	/// 
	/// 取消订阅消息
	/// 
	/// 
	/// 
	private static void Unsubscribe(WebSocketClient client, ReceiveMessage msg)
	{
		if (client == null || client.Socket.State != WebSocketState.Open
			|| msg.Content == null) return;

		string[] msgs = null;

		JToken token = msg.Content as JToken;
		if (token != null)
		{
			try
			{
				if (token.Type == JTokenType.String)
				{
					msgs = new string[] { token.ToString() };
				}
				else if (token.Type == JTokenType.Array)
				{
					JArray array = token.Value();
					msgs = array.Select(a => a.Value()).ToArray();
				}
			}
			catch
			{
				Program.Logger.Error("not valid message");
			}
		}
		else
		{
			//订阅单个消息
			string cmd = msg.Content as string;
			if (cmd != null)
			{
				msgs = new string[] { cmd };
			}
			else
			{
				//订阅多个消息
				msgs = msg.Content as string[];
			}
		}

		if (msgs != null && msgs.Length > 0)
		{
			foreach (string c in msgs)
			{
				if (string.IsNullOrEmpty(c)) continue;

				client.Messages.Remove(c);
			}
		}
	}
	#endregion

	/// 
	/// 客户端消息
	/// 
	public class ReceiveMessage
	{
		public string MessageType { get; set; }

		public object Content { get; set; }
	}

	/// 
	/// 服务器端消息
	/// 
	public class SendMessage
	{
		public string SenderID { get; set; }
		public string MessageType { get; set; }
		public object Content { get; set; }

		public SendMessage()
		{
			SenderID = "System";
		}
	}

	public class WebSocketClient
	{
		/// 
		/// 客户端id
		/// 
		public string ClientID { get; }
		/// 
		/// ip:port
		/// 
		public string Host { get; set; }

		public WebSocket Socket { get;}

		public List Messages { get;}

		public WebSocketClient(string host,
			WebSocket socket)
		{
			ClientID = Guid.NewGuid().ToString();
			this.Host = host;
			this.Socket = socket;
			this.Messages = new List();
		}
	}
}

你可能感兴趣的:(C#,.NET,.NETCORE)