CRNN模型Python实现笔记二

文章目录

    • 一、函数讲解
      • 1. `ImageEnhance.Sharpness(image)`类
      • 2. `cv2.imencode()`函数
      • 3. `random.randint()` 函数
      • 4. `Image.new()`函数
      • 5. `ImageDraw.Draw()`函数
      • 6. `np.clip()`函数
    • 二、疑难代码段讲解
      • 1. 字典推导(dictionary comprehension)
      • 2. `torch.load(model_path, map_location='cpu')`
      • 3. `_, preds = preds.max(2)`
      • 4. `h,w = img.shape[:2]`
      • 5. `super(TransBase, self).__init__()`
      • 6. `class RandomContrast(TransBase):`
      • 7. `param = [int(cv2.IMWRITE_JPEG_QUALITY), random.randint(self.lower, self.upper)]`
      • 8. `class Exposure(TransBase):`函数
    • 三、附录`trans.py`文件

一、函数讲解

1. ImageEnhance.Sharpness(image)

ImageEnhance.Sharpness(image) 是 PIL 中的一个类, 用于实现图像的锐化。

该类的构造函数接受一个PIL格式的图像作为参数,创建一个锐化器对象。

它提供了一个enhance()方法来锐化图像,该方法接受一个参数factor,表示锐化因子,值域在0~1之间,取值越大图像锐化越高,取值越小图像锐化越低。

如果你想使用这个类来增强图像的锐度,你需要先创建一个ImageEnhance.Sharpness的对象,然后调用enhance()方法来增强图像锐度。

例如:

image = Image.open("image.jpg")
sharp_image = ImageEnhance.Sharpness(image)
image = sharp_image.enhance(1.5)
image.save("sharp_image.jpg")

这样就可以将原图像增强锐度1.5倍后保存成"sharp_image.jpg"

简单来说,ImageEnhance.Sharpness() 这个类是PIL中用来增强图像锐度的类, 可以用来实现锐度变换。

2. cv2.imencode()函数

cv2.imencode() 函数接受三个参数,分别是:

  • ext: 图像格式,字符串类型,如 ‘.jpg’, ‘.png’, ‘.bmp’ 等。
  • img: 图像数据,numpy 数组类型。
  • params: 编码参数,可以是一个整数或一个整数列表,具体取决于所使用的编码方式。

例如:

param = [int(cv2.IMWRITE_JPEG_QUALITY), random.randint(self.lower, self.upper)]
img_encode = cv2.imencode('.jpeg', img, param)
img_decode = cv2.imdecode(img_encode[1], cv2.IMREAD_COLOR)

在上面的例子中,第一个参数 ext 传入的是 '.jpeg' 代表使用JPEG格式进行编码。
第二个参数 img 传入的是 cv2 图像数据。
第三个参数 params 传入的是列表, [int(cv2.IMWRITE_JPEG_QUALITY), random.randint(self.lower, self.upper)] 代表使用JPEG编码格式并且设置质量值。

其它编码格式也可以使用这个函数,只需要在params中添加不同的参数即可,比如使用 PNG 格式编码时,params 可以是 [int(cv2.IMWRITE_PNG_COMPRESSION), 9] 代表设置压缩级别。

3. random.randint() 函数

random.randint() 是 Python 标准库 random 模块中的一个函数,它可以生成一个在给定范围内的随机整数。

语法:random.randint(a, b)

其中 a 是随机整数的最小值,b 是随机整数的最大值,返回一个 a 到 b 之间的随机整数(包括 a 和 b)。

在附录的代码中,random.randint(0, w) 用于随机生成一个宽度范围内的值,random.randint(0, h) 用于随机生成一个高度范围内的值,random.randint(x0, w) 用于随机生成一个宽度范围内的值,random.randint(y0, h) 用于随机生成一个高度范围内的值。

4. Image.new()函数

Image.new() 是 Python Imaging Library (PIL) 中的一个函数,它可以创建一个新图像。

用法:Image.new(mode, size, color),其中 mode 是图像模式('L' 代表灰度图像,'RGB' 代表彩色图像,等等),size 是图像大小(用元组 (width, height) 表示),color 是填充颜色。

例如:

mask=Image.new('L', (w, h), color=255)

在上面的代码中,mask=Image.new('L', (w, h), color=255) 用于创建一个大小为 (w, h) 且灰度图像,初始化颜色为255的图像。

5. ImageDraw.Draw()函数

