TextRNN预测下一个句子

import gc
import torch
import numpy as np
from torch import nn
from torch import optim
from torch.utils.data import DataLoader,TensorDataset
dtypes =torch.FloatTensor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')

构造数据

sentences =[ "i like you", "i love coffee", "i hate milk","i think you"]
word_list = " ".join(sentences).split()
print("*"*80)
print("word_list:",word_list)
vocab =list(set(word_list))
print("*"*80)
print("vocab:",vocab)
word2idx ={n:i for i,n in enumerate(vocab)}
print("*"*80)
print("word2idx:",word2idx)
idx2word = {i:n for i,n in enumerate(vocab)}
print("*"*80)
print("idx2word:",idx2word)
n_class = len(vocab)
print("*"*80)
print("n_class:",n_class)
print("*"*80)
********************************************************************************
word_list: ['i', 'like', 'you', 'i', 'love', 'coffee', 'i', 'hate', 'milk', 'i', 'think', 'you']
********************************************************************************
vocab: ['you', 'like', 'think', 'milk', 'i', 'love', 'coffee', 'hate']
********************************************************************************
word2idx: {'you': 0, 'like': 1, 'think': 2, 'milk': 3, 'i': 4, 'love': 5, 'coffee': 6, 'hate': 7}
********************************************************************************
idx2word: {0: 'you', 1: 'like', 2: 'think', 3: 'milk', 4: 'i', 5: 'love', 6: 'coffee', 7: 'hate'}
********************************************************************************
n_class: 8
********************************************************************************

构建Dataset

batch_size = 2
time_step = 2
n_hidden = 5

def make_data(sentences):
    inputs_ = []
    targets_ = []
    
    for sen in sentences:
        word = sen.split()
        inputs = [word2idx[n] for n in word[:-1]]
        target = word2idx[word[-1]]
        inputs_.append(np.eye(n_class)[inputs])
        targets_.append(target)
    return inputs_,targets_

inputs,targets= make_data(sentences)
inputs = torch.Tensor(inputs)
targets= torch.LongTensor(targets)

dataset = TensorDataset(inputs,targets)
train_loader = DataLoader(dataset,batch_size,shuffle=True)
:18: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ..\torch\csrc\utils\tensor_new.cpp:201.)
  inputs = torch.Tensor(inputs)
for x,y in train_loader:
    print("*"*40)
    print(x.shape,y.shape)
    print("*"*40)
    print(x,y)
    print("="*40)
# x:[batch_size,seq_len,features]
****************************************
torch.Size([2, 2, 8]) torch.Size([2])
****************************************
tensor([[[0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1.]]]) tensor([0, 3])
========================================
****************************************
torch.Size([2, 2, 8]) torch.Size([2])
****************************************
tensor([[[0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0.]]]) tensor([0, 6])
========================================

定义网络

class text_rnn(nn.Module):
    def __init__(self):
        super(text_rnn,self).__init__()
        self.rnn = nn.RNN(n_class,n_hidden,batch_first = True)
        self.fc= nn.Linear(n_hidden,n_class)
    def forward(self,x):
        # x:[batch_size,seq_len,features]
#         h_0 = torch.zeros(2,2,8)
        out,hidden = self.rnn(x)
        #out:[batch_size,seq_len,n_hiddens*num_directions]
        #hidden:[num_layers*num_directions,batch_Size,n_hiddens]
        output = self.fc( out[-1])
        return output
#实例化
model = text_rnn().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr = 0.001)

loss_all = []
num_epochs = 500
for epoch in range(num_epochs):
    train_loss =0
    train_num = 0
    for step,(x,y) in enumerate(train_loader):
        x= x.to(device)
        y =y.to(device)
        z_hat=model.forward(x)
        loss = criterion(z_hat,y)
        loss.backward()
        optimizer.zero_grad()
        optimizer.step()
        train_loss+=loss.item()*len(y)
        train_num+=len(y)
    loss_all.append(train_loss/train_num)
    print(f"Epoch:{epoch+1} Loss:{loss_all[-1]:0.8f}")
    del x,y,loss,train_loss,train_num
    gc.collect()
    torch.cuda.empty_cache()
