import numpy as np
import cv2
import matplotlib.pyplot as plt
# # file_path = 'D:\\code_python\\KAIR\\visualization\\x0.png'
# file_path = 'D:\\dataset\\test\\classic5\\lena.bmp'
# img = cv2.imread(file_path)[:,:,0] #cv2默认是BGR通道顺序,这里调整到RGB
# # img = cv2.resize(img,(500,500))
# fre = np.fft.fft2(img,axes=(0,1)) #变换得到的频域图数据是复数组成的
# fre_m = np.abs(fre) #幅度谱,求模得到
# fre_p = np.angle(fre) #相位谱,求相角得到
# fre_m_real_log = np.log(fre_m)
# plt.figure()
# plt.imshow(fre_m_real_log.astype('uint8'))
# fre_shift = np.fft.fftshift(fre)
# fre_shift_real = np.abs(fre_shift)
# fre_shift_real_log = np.log(fre_shift_real)
# img2 = cv2.imread(file_path)[:,:,1] #cv2默认是BGR通道顺序,这里调整到RGB
# # img = cv2.resize(img,(500,500))
# fre2 = np.fft.fft2(img2,axes=(0,1)) #变换得到的频域图数据是复数组成的
# fre_shift2 = np.fft.fftshift(fre2)
# fre_shift_real2 = np.abs(fre_shift2)
# fre_shift_real_log2 = np.log(fre_shift_real2)
# #把振幅设为常数
# constant = fre_m.mean()
# fre_ = constant * np.e**(1j*fre_p) #把幅度谱和相位谱再合并为复数形式的频域图数据
# img_onlyphase = np.abs(np.fft.ifft2(fre_,axes=(0,1))) #还原为空间域图像
# #把相位设为常数
# constant = fre_p.mean()
# fre_ = fre_m * np.e**(1j*constant)
# img_onlymagnitude = np.abs(np.fft.ifft2(fre_,axes=(0,1)))
# plt.figure()
# plt.imshow(fre.astype('uint8'))
# plt.figure()
# plt.imshow(fre_shift_real_log.astype('uint8'))
# plt.figure()
# plt.imshow(fre_shift_real_log2.astype('uint8'))
# # plt.figure()
# # img_onlyphase_hist = plt.hist(img_onlyphase.astype('uint8').ravel())
# # plt.figure()
# # img_onlymagnitude_hist = plt.hist(img_onlymagnitude.astype('uint8').ravel())
# # plt.figure()
# # img_hist = plt.hist(img.astype('uint8').ravel())
# plt.figure()
# plt.imshow(img.astype('uint8'))
# # plt.figure()
# # plt.imshow(img_onlyphase.astype('uint8'))
# # plt.figure()
# # plt.imshow(img_onlymagnitude.astype('uint8'))
# # plt.figure()
# # plt.imshow(fre.astype('uint8'))
# # plt.figure()
# # plt.imshow(fre_m.astype('uint8'))
# # plt.figure()
# # plt.imshow(fre_p.astype('uint8'))
# # plt.figure()
# # plt.imshow(fre_.astype('uint8'))
# plt.show()
######### plot frequency curve
#### one image
import os
def plot_fre(path, legend_label):
img = cv2.imread(path)
# img = cv2.resize(img, (0,0), fx=0.5, fy=0.5)
img_YCrCb = cv2.cvtColor(img, cv2.COLOR_BGR2YCrCb)
img_Y = img_YCrCb[:,:,0] #cv2默认是BGR通道顺序,这里调整到RGB
# img = cv2.resize(img,(500,500))
fre = np.fft.fft2(img_Y,axes=(0,1)) #变换得到的频域图数据是复数组成的
# fre = np.fft.fftshift(fre)
fre_m = np.abs(fre) #幅度谱,求模得到
fre_p = np.angle(fre) #相位谱,求相角得到
fre_m_real_log = np.log(fre_m)
# x = np.linspace(0,128,128)
# y = np.linspace(0,1,128)
# xx, yy = np.meshgrid(x, y)
# # plt.contourf(xx,yy,fre_m_real_log.astype('uint8'))
# # plt.show()
# plt.figure()
# plt.imshow(fre_m_real_log.astype('uint8'))
# plt.xticks([])
# plt.yticks([])
# plt.show()
# ## only x
# x = [_ for _ in range(fre_m_real_log.shape[-1]//2)]
# y = [fre_m_real_log[0,i] for i in range(fre_m_real_log.shape[-1]//2)]
# plt.figure()
# plt.plot(x,y,label='only x')
# ## x and y
length = min(fre_m_real_log.shape[0]//2, fre_m_real_log.shape[1]//2)
x = [_ for _ in range(length)]
y = [(sum(fre_m_real_log[i,j] for j in range(i))+sum(fre_m_real_log[j,i] for j in range(i))+fre_m_real_log[i,i])/(2*i+1) for i in range(length)]
# plt.figure()
plt.plot(x,y,label=legend_label)
path1 = 'D:\dataset\SIDD_Valid\img\gt'
path2 = 'D:\dataset\SIDD_Valid\img\\noisy'
img_name = 'test_SIDD_702.png'
plot_fre(os.path.join(path1, img_name), img_name[:-4]+'_gt')
plot_fre(os.path.join(path2, img_name), img_name[:-4]+'_noisy')
plt.legend()
plt.show()
#### dataset
# import os
# dataset_path = 'D:\dataset\SIDD_Valid\img\\noisy'
# dataset_path_gt = 'D:\dataset\SIDD_Valid\img\gt'
# img_names = os.listdir(dataset_path)
# for ii, img_name in enumerate(img_names):
# img_rgb = cv2.imread(os.path.join(dataset_path, img_name))
# img_YCrCb = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2YCrCb)
# img_Y = img_YCrCb[:,:,0]
# fre = np.fft.fft2(img_Y,axes=(0,1))
# fre_m = np.abs(fre) #幅度谱,求模得到
# fre_p = np.angle(fre) #相位谱,求相角得到
# fre_m_real_log = np.log(fre_m)
# length = min(fre_m_real_log.shape[0]//2, fre_m_real_log.shape[1]//2)
# x = [_ for _ in range(length)]
# y = [(sum(fre_m_real_log[i,j] for j in range(i))+sum(fre_m_real_log[j,i] for j in range(i))+fre_m_real_log[i,i])/(2*i+1) for i in range(length)]
# if ii==0:
# y_sum=y
# else:
# y_sum = [y[i]+y_sum[i] for i in range(length)]
# img_rgb_gt = cv2.imread(os.path.join(dataset_path_gt, img_name))
# img_YCrCb_gt = cv2.cvtColor(img_rgb_gt, cv2.COLOR_BGR2YCrCb)
# img_Y_gt = img_YCrCb_gt[:,:,0]
# fre_gt = np.fft.fft2(img_Y_gt,axes=(0,1))
# fre_m_gt = np.abs(fre_gt) #幅度谱,求模得到
# fre_p_gt = np.angle(fre_gt) #相位谱,求相角得到
# fre_m_real_log_gt = np.log(fre_m_gt)
# length = min(fre_m_real_log_gt.shape[0]//2, fre_m_real_log_gt.shape[1]//2)
# x_gt = [_ for _ in range(length)]
# y_gt = [(sum(fre_m_real_log_gt[i,j] for j in range(i))+sum(fre_m_real_log_gt[j,i] for j in range(i))+fre_m_real_log_gt[i,i])/(2*i+1) for i in range(length)]
# if ii==0:
# y_sum_gt=y_gt
# else:
# y_sum_gt = [y_gt[i]+y_sum_gt[i] for i in range(length)]
# y_ave = [y_sum[i]/(ii+1) for i in range(length)]
# plt.plot(x,y_ave,label='SIDD_noisy')
# y_ave_gt = [y_sum_gt[i]/(ii+1) for i in range(length)]
# plt.plot(x,y_ave_gt,label='SIDD_GT')
# plt.legend()
# plt.show()