训练集和测试集数据分布不同_当您的训练和测试数据来自不同的分布时该怎么办...

训练集和测试集数据分布不同

by Nezar Assawiel

由Nezar Assawiel

当您的训练和测试数据来自不同的分布时该怎么办 (What to do when your training and testing data come from different distributions)

To build a well-performing machine learning (ML) model, it is essential to train the model on and test it against data that come from the same target distribution.

要构建性能良好的机器学习(ML)模型,必须对模型进行训练并针对来自相同目标分布的数据进行测试。

However, sometimes only a limited amount of data from the target distribution can be collected. It may not be sufficient to build the needed train/dev/test sets.

但是,有时只能从目标分发中收集有限数量的数据。 构建所需的训练/开发/测试集可能还不够。

Yet similar data from other data distributions might be readily available. What to do in such a case? Let us discuss some ideas!

然而,其他数据分布中的相似数据可能随时可用。 在这种情况下该怎么办? 让我们讨论一些想法!

一些背景知识 (Some background knowledge)

To better follow the discussion here, you can read up on the following basic ML concepts, if you are not familiar with them already:

为了更好地遵循此处的讨论,如果您还不熟悉以下基本ML概念,则可以阅读它们:

  • Train, dev (development), and test sets: Note that the dev set is also called the validation or the hold-on set. This post is a good short introduction to the topic.

    训练,开发(开发)和测试集:请注意,开发集也称为验证集或保留集。 这篇文章是对该主题的很好的简短介绍。

  • Bias (underfitting) and variance (overfitting) errors: This is a great simple explanation of these errors.

    偏差(拟合不足)和方差(拟合过度)错误: 这是对这些错误的简单解释。

  • How the train/dev/test split is correctly made: You may refer to this post that I have written before for a short background on this topic.

    如何正确进行训练/开发/测试拆分:您可以参考我的这篇文章 之前写过关于该主题的简短背景。

情境 (Scenario)

Say you are building a dog-image classifier application that determines if an image is of a dog or not.

假设您正在构建一个狗图像分类器应用程序,该应用程序将确定图像是否属于狗。

The application is intended for users in rural areas who can take pictures of animals by their mobile devices for the application to classify the animals for them.

该应用程序适用于农村地区的用户,他们可以通过其移动设备为动物拍照,以便为他们分类动物。

Studying the target data distribution — you found that the images are mostly blurry, low resolution, and similar to the following:

在研究目标数据分布时,您发现图像大部分是模糊的,低分辨率的,并且类似于以下内容:

You were only able to collect 8,000 such images, which is not enough to build the train/dev/test sets. Let us assume you have determined you’ll need at least 100,000 images.

您仅能收集8,000个此类图像,不足以构建训练/开发/测试集。 让我们假设您已确定您至少需要100,000张图像。

You wondered if you could use images from another dataset — in addition to the 8,000 images you collected — to build the train/dev/test sets.

您想知道是否可以使用其他数据集的图像(除了您收集的8,000张图像之外)来构建训练/开发/测试集。

You realized you can easily scrape the web to build a dataset of 100,000 images or more, with similar dog-image vs. non-dog-image frequencies to those frequencies required.

您意识到,您可以轻松地在网上刮刮以构建包含100,000张或更多图像的数据集,其狗图像频率与非狗图像频率与所需频率相似。

But, clearly this web dataset comes from a different distribution, with high resolution and clear images such as the following:

但是,很明显,此网络数据集来自不同的分布,具有高分辨率和清晰的图像,例如:

How would you build the train/dev/test sets?

您将如何构建训练/开发/测试集?

You can’t only use the original 8,000 images you collected to build the train/dev/test sets as they are not enough to make a well-performing classifier. Generally, computer vision as other natural perception problems — speech recognition or natural language processing — need a lot of data.

您不能仅使用收集的原始8,000张图像来构建训练/开发/测试集,因为它们不足以构成性能良好的分类器。 通常,计算机视觉和其他自然感知问题(语音识别或自然语言处理)需要大量数据。