Epoch:1 Loss:1.91730571
Epoch:2 Loss:2.00270271
Epoch:3 Loss:1.91730571
Epoch:4 Loss:1.76124740
Epoch:5 Loss:1.91730571
Epoch:6 Loss:1.84664440
Epoch:7 Loss:1.91730571
Epoch:8 Loss:1.87935627
Epoch:9 Loss:1.76124740
Epoch:10 Loss:1.87935627
Epoch:11 Loss:1.88459384
Epoch:12 Loss:2.00270271
Epoch:13 Loss:1.91730571
Epoch:14 Loss:1.84664440
Epoch:15 Loss:1.84664440
Epoch:16 Loss:1.76124740
Epoch:17 Loss:2.00270271
Epoch:18 Loss:1.84664440
Epoch:19 Loss:1.87935627
Epoch:20 Loss:1.88459384
Epoch:21 Loss:2.00270271
Epoch:22 Loss:1.87935627
Epoch:23 Loss:1.76124740
Epoch:24 Loss:2.00270271
Epoch:25 Loss:1.88459384
Epoch:26 Loss:2.00270271
Epoch:27 Loss:1.91730571
Epoch:28 Loss:1.87935627
Epoch:29 Loss:1.84664440
Epoch:30 Loss:2.00270271
Epoch:31 Loss:1.84664440
Epoch:32 Loss:1.76124740
Epoch:33 Loss:1.84664440
Epoch:34 Loss:1.76124740
Epoch:35 Loss:2.00270271
Epoch:36 Loss:1.76124740
Epoch:37 Loss:2.00270271
Epoch:38 Loss:1.84664440
Epoch:39 Loss:2.00270271
Epoch:40 Loss:2.00270271
Epoch:41 Loss:1.84664440
Epoch:42 Loss:1.91730571
Epoch:43 Loss:1.87935627
Epoch:44 Loss:1.84664440
Epoch:45 Loss:2.00270271
Epoch:46 Loss:1.91730571
Epoch:47 Loss:1.76124740
Epoch:48 Loss:2.00270271
Epoch:49 Loss:1.91730571
Epoch:50 Loss:2.00270271
Epoch:51 Loss:1.91730571
Epoch:52 Loss:1.84664440
Epoch:53 Loss:1.76124740
Epoch:54 Loss:1.91730571
Epoch:55 Loss:2.00270271
Epoch:56 Loss:2.00270271
Epoch:57 Loss:1.84664440
Epoch:58 Loss:1.87935627
Epoch:59 Loss:1.84664440
Epoch:60 Loss:1.88459384
Epoch:61 Loss:1.76124740
Epoch:62 Loss:1.88459384
Epoch:63 Loss:1.91730571
Epoch:64 Loss:1.84664440
Epoch:65 Loss:2.00270271
Epoch:66 Loss:2.00270271
Epoch:67 Loss:1.84664440
Epoch:68 Loss:1.84664440
Epoch:69 Loss:1.87935627
Epoch:70 Loss:1.88459384
Epoch:71 Loss:1.84664440
Epoch:72 Loss:1.76124740
Epoch:73 Loss:1.84664440
Epoch:74 Loss:1.87935627
Epoch:75 Loss:1.88459384
Epoch:76 Loss:2.00270271
Epoch:77 Loss:1.87935627
Epoch:78 Loss:1.88459384
Epoch:79 Loss:1.87935627
Epoch:80 Loss:1.88459384
Epoch:81 Loss:1.91730571
Epoch:82 Loss:2.00270271
Epoch:83 Loss:2.00270271
Epoch:84 Loss:1.84664440
Epoch:85 Loss:1.84664440
Epoch:86 Loss:1.88459384
Epoch:87 Loss:1.76124740
Epoch:88 Loss:1.88459384
Epoch:89 Loss:1.84664440
Epoch:90 Loss:1.87935627
Epoch:91 Loss:1.87935627
Epoch:92 Loss:1.91730571
Epoch:93 Loss:2.00270271
Epoch:94 Loss:1.87935627
Epoch:95 Loss:2.00270271
Epoch:96 Loss:1.87935627
Epoch:97 Loss:1.84664440
Epoch:98 Loss:1.76124740
Epoch:99 Loss:1.84664440
Epoch:100 Loss:2.00270271
Epoch:101 Loss:1.76124740
Epoch:102 Loss:1.84664440
Epoch:103 Loss:1.87935627
Epoch:104 Loss:1.91730571
Epoch:105 Loss:1.87935627
Epoch:106 Loss:1.84664440
Epoch:107 Loss:1.91730571
Epoch:108 Loss:1.87935627
Epoch:109 Loss:1.91730571
Epoch:110 Loss:1.91730571
Epoch:111 Loss:1.88459384
Epoch:112 Loss:1.87935627
Epoch:113 Loss:1.91730571
Epoch:114 Loss:2.00270271
Epoch:115 Loss:1.88459384
Epoch:116 Loss:1.84664440
Epoch:117 Loss:2.00270271
Epoch:118 Loss:1.91730571
Epoch:119 Loss:2.00270271
Epoch:120 Loss:2.00270271
Epoch:121 Loss:1.87935627
Epoch:122 Loss:1.87935627
Epoch:123 Loss:1.91730571
Epoch:124 Loss:1.91730571
Epoch:125 Loss:2.00270271
Epoch:126 Loss:1.84664440
Epoch:127 Loss:1.91730571
Epoch:128 Loss:1.88459384
Epoch:129 Loss:2.00270271
Epoch:130 Loss:1.91730571
Epoch:131 Loss:2.00270271
Epoch:132 Loss:1.88459384
Epoch:133 Loss:1.88459384
Epoch:134 Loss:1.91730571
Epoch:135 Loss:1.76124740
Epoch:136 Loss:1.87935627
Epoch:137 Loss:1.88459384
Epoch:138 Loss:2.00270271
Epoch:139 Loss:1.76124740
Epoch:140 Loss:1.87935627
Epoch:141 Loss:1.91730571
Epoch:142 Loss:1.91730571
Epoch:143 Loss:1.76124740
Epoch:144 Loss:1.88459384
Epoch:145 Loss:1.84664440
Epoch:146 Loss:1.84664440
Epoch:147 Loss:2.00270271
Epoch:148 Loss:1.84664440
Epoch:149 Loss:1.84664440
Epoch:150 Loss:1.88459384
Epoch:151 Loss:1.87935627
Epoch:152 Loss:1.84664440
Epoch:153 Loss:2.00270271
Epoch:154 Loss:1.76124740
Epoch:155 Loss:1.87935627
Epoch:156 Loss:1.91730571
Epoch:157 Loss:1.76124740
Epoch:158 Loss:1.91730571
Epoch:159 Loss:1.76124740
Epoch:160 Loss:1.84664440
Epoch:161 Loss:1.87935627
Epoch:162 Loss:1.87935627
Epoch:163 Loss:1.88459384
Epoch:164 Loss:1.76124740
Epoch:165 Loss:1.87935627
Epoch:166 Loss:1.84664440
Epoch:167 Loss:1.88459384
Epoch:168 Loss:1.91730571
Epoch:169 Loss:2.00270271
Epoch:170 Loss:1.76124740
Epoch:171 Loss:1.76124740
Epoch:172 Loss:1.87935627
Epoch:173 Loss:2.00270271
Epoch:174 Loss:1.88459384
Epoch:175 Loss:1.87935627
Epoch:176 Loss:1.91730571
Epoch:177 Loss:1.88459384
Epoch:178 Loss:1.84664440
Epoch:179 Loss:1.87935627
Epoch:180 Loss:2.00270271
Epoch:181 Loss:2.00270271
Epoch:182 Loss:1.84664440
Epoch:183 Loss:1.76124740
Epoch:184 Loss:1.87935627
Epoch:185 Loss:1.91730571
Epoch:186 Loss:1.91730571
Epoch:187 Loss:1.91730571
Epoch:188 Loss:1.88459384
Epoch:189 Loss:1.87935627
Epoch:190 Loss:1.87935627
Epoch:191 Loss:2.00270271
Epoch:192 Loss:2.00270271
Epoch:193 Loss:1.84664440
Epoch:194 Loss:1.76124740
Epoch:195 Loss:1.88459384
Epoch:196 Loss:2.00270271
Epoch:197 Loss:1.87935627
Epoch:198 Loss:1.88459384
Epoch:199 Loss:1.88459384
Epoch:200 Loss:1.76124740
Epoch:201 Loss:1.84664440
Epoch:202 Loss:2.00270271
Epoch:203 Loss:2.00270271
Epoch:204 Loss:1.91730571
Epoch:205 Loss:1.76124740
Epoch:206 Loss:1.76124740
Epoch:207 Loss:1.91730571
Epoch:208 Loss:1.76124740
Epoch:209 Loss:1.91730571
Epoch:210 Loss:1.88459384
Epoch:211 Loss:2.00270271
Epoch:212 Loss:1.91730571
Epoch:213 Loss:2.00270271
Epoch:214 Loss:2.00270271
Epoch:215 Loss:1.76124740
Epoch:216 Loss:1.84664440
Epoch:217 Loss:1.84664440
Epoch:218 Loss:1.88459384
Epoch:219 Loss:1.76124740
Epoch:220 Loss:1.88459384
Epoch:221 Loss:2.00270271
Epoch:222 Loss:1.88459384
Epoch:223 Loss:1.87935627
Epoch:224 Loss:1.87935627
Epoch:225 Loss:2.00270271
Epoch:226 Loss:2.00270271
Epoch:227 Loss:1.76124740
Epoch:228 Loss:1.76124740
Epoch:229 Loss:1.76124740
Epoch:230 Loss:1.91730571
Epoch:231 Loss:2.00270271
Epoch:232 Loss:2.00270271
Epoch:233 Loss:1.76124740
Epoch:234 Loss:1.76124740
Epoch:235 Loss:1.76124740
Epoch:236 Loss:1.87935627
Epoch:237 Loss:1.76124740
Epoch:238 Loss:1.84664440
Epoch:239 Loss:1.91730571
Epoch:240 Loss:1.87935627
Epoch:241 Loss:1.76124740
Epoch:242 Loss:2.00270271
Epoch:243 Loss:1.84664440
Epoch:244 Loss:1.76124740
Epoch:245 Loss:1.87935627
Epoch:246 Loss:2.00270271
Epoch:247 Loss:1.76124740
Epoch:248 Loss:1.87935627
Epoch:249 Loss:1.91730571
Epoch:250 Loss:1.91730571
Epoch:251 Loss:1.76124740
Epoch:252 Loss:1.87935627
Epoch:253 Loss:2.00270271
Epoch:254 Loss:1.76124740
Epoch:255 Loss:1.91730571
Epoch:256 Loss:1.84664440
Epoch:257 Loss:1.87935627
Epoch:258 Loss:1.87935627
Epoch:259 Loss:1.76124740
Epoch:260 Loss:1.87935627
Epoch:261 Loss:1.84664440
Epoch:262 Loss:1.91730571
Epoch:263 Loss:1.91730571
Epoch:264 Loss:1.87935627
Epoch:265 Loss:1.84664440
Epoch:266 Loss:1.88459384
Epoch:267 Loss:1.87935627
Epoch:268 Loss:1.91730571
Epoch:269 Loss:1.87935627
Epoch:270 Loss:1.88459384
Epoch:271 Loss:1.88459384
Epoch:272 Loss:1.87935627
Epoch:273 Loss:1.87935627
Epoch:274 Loss:1.88459384
Epoch:275 Loss:2.00270271
Epoch:276 Loss:2.00270271
Epoch:277 Loss:2.00270271
Epoch:278 Loss:1.88459384
Epoch:279 Loss:1.91730571
Epoch:280 Loss:2.00270271
Epoch:281 Loss:1.88459384
Epoch:282 Loss:2.00270271
Epoch:283 Loss:2.00270271
Epoch:284 Loss:2.00270271
Epoch:285 Loss:1.88459384
Epoch:286 Loss:2.00270271
Epoch:287 Loss:2.00270271
Epoch:288 Loss:1.87935627
Epoch:289 Loss:1.87935627
Epoch:290 Loss:1.87935627
Epoch:291 Loss:1.84664440
Epoch:292 Loss:1.76124740
Epoch:293 Loss:1.88459384
Epoch:294 Loss:2.00270271
Epoch:295 Loss:1.87935627
Epoch:296 Loss:1.84664440
Epoch:297 Loss:1.87935627
Epoch:298 Loss:1.88459384
Epoch:299 Loss:1.87935627
Epoch:300 Loss:1.87935627
Epoch:301 Loss:1.76124740
Epoch:302 Loss:1.87935627
Epoch:303 Loss:1.76124740
Epoch:304 Loss:1.88459384
Epoch:305 Loss:1.84664440
Epoch:306 Loss:1.91730571
Epoch:307 Loss:1.84664440
Epoch:308 Loss:1.87935627
Epoch:309 Loss:1.87935627
Epoch:310 Loss:2.00270271
Epoch:311 Loss:1.87935627
Epoch:312 Loss:1.76124740
Epoch:313 Loss:1.88459384
Epoch:314 Loss:1.87935627
Epoch:315 Loss:1.88459384
Epoch:316 Loss:1.76124740
Epoch:317 Loss:1.76124740
Epoch:318 Loss:1.76124740
Epoch:319 Loss:1.76124740
Epoch:320 Loss:1.76124740
Epoch:321 Loss:1.91730571
Epoch:322 Loss:1.76124740
Epoch:323 Loss:1.91730571
Epoch:324 Loss:2.00270271
Epoch:325 Loss:1.84664440
Epoch:326 Loss:1.88459384
Epoch:327 Loss:1.76124740
Epoch:328 Loss:1.84664440
Epoch:329 Loss:1.76124740
Epoch:330 Loss:1.88459384
Epoch:331 Loss:2.00270271
Epoch:332 Loss:1.91730571
Epoch:333 Loss:1.91730571
Epoch:334 Loss:2.00270271
Epoch:335 Loss:1.76124740
Epoch:336 Loss:1.91730571
Epoch:337 Loss:1.76124740
Epoch:338 Loss:2.00270271
Epoch:339 Loss:2.00270271
Epoch:340 Loss:2.00270271
Epoch:341 Loss:1.87935627
Epoch:342 Loss:1.91730571
Epoch:343 Loss:1.76124740
Epoch:344 Loss:1.76124740
Epoch:345 Loss:1.84664440
Epoch:346 Loss:1.91730571
Epoch:347 Loss:1.87935627
Epoch:348 Loss:1.84664440
Epoch:349 Loss:1.88459384
Epoch:350 Loss:2.00270271
Epoch:351 Loss:2.00270271
Epoch:352 Loss:1.91730571
Epoch:353 Loss:1.87935627
Epoch:354 Loss:1.76124740
Epoch:355 Loss:1.84664440
Epoch:356 Loss:2.00270271
Epoch:357 Loss:1.84664440
Epoch:358 Loss:1.76124740
Epoch:359 Loss:1.84664440
Epoch:360 Loss:1.84664440
Epoch:361 Loss:1.84664440
Epoch:362 Loss:1.88459384
Epoch:363 Loss:1.76124740
Epoch:364 Loss:1.84664440
Epoch:365 Loss:2.00270271
Epoch:366 Loss:1.84664440
Epoch:367 Loss:1.76124740
Epoch:368 Loss:1.91730571
Epoch:369 Loss:2.00270271
Epoch:370 Loss:1.84664440
Epoch:371 Loss:1.88459384
Epoch:372 Loss:1.84664440
Epoch:373 Loss:1.88459384
Epoch:374 Loss:2.00270271
Epoch:375 Loss:1.91730571
Epoch:376 Loss:1.76124740
Epoch:377 Loss:1.91730571
Epoch:378 Loss:1.84664440
Epoch:379 Loss:1.87935627
Epoch:380 Loss:1.76124740
Epoch:381 Loss:2.00270271
Epoch:382 Loss:1.87935627
Epoch:383 Loss:1.84664440
Epoch:384 Loss:1.88459384
Epoch:385 Loss:2.00270271
Epoch:386 Loss:2.00270271
Epoch:387 Loss:1.76124740
Epoch:388 Loss:1.88459384
Epoch:389 Loss:1.84664440
Epoch:390 Loss:1.91730571
Epoch:391 Loss:1.91730571
Epoch:392 Loss:1.91730571
Epoch:393 Loss:2.00270271
Epoch:394 Loss:1.84664440
Epoch:395 Loss:1.87935627
Epoch:396 Loss:1.76124740
Epoch:397 Loss:1.88459384
Epoch:398 Loss:1.88459384
Epoch:399 Loss:1.76124740
Epoch:400 Loss:1.88459384
Epoch:401 Loss:1.87935627
Epoch:402 Loss:1.87935627
Epoch:403 Loss:1.88459384
Epoch:404 Loss:1.91730571
Epoch:405 Loss:1.76124740
Epoch:406 Loss:1.91730571
Epoch:407 Loss:2.00270271
Epoch:408 Loss:2.00270271
Epoch:409 Loss:2.00270271
Epoch:410 Loss:1.87935627
Epoch:411 Loss:1.91730571
Epoch:412 Loss:1.88459384
Epoch:413 Loss:1.91730571
Epoch:414 Loss:1.84664440
Epoch:415 Loss:1.88459384
Epoch:416 Loss:1.91730571
Epoch:417 Loss:1.84664440
Epoch:418 Loss:2.00270271
Epoch:419 Loss:1.91730571
Epoch:420 Loss:1.76124740
Epoch:421 Loss:1.87935627
Epoch:422 Loss:1.88459384
Epoch:423 Loss:2.00270271
Epoch:424 Loss:1.84664440
Epoch:425 Loss:1.76124740
Epoch:426 Loss:1.76124740
Epoch:427 Loss:1.76124740
Epoch:428 Loss:1.88459384
Epoch:429 Loss:1.88459384
Epoch:430 Loss:1.76124740
Epoch:431 Loss:1.91730571
Epoch:432 Loss:1.87935627
Epoch:433 Loss:2.00270271
Epoch:434 Loss:1.84664440
Epoch:435 Loss:1.91730571
Epoch:436 Loss:1.88459384
Epoch:437 Loss:2.00270271
Epoch:438 Loss:1.91730571
Epoch:439 Loss:1.84664440
Epoch:440 Loss:1.84664440
Epoch:441 Loss:1.76124740
Epoch:442 Loss:1.76124740
Epoch:443 Loss:1.88459384
Epoch:444 Loss:1.76124740
Epoch:445 Loss:2.00270271
Epoch:446 Loss:1.76124740
Epoch:447 Loss:1.87935627
Epoch:448 Loss:1.84664440
Epoch:449 Loss:1.84664440
Epoch:450 Loss:1.88459384
Epoch:451 Loss:1.88459384
Epoch:452 Loss:1.76124740
Epoch:453 Loss:1.87935627
Epoch:454 Loss:2.00270271
Epoch:455 Loss:2.00270271
Epoch:456 Loss:1.91730571
Epoch:457 Loss:1.88459384
Epoch:458 Loss:1.87935627
Epoch:459 Loss:1.76124740
Epoch:460 Loss:1.87935627
Epoch:461 Loss:1.88459384
Epoch:462 Loss:1.84664440
Epoch:463 Loss:1.87935627
Epoch:464 Loss:1.76124740
Epoch:465 Loss:1.76124740
Epoch:466 Loss:2.00270271
Epoch:467 Loss:1.88459384
Epoch:468 Loss:2.00270271
Epoch:469 Loss:1.91730571
Epoch:470 Loss:1.84664440
Epoch:471 Loss:1.91730571
Epoch:472 Loss:1.76124740
Epoch:473 Loss:1.88459384
Epoch:474 Loss:1.88459384
Epoch:475 Loss:1.87935627
Epoch:476 Loss:1.91730571
Epoch:477 Loss:1.84664440
Epoch:478 Loss:1.91730571
Epoch:479 Loss:1.87935627
Epoch:480 Loss:1.84664440
Epoch:481 Loss:1.88459384
Epoch:482 Loss:2.00270271
Epoch:483 Loss:1.84664440
Epoch:484 Loss:1.88459384
Epoch:485 Loss:1.91730571
Epoch:486 Loss:1.88459384
Epoch:487 Loss:1.91730571
Epoch:488 Loss:1.76124740
Epoch:489 Loss:2.00270271
Epoch:490 Loss:2.00270271
Epoch:491 Loss:1.76124740
Epoch:492 Loss:1.88459384
Epoch:493 Loss:1.84664440
Epoch:494 Loss:1.88459384
Epoch:495 Loss:1.84664440
Epoch:496 Loss:2.00270271
Epoch:497 Loss:1.87935627
Epoch:498 Loss:1.76124740
Epoch:499 Loss:1.88459384
Epoch:500 Loss:1.84664440
test_text = [sen.split()[:2] for sen in sentences]
test_text
[['i', 'like'], ['i', 'love'], ['i', 'hate'], ['i', 'think']]
tests =[]
for word in test_text[:]:
    inputs = [word2idx[n] for n in word]
    
    tests.append(np.eye(n_class)[inputs])


tests = torch.tensor(tests,dtype = torch.float32).to(device)
test_dataset = TensorDataset(tests)

test_loader = DataLoader(test_dataset,batch_size,shuffle=True)
print("done!")
done!
for step,x in enumerate(test_loader):
    print(step,x)
0 [tensor([[[0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0.]]], device='cuda:0')]
1 [tensor([[[0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1.]],

        [[0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0.]]], device='cuda:0')]
model.eval()
for step,x in enumerate(test_loader):
    predict = model(x[0]).data.max(1, keepdim=True)[1]
    print( "->",[idx2word[n.item()] for n in predict.cpu().data.squeeze()])
-> ['you', 'like']
-> ['you', 'milk']

你可能感兴趣的:(深度学习,pytorch,深度学习,自然语言处理)