参考文献:
pytorch以图搜图作业__-周-_的博客-CSDN博客_以图搜图算法pytorch
pytorch加载自己的图片数据集的两种方法__-周-_的博客-CSDN博客_pytorch读取图片数据集
pytorch对网络层的增,删, 改, 修改预训练模型结构__-周-_的博客-CSDN博客_pytorch修改网络结构
1.网络的修改
1)保留vgg16提取特征网络,去除全连接层和 avgpool层, 并且给最后一个卷积层改成通道数为1
net = models.vgg16(pretrained=True)
net.classifier = nn.Sequential()
net.features[28] = nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1)
net.avgpool = nn.Sequential()
2.数据集的加载
1)定义一个转化成txt的函数
def mak_txt(root, file_name):
path = os.path.join(root, file_name)
data = os.listdir(path)
f = open(path + '\\' + 'f.txt', 'w')
for line in data:
if line == 'f.txt':
continue
f.write(line + '\n')
f.close()
2)调用mak_txt函数转化txt
image_packages = r'D:\AI\images_retreve\image_packages'
inputs_images = r'D:\AI\images_retrev\inputs_images'
path = r'D:\AI\images_retreve'
mak_txt(path, 'image_packages')
mak_txt(path, 'inputs_images')
3)图片预处理
# 进行图片预处理
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()
])
4)定义mydataset
class MyDataset(Dataset):
def __init__(self, img_path, transform=None):
super(MyDataset, self).__init__()
self.img_path = img_path
self.txt_root = img_path + r'\f.txt'
f = open(self.txt_root, 'r')
data = f.readlines()
imgs = []
for line in data:
line.strip()
word = line.split()
imgs.append(os.path.join(self.img_path, word[0]))
self.img = imgs
self.transform = transform
def __len__(self):
return len(self.img)
def __getitem__(self, item):
img = self.img[item]
img = Image.open(img).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img
5)加载数据集
dataset_inputs = MyDataset(inputs_images, transform=transform)
dataset_packages = MyDataset(image_packages, transform=transform)
data_loader_inputs = DataLoader(dataset=dataset_inputs, batch_size=1, shuffle=False)
data_loader_packages = DataLoader(dataset=dataset_packages, batch_size=100, shuffle=False)
3.计算相似性
1)开始输入数据
for i, data in enumerate(data_loader_inputs):
output_inputs = net(data)
for i, data in enumerate(data_loader_packages):
output_packages = net(data)
print(output_inputs.shape)
print(output_packages.shape)
2)调用F库中的欧式距离方法
dist2 = F.pairwise_distance(output_inputs, output_packages, p=2)
print(dist2.shape)
4.输出最相似的三张图片
1)输出最相似的三个图片的索引
max_list = []
for i in range(3):
max_n = torch.argmin(dist2)
max_list.append(int(max_n))
dist2[max_n] = 9999999.9
print(max_list)
2)根据索引找到原图片
path_dir = image_packages + r'\f.txt'
f = open(path_dir, 'r')
data = f.readlines()
data_img = []
for i in range(3):
img_path = os.path.join(image_packages, data[max_list[i]])
data_img.append(img_path)
3)创建画布,将图片放在画布上展示出来
fig = plt.figure(figsize=(10, 10))
for i in range(1, 4):
ax = fig.add_subplot(3, 1, i) # 创建一个3行1列的画布, 遍历依次为第1个、第2个画布、第3个画布
img = Image.open(data_img[i - 1].strip())
ax.imshow(img)
pass
plt.show()