TLS双向认证,三级证书链

  1. 服务器
/**
 *  Filename: server.c
 *   Created: 2019-09-19 11:15:30
 *      Desc: TODO (some description)
 *    Author: hair-man 
 *   Company: owner 
 */

#include 
#include 
#include 
#include 
#include 
#include 

#include 
#include 
#include 


#include 
#include 

#include 
#include 

#define USETLS


#define MAX_EPOLL_EVENTS 1000000
#define MAX_BUFFER_SIZE  65536 

//static __attribute__((unused))char* version = "VERSION"VERSION;

typedef struct _ssl_ctx_context
{
    int fd;
    SSL* ssl;
}scontext_t;


/* 证书打印信息 */
void ShowCerts(SSL *ssl)
{
    X509 *cert;
    char *line;

    cert = SSL_get_peer_certificate(ssl);
    if (cert != NULL)
    {
        printf("数字证书信息:\n");
        line = X509_NAME_oneline(X509_get_subject_name(cert), 0, 0);
        printf("证书: %s\n", line);
        free(line);

        line = X509_NAME_oneline(X509_get_issuer_name(cert), 0, 0);
        printf("颁发者: %s\n", line);
        free(line);
        X509_free(cert);
    }
    else
    {
        printf("无证书信息!\n");
    }
}

int verify_callback(int ok, X509_STORE_CTX *ctx)
{
    printf("\r\n ok = %d\r\n", ok);
    if (!ok) 
    {
        ok = 1;
    }
    return (ok);
}

SSL_CTX* init_ssl_ctx(int verify_client)
{
    SSL_CTX *ctx = NULL;
    int ret_err = 0;

    ctx = SSL_CTX_new(SSLv23_server_method());
    if (ctx == NULL)
    {
        ERR_print_errors_fp(stderr);
        fprintf(stderr, "SSL_CTX_new failed!\n");
        return NULL;
    }

#define SERVER_CERT     "third_cert/server.crt"
#define SERVER_KEY      "third_cert/ca.key"
#define SERVER_CA_CERT  "third_cert/ca.crt"

    /* 加载服务器证书 */
    ret_err = SSL_CTX_use_certificate_file(ctx, SERVER_CA_CERT, SSL_FILETYPE_PEM);
    if (ret_err <= 0)
    {
        ERR_print_errors_fp(stderr);
        fprintf(stderr, "SSL_CTX_use_certificate_file failed!\n");
        return NULL;
    }

    /* 加载证书链 */
    ret_err = SSL_CTX_use_certificate_chain_file(ctx, SERVER_CA_CERT);
    if (ret_err <= 0)
    {
        ERR_print_errors_fp(stderr);
        fprintf(stderr, "SSL_CTX_use_certificate_chain_file failed!\n");
        return NULL;
    }

    /* 加载服务器密钥 */
    ret_err = SSL_CTX_use_PrivateKey_file(ctx, SERVER_KEY, SSL_FILETYPE_PEM);
    if (ret_err <= 0)
    {
        ERR_print_errors_fp(stderr);
        fprintf(stderr, "SSL_CTX_use_PrivateKey_file failed!\n");
        return NULL;
    }

    /* 检查私钥与证书是否匹配 */
    if (!SSL_CTX_check_private_key(ctx))
    {
        ERR_print_errors_fp(stderr);
        fprintf(stderr, "SSL_CTX_check_private_key failed!\n");
        return NULL;
    }

    if(1 == verify_client)   /* 双向认证 */
    {
        /* 加载根证书 */
        ret_err = SSL_CTX_load_verify_locations(ctx, SERVER_CA_CERT, NULL);
        if (ret_err <= 0)
        {
            ERR_print_errors_fp(stderr);
            fprintf(stderr, "SSL_CTX_load_verify_locations failed!\n");
            return NULL;
        } 

        SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, verify_callback);
        SSL_CTX_set_verify_depth(ctx, 2);
    }
    else    /* 单向认证 */
    {
        SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL);
    }

    //设置单双向认证
    //SSL_VERIFY_NONE
    //作为服务器:服务器不会向客户端询问客户端证书
    //作为客户端:服务器会向客户端发送一个证书,不关心校验结果
    //SSL_VERIFY_PEER
    //作为服务器:服务器向客户端询问客户端证书,并检查,验证失败则终止
    //作为客户端:检查服务求发来的证书,验证失败则终止
    //
    //
    //
    //下面两个标志必须与SSL_VERIFY_PEER联合使用
    //SSL_VERIFY_FAIL_IF_NO_PEER_CERT
    //作为服务器有效:客户端如果不发送证书则表示验证失败,终止
    //SSL_VERIFY_CLIENT_ONCE
    //作为服务器有效:尽在初始TLS握手时请求客户端证书,重新协商不需要客户端证书
    
    //双向 - 让客户端发送客户端证书并进行验证
    //SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT, NULL);

    //单向 - 不需要客户端发送证书
    //SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL);




    return ctx;
}

