EMQX源码分析---esockd_connection_sup源码分析

该模块主要是监听连接的socket连接,所以这个模块主要包含一些针对连接的管理接口,该模块主要的API如下:

1、start_link(Opts, MFA) 该函数主要是启动esockd_connection_sup监听器,函数内部调用了OTP的
gen_server:start_link(?MODULE, [Opts, MFA], [])函数,然后回调该模块的init([Opts, MFA])方法。

2、count_connections(Sup) 计算该模块下的socket连接数。内部调用call(Sup, count_connections)发送同步消息,然后该消息被模块的handle_call方法处理。

3、get_max_connections(Sup) 获取最大的连接数,内部调用call(Sup,get_max_connections) 发送同步消息,然后该消息被模块的handle_call方法处理。

4、start_connection(Sup, Sock, UpgradeFuns) 该函数主要是启动一个socket连接,内部调用call(Sup, {start_connection, Sock})
   发送同步消息,然后该消息被模块的handle_call方法处理。

5、set_max_connections(Sup, MaxConns)设置系统最大的连接数,内部调用call(Sup, {set_max_connections, MaxConns}) 发送同步消息,然后该消息被模块的handle_call方法处理。

6、get_shutdown_count(Sup) 获取关闭的连接,内部调用call(Sup, get_shutdown_count),然后被handle_call方法处理。

下面具体看源码注释。

-module(esockd_connection_sup).

-behaviour(gen_server).

-import(proplists, [get_value/3]).

-export([start_link/2, start_connection/3, count_connections/1]).
-export([get_max_connections/1, set_max_connections/2]).
-export([get_shutdown_count/1]).

%% Allow, Deny
-export([access_rules/1, allow/2, deny/2]).

%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2, terminate/2, code_change/3]).

-type(shutdown() :: brutal_kill | infinity | pos_integer()).

-record(state, {curr_connections :: map(), max_connections :: pos_integer(), access_rules :: list(), shutdown :: shutdown(), mfargs :: mfa()}).
%% 定义最大的客户端连接
-define(DEFAULT_MAX_CONNS, 1024).
%% 定义模块
-define(TRANSPORT, esockd_transport).
%%错误消息输出宏定义
-define(ERROR_MSG(Format, Args), error_logger:error_msg("[~s] " ++ Format, [?MODULE | Args])).

%% 启动连接监督者服务
-spec(start_link([esockd:option()], esockd:mfargs()) -> {ok, pid()} | ignore | {error, term()}).
start_link(Opts, MFA) ->
%%    io:format("esockd_connection_sup start_link ~n"),
    gen_server:start_link(?MODULE, [Opts, MFA], []).

%%------------------------------------------------------------------------------
%% API
%%------------------------------------------------------------------------------

%% 开启连接
start_connection(Sup, Sock, UpgradeFuns) ->
%%    发送同步消息给模块的handle_call方法去处理
    case call(Sup, {start_connection, Sock}) of
%%         返回连接进程的Pid
        {ok, ConnPid} ->
            %% Transfer controlling from acceptor to connection
            _ = ?TRANSPORT:controlling_process(Sock, ConnPid), %% 被监听
            _ = ?TRANSPORT:ready(ConnPid, Sock, UpgradeFuns),%% 准备读
            {ok, ConnPid};
        ignore -> ignore;
        {error, Reason} ->
            {error, Reason}
    end.

%% 启动连接进程,以echo_server为例子,然后就会调用echo_server的init函数
-spec(start_connection_proc(esockd:mfargs(), esockd_transport:sock()) -> {ok, pid()} | ignore | {error, term()}).
start_connection_proc(M, Sock) when is_atom(M) ->
    M:start_link(?TRANSPORT, Sock);
start_connection_proc({M, F}, Sock) when is_atom(M), is_atom(F) ->
    M:F(?TRANSPORT, Sock);
start_connection_proc({M, F, Args}, Sock) when is_atom(M), is_atom(F), is_list(Args) ->
    erlang:apply(M, F, [?TRANSPORT, Sock | Args]). %% echo_server,start_link,[]


-spec(count_connections(pid()) -> integer()).
count_connections(Sup) ->
    call(Sup, count_connections).

