【Python深度学习系列】十几行代码教你使用CTGAN模拟生成表格数据

一、问题

    在机器学习中,我们经常会遇到数据集数量不足的情况。CTGAN是一个生成对抗网络(GAN)的实现,它可以学习原始数据的分布并生成具有相似特征的合成数据。生成的数据将尽量保持与原始数据的统计特性一致。

二、实现过程

2.1 安装CTGAN库

pip install ctgan

使用命令在终端或命令提示符中安装CTGAN库

2.2 导入CTGAN类

from ctgan import CTGAN

在Python代码中导入CTGAN类

2.3 准备原始数据

hci_data=pd.read_csv(r'HCI_Datasheet.csv')
hci_data=hci_data.drop(columns=['S. No','Decision Date','Application Date'])
hci_data['University_Program']=hci_data.University+' '+hci_data.Programme
hci_data.University=hci_data.University_Program
hci_data=hci_data.drop(columns=['Programme','University_Program','Year of Entry'])
#Label Encoding
le = preprocessing.LabelEncoder()
hci_data['Decision'] = le.fit_transform(hci_data['Decision']) ### 0-Accepted, 1-Rejected
hci_data['Research Experience'] = le.fit_transform(hci_data['Research Experience']) ###1-Yes, 0-No
hci_data['Submitted Portfolio'] = le.fit_transform(hci_data['Submitted Portfolio']) ###1-Yes, 0-No
hci_data['Student Status'] = le.fit_transform(hci_data['Student Status']) ### 0-Domestic, 1-International, 2-Undergrad domestic
hci_data['GRE']=hci_data['GRE'].fillna(330)
hci_data['TOEFL']=hci_data['TOEFL'].fillna(120)

原始数据准备为一个Pandas DataFrame对象,数据中包含连续型和离散型特征,以及一个标签列(如果适用)。实例数据如下:

【Python深度学习系列】十几行代码教你使用CTGAN模拟生成表格数据_第1张图片

2.4 创建CTGAN对象并拟合数据

columns=list(hci_data.columns)
ctgan = CTGAN(epochs=100)
ctgan.fit(hci_data, columns)

使用CTGAN类创建一个对象,并使用原始数据拟合模型

2.5 生成模拟数据

synthetic_data = ctgan.sample(len(data))
print(synthetic_data)

使用拟合好的模型生成模拟数据,模拟数据效果如下:

【Python深度学习系列】十几行代码教你使用CTGAN模拟生成表格数据_第2张图片

2.6 效果评价

table_evaluator = TableEvaluator(hci_data, synthetic_data)
try:
    table_evaluator.visual_evaluation()
except:
    print()

效果如下:

【Python深度学习系列】十几行代码教你使用CTGAN模拟生成表格数据_第3张图片

【Python深度学习系列】十几行代码教你使用CTGAN模拟生成表格数据_第4张图片

好了,本篇内容就总结分享到这里,需要源码的小伙伴可以关注底部公众号添加作者微信

作者简介:

读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。

你可能感兴趣的:(深度学习,python,深度学习)