diff --git a/backfill.go b/backfill.go index 20e254d..74f1256 100644 --- a/backfill.go +++ b/backfill.go @@ -29,6 +29,7 @@ func (portal *Portal) forwardBackfillInitial(source *User, thread *Thread) { limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.DM if thread != nil { limit = portal.bridge.Config.Bridge.Backfill.Limits.Initial.Thread + thread.initialBackfillAttempted = true } } if limit == 0 { @@ -225,16 +226,9 @@ func (portal *Portal) forwardBatchSend(log zerolog.Logger, source *User, message for i, evtID := range resp.EventIDs { dbMessages[i].MXID = evtID if metas[i] != nil && metas[i].Flags == discordgo.MessageFlagsHasThread { - thread = portal.bridge.GetThreadByID(metas[i].ID, &dbMessages[i]) - log.Debug(). - Str("message_id", metas[i].ID). - Str("event_id", evtID.String()). - Msg("Marked backfilled message as thread root") - if thread.CreationNoticeMXID == "" { - // TODO proper context - ctx := log.WithContext(context.Background()) - portal.sendThreadCreationNotice(ctx, thread) - } + // TODO proper context + ctx := log.WithContext(context.Background()) + portal.bridge.threadFound(ctx, source, &dbMessages[i], metas[i].ID, metas[i].Thread) } } portal.bridge.DB.Message.MassInsert(portal.Key, dbMessages) diff --git a/go.mod b/go.mod index 436a0b8..5ccf1ea 100644 --- a/go.mod +++ b/go.mod @@ -38,4 +38,4 @@ require ( maunium.net/go/mauflag v1.0.0 // indirect ) -replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20230512133900-5b12693331c0 +replace github.com/bwmarrin/discordgo => github.com/beeper/discordgo v0.0.0-20230618183737-3c7afd8d8596 diff --git a/go.sum b/go.sum index 8458c70..61addee 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,6 @@ github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= -github.com/beeper/discordgo v0.0.0-20230512133900-5b12693331c0 h1:ECBEbC4ruaXzcVJJ4UurkGpT/Xlm9ZnwsHiHn9gjPZw= -github.com/beeper/discordgo v0.0.0-20230512133900-5b12693331c0/go.mod h1:59+AOzzjmL6onAh62nuLXmn7dJCaC/owDLWbGtjTcFA= +github.com/beeper/discordgo v0.0.0-20230618183737-3c7afd8d8596 h1:PxtbetWbVi2OlACDNtx6YJahhXt/rhiEsGqtOOLSx4o= +github.com/beeper/discordgo v0.0.0-20230618183737-3c7afd8d8596/go.mod h1:59+AOzzjmL6onAh62nuLXmn7dJCaC/owDLWbGtjTcFA= github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/portal.go b/portal.go index f3dd341..6dcce43 100644 --- a/portal.go +++ b/portal.go @@ -683,11 +683,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess } else { firstDBMessage := portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, intent.UserID, dbParts) if msg.Flags == discordgo.MessageFlagsHasThread { - thread = portal.bridge.GetThreadByID(msg.ID, firstDBMessage) - log.Debug().Msg("Marked message as thread root") - if thread.CreationNoticeMXID == "" { - portal.sendThreadCreationNotice(ctx, thread) - } + portal.bridge.threadFound(ctx, user, firstDBMessage, msg.ID, msg.Thread) } } } @@ -817,11 +813,7 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess } if msg.Flags == discordgo.MessageFlagsHasThread { - thread := portal.bridge.GetThreadByID(msg.ID, existing[0]) - log.Debug().Msg("Marked message as thread root") - if thread.CreationNoticeMXID == "" { - portal.sendThreadCreationNotice(ctx, thread) - } + portal.bridge.threadFound(ctx, user, existing[0], msg.ID, msg.Thread) } if msg.Author == nil { @@ -1476,6 +1468,7 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { existingThread := portal.bridge.GetThreadByRootMXID(threadRoot) if existingThread != nil { threadID = existingThread.ID + existingThread.initialBackfillAttempted = true } else { if isWebhookSend { // TODO start thread with bot? diff --git a/thread.go b/thread.go index 843beb5..5de2410 100644 --- a/thread.go +++ b/thread.go @@ -1,10 +1,13 @@ package main import ( + "context" "sync" "time" "github.com/bwmarrin/discordgo" + "github.com/rs/zerolog" + "golang.org/x/exp/slices" "maunium.net/go/mautrix/id" "go.mau.fi/mautrix-discord/database" @@ -14,7 +17,8 @@ type Thread struct { *database.Thread Parent *Portal - creationNoticeLock sync.Mutex + creationNoticeLock sync.Mutex + initialBackfillAttempted bool } func (br *DiscordBridge) GetThreadByID(id string, root *database.Message) *Thread { @@ -74,12 +78,63 @@ func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root * return thread } +func (br *DiscordBridge) threadFound(ctx context.Context, source *User, rootMessage *database.Message, id string, metadata *discordgo.Channel) { + thread := br.GetThreadByID(id, rootMessage) + log := zerolog.Ctx(ctx) + log.Debug().Msg("Marked message as thread root") + if thread.CreationNoticeMXID == "" { + thread.Parent.sendThreadCreationNotice(ctx, thread) + } + // TODO member_ids_preview is probably not guaranteed to contain the source user + if source != nil && metadata != nil && slices.Contains(metadata.MemberIDsPreview, source.DiscordID) && !source.IsInPortal(thread.ID) { + source.MarkInPortal(database.UserPortal{ + DiscordID: thread.ID, + Type: database.UserPortalTypeThread, + Timestamp: time.Now(), + }) + if metadata.MessageCount > 0 { + go thread.maybeInitialBackfill(source) + } else { + thread.initialBackfillAttempted = true + } + } +} + +func (thread *Thread) maybeInitialBackfill(source *User) { + if thread.initialBackfillAttempted || thread.Parent.bridge.Config.Bridge.Backfill.Limits.Initial.Thread == 0 { + return + } + thread.Parent.forwardBackfillLock.Lock() + if thread.Parent.bridge.DB.Message.GetLastInThread(thread.Parent.Key, thread.ID) != nil { + thread.Parent.forwardBackfillLock.Unlock() + return + } + thread.Parent.forwardBackfillInitial(source, thread) +} + func (thread *Thread) Join(user *User) { if user.IsInPortal(thread.ID) { return } log := user.log.With().Str("thread_id", thread.ID).Str("channel_id", thread.ParentID).Logger() log.Debug().Msg("Joining thread") + + var doBackfill, backfillStarted bool + if !thread.initialBackfillAttempted && thread.Parent.bridge.Config.Bridge.Backfill.Limits.Initial.Thread > 0 { + thread.Parent.forwardBackfillLock.Lock() + lastMessage := thread.Parent.bridge.DB.Message.GetLastInThread(thread.Parent.Key, thread.ID) + if lastMessage != nil { + thread.Parent.forwardBackfillLock.Unlock() + } else { + doBackfill = true + defer func() { + if !backfillStarted { + thread.Parent.forwardBackfillLock.Unlock() + } + }() + } + } + var err error if user.Session.IsUser { err = user.Session.ThreadJoinWithLocation(thread.ID, discordgo.ThreadJoinLocationContextMenu) @@ -94,5 +149,9 @@ func (thread *Thread) Join(user *User) { Type: database.UserPortalTypeThread, Timestamp: time.Now(), }) + if doBackfill { + go thread.Parent.forwardBackfillInitial(user, thread) + backfillStarted = true + } } } diff --git a/user.go b/user.go index dd75978..7e68b10 100644 --- a/user.go +++ b/user.go @@ -1,6 +1,7 @@ package main import ( + "context" "errors" "fmt" "math/rand" @@ -650,6 +651,8 @@ func (user *User) eventHandler(rawEvt any) { user.typingStartHandler(evt) case *discordgo.InteractionSuccess: user.interactionSuccessHandler(evt) + case *discordgo.ThreadListSync: + user.threadListSyncHandler(evt) case *discordgo.Event: // Ignore default: @@ -1038,6 +1041,30 @@ func (user *User) guildUpdateHandler(g *discordgo.GuildUpdate) { user.handleGuild(g.Guild, time.Now(), user.IsInSpace(g.ID)) } +func (user *User) threadListSyncHandler(t *discordgo.ThreadListSync) { + for _, meta := range t.Threads { + log := user.log.With(). + Str("action", "thread list sync"). + Str("guild_id", t.GuildID). + Str("parent_id", meta.ParentID). + Str("thread_id", meta.ID). + Logger() + ctx := log.WithContext(context.Background()) + thread := user.bridge.GetThreadByID(meta.ID, nil) + if thread == nil { + msg := user.bridge.DB.Message.GetByDiscordID(database.NewPortalKey(meta.ParentID, ""), meta.ID) + if len(msg) == 0 { + log.Debug().Msg("Found unknown thread in thread list sync and don't have message") + } else { + log.Debug().Msg("Found unknown thread in thread list sync for existing message, creating thread") + user.bridge.threadFound(ctx, user, msg[0], meta.ID, meta) + } + } else { + thread.Parent.ForwardBackfillMissed(user, meta.LastMessageID, thread) + } + } +} + func (user *User) channelCreateHandler(c *discordgo.ChannelCreate) { if user.getGuildBridgingMode(c.GuildID) < database.GuildBridgeEverything { user.log.Debug().