这是一个julia的代码
using JSON
#引入 JSON 模块,用于处理 JSON 数据格式。
using Serialization
#引入序列化模块,它允许Julia中的对象可以被序列化(保存到文件)和反序列化(从文件加载)
using DelimitedFiles
#通常用于处理有分隔符的文件,例如 CSV 文件
using Distances
#引入了一个用于计算距离的模块,这个模块包含了多种计算距离的方法
include("utils.jl")
datapath = "/home/ruizhichai/t2vec/data"
param = JSON.parsefile("../hyper-parameters.json")
#从 hyper-parameters.json 文件中读取超参数,并保存到变量 param 中
regionps = param["region"]
cityname = regionps["cityname"]
cellsize = regionps["cellsize"]
region = SpatialRegion(cityname,
regionps["minlon"], regionps["minlat"],
regionps["maxlon"], regionps["maxlat"],
cellsize, cellsize,
regionps["minfreq"], # minfreq
40_000, # maxvocab_size
10, # k
4)
#创建了一个 SpatialRegion 对象
println("Building spatial region with:
cityname=$(region.name),
minlon=$(region.minlon),
minlat=$(region.minlat),
maxlon=$(region.maxlon),
maxlat=$(region.maxlat),
xstep=$(region.xstep),
ystep=$(region.ystep),
minfreq=$(region.minfreq)")
paramfile = "$datapath/$(region.name)-param-cell$(Int(cellsize))"
if isfile(paramfile)
println("Reading parameter file from $paramfile")
region = deserialize(paramfile)
println("Loaded $paramfile into region")
else
println("Cannot find $paramfile")
end
'''
Building spatial region with:
cityname=porto,
minlon=-8.735152,
minlat=40.953673,
maxlon=-8.156309,
maxlat=41.307945,
xstep=100.0,
ystep=100.0,
minfreq=100
#波尔图城区的一些基本信息
Reading parameter file from /home/ruizhichai/t2vec/data/porto-param-cell100
Loaded /home/ruizhichai/t2vec/data/porto-param-cell100 into region
#每一个单元格的最近K个热点单元格id以及其位置
'''
## create querydb
prefix = "exp1"
do_split = true
start = 1_000_000+20_000
num_query = 1000
num_db = 100_000
querydbfile = joinpath(datapath, "$prefix-querydb.h5")
tfile = joinpath(datapath, "$prefix-trj.t")
labelfile = joinpath(datapath, "$prefix-trj.label")
vecfile = joinpath(datapath, "$prefix-trj.h5")
createQueryDB(
"$datapath/$cityname.h5", # 轨迹文件路径
start, # 开始处理的轨迹的索引
num_query, # 查询数据集的大小
num_db, # 数据库的大小
(x, y)->(x, y), # 查询轨迹的噪声函数
(x, y)->(x, y), # 数据库轨迹的噪声函数
do_split=do_split, # 是否分割轨迹的标志
querydbfile=querydbfile # 输出的数据库文件名称
)
"""
createQueryDB("../preprocessing/porto.h5", 1_000_000+20_000, 1000, 100_000,
(x,y)->(x,y), (x,y)->(x,y))
"""
function createQueryDB(trjfile::String, start::Int,
querysize::Int, dbsize::Int,
querynoise::Function, dbnoise::Function;
do_split=true,
querydbfile="querydb.h5",
min_length=30,
max_length=100)
nquery, ndb = 0, 0
#函数首先初始化查询 (nquery) 和数据库 (ndb) 的轨迹计数为0
h5open(trjfile, "r") do f
querydbf = h5open(querydbfile, "w")
#使用 h5open 函数打开输入轨迹文件 trjfile 进行读取
#并在其内部打开或创建一个新的HDF5文件 querydbfile 用于写入
num = read(attributes(f)["num"])
for i = start:num
trip = read(f["/trips/$i"])
timestamp = read(f["/timestamps/$i"])
if nquery < querysize
if 2min_length <= size(trip, 2) <= 2max_length
#如果查询轨迹数量还没有达到 querysize 并且轨迹长度符合限制,则处理轨迹用于查询
if do_split
nquery += 1
ndb += 1
trip1, timestamp1, trip2, timestamp2 = uniformsplit(trip, timestamp)
#如果 do_split 为 true,则使用 uniformsplit 函数将轨迹分为两部分
querydbf["/query/trips/$nquery"], querydbf["/query/timestamps/$nquery"] = querynoise(trip1, timestamp1)
querydbf["/query/names/$nquery"] = i
querydbf["/db/trips/$ndb"], querydbf["/db/timestamps/$ndb"] = dbnoise(trip2, timestamp2)
querydbf["/db/names/$ndb"] = i
else
nquery += 1
querydbf["/query/trips/$nquery"], querydbf["/query/timestamps/$nquery"] = querynoise(trip, timestamp)
querydbf["/query/names/$nquery"] = i
end
'''
每个分割或未分割的轨迹将通过 querynoise 或 dbnoise 函数添加噪声
然后被写入到新的HDF5文件 querydbf 中
如果是分割的轨迹,那么分割的两个一个放在nquery中,另一个放在ndb中
如果不分割,那么全放在nquery中
'''
end
elseif ndb < dbsize
'''
如果数据库轨迹数量还没有达到 dbsize 并且轨迹长度符合限制,则处理轨迹用于数据库
进入这个条件语句,说明ndb已经比nquery多了,也即query的轨迹数量够了
'''
if 2min_length <= size(trip, 2) <= 2max_length
if do_split
ndb += 1
trip1, timestamp1, _, _ = uniformsplit(trip, timestamp)
querydbf["/db/trips/$ndb"], querydbf["/db/timestamps/$ndb"] = dbnoise(trip1, timestamp1)
#querydbf["/db/timestamps/$ndb"] = timestamp
querydbf["/db/names/$ndb"] = i
else
ndb += 1
querydbf["/db/trips/$ndb"], querydbf["/db/timestamps/$ndb"] = dbnoise(trip, timestamp)
querydbf["/db/names/$ndb"] = i
end
end
else
break
end
end
querydbf["/query/num"], querydbf["/db/num"] = nquery, ndb
close(querydbf)
#在循环结束后,将查询和数据库的轨迹数量写入到输出文件。
end
nquery, ndb
end
'''
原始的轨迹数据被均匀地分成两个部分,每个部分包含原始轨迹的一半点,而且这两个部分点是交错的
'''
function uniformsplit(trip::Matrix{Float64}, timestamp::Vector{Float64})
n = size(trip, 2)
# trip 矩阵中轨迹点的数量
idx1, idx2 = 1:2:n, 2:2:n
#idx1 包含从1开始到 n 结束的所有奇数索引,idx2 包含从2开始到 n 结束的所有偶数索引
copy(trip[:, idx1]), copy(timestamp[idx1]), copy(trip[:, idx2]), copy(timestamp[idx2])
end
createTLabel(region, querydbfile; tfile=tfile, labelfile=labelfile)
"""
Creating trj.t trj.label for t2vec()
"""
function createTLabel(region::SpatialRegion, querydbfile::String;
tfile="trj.t", labelfile="trj.label")
seq2str(seq) = join(map(string, seq), " ") * "\n"
querydbf = h5open(querydbfile, "r")
#打开轨迹数据库文件进行读取
label = Int[]
open(tfile, "w") do f
#打开 tfile 文件进行写入操作
num_query, num_db = read(querydbf["/query/num"]), read(querydbf["/db/num"])
#读取查询数据集和数据库的大小
for i = 1:num_query+num_db
location, idx = i <= num_query ? ("query", i) : ("db", i-num_query)
#遍历所有的轨迹,判断它们是查询数据集的还是数据库的,并对每个轨迹执行以下操作:
trip = read(querydbf["/$location/trips/$idx"])
name = read(querydbf["/$location/names/$idx"])
#从文件中读取轨迹 (trip) 和其名称 (name)。
seq = trip2seq(region, trip)
#调用 trip2seq 函数将轨迹转换成序列
#../preprocessing/SpatialRegionTools.jl里面设置的
#每一个点转化成最近的热门单元格的id
write(f, seq2str(seq))
#使用 seq2str 函数将序列转换为字符串,并写入 tfile 文件
push!(label, name)
#将轨迹的名称添加到 label 数组
end
end
writedlm(labelfile, label)
#将轨迹的名称添加到 label 数组
close(querydbf)
length(label)
end
python t2vec.py -vocab_size 20000 -criterion_name "KLDIV" -knearestvocabs "data/porto-vocab-dist-cell100.h5" -max_num_line 500 -batch=256 -mode=2 -prefix='exp1'
Namespace(batch=256, bidirectional=True, bucketsize=[(20, 30), (30, 30), (30, 50), (50, 50), (50, 70), (70, 70), (70, 100), (100, 100)], checkpoint='/home/ruizhichai/t2vec/data/checkpoint.pt', criterion_name='KLDIV', cuda=True, data='/home/ruizhichai/t2vec/data', discriminative_w=0.1, dist_decay_speed=0.8, dropout=0.2, embedding_size=256, epochs=15, generator_batch=32, hidden_size=256, knearestvocabs='data/porto-vocab-dist-cell100.h5', learning_rate=0.001, max_grad_norm=5.0, max_length=200, max_num_line=500, mode=2, num_layers=3, prefix='exp1', pretrained_embedding=None, print_freq=50, save_freq=100, start_iteration=0, t2vec_batch=256, use_discriminative=False, vocab_size=20000)
=> loading checkpoint '/home/ruizhichai/t2vec/data/checkpoint.pt'
0: Encoding 256 trjs...
/home/ruizhichai/t2vec/data_utils.py:84: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
src = list(np.array(src)[idx])
100: Encoding 256 trjs...
200: Encoding 256 trjs...
300: Encoding 256 trjs...
=> saving vectors into /home/ruizhichai/t2vec/data/exp1-trj.h5
## load vectors and labels
vecs = h5open(vecfile, "r") do f
read(f["layer3"])
end
#打开名为 vecfile 的HDF5文件
#读取最后一层hidden state的内容
label = readdlm(labelfile, Int)
#读取轨迹对应的标签
query, db = vecs[:, 1:num_query], vecs[:, num_query+1:end]
#查询向量 query 被设定为 vecs 的前 num_query 列。
#数据库向量 db 被设定为 vecs 的剩余列。
queryLabel, dbLabel = label[1:num_query], label[num_query+1:end]
#对应地,将 label 分为 queryLabel(查询标签)和 dbLabel(数据库标签)
query, db = [query[:, i] for i in 1:size(query, 2)], [db[:, i] for i in 1:size(db, 2)];
#将 query 和 db 各自转换为一个由列向量组成的列表,以便之后进行搜索
# without discriminative loss
dbsizes = [20_000, 40_000, 60_000, 80_000, 100_000]
#定义一个数组 dbsizes,包含了不同的数据库大小设置
for dbsize in dbsizes
ranks = ranksearch(query, queryLabel, db[1:dbsize], dbLabel[1:dbsize], euclidean)
#调用 ranksearch 函数(未在代码段中给出),这个函数执行搜索任务,计算每个查询向量在数据库中的排名
println("mean rank: $(mean(ranks)) with dbsize: $dbsize")
end
"""
计算查询集(query)中每个轨迹(trajectory)在数据库集(db)中的对应轨迹的排名
For each trj in query, computing the rank of its twin trj in db.
This function is applicable to vector representations.
query[:, i] (db[:, i]) represent a trajectory.
"""
function ranksearch(query::Matrix{T},
queryLabel::Vector{Int},
db::Matrix{T},
dbLabel::Vector{Int}) where T
@assert size(query, 2) == length(queryLabel) "unmatched query and label"
@assert size(db, 2) == length(dbLabel) "unmatched db and label"
#确保提供的查询向量的数量与查询标签的数量一致,数据库向量的数量与数据库标签的数量一致
println("Building KDTree...")
kdtree = KDTree(db)
#利用数据库向量 db 建立一个KD树(kdtree),以便快速进行k最近邻搜索
dbsize = length(dbLabel)
function rank(x::Vector{T}, xLabel::Int)
#内部定义了一个 rank 函数,它接受一个向量 x 和其对应的标签 xLabel。
k = 5
@assert k < dbsize "The database size is too small"
while k < dbsize
idxs, _ = knn(kdtree, x, k, true)
r = findfirst(t->t==xLabel, dbLabel[idxs])
r == nothing ? k = 2k : return r
end
dbsize
'''
设置变量 k 的初始值为5,k 表示在KD树中搜索最近邻的数量。
进行一个循环,循环条件是 k 小于数据库大小 dbsize。
使用 knn 函数从KD树中找到 x 的 k 个最近邻。
通过 dbLabel[idxs] 获取这些最近邻对应的标签,并查找 xLabel 是否在其中。
如果找到 xLabel,则返回其在最近邻中的位置 r(这就是排名)。
如果未找到,将 k 的值翻倍,再次搜索,直到找到位置或者 k 不小于 dbsize。
'''
end
ranks = Int[]
for i = 1:length(queryLabel)
push!(ranks, rank(query[:,i], queryLabel[i]))
end
'''
初始化一个空的整数数组 ranks 来存放每个查询向量的排名。
遍历查询标签 queryLabel,对于每个查询向量,调用 rank 函数计算其排名,并将排名添加到 ranks 数组。
'''
ranks
end