-spec(get_max_connections(pid()) -> integer()).
get_max_connections(Sup) when is_pid(Sup) ->
    call(Sup, get_max_connections).

-spec(set_max_connections(pid(), integer()) -> ok).
set_max_connections(Sup, MaxConns) when is_pid(Sup) ->
    call(Sup, {set_max_connections, MaxConns}).

-spec(get_shutdown_count(pid()) -> integer()).
get_shutdown_count(Sup) ->
    call(Sup, get_shutdown_count).

access_rules(Sup) ->
    call(Sup, access_rules).

allow(Sup, CIDR) ->
    call(Sup, {add_rule, {allow, CIDR}}).

deny(Sup, CIDR) ->
    call(Sup, {add_rule, {deny, CIDR}}).

call(Sup, Req) ->
    gen_server:call(Sup, Req, infinity).

%%------------------------------------------------------------------------------
%% gen_server callbacks
%%------------------------------------------------------------------------------

init([Opts, MFA]) ->
    process_flag(trap_exit, true),
%%      获取进程关闭方式
    Shutdown = get_value(shutdown, Opts, brutal_kill),
%%      获取设置的最大连接数量
    MaxConns = get_value(max_connections, Opts, ?DEFAULT_MAX_CONNS),
%%      获取规则
    RawRules = get_value(access_rules, Opts, [{allow, all}]),
%%      获取权限规则
    AccessRules = [esockd_access:compile(Rule) || Rule <- RawRules],
