ValueError: too many dimensions ‘str‘

使用torchtext读取预处理好的文本csv文件时遇到了报错,这是我的代码

import torch
import torch.nn as nn
import torch.utils.data as Data
from torchtext import data

if __name__ == "__main__":
    mytokenize = lambda x : x.split()   # str -> list[str]

    # 创建预处理对象
    TEXT = data.Field(
        sequential=True, # 这是文本序列
        tokenize=mytokenize, # 分词函数
        include_lengths=True,
        use_vocab=True,  # 建立词表
        batch_first=True,   # 批优先
        fix_length=200  # 定长
    )

    LABEL  =data.Field(
        sequential=True,
        use_vocab=False,
        pad_token=None,
        unk_token=None
    )

    # 对列进行处理
    train_test_fields = [
        ("text", TEXT),
        ("label", LABEL)
    ]

    # 读取数据
    traindata, testdata = data.TabularDataset.splits(
        path="./data", format="csv",    # 数据所在目录
        train="imdb_train.csv", test="imdb_test.csv",
        fields=train_test_fields,
        skip_header=True
    )

    ex0 = traindata.examples[0]
    print(ex0.label)

    train_data, val_data  = traindata.split(split_ratio=0.7)

    # 将训练集转化为词向量
    TEXT.build_vocab(train_data, max_size=20000)
    LABEL.build_vocab(train_data)
    # 训练集中的前10个高频词
    print(TEXT.vocab.freqs.most_common(n=10))
    print(TEXT.vocab.itos[:10])

    BATCH_SIZE = 32
    # 定义加载器
    train_iter = data.BucketIterator(train_data, batch_size=BATCH_SIZE)
    val_iter = data.BucketIterator(val_data, batch_size=BATCH_SIZE)
    test_iter = data.BucketIterator(testdata, batch_size=BATCH_SIZE)

    for batch in train_iter:
        break

    print("data's label", batch.label)
    print("shape of data", batch.text[0].shape)
    print("number of data", len(batch.text[1]))

报错是ValueError: too many dimensions 'str'。这个错误可以回溯到tensor对象的生成。结论是split读入文件时,标签是以[“1”]这样的str列表读入的,但是生成词表时需要[1]这样的数字列表。
加上其他博客写的,遇到这个错误可能的原因是某些数据是数字类型的,却以str对象的形式传入torch.tensor中,记得转换。

解决方法

在定义预处理对象LABEL时加入额外的分词函数,这是修改后的:

import torch
import torch.nn as nn
import torch.utils.data as Data
from torchtext import data

if __name__ == "__main__":
    mytokenize = lambda x : x.split()   # str -> list[str]
    labeltokenize = lambda x : [int(x[0])]  # [str] -> [int]
    # 创建预处理对象
    TEXT = data.Field(
        sequential=True, # 这是文本序列
        tokenize=mytokenize, # 分词函数
        include_lengths=True,
        use_vocab=True,  # 建立词表
        batch_first=True,   # 批优先
        fix_length=200  # 定长
    )

    LABEL  =data.Field(
        sequential=True,
        use_vocab=False,
        dtype=torch.int64,
        tokenize=labeltokenize,
        pad_token=None,
        unk_token=None
    )

    # 对列进行处理
    train_test_fields = [
        ("text", TEXT),
        ("label", LABEL)
    ]

    # 读取数据
    traindata, testdata = data.TabularDataset.splits(
        path="./data", format="csv",    # 数据所在目录
        train="imdb_train.csv", test="imdb_test.csv",
        fields=train_test_fields,
        skip_header=True
    )

    ex0 = traindata.examples[0]
    print(ex0.label)

    train_data, val_data  = traindata.split(split_ratio=0.7)

    # 将训练集转化为词向量
    TEXT.build_vocab(train_data, max_size=20000)
    LABEL.build_vocab(train_data)
    # 训练集中的前10个高频词
    print(TEXT.vocab.freqs.most_common(n=10))
    print(TEXT.vocab.itos[:10])

    BATCH_SIZE = 32
    # 定义加载器
    train_iter = data.BucketIterator(train_data, batch_size=BATCH_SIZE)
    val_iter = data.BucketIterator(val_data, batch_size=BATCH_SIZE)
    test_iter = data.BucketIterator(testdata, batch_size=BATCH_SIZE)

    for batch in train_iter:
        break

    print("data's label", batch.label)
    print("shape of data", batch.text[0].shape)
    print("number of data", len(batch.text[1]))

输出:

[1]
[('movie', 29836), ('film', 26802), ('one', 17978), ('like', 13816), ('good', 10270), ('even', 8722), ('would', 8564), ('time', 8433), ('really', 8220), ('story', 8185)]
['', '', 'movie', 'film', 'one', 'like', 'good', 'even', 'would', 'time']
data's label tensor([[1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0,
         0, 0, 1, 0, 0, 0, 0, 1]])
shape of data torch.Size([32, 200])
number of data 32

done(`・ω・´)

你可能感兴趣的:(踩坑专用专栏,python,深度学习)