ImageDraw.Draw() 是 Python Imaging Library (PIL) 中的一个函数,它可以创建一个画图对象。

用法:ImageDraw.Draw(image),其中 image 是一个 PIL 图像对象。

例如:

draw=ImageDraw.Draw(mask)
draw.rectangle(transparent_area, fill=random.randint(150,255))

在上面的代码中,draw=ImageDraw.Draw(mask) 用于创建一个画图对象,它可以对 mask 这个 PIL 图像对象进行绘图操作。

draw.rectangle() 是 PIL 库中的函数,可以在一个图像上画矩形。

用法:draw.rectangle(box, options),其中 box 是矩形的坐标(用元组 (x0, y0, x1, y1) 表示),options 是矩形的其它选项,如填充颜色(fill)等。

在这里,draw.rectangle(transparent_area, fill=random.randint(150,255)) 这一行代码用于在mask这个图像上画一个矩形,坐标是transparent_area,填充颜色是随机生成的一个150~255之间的颜色值。




6. np.clip()函数

np.clip() 是 numpy 中的一个函数,可以将数组中的元素限制在一个给定的最小值和最大值之间。

用法:numpy.clip(a, a_min, a_max, out=None),其中 a 是输入数组,a_min 是最小值,a_max 是最大值。

reflection_result = np.clip(reflection_result, 0, 255)

在这里,reflection_result = np.clip(reflection_result, 0, 255) 这一行代码用于将 reflection_result 数组中的所有元素限制在 0 和 255 之间。这样做的目的是为了确保图像的像素值在合理的范围内,避免出现像素值过大或过小的情况。




二、疑难代码段讲解

1. 字典推导(dictionary comprehension)

self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path).items()})

“:” 在这里是用来分隔字典中的键和值的,类似于 “key : value”。在这段代码中,它用来将原有的键和值对应关系映射到新的键和值对应关系中。

具体来说,这个字典推导(dictionary comprehension)会对原有字典中的每一个键值对进行操作,通过将原有的键中的 ‘module.’ 替换为空字符串,来得到新的键,而值则不变,最后组成一个新的字典。

简单来说,就是将原有字典中的键更改为新的键,值不变,构建出新的字典。

这段代码是在加载一个预训练模型的状态字典,特别是PyTorch模型的状态字典,从指定的文件路径(model_path)中加载。使用torch.load()函数从文件中加载状态字典,然后使用字典推导修改结果字典,以删除键的“module.”前缀。这样做是因为在加载在多个GPU上训练的模型时,状态字典的键将有“module.”前缀。因此,这一行代码将状态字典键转换为与模型架构匹配。

items() 是 python 字典 (dictionary) 的一个内置函数,它会返回一个由字典中的键值对组成的元组 (tuple) 列表。 比如如果有个字典

    my_dict = {'a': 1, 'b': 2, 'c': 3}

使用 items() 函数之后,会得到

    my_dict.items()
    #[('a', 1), ('b', 2), ('c', 3)]

在这段代码中,使用 items() 函数获取了字典中的所有键值对,并通过字典推导来更改键值对中的键,构建新的字典。

2. torch.load(model_path, map_location='cpu')

torch.load() 是 PyTorch 中的一个函数,用于从文件中加载已保存的 PyTorch 模型。传递给该函数的第一个参数是已保存模型的文件路径,在这种情况下是 model_path。

第二个参数 map_location='cpu' 是一个可选参数,用于指定应在其上加载模型的设备。默认情况下,模型在保存时的设备上加载。但是在这种情况下,模型在 CPU 上加载,而不是 GPU。这样可以在没有 GPU 的机器上加载在 GPU 上训练的模型,或在 GPU 和 CPU 之间切换。

通过传递 map_location=‘cpu’ 参数,它将张量转换为 CPU 张量,并将存储映射到 CPU。当你在 GPU 上训练了一个模型并希望在 CPU 上加载它时非常有用。




3. _, preds = preds.max(2)

这行代码是在选取preds张量第二维度上的最大值并返回最大值及其索引。
.max(dim) 是 PyTorch 中的一个函数,它会返回一个元组,包含给定维度上的最大值和最大值的索引,由 dim参数指定。在这种情况下,指定的维度是2。

_, 在preds之前是一种解包元组的方法,将最大值分配给变量preds,将索引值分配给变量 _。变量_ 是Python中的一个特殊变量,常用来丢弃不需要的值。因此,这行代码实际上是在第二维度上取最大值并丢弃索引值。




4. h,w = img.shape[:2]

这行代码是用来解包一个图像的形状的,图像用变量img表示,它把这个形状分别赋值给变量hw

img.shape是返回一个元组,表示图像的维度,其中第一个元素表示高度,第二个元素表示宽度。

[:2] 是用来切片元组,取前两个元素,它们分别代表图像的高和宽。




5. super(TransBase, self).__init__()

class TransBase(object):
    def __init__(self, probability = 1.):
        super(TransBase, self).__init__()
        self.probability = probability
    @abc.abstractmethod
    def tranfun(self, inputimage):
        pass
    # @utils.zlog
    def process(self,inputimage):
        if np.random.random() < self.probability:
            return self.tranfun(inputimage)
        else:
            return inputimage

super()是python中用来调用父类方法和属性的关键字。在这个例子中,super(TransBase, self)调用了TransBase类的父类的构造函数,用来初始化父类的一些属性和方法。这个调用是可选的,如果TransBase类的父类没有定义构造函数或者构造函数中没有需要初始化的属性和方法可以不用调用。

super()函数有两个参数,第一个参数是当前类的名称,第二个参数是当前类的实例对象。通过调用父类的构造函数来初始化父类的一些属性和方法。

@abc.abstractmethod

@abc.abstractmethod是python中用来标记一个方法为抽象方法的装饰器。这个装饰器是来自Python标准库中的abc(abstract base classes)模块

使用@abc.abstractmethod装饰器标记的方法是抽象方法,它不需要实现任何具体的功能,而是由子类来实现。如果在一个类中有一个或多个抽象方法,那么这个类就应该被标记为抽象类。

抽象类不能被实例化,只能用来定义接口。抽象类的子类必须实现父类的所有抽象方法。

通过使用@abc.abstractmethod装饰器,可以确保子类实现了父类中抽象方法,这是一种面向对象编程的重要思想。




6. class RandomContrast(TransBase):

class RandomContrast(TransBase):
    def setparam(self, lower=0.5, upper=1.5):
        self.lower = lower
        self.upper = upper
        assert self.upper >= self.lower, "upper must be >= lower."
        assert self.lower >= 0, "lower must be non-negative."
    def tranfun(self, image):
        image = getpilimage(image)
        enh_con = ImageEnhance.Brightness(image)
        return enh_con.enhance(random.uniform(self.lower, self.upper))



assert self.upper >= self.lower, "upper must be >= lower."

这是一个assert语句, 用来确保传入的参数满足特定条件。

assert语句的语法是:
assert expression [, message]

其中 expression 是要检验的表达式,如果表达式为假,那么将会触发AssertionError异常,并且 message 会作为异常的参数输出。

在这里, “assert self.upper >= self.lower, "upper must be >= lower.” 这个语句会检查 self.upper >= self.lower 是否为真,如果不为真,那么将会触发 AssertionError 异常,并且"upper must be >= lower" 会作为异常的参数输出。

这个语句的作用是确保upper >= lower, 也就是说如果 upper < lower 会抛出 AssertionError 异常,并输出 “upper must be >= lower”.

enh_con = ImageEnhance.Brightness(image)

ImageEnhance.Brightness()是Python Imaging Library (PIL) 中的一个类,用于实现图像的亮度增强

该类的构造函数接受一个PIL格式的图像作为参数, 创建一个亮度增强器对象.

它提供了一个enhance()方法来增强图像的亮度,该方法接受一个参数factor,表示增强因子,值域在0~1之间,取值越大图像亮度越高,取值越小图像亮度越低

在这个例子中, 使用了enhance()方法来增加图像的对比度, 随机生成对比度值可以使用 random.uniform(self.lower, self.upper) 从给定的区间中随机生成一个值。

综上, ImageEnhance.Brightness()这个类是PIL中用来增强图像亮度的类, 可以用来实现对比度变换。

7. param = [int(cv2.IMWRITE_JPEG_QUALITY), random.randint(self.lower, self.upper)]

param 变量是一个列表, 其中包含两个元素:

第一个元素 int(cv2.IMWRITE_JPEG_QUALITY) 是一个整数, 表示使用的编码方式是IMWRITE_JPEG_QUALITY,这是 OpenCV 中支持的一种图像编码方式。
第二个元素 random.randint(self.lower, self.upper) 是一个随机生成的整数,表示质量值。取值在self.lowerself.upper之间, 初始化时通过setparam设置.
当使用 cv2.imencode() 函数进行图像编码时,参数列表param就会被用来设置图像的编码方式和质量值。

