本文简单介绍了os.path 模块以及其在深度学习数据处理的使用。
功能:跨平台拼接路径,自动处理路径分隔符(Windows 用 \,Linux/macOS 用 /)。
语法:os.path.join(path1, path2, …)
import os
# 拼接路径
data_dir = os.path.join("data", "images", "train") # 输出:data/images/train (Linux) 或 data\images\train (Windows)
model_path = os.path.join("models", "resnet", "version_1.pt")
功能:将相对路径转换为绝对路径。
abs_path = os.path.abspath("data/images") # 输出当前工作目录下的完整路径,如 /home/user/project/data/images
dirname():提取路径的目录部分。
basename():提取路径的文件名或末尾目录名。
path = "/home/user/data/train/image.jpg"
dir_part = os.path.dirname(path) # 输出:/home/user/data/train
file_part = os.path.basename(path) # 输出:image.jpg
检查路径是否存在。
if not os.path.exists("models"):
os.makedirs("models") # 创建目录
功能:将路径拆分为目录和文件名两部分。
dir_name, file_name = os.path.split("/data/images/cat.jpg") # 输出:('/data/images', 'cat.jpg')
import os
from torchvision import datasets
# 定义数据集根目录和子目录
data_root = os.path.join("data", "cifar10")
train_dir = os.path.join(data_root, "train")
test_dir = os.path.join(data_root, "test")
# 自动创建目录(如果不存在)
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)
# 加载数据集(示例:PyTorch)
train_dataset = datasets.CIFAR10(root=train_dir, train=True, download=True)
test_dataset = datasets.CIFAR10(root=test_dir, train=False, download=True)
import os
import torch
# 定义模型保存目录和文件名
model_dir = os.path.join("saved_models", "resnet")
os.makedirs(model_dir, exist_ok=True) # 确保目录存在
# 按时间或版本号生成唯一文件名
model_name = "resnet50_epoch10.pt"
model_path = os.path.join(model_dir, model_name)
# 保存模型
torch.save(model.state_dict(), model_path)
# 加载模型
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path))
import os
import json
# 读取配置文件(假设配置文件在项目根目录下的 configs 文件夹)
config_dir = os.path.join(os.path.dirname(__file__), "configs") # __file__ 是当前脚本路径
config_path = os.path.join(config_dir, "hyperparams.json")
# 加载配置
with open(config_path, "r") as f:
config = json.load(f)
# 使用配置中的路径(例如数据集路径)
dataset_path = os.path.join(config["data_root"], config["dataset_name"])
避免硬编码分隔符:使用 os.path.join() 代替手动拼接(如 data + “/” + “images”)。
统一大小写:Windows 路径不区分大小写,但 Linux 区分。
使用 os.path.normpath() 处理路径中的冗余符号(如 …/ 或 //):path os.path.normpath(“data//images/…/train”) # 输出:data/train
home_dir = os.path.expanduser("~") # 输出:/home/user (Linux) 或 C:\Users\user (Windows)
data_path = os.path.join(os.environ["DATA_ROOT"], "dataset") # 需预先定义 DATA_ROOT 环境变量
# 遍历目录下所有文件和子目录
for root, dirs, files in os.walk("data"):
print(f"当前目录:{root}")
print(f"子目录:{dirs}")
print(f"文件:{files}")
file_name = "image.jpg"
ext = os.path.splitext(file_name)[1] # 输出:.jpg
核心工具:os.path.join() 是跨平台路径操作的核心,结合 os.makedirs()、os.path.exists() 等函数,可确保路径安全和兼容性。
深度学习应用:在数据加载、模型保存、配置管理中,合理组织路径是提高代码可维护性的关键。
最佳实践:始终使用 os.path 处理路径,避免手动拼接,并在关键操作前检查路径是否存在。