from dataset import GazeCaptureDataset
from transformers import TrainingArguments
from transformers import DeiTForImageClassification
from torch import nn
from transformers import Trainer
from transformers import DeiTConfig
# 数据集根路径
root_path = r"D:\datasets\GazeCapture_new"
# 1. 定义 Dataset
test_dataset = GazeCaptureDataset(root_path, data_type='test')
# 2. 定义 DeiT 图像模型
configuration = DeiTConfig(num_labels=2, problem_type="regression")
model = DeiTForImageClassification(configuration).from_pretrained('gaze_trainer/checkpoint-500')
# 3. 测试
## 3.1 定义测试参数
testing_args = TrainingArguments(output_dir="pred_trainer")
## 3.2 自定义 Trainer
class CustomTester(Trainer):
# 重写计算 loss 的函数
def compute_loss(self, model, inputs, return_outputs=False):
# 获取标签值
labels = inputs.get("labels")
# 获取输入值
x = inputs.get("pixel_values")
# 模型输出值
outputs = model(x)
logits = outputs.get('logits')
# 定义损失函数为平滑 L1 损失
loss_fct = nn.SmoothL1Loss()
# 计算输出值和标签的损失
loss = loss_fct(logits, labels)
return (loss, outputs) if return_outputs else loss
## 3.3 定义 Trainer 对象
tester = CustomTester(
model=model,
args=testing_args,
)
## 3.4 调用 predict 方法,开始测试
output = tester.predict(test_dataset=test_dataset)
# 4. 测试结果
print(output)
dataset.py
代码如下:
import os.path
from torch.utils.data import Dataset
from transform import transform
import numpy as np
# 读取数据,如果是训练数据,随即打乱数据顺序
def get_label_list(label_path):
# 存储所有标签文件中的所有内容
full_lines = []
# 获取所有标签文件的名称,如 00002.label, 00003.label, ......
label_names = os.listdir(label_path)
# 遍历每一个标签文件,并读取其中内容
for label_name in label_names:
# 标签文件全路径,如 D:\datasets\GazeCapture_new\Label\train\00002.label
label_abs_path = os.path.join(label_path, label_name)
# 读取每一个标签文件中的内容
with open(label_abs_path) as flist:
# 存储该标签文件中的所有内容
full_line = []
for line in flist:
full_line.append(line.strip())
# 移除首行表头 'Face Left Right Grid Xcam, Ycam Xdot, Ydot Device'
full_line.pop(0)
full_lines.extend(full_line)
return full_lines
class GazeCaptureDataset(Dataset):
def __init__(self, root_path, data_type):
self.data_dir = root_path
# 标签文件的根路径,如 D:\datasets\GazeCapture_new\Label\train
label_root_path = os.path.join(root_path + '/Label', data_type)
# 获取所有标签文件中的所有内容
self.full_lines = get_label_list(label_root_path)
# 每一行内容的分隔符
self.delimiter = ' '
# 数据集长度,也就是一共有多少个图片
self.num_samples = len(self.full_lines)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 标签文件的一行,对应一个训练实例
line = self.full_lines[idx]
# 将标签文件中的一行内容按照分隔符进行分割
Face, Left, Right, Grid, XYcam, XYdot, Device = line.split(self.delimiter)
# 获取网络的输入:人脸图片
face_path = os.path.join(self.data_dir + '/Image/', Face)
# 读取人脸图像
with open(face_path, 'rb') as f:
img = f.read()
# 将人脸图像进行格式转化:缩放、裁剪、标准化
pixel_values = transform(img)
# 获取标签值
labels = np.array(XYcam.split(","), np.float32)
# 注意返回值的形式一定要是 {"labels": xxx, "pixel_values": xxx}
result = {"labels": labels}
result["pixel_values"] = pixel_values
return result
输出结果如下:
***** Running Prediction *****
Num examples = 1716
Batch size = 8
100%|██████████| 215/215 [01:52<00:00, 1.90it/s]
PredictionOutput(predictions=array([[-2.309026 , -2.752627 ],
[-2.0178156, -3.0546618],
[-1.8222798, -3.309564 ],
...,
[-2.6463585, -2.3462727],
[-2.2149038, -2.7406967],
[-1.7267275, -3.3450181]], dtype=float32), label_ids=array([[ 0.969375, -7.525975],
[ 0.969375, -7.525975],
[ 0.969375, -7.525975],
...,
[ 5.5845 , 1.93875 ],
[ 5.5845 , 1.93875 ],
[ 5.5845 , 1.93875 ]], dtype=float32), metrics={'test_loss': 2.8067691326141357, 'test_runtime': 118.2811, 'test_samples_per_second': 14.508, 'test_steps_per_second': 1.818})
可以看到该模型在测试集的损失值是 2.8067691326141357
。