这样就可以实现对图像质量的随机变换,并且在保证质量值在合理范围内的同时随机化质量值,避免过度改变图像质量导致图像变得过于模糊或过于清晰。

8. class Exposure(TransBase):函数

class Exposure(TransBase):
    def setparam(self, lower=5, upper=10):
        self.lower = lower
        self.upper = upper
        assert self.upper >= self.lower, "upper must be >= lower."
        assert self.lower >= 0, "lower must be non-negative."
    def tranfun(self, image):
        image = trans_utils.getcvimage(image)
        h,w = image.shape[:2]
        x0 = random.randint(0, w)
        y0 = random.randint(0, h)
        x1 = random.randint(x0, w)
        y1 = random.randint(y0, h)
        transparent_area = (x0, y0, x1, y1)
        mask=Image.new('L', (w, h), color=255)
        draw=ImageDraw.Draw(mask)
        mask = np.array(mask)
        if len(image.shape)==3:
            mask = mask[:,:,np.newaxis]
            mask = np.concatenate([mask,mask,mask],axis=2)
        draw.rectangle(transparent_area, fill=random.randint(150,255))
        reflection_result = image + (255 - mask)
        reflection_result = np.clip(reflection_result, 0, 255)
        return trans_utils.cv2pil(reflection_result)

这是一段 Python 代码,它定义了一个名为 Exposure 的类,该类继承自 TransBase 类。Exposure 类具有两个方法 setparam 和 tranfun。

setparam 方法用于设置 lower 和 upper 两个参数,并对它们进行断言检查,确保 upper >= lower 且 lower >= 0。

tranfun 方法用于对输入的 image 进行随机曝光效果的处理。它首先使用 trans_utils.getcvimage 函数将 image 转换为 OpenCV 格式,然后随机生成一个区域 ( x 0 , y 0 , x 1 , y 1 ) (x_0,y_0,x_1,y_1) (x0,y0,x1,y1),在这个区域内进行曝光处理,最后返回结果。


mask=Image.new('L', (w, h), color=255)
draw=ImageDraw.Draw(mask)
mask = np.array(mask)
  • 第一行代码 mask=Image.new('L', (w, h), color=255) 使用 PIL 创建一个大小为 (w, h) 且灰度图像,初始化颜色为255。
  • 第二行代码 draw=ImageDraw.Draw(mask) 创建一个画图对象,用来对刚创建的 mask 进行绘图操作。
  • 第三行代码 mask = np.array(mask) 将 PIL 图像对象转换为 numpy 数组,方便后续的图像处理。

综上,这三行代码创建了一个大小为 (w, h),初始化颜色为255的灰度图像,并将其转换为numpy数组格式。


if len(image.shape)==3:
mask = mask[:,:,np.newaxis]
mask = np.concatenate([mask,mask,mask],axis=2)

这两行代码检查图像的通道数,如果是3通道的话,进行如下操作。

首先,if len(image.shape)==3: 判断 image 的通道数,如果是 3 通道,则进入 if语句块。

接着, mask = mask[:,:,np.newaxis] 是将数组增加一个维度,从而使得mask数组与image数组维度相同

最后,np.concatenate([mask,mask,mask],axis=2) 是将mask复制三遍,并在第三维度上合并成一个新数组。这样做的目的是为了使得mask数组的通道数与image数组的通道数相同。

综上,这两行代码的作用是将mask数组转换成与image数组维度相同,通道数也相同。




三、附录trans.py文件

#!/usr/bin/env python
#coding:utf-8
import sys
# reload(sys)
# sys.setdefaultencoding("utf-8")
import os, sys, shutil, math, random, json, multiprocessing, threading
from PIL import Image, ImageDraw, ImageFont, ImageChops
import cv2
import numpy as np
from PIL import Image, ImageEnhance, ImageFilter, ImageOps
# import 
import abc
import trans_utils
from trans_utils import getpilimage



global colormap
index = 0

class TransBase(object):
    def __init__(self, probability = 1.):
        super(TransBase, self).__init__()
        self.probability = probability
    @abc.abstractmethod
    def tranfun(self, inputimage):
        pass
    # @utils.zlog
    def process(self,inputimage):
        if np.random.random() < self.probability:
            return self.tranfun(inputimage)
        else:
            return inputimage