Also, you can’t only use the web dataset. The classifier will not perform well on the users’ blurry images, which are different from the clear, high definition web images used to train the model.

另外,您不能只使用网络数据集。 分类器在用户的模糊图像上表现不佳,这与用于训练模型的清晰,高清晰度的网络图像不同。

So, what do you do? Let us consider some possibilities.

所以你会怎么做? 让我们考虑一些可能性。

可能的选择-整理数据 (A possible option — shuffling the data)

Something you can do is to combine the two datasets and randomly shuffle them. Then, split the resulting dataset into train/dev/test sets.

您可以做的就是合并两个数据集并随机对其进行洗牌。 然后,将结果数据集拆分为训练/开发/测试集。

Assuming you decided to go with a 96:2:2% split for the train/dev/test sets, this process will be something like this:

假设您决定对训练/开发/测试集进行96:2:2%的分配,则此过程将如下所示:

With this set up, the train/dev/test sets all come from the same distribution, as illustrated by the colors in the graph above, which is desired.

通过此设置,训练/开发/测试集都来自相同的分布,如上图中的颜色所示,这是理想的。

However, there a big drawback here!

但是,这里有一个很大的缺点!

If you look at the dev set, out of 2,000 images, on average only 148 images come from the target distribution.

如果查看开发集,则在2,000张图像中,平均只有148张图像来自目标分布。

This means that for the most part you are optimizing the classifier for the web images distribution (1,852 images out of 2,000 images) — which is not what you want!

这意味着,在大多数情况下,您正在优化用于Web图像分发的分类器(2,000幅图像中的1,852幅图像)-这不是您想要的!

The same thing can be said about the test set when assessing the performance of the classifier against it. So, this is not a good way to make the train/dev/test split.

在评估分类器针对测试集的性能时,可以说相同的话。 因此,这不是使train / dev / test分开的好方法。

更好的选择 (A better option)

An alternative is to make the dev/test sets come from the target distribution dataset, and the training set from the web dataset.

另一种选择是使开发/测试集来自目标分布数据集,而训练集来自网络数据集。

Say you’re still using 96:2:2% split for the train/dev/test sets as before. The dev/test sets will be 2,000 images each — coming from the target distribution — and the rest will go to the train set, as illustrated below:

假设您仍然像以前一样将96:2:2%的比例用于训练/开发/测试集。 开发/测试集将各有2,000张图像(来自目标发行版),其余的将进入训练集,如下图所示:

Using this split, you will be optimizing the classifier to perform well on the target distribution, which is what you care about. This is because the images of the dev set come solely from the target distribution.

使用此拆分,您将优化分类器以使其在目标分布上表现良好,这正是您所关心的。 这是因为开发人员集的图像仅来自目标分布。

However, the training distribution is now different from the dev/test distribution. This means that for the most part, you are training the classifier on web images. Thus, it will take longer and more effort to optimize the model.

但是,培训分发现在不同于开发/测试分发。 这意味着在大多数情况下,您正在训练Web图像上的分类器。 因此,将需要花费更长的时间和更多的精力来优化模型。

More importantly, you will not be able to easily tell if the classifier error on the dev set relative to the error on the train set is a variance error, a data mismatch error, or a combination of both.

更重要的是,相对于训练集上的错误,您将无法轻松判断开发集上的分类器错误是方差错误,数据失配错误还是两者的组合。

Let us consider this in more detail, and see what we can do about it.

让我们更详细地考虑这一点,看看我们能对此做些什么。

方差与数据不匹配 (Variance vs data mismatch)

Consider the train/dev/test split from the second option above. Assume the human error is zero, for simplicity.

考虑上面第二个选项中的训练/开发/测试拆分。 为简单起见,假设人为错误为零。

Also, let us assume you found that the training error to be 2% and the dev error 10%. How much of the 8% error between these two is due to the data mismatch between the two sets — given they are coming from different distributions? And how much is due to the variance of the model (overfitting)? We can’t tell.

另外,我们假设您发现训练误差为2%,开发误差为10%。 如果这两组数据来自不同的分布,那么两者之间8%的误差中有多少是由于两组数据之间的不匹配所致? 多少是由于模型的差异(过度拟合)引起的? 我们不能说。

Let us modify the train/dev/test split. Take out a small portion of the train set and call it the “bridge” set. The bridge set will not be used to train the classifier. It is instead an independent set. The split now has four sets belonging to two data distributions — as follows:

让我们修改训练/开发/测试拆分。 取出一小部分火车,称为“桥梁”火车。 桥集将不会用于训练分类器。 相反,它是一个独立的集合。 拆分现在有四个集合,分别属于两个数据分布-如下:

方差误差 (Variance error)

With this split, let us assume you found training and dev errors to be 2% and 10%, respectively. You found the bridge error to be 9%, as shown below:

通过这种拆分,让我们假设您发现训练和开发错误分别为2%和10%。 您发现桥接错误为9%,如下所示:

Now, how much of the 8% error between the train and dev set errors is a variance error, and how much of it is a data mismatch error?

现在,训练和开发人员设置错误之间的8%错误中有多少是方差错误,而有多少是数据不匹配错误?

Easy! The answer is 7% variance error and 1% data mismatch error. But why?

简单! 答案是7%方差误差和1%数据失配误差。 但为什么?

It’s because the bridge set comes from the same distribution as the train set, and the error difference between them is 7%. This means the classifier is overfitted to the train set. This tells us we have a high variance problem at hand.

这是因为桥组与火车组具有相同的分布,并且它们之间的误差差异为7%。 这意味着分类器过度适合火车组。 这说明我们手头有一个高方差问题

数据不匹配错误 (Data mismatch error)

Now, let us assume you found the error on the bridge set to be 3% and the rest as before as shown below:

现在,让我们假设您发现桥接器上的错误设置为3%,其余错误如前所示,如下所示:

How much of the 8% error between the train and dev sets is a variance error and how much of it is a data mismatch error?

训练集和开发集之间8%的误差中有多少是方差误差,而有多少是数据失配误差?

The answer is 1% variance error and 7% data mismatch error. Why so?

答案是1%方差误差和7%数据失配误差。 为什么这样?

This time, it is because the classifier performs well on a dataset it hasn’t seen before if it comes from the same distribution, such as the bridge set. It performs poorly if it comes from a different distribution, like the dev set. Thus, we have a data mismatch problem.

这次是因为分类器是否来自相同的分布(例如桥集),因此在以前从未见过的数据集中表现良好。 如果它来自不同的发行版(如开发集),则其性能会很差。 因此,我们有一个数据不匹配的问题

Reducing the variance error is a common task in ML. For example, you can use regularization methods, or allocate a larger train set.

减少方差误差是ML中的常见任务。 例如,您可以使用正则化方法,或分配更大的训练集。

However, mitigating the data mismatch error is a more interesting problem. So, let us talk bout it.

但是,减轻数据不匹配错误是一个更有趣的问题。 所以,让我们谈一谈。

缓解数据不匹配 (Mitigating data mismatch)

To reduce the data mismatch error, you would need to somehow incorporate the characteristics of the dev/test datasets — the target distribution — into the train set.

为了减少数据不匹配错误,您需要以某种方式将开发/测试数据集的特征(目标分布)纳入训练集中。

Collecting more data from the target distribution to add to the train set is always the best option. But, if that is not possible (as we assumed at the beginning of our discussion), you can try the following approaches:

从目标分布收集更多数据以添加到火车集中始终是最佳选择。 但是,如果不可能(如我们在讨论开始时所假设的那样),则可以尝试以下方法:

错误分析 (Error analysis)

