multiprocessing.Pool python多进程最佳实践

#!/usr/bin/python3
import ember
from keras.models import load_model
import os
import argparse
from preprocess import preprocess
import lightgbm as lgb 
import time 
import multiprocessing
import numpy as np
from features import PEFeatureExtractor
import csv 
import pandas as pd
from inception_preprocess import Preprocess
import shutil
import lief
import logging 
import warnings
import struct
import lief
import sys
import pefile

#减少警告信息
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

current_path = os.path.split(os.path.realpath(__file__))[0]
test_path = '/data/aidm/samples'
result_base = '/data/aidm/logs'
log_path = os.path.join(result_base,'handle.log')
logging.basicConfig(level=logging.DEBUG,
                format='%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s',
                datefmt='%a, %d %b %Y %H:%M:%S',
                filename=log_path,
                filemode='w')
model_path = './model'
csv_path = '/data/aidm/logs'
mal_model_path = os.path.join(model_path,'malconv.h5')
mal_result_path = os.path.join(csv_path,'mal_result.csv')
ember_model_path = os.path.join(model_path,'model.txt')
ember_result_path = os.path.join(csv_path,'ember_result.csv')
inception_model_path = os.path.join(model_path,'RGBImage_ft.h5')
inception_result_path = os.path.join(csv_path,'inception.csv')
final_result_path = os.path.join(result_base,'result.csv')
parser = argparse.ArgumentParser(description='Testing for performance on test data!')
#parser.add_argument('command',type=str,default='aidm-test',help='The exec command is aidm-test!Please check the cmmand!')
parser.add_argument('--test_data',type=str,default=test_path,help='The path of the test data!')
parser.add_argument('--max_len', type=int, default=500000,help='You should set it as the model can accept!')
parser.add_argument('--models',type=str,default='all',choices=['all','gbdt','word2vec','inception'],help="You can select a model you want to test!")
parser.add_argument('--policy',type=str,default='max',choices=['max','min','all_b','all_w'],help="You can select a policy you want to test!")
args = parser.parse_args()

test_data_path = args.test_data
max_len= args.max_len
model_type = args.models
policy = args.policy


def predict_sample(lgbm_model, file_data):
    """
    Predict a PE file with an LightGBM model
    """
    extractor = PEFeatureExtractor()
    features = np.array(extractor.feature_vector(file_data), dtype=np.float32)
    return lgbm_model.predict([features])[0]

def get_test_data_list(test_data_path):
    test_data_list = []
    for dir,folder,file in os.walk(test_data_path):
        for f in file:
            file_abs = "%s/%s"%(dir,f)
            test_data_list.append(file_abs)
    return test_data_list

#run_type为1表示只跑单个模型,写入csv的内容不同
def malconv_test(mal_model_path,test_data_list,max_len,mal_result_path,process_num,run_type=0):    
    start = time.time()
    process_str = str(process_num)
    start_info = "Word2vec-model start predicting ...["+process_str+"]"
    logging.info(start_info)
    result_file = open(mal_result_path,"a")
    model = load_model(mal_model_path)
    mid = time.time()
    csvwriter = csv.writer(result_file,dialect='excel')
    count = 0 
    for i in range(len(test_data_list)):
            file = test_data_list[i]
            one_start = time.time()
            count += 1
            result_pair = []
            result_pair.append(file)
            xx_list = []
            xx_list.append(file)
            xx = preprocess(xx_list,max_len)[0]
            predict_arr = model.predict(xx)
            one_end = time.time() 
            one_time_cost = round(one_end-one_start,3)
            predict_value = predict_arr[0][0]
            if run_type == 1:
                if predict_value>=0.5:
                   result_pair.append(2)
                else:
                   result_pair.append(1)
                result_pair.append(one_time_cost)
            else:
                result_pair.append(predict_value)
                result_pair.append(one_time_cost)
            print('malconv[',process_str,']:',file,'--',one_start,"cost time:",one_time_cost,"number:",i)
            csvwriter.writerow(result_pair)
    finish_info = "Word2vec-model finished! ["+process_str+"]"
    logging.info(finish_info)
    end = time.time()
    load_cost = round(mid-start,3)
    time_cost = round(end-mid,3)
    each_cost = round(count/(end-start),1)
    final_info = "Total number of PE is:",str(count),";Loading model cost ",str(load_cost),"s"," speed:",str(each_cost),"file/s(word2vec["+process_str+"])"
    logging.info(final_info)
    


