Datawhale Al春训营 --RNA结构预测(AI+创新药)代码记录

# 模型训练与验证
class RNADataset(torch.utils.data.Dataset):
    def __init__(self, coords_dir, seqs_dir):
        self.samples = []
        
        # 读取所有数据并转换为图
        for fname in os.listdir(coords_dir):
            # 加载坐标数据
            coord = np.load(os.path.join(coords_dir, fname))  # [L, 7, 3]
            coord = np.nan_to_num(coord, nan=0.0)  # 新增行:将NaN替换为0
            # 加载对应序列
            seq_id = os.path.splitext(fname)[0]
            seq = next(SeqIO.parse(os.path.join(seqs_dir, f"{seq_id}.fasta"), "fasta")).seq
            
            # 转换为图结构
            graph = RNAGraphBuilder.build_graph(coord, str(seq))
            self.samples.append(graph)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

# 简单GNN模型
class GNNModel(nn.Module):
    def __init__(self):
        super().__init__()
        # 特征编码
        self.encoder = nn.Sequential(
            nn.Linear(7*3, Config.hidden_dim),
            nn.ReLU()
        )
        
        # GNN层
        self.conv1 = GCNConv(Config.hidden_dim, Config.hidden_dim)
        self.conv2 = GCNConv(Config.hidden_dim, Config.hidden_dim)
        
        # 分类头
        self.cls_head = nn.Sequential(
            nn.Linear(Config.hidden_dim, len(Config.seq_vocab))
        )
        
    def forward(self, data):
        # 节点特征编码
        x = self.encoder(data.x)  # [N, hidden]
        
        # 图卷积
        x = self.conv1(x, data.edge_index)
        x = torch.relu(x)
        x = self.conv2(x, data.edge_index)
        x = torch.relu(x)
        
        # 节点分类
        logits = self.cls_head(x)  # [N, 4]
        return logits

# 训练函数
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for batch in loader:
        batch = batch.to(Config.device)
        optimizer.zero_grad()
        
        # 前向传播
        logits = model(batch)
        
        # 计算损失
        loss = criterion(logits, batch.y)
        
        # 反向传播
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(loader)

# 评估函数
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(Config.device)
            logits = model(batch)
            preds = logits.argmax(dim=1)
            correct += (preds == batch.y).sum().item()
            total += batch.y.size(0)
    return correct / total

# 主流程
if __name__ == "__main__":
    # 设置随机种子
    torch.manual_seed(Config.seed)
    
    # 加载数据集
    full_dataset = RNADataset("./RNAdesignv1/train/coords", "./RNAdesignv1/train/seqs")
    
    # 划分数据集
    train_size = int(0.8 * len(full_dataset))
    val_size = (len(full_dataset) - train_size) // 2
    test_size = len(full_dataset) - train_size - val_size
    train_set, val_set, test_set = torch.utils.data.random_split(
        full_dataset, [train_size, val_size, test_size])
    
    # 创建DataLoader
    train_loader = torch_geometric.loader.DataLoader(
        train_set, batch_size=Config.batch_size, shuffle=True)
    val_loader = torch_geometric.loader.DataLoader(val_set, batch_size=Config.batch_size)
    test_loader = torch_geometric.loader.DataLoader(test_set, batch_size=Config.batch_size)
    
    # 初始化模型
    model = GNNModel().to(Config.device)
    optimizer = optim.Adam(model.parameters(), lr=Config.lr)
    criterion = nn.CrossEntropyLoss()
    
    # 训练循环
    best_acc = 0
    for epoch in range(Config.epochs):
        train_loss = train(model, train_loader, optimizer, criterion)
        val_acc = evaluate(model, val_loader)
        
        print(f"Epoch {epoch+1}/{Config.epochs}")
        print(f"Train Loss: {train_loss:.4f} | Val Acc: {val_acc:.4f}")
        
        # 保存最佳模型
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "best_gnn_model.pth")
    
    # 最终测试
    model.load_state_dict(torch.load("best_gnn_model.pth",weights_only=True))
    test_acc = evaluate(model, test_loader)
    print(f"\nTest Accuracy: {test_acc:.4f}")

你可能感兴趣的:(人工智能,深度学习,机器学习)