args
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--lr', default=0.001, help='learning rate', type=float)
parser.add_argument('--batch_size', default=2048, help='batch size', type=int)
parser.add_argument('--test_batch_size', default= 1, help='batch size', type=int)
parser.add_argument('--number_sample', default= 1000, help='negative sampling number', type=int)
parser.add_argument('--top_folder', default="/home/chao/haoyu/dien-taobao/", help='top folder') #"/home/chao/haoyu/dien-nsma/" #"/cluster/home/it_stu110/rec/dien-nsma"
parser.add_argument('--model_type', default="DIN", help='model name')
parser.add_argument('--seed', default= 3, help='seed', type=int)
parser.add_argument('--train_rounds', default= 4, help='seed', type=int)
parser.add_argument('--embed_size', default= 18, help='embed size', type=int)
parser.add_argument('--test_iter', default= 50, help='test iterations', type=int)
parser.add_argument('--save_iter', default= 50, help='save iterations', type=int)
parser.add_argument('--should_train', action='store_true', help='train model')
parser.add_argument('--should_test', action='store_true', help='eval model')
parser.add_argument('--dataset', default="taobao", help='dataset')
args = parser.parse_args()
log file
import time, os
model_path = args.top_folder + "save/" + DATASET + "/" + model_type + "_model" + "_H_" + str(args.embed_size) + "_lr" + str(args.lr) + "/ckpt_noshuff_" + model_type + str(seed)
best_model_path = args.top_folder + "save/" + DATASET + "/" + model_type + "_model" + "_H_" + str(args.embed_size) + "_lr" + str(args.lr) + "/best_model/ckpt_noshuff_" + model_type + str(seed)
log_path = args.top_folder + "save/" + DATASET + "/" + model_type + "_model"+ "_H_" + str(args.embed_size) + "_lr" + str(args.lr) + "/train_log.txt"
if not os.path.exists(model_path):
os.makedirs(model_path)
if not os.path.exists(best_model_path):
os.makedirs(best_model_path)
log_file = open(log_path, "a")
log_file.write("\n")
log_file.write("=======================")
log_file.write(str(time.asctime( time.localtime(time.time()) )))
log_file.write("\n")
for arg in vars(args):
print (arg, getattr(args, arg),file = log_file)
log_file.write("\n")
warp sampler
import numpy as np
from multiprocessing import Process, Queue
def random_neq(l, r, s):
t = np.random.randint(l, r)
while t in s:
t = np.random.randint(l, r)
return t
def sample_function(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED):
def sample():
user = np.random.randint(1, usernum)
while user not in user_train or len(user_train[user]) <= 1:
user = np.random.randint(1, usernum)
seq = np.zeros([maxlen], dtype=np.int32)
seq_t = np.zeros([maxlen], dtype=np.float32)
pos = np.zeros([maxlen], dtype=np.int32)
neg = np.zeros([maxlen], dtype=np.int32)
nxt = user_train[user][-1]
idx = maxlen - 1
trainset = set(user_train[user][:, 1])
for (i, t) in reversed(user_train[user][:-1]):
seq[idx] = i
seq_t[idx] = t
pos[idx] = nxt[0]
if nxt[0] != 0: neg[idx] = random_neq(1, itemnum, trainset)
nxt = (i, t)
idx -= 1
if idx == -1: break
return user, seq, seq_t, pos, neg
np.random.seed(SEED)
max_len = maxlen
while True:
user_b = np.zeros(batch_size, dtype=np.int32)
seq_b = np.zeros((batch_size, max_len), dtype=np.int32)
pos_b = np.zeros((batch_size, max_len), dtype=np.int32)
neg_b = np.zeros((batch_size, max_len), dtype=np.int32)
seq_tb = np.zeros((batch_size, max_len), dtype=np.float32)
for i in range(batch_size):
user, seq, seq_t, pos, neg = sample()
user_b[i] = user
seq_b[i, :] = seq
pos_b[i, :] = pos
neg_b[i, :] = neg
seq_tb[i, :] = seq_t
result_queue.put((user_b, seq_b, seq_tb, pos_b, neg_b))
class WarpSampler(object):
def __init__(self, User, usernum, itemnum, batch_size=64, maxlen=10, n_workers=1):
self.result_queue = Queue(maxsize=n_workers * 10)
self.processors = []
for i in range(n_workers):
self.processors.append(
Process(target=sample_function, args=(User,
usernum,
itemnum,
batch_size,
maxlen,
self.result_queue,
np.random.randint(6789)
)))
self.processors[-1].daemon = True
self.processors[-1].start()
def next_batch(self):
return self.result_queue.get()
def close(self):
for p in self.processors:
p.terminate()
p.join()
zip the code
import scipy.misc as misc
import shutil
import zipfile
top_folder= opt.top_folder
srczip = zipfile.ZipFile('./src.zip', 'w')
for root, dirnames, filenames in os.walk(top_folder):
print(dirnames, end="\t")
for filename in filenames:
if filename.split('\n')[0].split('.')[-1] == 'py':
srczip.write(os.path.join(root, filename).replace(top_folder, '.'))
srczip.close()
shutil.copy('./src.zip',log_dir+'/src.zip')