DeepSeek源码解析(1)

下载github的 DeepSeek-V3-main源码,目录如下
DeepSeek源码解析(1)_第1张图片
文章适合入门小白学习,因为我也是小白,本来作为一名前端开发,因为行业不好混所以跑来学ai的。

初步看它的代码并不多,主要是inference目录,

convert.py

# 1. 导入标准库
import os # os 是Python的标准库之一,提供了与操作系统交互的功能,比如文件路径操作、环境变量管理等。
import shutil # shutil 也是Python的标准库,提供了高级的文件操作功能,比如复制、移动、删除文件或目录等。
from argparse import ArgumentParser # argparse 是Python的标准库,用于解析命令行参数。ArgumentParser 是 argparse 中的一个类,用于定义和解析命令行参数。
from glob import glob # glob 是Python的标准库,用于查找符合特定模式的文件路径。glob 函数可以根据通配符(如 *.txt)来匹配文件名。
from tqdm import tqdm, trange # tqdm 是一个用于显示进度条的第三方库。tqdm 函数可以用于迭代时显示进度条,trange 是 tqdm 的一个简化版本,专门用于 range 循环。

#2. 导入深度学习相关库
import torch #torch 是 PyTorch 库,是一个广泛使用的深度学习框架,提供了张量操作、自动求导、神经网络模型等功能。
from safetensors.torch import safe_open, save_file #safetensors 是一个用于安全地加载和保存张量的库
# safe_open 用于安全地打开张量文件,save_file 用于安全地保存张量文件。

mapping = {
    "embed_tokens": ("embed", 0),
    "input_layernorm": ("attn_norm", None),
    "post_attention_layernorm": ("ffn_norm", None),
    "q_proj": ("wq", 0),
    "q_a_proj": ("wq_a", None),
    "q_a_layernorm": ("q_norm", None),
    "q_b_proj": ("wq_b", 0),
    "kv_a_proj_with_mqa": ("wkv_a", None),
    "kv_a_layernorm": ("kv_norm", None),
    "kv_b_proj": ("wkv_b", 0),
    "o_proj": ("wo", 1),
    "gate": ("gate", None),
    "gate_proj": ("w1", 0),
    "down_proj": ("w2", 1),
    "up_proj": ("w3", 0),
    "norm": ("norm", None),
    "lm_head": ("head", 0),
    "scale": ("scale", None),
}

