导入必要的软件包
投掷骰子
import torch
from torch.distributions import multinomial
from d2l import torch as d2l
fair_probs = torch.ones([6]) / 6
multinomial.Multinomial(1, fair_probs).sample()
Out[5]: tensor([0., 1., 0., 0., 0., 0.])
multinomial.Multinomial(1, fair_probs).sample()
Out[6]: tensor([1., 0., 0., 0., 0., 0.])
multinomial.Multinomial(1, fair_probs).sample()
Out[7]: tensor([0., 1., 0., 0., 0., 0.])
multinomial.Multinomial(1, fair_probs).sample()
Out[8]: tensor([0., 0., 0., 1., 0., 0.])
如果用Python的for循环来完成这个任务,速度会慢得令人难以忍受,因此我们使用的函数支持同时抽取多个样本,返回我们想要的任意形状的独立样本数组。
multinomial.Multinomial(10, fair_probs).sample()
Out[9]: tensor([3., 0., 1., 4., 2., 0.])
multinomial.Multinomial(100, fair_probs).sample()
Out[10]: tensor([15., 20., 18., 12., 18., 17.])
multinomial.Multinomial(10000, fair_probs).sample()
Out[11]: tensor([1617., 1671., 1702., 1700., 1646., 1664.])
计算相对频率作为真实概率的估计
# 将结果存储为32位浮点数以进行除法
counts = multinomial.Multinomial(1000, fair_probs).sample()
counts / 1000 # 相对频率作为估计值
Out[12]: tensor([0.1540, 0.1820, 0.1590, 0.1770, 0.1660, 0.1620])
因为我们是从一个公平的骰子中生成的数据,我们知道每个结果都有真实的概率1616,大约是0.1670.167,所以上面输出的估计值看起来不错。
我们也可以看到这些概率如何随着时间的推移收敛到真实概率。让我们进行500组实验,每组抽取10个样本。
counts = multinomial.Multinomial(10, fair_probs).sample((500,))
counts
Out[14]:
tensor([[2., 3., 0., 1., 2., 2.],
[2., 1., 1., 3., 1., 2.],
[3., 4., 0., 1., 1., 1.],
...,
[2., 0., 1., 1., 2., 4.],
[1., 1., 0., 2., 1., 5.],
[2., 0., 2., 3., 0., 3.]])
cum_counts = counts.cumsum(dim=0)
cum_counts
Out[16]:
tensor([[ 2., 3., 0., 1., 2., 2.],
[ 4., 4., 1., 4., 3., 4.],
[ 7., 8., 1., 5., 4., 5.],
...,
[804., 780., 847., 835., 842., 872.],
[805., 781., 847., 837., 843., 877.],
[807., 781., 849., 840., 843., 880.]])
estimates = cum_counts / cum_counts.sum(dim=1, keepdims=True)
estimates
Out[18]:
tensor([[0.2000, 0.3000, 0.0000, 0.1000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.0500, 0.2000, 0.1500, 0.2000],
[0.2333, 0.2667, 0.0333, 0.1667, 0.1333, 0.1667],
...,
[0.1614, 0.1566, 0.1701, 0.1677, 0.1691, 0.1751],
[0.1613, 0.1565, 0.1697, 0.1677, 0.1689, 0.1758],
[0.1614, 0.1562, 0.1698, 0.1680, 0.1686, 0.1760]])
d2l.set_figsize((6, 4.5))
for i in range(6):
d2l.plt.plot(estimates[:, i].numpy(),
label=("P(die=" + str(i + 1) + ")"))
d2l.plt.axhline(y=0.167, color='black', linestyle='dashed')
d2l.plt.gca().set_xlabel('实验次数')
d2l.plt.gca().set_ylabel('估算概率')
d2l.plt.legend();
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 23454 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 39564 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 27425 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 25968 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 23454 missing from current font.
font.set_text(s, 0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 39564 missing from current font.
font.set_text(s, 0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 27425 missing from current font.
font.set_text(s, 0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 25968 missing from current font.
font.set_text(s, 0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 20272 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 31639 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 27010 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 29575 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 20272 missing from current font.
font.set_text(s, 0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 31639 missing from current font.
font.set_text(s, 0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 27010 missing from current font.
font.set_text(s, 0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 29575 missing from current font.
font.set_text(s, 0, flags=flags)
d2l.set_figsize((6, 4.5))
for i in range(6):
d2l.plt.plot(estimates[:, i].numpy(),
label=("P(die=" + str(i + 1) + ")"))
d2l.plt.axhline(y=0.167, color='black', linestyle='dashed')
d2l.plt.gca().set_xlabel('实验次数')
d2l.plt.gca().set_ylabel('估算概率')
d2l.plt.legend();
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 23454 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 39564 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 27425 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 25968 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 23454 missing from current font.
font.set_text(s, 0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 39564 missing from current font.
font.set_text(s, 0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 27425 missing from current font.
font.set_text(s, 0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 25968 missing from current font.
font.set_text(s, 0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 20272 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 31639 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 27010 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:238: RuntimeWarning: Glyph 29575 missing from current font.
font.set_text(s, 0.0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 20272 missing from current font.
font.set_text(s, 0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 31639 missing from current font.
font.set_text(s, 0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 27010 missing from current font.
font.set_text(s, 0, flags=flags)
C:\ProgramData\Anaconda3\lib\site-packages\matplotlib\backends\backend_agg.py:201: RuntimeWarning: Glyph 29575 missing from current font.
font.set_text(s, 0, flags=flags)
d2l.set_figsize((6, 4.5))
for i in range(6):
d2l.plt.plot(estimates[:, i].numpy(),
label=("P(die=" + str(i + 1) + ")"))
d2l.plt.axhline(y=0.167, color='black', linestyle='dashed')
d2l.plt.gca().set_xlabel('shiyancishu')
d2l.plt.gca().set_ylabel('gusuangailv')
d2l.plt.legend();