ML.NET Cookbook:(20)我如何定义自己的数据转换?

ML.NET有很多内置的转换器,但是我们不可能涵盖所有内容。不可避免地,您将需要执行自定义的用户定义操作。为此,我们添加了MLContext.Transforms.CustomMapping就是为了这个目的:这是用户定义的数据的任意映射

假设我们有一个带有float数据的'Income'列的数据集,我们要计算'Label',如果收入超过50000,则等于true,否则等于false

这是我们如何通过自定义转换器执行此操作的方法:

// 为我们打算使用的所有输入列定义一个类。
class InputRow
{
    public float Income { get; set; }
}

// 为我们打算产生的所有输出列定义一个类。
class OutputRow
{
    public bool Label { get; set; }
}

public static IDataView PrepareData(MLContext mlContext, IDataView data)
{
    // 定义操作代码。
    Action mapping = (input, output) => output.Label = input.Income > 50000;
    // 创建一个定制的估计器并转换数据。
    var estimator = mlContext.Transforms.CustomMapping(mapping, null);
    return estimator.Fit(data).Transform(data);
}

您还可以在估计器管道中插入自定义映射:

public static ITransformer TrainModel(MLContext mlContext, IDataView trainData)
{
    // 使用自定义操作。
    Action mapping = (input, output) => output.Label = input.Income > 50000;
    // 构建学习管道。
    var estimator = mlContext.Transforms.CustomMapping(mapping, null)
        .AppendCacheCheckpoint(mlContext)
        .Append(mlContext.BinaryClassification.Trainers.FastTree(labelColumnName: "Label"));

    return estimator.Fit(trainData);
}

请注意,您需要将mapping操作变成“纯函数”

  • 它应该是可重入的(我们将从多个线程同时调用它)

  • 它不应该有副作用(我们可以在任何时候任意调用,或忽略调用)

一个重要的警告是:如果希望自定义转换成为已保存模型的一部分,则需要为其提供contractName。在加载时,您需要向MLContext注册自定义转换器。

下面是一个完整的示例,用于保存和加载带有自定义映射的模型。

/// 
/// 一个类包含我们的模型所需的自定义映射功能。
/// 
/// It has a  on it and
/// derives from .
/// 
[CustomMappingFactoryAttribute(nameof(CustomMappings.IncomeMapping))]
public class CustomMappings : CustomMappingFactory
{
    // 这是自定义映射。我们现在将它分离为一个方法,以便在训练和加载中都可以使用它。
    public static void IncomeMapping(InputRow input, OutputRow output) => output.Label = input.Income > 50000;

    // 当加载模型以获取映射操作时,将调用此工厂方法。
    public override Action GetMapping()
    {
        return IncomeMapping;
    }
}
// 构建学习管道。请注意,我们现在为自定义映射提供了一个约定名称:否则我们将无法保存模型。
var estimator = mlContext.Transforms.CustomMapping(CustomMappings.IncomeMapping, nameof(CustomMappings.IncomeMapping))
    .Append(mlContext.BinaryClassification.Trainers.FastTree(labelColumnName: "Label"));

// 如果内存足够,我们可以将数据缓存在内存中,以避免在多次访问文件时从文件中加载数据。
var cachedTrainData = mlContext.Data.Cache(trainData);

// 训练模型
var model = estimator.Fit(cachedTrainData);

// 保存模型。
using (var fs = File.Create(modelPath))
    mlContext.Model.Save(model, fs);

// 现在假设我们在一个不同的过程中。

// 向ComponentCatalog注册包含“CustomMappings”的程序集,以便在加载模型时可以找到它。
newContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly);

// 现在我们可以加载模型了。
ITransformer loadedModel = newContext.Model.Load(modelPath, out var schema);

欢迎关注我的个人公众号”My IO“

你可能感兴趣的:(python,java,tensorflow,人工智能,编程语言)