python常见代码段

args

import argparse

parser = argparse.ArgumentParser()

parser.add_argument('--lr', default=0.001, help='learning rate', type=float)
parser.add_argument('--batch_size', default=2048, help='batch size', type=int)
parser.add_argument('--test_batch_size', default= 1, help='batch size', type=int)
parser.add_argument('--number_sample', default= 1000, help='negative sampling number', type=int)
parser.add_argument('--top_folder', default="/home/chao/haoyu/dien-taobao/", help='top folder') #"/home/chao/haoyu/dien-nsma/" #"/cluster/home/it_stu110/rec/dien-nsma" 
parser.add_argument('--model_type', default="DIN", help='model name')
parser.add_argument('--seed', default= 3, help='seed', type=int)
parser.add_argument('--train_rounds', default= 4, help='seed', type=int)
parser.add_argument('--embed_size', default= 18, help='embed size', type=int)
parser.add_argument('--test_iter', default= 50, help='test iterations', type=int)
parser.add_argument('--save_iter', default= 50, help='save iterations', type=int)
parser.add_argument('--should_train', action='store_true', help='train model')
parser.add_argument('--should_test', action='store_true', help='eval model')

parser.add_argument('--dataset', default="taobao", help='dataset')

args = parser.parse_args()

log file

import time, os
model_path =  args.top_folder +  "save/" + DATASET + "/" + model_type + "_model" + "_H_" + str(args.embed_size) +  "_lr" + str(args.lr)  + "/ckpt_noshuff_" + model_type + str(seed)
best_model_path =  args.top_folder +  "save/" + DATASET + "/" + model_type + "_model" + "_H_" +  str(args.embed_size) + "_lr" + str(args.lr) + "/best_model/ckpt_noshuff_" + model_type + str(seed)
log_path =  args.top_folder +  "save/" + DATASET + "/" + model_type + "_model"+ "_H_" + str(args.embed_size) + "_lr" + str(args.lr)  + "/train_log.txt"


if not os.path.exists(model_path):
    os.makedirs(model_path)
if not os.path.exists(best_model_path):
    os.makedirs(best_model_path)

log_file = open(log_path, "a")
log_file.write("\n")
log_file.write("=======================")
log_file.write(str(time.asctime( time.localtime(time.time()) )))
log_file.write("\n")
for arg in vars(args):
    print (arg, getattr(args, arg),file = log_file)
log_file.write("\n")

warp sampler

import numpy as np
from multiprocessing import Process, Queue


def random_neq(l, r, s):
    t = np.random.randint(l, r)
    while t in s:
        t = np.random.randint(l, r)
    return t


def sample_function(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED):
    def sample():

        user = np.random.randint(1, usernum)
        while user not in user_train or len(user_train[user]) <= 1:
            user = np.random.randint(1, usernum)

        seq = np.zeros([maxlen], dtype=np.int32)
        seq_t = np.zeros([maxlen], dtype=np.float32)

        pos = np.zeros([maxlen], dtype=np.int32)
        neg = np.zeros([maxlen], dtype=np.int32)

        nxt = user_train[user][-1]
        idx = maxlen - 1

        trainset = set(user_train[user][:, 1])
        for (i, t) in reversed(user_train[user][:-1]):
            seq[idx] = i
            seq_t[idx] = t

            pos[idx] = nxt[0]
            if nxt[0] != 0: neg[idx] = random_neq(1, itemnum, trainset)

            nxt = (i, t)
            idx -= 1
            if idx == -1: break

        return user, seq, seq_t, pos, neg

    np.random.seed(SEED)
    max_len = maxlen
    while True:
        user_b = np.zeros(batch_size, dtype=np.int32)
        seq_b = np.zeros((batch_size, max_len), dtype=np.int32)
        pos_b = np.zeros((batch_size, max_len), dtype=np.int32)
        neg_b = np.zeros((batch_size, max_len), dtype=np.int32)
        seq_tb = np.zeros((batch_size, max_len), dtype=np.float32)

        for i in range(batch_size):
            user, seq, seq_t, pos, neg = sample()
            user_b[i] = user
            seq_b[i, :] = seq
            pos_b[i, :] = pos
            neg_b[i, :] = neg
            seq_tb[i, :] = seq_t
        result_queue.put((user_b, seq_b, seq_tb, pos_b, neg_b))


class WarpSampler(object):
    def __init__(self, User, usernum, itemnum, batch_size=64, maxlen=10, n_workers=1):
        self.result_queue = Queue(maxsize=n_workers * 10)
        self.processors = []
        for i in range(n_workers):
            self.processors.append(
                Process(target=sample_function, args=(User,
                                                      usernum,
                                                      itemnum,
                                                      batch_size,
                                                      maxlen,
                                                      self.result_queue,
                                                      np.random.randint(6789)
                                                      )))
            self.processors[-1].daemon = True
            self.processors[-1].start()

    def next_batch(self):
        return self.result_queue.get()

    def close(self):
        for p in self.processors:
            p.terminate()
            p.join()

zip the code

import scipy.misc as misc
import shutil
import zipfile

top_folder= opt.top_folder

srczip = zipfile.ZipFile('./src.zip', 'w')
for root, dirnames, filenames in os.walk(top_folder):
    print(dirnames, end="\t")
    for filename in filenames:
        if filename.split('\n')[0].split('.')[-1] == 'py':
            srczip.write(os.path.join(root, filename).replace(top_folder, '.'))
srczip.close()
shutil.copy('./src.zip',log_dir+'/src.zip')

你可能感兴趣的:(python常见代码段)