(略)
!unzip -qo data/data115112/ChineseStyle.zip -d data
label = paddle.to_tensor([1,0], dtype='int64')
one_hot_label =paddle.nn.functional.one_hot(label,num_classes=2)
print(one_hot_label )
#定义ChineseStyleDataset数据集类
class ChineseStyleDataset(Dataset):
#构造数据集和标签集
def __init__(self,transforms=None,train="train"):
super().__init__()
self.transforms=transforms
self.datas=list() #创建data列表成员,存放图像数据
self.labels=list() #创建labels列表成员,存放标签数据
#self.temps=list()
font_style=[("lishu",0),("xingkai",1)]
#遍历train文件夹下的xingkai图片
for font_tuple in font_style:
font_name=font_tuple[0]
font_val=font_tuple[1]
font_path="data/ChineseStyle/{}/{}".format(train,font_name)
for filename in os.listdir(font_path):
if ".ipynb_checkpoints" in filename: #图片文件内会自动包含.ipynb_checkpoints文件,需要排除掉
continue
img_path=os.path.join(font_path,filename)
#读入图片
photo=Image.open(img_path)
im=np.array(photo).astype("float32")
if im is not None: #防止读到空图片,遇到非空图片才加入数据列表
self.datas.append(im)
#self.temps.append([img_path,font_val])
self.labels.append(np.array(font_val,dtype="int64"))
#print(self.temps)
def __getitem__(self,index):
data=self.datas[index]
if self.transforms is not None:
data=self.transforms(data)
label=self.labels[index]
return data,label
def __len__(self):
return len(self.labels)
#预处理方案组合
transforms=T.Compose([T.Resize([227,227]),T.Normalize(mean=[0,0,0], std=[255,255,255], data_format='HWC'),T.ToTensor()])
#创建数据集实例
#训练数据集
train_dataset=ChineseStyleDataset(transforms=transforms)
#测试数据集
test_dataset=ChineseStyleDataset(transforms=transforms,train="test")
train_dataloader=DataLoader(dataset=train_dataset,shuffle=True,batch_size=200)
test_dataloader=DataLoader(dataset=test_dataset,shuffle=True)
# np.set_printoptions(threshold=np.inf)
# for id,data in enumerate(train_dataloader):
# print(np.array(data[0][0]))
# plt.imshow(np.array(data[0][0]))
# break
class AlexNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.conv_pool1=paddle.nn.Sequential(
paddle.nn.Conv2D(in_channels=3,out_channels=96,kernel_size=[11,11],stride=4,padding="valid"),
paddle.nn.ReLU(),
paddle.nn.MaxPool2D(kernel_size=[3,3],stride=2)
)
self.conv_pool2=paddle.nn.Sequential(
paddle.nn.Conv2D(in_channels=96,out_channels=256,kernel_size=[5,5],stride=1,padding="same"),
paddle.nn.ReLU(),
paddle.nn.MaxPool2D(kernel_size=[3,3],stride=2)
)
self.conv_pool3=paddle.nn.Sequential(
paddle.nn.Conv2D(in_channels=256,out_channels=384,kernel_size=[3,3],stride=1,padding="SAME"),
paddle.nn.ReLU()
)
self.conv_pool4=paddle.nn.Sequential(
paddle.nn.Conv2D(in_channels=384,out_channels=384,kernel_size=[3,3],stride=1,padding="SAME"),
paddle.nn.ReLU()
)
self.conv_pool5=paddle.nn.Sequential(
paddle.nn.Conv2D(in_channels=384,out_channels=256,kernel_size=[3,3],stride=1,padding="SAME"),
paddle.nn.ReLU(),
paddle.nn.MaxPool2D(kernel_size=[3,3],stride=2)
)
self.full_con=paddle.nn.Sequential(
paddle.nn.Linear(in_features=256*6*6,out_features=4096),
paddle.nn.ReLU(),
paddle.nn.Dropout(0.5),
paddle.nn.Linear(in_features=4096,out_features=4096),
paddle.nn.ReLU(),
paddle.nn.Dropout(0.5),
paddle.nn.Linear(in_features=4096,out_features=2)#,
#paddle.nn.Softmax()
)
self.flatten=paddle.nn.Flatten()
self.act=paddle.nn.Sigmoid()
def forward(self,x):
x=self.conv_pool1(x)
x=self.conv_pool2(x)
x=self.conv_pool3(x)
x=self.conv_pool4(x)
x=self.conv_pool5(x)
x=self.flatten(x)
x=self.full_con(x)
x=self.act(x)
return x
#实例化网络模型
alexNet=AlexNet()
paddle.summary(alexNet,(1,3,227,227))
#把模型实例封装高层API的Model对象
model=paddle.Model(alexNet)
#网络配置
model.prepare(optimizer=paddle.optimizer.Adam(parameters=model.parameters(),
learning_rate=0.001
),
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy()
)
vsDL=paddle.callbacks.VisualDL("log_dir")
model.fit(train_data=train_dataloader,epochs=10,verbose=1,callbacks=vsDL)
model.evaluate(eval_data=test_dataloader,verbose=1)
model.save("mymodel/AlexNet")
im=np.array(Image.open("data/ChineseStyle/train/xingkai/xingkai_1001.jpg")).astype("float32")
#im=paddle.to_tensor(im)
im=im.reshape([1,3,256,256])
alexNet=AlexNet()
model=paddle.Model(alexNet)
model.load("mymodel/AlexNet")
result=model.predict_batch(im)
print(result)
运行结果: