sklearn.utils.shuffle解析

在进行机器学习时,经常需要打乱样本,这种时候Python中叒有第三方库提供了这个功能——sklearn.utils.shuffle。

Shuffle arrays or sparse matrices in a consistent way. This is a convenience alias to resample(*arrays, replace=False) to do random permutations of the collections.

函数参数

Parameters

参数 介绍
*array 带索引的序列,可以是arrays, lists, dataframes或scipy sparse matrices
random_state int,随机量,就是一个random seed。如果是int,该参数作为random seed的值;如果是None,随机生成器就是一个np.random实例
n_sample int,默认为None,输出的样本数目。如果是空,则样本数目会设置为array的第一维元素数

Returns

参数 介绍
shuffled_arrays 带索引的序列,是一个view(也就是说不会改变输入array)

Examples

sklearn.utils.shuffle解析_第1张图片

解释:例程中建立了3个带索引的序列:array, array和sparse matrix。然后将它们作为一个元组进行shuffle,其中random_state=0表示它们的打乱方式是方式0。这个打乱方式不理解的可以看一下np.random.seed的介绍或者是看我接下来对源码的解析。

源码

shuffle

def shuffle(*arrays, **options):
    options['replace'] = False
    return resample(*arrays, **options)

Are you kidding? 这是个“空壳函数”。唯一的作用就是将一个参数replace置为了False,好让shuffle过程中不影响输入array(不过要记住这个replace,这是sklearn.utils.shufflesklearn.utils.resample唯一的区别)。

那么下面来看resample函数。

resample

def resample(*arrays, **options):
    '''先是类型检测部分,可以跳过'''
    random_state = check_random_state(options.pop('random_state', None))  # 此处注意:返回类型变了,变成:np.random.mtrand._rand或np.random.RandomState(seed)或seed
    replace = option.pop('replace', True)  # 如果没有‘replace’则返回True
    max_n_samples = options.pop('n_samples', None)
    if options:
        raise ValueError("Unexpected kw arguments: %r" % options.keys())

    if len(arrays) == 0:
    return None

    first = arrays[0]
    n_samples = first.shape[0] if hasattr(first, 'shape') else len(first)

    if max_n_samples is None:
        max_n_samples = n_samples
    elif (max_n_samples > n_samples) and (not replace):
        raise ValueError("Cannot sample %d out of arrays with dim %d when replace is False" % (max_n_samples, n_samples))
    check_consistent_length(*array)
    '''开始正文'''
    '''重排索引'''
    if replace:
        indices = random_state.randint(0, n_samples, size=(max_n_samples,))  # 创建新的随机序列索引indices
    else:
        indices = np.arange(n_samples)
        random_state.shuffle(indices)
        indices = indices[:max_n_samples]

    # convert sparse matrices to CSR for row-based indexing
    arrays = [a.tocsr() if issparse(a) else a for a in arrays]
    '''根据indices对arrays进行采样'''
    resampled_arrays = [safe_indexing(a, indices) for a in arrays]
    '''分两种情况,一种是输入的*arrays参数只有一个序列,另一种是输入的*arrays参数是一元组的序列'''
    if len(resampled_arrays) == 1:
        # syntactic sugar for the unit argument case
        return resampled_arrays[0]
    else:
        return resampled_arrays

解释:代码分为三部分:

  1. 类型检测
  2. 构建重排后的索引
  3. 根据索引输出序列

random_state在函数中用于产生随机索引

你可能感兴趣的:(Python)