Add support for intentional mentions

This commit is contained in:
Tulir Asokan 2023-05-24 13:18:23 +03:00
parent 75181741da
commit 434f27c8b4
9 changed files with 91 additions and 33 deletions

View file

@ -217,7 +217,8 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
puppet.UpdateInfo(source, msg.Author)
intent := puppet.IntentFor(portal)
replyTo := portal.getReplyTarget(source, "", msg.MessageReference, msg.Embeds, true)
replyTo, replySenderMXID := portal.getReplyTarget(source, "", msg.MessageReference, msg.Embeds, true)
mentions := portal.convertDiscordMentions(msg, replySenderMXID, false)
ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
log := log.With().
@ -232,6 +233,11 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
// Only set reply for first event
replyTo = nil
}
part.Content.Mentions = mentions
// Only set mentions for first event, but keep empty object for rest
mentions = &event.Mentions{}
partName := part.AttachmentID
// Always use blank part name for first part so that replies and other things
// can reference it without knowing about attachments.
@ -262,6 +268,7 @@ func (portal *Portal) convertMessageBatch(log zerolog.Logger, source *User, mess
SenderID: msg.Author.ID,
Timestamp: ts,
AttachmentID: part.AttachmentID,
SenderMXID: intent.UserID,
})
}
}

View file

@ -85,6 +85,7 @@ func DoUpgrade(helper *up.Helper) {
helper.Copy(up.Bool, "bridge", "encryption", "require")
helper.Copy(up.Bool, "bridge", "encryption", "appservice")
helper.Copy(up.Bool, "bridge", "encryption", "allow_key_sharing")
helper.Copy(up.Bool, "bridge", "encryption", "plaintext_mentions")
helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "delete_outbound_on_ack")
helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "dont_store_outbound")
helper.Copy(up.Bool, "bridge", "encryption", "delete_keys", "ratchet_on_decrypt")

View file

