llm-bot/index.ts
2024-02-09 20:28:27 -08:00

199 lines
6.6 KiB
TypeScript

import { TextDelimiterStream } from "https://deno.land/std@0.197.0/streams/mod.ts";
import { Client } from "https://git.celery.eu.org/tezlm/discount/raw/branch/master/src/index.ts";
import { DB } from "https://deno.land/x/sqlite/mod.ts";
import marked from "./markdown.ts";
const env = Deno.env.toObject();
const db = new DB(env.DATABASE);
const client = new Client({
baseUrl: env.MATRIX_BASE_URL,
userId: env.MATRIX_USER_ID,
token: env.MATRIX_TOKEN,
});
db.execute(`
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY,
event_id TEXT,
parent INTEGER,
role TEXT NOT NULL,
content TEXT NOT NULL,
FOREIGN KEY (parent) REFERENCES messages(id)
);
`);
type Message = {
role: "system" | "user" | "assistant",
content: string,
};
function createMessage(message: Message, eventId?: string, parentId?: string): Message {
db.query("INSERT INTO messages (event_id, parent, role, content) VALUES (?, ?, ?, ?) RETURNING id", [eventId ?? null, parentId ?? null, message.role, message.content]);
return message;
}
function getHistory(headId: string): Array<Message> {
const query = db.queryEntries(`
WITH RECURSIVE tree(id, event_id, parent, role, content) AS (
SELECT * FROM messages WHERE event_id = ?
UNION
SELECT messages.* FROM messages JOIN tree ON tree.parent = messages.id
) SELECT * FROM tree ORDER BY id`, [headId]);
return query.map(({ role, content }) => ({ role, content }));
}
const processed = new Set();
client.on("event", async (event) => {
if (client.status === "starting") return; // ignore if the bot is starting
if (event.content.new_content) return; // ignore edits
if (event.sender.id === client.userId) return; // ignore messages from me
if (event.type !== "m.room.message") return; // only handle messages
// conduit sends duplicate events sometimes
if (processed.has(event.id)) return;
processed.add(event.id);
// strip reply
let body = event.content.body.trim();
if (body.startsWith("> <")) body = body.slice(body.indexOf("\n\n")).trim();
const replyId = event.content["m.relates_to"]?.["m.in_reply_to"]?.event_id;
const MXID = env.MATRIX_USER_ID;
const SYSTEM = "You are a helpful assistant. You may use markdown.";
const localpart = MXID.split(":")[0];
if (body.startsWith(localpart) || body.startsWith(MXID)) {
// handle mentions
body = body.slice(body.startsWith(MXID) ? MXID.length : localpart.length).trim();
const messages = [createMessage({ role: "system", content: SYSTEM }), createMessage({ role: "user", content: body }, event.id)];
createCompletion({ event, messages });
} else if (replyId) {
// handle replies
const history = getHistory(replyId);
if (!history.length) return;
const messages = [...history, createMessage({ role: "user", content: body }, event.id)];
createCompletion({ event, messages });
}
});
async function *createGeneration(messages: Messages) {
const history = messages.map(({ role, content }) => `<|im_start|>${role}\n${content}<|im_end|>\n`).join("");
const res = await fetch(`${env.LLM_BASE_URL}/completion`, {
method: "POST",
headers: { "content-type": "application/json" },
body: JSON.stringify({
prompt: history + "<|im_start|>assistant\n",
stop: ["<|im_end"],
stream: true,
}),
});
const lines = res.body!
.pipeThrough(new TextDecoderStream())
.pipeThrough(new TextDelimiterStream("\n\n"));
for await (const line of lines) {
const data = line.slice("data: ".length);
if (!data) continue;
yield JSON.parse(data);
}
}
async function sendOrEditMessage(event, content: any, replyId?: string): string {
const body = { ...content };
if (replyId) {
Object.assign(body, {
"m.new_content": content,
"m.relates_to": {
"event_id": replyId,
"rel_type": "m.replace",
}
})
} else {
const thread = event.content["m.relates_to"]?.rel_type === "m.thread" ? event.content["m.relates_to"] : null;
Object.assign(body, {
"m.relates_to": {
"m.in_reply_to": { event_id: event.id },
...thread,
}
})
}
const { event_id } = await event.client.fetcher.fetchClient(`/rooms/${encodeURIComponent(event.room.id)}/send/m.room.message/${Math.random()}`, { method: "PUT", body });
return replyId ?? event_id;
}
async function createCompletion({ event, messages, replyId }) {
await updateTyping(event.room, 1);
try {
console.log("input", messages);
const stream = await createGeneration(messages);
let lastUpdate = Date.now();
let buffer = "";
for await (const chunk of stream) {
buffer += chunk.content;
if (chunk.stop || Date.now() - lastUpdate > 5 * 1000) {
// create or update the result message
const body = `${chunk.stop ? "" : "[WIP] "}${buffer.trim()}`;
const content = {
msgtype: "m.text",
body,
format: "org.matrix.custom.html",
formatted_body: marked.parse(body, { breaks: true }),
"_llm": {
model: chunk.model,
seed: chunk.seed,
shell: "v1",
timings: !chunk.stop ? undefined : {
count: chunk.tokens_predicted,
tokensPerSecond: chunk.timings.predicted_per_second.toFixed(4),
},
}
};
replyId = await sendOrEditMessage(event, content, replyId);
lastUpdate = Date.now();
if (!chunk.stop) await updateTyping(event.room, 0);
}
if (chunk.stop) {
console.log("output", buffer);
// update message history
const [parentId] = db.query("SELECT id FROM messages WHERE event_id = ?", [event.id]);
createMessage({ role: "assistant", content: buffer }, replyId, parentId[0]);
break;
}
}
} catch (err) {
console.error(err);
const content = {
msgtype: "m.notice",
format: "org.matrix.custom.html",
formatted_body: `<b>Error!</b> ${err.toString()}`,
body: "Error! " + err.toString(),
}
await sendOrEditMessage(event, content, replyId);
} finally {
await updateTyping(event.room, -1);
}
}
const typingCounters = new Map();
async function updateTyping(room, delta: number) {
const { client } = room;
typingCounters.set(room.id, (typingCounters.get(room.id) ?? 0) + delta);
console.log(typingCounters);
await client.fetcher.fetchClient(`/rooms/${room.id}/typing/${client.userId}`, {
method: "PUT",
body: { timeout: 300000, typing: typingCounters.get(room.id) > 0 },
});
}
client.on("invite", (invite) => invite.join());
client.on("ready", () => console.log("ready!"));
await client.start();