动手学深度学习笔记--LR_scratch

  • 作图

def use_svg_display():

    # 用矢量图显示

    display.set_matplotlib_formats('svg')

 

def set_figsize(figsize=(4.5, 2.5)):

    use_svg_display()

    # 设置图的尺寸

    plt.rcParams['figure.figsize'] = figsize

 

set_figsize()#设置fig大小

plt.scatter(features[:, 1].asnumpy(), labels.asnumpy(), 1);

 

 

plt作图函数,use_svg_display函数和set_figsize函数定义在d2lzh中了,作图只要调用d2lzh.plt就行。只要调用d2lzh.set_figsize()就可以打印矢量图并设置图尺寸。

  • Class range(object)
              range(stop) -->range object
              range(start,stop[,step])  -->range object
  • 自定义函数data_iter()解读
    def data_iter(batch_size,features,labels):
            num_features = len(features)#知道features的数目
            indices = list(range(num_features))#建立索引
            random.shuffle(indices)#将索引打乱
          for I in range(0,num_features,batch):#以batch_size为步长建立batch
               j = nd.array(indices[i:min(i+batch,num_features)])#找到这一batch的样本索引
               yield features.take(j),labels.take(j)
    '''
    yield-->相当于一个generator.挺难理解的样子,下一篇专门来写吧

 

Numpy.take(a,indices,axis=None,out=None,mode='raise') https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html
             其中a-->array源

                 indices-->要提取a中元素的索引
                 axis-->take元素的方式,axis = none 则按元素扫描,先行后列,如果axis = 0,                      

                 就按行扫描。如果axis = 1就按列扫描
ndarray.take与numpy.take有些许不一样
https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.take.html
ndarray.take(indices,axis=None,out=None,mode='raise')
               默认有两种形式,array.take(indices,axis)或者nd.take(source_array,indices,axis)
               其中,ndarray不按元素扫描,axis =none和axis=0效果一样,按行扫描,其中如果    

           indices超过了范围,就会默认取最后一行。同理axis=1按列取。
 

前面的步骤是建立一个数据集和batch训练集
下面是正式训练过程
第一步:因为是希望训练得到参数w,b。所以对w,b求导,故申请w,b的求导梯度内存
              w.attach_grad()

                   b.attach_grad()

第二步:定义线性函数 nd.dot(X,w)+b
第三步:定义损失函数,用于判别训练效果,这里用的均方根损失函数
第四步:定义我们迭代学习的方法,这里用的是SGD,即小批量随机梯度下降算法。
第五步:训练,建立收敛后得到的w,b
 

【记录几个讨论区问题】

1.关于attach_grad,autograd一定要遵从原变量及其地址

动手学深度学习笔记--LR_scratch_第1张图片

2.autograd不支持

你可能感兴趣的:(动手学深度学习笔记--LR_scratch)