本项目基于 PyTorch 与 EfficientNet-B0,实现压力性损伤(Stage I, II, III, IV)图像分期分类。 当前训练脚本包含:
- 数据增强与标准化
- 分层划分训练/验证/测试集
- 类别不平衡加权损失
- 早停(Early Stopping)
- 学习率自适应衰减
- 测试阶段 TTA(水平翻转)
- 混淆矩阵与分类报告输出
Pressure-Ulcers-Stages-Classifier/
- train.py:模型训练、验证、测试与可视化主脚本
- cut.py:交互式 ROI 裁剪脚本(OpenCV 手动框选)
- dataset/:原始数据目录(可选)
- dataset_crop/:裁剪后的四分类数据目录(训练实际使用)
- saved_models/:历史模型保存目录
- best_efficientnet_b0_stage4.pth:训练后最佳权重(默认保存名)
训练脚本默认读取 dataset_crop 目录,且类别文件夹名需与下列完全一致:
- Stage_I
- Stage_II
- Stage_III
- Stage_IV
示例:
- dataset_crop/Stage_I/*.jpg
- dataset_crop/Stage_II/*.jpg
- dataset_crop/Stage_III/*.jpg
- dataset_crop/Stage_IV/*.jpg
当前数据统计(dataset_crop):
- Stage_I: 50
- Stage_II: 78
- Stage_III: 69
- Stage_IV: 59
- 合计: 256
建议 Python 版本:3.9+(推荐 3.10)
核心依赖:
- torch
- torchvision
- numpy
- pillow
- matplotlib
- seaborn
- scikit-learn
- tqdm
- opencv-python
可使用如下命令安装:
pip install torch torchvision numpy pillow matplotlib seaborn scikit-learn tqdm opencv-python
如果你需要从原始图片手动框选病灶区域,可使用 cut.py。 该脚本会逐张弹出窗口,使用鼠标框选 ROI 并保存到目标目录。
注意:
- cut.py 中 src_root 和 dst_root 当前是本机绝对路径,请先按你的机器路径修改。
- 每张图框选后按 Enter/Space 确认,按 Esc 可跳过当前图像。
运行:
python cut.py
运行主脚本:
python train.py
train.py 关键默认配置:
- 模型:EfficientNet-B0(ImageNet 预训练)
- 输入尺寸:224 x 224
- batch size:8
- epochs:50
- 学习率:1e-4(AdamW)
- 早停 patience:10
- 损失函数:带类别权重 + label smoothing 的交叉熵
- 学习率策略:ReduceLROnPlateau(监控验证集准确率)
- 数据划分:70% 训练,15% 验证,15% 测试(分层抽样)
训练完成后会输出:
- 最优模型权重(best_efficientnet_b0_stage4.pth)
- 训练/验证 Loss 曲线
- 训练/验证 Accuracy 曲线
- 学习率变化曲线
- 测试集 TTA 准确率
- 混淆矩阵
- classification report(Precision/Recall/F1)
为提升可复现性,脚本已固定随机种子(2024)。 建议进一步确保:
- 固定同一版本的依赖库
- 保持数据目录结构一致
- 使用同一硬件与 CUDA 环境(若需要严格对齐)
本项目用于科研与学习目的,不可替代专业医生诊断。 在真实医疗场景中应结合多学科评估与临床流程进行决策。