class RandomContrast(TransBase):
    def setparam(self, lower=0.5, upper=1.5):
        self.lower = lower
        self.upper = upper
        assert self.upper >= self.lower, "upper must be >= lower."
        assert self.lower >= 0, "lower must be non-negative."
    def tranfun(self, image):
        image = getpilimage(image)
        enh_con = ImageEnhance.Brightness(image)
        return enh_con.enhance(random.uniform(self.lower, self.upper))

class RandomBrightness(TransBase):
    def setparam(self, lower=0.5, upper=1.5):
        self.lower = lower
        self.upper = upper
        assert self.upper >= self.lower, "upper must be >= lower."
        assert self.lower >= 0, "lower must be non-negative."
    def tranfun(self, image):
        image = getpilimage(image)
        bri = ImageEnhance.Brightness(image)
        return bri.enhance(random.uniform(self.lower, self.upper))

class RandomColor(TransBase):
    def setparam(self, lower=0.5, upper=1.5):
        self.lower = lower
        self.upper = upper
        assert self.upper >= self.lower, "upper must be >= lower."
        assert self.lower >= 0, "lower must be non-negative."
    def tranfun(self, image):
        image = getpilimage(image)
        col = ImageEnhance.Color(image)
        return col.enhance(random.uniform(self.lower, self.upper))

class RandomSharpness(TransBase):
    def setparam(self, lower=0.5, upper=1.5):
        self.lower = lower
        self.upper = upper
        assert self.upper >= self.lower, "upper must be >= lower."
        assert self.lower >= 0, "lower must be non-negative."
    def tranfun(self, image):
        image = getpilimage(image)
        sha = ImageEnhance.Sharpness(image)
        return sha.enhance(random.uniform(self.lower, self.upper))

class Compress(TransBase):
    def setparam(self, lower=5, upper=85):
        self.lower = lower
        self.upper = upper
        assert self.upper >= self.lower, "upper must be >= lower."
        assert self.lower >= 0, "lower must be non-negative."
    def tranfun(self, image):
        img = trans_utils.getcvimage(image)
        param = [int(cv2.IMWRITE_JPEG_QUALITY), random.randint(self.lower, self.upper)]
        img_encode = cv2.imencode('.jpeg', img, param)
        img_decode = cv2.imdecode(img_encode[1], cv2.IMREAD_COLOR)
        pil_img = trans_utils.cv2pil(img_decode)
        if len(image.split())==1:
            pil_img = pil_img.convert('L')
        return pil_img

class Exposure(TransBase):
    def setparam(self, lower=5, upper=10):
        self.lower = lower
        self.upper = upper
        assert self.upper >= self.lower, "upper must be >= lower."
        assert self.lower >= 0, "lower must be non-negative."
    def tranfun(self, image):
        image = trans_utils.getcvimage(image)
        h,w = image.shape[:2]
        x0 = random.randint(0, w)
        y0 = random.randint(0, h)
        x1 = random.randint(x0, w)
        y1 = random.randint(y0, h)
        transparent_area = (x0, y0, x1, y1)
        mask=Image.new('L', (w, h), color=255)
        draw=ImageDraw.Draw(mask)
        mask = np.array(mask)
        if len(image.shape)==3:
            mask = mask[:,:,np.newaxis]
            mask = np.concatenate([mask,mask,mask],axis=2)
        draw.rectangle(transparent_area, fill=random.randint(150,255))
        reflection_result = image + (255 - mask)
        reflection_result = np.clip(reflection_result, 0, 255)
        return trans_utils.cv2pil(reflection_result)

class Rotate(TransBase):
    def setparam(self, lower=-5, upper=5):
        self.lower = lower
        self.upper = upper
        assert self.upper >= self.lower, "upper must be >= lower."
        # assert self.lower >= 0, "lower must be non-negative."
    def tranfun(self, image):
        image = getpilimage(image)
        rot = random.uniform(self.lower, self.upper)
        trans_img = image.rotate(rot, expand=True)
        # trans_img.show()
        return trans_img

class Blur(TransBase):
    def setparam(self, lower=0, upper=1):
        self.lower = lower
        self.upper = upper
        assert self.upper >= self.lower, "upper must be >= lower."
        assert self.lower >= 0, "lower must be non-negative."
    def tranfun(self, image):
        image = getpilimage(image)
        image = image.filter(ImageFilter.GaussianBlur(radius=1))
        # blurred_image = image.filter(ImageFilter.Kernel((3,3), (1,1,1,0,0,0,2,0,2)))
        # Kernel
        return image

