图片去重算法(图片数量十万级以下)

一. 场景说明

经常遇到一种情况,手机或者电脑里面的图片太多并且存在重复的图片。这些重复的图片浪费设备的内存,同时也提高了处理这些数据的成本。
博主是学AI的,因此基于神经网络开发了一个图片去重算法。

二. 基本思路

  1. 先用视觉模型提取图片的特征
  2. 轮流对比图片的特征,将相似度很好的图片过滤掉
    代码实现:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
import glob
from tqdm import tqdm


class FeatureExtract(object):
    def __init__(self):
        # 加载预训练的ResNet18模型
        self.resnet = models.resnet18(pretrained=True)

        # 移除最后一层全连接层
        self.resnet = torch.nn.Sequential(*list(self.resnet.children())[:-1])

        # 设置模型为评估模式
        self.resnet.eval()

        self.preprocess = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])

    def feature_extract(self, image_path):
        # 加载和预处理图像
        image = Image.open(image_path)
        input_tensor = self.preprocess(image)
        input_batch = input_tensor.unsqueeze(0)

        # 使用模型提取特征
        with torch.no_grad():
            features = self.resnet(input_batch)

        # 输出特征向量
        return features.squeeze()


def is_duplicate(features, feature, thres=0.99):
    if len(features) == 0:
        return False
    for feat in features:
        similarity = F.cosine_similarity(feat, feature, dim=0).item()
        if similarity > thres:
            return True
    return False


if __name__ == "__main__":
    extract = FeatureExtract()
    features = []
    images = glob.glob("/home/gp/workspace/images/*.jpg")
    num = 0
    for img_path in tqdm(images):
        feature = extract.feature_extract(img_path)
        flag = is_duplicate(features, feature)
        if flag:
            num += 1
            print("copy")
        else:
            features.append(feature)
    print("copy num:%s" % num)

三. 优化思路

上面的代码可能比较慢,特别是图片比较多的时候。

  1. 批处理计算:可以将向量列表分成小批次进行计算,而不是逐个遍历每个向量。这样可以利用矩阵运算的并行性,提高计算效率。可以使用PyTorch的torch.stack()函数将向量列表转换为一个大张量,然后使用矩阵乘法或批量计算余弦相似度;
  2. 利用生产者消费者模式,生产者读取图片并提取特征,放入队列;消费者从队列中取图片,并计算相似度和后处理。
    代码实现:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
import glob
from tqdm import tqdm

import threading
import queue
import shutil
import uuid
import concurrent.futures


def is_duplicate_old(features, feature, thres=0.99):
    if len(features) == 0:
        return False
    for feat in features:
        similarity = F.cosine_similarity(feat, feature, dim=0).item()
        if similarity > thres:
            return True
    return False


def is_duplicate(features, feature, thres=0.99):
    if len(features) < 1000:
        return is_duplicate_old(features, feature, thres=0.99)

    num_vectors = len(features)
    batch_size = 1000  # 每批处理的向量数量
    max_similarity = -1
    with concurrent.futures.ThreadPoolExecutor() as executor:
        for i in range(0, num_vectors, batch_size):
            batch_vector = torch.stack(features[i:i+batch_size])
            similarity = F.cosine_similarity(feature.unsqueeze(0), batch_vector, dim=1)
            batch_max_similarity, max_batch_index = torch.max(similarity, dim=0)
            batch_max_similarity = batch_max_similarity.item()
            if batch_max_similarity > thres:
                return True
    return False



class ProducerConsumer:
    def __init__(self):
        self.queue = queue.Queue(maxsize=1000)  # 设置队列长度为1000
        self.extract = FeatureExtract()
        self.features = []

    def produce(self):
        images = glob.glob("imgs/*.jpg")
        for img_path in tqdm(images):
            feature = self.extract.feature_extract(img_path)
            self.queue.put([img_path, feature])
        self.queue.put(None)

    def consume(self):
        while True:
            img_info = self.queue.get()
            if img_info is None:
                break
            img_path, feature = img_info
            flag = is_duplicate(self.features, feature)
            if flag:
                continue
            else:
                self.features.append(feature)
                img_save_path = "images/%s.jpg" % uuid.uuid4()
                shutil.copy(img_path, img_save_path)

    def run(self):
        producer_thread = threading.Thread(target=self.produce)
        consumer_thread = threading.Thread(target=self.consume)

        producer_thread.start()
        consumer_thread.start()

        producer_thread.join()
        consumer_thread.join()



class FeatureExtract(object):
    def __init__(self):
        # 加载预训练的ResNet18模型
        self.resnet = models.resnet18(pretrained=True)

        # 移除最后一层全连接层
        self.resnet = torch.nn.Sequential(*list(self.resnet.children())[:-1])

        # 设置模型为评估模式
        self.resnet.eval()

        self.preprocess = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])

    def feature_extract(self, image_path):
        # 加载和预处理图像
        image = Image.open(image_path)
        input_tensor = self.preprocess(image)
        input_batch = input_tensor.unsqueeze(0)

        # 使用模型提取特征
        with torch.no_grad():
            features = self.resnet(input_batch)

        # 输出特征向量
        return features.squeeze()



if __name__ == "__main__":
    import time
    start = time.time()
    # 创建对象并运行
    pc = ProducerConsumer()
    pc.run()
    end = time.time()
    print(end-start)

实验测试,速度提升50%,改善效果明显。

你可能感兴趣的:(人工智能,算法,python,深度学习)