下面是TRLRF算法的函数部分
from matplotlib.image import imread
import numpy as np
from numpy import int8, linalg
from numpy import array, random
import matplotlib.pyplot as plt
def RSE_fun(X,X_hat,W):
pos_test = np.where((X != 0) & (W == 0))
rse = (np.linalg.norm(X_hat[pos_test] - X[pos_test], 2)
/ np.linalg.norm(X[pos_test], 2))
print(rse)
def coreten2tr(Z):
N=len(Z)
S=[]
for i in range(N):
S.append(Z[i].shape[1])
P=Z[0]
for i in range(1,N):
L=P.reshape(int(P.size/Z[i-1].shape[2]),Z[i-1].shape[2],order='F')
R=Z[i].reshape(Z[i].shape[0],S[i]*Z[i].shape[2],order='F')
P=L@R
P=P.reshape(Z[0].shape[0],np.prod(S),Z[N-1].shape[2],order='F')
P=np.moveaxis(P,0,-1)
P=P.reshape(np.prod(S),Z[0].shape[0]*Z[0].shape[0],order='F')
temp=np.eye(Z[0].shape[0],Z[0].shape[0])
P=P@temp.reshape(np.prod(temp.shape))
X=P.reshape(S,order='F')
return X
def Pro2TraceNorm(Z,tau):
m,n=Z.shape
if 2*m < n:
AAT = Z@Z.T
S, Sigma2, D =linalg.svd(AAT)
V = np.sqrt(Sigma2)
tol = max(Z.shape) * np.spacing(max(V))
n = sum(V > max(tol, tau))
mid = np.maximum(V[0:n]-tau, 0) / V[0:n]
X = S[:, 0:n] @ np.diag(mid) @ S[:, 0:n].T @ Z
return X
if m>2*n:
X= Pro2TraceNorm(Z.T, tau)
X = X.T
return X
S,V,D = linalg.svd(Z)
D=D.T
Sigma2 = np.diag(V)
n = sum(V > tau)
X = S[0:,0:n] @ np.maximum(Sigma2[0:n,0:n]-tau, 0) @ D[:, 0:n].T
return X
def Gunfold(GT,mode):
return np.reshape(np.moveaxis(GT, mode, 0), (GT.shape[mode], -1), order = 'F')
def Gfold(mat, tensor_size, mode):
index = list()
index.append(mode)
for i in range(len(tensor_size)):
if i != mode:
index.append(int(i))
size = []
for i in index:
size.append(int(tensor_size[i]))
return np.moveaxis(np.reshape(mat, size, order = 'F'), 0, mode)
def Z_neq(Y,n):
G=[]
for i in range(n+1,len(Y)):
G.append(Y[i])
for i in range(n+1):
G.append(Y[i])
Z=G
N=len(Z)
P=Z[0]
for i in range(N-2):
zl=P.reshape(int(np.size(P)/(Z[i].shape[2])),int(Z[i].shape[2]),order='F')
zr=Z[i+1].reshape(int(Z[i+1].shape[0]),int(np.size(Z[i+1])/Z[i+1].shape[0]),order='F')
P=zl@zr
Z_neq_out=P.reshape(int(Z[0].shape[0]),int(np.size(P)/(Z[0].shape[0]*Z[N-2].shape[2])),int(Z[N-2].shape[2]),order='F')
return Z_neq_out
def tenmat_sb(X,k):
S=X.shape
N=len(S)
if k==0:
X_sb_k=X.reshape(int(S[0]),int(X.size/S[0]),order='F')
elif k==N-1:
X_sb_k=X.reshape(int(X.size/S[N-1]),int(S[N-1]),order='F')
X_sb_k=np.transpose(X_sb_k)
else:
X=X.reshape(int(np.prod(S[0:k])),int(X.size/np.prod(S[0:k])),order='F')
X=np.transpose(X)
X_sb_k=X.reshape(int(S[k]),int(X.size/S[k]),order='F')
return X_sb_k
def Msum_fun(M):
N=len(M)
Msum_out=[]
for i in range(N):
Msum_out.append(M[i][0]+M[i][1]+M[i][2])
return Msum_out
def gen_W(S,mr):
random_tensor = np.random.rand(S[0], S[1], S[2])
p = mr
W=np.round(random_tensor + 0.5 - p)
return W
def TR_initcoreten(S,r):
N=len(S)
Z=[]
for i in range(N-1):
Z.append(np.random.randn(r[0][i],S[i],r[0][i+1]))
Z.append(np.random.randn(r[0][N-1],S[N-1],r[0][0]))
return Z
def TRLRF(data,W,r,maxiter,mu,ro,Lamda,tol):
Truth_data=data.copy()
T=data*W
N=len(T.shape)
S=T.shape
X= np.random.rand(S[0], S[1], S[2])
G=TR_initcoreten(S,r)
M=[[] for _ in range(N)]
Y=[[] for _ in range(N)]
for i in range(N):
G[i]=1*G[i]
for j in range(N):
M[i].append(np.zeros(G[i].shape))
Y[i].append(np.sign(G[i]))
mu_max=10^2
Convergence_rec=np.zeros((1,maxiter))
iter=0
while iter<maxiter:
iter=iter+1
for n in range(N):
Msum=Msum_fun(M)
Ysum=Msum_fun(Y)
Q=tenmat_sb(Z_neq(G,n),1)
Q=Q.T
G[n]=Gfold((Lamda*tenmat_sb(X,n)@Q.T+mu*Gunfold(Msum[n],1)+Gunfold(Ysum[n],1))@linalg.pinv((Lamda*(Q@Q.T)+3*mu*np.eye(Q.shape[0],Q.shape[0]))),G[n].shape,1)
for j in range(3):
Df=Gunfold(G[n]-Y[n][j]/mu,j)
M[n][j]=Gfold(Pro2TraceNorm(Df,1/mu),G[n].shape,j)
lastX=X
X_hat=coreten2tr(G)
X=X_hat
X[W==1]=T[W==1]
for n in range(N):
for j in range(3):
Y[n][j]=Y[n][j]+mu*(M[n][j]-G[n])
mu=min(mu*ro,mu_max)
print('Iter: {}'.format(iter + 1))
pos_test = np.where((Truth_data != 0) & (W == 0))
rse = (np.linalg.norm(X_hat[pos_test] - Truth_data[pos_test], 2)
/ np.linalg.norm(Truth_data[pos_test], 2))
print(rse)
print()
return X
下面是对于随机缺失80%图片的恢复代码
image=imread(r'C:\Users\wh\Desktop\TRLRF\TRLRF-Python\lena.bmp')
X=np.array((image)/255)
mr=0.8
W=gen_W(X.shape,mr);
r=8*np.ones((1,3),dtype=int);
maxiter=1000
tol=1e-8
Lambda=5
ro=1.1
mu=1e0
X_hat=TRLRF(X,W,r,maxiter,mu,ro,Lambda,tol)
fig = plt.figure(figsize = (15, 8))
ax = fig.add_subplot(1, 3, 1)
plt.imshow(image)
ax = fig.add_subplot(1, 3, 2)
plt.imshow(np.uint8(W*image))
ax = fig.add_subplot(1, 3, 3)
image_hat=np.uint8(255*X_hat)
plt.imshow(image_hat)
plt.show()