Analyzing the errors on the dev set and how they are different from the errors on the train set could give you ideas to address the data mismatch problem.

分析开发人员集上的错误以及它们与培训集上的错误有何不同,可以为您提供解决数据不匹配问题的想法。

For example, if you find many of the errors on the dev set occur when the background of the animal’s image is rocky, you can mitigate such errors by adding animal images with rocky background to the train set.

例如,如果您发现开发集上的许多错误是在动物图像的背景为岩石时发生的,则可以通过将具有岩石背景的动物图像添加到火车集中来减轻此类错误。

人工数据综合 (Artificial data synthesis)

Another way to incorporate the characteristics of the dev/test sets into the train set is to synthesize data with similar characteristics.

将开发/测试集的特征合并到训练集中的另一种方法是合成具有类似特征的数据。

For example, we mentioned before that the images in our dev/test sets are mostly blurry in contrast to the clear images from the web that make most of our train set. You can artificially add blurriness to the images of the train set to be more similar to the dev/test sets as in the following image:

例如,我们之前提到过,与构成大部分训练集的网络清晰图像相比,我们的开发/测试集中的图像大多模糊。 您可以人为地增加火车集图像的模糊度,使其更类似于开发/测试集,如下图所示:

However, there is an important point to notice here!

但是,这里有一个重要的注意事项!

You could end up overfitting your classifier to the artificial characteristics you made.

您最终可能会使分类器过度适应您的人工特征。

In our example, the blurriness you artificially made by some mathematical function might only be a small sub-set of the blurriness that exists in the images of the target distribution.

在我们的示例中,您通过某些数学函数人工创建的模糊可能只是目标分布图像中存在的模糊的一小部分。

In other words, the blurriness in the target distribution could be due to many reasons. For example, fog, low resolution camera, subject movement could all be causes. But your synthesized blurriness may not represent all of these causes.

换句话说,目标分布中的模糊性可能是由于许多原因造成的。 例如,雾气,低分辨率的相机,对象移动都是原因。 但是您综合的模糊性可能并不代表所有这些原因。

More generally, when synthesizing data for the training set for any type of problem (such as computer vision, or speech recognition), you could overfit your model to the synthesized dataset.

更一般而言,当针对任何类型的问题(例如计算机视觉或语音识别)为训练集综合数据时,您可能会将模型过度拟合至综合数据集。

This dataset may look representative enough of the target distribution to the human eye. But in fact, it is only a small set of the target distribution. So, just keep this in mind while using this powerful tool — data synthesis.

该数据集看起来足以代表人眼的目标分布。 但实际上,这只是目标分布的一小部分。 因此,在使用此功能强大的工具-数据综合时,请记住这一点。

综上所述 (In Summary)

When developing an ML model, ideally the trian/dev/test datasets should all come from the same data distribution — that of the data which the model will encounter when used by the userbase.

在开发ML模型时,理想情况下,trian / dev / test数据集应全部来自同一数据分布,即该模型在被用户群使用时将遇到的数据。

However, sometimes it is not possible to collect enough data from the target distribution to build the trian/dev/test sets, while similar data from other distributions is readily available.

但是,有时无法从目标发行版中收集足够的数据来构建trian / dev / test集,而其他发行版中的类似数据也很容易获得。

In such cases, the dev/test sets should come from the target distribution while the data from the other distributions can be used to build (most of) the train set. Data mismatch techniques can then be used to mitigate the the data distribution differences between the train set vs the dev/test sets.

在这种情况下,开发/测试集应来自目标分布,而其他分布中的数据可用于构建(大部分)训练集。 然后可以使用数据不匹配技术来减轻训练集与开发/测试集之间的数据分布差异。

翻译自: https://www.freecodecamp.org/news/what-to-do-when-your-training-and-testing-data-come-from-different-distributions-d89674c6ecd8/

训练集和测试集数据分布不同

你可能感兴趣的:(python,机器学习,人工智能,深度学习,java)