int readn(int fd, uint8_t* buffer, int n)
{
    int nleft = n;
    int nread = 0;

    while(nleft > 0)
    {
        if((nread = read(fd, buffer + nread, nleft)) == -1)
        {
            if(errno == EINTR)
                nread = 0;
            else
                return -1;
        }
        else if(nread == 0)
            break;

        nleft -= nread;
    }

    return n - nleft;
}

int writen(int fd, uint8_t* buffer, int n)
{
    int nleft = n;
    int nwrite = 0;

    while(nleft > 0)
    {
        if((nwrite = write(fd, buffer + nwrite, nleft)) <= 0)
        {
            if(errno == EINTR)
                nwrite = 0;
            else
                return -1;
        }

        nleft -= nwrite;
    }

    return n;
}

int set_nonblock(int fd)
{
    int block_mode = 0;

    //设置非阻塞模式
    block_mode = fcntl(fd, F_GETFL);
    if(block_mode < 0)
    {
        fprintf(stderr, "get block mode failed!Error:%d ErrMsg:%s\n", errno, strerror(errno));
        return -1;
    }

    block_mode = O_NONBLOCK | block_mode;
    if(fcntl(fd, F_SETFL, block_mode) < 0)
    {
        fprintf(stderr, "set block mode failed!Error:%d ErrMsg:%s\n", errno, strerror(errno));
        return -1;
    }

    return 0;
}


//ip - 网络序
//port - 主机序
int init_socket(int *sockfd, uint32_t ip, uint16_t port)
{
    int val = 1;
    int ret_err = 0;
    int sndmem = 1024*1024;
    int rcvmem = 1024*1024;

    struct sockaddr_in my_addr;

    *sockfd = socket(AF_INET, SOCK_STREAM, 0);
    if(*sockfd < 0)
    {
        fprintf(stderr, "socket failed.\n");
        return -1;
    }
    //设置端口复用
    ret_err = setsockopt(*sockfd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));
    if (ret_err)
    {
        fprintf(stderr, "set reuse addr failed.\n");
        close(*sockfd);
        *sockfd = -1;
        return -1;
    }

    bzero(&my_addr, sizeof(my_addr));
    my_addr.sin_family = AF_INET;
    my_addr.sin_port = htons(port);
    my_addr.sin_addr.s_addr = ip;
    ret_err = bind(*sockfd, (struct sockaddr *)&my_addr, sizeof(my_addr));
    if(ret_err)
    {
        fprintf(stderr, "bind failed. ip: %s port: %d\n", inet_ntoa(*(struct in_addr *)&ip), port);
        close(*sockfd);
        *sockfd = -1;
        return -1;
    }
    else
        fprintf(stderr, "bind success. ip: %s port: %d\n", inet_ntoa(*(struct in_addr *)&ip), port);


    if(0 != setsockopt(*sockfd, SOL_SOCKET, SO_RCVBUF, (const char*)&rcvmem, sizeof(int)))
    {
        fprintf(stderr, "setsockopt SO_RCVBUF failed. ip: %s port: %d\n", inet_ntoa(*(struct in_addr *)&ip), port);
    }
    
    if(0 != setsockopt(*sockfd, SOL_SOCKET, SO_SNDBUF, (const char*)&sndmem, sizeof(int)))
    {
        fprintf(stderr, "setsockopt SO_SNDBUF failed. ip: %s port: %d\n", inet_ntoa(*(struct in_addr *)&ip), port);
    }

    if(listen(*sockfd, 32) != 0)
    {
        fprintf(stderr, "listen socket tcp fd failed! Error:%d ErrMsg:%s\n", errno, strerror(errno));
        return -1;
    }

