paper: Fourier Features Let Networks LearnHigh Frequency Functions in Low Dimensional Domains
code:官方
tf2
torch
看了下torch的 稍微加了一点注释
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm as tqdm
import os, imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# Download image, take a square crop from the center
image_url = 'https://live.staticflickr.com/7492/15677707699_d9d67acf9d_b.jpg'
img = imageio.imread(image_url)[..., :3] / 255.
c = [img.shape[0]//2, img.shape[1]//2]## 图片中心点
r = 256
img = img[c[0]-r:c[0]+r, c[1]-r:c[1]+r]#裁切到512*512*3
plt.imshow(img)
plt.show()
# Create input pixel coordinates in the unit square
coords = np.linspace(0, 1, img.shape[0], endpoint=False)
x_test = np.stack(np.meshgrid(coords, coords), -1)##(512,512,0)是X轴位置,(512,512,1)是y轴位置
test_data = [x_test, img]
train_data = [x_test[::2,::2], img[::2,::2]]## 按照步长2 在第一维 第二维抽取成256*256*2的索引和图片
class MLP(nn.Module):
def __init__(self,depth=4,mapping_size=512,hidden_size=256):
super().__init__()
layers = []
layers.append(nn.Linear(mapping_size,hidden_size))
layers.append(nn.ReLU(inplace=True))
for _ in range(depth-2):
layers.append(nn.Linear(hidden_size,hidden_size))
layers.append(nn.ReLU(inplace=True))
layers.append(nn.Linear(hidden_size,3))
self.layers = nn.Sequential(*layers)
def forward(self,x):
return torch.sigmoid(self.layers(x))
xb,yb = torch.tensor(train_data[0]).reshape(-1,2),torch.tensor(train_data[1]).reshape(-1,3)
x_test,y_test = torch.tensor(test_data[0]).reshape(-1,2),torch.tensor(test_data[1]).reshape(-1,3)
xb,yb,x_test,y_test = xb.float().cuda(),yb.float().cuda(),x_test.float(),y_test.float()
def map_x(x,B): # x是index B是用来初始化的 如果
xp = torch.matmul(2*math.pi*x,B)
return torch.cat([torch.sin(xp),torch.cos(xp)],dim=-1)
model = MLP().cuda()# fc-relu-fc-relu-fc-relu-fc
opt = torch.optim.Adam(model.parameters(),lr=1e-4)
loss = nn.MSELoss()
B = torch.randn(2,256).cuda() * 10
xt = map_x(xb,B)
for i in tqdm(range(2000)):
ypred = model(xt)
l = loss(ypred,yb)
opt.zero_grad()
l.backward()
opt.step()
def map_x(x,B):
xp = torch.matmul(2*math.pi*x,B)
return torch.sin(xp)
model = MLP(mapping_size=256).cuda()
opt = torch.optim.Adam(model.parameters(),lr=1e-4)
loss = nn.MSELoss()
B = torch.randn(2,256).cuda() * 10
xt = map_x(xb,B)
for i in tqdm(range(2000)):
ypred = model(xt)
l = loss(ypred,yb)
opt.zero_grad()
l.backward()
opt.step()
# Preds
model.cpu().eval()
with torch.no_grad():
ypreds = model(map_x(x_test,B.cpu()))
ypreds = ypreds.reshape(512,512,3)
plt.imshow(ypreds)
plt.hist(xt[:,0].cpu())
def map_x(x,B):
xp = torch.matmul(x,B)
return xp
model = MLP(mapping_size=512).cuda()
opt = torch.optim.Adam(model.parameters(),lr=1e-4)
loss = nn.MSELoss()
B = torch.randn(2,512).cuda() * 10
xt = map_x(xb,B)
for i in tqdm(range(2000)):
ypred = model(xt)
l = loss(ypred,yb)
opt.zero_grad()
l.backward()
opt.step()
# Preds
model.cpu().eval()
with torch.no_grad():
ypreds = model(map_x(x_test,B.cpu()))
ypreds = ypreds.reshape(512,512,3)
plt.imshow(ypreds)
plt.hist(xt[:,0].cpu())