具体参考我的Github,那是一个keras版本的实现,也是Coursera作业使用的框架,我稍稍改编了一下,
里面有些实现的效果以及模型的结构,这里就不多说了,代码也很简单,容易理解。通过源码,你也会发现各种教程里都不太可能说到的东西。成功运行官方教程里给出的mnist程序,并不是你就会了这个框架,甚至连入门都说不上!一定要通读文档和源码!!
def get_batchs(X,Y,batchsize = 3,batchnum = 0):
if (batchnum*batchsize+batchsize) >= X.shape[0]:
bx = X[batchnum * batchsize:]
by = Y[batchnum * batchsize:]
else:
bx = X[batchnum * batchsize:(batchnum * batchsize + batchsize)]
by = Y[batchnum * batchsize:(batchnum * batchsize + batchsize)]
return np.array(bx),np.array(by)
class MNIST(data.Dataset):
urls = [
'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.training_file))
else:
self.test_data, self.test_labels = torch.load(
os.path.join(self.root, self.processed_folder, self.test_file))
def __getitem__(self, index):
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
def download(self):
pass
上面是主要代码,download被我删了,这里也不需要,接下来我们直接进行修改,怎么改呢,首先看__init__方法,赋值部分 都不怎么需要改,虽然本例中几个都用不到,download不需要,所以上面可以直接改成pass,然后把下面的实现删去,
class emojiDataset(Data.Dataset):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.train = train
if download:
pass
if self.train:
traindata,trainlabels = read_csv(os.path.join(self.root,'mytrain.csv'))
self.train_data = sentences_to_indices(traindata, word_to_index, 10)
self.train_labels = trainlabels
# self.train_labels = convert_to_one_hot(trainlabels, C=5)
else:
pass
def __getitem__(self, index):
if self.train:
data, target = self.train_data[index], self.train_labels[index]
else:
pass
if self.transform is not None:
pass
if self.target_transform is not None:
pass
return data, target
def __len__(self):
if self.train:
return 180
else:
return 0