#if 1
    /* 非阻塞 */
    if(0 != set_nonblock(*sockfd))
    {
        printf("set nonblock failed!\n");
        return -1;
    }
#endif


    return 0;
}

int add_epoll_fd(int efd, int fd, uint32_t events, void* user_data)
{
    struct epoll_event ev;

    ev.events = events;
    ev.data.ptr = (void *)user_data;
    if(epoll_ctl(efd, EPOLL_CTL_ADD, fd, &ev) == -1) 
    {   
        fprintf(stderr, "epoll_ctl ADD failed! efd:%u, connfd:%u Error:%d ErrMsg:%s\n", efd, fd, errno, strerror(errno));
        return -1; 
    }   

    return 0;
}

int del_epoll_fd(int efd, int fd)
{
    struct epoll_event ev;

    if(epoll_ctl(efd, EPOLL_CTL_DEL, fd, &ev) == -1) 
    {   
        fprintf(stderr, "epoll_ctl DEL failed! efd:%u, connfd:%u Error:%d ErrMsg:%s\n", efd, fd, errno, strerror(errno));
        return -1; 
    }   

    return 0;
}

void init_ssl()
{
    SSL_library_init();
    OpenSSL_add_all_algorithms();
    SSL_load_error_strings();
}

void usage(int argc, char** argv)
{
#if 1
    fprintf(stderr, "argc:%d", argc);
#endif

    fprintf(stdout, "\neg \n\t%s --ip [x.x.x.x] --port [0~65535]\n\n", argv[0]);
    exit(0);
}



int check_option(char*ip, uint16_t* port, int argc, char** argv)
{
    int opt = 0;
    struct option opts[] = 
    {
        {"ip", 1, NULL, 1},
        {"port", 1, NULL, 2},
        {0, 0, 0, 0}
    };

    while((opt = getopt_long(argc, argv, "", opts, NULL)) != -1)
    {
        switch(opt)
        {
            case 1:
                strcpy(ip, optarg);
                fprintf(stdout, "ip:%s\n", optarg);
                break;
            case 2:
                *port = atoi(optarg);
                fprintf(stdout, "port:%s\n", optarg);
                break;
            default:
                
                fprintf(stdout, "get opt fialed!\n");
                return -1;
        }
    }

    return 0;
}


