pytorch之数据处理

pytorch之数据处理

对于深度学习来说,要用大量的训练数据训练深度神经网络。因此,学习深度学习的过程中,学习数据处理是不可避免的。本文简单记录了一下我最近学习到的数据处理的基础知识。本文将数据处理分为以下四个步骤来介绍:

读取数据集

这部分介绍了读取数据的基本流程,和几种常用格式的数据读取方法。

读取步骤

(1) 拼接路径
os.path.join()函数:连接两个或更多的路径名组件

1.如果各组件名首字母不包含’/’,则函数会自动加上

2.第一个以”/”开头的参数开始拼接,之前的参数全部丢弃,当有多个时,从最后一个开始

3.如果最后一个组件为空,则生成的路径以一个’/’分隔符结尾
(2) 打开文件
with open(filepath,‘r’) as f:
(3) 读取文件
这一部分介绍了一些常用格式(如txt、csv、xlsx文件)文件的读取方式。

  • 读取.txt文件
    使用python即可读取txt文件,有四种读取方法。
  1. 直接读取
with open(filepath,'r') as f:
	for line in f:
        print(type(line))
        print(line)
 
"""

一次一行
"""
  1. f.read()
    一次性把所有文本都读出来。读取的结果是str类型的一堆字符串。
with open(filepath,'r') as f:
    ff = f.read()
    print(type(ff))
    print(ff)
    
"""

一堆字符串,原文本什么样,读到的就是什么样。
"""
  1. f.readlines()
    一次性读出所有内容,但是得到的是一个list列表。
with open(filepath,'r') as f:
    lines = f.readlines()
    print(type(lines))
    print(lines[1])
  
"""

列表中的第二个元素
"""
  1. f.readline()
    一次读取一行,得到的是字符串。
with open(filepath,'r') as f:
	line = f.readline()
    while line:
        print(line)
        line = f.readline()
     
"""
一次一行
"""
  • 读取csv文件
    csv文件是以纯文本文件存储表格数据,即原表格文件中的数字类型数据、日期类型数据等在csv文件中一律都转化为字符型数据。csv文件中的一行代表一个数据记录,记录之间由换行符分割。记录由多个字段组成,字段之间由逗号或‘\t’分割。
  1. 使用csv包
import csv

with open(filepath,'r') as f:
    csv_f = csv.reader(f)
    for line in csv_f:
        print(line)
  1. 使用pandas包
import pandas as pd

with open(filepath,'r') as f:
    df = f.read_csv(f)
    print(df.shape)
	print(df['col1'])  # 读取一列
    print(df.iloc[3])  # 读取第三行
  • 读取xlsx文件
    xlsx文件是excel2007以后版本多出来的文件类型,可以存储多于65535行的大量数据。
  1. 使用openpyxl包
import os
from openpyxl import load_workbook

wb = load_workbook(filepath)
ws = wb['Sheet1']

# 输出第5行B列的值。
print(ws['B5'].value)

# 输出一个列表,这个列表里存储了A列的所有表格对象。注意是对象,儿不是一个数值。
# 例如(, , , ...)
print(ws['A'])

# 输出A列到C列的所有表格类,包括C列
print(ws['A:C'])

# 输出第5行所有单元格的值。
for cell in ws['5']:
    print(cell.value)
  1. 使用pandas包
import pandas as pd
import os

df = pd.read_excel(filepath)
print(df.shape)

参考:
pandas常用函数
pandas教程
openpyxl教程

处理数据

根据任务自定义数据处理方式

构建数据集类

在pytorch中,所有数据集类都要集成Dataset类,并且重写__len__()和__getitem__()方法。
DataLoader类将数据集对象与不同的取样器连用,对数据进行取样,批处理等操作。

from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import os

class ClimateClass(Dataset):
    def __init__(self, rootdir, filename):
        self.filepath = os.path.join(rootdir, filename)
        self.file = pd.read_csv(self.filepath)
    
    def __len__(self):
        return len(self.file)
    
    def __getitem__(self, index):
        return self.file.iat[index, 0], self.file.iat[index, 2]
   
rootdir = 'E:\datasets\jena_climate_2009_2016'
filename = 'jena_climate_2009_2016.csv'
mydata = ClimateClass(rootdir, filename)
for i in range(len(mydata)):
    time, temperature = mydata[i]
    print("{}:{}".formate(time, temperature))

   
dataloader = DataLoader(dataset=mydata, batch_size=5, shuffle=True)

# ibatch = __len__ / batch_size
# batch_data的值取决于__getitem__的返回值
for i_batch, batch_data in enumerate(dataloader):
    print(i_batch)
    print(batch_data)

参考:
DataLoader教程

