20250301-代码笔记-test_n100

文章目录

  • 前言
  • 一、参数解析
    • 1.1具体参数含义
      • 1. 环境参数 (`env_params`)
      • 2. 模型参数 (`model_params`)
      • 3. 测试参数 (tester_params_regret)
      • 4. 日志参数 (logger_params)
    • 1.2分析
      • 1.2.1如何在脚本运行时输入model路径
    • 1.3代码
  • 二、函数 main()
    • 函数解析
    • 函数代码
  • 三、函数def _set_debug_mode()
    • 函数解析
    • 函数代码
  • 四、函数def _print_config()
    • 函数解析
    • 函数代码
  • 附录
    • 代码(全)


前言

讲解脚本test_n100.py中的代码。

/home/tang/RL_exa/NCO_code-main/single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/test_n100.py


一、参数解析

1.1具体参数含义

1. 环境参数 (env_params)

  • problem_size: 这个参数指定了问题的规模。
  • pomo_size:智能体数量

2. 模型参数 (model_params)

  • embedding_dim: 该参数指定嵌入层的维度,通常在神经网络中用来将输入数据映射到一个低维度的向量空间。
  • sqrt_embedding_dim: 这是 embedding_dim 的平方根,通常在某些模型计算中作为特征。
  • encoder_layer_num: 该参数表示模型中的编码器层数。
  • qkv_dim: 该参数用于定义查询(Query)、键(Key)和值(Value)在自注意力机制中的维度。
  • head_num: 在多头注意力机制中,头的数量决定了模型能够并行计算多少组查询、键和值。
  • logit_clipping: 该参数通常用于数值稳定性,特别是在生成模型输出(如Logits)时进行裁剪。设置值为 10,可能是为了防止输出值过大或过小,确保训练时的数值稳定。
  • ff_hidden_dim: 该参数表示前馈神经网络(Feed-Forward Network)中的隐藏层维度。
  • eval_type: 这个参数定义了模型评估时使用的策略。这里的 argmax 表示评估时选择最大概率的动作或预测,通常用于选择最优解或决策。

3. 测试参数 (tester_params_regret)

  • use_cuda: 这个参数指示是否使用 CUDA(即是否使用 GPU)进行加速运算。 它的值是 USE_CUDA,这取决于上文定义的 USE_CUDA,它会在 DEBUG_MODEFalse 时启用 GPU 加速,反之则禁用。
  • cuda_device_num: 这个参数指定使用的 CUDA 设备编号。
  • model_load
    • path: 模型预训练文件的路径。
    • epoch: 指定加载的模型训练到哪一轮。
  • test_episodes: 这个参数设置了测试时的总回合数。10000 表示在测试过程中进行 10000 次回合的测试。
  • test_batch_size: 这个参数指定了每个测试批次的大小,意味着每个批次中包含 1000 个样本。
  • augmentation_enable: 这是一个布尔值,表示是否启用数据增强。
  • aug_factor: 这个参数指定了数据增强的因子。8 表示在数据增强时,数据集会被扩展为原来的 8 倍,增加了训练数据的多样性。
  • aug_batch_size: 如果启用了数据增强,aug_batch_size 参数指定了增强后的每个批次的大小。这里设置为 400,意味着每个增强的批次将包含 400 个样本。
  • test_data_load: 这是一个字典,用来配置测试数据加载的参数:
    • enable: 这个参数指示是否启用从文件加载测试数据。
    • filename: 指定包含测试数据的文件路径。
  • if tester_params_regret['augmentation_enable']:: 这一段检查是否启用了数据增强(augmentation_enable)。如果启用了数据增强,测试批次大小(test_batch_size)会被重新设置为 aug_batch_size(即 400)。

4. 日志参数 (logger_params)

  • log_file: 这个字典包含了日志文件的配置:
    • desc: 这是日志的描述信息,通常用于标识当前测试的目标或描述。这些描述性信息有助于辨别不同的日志文件。这里设置为 'test__CVRP100',表明日志与 CVRP100 测试相关。
    • filename: 这是日志文件的名称,设置为 'log.txt',意味着所有日志信息将被记录到名为 log.txt 的文件中。

1.2分析

1.2.1如何在脚本运行时输入model路径


tester_params_regret = {

    'model_load': {
        'path': '../../pretrained/vrp100',
        
        'epoch': 8100,
    },
   # 其他参数
    },
}

设置初始参数model_loadpathepoch
pathepoch是生成存储模型地址的必要参数,将参数给到CVRPTester.py中的class CVRPTester:__init__(self,env_params,model_params,tester_params)

single_objective/LCH-Regret/Regret-POMO/CVRP/POMO/CVRPTester.py

        checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)

1.3代码

##########################################################################################
# parameters

env_params = {
    'problem_size': 100,
    'pomo_size': 100,
}