int main(int argc, char** argv)
{
    int ret = 0;
    int serverfd = 0;

    int clientfd = 0;
    struct sockaddr_in client_addr;
    socklen_t client_len = sizeof(struct sockaddr);

    int efd = 0;
    int num = 0;
    int fd_counts = 0;

#ifdef USETLS
    SSL* ssl = NULL;
    SSL_CTX* ctx = NULL;
#endif

    uint8_t* buffer = (uint8_t*)malloc(MAX_BUFFER_SIZE);

    char ip[32] = {0};
    uint16_t port = 0;
    uint8_t* ca = (uint8_t*)malloc(MAX_BUFFER_SIZE);
    uint8_t* pri = (uint8_t*)malloc(MAX_BUFFER_SIZE);
    uint8_t* cer = (uint8_t*)malloc(MAX_BUFFER_SIZE);

    struct epoll_event* events = (struct epoll_event*)malloc(sizeof(struct epoll_event) * MAX_EPOLL_EVENTS);
    memset(events, 0, sizeof(struct epoll_event) * MAX_EPOLL_EVENTS);

    if(argc != 5)
        usage(argc, argv);

    if(-1==  check_option(ip, &port, argc, argv))
    {
        fprintf(stderr, "check option failed!\n");
        exit(0);
    }

#ifdef USETLS
    init_ssl();
#endif

    ret = init_socket(&serverfd, inet_addr(ip), port);
    if(ret != 0)
    {
        exit(0);
    }


    efd = epoll_create(MAX_EPOLL_EVENTS);
    if(efd <= 0)
    {
        fprintf(stderr, "epoll_create failed! Error:%d ErrMsg:%s\n", errno, strerror(errno));
        return -1;
    }

    scontext_t* scontext = (scontext_t*)malloc(sizeof(scontext_t));
    if(scontext)
        memset(scontext, 0, sizeof(scontext_t));
    else
    {
        fprintf(stderr, "ssl ctx context create failed!\n");
        exit(0);
    }

    scontext->fd = serverfd;

#ifdef USETLS
    ctx = init_ssl_ctx(1);
#endif

    if(0 != add_epoll_fd(efd, serverfd, EPOLLIN, scontext))
    {
        printf("add listen socket tcp fd to epoll handle failed!");
        return -1;
    }

    fd_counts ++;


    do
    {
        num = epoll_wait(efd, events, fd_counts, -1);

        while(num--)
        {
            scontext = (scontext_t*)events[num].data.ptr;
            
            if(events[num].events & EPOLLERR || events[num].events & EPOLLHUP)
            {
                fprintf(stderr, "events ERR or HUP!\n");
                close(scontext->fd);
            }
            else if(scontext->fd == serverfd)
            {
                fprintf(stdout, "new connect is comming\n");

                memset(&client_addr, 0, sizeof(struct sockaddr));
                clientfd = accept(scontext->fd, (struct sockaddr*)&client_addr, &client_len);
                if(clientfd < 0)
                {
                    fprintf(stderr, "new client accept failed!\n");
                    continue;
                }

                fprintf(stdout, "new client [%s:%d] is accepted!\n", inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port));

                scontext = (scontext_t*)malloc(sizeof(scontext_t));
                scontext->ssl = NULL;
#ifdef USETLS
                ssl = SSL_new(ctx);
                SSL_set_fd(ssl, clientfd);
                if ((ret = SSL_accept(ssl)) != 1)
                {
                    ERR_print_errors_fp(stderr);
                    ERR_print_errors_fp(stdout);
                    fprintf(stderr, "ssl accept failed! ret %d errcode:%d\n", ret, SSL_get_error(ssl, ret));
                    close(clientfd);
                    continue;
                }

                /* 显示对端证书 */
                ShowCerts(ssl);

                scontext->ssl = ssl;
#endif

                scontext->fd = clientfd;
#if 1
                if(0 != set_nonblock(clientfd))
                {
                    printf("client fd set nonblock failed!\n");
                    continue;
                }
#endif

                if(0 != add_epoll_fd(efd, clientfd, EPOLLIN | EPOLLOUT, scontext))
                {
                    fprintf(stderr, "add epoll fd failed!\n");
                    continue;
                }

                continue;
            }


            if(events[num].events & EPOLLIN)
            {
                clientfd = scontext->fd;
                if(clientfd < 0)
                {
                    fprintf(stderr, "read fd < 0\n");
                    continue;
                }

                printf("epollin socket [%d]\n", clientfd);

                memset(buffer, 0, MAX_BUFFER_SIZE);
#ifndef USETLS
                ret = recv(clientfd, buffer, MAX_BUFFER_SIZE - 1, 0);
                if(ret <= 0)
                {
                    printf("cliet close!\n");
                    if(errno == ECONNRESET || ret == 0)    
                    {

                        if(0 != del_epoll_fd(efd, clientfd))
                        {
                            printf("del epoll fd failed!\n");
                        }

                        close(clientfd);
                    }

                    continue;
                }
#else
                ssl = scontext->ssl;
                ret = SSL_read(ssl, buffer, MAX_BUFFER_SIZE);
                if(ret <= 0)
                {
                    ret = SSL_get_error(ssl, ret);
                    fprintf(stderr, "SSL_read has error! ssl get errcode:%d\n", ret);

                    if(ret == SSL_ERROR_WANT_READ)
                    {
                        fprintf(stderr, "SSL_ERROR_WANT_READ\n");
                    }
                    else if(ret == SSL_ERROR_WANT_WRITE)
                    {
                        fprintf(stderr, "SSL_ERROR_WANT_WRITE");
                    }
                    
                }
#endif

                fprintf(stdout, "recv buffer: %s\n", buffer);

                memset(buffer, 0, MAX_BUFFER_SIZE);

                if(!getpeername(clientfd, (struct sockaddr*)&client_addr, &client_len))
                {
                    sprintf((char*)buffer, "HTTP/1.1 200 OK >>>> ------------- %s:%u", inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port));
                }

#ifndef USETLS
                if(-1 == writen(clientfd, (uint8_t*)buffer, (int)strlen((char*)buffer)))
                {
                    fprintf(stderr, "write failde!\n");
                }
#else
                ret = SSL_write(ssl, buffer, (int)strlen((char*)buffer));
                if(ret <= 0)
                {
                    ret = SSL_get_error(ssl, ret);
                    fprintf(stderr, "SSL_write has error! ssl get errcode:%d\n", ret);

                    if(ret == SSL_ERROR_WANT_READ)
                    {
                        fprintf(stderr, "-> SSL_ERROR_WANT_READ\n");
                    }
                    else if(ret == SSL_ERROR_WANT_WRITE)
                    {
                        fprintf(stderr, "-> SSL_ERROR_WANT_WRITE");
                    }
                    
                }
#endif

            }

#if 0
            if(events[num].events & EPOLLOUT)
            {
                clientfd = scontext->fd;
                if(clientfd < 0)
                {
                    fprintf(stderr, "read fd < 0\n");
                    continue;
                }

                printf("epollout socket [%d]\n", clientfd);

                memset(buffer, 0, MAX_BUFFER_SIZE);
                sprintf((char*)buffer, "HTTP/1.1 200 OK\r\nContent-Length: %d\r\n\r\n%s", (int)strlen("hello tls world"), "hello tls world");

                if(-1 == writen(clientfd, (uint8_t*)buffer, (int)strlen((char*)buffer)))
                {
                    fprintf(stderr, "write failde!\n");
                }
            }
#endif

        }
    }while(1);
        
    return 0;
}

