244 lines
7.8 KiB
Python
244 lines
7.8 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
AI/ML 训练工具函数
|
|
Training Utility Functions
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
from typing import Dict, Optional, Callable
|
|
import time
|
|
import logging
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class EarlyStopping:
|
|
"""早停机制"""
|
|
def __init__(self, patience: int = 5, min_delta: float = 0, mode: str = 'min'):
|
|
self.patience = patience
|
|
self.min_delta = min_delta
|
|
self.mode = mode
|
|
self.counter = 0
|
|
self.best_score = None
|
|
self.early_stop = False
|
|
|
|
def __call__(self, score: float) -> bool:
|
|
if self.best_score is None:
|
|
self.best_score = score
|
|
return False
|
|
|
|
if self.mode == 'min':
|
|
improved = score < self.best_score - self.min_delta
|
|
else:
|
|
improved = score > self.best_score + self.min_delta
|
|
|
|
if improved:
|
|
self.best_score = score
|
|
self.counter = 0
|
|
else:
|
|
self.counter += 1
|
|
if self.counter >= self.patience:
|
|
self.early_stop = True
|
|
|
|
return self.early_stop
|
|
|
|
|
|
class AverageMeter:
|
|
"""计算和存储平均值"""
|
|
def __init__(self):
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.val = 0
|
|
self.avg = 0
|
|
self.sum = 0
|
|
self.count = 0
|
|
|
|
def update(self, val: float, n: int = 1):
|
|
self.val = val
|
|
self.sum += val * n
|
|
self.count += n
|
|
self.avg = self.sum / self.count
|
|
|
|
|
|
def set_seed(seed: int = 42):
|
|
"""设置随机种子"""
|
|
import random
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
def get_device() -> torch.device:
|
|
"""获取可用设备"""
|
|
if torch.cuda.is_available():
|
|
return torch.device('cuda')
|
|
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
|
return torch.device('mps')
|
|
return torch.device('cpu')
|
|
|
|
|
|
def count_parameters(model: nn.Module) -> int:
|
|
"""统计模型参数量"""
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
|
def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer,
|
|
epoch: int, loss: float, path: str):
|
|
"""保存检查点"""
|
|
torch.save({
|
|
'epoch': epoch,
|
|
'model_state_dict': model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'loss': loss
|
|
}, path)
|
|
logger.info(f"Checkpoint saved to {path}")
|
|
|
|
|
|
def load_checkpoint(model: nn.Module, optimizer: Optional[torch.optim.Optimizer],
|
|
path: str, device: torch.device) -> int:
|
|
"""加载检查点"""
|
|
checkpoint = torch.load(path, map_location=device)
|
|
model.load_state_dict(checkpoint['model_state_dict'])
|
|
if optimizer:
|
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
logger.info(f"Checkpoint loaded from {path}")
|
|
return checkpoint['epoch']
|
|
|
|
|
|
def get_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_type: str,
|
|
num_epochs: int, **kwargs):
|
|
"""获取学习率调度器"""
|
|
if scheduler_type == 'cosine':
|
|
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
|
|
elif scheduler_type == 'step':
|
|
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=kwargs.get('step_size', 10))
|
|
elif scheduler_type == 'plateau':
|
|
return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)
|
|
elif scheduler_type == 'warmup_cosine':
|
|
from transformers import get_cosine_schedule_with_warmup
|
|
return get_cosine_schedule_with_warmup(
|
|
optimizer,
|
|
num_warmup_steps=kwargs.get('warmup_steps', 0),
|
|
num_training_steps=kwargs.get('total_steps', num_epochs)
|
|
)
|
|
else:
|
|
return None
|
|
|
|
|
|
class Trainer:
|
|
"""通用训练器"""
|
|
def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer,
|
|
criterion: nn.Module, device: torch.device,
|
|
scheduler: Optional = None):
|
|
self.model = model.to(device)
|
|
self.optimizer = optimizer
|
|
self.criterion = criterion
|
|
self.device = device
|
|
self.scheduler = scheduler
|
|
self.history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
|
|
|
|
def train_epoch(self, dataloader) -> Dict[str, float]:
|
|
self.model.train()
|
|
loss_meter = AverageMeter()
|
|
acc_meter = AverageMeter()
|
|
|
|
for batch in dataloader:
|
|
x = batch['x'].to(self.device)
|
|
y = batch['y'].to(self.device)
|
|
|
|
self.optimizer.zero_grad()
|
|
outputs = self.model(x)
|
|
loss = self.criterion(outputs, y)
|
|
loss.backward()
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
|
self.optimizer.step()
|
|
|
|
# 计算准确率
|
|
preds = outputs.argmax(dim=1)
|
|
acc = (preds == y).float().mean().item()
|
|
|
|
loss_meter.update(loss.item(), x.size(0))
|
|
acc_meter.update(acc, x.size(0))
|
|
|
|
return {'loss': loss_meter.avg, 'acc': acc_meter.avg}
|
|
|
|
@torch.no_grad()
|
|
def evaluate(self, dataloader) -> Dict[str, float]:
|
|
self.model.eval()
|
|
loss_meter = AverageMeter()
|
|
acc_meter = AverageMeter()
|
|
|
|
for batch in dataloader:
|
|
x = batch['x'].to(self.device)
|
|
y = batch['y'].to(self.device)
|
|
|
|
outputs = self.model(x)
|
|
loss = self.criterion(outputs, y)
|
|
|
|
preds = outputs.argmax(dim=1)
|
|
acc = (preds == y).float().mean().item()
|
|
|
|
loss_meter.update(loss.item(), x.size(0))
|
|
acc_meter.update(acc, x.size(0))
|
|
|
|
return {'loss': loss_meter.avg, 'acc': acc_meter.avg}
|
|
|
|
def fit(self, train_loader, val_loader, epochs: int,
|
|
early_stopping: Optional[EarlyStopping] = None,
|
|
save_path: Optional[str] = None):
|
|
best_val_loss = float('inf')
|
|
|
|
for epoch in range(epochs):
|
|
start_time = time.time()
|
|
|
|
train_metrics = self.train_epoch(train_loader)
|
|
val_metrics = self.evaluate(val_loader)
|
|
|
|
if self.scheduler:
|
|
self.scheduler.step()
|
|
|
|
# 记录历史
|
|
self.history['train_loss'].append(train_metrics['loss'])
|
|
self.history['val_loss'].append(val_metrics['loss'])
|
|
self.history['train_acc'].append(train_metrics['acc'])
|
|
self.history['val_acc'].append(val_metrics['acc'])
|
|
|
|
elapsed = time.time() - start_time
|
|
logger.info(
|
|
f"Epoch {epoch+1}/{epochs} ({elapsed:.1f}s) - "
|
|
f"train_loss: {train_metrics['loss']:.4f}, train_acc: {train_metrics['acc']:.4f}, "
|
|
f"val_loss: {val_metrics['loss']:.4f}, val_acc: {val_metrics['acc']:.4f}"
|
|
)
|
|
|
|
# 保存最优模型
|
|
if val_metrics['loss'] < best_val_loss:
|
|
best_val_loss = val_metrics['loss']
|
|
if save_path:
|
|
save_checkpoint(self.model, self.optimizer, epoch, val_metrics['loss'], save_path)
|
|
|
|
# 早停
|
|
if early_stopping and early_stopping(val_metrics['loss']):
|
|
logger.info(f"Early stopping at epoch {epoch+1}")
|
|
break
|
|
|
|
return self.history
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# 测试
|
|
set_seed(42)
|
|
device = get_device()
|
|
print(f"Using device: {device}")
|
|
|
|
# 简单模型测试
|
|
model = nn.Linear(10, 2)
|
|
print(f"Parameters: {count_parameters(model)}")
|