以下内容可直接以写入.ipynb文件的形式,放入服务器上任意已准备好数据集的文件夹下:
import os
import json
import re
from collections import defaultdict
def draw(dic): #输入样本数量统计字典
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
from collections import namedtuple
fig, ax = plt.subplots()
n_groups = len(dic) #列数
index=[]
data=[]
for i in sorted (dic) :
index.append(i)
data.append(dic[i])
bar_width = 0.2 #每条柱状的宽度
rects1 = ax.bar(index, data, bar_width,label='length') #绘制柱状图
ax.legend() #绘制图例(即右上角的方框)
fig.tight_layout()
plt.show()
#合并字典
def merge(dic1,dic2):
for key in dic1:
if key in dic2:
dic1[key]+=dic2[key]
for key in dic2:
if key not in dic1:
dic1[key]=dic2[key]
return dic1
# 样本数量统计
# 样本长度统计、样本长度平均值、样本长度分布
def fun(dataset,path): #输入数据集名称,文件名
samples_num=0
samples_len_dic=defaultdict(int)
samples_lens=0
with open(os.path.join(dataset,path+".jsonl"),"r",encoding="utf-8") as f: #默认为jsonl文件格式
lines=f.readlines()
samples_num=len(lines)
for line in lines:
line=json.loads(line) #将数据转换为字典格式
text=line["text"]
# 第一步,非英文字符用空格代替 #是否要对英文文本做预处理
# text = re.sub(r"'s", " 's", text)
# text = re.sub(r"'ve", " 've", text)
# text = re.sub(r"n't", " n't", text)
# text = re.sub(r"'re", " 're", text)
# text = re.sub(r"'d", " 'd", text)
# text = re.sub(r"'ll", " 'll", text)
# text = re.sub(r"[^A-Za-z]", " ", text)
text=text.split()
samples_len_dic[len(text)]+=1
samples_lens+=len(text)
return samples_num,samples_lens,samples_len_dic #输出:样本数量、样本总长度、样本长度分布字典
for dataset in ["dataset_name1","dataset_name2"]:
print(dataset)
samples_num,samples_lens,samples_len_dic=0,0,defaultdict(int)
for path in ["train","dev","test"]:
samples_num1,samples_lens1,samples_len_dic1=fun(dataset,path)
samples_num+=samples_num1
samples_lens+=samples_lens1
samples_len_dic=merge(samples_len_dic,samples_len_dic1)
count=0
for key in samples_len_dic:
count+=samples_len_dic[key]
if count!=samples_num:
print("False")
print("样本数量:",samples_num)
print("样本长度分布数值:",sorted(samples_len_dic.items(),key=lambda x:x[0]))
print("样本长度平均值:",samples_lens/samples_num)
print("样本长度分布图表:",)
draw(samples_len_dic)
from PIL import Image
import matplotlib.pyplot as plt
def print_out(dataset,sample): #sample为样本(字典格式)
print("原文本:",sample["text"])
print("原图片:")
img_id=sample["image"]
path=os.path.join(dataset+"/img",img_id)
img=Image.open(path)
display(img.resize((int(img.size[0]*0.5),int(img.size[1]*0.5)),Image.ANTIALIAS)) #此处为缩小50%