def ember_test(ember_model_path,test_data_list,ember_result_path,process_num,run_type=0):
    process_str = str(process_num)
    start_info = "GBDT start predicting....["+process_str+"]"
    logging.info(start_info)
    result_file = open(ember_result_path,"a")
    lgbm_model = lgb.Booster(model_file=ember_model_path)
    csvwriter = csv.writer(result_file,dialect='excel')  
    start = time.time()
    #total:total number of predict samle ; success:the number of success
    total = 0 
    success = 0
    unsuccess_list = []
    for i in range(len(test_data_list)):
            file = test_data_list[i]
            total += 1
            result_pair = []
            result_pair.append(file)
            try:
                one_start = time.time()
                putty_data = open(file,"rb").read()
                predict_value = predict_sample(lgbm_model,putty_data)
                one_end = time.time()
                one_time_cost = round(one_end-one_start,3)
                if run_type==1:
                   if predict_value >= 0.4:
                      result_pair.append(2)
                   else:
                      result_pair.append(1)
                   result_pair.append(one_time_cost)
                else:
                   result_pair.append(predict_value)
                   result_pair.append(one_time_cost)
                csvwriter.writerow(result_pair)
                print('ember[',process_str,']:',file,'--',one_start," cost time:",one_time_cost,"number:",i)
                success += 1
            except Exception as e:
                unsuccess_list.append(file)           
                logging.info(str(e))
    end = time.time() 
    process_str =str(process_num)
    time_cost = round(end-start,3)
    if success == total:
       finish_info = "GBDT["+process_str+"]:All samples have been predicted successful!(gbdt)"
       logging.info(finish_info)
    else:
      unsuccess = len(unsuccess_list)
      info_one = str(unsuccess)+" of the files has failed in processing!!(gbdt)["+process_str+"]"
      logging.info(info_one)
      info_two = "The paths are:"+str(unsuccess_list)+" (gbdt)["+process_str+"]"
      logging.warn(info_two)

    if success > 0:
       each_cost = round(success/(end-start),1)
       final_info = "model:GBDT total:",str(total)," success:",str(success), " spend time:",str(time_cost),"speed:",str(each_cost)," file/s(gbdt)["+process_str+"]"
       logging.info(final_info)
    final_info = "GBDT finished!(gbdt)["+process_str+"]"
    logging.info(final_info)



def inception_test(model_path,test_data_list,result_path,process_num,run_type=0):
    resize = 224
    start = time.time()
    process_str = str(process_num)
    start_info = "Inception-model"+"start predicting ...["+process_str+"]"
    logging.info(start_info)
    result_file = open(result_path,"a")
    model = load_model(model_path)
    mid = time.time()
    csvwriter = csv.writer(result_file,dialect='excel')
    count = 0
    for i in range(len(test_data_list)):
            file = test_data_list[i]
            one_start = time.time()
            count += 1
            result_pair = []
            result_pair.append(file)
            xx_list = []
            xx_list.append(file)
            xx = Preprocess(xx_list,resize)
            predict_arr = model.predict(xx)
            predict_value = predict_arr[0][0]
            one_end = time.time()
            one_time_cost = round(one_end-one_start,3)
            if run_type == 1:
                if predict_value >= 0.5:
                    result_pair.append(2)
                else:
                    result_pair.append(1)
                result_pair.append(one_time_cost)
            else:
                result_pair.append(predict_value)
                result_pair.append(one_time_cost)
            print('inception[',str(process_num),']:', file,'--',one_start,"cost_time:",one_time_cost,"number:",i)
            csvwriter.writerow(result_pair)
    finish_info = "Inception-model"+"finished!["+process_str+"]"
    logging.info(finish_info)
    end = time.time()
    load_cost = round(mid-start,3)
    time_cost = round(end-mid,3)
    each_cost = round(count/(end-start),1)
    final_info = "Total number of PE is:",str(count),";Loading model cost ",str(load_cost) ,"s"," speed:",str(each_cost),"file/s(inception)["+process_str+"]"
    logging.info(final_info)
    


def read_result_csv(result_path):
    result_file = open(result_path,"r")
    result = pd.read_csv(result_file,header=None) 
    data_dir,predict_value = result[0].values,result[1].values
    return data_dir,predict_value


def is_pe(data_path):       
    try:
        libnative = pefile.PE(data_path,fast_load=True)
    except:
        return False 
    return True 
  
