Add support for tags in Discord -> Matrix formatter

This commit is contained in:
Tulir Asokan 2022-07-02 14:48:42 +03:00
parent 98ec4c6ed9
commit 152fb5c7ce
5 changed files with 270 additions and 41 deletions

View file

@ -10,6 +10,32 @@ import (
"maunium.net/go/mautrix/id"
)
func (portal *Portal) getEmojiMXCByDiscordID(emojiID, name string, animated bool) id.ContentURI {
dbEmoji := portal.bridge.DB.Emoji.GetByDiscordID(emojiID)
if dbEmoji == nil {
data, mimeType, err := portal.downloadDiscordEmoji(emojiID, animated)
if err != nil {
portal.log.Warnfln("Failed to download emoji %s from discord: %v", emojiID, err)
return id.ContentURI{}
}
uri, err := portal.uploadMatrixEmoji(portal.MainIntent(), data, mimeType)
if err != nil {
portal.log.Warnfln("Failed to upload discord emoji %s to homeserver: %v", emojiID, err)
return id.ContentURI{}
}
dbEmoji = portal.bridge.DB.Emoji.New()
dbEmoji.DiscordID = emojiID
dbEmoji.DiscordName = name
dbEmoji.MatrixURL = uri
dbEmoji.Insert()
}
return dbEmoji.MatrixURL
}
func (portal *Portal) downloadDiscordEmoji(id string, animated bool) ([]byte, string, error) {
var url string
var mimeType string

View file

@ -1,3 +1,19 @@
// mautrix-discord - A Matrix-Discord puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package main
import (
@ -13,14 +29,17 @@ import (
"maunium.net/go/mautrix/format/mdext"
)
var mdRenderer = goldmark.New(format.Extensions, format.HTMLOptions,
goldmark.WithExtensions(mdext.EscapeHTML, mdext.SimpleSpoiler, mdext.DiscordUnderline))
var discordExtensions = goldmark.WithExtensions(mdext.EscapeHTML, mdext.SimpleSpoiler, mdext.DiscordUnderline)
var escapeFixer = regexp.MustCompile(`\\(__[^_]|\*\*[^*])`)
func renderDiscordMarkdown(text string) event.MessageEventContent {
func (portal *Portal) renderDiscordMarkdown(text string) event.MessageEventContent {
text = escapeFixer.ReplaceAllStringFunc(text, func(s string) string {
return s[:2] + `\` + s[2:]
})
mdRenderer := goldmark.New(
format.Extensions, format.HTMLOptions, discordExtensions,
goldmark.WithExtensions(&DiscordTag{portal}),
)
return format.RenderMarkdownCustom(text, mdRenderer)
}

208
formatter_tag.go Normal file
View file

@ -0,0 +1,208 @@
// mautrix-discord - A Matrix-Discord puppeting bridge.
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package main
import (
"fmt"
"regexp"
"strconv"
"strings"
"github.com/yuin/goldmark"
"github.com/yuin/goldmark/ast"
"github.com/yuin/goldmark/parser"
"github.com/yuin/goldmark/renderer"
"github.com/yuin/goldmark/text"
"github.com/yuin/goldmark/util"
"maunium.net/go/mautrix"
"go.mau.fi/mautrix-discord/database"
)
type astDiscordTag struct {
ast.BaseInline
id int64
}
var _ ast.Node = (*astDiscordTag)(nil)
var astKindDiscordTag = ast.NewNodeKind("DiscordTag")
func (n *astDiscordTag) Dump(source []byte, level int) {
ast.DumpHelper(n, source, level, nil, nil)
}
func (n *astDiscordTag) Kind() ast.NodeKind {
return astKindDiscordTag
}
type astDiscordUserMention struct {
astDiscordTag
hasNick bool
}
func (n *astDiscordUserMention) String() string {
if n.hasNick {
return fmt.Sprintf("<@!%d>", n.id)
}
return fmt.Sprintf("<@%d>", n.id)
}
type astDiscordRoleMention struct {
astDiscordTag
}
func (n *astDiscordRoleMention) String() string {
return fmt.Sprintf("<@&%d>", n.id)
}
type astDiscordChannelMention struct {
astDiscordTag
guildID int64
name string
}
func (n *astDiscordChannelMention) String() string {
if n.guildID != 0 {
return fmt.Sprintf("<#%d:%d:%s>", n.id, n.guildID, n.name)
}
return fmt.Sprintf("<#%d>", n.id)
}
type astDiscordCustomEmoji struct {
astDiscordTag
name string
animated bool
}
func (n *astDiscordCustomEmoji) String() string {
if n.animated {
return fmt.Sprintf("<a%s%s>", n.name, n.id)
}
return fmt.Sprintf("<%s%s>", n.name, n.id)
}
type discordTagParser struct{}
var discordTagRegex = regexp.MustCompile(`<(a?:\w+:|@[!&]?|#)(\d+)(?::(\d+):(.+?))?>`)
var defaultDiscordTagParser = &discordTagParser{}
func (s *discordTagParser) Trigger() []byte {
return []byte{'<'}
}
func (s *discordTagParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node {
//before := block.PrecendingCharacter()
line, _ := block.PeekLine()
match := discordTagRegex.FindSubmatch(line)
if match == nil {
return nil
}
//seg := segment.WithStop(segment.Start + len(match[0]))
block.Advance(len(match[0]))
id, err := strconv.ParseInt(string(match[2]), 10, 64)
if err != nil {
return nil
}
tag := astDiscordTag{id: id}
tagName := string(match[1])
switch {
case tagName == "@":
return &astDiscordUserMention{astDiscordTag: tag}
case tagName == "@!":
return &astDiscordUserMention{astDiscordTag: tag, hasNick: true}
case tagName == "@&":
return &astDiscordRoleMention{astDiscordTag: tag}
case tagName == "#":
var guildID int64
var channelName string
if len(match[3]) > 0 && len(match[4]) > 0 {
guildID, _ = strconv.ParseInt(string(match[3]), 10, 64)
channelName = string(match[4])
}
return &astDiscordChannelMention{astDiscordTag: tag, guildID: guildID, name: channelName}
case strings.HasPrefix(tagName, ":"):
return &astDiscordCustomEmoji{name: tagName, astDiscordTag: tag}
case strings.HasPrefix(tagName, "a:"):
return &astDiscordCustomEmoji{name: tagName[1:], astDiscordTag: tag}
default:
return nil
}
}
func (s *discordTagParser) CloseBlock(parent ast.Node, pc parser.Context) {
// nothing to do
}
type discordTagHTMLRenderer struct {
portal *Portal
}
func (r *discordTagHTMLRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) {
reg.Register(astKindDiscordTag, r.renderDiscordMention)
}
func (r *discordTagHTMLRenderer) renderDiscordMention(w util.BufWriter, source []byte, n ast.Node, entering bool) (status ast.WalkStatus, err error) {
status = ast.WalkContinue
if !entering {
return
}
switch node := n.(type) {
case *astDiscordUserMention:
puppet := r.portal.bridge.GetPuppetByID(strconv.FormatInt(node.id, 10))
_, _ = fmt.Fprintf(w, `<a href="https://matrix.to/#/%s">%s</a>`, puppet.MXID, puppet.Name)
return
case *astDiscordRoleMention:
// TODO
case *astDiscordChannelMention:
portal := r.portal.bridge.GetExistingPortalByID(database.PortalKey{
ChannelID: strconv.FormatInt(node.id, 10),
Receiver: "",
})
if portal != nil {
_, _ = fmt.Fprintf(w, `<a href="https://matrix.to/#/%s?via=%s">%s</a>`, portal.MXID, portal.bridge.AS.HomeserverDomain, portal.Name)
return
}
case *astDiscordCustomEmoji:
reactionMXC := r.portal.getEmojiMXCByDiscordID(strconv.FormatInt(node.id, 10), node.name, node.animated)
if !reactionMXC.IsEmpty() {
_, _ = fmt.Fprintf(w, `<img data-mx-emoticon src="%[1]s" alt="%[2]s" title="%[2]s" height="32"/>`, reactionMXC.String(), node.name)
return
}
}
stringifiable, ok := n.(mautrix.Stringifiable)
if ok {
_, _ = w.WriteString(stringifiable.String())
} else {
_, _ = w.Write(source)
}
return
}
type DiscordTag struct {
Portal *Portal
}
func (e *DiscordTag) Extend(m goldmark.Markdown) {
m.Parser().AddOptions(parser.WithInlineParsers(
util.Prioritized(defaultDiscordTagParser, 600),
))
m.Renderer().AddOptions(renderer.WithNodeRenderers(
util.Prioritized(&discordTagHTMLRenderer{e.Portal}, 600),
))
}

View file

@ -611,7 +611,7 @@ func (portal *Portal) handleDiscordMessageCreate(user *User, msg *discordgo.Mess
var parts []database.MessagePart
ts, _ := discordgo.SnowflakeTimestamp(msg.ID)
if msg.Content != "" {
content := renderDiscordMarkdown(msg.Content)
content := portal.renderDiscordMarkdown(msg.Content)
content.RelatesTo = threadRelation.Copy()
if msg.MessageReference != nil {
@ -697,24 +697,24 @@ func (portal *Portal) handleDiscordMessageUpdate(user *User, msg *discordgo.Mess
attachmentMap[existingPart.AttachmentID] = existingPart
}
}
for _, attachment := range msg.Attachments {
if _, found := attachmentMap[attachment.ID]; found {
delete(attachmentMap, attachment.ID)
for _, remainingAttachment := range msg.Attachments {
if _, found := attachmentMap[remainingAttachment.ID]; found {
delete(attachmentMap, remainingAttachment.ID)
}
}
for _, attachment := range attachmentMap {
_, err := intent.RedactEvent(portal.MXID, attachment.MXID)
for _, deletedAttachment := range attachmentMap {
_, err := intent.RedactEvent(portal.MXID, deletedAttachment.MXID)
if err != nil {
portal.log.Warnfln("Failed to remove attachment %s: %v", attachment.MXID, err)
portal.log.Warnfln("Failed to remove attachment %s: %v", deletedAttachment.MXID, err)
}
attachment.Delete()
deletedAttachment.Delete()
}
if msg.Content == "" || existing[0].AttachmentID != "" {
portal.log.Debugfln("Dropping non-text edit to %s (message on matrix: %t, text on discord: %t)", msg.ID, existing[0].AttachmentID == "", len(msg.Content) > 0)
return
}
content := renderDiscordMarkdown(msg.Content)
content := portal.renderDiscordMarkdown(msg.Content)
content.SetEdit(existing[0].MXID)
var editTS int64
@ -885,7 +885,6 @@ func (portal *Portal) startThreadFromMatrix(sender *User, threadRoot id.EventID)
return "", fmt.Errorf("error starting thread: %v", err)
}
portal.log.Debugfln("Created Discord thread from %s/%s", threadRoot, ch.ID)
fmt.Printf("Created thread %+v\n", ch)
portal.bridge.GetThreadByID(existingMsg.DiscordID, existingMsg)
return ch.ID, nil
}
@ -1295,32 +1294,12 @@ func (portal *Portal) handleDiscordReaction(user *User, reaction *discordgo.Mess
var matrixReaction string
if reaction.Emoji.ID != "" {
dbEmoji := portal.bridge.DB.Emoji.GetByDiscordID(reaction.Emoji.ID)
if dbEmoji == nil {
data, mimeType, err := portal.downloadDiscordEmoji(reaction.Emoji.ID, reaction.Emoji.Animated)
if err != nil {
portal.log.Warnfln("Failed to download emoji %s from discord: %v", reaction.Emoji.ID, err)
return
}
uri, err := portal.uploadMatrixEmoji(intent, data, mimeType)
if err != nil {
portal.log.Warnfln("Failed to upload discord emoji %s to homeserver: %v", reaction.Emoji.ID, err)
return
}
dbEmoji = portal.bridge.DB.Emoji.New()
dbEmoji.DiscordID = reaction.Emoji.ID
dbEmoji.DiscordName = reaction.Emoji.Name
dbEmoji.MatrixURL = uri
dbEmoji.Insert()
reactionMXC := portal.getEmojiMXCByDiscordID(reaction.Emoji.ID, reaction.Emoji.Name, reaction.Emoji.Animated)
if reactionMXC.IsEmpty() {
return
}
discordID = dbEmoji.DiscordID
matrixReaction = dbEmoji.MatrixURL.String()
matrixReaction = reactionMXC.String()
discordID = reaction.Emoji.ID
} else {
discordID = reaction.Emoji.Name
matrixReaction = variationselector.Add(reaction.Emoji.Name)

View file

@ -675,9 +675,6 @@ func (user *User) channelUpdateHandler(_ *discordgo.Session, c *discordgo.Channe
}
func (user *User) pushPortalMessage(msg interface{}, typeName, channelID, guildID string) {
if user.Session.LogLevel == discordgo.LogDebug {
fmt.Printf("%+v\n", msg)
}
if !user.bridgeMessage(guildID) {
return
}