本项目使用 C# 和 ML.NET 对美国成人人口普查数据进行分析和分类预测。目标是根据输入的数据特征(如年龄、职业、教育程度等)预测个人的收入是否超过 50,000 美元。
此示例演示如何通过使用 IEnumerable 将数据库用作 ML.NET 管道的数据源。由于数据库被视为任何其他数据源,因此可以查询数据库并将其结果用于训练和预测场景。
企业用户需要使用其公司数据库中的现有数据集来训练和预测 ML.NET 模型。尽管在大多数情况下,在训练机器学习模型之前都需要清理和准备数据,但许多企业对数据库非常熟悉,并且更喜欢将集中化和安全的数据保留在数据库服务器中,而不是处理导出的纯文本文件。
age, workclass, fnlwgt, education, education-num, marital-status, occupation, relationship, ethnicity, sex, capital-gain, capital-loss, hours-per-week, native-country, label(IsOver50K)
39, State-gov, 77516, Bachelors, 13, Never-married, Adm-clerical, Not-in-family, White, Male, 2174, 0, 40, United-States, 0
50, Self-emp-not-inc, 83311, Bachelors, 13, Married-civ-spouse, Exec-managerial, Husband, White, Male, 0, 0, 13, United-States, 0
38, Private, 215646, HS-grad, 9, Divorced, Handlers-cleaners, Not-in-family, White, Male, 0, 0, 40, United-States, 0
53, Private, 234721, 11th, 7, Married-civ-spouse, Handlers-cleaners, Husband, Black, Male, 0, 0, 40, United-States, 0
28, Private, 338409, Bachelors, 13, Married-civ-spouse, Prof-specialty, Wife, Black, Female, 0, 0, 40, Cuba, 0
37, Private, 284582, Masters, 14, Married-civ-spouse, Exec-managerial, Wife, White, Female, 0, 0, 40, United-States, 0
age:
表示个人的年龄。这是一个数值型字段。
workclass:
表示工作类型或职业类别,如“Private”(私营)、“Self-emp-not-inc”(自雇但无公司)、“Federal-gov”(联邦政府)等。
fnlwgt:
这是“final weight”的缩写,表示人口普查中的家庭权重。这个字段用于调整抽样数据的代表性,确保结果能够反映总体情况。
education:
表示教育程度,如“Bachelors”(学士学位)、 “Some college”(完成部分大学课程)、“HS-grad”(高中毕业)等。
education-num:
表示教育程度的编号,通常是对教育层次进行量化后的数值。例如,“HS-grad”可能被编码为9,“Bachelors”为13。
marital-status:
表示婚姻状况,如“Married-civ-spouse”(已婚且有合法配偶)、“Never-married”(未婚)、“Divorced”(离婚)等。
occupation:
表示职业类型,如“Tech-support”(技术支持)、 “Sales”(销售)、“Managerial”(管理职位)等。
relationship:
表示家庭关系,如“Husband”(丈夫)、 “Wife”(妻子)、 “Child”(子女)、 “Own-child”(自己孩子)等。
ethnicity:
表示种族或民族背景,常见的包括“White”(白人)、 “Black”(黑人)、 “Asian”(亚洲人)、 “Hispanic”(西班牙裔)等。
sex:
表示性别,通常分为“Male”(男性)和“Female”(女性)两类。
capital-gain:
表示资本收益,即来自投资、股票等的收入。
capital-loss:
表示资本损失,与资本收益相反,指投资上的亏损。
hours-per-week:
表示每周工作小时数,通常用于衡量工作强度或兼职/全职状态。
native-country:
表示原籍国,即个人的国籍或出生地,如“United-States”(美国)、 “Mexico”(墨西哥)、 “Germany”(德国)等。
label(IsOver50K):
这是目标字段,通常是一个二分类变量,表示该个体的年收入是否超过5万美元。例如,“>50K”表示收入超过5万美元,“<=50K”表示不超过5万美元。
你不能直接对事务表执行简单的联接查询? - 即使技术上可以从任何联接查询创建 IEnumerable,但在大多数实际情况下,这对于机器学习算法/训练器来说并不奏效。
数据准备之所以重要,是因为大多数机器学习训练器/算法需要数据以非常特定的方式格式化或输入特征列必须是特定的数据类型,因此数据集通常在训练模型之前需要进行一些准备。你还需要清理数据,有些数据源可能包含缺失值(空值、未定义),或者无效值(数据可能需要转换为不同的比例,你可能需要对特征中的数值进行上采样或归一化等),从而使训练过程要么失败,要么产生不准确的结果,甚至产生误导性的结果。因此,在几乎所有情况下都需要在训练 ML 模型之前进行数据准备。
public static void CreateDatabase(string url)
{
var dataset = ReadRemoteDataset(url);
// ... 数据清洗与存储逻辑 ...
}
var data = dataset
.Skip(1) // 跳过表头行
.Select(l => l.Split(','))
.Where(row => row.Length > 1)
.Select(row => new AdultCensus()
{
Age = int.Parse(row[0]),
Workclass = row[1],
Education = row[3],
MaritalStatus = row[5],
Occupation = row[6],
Relationship = row[7],
Race = row[8],
Sex = row[9],
CapitalGain = row[10],
CapitalLoss = row[11],
HoursPerWeek = int.Parse(row[12]),
NativeCountry = row[13],
Label = (int.Parse(row[14]) == 1) ? true : false
});
db.AdultCensus.AddRange(data);
var count = db.SaveChanges();
Console.WriteLine($"Total count of items saved to database: {count}");
var mlContext = new MLContext(seed: 1);
// 加载数据并划分训练集和测试集
var dataView = mlContext.Data.LoadFromEnumerable(QueryData());
var trainTestData = mlContext.Data.TrainTestSplit(dataView);
// 构建特征工程管道:对分类变量进行独热编码,并拼接特征向量
var pipeline = mlContext.Transforms.Categorical.OneHotEncoding(new[] {
new InputOutputColumnPair("MsOHE", "MaritalStatus"),
new InputOutputColumnPair("OccOHE", "Occupation"),
new InputOutputColumnPair("RelOHE", "Relationship"),
new InputOutputColumnPair("SOHE", "Sex"),
new InputOutputColumnPair("NatOHE", "NativeCountry")
}, OneHotEncodingEstimator.OutputKind.Binary)
.Append(mlContext.Transforms.Concatenate("Features", "MsOHE", "OccOHE", "RelOHE", "SOHE", "NatOHE"))
.Append(mlContext.BinaryClassification.Trainers.LightGbm());
// 训练模型
Console.WriteLine("Training model...");
var model = pipeline.Fit(trainTestData.TrainSet);
// 使用训练好的模型进行预测并计算性能指标
var prediction = model.Transform(trainTestData.TestSet);
var metrics = mlContext.BinaryClassification.Evaluate(prediction);
Console.WriteLine($"Accuracy: {metrics.Accuracy}");
Console.WriteLine($"Recall: {metrics.Recall}");
// ... 其他评估指标 ...
HttpClient
下载数据集,并将其内容转换为字符串流。1. 梯度提升框架
LightGBM 是一个基于梯度提升(Gradient Boosting)的框架。梯度提升是一种集成学习的方法,通过训练多个弱分类器(如决策树),然后将其组合起来形成一个强分类器。
基本思想: 每个新模型都试图拟合前一个模型的残差(即预测值与真实值之间的误差)。通过不断迭代,逐步优化模型的预测能力。
优势:
2. 基于直方图的算法
LightGBM 使用了一种基于直方图(Histogram-based)的优化方法来提升训练效率。这种方法将特征值分桶处理,减少了计算量。
实现步骤:
优势:
3. Leaf-wise 的生长策略
与传统的基于节点(Node-wise)的分裂不同,LightGBM 使用了Leaf-wise的策略来生成新的叶子节点。这种策略能够更好地控制树的深度,并且有助于防止过拟合。
工作原理:
优势:
4. 混合策略
LightGBM 结合了基于直方图的算法和Leaf-wise的生长策略,形成了高效的训练方法。
5. 分布式训练
LightGBM 支持分布式训练,能够在多台机器上并行处理数据,适用于大规模数据集。
实现机制:
优势:
6. 正则化与防止过拟合
LightGBM 提供了多种正则化机制来防止过拟合,确保模型的泛化能力。
L1 和 L2 正则化:
其他参数控制:
优势:
7. 参数调优
LightGBM 提供了许多参数来控制模型的行为,选择合适的参数组合对模型性能至关重要。
关键参数:
调优方法:
8. 特征工程
LightGBM 对特征数据有一定的要求,良好的特征工程可以提升模型性能。
9. 缺失值与类别变量处理
LightGBM 能够有效地处理缺失值和类别变量,增强了其适用性。
缺失值处理:
类别变量处理:
10. 与其他梯度提升框架的比较
XGBoost:
CatBoost:
LightGBM 的优势:
劣势:
适用场景:
LightGBM 是一个高效、强大的梯度提升框架,基于直方图和Leaf-wise策略,能够在保证高准确性的同时实现快速的训练。其分布式支持使其适用于处理大规模数据集。在实际应用中,合理调优参数和进行有效的特征工程能够进一步提升模型性能。理解其工作原理和优化机制,对于最大化利用LightGBM的优势、避免常见问题是非常重要的。
本项目展示了如何利用 ML.NET 进行从数据准备到模型构建与评估的完整机器学习流程。通过将成人人口普查数据存储于数据库,并使用 LightGBM 分类器进行收入预测,该项目为实际应用中类似的数据分析任务提供了一个参考实现。在后续开发中,可以进一步优化特征工程步骤,尝试其他分类算法,并对模型性能进行更全面的评估。