读取rml2016.10a数据集

先介绍数据集

rml2016.10a
该数据集共有220000条样本,每条样本的size(2,128),即I/Q两路数据,每个样本的长度为128.
信噪比为-20dB ~ 18dB,间隔为2dB,即19种不同的信噪比环境,
共11种不同的调制方式,分别为8PSK, AM-DSB, AM-SSB,BPSK, CPFSK, GFSK, 4-PAM, 16-QAM, 64-QAM, QPSK, WBFM
大家都喜欢看图表,下面放图表

样本总量 220000
信噪比范围 -20dB ~ 18dB
单一样本形状 shape(2,128)
调制类别11种 8PSK, AM-DSB, AM-SSB,BPSK, CPFSK, GFSK, 4-PAM, 16-QAM, 64-QAM, QPSK, WBFM

工程文件结构如下

读取rml2016.10a数据集_第1张图片
红框为当前程序执行所在文件

读取数据集并处理成tensor格式
PS:参考了原作者的代码,原作者整体代码采用tensorflow框架,本专栏是在pytorch框架下进行实现

import pickle
import numpy as np
import torch

def load_data():
    with open(r"../data/RML2016.10a_dict.pkl", 'rb') as f:
        data = pickle.load(f, encoding='latin1')
        # data.keys() = ['8psk', -10]
        snrs, mods = map(lambda j: sorted(list(set(map(lambda x: x[j], data.keys())))), [1, 0])
        X = []
        lbl = []
        for mod in mods:
            for snr in snrs:
                X.append(data[(mod, snr)])
                for i in range(data[(mod, snr)].shape[0]): lbl.append((mod, snr))
        X = np.vstack(X)

        np.random.seed(2022)
        n_examples = X.shape[0]
        n_train = n_examples * 0.8     # 划分数据集 n_train : n_test = 8:2
        train_idx = np.random.choice(range(0, int(n_examples)), size=int(n_train), replace=False)
        test_idx = list(set(range(0, n_examples)) - set(train_idx))
        X_train = X[train_idx]
        X_test = X[test_idx]
        def to_onehot(yy):
            yy1 = np.zeros([len(yy), max(yy)+1])
            yy1[np.arange(len(yy)), yy] = 1
            return yy1
        Y_train = to_onehot(list(map(lambda x: mods.index(lbl[x][0]), train_idx)))
        Y_test = to_onehot(list(map(lambda x: mods.index(lbl[x][0]), test_idx)))

        X_train = torch.from_numpy(X_train)
        X_train = X_train.unsqueeze(1)      # [176000,1,2,128]
        Y_train = torch.from_numpy(Y_train)
        X_test = torch.from_numpy(X_test)
        X_test = X_test.unsqueeze(1)        # [44000,1,2,128]
        Y_test = torch.from_numpy(Y_test)

        return X_train, Y_train, X_test, Y_test

if __name__ == '__main__':
    X_train, Y_train, X_test, Y_test = load_data()

你可能感兴趣的:(Pytorch实现调制方式识别,python,numpy)