bookworm-smart-assistant/scripts/route-feedback.js

586 lines
20 KiB
JavaScript
Raw Normal View History

#!/usr/bin/env node
/**
* 路由反馈权重学习 (v4.2)
*
* 记录路由纠正 分析模式 生成权重调整
* 配合 route-analyzer.js 实现闭环学习
*
* 用法:
* node route-feedback.js --correct "<query>" <actualSkill> 记录纠正
* node route-feedback.js --learn 生成权重调整
* node route-feedback.js --report 反馈统计
* node route-feedback.js --json JSON 输出
*
* 文件:
* debug/route-feedback.jsonl 纠正记录
* debug/route-weights.json 学习后的权重增量
*/
const fs = require('fs');
const path = require('path');
const detectClaudeRoot = () => require('./paths.config.js').PATHS.root;
const ROOT = detectClaudeRoot();
const INDEX_FILE = path.join(ROOT, 'skills-index.json');
const DEBUG_DIR = path.join(ROOT, 'debug');
const FEEDBACK_FILE = path.join(DEBUG_DIR, 'route-feedback.jsonl');
const WEIGHTS_FILE = path.join(DEBUG_DIR, 'route-weights.json');
const WEIGHTS_HISTORY_DIR = path.join(DEBUG_DIR, 'weights-history');
const HOLDOUT_FILE = path.join(DEBUG_DIR, 'holdout-set.json');
const MAX_WEIGHT_SNAPSHOTS = 20;
const HOLDOUT_RATIO = 0.3; // 30% holdout
// === 参数解析 ===
const args = process.argv.slice(2);
const correctMode = args.includes('--correct');
const learnMode = args.includes('--learn');
const reportMode = args.includes('--report');
const jsonMode = args.includes('--json');
const validateMode = args.includes('--validate');
const splitMode = args.includes('--split');
// === 工具函数 ===
function ensureDebugDir() {
if (!fs.existsSync(DEBUG_DIR)) fs.mkdirSync(DEBUG_DIR, { recursive: true });
}
/** 从 skills-index.json 加载有效技能名白名单 */
function loadSkillWhitelist() {
const index = loadIndex();
if (!index || !index.skills) return null;
return new Set(index.skills.map(s => s.name));
}
function loadIndex() {
if (!fs.existsSync(INDEX_FILE)) return null;
return JSON.parse(fs.readFileSync(INDEX_FILE, 'utf8'));
}
function loadFeedback() {
if (!fs.existsSync(FEEDBACK_FILE)) return [];
return fs.readFileSync(FEEDBACK_FILE, 'utf8')
.split('\n')
.filter(Boolean)
.map(line => { try { return JSON.parse(line); } catch { return null; } })
.filter(Boolean);
}
function loadRouteLog(dateStr) {
const logFile = path.join(DEBUG_DIR, `route-${dateStr}.jsonl`);
if (!fs.existsSync(logFile)) return [];
return fs.readFileSync(logFile, 'utf8')
.split('\n')
.filter(Boolean)
.map(line => { try { return JSON.parse(line); } catch { return null; } })
.filter(Boolean);
}
/** P0-3: 确定性哈希 — 基于 query 字符串生成稳定 hash */
function simpleHash(str) {
let hash = 0;
for (let i = 0; i < str.length; i++) {
hash = ((hash << 5) - hash + str.charCodeAt(i)) | 0;
}
return Math.abs(hash);
}
/** P0-3: 拆分训练集/holdout 集,生成 holdout-set.json */
function splitHoldout() {
const feedback = loadFeedback();
const corrections = feedback.filter(f => f.routedTo !== f.correctedTo && f.routedTo !== 'unknown');
if (corrections.length < 5) {
console.error(`纠正数据不足 (${corrections.length} 条),至少需要 5 条才能拆分`);
process.exit(1);
}
const holdoutQueries = [];
const trainQueries = [];
for (const fb of corrections) {
const bucket = simpleHash(fb.query) % 100;
if (bucket < HOLDOUT_RATIO * 100) {
holdoutQueries.push(fb.query);
} else {
trainQueries.push(fb.query);
}
}
const holdoutData = {
generated: new Date().toISOString(),
ratio: HOLDOUT_RATIO,
totalCorrections: corrections.length,
holdoutCount: holdoutQueries.length,
trainCount: trainQueries.length,
holdoutQueries,
};
ensureDebugDir();
fs.writeFileSync(HOLDOUT_FILE, JSON.stringify(holdoutData, null, 2) + '\n');
console.log(`Holdout 拆分完成:`);
console.log(` 纠正总数: ${corrections.length}`);
console.log(` 训练集: ${trainQueries.length} (${Math.round(trainQueries.length / corrections.length * 100)}%)`);
console.log(` Holdout: ${holdoutQueries.length} (${Math.round(holdoutQueries.length / corrections.length * 100)}%)`);
console.log(` 输出: ${HOLDOUT_FILE}`);
return holdoutData;
}
/** P0-3: 加载 holdout 集查询列表 */
function loadHoldoutSet() {
if (!fs.existsSync(HOLDOUT_FILE)) return null;
try { return JSON.parse(fs.readFileSync(HOLDOUT_FILE, 'utf8')); } catch { return null; }
}
/** P0-3: 过滤掉 holdout 条目,仅返回训练集 */
function filterTrainSet(feedback) {
const holdout = loadHoldoutSet();
if (!holdout) return feedback; // 无 holdout 文件则不过滤
const holdoutSet = new Set(holdout.holdoutQueries);
return feedback.filter(fb => !holdoutSet.has(fb.query));
}
/** P0-3: 在 holdout 集上验证当前路由准确率 */
function validateHoldout() {
const holdout = loadHoldoutSet();
if (!holdout || holdout.holdoutQueries.length === 0) {
console.log('无 holdout 集。运行 --split 先生成。');
return null;
}
const feedback = loadFeedback();
const holdoutSet = new Set(holdout.holdoutQueries);
const holdoutEntries = feedback.filter(fb =>
holdoutSet.has(fb.query) && fb.routedTo !== 'unknown' && fb.routedTo !== fb.correctedTo
);
// 使用 ab-backtest 的方式评估当前引擎
let fixed = 0, missed = 0;
let analyzer;
try { analyzer = require('./route-analyzer.js'); } catch { console.error('route-analyzer.js 不可用'); return null; }
const index = loadIndex();
if (!index) { console.error('skills-index.json 不可用'); return null; }
const bm25Params = analyzer.buildBM25Params ? analyzer.buildBM25Params(index) : null;
for (const fb of holdoutEntries) {
const queryTokens = analyzer.tokenize(fb.query);
const results = index.skills.map(skill => {
const { totalScore } = analyzer.scoreSkill(skill, queryTokens, bm25Params);
return { name: skill.name, score: Math.round(totalScore * 100) / 100 };
}).sort((a, b) => b.score - a.score);
// v5.9: applyDisambiguation 返回 { results, firedRules } 对象
let disambiguated = results;
if (analyzer.applyDisambiguation) {
const dr = analyzer.applyDisambiguation(results, fb.query, index);
disambiguated = Array.isArray(dr) ? dr : (dr?.results || results);
}
if (!disambiguated || disambiguated.length === 0) { missed++; continue; }
const normalized = analyzer.normalizeScores ? analyzer.normalizeScores(disambiguated) : disambiguated;
const top = normalized[0];
if (top && top.name === fb.correctedTo) {
fixed++;
} else {
missed++;
}
}
const result = {
holdoutSize: holdoutEntries.length,
fixed,
missed,
accuracy: holdoutEntries.length > 0 ? Math.round(fixed / holdoutEntries.length * 100 * 10) / 10 : 0,
};
if (jsonMode) {
console.log(JSON.stringify(result, null, 2));
} else {
console.log(`Holdout 验证结果:`);
console.log(` Holdout 纠正数: ${holdoutEntries.length}`);
console.log(` 修复: ${fixed}, 未修复: ${missed}`);
console.log(` 准确率: ${result.accuracy}%`);
}
return result;
}
/** P0-2: 权重快照 — 备份当前权重文件,保留最新 N 个 */
function snapshotWeights() {
if (!fs.existsSync(WEIGHTS_FILE)) return null;
if (!fs.existsSync(WEIGHTS_HISTORY_DIR)) fs.mkdirSync(WEIGHTS_HISTORY_DIR, { recursive: true });
const ts = new Date().toISOString().replace(/[:.]/g, '-');
const dest = path.join(WEIGHTS_HISTORY_DIR, `route-weights-${ts}.json`);
fs.copyFileSync(WEIGHTS_FILE, dest);
// 清理旧快照,保留最新 MAX_WEIGHT_SNAPSHOTS 个
const files = fs.readdirSync(WEIGHTS_HISTORY_DIR)
.filter(f => f.startsWith('route-weights-') && f.endsWith('.json'))
.sort();
while (files.length > MAX_WEIGHT_SNAPSHOTS) {
fs.unlinkSync(path.join(WEIGHTS_HISTORY_DIR, files.shift()));
}
return dest;
}
/** 简易文本 tokenize (与 route-analyzer.js 保持一致) */
function tokenize(text) {
const tokens = new Set();
const cnChars = text.match(/[\u4e00-\u9fff]+/g) || [];
for (const chunk of cnChars) {
for (let len = 2; len <= Math.min(6, chunk.length); len++) {
for (let i = 0; i <= chunk.length - len; i++) {
tokens.add(chunk.slice(i, i + len).toLowerCase());
}
}
}
const enWords = text.match(/[A-Za-z][\w.-]*(?:\s+[A-Za-z][\w.-]*){0,2}/g) || [];
for (const w of enWords) {
tokens.add(w.toLowerCase().trim());
for (const single of w.split(/[\s.-]+/)) {
if (single.length >= 2) tokens.add(single.toLowerCase());
}
}
// v4.9: 同义词展开
try {
const { expandSynonyms } = require('./synonym-expander.js');
return expandSynonyms(tokens);
} catch {
return tokens;
}
}
// === 模式 1: 记录纠正 ===
function recordCorrection() {
// 从路由日志找最近一条匹配 query 的记录
const flagIdx = args.indexOf('--correct');
const remaining = args.filter((a, i) => i > flagIdx && !a.startsWith('--'));
if (remaining.length < 2) {
console.error('Usage: --correct "<query>" <correctSkill>');
process.exit(1);
}
const correctSkill = remaining.pop();
const query = remaining.join(' ');
// P0-1: 技能名白名单校验
const whitelist = loadSkillWhitelist();
if (whitelist && !whitelist.has(correctSkill)) {
console.error(`Invalid skill name: "${correctSkill}"`);
console.error(`Valid skills: ${Array.from(whitelist).sort().join(', ')}`);
process.exit(1);
}
// 搜索最近路由日志找到这个 query 的路由结果
let routedTo = 'unknown';
let topConfidence = 0;
const today = new Date().toISOString().slice(0, 10);
for (let d = 0; d < 7; d++) {
const date = new Date(Date.now() - d * 86400000).toISOString().slice(0, 10);
const logs = loadRouteLog(date);
const match = logs.reverse().find(l =>
l.query && l.query.toLowerCase().includes(query.toLowerCase().slice(0, 50))
);
if (match) {
routedTo = match.topResult;
topConfidence = match.topConfidence;
break;
}
}
const entry = {
ts: new Date().toISOString(),
query: query.slice(0, 200),
routedTo,
correctedTo: correctSkill,
topConfidence,
queryTokens: Array.from(tokenize(query)),
};
ensureDebugDir();
fs.appendFileSync(FEEDBACK_FILE, JSON.stringify(entry) + '\n');
if (jsonMode) {
console.log(JSON.stringify(entry, null, 2));
} else {
console.log(`Correction recorded:`);
console.log(` Query: "${query}"`);
console.log(` Routed: ${routedTo} (${(topConfidence * 100).toFixed(0)}%)`);
console.log(` Correct: ${correctSkill}`);
}
}
// === 模式 2: 学习权重调整 ===
function learnWeights() {
const rawFeedback = loadFeedback();
// P0-3: 仅在训练集上学习,排除 holdout 条目
const feedback = filterTrainSet(rawFeedback);
if (feedback.length === 0) {
console.log('No feedback data. Use --correct to record corrections first.');
return;
}
const index = loadIndex();
if (!index) {
console.error('skills-index.json not found');
process.exit(1);
}
// 构建 skill → keywords 映射
const skillKeywords = {};
for (const skill of index.skills) {
skillKeywords[skill.name] = new Set(
skill.keywords.map(k => k.keyword.toLowerCase())
);
}
// 权重增量: { skillName: { keyword: delta } }
const deltas = {};
// 指数衰减: 越新的反馈权重越大
const now = Date.now();
const DECAY_HALF_LIFE = 5 * 86400000; // 5 天半衰期
for (const fb of feedback) {
if (fb.routedTo === fb.correctedTo) continue; // 非纠正
if (fb.routedTo === 'unknown') continue; // 无原始路由
const age = now - new Date(fb.ts).getTime();
const decay = Math.pow(0.5, age / DECAY_HALF_LIFE);
// v5.9: 反馈权重分层 — 直接观测 > 手动纠正 > 隐式推断 > 超时确认
const typeFactor = fb.type === 'observed' ? 0.8 : (fb.type === 'implicit' ? 0.5 : 1.0);
const timeoutFactor = fb.implicit === 'timeout-confirm' ? (fb.weight || 0.1) : 1.0;
const delta = 0.1 * decay * typeFactor * timeoutFactor;
// Fix: implicit feedback 不含 queryTokens需从 query 现场 tokenize
let queryTokens;
if (fb.queryTokens && fb.queryTokens.length > 0) {
queryTokens = new Set(fb.queryTokens);
} else if (fb.query) {
queryTokens = tokenize(fb.query);
} else {
queryTokens = new Set();
}
// 降低被错误路由到的技能中匹配的关键词权重
const wrongKw = skillKeywords[fb.routedTo];
if (wrongKw) {
for (const token of queryTokens) {
if (wrongKw.has(token)) {
if (!deltas[fb.routedTo]) deltas[fb.routedTo] = {};
deltas[fb.routedTo][token] = (deltas[fb.routedTo][token] || 0) - delta;
}
}
}
// 提升正确技能中匹配的关键词权重
const rightKw = skillKeywords[fb.correctedTo];
if (rightKw) {
for (const token of queryTokens) {
if (rightKw.has(token)) {
if (!deltas[fb.correctedTo]) deltas[fb.correctedTo] = {};
deltas[fb.correctedTo][token] = (deltas[fb.correctedTo][token] || 0) + delta;
}
}
}
}
// 裁剪极小增量 (|delta| < 0.01)
for (const skill of Object.keys(deltas)) {
for (const kw of Object.keys(deltas[skill])) {
if (Math.abs(deltas[skill][kw]) < 0.01) {
delete deltas[skill][kw];
} else {
// 限幅: [-0.5, +0.5]
deltas[skill][kw] = Math.max(-0.5, Math.min(0.5, Math.round(deltas[skill][kw] * 100) / 100));
}
}
if (Object.keys(deltas[skill]).length === 0) delete deltas[skill];
}
// P0-2: 写入前快照当前权重
const snapshotPath = snapshotWeights();
// 写入权重文件
const output = {
generated: new Date().toISOString(),
feedbackCount: feedback.length,
correctionCount: feedback.filter(f => f.routedTo !== f.correctedTo && f.routedTo !== 'unknown').length,
implicitCount: feedback.filter(f => f.type === 'implicit').length,
deltas,
};
ensureDebugDir();
// MEDIUM-1: 使用 weight-store 的并发安全写入替代自实现 tmp+rename
try {
const weightStore = require('./weight-store.js');
// writeWeights 是 async 的,但 learnWeights 是同步上下文
// 使用 safeWriteJson 可写入任意路径,此处目标是 route-weights.json
weightStore.writeWeights(output).catch(() => {
// 异步写入失败时回退到直接写入
fs.writeFileSync(WEIGHTS_FILE, JSON.stringify(output, null, 2) + '\n');
});
} catch {
// weight-store 不可用时回退到直接写入
fs.writeFileSync(WEIGHTS_FILE, JSON.stringify(output, null, 2) + '\n');
}
if (jsonMode) {
console.log(JSON.stringify(output, null, 2));
} else {
const skillCount = Object.keys(deltas).length;
const totalAdj = Object.values(deltas).reduce((n, obj) => n + Object.keys(obj).length, 0);
console.log(`Weight learning complete:`);
console.log(` Feedback entries: ${feedback.length} (${output.implicitCount} implicit)`);
console.log(` Corrections: ${output.correctionCount}`);
console.log(` Skills adjusted: ${skillCount}`);
console.log(` Total adjustments: ${totalAdj}`);
console.log(` Output: ${WEIGHTS_FILE}`);
if (snapshotPath) console.log(` Snapshot: ${snapshotPath}`);
}
}
// === v4.9: 自动学习触发 ===
/**
* 反馈 10 条时自动触发 learnWeights()
* @returns {boolean} 是否执行了学习
*/
function autoLearn() {
const feedback = loadFeedback();
if (feedback.length >= 10) {
learnWeights();
return true;
}
return false;
}
// === 模式 3: 反馈统计报告 ===
function showReport() {
const feedback = loadFeedback();
if (feedback.length === 0) {
console.log('No feedback data yet.');
return;
}
// 统计
const corrections = feedback.filter(f => f.routedTo !== f.correctedTo && f.routedTo !== 'unknown');
const routedToCount = {};
const correctedToCount = {};
const pairCount = {}; // "wrong → right" 频率
for (const fb of corrections) {
routedToCount[fb.routedTo] = (routedToCount[fb.routedTo] || 0) + 1;
correctedToCount[fb.correctedTo] = (correctedToCount[fb.correctedTo] || 0) + 1;
const pair = `${fb.routedTo}${fb.correctedTo}`;
pairCount[pair] = (pairCount[pair] || 0) + 1;
}
// 准确率 (非纠正/总路由)
const totalWithRoute = feedback.filter(f => f.routedTo !== 'unknown').length;
const accuracy = totalWithRoute > 0
? ((totalWithRoute - corrections.length) / totalWithRoute * 100).toFixed(1)
: 'N/A';
// 加载学习状态
let weights = null;
if (fs.existsSync(WEIGHTS_FILE)) {
try { weights = JSON.parse(fs.readFileSync(WEIGHTS_FILE, 'utf8')); } catch {}
}
if (jsonMode) {
console.log(JSON.stringify({
total: feedback.length,
corrections: corrections.length,
accuracy: accuracy + '%',
topMisroutes: Object.entries(routedToCount).sort((a, b) => b[1] - a[1]).slice(0, 5),
topCorrections: Object.entries(correctedToCount).sort((a, b) => b[1] - a[1]).slice(0, 5),
topPairs: Object.entries(pairCount).sort((a, b) => b[1] - a[1]).slice(0, 10),
weightsApplied: !!weights,
}, null, 2));
return;
}
console.log('# 路由反馈统计报告\n');
console.log(`总反馈: ${feedback.length}`);
console.log(`纠正次数: ${corrections.length}`);
console.log(`准确率: ${accuracy}%`);
if (corrections.length > 0) {
console.log('\n## 误路由频率 TOP 5 (被纠正的目标)');
Object.entries(routedToCount)
.sort((a, b) => b[1] - a[1])
.slice(0, 5)
.forEach(([name, cnt]) => console.log(` ${name.padEnd(30)} ${cnt}`));
console.log('\n## 纠正到 TOP 5 (正确目标)');
Object.entries(correctedToCount)
.sort((a, b) => b[1] - a[1])
.slice(0, 5)
.forEach(([name, cnt]) => console.log(` ${name.padEnd(30)} ${cnt}`));
console.log('\n## 常见纠正路径');
Object.entries(pairCount)
.sort((a, b) => b[1] - a[1])
.slice(0, 10)
.forEach(([pair, cnt]) => console.log(` ${pair.padEnd(50)} ${cnt}`));
}
console.log(`\n权重文件: ${weights ? `已生成 (${weights.generated})` : '未生成 (运行 --learn)'}`);
}
// === 模块导出 (供测试使用) ===
if (typeof module !== 'undefined') {
module.exports = {
tokenize,
loadFeedback,
loadRouteLog,
loadIndex,
loadSkillWhitelist,
snapshotWeights,
simpleHash,
splitHoldout,
loadHoldoutSet,
filterTrainSet,
validateHoldout,
recordCorrection,
learnWeights,
autoLearn,
showReport,
};
}
// === 主入口 ===
if (require.main === module) {
if (correctMode) {
recordCorrection();
} else if (splitMode) {
splitHoldout();
} else if (validateMode) {
validateHoldout();
} else if (learnMode) {
learnWeights();
} else if (reportMode) {
showReport();
} else {
console.log('Usage:');
console.log(' --correct "<query>" <skill> 记录路由纠正');
console.log(' --split 拆分训练集/holdout 集 (7:3)');
console.log(' --validate 在 holdout 集上验证准确率');
console.log(' --learn 生成权重调整 (仅训练集)');
console.log(' --report 查看反馈统计');
console.log(' --json JSON 输出');
process.exit(0);
}
}