190 lines
4.9 KiB
TypeScript
190 lines
4.9 KiB
TypeScript
import { readdirSync, readFileSync } from "node:fs"
|
|
import { join } from "node:path"
|
|
import type { PluginInput } from "@opencode-ai/plugin"
|
|
import type { PruningState, ToolCallSignature } from "./pruning-types"
|
|
import { estimateTokens } from "./pruning-types"
|
|
import { log } from "../../shared/logger"
|
|
import { getMessageDir } from "../../shared/opencode-message-dir"
|
|
import { isSqliteBackend } from "../../shared/opencode-storage-detection"
|
|
import { normalizeSDKResponse } from "../../shared"
|
|
|
|
type OpencodeClient = PluginInput["client"]
|
|
|
|
export interface DeduplicationConfig {
|
|
enabled: boolean
|
|
protectedTools?: string[]
|
|
}
|
|
|
|
interface ToolPart {
|
|
type: string
|
|
callID?: string
|
|
tool?: string
|
|
state?: {
|
|
input?: unknown
|
|
output?: string
|
|
}
|
|
}
|
|
|
|
interface MessagePart {
|
|
type: string
|
|
parts?: ToolPart[]
|
|
}
|
|
|
|
export function createToolSignature(toolName: string, input: unknown): string {
|
|
const sortedInput = sortObject(input)
|
|
return `${toolName}::${JSON.stringify(sortedInput)}`
|
|
}
|
|
|
|
function sortObject(obj: unknown): unknown {
|
|
if (obj === null || obj === undefined) return obj
|
|
if (typeof obj !== "object") return obj
|
|
if (Array.isArray(obj)) return obj.map(sortObject)
|
|
|
|
const sorted: Record<string, unknown> = {}
|
|
const keys = Object.keys(obj as Record<string, unknown>).sort()
|
|
for (const key of keys) {
|
|
sorted[key] = sortObject((obj as Record<string, unknown>)[key])
|
|
}
|
|
return sorted
|
|
}
|
|
|
|
function readMessages(sessionID: string): MessagePart[] {
|
|
const messageDir = getMessageDir(sessionID)
|
|
if (!messageDir) return []
|
|
|
|
const messages: MessagePart[] = []
|
|
|
|
try {
|
|
const files = readdirSync(messageDir).filter((f: string) => f.endsWith(".json"))
|
|
for (const file of files) {
|
|
const content = readFileSync(join(messageDir, file), "utf-8")
|
|
const data = JSON.parse(content)
|
|
if (data.parts) {
|
|
messages.push(data)
|
|
}
|
|
}
|
|
} catch {
|
|
return []
|
|
}
|
|
|
|
return messages
|
|
}
|
|
|
|
async function readMessagesFromSDK(client: OpencodeClient, sessionID: string): Promise<MessagePart[]> {
|
|
try {
|
|
const response = await client.session.messages({ path: { id: sessionID } })
|
|
const rawMessages = normalizeSDKResponse(response, [] as Array<{ parts?: ToolPart[] }>, { preferResponseOnMissingData: true })
|
|
return rawMessages.filter((m) => m.parts) as MessagePart[]
|
|
} catch {
|
|
return []
|
|
}
|
|
}
|
|
|
|
export async function executeDeduplication(
|
|
sessionID: string,
|
|
state: PruningState,
|
|
config: DeduplicationConfig,
|
|
protectedTools: Set<string>,
|
|
client?: OpencodeClient,
|
|
): Promise<number> {
|
|
if (!config.enabled) return 0
|
|
|
|
const messages = (client && isSqliteBackend())
|
|
? await readMessagesFromSDK(client, sessionID)
|
|
: readMessages(sessionID)
|
|
|
|
const signatures = new Map<string, ToolCallSignature[]>()
|
|
|
|
let currentTurn = 0
|
|
|
|
for (const msg of messages) {
|
|
if (!msg.parts) continue
|
|
|
|
for (const part of msg.parts) {
|
|
if (part.type === "step-start") {
|
|
currentTurn++
|
|
continue
|
|
}
|
|
|
|
if (part.type !== "tool" || !part.callID || !part.tool) continue
|
|
|
|
if (protectedTools.has(part.tool)) continue
|
|
|
|
if (config.protectedTools?.includes(part.tool)) continue
|
|
|
|
if (state.toolIdsToPrune.has(part.callID)) continue
|
|
|
|
const signature = createToolSignature(part.tool, part.state?.input)
|
|
|
|
if (!signatures.has(signature)) {
|
|
signatures.set(signature, [])
|
|
}
|
|
|
|
signatures.get(signature)!.push({
|
|
toolName: part.tool,
|
|
signature,
|
|
callID: part.callID,
|
|
turn: currentTurn,
|
|
})
|
|
|
|
if (!state.toolSignatures.has(signature)) {
|
|
state.toolSignatures.set(signature, [])
|
|
}
|
|
state.toolSignatures.get(signature)!.push({
|
|
toolName: part.tool,
|
|
signature,
|
|
callID: part.callID,
|
|
turn: currentTurn,
|
|
})
|
|
}
|
|
}
|
|
|
|
let prunedCount = 0
|
|
let tokensSaved = 0
|
|
|
|
for (const [signature, calls] of signatures) {
|
|
if (calls.length > 1) {
|
|
const toPrune = calls.slice(0, -1)
|
|
|
|
for (const call of toPrune) {
|
|
state.toolIdsToPrune.add(call.callID)
|
|
prunedCount++
|
|
|
|
const output = findToolOutput(messages, call.callID)
|
|
if (output) {
|
|
tokensSaved += estimateTokens(output)
|
|
}
|
|
|
|
log("[pruning-deduplication] pruned duplicate", {
|
|
tool: call.toolName,
|
|
callID: call.callID,
|
|
turn: call.turn,
|
|
signature: signature.substring(0, 100),
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
log("[pruning-deduplication] complete", {
|
|
prunedCount,
|
|
tokensSaved,
|
|
uniqueSignatures: signatures.size,
|
|
})
|
|
|
|
return prunedCount
|
|
}
|
|
|
|
function findToolOutput(messages: MessagePart[], callID: string): string | null {
|
|
for (const msg of messages) {
|
|
if (!msg.parts) continue
|
|
|
|
for (const part of msg.parts) {
|
|
if (part.type === "tool" && part.callID === callID && part.state?.output) {
|
|
return part.state.output
|
|
}
|
|
}
|
|
}
|
|
|
|
return null
|
|
}
|