class Salt(TransBase):
    def setparam(self, rate=0.02):
        self.rate = rate
    def tranfun(self, image):
        image = getpilimage(image)
        num_noise = int(image.size[1] * image.size[0] * self.rate)
        # assert len(image.split()) == 1
        for k in range(num_noise):
            i = int(np.random.random() * image.size[1])
            j = int(np.random.random() * image.size[0])
            image.putpixel((j, i), int(np.random.random() * 255))
        return image


class AdjustResolution(TransBase):
    def setparam(self, max_rate=0.95,min_rate = 0.5):
        self.max_rate = max_rate
        self.min_rate = min_rate

    def tranfun(self, image):
        image = getpilimage(image)
        w, h = image.size
        rate = np.random.random()*(self.max_rate-self.min_rate)+self.min_rate
        w2 = int(w*rate)
        h2 = int(h*rate)
        image = image.resize((w2, h2))
        image = image.resize((w, h))
        return image


class Crop(TransBase):
    def setparam(self, maxv=2):
        self.maxv = maxv
    def tranfun(self, image):
        img = trans_utils.getcvimage(image)
        h,w = img.shape[:2]
        org = np.array([[0,np.random.randint(0,self.maxv)],
                        [w,np.random.randint(0,self.maxv)],
                        [0,h-np.random.randint(0,self.maxv)],
                        [w,h-np.random.randint(0,self.maxv)]],np.float32)
        dst = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
        M = cv2.getPerspectiveTransform(org,dst)
        res = cv2.warpPerspective(img,M,(w,h))
        return getpilimage(res)

class Crop2(TransBase):
    def setparam(self, maxv_h=4, maxv_w=4):
        self.maxv_h = maxv_h
        self.maxv_w = maxv_w
    def tranfun(self, image_and_loc):
        image, left, top, right, bottom = image_and_loc
        w, h = image.size
        left = np.clip(left,0,w-1)
        right = np.clip(right,0,w-1)
        top = np.clip(top, 0, h-1)
        bottom = np.clip(bottom, 0, h-1)
        img = trans_utils.getcvimage(image)
        try:
            # global index
            res = getpilimage(img[top:bottom,left:right])
            # res.save('test_imgs/crop-debug-{}.jpg'.format(index))
            # index+=1
            return res
        except AttributeError as e:
            print('error')
            image.save('test_imgs/t.png')
            print( left, top, right, bottom)

        h = bottom - top
        w = right - left
        org = np.array([[left - np.random.randint(0, self.maxv_w), top + np.random.randint(-self.maxv_h, self.maxv_h//2)],
                        [right + np.random.randint(0, self.maxv_w), top + np.random.randint(-self.maxv_h, self.maxv_h//2)],
                        [left - np.random.randint(0, self.maxv_w), bottom - np.random.randint(-self.maxv_h, self.maxv_h//2)],
                        [right + np.random.randint(0, self.maxv_w), bottom - np.random.randint(-self.maxv_h, self.maxv_h//2)]], np.float32)
        dst = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32)
        M = cv2.getPerspectiveTransform(org,dst)
        res = cv2.warpPerspective(img,M,(w,h))
        return getpilimage(res)

class Stretch(TransBase):
    def setparam(self, max_rate = 1.2,min_rate = 0.8):
        self.max_rate = max_rate
        self.min_rate = min_rate

    def tranfun(self, image):
        image = getpilimage(image)
        w, h = image.size
        rate = np.random.random()*(self.max_rate-self.min_rate)+self.min_rate
        w2 = int(w*rate)
        image = image.resize((w2, h))
        return image


if __name__ == "__main__":
    # img_name = 'test_files/NID 1468666480 (1) Front.jpg'
    # img = Image.open(img_name)
    # w,h = img.size
    #
    # img.show()
    # rc = Crop2()
    # rc.setparam()
    # img = rc.process([img,362,418,581,463])
    # # img = ImageOps.invert(img)
    # img.show()

    img_name = 'data_set/images_0701_EC_3/0.png'
    img = Image.open(img_name)
    print(img.size[1])
    w, h = img.size
    img_cv = trans_utils.pil2cv(img)
    print(img_cv.shape)
    # print(len(img.split()))

    img.show()
    # img = cv2.imread(img_name)
    rc = Compress()
    rc.setparam()
    img = rc.process(img)
    # img = ImageOps.invert(img)
    img.show()


你可能感兴趣的:(机器学习,python,opencv,计算机视觉)