2.客户端

/**
 *  Filename: client.c
 *   Created: 2019-09-19 17:38:25
 *      Desc: TODO (some description)
 *    Author: hair-man
 *   Company: owner 
 */


#include 
#include 
#include 
#include 
#include 
#include 
#include 

#include 
#include 
#include 
#include 


#include 
#include 

#include 

#define MAX_BUFFER_SIZE 65536
#define USETLS


int readn(int fd, uint8_t* buffer, int n)
{
    int nleft = n;
    int nread = 0;

    while(nleft > 0)
    {
        if((nread = read(fd, buffer + nread, nleft)) == -1)
        {
            if(errno == EINTR)
                nread = 0;
            else
                return -1;
        }
        else if(nread == 0)
            break;

        nleft -= nread;
    }

    return n - nleft;
}

int writen(int fd, uint8_t* buffer, int n)
{
    int nleft = n;
    int nwrite = 0;

    while(nleft > 0)
    {
        if((nwrite = write(fd, buffer + nwrite, nleft)) <= 0)
        {
            if(errno == EINTR)
                nwrite = 0;
            else
                return -1;
        }

        nleft -= nwrite;
    }

    return n;
}


void usage(int argc, char** argv)
{
#if 1
    fprintf(stderr, "argc:%d", argc);
#endif

    fprintf(stdout, "\neg \n\t%s --ip [xxx.xxx.xxx.xxx] --port [1~65535]\n\n", argv[0]);
    exit(0);
}


