diff --git a/directmedia.go b/directmedia.go index 957217d..c6f8a2b 100644 --- a/directmedia.go +++ b/directmedia.go @@ -25,8 +25,10 @@ import ( "fmt" "io" "mime" + "mime/multipart" "net" "net/http" + "net/textproto" "net/url" "os" "strconv" @@ -110,9 +112,11 @@ func newDirectMediaAPI(br *DiscordBridge) *DirectMediaAPI { if dma.ks.WellKnownTarget == "" { dma.ks.WellKnownTarget = fmt.Sprintf("%s:443", dma.cfg.ServerName) } + federationRouter := r.PathPrefix("/_matrix/federation").Subrouter() mediaRouter := r.PathPrefix("/_matrix/media").Subrouter() + clientMediaRouter := r.PathPrefix("/_matrix/client/v1/media").Subrouter() var reqIDCounter atomic.Uint64 - mediaRouter.Use(func(next http.Handler) http.Handler { + middleware := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Access-Control-Allow-Origin", "*") w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") @@ -124,21 +128,37 @@ func newDirectMediaAPI(br *DiscordBridge) *DirectMediaAPI { Logger() next.ServeHTTP(w, r.WithContext(log.WithContext(r.Context()))) }) - }) + } + mediaRouter.Use(middleware) + federationRouter.Use(middleware) + clientMediaRouter.Use(middleware) addRoutes := func(version string) { mediaRouter.HandleFunc("/"+version+"/download/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet) mediaRouter.HandleFunc("/"+version+"/download/{serverName}/{mediaID}/{fileName}", dma.DownloadMedia).Methods(http.MethodGet) mediaRouter.HandleFunc("/"+version+"/thumbnail/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet) mediaRouter.HandleFunc("/"+version+"/upload/{serverName}/{mediaID}", dma.UploadNotSupported).Methods(http.MethodPut) mediaRouter.HandleFunc("/"+version+"/upload", dma.UploadNotSupported).Methods(http.MethodPost) + mediaRouter.HandleFunc("/"+version+"/create", dma.UploadNotSupported).Methods(http.MethodPost) mediaRouter.HandleFunc("/"+version+"/config", dma.UploadNotSupported).Methods(http.MethodGet) mediaRouter.HandleFunc("/"+version+"/preview_url", dma.PreviewURLNotSupported).Methods(http.MethodGet) } + clientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet) + clientMediaRouter.HandleFunc("/download/{serverName}/{mediaID}/{fileName}", dma.DownloadMedia).Methods(http.MethodGet) + clientMediaRouter.HandleFunc("/thumbnail/{serverName}/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet) + clientMediaRouter.HandleFunc("/upload/{serverName}/{mediaID}", dma.UploadNotSupported).Methods(http.MethodPut) + clientMediaRouter.HandleFunc("/upload", dma.UploadNotSupported).Methods(http.MethodPost) + clientMediaRouter.HandleFunc("/create", dma.UploadNotSupported).Methods(http.MethodPost) + clientMediaRouter.HandleFunc("/config", dma.UploadNotSupported).Methods(http.MethodGet) + clientMediaRouter.HandleFunc("/preview_url", dma.PreviewURLNotSupported).Methods(http.MethodGet) addRoutes("v3") addRoutes("r0") addRoutes("v1") + federationRouter.HandleFunc("/v1/media/download/{mediaID}", dma.DownloadMedia).Methods(http.MethodGet) + federationRouter.HandleFunc("/v1/version", dma.ks.GetServerVersion).Methods(http.MethodGet) mediaRouter.NotFoundHandler = http.HandlerFunc(dma.UnknownEndpoint) mediaRouter.MethodNotAllowedHandler = http.HandlerFunc(dma.UnsupportedMethod) + federationRouter.NotFoundHandler = http.HandlerFunc(dma.UnknownEndpoint) + federationRouter.MethodNotAllowedHandler = http.HandlerFunc(dma.UnsupportedMethod) dma.ks.Register(r) return dma @@ -532,14 +552,17 @@ func (dma *DirectMediaAPI) proxyDownload(ctx context.Context, w http.ResponseWri func (dma *DirectMediaAPI) DownloadMedia(w http.ResponseWriter, r *http.Request) { ctx := r.Context() log := zerolog.Ctx(ctx) + isNewFederation := strings.HasPrefix(r.URL.Path, "/_matrix/federation/v1/media/download/") vars := mux.Vars(r) - if vars["serverName"] != dma.cfg.ServerName { + if !isNewFederation && vars["serverName"] != dma.cfg.ServerName { jsonResponse(w, http.StatusNotFound, &mautrix.RespError{ ErrCode: mautrix.MNotFound.ErrCode, Err: fmt.Sprintf("This is a Discord media proxy for %q, other media downloads are not available here", dma.cfg.ServerName), }) return } + // TODO check destination header in X-Matrix auth when isNewFederation + url, expiresAt, err := dma.getMediaURL(ctx, vars["mediaID"]) if err != nil { var respError *RespError @@ -556,7 +579,36 @@ func (dma *DirectMediaAPI) DownloadMedia(w http.ResponseWriter, r *http.Request) }) } return - + } + if isNewFederation { + mp := multipart.NewWriter(w) + w.Header().Set("Content-Type", strings.Replace(mp.FormDataContentType(), "form-data", "mixed", 1)) + var metaPart io.Writer + metaPart, err = mp.CreatePart(textproto.MIMEHeader{ + "Content-Type": {"application/json"}, + }) + if err != nil { + log.Err(err).Msg("Failed to create multipart metadata field") + return + } + _, err = metaPart.Write([]byte(`{}`)) + if err != nil { + log.Err(err).Msg("Failed to write multipart metadata field") + return + } + _, err = mp.CreatePart(textproto.MIMEHeader{ + "Location": {url}, + }) + if err != nil { + log.Err(err).Msg("Failed to create multipart redirect field") + return + } + err = mp.Close() + if err != nil { + log.Err(err).Msg("Failed to close multipart writer") + return + } + return } // Proxy if the config allows proxying and the request doesn't allow redirects. // In any other case, redirect to the Discord CDN.