数据可视化

深度学习常用的可视化库为matplotlib库。

绘图步骤

绘制画板

有四种绘制方式:

  1. 先通过plt.figure()定义一个fig,再使用fig.add_subplot()绘制坐标系。

    import matplotlib as plt
    
    fig1 = plt.figure(num=1,figsize=(6,6),dpi=100)    # num唯一标识该画板
    ax1 = fig1.add_subplot(221) 
    ax1.set(xlim=[-2,2],ylim=[-2,4],title='example1',xlabel='x',ylabel='y')
    ax2 = fig1.add_subplot(224)
    ax1.set(xlim=[-2,2],ylim=[-2,4],title='example2',xlabel='x',ylabel='y')
    fig.show()
    
  2. 使用Figure和Axes类实例化。多用于绘制自定义排版的多个坐标图。

    import matplotlib as plt
    
    # 生成多个窗口
    fig1 = plt.figure(num=1,figsize=(6,6),dpi=100)
    fig2 = plt.figure(num=2,figsize=(8,8),dpi=100)
    
    # 激活fig1
    plt.figure(num=1)
    ax1 = plt.axes((0.1,0.2,0.4,0.4)) # 四元组参数代表坐标在画板上的位置,数值均为0~1的小数,代表百分比
    ax1.set(xlim=[-2,2],ylim=[-4,4],title='example1',xlabel='x',ylabel='y')
    ax2 = plt.axes((0.5,0.2,0.3,0.8))
    ax2.set(...)
    plt.show()
    
    
  3. 使用plt.subplot,省略了显示定义figure步骤。

    import matplotlib as plt
    
    ax1 = plt.subplot(221)
    ax2 = plt.subplot(223)
    plt.show()
    
  4. 使用plt.subplots,返回一个figure对象和一个坐标系对象数组。

import matplotlib as plt

fig,axes = plt.subplots(2,2) # 返回一个画板对象和四个坐标系对象,其中多个坐标系对象以numpy.nparray的形势存在
axes[0][1].set(...)
plt.show()
绘制图表

以下是一些绘制特定图的函数。
pytorch之数据处理_第1张图片

import matplotlib.pyplot as plt
import numpy as np

x = np.arange(0,10,0.2)
y1 = np.sin(x)
y2 = np.cos(x)

fig = plt.figure(num=1, figsize=(6,6), dpi=100)
ax1 = fig.add_subplot(111)
plt.xlabel('x')
plt.ylabel('y')
plt.title('example1')
plt.axis([0,10,-1,1])
plt.xticks([0,3.14,6.28,9.52],['0','pi','2pi','3pi'])
plt.yticks(-1,0,1)

plt.plot(x,y1,'b-',label='y=sinx')
plt.plot(x,y2,'g-',label='y=cosx')
plt.legend(loc=1,fontsize='medium')

plt.grid(True)
plt.show()

坐标轴配置

  1. 刻度范围设置

    plt.xlim(s, e)

    plt.ylim(s, e)

表示刻度从s开始,到e结束。

  1. 自定义显示的刻度值

    plt.xticks([0, 3.14, 6.28, 9.52], [‘0’, ‘pi’, ‘2pi’, ‘3pi’], rotation=90)

    plt.yticks(…)

表示x轴显示出来的刻度为0,3.14,6.28,9.52,但由于给每一个刻度值又赋予了一个别名pi,2pi等,所以实际显示的是别名。rotation表示文本旋转90度显示出来。

  1. 自定义刻度之间的间隔

    from matplotlib.pyplot import MultipleLocator

    ax.xaxis.set_major_locator(MultipleLocator(2)),也就是x轴显示出来的刻度是一个差值为2的等差数列。

    ax.yaxis.set_major_locator(…)

特别地,对于时间格式的刻度有特定的包来用。

​ import matplotlib.dates as mdates

​ mdates.YearLocator() 刻度显示年份,月、日同理。

​ mdates.WeekdayLocator() 星期

​ mddates.WeekdayLocator(Wednesday) 显示每周三的数据

​ mdates.HourLocator() 小时

​ mdates.MinuteLocator() 分钟

  1. 设置刻度显示格式

    ax.xaxis.set_major_formatter(mdates.DateFormatter(‘%y-%m-%d’)) ,按日期的格式显示。

    注意要把从文件中读出来的日期数据转换成日期类型后才能正确显示日期格式。

    转换方法为:pandas.to_datetime(…)

参考:
[matplotlib教程]: https://blog.csdn.net/lemonbit/article/details/107096392
[matplotlib函数]: https://blog.csdn.net/weixin_45568391/article/details/111346371

你可能感兴趣的:(pytorch,python,深度学习)