int check_option(char* ip, uint16_t* port, int argc, char** argv)
{
    int opt = 0;
    struct option opts[] = 
    {
        {"ip", 1, NULL, 1},
        {"port", 1, NULL, 2},
        {0, 0, 0, 0}
    };

    while((opt = getopt_long(argc, argv, "", opts, NULL)) != -1)
    {
        switch(opt)
        {
            case 1:
                strcpy(ip, optarg);
                fprintf(stdout, "ip:%s\n", optarg);
                break;
            case 2:
                *port = atoi(optarg);
                fprintf(stdout, "port:%s\n", optarg);
                break;
            default:
                
                fprintf(stdout, "get opt fialed!\n");
                return -1;
        }
    }

    return 0;
}

int set_nonblock(int fd)
{
    int block_mode = 0;

    //设置非阻塞模式
    block_mode = fcntl(fd, F_GETFL);
    if(block_mode < 0)
    {
        fprintf(stderr, "get block mode failed!Error:%d ErrMsg:%s\n", errno, strerror(errno));
        return -1;
    }

    block_mode = O_NONBLOCK | block_mode;
    if(fcntl(fd, F_SETFL, block_mode) < 0)
    {
        fprintf(stderr, "set block mode failed!Error:%d ErrMsg:%s\n", errno, strerror(errno));
        return -1;
    }

    return 0;
}

/* 证书打印信息 */
void ShowCerts(SSL *ssl)
{
    X509 *cert;
    char *line;

    cert = SSL_get_peer_certificate(ssl);
    if (cert != NULL)
    {
        printf("数字证书信息:\n");
        line = X509_NAME_oneline(X509_get_subject_name(cert), 0, 0);
        printf("证书: %s\n", line);
        free(line);

        line = X509_NAME_oneline(X509_get_issuer_name(cert), 0, 0);
        printf("颁发者: %s\n", line);
        free(line);
        X509_free(cert);
    }
    else
    {
        printf("无证书信息!\n");
    }
}

/* SSL 初始化 */
void init_ssl()
{
    SSL_library_init();
    OpenSSL_add_all_algorithms();
    SSL_load_error_strings();
}

/* 创建socket */
int init_socket(char* ip, uint16_t port)
{
    int clientfd = 0;
    int sndmem = 1024*1024;
    int rcvmem = 1024*1024;

    clientfd = socket(PF_INET, SOCK_STREAM, 0);
    if(clientfd < 0)
    {
        fprintf(stderr, "client fd create failed\n");
        return -1;
    }

#if 1
    /* 设置发送接收缓冲区 */
    if(0 != setsockopt(clientfd, SOL_SOCKET, SO_SNDBUF, (const char*)&sndmem, sizeof(int)))
    {
        fprintf(stderr, "setsockopt SO_SNDBUF FAILED. ip: %s,  port: %d\n", inet_ntoa(*(struct in_addr*)&ip), port);
        return -2;
    }

    if(0 != setsockopt(clientfd, SOL_SOCKET, SO_RCVBUF, (const char*)&rcvmem, sizeof(int)))
    {
        fprintf(stderr, "setsockopt SO_RCVBUF FAILED. ip : %d, port : %d\n", inet_ntoa(*(struct in_addr*)&ip), port);
        return -3;
    }
#endif
#if 0
    /* 设置阻塞等待时间 */
    struct timeval tv_out;
    tv_out.tv_sec = 2;
    tv_out.tv_usec = 0;

    if(0 != setsockopt(clientfd, SOL_SOCKET, SO_RCVTIMEO, &tv_out, sizeof(tv_out)))
    {
        fprintf(stderr, "setsockopt SO_RCVTIMEO FAILED. ip : %d, port : %d\n", inet_ntoa(*(struct in_addr*)&ip), port);
        return -4;
    }
    if(0 != setsockopt(clientfd, SOL_SOCKET, SO_SNDTIMEO, &tv_out, sizeof(tv_out)))
    {
        fprintf(stderr, "setsockopt SO_SNDTIMEO FAILED. ip : %d, port : %d\n", inet_ntoa(*(struct in_addr*)&ip), port);
        return -5;
    }
#endif

    return clientfd;
}