#for all model test:write the final result to the csv file
def write_csv_result(ember_result_path,mal_result_path,inception_result_path,final_result_path,time_cost,policy):
    ember_data_dir,ember_predict = read_result_csv(ember_result_path)
    mal_data_dir,mal_predict = read_result_csv(mal_result_path)
    inception_data_dir,inception_predict = read_result_csv(inception_result_path)
    ember_dict,mal_dict,inception_dict = {},{},{}
    ember_total,mal_total,inception_total = len(ember_data_dir),len(mal_data_dir),len(inception_data_dir)
    for i in range(ember_total):
        ember_dict[ember_data_dir[i]] = ember_predict[i] 
    for i in range(mal_total):
        mal_dict[mal_data_dir[i]] = mal_predict[i]
    for i in range(inception_total):
        inception_dict[inception_data_dir[i]] = inception_predict[i]
    
    final_result_csv = open(final_result_path,"w")
    csv_writer = csv.writer(final_result_csv,dialect='excel')
    for i in range(len(mal_data_dir)):
        #write_item用来存放最后写入csv的结果,result_list用来存放三个模型给出的预测值
        write_item = []
        file_name = mal_data_dir[i]
        write_item.append(file_name)
        if file_name in ember_dict.keys() and ember_dict[file_name] != 3:
            final_result = 0.6*ember_dict[file_name]+0.2*mal_dict[file_name]+0.2*inception_dict[file_name]
            if final_result>=0.5:
                write_item.append(2)
            else:
                write_item.append(1)
        else: #如果gbdt没有预测值,则另外两个模型遇白则白
            if mal_dict[file_name]>0.5 or inception_dict[file_name]>0.5:
                write_item.append(2)
            else:
                write_item.append(1) 
        write_item.append(time_cost) 
        csv_writer.writerow(write_item)      
    final_result_csv.close()

def split_pe_file(file_path):
    pe_list,not_pe_list = [],[]
    current = 0
    for file in file_path:
        current += 1
        print(file,"-----",time.time(),"-----",current)
        if is_pe(file):
            pe_list.append(file)
        else:
            not_pe_list.append(file)  
    return pe_list,not_pe_list

def get_result_together(result_path_list,result_path):
    final_result_csv = open(result_path,"w")
    csv_writer = csv.writer(final_result_csv,dialect='excel')
    for path in result_path_list:
        data_dir,predict = read_result_csv(path)
        for i in range(len(data_dir)):
            item = []
            item.append(data_dir[i])
            item.append(predict[i])
            csv_writer.writerow(item)
        if os.path.exists(path):
            os.remove(path)

