Add support for MSC3916 endpoints for direct media

This commit is contained in:
Tulir Asokan 2024-05-29 11:40:57 +03:00
parent 8d01c30014
commit a6d9e62b49

View file

@ -25,8 +25,10 @@ import (
"fmt" "fmt"
"io" "io"
"mime" "mime"
"mime/multipart"
"net" "net"
"net/http" "net/http"
"net/textproto"
"net/url" "net/url"
"os" "os"
"strconv" "strconv"
@ -110,9 +112,11 @@ func newDirectMediaAPI(br *DiscordBridge) *DirectMediaAPI {
if dma.ks.WellKnownTarget == "" { if dma.ks.WellKnownTarget == "" {
dma.ks.WellKnownTarget = fmt.Sprintf("%s:443", dma.cfg.ServerName) dma.ks.WellKnownTarget = fmt.Sprintf("%s:443", dma.cfg.ServerName)
} }
federationRouter := r.PathPrefix("/_matrix/federation").Subrouter()
mediaRouter := r.PathPrefix("/_matrix/media").Subrouter() mediaRouter := r.PathPrefix("/_matrix/media").Subrouter()
clientMediaRouter := r.PathPrefix("/_matrix/client/v1/media").Subrouter()
var reqIDCounter atomic.Uint64 var reqIDCounter atomic.Uint64
mediaRouter.Use(func(next http.Handler) http.Handler { middleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
@ -124,21 +128,37 @@ func newDirectMediaAPI(br *DiscordBridge) *DirectMediaAPI {
Logger() Logger()
next.ServeHTTP(w, r.WithContext(log.WithContext(r.Context()))) next.ServeHTTP(w, r.WithContext(log.WithContext(r.Context())))
}) })
}) }
mediaRouter.Use(middleware)
federationRouter.Use(middleware)
clientMediaRouter.Use(middleware)
addRoutes := func(version string) { addRoutes := func(version string) {
mediaRouter.HandleFunc("/"+version+"/download/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet) mediaRouter.HandleFunc("/"+version+"/download/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet)
mediaRouter.HandleFunc("/"+version+"/download/{serverName}/{mediaID}/{fileName}", dma.DownloadMedia).Methods(http.MethodGet) mediaRouter.HandleFunc("/"+version+"/download/{serverName}/{mediaID}/{fileName}", dma.DownloadMedia).Methods(http.MethodGet)
mediaRouter.HandleFunc("/"+version+"/thumbnail/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet) mediaRouter.HandleFunc("/"+version+"/thumbnail/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet)
mediaRouter.HandleFunc("/"+version+"/upload/{serverName}/{mediaID}", dma.UploadNotSupported).Methods(http.MethodPut) mediaRouter.HandleFunc("/"+version+"/upload/{serverName}/{mediaID}", dma.UploadNotSupported).Methods(http.MethodPut)
mediaRouter.HandleFunc("/"+version+"/upload", dma.UploadNotSupported).Methods(http.MethodPost) mediaRouter.HandleFunc("/"+version+"/upload", dma.UploadNotSupported).Methods(http.MethodPost)
mediaRouter.HandleFunc("/"+version+"/create", dma.UploadNotSupported).Methods(http.MethodPost)
mediaRouter.HandleFunc("/"+version+"/config", dma.UploadNotSupported).Methods(http.MethodGet) mediaRouter.HandleFunc("/"+version+"/config", dma.UploadNotSupported).Methods(http.MethodGet)
mediaRouter.HandleFunc("/"+version+"/preview_url", dma.PreviewURLNotSupported).Methods(http.MethodGet) mediaRouter.HandleFunc("/"+version+"/preview_url", dma.PreviewURLNotSupported).Methods(http.MethodGet)
} }
clientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet)
clientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}/{fileName}", dma.DownloadMedia).Methods(http.MethodGet)
clientMediaRouter.HandleFunc("/thumbnail/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet)
clientMediaRouter.HandleFunc("/upload/{serverName}/{mediaID}", dma.UploadNotSupported).Methods(http.MethodPut)
clientMediaRouter.HandleFunc("/upload", dma.UploadNotSupported).Methods(http.MethodPost)
clientMediaRouter.HandleFunc("/create", dma.UploadNotSupported).Methods(http.MethodPost)
clientMediaRouter.HandleFunc("/config", dma.UploadNotSupported).Methods(http.MethodGet)
clientMediaRouter.HandleFunc("/preview_url", dma.PreviewURLNotSupported).Methods(http.MethodGet)
addRoutes("v3") addRoutes("v3")
addRoutes("r0") addRoutes("r0")
addRoutes("v1") addRoutes("v1")
federationRouter.HandleFunc("/v1/media/download/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet)
federationRouter.HandleFunc("/v1/version", dma.ks.GetServerVersion).Methods(http.MethodGet)
mediaRouter.NotFoundHandler = http.HandlerFunc(dma.UnknownEndpoint) mediaRouter.NotFoundHandler = http.HandlerFunc(dma.UnknownEndpoint)
mediaRouter.MethodNotAllowedHandler = http.HandlerFunc(dma.UnsupportedMethod) mediaRouter.MethodNotAllowedHandler = http.HandlerFunc(dma.UnsupportedMethod)
federationRouter.NotFoundHandler = http.HandlerFunc(dma.UnknownEndpoint)
federationRouter.MethodNotAllowedHandler = http.HandlerFunc(dma.UnsupportedMethod)
dma.ks.Register(r) dma.ks.Register(r)
return dma return dma
@ -532,14 +552,17 @@ func (dma *DirectMediaAPI) proxyDownload(ctx context.Context, w http.ResponseWri
func (dma *DirectMediaAPI) DownloadMedia(w http.ResponseWriter, r *http.Request) { func (dma *DirectMediaAPI) DownloadMedia(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
log := zerolog.Ctx(ctx) log := zerolog.Ctx(ctx)
isNewFederation := strings.HasPrefix(r.URL.Path, "/_matrix/federation/v1/media/download/")
vars := mux.Vars(r) vars := mux.Vars(r)
if vars["serverName"] != dma.cfg.ServerName { if !isNewFederation && vars["serverName"] != dma.cfg.ServerName {
jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ jsonResponse(w, http.StatusNotFound, &mautrix.RespError{
ErrCode: mautrix.MNotFound.ErrCode, ErrCode: mautrix.MNotFound.ErrCode,
Err: fmt.Sprintf("This is a Discord media proxy for %q, other media downloads are not available here", dma.cfg.ServerName), Err: fmt.Sprintf("This is a Discord media proxy for %q, other media downloads are not available here", dma.cfg.ServerName),
}) })
return return
} }
// TODO check destination header in X-Matrix auth when isNewFederation
url, expiresAt, err := dma.getMediaURL(ctx, vars["mediaID"]) url, expiresAt, err := dma.getMediaURL(ctx, vars["mediaID"])
if err != nil { if err != nil {
var respError *RespError var respError *RespError
@ -556,7 +579,36 @@ func (dma *DirectMediaAPI) DownloadMedia(w http.ResponseWriter, r *http.Request)
}) })
} }
return return
}
if isNewFederation {
mp := multipart.NewWriter(w)
w.Header().Set("Content-Type", strings.Replace(mp.FormDataContentType(), "form-data", "mixed", 1))
var metaPart io.Writer
metaPart, err = mp.CreatePart(textproto.MIMEHeader{
"Content-Type": {"application/json"},
})
if err != nil {
log.Err(err).Msg("Failed to create multipart metadata field")
return
}
_, err = metaPart.Write([]byte(`{}`))
if err != nil {
log.Err(err).Msg("Failed to write multipart metadata field")
return
}
_, err = mp.CreatePart(textproto.MIMEHeader{
"Location": {url},
})
if err != nil {
log.Err(err).Msg("Failed to create multipart redirect field")
return
}
err = mp.Close()
if err != nil {
log.Err(err).Msg("Failed to close multipart writer")
return
}
return
} }
// Proxy if the config allows proxying and the request doesn't allow redirects. // Proxy if the config allows proxying and the request doesn't allow redirects.
// In any other case, redirect to the Discord CDN. // In any other case, redirect to the Discord CDN.