SSL_CTX *ssl_ctx_init(int verify_option)
{
    SSL_CTX *ctx = NULL;
    const SSL_METHOD *meth;

    meth = SSLv23_client_method();
    ctx = SSL_CTX_new(meth);

    if(2 == verify_option)  //双向认证
    {

#define CLIENT_S_CERT   "third_cert/client.crt"
#define CLIENT_S_KEY    "third_cert/ca.key"
#define CLIENT_CA_CERT  "third_cert/ca.crt"
  
        #if 1
        if(SSL_CTX_use_certificate_chain_file(ctx, CLIENT_CA_CERT) <= 0)
        {
            ERR_print_errors_fp(stderr);
            goto err_proc;
        }
        #endif

        /* 加载服务端证书 */
        if(SSL_CTX_use_certificate_file(ctx, CLIENT_CA_CERT, SSL_FILETYPE_PEM) <= 0)
        {
            ERR_print_errors_fp(stderr);
            goto err_proc;
        }

        /* 加载服务端密钥 */
        if(SSL_CTX_use_PrivateKey_file(ctx, CLIENT_S_KEY, SSL_FILETYPE_PEM) <= 0)
        {
            ERR_print_errors_fp(stderr);
            goto err_proc;
        }
        #if 1
        /* 加载根证书 */
        if(!SSL_CTX_load_verify_locations(ctx, CLIENT_CA_CERT, NULL))
        {
            ERR_print_errors_fp(stderr);
            goto err_proc;
        }
        #endif
        SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL);
    }
    else if(1 == verify_option) //单向认证
    {
        if(!SSL_CTX_load_verify_locations(ctx, CLIENT_CA_CERT, NULL))
        {
            ERR_print_errors_fp(stderr);
            goto err_proc;
        }
        SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL);
        SSL_CTX_set_verify_depth(ctx, 2);
    }
    else    //不验证
    {
        SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL);
    }
    
    return ctx;
err_proc:
    if(ctx)
    {
        SSL_CTX_free(ctx);
        ctx = NULL;
    }
    return NULL;
}


void u_alarm_handler()
{
    printf("connect time out \n");
}

