Parse expiry from URL

This commit is contained in:
Tulir Asokan 2024-02-18 23:26:16 +02:00
parent 23ae2d314f
commit 2a7a2c3895

View file

@ -19,6 +19,7 @@ package main
import ( import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/binary"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
@ -26,6 +27,7 @@ import (
"mime" "mime"
"net" "net"
"net/http" "net/http"
"net/url"
"os" "os"
"strconv" "strconv"
"strings" "strings"
@ -148,19 +150,39 @@ func (dma *DirectMediaAPI) makeMXC(data MediaIDData) id.ContentURI {
} }
} }
func (dma *DirectMediaAPI) addAttachmentToCache(channelID uint64, att *discordgo.MessageAttachment) { func parseExpiryTS(addr string) time.Time {
parsedURL, err := url.Parse(addr)
if err != nil {
return time.Time{}
}
tsBytes, err := hex.DecodeString(parsedURL.Query().Get("ex"))
if err != nil || len(tsBytes) != 4 {
return time.Time{}
}
parsedTS := int64(binary.BigEndian.Uint32(tsBytes))
if parsedTS > time.Now().Unix() && parsedTS < time.Now().Add(365*24*time.Hour).Unix() {
return time.Unix(parsedTS, 0)
}
return time.Time{}
}
func (dma *DirectMediaAPI) addAttachmentToCache(channelID uint64, att *discordgo.MessageAttachment) time.Time {
attachmentID, err := strconv.ParseUint(att.ID, 10, 64) attachmentID, err := strconv.ParseUint(att.ID, 10, 64)
if err != nil { if err != nil {
return return time.Time{}
}
expiry := parseExpiryTS(att.URL)
if expiry.IsZero() {
expiry = time.Now().Add(24 * time.Hour)
} }
dma.attachmentCache[AttachmentCacheKey{ dma.attachmentCache[AttachmentCacheKey{
ChannelID: channelID, ChannelID: channelID,
AttachmentID: attachmentID, AttachmentID: attachmentID,
}] = AttachmentCacheValue{ }] = AttachmentCacheValue{
URL: att.URL, URL: att.URL,
// TODO find expiry somehow properly? Expiry: expiry,
Expiry: time.Now().Add(23 * time.Hour),
} }
return expiry
} }
func (dma *DirectMediaAPI) AttachmentMXC(channelID, messageID string, att *discordgo.MessageAttachment) (mxc id.ContentURI) { func (dma *DirectMediaAPI) AttachmentMXC(channelID, messageID string, att *discordgo.MessageAttachment) (mxc id.ContentURI) {
@ -281,7 +303,7 @@ func (re *RespError) Error() string {
var ErrNoUsersWithAccessFound = errors.New("no users found to fetch message") var ErrNoUsersWithAccessFound = errors.New("no users found to fetch message")
var ErrAttachmentNotFound = errors.New("attachment not found") var ErrAttachmentNotFound = errors.New("attachment not found")
func (dma *DirectMediaAPI) fetchNewAttachmentURL(ctx context.Context, meta *AttachmentMediaData) (string, error) { func (dma *DirectMediaAPI) fetchNewAttachmentURL(ctx context.Context, meta *AttachmentMediaData) (string, time.Time, error) {
var client *discordgo.Session var client *discordgo.Session
channelIDStr := strconv.FormatUint(meta.ChannelID, 10) channelIDStr := strconv.FormatUint(meta.ChannelID, 10)
users := dma.bridge.DB.GetUsersInPortal(channelIDStr) users := dma.bridge.DB.GetUsersInPortal(channelIDStr)
@ -295,9 +317,8 @@ func (dma *DirectMediaAPI) fetchNewAttachmentURL(ctx context.Context, meta *Atta
} }
} }
if client == nil { if client == nil {
return "", ErrNoUsersWithAccessFound return "", time.Time{}, ErrNoUsersWithAccessFound
} }
var url string
var msgs []*discordgo.Message var msgs []*discordgo.Message
var err error var err error
messageIDStr := strconv.FormatUint(meta.MessageID, 10) messageIDStr := strconv.FormatUint(meta.MessageID, 10)
@ -309,21 +330,24 @@ func (dma *DirectMediaAPI) fetchNewAttachmentURL(ctx context.Context, meta *Atta
msgs = []*discordgo.Message{msg} msgs = []*discordgo.Message{msg}
} }
if err != nil { if err != nil {
return "", fmt.Errorf("failed to fetch message: %w", err) return "", time.Time{}, fmt.Errorf("failed to fetch message: %w", err)
} }
attachmentIDStr := strconv.FormatUint(meta.AttachmentID, 10) attachmentIDStr := strconv.FormatUint(meta.AttachmentID, 10)
var url string
var expiry time.Time
for _, item := range msgs { for _, item := range msgs {
for _, att := range item.Attachments { for _, att := range item.Attachments {
dma.addAttachmentToCache(meta.ChannelID, att) thisExpiry := dma.addAttachmentToCache(meta.ChannelID, att)
if att.ID == attachmentIDStr { if att.ID == attachmentIDStr {
url = att.URL url = att.URL
expiry = thisExpiry
} }
} }
} }
if url == "" { if url == "" {
return "", ErrAttachmentNotFound return "", time.Time{}, ErrAttachmentNotFound
} }
return url, nil return url, expiry, nil
} }
func (dma *DirectMediaAPI) GetEmojiInfo(contentURI id.ContentURI) *EmojiMediaData { func (dma *DirectMediaAPI) GetEmojiInfo(contentURI id.ContentURI) *EmojiMediaData {
@ -366,7 +390,7 @@ func (dma *DirectMediaAPI) getMediaURL(ctx context.Context, encodedMediaID strin
Uint64("message_id", mediaData.MessageID). Uint64("message_id", mediaData.MessageID).
Uint64("attachment_id", mediaData.AttachmentID). Uint64("attachment_id", mediaData.AttachmentID).
Msg("Refreshing attachment URL") Msg("Refreshing attachment URL")
url, err = dma.fetchNewAttachmentURL(ctx, mediaData) url, expiry, err = dma.fetchNewAttachmentURL(ctx, mediaData)
if err != nil { if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to refresh attachment URL") zerolog.Ctx(ctx).Err(err).Msg("Failed to refresh attachment URL")
msg := "Failed to refresh attachment URL" msg := "Failed to refresh attachment URL"
@ -381,9 +405,7 @@ func (dma *DirectMediaAPI) getMediaURL(ctx context.Context, encodedMediaID strin
Status: http.StatusNotFound, Status: http.StatusNotFound,
} }
} else { } else {
zerolog.Ctx(ctx).Debug().Msg("Successfully refreshed attachment URL") zerolog.Ctx(ctx).Debug().Time("expiry", expiry).Msg("Successfully refreshed attachment URL")
// TODO find expiry somehow properly?
expiry = time.Now().Add(23 * time.Hour)
} }
case *EmojiMediaData: case *EmojiMediaData:
if mediaData.Animated { if mediaData.Animated {