199 lines
6.6 KiB
TypeScript
199 lines
6.6 KiB
TypeScript
import { TextDelimiterStream } from "https://deno.land/std@0.197.0/streams/mod.ts";
|
|
import { Client } from "https://git.celery.eu.org/tezlm/discount/raw/branch/master/src/index.ts";
|
|
import { DB } from "https://deno.land/x/sqlite/mod.ts";
|
|
import marked from "./markdown.ts";
|
|
|
|
const env = Deno.env.toObject();
|
|
const db = new DB(env.DATABASE);
|
|
const client = new Client({
|
|
baseUrl: env.MATRIX_BASE_URL,
|
|
userId: env.MATRIX_USER_ID,
|
|
token: env.MATRIX_TOKEN,
|
|
});
|
|
|
|
db.execute(`
|
|
CREATE TABLE IF NOT EXISTS messages (
|
|
id INTEGER PRIMARY KEY,
|
|
event_id TEXT,
|
|
parent INTEGER,
|
|
role TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
FOREIGN KEY (parent) REFERENCES messages(id)
|
|
);
|
|
`);
|
|
|
|
type Message = {
|
|
role: "system" | "user" | "assistant",
|
|
content: string,
|
|
};
|
|
|
|
function createMessage(message: Message, eventId?: string, parentId?: string): Message {
|
|
db.query("INSERT INTO messages (event_id, parent, role, content) VALUES (?, ?, ?, ?) RETURNING id", [eventId ?? null, parentId ?? null, message.role, message.content]);
|
|
return message;
|
|
}
|
|
|
|
function getHistory(headId: string): Array<Message> {
|
|
const query = db.queryEntries(`
|
|
WITH RECURSIVE tree(id, event_id, parent, role, content) AS (
|
|
SELECT * FROM messages WHERE event_id = ?
|
|
UNION
|
|
SELECT messages.* FROM messages JOIN tree ON tree.parent = messages.id
|
|
) SELECT * FROM tree ORDER BY id`, [headId]);
|
|
return query.map(({ role, content }) => ({ role, content }));
|
|
}
|
|
|
|
const processed = new Set();
|
|
client.on("event", async (event) => {
|
|
if (client.status === "starting") return; // ignore if the bot is starting
|
|
if (event.content.new_content) return; // ignore edits
|
|
if (event.sender.id === client.userId) return; // ignore messages from me
|
|
if (event.type !== "m.room.message") return; // only handle messages
|
|
|
|
// conduit sends duplicate events sometimes
|
|
if (processed.has(event.id)) return;
|
|
processed.add(event.id);
|
|
|
|
// strip reply
|
|
let body = event.content.body.trim();
|
|
if (body.startsWith("> <")) body = body.slice(body.indexOf("\n\n")).trim();
|
|
|
|
const replyId = event.content["m.relates_to"]?.["m.in_reply_to"]?.event_id;
|
|
|
|
const MXID = env.MATRIX_USER_ID;
|
|
const SYSTEM = "You are a helpful assistant. You may use markdown.";
|
|
const localpart = MXID.split(":")[0];
|
|
if (body.startsWith(localpart) || body.startsWith(MXID)) {
|
|
// handle mentions
|
|
body = body.slice(body.startsWith(MXID) ? MXID.length : localpart.length).trim();
|
|
const messages = [createMessage({ role: "system", content: SYSTEM }), createMessage({ role: "user", content: body }, event.id)];
|
|
createCompletion({ event, messages });
|
|
} else if (replyId) {
|
|
// handle replies
|
|
const history = getHistory(replyId);
|
|
if (!history.length) return;
|
|
const messages = [...history, createMessage({ role: "user", content: body }, event.id)];
|
|
createCompletion({ event, messages });
|
|
}
|
|
});
|
|
|
|
async function *createGeneration(messages: Messages) {
|
|
const history = messages.map(({ role, content }) => `<|im_start|>${role}\n${content}<|im_end|>\n`).join("");
|
|
const res = await fetch(`${env.LLM_BASE_URL}/completion`, {
|
|
method: "POST",
|
|
headers: { "content-type": "application/json" },
|
|
body: JSON.stringify({
|
|
prompt: history + "<|im_start|>assistant\n",
|
|
stop: ["<|im_end"],
|
|
stream: true,
|
|
}),
|
|
});
|
|
|
|
const lines = res.body!
|
|
.pipeThrough(new TextDecoderStream())
|
|
.pipeThrough(new TextDelimiterStream("\n\n"));
|
|
|
|
for await (const line of lines) {
|
|
const data = line.slice("data: ".length);
|
|
if (!data) continue;
|
|
yield JSON.parse(data);
|
|
}
|
|
}
|
|
|
|
async function sendOrEditMessage(event, content: any, replyId?: string): string {
|
|
const body = { ...content };
|
|
if (replyId) {
|
|
Object.assign(body, {
|
|
"m.new_content": content,
|
|
"m.relates_to": {
|
|
"event_id": replyId,
|
|
"rel_type": "m.replace",
|
|
}
|
|
})
|
|
} else {
|
|
const thread = event.content["m.relates_to"]?.rel_type === "m.thread" ? event.content["m.relates_to"] : null;
|
|
Object.assign(body, {
|
|
"m.relates_to": {
|
|
"m.in_reply_to": { event_id: event.id },
|
|
...thread,
|
|
}
|
|
})
|
|
}
|
|
const { event_id } = await event.client.fetcher.fetchClient(`/rooms/${encodeURIComponent(event.room.id)}/send/m.room.message/${Math.random()}`, { method: "PUT", body });
|
|
return replyId ?? event_id;
|
|
}
|
|
|
|
async function createCompletion({ event, messages, replyId }) {
|
|
await updateTyping(event.room, 1);
|
|
|
|
try {
|
|
console.log("input", messages);
|
|
const stream = await createGeneration(messages);
|
|
|
|
let lastUpdate = Date.now();
|
|
let buffer = "";
|
|
|
|
for await (const chunk of stream) {
|
|
buffer += chunk.content;
|
|
|
|
if (chunk.stop || Date.now() - lastUpdate > 5 * 1000) {
|
|
// create or update the result message
|
|
const body = `${chunk.stop ? "" : "[WIP] "}${buffer.trim()}`;
|
|
const content = {
|
|
msgtype: "m.text",
|
|
body,
|
|
format: "org.matrix.custom.html",
|
|
formatted_body: marked.parse(body, { breaks: true }),
|
|
"_llm": {
|
|
model: chunk.model,
|
|
seed: chunk.seed,
|
|
shell: "v1",
|
|
timings: !chunk.stop ? undefined : {
|
|
count: chunk.tokens_predicted,
|
|
tokensPerSecond: chunk.timings.predicted_per_second.toFixed(4),
|
|
},
|
|
}
|
|
};
|
|
|
|
replyId = await sendOrEditMessage(event, content, replyId);
|
|
lastUpdate = Date.now();
|
|
if (!chunk.stop) await updateTyping(event.room, 0);
|
|
}
|
|
|
|
if (chunk.stop) {
|
|
console.log("output", buffer);
|
|
|
|
// update message history
|
|
const [parentId] = db.query("SELECT id FROM messages WHERE event_id = ?", [event.id]);
|
|
createMessage({ role: "assistant", content: buffer }, replyId, parentId[0]);
|
|
break;
|
|
}
|
|
}
|
|
} catch (err) {
|
|
console.error(err);
|
|
const content = {
|
|
msgtype: "m.notice",
|
|
format: "org.matrix.custom.html",
|
|
formatted_body: `<b>Error!</b> ${err.toString()}`,
|
|
body: "Error! " + err.toString(),
|
|
}
|
|
await sendOrEditMessage(event, content, replyId);
|
|
} finally {
|
|
await updateTyping(event.room, -1);
|
|
}
|
|
}
|
|
|
|
const typingCounters = new Map();
|
|
async function updateTyping(room, delta: number) {
|
|
const { client } = room;
|
|
typingCounters.set(room.id, (typingCounters.get(room.id) ?? 0) + delta);
|
|
console.log(typingCounters);
|
|
await client.fetcher.fetchClient(`/rooms/${room.id}/typing/${client.userId}`, {
|
|
method: "PUT",
|
|
body: { timeout: 300000, typing: typingCounters.get(room.id) > 0 },
|
|
});
|
|
}
|
|
|
|
client.on("invite", (invite) => invite.join());
|
|
client.on("ready", () => console.log("ready!"));
|
|
|
|
await client.start();
|