int main(int argc, char** argv)
{
    int ret = 0;
    int clientfd = 0;
    struct sockaddr_in server_addr;
    socklen_t server_len = sizeof(struct sockaddr);

    char* buffer = (char*)malloc(MAX_BUFFER_SIZE);
    char ip[32] = {0};
    uint16_t port = 0;

    /* 输入校验 */
    if(argc != 5)
        usage(argc, argv);

    if(-1 == check_option((char*)ip, &port, argc, argv))
    {
        fprintf(stderr, "check option failed!\n\n");
        usage(argc, argv);
        exit(-1);
    }

#ifdef USETLS
    SSL* ssl = NULL;
    SSL_CTX* ctx = NULL;

    init_ssl();
#endif

    /* 创建socket */
    clientfd = init_socket(ip,port);
    if(clientfd < 0 )
    {
        fprintf(stderr, "create failed, result = %d\n", clientfd);
        return -1;
    }

    /* connect 超时处理 */
    sigset(SIGALRM, u_alarm_handler);
    alarm(2);

    /* 服务器地址初始化 */
    memset(&server_addr, 0, sizeof(struct sockaddr_in));
    server_addr.sin_family = AF_INET;
    server_addr.sin_addr.s_addr = inet_addr(ip);
    server_addr.sin_port = htons(port);

    ret = connect(clientfd, (struct sockaddr*)&server_addr, server_len);
    if(ret < 0)
    {
        fprintf(stderr, "connect [%s:%d] failed! clientfd:%d errno:%u errmsg:%s\n", inet_ntoa(server_addr.sin_addr), ntohs(server_addr.sin_port), clientfd, errno, strerror(errno));
        return -1;
    }
    alarm(0);
    sigrelse(SIGALRM);

#ifdef USETLS

    //不需要验证服务器证书
    //SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL);

    //需要验证服务器证书
    //如果是非权威CA(自己生成的TEST CA) 则需要先加载CA证书才能通过校验
    //ret = SSL_CTX_load_verify_locations(ctx, ca_file, NULL);
    //
    //加载成功之后,服务器证书就可以验证通过,否则验证失败Unkonw CA
    //
    //SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, NULL);


    /* 创建ctx对象 */
    ctx = ssl_ctx_init(2);

    ssl = SSL_new(ctx);
    SSL_set_fd(ssl, clientfd);

    if((ret = SSL_connect(ssl)) != 1)
    {
        ShowCerts(ssl);
        ERR_print_errors_fp(stderr);
        fprintf(stderr, "SSL connect failed! ret %d errcode:%d\n", ret, SSL_get_error(ssl, ret));  
        return -1;
    }
    else
        ShowCerts(ssl);
        fprintf(stdout, "Connect with [%s] encryption\n", SSL_get_cipher(ssl));

#endif

#if 0
    if(0 != set_nonblock(clientfd))
    {
        printf("set nonblock failed!\n");
        return -1;
    }
#endif

    while(1)
    {
        memset(buffer, 0, MAX_BUFFER_SIZE);
        memset(&server_addr, 0, sizeof(struct sockaddr_in));

        if(!getsockname(clientfd, (struct sockaddr*)&server_addr, &server_len))
        {
            sprintf(buffer, "TLS -> %s, %d Comming!", inet_ntoa(server_addr.sin_addr), ntohs(server_addr.sin_port));
        }

        /* write */
#ifndef USETLS
        if(-1 == writen(clientfd, (uint8_t*)buffer, strlen(buffer)))
            fprintf(stderr, "write [%s] failed!\n", buffer);
        else
            fprintf(stdout, "write [%s] success!\n", buffer);
#else
        ret = SSL_write(ssl, buffer, strlen(buffer));
        if(ret <= 0)
        {
            ret = SSL_get_error(ssl, ret);
            fprintf(stderr, "SSL_read has error! ssl get errcode:%d\n", ret);

            if(ret == SSL_ERROR_WANT_READ)
            {
                fprintf(stderr, "SSL_ERROR_WANT_READ\n");
            }
            else if(ret == SSL_ERROR_WANT_WRITE)
            {
                fprintf(stderr, "SSL_ERROR_WANT_WRITE");
            }

        }
         fprintf(stdout, "write [%s] success!\n", (char*)buffer);
#endif

        /* read */
        memset(buffer, 0, MAX_BUFFER_SIZE);
#ifndef USETLS
        if(-1 == read(clientfd, (uint8_t*)buffer, MAX_BUFFER_SIZE))
            fprintf(stderr, "read failed! errno:%u, errmsg:%s\n", errno, strerror(errno));
        else
            fprintf(stdout, "read [%s] success!\n", (char*)buffer);
#else
        ret = SSL_read(ssl, buffer, MAX_BUFFER_SIZE);
        if(ret <= 0)
        {
            ret = SSL_get_error(ssl, ret);
            fprintf(stderr, "SSL_read has error! ssl get errcode:%d\n", ret);

            if(ret == SSL_ERROR_WANT_READ)
            {
                fprintf(stderr, "SSL_ERROR_WANT_READ\n");
            }
            else if(ret == SSL_ERROR_WANT_WRITE)
            {
                fprintf(stderr, "SSL_ERROR_WANT_WRITE");
            }

        }

        fprintf(stdout, "read [%s] success!\n", (char*)buffer);
#endif

        sleep(1);
    }


    /* 与OpenSSL_add_all_algorithms对应  */
    EVP_cleanup();

    return 0;
}

你可能感兴趣的:(TLS双向认证,三级证书链)