1.安装labelme,用于标注,对机器学习来说就是选取像素样本,这一步对监督分类来说是没法避免的,你想想深度学习还要画样本呢,只不过这里选取样本比较随意,一点都不麻烦。
pip install labelme
安装好后直接在命令行输入labelme按enter,接着工具就会弹出来了
2.获取样本
放大看:
取任意形状都可以接受,每个类别的范围不包含其它类就行,这里我分了两个类别,请注意在弹出来填写类别的框中务必以0,1,2,3…这样从0开始按顺序给类别,主要是为了方便后面的处理,如果不照做会报错。
样本选取好后,点击那个保存按钮,图像目录下会自动生成一个.json文件。
3.标签转换
参考链接:https://zhuanlan.zhihu.com/p/116023772
import os
json_folder = r"C:\Users\Administrator\Desktop\data\test" #这个路径包含了图像和对应的json文件,就是上面截图的路径
# 获取文件夹内的文件名
FileNameList = os.listdir(json_folder)
# 激活labelme环境
os.system("activate labelme")
for i in range(len(FileNameList)):
# 判断当前文件是否为json文件
if(os.path.splitext(FileNameList[i])[1] == ".json"):
json_file = json_folder + "\\" + FileNameList[i]
# 将该json文件转为png
os.system("labelme_json_to_dataset " + json_file)
这一步是利用了labelme自带的转换工具生成了需要的数据,包括图像、标签等,后面分类也是在转出的基础上做的,如果觉得会有图像质量损失的问题,你们自己把预测图换一下就好了,看完代码应该能明白什么意思,其实效果差不多的,我对比了,不用另外麻烦去改代码了。
4.训练预测代码
这里用到了配置文件的方式读取数据,下面会给出配置文件
# -*- coding: utf-8 -*-
from osgeo import ogr, osr
from osgeo import gdal
from gdalconst import *
import os, sys, time
import copy
from tqdm import tqdm
import numpy as np
import cv2
from PIL import Image
from collections import Counter
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from skimage import morphology, filters
import pickle
# import numba
# from numba import jit
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_labels, create_pairwise_bilateral, create_pairwise_gaussian
def getValues(img_path,mask_path):
img = cv2.imread(img_path)
mask = Image.open(mask_path)
mask = mask.convert('P')
mask = np.array(mask)
statis = mask.flatten()
lab_dict = Counter(statis)
del lab_dict[0]
class_list = {
}
for k,v in lab_dict.items():
temp_list = []
mask_temp = mask.copy()
mask_temp[mask_temp<k] = 0
mask_temp[mask_temp>k] = 0
mask_temp[mask_temp==k] = 1
masked = cv2.add(img, np.zeros(np.shape(img), dtype=np.uint8), mask=mask_temp)
masked = cv2.cvtColor(masked, cv2.COLOR_BGR2RGB)
cv2.imwrite(str(k)+'.png',masked)
masked = Image.fromarray(masked.astype('uint8'))
a, b = masked.size
for i in range(a):
for j in range(b):
pixel = masked.getpixel((i,j))
if pixel != ((0,0,0)):
temp_list.append(pixel)
class_list[k] = np.array(temp_list)
return class_list
def svm_train(class_list, img_arr, model_path):
array_num = len(class_list)
RGB_arr = np.array([[0,0,0]])
label= np.array([])
count = 0
class_final = {
}
for i in sorted(class_list):
RGB_arr = np.concatenate((RGB_arr,class_list[i]),axis=0)
array_l = class_list[i].shape[0]
label = np.append(label, count * np.ones(array_l))
class_final[i] = count
count += 1
RGB_arr = np.delete(RGB_arr,0,0)
if os.path.exists(model_path):
pass
else:
rf = RandomForestClassifier(n_estimators=500, max_depth=10, n_jobs=14)
rf.fit(RGB_arr, label)
# svc.fit(RGB_arr,label)
with open(model_path, 'wb') as f:
pickle.dump(rf, f)
return array_num, class_final
def get_model(model_path):
with open(model_path, 'rb') as f:
svc = pickle.load(f)
return svc
def svm_predict(svc, img_arr, array_num, outPath):
temp = copy.copy(img_arr)
img_reshape = img_arr.reshape([img_arr.shape[0]*img_arr.shape[1],img_arr.shape[2]])
predict = svc.predict(img_reshape)
for j in range(array_num):
lake_bool = predict == np.float(j)
lake_bool = lake_bool[:,np.newaxis]
try:
lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool,lake_bool),axis=1)
lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))
img_arr[lake_bool_4d] = np.float(j)
except:
lake_bool_4col = np.concatenate((lake_bool,lake_bool,lake_bool),axis=1)
lake_bool_4d = lake_bool_4col.reshape((img_arr.shape[0],img_arr.shape[1],img_arr.shape[2]))
img_arr[lake_bool_4d] = np.float(j)
# crf_deal = crf(temp, img_arr[:,:,0])
# img_arr = crf_deal.transpose((1,0))
img_arr = img_arr.transpose((2,1,0))
img_arr = img_arr[0]
# write_img(outPath, im_proj, im_geotrans, img_arr)
return img_arr
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
del dataset
return im_proj,im_geotrans,im_width, im_height,im_data
def write_img(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
def write_img_(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, gdal.GDT_Byte)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
def array_change(inlist, outlist):
for i in range(len(inlist[0])):
outlist.append([j[i] for j in inlist])
return outlist
def array_change2(inlist, outlist):
for ele in inlist:
for ele2 in ele:
outlist.append(ele2)
return outlist
def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
out = np.zeros_like(bands).astype(np.float32)
# a = 0
# b = 65535
a = img_min
b = img_max
# print(a, b)
c = np.percentile(bands[:, :], lower_percent)
d = np.percentile(bands[:, :], higher_percent)
# x = d-c
# if (x==0).any():
# t = 0
# else:
t = a + (bands[:, :] - c) * (b - a) / (d - c)
t[t < a] = a
t[t > b] = b
out[:, :] = t
return out
def getTifSize(tif):
dataSet = gdal.Open(tif)
width = dataSet.RasterXSize
height = dataSet.RasterYSize
bands = dataSet.RasterCount
geoTrans = dataSet.GetGeoTransform()
proj = dataSet.GetProjection()
return width,height,bands,geoTrans,proj
# @jit(nopython=True)
def partDivisionForBoundary(model,array_num,tif1,divisionSize,tempPath):
width,height,bands,geoTrans,proj = getTifSize(tif1)
partWidth = partHeight = divisionSize
if width % partWidth > 0 :
widthNum = width // partWidth + 1
else:
widthNum = width // partWidth
if height % partHeight > 0:
heightNum = height // partHeight +1
else:
heightNum = height // partHeight
realName = os.path.split(tif1)[1].split(".")[0]
tif1 = gdal.Open(tif1)
# for i in range(heightNum):
for i in tqdm(range(heightNum), desc='Processing'):
for j in range(widthNum):
startX = partWidth * j
startY = partHeight * i
if startX+partWidth<= width and startY+partHeight<=height:
realPartWidth = partWidth
realPartHeight = partHeight
elif startX + partWidth > width and startY+partHeight<=height:
realPartWidth = width - startX
realPartHeight = partHeight
elif startX+partWidth <= width and startY+partHeight > height:
realPartWidth = partWidth
realPartHeight = height - startY
elif startX + partWidth > width and startY+partHeight > height:
realPartWidth = width - startX
realPartHeight = height - startY
outName = realName + str(i)+str(j)+".tif"
outPath = os.path.join(tempPath,outName)
if not os.path.exists(outPath):
driver = gdal.GetDriverByName("GTiff")
outTif = driver.Create(outPath,realPartWidth,realPartHeight,1,gdal.GDT_Float32)
outTif.SetGeoTransform(geoTrans)
outTif.SetProjection(proj)
data1 = tif1.ReadAsArray(startX,startY,realPartWidth,realPartHeight)
data1 = data1.transpose((2,1,0))
svmData = svm_predict(model, data1, array_num, outPath)
outTif.GetRasterBand(1).WriteArray(svmData)
return 1
# @jit(nopython=True)
def partStretch(tif1,divisionSize,outStratchPath,tempPath):
width,height,bands,geoTrans,proj = getTifSize(tif1)
# bands = 1
partWidth = partHeight = divisionSize
if width % partWidth > 0 :
widthNum = width // partWidth + 1
else:
widthNum = width // partWidth
if height % partHeight > 0:
heightNum = height // partHeight +1
else:
heightNum = height // partHeight
realName = os.path.split(tif1)[1].split(".")[0]
driver = gdal.GetDriverByName("GTiff")
outTif = driver.Create(outStratchPath,width,height,1,gdal.GDT_Byte)
if outTif!= None:
outTif.SetGeoTransform(geoTrans)
outTif.SetProjection(proj)
for i in range(heightNum):
for j in range(widthNum):
startX = partWidth * j
startY = partHeight * i
if startX+partWidth<= width and startY+partHeight<=height:
realPartWidth = partWidth
realPartHeight = partHeight
elif startX + partWidth > width and startY+partHeight<=height:
realPartWidth = width - startX
realPartHeight = partHeight
elif startX+partWidth <= width and startY+partHeight > height:
realPartWidth = partWidth
realPartHeight = height - startY
elif startX + partWidth > width and startY+partHeight > height:
realPartWidth = width - startX
realPartHeight = height - startY
partTifName = realName+str(i)+str(j)+".tif"
partTifPath = os.path.join(tempPath,partTifName)
divisionImg = gdal.Open(partTifPath)
for k in range(1):
data1 = divisionImg.GetRasterBand(k+1).ReadAsArray(0,0,realPartWidth,realPartHeight)
outPartBand = outTif.GetRasterBand(k+1)
outPartBand.WriteArray(data1,startX,startY)
def DoesDriverHandleExtension(drv, ext):
exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS)
return exts is not None and exts.lower().find(ext.lower()) >= 0
def GetExtension(filename):
ext = os.path.splitext(filename)[1]
if ext.startswith('.'):
ext = ext[1:]
return ext
def GetOutputDriversFor(filename):
drv_list = []
ext = GetExtension(filename)
for i in range(gdal.GetDriverCount()):
drv = gdal.GetDriver(i)
if (drv.GetMetadataItem(gdal.DCAP_CREATE) is not None or
drv.GetMetadataItem(gdal.DCAP_CREATECOPY) is not None) and \
drv.GetMetadataItem(gdal.DCAP_VECTOR) is not None:
if ext and DoesDriverHandleExtension(drv, ext):
drv_list.append(drv.ShortName)
else:
prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX)
if prefix is not None and filename.lower().startswith(prefix.lower()):
drv_list.append(drv.ShortName)
return drv_list
def GetOutputDriverFor(filename):
drv_list = GetOutputDriversFor(filename)
ext = GetExtension(filename)
if not drv_list:
if not ext:
return 'ESRI Shapefile'
else:
raise Exception("Cannot guess driver for %s" % filename)
elif len(drv_list) > 1:
print("Several drivers matching %s extension. Using %s" % (ext if ext else '', drv_list[0]))
return drv_list[0]
def crf(inimage,img_anno): # inimage为原图 img_anno为预测结果,我的预测结果是0,1,2,3这样,每个数字代表一个类别
fn_im = inimage
fn_anno = img_anno
img = inimage
anno_rgb = img_anno
rgb = anno_rgb
# print("=========>>", anno_rgb.shape)
#rgb= np.argmax(anno_rgb[0],axis=0)
# print("=======>>",rgb.shape)
# print(np.max(rgb), np.min(rgb))
anno_lbl=rgb
# img = img[0]
# img = img.transpose(1, 2, 0)
colors, labels = np.unique(anno_lbl, return_inverse=True)
colors = colors[1:]
colorize = np.empty((len(colors), 3), np.uint8)
colorize[:,0] = (colors & 0x0000FF)
colorize[:,1] = (colors & 0x00FF00) >> 8
colorize[:,2] = (colors & 0xFF0000) >> 16
# n_labels = len(set(labels.flat))-1
n_labels = len(set(labels.flat)) #这里我把减1去掉了,因为我的所有数字都代表一个类别,没有背景
if n_labels <= 1:
return rgb
use_2d = False
if use_2d:
img = img.astype(int)
d = dcrf.DenseCRF2D(img.shape[1], img.shape[0], n_labels)
U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=False)
d.setUnaryEnergy(U)
d.addPairwiseGaussian(sxy=(3, 3), compat=3, kernel=dcrf.DIAG_KERNEL, #1.CONST_KERNEL 2.DIAG_KERNEL (the default) 3.FULL_KERNEL
normalization=dcrf.NORMALIZE_SYMMETRIC) #1.NO_NORMALIZATION 2.NORMALIZE_BEFORE 3.NORMALIZE_AFTER 4.NORMALIZE_SYMMETRIC (the default)
img = counts = np.copy(np.array(img,dtype = np.uint8),order='C')
d.addPairwiseBilateral(sxy=(80,80), srgb=(13, 13, 13), rgbim=img,
compat=10,
kernel=dcrf.CONST_KERNEL,
normalization=dcrf.NORMALIZE_SYMMETRIC)
else:
#这部分比上面的效果好点,建议用这个
# Example using the DenseCRF class and the util functions
d = dcrf.DenseCRF(img.shape[1] * img.shape[0], n_labels)
# get unary potentials (neg log probability)
U = unary_from_labels(labels, n_labels, gt_prob=0.7, zero_unsure=False) #zero_unsure=False 0不是背景而是一个类别,所以False
d.setUnaryEnergy(U)
# This creates the color-independent features and then add them to the CRF
feats = create_pairwise_gaussian(sdims=(3, 3), shape=img.shape[:2])
d.addPairwiseEnergy(feats, compat=3,
kernel=dcrf.DIAG_KERNEL,
normalization=dcrf.NORMALIZE_SYMMETRIC)
# This creates the color-dependent features and then add them to the CRF
feats = create_pairwise_bilateral(sdims=(80, 80), schan=(13, 13, 13),
img=img, chdim=2)
d.addPairwiseEnergy(feats, compat=10,
kernel=dcrf.DIAG_KERNEL,
normalization=dcrf.NORMALIZE_SYMMETRIC)
Q = d.inference(20)
# Find out the most probable class for each pixel.
MAP = np.argmax(Q, axis=0)
return MAP.reshape(img.shape[:2])
def remove_and_deal(img_array, hole, obj):
# ret, binary = cv2.threshold(img_array, 0, 1, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
binary = img_array
binary = binary.astype(bool)
binary = morphology.remove_small_holes(binary, min_size=hole, connectivity=8)
binary = morphology.remove_small_objects(binary, min_size=obj, connectivity=8)
binary = binary + 0
binary = np.uint8(binary)
return binary
def cls_deal(class_path):
im_proj,im_geotrans,im_width, im_height,im_data = read_img(class_path)
binary_0 = copy.copy(im_data)
binary_0[binary_0==0] = 100
binary_0[binary_0 < 10] = 0
binary_0[binary_0==100] = 1
binary_0 = remove_and_deal(binary_0, 2000, 500)
temp0 = os.path.join(temp_path, '0.tif')
write_img(temp0, im_proj, im_geotrans, binary_0)
binary_1 = copy.copy(im_data)
binary_1[binary_1 == 1] = 100
binary_1[binary_1 < 10] = 0
binary_1[binary_1 == 100] = 1
binary_1 = remove_and_deal(binary_1, 2000, 500)
temp1 = os.path.join(temp_path, '1.tif')
write_img(temp1, im_proj, im_geotrans, binary_1)
binary_2 = copy.copy(im_data)
binary_2[binary_2 == 2] = 100
binary_2[binary_2 < 10] = 0
binary_2[binary_2 == 100] = 1
binary_2 = remove_and_deal(binary_2, 2000, 500)
temp2 = os.path.join(temp_path, '2.tif')
write_img(temp2, im_proj, im_geotrans, binary_2)
binary_3 = copy.copy(im_data)
binary_3[binary_3 == 3] = 100
binary_3[binary_3 < 10] = 0
binary_3[binary_3 == 100] = 1
binary_3 = remove_and_deal(binary_3, 2000, 500)
temp3 = os.path.join(temp_path, '3.tif')
write_img(temp3, im_proj, im_geotrans, binary_3)
if __name__ == '__main__':
config_file='config_order.txt'
dirs=[]
for line in open(config_file):
dirs.append(line.split()[0])
data_image = dirs[0]
data_image = data_image.replace('\\','/')
mask_path = dirs[1]
mask_path = mask_path.replace('\\','/')
task_image = dirs[2]
task_image = task_image.replace('\\','/')
result_path = dirs[3]
result_path = result_path.replace('\\','/')
temp_path = dirs[4]
temp_path = temp_path.replace('\\','/')
time1 = time.time()
print('Start ...')
class_list = getValues(data_image,mask_path)
print('Train model ...')
model_path = os.path.join(temp_path, 'model.pickle')
num, class_final = svm_train(class_list,data_image,model_path)
svm = get_model(model_path)
slice_path = os.path.join(temp_path, 'slice_temp')
if os.path.exists(slice_path):
pass
else:
os.mkdir(slice_path)
print('Predict task area ...')
partDivisionForBoundary(svm,num,task_image,1000,slice_path)
raster_path = os.path.join(temp_path, 'class_raster.tif')
partStretch(task_image,1000,raster_path,slice_path)
im_proj, im_geotrans, im_width, im_height, im_data = read_img(raster_path)
im_proj, im_geotrans, im_width, im_height, im_data2 = read_img(task_image)
im_data = im_data.transpose((1,0))
im_data2 = im_data2.transpose((2,1,0))
crf_deal = crf(im_data2, im_data)
crf_deal = crf_deal.transpose((1,0))
raster_path = os.path.join(temp_path, 'class_raster2.tif')
write_img_(raster_path, im_proj, im_geotrans, crf_deal)
time2 = time.time()
print((time2-time1)/3600)
配置文件名字config_order.txt,内容如下:
C:\Users\Administrator\Desktop\data\test2\fixed_json\img.png #样本图
C:\Users\Administrator\Desktop\data\test2\fixed_json\label.png #样本图上选取的标签,上面已经生成了
C:\Users\Administrator\Desktop\data\test2\fixed_json\img.png #需要预测的图,考虑到可能图很大,所以选取样本的图可以是从大图上裁剪下来的,而预测的图可以是别的,反正模型是根据样本图训练生成的,这个要注意
C:\Users\Administrator\Desktop\data\test2\temp/t.png # 这个路径本来是用来放结果的,现在没有用到,但是要有,代码自己改吧
C:\Users\Administrator\Desktop\data\test2\temp #新建一个中间文件,里面放结果
运行结果:
注:这里用的是随机森林的方法,开头import还导入了SVM需要的话自己替换就行了