此部分为原型网络的两个函数,分别为保存原型点函数和加载原型点函数,与之前的系列相承接。
def save_center(self,path):
datas = []
for label in self.center.keys():
datas.append([label] + list(self.center[label].cpu().detach().numpy()))
with open(path,"w", newline="") as datacsv:
csvwriter = csv.writer(datacsv,dialect = ("excel"))
csvwriter.writerows(datas)
def load_center(self,path):
csvReader = csv.reader(open(path))
for line in csvReader:
label = int(line[0])
center = [ float(line[i]) for i in range(1,len(line))]
center = np.array(center)
center = Variable(torch.from_numpy(center))
self.center[label] = center
save_center(self, path):
datas = []
: 初始化一个空列表,用于存储要写入文件的数据。
for label in self.center.keys()
: 遍历 self.center 字典的键(可能代表中心的不同标签或名称)。
datas.append([label] + list(self.center[label].cpu().detach().numpy())):
将每个键以及与该键关联的值添加到 datas 列表中。
这里,.cpu().detach().numpy()
是将 PyTorch 张量转换为 numpy 数组的过程。
with open(path, "w", newline="") as datacsv:
使用 “w” 模式(写模式)打开文件,如果文件已存在,则覆盖它。
csvwriter = csv.writer(datacsv, dialect = ("excel")):
使用 csv 模块创建一个写入器,指定使用的语法为 Excel 语法。
csvwriter.writerows(datas)
: 将 datas 列表中的所有行写入到文件中。
load_center(self, path):
csvReader = csv.reader(open(path)):
使用 csv 模块创建一个读取器,以读取打开的文件。
for line in csvReader:
遍历文件中的每一行。
label = int(line[0]):
从第一列读取一个整数,并将其赋值给变量 label。
center = [ float(line[i]) for i in range(1,len(line))]:
从第二列到最后一列读取一系列浮点数,并将它们放入一个列表中。
center = np.array(center):
将上述列表转换为 numpy 数组。
center = Variable(torch.from_numpy(center)):
将 numpy 数组转换回 PyTorch 张量,并使用 torch.from_numpy() 方法。
self.center[label] = center:
将新加载的中心存储到 self.center 字典中,使用从文件中读取的标签作为键。