model_params = {
    'embedding_dim': 128,
    'sqrt_embedding_dim': 128**(1/2),
    'encoder_layer_num': 6,
    'qkv_dim': 16,
    'head_num': 8,
    'logit_clipping': 10,
    'ff_hidden_dim': 512,
    'eval_type': 'argmax',
}

tester_params_regret = {
    'use_cuda': USE_CUDA,
    'cuda_device_num': CUDA_DEVICE_NUM,
    'model_load': {
        'path': '../../pretrained/vrp100',
        'epoch': 8100,
    },
    'test_episodes': 10000,
    'test_batch_size': 1000,
    'augmentation_enable': True,
    'aug_factor': 8,
    'aug_batch_size': 400,
    'test_data_load': {
        'enable': True,
        'filename': '../../../data/vrp100.pt'
    },
}

if tester_params_regret['augmentation_enable']:
    tester_params_regret['test_batch_size'] = tester_params_regret['aug_batch_size']

logger_params = {
    'log_file': {
        'desc': 'test__CVRP100',
        'filename': 'log.txt'
    }
}


二、函数 main()

函数解析

执行流程图链接

20250301-代码笔记-test_n100_第1张图片

函数代码

def main():

    if DEBUG_MODE:
        _set_debug_mode()

    create_logger(**logger_params)
    _print_config()


    tester_regret = Tester_regret(env_params=env_params,
                    model_params=model_params,
                    tester_params=tester_params_regret)

    copy_all_src(tester_regret.result_folder)

    tester_regret.run()


三、函数def _set_debug_mode()

函数解析

根据设置调整测试时的参数。

函数代码

def _set_debug_mode():
    global tester_params_regret
    tester_params_regret['test_episodes'] = 100



四、函数def _print_config()

函数解析

打印配置信息 (_print_config)。

函数代码

def _print_config():
    logger = logging.getLogger('root')
    logger.info('DEBUG_MODE: {}'.format(DEBUG_MODE))
    logger.info('USE_CUDA: {}, CUDA_DEVICE_NUM: {}'.format(USE_CUDA, CUDA_DEVICE_NUM))
    [logger.info(g_key + "{}".format(globals()[g_key])) for g_key in globals().keys() if g_key.endswith('params')]



附录

代码(全)

##########################################################################################
# Machine Environment Config

DEBUG_MODE = False
USE_CUDA = not DEBUG_MODE
CUDA_DEVICE_NUM = 0


##########################################################################################
# Path Config

import os
import sys

os.chdir(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, "..")  # for problem_def
sys.path.insert(0, "../..")  # for utils


##########################################################################################
# import
import torch
import logging
from utils.utils import create_logger, copy_all_src

from CVRPTester import CVRPTester as Tester_regret

##########################################################################################
# parameters

env_params = {
    'problem_size': 100,
    'pomo_size': 100,
}

model_params = {
    'embedding_dim': 128,
    'sqrt_embedding_dim': 128**(1/2),
    'encoder_layer_num': 6,
    'qkv_dim': 16,
    'head_num': 8,
    'logit_clipping': 10,
    'ff_hidden_dim': 512,
    'eval_type': 'argmax',
}

tester_params_regret = {
    'use_cuda': USE_CUDA,
    'cuda_device_num': CUDA_DEVICE_NUM,
    'model_load': {
        'path': '../../pretrained/vrp100',
        'epoch': 8100,
    },
    'test_episodes': 10000,
    'test_batch_size': 1000,
    'augmentation_enable': True,
    'aug_factor': 8,
    'aug_batch_size': 400,
    'test_data_load': {
        'enable': True,
        'filename': '../../../data/vrp100.pt'
    },
}

if tester_params_regret['augmentation_enable']:
    tester_params_regret['test_batch_size'] = tester_params_regret['aug_batch_size']

logger_params = {
    'log_file': {
        'desc': 'test__CVRP100',
        'filename': 'log.txt'
    }
}

##########################################################################################
# main

def main():

    if DEBUG_MODE:
        _set_debug_mode()

    create_logger(**logger_params)
    _print_config()


    tester_regret = Tester_regret(env_params=env_params,
                    model_params=model_params,
                    tester_params=tester_params_regret)

    copy_all_src(tester_regret.result_folder)

    tester_regret.run()


def _set_debug_mode():
    global tester_params_regret
    tester_params_regret['test_episodes'] = 100


def _print_config():
    logger = logging.getLogger('root')
    logger.info('DEBUG_MODE: {}'.format(DEBUG_MODE))
    logger.info('USE_CUDA: {}, CUDA_DEVICE_NUM: {}'.format(USE_CUDA, CUDA_DEVICE_NUM))
    [logger.info(g_key + "{}".format(globals()[g_key])) for g_key in globals().keys() if g_key.endswith('params')]



##########################################################################################

if __name__ == "__main__":
    main()

你可能感兴趣的:(代码学习笔记,笔记)