OpenSSL CSP Engine

#ifndef CSPEngineH
#define CSPEngineH

#include <openssl/rsa.h>
#include <openssl/evp.h>
#include <openssl/x509.h>
#include <openssl/x509_vfy.h>
#include <windows.h>
#include <wincrypt.h>
//===========================================================================
#if defined(__cplusplus)
extern "C" {
#endif

extern X509*                    x509;
extern EVP_PKEY*                key;

X509_STORE* X509_STORE_load_MSCryptoAPI(void);
int X509_STORE_load_CERT_STORE(X509_STORE* store, HCERTSTORE hCertStore);
EVP_PKEY* EVP_PKEY_new_CERT_CONTEXT(PCCERT_CONTEXT pCertContext, X509** cert);
X509* X509_new_CERT_CONTEXT(PCCERT_CONTEXT pCertContext);
X509_CRL* X509_CRL_new_CRL_CONTEXT(PCCRL_CONTEXT pCrlContext);

int CSP_rsa_init(
    RSA*                        rsa
);
int CSP_rsa_finish(
    RSA*                        rsa
);
int CSP_rsa_pub_enc(
    int                         flen
    , const unsigned char*      from
    , unsigned char*            to
    , RSA*                      rsa
    , int                       padding
);
int CSP_rsa_pub_dec(
    int                         flen
    , const unsigned char*      from
    , unsigned char*            to
    , RSA*                      rsa
    , int                       padding
);
int CSP_rsa_priv_enc(
    int                         flen
    , const unsigned char*      from
    , unsigned char*            to
    , RSA*                      rsa
    , int                       padding
);
int CSP_rsa_priv_dec(
    int                         flen
    , const unsigned char*      from
    , unsigned char*            to
    , RSA*                      rsa
    , int                       padding
);
int CSP_rsa_sign(
    int                         type
    , const unsigned char*      m
    , unsigned int              m_length
    , unsigned char*            sigret
    , unsigned int*             siglen
    , const RSA*                rsa
);
int CSP_rsa_verify(
    int                         type
    , const unsigned char*      m
    , unsigned int              m_length
    , unsigned char*            sigbuf
    , unsigned int              siglen
    , const RSA*                rsa
);
//---------------------------------------------------------------------------
static const RSA_METHOD CSP_rsa_method =
{
    "Cryptographic RSA method"  //! name
    , CSP_rsa_pub_enc           //! rsa_pub_enc
    , CSP_rsa_pub_dec           //! rsa_pub_dec
    , CSP_rsa_priv_enc          //! rsa_priv_enc
    , CSP_rsa_priv_dec          //! rsa_priv_dec
    , NULL                      //! rsa_mod_exp
    , BN_mod_exp_mont           //! bn_mod_exp
    , CSP_rsa_init              //! init
    , CSP_rsa_finish            //! finish
    , RSA_FLAG_SIGN_VER         //! flags
    , NULL                      //! app_data
    , CSP_rsa_sign              //! rsa_sign
    , CSP_rsa_verify            //! rsa_verify
};

#if defined(__cplusplus)
};//extern "C"
#endif
//===========================================================================
#endif//CSPEngineH

//---------------------------------------------------------------------------
#include "stdafx.h"

