python实现grabcut算法进行物体分割

GrabCut算法进行物体分割

  • GrabCut算法原理
    • python实现GrabCut算法

GrabCut算法原理

(1)在图片中定义含有(一个或多个)物体的矩形。
(2)矩形外外的区域被自动认为是背景。
(3)对于用户定义的矩形区域,可用背景中的数据来区别它里面的前景和背景区域。
(4)用高斯混合模型(Gaussians Mixture Model,GMM)来对背景和前景建模,并将未定义的像素标记为可能的前景和背景。
(5)图像中的每一个像素都被看作通过虚拟边与周围像素的链接,而每条边都有一个属于前景或背景的概率,这基于它与周围像素颜色上的相似性。
(6)每一个像素(即算法中的节点)会与一个前景或背景节点链接。
(7)在节点完成链接后(可能与背景链接,也可能与前景链接),若节点之间的边属于不同终端(即一个节点属于前景,另一个节点属于背景),则会切断他们之间的边,这就能将图像各部分分割出来。

python实现GrabCut算法

(批量处理图片)

'''
grabcut算法进行物体分割
'''
import numpy as np
# from PIL import Image
import cv2
import os
from matplotlib import pyplot as plt

# img_prefix='.JPG'
# path="G:/Datasets/goldcion/train"
# def read_imgs(path):
#     files=os.listdir(path)
#     for file in files:
#         index = file.find('.')
#         prefix-file[index+1]
#         if prefix in img_prefix:
#             print(file)
#             return file
fp = open('G:/Datasets/goldcion/images_train.txt', 'r')
# img = Image.open('G:\\Datasets\\goldcion\\2-1-1.JPG')#返回一个Image对象
# print('宽:%d,高:%d'%(img.size[0],img.size[1]))
# img=img.resize((224,224))
for file in fp:
    file=file.strip("\n")
    file=file.split(' ')
    file=file[0]
    pth=file.split('/')
    pth1='G:/Datasets/goldcoin_seg/train/%s/'%pth[4]
    pth0=pth[4]+'/'+pth[5]
    img=cv2.imread(file)
    # img=cv2.imread('G:/Datasets/goldcion/2-1-1.JPG')
    img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_CUBIC)
    mask=np.zeros(img.shape[:2],np.uint8)

    bgdModel=np.zeros((1,65),np.float64)
    fgdModel=np.zeros((1,65),np.float64)

    rect=(5,5,224,224)
    cv2.grabCut(img,mask,rect,bgdModel,fgdModel,5,cv2.GC_INIT_WITH_RECT)

    mask2=np.where((mask==2)|(mask==0),0,1).astype('uint8')
    img=img*mask2[:,:,np.newaxis]
    if not os.path.lexists(pth1):
        os.makedirs(pth1)

    cv2.imwrite('G:/Datasets/goldcoin_seg/train/%s'%pth0,img)#第一个参数是目标存放的名字,第二个参数是目标
# plt.subplot(121),plt.imshow(img)
# plt.title("grabcut"),plt.xticks([]),plt.yticks([])
# plt.subplot(122),plt.imshow(cv2.cvtColor(cv2.imread('G:/Datasets/goldcion/2-1-1.JPG'),cv2.COLOR_BGR2GRAY))
# plt.title("original"),plt.xticks([]),plt.yticks([])
# plt.show()

你可能感兴趣的:(grabcut,python实现opencv,图像分割)