使用pytorch和opencv根据颜色相似性提取图像

需求:将下图中的花朵提取出来。

使用pytorch和opencv根据颜色相似性提取图像_第1张图片

代码:

import cv2
import torch
import numpy as np
import time

def get_similar_colors(image, color_list, threshold):
    # 将图像和颜色列表转换为torch张量
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image_tensor = torch.from_numpy(image.astype(np.float32)).to(device)
    color_tensor = torch.tensor(color_list, dtype=torch.float32).to(device)

    # 计算每个像素与颜色列表中每个颜色的距离
    distances = torch.cdist(image_tensor.view(-1, 3), color_tensor, p=2).view(image_tensor.shape[0], image_tensor.shape[1], -1)

    # 找到最小距离及其索引
    min_distances, _ = torch.min(distances, dim=-1)

    # 创建掩码,标记接近目标颜色的像素
    mask = min_distances < threshold

    # 根据掩码提取接近颜色的部分
    result = torch.where(mask.unsqueeze(-1), image_tensor, torch.zeros_like(image_tensor))

    # 将结果转换回numpy数组
    result_np = result.cpu().numpy().astype(np.uint8)

    return result_np
# 读取图像s
image = cv2.imread('flower2.jpg')
# 定义颜色列表,每个颜色用BGR格式表示
color_list = [(15, 220, 255),(30, 50, 220)]
# 定义颜色接近度的阈值
threshold = 100
time_start = time.time()
# 提取接近颜色的部分
extracted_image = get_similar_colors(image, color_list, threshold)
time_end = time.time()
time = time_end - time_start
print("time: ", time)

# 显示原始图像和提取结果
cv2.imshow('Original Image', image)
cv2.imshow('Extracted Image', extracted_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

 使用pytorch和opencv根据颜色相似性提取图像_第2张图片

进一步,输出掩码部分的黑白图像

import cv2
import torch
import numpy as np
import time

def get_similar_colors(image, color_list, threshold):
    # 将图像和颜色列表转换为torch张量
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    image_tensor = torch.from_numpy(image.astype(np.float32)).to(device)
    color_tensor = torch.tensor(color_list, dtype=torch.float32).to(device)

    # 计算每个像素与颜色列表中每个颜色的距离
    distances = torch.cdist(image_tensor.view(-1, 3), color_tensor, p=2).view(image_tensor.shape[0], image_tensor.shape[1], -1)

    # 找到最小距离及其索引
    min_distances, _ = torch.min(distances, dim=-1)

    # 创建掩码,标记接近目标颜色的像素
    mask = min_distances < threshold

    # 将符合条件的像素设置为黑色
    result = np.ones_like(image_tensor)
    result[mask] = [0, 0, 0]  # 设置为黑色

    return result
# 读取图像s
image = cv2.imread('your/image/path')
# 定义颜色列表,每个颜色用BGR格式表示
color_list = [(50, 15, 0), (45, 10, 0), (30, 10, 0)]
# 定义颜色接近度的阈值
threshold = 100
time_start = time.time()
# 提取接近颜色的部分
extracted_image = get_similar_colors(image, color_list, threshold)
time_end = time.time()
time = time_end - time_start
print("time: ", time)

# 显示原始图像和提取结果
cv2.imshow('Original Image', image)
cv2.imshow('Extracted Image', extracted_image)
cv2.waitKey(0)
cv2.destroyAllWindows()

你可能感兴趣的:(机器视觉和人工智能学习,opencv学习笔记,pytorch,opencv,人工智能)