https://www.cnblogs.com/Terrypython/p/9577657.html
import numpy as np
def FindImageBBox(img):
v_sum = np.sum(img, axis=0)
start_i = None
end_i = None
minimun_range = 10
maximun_range = 20
min_val = 10
peek_ranges = []
ser_val = 0
# 从左往右扫描,遇到非零像素点就以此为字体的左边界
for i, val in enumerate(v_sum):
#定位第一个字体的起始位置
if val > min_val and start_i is None:
start_i = i
ser_val = 0
#继续扫描到字体,继续往右扫描
elif val > min_val and start_i is not None:
ser_val = 0
#扫描到背景,判断空白长度
elif val <= min_val and start_i is not None:
ser_val = ser_val + 1
if (i - start_i >= minimun_range and ser_val > 2) or (i - start_i >= maximun_range):
# print(i)
end_i = i
#print(end_i - start_i)
if start_i> 5:
start_i = start_i-5
peek_ranges.append((start_i, end_i+2))
start_i = None
end_i = None
#扫描到背景,继续扫描下一个字体
elif val <= min_val and start_i is None:
ser_val = ser_val+1
else:
raise ValueError("cannot parse this case...")
return peek_ranges
image = cv2.imread(now_images, cv2.IMREAD_GRAYSCALE)
cropped1 = image[345:384, 0:115]
ret, image1 = cv2.threshold(cropped1, 127, 255, cv2.THRESH_BINARY_INV)
box = FindImageBBox(image1)
for l,i in enumerate(box):
cropped2 = cropped1[0:39, i[0]:i[1]] # 裁剪坐标为[y0:y1, x0:x1]
cv2.imwrite(os.path.join(path1,f"{name_1}_small{l}.jpg"), cropped2)
class SiameseNetworkDataset(Dataset):
def __init__(self ,imageFolderDataset ,transform=None ,should_invert=True):
self.imageFolderDataset = [os.path.join(imageFolderDataset,i) for i in os.listdir(imageFolderDataset)]
self.transform = transform
self.should_invert = should_invert
def __getitem__(self ,index):
imgpath = random.choice(self.imageFolderDataset) # 37个类别中任选一个
img = Image.open(imgpath)
img = img.resize((120,60))
img0 = img.crop((0,0,60,60))
img1 = img.crop((60,0,120,60))
img00 = img0.convert("L")
img11 = img1.convert("L")
label = int(imgpath.split('_')[-1].replace('.jpg',''))
if self.should_invert:
img00 = PIL.ImageOps.invert(img00)
img11 = PIL.ImageOps.invert(img11)
if self.transform is not None:
img00 = self.transform(img00)
img11 = self.transform(img11)
return img00, img11, torch.from_numpy(np.array([label] ,dtype=np.float32))
def __len__(self):
return len(self.imageFolderDataset)
transform = transforms.Compose(
[
# transforms.RandomCrop((40,40)),
# transforms.ColorJitter,
transforms.RandomVerticalFlip(),
# transforms.RandomCrop(),
transforms.RandomHorizontalFlip(),
transforms.Resize((100 ,100)),
transforms.ToTensor()])
# 搭建模型
class SiameseNetwork(nn.Module):
def __init__(self):
super().__init__()
self.cnn1 = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(1, 4, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(4),
nn.ReflectionPad2d(1),
nn.Conv2d(4, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
nn.ReflectionPad2d(1),
nn.Conv2d(8, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
)
self.fc1 = nn.Sequential(
nn.Linear(8 * 100 * 100, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 5))
def forward_once(self, x):
output = self.cnn1(x)
output = output.view(output.size()[0], -1)
output = self.fc1(output)
return output
def forward(self, input1, input2):
output1 = self.forward_once(input1)
output2 = self.forward_once(input2)
return output1, output2
网络搞好了肯定就需要结果了,说白了就是计算相似度
class ContrastiveLoss(torch.nn.Module):
"""
Contrastive loss function.
Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
"""
def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
return loss_contrastive
最后一步就是组装训练,等待结果
def train():
global iteration_number
for epoch in range(0, train_number_epochs):
for i,data in enumerate(train_dataloader,0):
img0, img1, label = data
# img0维度为torch.Size([32, 1, 100, 100]),32是batch,label为torch.Size([32, 1])
img0, img1, label = img0.cuda(), img1.cuda(), label.cuda() # 数据移至GPU
optimizer.zero_grad()
output1, output2 = net(img0, img1)
loss_contrastive = criterion(output1, output2, label)
loss_contrastive.backward()
optimizer.step()
if i % 10 == 0:
iteration_number += 10
counter.append(iteration_number)
loss_history.append(loss_contrastive.item())
if epoch%20 ==0 and epoch!=0:
torch.save(net.state_dict(), f'siamese.pth')
print("Epoch number: {} , Current loss: {:.4f}".format(epoch, loss_contrastive.item()))
show_plot(counter, loss_history)
到此为止是不是很简单呢?你学会了么?