#定义名为 main 的函数,用于将模型检查点文件转换为指定格式并保存
def main(hf_ckpt_path, save_path, n_experts, mp):

    # hf_ckpt_path: 字符串类型,表示输入检查点文件所在的目录路径。
    # save_path: 字符串类型,表示转换后的检查点文件将要保存的目录路径。
    # n_experts: 整数类型,表示模型中专家(experts)的总数。
    # mp: 整数类型,表示模型并行(Model Parallelism)的因子。
    """
    Converts and saves model checkpoint files into a specified format.

    Args:
        hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
        save_path (str): Path to the directory where the converted checkpoint files will be saved.
        n_experts (int): Total number of experts in the model.
        mp (int): Model parallelism factor.
        
    Returns:
        None
    """

    # 这行代码设置了 PyTorch 使用的线程数为 8。这可以控制 PyTorch 在计算时使用的 CPU 线程数量,以优化性能。
    torch.set_num_threads(8)
    #  Python 的整数除法运算符 //
    # n_local_experts 表示每个模型并行分区中的专家数量。通过将总专家数 n_experts 除以模型并行因子 mp 得到。
    n_local_experts = n_experts // mp
    # state_dicts 是一个包含 mp 个空字典的列表。每个字典将用于存储一个模型并行分区中的模型状态。
    state_dicts = [{} for _ in range(mp)]

    # 这段代码的作用是遍历指定目录下所有以 .safetensors 结尾的文件,并逐个打开这些文件进行处理。
    # os.path.join(hf_ckpt_path, "*.safetensors"):将 hf_ckpt_path 路径与 *.safetensors 进行拼接,生成一个匹配模式,表示在 hf_ckpt_path 目录下查找所有以 .safetensors 结尾的文件。
    # glob(...):使用 glob 模块根据上述匹配模式查找所有符合条件的文件路径,返回一个文件路径列表。
    # tqdm 是一个用于显示进度条的库。这里用 tqdm 包裹 glob 的结果,表示在遍历文件路径时显示一个进度条,方便用户了解处理进度。
    # for file_path in ...:遍历 glob 返回的文件路径列表,每次循环中 file_path 变量会指向当前处理的文件路径。
    for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
        # with ... as f:使用 with 语句打开文件,确保文件在使用完毕后自动关闭。f 是文件对象,可以通过它访问文件内容。
        with safe_open(file_path, framework="pt", device="cpu") as f:
            for name in f.keys():
                # 这段代码是用于处理神经网络模型参数的Python代码
                if "model.layers.61" in name:  #  如果变量name中包含字符串"model.layers.61"
                    continue
                param: torch.Tensor = f.get_tensor(name) #  从某个函数或对象f中获取名为name的张量(Tensor),并将其赋值给变量param
                if name.startswith("model."):
                    name = name[len("model."):] # 则去掉这个前缀,只保留后面的部分
                name = name.replace("self_attn", "attn")
                name = name.replace("mlp", "ffn")
                name = name.replace("weight_scale_inv", "scale")
                name = name.replace("e_score_correction_bias", "bias")
                key = name.split(".")[-2]
                assert key in mapping, f"Key {key} not found in mapping" # 断言检查键值(key)是否存在于字典(mapping)中,如果不存在,则抛出异常,并显示错误信息。
                new_key, dim = mapping[key]
                name = name.replace(key, new_key)
                for i in range(mp):
                    new_param = param
                    if "experts" in name and "shared_experts" not in name: # 检查变量name是否包含字符串"experts",并且不包含字符串"shared_experts"
                        idx = int(name.split(".")[-3])
                        if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
                            continue
                    elif dim is not None:
                        assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
                        shard_size = param.size(dim) // mp
                        new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
                    state_dicts[i][name] = new_param

    os.makedirs(save_path, exist_ok=True) # 用于递归创建目录 save_path为创建的目录路径
    # exist_ok 设置为 True 时,如果目标目录已经存在,则不会抛出异常。如果设置为 False(默认值),当目标目录已经存在时会抛出一个异常。

    for i in trange(mp):
        save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))

    for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
        new_file_path = os.path.join(save_path, os.path.basename(file_path))
        shutil.copyfile(file_path, new_file_path)

# 这是Python脚本的标准入口点
if __name__ == "__main__":
    parser = ArgumentParser() # 创建了一个ArgumentParser对象,用于解析命令行参数。
    # 添加参数
    # 这个参数通常用于指定Hugging Face模型检查点的路径。
    parser.add_argument("--hf-ckpt-path", type=str, required=True) 
    # 这个参数通常用于指定保存模型或结果的路径
    parser.add_argument("--save-path", type=str, required=True)
    # 这个参数通常用于指定专家的数量。
    parser.add_argument("--n-experts", type=int, required=True)
    # 这个参数通常用于指定模型并行的数量。
    parser.add_argument("--model-parallel", type=int, required=True)
    # 解析命令行参数,并将结果存储在args对象中。args对象包含了所有通过命令行传递的参数值。
    args = parser.parse_args()
    # 断言语句 用于确保n_experts(专家数量)能够被model_parallel(模型并行数量)整除。如果不能整除,程序将抛出AssertionError,并显示错误信
    assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
    # 调用上面定义的main函数,并将解析后的命令行参数传递给它。main函数通常是脚本的主要逻辑部分,负责执行实际的任务。
    main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)

