bookworm-smart-assistant/skills/ai-ml-expert/scripts/train_utils.py

244 lines
7.8 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""
AI/ML 训练工具函数
Training Utility Functions
"""
import torch
import torch.nn as nn
import numpy as np
from typing import Dict, Optional, Callable
import time
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EarlyStopping:
"""早停机制"""
def __init__(self, patience: int = 5, min_delta: float = 0, mode: str = 'min'):
self.patience = patience
self.min_delta = min_delta
self.mode = mode
self.counter = 0
self.best_score = None
self.early_stop = False
def __call__(self, score: float) -> bool:
if self.best_score is None:
self.best_score = score
return False
if self.mode == 'min':
improved = score < self.best_score - self.min_delta
else:
improved = score > self.best_score + self.min_delta
if improved:
self.best_score = score
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return self.early_stop
class AverageMeter:
"""计算和存储平均值"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val: float, n: int = 1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def set_seed(seed: int = 42):
"""设置随机种子"""
import random
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def get_device() -> torch.device:
"""获取可用设备"""
if torch.cuda.is_available():
return torch.device('cuda')
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return torch.device('mps')
return torch.device('cpu')
def count_parameters(model: nn.Module) -> int:
"""统计模型参数量"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer,
epoch: int, loss: float, path: str):
"""保存检查点"""
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, path)
logger.info(f"Checkpoint saved to {path}")
def load_checkpoint(model: nn.Module, optimizer: Optional[torch.optim.Optimizer],
path: str, device: torch.device) -> int:
"""加载检查点"""
checkpoint = torch.load(path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
if optimizer:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
logger.info(f"Checkpoint loaded from {path}")
return checkpoint['epoch']
def get_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_type: str,
num_epochs: int, **kwargs):
"""获取学习率调度器"""
if scheduler_type == 'cosine':
return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
elif scheduler_type == 'step':
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=kwargs.get('step_size', 10))
elif scheduler_type == 'plateau':
return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)
elif scheduler_type == 'warmup_cosine':
from transformers import get_cosine_schedule_with_warmup
return get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=kwargs.get('warmup_steps', 0),
num_training_steps=kwargs.get('total_steps', num_epochs)
)
else:
return None
class Trainer:
"""通用训练器"""
def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer,
criterion: nn.Module, device: torch.device,
scheduler: Optional = None):
self.model = model.to(device)
self.optimizer = optimizer
self.criterion = criterion
self.device = device
self.scheduler = scheduler
self.history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
def train_epoch(self, dataloader) -> Dict[str, float]:
self.model.train()
loss_meter = AverageMeter()
acc_meter = AverageMeter()
for batch in dataloader:
x = batch['x'].to(self.device)
y = batch['y'].to(self.device)
self.optimizer.zero_grad()
outputs = self.model(x)
loss = self.criterion(outputs, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# 计算准确率
preds = outputs.argmax(dim=1)
acc = (preds == y).float().mean().item()
loss_meter.update(loss.item(), x.size(0))
acc_meter.update(acc, x.size(0))
return {'loss': loss_meter.avg, 'acc': acc_meter.avg}
@torch.no_grad()
def evaluate(self, dataloader) -> Dict[str, float]:
self.model.eval()
loss_meter = AverageMeter()
acc_meter = AverageMeter()
for batch in dataloader:
x = batch['x'].to(self.device)
y = batch['y'].to(self.device)
outputs = self.model(x)
loss = self.criterion(outputs, y)
preds = outputs.argmax(dim=1)
acc = (preds == y).float().mean().item()
loss_meter.update(loss.item(), x.size(0))
acc_meter.update(acc, x.size(0))
return {'loss': loss_meter.avg, 'acc': acc_meter.avg}
def fit(self, train_loader, val_loader, epochs: int,
early_stopping: Optional[EarlyStopping] = None,
save_path: Optional[str] = None):
best_val_loss = float('inf')
for epoch in range(epochs):
start_time = time.time()
train_metrics = self.train_epoch(train_loader)
val_metrics = self.evaluate(val_loader)
if self.scheduler:
self.scheduler.step()
# 记录历史
self.history['train_loss'].append(train_metrics['loss'])
self.history['val_loss'].append(val_metrics['loss'])
self.history['train_acc'].append(train_metrics['acc'])
self.history['val_acc'].append(val_metrics['acc'])
elapsed = time.time() - start_time
logger.info(
f"Epoch {epoch+1}/{epochs} ({elapsed:.1f}s) - "
f"train_loss: {train_metrics['loss']:.4f}, train_acc: {train_metrics['acc']:.4f}, "
f"val_loss: {val_metrics['loss']:.4f}, val_acc: {val_metrics['acc']:.4f}"
)
# 保存最优模型
if val_metrics['loss'] < best_val_loss:
best_val_loss = val_metrics['loss']
if save_path:
save_checkpoint(self.model, self.optimizer, epoch, val_metrics['loss'], save_path)
# 早停
if early_stopping and early_stopping(val_metrics['loss']):
logger.info(f"Early stopping at epoch {epoch+1}")
break
return self.history
if __name__ == '__main__':
# 测试
set_seed(42)
device = get_device()
print(f"Using device: {device}")
# 简单模型测试
model = nn.Linear(10, 2)
print(f"Parameters: {count_parameters(model)}")