%%      数据存入进程的state记录里
    {ok, #state{curr_connections = #{}, max_connections = MaxConns, access_rules = AccessRules, shutdown = Shutdown, mfargs = MFA}}.

%% 当连接的数量大于最大设置数据,就返回一个{error, maxlimit} 消息
handle_call({start_connection, _Sock}, _From, State = #state{curr_connections = Conns, max_connections = MaxConns}) when map_size(Conns) >= MaxConns ->
    {reply, {error, maxlimit}, State};

%% 启动连接
%% 参数说明:
%% start_connection:原子变量,用于匹配消息
%% Sock:启动socket
%% _From: 消息来自哪个进程
%% State:当前进程状态
%%      curr_connections:当前连接数量
%%      access_rules:当前权限资源
%%      mfargs:要启动的模块,方法,和方法执行的参数组成的元组{M,F,A}
handle_call({start_connection, Sock}, _From, State = #state{curr_connections = Conns, access_rules = Rules, mfargs = MFA}) ->
%%    通过Sock获取socket的ip和port
    case esockd_transport:peername(Sock) of
        {ok, {Addr, _Port}} ->
%%            判断当前的ip地址是不是合法
            case allowed(Addr, Rules) of
                true ->
%%                    如果是合法的地址 开启一个连接进程
                    case catch start_connection_proc(MFA, Sock) of %% echo_server,start_link,[]
%%                    执行成功,返回echo_server的进程Pid
                        {ok, Pid} when is_pid(Pid) ->
%%                            修改进程记录State里面当前连接数的值
                            {reply, {ok, Pid}, State#state{curr_connections = maps:put(Pid, true, Conns)}};
                        ignore ->
                            {reply, ignore, State};
                        {error, Reason} ->
                            {reply, {error, Reason}, State};
                        What ->
                            {reply, {error, What}, State}
                    end;
                false ->
                    {reply, {error, forbidden}, State}
            end;
        {error, Reason} ->
            {reply, {error, Reason}, State}
    end;

%% 计算连接数量
handle_call(count_connections, _From, State = #state{curr_connections = Conns}) ->
    {reply, maps:size(Conns), State};
%% 获取最大的连接数量
handle_call(get_max_connections, _From, State = #state{max_connections = MaxConns}) ->
    {reply, MaxConns, State};
%% 设置最大的连接数
handle_call({set_max_connections, MaxConns}, _From, State) ->
    {reply, ok, State#state{max_connections = MaxConns}};
%% 获取关闭的连接
handle_call(get_shutdown_count, _From, State) ->
    Counts = [{Reason, Count} || {{shutdown_count, Reason}, Count} <- get()],
    {reply, Counts, State};
%%权限规则
handle_call(access_rules, _From, State = #state{access_rules = Rules}) ->
    {reply, [raw(Rule) || Rule <- Rules], State};
%% 增加规则
handle_call({add_rule, RawRule}, _From, State = #state{access_rules = Rules}) ->
    case catch esockd_access:compile(RawRule) of
        {'EXIT', _Error} ->
            {reply, {error, bad_access_rule}, State};
        Rule ->
            case lists:member(Rule, Rules) of
                true ->
                    {reply, {error, already_exists}, State};
                false ->
                    {reply, ok, State#state{access_rules = [Rule | Rules]}}
            end
    end;

handle_call(Req, _From, State) ->
    ?ERROR_MSG("unexpected call: ~p", [Req]),
    {reply, ignored, State}.

handle_cast(Msg, State) ->
    ?ERROR_MSG("unexpected cast: ~p", [Msg]),
    {noreply, State}.

%% 处理异常退出原因
handle_info({'EXIT', Pid, Reason}, State = #state{curr_connections = Conns}) ->
    case maps:take(Pid, Conns) of
        {true, Conns1} ->
            connection_crashed(Pid, Reason, State),
            {noreply, State#state{curr_connections = Conns1}};
        error ->
            ?ERROR_MSG("unexpected 'EXIT': ~p, reason: ~p", [Pid, Reason]),
            {noreply, State}
    end;

handle_info(Info, State) ->
    ?ERROR_MSG("unexpected info: ~p", [Info]),
    {noreply, State}.

%% 终止子进程
terminate(_Reason, State) ->
    terminate_children(State).

code_change(_OldVsn, State, _Extra) ->
    {ok, State}.

%%------------------------------------------------------------------------------
%% Internal functions
%%------------------------------------------------------------------------------
%% 匹配是否有权限
allowed(Addr, Rules) ->
    case esockd_access:match(Addr, Rules) of
%%        没有匹配,返回true
         nomatch          -> true;
%%        匹配允许,返回true
        {matched, allow} -> true;
%%        匹配否定,返回false
        {matched, deny}  -> false
    end.
%% 允许
raw({allow, CIDR = {_Start, _End, _Len}}) ->
     {allow, esockd_cidr:to_string(CIDR)};
%% 否定
raw({deny, CIDR = {_Start, _End, _Len}}) ->
     {deny, esockd_cidr:to_string(CIDR)};
raw(Rule) ->
     Rule.

%% 正常的连接销魂
connection_crashed(_Pid, normal, _State) ->
    ok;
%% 关闭销毁
connection_crashed(_Pid, shutdown, _State) ->
    ok;
%% kill销毁
connection_crashed(_Pid, killed, _State) ->
    ok;

connection_crashed(_Pid, Reason, _State) when is_atom(Reason) ->
    count_shutdown(Reason);
connection_crashed(_Pid, {shutdown, Reason}, _State) when is_atom(Reason) ->
    count_shutdown(Reason);
connection_crashed(Pid, {shutdown, Reason}, State) ->
%%    记录连接关闭
    report_error(connection_shutdown, Reason, Pid, State);
connection_crashed(Pid, Reason, State) ->
%%    记录连接销毁
    report_error(connection_crashed, Reason, Pid, State).

%% 计算关机原因
count_shutdown(Reason) ->
    Key = {shutdown_count, Reason},
    put(Key, case get(Key) of undefined -> 1; Cnt -> Cnt+1 end).

%% 终止该进程下的子进程
terminate_children(State = #state{curr_connections = Conns, shutdown = Shutdown}) ->
%% 返回进程数组
    {Pids, EStack0} = monitor_children(Conns),
%% 计算数组大小    
    Sz = sets:size(Pids),
%% 判断关闭原因
    EStack = case Shutdown of
                %% 暴力关闭
                 brutal_kill ->
                     sets:fold(fun(P, _) -> exit(P, kill) end, ok, Pids),
                     wait_children(Shutdown, Pids, Sz, undefined, EStack0);
                %% 
                 infinity ->
                     sets:fold(fun(P, _) -> exit(P, shutdown) end, ok, Pids),
                     wait_children(Shutdown, Pids, Sz, undefined, EStack0);
                %% 超时关闭 
                Time when is_integer(Time) ->
                     sets:fold(fun(P, _) -> exit(P, shutdown) end, ok, Pids),
                     TRef = erlang:start_timer(Time, self(), kill),
                     wait_children(Shutdown, Pids, Sz, TRef, EStack0)
             end,
    %% Unroll stacked errors and report them
    dict:fold(fun(Reason, Pid, _) ->
                  report_error(connection_shutdown_error, Reason, Pid, State)
              end, ok, EStack).

monitor_children(Conns) ->
    lists:foldl(fun(P, {Pids, EStack}) ->
        case monitor_child(P) of
            ok ->
                {sets:add_element(P, Pids), EStack};
            {error, normal} ->
                {Pids, EStack};
            {error, Reason} ->
                {Pids, dict:append(Reason, P, EStack)}
        end
    end, {sets:new(), dict:new()}, maps:keys(Conns)).

%% Help function to shutdown/2 switches from link to monitor approach
monitor_child(Pid) ->
    %% Do the monitor operation first so that if the child dies
    %% before the monitoring is done causing a 'DOWN'-message with
    %% reason noproc, we will get the real reason in the 'EXIT'-message
    %% unless a naughty child has already done unlink...
    erlang:monitor(process, Pid),
    unlink(Pid),

    receive
	%% If the child dies before the unlik we must empty
	%% the mail-box of the 'EXIT'-message and the 'DOWN'-message.
	{'EXIT', Pid, Reason} ->
	    receive
		{'DOWN', _, process, Pid, _} ->
		    {error, Reason}
	    end
    after 0 ->
	    %% If a naughty child did unlink and the child dies before
	    %% monitor the result will be that shutdown/2 receives a
	    %% 'DOWN'-message with reason noproc.
	    %% If the child should die after the unlink there
	    %% will be a 'DOWN'-message with a correct reason
	    %% that will be handled in shutdown/2.
	    ok
    end.

wait_children(_Shutdown, _Pids, 0, undefined, EStack) ->
    EStack;
wait_children(_Shutdown, _Pids, 0, TRef, EStack) ->
	%% If the timer has expired before its cancellation, we must empty the
	%% mail-box of the 'timeout'-message.
    erlang:cancel_timer(TRef),
    receive
        {timeout, TRef, kill} ->
            EStack
    after 0 ->
            EStack
    end;

%%TODO: Copied from supervisor.erl, rewrite it later.
wait_children(brutal_kill, Pids, Sz, TRef, EStack) ->
    receive
        {'DOWN', _MRef, process, Pid, killed} ->
            wait_children(brutal_kill, sets:del_element(Pid, Pids), Sz-1, TRef, EStack);

        {'DOWN', _MRef, process, Pid, Reason} ->
            wait_children(brutal_kill, sets:del_element(Pid, Pids),
                          Sz-1, TRef, dict:append(Reason, Pid, EStack))
    end;

wait_children(Shutdown, Pids, Sz, TRef, EStack) ->
    receive
        {'DOWN', _MRef, process, Pid, shutdown} ->
            wait_children(Shutdown, sets:del_element(Pid, Pids), Sz-1, TRef, EStack);
        {'DOWN', _MRef, process, Pid, normal} ->
            wait_children(Shutdown, sets:del_element(Pid, Pids), Sz-1, TRef, EStack);
        {'DOWN', _MRef, process, Pid, Reason} ->
            wait_children(Shutdown, sets:del_element(Pid, Pids), Sz-1,
                          TRef, dict:append(Reason, Pid, EStack));
        {timeout, TRef, kill} ->
            sets:fold(fun(P, _) -> exit(P, kill) end, ok, Pids),
            wait_children(Shutdown, Pids, Sz-1, undefined, EStack)
    end.

%% 上报错误原因
report_error(Error, Reason, Pid, #state{mfargs = MFA}) ->
%%    获取sup进程的名称
    SupName  = list_to_atom("esockd_connection_sup - " ++ pid_to_list(self())),
%%    组装错误信息
    ErrorMsg = [{supervisor, SupName}, {errorContext, Error}, {reason, Reason}, {offender, [{pid, Pid}, {name, connection}, {mfargs, MFA}]}],
%%    上报错误日志
    error_logger:error_report(supervisor_report, ErrorMsg).

下一篇将介绍esockd_acceptor_sup 模块的基本功能。

你可能感兴趣的:(Erlang)