mautrix-discord/backfill.go
2023-06-29 15:19:52 +03:00

380 lines
12 KiB
Go

package main
import (
"context"
"crypto/sha256"
"encoding/base64"
"fmt"
"sort"
"github.com/bwmarrin/discordgo"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"go.mau.fi/mautrix-discord/database"
)
func (portal *Portal) forwardBackfillInitial(source *User, thread *Thread) {
log := portal.log
defer func() {
log.Debug().Msg("Forward backfill finished, unlocking lock")
portal.forwardBackfillLock.Unlock()
}()
// This should only be called from CreateMatrixRoom which locks forwardBackfillLock before creating the room.
if portal.forwardBackfillLock.TryLock() {
panic("forwardBackfillInitial() called without locking forwardBackfillLock")
}
limit := portal.bridge.Config.Bridge.Backfill.Limits.Initial.Channel
if portal.GuildID == "" {
limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.DM
if thread != nil {
limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.Thread
thread.initialBackfillAttempted = true
}
}
if limit == 0 {
return
}
with := log.With().
Str("action", "initial backfill").
Str("room_id", portal.MXID.String()).
Int("limit", limit)
if thread != nil {
with = with.Str("thread_id", thread.ID)
}
log = with.Logger()
portal.backfillLimited(log, source, limit, "", thread)
}
func (portal *Portal) ForwardBackfillMissed(source *User, serverLastMessageID string, thread *Thread) {
if portal.MXID == "" {
return
}
limit := portal.bridge.Config.Bridge.Backfill.Limits.Missed.Channel
if portal.GuildID == "" {
limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.DM
if thread != nil {
limit = portal.bridge.Config.Bridge.Backfill.Limits.Missed.Thread
}
}
if limit == 0 {
return
}
with := portal.log.With().
Str("action", "missed event backfill").
Str("room_id", portal.MXID.String()).
Int("limit", limit)
if thread != nil {
with = with.Str("thread_id", thread.ID)
}
log := with.Logger()
portal.forwardBackfillLock.Lock()
defer portal.forwardBackfillLock.Unlock()
var lastMessage *database.Message
if thread != nil {
lastMessage = portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID)
} else {
lastMessage = portal.bridge.DB.Message.GetLast(portal.Key)
}
if lastMessage == nil || serverLastMessageID == "" {
log.Debug().Msg("Not backfilling, no last message in database or no last message in metadata")
return
} else if !shouldBackfill(lastMessage.DiscordID, serverLastMessageID) {
log.Debug().
Str("last_bridged_message", lastMessage.DiscordID).
Str("last_server_message", serverLastMessageID).
Msg("Not backfilling, last message in database is newer than last message in metadata")
return
}
log.Debug().
Str("last_bridged_message", lastMessage.DiscordID).
Str("last_server_message", serverLastMessageID).
Msg("Backfilling missed messages")
if limit < 0 {
portal.backfillUnlimitedMissed(log, source, lastMessage.DiscordID, thread)
} else {
portal.backfillLimited(log, source, limit, lastMessage.DiscordID, thread)
}
}
const messageFetchChunkSize = 50
func (portal *Portal) collectBackfillMessages(log zerolog.Logger, source *User, limit int, until string, thread *Thread) ([]*discordgo.Message, bool, error) {
var messages []*discordgo.Message
var before string
var foundAll bool
protoChannelID := portal.Key.ChannelID
if thread != nil {
protoChannelID = thread.ID
}
for {
log.Debug().Str("before_id", before).Msg("Fetching messages for backfill")
newMessages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, before, "", "")
if err != nil {
return nil, false, err
}
if until != "" {
for i, msg := range newMessages {
if compareMessageIDs(msg.ID, until) <= 0 {
log.Debug().
Str("message_id", msg.ID).
Str("until_id", until).
Msg("Found message that was already bridged")
newMessages = newMessages[:i]
foundAll = true
break
}
}
}
messages = append(messages, newMessages...)
log.Debug().Int("count", len(newMessages)).Msg("Added messages to backfill collection")
if len(newMessages) < messageFetchChunkSize || len(messages) >= limit {
break
}
before = newMessages[len(newMessages)-1].ID
}
if len(messages) > limit {
foundAll = false
messages = messages[:limit]
}
return messages, foundAll, nil
}
func (portal *Portal) backfillLimited(log zerolog.Logger, source *User, limit int, after string, thread *Thread) {
messages, foundAll, err := portal.collectBackfillMessages(log, source, limit, after, thread)
if err != nil {
log.Err(err).Msg("Error collecting messages to forward backfill")
return
}
log.Info().
Int("count", len(messages)).
Bool("found_all", foundAll).
Msg("Collected messages to backfill")
sort.Sort(MessageSlice(messages))
if !foundAll && after != "" {
_, err = portal.sendMatrixMessage(portal.MainIntent(), event.EventMessage, &event.MessageEventContent{
MsgType: event.MsgNotice,
Body: "Some messages may have been missed here while the bridge was offline.",
}, nil, 0)
if err != nil {
log.Warn().Err(err).Msg("Failed to send missed message warning")
} else {
log.Debug().Msg("Sent warning about possibly missed messages")
}
}
portal.sendBackfillBatch(log, source, messages, thread)
}
func (portal *Portal) backfillUnlimitedMissed(log zerolog.Logger, source *User, after string, thread *Thread) {
protoChannelID := portal.Key.ChannelID
if thread != nil {
protoChannelID = thread.ID
}
for {
log.Debug().Str("after_id", after).Msg("Fetching chunk of messages to backfill")
messages, err := source.Session.ChannelMessages(protoChannelID, messageFetchChunkSize, "", after, "")
if err != nil {
log.Err(err).Msg("Error fetching chunk of messages to forward backfill")
return
}
log.Debug().Int("count", len(messages)).Msg("Fetched chunk of messages to backfill")
sort.Sort(MessageSlice(messages))
portal.sendBackfillBatch(log, source, messages, thread)
if len(messages) < messageFetchChunkSize {
// Assume that was all the missing messages
log.Debug().Msg("Chunk had less than 50 messages, stopping backfill")
return
}
after = messages[len(messages)-1].ID
}
}
func (portal *Portal) sendBackfillBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) {
if portal.bridge.SpecVersions.Supports(mautrix.BeeperFeatureBatchSending) {
log.Debug().Msg("Using hungryserv, sending messages with batch send endpoint")
portal.forwardBatchSend(log, source, messages, thread)
} else {
log.Debug().Msg("Not using hungryserv, sending messages one by one")
for _, msg := range messages {
portal.handleDiscordMessageCreate(source, msg, thread)
}
}
}
func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) {
evts, metas, dbMessages := portal.convertMessageBatch(log, source, messages, thread)
if len(evts) == 0 {
log.Warn().Msg("Didn't get any events to backfill")
return
}
log.Info().Int("events", len(evts)).Msg("Converted messages to backfill")
resp, err := portal.MainIntent().BeeperBatchSend(portal.MXID, &mautrix.ReqBeeperBatchSend{
Forward: true,
Events: evts,
})
if err != nil {
log.Err(err).Msg("Error sending backfill batch")
return
}
for i, evtID := range resp.EventIDs {
dbMessages[i].MXID = evtID
if metas[i] != nil && metas[i].Flags == discordgo.MessageFlagsHasThread {
// TODO proper context
ctx := log.WithContext(context.Background())
portal.bridge.threadFound(ctx, source, &dbMessages[i], metas[i].ID, metas[i].Thread)
}
}
portal.bridge.DB.Message.MassInsert(portal.Key, dbMessages)
}
func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, messages []*discordgo.Message, thread *Thread) ([]*event.Event, []*discordgo.Message, []database.Message) {
var discordThreadID string
var threadRootEvent, lastThreadEvent id.EventID
if thread != nil {
discordThreadID = thread.ID
threadRootEvent = thread.RootMXID
lastThreadEvent = threadRootEvent
lastInThread := portal.bridge.DB.Message.GetLastInThread(portal.Key, thread.ID)
if lastInThread != nil {
lastThreadEvent = lastInThread.MXID
}
}
evts := make([]*event.Event, 0, len(messages))
dbMessages := make([]database.Message, 0, len(messages))
metas := make([]*discordgo.Message, 0, len(messages))
ctx := context.Background()
for _, msg := range messages {
for _, mention := range msg.Mentions {
puppet := portal.bridge.GetPuppetByID(mention.ID)
puppet.UpdateInfo(nil, mention, nil)
}
puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
puppet.UpdateInfo(source, msg.Author, msg)
intent := puppet.IntentFor(portal)
replyTo := portal.getReplyTarget(source, discordThreadID, msg.MessageReference, msg.Embeds, true)
mentions := portal.convertDiscordMentions(msg, false)
ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
log := log.With().
Str("message_id", msg.ID).
Int("message_type", int(msg.Type)).
Str("author_id", msg.Author.ID).
Logger()
parts := portal.convertDiscordMessage(log.WithContext(ctx), puppet, intent, msg)
for i, part := range parts {
if (replyTo != nil || threadRootEvent != "") && part.Content.RelatesTo == nil {
part.Content.RelatesTo = &event.RelatesTo{}
}
if threadRootEvent != "" {
part.Content.RelatesTo.SetThread(threadRootEvent, lastThreadEvent)
}
if replyTo != nil {
part.Content.RelatesTo.SetReplyTo(replyTo.EventID)
// Only set reply for first event
replyTo = nil
}
part.Content.Mentions = mentions
// Only set mentions for first event, but keep empty object for rest
mentions = &event.Mentions{}
partName := part.AttachmentID
// Always use blank part name for first part so that replies and other things
// can reference it without knowing about attachments.
if i == 0 {
partName = ""
}
evt := &event.Event{
ID: portal.deterministicEventID(msg.ID, partName),
Type: part.Type,
Sender: intent.UserID,
Timestamp: ts.UnixMilli(),
Content: event.Content{
Parsed: part.Content,
Raw: part.Extra,
},
}
var err error
evt.Type, err = portal.encrypt(intent, &evt.Content, evt.Type)
if err != nil {
log.Err(err).Msg("Failed to encrypt event")
continue
}
intent.AddDoublePuppetValue(&evt.Content)
evts = append(evts, evt)
dbMessages = append(dbMessages, database.Message{
Channel: portal.Key,
DiscordID: msg.ID,
SenderID: msg.Author.ID,
Timestamp: ts,
AttachmentID: part.AttachmentID,
SenderMXID: intent.UserID,
})
if i == 0 {
metas = append(metas, msg)
} else {
metas = append(metas, nil)
}
lastThreadEvent = evt.ID
}
}
return evts, metas, dbMessages
}
func (portal *Portal) deterministicEventID(messageID, partName string) id.EventID {
data := fmt.Sprintf("%s/discord/%s/%s", portal.MXID, messageID, partName)
sum := sha256.Sum256([]byte(data))
return id.EventID(fmt.Sprintf("$%s:discord.com", base64.RawURLEncoding.EncodeToString(sum[:])))
}
// compareMessageIDs compares two Discord message IDs.
//
// If the first ID is lower, -1 is returned.
// If the second ID is lower, 1 is returned.
// If the IDs are equal, 0 is returned.
func compareMessageIDs(id1, id2 string) int {
if id1 == id2 {
return 0
}
if len(id1) < len(id2) {
return -1
} else if len(id2) < len(id1) {
return 1
}
if id1 < id2 {
return -1
}
return 1
}
func shouldBackfill(latestBridgedIDStr, latestIDFromServerStr string) bool {
return compareMessageIDs(latestBridgedIDStr, latestIDFromServerStr) == -1
}
type MessageSlice []*discordgo.Message
var _ sort.Interface = (MessageSlice)(nil)
func (a MessageSlice) Len() int {
return len(a)
}
func (a MessageSlice) Swap(i, j int) {
a[i], a[j] = a[j], a[i]
}
func (a MessageSlice) Less(i, j int) bool {
return compareMessageIDs(a[i].ID, a[j].ID) == -1
}