if __name__ == '__main__': 
    logging.info("------------------------------------The test start ---------------------------------------")
    #移除已存在的结果文件
    if os.path.exists(final_result_path):
        os.remove(final_result_path)
    if os.path.exists(mal_result_path):    
        os.remove(mal_result_path)
    if os.path.exists(ember_result_path):  
        os.remove(ember_result_path)  
    if os.path.exists(inception_result_path):
        os.remove(inception_result_path)  
         
    #get all test data pat
    test_data_list = get_test_data_list(test_path)
    logging.info("Finish read data path,test_data length:"+str(len(test_data_list)))
    pe_list = []
    not_pe_list = []
    
    #求cpu核数
    cpu_num = multiprocessing.cpu_count()
    
    #split_d 表示启用多少个进程去判断文件是否为PE文件
    pe_start_time = time.time() 
    split_d = cpu_num
    half_d = int(len(test_data_list)/split_d)
    data_list_list = []
    total_test = half_d*split_d
    if total_test != len(test_data_list):
       half =  int(len(test_data_list)/split_d) + 1
    for i in range(split_d):
        if i != split_d-1:
           data_list_list.append(test_data_list[i*half_d:(i+1)*half_d])
        else:
           data_list_list.append(test_data_list[i*half_d:])

    #有多少cpu就用多少个进程进行数据预处理
    pool = multiprocessing.Pool(processes=split_d) 
    res = pool.map(split_pe_file,[file_path for file_path in data_list_list])
    for item in res:
        for it in item[0]:
            pe_list.append(it)
        for it in item[1]:
            not_pe_list.append(it)
    pe_end_time = time.time()
    pe_cost_time = round((pe_end_time-pe_start_time),3)
    print("Finsh split pe time,spent ",pe_cost_time,"s")
   

    if  len(not_pe_list)>0:
        len_info = "The not_pe_list total length:"+str(len(not_pe_list))+"The list is:"+str(not_pe_list)
        logging.info(len_info)
    #print("pe_list_len:",len(pe_list),"not_pe_list_len:",len(not_pe_list))
    if len(pe_list)==0:
        logging.info("There is no pe file in test path!")
        os._exit(0)
    
    
    start = time.time()
    split_n = 0
    if model_type == 'all':
       split_n = int(cpu_num/3)
    else:
       split_n = cpu_num
    half = int(len(pe_list)/split_n)
    pe_list_list = []
    total = half*split_n
    if total != len(pe_list):
       half =  int(len(pe_list)/split_n) + 1
    for i in range(split_n):
        if i != split_n:
           pe_list_list.append(pe_list[i*half:(i+1)*half])
        else:
           pe_list_list.append(pe_list[i*half:])
    #select different number of model_number to change different model composition
    mal_result_path_list = []
    ember_result_path_list = []
    inception_result_path_list = []
    #每个任务写入一个文件
    for i in range(split_n):
        mal_result_path_i = csv_path + 'mal_result_'+str(i)+".csv" 
        ember_result_path_i = csv_path + 'ember_result_'+str(i)+".csv"  
        inception_result_path_i = csv_path + 'inception_result_'+str(i)+".csv" 
        mal_result_path_list.append(mal_result_path_i)
        ember_result_path_list.append(ember_result_path_i)
        inception_result_path_list.append(inception_result_path_i)

    if model_type == 'all':
         start_predict_time = time.time()
         logging.info("The model type is all.")
         process_list = []
         for i in range(split_n):
            process_list.append(multiprocessing.Process(target=malconv_test,args=(mal_model_path,pe_list_list[i],max_len,mal_result_path_list[i],i)))
            process_list.append(multiprocessing.Process(target=ember_test,args=(ember_model_path,pe_list_list[i],ember_result_path_list[i],i)))
            process_list.append(multiprocessing.Process(target=inception_test,args=(inception_model_path,pe_list_list[i],inception_result_path_list[i],i)))
         for process in process_list:
             process.start()
         for process in process_list:
             process.join()
         end_predict_time = time.time()
         result_start = time.time()
         time_cost =  round((end_predict_time-start_predict_time)/len(pe_list),3)
         get_result_together(mal_result_path_list,mal_result_path)
         get_result_together(ember_result_path_list,ember_result_path)
         get_result_together(inception_result_path_list,inception_result_path)
         write_csv_result(ember_result_path,mal_result_path,inception_result_path,final_result_path,time_cost,policy) 
         result_end = time.time()
         result_cost_time = round(result_end-result_start,3)
         predict_total_time = round(end_predict_time-start_predict_time,3)
         predict_info = "[All model]Total predict calculate time is:" + str(predict_total_time)+"s" 
         parse_info = "[All model]parse and wirte result cost time is:"+str(result_cost_time)+"s"
         logging.info(predict_info)
         logging.info(parse_info)
    elif model_type == 'gbdt': 
         start_predict_time = time.time()
         logging.info("The model type is gdbt.")
         process_list = []
         for i in range(split_n):
            process_list.append(multiprocessing.Process(target=ember_test,args=(ember_model_path,pe_list_list[i],final_result_path,i,1)))
         for process in process_list:
             process.start()
         for process in process_list:
             process.join()
         end_predict_time = time.time()
         predict_total_time = round(end_predict_time-start_predict_time,3)
         predict_info = "[GBDT model]Total predict calculate time is:" + str(predict_total_time)+"s"
         logging.info(predict_info)
    elif model_type == 'word2vec':
         start_predict_time = time.time()
         logging.info("The model type is word2vec.")
         process_list = []
         for i in range(split_n):
            process_list.append(multiprocessing.Process(target=malconv_test,args=(mal_model_path,pe_list_list[i],max_len,final_result_path,i,1)))
         for process in process_list:
             process.start()
         for process in process_list:
             process.join()
         end_predict_time = time.time()
         predict_total_time = round(end_predict_time-start_predict_time,3)
         predict_info = "[word2vec model]Total predict calculate time is:" + str(predict_total_time)+"s"
         logging.info(predict_info)
    elif model_type == 'inception':
         start_predict_time = time.time()
         logging.info("The model type is word2vec.")
         process_list = []
         for i in range(split_n):            
             process_list.append(multiprocessing.Process(target=inception_test,args=(inception_model_path,pe_list_list[i],final_result_path,i,1)))
         for process in process_list:
             process.start()
         for process in process_list:
             process.join()
         end_predict_time = time.time()
         predict_total_time = round(end_predict_time-start_predict_time,3)
         predict_info = "[inception model]Total predict calculate time is:" + str(predict_total_time)+"s"
         logging.info(predict_info)
   
    end = time.time()
    time_cost = round(end-start,3)
    predict_info = "[Total]Total running time is:"+str(time_cost)+"s"
    logging.info(predict_info)
    #将不支持的添加到csv结尾
    if len(not_pe_list)>0:
       final_result_csv = open(final_result_path,"a")
       csv_writer = csv.writer(final_result_csv,dialect='excel')
       for item in not_pe_list:
           not_pe = []
           not_pe.append(item)
           not_pe.append(3)
           not_pe.append(0) 
           csv_writer.writerow(not_pe)
    logging.info("------------------------------------The test end ---------------------------------------") 

 

你可能感兴趣的:(multiprocessing.Pool python多进程最佳实践)