#include <vector>
#include <algorithm>
#include <openssl/bio.h>
#include <openssl/objects.h>
#include "CSPEngine.h"
//---------------------------------------------------------------------------
X509*                           x509 = NULL;
EVP_PKEY*                       key  = NULL;
//---------------------------------------------------------------------------
X509_STORE* X509_STORE_load_MSCryptoAPI(void)
{
    static X509_STORE*          store = NULL;
    HCERTSTORE                  hCertStore;
    LPCTSTR                     lpStoreNames[] = {
        TEXT("ROOT"),
        TEXT("CA")
    };

/*  if(NULL != store) {
        ++store->references;
        return store;
    }*/
   
    store = X509_STORE_new();

    for(int i=0; i<sizeof(lpStoreNames)/sizeof(lpStoreNames[0]); ++i) {
        hCertStore = ::CertOpenSystemStore(NULL, lpStoreNames[i]);
        if(NULL == hCertStore)  return 0;
       
        X509_STORE_load_CERT_STORE(store, hCertStore);
    }

    return store;
}
//---------------------------------------------------------------------------
int X509_STORE_load_CERT_STORE(X509_STORE* store, HCERTSTORE hCertStore)
{
    if(NULL == store)           return 0;
    if(NULL == hCertStore)      return 0;

    PCCERT_CONTEXT              pCertContext = NULL;
    while(pCertContext = ::CertEnumCertificatesInStore(hCertStore, pCertContext)) {
        X509*                   x509;

        x509 = X509_new_CERT_CONTEXT(pCertContext);
        X509_STORE_add_cert(store, x509);
    }

    PCCRL_CONTEXT               pCrlContext = NULL;
    while(pCrlContext = ::CertEnumCRLsInStore(hCertStore, pCrlContext)) {
        X509_CRL*               crl;

        crl = X509_CRL_new_CRL_CONTEXT(pCrlContext);
        X509_STORE_add_crl(store, crl);
    }

    return 1;
}
//---------------------------------------------------------------------------
EVP_PKEY* EVP_PKEY_new_CERT_CONTEXT(PCCERT_CONTEXT pCertContext, X509** cert)
{
    X509*                       x509;
    EVP_PKEY*                   evp;

    if(NULL == pCertContext)    return NULL;

    x509 = X509_new_CERT_CONTEXT(pCertContext);
    if(NULL == x509)            return NULL;

    evp = X509_get_pubkey(x509);
    if(NULL == evp)             return NULL;

    if(RSA* rsa = EVP_PKEY_get1_RSA(evp)) {
        RSA_set_ex_data(rsa, 0, (void*)pCertContext);
        RSA_set_method(rsa, &CSP_rsa_method);
        rsa->flags |= CSP_rsa_method.flags;
        RSA_free(rsa);
    }   else {
        EVP_PKEY_free(evp);
        evp = NULL;
    }

    if(evp) {
        if(cert) {
            *cert = x509;
        }   else {
            X509_free(x509);
        }
    }
   
    return evp;
}
//---------------------------------------------------------------------------
X509* X509_new_CERT_CONTEXT(PCCERT_CONTEXT pCertContext)
{
    X509*                       x509 = NULL;
    BIO*                        bio;

    if(NULL == pCertContext)    return NULL;

    bio = BIO_new_mem_buf(pCertContext->pbCertEncoded
        , pCertContext->cbCertEncoded);
    if(NULL == bio)             return NULL;

    d2i_X509_bio(bio, &x509);
    BIO_free(bio);

    return x509;
}
//---------------------------------------------------------------------------
X509_CRL* X509_CRL_new_CRL_CONTEXT(PCCRL_CONTEXT pCrlContext)
{
    X509_CRL*                   crl = NULL;
    BIO*                        bio;

    if(NULL == pCrlContext)     return NULL;

    bio = BIO_new_mem_buf(pCrlContext->pbCrlEncoded
        , pCrlContext->cbCrlEncoded);
    if(NULL == bio)             return NULL;

    d2i_X509_CRL_bio(bio, &crl);
    BIO_free(bio);

    return crl;
}
//---------------------------------------------------------------------------
int CSP_rsa_init(RSA* rsa)
{
    BOOL                        ret;
    PCERT_CONTEXT               pCertContext;           // 证书内容
    HCRYPTPROV                  hCryptProv = NULL;      // 密钥位置
    HCRYPTKEY                   hCryptKey  = NULL;      // 私钥句柄
    DWORD                       dwKeySpec;
    BOOL                        fCallerFreeProv = FALSE;

    pCertContext = (PCERT_CONTEXT)RSA_get_ex_data(rsa, 0);
    if(NULL == pCertContext)    return 0;

    hCryptKey    = (HCRYPTKEY    )RSA_get_ex_data(rsa, 2);
    if(NULL != hCryptKey)       return 1;

    // 得到证书相关密钥位置
    ret = ::CryptAcquireCertificatePrivateKey(pCertContext
        , 0, NULL, &hCryptProv, &dwKeySpec, &fCallerFreeProv);
    if(!ret)                    goto err;

    // 获得私钥句柄
    ret = ::CryptGetUserKey(hCryptProv, dwKeySpec, &hCryptKey);
    if(!ret)                    goto err;

    RSA_set_ex_data(rsa, 1, (void*)hCryptProv);
    RSA_set_ex_data(rsa, 2, (void*)hCryptKey);
    RSA_set_ex_data(rsa, 3, (void*)dwKeySpec);
    RSA_set_ex_data(rsa, 4, (void*)fCallerFreeProv);

    return 1;

err:
    if(hCryptKey)               ::CryptDestroyKey(hCryptKey);
    if(hCryptProv && fCallerFreeProv) {
        ::CryptReleaseContext(hCryptProv, 0);
    }

    return 0;
}
//---------------------------------------------------------------------------
int CSP_rsa_finish(RSA* rsa)
{
    HCRYPTPROV                  hCryptProv   = NULL;
    HCRYPTKEY                   hCryptKey    = NULL;
    BOOL                        fCallerFreeProv = FALSE;

    hCryptKey       = (HCRYPTKEY )RSA_get_ex_data(rsa, 2);
    ::CryptDestroyKey(hCryptKey);
    RSA_set_ex_data(rsa, 2, NULL);

    hCryptProv      = (HCRYPTPROV)RSA_get_ex_data(rsa, 1);
    fCallerFreeProv = (BOOL      )RSA_get_ex_data(rsa, 4);
    if(hCryptProv && fCallerFreeProv) {
        ::CryptReleaseContext(hCryptProv, 0);
        RSA_set_ex_data(rsa, 1, NULL);
    }

    return 1;
}
//---------------------------------------------------------------------------
int CSP_rsa_pub_enc(int flen, const unsigned char* from
    , unsigned char* to, RSA* rsa, int padding)
{
    return RSA_PKCS1_SSLeay()->rsa_pub_enc(flen, from, to, rsa, padding);
}
//---------------------------------------------------------------------------
int CSP_rsa_pub_dec(int flen, const unsigned char* from
    , unsigned char* to, RSA* rsa, int padding)
{
    return RSA_PKCS1_SSLeay()->rsa_pub_dec(flen, from, to, rsa, padding);
}
//---------------------------------------------------------------------------
int CSP_rsa_priv_enc(int flen, const unsigned char* from
    , unsigned char* to, RSA* rsa, int padding)
{
    return -1;
}
//---------------------------------------------------------------------------
int CSP_rsa_priv_dec(int flen, const unsigned char* from
    , unsigned char* to, RSA* rsa, int padding)
{
    BOOL                        ret;
    HCRYPTKEY                   hCryptKey;
    DWORD                       cbData = flen;
    std::vector<BYTE>           pbData;

    hCryptKey = (HCRYPTKEY)RSA_get_ex_data(rsa, 2);

    pbData.resize(cbData);
    std::copy(from, from+flen, pbData.rbegin());
    ret = ::CryptDecrypt(hCryptKey, NULL, TRUE, 0, &*pbData.begin(), &cbData);
    if(!ret)                    return -1;
    std::copy(pbData.begin(), pbData.begin()+cbData, to);

    return cbData;
}
//---------------------------------------------------------------------------
ALG_ID nid2algid(int nid)
{
    ALG_ID                      algId;

    switch(nid) {
    case NID_md2:
        algId = CALG_MD2;       break;
    case NID_md4:
        algId = CALG_MD4;       break;
    case NID_md5:
        algId = CALG_MD5;       break;
    case NID_sha:
        algId = CALG_SHA;       break;
    case NID_sha1:
        algId = CALG_SHA1;      break;
    case NID_md5_sha1:
    default:
        algId = CALG_SSL3_SHAMD5;
        break;
    }

    return algId;
}
//---------------------------------------------------------------------------
int CSP_rsa_sign(int type, const unsigned char *m, unsigned int m_length
    , unsigned char *sigret, unsigned int *siglen, const RSA *rsa)
{
    BOOL                        ret = FALSE;
    HCRYPTPROV                  hCryptProv;
    HCRYPTKEY                   hCryptKey;
    DWORD                       dwKeySpec;
    ALG_ID                      algId;
    HCRYPTHASH                  hHash = NULL;
    DWORD                       cbHash, cbHashSize;
    DWORD                       cbData = 0;
    std::vector<BYTE>           pbData;

    hCryptProv = (HCRYPTPROV)RSA_get_ex_data(rsa, 1);
    hCryptKey  = (HCRYPTKEY )RSA_get_ex_data(rsa, 2);
    dwKeySpec  = (DWORD     )RSA_get_ex_data(rsa, 3);
    if(NULL == hCryptKey)       goto err;

    algId = nid2algid(type);
    if(-1 == algId)             goto err;

    ret = ::CryptCreateHash(hCryptProv, algId, 0, 0, &hHash);
    if(!ret)                    goto err;
    ret = ::CryptGetHashParam(hHash, HP_HASHSIZE, (LPBYTE)&cbHashSize, &cbHash, 0);
    if(!ret)                    goto err;
    if(m_length != cbHashSize)  goto err;

    ret = ::CryptSetHashParam(hHash, HP_HASHVAL, m, 0);
    if(!ret)                    goto err;

    ret = ::CryptSignHash(hHash, dwKeySpec, NULL, 0, NULL, &cbData);
    if(!ret)                    goto err;
    *siglen = cbData;

    pbData.resize(cbData);
    ret = ::CryptSignHash(hHash, dwKeySpec, NULL, 0, &*pbData.begin(), &cbData);
    if(!ret)                    goto err;
    std::copy(pbData.rbegin(), pbData.rend(), sigret);

err:
    ::CryptDestroyHash(hHash);

    return ret;
}
//---------------------------------------------------------------------------
int CSP_rsa_verify(int type, const unsigned char *m, unsigned int m_length
    , unsigned char *sigbuf, unsigned int siglen, const RSA *rsa)
{
    BOOL                        ret = FALSE;
    RSA*                        pubrsa;

    pubrsa = RSAPublicKey_dup(const_cast<RSA*>(rsa));
    if(NULL == pubrsa)          goto err;

    ret = RSA_verify(type, m, m_length, sigbuf, siglen, pubrsa);
   
err:
    RSA_free(pubrsa);
    return ret;
}
//---------------------------------------------------------------------------

你可能感兴趣的:(OpenSSL CSP Engine)