Python 知识补充

一、在 Python 中,导入库的常见格式,它们的使用场景和目的有所不同。以下是它们的区别和适用场景:
1、 import 库名

import math

result = math.sqrt(16)
print(result)  # 输出: 4.0

优点:明确知道函数或类的来源,避免命名冲突。代码可读性高,因为可以通过模块名清楚地知道函数或类的来源。
适用场景:当你需要导入整个模块,并且模块中的多个函数或类都会被使用时。
当你希望避免命名冲突时。

  1. from 模块名 import 库名
    这种格式用于从模块中导入特定的函数、类或变量。导入后,你可以直接使用这些函数、类或变量,而不需要通过模块名。
from math import sqrt

result = sqrt(16)
print(result)  # 输出: 4.0

优点:代码更简洁,因为可以直接使用函数或类,而不需要写模块名。
可以减少代码的冗余。
适用场景:当你只需要使用模块中的少数几个函数或类时。
当你希望代码更简洁时。

  1. from 模块名 import *
    这种格式用于导入模块中的所有内容。导入后,你可以直接使用模块中的所有函数、类或变量,而不需要通过模块名。
from math import *

result = sqrt(16)
print(result)  # 输出: 4.0

注意:
这种方式虽然方便,但容易引起命名冲突,尤其是当导入的模块中有与当前命名空间中相同名称的函数或变量时。
通常不推荐使用这种方式,除非你非常清楚自己在做什么。
总结
使用 import name 时,你需要通过模块名来访问其中的内容,适合导入整个模块。
使用 from name import name 时,你可以直接使用导入的内容,适合导入模块中的特定部分。
使用 from name import * 时,你可以直接使用模块中的所有内容,但容易引起命名冲突,需谨慎使用。
选择哪种导入方式取决于你的具体需求和代码的可读性。

二、 整数除法 //

result = 7 // 3  # result 的值为 2
result = 10 // 2  # result 的值为 5
result = 5 // 2  # result 的值为 2

三. 列表推导式(List Comprehension)
列表推导式是 Python 中一种简洁的创建列表的方式。 是 Python 中一种简洁且强大的语法,用于快速创建列表。它允许你在一行代码中生成列表,通常比使用传统的 for 循环更加简洁和易读。

基本语法是 [expression for item in iterable if condition]

expression: 对 item 进行操作的表达式,结果将作为新列表的元素。
item: 从 iterable 中取出的每个元素。
iterable: 可迭代对象,如列表、元组、字符串等。
condition (可选): 过滤条件,只有满足条件的 item 才会被处理。

示例

# 生成一个包含 0 到 9 的平方的列表:
squares = [x**2 for x in range(10)]
print(squares)
# 输出: [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

# 生成一个包含 0 到 9 中偶数的平方的列表:
even_squares = [x**2 for x in range(10) if x % 2 == 0]
print(even_squares)
# 输出: [0, 4, 16, 36, 64]

# 生成一个 3x3 的矩阵: 
matrix = [[i + j for j in range(3)] for i in range(3)]
print(matrix)
# 输出: [[0, 1, 2], [1, 2, 3], [2, 3, 4]]

四. range() 是 Python 中一个非常常用的内置函数
用于生成一个整数序列。它通常用于循环中,特别是在 for 循环中,用于迭代一定范围内的数字。以下是关于 range() 函数的一些基本用法和注意事项。
基本语法

range(stop)
range(start, stop)
range(start, stop, step)

start: 序列的起始值(包含),默认为 0。
stop: 序列的结束值(不包含)。
step: 序列的步长,默认为 1。

# 生成从 0 到 4 的整数序列
for i in range(5):
    print(i)

输出
0
1
2
3
4

生成从 1 到 10,步长为 2 的整数序列

for i in range(1, 11, 2):
    print(i)

输出
1
3
5
7
9

你可能感兴趣的:(deepseek,ai)