@ -19,7 +19,7 @@ type MessageQuery struct {
}
const (
messageSelect = "SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid FROM message"
messageSelect = "SELECT dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid, sender_mxid FROM message"
)
func (mq *MessageQuery) New() *Message {
@ -99,11 +99,11 @@ func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) {
if len(msgs) == 0 {
return
}
valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d)"
valueStringFormat := "($%d, $%d, $1, $2, $%d, $%d, $%d, $%d, $%d, $%d)"
if mq.db.Dialect == dbutil.SQLite {
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
}
params := make([]interface{}, 2+len(msgs)*7)
params := make([]interface{}, 2+len(msgs)*8)
placeholders := make([]string, len(msgs))
params[0] = key.ChannelID
params[1] = key.Receiver
@ -116,7 +116,8 @@ func (mq *MessageQuery) MassInsert(key PortalKey, msgs []Message) {
params[baseIndex+4] = msg.editTimestampVal()
params[baseIndex+5] = msg.ThreadID
params[baseIndex+6] = msg.MXID
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7)
params[baseIndex+7] = msg.SenderMXID.String()
placeholders[i] = fmt.Sprintf(valueStringFormat, baseIndex+1, baseIndex+2, baseIndex+3, baseIndex+4, baseIndex+5, baseIndex+6, baseIndex+7, baseIndex+8)
}
_, err := mq.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...)
if err != nil {
@ -137,7 +138,8 @@ type Message struct {
EditTimestamp time.Time
ThreadID string
MXID id.EventID
MXID id.EventID
SenderMXID id.UserID
}
func (m *Message) DiscordProtoChannelID() string {
@ -151,7 +153,7 @@ func (m *Message) DiscordProtoChannelID() string {
func (m *Message) Scan(row dbutil.Scannable) *Message {
var ts, editTS int64
err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &editTS, &m.ThreadID, &m.MXID)
err := row.Scan(&m.DiscordID, &m.AttachmentID, &m.Channel.ChannelID, &m.Channel.Receiver, &m.SenderID, &ts, &editTS, &m.ThreadID, &m.MXID, &m.SenderMXID)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
m.log.Errorln("Database scan failed:", err)
@ -173,12 +175,12 @@ func (m *Message) Scan(row dbutil.Scannable) *Message {
const messageInsertQuery = `
INSERT INTO message (
dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid
dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver, dc_sender, timestamp, dc_edit_timestamp, dc_thread_id, mxid, sender_mxid
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
`
var messageMassInsertTemplate = strings.Replace(messageInsertQuery, "($1, $2, $3, $4, $5, $6, $7, $8, $9)", "%s", 1)
var messageMassInsertTemplate = strings.Replace(messageInsertQuery, "($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)", "%s", 1)
type MessagePart struct {
AttachmentID string
@ -196,11 +198,11 @@ func (m *Message) MassInsertParts(msgs []MessagePart) {
if len(msgs) == 0 {
return
}
valueStringFormat := "($1, $%d, $2, $3, $4, $5, $6, $7, $%d)"
valueStringFormat := "($1, $%d, $2, $3, $4, $5, $6, $7, $%d, $8)"
if m.db.Dialect == dbutil.SQLite {
valueStringFormat = strings.ReplaceAll(valueStringFormat, "$", "?")
}
params := make([]interface{}, 7+len(msgs)*2)
params := make([]interface{}, 8+len(msgs)*2)
placeholders := make([]string, len(msgs))
params[0] = m.DiscordID
params[1] = m.Channel.ChannelID
@ -209,10 +211,11 @@ func (m *Message) MassInsertParts(msgs []MessagePart) {
params[4] = m.Timestamp.UnixMilli()
params[5] = m.editTimestampVal()
params[6] = m.ThreadID
params[7] = m.SenderMXID.String()
for i, msg := range msgs {
params[7+i*2] = msg.AttachmentID
params[7+i*2+1] = msg.MXID
placeholders[i] = fmt.Sprintf(valueStringFormat, 7+i*2+1, 7+i*2+2)
params[8+i*2] = msg.AttachmentID
params[8+i*2+1] = msg.MXID
placeholders[i] = fmt.Sprintf(valueStringFormat, 8+i*2+1, 8+i*2+2)
}
_, err := m.db.Exec(fmt.Sprintf(messageMassInsertTemplate, strings.Join(placeholders, ", ")), params...)
if err != nil {
@ -224,7 +227,7 @@ func (m *Message) MassInsertParts(msgs []MessagePart) {
func (m *Message) Insert() {
_, err := m.db.Exec(messageInsertQuery,
m.DiscordID, m.AttachmentID, m.Channel.ChannelID, m.Channel.Receiver, m.SenderID,
m.Timestamp.UnixMilli(), m.editTimestampVal(), m.ThreadID, m.MXID)
m.Timestamp.UnixMilli(), m.editTimestampVal(), m.ThreadID, m.MXID, m.SenderMXID.String())
if err != nil {
m.log.Warnfln("Failed to insert %s@%s: %v", m.DiscordID, m.Channel, err)

View file

@ -1,4 +1,4 @@
-- v0 -> v19: Latest revision
-- v0 -> v20 (compatible with v19+): Latest revision
CREATE TABLE guild (
dcid TEXT PRIMARY KEY,
@ -113,7 +113,8 @@ CREATE TABLE message (
dc_edit_timestamp BIGINT NOT NULL,
dc_thread_id TEXT NOT NULL,
mxid TEXT NOT NULL UNIQUE,
mxid TEXT NOT NULL UNIQUE,
sender_mxid TEXT NOT NULL DEFAULT '',
PRIMARY KEY (dcid, dc_attachment_id, dc_chan_id, dc_chan_receiver),
CONSTRAINT message_portal_fkey FOREIGN KEY (dc_chan_id, dc_chan_receiver) REFERENCES portal (dcid, receiver) ON DELETE CASCADE

View file

@ -0,0 +1,2 @@
-- v20 (compatible with v19+): Store message sender Matrix user ID
ALTER TABLE message ADD COLUMN sender_mxid TEXT NOT NULL DEFAULT '';

View file

@ -247,6 +247,8 @@ bridge:
# Enable key sharing? If enabled, key requests for rooms where users are in will be fulfilled.
# You must use a client that supports requesting keys from other users to use this feature.
allow_key_sharing: false
# Should users mentions be in the event wire content to enable the server to send push notifications?
plaintext_mentions: false
# Options for deleting megolm sessions from the bridge.
delete_keys:
# Beeper-specific: delete outbound sessions when hungryserv confirms

View file

@ -26,6 +26,7 @@ import (
"github.com/yuin/goldmark/extension"
"github.com/yuin/goldmark/parser"
"github.com/yuin/goldmark/util"
"golang.org/x/exp/slices"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
@ -93,6 +94,7 @@ func (portal *Portal) renderDiscordMarkdownOnlyHTML(text string, allowInlineLink
const formatterContextPortalKey = "fi.mau.discord.portal"
const formatterContextAllowedMentionsKey = "fi.mau.discord.allowed_mentions"
const formatterContextInputAllowedMentionsKey = "fi.mau.discord.input_allowed_mentions"
func appendIfNotContains(arr []string, newItem string) []string {
for _, item := range arr {
@ -135,6 +137,10 @@ func (br *DiscordBridge) pillConverter(displayname, mxid, eventID string, ctx fo
}
}
} else if mxid[0] == '@' {
allowedMentions, _ := ctx.ReturnData[formatterContextInputAllowedMentionsKey].([]id.UserID)
if allowedMentions != nil && !slices.Contains(allowedMentions, id.UserID(mxid)) {
return displayname
}
mentions := ctx.ReturnData[formatterContextAllowedMentionsKey].(*discordgo.MessageAllowedMentions)
parsedID, ok := br.ParsePuppetMXID(id.UserID(mxid))
if ok {
@ -219,6 +225,9 @@ func (portal *Portal) parseMatrixHTML(content *event.MessageEventContent) (strin
ctx := format.NewContext()
ctx.ReturnData[formatterContextPortalKey] = portal
ctx.ReturnData[formatterContextAllowedMentionsKey] = allowedMentions
if content.Mentions != nil {
ctx.ReturnData[formatterContextInputAllowedMentionsKey] = content.Mentions.UserIDs
}
return variationselector.FullyQualify(matrixHTMLParser.Parse(content.FormattedBody, ctx)), allowedMentions
} else {
return variationselector.FullyQualify(escapeDiscordMarkdown(content.Body)), allowedMentions

View file

@ -584,13 +584,14 @@ func (portal *Portal) ensureUserInvited(user *User, ignoreCache bool) bool {
return user.ensureInvited(portal.MainIntent(), portal.MXID, portal.IsPrivateChat(), ignoreCache)
}
func (portal *Portal) markMessageHandled(discordID string, authorID string, timestamp time.Time, threadID string, parts []database.MessagePart) {
func (portal *Portal) markMessageHandled(discordID string, authorID string, timestamp time.Time, threadID string, senderMXID id.UserID, parts []database.MessagePart) {
msg := portal.bridge.DB.Message.New()
msg.Channel = portal.Key
msg.DiscordID = discordID
msg.SenderID = authorID
msg.Timestamp = timestamp
msg.ThreadID = threadID
msg.SenderMXID = senderMXID
msg.MassInsertParts(parts)
}
@ -618,11 +619,6 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
}
log.Debug().Msg("Starting handling of Discord message")
for _, mention := range msg.Mentions {
puppet := portal.bridge.GetPuppetByID(mention.ID)
puppet.UpdateInfo(nil, mention)
}
puppet := portal.bridge.GetPuppetByID(msg.Author.ID)
puppet.UpdateInfo(user, msg.Author)
intent := puppet.IntentFor(portal)
@ -638,7 +634,8 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
lastThreadEvent = lastInThread.MXID
}
}
replyTo := portal.getReplyTarget(user, discordThreadID, msg.MessageReference, msg.Embeds, false)
replyTo, replySenderMXID := portal.getReplyTarget(user, discordThreadID, msg.MessageReference, msg.Embeds, false)
mentions := portal.convertDiscordMentions(msg, replySenderMXID, true)
ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
parts := portal.convertDiscordMessage(ctx, intent, msg)
@ -658,6 +655,11 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
// Only set reply for first event
replyTo = nil
}
part.Content.Mentions = mentions
// Only set mentions for first event, but keep empty object for rest
mentions = &event.Mentions{}
resp, err := portal.sendMatrixMessage(intent, part.Type, part.Content, part.Extra, ts.UnixMilli())
if err != nil {
log.Err(err).
@ -674,7 +676,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
} else if len(dbParts) == 0 {
log.Warn().Msg("All parts of message failed to send to Matrix")
} else {
portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, dbParts)
portal.markMessageHandled(msg.ID, msg.Author.ID, ts, discordThreadID, intent.UserID, dbParts)
}
}
@ -684,7 +686,7 @@ func isReplyEmbed(embed *discordgo.MessageEmbed) bool {
return hackyReplyPattern.MatchString(embed.Description)
}
func (portal *Portal) getReplyTarget(source *User, threadID string, ref *discordgo.MessageReference, embeds []*discordgo.MessageEmbed, allowNonExistent bool) *event.InReplyTo {
func (portal *Portal) getReplyTarget(source *User, threadID string, ref *discordgo.MessageReference, embeds []*discordgo.MessageEmbed, allowNonExistent bool) (*event.InReplyTo, id.UserID) {
if ref == nil && len(embeds) > 0 {
match := hackyReplyPattern.FindStringSubmatch(embeds[0].Description)
if match != nil && match[1] == portal.GuildID && (match[2] == portal.Key.ChannelID || match[2] == threadID) {
@ -696,7 +698,7 @@ func (portal *Portal) getReplyTarget(source *User, threadID string, ref *discord
}
}
if ref == nil {
return nil
return nil, ""
}
isHungry := portal.bridge.Config.Homeserver.Software == bridgeconfig.SoftwareHungry
if !isHungry {
@ -709,25 +711,25 @@ func (portal *Portal) getReplyTarget(source *User, threadID string, ref *discord
if ref.ChannelID != portal.Key.ChannelID && ref.ChannelID != threadID && crossRoomReplies {
targetPortal = portal.bridge.GetExistingPortalByID(database.PortalKey{ChannelID: ref.ChannelID, Receiver: source.DiscordID})
if targetPortal == nil {
return nil
return nil, ""
}
}
replyToMsg := portal.bridge.DB.Message.GetByDiscordID(targetPortal.Key, ref.MessageID)
if len(replyToMsg) > 0 {
if !crossRoomReplies {
return &event.InReplyTo{EventID: replyToMsg[0].MXID}
return &event.InReplyTo{EventID: replyToMsg[0].MXID}, replyToMsg[0].SenderMXID
}
return &event.InReplyTo{
EventID: replyToMsg[0].MXID,
UnstableRoomID: targetPortal.MXID,
}
}, replyToMsg[0].SenderMXID
} else if allowNonExistent {
return &event.InReplyTo{
EventID: targetPortal.deterministicEventID(ref.MessageID, ""),
UnstableRoomID: targetPortal.MXID,
}
}, ""
}
return nil
return nil, ""
}
const JoinThreadReaction = "join thread"
@ -895,7 +897,10 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess
Msg("Dropping non-text edit")
return
}
converted.Content.Mentions = portal.convertDiscordMentions(msg, "", false)
converted.Content.SetEdit(existing[0].MXID)
// Never actually mention new users of edits, only include mentions inside m.new_content
converted.Content.Mentions = &event.Mentions{}
if converted.Extra != nil {
converted.Extra = map[string]any{
"m.new_content": converted.Extra,
@ -1585,6 +1590,7 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) {
} else {
dbMsg.SenderID = portal.RelayWebhookID
}
dbMsg.SenderMXID = sender.MXID
dbMsg.Timestamp, _ = discordgo.SnowflakeTimestamp(msg.ID)
dbMsg.ThreadID = threadID
dbMsg.Insert()

View file

@ -26,6 +26,8 @@ import (
"github.com/bwmarrin/discordgo"
"github.com/rs/zerolog"
"golang.org/x/exp/slices"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice"
@ -518,6 +520,31 @@ func isPlainGifMessage(msg *discordgo.Message) bool {
return len(msg.Embeds) == 1 && msg.Embeds[0].Video != nil && msg.Embeds[0].URL == msg.Content && msg.Embeds[0].Type == discordgo.EmbedTypeGifv
}
func (portal *Portal) convertDiscordMentions(msg *discordgo.Message, replySender id.UserID, syncGhosts bool) *event.Mentions {
var matrixMentions event.Mentions
for _, mention := range msg.Mentions {
puppet := portal.bridge.GetPuppetByID(mention.ID)
if syncGhosts {
puppet.UpdateInfo(nil, mention)
}
user := portal.bridge.GetUserByID(mention.ID)
if user != nil {
matrixMentions.UserIDs = append(matrixMentions.UserIDs, user.MXID)
} else {
matrixMentions.UserIDs = append(matrixMentions.UserIDs, puppet.MXID)
}
}
if replySender != "" {
matrixMentions.UserIDs = append(matrixMentions.UserIDs, replySender)
}
slices.Sort(matrixMentions.UserIDs)
matrixMentions.UserIDs = slices.Compact(matrixMentions.UserIDs)
if msg.MentionEveryone {
matrixMentions.Room = true
}
return &matrixMentions
}
func (portal *Portal) convertDiscordTextMessage(ctx context.Context, intent *appservice.IntentAPI, msg *discordgo.Message) *ConvertedMessage {
log := zerolog.Ctx(ctx)
if msg.Type == discordgo.MessageTypeCall {