mautrix-discord/thread.go

158 lines
4.4 KiB
Go
Raw Normal View History

2022-05-28 20:03:24 +00:00
package main
import (
"context"
"sync"
"time"
"github.com/bwmarrin/discordgo"
"github.com/rs/zerolog"
"golang.org/x/exp/slices"
2022-05-28 20:03:24 +00:00
"maunium.net/go/mautrix/id"
"go.mau.fi/mautrix-discord/database"
)
type Thread struct {
*database.Thread
Parent *Portal
creationNoticeLock sync.Mutex
initialBackfillAttempted bool
2022-05-28 20:03:24 +00:00
}
func (br *DiscordBridge) GetThreadByID(id string, root *database.Message) *Thread {
br.threadsLock.Lock()
defer br.threadsLock.Unlock()
thread, ok := br.threadsByID[id]
if !ok {
return br.loadThread(br.DB.Thread.GetByDiscordID(id), id, root)
}
return thread
}
func (br *DiscordBridge) GetThreadByRootMXID(mxid id.EventID) *Thread {
br.threadsLock.Lock()
defer br.threadsLock.Unlock()
thread, ok := br.threadsByRootMXID[mxid]
if !ok {
return br.loadThread(br.DB.Thread.GetByMatrixRootMsg(mxid), "", nil)
}
return thread
}
func (br *DiscordBridge) GetThreadByRootOrCreationNoticeMXID(mxid id.EventID) *Thread {
br.threadsLock.Lock()
defer br.threadsLock.Unlock()
thread, ok := br.threadsByRootMXID[mxid]
if !ok {
thread, ok = br.threadsByCreationNoticeMXID[mxid]
if !ok {
return br.loadThread(br.DB.Thread.GetByMatrixRootOrCreationNoticeMsg(mxid), "", nil)
}
}
return thread
}
2022-05-28 20:03:24 +00:00
func (br *DiscordBridge) loadThread(dbThread *database.Thread, id string, root *database.Message) *Thread {
if dbThread == nil {
if root == nil {
return nil
}
dbThread = br.DB.Thread.New()
dbThread.ID = id
dbThread.RootDiscordID = root.DiscordID
dbThread.RootMXID = root.MXID
dbThread.ParentID = root.Channel.ChannelID
dbThread.Insert()
}
thread := &Thread{
Thread: dbThread,
}
thread.Parent = br.GetExistingPortalByID(database.NewPortalKey(thread.ParentID, ""))
br.threadsByID[thread.ID] = thread
br.threadsByRootMXID[thread.RootMXID] = thread
if thread.CreationNoticeMXID != "" {
br.threadsByCreationNoticeMXID[thread.CreationNoticeMXID] = thread
}
2022-05-28 20:03:24 +00:00
return thread
}
func (br *DiscordBridge) threadFound(ctx context.Context, source *User, rootMessage *database.Message, id string, metadata *discordgo.Channel) {
thread := br.GetThreadByID(id, rootMessage)
log := zerolog.Ctx(ctx)
log.Debug().Msg("Marked message as thread root")
if thread.CreationNoticeMXID == "" {
thread.Parent.sendThreadCreationNotice(ctx, thread)
}
// TODO member_ids_preview is probably not guaranteed to contain the source user
if source != nil && metadata != nil && slices.Contains(metadata.MemberIDsPreview, source.DiscordID) && !source.IsInPortal(thread.ID) {
source.MarkInPortal(database.UserPortal{
DiscordID: thread.ID,
Type: database.UserPortalTypeThread,
Timestamp: time.Now(),
})
if metadata.MessageCount > 0 {
go thread.maybeInitialBackfill(source)
} else {
thread.initialBackfillAttempted = true
}
}
}
func (thread *Thread) maybeInitialBackfill(source *User) {
if thread.initialBackfillAttempted || thread.Parent.bridge.Config.Bridge.Backfill.Limits.Initial.Thread == 0 {
return
}
thread.Parent.forwardBackfillLock.Lock()
if thread.Parent.bridge.DB.Message.GetLastInThread(thread.Parent.Key, thread.ID) != nil {
thread.Parent.forwardBackfillLock.Unlock()
return
}
thread.Parent.forwardBackfillInitial(source, thread)
}
func (thread *Thread) Join(user *User) {
if user.IsInPortal(thread.ID) {
return
}
log := user.log.With().Str("thread_id", thread.ID).Str("channel_id", thread.ParentID).Logger()
log.Debug().Msg("Joining thread")
var doBackfill, backfillStarted bool
if !thread.initialBackfillAttempted && thread.Parent.bridge.Config.Bridge.Backfill.Limits.Initial.Thread > 0 {
thread.Parent.forwardBackfillLock.Lock()
lastMessage := thread.Parent.bridge.DB.Message.GetLastInThread(thread.Parent.Key, thread.ID)
if lastMessage != nil {
thread.Parent.forwardBackfillLock.Unlock()
} else {
doBackfill = true
defer func() {
if !backfillStarted {
thread.Parent.forwardBackfillLock.Unlock()
}
}()
}
}
2023-02-26 23:12:06 +00:00
var err error
if user.Session.IsUser {
err = user.Session.ThreadJoinWithLocation(thread.ID, discordgo.ThreadJoinLocationContextMenu)
} else {
err = user.Session.ThreadJoin(thread.ID)
}
if err != nil {
log.Error().Err(err).Msg("Error joining thread")
} else {
user.MarkInPortal(database.UserPortal{
DiscordID: thread.ID,
Type: database.UserPortalTypeThread,
Timestamp: time.Now(),
})
if doBackfill {
go thread.Parent.forwardBackfillInitial(user, thread)
backfillStarted = true
}
}
}