原型网络Prototypical Network的python代码逐行解释,新手小白也可学会!!-----系列6 (承接系列5)

文章目录

  • 一、原始代码---随机采样和评估模型
  • 二、详细解释分析每一行代码


一、原始代码—随机采样和评估模型

def randomSample(self,D_set): #从D_set随机取支持集和查询集(20个类中的其中一个类,shape为[20,105,105])
		index_list = list(range(D_set.shape[0]))#20个图片中选5个
		random.shuffle(index_list)
		support_data_index = index_list[:self.Ns]
		query_data_index = index_list[self.Ns:self.Ns + self.Nq]
		support_set = []
		query_set = []
		for i in support_data_index:
			support_set.append(D_set[i])
		for i in query_data_index:
			query_set.append(D_set[i])
		return support_set,query_set
	
	def evaluation_model(self,labels_data,class_number):
		test_accury = []
		center_for_test={}
		class_index = list(range(class_number))#600多类
		random.shuffle(class_index)
		choss_class_index = class_index[:self.Nc]#选20个类
		sample = {'xc':[],'xq':[]}
		for label in choss_class_index:
			D_set = labels_data[label]
			#从D_set随机取支持集和查询集
			support_set,query_set = self.randomSample(D_set)
			#计算中心点
			center_for_test[label] = self.compute_center(support_set)
			#将中心和查询集存储在list中
			sample['xc'].append(center_for_test[label])	#list
			sample['xq'].append(query_set)

二、详细解释分析每一行代码

def randomSample(self,D_set)::定义一个名为randomSample的方法,该方法属于某个类的实例方法。它接受一个名为D_set的参数,这个参数应该是一个三维数组(20个类别,每个类别有105*105个数据)。

index_list = list(range(D_set.shape[0])):生成一个包含所有索引的列表。这里使用range(D_set.shape[0])来生成从0到D_set长度(即类别数)的整数序列。

random.shuffle(index_list):使用random.shuffle函数将index_list中的元素随机打乱顺序,以便在下面的代码中选择随机的索引。

support_data_index = index_list[:self.Ns]:选取前self.Ns个索引作为支持集的索引。

query_data_index = index_list[self.Ns:self.Ns + self.Nq]:选取从第self.Ns个索引到第self.Ns + self.Nq个索引作为查询集的索引。

support_set = [] 和 query_set = []:初始化两个空列表,用于存储从D_set中提取的支持集和查询集。

在接下来的两个循环中,对每个支持集索引和查询集索引,从D_set中提取对应的样本并添加到对应的集合中。

return support_set,query_set:返回支持集和查询集。

def evaluation_model(self,labels_data,class_number)::定义一个名为evaluation_model的方法,该方法属于某个类的实例方法。它接受两个参数:labels_data(包含所有类别数据的数组)和class_number(类别数)。

test_accury = []:初始化一个空列表,用于存储模型的测试准确度。

class_index = list(range(class_number)):生成一个包含所有类别索引的列表。

random.shuffle(class_index):使用random.shuffle函数将class_index中的元素随机打乱顺序,以便在下面的代码中选择随机的类别。

choss_class_index = class_index[:self.Nc]:选取前self.Nc个类别作为选择的类别。

初始化一个字典sample,包含两个键值对:'xc’对应一个空列表,'xq’对应一个空列表。

在接下来的循环中,对于选择的每个类别,执行以下操作:

a. 从该类别的数据中随机选择支持集和查询集(使用之前定义的randomSample方法)。

b. 计算支持集的中心点(使用之前定义的compute_center方法)。

c. 将中心点和查询集的元素添加到字典的对应列表中。

return sample:返回包含中心点和查询集的字典。


你可能感兴趣的:(Python程序代码,python,开发语言)