bookworm-smart-assistant/scripts/route-feedback.js

586 lines
20 KiB
JavaScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env node
/**
* 路由反馈权重学习 (v4.2)
*
* 记录路由纠正 → 分析模式 → 生成权重调整。
* 配合 route-analyzer.js 实现闭环学习。
*
* 用法:
* node route-feedback.js --correct "<query>" <actualSkill> 记录纠正
* node route-feedback.js --learn 生成权重调整
* node route-feedback.js --report 反馈统计
* node route-feedback.js --json JSON 输出
*
* 文件:
* debug/route-feedback.jsonl 纠正记录
* debug/route-weights.json 学习后的权重增量
*/
const fs = require('fs');
const path = require('path');
const detectClaudeRoot = () => require('./paths.config.js').PATHS.root;
const ROOT = detectClaudeRoot();
const INDEX_FILE = path.join(ROOT, 'skills-index.json');
const DEBUG_DIR = path.join(ROOT, 'debug');
const FEEDBACK_FILE = path.join(DEBUG_DIR, 'route-feedback.jsonl');
const WEIGHTS_FILE = path.join(DEBUG_DIR, 'route-weights.json');
const WEIGHTS_HISTORY_DIR = path.join(DEBUG_DIR, 'weights-history');
const HOLDOUT_FILE = path.join(DEBUG_DIR, 'holdout-set.json');
const MAX_WEIGHT_SNAPSHOTS = 20;
const HOLDOUT_RATIO = 0.3; // 30% holdout
// === 参数解析 ===
const args = process.argv.slice(2);
const correctMode = args.includes('--correct');
const learnMode = args.includes('--learn');
const reportMode = args.includes('--report');
const jsonMode = args.includes('--json');
const validateMode = args.includes('--validate');
const splitMode = args.includes('--split');
// === 工具函数 ===
function ensureDebugDir() {
if (!fs.existsSync(DEBUG_DIR)) fs.mkdirSync(DEBUG_DIR, { recursive: true });
}
/** 从 skills-index.json 加载有效技能名白名单 */
function loadSkillWhitelist() {
const index = loadIndex();
if (!index || !index.skills) return null;
return new Set(index.skills.map(s => s.name));
}
function loadIndex() {
if (!fs.existsSync(INDEX_FILE)) return null;
return JSON.parse(fs.readFileSync(INDEX_FILE, 'utf8'));
}
function loadFeedback() {
if (!fs.existsSync(FEEDBACK_FILE)) return [];
return fs.readFileSync(FEEDBACK_FILE, 'utf8')
.split('\n')
.filter(Boolean)
.map(line => { try { return JSON.parse(line); } catch { return null; } })
.filter(Boolean);
}
function loadRouteLog(dateStr) {
const logFile = path.join(DEBUG_DIR, `route-${dateStr}.jsonl`);
if (!fs.existsSync(logFile)) return [];
return fs.readFileSync(logFile, 'utf8')
.split('\n')
.filter(Boolean)
.map(line => { try { return JSON.parse(line); } catch { return null; } })
.filter(Boolean);
}
/** P0-3: 确定性哈希 — 基于 query 字符串生成稳定 hash */
function simpleHash(str) {
let hash = 0;
for (let i = 0; i < str.length; i++) {
hash = ((hash << 5) - hash + str.charCodeAt(i)) | 0;
}
return Math.abs(hash);
}
/** P0-3: 拆分训练集/holdout 集,生成 holdout-set.json */
function splitHoldout() {
const feedback = loadFeedback();
const corrections = feedback.filter(f => f.routedTo !== f.correctedTo && f.routedTo !== 'unknown');
if (corrections.length < 5) {
console.error(`纠正数据不足 (${corrections.length} 条),至少需要 5 条才能拆分`);
process.exit(1);
}
const holdoutQueries = [];
const trainQueries = [];
for (const fb of corrections) {
const bucket = simpleHash(fb.query) % 100;
if (bucket < HOLDOUT_RATIO * 100) {
holdoutQueries.push(fb.query);
} else {
trainQueries.push(fb.query);
}
}
const holdoutData = {
generated: new Date().toISOString(),
ratio: HOLDOUT_RATIO,
totalCorrections: corrections.length,
holdoutCount: holdoutQueries.length,
trainCount: trainQueries.length,
holdoutQueries,
};
ensureDebugDir();
fs.writeFileSync(HOLDOUT_FILE, JSON.stringify(holdoutData, null, 2) + '\n');
console.log(`Holdout 拆分完成:`);
console.log(` 纠正总数: ${corrections.length}`);
console.log(` 训练集: ${trainQueries.length} (${Math.round(trainQueries.length / corrections.length * 100)}%)`);
console.log(` Holdout: ${holdoutQueries.length} (${Math.round(holdoutQueries.length / corrections.length * 100)}%)`);
console.log(` 输出: ${HOLDOUT_FILE}`);
return holdoutData;
}
/** P0-3: 加载 holdout 集查询列表 */
function loadHoldoutSet() {
if (!fs.existsSync(HOLDOUT_FILE)) return null;
try { return JSON.parse(fs.readFileSync(HOLDOUT_FILE, 'utf8')); } catch { return null; }
}
/** P0-3: 过滤掉 holdout 条目,仅返回训练集 */
function filterTrainSet(feedback) {
const holdout = loadHoldoutSet();
if (!holdout) return feedback; // 无 holdout 文件则不过滤
const holdoutSet = new Set(holdout.holdoutQueries);
return feedback.filter(fb => !holdoutSet.has(fb.query));
}
/** P0-3: 在 holdout 集上验证当前路由准确率 */
function validateHoldout() {
const holdout = loadHoldoutSet();
if (!holdout || holdout.holdoutQueries.length === 0) {
console.log('无 holdout 集。运行 --split 先生成。');
return null;
}
const feedback = loadFeedback();
const holdoutSet = new Set(holdout.holdoutQueries);
const holdoutEntries = feedback.filter(fb =>
holdoutSet.has(fb.query) && fb.routedTo !== 'unknown' && fb.routedTo !== fb.correctedTo
);
// 使用 ab-backtest 的方式评估当前引擎
let fixed = 0, missed = 0;
let analyzer;
try { analyzer = require('./route-analyzer.js'); } catch { console.error('route-analyzer.js 不可用'); return null; }
const index = loadIndex();
if (!index) { console.error('skills-index.json 不可用'); return null; }
const bm25Params = analyzer.buildBM25Params ? analyzer.buildBM25Params(index) : null;
for (const fb of holdoutEntries) {
const queryTokens = analyzer.tokenize(fb.query);
const results = index.skills.map(skill => {
const { totalScore } = analyzer.scoreSkill(skill, queryTokens, bm25Params);
return { name: skill.name, score: Math.round(totalScore * 100) / 100 };
}).sort((a, b) => b.score - a.score);
// v5.9: applyDisambiguation 返回 { results, firedRules } 对象
let disambiguated = results;
if (analyzer.applyDisambiguation) {
const dr = analyzer.applyDisambiguation(results, fb.query, index);
disambiguated = Array.isArray(dr) ? dr : (dr?.results || results);
}
if (!disambiguated || disambiguated.length === 0) { missed++; continue; }
const normalized = analyzer.normalizeScores ? analyzer.normalizeScores(disambiguated) : disambiguated;
const top = normalized[0];
if (top && top.name === fb.correctedTo) {
fixed++;
} else {
missed++;
}
}
const result = {
holdoutSize: holdoutEntries.length,
fixed,
missed,
accuracy: holdoutEntries.length > 0 ? Math.round(fixed / holdoutEntries.length * 100 * 10) / 10 : 0,
};
if (jsonMode) {
console.log(JSON.stringify(result, null, 2));
} else {
console.log(`Holdout 验证结果:`);
console.log(` Holdout 纠正数: ${holdoutEntries.length}`);
console.log(` 修复: ${fixed}, 未修复: ${missed}`);
console.log(` 准确率: ${result.accuracy}%`);
}
return result;
}
/** P0-2: 权重快照 — 备份当前权重文件,保留最新 N 个 */
function snapshotWeights() {
if (!fs.existsSync(WEIGHTS_FILE)) return null;
if (!fs.existsSync(WEIGHTS_HISTORY_DIR)) fs.mkdirSync(WEIGHTS_HISTORY_DIR, { recursive: true });
const ts = new Date().toISOString().replace(/[:.]/g, '-');
const dest = path.join(WEIGHTS_HISTORY_DIR, `route-weights-${ts}.json`);
fs.copyFileSync(WEIGHTS_FILE, dest);
// 清理旧快照,保留最新 MAX_WEIGHT_SNAPSHOTS 个
const files = fs.readdirSync(WEIGHTS_HISTORY_DIR)
.filter(f => f.startsWith('route-weights-') && f.endsWith('.json'))
.sort();
while (files.length > MAX_WEIGHT_SNAPSHOTS) {
fs.unlinkSync(path.join(WEIGHTS_HISTORY_DIR, files.shift()));
}
return dest;
}
/** 简易文本 tokenize (与 route-analyzer.js 保持一致) */
function tokenize(text) {
const tokens = new Set();
const cnChars = text.match(/[\u4e00-\u9fff]+/g) || [];
for (const chunk of cnChars) {
for (let len = 2; len <= Math.min(6, chunk.length); len++) {
for (let i = 0; i <= chunk.length - len; i++) {
tokens.add(chunk.slice(i, i + len).toLowerCase());
}
}
}
const enWords = text.match(/[A-Za-z][\w.-]*(?:\s+[A-Za-z][\w.-]*){0,2}/g) || [];
for (const w of enWords) {
tokens.add(w.toLowerCase().trim());
for (const single of w.split(/[\s.-]+/)) {
if (single.length >= 2) tokens.add(single.toLowerCase());
}
}
// v4.9: 同义词展开
try {
const { expandSynonyms } = require('./synonym-expander.js');
return expandSynonyms(tokens);
} catch {
return tokens;
}
}
// === 模式 1: 记录纠正 ===
function recordCorrection() {
// 从路由日志找最近一条匹配 query 的记录
const flagIdx = args.indexOf('--correct');
const remaining = args.filter((a, i) => i > flagIdx && !a.startsWith('--'));
if (remaining.length < 2) {
console.error('Usage: --correct "<query>" <correctSkill>');
process.exit(1);
}
const correctSkill = remaining.pop();
const query = remaining.join(' ');
// P0-1: 技能名白名单校验
const whitelist = loadSkillWhitelist();
if (whitelist && !whitelist.has(correctSkill)) {
console.error(`Invalid skill name: "${correctSkill}"`);
console.error(`Valid skills: ${Array.from(whitelist).sort().join(', ')}`);
process.exit(1);
}
// 搜索最近路由日志找到这个 query 的路由结果
let routedTo = 'unknown';
let topConfidence = 0;
const today = new Date().toISOString().slice(0, 10);
for (let d = 0; d < 7; d++) {
const date = new Date(Date.now() - d * 86400000).toISOString().slice(0, 10);
const logs = loadRouteLog(date);
const match = logs.reverse().find(l =>
l.query && l.query.toLowerCase().includes(query.toLowerCase().slice(0, 50))
);
if (match) {
routedTo = match.topResult;
topConfidence = match.topConfidence;
break;
}
}
const entry = {
ts: new Date().toISOString(),
query: query.slice(0, 200),
routedTo,
correctedTo: correctSkill,
topConfidence,
queryTokens: Array.from(tokenize(query)),
};
ensureDebugDir();
fs.appendFileSync(FEEDBACK_FILE, JSON.stringify(entry) + '\n');
if (jsonMode) {
console.log(JSON.stringify(entry, null, 2));
} else {
console.log(`Correction recorded:`);
console.log(` Query: "${query}"`);
console.log(` Routed: ${routedTo} (${(topConfidence * 100).toFixed(0)}%)`);
console.log(` Correct: ${correctSkill}`);
}
}
// === 模式 2: 学习权重调整 ===
function learnWeights() {
const rawFeedback = loadFeedback();
// P0-3: 仅在训练集上学习,排除 holdout 条目
const feedback = filterTrainSet(rawFeedback);
if (feedback.length === 0) {
console.log('No feedback data. Use --correct to record corrections first.');
return;
}
const index = loadIndex();
if (!index) {
console.error('skills-index.json not found');
process.exit(1);
}
// 构建 skill → keywords 映射
const skillKeywords = {};
for (const skill of index.skills) {
skillKeywords[skill.name] = new Set(
skill.keywords.map(k => k.keyword.toLowerCase())
);
}
// 权重增量: { skillName: { keyword: delta } }
const deltas = {};
// 指数衰减: 越新的反馈权重越大
const now = Date.now();
const DECAY_HALF_LIFE = 5 * 86400000; // 5 天半衰期
for (const fb of feedback) {
if (fb.routedTo === fb.correctedTo) continue; // 非纠正
if (fb.routedTo === 'unknown') continue; // 无原始路由
const age = now - new Date(fb.ts).getTime();
const decay = Math.pow(0.5, age / DECAY_HALF_LIFE);
// v5.9: 反馈权重分层 — 直接观测 > 手动纠正 > 隐式推断 > 超时确认
const typeFactor = fb.type === 'observed' ? 0.8 : (fb.type === 'implicit' ? 0.5 : 1.0);
const timeoutFactor = fb.implicit === 'timeout-confirm' ? (fb.weight || 0.1) : 1.0;
const delta = 0.1 * decay * typeFactor * timeoutFactor;
// Fix: implicit feedback 不含 queryTokens需从 query 现场 tokenize
let queryTokens;
if (fb.queryTokens && fb.queryTokens.length > 0) {
queryTokens = new Set(fb.queryTokens);
} else if (fb.query) {
queryTokens = tokenize(fb.query);
} else {
queryTokens = new Set();
}
// 降低被错误路由到的技能中匹配的关键词权重
const wrongKw = skillKeywords[fb.routedTo];
if (wrongKw) {
for (const token of queryTokens) {
if (wrongKw.has(token)) {
if (!deltas[fb.routedTo]) deltas[fb.routedTo] = {};
deltas[fb.routedTo][token] = (deltas[fb.routedTo][token] || 0) - delta;
}
}
}
// 提升正确技能中匹配的关键词权重
const rightKw = skillKeywords[fb.correctedTo];
if (rightKw) {
for (const token of queryTokens) {
if (rightKw.has(token)) {
if (!deltas[fb.correctedTo]) deltas[fb.correctedTo] = {};
deltas[fb.correctedTo][token] = (deltas[fb.correctedTo][token] || 0) + delta;
}
}
}
}
// 裁剪极小增量 (|delta| < 0.01)
for (const skill of Object.keys(deltas)) {
for (const kw of Object.keys(deltas[skill])) {
if (Math.abs(deltas[skill][kw]) < 0.01) {
delete deltas[skill][kw];
} else {
// 限幅: [-0.5, +0.5]
deltas[skill][kw] = Math.max(-0.5, Math.min(0.5, Math.round(deltas[skill][kw] * 100) / 100));
}
}
if (Object.keys(deltas[skill]).length === 0) delete deltas[skill];
}
// P0-2: 写入前快照当前权重
const snapshotPath = snapshotWeights();
// 写入权重文件
const output = {
generated: new Date().toISOString(),
feedbackCount: feedback.length,
correctionCount: feedback.filter(f => f.routedTo !== f.correctedTo && f.routedTo !== 'unknown').length,
implicitCount: feedback.filter(f => f.type === 'implicit').length,
deltas,
};
ensureDebugDir();
// MEDIUM-1: 使用 weight-store 的并发安全写入替代自实现 tmp+rename
try {
const weightStore = require('./weight-store.js');
// writeWeights 是 async 的,但 learnWeights 是同步上下文
// 使用 safeWriteJson 可写入任意路径,此处目标是 route-weights.json
weightStore.writeWeights(output).catch(() => {
// 异步写入失败时回退到直接写入
fs.writeFileSync(WEIGHTS_FILE, JSON.stringify(output, null, 2) + '\n');
});
} catch {
// weight-store 不可用时回退到直接写入
fs.writeFileSync(WEIGHTS_FILE, JSON.stringify(output, null, 2) + '\n');
}
if (jsonMode) {
console.log(JSON.stringify(output, null, 2));
} else {
const skillCount = Object.keys(deltas).length;
const totalAdj = Object.values(deltas).reduce((n, obj) => n + Object.keys(obj).length, 0);
console.log(`Weight learning complete:`);
console.log(` Feedback entries: ${feedback.length} (${output.implicitCount} implicit)`);
console.log(` Corrections: ${output.correctionCount}`);
console.log(` Skills adjusted: ${skillCount}`);
console.log(` Total adjustments: ${totalAdj}`);
console.log(` Output: ${WEIGHTS_FILE}`);
if (snapshotPath) console.log(` Snapshot: ${snapshotPath}`);
}
}
// === v4.9: 自动学习触发 ===
/**
* 反馈 ≥10 条时自动触发 learnWeights()
* @returns {boolean} 是否执行了学习
*/
function autoLearn() {
const feedback = loadFeedback();
if (feedback.length >= 10) {
learnWeights();
return true;
}
return false;
}
// === 模式 3: 反馈统计报告 ===
function showReport() {
const feedback = loadFeedback();
if (feedback.length === 0) {
console.log('No feedback data yet.');
return;
}
// 统计
const corrections = feedback.filter(f => f.routedTo !== f.correctedTo && f.routedTo !== 'unknown');
const routedToCount = {};
const correctedToCount = {};
const pairCount = {}; // "wrong → right" 频率
for (const fb of corrections) {
routedToCount[fb.routedTo] = (routedToCount[fb.routedTo] || 0) + 1;
correctedToCount[fb.correctedTo] = (correctedToCount[fb.correctedTo] || 0) + 1;
const pair = `${fb.routedTo}${fb.correctedTo}`;
pairCount[pair] = (pairCount[pair] || 0) + 1;
}
// 准确率 (非纠正/总路由)
const totalWithRoute = feedback.filter(f => f.routedTo !== 'unknown').length;
const accuracy = totalWithRoute > 0
? ((totalWithRoute - corrections.length) / totalWithRoute * 100).toFixed(1)
: 'N/A';
// 加载学习状态
let weights = null;
if (fs.existsSync(WEIGHTS_FILE)) {
try { weights = JSON.parse(fs.readFileSync(WEIGHTS_FILE, 'utf8')); } catch {}
}
if (jsonMode) {
console.log(JSON.stringify({
total: feedback.length,
corrections: corrections.length,
accuracy: accuracy + '%',
topMisroutes: Object.entries(routedToCount).sort((a, b) => b[1] - a[1]).slice(0, 5),
topCorrections: Object.entries(correctedToCount).sort((a, b) => b[1] - a[1]).slice(0, 5),
topPairs: Object.entries(pairCount).sort((a, b) => b[1] - a[1]).slice(0, 10),
weightsApplied: !!weights,
}, null, 2));
return;
}
console.log('# 路由反馈统计报告\n');
console.log(`总反馈: ${feedback.length}`);
console.log(`纠正次数: ${corrections.length}`);
console.log(`准确率: ${accuracy}%`);
if (corrections.length > 0) {
console.log('\n## 误路由频率 TOP 5 (被纠正的目标)');
Object.entries(routedToCount)
.sort((a, b) => b[1] - a[1])
.slice(0, 5)
.forEach(([name, cnt]) => console.log(` ${name.padEnd(30)} ${cnt}`));
console.log('\n## 纠正到 TOP 5 (正确目标)');
Object.entries(correctedToCount)
.sort((a, b) => b[1] - a[1])
.slice(0, 5)
.forEach(([name, cnt]) => console.log(` ${name.padEnd(30)} ${cnt}`));
console.log('\n## 常见纠正路径');
Object.entries(pairCount)
.sort((a, b) => b[1] - a[1])
.slice(0, 10)
.forEach(([pair, cnt]) => console.log(` ${pair.padEnd(50)} ${cnt}`));
}
console.log(`\n权重文件: ${weights ? `已生成 (${weights.generated})` : '未生成 (运行 --learn)'}`);
}
// === 模块导出 (供测试使用) ===
if (typeof module !== 'undefined') {
module.exports = {
tokenize,
loadFeedback,
loadRouteLog,
loadIndex,
loadSkillWhitelist,
snapshotWeights,
simpleHash,
splitHoldout,
loadHoldoutSet,
filterTrainSet,
validateHoldout,
recordCorrection,
learnWeights,
autoLearn,
showReport,
};
}
// === 主入口 ===
if (require.main === module) {
if (correctMode) {
recordCorrection();
} else if (splitMode) {
splitHoldout();
} else if (validateMode) {
validateHoldout();
} else if (learnMode) {
learnWeights();
} else if (reportMode) {
showReport();
} else {
console.log('Usage:');
console.log(' --correct "<query>" <skill> 记录路由纠正');
console.log(' --split 拆分训练集/holdout 集 (7:3)');
console.log(' --validate 在 holdout 集上验证准确率');
console.log(' --learn 生成权重调整 (仅训练集)');
console.log(' --report 查看反馈统计');
console.log(' --json JSON 输出');
process.exit(0);
}
}