Python图像逐像素点取邻域数据

Python图像逐像素点取邻域数据

图像比较大的话,在MATLAB上跑起来比较慢,用Python跑就会快很多,贴此备用吧!

#coding=utf-8

import pandas as pd
import numpy as np
from pandas import DataFrame
from matplotlib import pyplot as plt
from matplotlib import image
import scipy
import cv2
import scipy.io as sio

#原始数据四周补0
def pad_data(data,nei_size):
    m,n = data.shape
    t1 = np.zeros([nei_size//2,n])
    data = np.concatenate((t1,data,t1))
    m,n = data.shape
    t2 = np.zeros([m,nei_size//2])
    data = np.concatenate((t2,data,t2),axis=1)  
    return data

#逐像素取大小为nei_size*nei_size的邻域数据
def gen_dataX(data,nei_size):
    x,y = data.shape
    m = x-nei_size//2*2;n = y-nei_size//2*2
    res = np.zeros([m*n,nei_size**2])
    print m,n
    k = 0
    for i in range(nei_size//2,m+nei_size//2):
        for j in range(nei_size//2,n+nei_size//2):
            res[k,:] = np.reshape(data[i-nei_size//2:i+nei_size//2+1,j-nei_size//2:j+nei_size//2+1].T,(1,-1))
            k += 1
    print k
    return res

im = sio.loadmat('data/im1.mat');
im1 = im1['im1']
nei_size=5
#邻域取训练数据
im1= pad_data(im1,nei_size)
data = gen_dataX(im1,nei_size)
sio.savemat("results/"+str(kk)+"/dataX.mat", {'dataX':dataX}) 

你可能感兴趣的:(python,机器学习)