ImageEnhance.Sharpness(image)
类ImageEnhance.Sharpness(image)
是 PIL 中的一个类, 用于实现图像的锐化。
该类的构造函数接受一个PIL格式的图像作为参数,创建一个锐化器对象。
它提供了一个enhance()
方法来锐化图像,该方法接受一个参数factor
,表示锐化因子,值域在0~1之间,取值越大图像锐化越高,取值越小图像锐化越低。
如果你想使用这个类来增强图像的锐度,你需要先创建一个ImageEnhance.Sharpness
的对象,然后调用enhance()方法来增强图像锐度。
例如:
image = Image.open("image.jpg")
sharp_image = ImageEnhance.Sharpness(image)
image = sharp_image.enhance(1.5)
image.save("sharp_image.jpg")
这样就可以将原图像增强锐度1.5倍后保存成"sharp_image.jpg"
简单来说,ImageEnhance.Sharpness()
这个类是PIL中用来增强图像锐度的类, 可以用来实现锐度变换。
cv2.imencode()
函数cv2.imencode()
函数接受三个参数,分别是:
例如:
param = [int(cv2.IMWRITE_JPEG_QUALITY), random.randint(self.lower, self.upper)]
img_encode = cv2.imencode('.jpeg', img, param)
img_decode = cv2.imdecode(img_encode[1], cv2.IMREAD_COLOR)
在上面的例子中,第一个参数 ext 传入的是 '.jpeg'
代表使用JPEG格式进行编码。
第二个参数 img 传入的是 cv2 图像数据。
第三个参数 params 传入的是列表, [int(cv2.IMWRITE_JPEG_QUALITY), random.randint(self.lower, self.upper)]
代表使用JPEG编码格式并且设置质量值。
其它编码格式也可以使用这个函数,只需要在params中添加不同的参数即可,比如使用 PNG 格式编码时,params 可以是 [int(cv2.IMWRITE_PNG_COMPRESSION), 9]
代表设置压缩级别。
random.randint()
函数random.randint()
是 Python 标准库 random 模块中的一个函数,它可以生成一个在给定范围内的随机整数。
语法:random.randint(a, b)
其中 a 是随机整数的最小值,b 是随机整数的最大值,返回一个 a 到 b 之间的随机整数(包括 a 和 b)。
在附录的代码中,random.randint(0, w)
用于随机生成一个宽度范围内的值,random.randint(0, h)
用于随机生成一个高度范围内的值,random.randint(x0, w)
用于随机生成一个宽度范围内的值,random.randint(y0, h)
用于随机生成一个高度范围内的值。
Image.new()
函数Image.new()
是 Python Imaging Library (PIL) 中的一个函数,它可以创建一个新图像。
用法:Image.new(mode, size, color)
,其中 mode
是图像模式('L'
代表灰度图像,'RGB'
代表彩色图像,等等),size 是图像大小(用元组 (width, height) 表示),color 是填充颜色。
例如:
mask=Image.new('L', (w, h), color=255)
在上面的代码中,mask=Image.new('L', (w, h), color=255)
用于创建一个大小为 (w, h)
且灰度图像,初始化颜色为255的图像。
ImageDraw.Draw()
函数ImageDraw.Draw()
是 Python Imaging Library (PIL) 中的一个函数,它可以创建一个画图对象。
用法:ImageDraw.Draw(image)
,其中 image 是一个 PIL 图像对象。
例如:
draw=ImageDraw.Draw(mask)
draw.rectangle(transparent_area, fill=random.randint(150,255))
在上面的代码中,draw=ImageDraw.Draw(mask)
用于创建一个画图对象,它可以对 mask 这个 PIL 图像对象进行绘图操作。
draw.rectangle()
是 PIL 库中的函数,可以在一个图像上画矩形。
用法:draw.rectangle(box, options)
,其中 box 是矩形的坐标(用元组 (x0, y0, x1, y1)
表示),options 是矩形的其它选项,如填充颜色(fill
)等。
在这里,draw.rectangle(transparent_area, fill=random.randint(150,255))
这一行代码用于在mask这个图像上画一个矩形,坐标是transparent_area
,填充颜色是随机生成的一个150~255之间的颜色值。
np.clip()
函数np.clip()
是 numpy 中的一个函数,可以将数组中的元素限制在一个给定的最小值和最大值之间。
用法:numpy.clip(a, a_min, a_max, out=None)
,其中 a 是输入数组,a_min 是最小值,a_max 是最大值。
reflection_result = np.clip(reflection_result, 0, 255)
在这里,reflection_result = np.clip(reflection_result, 0, 255)
这一行代码用于将 reflection_result
数组中的所有元素限制在 0 和 255 之间。这样做的目的是为了确保图像的像素值在合理的范围内,避免出现像素值过大或过小的情况。
self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path).items()})
“:” 在这里是用来分隔字典中的键和值的,类似于 “key : value”。在这段代码中,它用来将原有的键和值对应关系映射到新的键和值对应关系中。
具体来说,这个字典推导(dictionary comprehension)会对原有字典中的每一个键值对进行操作,通过将原有的键中的 ‘module.’ 替换为空字符串,来得到新的键,而值则不变,最后组成一个新的字典。
简单来说,就是将原有字典中的键更改为新的键,值不变,构建出新的字典。
这段代码是在加载一个预训练模型的状态字典,特别是PyTorch模型的状态字典,从指定的文件路径(model_path)中加载。使用torch.load()函数从文件中加载状态字典,然后使用字典推导修改结果字典,以删除键的“module.”前缀。这样做是因为在加载在多个GPU上训练的模型时,状态字典的键将有“module.”前缀。因此,这一行代码将状态字典键转换为与模型架构匹配。
items()
是 python 字典 (dictionary) 的一个内置函数,它会返回一个由字典中的键值对组成的元组 (tuple) 列表。 比如如果有个字典
my_dict = {'a': 1, 'b': 2, 'c': 3}
使用 items()
函数之后,会得到
my_dict.items()
#[('a', 1), ('b', 2), ('c', 3)]
在这段代码中,使用 items() 函数获取了字典中的所有键值对,并通过字典推导来更改键值对中的键,构建新的字典。
torch.load(model_path, map_location='cpu')
torch.load()
是 PyTorch 中的一个函数,用于从文件中加载已保存的 PyTorch 模型。传递给该函数的第一个参数是已保存模型的文件路径,在这种情况下是 model_path。
第二个参数 map_location='cpu'
是一个可选参数,用于指定应在其上加载模型的设备。默认情况下,模型在保存时的设备上加载。但是在这种情况下,模型在 CPU 上加载,而不是 GPU。这样可以在没有 GPU 的机器上加载在 GPU 上训练的模型,或在 GPU 和 CPU 之间切换。
通过传递 map_location=‘cpu’ 参数,它将张量转换为 CPU 张量,并将存储映射到 CPU。当你在 GPU 上训练了一个模型并希望在 CPU 上加载它时非常有用。
_, preds = preds.max(2)
这行代码是在选取preds张量第二维度上的最大值并返回最大值及其索引。
.max(dim)
是 PyTorch 中的一个函数,它会返回一个元组,包含给定维度上的最大值和最大值的索引,由 dim参数指定。在这种情况下,指定的维度是2。
_,
在preds之前是一种解包元组的方法,将最大值分配给变量preds,将索引值分配给变量 _
。变量_ 是Python中的一个特殊变量,常用来丢弃不需要的值。因此,这行代码实际上是在第二维度上取最大值并丢弃索引值。
h,w = img.shape[:2]
这行代码是用来解包一个图像的形状的,图像用变量img
表示,它把这个形状分别赋值给变量h
和w
。
img.shape
是返回一个元组,表示图像的维度,其中第一个元素表示高度,第二个元素表示宽度。
[:2]
是用来切片元组,取前两个元素,它们分别代表图像的高和宽。
super(TransBase, self).__init__()
class TransBase(object):
def __init__(self, probability = 1.):
super(TransBase, self).__init__()
self.probability = probability
@abc.abstractmethod
def tranfun(self, inputimage):
pass
# @utils.zlog
def process(self,inputimage):
if np.random.random() < self.probability:
return self.tranfun(inputimage)
else:
return inputimage
super()
是python中用来调用父类方法和属性的关键字。在这个例子中,super(TransBase, self)
调用了TransBase
类的父类的构造函数,用来初始化父类的一些属性和方法。这个调用是可选的,如果TransBase类的父类没有定义构造函数或者构造函数中没有需要初始化的属性和方法可以不用调用。
super()
函数有两个参数,第一个参数是当前类的名称,第二个参数是当前类的实例对象。通过调用父类的构造函数来初始化父类的一些属性和方法。
@abc.abstractmethod
@abc.abstractmethod
是python中用来标记一个方法为抽象方法的装饰器。这个装饰器是来自Python标准库中的abc(abstract base classes)模块。
使用@abc.abstractmethod
装饰器标记的方法是抽象方法,它不需要实现任何具体的功能,而是由子类来实现。如果在一个类中有一个或多个抽象方法,那么这个类就应该被标记为抽象类。
抽象类不能被实例化,只能用来定义接口。抽象类的子类必须实现父类的所有抽象方法。
通过使用@abc.abstractmethod
装饰器,可以确保子类实现了父类中抽象方法,这是一种面向对象编程的重要思想。
class RandomContrast(TransBase):
class RandomContrast(TransBase):
def setparam(self, lower=0.5, upper=1.5):
self.lower = lower
self.upper = upper
assert self.upper >= self.lower, "upper must be >= lower."
assert self.lower >= 0, "lower must be non-negative."
def tranfun(self, image):
image = getpilimage(image)
enh_con = ImageEnhance.Brightness(image)
return enh_con.enhance(random.uniform(self.lower, self.upper))
assert self.upper >= self.lower, "upper must be >= lower."
这是一个assert语句, 用来确保传入的参数满足特定条件。
assert语句的语法是:
assert expression [, message]
其中 expression
是要检验的表达式,如果表达式为假,那么将会触发AssertionError
异常,并且 message
会作为异常的参数输出。
在这里, “assert self.upper >= self.lower, "upper must be >= lower.
” 这个语句会检查 self.upper >= self.lower
是否为真,如果不为真,那么将会触发 AssertionError
异常,并且"upper must be >= lower
" 会作为异常的参数输出。
这个语句的作用是确保upper >= lower, 也就是说如果 upper < lower 会抛出 AssertionError 异常,并输出 “upper must be >= lower”.
enh_con = ImageEnhance.Brightness(image)
ImageEnhance.Brightness()
是Python Imaging Library (PIL) 中的一个类,用于实现图像的亮度增强。
该类的构造函数接受一个PIL格式的图像作为参数, 创建一个亮度增强器对象.
它提供了一个enhance()
方法来增强图像的亮度,该方法接受一个参数factor
,表示增强因子,值域在0~1之间,取值越大图像亮度越高,取值越小图像亮度越低。
在这个例子中, 使用了enhance()
方法来增加图像的对比度, 随机生成对比度值可以使用 random.uniform(self.lower, self.upper)
从给定的区间中随机生成一个值。
综上, ImageEnhance.Brightness()这个类是PIL中用来增强图像亮度的类, 可以用来实现对比度变换。
param = [int(cv2.IMWRITE_JPEG_QUALITY), random.randint(self.lower, self.upper)]
param 变量是一个列表, 其中包含两个元素:
第一个元素 int(cv2.IMWRITE_JPEG_QUALITY)
是一个整数, 表示使用的编码方式是IMWRITE_JPEG_QUALITY
,这是 OpenCV 中支持的一种图像编码方式。
第二个元素 random.randint(self.lower, self.upper)
是一个随机生成的整数,表示质量值。取值在self.lower
和self.upper
之间, 初始化时通过setparam设置.
当使用 cv2.imencode()
函数进行图像编码时,参数列表param就会被用来设置图像的编码方式和质量值。
这样就可以实现对图像质量的随机变换,并且在保证质量值在合理范围内的同时随机化质量值,避免过度改变图像质量导致图像变得过于模糊或过于清晰。
class Exposure(TransBase):
函数class Exposure(TransBase):
def setparam(self, lower=5, upper=10):
self.lower = lower
self.upper = upper
assert self.upper >= self.lower, "upper must be >= lower."
assert self.lower >= 0, "lower must be non-negative."
def tranfun(self, image):
image = trans_utils.getcvimage(image)
h,w = image.shape[:2]
x0 = random.randint(0, w)
y0 = random.randint(0, h)
x1 = random.randint(x0, w)
y1 = random.randint(y0, h)
transparent_area = (x0, y0, x1, y1)
mask=Image.new('L', (w, h), color=255)
draw=ImageDraw.Draw(mask)
mask = np.array(mask)
if len(image.shape)==3:
mask = mask[:,:,np.newaxis]
mask = np.concatenate([mask,mask,mask],axis=2)
draw.rectangle(transparent_area, fill=random.randint(150,255))
reflection_result = image + (255 - mask)
reflection_result = np.clip(reflection_result, 0, 255)
return trans_utils.cv2pil(reflection_result)
这是一段 Python 代码,它定义了一个名为 Exposure 的类,该类继承自 TransBase 类。Exposure 类具有两个方法 setparam 和 tranfun。
setparam
方法用于设置 lower 和 upper 两个参数,并对它们进行断言检查,确保 upper >= lower 且 lower >= 0。
tranfun
方法用于对输入的 image 进行随机曝光效果的处理。它首先使用 trans_utils.getcvimage
函数将 image 转换为 OpenCV 格式,然后随机生成一个区域 ( x 0 , y 0 , x 1 , y 1 ) (x_0,y_0,x_1,y_1) (x0,y0,x1,y1),在这个区域内进行曝光处理,最后返回结果。
mask=Image.new('L', (w, h), color=255)
draw=ImageDraw.Draw(mask)
mask = np.array(mask)
mask=Image.new('L', (w, h), color=255)
使用 PIL 创建一个大小为 (w, h) 且灰度图像,初始化颜色为255。draw=ImageDraw.Draw(mask)
创建一个画图对象,用来对刚创建的 mask
进行绘图操作。综上,这三行代码创建了一个大小为 (w, h),初始化颜色为255的灰度图像,并将其转换为numpy数组格式。
if len(image.shape)==3:
mask = mask[:,:,np.newaxis]
mask = np.concatenate([mask,mask,mask],axis=2)
这两行代码检查图像的通道数,如果是3通道的话,进行如下操作。
首先,if len(image.shape)==3:
判断 image 的通道数,如果是 3 通道,则进入 if
语句块。
接着, mask = mask[:,:,np.newaxis]
是将数组增加一个维度,从而使得mask数组与image数组维度相同。
最后,np.concatenate([mask,mask,mask],axis=2)
是将mask复制三遍,并在第三维度上合并成一个新数组。这样做的目的是为了使得mask数组的通道数与image数组的通道数相同。
综上,这两行代码的作用是将mask数组转换成与image数组维度相同,通道数也相同。
trans.py
文件#!/usr/bin/env python
#coding:utf-8
import sys
# reload(sys)
# sys.setdefaultencoding("utf-8")
import os, sys, shutil, math, random, json, multiprocessing, threading
from PIL import Image, ImageDraw, ImageFont, ImageChops
import cv2
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter, ImageOps
# import
import abc
import trans_utils
from trans_utils import getpilimage
global colormap
index = 0
class TransBase(object):
def __init__(self, probability = 1.):
super(TransBase, self).__init__()
self.probability = probability
@abc.abstractmethod
def tranfun(self, inputimage):
pass
# @utils.zlog
def process(self,inputimage):
if np.random.random() < self.probability:
return self.tranfun(inputimage)
else:
return inputimage
class RandomContrast(TransBase):
def setparam(self, lower=0.5, upper=1.5):
self.lower = lower
self.upper = upper
assert self.upper >= self.lower, "upper must be >= lower."
assert self.lower >= 0, "lower must be non-negative."
def tranfun(self, image):
image = getpilimage(image)
enh_con = ImageEnhance.Brightness(image)
return enh_con.enhance(random.uniform(self.lower, self.upper))
class RandomBrightness(TransBase):
def setparam(self, lower=0.5, upper=1.5):
self.lower = lower
self.upper = upper
assert self.upper >= self.lower, "upper must be >= lower."
assert self.lower >= 0, "lower must be non-negative."
def tranfun(self, image):
image = getpilimage(image)
bri = ImageEnhance.Brightness(image)
return bri.enhance(random.uniform(self.lower, self.upper))
class RandomColor(TransBase):
def setparam(self, lower=0.5, upper=1.5):
self.lower = lower
self.upper = upper
assert self.upper >= self.lower, "upper must be >= lower."
assert self.lower >= 0, "lower must be non-negative."
def tranfun(self, image):
image = getpilimage(image)
col = ImageEnhance.Color(image)
return col.enhance(random.uniform(self.lower, self.upper))
class RandomSharpness(TransBase):
def setparam(self, lower=0.5, upper=1.5):
self.lower = lower
self.upper = upper
assert self.upper >= self.lower, "upper must be >= lower."
assert self.lower >= 0, "lower must be non-negative."
def tranfun(self, image):
image = getpilimage(image)
sha = ImageEnhance.Sharpness(image)
return sha.enhance(random.uniform(self.lower, self.upper))
class Compress(TransBase):
def setparam(self, lower=5, upper=85):
self.lower = lower
self.upper = upper
assert self.upper >= self.lower, "upper must be >= lower."
assert self.lower >= 0, "lower must be non-negative."
def tranfun(self, image):
img = trans_utils.getcvimage(image)
param = [int(cv2.IMWRITE_JPEG_QUALITY), random.randint(self.lower, self.upper)]
img_encode = cv2.imencode('.jpeg', img, param)
img_decode = cv2.imdecode(img_encode[1], cv2.IMREAD_COLOR)
pil_img = trans_utils.cv2pil(img_decode)
if len(image.split())==1:
pil_img = pil_img.convert('L')
return pil_img
class Exposure(TransBase):
def setparam(self, lower=5, upper=10):
self.lower = lower
self.upper = upper
assert self.upper >= self.lower, "upper must be >= lower."
assert self.lower >= 0, "lower must be non-negative."
def tranfun(self, image):
image = trans_utils.getcvimage(image)
h,w = image.shape[:2]
x0 = random.randint(0, w)
y0 = random.randint(0, h)
x1 = random.randint(x0, w)
y1 = random.randint(y0, h)
transparent_area = (x0, y0, x1, y1)
mask=Image.new('L', (w, h), color=255)
draw=ImageDraw.Draw(mask)
mask = np.array(mask)
if len(image.shape)==3:
mask = mask[:,:,np.newaxis]
mask = np.concatenate([mask,mask,mask],axis=2)
draw.rectangle(transparent_area, fill=random.randint(150,255))
reflection_result = image + (255 - mask)
reflection_result = np.clip(reflection_result, 0, 255)
return trans_utils.cv2pil(reflection_result)
class Rotate(TransBase):
def setparam(self, lower=-5, upper=5):
self.lower = lower
self.upper = upper
assert self.upper >= self.lower, "upper must be >= lower."
# assert self.lower >= 0, "lower must be non-negative."
def tranfun(self, image):
image = getpilimage(image)
rot = random.uniform(self.lower, self.upper)
trans_img = image.rotate(rot, expand=True)
# trans_img.show()
return trans_img
class Blur(TransBase):
def setparam(self, lower=0, upper=1):
self.lower = lower
self.upper = upper
assert self.upper >= self.lower, "upper must be >= lower."
assert self.lower >= 0, "lower must be non-negative."
def tranfun(self, image):
image = getpilimage(image)
image = image.filter(ImageFilter.GaussianBlur(radius=1))
# blurred_image = image.filter(ImageFilter.Kernel((3,3), (1,1,1,0,0,0,2,0,2)))
# Kernel
return image
class Salt(TransBase):
def setparam(self, rate=0.02):
self.rate = rate
def tranfun(self, image):
image = getpilimage(image)
num_noise = int(image.size[1] * image.size[0] * self.rate)
# assert len(image.split()) == 1
for k in range(num_noise):
i = int(np.random.random() * image.size[1])
j = int(np.random.random() * image.size[0])
image.putpixel((j, i), int(np.random.random() * 255))
return image
class AdjustResolution(TransBase):
def setparam(self, max_rate=0.95,min_rate = 0.5):
self.max_rate = max_rate
self.min_rate = min_rate
def tranfun(self, image):
image = getpilimage(image)
w, h = image.size
rate = np.random.random()*(self.max_rate-self.min_rate)+self.min_rate
w2 = int(w*rate)
h2 = int(h*rate)
image = image.resize((w2, h2))
image = image.resize((w, h))
return image
class Crop(TransBase):
def setparam(self, maxv=2):
self.maxv = maxv
def tranfun(self, image):
img = trans_utils.getcvimage(image)
h,w = img.shape[:2]
org = np.array([[0,np.random.randint(0,self.maxv)],
[w,np.random.randint(0,self.maxv)],
[0,h-np.random.randint(0,self.maxv)],
[w,h-np.random.randint(0,self.maxv)]],np.float32)
dst = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
M = cv2.getPerspectiveTransform(org,dst)
res = cv2.warpPerspective(img,M,(w,h))
return getpilimage(res)
class Crop2(TransBase):
def setparam(self, maxv_h=4, maxv_w=4):
self.maxv_h = maxv_h
self.maxv_w = maxv_w
def tranfun(self, image_and_loc):
image, left, top, right, bottom = image_and_loc
w, h = image.size
left = np.clip(left,0,w-1)
right = np.clip(right,0,w-1)
top = np.clip(top, 0, h-1)
bottom = np.clip(bottom, 0, h-1)
img = trans_utils.getcvimage(image)
try:
# global index
res = getpilimage(img[top:bottom,left:right])
# res.save('test_imgs/crop-debug-{}.jpg'.format(index))
# index+=1
return res
except AttributeError as e:
print('error')
image.save('test_imgs/t.png')
print( left, top, right, bottom)
h = bottom - top
w = right - left
org = np.array([[left - np.random.randint(0, self.maxv_w), top + np.random.randint(-self.maxv_h, self.maxv_h//2)],
[right + np.random.randint(0, self.maxv_w), top + np.random.randint(-self.maxv_h, self.maxv_h//2)],
[left - np.random.randint(0, self.maxv_w), bottom - np.random.randint(-self.maxv_h, self.maxv_h//2)],
[right + np.random.randint(0, self.maxv_w), bottom - np.random.randint(-self.maxv_h, self.maxv_h//2)]], np.float32)
dst = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
M = cv2.getPerspectiveTransform(org,dst)
res = cv2.warpPerspective(img,M,(w,h))
return getpilimage(res)
class Stretch(TransBase):
def setparam(self, max_rate = 1.2,min_rate = 0.8):
self.max_rate = max_rate
self.min_rate = min_rate
def tranfun(self, image):
image = getpilimage(image)
w, h = image.size
rate = np.random.random()*(self.max_rate-self.min_rate)+self.min_rate
w2 = int(w*rate)
image = image.resize((w2, h))
return image
if __name__ == "__main__":
# img_name = 'test_files/NID 1468666480 (1) Front.jpg'
# img = Image.open(img_name)
# w,h = img.size
#
# img.show()
# rc = Crop2()
# rc.setparam()
# img = rc.process([img,362,418,581,463])
# # img = ImageOps.invert(img)
# img.show()
img_name = 'data_set/images_0701_EC_3/0.png'
img = Image.open(img_name)
print(img.size[1])
w, h = img.size
img_cv = trans_utils.pil2cv(img)
print(img_cv.shape)
# print(len(img.split()))
img.show()
# img = cv2.imread(img_name)
rc = Compress()
rc.setparam()
img = rc.process(img)
# img = ImageOps.invert(img)
img.show()