1
0
Fork 0
forked from mirror/grapevine

Compare commits

..

1 commit

Author SHA1 Message Date
Charles Hall
b0ab736da5
add admin command to unset a room alias 2024-06-12 21:18:23 -07:00
101 changed files with 2668 additions and 4820 deletions

3
.gitignore vendored
View file

@ -12,6 +12,3 @@ result*
# GitLab CI cache
/.gitlab-ci.d
# mdbook artifacts
/public

View file

@ -1,7 +1,6 @@
stages:
- ci
- artifacts
- deploy
variables:
# Makes some things print in color
@ -11,9 +10,6 @@ before_script:
# Enable nix-command and flakes
- if command -v nix > /dev/null; then echo "experimental-features = nix-command flakes" >> /etc/nix/nix.conf; fi
# Disable IFD, to ensure we are able to build without it
- if command -v nix > /dev/null; then echo "allow-import-from-derivation = false" >> /etc/nix/nix.conf; fi
# Add our own binary cache
- if command -v nix > /dev/null && [ -n "$ATTIC_ENDPOINT" ] && [ -n "$ATTIC_CACHE" ]; then echo "extra-substituters = $ATTIC_ENDPOINT/$ATTIC_CACHE" >> /etc/nix/nix.conf; fi
- if command -v nix > /dev/null && [ -n "$ATTIC_PUBLIC_KEY" ]; then echo "extra-trusted-public-keys = $ATTIC_PUBLIC_KEY" >> /etc/nix/nix.conf; fi
@ -52,14 +48,3 @@ artifacts:
image: nixos/nix:2.18.2
script:
- ./bin/nix-build-and-cache packages
pages:
stage: deploy
image: nixos/nix:2.18.2
script:
- direnv exec . mdbook build
artifacts:
paths:
- public
only:
- main

View file

@ -1 +0,0 @@
.gitignore

View file

@ -1 +0,0 @@
.gitignore

69
Cargo.lock generated
View file

@ -1045,15 +1045,6 @@ dependencies = [
"itoa",
]
[[package]]
name = "http-auth"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "643c9bbf6a4ea8a656d6b4cd53d34f79e3f841ad5203c1a55fb7d761923bc255"
dependencies = [
"memchr",
]
[[package]]
name = "http-body"
version = "0.4.6"
@ -1090,9 +1081,9 @@ dependencies = [
[[package]]
name = "httparse"
version = "1.9.4"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fcc0b4a115bf80b728eb8ea024ad5bd707b615bfed49e0665b6e0f86fd082d9"
checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904"
[[package]]
name = "httpdate"
@ -2120,7 +2111,7 @@ dependencies = [
[[package]]
name = "ruma"
version = "0.10.1"
source = "git+https://github.com/ruma/ruma?branch=main#14d7415f0d80aadf425c2384c0f348d1c03527c8"
source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012"
dependencies = [
"assign",
"js_int",
@ -2141,7 +2132,7 @@ dependencies = [
[[package]]
name = "ruma-appservice-api"
version = "0.10.0"
source = "git+https://github.com/ruma/ruma?branch=main#14d7415f0d80aadf425c2384c0f348d1c03527c8"
source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012"
dependencies = [
"js_int",
"ruma-common",
@ -2153,7 +2144,7 @@ dependencies = [
[[package]]
name = "ruma-client-api"
version = "0.18.0"
source = "git+https://github.com/ruma/ruma?branch=main#14d7415f0d80aadf425c2384c0f348d1c03527c8"
source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012"
dependencies = [
"as_variant",
"assign",
@ -2176,7 +2167,7 @@ dependencies = [
[[package]]
name = "ruma-common"
version = "0.13.0"
source = "git+https://github.com/ruma/ruma?branch=main#14d7415f0d80aadf425c2384c0f348d1c03527c8"
source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012"
dependencies = [
"as_variant",
"base64 0.22.1",
@ -2206,7 +2197,7 @@ dependencies = [
[[package]]
name = "ruma-events"
version = "0.28.1"
source = "git+https://github.com/ruma/ruma?branch=main#14d7415f0d80aadf425c2384c0f348d1c03527c8"
source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012"
dependencies = [
"as_variant",
"indexmap 2.2.6",
@ -2222,22 +2213,15 @@ dependencies = [
"thiserror",
"tracing",
"url",
"web-time",
"wildmatch",
]
[[package]]
name = "ruma-federation-api"
version = "0.9.0"
source = "git+https://github.com/ruma/ruma?branch=main#14d7415f0d80aadf425c2384c0f348d1c03527c8"
source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012"
dependencies = [
"bytes",
"http 1.1.0",
"httparse",
"js_int",
"memchr",
"mime",
"rand",
"ruma-common",
"ruma-events",
"serde",
@ -2247,7 +2231,7 @@ dependencies = [
[[package]]
name = "ruma-identifiers-validation"
version = "0.9.5"
source = "git+https://github.com/ruma/ruma?branch=main#14d7415f0d80aadf425c2384c0f348d1c03527c8"
source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012"
dependencies = [
"js_int",
"thiserror",
@ -2256,7 +2240,7 @@ dependencies = [
[[package]]
name = "ruma-identity-service-api"
version = "0.9.0"
source = "git+https://github.com/ruma/ruma?branch=main#14d7415f0d80aadf425c2384c0f348d1c03527c8"
source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012"
dependencies = [
"js_int",
"ruma-common",
@ -2266,7 +2250,7 @@ dependencies = [
[[package]]
name = "ruma-macros"
version = "0.13.0"
source = "git+https://github.com/ruma/ruma?branch=main#14d7415f0d80aadf425c2384c0f348d1c03527c8"
source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012"
dependencies = [
"once_cell",
"proc-macro-crate",
@ -2281,7 +2265,7 @@ dependencies = [
[[package]]
name = "ruma-push-gateway-api"
version = "0.9.0"
source = "git+https://github.com/ruma/ruma?branch=main#14d7415f0d80aadf425c2384c0f348d1c03527c8"
source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012"
dependencies = [
"js_int",
"ruma-common",
@ -2293,20 +2277,18 @@ dependencies = [
[[package]]
name = "ruma-server-util"
version = "0.3.0"
source = "git+https://github.com/ruma/ruma?branch=main#14d7415f0d80aadf425c2384c0f348d1c03527c8"
source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012"
dependencies = [
"headers",
"http 1.1.0",
"http-auth",
"ruma-common",
"thiserror",
"tracing",
"yap",
]
[[package]]
name = "ruma-signatures"
version = "0.15.0"
source = "git+https://github.com/ruma/ruma?branch=main#14d7415f0d80aadf425c2384c0f348d1c03527c8"
source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012"
dependencies = [
"base64 0.22.1",
"ed25519-dalek",
@ -2322,7 +2304,7 @@ dependencies = [
[[package]]
name = "ruma-state-res"
version = "0.11.0"
source = "git+https://github.com/ruma/ruma?branch=main#14d7415f0d80aadf425c2384c0f348d1c03527c8"
source = "git+https://github.com/ruma/ruma?branch=main#ba9a492fdee6ad89b179e2b3ab689c3114107012"
dependencies = [
"itertools",
"js_int",
@ -3232,16 +3214,6 @@ dependencies = [
"web-time",
]
[[package]]
name = "tracing-serde"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc6b213177105856957181934e4920de57730fc69bf42c37ee5bb664d406d9e1"
dependencies = [
"serde",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.18"
@ -3252,15 +3224,12 @@ dependencies = [
"nu-ansi-term",
"once_cell",
"regex",
"serde",
"serde_json",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
"tracing-serde",
]
[[package]]
@ -3738,6 +3707,12 @@ version = "2.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "213b7324336b53d2414b2db8537e56544d981803139155afa84f76eeebb7a546"
[[package]]
name = "yap"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfe269e7b803a5e8e20cbd97860e136529cd83bf2c9c6d37b142467e7e1f051f"
[[package]]
name = "zerocopy"
version = "0.7.34"

View file

@ -5,6 +5,7 @@ explicit_outlives_requirements = "warn"
macro_use_extern_crate = "warn"
missing_abi = "warn"
noop_method_call = "warn"
pointer_structural_match = "warn"
single_use_lifetimes = "warn"
unreachable_pub = "warn"
unsafe_op_in_unsafe_fn = "warn"
@ -16,7 +17,7 @@ unused_qualifications = "warn"
[workspace.lints.clippy]
# Groups. Keep alphabetically sorted
pedantic = { level = "warn", priority = -1 }
pedantic = "warn"
# Lints. Keep alphabetically sorted
as_conversions = "warn"
@ -79,7 +80,7 @@ version = "0.1.0"
edition = "2021"
# See also `rust-toolchain.toml`
rust-version = "1.81.0"
rust-version = "1.78.0"
[lints]
workspace = true
@ -138,7 +139,7 @@ tower-http = { version = "0.5.2", features = ["add-extension", "cors", "sensitiv
tracing = { version = "0.1.40", features = [] }
tracing-flame = "0.2.0"
tracing-opentelemetry = "0.24.0"
tracing-subscriber = { version = "0.3.18", features = ["env-filter", "json"] }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
trust-dns-resolver = "0.23.2"
xdg = "2.5.2"

View file

@ -1,9 +0,0 @@
# Grapevine
A Matrix homeserver.
## Read the book
[Click here to read the latest version.][0]
[0]: https://matrix.pages.gitlab.computer.surgery/grapevine-fork/

View file

@ -1,12 +0,0 @@
[book]
title = "Grapevine"
language = "en"
multilingual = false
src = "book"
[build]
build-dir = "public"
[output.html]
git-repository-icon = "fa-git-square"
git-repository-url = "https://gitlab.computer.surgery/matrix/grapevine-fork"

View file

@ -1,6 +0,0 @@
# Summary
* [Introduction](./introduction.md)
* [Code of conduct](./code-of-conduct.md)
* [Contributing](./contributing.md)
* [Changelog](./changelog.md)

View file

@ -1,224 +0,0 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog][keep-a-changelog], and this project
adheres to [Semantic Versioning][semver].
[keep-a-changelog]: https://keepachangelog.com/en/1.0.0/
[semver]: https://semver.org/spec/v2.0.0.html
<!--
Changelog sections must appear in the following order if they appear for a
particular version so that attention can be drawn to the important parts:
1. Security
2. Removed
3. Deprecated
4. Changed
5. Fixed
6. Added
Entries within each section should be sorted by merge order. If multiple changes
result in a single entry, choose the merge order of the first or last change.
-->
## Unreleased
<!-- TODO: Change "will be" to "is" on release -->
This will be the first release of Grapevine since it was forked from Conduit
0.7.0.
### Security
1. Prevent XSS via user-uploaded media.
([!8](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/8))
2. Switch from incorrect, hand-rolled `X-Matrix` `Authorization` parser to the
much better implementation provided by Ruma.
([!31](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/31))
* This is not practically exploitable to our knowledge, but this change does
reduce risk.
3. Switch to a more trustworthy password hashing library.
([!29](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/29))
* This is not practically exploitable to our knowledge, but this change does
reduce risk.
4. Don't return redacted events from the search endpoint.
([!41 (f74043d)](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/41/diffs?commit_id=f74043df9aa59b406b5086c2e9fa2791a31aa41b),
[!41 (83cdc9c)](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/41/diffs?commit_id=83cdc9c708cd7b50fe1ab40ea6a68dcf252c190b))
5. Prevent impersonation in EDUs.
([!41 (da99b07)](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/41/diffs?commit_id=da99b0706e683a2d347768efe5b50676abdf7b44))
* `m.signing_key_update` was not affected by this bug.
6. Verify PDUs and transactions against the temporally-correct signing keys.
([!41 (9087da9)](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/41/diffs?commit_id=9087da91db8585f34d026a48ba8fdf64865ba14d))
7. Only allow the admin bot to change the room ID that the admin room alias
points to.
([!42](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/42))
### Removed
1. Remove update checker.
([17a0b34](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/17a0b3430934fbb8370066ee9dc3506102c5b3f6))
2. Remove optional automatic display name emoji for newly registered users.
([cddf699](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/cddf6991f280008b5af5acfab6a9719bb0cfb7f1))
3. Remove admin room welcome message on first startup.
([c9945f6](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/c9945f6bbac6e22af6cf955cfa99826d4b04fe8c))
4. Remove incomplete presence implementation.
([f27941d](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/f27941d5108acda250921c6a58499a46568fd030))
5. Remove Debian packaging.
([d41f0fb](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/d41f0fbf72dae6562358173f425d23bb0e174ca2))
6. Remove Docker packaging.
([!48](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/48))
7. **BREAKING:** Remove unstable room versions.
([!59](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/59))
### Changed
1. **BREAKING:** Rename `conduit_cache_capacity_modifier` configuration option
to `cache_capacity_modifier`.
([5619d7e](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/5619d7e3180661731800e253b558b88b407d2ae7))
* If you are explicitly setting this configuration option, make sure to
change its name before updating.
2. **BREAKING:** Rename Conduit to Grapevine.
([360e020](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/360e020b644bd012ed438708b661a25fbd124f68))
* The `CONDUIT_VERSION_EXTRA` build-time environment variable has been
renamed to `GRAPEVINE_VERSION_EXTRA`. This change only affects distribution
packagers or non-Nix users who are building from source. If you fall into
one of those categories *and* were explicitly setting this environment
variable, make sure to change its name before building Grapevine.
3. **BREAKING:** Change the default port from 8000 to 6167.
([f205280](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/f2052805201f0685d850592b1c96f4861c58fb22))
* If you relied on the default port being 8000, either update your other
configuration to use the new port, or explicitly configure Grapevine's port
to 8000.
4. Improve tracing spans and events.
([!11 (a275db3)](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/11/diffs?commit_id=a275db3847b8d5aaa0c651a686c19cfbf9fdb8b5)
(merged as [5172f66](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/5172f66c1a90e0e97b67be2897ae59fbc00208a4)),
[!11 (a275db3)](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/11/diffs?commit_id=a275db3847b8d5aaa0c651a686c19cfbf9fdb8b5)
(merged as [5172f66](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/5172f66c1a90e0e97b67be2897ae59fbc00208a4)),
[!11 (f556fce)](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/11/diffs?commit_id=f556fce73eb7beec2ed7b1781df0acdf47920d9c)
(merged as [ac42e0b](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/ac42e0bfff6af8677636a3dc1a56701a3255071d)),
[!18](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/18),
[!26](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/26),
[!50](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/50),
[!52](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/52),
[!54](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/54),
[!56](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/56),
[!69](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/69))
5. Stop returning unnecessary member counts from `/_matrix/client/{r0,v3}/sync`.
([!12](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/12))
6. **BREAKING:** Allow federation by default.
([!24](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/24))
* If you relied on federation being disabled by default, make sure to
explicitly disable it before upgrading.
7. **BREAKING:** Remove the `[global]` section from the configuration file.
([!38](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/38))
* Details on how to migrate can be found in the merge request's description.
8. **BREAKING:** Allow specifying multiple transport listeners in the
configuration file.
([!39](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/39))
* Details on how to migrate can be found in the merge request's description.
9. Increase default log level so that span information is included.
([!50](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/50))
10. **BREAKING:** Reorganize config into sections.
([!49](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/49))
* Details on how to migrate can be found in the merge request's description.
11. Try to generate thumbnails for remote media ourselves if the federation
thumbnail request fails.
([!58](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/58))
### Fixed
1. Fix questionable numeric conversions.
([71c48f6](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/71c48f66c4922813c2dc30b7b875200e06ce4b75))
2. Stop sending no-longer-valid cached responses from the
`/_matrix/client/{r0,v3}/sync` endpoints.
([!7](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/7))
3. Stop returning extra E2EE device updates from `/_matrix/client/{r0,v3}/sync`
as that violates the specification.
([!12](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/12))
4. Make certain membership state transitions work correctly again.
([!16](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/16))
* For example, it was previously impossible to unban users from rooms.
5. Ensure that `tracing-flame` flushes all its data before the process exits.
([!20 (263edcc)](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/20/diffs?commit_id=263edcc8a127ad2a541a3bb6ad35a8a459ea5616))
6. Reduce the likelihood of locking up the async runtime.
([!19](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/19))
7. Fix dynamically linked jemalloc builds.
([!23](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/23))
8. Fix search results not including subsequent pages in certain situations.
([!35 (0cdf032)](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/35/diffs?commit_id=0cdf03288ab8fa363c313bd929c8b5183d14ab77))
9. Fix search results missing events in subsequent pages in certain situations.
([!35 (3551a6e)](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/35/diffs?commit_id=3551a6ef7a29219b9b30f50a7e8c92b92debcdcf))
10. Only process admin commands if the admin bot is in the admin room.
([!43](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/43))
11. Fix bug where invalid account data from a client could prevent a user from
joining any upgraded rooms and brick rooms that affected users attempted to
upgrade.
([!53](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/53))
12. Fix bug where unexpected keys were deleted from `m.direct` account data
events when joining an upgraded room.
([!53](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/53))
13. Fixed appservice users not receiving federated invites if the local server
isn't already resident in the room
([!80](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/80))
14. Fix bug where, if a server has multiple public keys, only one would be fetched.
([!78](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/78))
15. Fix bug where expired keys may not be re-fetched in some scenarios.
([!78](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/78))
16. Fix bug where signing keys would not be fetched when joining a room if we
hadn't previously seen any signing keys from that server.
([!87](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/87))
### Added
1. Add various conveniences for users of the Nix package.
([51f9650](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/51f9650ca7bc9378690d331192c85fea3c151b58),
[bbb1a6f](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/bbb1a6fea45b16e8d4f94c1afbf7fa22c9281f37))
2. Add a NixOS module.
([33e7a46](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/33e7a46b5385ea9035c9d13c6775d63e5626a4c7))
3. Add a Conduit compat mode.
([a25f2ec](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/a25f2ec95045c5620c98eead88197a0bf13e6bb3))
* **BREAKING:** If you're migrating from Conduit, this option must be enabled
or else your homeserver will refuse to start.
4. Include `GRAPEVINE_VERSION_EXTRA` information in the
`/_matrix/federation/v1/version` endpoint.
([509b70b](https://gitlab.computer.surgery/matrix/grapevine-fork/-/commit/509b70bd827fec23b88e223b57e0df3b42cede34))
5. Allow multiple tracing subscribers to be active at once.
([!20 (7a154f74)](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/20/diffs?commit_id=7a154f74166c1309ca5752149e02bbe44cd91431))
6. Allow configuring the filter for `tracing-flame`.
([!20 (507de06)](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/20/diffs?commit_id=507de063f53f52e0cf8e2c1a67215a5ad87bb35a))
7. Collect HTTP response time metrics via OpenTelemetry and optionally expose
them as Prometheus metrics. This functionality is disabled by default.
([!22](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/22))
8. Collect metrics for lookup results (e.g. cache hits/misses).
([!15](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/15),
[!36](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/36))
9. Add configuration options for controlling the log format and colors.
([!46](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/46))
10. Recognize the `!admin` prefix to invoke admin commands.
([!45](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/45))
11. Add the `set-tracing-filter` admin command to change log/metrics/flame
filters dynamically at runtime.
([!49](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/49))
12. Add more configuration options.
([!49](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/49))
* `observability.traces.filter`: The `tracing` filter to use for
OpenTelemetry traces.
* `observability.traces.endpoint`: Where OpenTelemetry should send traces.
* `observability.flame.filter`: The `tracing` filter for `tracing-flame`.
* `observability.flame.filename`: Where `tracing-flame` will write its
output.
* `observability.logs.timestamp`: Whether timestamps should be included in
the logs.
13. Support building nix packages without IFD
([!73](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/73))
14. Report local users getting banned in the server logs and admin room.
([!65](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/65),
[!84](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/84))
15. Added support for Authenticated Media ([MSC3916](https://github.com/matrix-org/matrix-spec-proposals/pull/3916)).
([!58](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/58))
16. Added support for configuring and serving `/.well-known/matrix/...` data.
([!90](https://gitlab.computer.surgery/matrix/grapevine-fork/-/merge_requests/90))

View file

@ -1,12 +0,0 @@
# Code of conduct
We follow the [Rust Code of Conduct][rust-coc] with some extra points:
* In the absence of evidence to suggest otherwise, assume good faith when
engaging with others
* Moderation actions may be taken for behavior observed outside of
project-specific spaces
* We have limited patience, so violations may skip the warning and directly
result in a ban
[rust-coc]: https://www.rust-lang.org/policies/code-of-conduct

View file

@ -1,47 +0,0 @@
# Contributing
## On Matrix
Currently, the Matrix room at [#grapevine:computer.surgery][room] serves
multiple purposes:
* General discussion about the project, such as answering questions about it
* Reporting issues with Grapevine, if getting GitLab access is too much trouble
for you
* Providing support to users running Grapevine
* Discussing the development of Grapevine
If you'd like to engage in or observe any of those things, please join!
[room]: https://matrix.to/#/#grapevine:computer.surgery
## On GitLab
Instructions for getting GitLab access can be found on the [sign-in][sign-in]
page.
GitLab access is primarily useful if you'd like to open issues, engage in
discussions on issues or merge requests, or submit your own merge requests.
Note that if the sign-up process is too much trouble and you'd just
like to report an issue, feel free to report it in the Matrix room at
[#grapevine:computer.surgery][room]; someone with GitLab access can open an
issue on your behalf.
[sign-in]: https://gitlab.computer.surgery/users/sign_in
## Information about a vulnerability
If you find a security vulnerability in Grapevine, please privately report it to
the Grapevine maintainers in one of the following ways:
* Open a GitLab issue that's marked as confidential
* Create a private, invite-only, E2EE Matrix room and invite the following
users:
* `@benjamin:computer.surgery`
* `@charles:computer.surgery`
* `@xiretza:xiretza.xyz`
If the maintainers determine that the vulnerability is shared with Conduit or
other forks, we'll work with their teams to ensure that all affected projects
can release a fix at the same time.

View file

@ -1,80 +0,0 @@
# Introduction
Grapevine is a [Matrix][matrix] homeserver that was originally forked from
[Conduit 0.7.0][conduit].
[matrix]: https://matrix.org/
[conduit]: https://gitlab.com/famedly/conduit/-/tree/v0.7.0?ref_type=tags
## Goals
Our goal is to provide a robust and reliable Matrix homeserver implementation.
In order to accomplish this, we aim to do the following:
* Optimize for maintainability
* Implement automated testing to ensure correctness
* Improve instrumentation to provide real-world data to aid decision-making
## Non-goals
We also have some things we specifically want to avoid as we feel they inhibit
our ability to accomplish our goals:
* macOS or Windows support
* These operating systems are very uncommon in the hobbyist server space, and
we feel our effort is better spent elsewhere.
* Docker support
* Docker tends to generate a high volume of support requests that are solely
due to Docker itself or how users are using Docker. In attempt to mitigate
this, we will not provide first-party Docker images. Instead, we'd recommend
avoiding Docker and either using our pre-built statically-linked binaries
or building from source. However, if your deployment mechanism *requires*
Docker, it should be straightforward to build your own Docker image.
* Configuration via environment variables
* Environment variables restrict the options for structuring configuration and
support for them would increase the maintenance burden. If your deployment
mechanism requires this, consider using an external tool like
[`envsubst`][envsubst].
* Configuration compatibility with Conduit
* To provide a secure and ergonomic configuration experience, breaking changes
are required. However, we do intend to provide a migration tool to ease
migration; this feature is tracked [here][migration-tool].
* Perfect database compatibility with Conduit
* The current database compatibility status can be tracked [here][db-compat].
In the long run, it's inevitable that changes will be made to Conduit that
we won't want to pull in, or that we need to make changes that Conduit won't
want to pull in.
[envsubst]: https://github.com/a8m/envsubst
[migration-tool]: https://gitlab.computer.surgery/matrix/grapevine-fork/-/issues/38
[db-compat]: https://gitlab.computer.surgery/matrix/grapevine-fork/-/issues/17
## Project management
The project's current maintainers[^1] are:
| Matrix username | GitLab username |
|-|-|
| `@charles:computer.surgery` | `charles` |
| `@benjamin:computer.surgery` | `benjamin` |
| `@xiretza:xiretza.xyz` | `Lambda` |
We would like to expand this list in the future as social trust is built and
technical competence is demonstrated by other contributors.
We require at least 1 approving code review from a maintainer[^2] before changes
can be merged. This number may increase in the future as the list of maintainers
grows.
## Expectations management
This project is run and maintained entirely by volunteers who are doing their
best. Additionally, due to our goals, the development of new features may be
slower than alternatives. We find this to be an acceptable tradeoff considering
the importance of the reliability of a project like this.
---
[^1]: A "maintainer" is someone who has the ability to close issues opened by
someone else and merge changes.
[^2]: A maintainer approving their own change doesn't count.

View file

@ -30,26 +30,6 @@ name = "cargo-clippy"
group = "versions"
script = "cargo clippy -- --version"
[[task]]
name = "lychee"
group = "versions"
script = "lychee --version"
[[task]]
name = "markdownlint"
group = "versions"
script = "markdownlint --version"
[[task]]
name = "lychee"
group = "lints"
script = "lychee --offline ."
[[task]]
name = "markdownlint"
group = "lints"
script = "markdownlint ."
[[task]]
name = "cargo-fmt"
group = "lints"

View file

@ -226,8 +226,7 @@
"flake-compat": "flake-compat_2",
"flake-utils": "flake-utils_2",
"nix-filter": "nix-filter",
"nixpkgs": "nixpkgs_2",
"rust-manifest": "rust-manifest"
"nixpkgs": "nixpkgs_2"
}
},
"rust-analyzer-src": {
@ -247,18 +246,6 @@
"type": "github"
}
},
"rust-manifest": {
"flake": false,
"locked": {
"narHash": "sha256-tB9BZB6nRHDk5ELIVlGYlIjViLKBjQl52nC1avhcCwA=",
"type": "file",
"url": "https://static.rust-lang.org/dist/channel-rust-1.81.0.toml"
},
"original": {
"type": "file",
"url": "https://static.rust-lang.org/dist/channel-rust-1.81.0.toml"
}
},
"systems": {
"locked": {
"lastModified": 1681028828,

View file

@ -8,12 +8,6 @@
flake-utils.url = "github:numtide/flake-utils?ref=main";
nix-filter.url = "github:numtide/nix-filter?ref=main";
nixpkgs.url = "github:NixOS/nixpkgs?ref=nixos-unstable";
rust-manifest = {
# Keep version in sync with rust-toolchain.toml
url = "https://static.rust-lang.org/dist/channel-rust-1.81.0.toml";
flake = false;
};
};
outputs = inputs:
@ -27,6 +21,8 @@
inherit inputs;
oci-image = self.callPackage ./nix/pkgs/oci-image {};
# Return a new scope with overrides applied to the 'default' package
overrideDefaultPackage = args: self.overrideScope (final: prev: {
default = prev.default.override args;
@ -35,59 +31,26 @@
shell = self.callPackage ./nix/shell.nix {};
# The Rust toolchain to use
# Using fromManifestFile and parsing the toolchain file with importTOML
# instead of fromToolchainFile to avoid IFD
toolchain = let
toolchainFile = pkgs.lib.importTOML ./rust-toolchain.toml;
defaultProfileComponents = [
"rustc"
"cargo"
"rust-docs"
"rustfmt"
"clippy"
];
components = defaultProfileComponents ++
toolchainFile.toolchain.components;
targets = toolchainFile.toolchain.targets;
fenix = inputs.fenix.packages.${pkgs.stdenv.buildPlatform.system};
in
fenix.combine (builtins.map
(target:
(fenix.targets.${target}.fromManifestFile inputs.rust-manifest)
.withComponents components)
targets);
toolchain = inputs
.fenix
.packages
.${pkgs.pkgsBuildHost.system}
.fromToolchainFile {
file = ./rust-toolchain.toml;
# See also `rust-toolchain.toml`
sha256 = "sha256-opUgs6ckUQCyDxcB9Wy51pqhd0MPGHUVbwRKKPGiwZU=";
};
});
in
inputs.flake-utils.lib.eachDefaultSystem (system:
let
pkgs = import inputs.nixpkgs {
inherit system;
# Some users find it useful to set this on their Nixpkgs instance and
# we want to support that use case, so we set it here too to help us
# test/ensure that this works.
config.allowAliases = false;
};
pkgs = inputs.nixpkgs.legacyPackages.${system};
in
{
packages = rec {
packages = {
default = (mkScope pkgs).default;
oci-image = pkgs.dockerTools.buildImage {
name = default.pname;
tag = "next";
copyToRoot = [
pkgs.dockerTools.caCertificates
];
config = {
# Use the `tini` init system so that signals (e.g. ctrl+c/SIGINT)
# are handled as expected
Entrypoint = [
"${pkgs.lib.getExe' pkgs.tini "tini"}"
"--"
"${pkgs.lib.getExe default}"
];
};
};
oci-image = (mkScope pkgs).oci-image;
}
//
builtins.listToAttrs
@ -102,11 +65,6 @@
crossSystem = {
config = crossSystem;
};
# Some users find it useful to set this on their Nixpkgs
# instance and we want to support that use case, so we set
# it here too to help us test/ensure that this works.
config.allowAliases = false;
}).pkgsStatic;
in
[
@ -115,6 +73,12 @@
name = binaryName;
value = (mkScope pkgsCrossStatic).default;
}
# An output for an OCI image based on that binary
{
name = "oci-image-${crossSystem}";
value = (mkScope pkgsCrossStatic).oci-image;
}
]
)
[

View file

@ -18,17 +18,24 @@ in
options.services.grapevine = {
enable = lib.mkEnableOption "grapevine";
package = lib.mkPackageOption
inputs.self.packages.${pkgs.hostPlatform.system}
inputs.self.packages.${pkgs.system}
"grapevine"
{
default = "default";
pkgsText = "inputs.grapevine.packages.\${pkgs.hostPlatform.system}";
pkgsText = "inputs.grapevine.packages.\${pkgs.system}";
};
settings = lib.mkOption {
type = types.submodule {
freeformType = format.type;
options = {
address = lib.mkOption {
type = types.nonEmptyStr;
description = ''
The local IP address to bind to.
'';
default = "::1";
};
conduit_compat = lib.mkOption {
type = types.bool;
description = ''
@ -36,7 +43,7 @@ in
'';
default = false;
};
database.path = lib.mkOption {
database_path = lib.mkOption {
type = types.nonEmptyStr;
readOnly = true;
description = ''
@ -49,18 +56,12 @@ in
then "/var/lib/matrix-conduit"
else "/var/lib/grapevine";
};
listen = lib.mkOption {
type = types.listOf format.type;
port = lib.mkOption {
type = types.port;
description = ''
List of places to listen for incoming connections.
The local port to bind to.
'';
default = [
{
type = "tcp";
address = "::1";
port = 6167;
}
];
default = 6167;
};
};
};

View file

@ -71,10 +71,12 @@ let
} // buildDepsOnlyEnv;
commonAttrs = {
# Reading from cargoManifest directly instead of using
# createNameFromCargoToml to avoid IFD
pname = cargoManifest.package.name;
version = cargoManifest.package.version;
inherit
(craneLib.crateNameFromCargoToml {
cargoToml = "${inputs.self}/Cargo.toml";
})
pname
version;
src = let filter = inputs.nix-filter.lib; in filter {
root = inputs.self;

View file

@ -0,0 +1,25 @@
# Keep sorted
{ default
, dockerTools
, lib
, tini
}:
dockerTools.buildImage {
name = default.pname;
tag = "next";
copyToRoot = [
dockerTools.caCertificates
];
config = {
# Use the `tini` init system so that signals (e.g. ctrl+c/SIGINT)
# are handled as expected
Entrypoint = [
"${lib.getExe' tini "tini"}"
"--"
];
Cmd = [
"${lib.getExe default}"
];
};
}

View file

@ -1,13 +1,10 @@
# Keep sorted
{ buildPlatform
, default
{ default
, engage
, inputs
, jq
, lychee
, markdownlint-cli
, mdbook
, mkShell
, system
, toolchain
}:
@ -25,14 +22,11 @@ mkShell {
#
# This needs to come before `toolchain` in this list, otherwise
# `$PATH` will have stable rustfmt instead.
inputs.fenix.packages.${buildPlatform.system}.latest.rustfmt
inputs.fenix.packages.${system}.latest.rustfmt
# Keep sorted
engage
jq
lychee
markdownlint-cli
mdbook
toolchain
]
++

View file

@ -9,7 +9,7 @@
# If you're having trouble making the relevant changes, bug a maintainer.
[toolchain]
channel = "1.81.0"
channel = "1.78.0"
components = [
# For rust-analyzer
"rust-src",

View file

@ -2,4 +2,3 @@ pub(crate) mod appservice_server;
pub(crate) mod client_server;
pub(crate) mod ruma_wrapper;
pub(crate) mod server_server;
pub(crate) mod well_known;

View file

@ -57,19 +57,21 @@ where
*reqwest_request.timeout_mut() = Some(Duration::from_secs(30));
let url = reqwest_request.url().clone();
let mut response = services()
let mut response = match services()
.globals
.default_client()
.execute(reqwest_request)
.await
.inspect_err(|error| {
{
Ok(r) => r,
Err(e) => {
warn!(
%error,
appservice = registration.id,
%destination,
"Could not send request to appservice",
"Could not send request to appservice {:?} at {}: {}",
registration.id, destination, e
);
})?;
return Err(e.into());
}
};
// reqwest::Response -> http::Response conversion
let status = response.status();
@ -83,21 +85,18 @@ where
);
// TODO: handle timeout
let body = response.bytes().await.unwrap_or_else(|error| {
warn!(%error, "Server error");
let body = response.bytes().await.unwrap_or_else(|e| {
warn!("server error: {}", e);
Vec::new().into()
});
if status != 200 {
warn!(
appservice = %destination,
%status,
%url,
body = %utils::dbg_truncate_str(
String::from_utf8_lossy(&body).as_ref(),
100,
),
"Appservice returned bad response",
"Appservice returned bad response {} {}\n{}\n{:?}",
destination,
status,
url,
utils::string_from_bytes(&body)
);
}
@ -107,12 +106,10 @@ where
.expect("reqwest body is valid http body"),
);
response.map(Some).map_err(|error| {
response.map(Some).map_err(|_| {
warn!(
%error,
appservice = %destination,
%url,
"Appservice returned invalid response bytes",
"Appservice returned invalid response bytes {}\n{}",
destination, url
);
Error::BadServerResponse("Server returned bad response.")
})

View file

@ -276,7 +276,7 @@ pub(crate) async fn register_route(
body.initial_device_display_name.clone(),
)?;
info!(%user_id, "New user registered on this server");
info!("New user {} registered on this server.", user_id);
if body.appservice_info.is_none() && !is_guest {
services().admin.send_message(RoomMessageEventContent::notice_plain(
format!("New user {user_id} registered on this server."),
@ -293,8 +293,8 @@ pub(crate) async fn register_route(
services().admin.make_user_admin(&user_id, displayname).await?;
warn!(
%user_id,
"Granting admin privileges to the first user",
"Granting {} admin privileges as the first user",
user_id
);
}
}
@ -316,7 +316,8 @@ pub(crate) async fn register_route(
/// - Requires UIAA to verify user password
/// - Changes the password of the sender user
/// - The password hash is calculated using argon2 with 32 character salt, the
/// plain password is not saved
/// plain password is
/// not saved
///
/// If `logout_devices` is true it does the following for each device except the
/// sender device:
@ -375,7 +376,7 @@ pub(crate) async fn change_password_route(
}
}
info!(user_id = %sender_user, "User changed their password");
info!("User {} changed their password.", sender_user);
services().admin.send_message(RoomMessageEventContent::notice_plain(
format!("User {sender_user} changed their password."),
));
@ -455,7 +456,7 @@ pub(crate) async fn deactivate_route(
// Remove devices and mark account as deactivated
services().users.deactivate_account(sender_user)?;
info!(user_id = %sender_user, "User deactivated their account");
info!("User {} deactivated their account.", sender_user);
services().admin.send_message(RoomMessageEventContent::notice_plain(
format!("User {sender_user} deactivated their account."),
));

View file

@ -14,6 +14,9 @@ pub(crate) async fn get_capabilities_route(
_body: Ar<get_capabilities::v3::Request>,
) -> Result<Ra<get_capabilities::v3::Response>> {
let mut available = BTreeMap::new();
for room_version in &services().globals.unstable_room_versions {
available.insert(room_version.clone(), RoomVersionStability::Unstable);
}
for room_version in &services().globals.stable_room_versions {
available.insert(room_version.clone(), RoomVersionStability::Stable);
}

View file

@ -16,7 +16,8 @@ use crate::{services, Ar, Error, Ra, Result};
/// Allows loading room history around an event.
///
/// - Only works if the user is joined (TODO: always allow, but only show events
/// if the user was joined, depending on `history_visibility`)
/// if the user was
/// joined, depending on `history_visibility`)
#[allow(clippy::too_many_lines)]
pub(crate) async fn get_context_route(
body: Ar<get_context::v3::Request>,
@ -158,21 +159,19 @@ pub(crate) async fn get_context_route(
let mut state = Vec::new();
for (shortstatekey, event_id) in state_ids {
for (shortstatekey, id) in state_ids {
let (event_type, state_key) =
services().rooms.short.get_statekey_from_short(shortstatekey)?;
if event_type != StateEventType::RoomMember {
let Some(pdu) = services().rooms.timeline.get_pdu(&event_id)?
else {
error!(%event_id, "Event in state not found");
let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else {
error!("Pdu in state not found: {}", id);
continue;
};
state.push(pdu.to_state_event());
} else if !lazy_load_enabled || lazy_loaded.contains(&state_key) {
let Some(pdu) = services().rooms.timeline.get_pdu(&event_id)?
else {
error!(%event_id, "Event in state not found");
let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else {
error!("Pdu in state not found: {}", id);
continue;
};
state.push(pdu.to_state_event());

View file

@ -93,11 +93,7 @@ pub(crate) async fn set_room_visibility_route(
match &body.visibility {
room::Visibility::Public => {
services().rooms.directory.set_public(&body.room_id)?;
info!(
user_id = %sender_user,
room_id = %body.room_id,
"User made room public",
);
info!("{} made {} public", sender_user, body.room_id);
}
room::Visibility::Private => {
services().rooms.directory.set_not_public(&body.room_id)?;
@ -201,9 +197,177 @@ pub(crate) async fn get_public_rooms_filtered_helper(
.rooms
.directory
.public_rooms()
.filter_map(Result::ok)
.map(room_id_to_chunk)
.filter_map(Result::ok)
.map(|room_id| {
let room_id = room_id?;
let chunk = PublicRoomsChunk {
canonical_alias: services()
.rooms
.state_accessor
.room_state_get(
&room_id,
&StateEventType::RoomCanonicalAlias,
"",
)?
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomCanonicalAliasEventContent| c.alias)
.map_err(|_| {
Error::bad_database(
"Invalid canonical alias event in \
database.",
)
})
})?,
name: services().rooms.state_accessor.get_name(&room_id)?,
num_joined_members: services()
.rooms
.state_cache
.room_joined_count(&room_id)?
.unwrap_or_else(|| {
warn!("Room {} has no member count", room_id);
0
})
.try_into()
.expect("user count should not be that big"),
topic: services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomTopic, "")?
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomTopicEventContent| Some(c.topic))
.map_err(|_| {
error!(
"Invalid room topic event in database for \
room {}",
room_id
);
Error::bad_database(
"Invalid room topic event in database.",
)
})
})?,
world_readable: services()
.rooms
.state_accessor
.room_state_get(
&room_id,
&StateEventType::RoomHistoryVisibility,
"",
)?
.map_or(Ok(false), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| {
c.history_visibility
== HistoryVisibility::WorldReadable
})
.map_err(|_| {
Error::bad_database(
"Invalid room history visibility event in \
database.",
)
})
})?,
guest_can_join: services()
.rooms
.state_accessor
.room_state_get(
&room_id,
&StateEventType::RoomGuestAccess,
"",
)?
.map_or(Ok(false), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomGuestAccessEventContent| {
c.guest_access == GuestAccess::CanJoin
})
.map_err(|_| {
Error::bad_database(
"Invalid room guest access event in \
database.",
)
})
})?,
avatar_url: services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomAvatar, "")?
.map(|s| {
serde_json::from_str(s.content.get())
.map(|c: RoomAvatarEventContent| c.url)
.map_err(|_| {
Error::bad_database(
"Invalid room avatar event in database.",
)
})
})
.transpose()?
.flatten(),
join_rule: services()
.rooms
.state_accessor
.room_state_get(
&room_id,
&StateEventType::RoomJoinRules,
"",
)?
.map(|s| {
serde_json::from_str(s.content.get())
.map(|c: RoomJoinRulesEventContent| {
match c.join_rule {
JoinRule::Public => {
Some(PublicRoomJoinRule::Public)
}
JoinRule::Knock => {
Some(PublicRoomJoinRule::Knock)
}
_ => None,
}
})
.map_err(|e| {
error!(
"Invalid room join rule event in \
database: {}",
e
);
Error::BadDatabase(
"Invalid room join rule event in database.",
)
})
})
.transpose()?
.flatten()
.ok_or_else(|| {
Error::bad_database(
"Missing room join rule event for room.",
)
})?,
room_type: services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")?
.map(|s| {
serde_json::from_str::<RoomCreateEventContent>(
s.content.get(),
)
.map_err(|e| {
error!(
"Invalid room create event in database: {}",
e
);
Error::BadDatabase(
"Invalid room create event in database.",
)
})
})
.transpose()?
.and_then(|e| e.room_type),
room_id,
};
Ok(chunk)
})
.filter_map(Result::<_>::ok)
.filter(|chunk| {
if let Some(query) =
filter.generic_search_term.as_ref().map(|q| q.to_lowercase())
@ -266,146 +430,3 @@ pub(crate) async fn get_public_rooms_filtered_helper(
total_room_count_estimate: Some(total_room_count_estimate),
})
}
#[allow(clippy::too_many_lines)]
#[tracing::instrument]
fn room_id_to_chunk(room_id: ruma::OwnedRoomId) -> Result<PublicRoomsChunk> {
let canonical_alias = services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")?
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomCanonicalAliasEventContent| c.alias)
.map_err(|_| {
Error::bad_database(
"Invalid canonical alias event in database.",
)
})
})?;
let name = services().rooms.state_accessor.get_name(&room_id)?;
let num_joined_members = services()
.rooms
.state_cache
.room_joined_count(&room_id)?
.unwrap_or_else(|| {
warn!("Room has no member count");
0
})
.try_into()
.expect("user count should not be that big");
let topic = services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomTopic, "")?
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomTopicEventContent| Some(c.topic))
.map_err(|_| {
error!("Invalid room topic event in database for room",);
Error::bad_database("Invalid room topic event in database.")
})
})?;
let world_readable = services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")?
.map_or(Ok(false), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| {
c.history_visibility == HistoryVisibility::WorldReadable
})
.map_err(|_| {
Error::bad_database(
"Invalid room history visibility event in database.",
)
})
})?;
let guest_can_join = services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")?
.map_or(Ok(false), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomGuestAccessEventContent| {
c.guest_access == GuestAccess::CanJoin
})
.map_err(|_| {
Error::bad_database(
"Invalid room guest access event in database.",
)
})
})?;
let avatar_url = services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomAvatar, "")?
.map(|s| {
serde_json::from_str(s.content.get())
.map(|c: RoomAvatarEventContent| c.url)
.map_err(|_| {
Error::bad_database(
"Invalid room avatar event in database.",
)
})
})
.transpose()?
.flatten();
let join_rule = services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomJoinRules, "")?
.map(|s| {
serde_json::from_str(s.content.get())
.map(|c: RoomJoinRulesEventContent| match c.join_rule {
JoinRule::Public => Some(PublicRoomJoinRule::Public),
JoinRule::Knock => Some(PublicRoomJoinRule::Knock),
_ => None,
})
.map_err(|error| {
error!(%error, "Invalid room join rule event in database");
Error::BadDatabase(
"Invalid room join rule event in database.",
)
})
})
.transpose()?
.flatten()
.ok_or_else(|| {
Error::bad_database("Missing room join rule event for room.")
})?;
let room_type = services()
.rooms
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")?
.map(|s| {
serde_json::from_str::<RoomCreateEventContent>(s.content.get())
.map_err(|error| {
error!(%error, "Invalid room create event in database");
Error::BadDatabase("Invalid room create event in database.")
})
})
.transpose()?
.and_then(|e| e.room_type);
Ok(PublicRoomsChunk {
canonical_alias,
name,
num_joined_members,
room_id,
topic,
world_readable,
guest_can_join,
avatar_url,
join_rule,
room_type,
})
}

View file

@ -413,10 +413,8 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if let Some(remaining) =
min_elapsed_duration.checked_sub(time.elapsed())
{
debug!(%server, %tries, ?remaining, "Backing off from server");
if time.elapsed() < min_elapsed_duration {
debug!("Backing off query from {:?}", server);
return (
server,
Err(Error::BadServerResponse(
@ -430,8 +428,6 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
for (user_id, keys) in vec {
device_keys_input_fed.insert(user_id.to_owned(), keys.clone());
}
// TODO: switch .and_then(|result| result) to .flatten() when stable
// <https://github.com/rust-lang/rust/issues/70142>
(
server,
tokio::time::timeout(
@ -444,53 +440,48 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
),
)
.await
.map_err(|_e| Error::BadServerResponse("Query took too long"))
.and_then(|result| result),
.map_err(|_e| Error::BadServerResponse("Query took too long")),
)
})
.collect();
while let Some((server, response)) = futures.next().await {
let response = match response {
Ok(response) => response,
Err(error) => {
back_off(server.to_owned()).await;
debug!(%server, %error, "remote device key query failed");
failures.insert(server.to_string(), json!({}));
continue;
}
};
if let Ok(Ok(response)) = response {
for (user, masterkey) in response.master_keys {
let (master_key_id, mut master_key) =
services().users.parse_master_key(&user, &masterkey)?;
for (user, masterkey) in response.master_keys {
let (master_key_id, mut master_key) =
services().users.parse_master_key(&user, &masterkey)?;
if let Some(our_master_key) = services().users.get_key(
&master_key_id,
sender_user,
&user,
&allowed_signatures,
)? {
let (_, our_master_key) = services()
.users
.parse_master_key(&user, &our_master_key)?;
master_key.signatures.extend(our_master_key.signatures);
if let Some(our_master_key) = services().users.get_key(
&master_key_id,
sender_user,
&user,
&allowed_signatures,
)? {
let (_, our_master_key) = services()
.users
.parse_master_key(&user, &our_master_key)?;
master_key.signatures.extend(our_master_key.signatures);
}
let json = serde_json::to_value(master_key)
.expect("to_value always works");
let raw = serde_json::from_value(json)
.expect("Raw::from_value always works");
services().users.add_cross_signing_keys(
&user, &raw, &None, &None,
// Dont notify. A notification would trigger another key
// request resulting in an endless loop
false,
)?;
master_keys.insert(user, raw);
}
let json = serde_json::to_value(master_key)
.expect("to_value always works");
let raw = serde_json::from_value(json)
.expect("Raw::from_value always works");
services().users.add_cross_signing_keys(
&user, &raw, &None, &None,
// Dont notify. A notification would trigger another key
// request resulting in an endless loop
false,
)?;
master_keys.insert(user, raw);
self_signing_keys.extend(response.self_signing_keys);
device_keys.extend(response.device_keys);
} else {
back_off(server.to_owned()).await;
failures.insert(server.to_string(), json!({}));
}
self_signing_keys.extend(response.self_signing_keys);
device_keys.extend(response.device_keys);
}
Ok(get_keys::v3::Response {

View file

@ -3,28 +3,19 @@ use std::time::Duration;
use axum::response::IntoResponse;
use http::{
header::{CONTENT_DISPOSITION, CONTENT_SECURITY_POLICY, CONTENT_TYPE},
HeaderName, HeaderValue, Method,
HeaderName, HeaderValue,
};
use phf::{phf_set, Set};
use ruma::{
api::{
client::{
authenticated_media as authenticated_media_client,
error::ErrorKind,
media::{self as legacy_media, create_content},
},
federation::authenticated_media as authenticated_media_fed,
use ruma::api::client::{
error::ErrorKind,
media::{
create_content, get_content, get_content_as_filename,
get_content_thumbnail, get_media_config,
},
http_headers::{ContentDisposition, ContentDispositionType},
};
use tracing::{debug, error, info, warn};
use tracing::error;
use crate::{
service::media::FileMeta,
services,
utils::{self, MxcData},
Ar, Error, Ra, Result,
};
use crate::{service::media::FileMeta, services, utils, Ar, Error, Ra, Result};
const MXC_LENGTH: usize = 32;
@ -86,17 +77,16 @@ fn content_security_policy() -> HeaderValue {
// Doing this correctly is tricky, so I'm skipping it for now.
fn content_disposition_for(
content_type: Option<&str>,
filename: Option<String>,
) -> ContentDisposition {
let disposition_type = match content_type {
Some(x) if INLINE_CONTENT_TYPES.contains(x) => {
ContentDispositionType::Inline
}
_ => ContentDispositionType::Attachment,
};
ContentDisposition {
disposition_type,
filename: Option<&str>,
) -> String {
match (
content_type.is_some_and(|x| INLINE_CONTENT_TYPES.contains(x)),
filename,
) {
(true, None) => "inline".to_owned(),
(true, Some(x)) => format!("inline; filename={x}"),
(false, None) => "attachment".to_owned(),
(false, Some(x)) => format!("attachment; filename={x}"),
}
}
@ -124,22 +114,10 @@ fn set_header_or_panic(
/// # `GET /_matrix/media/r0/config`
///
/// Returns max upload size.
#[allow(deprecated)] // unauthenticated media
pub(crate) async fn get_media_config_legacy_route(
_body: Ar<legacy_media::get_media_config::v3::Request>,
) -> Result<Ra<legacy_media::get_media_config::v3::Response>> {
Ok(Ra(legacy_media::get_media_config::v3::Response {
upload_size: services().globals.max_request_size().into(),
}))
}
/// # `GET /_matrix/client/v1/media/config`
///
/// Returns max upload size.
pub(crate) async fn get_media_config_route(
_body: Ar<authenticated_media_client::get_media_config::v1::Request>,
) -> Result<Ra<authenticated_media_client::get_media_config::v1::Response>> {
Ok(Ra(authenticated_media_client::get_media_config::v1::Response {
_body: Ar<get_media_config::v3::Request>,
) -> Result<Ra<get_media_config::v3::Response>> {
Ok(Ra(get_media_config::v3::Response {
upload_size: services().globals.max_request_size().into(),
}))
}
@ -153,21 +131,21 @@ pub(crate) async fn get_media_config_route(
pub(crate) async fn create_content_route(
body: Ar<create_content::v3::Request>,
) -> Result<Ra<create_content::v3::Response>> {
let media_id = utils::random_string(MXC_LENGTH);
let mxc = MxcData::new(services().globals.server_name(), &media_id)?;
let mxc = format!(
"mxc://{}/{}",
services().globals.server_name(),
utils::random_string(MXC_LENGTH)
);
services()
.media
.create(
mxc.to_string(),
mxc.clone(),
body.filename
.clone()
.map(|filename| ContentDisposition {
disposition_type: ContentDispositionType::Inline,
filename: Some(filename),
})
.as_ref(),
body.content_type.clone(),
.as_ref()
.map(|filename| format!("inline; filename={filename}"))
.as_deref(),
body.content_type.as_deref(),
&body.file,
)
.await?;
@ -178,186 +156,41 @@ pub(crate) async fn create_content_route(
}))
}
/// Whether or not to allow remote content to be loaded
#[derive(Clone, Copy, PartialEq, Eq)]
enum AllowRemote {
Yes,
No,
}
impl From<bool> for AllowRemote {
fn from(allow: bool) -> Self {
if allow {
Self::Yes
} else {
Self::No
}
}
}
struct RemoteResponse {
#[allow(unused)]
metadata: authenticated_media_fed::ContentMetadata,
content: authenticated_media_fed::Content,
}
/// Fetches remote media content from a URL specified in a
/// `/_matrix/federation/v1/media/*/{mediaId}` `Location` header
#[tracing::instrument]
async fn get_redirected_content(
location: String,
) -> Result<authenticated_media_fed::Content> {
let location = location.parse().map_err(|error| {
warn!(location, %error, "Invalid redirect location");
Error::BadServerResponse("Invalid redirect location")
})?;
let response = services()
.globals
.federation_client()
.execute(reqwest::Request::new(Method::GET, location))
.await?;
let content_type = response
.headers()
.get(CONTENT_TYPE)
.map(|value| {
value.to_str().map_err(|error| {
error!(
?value,
%error,
"Invalid Content-Type header"
);
Error::BadServerResponse("Invalid Content-Type header")
})
})
.transpose()?
.map(str::to_owned);
let content_disposition = response
.headers()
.get(CONTENT_DISPOSITION)
.map(|value| {
ContentDisposition::try_from(value.as_bytes()).map_err(|error| {
error!(
?value,
%error,
"Invalid Content-Disposition header"
);
Error::BadServerResponse("Invalid Content-Disposition header")
})
})
.transpose()?;
Ok(authenticated_media_fed::Content {
file: response.bytes().await?.to_vec(),
content_type,
content_disposition,
})
}
#[tracing::instrument(skip_all)]
async fn get_remote_content_via_federation_api(
mxc: &MxcData<'_>,
) -> Result<RemoteResponse, Error> {
let authenticated_media_fed::get_content::v1::Response {
metadata,
content,
} = services()
.sending
.send_federation_request(
mxc.server_name,
authenticated_media_fed::get_content::v1::Request {
media_id: mxc.media_id.to_owned(),
timeout_ms: Duration::from_secs(20),
},
)
.await?;
let content = match content {
authenticated_media_fed::FileOrLocation::File(content) => {
debug!("Got media from remote server");
content
}
authenticated_media_fed::FileOrLocation::Location(location) => {
debug!(location, "Following redirect");
get_redirected_content(location).await?
}
};
Ok(RemoteResponse {
metadata,
content,
})
}
#[allow(deprecated)] // unauthenticated media
#[tracing::instrument(skip_all)]
async fn get_remote_content_via_legacy_api(
mxc: &MxcData<'_>,
) -> Result<RemoteResponse, Error> {
pub(crate) async fn get_remote_content(
mxc: &str,
server_name: &ruma::ServerName,
media_id: String,
) -> Result<get_content::v3::Response, Error> {
let content_response = services()
.sending
.send_federation_request(
mxc.server_name,
legacy_media::get_content::v3::Request {
server_name,
get_content::v3::Request {
allow_remote: false,
server_name: mxc.server_name.to_owned(),
media_id: mxc.media_id.to_owned(),
server_name: server_name.to_owned(),
media_id,
timeout_ms: Duration::from_secs(20),
allow_redirect: false,
},
)
.await?;
Ok(RemoteResponse {
metadata: authenticated_media_fed::ContentMetadata {},
content: authenticated_media_fed::Content {
file: content_response.file,
content_disposition: content_response.content_disposition,
content_type: content_response.content_type,
},
})
}
#[tracing::instrument]
pub(crate) async fn get_remote_content(
mxc: &MxcData<'_>,
) -> Result<RemoteResponse, Error> {
let fed_result = get_remote_content_via_federation_api(mxc).await;
let response = match fed_result {
Ok(response) => {
debug!("Got remote content via authenticated media API");
response
}
Err(Error::Federation(_, error))
if error.error_kind() == Some(&ErrorKind::Unrecognized)
// https://github.com/t2bot/matrix-media-repo/issues/609
|| error.error_kind() == Some(&ErrorKind::Unauthorized) =>
{
info!(
"Remote server does not support authenticated media, falling \
back to deprecated API"
);
get_remote_content_via_legacy_api(mxc).await?
}
Err(e) => {
return Err(e);
}
};
services()
.media
.create(
mxc.to_string(),
response.content.content_disposition.as_ref(),
response.content.content_type.clone(),
&response.content.file,
mxc.to_owned(),
content_response.content_disposition.as_deref(),
content_response.content_type.as_deref(),
&content_response.file,
)
.await?;
Ok(response)
Ok(get_content::v3::Response {
file: content_response.file,
content_disposition: content_response.content_disposition,
content_type: content_response.content_type,
cross_origin_resource_policy: Some("cross-origin".to_owned()),
})
}
/// # `GET /_matrix/media/r0/download/{serverName}/{mediaId}`
@ -365,72 +198,10 @@ pub(crate) async fn get_remote_content(
/// Load media from our server or over federation.
///
/// - Only allows federation if `allow_remote` is true
#[allow(deprecated)] // unauthenticated media
pub(crate) async fn get_content_legacy_route(
body: Ar<legacy_media::get_content::v3::Request>,
) -> Result<axum::response::Response> {
use authenticated_media_client::get_content::v1::{
Request as AmRequest, Response as AmResponse,
};
use legacy_media::get_content::v3::{
Request as LegacyRequest, Response as LegacyResponse,
};
fn convert_request(
LegacyRequest {
server_name,
media_id,
timeout_ms,
..
}: LegacyRequest,
) -> AmRequest {
AmRequest {
server_name,
media_id,
timeout_ms,
}
}
fn convert_response(
AmResponse {
file,
content_type,
content_disposition,
}: AmResponse,
) -> LegacyResponse {
LegacyResponse {
file,
content_type,
content_disposition,
cross_origin_resource_policy: Some("cross-origin".to_owned()),
}
}
let allow_remote = body.allow_remote.into();
get_content_route_ruma(body.map_body(convert_request), allow_remote)
.await
.map(|response| {
let response = convert_response(response);
let mut r = Ra(response).into_response();
set_header_or_panic(
&mut r,
CONTENT_SECURITY_POLICY,
content_security_policy(),
);
r
})
}
/// # `GET /_matrix/client/v1/media/download/{serverName}/{mediaId}`
///
/// Load media from our server or over federation.
pub(crate) async fn get_content_route(
body: Ar<authenticated_media_client::get_content::v1::Request>,
body: Ar<get_content::v3::Request>,
) -> Result<axum::response::Response> {
get_content_route_ruma(body, AllowRemote::Yes).await.map(|x| {
get_content_route_ruma(body).await.map(|x| {
let mut r = Ra(x).into_response();
set_header_or_panic(
@ -444,41 +215,42 @@ pub(crate) async fn get_content_route(
}
async fn get_content_route_ruma(
body: Ar<authenticated_media_client::get_content::v1::Request>,
allow_remote: AllowRemote,
) -> Result<authenticated_media_client::get_content::v1::Response> {
let mxc = MxcData::new(&body.server_name, &body.media_id)?;
body: Ar<get_content::v3::Request>,
) -> Result<get_content::v3::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
if let Some((
FileMeta {
content_type,
..
},
if let Some(FileMeta {
content_type,
file,
)) = services().media.get(mxc.to_string()).await?
..
}) = services().media.get(mxc.clone()).await?
{
Ok(authenticated_media_client::get_content::v1::Response {
Ok(get_content::v3::Response {
file,
content_disposition: Some(content_disposition_for(
content_type.as_deref(),
None,
)),
content_type,
cross_origin_resource_policy: Some("cross-origin".to_owned()),
})
} else if &*body.server_name != services().globals.server_name()
&& allow_remote == AllowRemote::Yes
&& body.allow_remote
{
let remote_response = get_remote_content(&mxc).await?;
Ok(authenticated_media_client::get_content::v1::Response {
file: remote_response.content.file,
let remote_content_response =
get_remote_content(&mxc, &body.server_name, body.media_id.clone())
.await?;
Ok(get_content::v3::Response {
file: remote_content_response.file,
content_disposition: Some(content_disposition_for(
remote_response.content.content_type.as_deref(),
remote_content_response.content_type.as_deref(),
None,
)),
content_type: remote_response.content.content_type,
content_type: remote_content_response.content_type,
cross_origin_resource_policy: Some("cross-origin".to_owned()),
})
} else {
Err(Error::BadRequest(ErrorKind::NotYetUploaded, "Media not found."))
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
}
}
@ -487,76 +259,10 @@ async fn get_content_route_ruma(
/// Load media from our server or over federation, permitting desired filename.
///
/// - Only allows federation if `allow_remote` is true
#[allow(deprecated)] // unauthenticated media
pub(crate) async fn get_content_as_filename_legacy_route(
body: Ar<legacy_media::get_content_as_filename::v3::Request>,
) -> Result<axum::response::Response> {
use authenticated_media_client::get_content_as_filename::v1::{
Request as AmRequest, Response as AmResponse,
};
use legacy_media::get_content_as_filename::v3::{
Request as LegacyRequest, Response as LegacyResponse,
};
fn convert_request(
LegacyRequest {
server_name,
media_id,
filename,
timeout_ms,
..
}: LegacyRequest,
) -> AmRequest {
AmRequest {
server_name,
media_id,
filename,
timeout_ms,
}
}
fn convert_response(
AmResponse {
file,
content_type,
content_disposition,
}: AmResponse,
) -> LegacyResponse {
LegacyResponse {
file,
content_type,
content_disposition,
cross_origin_resource_policy: Some("cross-origin".to_owned()),
}
}
let allow_remote = body.allow_remote.into();
get_content_as_filename_route_ruma(
body.map_body(convert_request),
allow_remote,
)
.await
.map(|response| {
let response = convert_response(response);
let mut r = Ra(response).into_response();
set_header_or_panic(
&mut r,
CONTENT_SECURITY_POLICY,
content_security_policy(),
);
r
})
}
/// # `GET /_matrix/client/v1/media/download/{serverName}/{mediaId}/{fileName}`
///
/// Load media from our server or over federation, permitting desired filename.
pub(crate) async fn get_content_as_filename_route(
body: Ar<authenticated_media_client::get_content_as_filename::v1::Request>,
body: Ar<get_content_as_filename::v3::Request>,
) -> Result<axum::response::Response> {
get_content_as_filename_route_ruma(body, AllowRemote::Yes).await.map(|x| {
get_content_as_filename_route_ruma(body).await.map(|x| {
let mut r = Ra(x).into_response();
set_header_or_panic(
@ -569,348 +275,146 @@ pub(crate) async fn get_content_as_filename_route(
})
}
async fn get_content_as_filename_route_ruma(
body: Ar<authenticated_media_client::get_content_as_filename::v1::Request>,
allow_remote: AllowRemote,
) -> Result<authenticated_media_client::get_content_as_filename::v1::Response> {
let mxc = MxcData::new(&body.server_name, &body.media_id)?;
pub(crate) async fn get_content_as_filename_route_ruma(
body: Ar<get_content_as_filename::v3::Request>,
) -> Result<get_content_as_filename::v3::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
if let Some((
FileMeta {
content_type,
..
},
if let Some(FileMeta {
content_type,
file,
)) = services().media.get(mxc.to_string()).await?
..
}) = services().media.get(mxc.clone()).await?
{
Ok(authenticated_media_client::get_content_as_filename::v1::Response {
Ok(get_content_as_filename::v3::Response {
file,
content_disposition: Some(content_disposition_for(
content_type.as_deref(),
Some(body.filename.clone()),
Some(body.filename.as_str()),
)),
content_type,
cross_origin_resource_policy: Some("cross-origin".to_owned()),
})
} else if &*body.server_name != services().globals.server_name()
&& allow_remote == AllowRemote::Yes
&& body.allow_remote
{
let remote_response = get_remote_content(&mxc).await?;
let remote_content_response =
get_remote_content(&mxc, &body.server_name, body.media_id.clone())
.await?;
Ok(authenticated_media_client::get_content_as_filename::v1::Response {
Ok(get_content_as_filename::v3::Response {
content_disposition: Some(content_disposition_for(
remote_response.content.content_type.as_deref(),
Some(body.filename.clone()),
remote_content_response.content_type.as_deref(),
Some(body.filename.as_str()),
)),
content_type: remote_response.content.content_type,
file: remote_response.content.file,
content_type: remote_content_response.content_type,
file: remote_content_response.file,
cross_origin_resource_policy: Some("cross-origin".to_owned()),
})
} else {
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
}
}
fn fix_thumbnail_headers(r: &mut axum::response::Response) {
let content_type = r
.headers()
.get(CONTENT_TYPE)
.and_then(|x| std::str::from_utf8(x.as_ref()).ok())
.map(ToOwned::to_owned);
set_header_or_panic(r, CONTENT_SECURITY_POLICY, content_security_policy());
set_header_or_panic(
r,
CONTENT_DISPOSITION,
content_disposition_for(content_type.as_deref(), None)
.to_string()
.try_into()
.expect("generated header value should be valid"),
);
}
/// # `GET /_matrix/media/r0/thumbnail/{serverName}/{mediaId}`
///
/// Load media thumbnail from our server or over federation.
///
/// - Only allows federation if `allow_remote` is true
#[allow(deprecated)] // unauthenticated media
pub(crate) async fn get_content_thumbnail_legacy_route(
body: Ar<legacy_media::get_content_thumbnail::v3::Request>,
) -> Result<axum::response::Response> {
use authenticated_media_client::get_content_thumbnail::v1::{
Request as AmRequest, Response as AmResponse,
};
use legacy_media::get_content_thumbnail::v3::{
Request as LegacyRequest, Response as LegacyResponse,
};
fn convert_request(
LegacyRequest {
server_name,
media_id,
method,
width,
height,
timeout_ms,
animated,
..
}: LegacyRequest,
) -> AmRequest {
AmRequest {
server_name,
media_id,
method,
width,
height,
timeout_ms,
animated,
}
}
fn convert_response(
AmResponse {
file,
content_type,
}: AmResponse,
) -> LegacyResponse {
LegacyResponse {
file,
content_type,
cross_origin_resource_policy: Some("cross-origin".to_owned()),
}
}
let allow_remote = body.allow_remote.into();
get_content_thumbnail_route_ruma(
body.map_body(convert_request),
allow_remote,
)
.await
.map(|response| {
let response = convert_response(response);
let mut r = Ra(response).into_response();
fix_thumbnail_headers(&mut r);
r
})
}
/// # `GET /_matrix/client/v1/media/thumbnail/{serverName}/{mediaId}`
///
/// Load media thumbnail from our server or over federation.
pub(crate) async fn get_content_thumbnail_route(
body: Ar<authenticated_media_client::get_content_thumbnail::v1::Request>,
body: Ar<get_content_thumbnail::v3::Request>,
) -> Result<axum::response::Response> {
get_content_thumbnail_route_ruma(body, AllowRemote::Yes).await.map(|x| {
get_content_thumbnail_route_ruma(body).await.map(|x| {
let mut r = Ra(x).into_response();
fix_thumbnail_headers(&mut r);
let content_type = r
.headers()
.get(CONTENT_TYPE)
.and_then(|x| std::str::from_utf8(x.as_ref()).ok())
.map(ToOwned::to_owned);
set_header_or_panic(
&mut r,
CONTENT_SECURITY_POLICY,
content_security_policy(),
);
set_header_or_panic(
&mut r,
CONTENT_DISPOSITION,
content_disposition_for(content_type.as_deref(), None)
.try_into()
.expect("generated header value should be valid"),
);
r
})
}
#[tracing::instrument(skip_all)]
async fn get_remote_thumbnail_via_federation_api(
server_name: &ruma::ServerName,
request: authenticated_media_fed::get_content_thumbnail::v1::Request,
) -> Result<RemoteResponse, Error> {
let authenticated_media_fed::get_content_thumbnail::v1::Response {
metadata,
content,
} = services()
.sending
.send_federation_request(server_name, request)
.await?;
let content = match content {
authenticated_media_fed::FileOrLocation::File(content) => {
debug!("Got thumbnail from remote server");
content
}
authenticated_media_fed::FileOrLocation::Location(location) => {
debug!(location, "Following redirect");
get_redirected_content(location).await?
}
};
Ok(RemoteResponse {
metadata,
content,
})
}
#[allow(deprecated)] // unauthenticated media
#[tracing::instrument(skip_all)]
async fn get_remote_thumbnail_via_legacy_api(
server_name: &ruma::ServerName,
authenticated_media_fed::get_content_thumbnail::v1::Request {
media_id,
method,
width,
height,
timeout_ms,
animated,
}: authenticated_media_fed::get_content_thumbnail::v1::Request,
) -> Result<RemoteResponse, Error> {
let content_response = services()
.sending
.send_federation_request(
server_name,
legacy_media::get_content_thumbnail::v3::Request {
server_name: server_name.to_owned(),
allow_remote: false,
allow_redirect: false,
media_id,
method,
width,
height,
timeout_ms,
animated,
},
)
.await?;
Ok(RemoteResponse {
metadata: authenticated_media_fed::ContentMetadata {},
content: authenticated_media_fed::Content {
file: content_response.file,
content_disposition: None,
content_type: content_response.content_type,
},
})
}
#[tracing::instrument]
pub(crate) async fn get_remote_thumbnail(
server_name: &ruma::ServerName,
request: authenticated_media_fed::get_content_thumbnail::v1::Request,
) -> Result<RemoteResponse, Error> {
let fed_result =
get_remote_thumbnail_via_federation_api(server_name, request.clone())
.await;
let response = match fed_result {
Ok(response) => {
debug!("Got remote content via authenticated media API");
response
}
Err(Error::Federation(_, error))
if error.error_kind() == Some(&ErrorKind::Unrecognized)
// https://github.com/t2bot/matrix-media-repo/issues/609
|| error.error_kind() == Some(&ErrorKind::Unauthorized) =>
{
info!(
"Remote server does not support authenticated media, falling \
back to deprecated API"
);
get_remote_thumbnail_via_legacy_api(server_name, request.clone())
.await?
}
Err(e) => {
return Err(e);
}
};
Ok(response)
}
async fn get_content_thumbnail_route_ruma(
body: Ar<authenticated_media_client::get_content_thumbnail::v1::Request>,
allow_remote: AllowRemote,
) -> Result<authenticated_media_client::get_content_thumbnail::v1::Response> {
let mxc = MxcData::new(&body.server_name, &body.media_id)?;
let width = body.width.try_into().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid.")
})?;
let height = body.height.try_into().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid.")
})?;
body: Ar<get_content_thumbnail::v3::Request>,
) -> Result<get_content_thumbnail::v3::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
let make_response = |file, content_type| {
authenticated_media_client::get_content_thumbnail::v1::Response {
file,
content_type,
}
};
if let Some((
FileMeta {
content_type,
..
},
if let Some(FileMeta {
content_type,
file,
)) =
services().media.get_thumbnail(mxc.to_string(), width, height).await?
{
return Ok(make_response(file, content_type));
}
if &*body.server_name != services().globals.server_name()
&& allow_remote == AllowRemote::Yes
{
let get_thumbnail_response = get_remote_thumbnail(
&body.server_name,
authenticated_media_fed::get_content_thumbnail::v1::Request {
height: body.height,
width: body.width,
method: body.method.clone(),
media_id: body.media_id.clone(),
timeout_ms: Duration::from_secs(20),
// we don't support animated thumbnails, so don't try requesting
// one - we're allowed to ignore the client's request for an
// animated thumbnail
animated: Some(false),
},
..
}) = services()
.media
.get_thumbnail(
mxc.clone(),
body.width.try_into().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid.")
})?,
body.height.try_into().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid.")
})?,
)
.await;
match get_thumbnail_response {
Ok(resp) => {
services()
.media
.upload_thumbnail(
mxc.to_string(),
None,
resp.content.content_type.clone(),
width,
height,
&resp.content.file,
)
.await?;
return Ok(make_response(
resp.content.file,
resp.content.content_type,
));
}
Err(error) => warn!(
%error,
"Failed to fetch thumbnail via federation, trying to fetch \
original media and create thumbnail ourselves"
),
}
get_remote_content(&mxc).await?;
if let Some((
FileMeta {
content_type,
..
},
.await?
{
Ok(get_content_thumbnail::v3::Response {
file,
)) = services()
content_type,
cross_origin_resource_policy: Some("cross-origin".to_owned()),
})
} else if &*body.server_name != services().globals.server_name()
&& body.allow_remote
{
let get_thumbnail_response = services()
.sending
.send_federation_request(
&body.server_name,
get_content_thumbnail::v3::Request {
allow_remote: false,
height: body.height,
width: body.width,
method: body.method.clone(),
server_name: body.server_name.clone(),
media_id: body.media_id.clone(),
timeout_ms: Duration::from_secs(20),
allow_redirect: false,
},
)
.await?;
services()
.media
.get_thumbnail(mxc.to_string(), width, height)
.await?
{
return Ok(make_response(file, content_type));
}
.upload_thumbnail(
mxc,
None,
get_thumbnail_response.content_type.as_deref(),
body.width.try_into().expect("all UInts are valid u32s"),
body.height.try_into().expect("all UInts are valid u32s"),
&get_thumbnail_response.file,
)
.await?;
error!("Source media doesn't exist even after fetching it from remote");
Ok(get_content_thumbnail::v3::Response {
file: get_thumbnail_response.file,
content_type: get_thumbnail_response.content_type,
cross_origin_resource_policy: Some("cross-origin".to_owned()),
})
} else {
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
}
Err(Error::BadRequest(ErrorKind::NotYetUploaded, "Media not found."))
}

View file

@ -226,11 +226,16 @@ pub(crate) async fn kick_user_route(
event.membership = MembershipState::Leave;
event.reason.clone_from(&body.reason);
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(body.room_id.clone())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
services()
.rooms
@ -245,11 +250,12 @@ pub(crate) async fn kick_user_route(
redacts: None,
},
sender_user,
&room_token,
&body.room_id,
&state_lock,
)
.await?;
drop(room_token);
drop(state_lock);
Ok(Ra(kick_user::v3::Response::new()))
}
@ -296,11 +302,16 @@ pub(crate) async fn ban_user_route(
},
)?;
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(body.room_id.clone())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
services()
.rooms
@ -315,11 +326,12 @@ pub(crate) async fn ban_user_route(
redacts: None,
},
sender_user,
&room_token,
&body.room_id,
&state_lock,
)
.await?;
drop(room_token);
drop(state_lock);
Ok(Ra(ban_user::v3::Response::new()))
}
@ -353,11 +365,16 @@ pub(crate) async fn unban_user_route(
event.membership = MembershipState::Leave;
event.reason.clone_from(&body.reason);
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(body.room_id.clone())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
services()
.rooms
@ -372,11 +389,12 @@ pub(crate) async fn unban_user_route(
redacts: None,
},
sender_user,
&room_token,
&body.room_id,
&state_lock,
)
.await?;
drop(room_token);
drop(state_lock);
Ok(Ra(unban_user::v3::Response::new()))
}
@ -500,7 +518,6 @@ pub(crate) async fn joined_members_route(
}
#[allow(clippy::too_many_lines)]
#[tracing::instrument(skip(reason, _third_party_signed))]
async fn join_room_by_id_helper(
sender_user: Option<&UserId>,
room_id: &RoomId,
@ -510,11 +527,16 @@ async fn join_room_by_id_helper(
) -> Result<join_room_by_id::v3::Response> {
let sender_user = sender_user.expect("user is authenticated");
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(room_id.to_owned())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await;
// Ask a remote server if we are not participating in this room
if services()
@ -535,8 +557,8 @@ async fn join_room_by_id_helper(
.as_ref()
.map(|join_rules_event| {
serde_json::from_str(join_rules_event.content.get())
.map_err(|error| {
warn!(%error, "Invalid join rules event");
.map_err(|e| {
warn!("Invalid join rules event: {}", e);
Error::bad_database(
"Invalid join rules event in db.",
)
@ -578,9 +600,10 @@ async fn join_room_by_id_helper(
{
if user.server_name() == services().globals.server_name()
&& services().rooms.state_accessor.user_can_invite(
&room_token,
room_id,
&user,
sender_user,
&state_lock,
)
{
auth_user = Some(user);
@ -617,7 +640,8 @@ async fn join_room_by_id_helper(
redacts: None,
},
sender_user,
&room_token,
room_id,
&state_lock,
)
.await
{
@ -772,7 +796,7 @@ async fn join_room_by_id_helper(
));
}
drop(room_token);
drop(state_lock);
let pub_key_map = RwLock::new(BTreeMap::new());
services()
.rooms
@ -787,7 +811,7 @@ async fn join_room_by_id_helper(
)
.await?;
} else {
info!("Joining over federation.");
info!("Joining {room_id} over federation.");
let (make_join_response, remote_server) =
make_join_request(sender_user, room_id, servers).await?;
@ -891,7 +915,7 @@ async fn join_room_by_id_helper(
// It has enough fields to be called a proper event now
let mut join_event = join_event_stub;
info!(server = %remote_server, "Asking other server for send_join");
info!("Asking {remote_server} for send_join");
let send_join_response = services()
.sending
.send_federation_request(
@ -951,13 +975,10 @@ async fn join_room_by_id_helper(
.expect("we created a valid pdu")
.insert(remote_server.to_string(), signature.clone());
}
Err(error) => {
Err(e) => {
warn!(
%error,
server = %remote_server,
event = ?signed_value,
"Other server sent invalid signature in sendjoin \
signatures for event",
"Server {remote_server} sent invalid signature in \
sendjoin signatures for event {signed_value:?}: {e:?}",
);
}
}
@ -994,11 +1015,10 @@ async fn join_room_by_id_helper(
};
let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(
|error| {
|e| {
warn!(
%error,
object = ?value,
"Invalid PDU in send_join response",
"Invalid PDU in send_join response: {} {:?}",
e, value
);
Error::BadServerResponse(
"Invalid PDU in send_join response.",
@ -1056,8 +1076,8 @@ async fn join_room_by_id_helper(
.ok()?
},
)
.map_err(|error| {
warn!(%error, "Auth check failed");
.map_err(|e| {
warn!("Auth check failed: {e}");
Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed")
})?;
@ -1088,7 +1108,13 @@ async fn join_room_by_id_helper(
services()
.rooms
.state
.force_state(&room_token, statehash_before_join, new, removed)
.force_state(
room_id,
statehash_before_join,
new,
removed,
&state_lock,
)
.await?;
info!("Updating joined counts for new room");
@ -1108,7 +1134,7 @@ async fn join_room_by_id_helper(
&parsed_join_pdu,
join_event,
vec![(*parsed_join_pdu.event_id).to_owned()],
&room_token,
&state_lock,
)
.await?;
@ -1116,10 +1142,11 @@ async fn join_room_by_id_helper(
// We set the room state after inserting the pdu, so that we never have
// a moment in time where events in the current room state do
// not exist
services()
.rooms
.state
.set_room_state(&room_token, statehash_after_join)?;
services().rooms.state.set_room_state(
room_id,
statehash_after_join,
&state_lock,
)?;
}
Ok(join_room_by_id::v3::Response::new(room_id.to_owned()))
@ -1141,7 +1168,7 @@ async fn make_join_request(
if remote_server == services().globals.server_name() {
continue;
}
info!(server = %remote_server, "Asking other server for make_join");
info!("Asking {remote_server} for make_join");
let make_join_response = services()
.sending
.send_federation_request(
@ -1171,8 +1198,8 @@ async fn validate_and_add_event_id(
pub_key_map: &RwLock<BTreeMap<String, SigningKeys>>,
) -> Result<(OwnedEventId, CanonicalJsonObject)> {
let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get())
.map_err(|error| {
error!(%error, object = ?pdu, "Invalid PDU in server response");
.map_err(|e| {
error!("Invalid PDU in server response: {:?}: {:?}", pdu, e);
Error::BadServerResponse("Invalid PDU in server response")
})?;
let event_id = EventId::parse(format!(
@ -1204,7 +1231,7 @@ async fn validate_and_add_event_id(
}
if time.elapsed() < min_elapsed_duration {
debug!(%event_id, "Backing off from event");
debug!("Backing off from {}", event_id);
return Err(Error::BadServerResponse(
"bad event, still backing off",
));
@ -1243,15 +1270,9 @@ async fn validate_and_add_event_id(
room_version,
);
if let Err(error) =
ruma::signatures::verify_event(&keys, &value, room_version)
if let Err(e) = ruma::signatures::verify_event(&keys, &value, room_version)
{
warn!(
%event_id,
%error,
?pdu,
"Event failed verification",
);
warn!("Event {} failed verification {:?} {}", event_id, pdu, e);
back_off(event_id).await;
return Err(Error::BadServerResponse("Event failed verification."));
}
@ -1274,11 +1295,16 @@ pub(crate) async fn invite_helper(
) -> Result<()> {
if user_id.server_name() != services().globals.server_name() {
let (pdu, pdu_json, invite_room_state) = {
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(room_id.to_owned())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await;
let content = to_raw_value(&RoomMemberEventContent {
avatar_url: None,
@ -1302,13 +1328,14 @@ pub(crate) async fn invite_helper(
redacts: None,
},
sender_user,
&room_token,
room_id,
&state_lock,
)?;
let invite_room_state =
services().rooms.state.calculate_invite_state(&pdu)?;
drop(room_token);
drop(state_lock);
(pdu, pdu_json, invite_room_state)
};
@ -1348,11 +1375,11 @@ pub(crate) async fn invite_helper(
if *pdu.event_id != *event_id {
warn!(
server = %user_id.server_name(),
our_object = ?pdu_json,
their_object = ?value,
"Other server changed invite event, that's not allowed in the \
spec",
"Server {} changed invite event, that's not allowed in the \
spec: ours: {:?}, theirs: {:?}",
user_id.server_name(),
pdu_json,
value
);
}
@ -1372,7 +1399,7 @@ pub(crate) async fn invite_helper(
)
})?;
let pdu_id = services()
let pdu_id: Vec<u8> = services()
.rooms
.event_handler
.handle_incoming_pdu(
@ -1409,11 +1436,16 @@ pub(crate) async fn invite_helper(
));
}
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(room_id.to_owned())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await;
services()
.rooms
@ -1437,11 +1469,12 @@ pub(crate) async fn invite_helper(
redacts: None,
},
sender_user,
&room_token,
room_id,
&state_lock,
)
.await?;
drop(room_token);
drop(state_lock);
Ok(())
}
@ -1467,14 +1500,13 @@ pub(crate) async fn leave_all_rooms(user_id: &UserId) -> Result<()> {
};
if let Err(error) = leave_room(user_id, &room_id, None).await {
warn!(%user_id, %room_id, %error, "Failed to leave room");
warn!(%user_id, %room_id, %error, "failed to leave room");
}
}
Ok(())
}
#[tracing::instrument(skip(reason))]
pub(crate) async fn leave_room(
user_id: &UserId,
room_id: &RoomId,
@ -1486,11 +1518,16 @@ pub(crate) async fn leave_room(
.state_cache
.server_in_room(services().globals.server_name(), room_id)?
{
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(room_id.to_owned())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await;
let member_event = services().rooms.state_accessor.room_state_get(
room_id,
@ -1538,12 +1575,13 @@ pub(crate) async fn leave_room(
redacts: None,
},
user_id,
&room_token,
room_id,
&state_lock,
)
.await?;
} else {
if let Err(error) = remote_leave_room(user_id, room_id).await {
warn!(%error, "Failed to leave room remotely");
if let Err(e) = remote_leave_room(user_id, room_id).await {
warn!("Failed to leave room {} remotely: {}", user_id, e);
// Don't tell the client about this error
}

View file

@ -1,4 +1,7 @@
use std::collections::{BTreeMap, HashSet};
use std::{
collections::{BTreeMap, HashSet},
sync::Arc,
};
use ruma::{
api::client::{
@ -29,11 +32,16 @@ pub(crate) async fn send_message_event_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_deref();
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(body.room_id.clone())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
// Forbid m.room.encrypted if encryption is disabled
if TimelineEventType::RoomEncrypted == body.event_type.to_string().into()
@ -96,7 +104,8 @@ pub(crate) async fn send_message_event_route(
redacts: None,
},
sender_user,
&room_token,
&body.room_id,
&state_lock,
)
.await?;
@ -107,7 +116,7 @@ pub(crate) async fn send_message_event_route(
event_id.as_bytes(),
)?;
drop(room_token);
drop(state_lock);
Ok(Ra(send_message_event::v3::Response::new((*event_id).to_owned())))
}
@ -117,7 +126,8 @@ pub(crate) async fn send_message_event_route(
/// Allows paginating through room history.
///
/// - Only works if the user is joined (TODO: always allow, but only show events
/// where the user was joined, depending on `history_visibility`)
/// where the user was
/// joined, depending on `history_visibility`)
#[allow(clippy::too_many_lines)]
pub(crate) async fn get_message_events_route(
body: Ar<get_message_events::v3::Request>,

View file

@ -1,3 +1,5 @@
use std::sync::Arc;
use ruma::{
api::{
client::{
@ -79,16 +81,26 @@ pub(crate) async fn set_displayname_route(
.collect();
for (pdu_builder, room_id) in all_rooms_joined {
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(room_id.clone())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
if let Err(error) = services()
.rooms
.timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_token)
.build_and_append_pdu(
pdu_builder,
sender_user,
&room_id,
&state_lock,
)
.await
{
warn!(%error, "failed to add PDU");
@ -191,16 +203,26 @@ pub(crate) async fn set_avatar_url_route(
.collect();
for (pdu_builder, room_id) in all_joined_rooms {
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(room_id.clone())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
if let Err(error) = services()
.rooms
.timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_token)
.build_and_append_pdu(
pdu_builder,
sender_user,
&room_id,
&state_lock,
)
.await
{
warn!(%error, "failed to add PDU");

View file

@ -1,3 +1,5 @@
use std::sync::Arc;
use ruma::{
api::client::redact::redact_event,
events::{room::redaction::RoomRedactionEventContent, TimelineEventType},
@ -17,11 +19,16 @@ pub(crate) async fn redact_event_route(
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let body = body.body;
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(body.room_id.clone())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
let event_id = services()
.rooms
@ -39,11 +46,12 @@ pub(crate) async fn redact_event_route(
redacts: Some(body.event_id.into()),
},
sender_user,
&room_token,
&body.room_id,
&state_lock,
)
.await?;
drop(room_token);
drop(state_lock);
let event_id = (*event_id).to_owned();
Ok(Ra(redact_event::v3::Response {

View file

@ -1,4 +1,4 @@
use std::{cmp::max, collections::BTreeMap};
use std::{cmp::max, collections::BTreeMap, sync::Arc};
use ruma::{
api::client::{
@ -63,8 +63,16 @@ pub(crate) async fn create_room_route(
services().rooms.short.get_or_create_shortroomid(&room_id)?;
let room_token =
services().globals.roomid_mutex_state.lock_key(room_id.clone()).await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
if !services().globals.allow_room_creation()
&& body.appservice_info.is_none()
@ -138,8 +146,17 @@ pub(crate) async fn create_room_route(
.deserialize_as::<CanonicalJsonObject>()
.expect("Invalid creation content");
match &room_version {
room_version if *room_version < RoomVersionId::V11 => {
match room_version {
RoomVersionId::V1
| RoomVersionId::V2
| RoomVersionId::V3
| RoomVersionId::V4
| RoomVersionId::V5
| RoomVersionId::V6
| RoomVersionId::V7
| RoomVersionId::V8
| RoomVersionId::V9
| RoomVersionId::V10 => {
content.insert(
"creator".into(),
json!(&sender_user).try_into().map_err(|_| {
@ -152,11 +169,7 @@ pub(crate) async fn create_room_route(
}
// V11 removed the "creator" key
RoomVersionId::V11 => {}
_ => {
return Err(Error::BadServerResponse(
"Unsupported room version.",
))
}
_ => unreachable!("Validity of room version already checked"),
}
content.insert(
@ -171,16 +184,21 @@ pub(crate) async fn create_room_route(
content
}
None => {
let content = match &room_version {
room_version if *room_version < RoomVersionId::V11 => {
let content = match room_version {
RoomVersionId::V1
| RoomVersionId::V2
| RoomVersionId::V3
| RoomVersionId::V4
| RoomVersionId::V5
| RoomVersionId::V6
| RoomVersionId::V7
| RoomVersionId::V8
| RoomVersionId::V9
| RoomVersionId::V10 => {
RoomCreateEventContent::new_v1(sender_user.to_owned())
}
RoomVersionId::V11 => RoomCreateEventContent::new_v11(),
_ => {
return Err(Error::BadServerResponse(
"Unsupported room version.",
))
}
_ => unreachable!("Validity of room version already checked"),
};
let mut content = serde_json::from_str::<CanonicalJsonObject>(
to_raw_value(&content)
@ -232,7 +250,8 @@ pub(crate) async fn create_room_route(
redacts: None,
},
sender_user,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -259,7 +278,8 @@ pub(crate) async fn create_room_route(
redacts: None,
},
sender_user,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -318,7 +338,8 @@ pub(crate) async fn create_room_route(
redacts: None,
},
sender_user,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -340,7 +361,8 @@ pub(crate) async fn create_room_route(
redacts: None,
},
sender_user,
&room_token,
&room_id,
&state_lock,
)
.await?;
}
@ -367,7 +389,8 @@ pub(crate) async fn create_room_route(
redacts: None,
},
sender_user,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -387,7 +410,8 @@ pub(crate) async fn create_room_route(
redacts: None,
},
sender_user,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -410,15 +434,16 @@ pub(crate) async fn create_room_route(
redacts: None,
},
sender_user,
&room_token,
&room_id,
&state_lock,
)
.await?;
// 6. Events listed in initial_state
for event in &body.initial_state {
let mut pdu_builder =
event.deserialize_as::<PduBuilder>().map_err(|error| {
warn!(%error, "Invalid initial state event");
event.deserialize_as::<PduBuilder>().map_err(|e| {
warn!("Invalid initial state event: {:?}", e);
Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid initial state event.",
@ -438,7 +463,12 @@ pub(crate) async fn create_room_route(
services()
.rooms
.timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_token)
.build_and_append_pdu(
pdu_builder,
sender_user,
&room_id,
&state_lock,
)
.await?;
}
@ -459,7 +489,8 @@ pub(crate) async fn create_room_route(
redacts: None,
},
sender_user,
&room_token,
&room_id,
&state_lock,
)
.await?;
}
@ -480,19 +511,20 @@ pub(crate) async fn create_room_route(
redacts: None,
},
sender_user,
&room_token,
&room_id,
&state_lock,
)
.await?;
}
// 8. Events implied by invite (and TODO: invite_3pid)
drop(room_token);
drop(state_lock);
for user_id in &body.invite {
if let Err(error) =
invite_helper(sender_user, user_id, &room_id, None, body.is_direct)
.await
{
warn!(%error, "Invite helper failed");
warn!(%error, "invite helper failed");
};
}
@ -505,7 +537,7 @@ pub(crate) async fn create_room_route(
services().rooms.directory.set_public(&room_id)?;
}
info!(user_id = %sender_user, room_id = %room_id, "User created a room");
info!("{} created a room", sender_user);
Ok(Ra(create_room::v3::Response::new(room_id)))
}
@ -523,7 +555,7 @@ pub(crate) async fn get_room_event_route(
let event = services().rooms.timeline.get_pdu(&body.event_id)?.ok_or_else(
|| {
warn!(event_id = %body.event_id, "Event not found");
warn!("Event not found, event ID: {:?}", &body.event_id);
Error::BadRequest(ErrorKind::NotFound, "Event not found.")
},
)?;
@ -603,11 +635,16 @@ pub(crate) async fn upgrade_room_route(
let replacement_room = RoomId::new(services().globals.server_name());
services().rooms.short.get_or_create_shortroomid(&replacement_room)?;
let original_room_token = services()
.globals
.roomid_mutex_state
.lock_key(body.room_id.clone())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
// Send a m.room.tombstone event to the old room to indicate that it is not
// intended to be used any further Fail if the sender does not have the
@ -628,16 +665,23 @@ pub(crate) async fn upgrade_room_route(
redacts: None,
},
sender_user,
&original_room_token,
&body.room_id,
&state_lock,
)
.await?;
// Change lock to replacement room
let replacement_room_token = services()
.globals
.roomid_mutex_state
.lock_key(replacement_room.clone())
.await;
drop(state_lock);
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(replacement_room.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
// Get the old room creation event
let mut create_event_content = serde_json::from_str::<CanonicalJsonObject>(
@ -661,8 +705,17 @@ pub(crate) async fn upgrade_room_route(
// Send a m.room.create event containing a predecessor field and the
// applicable room_version
match &body.new_version {
room_version if *room_version < RoomVersionId::V11 => {
match body.new_version {
RoomVersionId::V1
| RoomVersionId::V2
| RoomVersionId::V3
| RoomVersionId::V4
| RoomVersionId::V5
| RoomVersionId::V6
| RoomVersionId::V7
| RoomVersionId::V8
| RoomVersionId::V9
| RoomVersionId::V10 => {
create_event_content.insert(
"creator".into(),
json!(&sender_user).try_into().map_err(|_| {
@ -677,7 +730,7 @@ pub(crate) async fn upgrade_room_route(
// "creator" key no longer exists in V11 rooms
create_event_content.remove("creator");
}
_ => return Err(Error::BadServerResponse("Unsupported room version.")),
_ => unreachable!("Validity of room version already checked"),
}
create_event_content.insert(
"room_version".into(),
@ -725,7 +778,8 @@ pub(crate) async fn upgrade_room_route(
redacts: None,
},
sender_user,
&replacement_room_token,
&replacement_room,
&state_lock,
)
.await?;
@ -752,7 +806,8 @@ pub(crate) async fn upgrade_room_route(
redacts: None,
},
sender_user,
&replacement_room_token,
&replacement_room,
&state_lock,
)
.await?;
@ -793,7 +848,8 @@ pub(crate) async fn upgrade_room_route(
redacts: None,
},
sender_user,
&replacement_room_token,
&replacement_room,
&state_lock,
)
.await?;
}
@ -855,10 +911,13 @@ pub(crate) async fn upgrade_room_route(
redacts: None,
},
sender_user,
&original_room_token,
&body.room_id,
&state_lock,
)
.await?;
drop(state_lock);
// Return the replacement room id
Ok(Ra(upgrade_room::v3::Response {
replacement_room,

View file

@ -84,9 +84,7 @@ pub(crate) async fn search_events_route(
if let Some(s) = searches
.iter_mut()
.map(|s| (s.peek().cloned(), s))
.max_by_key(|(peek, _)| {
peek.as_ref().map(|id| id.as_bytes().to_vec())
})
.max_by_key(|(peek, _)| peek.clone())
.and_then(|(_, i)| i.next())
{
results.push(s);

View file

@ -79,7 +79,7 @@ pub(crate) async fn login_route(
} else if let Some(user) = user {
UserId::parse(user)
} else {
warn!(kind = ?body.login_info, "Bad login kind");
warn!("Bad login type: {:?}", &body.login_info);
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"Bad login type.",
@ -184,7 +184,7 @@ pub(crate) async fn login_route(
} else if let Some(user) = user {
UserId::parse(user)
} else {
warn!(kind = ?body.login_info, "Bad login kind");
warn!("Bad login type: {:?}", &body.login_info);
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"Bad login type.",
@ -214,7 +214,7 @@ pub(crate) async fn login_route(
user_id
}
_ => {
warn!(kind = ?body.login_info, "Unsupported or unknown login kind");
warn!("Unsupported or unknown login type: {:?}", &body.login_info);
return Err(Error::BadRequest(
ErrorKind::Unknown,
"Unsupported login type.",
@ -250,7 +250,7 @@ pub(crate) async fn login_route(
)?;
}
info!(%user_id, %device_id, "User logged in");
info!("{} logged in", user_id);
// Homeservers are still required to send the `home_server` field
#[allow(deprecated)]
@ -292,8 +292,6 @@ pub(crate) async fn logout_route(
services().users.remove_device(sender_user, sender_device)?;
info!(user_id = %sender_user, device_id = %sender_device, "User logged out");
Ok(Ra(logout::v3::Response::new()))
}

View file

@ -240,11 +240,16 @@ async fn send_state_event_for_key_helper(
}
}
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(room_id.to_owned())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await;
let event_id = services()
.rooms
@ -259,7 +264,8 @@ async fn send_state_event_for_key_helper(
redacts: None,
},
sender_user,
&room_token,
room_id,
&state_lock,
)
.await?;

View file

@ -1,5 +1,6 @@
use std::{
collections::{BTreeMap, BTreeSet, HashMap, HashSet},
sync::Arc,
time::Duration,
};
@ -36,7 +37,8 @@ use crate::{
/// Synchronize the client's state with the latest state on the server.
///
/// - This endpoint takes a `since` parameter which should be the `next_batch`
/// value from a previous request for incremental syncs.
/// value from a
/// previous request for incremental syncs.
///
/// Calling this endpoint without a `since` parameter returns:
/// - Some of the most recent events of each timeline
@ -49,9 +51,11 @@ use crate::{
/// - Some of the most recent events of each timeline that happened after
/// `since`
/// - If user joined the room after `since`: All state events (unless lazy
/// loading is activated) and all device list updates in that room
/// loading is activated) and
/// all device list updates in that room
/// - If the user was already in the room: A list of all events that are in the
/// state now, but were not in the state at `since`
/// state now, but were
/// not in the state at `since`
/// - If the state we send contains a member event: Joined and invited member
/// counts, heroes
/// - Device list updates that happened after `since`
@ -157,12 +161,17 @@ pub(crate) async fn sync_events_route(
{
// Get and drop the lock to wait for remaining operations to finish
let room_token = services()
.globals
.roomid_mutex_insert
.lock_key(room_id.clone())
.await;
drop(room_token);
let mutex_insert = Arc::clone(
services()
.globals
.roomid_mutex_insert
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let insert_lock = mutex_insert.lock().await;
drop(insert_lock);
}
let left_count = services()
@ -267,8 +276,8 @@ pub(crate) async fn sync_events_route(
left_state_ids.insert(leave_shortstatekey, left_event_id);
let mut i = 0;
for (key, event_id) in left_state_ids {
if full_state || since_state_ids.get(&key) != Some(&event_id) {
for (key, id) in left_state_ids {
if full_state || since_state_ids.get(&key) != Some(&id) {
let (event_type, state_key) =
services().rooms.short.get_statekey_from_short(key)?;
@ -278,10 +287,9 @@ pub(crate) async fn sync_events_route(
// TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565
|| *sender_user == state_key
{
let Some(pdu) =
services().rooms.timeline.get_pdu(&event_id)?
let Some(pdu) = services().rooms.timeline.get_pdu(&id)?
else {
error!(%event_id, "Event in state not found");
error!("Pdu in state not found: {}", id);
continue;
};
@ -321,12 +329,17 @@ pub(crate) async fn sync_events_route(
{
// Get and drop the lock to wait for remaining operations to finish
let room_token = services()
.globals
.roomid_mutex_insert
.lock_key(room_id.clone())
.await;
drop(room_token);
let mutex_insert = Arc::clone(
services()
.globals
.roomid_mutex_insert
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let insert_lock = mutex_insert.lock().await;
drop(insert_lock);
}
let invite_count = services()
@ -442,7 +455,7 @@ pub(crate) async fn sync_events_route(
}
match tokio::time::timeout(duration, watcher).await {
Ok(x) => x.expect("watcher should succeed"),
Err(error) => debug!(%error, "Timed out"),
Err(error) => debug!(%error, "timed out"),
};
}
Ok(Ra(response))
@ -467,12 +480,17 @@ async fn load_joined_room(
{
// Get and drop the lock to wait for remaining operations to finish
// This will make sure the we have all events until next_batch
let room_token = services()
.globals
.roomid_mutex_insert
.lock_key(room_id.to_owned())
.await;
drop(room_token);
let mutex_insert = Arc::clone(
services()
.globals
.roomid_mutex_insert
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let insert_lock = mutex_insert.lock().await;
drop(insert_lock);
}
let (timeline_pdus, limited) =
@ -506,7 +524,7 @@ async fn load_joined_room(
let Some(current_shortstatehash) =
services().rooms.state.get_room_shortstatehash(room_id)?
else {
error!("Room has no state");
error!("Room {} has no state", room_id);
return Err(Error::BadDatabase("Room has no state"));
};
@ -654,17 +672,16 @@ async fn load_joined_room(
let mut lazy_loaded = HashSet::new();
let mut i = 0;
for (shortstatekey, event_id) in current_state_ids {
for (shortstatekey, id) in current_state_ids {
let (event_type, state_key) = services()
.rooms
.short
.get_statekey_from_short(shortstatekey)?;
if event_type != StateEventType::RoomMember {
let Some(pdu) =
services().rooms.timeline.get_pdu(&event_id)?
let Some(pdu) = services().rooms.timeline.get_pdu(&id)?
else {
error!(%event_id, "Event in state not found");
error!("Pdu in state not found: {}", id);
continue;
};
state_events.push(pdu);
@ -679,10 +696,9 @@ async fn load_joined_room(
// TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565
|| *sender_user == state_key
{
let Some(pdu) =
services().rooms.timeline.get_pdu(&event_id)?
let Some(pdu) = services().rooms.timeline.get_pdu(&id)?
else {
error!(%event_id, "Event in state not found");
error!("Pdu in state not found: {}", id);
continue;
};
@ -746,14 +762,12 @@ async fn load_joined_room(
.state_full_ids(since_shortstatehash)
.await?;
for (key, event_id) in current_state_ids {
if full_state
|| since_state_ids.get(&key) != Some(&event_id)
{
for (key, id) in current_state_ids {
if full_state || since_state_ids.get(&key) != Some(&id) {
let Some(pdu) =
services().rooms.timeline.get_pdu(&event_id)?
services().rooms.timeline.get_pdu(&id)?
else {
error!(%event_id, "Event in state not found");
error!("Pdu in state not found: {}", id);
continue;
};
@ -881,12 +895,8 @@ async fn load_joined_room(
Ok(state_key_userid) => {
lazy_loaded.insert(state_key_userid);
}
Err(error) => {
error!(
event_id = %pdu.event_id,
%error,
"Invalid state key for member event",
);
Err(e) => {
error!("Invalid state key for member event: {}", e);
}
}
}
@ -964,7 +974,7 @@ async fn load_joined_room(
|(pdu_count, _)| {
Ok(Some(match pdu_count {
PduCount::Backfilled(_) => {
error!("Timeline in backfill state?!");
error!("timeline in backfill state?!");
"0".to_owned()
}
PduCount::Normal(c) => c.to_string(),
@ -1064,12 +1074,11 @@ fn load_timeline(
.rooms
.timeline
.pdus_until(sender_user, room_id, PduCount::MAX)?
.filter_map(|x| match x {
Ok(x) => Some(x),
Err(error) => {
error!(%error, "Bad PDU in pdus_since");
None
.filter_map(|r| {
if r.is_err() {
error!("Bad pdu in pdus_since: {:?}", r);
}
r.ok()
})
.take_while(|(pducount, _)| pducount > &roomsincecount);
@ -1186,7 +1195,7 @@ pub(crate) async fn sync_events_v4_route(
let Some(current_shortstatehash) =
services().rooms.state.get_room_shortstatehash(room_id)?
else {
error!(%room_id, "Room has no state");
error!("Room {} has no state", room_id);
continue;
};
@ -1259,12 +1268,12 @@ pub(crate) async fn sync_events_v4_route(
.state_full_ids(since_shortstatehash)
.await?;
for (key, event_id) in current_state_ids {
if since_state_ids.get(&key) != Some(&event_id) {
for (key, id) in current_state_ids {
if since_state_ids.get(&key) != Some(&id) {
let Some(pdu) =
services().rooms.timeline.get_pdu(&event_id)?
services().rooms.timeline.get_pdu(&id)?
else {
error!(%event_id, "Event in state not found");
error!("Pdu in state not found: {}", id);
continue;
};
if pdu.kind == TimelineEventType::RoomMember {
@ -1543,7 +1552,7 @@ pub(crate) async fn sync_events_v4_route(
.map_or(Ok::<_, Error>(None), |(pdu_count, _)| {
Ok(Some(match pdu_count {
PduCount::Backfilled(_) => {
error!("Timeline in backfill state?!");
error!("timeline in backfill state?!");
"0".to_owned()
}
PduCount::Normal(c) => c.to_string(),
@ -1695,7 +1704,7 @@ pub(crate) async fn sync_events_v4_route(
}
match tokio::time::timeout(duration, watcher).await {
Ok(x) => x.expect("watcher should succeed"),
Err(error) => debug!(%error, "Timed out"),
Err(error) => debug!(%error, "timed out"),
};
}

View file

@ -5,7 +5,6 @@ use ruma::{
client::{error::ErrorKind, to_device::send_event_to_device},
federation::{self, transactions::edu::DirectDeviceContent},
},
serde::Raw,
to_device::DeviceIdOrAllDevices,
};
@ -41,7 +40,7 @@ pub(crate) async fn send_event_to_device_route(
services().sending.send_reliable_edu(
target_user_id.server_name(),
Raw::new(
serde_json::to_vec(
&federation::transactions::edu::Edu::DirectToDevice(
DirectDeviceContent {
sender: sender_user.clone(),

View file

@ -29,10 +29,10 @@ pub(crate) async fn get_supported_versions_route(
"v1.4".to_owned(),
"v1.5".to_owned(),
],
unstable_features: BTreeMap::from_iter([
("org.matrix.e2e_cross_signing".to_owned(), true),
("org.matrix.msc3916.stable".to_owned(), true),
]),
unstable_features: BTreeMap::from_iter([(
"org.matrix.e2e_cross_signing".to_owned(),
true,
)]),
};
Ok(Ra(resp))

View file

@ -13,7 +13,8 @@ use crate::{services, Ar, Ra, Result};
/// Searches all known users for a match.
///
/// - Hides any local users that aren't in any public rooms (i.e. those that
/// have the join rule set to public) and don't share a room with the sender
/// have the join rule set to public)
/// and don't share a room with the sender
pub(crate) async fn search_users_route(
body: Ar<search_users::v3::Request>,
) -> Result<Ra<search_users::v3::Response>> {

View file

@ -24,31 +24,6 @@ pub(crate) struct Ar<T> {
pub(crate) appservice_info: Option<RegistrationInfo>,
}
impl<T> Ar<T> {
pub(crate) fn map_body<F, U>(self, f: F) -> Ar<U>
where
F: FnOnce(T) -> U,
{
let Ar {
body,
sender_user,
sender_device,
sender_servername,
json_body,
appservice_info,
} = self;
Ar {
body: f(body),
sender_user,
sender_device,
sender_servername,
json_body,
appservice_info,
}
}
}
impl<T> Deref for Ar<T> {
type Target = T;

View file

@ -25,7 +25,7 @@ use ruma::{
OwnedServerName, OwnedUserId, UserId,
};
use serde::Deserialize;
use tracing::{error, warn};
use tracing::{debug, error, warn};
use super::{Ar, Ra};
use crate::{service::appservice::RegistrationInfo, services, Error, Result};
@ -81,8 +81,8 @@ async fn ar_from_request_inner(
let query = parts.uri.query().unwrap_or_default();
let query_params: QueryParams = match serde_html_form::from_str(query) {
Ok(params) => params,
Err(error) => {
error!(%error, %query, "Failed to deserialize query parameters");
Err(e) => {
error!(%query, "Failed to deserialize query parameters: {}", e);
return Err(Error::BadRequest(
ErrorKind::Unknown,
"Failed to read query parameters",
@ -181,10 +181,10 @@ async fn ar_from_request_inner(
let TypedHeader(Authorization(x_matrix)) = parts
.extract::<TypedHeader<Authorization<XMatrix>>>()
.await
.map_err(|error| {
warn!(%error, "Missing or invalid Authorization header");
.map_err(|e| {
warn!("Missing or invalid Authorization header: {}", e);
let msg = match error.reason() {
let msg = match e.reason() {
TypedHeaderRejectionReason::Missing => {
"Missing Authorization header."
}
@ -212,7 +212,7 @@ async fn ar_from_request_inner(
let origin_signatures = BTreeMap::from_iter([(
x_matrix.key.to_string(),
CanonicalJsonValue::String(x_matrix.sig.to_string()),
CanonicalJsonValue::String(x_matrix.sig),
)]);
let signatures = BTreeMap::from_iter([(
@ -267,8 +267,8 @@ async fn ar_from_request_inner(
let keys = match keys_result {
Ok(b) => b,
Err(error) => {
warn!(%error, "Failed to fetch signing keys");
Err(e) => {
warn!("Failed to fetch signing keys: {}", e);
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"Failed to fetch signing keys.",
@ -293,12 +293,10 @@ async fn ar_from_request_inner(
match ruma::signatures::verify_json(&pub_key_map, &request_map)
{
Ok(()) => (None, None, Some(x_matrix.origin), None),
Err(error) => {
Err(e) => {
warn!(
%error,
origin = %x_matrix.origin,
object = ?request_map,
"Failed to verify JSON request"
"Failed to verify json request from {}: {}\n{:?}",
x_matrix.origin, e, request_map
);
if parts.uri.to_string().contains('@') {
@ -404,12 +402,9 @@ where
let body =
T::try_from_http_request(pieces.http_request, &pieces.path_params)
.map_err(|error| {
warn!(
%error,
body = ?pieces.json_body,
"Request body JSON structure is incorrect"
);
.map_err(|e| {
warn!("try_from_http_request failed: {:?}", e);
debug!("JSON body: {:?}", pieces.json_body);
Error::BadRequest(
ErrorKind::BadJson,
"Failed to deserialize request.",

View file

@ -1,3 +1,5 @@
#![allow(deprecated)]
use std::{
collections::BTreeMap,
fmt::Debug,
@ -9,13 +11,11 @@ use std::{
use axum::{response::IntoResponse, Json};
use axum_extra::headers::{Authorization, HeaderMapExt};
use base64::Engine as _;
use get_profile_information::v1::ProfileField;
use ruma::{
api::{
client::error::{Error as RumaError, ErrorKind},
federation::{
authenticated_media,
authorization::get_event_authorization,
backfill::get_backfill,
device::get_devices::{self, v1::UserDevice},
@ -55,7 +55,6 @@ use ruma::{
},
serde::{Base64, JsonObject, Raw},
server_util::authorization::XMatrix,
state_res::Event,
to_device::DeviceIdOrAllDevices,
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId,
MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, OwnedServerName,
@ -64,16 +63,13 @@ use ruma::{
};
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::sync::RwLock;
use tracing::{debug, error, field, trace, trace_span, warn};
use tracing::{debug, error, field, warn};
use super::appservice_server;
use crate::{
api::client_server::{self, claim_keys_helper, get_keys_helper},
observability::{FoundIn, Lookup, METRICS},
service::pdu::{gen_event_id_canonical_json, PduBuilder},
services,
utils::{self, dbg_truncate_str, MxcData},
Ar, Error, PduEvent, Ra, Result,
services, utils, Ar, Error, PduEvent, Ra, Result,
};
/// Wraps either an literal IP address plus port, or a hostname plus complement
@ -132,11 +128,10 @@ impl FedDest {
}
}
#[tracing::instrument(skip(request, log_error), fields(url))]
#[tracing::instrument(skip(request), fields(url))]
pub(crate) async fn send_request<T>(
destination: &ServerName,
request: T,
log_error: bool,
) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug,
@ -180,13 +175,12 @@ where
.try_into_http_request::<Vec<u8>>(
&actual_destination_str,
SendAccessToken::IfRequired(""),
&[MatrixVersion::V1_11],
&[MatrixVersion::V1_4],
)
.map_err(|error| {
.map_err(|e| {
warn!(
%error,
actual_destination = actual_destination_str,
"Failed to find destination",
"Failed to find destination {}: {}",
actual_destination_str, e
);
Error::BadServerResponse("Invalid destination")
})?;
@ -243,24 +237,15 @@ where
.unwrap();
let key_id = OwnedSigningKeyId::try_from(key_id.clone()).unwrap();
let signature = Base64::parse(signature.as_str().unwrap())
.expect("generated signature should be valid base64");
let signature = signature.as_str().unwrap().to_owned();
http_request.headers_mut().typed_insert(Authorization(XMatrix::new(
services().globals.server_name().to_owned(),
destination.to_owned(),
Some(destination.to_owned()),
key_id,
signature,
)));
// can be enabled selectively using `filter =
// grapevine[outgoing_request_curl]=trace` in config
trace_span!("outgoing_request_curl").in_scope(|| {
trace!(
cmd = utils::curlify(&http_request),
"curl command line for outgoing request"
);
});
let reqwest_request = reqwest::Request::try_from(http_request)?;
let url = reqwest_request.url().clone();
@ -270,67 +255,85 @@ where
let response =
services().globals.federation_client().execute(reqwest_request).await;
let mut response = response.inspect_err(|error| {
if log_error {
warn!(%error, "Could not send request");
match response {
Ok(mut response) => {
// reqwest::Response -> http::Response conversion
let status = response.status();
debug!(status = u16::from(status), "Received response");
let mut http_response_builder = http::Response::builder()
.status(status)
.version(response.version());
mem::swap(
response.headers_mut(),
http_response_builder
.headers_mut()
.expect("http::response::Builder is usable"),
);
debug!("Getting response bytes");
// TODO: handle timeout
let body = response.bytes().await.unwrap_or_else(|e| {
warn!("server error {}", e);
Vec::new().into()
});
debug!("Got response bytes");
if status != 200 {
warn!(
status = u16::from(status),
response = String::from_utf8_lossy(&body)
.lines()
.collect::<Vec<_>>()
.join(" "),
"Received error over federation",
);
}
let http_response = http_response_builder
.body(body)
.expect("reqwest body is valid http body");
if status == 200 {
debug!("Parsing response bytes");
let response =
T::IncomingResponse::try_from_http_response(http_response);
if response.is_ok() && write_destination_to_cache {
METRICS.record_lookup(
Lookup::FederationDestination,
FoundIn::Remote,
);
services()
.globals
.actual_destination_cache
.write()
.await
.insert(
OwnedServerName::from(destination),
(actual_destination, host),
);
}
response.map_err(|e| {
warn!(error = %e, "Invalid 200 response",);
Error::BadServerResponse(
"Server returned bad 200 response.",
)
})
} else {
Err(Error::Federation(
destination.to_owned(),
RumaError::from_http_response(http_response),
))
}
}
Err(e) => {
warn!(
error = %e,
"Could not send request",
);
Err(e.into())
}
})?;
// reqwest::Response -> http::Response conversion
let status = response.status();
debug!(status = u16::from(status), "Received response");
let mut http_response_builder =
http::Response::builder().status(status).version(response.version());
mem::swap(
response.headers_mut(),
http_response_builder
.headers_mut()
.expect("http::response::Builder is usable"),
);
debug!("Getting response bytes");
// TODO: handle timeout
let body = response.bytes().await.unwrap_or_else(|error| {
warn!(%error, "Server error");
Vec::new().into()
});
debug!("Got response bytes");
if status != 200 {
warn!(
status = u16::from(status),
response =
dbg_truncate_str(String::from_utf8_lossy(&body).as_ref(), 100)
.into_owned(),
"Received error over federation",
);
}
let http_response = http_response_builder
.body(body)
.expect("reqwest body is valid http body");
if status != 200 {
return Err(Error::Federation(
destination.to_owned(),
RumaError::from_http_response(http_response),
));
}
debug!("Parsing response bytes");
let response = T::IncomingResponse::try_from_http_response(http_response);
if response.is_ok() && write_destination_to_cache {
METRICS.record_lookup(Lookup::FederationDestination, FoundIn::Remote);
services().globals.actual_destination_cache.write().await.insert(
OwnedServerName::from(destination),
(actual_destination, host),
);
}
response.map_err(|e| {
warn!(error = %e, "Invalid 200 response");
Error::BadServerResponse("Server returned bad 200 response.")
})
}
fn get_ip_with_port(destination_str: &str) -> Option<FedDest> {
@ -356,11 +359,11 @@ fn add_port_to_hostname(destination_str: &str) -> FedDest {
/// Numbers in comments below refer to bullet points in linked section of
/// specification
#[allow(clippy::too_many_lines)]
#[tracing::instrument(ret(level = "debug"))]
#[tracing::instrument(skip(destination), ret(level = "debug"))]
async fn find_actual_destination(
destination: &'_ ServerName,
) -> (FedDest, FedDest) {
debug!("Finding actual destination");
debug!("Finding actual destination for {destination}");
let destination_str = destination.as_str().to_owned();
let mut hostname = destination_str.clone();
let actual_destination = match get_ip_with_port(&destination_str) {
@ -374,7 +377,7 @@ async fn find_actual_destination(
let (host, port) = destination_str.split_at(pos);
FedDest::Named(host.to_owned(), port.to_owned())
} else {
debug!(%destination, "Requesting well known");
debug!("Requesting well known for {destination}");
if let Some(delegated_hostname) =
request_well_known(destination.as_str()).await
{
@ -483,7 +486,7 @@ async fn find_actual_destination(
}
}
};
debug!(?actual_destination, "Resolved actual destination");
debug!("Actual destination: {actual_destination:?}");
// Can't use get_ip_with_port here because we don't want to add a port
// to an IP address if it wasn't specified
@ -540,12 +543,12 @@ async fn request_well_known(destination: &str) -> Option<String> {
let response = services()
.globals
.default_client()
.get(format!("https://{destination}/.well-known/matrix/server"))
.get(&format!("https://{destination}/.well-known/matrix/server"))
.send()
.await;
debug!("Got well known response");
if let Err(error) = &response {
debug!(%error, "Failed to request .well-known");
if let Err(e) = &response {
debug!("Well known error: {e:?}");
return None;
}
let text = response.ok()?.text().await;
@ -573,26 +576,23 @@ pub(crate) async fn get_server_version_route(
/// Gets the public signing keys of this server.
///
/// - Matrix does not support invalidating public keys, so the key returned by
/// this will be valid forever.
/// this will be valid
/// forever.
// Response type for this endpoint is Json because we need to calculate a
// signature for the response
pub(crate) async fn get_server_keys_route() -> Result<impl IntoResponse> {
let keys: Vec<_> = [services().globals.keypair()]
.into_iter()
.chain(&services().globals.config.extra_key)
.collect();
let mut verify_keys: BTreeMap<OwnedServerSigningKeyId, VerifyKey> =
BTreeMap::new();
for key in &keys {
verify_keys.insert(
format!("ed25519:{}", key.version())
.try_into()
.expect("found invalid server signing keys in DB"),
VerifyKey {
key: Base64::new(key.public_key().to_vec()),
},
);
}
verify_keys.insert(
format!("ed25519:{}", services().globals.keypair().version())
.try_into()
.expect("found invalid server signing keys in DB"),
VerifyKey {
key: Base64::new(
services().globals.keypair().public_key().to_vec(),
),
},
);
let mut response = serde_json::from_slice(
get_server_keys::v2::Response {
server_key: Raw::new(&ServerSigningKeys {
@ -613,14 +613,12 @@ pub(crate) async fn get_server_keys_route() -> Result<impl IntoResponse> {
)
.unwrap();
for key in &keys {
ruma::signatures::sign_json(
services().globals.server_name().as_str(),
*key,
&mut response,
)
.unwrap();
}
ruma::signatures::sign_json(
services().globals.server_name().as_str(),
services().globals.keypair(),
&mut response,
)
.unwrap();
Ok(Json(response))
}
@ -630,7 +628,8 @@ pub(crate) async fn get_server_keys_route() -> Result<impl IntoResponse> {
/// Gets the public signing keys of this server.
///
/// - Matrix does not support invalidating public keys, so the key returned by
/// this will be valid forever.
/// this will be valid
/// forever.
pub(crate) async fn get_server_keys_deprecated_route() -> impl IntoResponse {
get_server_keys_route().await
}
@ -686,8 +685,8 @@ pub(crate) fn parse_incoming_pdu(
pdu: &RawJsonValue,
) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
let value: CanonicalJsonObject =
serde_json::from_str(pdu.get()).map_err(|error| {
warn!(%error, object = ?pdu, "Error parsing incoming event");
serde_json::from_str(pdu.get()).map_err(|e| {
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
Error::BadServerResponse("Invalid PDU in server response")
})?;
@ -729,8 +728,8 @@ pub(crate) async fn send_transaction_message_route(
for pdu in &body.pdus {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get())
.map_err(|error| {
warn!(%error, object = ?pdu, "Error parsing incoming event");
.map_err(|e| {
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
Error::BadServerResponse("Invalid PDU in server response")
})?;
let room_id: OwnedRoomId = value
@ -742,26 +741,32 @@ pub(crate) async fn send_transaction_message_route(
))?;
if services().rooms.state.get_room_version(&room_id).is_err() {
debug!(%room_id, "This server is not in the room");
debug!("Server is not in room {room_id}");
continue;
}
let r = parse_incoming_pdu(pdu);
let (event_id, value, room_id) = match r {
Ok(t) => t,
Err(error) => {
warn!(%error, object = ?pdu, "Error parsing incoming event");
Err(e) => {
warn!("Could not parse PDU: {e}");
warn!("Full PDU: {:?}", &pdu);
continue;
}
};
// We do not add the event_id field to the pdu here because of signature
// and hashes checks
let federation_token = services()
.globals
.roomid_mutex_federation
.lock_key(room_id.clone())
.await;
let mutex = Arc::clone(
services()
.globals
.roomid_mutex_federation
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let mutex_lock = mutex.lock().await;
let start_time = Instant::now();
resolved_map.insert(
event_id.clone(),
@ -779,19 +784,21 @@ pub(crate) async fn send_transaction_message_route(
.await
.map(|_| ()),
);
drop(federation_token);
drop(mutex_lock);
let elapsed = start_time.elapsed();
debug!(
%event_id,
elapsed = ?start_time.elapsed(),
"Finished handling event",
"Handling transaction of event {} took {}m{}s",
event_id,
elapsed.as_secs() / 60,
elapsed.as_secs() % 60
);
}
for pdu in &resolved_map {
if let (event_id, Err(error)) = pdu {
if matches!(error, Error::BadRequest(ErrorKind::NotFound, _)) {
warn!(%error, %event_id, "Incoming PDU failed");
if let Err(e) = pdu.1 {
if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) {
warn!("Incoming PDU failed {:?}", pdu);
}
}
}
@ -853,8 +860,8 @@ pub(crate) async fn send_transaction_message_route(
} else {
// TODO fetch missing events
debug!(
?user_updates,
"No known event ids in read receipt",
"No known event ids in read receipt: {:?}",
user_updates
);
}
}
@ -941,19 +948,16 @@ pub(crate) async fn send_transaction_message_route(
target_user_id,
target_device_id,
&ev_type.to_string(),
event.deserialize_as().map_err(
|error| {
warn!(
%error,
object = ?event.json(),
"To-Device event is invalid",
);
Error::BadRequest(
ErrorKind::InvalidParam,
"Event is invalid",
)
},
)?,
event.deserialize_as().map_err(|e| {
warn!(
"To-Device event is invalid: \
{event:?} {e}"
);
Error::BadRequest(
ErrorKind::InvalidParam,
"Event is invalid",
)
})?,
)?,
DeviceIdOrAllDevices::AllDevices => {
@ -996,12 +1000,7 @@ pub(crate) async fn send_transaction_message_route(
self_signing_key,
}) => {
if user_id.server_name() != sender_servername {
warn!(
%user_id,
%sender_servername,
"Got signing key update from incorrect homeserver, \
ignoring",
);
warn!(%user_id, %sender_servername, "Got signing key update from incorrect homeserver, ignoring");
continue;
}
if let Some(master_key) = master_key {
@ -1041,7 +1040,7 @@ pub(crate) async fn get_event_route(
let event =
services().rooms.timeline.get_pdu_json(&body.event_id)?.ok_or_else(
|| {
warn!(event_id = %body.event_id, "Event not found");
warn!("Event not found, event ID: {:?}", &body.event_id);
Error::BadRequest(ErrorKind::NotFound, "Event not found.")
},
)?;
@ -1094,7 +1093,7 @@ pub(crate) async fn get_backfill_route(
let sender_servername =
body.sender_servername.as_ref().expect("server is authenticated");
debug!(server = %sender_servername, "Got backfill request");
debug!("Got backfill request from: {}", sender_servername);
if !services()
.rooms
@ -1204,10 +1203,9 @@ pub(crate) async fn get_missing_events_route(
if event_room_id != body.room_id {
warn!(
event_id = %queued_events[i],
expected_room_id = %body.room_id,
actual_room_id = %event_room_id,
"Evil event detected"
"Evil event detected: Event {} found while searching in \
room {}",
queued_events[i], body.room_id
);
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
@ -1286,7 +1284,7 @@ pub(crate) async fn get_event_authorization_route(
let event =
services().rooms.timeline.get_pdu_json(&body.event_id)?.ok_or_else(
|| {
warn!(event_id = %body.event_id, "Event not found");
warn!("Event not found, event ID: {:?}", &body.event_id);
Error::BadRequest(ErrorKind::NotFound, "Event not found.")
},
)?;
@ -1371,13 +1369,13 @@ pub(crate) async fn get_room_state_route(
Ok(Ra(get_room_state::v1::Response {
auth_chain: auth_chain_ids
.filter_map(|event_id| {
.filter_map(|id| {
if let Some(json) =
services().rooms.timeline.get_pdu_json(&event_id).ok()?
services().rooms.timeline.get_pdu_json(&id).ok()?
{
Some(PduEvent::convert_to_outgoing_federation_event(json))
} else {
error!(%event_id, "Could not find event JSON for event");
error!("Could not find event json for {id} in db.");
None
}
})
@ -1462,11 +1460,16 @@ pub(crate) async fn create_join_event_template_route(
.event_handler
.acl_check(sender_servername, &body.room_id)?;
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(body.room_id.clone())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(body.room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
// TODO: Grapevine does not implement restricted join rules yet, we always
// reject
@ -1481,8 +1484,8 @@ pub(crate) async fn create_join_event_template_route(
.as_ref()
.map(|join_rules_event| {
serde_json::from_str(join_rules_event.content.get()).map_err(
|error| {
warn!(%error, "Invalid join rules event");
|e| {
warn!("Invalid join rules event: {}", e);
Error::bad_database("Invalid join rules event in db.")
},
)
@ -1534,10 +1537,11 @@ pub(crate) async fn create_join_event_template_route(
redacts: None,
},
&body.user_id,
&room_token,
&body.room_id,
&state_lock,
)?;
drop(room_token);
drop(state_lock);
pdu_json.remove("event_id");
@ -1553,7 +1557,7 @@ async fn create_join_event(
sender_servername: &ServerName,
room_id: &RoomId,
pdu: &RawJsonValue,
) -> Result<create_join_event::v2::RoomState> {
) -> Result<create_join_event::v1::RoomState> {
if !services().rooms.metadata.exists(room_id)? {
return Err(Error::BadRequest(
ErrorKind::NotFound,
@ -1576,8 +1580,8 @@ async fn create_join_event(
.as_ref()
.map(|join_rules_event| {
serde_json::from_str(join_rules_event.content.get()).map_err(
|error| {
warn!(%error, "Invalid join rules event");
|e| {
warn!("Invalid join rules event: {}", e);
Error::bad_database("Invalid join rules event in db.")
},
)
@ -1629,12 +1633,17 @@ async fn create_join_event(
Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid.")
})?;
let federation_token = services()
.globals
.roomid_mutex_federation
.lock_key(room_id.to_owned())
.await;
let pdu_id = services()
let mutex = Arc::clone(
services()
.globals
.roomid_mutex_federation
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let mutex_lock = mutex.lock().await;
let pdu_id: Vec<u8> = services()
.rooms
.event_handler
.handle_incoming_pdu(
@ -1650,7 +1659,7 @@ async fn create_join_event(
ErrorKind::InvalidParam,
"Could not accept incoming PDU as timeline event.",
))?;
drop(federation_token);
drop(mutex_lock);
let state_ids =
services().rooms.state_accessor.state_full_ids(shortstatehash).await?;
@ -1669,7 +1678,7 @@ async fn create_join_event(
services().sending.send_pdu(servers, &pdu_id)?;
Ok(create_join_event::v2::RoomState {
Ok(create_join_event::v1::RoomState {
auth_chain: auth_chain_ids
.filter_map(|id| {
services().rooms.timeline.get_pdu_json(&id).ok().flatten()
@ -1685,32 +1694,20 @@ async fn create_join_event(
.collect(),
// TODO: handle restricted joins
event: None,
members_omitted: false,
servers_in_room: None,
})
}
/// # `PUT /_matrix/federation/v1/send_join/{roomId}/{eventId}`
///
/// Submits a signed join event.
#[allow(deprecated)]
pub(crate) async fn create_join_event_v1_route(
body: Ar<create_join_event::v1::Request>,
) -> Result<Ra<create_join_event::v1::Response>> {
let sender_servername =
body.sender_servername.as_ref().expect("server is authenticated");
let create_join_event::v2::RoomState {
auth_chain,
state,
event,
..
} = create_join_event(sender_servername, &body.room_id, &body.pdu).await?;
let room_state = create_join_event::v1::RoomState {
auth_chain,
state,
event,
};
let room_state =
create_join_event(sender_servername, &body.room_id, &body.pdu).await?;
Ok(Ra(create_join_event::v1::Response {
room_state,
@ -1726,8 +1723,18 @@ pub(crate) async fn create_join_event_v2_route(
let sender_servername =
body.sender_servername.as_ref().expect("server is authenticated");
let room_state =
create_join_event(sender_servername, &body.room_id, &body.pdu).await?;
let create_join_event::v1::RoomState {
auth_chain,
state,
event,
} = create_join_event(sender_servername, &body.room_id, &body.pdu).await?;
let room_state = create_join_event::v2::RoomState {
members_omitted: false,
auth_chain,
state,
event,
servers_in_room: None,
};
Ok(Ra(create_join_event::v2::Response {
room_state,
@ -1837,19 +1844,15 @@ pub(crate) async fn create_invite_route(
event.insert("event_id".to_owned(), "$dummy".into());
let pdu: PduEvent =
serde_json::from_value(event.into()).map_err(|error| {
warn!(%error, "Invalid invite event");
Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event.")
})?;
let pdu: PduEvent = serde_json::from_value(event.into()).map_err(|e| {
warn!("Invalid invite event: {}", e);
Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event.")
})?;
invite_state.push(pdu.to_stripped_state_event());
// If we are active in the room, the remote server will notify us about the
// invite via m.room.member through /send. If we are not in the room, we
// need to manually record the invited state for clients' /sync through
// update_membership(), and send the invite pseudo-PDU to the affected
// appservices.
// join via /send
if !services()
.rooms
.state_cache
@ -1863,24 +1866,6 @@ pub(crate) async fn create_invite_route(
Some(invite_state),
true,
)?;
for appservice in services().appservice.read().await.values() {
if appservice.is_user_match(&invited_user) {
appservice_server::send_request(
appservice.registration.clone(),
ruma::api::appservice::event::push_events::v1::Request {
events: vec![pdu.to_room_event()],
txn_id:
base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(utils::calculate_hash([pdu
.event_id()
.as_bytes()]))
.into(),
},
)
.await?;
}
}
}
Ok(Ra(create_invite::v2::Response {
@ -2051,86 +2036,6 @@ pub(crate) async fn claim_keys_route(
}))
}
/// # `GET /_matrix/federation/v1/media/download/{mediaId}`
///
/// Downloads media owned by a remote homeserver.
pub(crate) async fn media_download_route(
body: Ar<authenticated_media::get_content::v1::Request>,
) -> Result<Ra<authenticated_media::get_content::v1::Response>> {
let mxc = MxcData::new(services().globals.server_name(), &body.media_id)?;
let Some((
crate::service::media::FileMeta {
content_disposition,
content_type,
},
file,
)) = services().media.get(mxc.to_string()).await?
else {
return Err(Error::BadRequest(
ErrorKind::NotYetUploaded,
"Media not found",
));
};
let content_disposition = content_disposition.and_then(|s| {
s.parse().inspect_err(
|error| warn!(%error, "Invalid Content-Disposition in database"),
)
.ok()
});
Ok(Ra(authenticated_media::get_content::v1::Response {
metadata: authenticated_media::ContentMetadata {},
content: authenticated_media::FileOrLocation::File(
authenticated_media::Content {
file,
content_type,
content_disposition,
},
),
}))
}
/// # `GET /_matrix/federation/v1/media/thumbnail/{mediaId}`
///
/// Downloads a thumbnail from a remote homeserver.
pub(crate) async fn media_thumbnail_route(
body: Ar<authenticated_media::get_content_thumbnail::v1::Request>,
) -> Result<Ra<authenticated_media::get_content_thumbnail::v1::Response>> {
let mxc = MxcData::new(services().globals.server_name(), &body.media_id)?;
let width = body.width.try_into().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid.")
})?;
let height = body.height.try_into().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid.")
})?;
let Some((
crate::service::media::FileMeta {
content_type,
..
},
file,
)) = services().media.get_thumbnail(mxc.to_string(), width, height).await?
else {
return Err(Error::BadRequest(
ErrorKind::NotYetUploaded,
"Media not found",
));
};
Ok(Ra(authenticated_media::get_content_thumbnail::v1::Response {
metadata: authenticated_media::ContentMetadata {},
content: authenticated_media::FileOrLocation::File(
authenticated_media::Content {
file,
content_type,
content_disposition: None,
},
),
}))
}
#[cfg(test)]
mod tests {
use super::{add_port_to_hostname, get_ip_with_port, FedDest};

View file

@ -1,64 +0,0 @@
#![warn(missing_docs, clippy::missing_docs_in_private_items)]
//! Handle requests for `/.well-known/matrix/...` files
use http::StatusCode;
use ruma::api::{
client::discovery::discover_homeserver as client,
federation::discovery::discover_homeserver as server,
};
use crate::{services, Ar, Ra};
/// Handler for `/.well-known/matrix/server`
pub(crate) async fn server(
_: Ar<server::Request>,
) -> Result<Ra<server::Response>, StatusCode> {
let Some(authority) =
services().globals.config.server_discovery.server.authority.clone()
else {
return Err(StatusCode::NOT_FOUND);
};
if authority == services().globals.config.server_name {
// Delegation isn't needed in this case
return Err(StatusCode::NOT_FOUND);
}
Ok(Ra(server::Response::new(authority)))
}
/// Handler for `/.well-known/matrix/client`
pub(crate) async fn client(_: Ar<client::Request>) -> Ra<client::Response> {
let authority = services()
.globals
.config
.server_discovery
.client
.authority
.clone()
.unwrap_or_else(|| services().globals.config.server_name.clone());
let scheme = if services().globals.config.server_discovery.client.insecure {
"http"
} else {
"https"
};
let base_url = format!("{scheme}://{authority}");
// I wish ruma used an actual URL type instead of `String`
Ra(client::Response {
homeserver: client::HomeserverInfo::new(base_url.clone()),
identity_server: None,
sliding_sync_proxy: services()
.globals
.config
.server_discovery
.client
.advertise_sliding_sync
.then_some(client::SlidingSyncProxyInfo {
url: base_url,
}),
})
}

View file

@ -1,20 +1,19 @@
use std::{
borrow::Cow,
fmt::{self, Display},
net::{IpAddr, Ipv4Addr},
path::{Path, PathBuf},
};
use once_cell::sync::Lazy;
use ruma::{serde::Base64, signatures::Ed25519KeyPair, OwnedServerName, RoomVersionId};
use serde::{Deserialize, Deserializer};
use ruma::{OwnedServerName, RoomVersionId};
use serde::Deserialize;
use crate::error;
mod env_filter_clone;
mod proxy;
pub(crate) use env_filter_clone::EnvFilterClone;
use env_filter_clone::EnvFilterClone;
use proxy::ProxyConfig;
/// The default configuration file path
@ -26,301 +25,79 @@ pub(crate) static DEFAULT_PATH: Lazy<PathBuf> =
pub(crate) struct Config {
#[serde(default = "false_fn")]
pub(crate) conduit_compat: bool,
#[serde(default = "default_listen")]
pub(crate) listen: Vec<ListenConfig>,
#[serde(default = "default_address")]
pub(crate) address: IpAddr,
#[serde(default = "default_port")]
pub(crate) port: u16,
pub(crate) tls: Option<TlsConfig>,
#[serde(default, deserialize_with = "deserialize_keys_config")]
pub(crate) extra_key: Vec<Ed25519KeyPair>,
/// The name of this homeserver
///
/// This is the value that will appear e.g. in user IDs and room aliases.
pub(crate) server_name: OwnedServerName,
#[serde(default)]
pub(crate) server_discovery: ServerDiscovery,
pub(crate) database: DatabaseConfig,
#[serde(default)]
pub(crate) federation: FederationConfig,
pub(crate) database_backend: String,
pub(crate) database_path: String,
#[cfg(feature = "rocksdb")]
#[serde(default = "default_db_cache_capacity_mb")]
pub(crate) db_cache_capacity_mb: f64,
#[serde(default = "default_cache_capacity_modifier")]
pub(crate) cache_capacity_modifier: f64,
#[cfg(feature = "rocksdb")]
#[serde(default = "default_rocksdb_max_open_files")]
pub(crate) rocksdb_max_open_files: i32,
#[serde(default = "default_pdu_cache_capacity")]
pub(crate) pdu_cache_capacity: u32,
#[serde(default = "default_cleanup_second_interval")]
pub(crate) cleanup_second_interval: u32,
#[serde(default = "default_max_request_size")]
pub(crate) max_request_size: u32,
#[serde(default = "default_max_concurrent_requests")]
pub(crate) max_concurrent_requests: u16,
#[serde(default = "default_max_fetch_prev_events")]
pub(crate) max_fetch_prev_events: u16,
#[serde(default = "false_fn")]
pub(crate) allow_registration: bool,
pub(crate) registration_token: Option<String>,
#[serde(default = "true_fn")]
pub(crate) allow_encryption: bool,
#[serde(default = "true_fn")]
pub(crate) allow_federation: bool,
#[serde(default = "true_fn")]
pub(crate) allow_room_creation: bool,
#[serde(default = "true_fn")]
pub(crate) allow_unstable_room_versions: bool,
#[serde(default = "default_default_room_version")]
pub(crate) default_room_version: RoomVersionId,
#[serde(default = "false_fn")]
pub(crate) allow_jaeger: bool,
#[serde(default = "false_fn")]
pub(crate) allow_prometheus: bool,
#[serde(default = "false_fn")]
pub(crate) tracing_flame: bool,
#[serde(default)]
pub(crate) proxy: ProxyConfig,
pub(crate) jwt_secret: Option<String>,
#[serde(default = "default_trusted_servers")]
pub(crate) trusted_servers: Vec<OwnedServerName>,
#[serde(default = "default_log")]
pub(crate) log: EnvFilterClone,
#[serde(default)]
pub(crate) observability: ObservabilityConfig,
pub(crate) turn_username: String,
#[serde(default)]
pub(crate) turn: TurnConfig,
pub(crate) turn_password: String,
#[serde(default = "Vec::new")]
pub(crate) turn_uris: Vec<String>,
#[serde(default)]
pub(crate) turn_secret: String,
#[serde(default = "default_turn_ttl")]
pub(crate) turn_ttl: u64,
pub(crate) emergency_password: Option<String>,
}
fn deserialize_keys_config<'de, D>(de: D) -> Result<Vec<Ed25519KeyPair>, D::Error> where D: Deserializer<'de> {
use serde::de::Error;
#[derive(Debug, Deserialize)]
struct RawConfig {
key: Base64,
version: String,
}
let raw: Vec<RawConfig> = Deserialize::deserialize(de)?;
raw
.into_iter()
.map(|r| Ed25519KeyPair::from_der(&r.key.into_inner(), r.version).map_err(D::Error::custom))
.collect()
}
#[derive(Debug, Default, Deserialize)]
pub(crate) struct ServerDiscovery {
/// Server-server discovery configuration
#[serde(default)]
pub(crate) server: ServerServerDiscovery,
/// Client-server discovery configuration
#[serde(default)]
pub(crate) client: ClientServerDiscovery,
}
/// Server-server discovery configuration
#[derive(Debug, Default, Deserialize)]
pub(crate) struct ServerServerDiscovery {
/// The alternative authority to make server-server API requests to
pub(crate) authority: Option<OwnedServerName>,
}
/// Client-server discovery configuration
#[derive(Debug, Default, Deserialize)]
pub(crate) struct ClientServerDiscovery {
/// The alternative authority to make client-server API requests to
pub(crate) authority: Option<OwnedServerName>,
/// Controls whether HTTPS is used
#[serde(default)]
pub(crate) insecure: bool,
#[serde(default, rename = "advertise_buggy_sliding_sync")]
pub(crate) advertise_sliding_sync: bool,
}
#[derive(Debug, Deserialize)]
pub(crate) struct TlsConfig {
pub(crate) certs: String,
pub(crate) key: String,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum ListenConfig {
Tcp {
#[serde(default = "default_address")]
address: IpAddr,
#[serde(default = "default_port")]
port: u16,
#[serde(default = "false_fn")]
tls: bool,
},
}
impl Display for ListenConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ListenConfig::Tcp {
address,
port,
tls: false,
} => write!(f, "http://{address}:{port}"),
ListenConfig::Tcp {
address,
port,
tls: true,
} => write!(f, "https://{address}:{port}"),
}
}
}
#[derive(Copy, Clone, Default, Debug, Deserialize)]
#[serde(rename_all = "snake_case")]
pub(crate) enum LogFormat {
/// Use the [`tracing_subscriber::fmt::format::Pretty`] formatter
Pretty,
/// Use the [`tracing_subscriber::fmt::format::Full`] formatter
#[default]
Full,
/// Use the [`tracing_subscriber::fmt::format::Compact`] formatter
Compact,
/// Use the [`tracing_subscriber::fmt::format::Json`] formatter
Json,
}
#[derive(Clone, Debug, Deserialize)]
#[serde(default)]
pub(crate) struct TurnConfig {
pub(crate) username: String,
pub(crate) password: String,
pub(crate) uris: Vec<String>,
pub(crate) secret: String,
pub(crate) ttl: u64,
}
impl Default for TurnConfig {
fn default() -> Self {
Self {
username: String::new(),
password: String::new(),
uris: Vec::new(),
secret: String::new(),
ttl: 60 * 60 * 24,
}
}
}
#[derive(Clone, Copy, Debug, Deserialize)]
#[serde(rename_all = "lowercase")]
pub(crate) enum DatabaseBackend {
#[cfg(feature = "rocksdb")]
Rocksdb,
#[cfg(feature = "sqlite")]
Sqlite,
}
impl Display for DatabaseBackend {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
#[cfg(feature = "rocksdb")]
DatabaseBackend::Rocksdb => write!(f, "RocksDB"),
#[cfg(feature = "sqlite")]
DatabaseBackend::Sqlite => write!(f, "SQLite"),
}
}
}
#[derive(Clone, Debug, Deserialize)]
pub(crate) struct DatabaseConfig {
pub(crate) backend: DatabaseBackend,
pub(crate) path: String,
#[serde(default = "default_db_cache_capacity_mb")]
pub(crate) cache_capacity_mb: f64,
#[cfg(feature = "rocksdb")]
#[serde(default = "default_rocksdb_max_open_files")]
pub(crate) rocksdb_max_open_files: i32,
}
#[derive(Clone, Debug, Default, Deserialize)]
#[serde(default)]
pub(crate) struct MetricsConfig {
pub(crate) enable: bool,
}
#[derive(Debug, Deserialize)]
#[serde(default)]
pub(crate) struct OtelTraceConfig {
pub(crate) enable: bool,
pub(crate) filter: EnvFilterClone,
pub(crate) endpoint: Option<String>,
pub(crate) service_name: String,
}
impl Default for OtelTraceConfig {
fn default() -> Self {
Self {
enable: false,
filter: default_tracing_filter(),
endpoint: None,
service_name: env!("CARGO_PKG_NAME").to_owned(),
}
}
}
#[derive(Debug, Deserialize)]
#[serde(default)]
pub(crate) struct FlameConfig {
pub(crate) enable: bool,
pub(crate) filter: EnvFilterClone,
pub(crate) filename: String,
}
impl Default for FlameConfig {
fn default() -> Self {
Self {
enable: false,
filter: default_tracing_filter(),
filename: "./tracing.folded".to_owned(),
}
}
}
#[derive(Debug, Deserialize)]
#[serde(default)]
pub(crate) struct LogConfig {
pub(crate) filter: EnvFilterClone,
pub(crate) colors: bool,
pub(crate) format: LogFormat,
pub(crate) timestamp: bool,
}
impl Default for LogConfig {
fn default() -> Self {
Self {
filter: default_tracing_filter(),
colors: true,
format: LogFormat::default(),
timestamp: true,
}
}
}
#[derive(Debug, Default, Deserialize)]
#[serde(default)]
pub(crate) struct ObservabilityConfig {
/// Prometheus metrics
pub(crate) metrics: MetricsConfig,
/// OpenTelemetry traces
pub(crate) traces: OtelTraceConfig,
/// Folded inferno stack traces
pub(crate) flame: FlameConfig,
/// Logging to stdout
pub(crate) logs: LogConfig,
}
#[derive(Debug, Deserialize)]
#[serde(default)]
pub(crate) struct FederationConfig {
pub(crate) enable: bool,
pub(crate) trusted_servers: Vec<OwnedServerName>,
pub(crate) max_fetch_prev_events: u16,
pub(crate) max_concurrent_requests: u16,
}
impl Default for FederationConfig {
fn default() -> Self {
Self {
enable: true,
trusted_servers: vec![
OwnedServerName::try_from("matrix.org").unwrap()
],
max_fetch_prev_events: 100,
max_concurrent_requests: 100,
}
}
}
fn false_fn() -> bool {
false
}
@ -329,14 +106,6 @@ fn true_fn() -> bool {
true
}
fn default_listen() -> Vec<ListenConfig> {
vec![ListenConfig::Tcp {
address: default_address(),
port: default_port(),
tls: false,
}]
}
fn default_address() -> IpAddr {
Ipv4Addr::LOCALHOST.into()
}
@ -345,6 +114,7 @@ fn default_port() -> u16 {
6167
}
#[cfg(feature = "rocksdb")]
fn default_db_cache_capacity_mb() -> f64 {
300.0
}
@ -372,12 +142,28 @@ fn default_max_request_size() -> u32 {
20 * 1024 * 1024
}
fn default_tracing_filter() -> EnvFilterClone {
"info,ruma_state_res=warn"
fn default_max_concurrent_requests() -> u16 {
100
}
fn default_max_fetch_prev_events() -> u16 {
100_u16
}
fn default_trusted_servers() -> Vec<OwnedServerName> {
vec![OwnedServerName::try_from("matrix.org").unwrap()]
}
fn default_log() -> EnvFilterClone {
"warn,state_res=warn,_=off"
.parse()
.expect("hardcoded env filter should be valid")
}
fn default_turn_ttl() -> u64 {
60 * 60 * 24
}
// I know, it's a great name
pub(crate) fn default_default_room_version() -> RoomVersionId {
RoomVersionId::V10

View file

@ -25,17 +25,8 @@ use ruma::{
use tracing::{debug, error, info, info_span, warn, Instrument};
use crate::{
config::DatabaseBackend,
observability::FilterReloadHandles,
service::{
media::MediaFileKey,
rooms::{
short::{ShortEventId, ShortStateHash, ShortStateKey},
state_compressor::CompressedStateEvent,
timeline::PduCount,
},
},
services, utils, Config, Error, PduEvent, Result, Services, SERVICES,
service::rooms::timeline::PduCount, services, utils, Config, Error,
PduEvent, Result, Services, SERVICES,
};
pub(crate) struct KeyValueDatabase {
@ -243,14 +234,13 @@ pub(crate) struct KeyValueDatabase {
// Uncategorized trees
pub(super) pdu_cache: Mutex<LruCache<OwnedEventId, Arc<PduEvent>>>,
pub(super) shorteventid_cache: Mutex<LruCache<ShortEventId, Arc<EventId>>>,
pub(super) auth_chain_cache:
Mutex<LruCache<Vec<ShortEventId>, Arc<HashSet<ShortEventId>>>>,
pub(super) eventidshort_cache: Mutex<LruCache<OwnedEventId, ShortEventId>>,
pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>,
pub(super) eventidshort_cache: Mutex<LruCache<OwnedEventId, u64>>,
pub(super) statekeyshort_cache:
Mutex<LruCache<(StateEventType, String), ShortStateKey>>,
Mutex<LruCache<(StateEventType, String), u64>>,
pub(super) shortstatekey_cache:
Mutex<LruCache<ShortStateKey, (StateEventType, String)>>,
Mutex<LruCache<u64, (StateEventType, String)>>,
pub(super) our_real_users_cache:
RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>,
pub(super) appservice_in_room_cache:
@ -260,7 +250,7 @@ pub(crate) struct KeyValueDatabase {
impl KeyValueDatabase {
fn check_db_setup(config: &Config) -> Result<()> {
let path = Path::new(&config.database.path);
let path = Path::new(&config.database_path);
let sqlite_exists = path
.join(format!(
@ -289,22 +279,14 @@ impl KeyValueDatabase {
return Ok(());
}
let (backend_is_rocksdb, backend_is_sqlite): (bool, bool) =
match config.database.backend {
#[cfg(feature = "rocksdb")]
DatabaseBackend::Rocksdb => (true, false),
#[cfg(feature = "sqlite")]
DatabaseBackend::Sqlite => (false, true),
};
if sqlite_exists && !backend_is_sqlite {
if sqlite_exists && config.database_backend != "sqlite" {
return Err(Error::bad_config(
"Found sqlite at database_path, but is not specified in \
config.",
));
}
if rocksdb_exists && !backend_is_rocksdb {
if rocksdb_exists && config.database_backend != "rocksdb" {
return Err(Error::bad_config(
"Found rocksdb at database_path, but is not specified in \
config.",
@ -320,14 +302,11 @@ impl KeyValueDatabase {
allow(unreachable_code)
)]
#[allow(clippy::too_many_lines)]
pub(crate) async fn load_or_create(
config: Config,
reload_handles: FilterReloadHandles,
) -> Result<()> {
pub(crate) async fn load_or_create(config: Config) -> Result<()> {
Self::check_db_setup(&config)?;
if !Path::new(&config.database.path).exists() {
fs::create_dir_all(&config.database.path).map_err(|_| {
if !Path::new(&config.database_path).exists() {
fs::create_dir_all(&config.database_path).map_err(|_| {
Error::BadConfig(
"Database folder doesn't exists and couldn't be created \
(e.g. due to missing permissions). Please create the \
@ -340,18 +319,20 @@ impl KeyValueDatabase {
not(any(feature = "rocksdb", feature = "sqlite")),
allow(unused_variables)
)]
let builder: Arc<dyn KeyValueDatabaseEngine> = match config
.database
.backend
let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config
.database_backend
{
#[cfg(feature = "sqlite")]
DatabaseBackend::Sqlite => {
"sqlite" => {
Arc::new(Arc::<abstraction::sqlite::Engine>::open(&config)?)
}
#[cfg(feature = "rocksdb")]
DatabaseBackend::Rocksdb => {
"rocksdb" => {
Arc::new(Arc::<abstraction::rocksdb::Engine>::open(&config)?)
}
_ => {
return Err(Error::BadConfig("Database backend not found."));
}
};
if config.registration_token == Some(String::new()) {
@ -359,10 +340,7 @@ impl KeyValueDatabase {
}
if config.max_request_size < 1024 {
error!(
?config.max_request_size,
"Max request size is less than 1KB. Please increase it.",
);
error!(?config.max_request_size, "Max request size is less than 1KB. Please increase it.");
}
let db_raw = Box::new(Self {
@ -543,8 +521,7 @@ impl KeyValueDatabase {
let db = Box::leak(db_raw);
let services_raw =
Box::new(Services::build(db, config, reload_handles)?);
let services_raw = Box::new(Services::build(db, config)?);
// This is the first and only time we initialize the SERVICE static
*SERVICES.write().unwrap() = Some(Box::leak(services_raw));
@ -552,11 +529,21 @@ impl KeyValueDatabase {
// Matrix resource ownership is based on the server name; changing it
// requires recreating the database from scratch.
if services().users.count()? > 0 {
let admin_bot = services().globals.admin_bot_user_id.as_ref();
if !services().users.exists(admin_bot)? {
let grapevine_user = UserId::parse_with_server_name(
if services().globals.config.conduit_compat {
"conduit"
} else {
"grapevine"
},
services().globals.server_name(),
)
.expect("admin bot username should be valid");
if !services().users.exists(&grapevine_user)? {
error!(
user_id = %admin_bot,
"The admin bot does not exist and the database is not new",
"The {} server user does not exist, and the database is \
not new.",
grapevine_user
);
return Err(Error::bad_database(
"Cannot reuse an existing database after changing the \
@ -615,7 +602,6 @@ impl KeyValueDatabase {
if services().globals.database_version()? < 3 {
// Move media to filesystem
for (key, content) in db.mediaid_file.iter() {
let key = MediaFileKey::new(key);
if content.is_empty() {
continue;
}
@ -623,7 +609,7 @@ impl KeyValueDatabase {
let path = services().globals.get_media_file(&key);
let mut file = fs::File::create(path)?;
file.write_all(&content)?;
db.mediaid_file.insert(key.as_bytes(), &[])?;
db.mediaid_file.insert(&key, &[])?;
}
services().globals.bump_database_version(3)?;
@ -703,15 +689,15 @@ impl KeyValueDatabase {
if services().globals.database_version()? < 7 {
// Upgrade state store
let mut last_roomstates: HashMap<OwnedRoomId, ShortStateHash> =
let mut last_roomstates: HashMap<OwnedRoomId, u64> =
HashMap::new();
let mut current_sstatehash: Option<ShortStateHash> = None;
let mut current_sstatehash: Option<u64> = None;
let mut current_room = None;
let mut current_state = HashSet::new();
let mut counter = 0;
let mut handle_state =
|current_sstatehash: ShortStateHash,
|current_sstatehash: u64,
current_room: &RoomId,
current_state: HashSet<_>,
last_roomstates: &mut HashMap<_, _>| {
@ -770,14 +756,10 @@ impl KeyValueDatabase {
for (k, seventid) in
db.db.open_tree("stateid_shorteventid")?.iter()
{
let sstatehash = ShortStateHash::new(
let sstatehash =
utils::u64_from_bytes(&k[0..size_of::<u64>()])
.expect("number of bytes is correct"),
);
let sstatekey = ShortStateKey::new(
utils::u64_from_bytes(&k[size_of::<u64>()..])
.expect("number of bytes is correct"),
);
.expect("number of bytes is correct");
let sstatekey = k[size_of::<u64>()..].to_vec();
if Some(sstatehash) != current_sstatehash {
if let Some(current_sstatehash) = current_sstatehash {
handle_state(
@ -815,14 +797,10 @@ impl KeyValueDatabase {
}
}
let seventid = ShortEventId::new(
utils::u64_from_bytes(&seventid)
.expect("number of bytes is correct"),
);
current_state.insert(CompressedStateEvent {
state: sstatekey,
event: seventid,
});
let mut val = sstatekey;
val.extend_from_slice(&seventid);
current_state
.insert(val.try_into().expect("size is correct"));
}
if let Some(current_sstatehash) = current_sstatehash {
@ -982,12 +960,8 @@ impl KeyValueDatabase {
services().globals.server_name(),
) {
Ok(u) => u,
Err(error) => {
warn!(
%error,
user_localpart = %username,
"Invalid username",
);
Err(e) => {
warn!("Invalid username {username}: {e}");
continue;
}
};
@ -1087,12 +1061,8 @@ impl KeyValueDatabase {
services().globals.server_name(),
) {
Ok(u) => u,
Err(error) => {
warn!(
%error,
user_localpart = %username,
"Invalid username",
);
Err(e) => {
warn!("Invalid username {username}: {e}");
continue;
}
};
@ -1144,9 +1114,9 @@ impl KeyValueDatabase {
);
info!(
backend = %services().globals.config.database.backend,
version = latest_database_version,
"Loaded database",
"Loaded {} database with version {}",
services().globals.config.database_backend,
latest_database_version
);
} else {
services()
@ -1156,10 +1126,10 @@ impl KeyValueDatabase {
// Create the admin room and server user on first run
services().admin.create_admin_room().await?;
info!(
backend = %services().globals.config.database.backend,
version = latest_database_version,
"Created new database",
warn!(
"Created new {} database with version {}",
services().globals.config.database_backend,
latest_database_version
);
}
@ -1183,11 +1153,11 @@ impl KeyValueDatabase {
);
}
}
Err(error) => {
Err(e) => {
error!(
%error,
"Could not set the configured emergency password for the \
Grapevine user",
grapevine user: {}",
e
);
}
};
@ -1235,10 +1205,10 @@ impl KeyValueDatabase {
async {
msg();
let start = Instant::now();
if let Err(error) = services().globals.cleanup() {
error!(%error, "cleanup: Error");
if let Err(e) = services().globals.cleanup() {
error!("cleanup: Errored: {}", e);
} else {
debug!(elapsed = ?start.elapsed(), "cleanup: Finished");
debug!("cleanup: Finished in {:?}", start.elapsed());
}
}
.instrument(info_span!("database_cleanup"))
@ -1251,21 +1221,25 @@ impl KeyValueDatabase {
/// Sets the emergency password and push rules for the @grapevine account in
/// case emergency password is set
fn set_emergency_access() -> Result<bool> {
let admin_bot = services().globals.admin_bot_user_id.as_ref();
let grapevine_user = UserId::parse_with_server_name(
"grapevine",
services().globals.server_name(),
)
.expect("@grapevine:server_name is a valid UserId");
services().users.set_password(
admin_bot,
&grapevine_user,
services().globals.emergency_password().as_deref(),
)?;
let (ruleset, res) = match services().globals.emergency_password() {
Some(_) => (Ruleset::server_default(admin_bot), Ok(true)),
Some(_) => (Ruleset::server_default(&grapevine_user), Ok(true)),
None => (Ruleset::new(), Ok(false)),
};
services().account_data.update(
None,
admin_bot,
&grapevine_user,
GlobalAccountDataEventType::PushRules.to_string().into(),
&serde_json::to_value(&GlobalAccountDataEvent {
content: PushRulesEventContent {

View file

@ -11,7 +11,6 @@ use rocksdb::{
DBRecoveryMode, DBWithThreadMode, Direction, IteratorMode, MultiThreaded,
Options, ReadOptions, WriteOptions,
};
use tracing::Level;
use super::{
super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree,
@ -78,36 +77,32 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
clippy::cast_possible_truncation
)]
let cache_capacity_bytes =
(config.database.cache_capacity_mb * 1024.0 * 1024.0) as usize;
(config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize;
let rocksdb_cache = Cache::new_lru_cache(cache_capacity_bytes);
let db_opts =
db_options(config.database.rocksdb_max_open_files, &rocksdb_cache);
let db_opts = db_options(config.rocksdb_max_open_files, &rocksdb_cache);
let cfs = DBWithThreadMode::<MultiThreaded>::list_cf(
&db_opts,
&config.database.path,
&config.database_path,
)
.map(|x| x.into_iter().collect::<HashSet<_>>())
.unwrap_or_default();
let db = DBWithThreadMode::<MultiThreaded>::open_cf_descriptors(
&db_opts,
&config.database.path,
&config.database_path,
cfs.iter().map(|name| {
ColumnFamilyDescriptor::new(
name,
db_options(
config.database.rocksdb_max_open_files,
&rocksdb_cache,
),
db_options(config.rocksdb_max_open_files, &rocksdb_cache),
)
}),
)?;
Ok(Arc::new(Engine {
rocks: db,
max_open_files: config.database.rocksdb_max_open_files,
max_open_files: config.rocksdb_max_open_files,
cache: rocksdb_cache,
old_cfs: cfs,
new_cfs: Mutex::default(),
@ -170,14 +165,12 @@ impl RocksDbEngineTree<'_> {
}
impl KvTree for RocksDbEngineTree<'_> {
#[tracing::instrument(level = Level::TRACE, skip_all)]
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
let readoptions = ReadOptions::default();
Ok(self.db.rocks.get_cf_opt(&self.cf(), key, &readoptions)?)
}
#[tracing::instrument(level = Level::TRACE, skip_all)]
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
let writeoptions = WriteOptions::default();
let lock = self.write_lock.read().unwrap();
@ -189,7 +182,6 @@ impl KvTree for RocksDbEngineTree<'_> {
Ok(())
}
#[tracing::instrument(level = Level::TRACE, skip_all)]
fn insert_batch(
&self,
iter: &mut dyn Iterator<Item = (Vec<u8>, Vec<u8>)>,
@ -202,13 +194,11 @@ impl KvTree for RocksDbEngineTree<'_> {
Ok(())
}
#[tracing::instrument(level = Level::TRACE, skip_all)]
fn remove(&self, key: &[u8]) -> Result<()> {
let writeoptions = WriteOptions::default();
Ok(self.db.rocks.delete_cf_opt(&self.cf(), key, &writeoptions)?)
}
#[tracing::instrument(level = Level::TRACE, skip_all)]
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
let readoptions = ReadOptions::default();
@ -221,7 +211,6 @@ impl KvTree for RocksDbEngineTree<'_> {
)
}
#[tracing::instrument(level = Level::TRACE, skip_all)]
fn iter_from<'a>(
&'a self,
from: &[u8],
@ -249,7 +238,6 @@ impl KvTree for RocksDbEngineTree<'_> {
)
}
#[tracing::instrument(level = Level::TRACE, skip_all)]
fn increment(&self, key: &[u8]) -> Result<Vec<u8>> {
let readoptions = ReadOptions::default();
let writeoptions = WriteOptions::default();
@ -264,7 +252,6 @@ impl KvTree for RocksDbEngineTree<'_> {
Ok(new)
}
#[tracing::instrument(level = Level::TRACE, skip_all)]
fn increment_batch(
&self,
iter: &mut dyn Iterator<Item = Vec<u8>>,
@ -286,7 +273,6 @@ impl KvTree for RocksDbEngineTree<'_> {
Ok(())
}
#[tracing::instrument(level = Level::TRACE, skip_all)]
fn scan_prefix<'a>(
&'a self,
prefix: Vec<u8>,
@ -307,7 +293,6 @@ impl KvTree for RocksDbEngineTree<'_> {
)
}
#[tracing::instrument(level = Level::TRACE, skip_all)]
fn watch_prefix<'a>(
&'a self,
prefix: &[u8],

View file

@ -110,7 +110,7 @@ impl Engine {
impl KeyValueDatabaseEngine for Arc<Engine> {
fn open(config: &Config) -> Result<Self> {
let path = Path::new(&config.database.path).join(format!(
let path = Path::new(&config.database_path).join(format!(
"{}.db",
if config.conduit_compat {
"conduit"
@ -130,9 +130,9 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
clippy::cast_precision_loss,
clippy::cast_sign_loss
)]
let cache_size_per_thread =
((config.database.cache_capacity_mb * 1024.0)
/ ((num_cpus::get() as f64 * 2.0) + 1.0)) as u32;
let cache_size_per_thread = ((config.db_cache_capacity_mb * 1024.0)
/ ((num_cpus::get() as f64 * 2.0) + 1.0))
as u32;
let writer =
Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?);

View file

@ -74,7 +74,6 @@ impl service::globals::Data for KeyValueDatabase {
.ok()
.flatten()
.expect("room exists")
.get()
.to_be_bytes()
.to_vec();

View file

@ -1,13 +1,6 @@
use ruma::api::client::error::ErrorKind;
use crate::{
database::KeyValueDatabase,
service::{
self,
media::{FileMeta, MediaFileKey},
},
utils, Error, Result,
};
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::media::Data for KeyValueDatabase {
fn create_file_metadata(
@ -15,30 +8,26 @@ impl service::media::Data for KeyValueDatabase {
mxc: String,
width: u32,
height: u32,
meta: &FileMeta,
) -> Result<MediaFileKey> {
content_disposition: Option<&str>,
content_type: Option<&str>,
) -> Result<Vec<u8>> {
let mut key = mxc.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(&width.to_be_bytes());
key.extend_from_slice(&height.to_be_bytes());
key.push(0xFF);
key.extend_from_slice(
meta.content_disposition
content_disposition
.as_ref()
.map(String::as_bytes)
.map(|f| f.as_bytes())
.unwrap_or_default(),
);
key.push(0xFF);
key.extend_from_slice(
meta.content_type
.as_ref()
.map(String::as_bytes)
.unwrap_or_default(),
content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default(),
);
let key = MediaFileKey::new(key);
self.mediaid_file.insert(key.as_bytes(), &[])?;
self.mediaid_file.insert(&key, &[])?;
Ok(key)
}
@ -48,7 +37,7 @@ impl service::media::Data for KeyValueDatabase {
mxc: String,
width: u32,
height: u32,
) -> Result<(FileMeta, MediaFileKey)> {
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
let mut prefix = mxc.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(&width.to_be_bytes());
@ -60,9 +49,7 @@ impl service::media::Data for KeyValueDatabase {
Error::BadRequest(ErrorKind::NotFound, "Media not found"),
)?;
let key = MediaFileKey::new(key);
let mut parts = key.as_bytes().rsplit(|&b| b == 0xFF);
let mut parts = key.rsplit(|&b| b == 0xFF);
let content_type = parts
.next()
@ -91,12 +78,6 @@ impl service::media::Data for KeyValueDatabase {
},
)?)
};
Ok((
FileMeta {
content_disposition,
content_type,
},
key,
))
Ok((content_disposition, content_type, key))
}
}

View file

@ -3,16 +3,15 @@ use std::{collections::HashSet, mem::size_of, sync::Arc};
use crate::{
database::KeyValueDatabase,
observability::{FoundIn, Lookup, METRICS},
service::{self, rooms::short::ShortEventId},
utils, Result,
service, utils, Result,
};
impl service::rooms::auth_chain::Data for KeyValueDatabase {
#[tracing::instrument(skip(self, key))]
fn get_cached_eventid_authchain(
&self,
key: &[ShortEventId],
) -> Result<Option<Arc<HashSet<ShortEventId>>>> {
key: &[u64],
) -> Result<Option<Arc<HashSet<u64>>>> {
let lookup = Lookup::AuthChain;
// Check RAM cache
@ -27,15 +26,13 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
// Check DB cache
let chain = self
.shorteventid_authchain
.get(&key[0].get().to_be_bytes())?
.get(&key[0].to_be_bytes())?
.map(|chain| {
chain
.chunks_exact(size_of::<u64>())
.map(|chunk| {
ShortEventId::new(
utils::u64_from_bytes(chunk)
.expect("byte length is correct"),
)
utils::u64_from_bytes(chunk)
.expect("byte length is correct")
})
.collect()
});
@ -60,16 +57,16 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
fn cache_auth_chain(
&self,
key: Vec<ShortEventId>,
auth_chain: Arc<HashSet<ShortEventId>>,
key: Vec<u64>,
auth_chain: Arc<HashSet<u64>>,
) -> Result<()> {
// Only persist single events in db
if key.len() == 1 {
self.shorteventid_authchain.insert(
&key[0].get().to_be_bytes(),
&key[0].to_be_bytes(),
&auth_chain
.iter()
.flat_map(|s| s.get().to_be_bytes().to_vec())
.flat_map(|s| s.to_be_bytes().to_vec())
.collect::<Vec<u8>>(),
)?;
}

View file

@ -1,3 +1,5 @@
use std::mem;
use ruma::{
events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject,
OwnedUserId, RoomId, UserId,
@ -81,7 +83,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(k, v)| {
let count = utils::u64_from_bytes(
&k[prefix.len()..prefix.len() + size_of::<u64>()],
&k[prefix.len()..prefix.len() + mem::size_of::<u64>()],
)
.map_err(|_| {
Error::bad_database(
@ -90,7 +92,7 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
})?;
let user_id = UserId::parse(
utils::string_from_bytes(
&k[prefix.len() + size_of::<u64>() + 1..],
&k[prefix.len() + mem::size_of::<u64>() + 1..],
)
.map_err(|_| {
Error::bad_database(

View file

@ -8,7 +8,7 @@ impl service::rooms::metadata::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))]
fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
Some(b) => b.get().to_be_bytes().to_vec(),
Some(b) => b.to_be_bytes().to_vec(),
None => return Ok(false),
};

View file

@ -1,16 +1,10 @@
use std::sync::Arc;
use std::{mem, sync::Arc};
use ruma::{EventId, RoomId, UserId};
use crate::{
database::KeyValueDatabase,
service::{
self,
rooms::{
short::ShortRoomId,
timeline::{PduCount, PduId},
},
},
service::{self, rooms::timeline::PduCount},
services, utils, Error, PduEvent, Result,
};
@ -25,7 +19,7 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
fn relations_until<'a>(
&'a self,
user_id: &'a UserId,
shortroomid: ShortRoomId,
shortroomid: u64,
target: u64,
until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>
@ -47,17 +41,15 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(tofrom, _data)| {
let from =
utils::u64_from_bytes(&tofrom[(size_of::<u64>())..])
.map_err(|_| {
Error::bad_database(
"Invalid count in tofrom_relation.",
)
})?;
let from = utils::u64_from_bytes(
&tofrom[(mem::size_of::<u64>())..],
)
.map_err(|_| {
Error::bad_database("Invalid count in tofrom_relation.")
})?;
let mut pduid = shortroomid.get().to_be_bytes().to_vec();
let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes());
let pduid = PduId::new(pduid);
let mut pdu = services()
.rooms

View file

@ -1,13 +1,6 @@
use ruma::RoomId;
use crate::{
database::KeyValueDatabase,
service::{
self,
rooms::{short::ShortRoomId, timeline::PduId},
},
services, utils, Result,
};
use crate::{database::KeyValueDatabase, service, services, utils, Result};
/// Splits a string into tokens used as keys in the search inverted index
///
@ -24,16 +17,16 @@ impl service::rooms::search::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))]
fn index_pdu(
&self,
shortroomid: ShortRoomId,
pdu_id: &PduId,
shortroomid: u64,
pdu_id: &[u8],
message_body: &str,
) -> Result<()> {
let mut batch = tokenize(message_body).map(|word| {
let mut key = shortroomid.get().to_be_bytes().to_vec();
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes());
key.push(0xFF);
// TODO: currently we save the room id a second time here
key.extend_from_slice(pdu_id.as_bytes());
key.extend_from_slice(pdu_id);
(key, Vec::new())
});
@ -43,16 +36,16 @@ impl service::rooms::search::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))]
fn deindex_pdu(
&self,
shortroomid: ShortRoomId,
pdu_id: &PduId,
shortroomid: u64,
pdu_id: &[u8],
message_body: &str,
) -> Result<()> {
let batch = tokenize(message_body).map(|word| {
let mut key = shortroomid.get().to_be_bytes().to_vec();
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes());
key.push(0xFF);
// TODO: currently we save the room id a second time here
key.extend_from_slice(pdu_id.as_bytes());
key.extend_from_slice(pdu_id);
key
});
@ -69,14 +62,13 @@ impl service::rooms::search::Data for KeyValueDatabase {
&'a self,
room_id: &RoomId,
search_string: &str,
) -> Result<Option<(Box<dyn Iterator<Item = PduId> + 'a>, Vec<String>)>>
) -> Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>
{
let prefix = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists")
.get()
.to_be_bytes()
.to_vec();
@ -95,14 +87,12 @@ impl service::rooms::search::Data for KeyValueDatabase {
// Newest pdus first
.iter_from(&last_possible_id, true)
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(key, _)| PduId::new(key[prefix3.len()..].to_vec()))
.map(move |(key, _)| key[prefix3.len()..].to_vec())
});
// We compare b with a because we reversed the iterator earlier
let Some(common_elements) =
utils::common_elements(iterators, |a, b| {
b.as_bytes().cmp(a.as_bytes())
})
utils::common_elements(iterators, |a, b| b.cmp(a))
else {
return Ok(None);
};

View file

@ -5,21 +5,12 @@ use ruma::{events::StateEventType, EventId, RoomId};
use crate::{
database::KeyValueDatabase,
observability::{FoundIn, Lookup, METRICS},
service::{
self,
rooms::short::{
ShortEventId, ShortRoomId, ShortStateHash, ShortStateKey,
},
},
services, utils, Error, Result,
service, services, utils, Error, Result,
};
impl service::rooms::short::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))]
fn get_or_create_shorteventid(
&self,
event_id: &EventId,
) -> Result<ShortEventId> {
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
let lookup = Lookup::CreateEventIdToShort;
if let Some(short) =
@ -48,8 +39,6 @@ impl service::rooms::short::Data for KeyValueDatabase {
shorteventid
};
let short = ShortEventId::new(short);
self.eventidshort_cache
.lock()
.unwrap()
@ -63,7 +52,7 @@ impl service::rooms::short::Data for KeyValueDatabase {
&self,
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<ShortStateKey>> {
) -> Result<Option<u64>> {
let lookup = Lookup::StateKeyToShort;
if let Some(short) = self
@ -84,11 +73,9 @@ impl service::rooms::short::Data for KeyValueDatabase {
.statekey_shortstatekey
.get(&db_key)?
.map(|shortstatekey| {
utils::u64_from_bytes(&shortstatekey)
.map_err(|_| {
Error::bad_database("Invalid shortstatekey in db.")
})
.map(ShortStateKey::new)
utils::u64_from_bytes(&shortstatekey).map_err(|_| {
Error::bad_database("Invalid shortstatekey in db.")
})
})
.transpose()?;
@ -111,7 +98,7 @@ impl service::rooms::short::Data for KeyValueDatabase {
&self,
event_type: &StateEventType,
state_key: &str,
) -> Result<ShortStateKey> {
) -> Result<u64> {
let lookup = Lookup::CreateStateKeyToShort;
if let Some(short) = self
@ -147,8 +134,6 @@ impl service::rooms::short::Data for KeyValueDatabase {
shortstatekey
};
let short = ShortStateKey::new(short);
self.statekeyshort_cache
.lock()
.unwrap()
@ -160,7 +145,7 @@ impl service::rooms::short::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))]
fn get_eventid_from_short(
&self,
shorteventid: ShortEventId,
shorteventid: u64,
) -> Result<Arc<EventId>> {
let lookup = Lookup::ShortToEventId;
@ -173,7 +158,7 @@ impl service::rooms::short::Data for KeyValueDatabase {
let bytes = self
.shorteventid_eventid
.get(&shorteventid.get().to_be_bytes())?
.get(&shorteventid.to_be_bytes())?
.ok_or_else(|| {
Error::bad_database("Shorteventid does not exist")
})?;
@ -202,7 +187,7 @@ impl service::rooms::short::Data for KeyValueDatabase {
#[tracing::instrument(skip(self))]
fn get_statekey_from_short(
&self,
shortstatekey: ShortStateKey,
shortstatekey: u64,
) -> Result<(StateEventType, String)> {
let lookup = Lookup::ShortToStateKey;
@ -215,7 +200,7 @@ impl service::rooms::short::Data for KeyValueDatabase {
let bytes = self
.shortstatekey_statekey
.get(&shortstatekey.get().to_be_bytes())?
.get(&shortstatekey.to_be_bytes())?
.ok_or_else(|| {
Error::bad_database("Shortstatekey does not exist")
})?;
@ -259,56 +244,51 @@ impl service::rooms::short::Data for KeyValueDatabase {
fn get_or_create_shortstatehash(
&self,
state_hash: &[u8],
) -> Result<(ShortStateHash, bool)> {
let (short, existed) = if let Some(shortstatehash) =
self.statehash_shortstatehash.get(state_hash)?
{
(
utils::u64_from_bytes(&shortstatehash).map_err(|_| {
Error::bad_database("Invalid shortstatehash in db.")
})?,
true,
)
} else {
let shortstatehash = services().globals.next_count()?;
self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false)
};
Ok((ShortStateHash::new(short), existed))
) -> Result<(u64, bool)> {
Ok(
if let Some(shortstatehash) =
self.statehash_shortstatehash.get(state_hash)?
{
(
utils::u64_from_bytes(&shortstatehash).map_err(|_| {
Error::bad_database("Invalid shortstatehash in db.")
})?,
true,
)
} else {
let shortstatehash = services().globals.next_count()?;
self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false)
},
)
}
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<ShortRoomId>> {
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortroomid
.get(room_id.as_bytes())?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| {
Error::bad_database("Invalid shortroomid in db.")
})
.map(ShortRoomId::new)
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortroomid in db.")
})
})
.transpose()
}
fn get_or_create_shortroomid(
&self,
room_id: &RoomId,
) -> Result<ShortRoomId> {
let short = if let Some(short) =
self.roomid_shortroomid.get(room_id.as_bytes())?
{
utils::u64_from_bytes(&short).map_err(|_| {
Error::bad_database("Invalid shortroomid in db.")
})?
} else {
let short = services().globals.next_count()?;
self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
short
};
Ok(ShortRoomId::new(short))
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
Ok(
if let Some(short) =
self.roomid_shortroomid.get(room_id.as_bytes())?
{
utils::u64_from_bytes(&short).map_err(|_| {
Error::bad_database("Invalid shortroomid in db.")
})?
} else {
let short = services().globals.next_count()?;
self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
short
},
)
}
}

View file

@ -1,57 +1,44 @@
use std::{collections::HashSet, sync::Arc};
use ruma::{EventId, OwnedEventId, OwnedRoomId, RoomId};
use ruma::{EventId, OwnedEventId, RoomId};
use tokio::sync::MutexGuard;
use crate::{
database::KeyValueDatabase,
service::{
self,
globals::marker,
rooms::short::{ShortEventId, ShortStateHash},
},
utils::{self, on_demand_hashmap::KeyToken},
Error, Result,
};
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::rooms::state::Data for KeyValueDatabase {
fn get_room_shortstatehash(
&self,
room_id: &RoomId,
) -> Result<Option<ShortStateHash>> {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash.get(room_id.as_bytes())?.map_or(
Ok(None),
|bytes| {
Ok(Some(ShortStateHash::new(
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database(
"Invalid shortstatehash in roomid_shortstatehash",
)
})?,
)))
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database(
"Invalid shortstatehash in roomid_shortstatehash",
)
})?))
},
)
}
fn set_room_state(
&self,
room_id: &KeyToken<OwnedRoomId, marker::State>,
new_shortstatehash: ShortStateHash,
room_id: &RoomId,
new_shortstatehash: u64,
// Take mutex guard to make sure users get the room state mutex
_mutex_lock: &MutexGuard<'_, ()>,
) -> Result<()> {
self.roomid_shortstatehash.insert(
room_id.as_bytes(),
&new_shortstatehash.get().to_be_bytes(),
)?;
self.roomid_shortstatehash
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(())
}
fn set_event_state(
&self,
shorteventid: ShortEventId,
shortstatehash: ShortStateHash,
shorteventid: u64,
shortstatehash: u64,
) -> Result<()> {
self.shorteventid_shortstatehash.insert(
&shorteventid.get().to_be_bytes(),
&shortstatehash.get().to_be_bytes(),
&shorteventid.to_be_bytes(),
&shortstatehash.to_be_bytes(),
)?;
Ok(())
}
@ -84,8 +71,10 @@ impl service::rooms::state::Data for KeyValueDatabase {
fn set_forward_extremities(
&self,
room_id: &KeyToken<OwnedRoomId, marker::State>,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
// Take mutex guard to make sure users get the room state mutex
_mutex_lock: &MutexGuard<'_, ()>,
) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);

View file

@ -4,20 +4,16 @@ use async_trait::async_trait;
use ruma::{events::StateEventType, EventId, RoomId};
use crate::{
database::KeyValueDatabase,
service::{
self,
rooms::short::{ShortStateHash, ShortStateKey},
},
services, utils, Error, PduEvent, Result,
database::KeyValueDatabase, service, services, utils, Error, PduEvent,
Result,
};
#[async_trait]
impl service::rooms::state_accessor::Data for KeyValueDatabase {
async fn state_full_ids(
&self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<ShortStateKey, Arc<EventId>>> {
shortstatehash: u64,
) -> Result<HashMap<u64, Arc<EventId>>> {
let full_state = services()
.rooms
.state_compressor
@ -44,7 +40,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
async fn state_full(
&self,
shortstatehash: ShortStateHash,
shortstatehash: u64,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
let full_state = services()
.rooms
@ -91,7 +87,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
/// `state_key`).
fn state_get_id(
&self,
shortstatehash: ShortStateHash,
shortstatehash: u64,
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<Arc<EventId>>> {
@ -109,7 +105,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
.full_state;
Ok(full_state
.iter()
.find(|compressed| compressed.state == shortstatekey)
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| {
services()
.rooms
@ -124,7 +120,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
/// `state_key`).
fn state_get(
&self,
shortstatehash: ShortStateHash,
shortstatehash: u64,
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
@ -135,24 +131,19 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
}
/// Returns the state hash for this pdu.
fn pdu_shortstatehash(
&self,
event_id: &EventId,
) -> Result<Option<ShortStateHash>> {
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
self.eventid_shorteventid.get(event_id.as_bytes())?.map_or(
Ok(None),
|shorteventid| {
self.shorteventid_shortstatehash
.get(&shorteventid)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| {
Error::bad_database(
"Invalid shortstatehash bytes in \
shorteventid_shortstatehash",
)
})
.map(ShortStateHash::new)
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database(
"Invalid shortstatehash bytes in \
shorteventid_shortstatehash",
)
})
})
.transpose()
},

View file

@ -2,28 +2,19 @@ use std::{collections::HashSet, mem::size_of, sync::Arc};
use crate::{
database::KeyValueDatabase,
service::{
self,
rooms::{
short::ShortStateHash,
state_compressor::{data::StateDiff, CompressedStateEvent},
},
},
service::{self, rooms::state_compressor::data::StateDiff},
utils, Error, Result,
};
impl service::rooms::state_compressor::Data for KeyValueDatabase {
fn get_statediff(
&self,
shortstatehash: ShortStateHash,
) -> Result<StateDiff> {
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
let value = self
.shortstatehash_statediff
.get(&shortstatehash.get().to_be_bytes())?
.get(&shortstatehash.to_be_bytes())?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()])
.expect("bytes have right length");
let parent = (parent != 0).then_some(ShortStateHash::new(parent));
let parent = (parent != 0).then_some(parent);
let mut add_mode = true;
let mut added = HashSet::new();
@ -37,13 +28,10 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase {
continue;
}
if add_mode {
added.insert(CompressedStateEvent::from_bytes(
v.try_into().expect("we checked the size above"),
));
added.insert(v.try_into().expect("we checked the size above"));
} else {
removed.insert(CompressedStateEvent::from_bytes(
v.try_into().expect("we checked the size above"),
));
removed
.insert(v.try_into().expect("we checked the size above"));
}
i += 2 * size_of::<u64>();
}
@ -57,23 +45,22 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase {
fn save_statediff(
&self,
shortstatehash: ShortStateHash,
shortstatehash: u64,
diff: StateDiff,
) -> Result<()> {
let mut value =
diff.parent.map_or(0, |h| h.get()).to_be_bytes().to_vec();
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
for new in diff.added.iter() {
value.extend_from_slice(&new.as_bytes());
value.extend_from_slice(&new[..]);
}
if !diff.removed.is_empty() {
value.extend_from_slice(&0_u64.to_be_bytes());
for removed in diff.removed.iter() {
value.extend_from_slice(&removed.as_bytes());
value.extend_from_slice(&removed[..]);
}
}
self.shortstatehash_statediff
.insert(&shortstatehash.get().to_be_bytes(), &value)
.insert(&shortstatehash.to_be_bytes(), &value)
}
}

View file

@ -1,12 +1,13 @@
use std::mem;
use ruma::{
api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId,
UserId,
};
use crate::{
database::KeyValueDatabase,
service::{self, rooms::timeline::PduId},
services, utils, Error, PduEvent, Result,
database::KeyValueDatabase, service, services, utils, Error, PduEvent,
Result,
};
impl service::rooms::threads::Data for KeyValueDatabase {
@ -22,7 +23,6 @@ impl service::rooms::threads::Data for KeyValueDatabase {
.short
.get_shortroomid(room_id)?
.expect("room exists")
.get()
.to_be_bytes()
.to_vec();
@ -34,16 +34,14 @@ impl service::rooms::threads::Data for KeyValueDatabase {
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pduid, _users)| {
let count =
utils::u64_from_bytes(&pduid[(size_of::<u64>())..])
.map_err(|_| {
Error::bad_database(
"Invalid pduid in threadid_userids.",
)
})?;
let pduid = PduId::new(pduid);
let count = utils::u64_from_bytes(
&pduid[(mem::size_of::<u64>())..],
)
.map_err(|_| {
Error::bad_database(
"Invalid pduid in threadid_userids.",
)
})?;
let mut pdu = services()
.rooms
.timeline
@ -63,7 +61,7 @@ impl service::rooms::threads::Data for KeyValueDatabase {
fn update_participants(
&self,
root_id: &PduId,
root_id: &[u8],
participants: &[OwnedUserId],
) -> Result<()> {
let users = participants
@ -72,16 +70,16 @@ impl service::rooms::threads::Data for KeyValueDatabase {
.collect::<Vec<_>>()
.join(&[0xFF][..]);
self.threadid_userids.insert(root_id.as_bytes(), &users)?;
self.threadid_userids.insert(root_id, &users)?;
Ok(())
}
fn get_participants(
&self,
root_id: &PduId,
root_id: &[u8],
) -> Result<Option<Vec<OwnedUserId>>> {
if let Some(users) = self.threadid_userids.get(root_id.as_bytes())? {
if let Some(users) = self.threadid_userids.get(root_id)? {
Ok(Some(
users
.split(|b| *b == 0xFF)

View file

@ -10,8 +10,7 @@ use tracing::error;
use crate::{
database::KeyValueDatabase,
observability::{FoundIn, Lookup, METRICS},
service::{self, rooms::timeline::PduId},
services, utils, Error, PduEvent, Result,
service, services, utils, Error, PduEvent, Result,
};
impl service::rooms::timeline::Data for KeyValueDatabase {
@ -32,12 +31,11 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
hash_map::Entry::Vacant(v) => {
if let Some(last_count) = self
.pdus_until(sender_user, room_id, PduCount::MAX)?
.find_map(|x| match x {
Ok(x) => Some(x),
Err(error) => {
error!(%error, "Bad pdu in pdus_since");
None
.find_map(|r| {
if r.is_err() {
error!("Bad pdu in pdus_since: {:?}", r);
}
r.ok()
})
{
METRICS.record_lookup(lookup, FoundIn::Database);
@ -103,8 +101,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
}
/// Returns the pdu's id.
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<PduId>> {
self.eventid_pduid.get(event_id.as_bytes()).map(|x| x.map(PduId::new))
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> {
self.eventid_pduid.get(event_id.as_bytes())
}
/// Returns the pdu.
@ -171,8 +169,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
fn get_pdu_from_id(&self, pdu_id: &PduId) -> Result<Option<PduEvent>> {
self.pduid_pdu.get(pdu_id.as_bytes())?.map_or(Ok(None), |pdu| {
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
@ -183,9 +181,9 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
fn get_pdu_json_from_id(
&self,
pdu_id: &PduId,
pdu_id: &[u8],
) -> Result<Option<CanonicalJsonObject>> {
self.pduid_pdu.get(pdu_id.as_bytes())?.map_or(Ok(None), |pdu| {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
@ -195,13 +193,13 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
fn append_pdu(
&self,
pdu_id: &PduId,
pdu_id: &[u8],
pdu: &PduEvent,
json: &CanonicalJsonObject,
count: u64,
) -> Result<()> {
self.pduid_pdu.insert(
pdu_id.as_bytes(),
pdu_id,
&serde_json::to_vec(json)
.expect("CanonicalJsonObject is always a valid"),
)?;
@ -211,8 +209,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
.unwrap()
.insert(pdu.room_id.clone(), PduCount::Normal(count));
self.eventid_pduid
.insert(pdu.event_id.as_bytes(), pdu_id.as_bytes())?;
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
Ok(())
@ -220,17 +217,17 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
fn prepend_backfill_pdu(
&self,
pdu_id: &PduId,
pdu_id: &[u8],
event_id: &EventId,
json: &CanonicalJsonObject,
) -> Result<()> {
self.pduid_pdu.insert(
pdu_id.as_bytes(),
pdu_id,
&serde_json::to_vec(json)
.expect("CanonicalJsonObject is always a valid"),
)?;
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id.as_bytes())?;
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(event_id.as_bytes())?;
Ok(())
@ -239,13 +236,13 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
/// Removes a pdu and creates a new one with the same id.
fn replace_pdu(
&self,
pdu_id: &PduId,
pdu_id: &[u8],
pdu_json: &CanonicalJsonObject,
pdu: &PduEvent,
) -> Result<()> {
if self.pduid_pdu.get(pdu_id.as_bytes())?.is_some() {
if self.pduid_pdu.get(pdu_id)?.is_some() {
self.pduid_pdu.insert(
pdu_id.as_bytes(),
pdu_id,
&serde_json::to_vec(pdu_json)
.expect("CanonicalJsonObject is always a valid"),
)?;
@ -383,7 +380,6 @@ fn count_to_id(
.ok_or_else(|| {
Error::bad_database("Looked for bad shortroomid in timeline")
})?
.get()
.to_be_bytes()
.to_vec();
let mut pdu_id = prefix.clone();

View file

@ -1,9 +1,7 @@
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::{
database::KeyValueDatabase,
service::{self, rooms::short::ShortStateHash},
services, utils, Error, Result,
database::KeyValueDatabase, service, services, utils, Error, Result,
};
impl service::rooms::user::Data for KeyValueDatabase {
@ -97,7 +95,7 @@ impl service::rooms::user::Data for KeyValueDatabase {
&self,
room_id: &RoomId,
token: u64,
shortstatehash: ShortStateHash,
shortstatehash: u64,
) -> Result<()> {
let shortroomid = services()
.rooms
@ -105,38 +103,36 @@ impl service::rooms::user::Data for KeyValueDatabase {
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.get().to_be_bytes().to_vec();
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.insert(&key, &shortstatehash.get().to_be_bytes())
.insert(&key, &shortstatehash.to_be_bytes())
}
fn get_token_shortstatehash(
&self,
room_id: &RoomId,
token: u64,
) -> Result<Option<ShortStateHash>> {
) -> Result<Option<u64>> {
let shortroomid = services()
.rooms
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.get().to_be_bytes().to_vec();
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| {
Error::bad_database(
"Invalid shortstatehash in \
roomsynctoken_shortstatehash",
)
})
.map(ShortStateHash::new)
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database(
"Invalid shortstatehash in \
roomsynctoken_shortstatehash",
)
})
})
.transpose()
}

View file

@ -1,10 +1,9 @@
use ruma::{serde::Raw, ServerName, UserId};
use ruma::{ServerName, UserId};
use crate::{
database::KeyValueDatabase,
service::{
self,
rooms::timeline::PduId,
sending::{Destination, RequestKey, SendingEventType},
},
services, utils, Error, Result,
@ -62,14 +61,14 @@ impl service::sending::Data for KeyValueDatabase {
for (destination, event) in requests {
let mut key = destination.get_prefix();
if let SendingEventType::Pdu(value) = &event {
key.extend_from_slice(value.as_bytes());
key.extend_from_slice(value);
} else {
key.extend_from_slice(
&services().globals.next_count()?.to_be_bytes(),
);
}
let value = if let SendingEventType::Edu(value) = &event {
value.json().get().as_bytes()
&**value
} else {
&[]
};
@ -100,7 +99,7 @@ impl service::sending::Data for KeyValueDatabase {
) -> Result<()> {
for (e, key) in events {
let value = if let SendingEventType::Edu(value) = &e {
value.json().get().as_bytes()
&**value
} else {
&[]
};
@ -203,15 +202,9 @@ fn parse_servercurrentevent(
Ok((
destination,
if value.is_empty() {
SendingEventType::Pdu(PduId::new(event.to_vec()))
SendingEventType::Pdu(event.to_vec())
} else {
SendingEventType::Edu(
Raw::from_json_string(
String::from_utf8(value)
.expect("EDU content in database should be a string"),
)
.expect("EDU content in database should be valid JSON"),
)
SendingEventType::Edu(value)
},
))
}

View file

@ -4,8 +4,6 @@ use std::{fmt, iter, path::PathBuf};
use thiserror::Error;
use crate::config::ListenConfig;
/// Formats an [`Error`][0] and its [`source`][1]s with a separator
///
/// [0]: std::error::Error
@ -50,7 +48,7 @@ pub(crate) enum Main {
DatabaseError(#[source] crate::utils::error::Error),
#[error("failed to serve requests")]
Serve(#[from] Serve),
Serve(#[source] std::io::Error),
}
/// Observability initialization errors
@ -99,30 +97,3 @@ pub(crate) enum ConfigSearch {
#[error("no relevant configuration files found in XDG Base Directories")]
NotFound,
}
/// Errors serving traffic
// Missing docs are allowed here since that kind of information should be
// encoded in the error messages themselves anyway.
#[allow(missing_docs)]
#[derive(Error, Debug)]
pub(crate) enum Serve {
#[error("no listeners were specified in the configuration file")]
NoListeners,
#[error(
"listener {0} requested TLS, but no TLS cert was specified in the \
configuration file. Please set 'tls.certs' and 'tls.key'"
)]
NoTlsCerts(ListenConfig),
#[error("failed to read TLS cert and key files at {certs:?} and {key:?}")]
LoadCerts {
certs: String,
key: String,
#[source]
err: std::io::Error,
},
#[error("failed to run request listener on {1}")]
Listen(#[source] std::io::Error, ListenConfig),
}

View file

@ -1,9 +1,6 @@
// Avoid spurious warnings with --no-default-features, which isn't expected to
// work anyway
#![cfg_attr(not(any(feature = "sqlite", feature = "rocksdb")), allow(unused))]
use std::{
future::Future,
io,
net::SocketAddr,
process::ExitCode,
sync::{atomic, RwLock},
@ -19,7 +16,6 @@ use axum::{
use axum_server::{
bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle,
};
use futures_util::FutureExt;
use http::{
header::{self, HeaderName},
Method, StatusCode, Uri,
@ -31,14 +27,14 @@ use ruma::api::{
},
IncomingRequest,
};
use tokio::{signal, task::JoinSet};
use tokio::signal;
use tower::ServiceBuilder;
use tower_http::{
cors::{self, CorsLayer},
trace::TraceLayer,
ServiceBuilderExt as _,
};
use tracing::{debug, error, info, info_span, warn, Instrument};
use tracing::{debug, info, info_span, warn, Instrument};
mod api;
mod args;
@ -50,8 +46,8 @@ mod service;
mod utils;
pub(crate) use api::ruma_wrapper::{Ar, Ra};
use api::{client_server, server_server, well_known};
pub(crate) use config::{Config, ListenConfig};
use api::{client_server, server_server};
pub(crate) use config::Config;
pub(crate) use database::KeyValueDatabase;
pub(crate) use service::{pdu::PduEvent, Services};
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
@ -112,7 +108,7 @@ async fn try_main() -> Result<(), error::Main> {
let config = config::load(args.config.as_ref()).await?;
let (_guard, reload_handles) = observability::init(&config)?;
let _guard = observability::init(&config);
// This is needed for opening lots of file descriptors, which tends to
// happen more often when using RocksDB and making lots of federation
@ -126,63 +122,38 @@ async fn try_main() -> Result<(), error::Main> {
.expect("should be able to increase the soft limit to the hard limit");
info!("Loading database");
KeyValueDatabase::load_or_create(config, reload_handles)
KeyValueDatabase::load_or_create(config)
.await
.map_err(Error::DatabaseError)?;
info!("Starting server");
run_server().await?;
run_server().await.map_err(Error::Serve)?;
Ok(())
}
#[allow(clippy::too_many_lines)]
async fn run_server() -> Result<(), error::Serve> {
use error::Serve as Error;
async fn run_server() -> io::Result<()> {
let config = &services().globals.config;
let addr = SocketAddr::from((config.address, config.port));
let x_requested_with = HeaderName::from_static("x-requested-with");
let middlewares = ServiceBuilder::new()
.sensitive_headers([header::AUTHORIZATION])
.layer(axum::middleware::from_fn(spawn_task))
.layer(
TraceLayer::new_for_http()
.make_span_with(|request: &http::Request<_>| {
let endpoint = if let Some(endpoint) =
request.extensions().get::<MatchedPath>()
{
endpoint.as_str()
} else {
request.uri().path()
};
.layer(TraceLayer::new_for_http().make_span_with(
|request: &http::Request<_>| {
let path = if let Some(path) =
request.extensions().get::<MatchedPath>()
{
path.as_str()
} else {
request.uri().path()
};
let method = request.method();
tracing::info_span!(
"http_request",
otel.name = format!("{method} {endpoint}"),
%method,
%endpoint,
)
})
.on_request(
|request: &http::Request<_>, _span: &tracing::Span| {
// can be enabled selectively using `filter =
// grapevine[incoming_request_curl]=trace` in config
tracing::trace_span!("incoming_request_curl").in_scope(
|| {
tracing::trace!(
cmd = utils::curlify(request),
"curl command line for incoming request \
(guessed hostname)"
);
},
);
},
),
)
tracing::info_span!("http_request", otel.name = path, %path, method = %request.method())
},
))
.layer(axum::middleware::from_fn(unrecognized_method))
.layer(
CorsLayer::new()
@ -212,64 +183,31 @@ async fn run_server() -> Result<(), error::Serve> {
.layer(axum::middleware::from_fn(observability::http_metrics_layer));
let app = routes(config).layer(middlewares).into_make_service();
let mut handles = Vec::new();
let mut servers = JoinSet::new();
let handle = ServerHandle::new();
let tls_config = if let Some(tls) = &config.tls {
Some(RustlsConfig::from_pem_file(&tls.certs, &tls.key).await.map_err(
|err| Error::LoadCerts {
certs: tls.certs.clone(),
key: tls.key.clone(),
err,
},
)?)
} else {
None
};
tokio::spawn(shutdown_signal(handle.clone()));
if config.listen.is_empty() {
return Err(Error::NoListeners);
}
match &config.tls {
Some(tls) => {
let conf =
RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
let server = bind_rustls(addr, conf).handle(handle).serve(app);
for listen in &config.listen {
info!(listener = %listen, "Listening for incoming traffic");
match listen {
ListenConfig::Tcp {
address,
port,
tls,
} => {
let addr = SocketAddr::from((*address, *port));
let handle = ServerHandle::new();
handles.push(handle.clone());
let server = if *tls {
let tls_config = tls_config
.clone()
.ok_or_else(|| Error::NoTlsCerts(listen.clone()))?;
bind_rustls(addr, tls_config)
.handle(handle)
.serve(app.clone())
.left_future()
} else {
bind(addr).handle(handle).serve(app.clone()).right_future()
};
servers.spawn(
server.then(|result| async { (listen.clone(), result) }),
);
}
#[cfg(feature = "systemd")]
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
.expect("should be able to notify systemd");
server.await?;
}
}
None => {
let server = bind(addr).handle(handle).serve(app);
#[cfg(feature = "systemd")]
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
.expect("should be able to notify systemd");
#[cfg(feature = "systemd")]
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
.expect("should be able to notify systemd");
tokio::spawn(shutdown_signal(handles));
while let Some(result) = servers.join_next().await {
let (listen, result) =
result.expect("should be able to join server task");
result.map_err(|err| Error::Listen(err, listen))?;
server.await?;
}
}
Ok(())
@ -299,7 +237,7 @@ async fn unrecognized_method(
let uri = req.uri().clone();
let inner = next.run(req).await;
if inner.status() == StatusCode::METHOD_NOT_ALLOWED {
warn!(%method, %uri, "Method not allowed");
warn!("Method not allowed: {method} {uri}");
return Ok(Ra(UiaaResponse::MatrixError(RumaError {
body: ErrorBody::Standard {
kind: ErrorKind::Unrecognized,
@ -406,24 +344,12 @@ fn routes(config: &Config) -> Router {
.ruma_route(c2s::get_message_events_route)
.ruma_route(c2s::search_events_route)
.ruma_route(c2s::turn_server_route)
.ruma_route(c2s::send_event_to_device_route);
// unauthenticated (legacy) media
let router = router
.ruma_route(c2s::get_media_config_legacy_route)
.ruma_route(c2s::get_content_legacy_route)
.ruma_route(c2s::get_content_as_filename_legacy_route)
.ruma_route(c2s::get_content_thumbnail_legacy_route);
// authenticated media
let router = router
.ruma_route(c2s::send_event_to_device_route)
.ruma_route(c2s::get_media_config_route)
.ruma_route(c2s::create_content_route)
.ruma_route(c2s::get_content_route)
.ruma_route(c2s::get_content_as_filename_route)
.ruma_route(c2s::get_content_thumbnail_route);
let router = router
.ruma_route(c2s::get_content_thumbnail_route)
.ruma_route(c2s::get_devices_route)
.ruma_route(c2s::get_device_route)
.ruma_route(c2s::update_device_route)
@ -470,7 +396,7 @@ fn routes(config: &Config) -> Router {
.put(c2s::send_state_event_for_empty_key_route),
);
let router = if config.observability.metrics.enable {
let router = if config.allow_prometheus {
router.route(
"/metrics",
get(|| async { observability::METRICS.export() }),
@ -491,11 +417,7 @@ fn routes(config: &Config) -> Router {
.route("/", get(it_works))
.fallback(not_found);
let router = router
.route("/.well-known/matrix/client", get(well_known::client))
.route("/.well-known/matrix/server", get(well_known::server));
if config.federation.enable {
if config.allow_federation {
router
.ruma_route(s2s::get_server_version_route)
.route("/_matrix/key/v2/server", get(s2s::get_server_keys_route))
@ -521,8 +443,6 @@ fn routes(config: &Config) -> Router {
.ruma_route(s2s::get_profile_information_route)
.ruma_route(s2s::get_keys_route)
.ruma_route(s2s::claim_keys_route)
.ruma_route(s2s::media_download_route)
.ruma_route(s2s::media_thumbnail_route)
} else {
router
.route("/_matrix/federation/*path", any(federation_disabled))
@ -530,7 +450,7 @@ fn routes(config: &Config) -> Router {
}
}
async fn shutdown_signal(handles: Vec<ServerHandle>) {
async fn shutdown_signal(handle: ServerHandle) {
let ctrl_c = async {
signal::ctrl_c().await.expect("failed to install Ctrl+C handler");
};
@ -553,14 +473,11 @@ async fn shutdown_signal(handles: Vec<ServerHandle>) {
() = terminate => { sig = "SIGTERM"; },
}
warn!(signal = %sig, "Shutting down due to signal");
warn!("Received {}, shutting down...", sig);
handle.graceful_shutdown(Some(Duration::from_secs(30)));
services().globals.shutdown();
for handle in handles {
handle.graceful_shutdown(Some(Duration::from_secs(30)));
}
#[cfg(feature = "systemd")]
sd_notify::notify(true, &[sd_notify::NotifyState::Stopping])
.expect("should be able to notify systemd");
@ -571,7 +488,7 @@ async fn federation_disabled(_: Uri) -> impl IntoResponse {
}
async fn not_found(method: Method, uri: Uri) -> impl IntoResponse {
debug!(%method, %uri, "Unknown route");
debug!(%method, %uri, "unknown route");
Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request")
}
@ -691,11 +608,11 @@ fn maximize_fd_limit() -> Result<(), nix::errno::Errno> {
let (soft_limit, hard_limit) = getrlimit(res)?;
debug!(soft_limit, "Current nofile soft limit");
debug!("Current nofile soft limit: {soft_limit}");
setrlimit(res, hard_limit, hard_limit)?;
debug!(hard_limit, "Increased nofile soft limit to the hard limit");
debug!("Increased nofile soft limit to {hard_limit}");
Ok(())
}

View file

@ -1,7 +1,7 @@
//! Facilities for observing runtime behavior
#![warn(missing_docs, clippy::missing_docs_in_private_items)]
use std::{collections::HashSet, fs::File, io::BufWriter, sync::Arc};
use std::{collections::HashSet, fs::File, io::BufWriter};
use axum::{
extract::{MatchedPath, Request},
@ -14,7 +14,6 @@ use opentelemetry::{
metrics::{MeterProvider, Unit},
KeyValue,
};
use opentelemetry_otlp::WithExportConfig;
use opentelemetry_sdk::{
metrics::{new_view, Aggregation, Instrument, SdkMeterProvider, Stream},
Resource,
@ -22,15 +21,9 @@ use opentelemetry_sdk::{
use strum::{AsRefStr, IntoStaticStr};
use tokio::time::Instant;
use tracing_flame::{FlameLayer, FlushGuard};
use tracing_subscriber::{
layer::SubscriberExt, reload, EnvFilter, Layer, Registry,
};
use tracing_subscriber::{layer::SubscriberExt, EnvFilter, Layer, Registry};
use crate::{
config::{Config, EnvFilterClone, LogFormat},
error,
utils::error::Result,
};
use crate::{config::Config, error, utils::error::Result};
/// Globally accessible metrics state
pub(crate) static METRICS: Lazy<Metrics> = Lazy::new(Metrics::new);
@ -48,44 +41,6 @@ impl Drop for Guard {
}
}
/// We need to store a [`reload::Handle`] value, but can't name it's type
/// explicitly because the S type parameter depends on the subscriber's previous
/// layers. In our case, this includes unnameable 'impl Trait' types.
///
/// This is fixed[1] in the unreleased tracing-subscriber from the master
/// branch, which removes the S parameter. Unfortunately can't use it without
/// pulling in a version of tracing that's incompatible with the rest of our
/// deps.
///
/// To work around this, we define an trait without the S paramter that forwards
/// to the [`reload::Handle::reload`] method, and then store the handle as a
/// trait object.
///
/// [1]: https://github.com/tokio-rs/tracing/pull/1035/commits/8a87ea52425098d3ef8f56d92358c2f6c144a28f
pub(crate) trait ReloadHandle<L> {
/// Replace the layer with a new value. See [`reload::Handle::reload`].
fn reload(&self, new_value: L) -> Result<(), reload::Error>;
}
impl<L, S> ReloadHandle<L> for reload::Handle<L, S> {
fn reload(&self, new_value: L) -> Result<(), reload::Error> {
reload::Handle::reload(self, new_value)
}
}
/// A type-erased [reload handle][reload::Handle] for an [`EnvFilter`].
pub(crate) type FilterReloadHandle = Box<dyn ReloadHandle<EnvFilter> + Sync>;
/// Collection of [`FilterReloadHandle`]s, allowing the filters for tracing
/// backends to be changed dynamically. Handles may be [`None`] if the backend
/// is disabled in the config.
#[allow(clippy::missing_docs_in_private_items)]
pub(crate) struct FilterReloadHandles {
pub(crate) traces: Option<FilterReloadHandle>,
pub(crate) flame: Option<FilterReloadHandle>,
pub(crate) log: Option<FilterReloadHandle>,
}
/// A kind of data that gets looked up
///
/// See also [`Metrics::record_lookup`].
@ -125,135 +80,65 @@ pub(crate) enum FoundIn {
Nothing,
}
/// Wrapper for the creation of a `tracing` [`Layer`] and any associated opaque
/// data.
///
/// Returns a no-op `None` layer if `enable` is `false`, otherwise calls the
/// given closure to construct the layer and associated data, then applies the
/// filter to the layer.
fn make_backend<S, L, T>(
enable: bool,
filter: &EnvFilterClone,
init: impl FnOnce() -> Result<(L, T), error::Observability>,
) -> Result<
(impl Layer<S>, Option<FilterReloadHandle>, Option<T>),
error::Observability,
>
where
L: Layer<S>,
S: tracing::Subscriber
+ for<'span> tracing_subscriber::registry::LookupSpan<'span>,
{
if !enable {
return Ok((None, None, None));
}
let (filter, handle) = reload::Layer::new(EnvFilter::from(filter));
let (layer, data) = init()?;
Ok((Some(layer.with_filter(filter)), Some(Box::new(handle)), Some(data)))
}
/// Initialize observability
pub(crate) fn init(
config: &Config,
) -> Result<(Guard, FilterReloadHandles), error::Observability> {
let (traces_layer, traces_filter, _) = make_backend(
config.observability.traces.enable,
&config.observability.traces.filter,
|| {
pub(crate) fn init(config: &Config) -> Result<Guard, error::Observability> {
let jaeger_layer = config
.allow_jaeger
.then(|| {
opentelemetry::global::set_text_map_propagator(
opentelemetry_jaeger_propagator::Propagator::new(),
);
let mut exporter = opentelemetry_otlp::new_exporter().tonic();
if let Some(endpoint) = &config.observability.traces.endpoint {
exporter = exporter.with_endpoint(endpoint);
}
let tracer = opentelemetry_otlp::new_pipeline()
.tracing()
.with_trace_config(
opentelemetry_sdk::trace::config().with_resource(
standard_resource(
config.observability.traces.service_name.clone(),
),
),
opentelemetry_sdk::trace::config()
.with_resource(standard_resource()),
)
.with_exporter(exporter)
.with_exporter(opentelemetry_otlp::new_exporter().tonic())
.install_batch(opentelemetry_sdk::runtime::Tokio)?;
Ok((tracing_opentelemetry::layer().with_tracer(tracer), ()))
},
)?;
let telemetry = tracing_opentelemetry::layer().with_tracer(tracer);
let (flame_layer, flame_filter, flame_guard) = make_backend(
config.observability.flame.enable,
&config.observability.flame.filter,
|| {
Ok::<_, error::Observability>(
telemetry.with_filter(EnvFilter::from(&config.log)),
)
})
.transpose()?;
let (flame_layer, flame_guard) = config
.tracing_flame
.then(|| {
let (flame_layer, guard) =
FlameLayer::with_file(&config.observability.flame.filename)?;
Ok((flame_layer.with_empty_samples(false), guard))
},
)?;
FlameLayer::with_file("./tracing.folded")?;
let flame_layer = flame_layer.with_empty_samples(false);
let (log_layer, log_filter, _) =
make_backend(true, &config.observability.logs.filter, || {
/// Time format selection for `tracing_subscriber` at runtime
#[allow(clippy::missing_docs_in_private_items)]
enum TimeFormat {
SystemTime,
NoTime,
}
impl tracing_subscriber::fmt::time::FormatTime for TimeFormat {
fn format_time(
&self,
w: &mut tracing_subscriber::fmt::format::Writer<'_>,
) -> std::fmt::Result {
match self {
TimeFormat::SystemTime => {
tracing_subscriber::fmt::time::SystemTime
.format_time(w)
}
TimeFormat::NoTime => Ok(()),
}
}
}
Ok::<_, error::Observability>((
flame_layer.with_filter(EnvFilter::from(&config.log)),
guard,
))
})
.transpose()?
.unzip();
let fmt_layer = tracing_subscriber::fmt::Layer::new()
.with_ansi(config.observability.logs.colors)
.with_timer(if config.observability.logs.timestamp {
TimeFormat::SystemTime
} else {
TimeFormat::NoTime
});
let fmt_layer = match config.observability.logs.format {
LogFormat::Pretty => fmt_layer.pretty().boxed(),
LogFormat::Full => fmt_layer.boxed(),
LogFormat::Compact => fmt_layer.compact().boxed(),
LogFormat::Json => fmt_layer.json().boxed(),
};
Ok((fmt_layer, ()))
})?;
let fmt_layer = tracing_subscriber::fmt::Layer::new()
.with_filter(EnvFilter::from(&config.log));
let subscriber = Registry::default()
.with(traces_layer)
.with(jaeger_layer)
.with(flame_layer)
.with(log_layer);
.with(fmt_layer);
tracing::subscriber::set_global_default(subscriber)?;
Ok((
Guard {
flame_guard,
},
FilterReloadHandles {
traces: traces_filter,
flame: flame_filter,
log: log_filter,
},
))
Ok(Guard {
flame_guard,
})
}
/// Construct the standard [`Resource`] value to use for this service
fn standard_resource(service_name: String) -> Resource {
Resource::default()
.merge(&Resource::new([KeyValue::new("service.name", service_name)]))
fn standard_resource() -> Resource {
Resource::default().merge(&Resource::new([KeyValue::new(
"service.name",
env!("CARGO_PKG_NAME"),
)]))
}
/// Holds state relating to metrics
@ -270,10 +155,6 @@ pub(crate) struct Metrics {
/// Counts where data is found from
lookup: opentelemetry::metrics::Counter<u64>,
/// Number of entries in an
/// [`OnDemandHashMap`](crate::utils::on_demand_hashmap::OnDemandHashMap)
on_demand_hashmap_size: opentelemetry::metrics::Gauge<u64>,
}
impl Metrics {
@ -307,7 +188,7 @@ impl Metrics {
)
.expect("view should be valid"),
)
.with_resource(standard_resource(env!("CARGO_PKG_NAME").to_owned()))
.with_resource(standard_resource())
.build();
let meter = provider.meter(env!("CARGO_PKG_NAME"));
@ -324,16 +205,10 @@ impl Metrics {
.with_description("Counts where data is found from")
.init();
let on_demand_hashmap_size = meter
.u64_gauge("on_demand_hashmap_size")
.with_description("Number of entries in OnDemandHashMap")
.init();
Metrics {
otel_state: (registry, provider),
http_requests_histogram,
lookup,
on_demand_hashmap_size,
}
}
@ -354,20 +229,6 @@ impl Metrics {
],
);
}
/// Record size of [`OnDemandHashMap`]
///
/// [`OnDemandHashMap`]: crate::utils::on_demand_hashmap::OnDemandHashMap
pub(crate) fn record_on_demand_hashmap_size(
&self,
name: Arc<str>,
size: usize,
) {
self.on_demand_hashmap_size.record(
size.try_into().unwrap_or(u64::MAX),
&[KeyValue::new("name", name)],
);
}
}
/// Track HTTP metrics by converting this into an [`axum`] layer

View file

@ -6,7 +6,7 @@ use std::{
use lru_cache::LruCache;
use tokio::sync::{broadcast, Mutex, RwLock};
use crate::{observability::FilterReloadHandles, Config, Result};
use crate::{Config, Result};
pub(crate) mod account_data;
pub(crate) mod admin;
@ -54,7 +54,6 @@ impl Services {
>(
db: &'static D,
config: Config,
reload_handles: FilterReloadHandles,
) -> Result<Self> {
Ok(Self {
appservice: appservice::Service::build(db)?,
@ -150,7 +149,7 @@ impl Services {
},
sending: sending::Service::build(db, &config),
globals: globals::Service::load(db, config, reload_handles)?,
globals: globals::Service::load(db, config)?,
})
}

View file

@ -1,6 +1,6 @@
use std::{collections::BTreeMap, fmt::Write, sync::Arc, time::Instant};
use clap::{Parser, ValueEnum};
use clap::Parser;
use regex::Regex;
use ruma::{
api::appservice::Registration,
@ -23,8 +23,8 @@ use ruma::{
TimelineEventType,
},
signatures::verify_json,
EventId, MilliSecondsSinceUnixEpoch, OwnedRoomId, RoomId, RoomVersionId,
ServerName, UserId,
EventId, MilliSecondsSinceUnixEpoch, OwnedRoomAliasId, OwnedRoomId,
RoomAliasId, RoomId, RoomVersionId, ServerName, UserId,
};
use serde_json::value::to_raw_value;
use tokio::sync::{mpsc, Mutex, RwLock};
@ -34,7 +34,7 @@ use super::pdu::PduBuilder;
use crate::{
api::client_server::{leave_all_rooms, AUTO_GEN_PASSWORD_LENGTH},
services,
utils::{self, dbg_truncate_str},
utils::{self, truncate_str_for_debug},
Error, PduEvent, Result,
};
@ -179,6 +179,11 @@ enum AdminCommand {
room_id: Box<RoomId>,
},
/// Remove a room alias.
UnsetAlias {
room_alias_id: OwnedRoomAliasId,
},
/// Verify json signatures
/// [commandbody]()
/// # ```
@ -196,17 +201,11 @@ enum AdminCommand {
// Allowed because the doc comment gets parsed by our code later
#[allow(clippy::doc_markdown)]
VerifyJson,
/// Dynamically change a tracing backend's filter string
SetTracingFilter {
backend: TracingBackend,
filter: String,
},
}
#[derive(Debug)]
pub(crate) enum AdminRoomEvent {
ProcessMessage(String),
ProcessMessage(Box<PduEvent>, String),
SendMessage(RoomMessageEventContent),
}
@ -215,13 +214,6 @@ pub(crate) struct Service {
receiver: Mutex<mpsc::UnboundedReceiver<AdminRoomEvent>>,
}
#[derive(Debug, Clone, ValueEnum)]
enum TracingBackend {
Log,
Flame,
Traces,
}
impl Service {
pub(crate) fn build() -> Arc<Self> {
let (sender, receiver) = mpsc::unbounded_channel();
@ -236,7 +228,8 @@ impl Service {
tokio::spawn(async move {
let mut receiver = self2.receiver.lock().await;
let Ok(Some(grapevine_room)) = self2.get_admin_room() else {
let Ok(Some(grapevine_room)) = services().admin.get_admin_room()
else {
return;
};
@ -259,16 +252,22 @@ impl Service {
) {
let message_content = match event {
AdminRoomEvent::SendMessage(content) => content,
AdminRoomEvent::ProcessMessage(room_message) => {
self.process_admin_message(room_message).await
AdminRoomEvent::ProcessMessage(pdu, room_message) => {
self.process_admin_message(*pdu, room_message).await
}
};
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(grapevine_room.clone())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(grapevine_room.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
services()
.rooms
@ -283,7 +282,8 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_token,
grapevine_room,
&state_lock,
)
.await
.unwrap();
@ -292,11 +292,13 @@ impl Service {
#[tracing::instrument(
skip(self, room_message),
fields(
room_message = dbg_truncate_str(&room_message, 50).as_ref(),
room_message = truncate_str_for_debug(&room_message, 50).as_ref(),
),
)]
pub(crate) fn process_message(&self, room_message: String) {
self.sender.send(AdminRoomEvent::ProcessMessage(room_message)).unwrap();
pub(crate) fn process_message(&self, pdu: PduEvent, room_message: String) {
self.sender
.send(AdminRoomEvent::ProcessMessage(Box::new(pdu), room_message))
.unwrap();
}
#[tracing::instrument(skip(self, message_content))]
@ -311,11 +313,12 @@ impl Service {
#[tracing::instrument(
skip(self, room_message),
fields(
room_message = dbg_truncate_str(&room_message, 50).as_ref(),
room_message = truncate_str_for_debug(&room_message, 50).as_ref(),
),
)]
async fn process_admin_message(
&self,
pdu: PduEvent,
room_message: String,
) -> RoomMessageEventContent {
let mut lines = room_message.lines().filter(|l| !l.trim().is_empty());
@ -338,7 +341,7 @@ impl Service {
}
};
match self.process_admin_command(admin_command, body).await {
match self.process_admin_command(pdu, admin_command, body).await {
Ok(reply_message) => reply_message,
Err(error) => {
let markdown_message = format!(
@ -362,7 +365,7 @@ impl Service {
#[tracing::instrument(
skip(command_line),
fields(
command_line = dbg_truncate_str(command_line, 50).as_ref(),
command_line = truncate_str_for_debug(command_line, 50).as_ref(),
),
)]
fn parse_admin_command(
@ -393,6 +396,7 @@ impl Service {
#[tracing::instrument(skip(self, body))]
async fn process_admin_command(
&self,
pdu: PduEvent,
command: AdminCommand,
body: Vec<&str>,
) -> Result<RoomMessageEventContent> {
@ -1086,39 +1090,9 @@ impl Service {
)
}
}
AdminCommand::SetTracingFilter {
backend,
filter,
} => {
let handles = &services().globals.reload_handles;
let handle = match backend {
TracingBackend::Log => &handles.log,
TracingBackend::Flame => &handles.flame,
TracingBackend::Traces => &handles.traces,
};
let Some(handle) = handle else {
return Ok(RoomMessageEventContent::text_plain(
"Backend is disabled",
));
};
let filter = match filter.parse() {
Ok(filter) => filter,
Err(e) => {
return Ok(RoomMessageEventContent::text_plain(
format!("Invalid filter string: {e}"),
));
}
};
if let Err(e) = handle.reload(filter) {
return Ok(RoomMessageEventContent::text_plain(format!(
"Failed to reload filter: {e}"
)));
};
return Ok(RoomMessageEventContent::text_plain(
"Filter reloaded",
));
}
AdminCommand::UnsetAlias {
room_alias_id,
} => cmd_unset_alias(&room_alias_id, &pdu.sender),
};
Ok(reply_message_content)
@ -1212,27 +1186,35 @@ impl Service {
services().rooms.short.get_or_create_shortroomid(&room_id)?;
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(room_id.clone())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
services().users.create(&services().globals.admin_bot_user_id, None)?;
let room_version = services().globals.default_room_version();
let mut content = match &room_version {
room_version if *room_version < RoomVersionId::V11 => {
RoomCreateEventContent::new_v1(
services().globals.admin_bot_user_id.clone(),
)
}
let mut content = match room_version {
RoomVersionId::V1
| RoomVersionId::V2
| RoomVersionId::V3
| RoomVersionId::V4
| RoomVersionId::V5
| RoomVersionId::V6
| RoomVersionId::V7
| RoomVersionId::V8
| RoomVersionId::V9
| RoomVersionId::V10 => RoomCreateEventContent::new_v1(
services().globals.admin_bot_user_id.clone(),
),
RoomVersionId::V11 => RoomCreateEventContent::new_v11(),
_ => {
return Err(Error::BadServerResponse(
"Unsupported room version.",
))
}
_ => unreachable!("Validity of room version already checked"),
};
content.federate = true;
content.predecessor = None;
@ -1252,7 +1234,8 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -1281,7 +1264,8 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -1305,7 +1289,8 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -1325,7 +1310,8 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -1347,7 +1333,8 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -1367,7 +1354,8 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -1389,7 +1377,8 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -1411,7 +1400,8 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -1434,7 +1424,8 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -1469,12 +1460,17 @@ impl Service {
user_id: &UserId,
displayname: String,
) -> Result<()> {
if let Some(room_id) = self.get_admin_room()? {
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(room_id.clone())
.await;
if let Some(room_id) = services().admin.get_admin_room()? {
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let state_lock = mutex_state.lock().await;
// Use the server user to grant the new admin's power level
// Invite and join the real user
@ -1500,7 +1496,8 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_token,
&room_id,
&state_lock,
)
.await?;
services()
@ -1525,7 +1522,8 @@ impl Service {
redacts: None,
},
user_id,
&room_token,
&room_id,
&state_lock,
)
.await?;
@ -1553,7 +1551,8 @@ impl Service {
redacts: None,
},
&services().globals.admin_bot_user_id,
&room_token,
&room_id,
&state_lock,
)
.await?;
}
@ -1561,6 +1560,23 @@ impl Service {
}
}
/// Remove an alias from a room
///
/// This authenticates the command against the command's sender.
fn cmd_unset_alias(
room_alias_id: &RoomAliasId,
user_id: &UserId,
) -> RoomMessageEventContent {
let res = services().rooms.alias.remove_alias(room_alias_id, user_id);
let res = match res {
Ok(()) => "Successfully removed room alias.".to_owned(),
Err(e) => format!("Failed to remove room alias: {e}"),
};
RoomMessageEventContent::text_plain(res)
}
#[cfg(test)]
mod test {
use super::*;

View file

@ -11,7 +11,7 @@ use ruma::{
};
use tokio::sync::RwLock;
use crate::Result;
use crate::{services, Result};
/// Compiled regular expressions for a namespace.
#[derive(Clone, Debug)]
@ -160,9 +160,15 @@ impl Service {
&self,
service_name: &str,
) -> Result<()> {
self.registration_info.write().await.remove(service_name).ok_or_else(
|| crate::Error::AdminCommand("Appservice not found"),
)?;
services()
.appservice
.registration_info
.write()
.await
.remove(service_name)
.ok_or_else(|| {
crate::Error::AdminCommand("Appservice not found")
})?;
self.db.unregister_appservice(service_name)
}

View file

@ -29,34 +29,18 @@ use ruma::{
UserId,
};
use tokio::sync::{broadcast, Mutex, RwLock, Semaphore};
use tracing::{error, Instrument};
use tracing::{error, info, Instrument};
use trust_dns_resolver::TokioAsyncResolver;
use crate::{
api::server_server::FedDest,
observability::FilterReloadHandles,
service::media::MediaFileKey,
utils::on_demand_hashmap::{OnDemandHashMap, TokenSet},
Config, Error, Result,
};
use crate::{api::server_server::FedDest, services, Config, Error, Result};
type WellKnownMap = HashMap<OwnedServerName, (FedDest, String)>;
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
// Time if last failed try, number of failed tries
type RateLimitState = (Instant, u32);
// Markers for
// [`Service::roomid_mutex_state`]/[`Service::roomid_mutex_insert`]/
// [`Service::roomid_mutex_federation`]
pub(crate) mod marker {
pub(crate) enum State {}
pub(crate) enum Insert {}
pub(crate) enum Federation {}
}
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub(crate) reload_handles: FilterReloadHandles,
// actual_destination, host
pub(crate) actual_destination_cache: Arc<RwLock<WellKnownMap>>,
@ -68,6 +52,7 @@ pub(crate) struct Service {
federation_client: reqwest::Client,
default_client: reqwest::Client,
pub(crate) stable_room_versions: Vec<RoomVersionId>,
pub(crate) unstable_room_versions: Vec<RoomVersionId>,
pub(crate) admin_bot_user_id: OwnedUserId,
pub(crate) admin_bot_room_alias_id: OwnedRoomAliasId,
pub(crate) bad_event_ratelimiter:
@ -77,13 +62,14 @@ pub(crate) struct Service {
pub(crate) bad_query_ratelimiter:
Arc<RwLock<HashMap<OwnedServerName, RateLimitState>>>,
pub(crate) servername_ratelimiter:
OnDemandHashMap<OwnedServerName, Semaphore>,
pub(crate) roomid_mutex_insert: TokenSet<OwnedRoomId, marker::Insert>,
pub(crate) roomid_mutex_state: TokenSet<OwnedRoomId, marker::State>,
Arc<RwLock<HashMap<OwnedServerName, Arc<Semaphore>>>>,
pub(crate) roomid_mutex_insert:
RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub(crate) roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
// this lock will be held longer
pub(crate) roomid_mutex_federation:
TokenSet<OwnedRoomId, marker::Federation>,
RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
pub(crate) roomid_federationhandletime:
RwLock<HashMap<OwnedRoomId, (OwnedEventId, Instant)>>,
pub(crate) stateres_mutex: Arc<Mutex<()>>,
@ -187,11 +173,7 @@ impl Resolve for Resolver {
impl Service {
#[tracing::instrument(skip_all)]
pub(crate) fn load(
db: &'static dyn Data,
config: Config,
reload_handles: FilterReloadHandles,
) -> Result<Self> {
pub(crate) fn load(db: &'static dyn Data, config: Config) -> Result<Self> {
let keypair = db.load_keypair();
let keypair = match keypair {
@ -223,6 +205,9 @@ impl Service {
RoomVersionId::V10,
RoomVersionId::V11,
];
// Experimental, partially supported room versions
let unstable_room_versions =
vec![RoomVersionId::V3, RoomVersionId::V4, RoomVersionId::V5];
let admin_bot_user_id = UserId::parse(format!(
"@{}:{}",
@ -242,7 +227,6 @@ impl Service {
let mut s = Self {
db,
config,
reload_handles,
keypair: Arc::new(keypair),
dns_resolver: TokioAsyncResolver::tokio_from_system_conf()
.map_err(|e| {
@ -264,21 +248,16 @@ impl Service {
default_client,
jwt_decoding_key,
stable_room_versions,
unstable_room_versions,
admin_bot_user_id,
admin_bot_room_alias_id,
bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
bad_signature_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
bad_query_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
servername_ratelimiter: OnDemandHashMap::new(
"servername_ratelimiter".to_owned(),
),
roomid_mutex_state: TokenSet::new("roomid_mutex_state".to_owned()),
roomid_mutex_insert: TokenSet::new(
"roomid_mutex_insert".to_owned(),
),
roomid_mutex_federation: TokenSet::new(
"roomid_mutex_federation".to_owned(),
),
servername_ratelimiter: Arc::new(RwLock::new(HashMap::new())),
roomid_mutex_state: RwLock::new(HashMap::new()),
roomid_mutex_insert: RwLock::new(HashMap::new()),
roomid_mutex_federation: RwLock::new(HashMap::new()),
roomid_federationhandletime: RwLock::new(HashMap::new()),
stateres_mutex: Arc::new(Mutex::new(())),
rotate: RotationHandler::new(),
@ -345,7 +324,7 @@ impl Service {
}
pub(crate) fn max_fetch_prev_events(&self) -> u16 {
self.config.federation.max_fetch_prev_events
self.config.max_fetch_prev_events
}
pub(crate) fn allow_registration(&self) -> bool {
@ -357,19 +336,23 @@ impl Service {
}
pub(crate) fn allow_federation(&self) -> bool {
self.config.federation.enable
self.config.allow_federation
}
pub(crate) fn allow_room_creation(&self) -> bool {
self.config.allow_room_creation
}
pub(crate) fn allow_unstable_room_versions(&self) -> bool {
self.config.allow_unstable_room_versions
}
pub(crate) fn default_room_version(&self) -> RoomVersionId {
self.config.default_room_version.clone()
}
pub(crate) fn trusted_servers(&self) -> &[OwnedServerName] {
&self.config.federation.trusted_servers
&self.config.trusted_servers
}
pub(crate) fn dns_resolver(&self) -> &TokioAsyncResolver {
@ -383,23 +366,23 @@ impl Service {
}
pub(crate) fn turn_password(&self) -> &String {
&self.config.turn.password
&self.config.turn_password
}
pub(crate) fn turn_ttl(&self) -> u64 {
self.config.turn.ttl
self.config.turn_ttl
}
pub(crate) fn turn_uris(&self) -> &[String] {
&self.config.turn.uris
&self.config.turn_uris
}
pub(crate) fn turn_username(&self) -> &String {
&self.config.turn.username
&self.config.turn_username
}
pub(crate) fn turn_secret(&self) -> &String {
&self.config.turn.secret
&self.config.turn_secret
}
pub(crate) fn emergency_password(&self) -> &Option<String> {
@ -407,7 +390,12 @@ impl Service {
}
pub(crate) fn supported_room_versions(&self) -> Vec<RoomVersionId> {
self.stable_room_versions.clone()
let mut room_versions: Vec<RoomVersionId> = vec![];
room_versions.extend(self.stable_room_versions.clone());
if self.allow_unstable_room_versions() {
room_versions.extend(self.unstable_room_versions.clone());
};
room_versions
}
/// This doesn't actually check that the keys provided are newer than the
@ -465,10 +453,15 @@ impl Service {
&self,
keys: SigningKeys,
timestamp: MilliSecondsSinceUnixEpoch,
_room_version_id: &RoomVersionId,
room_version_id: &RoomVersionId,
) -> Option<BTreeMap<String, Base64>> {
let all_valid = keys.valid_until_ts > timestamp;
let all_valid = keys.valid_until_ts > timestamp
// valid_until_ts MUST be ignored in room versions 1, 2, 3, and 4.
// https://spec.matrix.org/v1.10/server-server-api/#get_matrixkeyv2server
|| matches!(room_version_id, RoomVersionId::V1
| RoomVersionId::V2
| RoomVersionId::V4
| RoomVersionId::V3);
all_valid.then(|| {
// Given that either the room version allows stale keys, or the
// valid_until_ts is in the future, all verify_keys are
@ -501,22 +494,24 @@ impl Service {
pub(crate) fn get_media_folder(&self) -> PathBuf {
let mut r = PathBuf::new();
r.push(self.config.database.path.clone());
r.push(self.config.database_path.clone());
r.push("media");
r
}
pub(crate) fn get_media_file(&self, key: &MediaFileKey) -> PathBuf {
pub(crate) fn get_media_file(&self, key: &[u8]) -> PathBuf {
let mut r = PathBuf::new();
r.push(self.config.database.path.clone());
r.push(self.config.database_path.clone());
r.push("media");
r.push(general_purpose::URL_SAFE_NO_PAD.encode(key.as_bytes()));
r.push(general_purpose::URL_SAFE_NO_PAD.encode(key));
r
}
pub(crate) fn shutdown(&self) {
self.shutdown.store(true, atomic::Ordering::Relaxed);
self.rotate.fire();
// On shutdown
info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers...");
services().globals.rotate.fire();
}
}

View file

@ -1,10 +1,9 @@
use std::io::Cursor;
use image::imageops::FilterType;
use ruma::http_headers::ContentDisposition;
use tokio::{
fs::File,
io::{AsyncReadExt, AsyncWriteExt},
io::{AsyncReadExt, AsyncWriteExt, BufReader},
};
use tracing::{debug, warn};
@ -21,20 +20,9 @@ pub(crate) struct FileMeta {
// only the filename instead of the entire `Content-Disposition` header.
#[allow(dead_code)]
pub(crate) content_disposition: Option<String>,
pub(crate) content_type: Option<String>,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct MediaFileKey(Vec<u8>);
impl MediaFileKey {
pub(crate) fn new(key: Vec<u8>) -> Self {
Self(key)
}
pub(crate) fn as_bytes(&self) -> &[u8] {
&self.0
}
pub(crate) file: Vec<u8>,
}
pub(crate) struct Service {
@ -47,64 +35,69 @@ impl Service {
pub(crate) async fn create(
&self,
mxc: String,
content_disposition: Option<&ContentDisposition>,
content_type: Option<String>,
content_disposition: Option<&str>,
content_type: Option<&str>,
file: &[u8],
) -> Result<FileMeta> {
let meta = FileMeta {
content_disposition: content_disposition
.map(ContentDisposition::to_string),
content_type,
};
) -> Result<()> {
// Width, Height = 0 if it's not a thumbnail
let key = self.db.create_file_metadata(mxc, 0, 0, &meta)?;
let key = self.db.create_file_metadata(
mxc,
0,
0,
content_disposition,
content_type,
)?;
let path = services().globals.get_media_file(&key);
let mut f = File::create(path).await?;
f.write_all(file).await?;
Ok(meta)
Ok(())
}
/// Uploads or replaces a file thumbnail.
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip(self, file))]
pub(crate) async fn upload_thumbnail(
&self,
mxc: String,
content_disposition: Option<String>,
content_type: Option<String>,
content_disposition: Option<&str>,
content_type: Option<&str>,
width: u32,
height: u32,
file: &[u8],
) -> Result<FileMeta> {
let meta = FileMeta {
) -> Result<()> {
let key = self.db.create_file_metadata(
mxc,
width,
height,
content_disposition,
content_type,
};
let key = self.db.create_file_metadata(mxc, width, height, &meta)?;
)?;
let path = services().globals.get_media_file(&key);
let mut f = File::create(path).await?;
f.write_all(file).await?;
Ok(meta)
Ok(())
}
/// Downloads a file.
#[tracing::instrument(skip(self))]
pub(crate) async fn get(
&self,
mxc: String,
) -> Result<Option<(FileMeta, Vec<u8>)>> {
if let Ok((meta, key)) = self.db.search_file_metadata(mxc, 0, 0) {
pub(crate) async fn get(&self, mxc: String) -> Result<Option<FileMeta>> {
if let Ok((content_disposition, content_type, key)) =
self.db.search_file_metadata(mxc, 0, 0)
{
let path = services().globals.get_media_file(&key);
let mut file_data = Vec::new();
let Ok(mut file) = File::open(path).await else {
return Ok(None);
};
let mut file = Vec::new();
BufReader::new(File::open(path).await?)
.read_to_end(&mut file)
.await?;
file.read_to_end(&mut file_data).await?;
Ok(Some((meta, file_data)))
Ok(Some(FileMeta {
content_disposition,
content_type,
file,
}))
} else {
Ok(None)
}
@ -171,23 +164,23 @@ impl Service {
/ u64::from(original_height)
};
if use_width {
if let Ok(intermediate) = u32::try_from(intermediate) {
(width, intermediate)
if intermediate <= u64::from(::std::u32::MAX) {
(width, intermediate.try_into().unwrap_or(u32::MAX))
} else {
(
(u64::from(width) * u64::from(u32::MAX)
(u64::from(width) * u64::from(::std::u32::MAX)
/ intermediate)
.try_into()
.unwrap_or(u32::MAX),
u32::MAX,
::std::u32::MAX,
)
}
} else if let Ok(intermediate) = u32::try_from(intermediate) {
(intermediate, height)
} else if intermediate <= u64::from(::std::u32::MAX) {
(intermediate.try_into().unwrap_or(u32::MAX), height)
} else {
(
u32::MAX,
(u64::from(height) * u64::from(u32::MAX)
::std::u32::MAX,
(u64::from(height) * u64::from(::std::u32::MAX)
/ intermediate)
.try_into()
.unwrap_or(u32::MAX),
@ -221,18 +214,19 @@ impl Service {
///
/// For width,height <= 96 the server uses another thumbnailing algorithm
/// which crops the image afterwards.
#[allow(clippy::too_many_lines)]
#[tracing::instrument(skip(self))]
pub(crate) async fn get_thumbnail(
&self,
mxc: String,
width: u32,
height: u32,
) -> Result<Option<(FileMeta, Vec<u8>)>> {
) -> Result<Option<FileMeta>> {
// 0, 0 because that's the original file
let (width, height, crop) =
Self::thumbnail_properties(width, height).unwrap_or((0, 0, false));
if let Ok((meta, key)) =
if let Ok((content_disposition, content_type, key)) =
self.db.search_file_metadata(mxc.clone(), width, height)
{
debug!("Using saved thumbnail");
@ -240,10 +234,15 @@ impl Service {
let mut file = Vec::new();
File::open(path).await?.read_to_end(&mut file).await?;
return Ok(Some((meta, file.clone())));
return Ok(Some(FileMeta {
content_disposition,
content_type,
file: file.clone(),
}));
}
let Ok((meta, key)) = self.db.search_file_metadata(mxc.clone(), 0, 0)
let Ok((content_disposition, content_type, key)) =
self.db.search_file_metadata(mxc.clone(), 0, 0)
else {
debug!("Original image not found, can't generate thumbnail");
return Ok(None);
@ -269,20 +268,33 @@ impl Service {
let Some(thumbnail_bytes) = thumbnail_result? else {
debug!("Returning source image as-is");
return Ok(Some((meta, file)));
return Ok(Some(FileMeta {
content_disposition,
content_type,
file,
}));
};
debug!("Saving created thumbnail");
// Save thumbnail in database so we don't have to generate it
// again next time
let thumbnail_key =
self.db.create_file_metadata(mxc, width, height, &meta)?;
let thumbnail_key = self.db.create_file_metadata(
mxc,
width,
height,
content_disposition.as_deref(),
content_type.as_deref(),
)?;
let path = services().globals.get_media_file(&thumbnail_key);
let mut f = File::create(path).await?;
f.write_all(&thumbnail_bytes).await?;
Ok(Some((meta, thumbnail_bytes.clone())))
Ok(Some(FileMeta {
content_disposition,
content_type,
file: thumbnail_bytes.clone(),
}))
}
}

View file

@ -1,4 +1,3 @@
use super::{FileMeta, MediaFileKey};
use crate::Result;
pub(crate) trait Data: Send + Sync {
@ -7,13 +6,15 @@ pub(crate) trait Data: Send + Sync {
mxc: String,
width: u32,
height: u32,
meta: &FileMeta,
) -> Result<MediaFileKey>;
content_disposition: Option<&str>,
content_type: Option<&str>,
) -> Result<Vec<u8>>;
/// Returns `content_disposition`, `content_type` and the `metadata` key.
fn search_file_metadata(
&self,
mxc: String,
width: u32,
height: u32,
) -> Result<(FileMeta, MediaFileKey)>;
) -> Result<(Option<String>, Option<String>, Vec<u8>)>;
}

View file

@ -474,8 +474,8 @@ pub(crate) fn gen_event_id_canonical_json(
room_version_id: &RoomVersionId,
) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> {
let value: CanonicalJsonObject =
serde_json::from_str(pdu.get()).map_err(|error| {
warn!(%error, object = ?pdu, "Error parsing incoming event");
serde_json::from_str(pdu.get()).map_err(|e| {
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
Error::BadServerResponse("Invalid PDU in server response")
})?;

View file

@ -25,9 +25,9 @@ use ruma::{
serde::Raw,
uint, RoomId, UInt, UserId,
};
use tracing::warn;
use tracing::{info, warn};
use crate::{services, utils, Error, PduEvent, Result};
use crate::{services, Error, PduEvent, Result};
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
@ -78,8 +78,8 @@ impl Service {
SendAccessToken::IfRequired(""),
&[MatrixVersion::V1_0],
)
.map_err(|error| {
warn!(%error, %destination, "Failed to find destination");
.map_err(|e| {
warn!("Failed to find destination {}: {}", destination, e);
Error::BadServerResponse("Invalid destination")
})?
.map(BytesMut::freeze);
@ -105,21 +105,18 @@ impl Service {
);
// TODO: handle timeout
let body = response.bytes().await.unwrap_or_else(|error| {
warn!(%error, "Server error");
let body = response.bytes().await.unwrap_or_else(|e| {
warn!("server error {}", e);
Vec::new().into()
});
if status != 200 {
warn!(
push_gateway = %destination,
%status,
%url,
body = %utils::dbg_truncate_str(
String::from_utf8_lossy(&body).as_ref(),
100,
),
"Push gateway returned bad response",
info!(
"Push gateway returned bad response {} {}\n{}\n{:?}",
destination,
status,
url,
crate::utils::string_from_bytes(&body)
);
}
@ -128,25 +125,22 @@ impl Service {
.body(body)
.expect("reqwest body is valid http body"),
);
response.map_err(|error| {
warn!(
%error,
appservice = %destination,
%url,
"Push gateway returned invalid response bytes",
response.map_err(|_| {
info!(
"Push gateway returned invalid response bytes {}\n{}",
destination, url
);
Error::BadServerResponse(
"Push gateway returned bad response.",
)
})
}
Err(error) => {
Err(e) => {
warn!(
%error,
%destination,
"Could not send request to push gateway",
"Could not send request to pusher {}: {}",
destination, e
);
Err(error.into())
Err(e.into())
}
}
}

View file

@ -8,7 +8,6 @@ pub(crate) use data::Data;
use ruma::{api::client::error::ErrorKind, EventId, RoomId};
use tracing::{debug, error, warn};
use super::short::ShortEventId;
use crate::{services, utils::debug_slice_truncated, Error, Result};
pub(crate) struct Service {
@ -18,16 +17,16 @@ pub(crate) struct Service {
impl Service {
pub(crate) fn get_cached_eventid_authchain(
&self,
key: &[ShortEventId],
) -> Result<Option<Arc<HashSet<ShortEventId>>>> {
key: &[u64],
) -> Result<Option<Arc<HashSet<u64>>>> {
self.db.get_cached_eventid_authchain(key)
}
#[tracing::instrument(skip(self))]
pub(crate) fn cache_auth_chain(
&self,
key: Vec<ShortEventId>,
auth_chain: Arc<HashSet<ShortEventId>>,
key: Vec<u64>,
auth_chain: Arc<HashSet<u64>>,
) -> Result<()> {
self.db.cache_auth_chain(key, auth_chain)
}
@ -52,7 +51,7 @@ impl Service {
// I'm afraid to change this in case there is accidental reliance on
// the truncation
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
let bucket_id = (short.get() % NUM_BUCKETS as u64) as usize;
let bucket_id = (short % NUM_BUCKETS as u64) as usize;
buckets[bucket_id].insert((short, id.clone()));
i += 1;
if i % 100 == 0 {
@ -69,10 +68,12 @@ impl Service {
continue;
}
let chunk_key: Vec<_> =
let chunk_key: Vec<u64> =
chunk.iter().map(|(short, _)| short).copied().collect();
if let Some(cached) =
self.get_cached_eventid_authchain(&chunk_key)?
if let Some(cached) = services()
.rooms
.auth_chain
.get_cached_eventid_authchain(&chunk_key)?
{
hits += 1;
full_auth_chain.extend(cached.iter().copied());
@ -85,8 +86,10 @@ impl Service {
let mut misses2 = 0;
let mut i = 0;
for (sevent_id, event_id) in chunk {
if let Some(cached) =
self.get_cached_eventid_authchain(&[sevent_id])?
if let Some(cached) = services()
.rooms
.auth_chain
.get_cached_eventid_authchain(&[sevent_id])?
{
hits2 += 1;
chunk_cache.extend(cached.iter().copied());
@ -95,7 +98,7 @@ impl Service {
let auth_chain = Arc::new(
self.get_auth_chain_inner(room_id, &event_id)?,
);
self.cache_auth_chain(
services().rooms.auth_chain.cache_auth_chain(
vec![sevent_id],
Arc::clone(&auth_chain),
)?;
@ -119,7 +122,10 @@ impl Service {
"Chunk missed",
);
let chunk_cache = Arc::new(chunk_cache);
self.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
services()
.rooms
.auth_chain
.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
full_auth_chain.extend(chunk_cache.iter());
}
@ -140,7 +146,7 @@ impl Service {
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<HashSet<ShortEventId>> {
) -> Result<HashSet<u64>> {
let mut todo = vec![Arc::from(event_id)];
let mut found = HashSet::new();
@ -148,10 +154,9 @@ impl Service {
match services().rooms.timeline.get_pdu(&event_id) {
Ok(Some(pdu)) => {
if pdu.room_id != room_id {
warn!(bad_room_id = %pdu.room_id, "Event referenced in auth chain has incorrect room id");
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"Event has incorrect room id",
"Evil event in db",
));
}
for auth_event in &pdu.auth_events {

View file

@ -1,15 +1,15 @@
use std::{collections::HashSet, sync::Arc};
use crate::{service::rooms::short::ShortEventId, Result};
use crate::Result;
pub(crate) trait Data: Send + Sync {
fn get_cached_eventid_authchain(
&self,
shorteventid: &[ShortEventId],
) -> Result<Option<Arc<HashSet<ShortEventId>>>>;
shorteventid: &[u64],
) -> Result<Option<Arc<HashSet<u64>>>>;
fn cache_auth_chain(
&self,
shorteventid: Vec<ShortEventId>,
auth_chain: Arc<HashSet<ShortEventId>>,
shorteventid: Vec<u64>,
auth_chain: Arc<HashSet<u64>>,
) -> Result<()>;
}

View file

@ -1,10 +1,14 @@
/// An async function that can recursively call itself.
type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>;
use std::{
collections::{hash_map, BTreeMap, HashMap, HashSet},
pin::Pin,
sync::Arc,
time::{Duration, Instant, SystemTime},
};
use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
use futures_util::{stream::FuturesUnordered, Future, StreamExt};
use ruma::{
api::{
client::error::ErrorKind,
@ -32,15 +36,11 @@ use ruma::{
MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedServerSigningKeyId,
RoomId, RoomVersionId, ServerName,
};
use serde::Deserialize;
use serde_json::value::RawValue as RawJsonValue;
use tokio::sync::{RwLock, RwLockWriteGuard, Semaphore};
use tracing::{debug, error, info, trace, warn};
use super::{
short::ShortStateKey, state_compressor::CompressedStateEvent,
timeline::PduId,
};
use super::state_compressor::CompressedStateEvent;
use crate::{
service::{globals::SigningKeys, pdu},
services,
@ -48,11 +48,6 @@ use crate::{
Error, PduEvent, Result,
};
#[derive(Deserialize)]
struct ExtractOriginServerTs {
origin_server_ts: MilliSecondsSinceUnixEpoch,
}
pub(crate) struct Service;
impl Service {
@ -83,16 +78,18 @@ impl Service {
/// 13. Use state resolution to find new room state
/// 14. Check if the event passes auth based on the "current state" of the
/// room, if not soft fail it
// We use some AsyncRecursiveType hacks here so we can call this async
// funtion recursively
#[tracing::instrument(skip(self, value, is_timeline_event, pub_key_map))]
pub(crate) async fn handle_incoming_pdu<'a>(
&self,
origin: &'a ServerName,
event_id: &'a EventId,
room_id: &'a RoomId,
value: CanonicalJsonObject,
value: BTreeMap<String, CanonicalJsonValue>,
is_timeline_event: bool,
pub_key_map: &'a RwLock<BTreeMap<String, SigningKeys>>,
) -> Result<Option<PduId>> {
) -> Result<Option<Vec<u8>>> {
// 0. Check the server is in the room
if !services().rooms.metadata.exists(room_id)? {
return Err(Error::BadRequest(
@ -108,7 +105,7 @@ impl Service {
));
}
self.acl_check(origin, room_id)?;
services().rooms.event_handler.acl_check(origin, room_id)?;
// 1. Skip the PDU if we already have it as a timeline event
if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(event_id)? {
@ -124,12 +121,10 @@ impl Service {
})?;
let create_event_content: RoomCreateEventContent =
serde_json::from_str(create_event.content.get()).map_err(
|error| {
error!(%error, "Invalid create event");
Error::BadDatabase("Invalid create event.")
},
)?;
serde_json::from_str(create_event.content.get()).map_err(|e| {
error!("Invalid create event: {}", e);
Error::BadDatabase("Invalid create event in db")
})?;
let room_version_id = &create_event_content.room_version;
let first_pdu_in_room =
@ -200,7 +195,7 @@ impl Service {
}
if time.elapsed() < min_elapsed_duration {
info!(event_id = %prev_id, "Backing off from prev event");
info!("Backing off from {}", prev_id);
continue;
}
}
@ -241,7 +236,7 @@ impl Service {
((*prev_id).to_owned(), start_time),
);
if let Err(error) = self
if let Err(e) = self
.upgrade_outlier_to_timeline_pdu(
pdu,
json,
@ -253,7 +248,7 @@ impl Service {
.await
{
errors += 1;
warn!(%error, event_id = %prev_id, "Prev event failed");
warn!("Prev event {} failed: {}", prev_id, e);
match services()
.globals
.bad_event_ratelimiter
@ -269,6 +264,7 @@ impl Service {
}
}
}
let elapsed = start_time.elapsed();
services()
.globals
.roomid_federationhandletime
@ -276,9 +272,10 @@ impl Service {
.await
.remove(&room_id.to_owned());
debug!(
elapsed = ?start_time.elapsed(),
event_id = %prev_id,
"Finished handling prev event",
"Handling prev event {} took {}m{}s",
prev_id,
elapsed.as_secs() / 60,
elapsed.as_secs() % 60
);
}
}
@ -292,7 +289,9 @@ impl Service {
.write()
.await
.insert(room_id.to_owned(), (event_id.to_owned(), start_time));
let r = self
let r = services()
.rooms
.event_handler
.upgrade_outlier_to_timeline_pdu(
incoming_pdu,
val,
@ -312,7 +311,7 @@ impl Service {
r
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::type_complexity, clippy::too_many_arguments)]
#[tracing::instrument(skip(self, origin, room_id, value, pub_key_map))]
fn handle_outlier_pdu<'a>(
&'a self,
@ -320,10 +319,13 @@ impl Service {
create_event: &'a PduEvent,
event_id: &'a EventId,
room_id: &'a RoomId,
mut value: CanonicalJsonObject,
mut value: BTreeMap<String, CanonicalJsonValue>,
auth_events_known: bool,
pub_key_map: &'a RwLock<BTreeMap<String, SigningKeys>>,
) -> BoxFuture<'a, Result<(Arc<PduEvent>, CanonicalJsonObject)>> {
) -> AsyncRecursiveType<
'a,
Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>,
> {
Box::pin(async move {
// 1.1. Remove unsigned field
value.remove("unsigned");
@ -332,8 +334,8 @@ impl Service {
// 3. check content hash, redact if doesn't match
let create_event_content: RoomCreateEventContent =
serde_json::from_str(create_event.content.get()).map_err(
|error| {
error!(%error, "Invalid create event");
|e| {
error!("Invalid create event: {}", e);
Error::BadDatabase("Invalid create event in db")
},
)?;
@ -395,9 +397,9 @@ impl Service {
&value,
room_version_id,
) {
Err(error) => {
Err(e) => {
// Drop
warn!(%event_id, %error, "Dropping bad event");
warn!("Dropping bad event {}: {}", event_id, e,);
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Signature verification failed",
@ -405,7 +407,7 @@ impl Service {
}
Ok(ruma::signatures::Verified::Signatures) => {
// Redact
warn!(%event_id, "Calculated hash does not match");
warn!("Calculated hash does not match: {}", event_id);
let Ok(obj) = ruma::canonical_json::redact(
value,
room_version_id,
@ -479,17 +481,16 @@ impl Service {
// 6. Reject "due to auth events" if the event doesn't pass auth
// based on the auth events
debug!(
event_id = %incoming_pdu.event_id,
"Starting auth check for event based on auth events",
"Auth check for {} based on auth events",
incoming_pdu.event_id
);
// Build map of auth events
let mut auth_events = HashMap::new();
for event_id in &incoming_pdu.auth_events {
let Some(auth_event) =
services().rooms.timeline.get_pdu(event_id)?
for id in &incoming_pdu.auth_events {
let Some(auth_event) = services().rooms.timeline.get_pdu(id)?
else {
warn!(%event_id, "Could not find auth event");
warn!("Could not find auth event {}", id);
continue;
};
@ -542,7 +543,7 @@ impl Service {
));
}
debug!("Validation successful");
debug!("Validation successful.");
// 7. Persist the event as an outlier.
services()
@ -550,7 +551,7 @@ impl Service {
.outlier
.add_pdu_outlier(&incoming_pdu.event_id, &val)?;
debug!("Added pdu as outlier");
debug!("Added pdu as outlier.");
Ok((Arc::new(incoming_pdu), val))
})
@ -563,12 +564,12 @@ impl Service {
pub(crate) async fn upgrade_outlier_to_timeline_pdu(
&self,
incoming_pdu: Arc<PduEvent>,
val: CanonicalJsonObject,
val: BTreeMap<String, CanonicalJsonValue>,
create_event: &PduEvent,
origin: &ServerName,
room_id: &RoomId,
pub_key_map: &RwLock<BTreeMap<String, SigningKeys>>,
) -> Result<Option<PduId>> {
) -> Result<Option<Vec<u8>>> {
// Skip the PDU if we already have it as a timeline event
if let Ok(Some(pduid)) =
services().rooms.timeline.get_pdu_id(&incoming_pdu.event_id)
@ -587,18 +588,13 @@ impl Service {
));
}
debug!(
event_id = %incoming_pdu.event_id,
"Upgrading event to timeline pdu",
);
info!("Upgrading {} to timeline pdu", incoming_pdu.event_id);
let create_event_content: RoomCreateEventContent =
serde_json::from_str(create_event.content.get()).map_err(
|error| {
warn!(%error, "Invalid create event");
Error::BadDatabase("Invalid create event in db")
},
)?;
serde_json::from_str(create_event.content.get()).map_err(|e| {
warn!("Invalid create event: {}", e);
Error::BadDatabase("Invalid create event in db")
})?;
let room_version_id = &create_event_content.room_version;
let room_version = RoomVersion::new(room_version_id)
@ -729,7 +725,7 @@ impl Service {
id.clone(),
);
} else {
warn!("Failed to get_statekey_from_short");
warn!("Failed to get_statekey_from_short.");
}
starting_events.push(id);
}
@ -752,10 +748,10 @@ impl Service {
room_version_id,
&fork_states,
auth_chain_sets,
|event_id| {
let res = services().rooms.timeline.get_pdu(event_id);
if let Err(error) = &res {
error!(%error, %event_id, "Failed to fetch event");
|id| {
let res = services().rooms.timeline.get_pdu(id);
if let Err(e) = &res {
error!("LOOK AT ME Failed to fetch event: {}", e);
}
res.ok().flatten()
},
@ -778,11 +774,12 @@ impl Service {
})
.collect::<Result<_>>()?,
),
Err(error) => {
Err(e) => {
warn!(
%error,
"State resolution on prev events failed, either \
an event could not be found or deserialization"
an event could not be found or deserialization: \
{}",
e
);
None
}
@ -807,7 +804,7 @@ impl Service {
.await
{
Ok(res) => {
debug!("Fetching state events at event");
debug!("Fetching state events at event.");
let collect = res
.pdu_ids
.iter()
@ -871,9 +868,9 @@ impl Service {
state_at_incoming_event = Some(state);
}
Err(error) => {
warn!(%error, "Fetching state for event failed");
return Err(error);
Err(e) => {
warn!("Fetching state for event failed: {}", e);
return Err(e);
}
};
}
@ -940,7 +937,16 @@ impl Service {
})? || incoming_pdu.kind
== TimelineEventType::RoomRedaction
&& match room_version_id {
room_version if *room_version < RoomVersionId::V11 => {
RoomVersionId::V1
| RoomVersionId::V2
| RoomVersionId::V3
| RoomVersionId::V4
| RoomVersionId::V5
| RoomVersionId::V6
| RoomVersionId::V7
| RoomVersionId::V8
| RoomVersionId::V9
| RoomVersionId::V10 => {
if let Some(redact_id) = &incoming_pdu.redacts {
!services().rooms.state_accessor.user_can_redact(
redact_id,
@ -974,20 +980,23 @@ impl Service {
}
}
_ => {
return Err(Error::BadServerResponse(
"Unsupported room version.",
))
unreachable!("Validity of room version already checked")
}
};
// 13. Use state resolution to find new room state
// We start looking at current room state now, so lets lock the room
let room_token = services()
.globals
.roomid_mutex_state
.lock_key(room_id.to_owned())
.await;
let mutex_state = Arc::clone(
services()
.globals
.roomid_mutex_state
.write()
.await
.entry(room_id.to_owned())
.or_default(),
);
let state_lock = mutex_state.lock().await;
// Now we calculate the set of extremities this room has after the
// incoming event has been applied. We start with the previous
@ -1056,7 +1065,7 @@ impl Service {
services()
.rooms
.state
.force_state(&room_token, sstatehash, new, removed)
.force_state(room_id, sstatehash, new, removed, &state_lock)
.await?;
}
@ -1074,13 +1083,13 @@ impl Service {
extremities.iter().map(|e| (**e).to_owned()).collect(),
state_ids_compressed,
soft_fail,
&room_token,
&state_lock,
)
.await?;
// Soft fail, we keep the event as an outlier but don't add it to
// the timeline
warn!("Event was soft failed");
warn!("Event was soft failed: {:?}", incoming_pdu);
services()
.rooms
.pdu_metadata
@ -1107,14 +1116,14 @@ impl Service {
extremities.iter().map(|e| (**e).to_owned()).collect(),
state_ids_compressed,
soft_fail,
&room_token,
&state_lock,
)
.await?;
debug!("Appended incoming pdu");
// Event has passed all auth/stateres checks
drop(room_token);
drop(state_lock);
Ok(pdu_id)
}
@ -1123,7 +1132,7 @@ impl Service {
&self,
room_id: &RoomId,
room_version_id: &RoomVersionId,
incoming_state: HashMap<ShortStateKey, Arc<EventId>>,
incoming_state: HashMap<u64, Arc<EventId>>,
) -> Result<Arc<HashSet<CompressedStateEvent>>> {
debug!("Loading current room state ids");
let current_sstatehash = services()
@ -1179,8 +1188,8 @@ impl Service {
let fetch_event = |id: &_| {
let res = services().rooms.timeline.get_pdu(id);
if let Err(error) = &res {
error!(%error, "Failed to fetch event");
if let Err(e) = &res {
error!("LOOK AT ME Failed to fetch event: {}", e);
}
res.ok().flatten()
};
@ -1200,7 +1209,7 @@ impl Service {
drop(lock);
debug!("State resolution done; compressing state");
debug!("State resolution done. Compressing state");
let new_room_state = state
.into_iter()
@ -1229,6 +1238,7 @@ impl Service {
/// b. Look at outlier pdu tree
/// c. Ask origin server over federation
/// d. TODO: Ask other servers over federation?
#[allow(clippy::type_complexity)]
#[tracing::instrument(skip_all)]
pub(crate) fn fetch_and_handle_outliers<'a>(
&'a self,
@ -1238,7 +1248,10 @@ impl Service {
room_id: &'a RoomId,
room_version_id: &'a RoomVersionId,
pub_key_map: &'a RwLock<BTreeMap<String, SigningKeys>>,
) -> BoxFuture<'a, Vec<(Arc<PduEvent>, Option<CanonicalJsonObject>)>> {
) -> AsyncRecursiveType<
'a,
Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)>,
> {
Box::pin(async move {
let back_off = |id| async move {
match services()
@ -1258,14 +1271,14 @@ impl Service {
};
let mut pdus = vec![];
for event_id in events {
for id in events {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Ok(Some(local_pdu)) =
services().rooms.timeline.get_pdu(event_id)
services().rooms.timeline.get_pdu(id)
{
trace!(%event_id, "Found event locally");
trace!("Found {} in db", id);
pdus.push((local_pdu, None));
continue;
}
@ -1273,7 +1286,7 @@ impl Service {
// c. Ask origin server over federation
// We also handle its auth chain here so we don't get a stack
// overflow in handle_outlier_pdu.
let mut todo_auth_events = vec![Arc::clone(event_id)];
let mut todo_auth_events = vec![Arc::clone(id)];
let mut events_in_reverse_order = Vec::new();
let mut events_all = HashSet::new();
let mut i = 0;
@ -1296,10 +1309,7 @@ impl Service {
}
if time.elapsed() < min_elapsed_duration {
info!(
event_id = %next_id,
"Backing off from event",
);
info!("Backing off from {}", next_id);
continue;
}
}
@ -1316,14 +1326,11 @@ impl Service {
if let Ok(Some(_)) =
services().rooms.timeline.get_pdu(&next_id)
{
trace!(event_id = %next_id, "Found event locally");
trace!("Found {} in db", next_id);
continue;
}
info!(
event_id = %next_id,
"Fetching event over federation",
);
info!("Fetching {} over federation.", next_id);
if let Ok(res) = services()
.sending
.send_federation_request(
@ -1334,7 +1341,7 @@ impl Service {
)
.await
{
info!(event_id = %next_id, "Got event over federation");
info!("Got {} over federation", next_id);
let Ok((calculated_event_id, value)) =
pdu::gen_event_id_canonical_json(
&res.pdu,
@ -1347,10 +1354,9 @@ impl Service {
if calculated_event_id != *next_id {
warn!(
expected_event_id = %next_id,
actual_event_id = %calculated_event_id,
"Server returned an event with a different ID \
than requested",
"Server didn't return event id we requested: \
requested: {}, we got {}. Event: {:?}",
next_id, calculated_event_id, &res.pdu
);
}
@ -1374,7 +1380,7 @@ impl Service {
events_in_reverse_order.push((next_id.clone(), value));
events_all.insert(next_id);
} else {
warn!(event_id = %next_id, "Failed to fetch event");
warn!("Failed to fetch event: {}", next_id);
back_off((*next_id).to_owned()).await;
}
}
@ -1398,10 +1404,7 @@ impl Service {
}
if time.elapsed() < min_elapsed_duration {
info!(
event_id = %next_id,
"Backing off from event",
);
info!("Backing off from {}", next_id);
continue;
}
}
@ -1419,15 +1422,14 @@ impl Service {
.await
{
Ok((pdu, json)) => {
if next_id == event_id {
if next_id == id {
pdus.push((pdu, Some(json)));
}
}
Err(error) => {
Err(e) => {
warn!(
event_id = %next_id,
%error,
"Event failed auth checks",
"Authentication of event {} failed: {:?}",
next_id, e
);
back_off((**next_id).to_owned()).await;
}
@ -1450,7 +1452,10 @@ impl Service {
initial_set: Vec<Arc<EventId>>,
) -> Result<(
Vec<Arc<EventId>>,
HashMap<Arc<EventId>, (Arc<PduEvent>, CanonicalJsonObject)>,
HashMap<
Arc<EventId>,
(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>),
>,
)> {
let mut graph: HashMap<Arc<EventId>, _> = HashMap::new();
let mut eventid_info = HashMap::new();
@ -1545,7 +1550,7 @@ impl Service {
#[tracing::instrument(skip_all)]
pub(crate) async fn fetch_required_signing_keys(
&self,
event: &CanonicalJsonObject,
event: &BTreeMap<String, CanonicalJsonValue>,
pub_key_map: &RwLock<BTreeMap<String, SigningKeys>>,
) -> Result<()> {
let signatures = event
@ -1583,7 +1588,10 @@ impl Service {
.await;
let Ok(keys) = fetch_res else {
warn!("Failed to fetch signing key");
warn!(
"Signature verification failed: Could not fetch signing \
key.",
);
continue;
};
@ -1608,8 +1616,8 @@ impl Service {
pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap<String, SigningKeys>>,
) -> Result<()> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get())
.map_err(|error| {
error!(%error, ?pdu, "Invalid PDU in server response");
.map_err(|e| {
error!("Invalid PDU in server response: {:?}: {:?}", pdu, e);
Error::BadServerResponse("Invalid PDU in server response")
})?;
@ -1621,15 +1629,6 @@ impl Service {
let event_id = <&EventId>::try_from(event_id.as_str())
.expect("ruma's reference hashes are valid event ids");
let ExtractOriginServerTs {
origin_server_ts,
} = ExtractOriginServerTs::deserialize(pdu).map_err(|_| {
Error::BadServerResponse(
"Invalid PDU in server response, origin_server_ts field is \
missing or invalid",
)
})?;
if let Some((time, tries)) =
services().globals.bad_event_ratelimiter.read().await.get(event_id)
{
@ -1641,9 +1640,9 @@ impl Service {
}
if time.elapsed() < min_elapsed_duration {
debug!(%event_id, "Backing off from event");
debug!("Backing off from {}", event_id);
return Err(Error::BadServerResponse(
"Bad event, still backing off",
"bad event, still backing off",
));
}
}
@ -1669,12 +1668,15 @@ impl Service {
let contains_all_ids = |keys: &SigningKeys| {
signature_ids.iter().all(|id| {
keys.verify_keys.get(id).is_some_and(|_| {
keys.valid_until_ts >= origin_server_ts
}) || keys
.old_verify_keys
.get(id)
.is_some_and(|v| v.expired_ts >= origin_server_ts)
keys.verify_keys
.keys()
.map(ToString::to_string)
.any(|key_id| id == &key_id)
|| keys
.old_verify_keys
.keys()
.map(ToString::to_string)
.any(|key_id| id == &key_id)
})
};
@ -1686,28 +1688,20 @@ impl Service {
)
})?;
// check that we have the server in our list already, or
// all `signature_ids` are in pub_key_map
// if yes, we don't have to do anything
if servers.contains_key(origin)
|| pub_key_map
.get(origin.as_str())
.is_some_and(contains_all_ids)
|| pub_key_map.contains_key(origin.as_str())
{
continue;
}
trace!(server = %origin, "Loading signing keys for other server");
trace!("Loading signing keys for {}", origin);
if let Some(result) = services().globals.signing_keys_for(origin)? {
if !contains_all_ids(&result) {
trace!("Signing key not loaded for {}", origin);
servers.insert(origin.to_owned(), BTreeMap::new());
}
let result = services().globals.signing_keys_for(origin)?;
if !result.as_ref().is_some_and(contains_all_ids) {
trace!(
server = %origin,
"Signing key not loaded for server",
);
servers.insert(origin.to_owned(), BTreeMap::new());
}
if let Some(result) = result {
pub_key_map.insert(origin.to_string(), result);
}
}
@ -1747,7 +1741,7 @@ impl Service {
)
.await
{
debug!(%error, "Failed to get server keys from cache");
debug!(%error, "failed to get server keys from cache");
};
}
@ -1760,7 +1754,7 @@ impl Service {
}
for server in services().globals.trusted_servers() {
info!(%server, "Asking batch signing keys from trusted server");
info!("Asking batch signing keys from trusted server {}", server);
if let Ok(keys) = services()
.sending
.send_federation_request(
@ -1771,18 +1765,18 @@ impl Service {
)
.await
{
trace!(signing_keys = ?keys, "Got signing keys");
trace!("Got signing keys: {:?}", keys);
let mut pkm = pub_key_map.write().await;
for k in keys.server_keys {
let k = match k.deserialize() {
Ok(key) => key,
Err(error) => {
Err(e) => {
warn!(
%error,
%server,
object = ?k.json(),
"Failed to fetch keys from trusted server",
"Received error {} while fetching keys from \
trusted server {}",
e, server
);
warn!("{}", k.into_json());
continue;
}
};
@ -1807,7 +1801,7 @@ impl Service {
}
}
info!(?servers, "Asking individual servers for signing keys");
info!("Asking individual servers for signing keys: {servers:?}");
let mut futures: FuturesUnordered<_> = servers
.into_keys()
.map(|server| async move {
@ -1825,8 +1819,9 @@ impl Service {
.collect();
while let Some(result) = futures.next().await {
info!("Received new result");
if let (Ok(get_keys_response), origin) = result {
info!(server = %origin, "Received new result from server");
info!("Result is from {origin}");
if let Ok(key) = get_keys_response.server_key.deserialize() {
let result = services()
.globals
@ -1863,15 +1858,11 @@ impl Service {
return Ok(());
};
let acl_event_content = match serde_json::from_str::<
let Ok(acl_event_content) = serde_json::from_str::<
RoomServerAclEventContent,
>(acl_event.content.get())
{
Ok(x) => x,
Err(error) => {
warn!(%error, "Invalid ACL event");
return Ok(());
}
>(acl_event.content.get()) else {
warn!("Invalid ACL event");
return Ok(());
};
if acl_event_content.allow.is_empty() {
@ -1883,9 +1874,8 @@ impl Service {
Ok(())
} else {
info!(
server = %server_name,
%room_id,
"Other server was denied by room ACL",
"Server {} was denied by room ACL in {}",
server_name, room_id
);
Err(Error::BadRequest(
ErrorKind::forbidden(),
@ -1925,9 +1915,25 @@ impl Service {
let permit = services()
.globals
.servername_ratelimiter
.get_or_insert_with(origin.to_owned(), || Semaphore::new(1))
.await;
let permit = permit.acquire().await;
.read()
.await
.get(origin)
.map(|s| Arc::clone(s).acquire_owned());
let permit = if let Some(p) = permit {
p
} else {
let mut write =
services().globals.servername_ratelimiter.write().await;
let s = Arc::clone(
write
.entry(origin.to_owned())
.or_insert_with(|| Arc::new(Semaphore::new(1))),
);
s.acquire_owned()
}
.await;
let back_off = |id| async {
match services()
@ -1961,7 +1967,7 @@ impl Service {
}
if time.elapsed() < min_elapsed_duration {
debug!(?signature_ids, "Backing off from signatures");
debug!("Backing off from {:?}", signature_ids);
return Err(Error::BadServerResponse(
"bad signature, still backing off",
));
@ -1981,10 +1987,8 @@ impl Service {
.expect("Should be valid until year 500,000,000");
debug!(
server = %origin,
ts_threshold = %ts_threshold.get(),
ts_valid_until = %result.valid_until_ts.get(),
"Loaded signing keys for server",
"The threshhold is {:?}, found time is {:?} for server {}",
ts_threshold, result.valid_until_ts, origin
);
if contains_all_ids(&result) {
@ -1994,7 +1998,7 @@ impl Service {
debug!(
origin = %origin,
valid_until_ts = %result.valid_until_ts.get(),
"Keys are valid because they expire after threshold",
"Keys for are deemed as valid, as they expire after threshold",
);
return Ok(result);
}
@ -2174,12 +2178,7 @@ impl Service {
#[tracing::instrument(skip_all)]
fn check_room_id(room_id: &RoomId, pdu: &PduEvent) -> Result<()> {
if pdu.room_id != room_id {
warn!(
event_id = %pdu.event_id,
expected_room_id = %pdu.room_id,
actual_room_id = %room_id,
"Event has wrong room ID",
);
warn!("Found event from room {} in room {}", pdu.room_id, room_id);
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Event has wrong room id",

View file

@ -45,7 +45,12 @@ impl Service {
}
}
#[allow(clippy::too_many_arguments, clippy::too_many_lines)]
#[allow(
clippy::too_many_arguments,
clippy::too_many_lines,
// Allowed because this function uses `services()`
clippy::unused_self,
)]
#[tracing::instrument(skip(self))]
pub(crate) fn paginate_relations_with_filter(
&self,
@ -64,7 +69,9 @@ impl Service {
match ruma::api::Direction::Backward {
ruma::api::Direction::Forward => {
// TODO: should be relations_after
let events_after: Vec<_> = self
let events_after: Vec<_> = services()
.rooms
.pdu_metadata
.relations_until(sender_user, room_id, target, from)?
.filter(|r| {
r.as_ref().map_or(true, |(_, pdu)| {
@ -119,7 +126,9 @@ impl Service {
})
}
ruma::api::Direction::Backward => {
let events_before: Vec<_> = self
let events_before: Vec<_> = services()
.rooms
.pdu_metadata
.relations_until(sender_user, room_id, target, from)?
.filter(|r| {
r.as_ref().map_or(true, |(_, pdu)| {

View file

@ -2,10 +2,7 @@ use std::sync::Arc;
use ruma::{EventId, RoomId, UserId};
use crate::{
service::rooms::{short::ShortRoomId, timeline::PduCount},
PduEvent, Result,
};
use crate::{service::rooms::timeline::PduCount, PduEvent, Result};
pub(crate) trait Data: Send + Sync {
fn add_relation(&self, from: u64, to: u64) -> Result<()>;
@ -13,7 +10,7 @@ pub(crate) trait Data: Send + Sync {
fn relations_until<'a>(
&'a self,
user_id: &'a UserId,
room_id: ShortRoomId,
room_id: u64,
target: u64,
until: PduCount,
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>>;

View file

@ -1,22 +1,19 @@
use ruma::RoomId;
use crate::{
service::rooms::{short::ShortRoomId, timeline::PduId},
Result,
};
use crate::Result;
pub(crate) trait Data: Send + Sync {
fn index_pdu(
&self,
shortroomid: ShortRoomId,
pdu_id: &PduId,
shortroomid: u64,
pdu_id: &[u8],
message_body: &str,
) -> Result<()>;
fn deindex_pdu(
&self,
shortroomid: ShortRoomId,
pdu_id: &PduId,
shortroomid: u64,
pdu_id: &[u8],
message_body: &str,
) -> Result<()>;
@ -25,5 +22,5 @@ pub(crate) trait Data: Send + Sync {
&'a self,
room_id: &RoomId,
search_string: &str,
) -> Result<Option<(Box<dyn Iterator<Item = PduId> + 'a>, Vec<String>)>>;
) -> Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>;
}

View file

@ -1,27 +1,4 @@
mod data;
macro_rules! short_id_type {
($name:ident) => {
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(transparent)]
pub(crate) struct $name(u64);
impl $name {
pub(crate) fn new(id: u64) -> Self {
Self(id)
}
pub(crate) fn get(&self) -> u64 {
self.0
}
}
};
}
short_id_type!(ShortRoomId);
short_id_type!(ShortEventId);
short_id_type!(ShortStateHash);
short_id_type!(ShortStateKey);
pub(crate) use data::Data;
pub(crate) type Service = &'static dyn Data;

View file

@ -2,47 +2,38 @@ use std::sync::Arc;
use ruma::{events::StateEventType, EventId, RoomId};
use super::{ShortEventId, ShortRoomId, ShortStateHash, ShortStateKey};
use crate::Result;
pub(crate) trait Data: Send + Sync {
fn get_or_create_shorteventid(
&self,
event_id: &EventId,
) -> Result<ShortEventId>;
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64>;
fn get_shortstatekey(
&self,
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<ShortStateKey>>;
) -> Result<Option<u64>>;
fn get_or_create_shortstatekey(
&self,
event_type: &StateEventType,
state_key: &str,
) -> Result<ShortStateKey>;
) -> Result<u64>;
fn get_eventid_from_short(
&self,
shorteventid: ShortEventId,
) -> Result<Arc<EventId>>;
fn get_eventid_from_short(&self, shorteventid: u64)
-> Result<Arc<EventId>>;
fn get_statekey_from_short(
&self,
shortstatekey: ShortStateKey,
shortstatekey: u64,
) -> Result<(StateEventType, String)>;
/// Returns `(shortstatehash, already_existed)`
fn get_or_create_shortstatehash(
&self,
state_hash: &[u8],
) -> Result<(ShortStateHash, bool)>;
) -> Result<(u64, bool)>;
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<ShortRoomId>>;
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>>;
fn get_or_create_shortroomid(
&self,
room_id: &RoomId,
) -> Result<ShortRoomId>;
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64>;
}

View file

@ -179,10 +179,11 @@ impl Service {
.map(|s| {
serde_json::from_str(s.content.get())
.map(|c: RoomJoinRulesEventContent| c.join_rule)
.map_err(|error| {
.map_err(|e| {
error!(
%error,
"Invalid room join rule event"
"Invalid room join rule event in \
database: {}",
e
);
Error::BadDatabase(
"Invalid room join rule event in \
@ -217,7 +218,7 @@ impl Service {
// Early return so the client can see some data already
break;
}
debug!(%server, "Asking other server for /hierarchy");
debug!("Asking {server} for /hierarchy");
if let Ok(response) = services()
.sending
.send_federation_request(
@ -230,9 +231,8 @@ impl Service {
.await
{
warn!(
%server,
?response,
"Got response from other server for /hierarchy",
"Got response from {server} for \
/hierarchy\n{response:?}"
);
let chunk = SpaceHierarchyRoomsChunk {
canonical_alias: response.room.canonical_alias,
@ -327,7 +327,7 @@ impl Service {
}
#[allow(clippy::too_many_lines)]
#[tracing::instrument(skip(self, children))]
#[tracing::instrument(skip(self, sender_user, children))]
fn get_room_chunk(
&self,
sender_user: &UserId,
@ -346,13 +346,8 @@ impl Service {
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomCanonicalAliasEventContent| c.alias)
.map_err(|error| {
error!(
%error,
event_id = %s.event_id,
"Invalid room canonical alias event"
);
Error::BadDatabase(
.map_err(|_| {
Error::bad_database(
"Invalid canonical alias event in database.",
)
})
@ -363,7 +358,7 @@ impl Service {
.state_cache
.room_joined_count(room_id)?
.unwrap_or_else(|| {
warn!("Room has no member count");
warn!("Room {} has no member count", room_id);
0
})
.try_into()
@ -376,13 +371,13 @@ impl Service {
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomTopicEventContent| Some(c.topic))
.map_err(|error| {
.map_err(|_| {
error!(
%error,
event_id = %s.event_id,
"Invalid room topic event"
"Invalid room topic event in database for \
room {}",
room_id
);
Error::BadDatabase(
Error::bad_database(
"Invalid room topic event in database.",
)
})
@ -401,13 +396,8 @@ impl Service {
c.history_visibility
== HistoryVisibility::WorldReadable
})
.map_err(|error| {
error!(
%error,
event_id = %s.event_id,
"Invalid room history visibility event"
);
Error::BadDatabase(
.map_err(|_| {
Error::bad_database(
"Invalid room history visibility event in \
database.",
)
@ -422,13 +412,8 @@ impl Service {
.map(|c: RoomGuestAccessEventContent| {
c.guest_access == GuestAccess::CanJoin
})
.map_err(|error| {
error!(
%error,
event_id = %s.event_id,
"Invalid room guest access event"
);
Error::BadDatabase(
.map_err(|_| {
Error::bad_database(
"Invalid room guest access event in database.",
)
})
@ -440,12 +425,7 @@ impl Service {
.map(|s| {
serde_json::from_str(s.content.get())
.map(|c: RoomAvatarEventContent| c.url)
.map_err(|error| {
error!(
%error,
event_id = %s.event_id,
"Invalid room avatar event"
);
.map_err(|_| {
Error::bad_database(
"Invalid room avatar event in database.",
)
@ -465,11 +445,11 @@ impl Service {
.map(|s| {
serde_json::from_str(s.content.get())
.map(|c: RoomJoinRulesEventContent| c.join_rule)
.map_err(|error| {
.map_err(|e| {
error!(
%error,
event_id = %s.event_id,
"Invalid room join rule event",
"Invalid room join rule event in \
database: {}",
e
);
Error::BadDatabase(
"Invalid room join rule event in database.",
@ -480,7 +460,7 @@ impl Service {
.unwrap_or(JoinRule::Invite);
if !self.handle_join_rule(&join_rule, sender_user, room_id)? {
debug!("User is not allowed to see room");
debug!("User is not allowed to see room {room_id}");
// This error will be caught later
return Err(Error::BadRequest(
ErrorKind::forbidden(),
@ -498,12 +478,8 @@ impl Service {
serde_json::from_str::<RoomCreateEventContent>(
s.content.get(),
)
.map_err(|error| {
error!(
%error,
event_id = %s.event_id,
"Invalid room create event",
);
.map_err(|e| {
error!("Invalid room create event in database: {}", e);
Error::BadDatabase(
"Invalid room create event in database.",
)

View file

@ -13,18 +13,16 @@ use ruma::{
},
serde::Raw,
state_res::{self, StateMap},
EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId,
EventId, OwnedEventId, RoomId, RoomVersionId, UserId,
};
use serde::Deserialize;
use tokio::sync::MutexGuard;
use tracing::warn;
use super::{short::ShortStateHash, state_compressor::CompressedStateEvent};
use super::state_compressor::CompressedStateEvent;
use crate::{
service::globals::marker,
services,
utils::{
calculate_hash, debug_slice_truncated, on_demand_hashmap::KeyToken,
},
utils::{calculate_hash, debug_slice_truncated},
Error, PduEvent, Result,
};
@ -34,13 +32,20 @@ pub(crate) struct Service {
impl Service {
/// Set the room to the given statehash and update caches.
#[tracing::instrument(skip(self, statediffnew, _statediffremoved))]
#[tracing::instrument(skip(
self,
statediffnew,
_statediffremoved,
state_lock
))]
pub(crate) async fn force_state(
&self,
room_id: &KeyToken<OwnedRoomId, marker::State>,
shortstatehash: ShortStateHash,
room_id: &RoomId,
shortstatehash: u64,
statediffnew: Arc<HashSet<CompressedStateEvent>>,
_statediffremoved: Arc<HashSet<CompressedStateEvent>>,
// Take mutex guard to make sure users get the room state mutex
state_lock: &MutexGuard<'_, ()>,
) -> Result<()> {
for event_id in statediffnew.iter().filter_map(|new| {
services()
@ -111,7 +116,7 @@ impl Service {
services().rooms.state_cache.update_joined_count(room_id)?;
self.db.set_room_state(room_id, shortstatehash)?;
self.db.set_room_state(room_id, shortstatehash, state_lock)?;
Ok(())
}
@ -126,7 +131,7 @@ impl Service {
event_id: &EventId,
room_id: &RoomId,
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> Result<ShortStateHash> {
) -> Result<u64> {
let shorteventid =
services().rooms.short.get_or_create_shorteventid(event_id)?;
@ -134,7 +139,7 @@ impl Service {
self.db.get_room_shortstatehash(room_id)?;
let state_hash = calculate_hash(
state_ids_compressed.iter().map(CompressedStateEvent::as_bytes),
&state_ids_compressed.iter().map(|s| &s[..]).collect::<Vec<_>>(),
);
let (shortstatehash, already_existed) =
@ -188,10 +193,7 @@ impl Service {
/// This adds all current state events (not including the incoming event)
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
#[tracing::instrument(skip(self, new_pdu))]
pub(crate) fn append_to_state(
&self,
new_pdu: &PduEvent,
) -> Result<ShortStateHash> {
pub(crate) fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> {
let shorteventid = services()
.rooms
.short
@ -229,9 +231,9 @@ impl Service {
let replaces = states_parents
.last()
.map(|info| {
info.full_state
.iter()
.find(|compressed| compressed.state == shortstatekey)
info.full_state.iter().find(|bytes| {
bytes.starts_with(&shortstatekey.to_be_bytes())
})
})
.unwrap_or_default();
@ -240,8 +242,7 @@ impl Service {
}
// TODO: statehash with deterministic inputs
let shortstatehash =
ShortStateHash::new(services().globals.next_count()?);
let shortstatehash = services().globals.next_count()?;
let mut statediffnew = HashSet::new();
statediffnew.insert(new);
@ -324,10 +325,12 @@ impl Service {
#[tracing::instrument(skip(self))]
pub(crate) fn set_room_state(
&self,
room_id: &KeyToken<OwnedRoomId, marker::State>,
shortstatehash: ShortStateHash,
room_id: &RoomId,
shortstatehash: u64,
// Take mutex guard to make sure users get the room state mutex
mutex_lock: &MutexGuard<'_, ()>,
) -> Result<()> {
self.db.set_room_state(room_id, shortstatehash)
self.db.set_room_state(room_id, shortstatehash, mutex_lock)
}
/// Returns the room's version.
@ -345,12 +348,10 @@ impl Service {
let create_event_content: RoomCreateEventContent = create_event
.as_ref()
.map(|create_event| {
serde_json::from_str(create_event.content.get()).map_err(
|error| {
warn!(%error, "Invalid create event");
Error::BadDatabase("Invalid create event in db.")
},
)
serde_json::from_str(create_event.content.get()).map_err(|e| {
warn!("Invalid create event: {}", e);
Error::bad_database("Invalid create event in db.")
})
})
.transpose()?
.ok_or_else(|| {
@ -367,7 +368,7 @@ impl Service {
pub(crate) fn get_room_shortstatehash(
&self,
room_id: &RoomId,
) -> Result<Option<ShortStateHash>> {
) -> Result<Option<u64>> {
self.db.get_room_shortstatehash(room_id)
}
@ -380,15 +381,17 @@ impl Service {
}
#[tracing::instrument(
skip(self, event_ids),
skip(self, event_ids, state_lock),
fields(event_ids = debug_slice_truncated(&event_ids, 5)),
)]
pub(crate) fn set_forward_extremities(
&self,
room_id: &KeyToken<OwnedRoomId, marker::State>,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
// Take mutex guard to make sure users get the room state mutex
state_lock: &MutexGuard<'_, ()>,
) -> Result<()> {
self.db.set_forward_extremities(room_id, event_ids)
self.db.set_forward_extremities(room_id, event_ids, state_lock)
}
/// This fetches auth events from the current state.
@ -401,7 +404,8 @@ impl Service {
state_key: Option<&str>,
content: &serde_json::value::RawValue,
) -> Result<StateMap<Arc<PduEvent>>> {
let Some(shortstatehash) = self.get_room_shortstatehash(room_id)?
let Some(shortstatehash) =
services().rooms.state.get_room_shortstatehash(room_id)?
else {
return Ok(HashMap::new());
};

View file

@ -1,35 +1,28 @@
use std::{collections::HashSet, sync::Arc};
use ruma::{EventId, OwnedEventId, OwnedRoomId, RoomId};
use ruma::{EventId, OwnedEventId, RoomId};
use tokio::sync::MutexGuard;
use crate::{
service::{
globals::marker,
rooms::short::{ShortEventId, ShortStateHash},
},
utils::on_demand_hashmap::KeyToken,
Result,
};
use crate::Result;
pub(crate) trait Data: Send + Sync {
/// Returns the last state hash key added to the db for the given room.
fn get_room_shortstatehash(
&self,
room_id: &RoomId,
) -> Result<Option<ShortStateHash>>;
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>>;
/// Set the state hash to a new version, but does not update `state_cache`.
fn set_room_state(
&self,
room_id: &KeyToken<OwnedRoomId, marker::State>,
new_shortstatehash: ShortStateHash,
room_id: &RoomId,
new_shortstatehash: u64,
// Take mutex guard to make sure users get the room state mutex
_mutex_lock: &MutexGuard<'_, ()>,
) -> Result<()>;
/// Associates a state with an event.
fn set_event_state(
&self,
shorteventid: ShortEventId,
shortstatehash: ShortStateHash,
shorteventid: u64,
shortstatehash: u64,
) -> Result<()>;
/// Returns all events we would send as the `prev_events` of the next event.
@ -41,7 +34,9 @@ pub(crate) trait Data: Send + Sync {
/// Replace the forward extremities of the room.
fn set_forward_extremities(
&self,
room_id: &KeyToken<OwnedRoomId, marker::State>,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
// Take mutex guard to make sure users get the room state mutex
_mutex_lock: &MutexGuard<'_, ()>,
) -> Result<()>;
}

View file

@ -20,27 +20,24 @@ use ruma::{
StateEventType,
},
state_res::Event,
EventId, JsOption, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId,
ServerName, UserId,
EventId, JsOption, OwnedServerName, OwnedUserId, RoomId, ServerName,
UserId,
};
use serde_json::value::to_raw_value;
use tokio::sync::MutexGuard;
use tracing::{error, warn};
use super::short::{ShortStateHash, ShortStateKey};
use crate::{
observability::{FoundIn, Lookup, METRICS},
service::{globals::marker, pdu::PduBuilder},
services,
utils::on_demand_hashmap::KeyToken,
Error, PduEvent, Result,
service::pdu::PduBuilder,
services, Error, PduEvent, Result,
};
pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
pub(crate) server_visibility_cache:
Mutex<LruCache<(OwnedServerName, ShortStateHash), bool>>,
pub(crate) user_visibility_cache:
Mutex<LruCache<(OwnedUserId, ShortStateHash), bool>>,
Mutex<LruCache<(OwnedServerName, u64), bool>>,
pub(crate) user_visibility_cache: Mutex<LruCache<(OwnedUserId, u64), bool>>,
}
impl Service {
@ -49,15 +46,15 @@ impl Service {
#[tracing::instrument(skip(self))]
pub(crate) async fn state_full_ids(
&self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<ShortStateKey, Arc<EventId>>> {
shortstatehash: u64,
) -> Result<HashMap<u64, Arc<EventId>>> {
self.db.state_full_ids(shortstatehash).await
}
#[tracing::instrument(skip(self))]
pub(crate) async fn state_full(
&self,
shortstatehash: ShortStateHash,
shortstatehash: u64,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
self.db.state_full(shortstatehash).await
}
@ -67,7 +64,7 @@ impl Service {
#[tracing::instrument(skip(self))]
pub(crate) fn state_get_id(
&self,
shortstatehash: ShortStateHash,
shortstatehash: u64,
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<Arc<EventId>>> {
@ -79,7 +76,7 @@ impl Service {
#[tracing::instrument(skip(self))]
pub(crate) fn state_get(
&self,
shortstatehash: ShortStateHash,
shortstatehash: u64,
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
@ -90,7 +87,7 @@ impl Service {
#[tracing::instrument(skip(self))]
fn user_membership(
&self,
shortstatehash: ShortStateHash,
shortstatehash: u64,
user_id: &UserId,
) -> Result<MembershipState> {
self.state_get(
@ -111,11 +108,7 @@ impl Service {
/// The user was a joined member at this state (potentially in the past)
#[tracing::instrument(skip(self), ret(level = "trace"))]
fn user_was_joined(
&self,
shortstatehash: ShortStateHash,
user_id: &UserId,
) -> bool {
fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool {
self.user_membership(shortstatehash, user_id)
.is_ok_and(|s| s == MembershipState::Join)
}
@ -123,11 +116,7 @@ impl Service {
/// The user was an invited or joined room member at this state (potentially
/// in the past)
#[tracing::instrument(skip(self), ret(level = "trace"))]
fn user_was_invited(
&self,
shortstatehash: ShortStateHash,
user_id: &UserId,
) -> bool {
fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool {
self.user_membership(shortstatehash, user_id).is_ok_and(|s| {
s == MembershipState::Join || s == MembershipState::Invite
})
@ -199,8 +188,8 @@ impl Service {
current_server_members
.any(|member| self.user_was_joined(shortstatehash, &member))
}
other => {
error!(kind = %other, "Unknown history visibility");
_ => {
error!("Unknown history visibility {history_visibility}");
false
}
};
@ -272,8 +261,8 @@ impl Service {
// Allow if any member on requested server was joined, else deny
self.user_was_joined(shortstatehash, user_id)
}
other => {
error!(kind = %other, "Unknown history visibility");
_ => {
error!("Unknown history visibility {history_visibility}");
false
}
};
@ -325,7 +314,7 @@ impl Service {
pub(crate) fn pdu_shortstatehash(
&self,
event_id: &EventId,
) -> Result<Option<ShortStateHash>> {
) -> Result<Option<u64>> {
self.db.pdu_shortstatehash(event_id)
}
@ -369,9 +358,13 @@ impl Service {
|s| {
serde_json::from_str(s.content.get())
.map(|c: RoomNameEventContent| Some(c.name))
.map_err(|error| {
error!(%error, "Invalid room name event in database");
Error::BadDatabase(
.map_err(|e| {
error!(
"Invalid room name event in database for room {}. \
{}",
room_id, e
);
Error::bad_database(
"Invalid room name event in database.",
)
})
@ -401,9 +394,10 @@ impl Service {
#[tracing::instrument(skip(self), ret(level = "trace"))]
pub(crate) fn user_can_invite(
&self,
room_id: &KeyToken<OwnedRoomId, marker::State>,
room_id: &RoomId,
sender: &UserId,
target_user: &UserId,
state_lock: &MutexGuard<'_, ()>,
) -> bool {
let content =
to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite))
@ -420,7 +414,7 @@ impl Service {
services()
.rooms
.timeline
.create_hash_and_sign_event(new_event, sender, room_id)
.create_hash_and_sign_event(new_event, sender, room_id, state_lock)
.is_ok()
}

View file

@ -3,10 +3,7 @@ use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use ruma::{events::StateEventType, EventId, RoomId};
use crate::{
service::rooms::short::{ShortStateHash, ShortStateKey},
PduEvent, Result,
};
use crate::{PduEvent, Result};
#[async_trait]
pub(crate) trait Data: Send + Sync {
@ -14,19 +11,19 @@ pub(crate) trait Data: Send + Sync {
/// with state_hash, this gives the full state for the given state_hash.
async fn state_full_ids(
&self,
shortstatehash: ShortStateHash,
) -> Result<HashMap<ShortStateKey, Arc<EventId>>>;
shortstatehash: u64,
) -> Result<HashMap<u64, Arc<EventId>>>;
async fn state_full(
&self,
shortstatehash: ShortStateHash,
shortstatehash: u64,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>>;
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
fn state_get_id(
&self,
shortstatehash: ShortStateHash,
shortstatehash: u64,
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<Arc<EventId>>>;
@ -35,16 +32,13 @@ pub(crate) trait Data: Send + Sync {
/// `state_key`).
fn state_get(
&self,
shortstatehash: ShortStateHash,
shortstatehash: u64,
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<Arc<PduEvent>>>;
/// Returns the state hash for this pdu.
fn pdu_shortstatehash(
&self,
event_id: &EventId,
) -> Result<Option<ShortStateHash>>;
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>>;
/// Returns the full room state.
async fn room_state_full(

View file

@ -4,6 +4,7 @@ use std::{collections::HashSet, sync::Arc};
pub(crate) use data::Data;
use ruma::{
events::{
direct::DirectEvent,
ignored_user_list::IgnoredUserListEvent,
room::{create::RoomCreateEventContent, member::MembershipState},
AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
@ -64,21 +65,118 @@ impl Service {
content.predecessor
})
{
self.copy_upgraded_account_data(
user_id,
&predecessor.room_id,
room_id,
)?;
// Copy user settings from predecessor to the current
// room:
// - Push rules
//
// TODO: finish this once push rules are implemented.
//
// let mut push_rules_event_content: PushRulesEvent =
// account_data .get(
// None,
// user_id,
// EventType::PushRules,
// )?;
//
// NOTE: find where `predecessor.room_id` match
// and update to `room_id`.
//
// account_data
// .update(
// None,
// user_id,
// EventType::PushRules,
// &push_rules_event_content,
// globals,
// )
// .ok();
// Copy old tags to new room
if let Some(tag_event) = services()
.account_data
.get(
Some(&predecessor.room_id),
user_id,
RoomAccountDataEventType::Tag,
)?
.map(|event| {
serde_json::from_str(event.get()).map_err(|e| {
warn!(
"Invalid account data event in db: \
{e:?}"
);
Error::BadDatabase(
"Invalid account data event in db.",
)
})
})
{
services()
.account_data
.update(
Some(room_id),
user_id,
RoomAccountDataEventType::Tag,
&tag_event?,
)
.ok();
};
// Copy direct chat flag
if let Some(direct_event) = services()
.account_data
.get(
None,
user_id,
GlobalAccountDataEventType::Direct
.to_string()
.into(),
)?
.map(|event| {
serde_json::from_str::<DirectEvent>(event.get())
.map_err(|e| {
warn!(
"Invalid account data event in \
db: {e:?}"
);
Error::BadDatabase(
"Invalid account data event in db.",
)
})
})
{
let mut direct_event = direct_event?;
let mut room_ids_updated = false;
for room_ids in direct_event.content.0.values_mut()
{
if room_ids
.iter()
.any(|r| r == &predecessor.room_id)
{
room_ids.push(room_id.to_owned());
room_ids_updated = true;
}
}
if room_ids_updated {
services().account_data.update(
None,
user_id,
GlobalAccountDataEventType::Direct
.to_string()
.into(),
&serde_json::to_value(&direct_event)
.expect("to json always works"),
)?;
}
};
}
}
self.db.mark_as_joined(user_id, room_id)?;
}
MembershipState::Invite => {
let event_kind = RoomAccountDataEventType::from(
GlobalAccountDataEventType::IgnoredUserList.to_string(),
);
// We want to know if the sender is ignored by the receiver
let is_ignored = services()
.account_data
@ -87,19 +185,19 @@ impl Service {
None,
// Receiver
user_id,
event_kind.clone(),
GlobalAccountDataEventType::IgnoredUserList
.to_string()
.into(),
)?
.map(|event| {
serde_json::from_str::<IgnoredUserListEvent>(
event.get(),
)
.map_err(|error| {
warn!(
%error,
%event_kind,
"Invalid account data event",
);
Error::BadDatabase("Invalid account data event.")
.map_err(|e| {
warn!("Invalid account data event in db: {e:?}");
Error::BadDatabase(
"Invalid account data event in db.",
)
})
})
.transpose()?
@ -130,156 +228,6 @@ impl Service {
Ok(())
}
/// Copy all account data references from the predecessor to a new room when
/// joining an upgraded room.
///
/// References to the predecessor room are not removed.
#[tracing::instrument(skip(self))]
fn copy_upgraded_account_data(
&self,
user_id: &UserId,
from_room_id: &RoomId,
to_room_id: &RoomId,
) -> Result<()> {
// - Push rules
//
// TODO: finish this once push rules are implemented.
//
// let mut push_rules_event_content: PushRulesEvent =
// account_data .get(
// None,
// user_id,
// EventType::PushRules,
// )?;
//
// NOTE: find where `predecessor.room_id` match
// and update to `room_id`.
//
// account_data
// .update(
// None,
// user_id,
// EventType::PushRules,
// &push_rules_event_content,
// globals,
// )
// .ok();
self.copy_upgraded_account_data_tag(user_id, from_room_id, to_room_id)?;
self.copy_upgraded_account_data_direct(
user_id,
from_room_id,
to_room_id,
)?;
Ok(())
}
/// Copy `m.tag` account data to an upgraded room.
// Allowed because this function uses `services()`
#[allow(clippy::unused_self)]
fn copy_upgraded_account_data_tag(
&self,
user_id: &UserId,
from_room_id: &RoomId,
to_room_id: &RoomId,
) -> Result<()> {
let Some(event) = services().account_data.get(
Some(from_room_id),
user_id,
RoomAccountDataEventType::Tag,
)?
else {
return Ok(());
};
let event = serde_json::from_str::<serde_json::Value>(event.get())
.expect("RawValue -> Value should always succeed");
if let Err(error) = services().account_data.update(
Some(to_room_id),
user_id,
RoomAccountDataEventType::Tag,
&event,
) {
warn!(%error, "error writing m.tag account data to upgraded room");
}
Ok(())
}
/// Copy references in `m.direct` account data events to an upgraded room.
// Allowed because this function uses `services()`
#[allow(clippy::unused_self)]
fn copy_upgraded_account_data_direct(
&self,
user_id: &UserId,
from_room_id: &RoomId,
to_room_id: &RoomId,
) -> Result<()> {
let event_kind = RoomAccountDataEventType::from(
GlobalAccountDataEventType::Direct.to_string(),
);
let Some(event) =
services().account_data.get(None, user_id, event_kind.clone())?
else {
return Ok(());
};
let mut event = serde_json::from_str::<serde_json::Value>(event.get())
.expect("RawValue -> Value should always succeed");
// As a server, we should try not to assume anything about the schema
// of this event. Account data may be arbitrary JSON.
//
// In particular, there is an element bug[1] that causes it to store
// m.direct events that don't match the schema from the spec.
//
// [1]: https://github.com/element-hq/element-web/issues/27630
//
// A valid m.direct event looks like this:
//
// {
// "type": "m.account_data",
// "content": {
// "@userid1": [ "!roomid1", "!roomid2" ],
// "@userid2": [ "!roomid3" ],
// }
// }
//
// We want to find userid keys where the value contains from_room_id,
// and insert a new entry for to_room_id. This should work even if some
// of the userid keys do not conform to the spec. If parts of the object
// do not match the expected schema, we should prefer to skip just those
// parts.
let mut event_updated = false;
let Some(direct_user_ids) = event.get_mut("content") else {
return Ok(());
};
let Some(direct_user_ids) = direct_user_ids.as_object_mut() else {
return Ok(());
};
for room_ids in direct_user_ids.values_mut() {
let Some(room_ids) = room_ids.as_array_mut() else {
continue;
};
if room_ids.iter().any(|room_id| room_id == from_room_id.as_str()) {
room_ids.push(to_room_id.to_string().into());
event_updated = true;
}
}
if event_updated {
if let Err(error) = services().account_data.update(
None,
user_id,
event_kind.clone(),
&event,
) {
warn!(%event_kind, %error, "error writing account data event after upgrading room");
}
}
Ok(())
}
#[tracing::instrument(skip(self, room_id))]
pub(crate) fn update_joined_count(&self, room_id: &RoomId) -> Result<()> {
self.db.update_joined_count(room_id)

View file

@ -1,5 +1,4 @@
use std::{
array,
collections::HashSet,
mem::size_of,
sync::{Arc, Mutex},
@ -18,11 +17,9 @@ pub(crate) mod data;
pub(crate) use data::Data;
use data::StateDiff;
use super::short::{ShortEventId, ShortStateHash, ShortStateKey};
#[derive(Clone)]
pub(crate) struct CompressedStateLayer {
pub(crate) shortstatehash: ShortStateHash,
pub(crate) shortstatehash: u64,
pub(crate) full_state: Arc<HashSet<CompressedStateEvent>>,
pub(crate) added: Arc<HashSet<CompressedStateEvent>>,
pub(crate) removed: Arc<HashSet<CompressedStateEvent>>,
@ -32,45 +29,10 @@ pub(crate) struct Service {
pub(crate) db: &'static dyn Data,
#[allow(clippy::type_complexity)]
pub(crate) stateinfo_cache:
Mutex<LruCache<ShortStateHash, Vec<CompressedStateLayer>>>,
pub(crate) stateinfo_cache: Mutex<LruCache<u64, Vec<CompressedStateLayer>>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) struct CompressedStateEvent {
pub(crate) state: ShortStateKey,
pub(crate) event: ShortEventId,
}
impl CompressedStateEvent {
pub(crate) fn as_bytes(
&self,
) -> [u8; size_of::<ShortStateKey>() + size_of::<ShortEventId>()] {
let mut bytes = self
.state
.get()
.to_be_bytes()
.into_iter()
.chain(self.event.get().to_be_bytes());
array::from_fn(|_| bytes.next().unwrap())
}
pub(crate) fn from_bytes(
bytes: [u8; size_of::<ShortStateKey>() + size_of::<ShortEventId>()],
) -> Self {
let state = ShortStateKey::new(u64::from_be_bytes(
bytes[0..8].try_into().unwrap(),
));
let event = ShortEventId::new(u64::from_be_bytes(
bytes[8..16].try_into().unwrap(),
));
Self {
state,
event,
}
}
}
pub(crate) type CompressedStateEvent = [u8; 2 * size_of::<u64>()];
impl Service {
/// Returns a stack with info on shortstatehash, full state, added diff and
@ -79,7 +41,7 @@ impl Service {
#[tracing::instrument(skip(self))]
pub(crate) fn load_shortstatehash_info(
&self,
shortstatehash: ShortStateHash,
shortstatehash: u64,
) -> Result<Vec<CompressedStateLayer>> {
let lookup = Lookup::StateInfo;
@ -134,16 +96,18 @@ impl Service {
#[allow(clippy::unused_self)]
pub(crate) fn compress_state_event(
&self,
shortstatekey: ShortStateKey,
shortstatekey: u64,
event_id: &EventId,
) -> Result<CompressedStateEvent> {
Ok(CompressedStateEvent {
state: shortstatekey,
event: services()
let mut v = shortstatekey.to_be_bytes().to_vec();
v.extend_from_slice(
&services()
.rooms
.short
.get_or_create_shorteventid(event_id)?,
})
.get_or_create_shorteventid(event_id)?
.to_be_bytes(),
);
Ok(v.try_into().expect("we checked the size above"))
}
/// Returns shortstatekey, event id
@ -152,13 +116,14 @@ impl Service {
pub(crate) fn parse_compressed_state_event(
&self,
compressed_event: &CompressedStateEvent,
) -> Result<(ShortStateKey, Arc<EventId>)> {
) -> Result<(u64, Arc<EventId>)> {
Ok((
compressed_event.state,
services()
.rooms
.short
.get_eventid_from_short(compressed_event.event)?,
utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()])
.expect("bytes have right length"),
services().rooms.short.get_eventid_from_short(
utils::u64_from_bytes(&compressed_event[size_of::<u64>()..])
.expect("bytes have right length"),
)?,
))
}
@ -190,7 +155,7 @@ impl Service {
))]
pub(crate) fn save_state_from_diff(
&self,
shortstatehash: ShortStateHash,
shortstatehash: u64,
statediffnew: Arc<HashSet<CompressedStateEvent>>,
statediffremoved: Arc<HashSet<CompressedStateEvent>>,
diff_to_sibling: usize,
@ -310,7 +275,7 @@ impl Service {
room_id: &RoomId,
new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> Result<(
ShortStateHash,
u64,
Arc<HashSet<CompressedStateEvent>>,
Arc<HashSet<CompressedStateEvent>>,
)> {
@ -318,7 +283,10 @@ impl Service {
services().rooms.state.get_room_shortstatehash(room_id)?;
let state_hash = utils::calculate_hash(
new_state_ids_compressed.iter().map(CompressedStateEvent::as_bytes),
&new_state_ids_compressed
.iter()
.map(|bytes| &bytes[..])
.collect::<Vec<_>>(),
);
let (new_shortstatehash, already_existed) =

View file

@ -1,22 +1,19 @@
use std::{collections::HashSet, sync::Arc};
use super::CompressedStateEvent;
use crate::{service::rooms::short::ShortStateHash, Result};
use crate::Result;
pub(crate) struct StateDiff {
pub(crate) parent: Option<ShortStateHash>,
pub(crate) parent: Option<u64>,
pub(crate) added: Arc<HashSet<CompressedStateEvent>>,
pub(crate) removed: Arc<HashSet<CompressedStateEvent>>,
}
pub(crate) trait Data: Send + Sync {
fn get_statediff(
&self,
shortstatehash: ShortStateHash,
) -> Result<StateDiff>;
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff>;
fn save_statediff(
&self,
shortstatehash: ShortStateHash,
shortstatehash: u64,
diff: StateDiff,
) -> Result<()>;
}

View file

@ -3,7 +3,7 @@ use ruma::{
UserId,
};
use crate::{service::rooms::timeline::PduId, PduEvent, Result};
use crate::{PduEvent, Result};
pub(crate) trait Data: Send + Sync {
#[allow(clippy::type_complexity)]
@ -17,11 +17,11 @@ pub(crate) trait Data: Send + Sync {
fn update_participants(
&self,
root_id: &PduId,
root_id: &[u8],
participants: &[OwnedUserId],
) -> Result<()>;
fn get_participants(
&self,
root_id: &PduId,
root_id: &[u8],
) -> Result<Option<Vec<OwnedUserId>>>;
}

View file

@ -14,8 +14,7 @@ use ruma::{
push_rules::PushRulesEvent,
room::{
create::RoomCreateEventContent, encrypted::Relation,
member::MembershipState, message::RoomMessageEventContent,
power_levels::RoomPowerLevelsEventContent,
member::MembershipState, power_levels::RoomPowerLevelsEventContent,
redaction::RoomRedactionEventContent,
},
GlobalAccountDataEventType, StateEventType, TimelineEventType,
@ -23,44 +22,24 @@ use ruma::{
push::{Action, Ruleset, Tweak},
state_res::{self, Event, RoomVersion},
uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId,
OwnedEventId, OwnedRoomId, OwnedServerName, RoomId, RoomVersionId,
ServerName, UserId,
OwnedEventId, OwnedServerName, RoomId, RoomVersionId, ServerName, UserId,
};
use serde::Deserialize;
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::sync::RwLock;
use tokio::sync::{MutexGuard, RwLock};
use tracing::{error, info, warn};
use super::{short::ShortRoomId, state_compressor::CompressedStateEvent};
use super::state_compressor::CompressedStateEvent;
use crate::{
api::server_server,
service::{
appservice::NamespaceRegex,
globals::{marker, SigningKeys},
globals::SigningKeys,
pdu::{EventHash, PduBuilder},
},
services,
utils::{self, on_demand_hashmap::KeyToken},
Error, PduEvent, Result,
services, utils, Error, PduEvent, Result,
};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) struct PduId {
inner: Vec<u8>,
}
impl PduId {
pub(crate) fn new(inner: Vec<u8>) -> Self {
Self {
inner,
}
}
pub(crate) fn as_bytes(&self) -> &[u8] {
&self.inner
}
}
#[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)]
pub(crate) enum PduCount {
Backfilled(u64),
@ -163,7 +142,7 @@ impl Service {
pub(crate) fn get_pdu_id(
&self,
event_id: &EventId,
) -> Result<Option<PduId>> {
) -> Result<Option<Vec<u8>>> {
self.db.get_pdu_id(event_id)
}
@ -182,7 +161,7 @@ impl Service {
/// This does __NOT__ check the outliers `Tree`.
pub(crate) fn get_pdu_from_id(
&self,
pdu_id: &PduId,
pdu_id: &[u8],
) -> Result<Option<PduEvent>> {
self.db.get_pdu_from_id(pdu_id)
}
@ -190,7 +169,7 @@ impl Service {
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
pub(crate) fn get_pdu_json_from_id(
&self,
pdu_id: &PduId,
pdu_id: &[u8],
) -> Result<Option<CanonicalJsonObject>> {
self.db.get_pdu_json_from_id(pdu_id)
}
@ -199,7 +178,7 @@ impl Service {
#[tracing::instrument(skip(self))]
pub(crate) fn replace_pdu(
&self,
pdu_id: &PduId,
pdu_id: &[u8],
pdu_json: &CanonicalJsonObject,
pdu: &PduEvent,
) -> Result<()> {
@ -218,10 +197,9 @@ impl Service {
pdu: &PduEvent,
mut pdu_json: CanonicalJsonObject,
leaves: Vec<OwnedEventId>,
room_id: &KeyToken<OwnedRoomId, marker::State>,
) -> Result<PduId> {
assert_eq!(*pdu.room_id, **room_id, "Token for incorrect room passed");
// Take mutex guard to make sure users get the room state mutex
state_lock: &MutexGuard<'_, ()>,
) -> Result<Vec<u8>> {
let shortroomid = services()
.rooms
.short
@ -266,7 +244,7 @@ impl Service {
}
}
} else {
error!("Invalid unsigned type in pdu");
error!("Invalid unsigned type in pdu.");
}
}
@ -275,13 +253,22 @@ impl Service {
.rooms
.pdu_metadata
.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?;
services().rooms.state.set_forward_extremities(room_id, leaves)?;
services().rooms.state.set_forward_extremities(
&pdu.room_id,
leaves,
state_lock,
)?;
let insert_token = services()
.globals
.roomid_mutex_insert
.lock_key(pdu.room_id.clone())
.await;
let mutex_insert = Arc::clone(
services()
.globals
.roomid_mutex_insert
.write()
.await
.entry(pdu.room_id.clone())
.or_default(),
);
let insert_lock = mutex_insert.lock().await;
let count1 = services().globals.next_count()?;
// Mark as read first so the sending client doesn't get a notification
@ -297,14 +284,13 @@ impl Service {
.reset_notification_counts(&pdu.sender, &pdu.room_id)?;
let count2 = services().globals.next_count()?;
let mut pdu_id = shortroomid.get().to_be_bytes().to_vec();
let mut pdu_id = shortroomid.to_be_bytes().to_vec();
pdu_id.extend_from_slice(&count2.to_be_bytes());
let pdu_id = PduId::new(pdu_id);
// Insert pdu
self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2)?;
drop(insert_token);
drop(insert_lock);
// See if the event matches any known pushers
let power_levels: RoomPowerLevelsEventContent = services()
@ -410,8 +396,17 @@ impl Service {
TimelineEventType::RoomRedaction => {
let room_version_id =
services().rooms.state.get_room_version(&pdu.room_id)?;
match &room_version_id {
room_version if *room_version < RoomVersionId::V11 => {
match room_version_id {
RoomVersionId::V1
| RoomVersionId::V2
| RoomVersionId::V3
| RoomVersionId::V4
| RoomVersionId::V5
| RoomVersionId::V6
| RoomVersionId::V7
| RoomVersionId::V8
| RoomVersionId::V9
| RoomVersionId::V10 => {
if let Some(redact_id) = &pdu.redacts {
if services().rooms.state_accessor.user_can_redact(
redact_id,
@ -445,9 +440,7 @@ impl Service {
}
}
_ => {
return Err(Error::BadServerResponse(
"Unsupported room version.",
));
unreachable!("Validity of room version already checked")
}
};
}
@ -467,21 +460,20 @@ impl Service {
#[derive(Deserialize)]
struct ExtractMembership {
membership: MembershipState,
reason: Option<String>,
}
// if the state_key fails
let target_user_id = UserId::parse(state_key.clone())
.expect("This state_key was previously validated");
let ExtractMembership {
membership,
reason,
} = serde_json::from_str(pdu.content.get()).map_err(
|_| Error::bad_database("Invalid content in pdu."),
)?;
let content = serde_json::from_str::<ExtractMembership>(
pdu.content.get(),
)
.map_err(|_| {
Error::bad_database("Invalid content in pdu.")
})?;
let invite_state = match membership {
let invite_state = match content.membership {
MembershipState::Invite => {
let state = services()
.rooms
@ -492,41 +484,13 @@ impl Service {
_ => None,
};
if membership == MembershipState::Ban {
let (room, user) = (&pdu.room_id, &target_user_id);
if user.server_name()
== services().globals.server_name()
{
info!(
%user,
%room,
reason,
"User has been banned from room"
);
let reason = match reason.filter(|s| !s.is_empty())
{
Some(s) => format!(": {s}"),
None => String::new(),
};
services().admin.send_message(
RoomMessageEventContent::notice_plain(format!(
"User {user} has been banned from room \
{room}{reason}",
)),
);
}
}
// Update our membership info, we do this here incase a user
// is invited and immediately leaves we
// need the DB to record the invite event for auth
services().rooms.state_cache.update_membership(
&pdu.room_id,
&target_user_id,
membership,
content.membership,
&pdu.sender,
invite_state,
true,
@ -552,35 +516,36 @@ impl Service {
&body,
)?;
let admin_bot = &services().globals.admin_bot_user_id;
let server_user = format!(
"@{}:{}",
if services().globals.config.conduit_compat {
"conduit"
} else {
"grapevine"
},
services().globals.server_name()
);
let to_admin_bot = body
.starts_with(&format!("{admin_bot}: "))
|| body.starts_with(&format!("{admin_bot} "))
|| body == format!("{admin_bot}:")
|| body == admin_bot.as_str()
|| body.starts_with("!admin ")
|| body == "!admin";
let to_grapevine = body
.starts_with(&format!("{server_user}: "))
|| body.starts_with(&format!("{server_user} "))
|| body == format!("{server_user}:")
|| body == server_user;
// This will evaluate to false if the emergency password
// is set up so that the administrator can execute commands
// as the admin bot
let from_admin_bot = &pdu.sender == admin_bot
// This will evaluate to false if the emergency password is
// set up so that the administrator can
// execute commands as grapevine
let from_grapevine = pdu.sender == server_user
&& services().globals.emergency_password().is_none();
if let Some(admin_room) =
services().admin.get_admin_room()?
{
if to_admin_bot
&& !from_admin_bot
if to_grapevine
&& !from_grapevine
&& admin_room == pdu.room_id
&& services()
.rooms
.state_cache
.is_joined(admin_bot, &admin_room)
.unwrap_or(false)
{
services().admin.process_message(body);
services().admin.process_message(pdu.clone(), body);
}
}
}
@ -593,8 +558,10 @@ impl Service {
if let Ok(content) =
serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get())
{
if let Some(related_pducount) =
self.get_pdu_count(&content.relates_to.event_id)?
if let Some(related_pducount) = services()
.rooms
.timeline
.get_pdu_count(&content.relates_to.event_id)?
{
services()
.rooms
@ -612,8 +579,10 @@ impl Service {
} => {
// We need to do it again here, because replies don't have
// event_id as a top level field
if let Some(related_pducount) =
self.get_pdu_count(&in_reply_to.event_id)?
if let Some(related_pducount) = services()
.rooms
.timeline
.get_pdu_count(&in_reply_to.event_id)?
{
services().rooms.pdu_metadata.add_relation(
PduCount::Normal(count2),
@ -701,7 +670,9 @@ impl Service {
&self,
pdu_builder: PduBuilder,
sender: &UserId,
room_id: &KeyToken<OwnedRoomId, marker::State>,
room_id: &RoomId,
// Take mutex guard to make sure users get the room state mutex
_mutex_lock: &MutexGuard<'_, ()>,
) -> Result<(PduEvent, CanonicalJsonObject)> {
let PduBuilder {
event_type,
@ -732,7 +703,7 @@ impl Service {
} else {
Err(Error::InconsistentRoomState(
"non-create event for room of unknown version",
(**room_id).clone(),
room_id.to_owned(),
))
}
})?;
@ -781,7 +752,7 @@ impl Service {
let mut pdu = PduEvent {
event_id: ruma::event_id!("$thiswillbefilledinlater").into(),
room_id: (**room_id).clone(),
room_id: room_id.to_owned(),
sender: sender.to_owned(),
origin_server_ts: utils::millis_since_unix_epoch()
.try_into()
@ -816,9 +787,9 @@ impl Service {
None::<PduEvent>,
|k, s| auth_events.get(&(k.clone(), s.to_owned())),
)
.map_err(|error| {
error!(%error, "Auth check failed");
Error::BadDatabase("Auth check failed.")
.map_err(|e| {
error!("{:?}", e);
Error::bad_database("Auth check failed.")
})?;
if !auth_check {
@ -885,18 +856,24 @@ impl Service {
/// Creates a new persisted data unit and adds it to a room. This function
/// takes a roomid_mutex_state, meaning that only this function is able
/// to mutate the room state.
#[tracing::instrument(skip(self))]
#[tracing::instrument(skip(self, state_lock))]
pub(crate) async fn build_and_append_pdu(
&self,
pdu_builder: PduBuilder,
sender: &UserId,
room_id: &KeyToken<OwnedRoomId, marker::State>,
room_id: &RoomId,
// Take mutex guard to make sure users get the room state mutex
state_lock: &MutexGuard<'_, ()>,
) -> Result<Arc<EventId>> {
let (pdu, pdu_json) =
self.create_hash_and_sign_event(pdu_builder, sender, room_id)?;
let (pdu, pdu_json) = self.create_hash_and_sign_event(
pdu_builder,
sender,
room_id,
state_lock,
)?;
if let Some(admin_room) = services().admin.get_admin_room()? {
if admin_room == **room_id {
if admin_room == room_id {
match pdu.event_type() {
TimelineEventType::RoomEncryption => {
warn!("Encryption is not allowed in the admins room");
@ -1076,14 +1053,18 @@ impl Service {
// Since this PDU references all pdu_leaves we can update the
// leaves of the room
vec![(*pdu.event_id).to_owned()],
room_id,
state_lock,
)
.await?;
// We set the room state after inserting the pdu, so that we never have
// a moment in time where events in the current room state do
// not exist
services().rooms.state.set_room_state(room_id, statehashid)?;
services().rooms.state.set_room_state(
room_id,
statehashid,
state_lock,
)?;
let mut servers: HashSet<OwnedServerName> = services()
.rooms
@ -1123,10 +1104,9 @@ impl Service {
new_room_leaves: Vec<OwnedEventId>,
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
soft_fail: bool,
room_id: &KeyToken<OwnedRoomId, marker::State>,
) -> Result<Option<PduId>> {
assert_eq!(*pdu.room_id, **room_id, "Token for incorrect room passed");
// Take mutex guard to make sure users get the room state mutex
state_lock: &MutexGuard<'_, ()>,
) -> Result<Option<Vec<u8>>> {
// We append to state before appending the pdu, so we don't have a
// moment in time with the pdu without it's state. This is okay
// because append_pdu can't fail.
@ -1140,16 +1120,20 @@ impl Service {
services()
.rooms
.pdu_metadata
.mark_as_referenced(room_id, &pdu.prev_events)?;
services()
.rooms
.state
.set_forward_extremities(room_id, new_room_leaves)?;
.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?;
services().rooms.state.set_forward_extremities(
&pdu.room_id,
new_room_leaves,
state_lock,
)?;
return Ok(None);
}
let pdu_id =
self.append_pdu(pdu, pdu_json, new_room_leaves, room_id).await?;
let pdu_id = services()
.rooms
.timeline
.append_pdu(pdu, pdu_json, new_room_leaves, state_lock)
.await?;
Ok(Some(pdu_id))
}
@ -1194,7 +1178,7 @@ impl Service {
&self,
event_id: &EventId,
reason: &PduEvent,
shortroomid: ShortRoomId,
shortroomid: u64,
) -> Result<()> {
// TODO: Don't reserialize, keep original json
if let Some(pdu_id) = self.get_pdu_id(event_id)? {
@ -1268,7 +1252,7 @@ impl Service {
// Request backfill
for backfill_server in admin_servers {
info!(server = %backfill_server, "Asking server for backfill");
info!("Asking {backfill_server} for backfill");
let response = services()
.sending
.send_federation_request(
@ -1284,21 +1268,17 @@ impl Service {
Ok(response) => {
let pub_key_map = RwLock::new(BTreeMap::new());
for pdu in response.pdus {
if let Err(error) = self
if let Err(e) = self
.backfill_pdu(backfill_server, pdu, &pub_key_map)
.await
{
warn!(%error, "Failed to add backfilled pdu");
warn!("Failed to add backfilled pdu: {e}");
}
}
return Ok(());
}
Err(error) => {
warn!(
server = %backfill_server,
%error,
"Server could not provide backfill",
);
Err(e) => {
warn!("{backfill_server} could not provide backfill: {e}");
}
}
}
@ -1318,15 +1298,20 @@ impl Service {
server_server::parse_incoming_pdu(&pdu)?;
// Lock so we cannot backfill the same pdu twice at the same time
let federation_token = services()
.globals
.roomid_mutex_federation
.lock_key(room_id.clone())
.await;
let mutex = Arc::clone(
services()
.globals
.roomid_mutex_federation
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let mutex_lock = mutex.lock().await;
// Skip the PDU if we already have it as a timeline event
if let Some(pdu_id) = self.get_pdu_id(&event_id)? {
info!(%event_id, ?pdu_id, "We already know this event");
if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(&event_id)? {
info!("We already know {event_id} at {pdu_id:?}");
return Ok(());
}
@ -1352,22 +1337,26 @@ impl Service {
.get_shortroomid(&room_id)?
.expect("room exists");
let insert_token = services()
.globals
.roomid_mutex_insert
.lock_key(room_id.clone())
.await;
let mutex_insert = Arc::clone(
services()
.globals
.roomid_mutex_insert
.write()
.await
.entry(room_id.clone())
.or_default(),
);
let insert_lock = mutex_insert.lock().await;
let count = services().globals.next_count()?;
let mut pdu_id = shortroomid.get().to_be_bytes().to_vec();
let mut pdu_id = shortroomid.to_be_bytes().to_vec();
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
pdu_id.extend_from_slice(&(u64::MAX - count).to_be_bytes());
let pdu_id = PduId::new(pdu_id);
// Insert pdu
self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value)?;
drop(insert_token);
drop(insert_lock);
if pdu.kind == TimelineEventType::RoomMessage {
#[derive(Deserialize)]
@ -1389,7 +1378,7 @@ impl Service {
)?;
}
}
drop(federation_token);
drop(mutex_lock);
info!("Prepended backfill pdu");
Ok(())

View file

@ -3,7 +3,7 @@ use std::sync::Arc;
use ruma::{CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId};
use super::PduCount;
use crate::{service::rooms::timeline::PduId, PduEvent, Result};
use crate::{PduEvent, Result};
pub(crate) trait Data: Send + Sync {
fn last_timeline_count(
@ -28,7 +28,7 @@ pub(crate) trait Data: Send + Sync {
) -> Result<Option<CanonicalJsonObject>>;
/// Returns the pdu's id.
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<PduId>>;
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>>;
/// Returns the pdu.
///
@ -46,18 +46,18 @@ pub(crate) trait Data: Send + Sync {
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
fn get_pdu_from_id(&self, pdu_id: &PduId) -> Result<Option<PduEvent>>;
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>>;
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
fn get_pdu_json_from_id(
&self,
pdu_id: &PduId,
pdu_id: &[u8],
) -> Result<Option<CanonicalJsonObject>>;
/// Adds a new pdu to the timeline
fn append_pdu(
&self,
pdu_id: &PduId,
pdu_id: &[u8],
pdu: &PduEvent,
json: &CanonicalJsonObject,
count: u64,
@ -66,7 +66,7 @@ pub(crate) trait Data: Send + Sync {
// Adds a new pdu to the backfilled timeline
fn prepend_backfill_pdu(
&self,
pdu_id: &PduId,
pdu_id: &[u8],
event_id: &EventId,
json: &CanonicalJsonObject,
) -> Result<()>;
@ -74,7 +74,7 @@ pub(crate) trait Data: Send + Sync {
/// Removes a pdu and creates a new one with the same id.
fn replace_pdu(
&self,
pdu_id: &PduId,
pdu_id: &[u8],
pdu_json: &CanonicalJsonObject,
pdu: &PduEvent,
) -> Result<()>;

View file

@ -1,6 +1,6 @@
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use crate::{service::rooms::short::ShortStateHash, Result};
use crate::Result;
pub(crate) trait Data: Send + Sync {
fn reset_notification_counts(
@ -32,14 +32,14 @@ pub(crate) trait Data: Send + Sync {
&self,
room_id: &RoomId,
token: u64,
shortstatehash: ShortStateHash,
shortstatehash: u64,
) -> Result<()>;
fn get_token_shortstatehash(
&self,
room_id: &RoomId,
token: u64,
) -> Result<Option<ShortStateHash>>;
) -> Result<Option<u64>>;
fn get_shared_rooms<'a>(
&'a self,

View file

@ -28,10 +28,8 @@ use ruma::{
push_rules::PushRulesEvent, receipt::ReceiptType,
AnySyncEphemeralRoomEvent, GlobalAccountDataEventType,
},
push,
serde::Raw,
uint, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId, ServerName,
UInt, UserId,
push, uint, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId,
ServerName, UInt, UserId,
};
use tokio::{
select,
@ -39,7 +37,6 @@ use tokio::{
};
use tracing::{debug, error, warn, Span};
use super::rooms::timeline::PduId;
use crate::{
api::{appservice_server, server_server},
services,
@ -83,12 +80,12 @@ impl Destination {
}
}
#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) enum SendingEventType {
// pduid
Pdu(PduId),
Pdu(Vec<u8>),
// pdu json
Edu(Raw<Edu>),
Edu(Vec<u8>),
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
@ -162,7 +159,7 @@ impl Service {
sender,
receiver: Mutex::new(receiver),
maximum_requests: Arc::new(Semaphore::new(
config.federation.max_concurrent_requests.into(),
config.max_concurrent_requests.into(),
)),
})
}
@ -193,10 +190,8 @@ impl Service {
if entry.len() > 30 {
warn!(
?key,
?destination,
?event,
"Dropping some current events",
"Dropping some current events: {:?} {:?} {:?}",
key, destination, event
);
self.db.delete_active_request(key)?;
continue;
@ -261,47 +256,57 @@ impl Service {
Span::current().record("error", e.to_string());
}
if let Err(error) = result {
warn!(%error, "Marking transaction as failed");
current_transaction_status.entry(destination).and_modify(|e| {
use TransactionStatus::{Failed, Retrying, Running};
match result {
Ok(()) => {
self.db.delete_all_active_requests_for(&destination)?;
*e = match e {
Running => Failed(1, Instant::now()),
Retrying(n) => Failed(*n + 1, Instant::now()),
Failed(..) => {
error!("Request that was not even running failed?!");
return;
}
// Find events that have been added since starting the
// last request
let new_events = self
.db
.queued_requests(&destination)
.filter_map(Result::ok)
.take(30)
.collect::<Vec<_>>();
if new_events.is_empty() {
current_transaction_status.remove(&destination);
Ok(None)
} else {
// Insert pdus we found
self.db.mark_as_active(&new_events)?;
Ok(Some(HandlerInputs {
destination: destination.clone(),
events: new_events
.into_iter()
.map(|(event, _)| event)
.collect(),
requester_span: None,
}))
}
});
return Ok(None);
}
Err(_err) => {
warn!("Marking transaction as failed");
current_transaction_status.entry(destination).and_modify(|e| {
*e = match e {
TransactionStatus::Running => {
TransactionStatus::Failed(1, Instant::now())
}
TransactionStatus::Retrying(n) => {
TransactionStatus::Failed(*n + 1, Instant::now())
}
TransactionStatus::Failed(..) => {
error!(
"Request that was not even running failed?!"
);
return;
}
}
});
Ok(None)
}
}
self.db.delete_all_active_requests_for(&destination)?;
// Find events that have been added since starting the
// last request
let new_events = self
.db
.queued_requests(&destination)
.filter_map(Result::ok)
.take(30)
.collect::<Vec<_>>();
if new_events.is_empty() {
current_transaction_status.remove(&destination);
return Ok(None);
}
// Insert pdus we found
self.db.mark_as_active(&new_events)?;
Ok(Some(HandlerInputs {
destination: destination.clone(),
events: new_events.into_iter().map(|(event, _)| event).collect(),
requester_span: None,
}))
}
#[tracing::instrument(
@ -329,7 +334,7 @@ impl Service {
current_transaction_status,
) {
Ok(SelectedEvents::Retries(events)) => {
debug!("Retrying old events");
debug!("retrying old events");
Some(HandlerInputs {
destination,
events,
@ -337,7 +342,7 @@ impl Service {
})
}
Ok(SelectedEvents::New(events)) => {
debug!("Sending new event");
debug!("sending new event");
Some(HandlerInputs {
destination,
events,
@ -345,7 +350,7 @@ impl Service {
})
}
Ok(SelectedEvents::None) => {
debug!("Holding off from sending any events");
debug!("holding off from sending any events");
None
}
Err(error) => {
@ -446,7 +451,7 @@ impl Service {
pub(crate) fn select_edus(
&self,
server_name: &ServerName,
) -> Result<(Vec<Raw<Edu>>, u64)> {
) -> Result<(Vec<Vec<u8>>, u64)> {
// u64: count of last edu
let since = self.db.get_latest_educount(server_name)?;
let mut events = Vec::new();
@ -534,7 +539,7 @@ impl Service {
};
events.push(
Raw::new(&federation_event)
serde_json::to_vec(&federation_event)
.expect("json can be serialized"),
);
@ -557,7 +562,9 @@ impl Service {
keys: None,
});
events.push(Raw::new(&edu).expect("json can be serialized"));
events.push(
serde_json::to_vec(&edu).expect("json can be serialized"),
);
}
Ok((events, max_edu_count))
@ -566,7 +573,7 @@ impl Service {
#[tracing::instrument(skip(self, pdu_id, user, pushkey))]
pub(crate) fn send_push_pdu(
&self,
pdu_id: &PduId,
pdu_id: &[u8],
user: &UserId,
pushkey: String,
) -> Result<()> {
@ -590,7 +597,7 @@ impl Service {
pub(crate) fn send_pdu<I: Iterator<Item = OwnedServerName>>(
&self,
servers: I,
pdu_id: &PduId,
pdu_id: &[u8],
) -> Result<()> {
let requests = servers
.into_iter()
@ -622,7 +629,7 @@ impl Service {
pub(crate) fn send_reliable_edu(
&self,
server: &ServerName,
serialized: Raw<Edu>,
serialized: Vec<u8>,
id: u64,
) -> Result<()> {
let destination = Destination::Normal(server.to_owned());
@ -645,7 +652,7 @@ impl Service {
pub(crate) fn send_pdu_appservice(
&self,
appservice_id: String,
pdu_id: PduId,
pdu_id: Vec<u8>,
) -> Result<()> {
let destination = Destination::Appservice(appservice_id);
let event_type = SendingEventType::Pdu(pdu_id);
@ -677,11 +684,11 @@ impl Service {
debug!("Got permit");
let response = tokio::time::timeout(
Duration::from_secs(2 * 60),
server_server::send_request(destination, request, true),
server_server::send_request(destination, request),
)
.await
.map_err(|_| {
warn!("Timeout waiting for server response");
warn!("Timeout waiting for server response of {destination}");
Error::BadServerResponse("Timeout waiting for server response")
})?;
drop(permit);
@ -755,10 +762,15 @@ async fn handle_appservice_event(
appservice::event::push_events::v1::Request {
events: pdu_jsons,
txn_id: general_purpose::URL_SAFE_NO_PAD
.encode(calculate_hash(events.iter().map(|e| match e {
SendingEventType::Edu(b) => b.json().get().as_bytes(),
SendingEventType::Pdu(b) => b.as_bytes(),
})))
.encode(calculate_hash(
&events
.iter()
.map(|e| match e {
SendingEventType::Edu(b)
| SendingEventType::Pdu(b) => &**b,
})
.collect::<Vec<_>>(),
))
.into(),
},
)
@ -871,7 +883,7 @@ async fn handle_federation_event(
.timeline
.get_pdu_json_from_id(pdu_id)?
.ok_or_else(|| {
error!(pdu_id = ?pdu_id, "PDU not found");
error!("event not found: {server} {pdu_id:?}");
Error::bad_database(
"[Normal] Event in servernamevent_datas not \
found in db.",
@ -880,7 +892,9 @@ async fn handle_federation_event(
));
}
SendingEventType::Edu(edu) => {
edu_jsons.push(edu.clone());
if let Ok(raw) = serde_json::from_slice(edu) {
edu_jsons.push(raw);
}
}
}
}
@ -895,19 +909,23 @@ async fn handle_federation_event(
edus: edu_jsons,
origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
transaction_id: general_purpose::URL_SAFE_NO_PAD
.encode(calculate_hash(events.iter().map(|e| match e {
SendingEventType::Edu(b) => b.json().get().as_bytes(),
SendingEventType::Pdu(b) => b.as_bytes(),
})))
.encode(calculate_hash(
&events
.iter()
.map(|e| match e {
SendingEventType::Edu(b)
| SendingEventType::Pdu(b) => &**b,
})
.collect::<Vec<_>>(),
))
.into(),
},
false,
)
.await?;
for pdu in response.pdus {
if let (event_id, Err(error)) = pdu {
warn!(%server, %event_id, %error, "Failed to send event to server");
if pdu.1.is_err() {
warn!("Failed to send to {}: {:?}", server, pdu);
}
}

View file

@ -118,7 +118,7 @@ impl Service {
AuthData::Dummy(_) => {
uiaainfo.completed.push(AuthType::Dummy);
}
kind => error!(?kind, "Auth kind not supported"),
k => error!("type not supported: {:?}", k),
}
// Check if a flow now succeeds

View file

@ -9,6 +9,7 @@ pub(crate) use data::Data;
use ruma::{
api::client::{
device::Device,
error::ErrorKind,
filter::FilterDefinition,
sync::sync_events::{
self,
@ -19,7 +20,7 @@ use ruma::{
events::AnyToDeviceEvent,
serde::Raw,
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId,
OwnedMxcUri, OwnedRoomId, OwnedUserId, UInt, UserId,
OwnedMxcUri, OwnedRoomId, OwnedUserId, RoomAliasId, UInt, UserId,
};
use crate::{services, Error, Result};
@ -258,9 +259,20 @@ impl Service {
// Allowed because this function uses `services()`
#[allow(clippy::unused_self)]
pub(crate) fn is_admin(&self, user_id: &UserId) -> Result<bool> {
services().admin.get_admin_room()?.map_or(Ok(false), |admin_room_id| {
services().rooms.state_cache.is_joined(user_id, &admin_room_id)
})
let admin_room_alias_id = RoomAliasId::parse(format!(
"#admins:{}",
services().globals.server_name()
))
.map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias.")
})?;
let admin_room_id = services()
.rooms
.alias
.resolve_local_alias(&admin_room_alias_id)?
.unwrap();
services().rooms.state_cache.is_joined(user_id, &admin_room_id)
}
/// Create a new user account on this homeserver.

View file

@ -1,5 +1,4 @@
pub(crate) mod error;
pub(crate) mod on_demand_hashmap;
use std::{
borrow::Cow,
@ -13,12 +12,9 @@ use cmp::Ordering;
use rand::{prelude::*, rngs::OsRng};
use ring::digest;
use ruma::{
api::client::error::ErrorKind, canonical_json::try_from_json_map,
CanonicalJsonError, CanonicalJsonObject, MxcUri, MxcUriError, OwnedMxcUri,
canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject,
};
use crate::{Error, Result};
// Hopefully we have a better chat protocol in 530 years
#[allow(clippy::as_conversions, clippy::cast_possible_truncation)]
pub(crate) fn millis_since_unix_epoch() -> u64 {
@ -106,30 +102,21 @@ where
}
#[tracing::instrument(skip(keys))]
pub(crate) fn calculate_hash<'a, I, T>(keys: I) -> Vec<u8>
where
I: IntoIterator<Item = T>,
T: AsRef<[u8]>,
{
let mut bytes = Vec::new();
for (i, key) in keys.into_iter().enumerate() {
if i != 0 {
bytes.push(0xFF);
}
bytes.extend_from_slice(key.as_ref());
}
pub(crate) fn calculate_hash(keys: &[&[u8]]) -> Vec<u8> {
// We only hash the pdu's event ids, not the whole pdu
let bytes = keys.join(&0xFF);
let hash = digest::digest(&digest::SHA256, &bytes);
hash.as_ref().to_owned()
}
pub(crate) fn common_elements<I, T, F>(
pub(crate) fn common_elements<I, F>(
mut iterators: I,
check_order: F,
) -> Option<impl Iterator<Item = T>>
) -> Option<impl Iterator<Item = Vec<u8>>>
where
I: Iterator,
I::Item: Iterator<Item = T>,
F: Fn(&T, &T) -> Ordering,
I::Item: Iterator<Item = Vec<u8>>,
F: Fn(&[u8], &[u8]) -> Ordering,
{
let first_iterator = iterators.next()?;
let mut other_iterators =
@ -241,8 +228,11 @@ pub(crate) fn debug_slice_truncated<T: fmt::Debug>(
/// Truncates a string to an approximate maximum length, replacing any extra
/// text with an ellipsis.
///
/// Only to be used for informational purposes, exact semantics are unspecified.
pub(crate) fn dbg_truncate_str(s: &str, mut max_len: usize) -> Cow<'_, str> {
/// Only to be used for debug logging, exact semantics are unspecified.
pub(crate) fn truncate_str_for_debug(
s: &str,
mut max_len: usize,
) -> Cow<'_, str> {
while max_len < s.len() && !s.is_char_boundary(max_len) {
max_len += 1;
}
@ -255,131 +245,23 @@ pub(crate) fn dbg_truncate_str(s: &str, mut max_len: usize) -> Cow<'_, str> {
}
}
/// Data that makes up an `mxc://` URL.
#[derive(Debug, Clone)]
pub(crate) struct MxcData<'a> {
pub(crate) server_name: &'a ruma::ServerName,
pub(crate) media_id: &'a str,
}
impl<'a> MxcData<'a> {
pub(crate) fn new(
server_name: &'a ruma::ServerName,
media_id: &'a str,
) -> Result<Self> {
if !media_id.bytes().all(|b| {
matches!(b,
b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'-' | b'_'
)
}) {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Invalid MXC media id",
));
}
Ok(Self {
server_name,
media_id,
})
}
}
impl fmt::Display for MxcData<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "mxc://{}/{}", self.server_name, self.media_id)
}
}
impl From<MxcData<'_>> for OwnedMxcUri {
fn from(value: MxcData<'_>) -> Self {
value.to_string().into()
}
}
impl<'a> TryFrom<&'a MxcUri> for MxcData<'a> {
type Error = MxcUriError;
fn try_from(value: &'a MxcUri) -> Result<Self, Self::Error> {
Ok(Self::new(value.server_name()?, value.media_id()?)
.expect("validated MxcUri should always be valid MxcData"))
}
}
fn curlify_args<T>(req: &http::Request<T>) -> Option<Vec<String>> {
let mut args =
vec!["curl".to_owned(), "-X".to_owned(), req.method().to_string()];
for (name, val) in req.headers() {
args.extend([
"-H".to_owned(),
format!("{name}: {}", val.to_str().ok()?),
]);
}
let fix_uri = || {
if req.uri().scheme().is_some() {
return None;
}
if req.uri().authority().is_some() {
return None;
}
let mut parts = req.uri().clone().into_parts();
parts.scheme = Some(http::uri::Scheme::HTTPS);
let host =
req.headers().get(http::header::HOST)?.to_str().ok()?.to_owned();
parts.authority =
Some(http::uri::Authority::from_maybe_shared(host).ok()?);
http::uri::Uri::from_parts(parts).ok()
};
let uri = if let Some(new_uri) = fix_uri() {
Cow::Owned(new_uri)
} else {
Cow::Borrowed(req.uri())
};
args.push(uri.to_string());
Some(args)
}
pub(crate) fn curlify<T>(req: &http::Request<T>) -> Option<String> {
let args = curlify_args(req)?;
Some(
args.into_iter()
.map(|arg| {
if arg.chars().all(|c| {
c.is_alphanumeric() || ['-', '_', ':', '/'].contains(&c)
}) {
arg
} else {
format!("'{}'", arg.replace('\'', "\\'"))
}
})
.collect::<Vec<_>>()
.join(" "),
)
}
#[cfg(test)]
mod tests {
use crate::utils::dbg_truncate_str;
use crate::utils::truncate_str_for_debug;
#[test]
fn test_truncate_str() {
assert_eq!(dbg_truncate_str("short", 10), "short");
assert_eq!(dbg_truncate_str("very long string", 10), "very long ...");
assert_eq!(dbg_truncate_str("no info, only dots", 0), "...");
assert_eq!(dbg_truncate_str("", 0), "");
assert_eq!(dbg_truncate_str("unicöde", 5), "unicö...");
fn test_truncate_str_for_debug() {
assert_eq!(truncate_str_for_debug("short", 10), "short");
assert_eq!(
truncate_str_for_debug("very long string", 10),
"very long ..."
);
assert_eq!(truncate_str_for_debug("no info, only dots", 0), "...");
assert_eq!(truncate_str_for_debug("", 0), "");
assert_eq!(truncate_str_for_debug("unicöde", 5), "unicö...");
let ok_hand = "👌🏽";
assert_eq!(dbg_truncate_str(ok_hand, 1), "👌...");
assert_eq!(dbg_truncate_str(ok_hand, ok_hand.len() - 1), "👌🏽");
assert_eq!(dbg_truncate_str(ok_hand, ok_hand.len()), "👌🏽");
assert_eq!(truncate_str_for_debug(ok_hand, 1), "👌...");
assert_eq!(truncate_str_for_debug(ok_hand, ok_hand.len() - 1), "👌🏽");
assert_eq!(truncate_str_for_debug(ok_hand, ok_hand.len()), "👌🏽");
}
}

View file

@ -9,7 +9,7 @@ use ruma::{
OwnedServerName,
};
use thiserror::Error;
use tracing::{error, warn};
use tracing::{error, info};
use crate::Ra;
@ -86,21 +86,21 @@ pub(crate) enum Error {
impl Error {
pub(crate) fn bad_database(message: &'static str) -> Self {
error!(message, "Bad database");
error!("BadDatabase: {}", message);
Self::BadDatabase(message)
}
pub(crate) fn bad_config(message: &'static str) -> Self {
error!(message, "Bad config");
error!("BadConfig: {}", message);
Self::BadConfig(message)
}
pub(crate) fn to_response(&self) -> Ra<UiaaResponse> {
use ErrorKind::{
Forbidden, GuestAccessForbidden, LimitExceeded, MissingToken,
NotFound, NotYetUploaded, ThreepidAuthFailed, ThreepidDenied,
TooLarge, Unauthorized, Unknown, UnknownToken, Unrecognized,
UserDeactivated, WrongRoomKeysVersion,
NotFound, ThreepidAuthFailed, ThreepidDenied, TooLarge,
Unauthorized, Unknown, UnknownToken, Unrecognized, UserDeactivated,
WrongRoomKeysVersion,
};
if let Self::Uiaa(uiaainfo) = self {
@ -142,7 +142,6 @@ impl Error {
..
} => StatusCode::TOO_MANY_REQUESTS,
TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
NotYetUploaded => StatusCode::GATEWAY_TIMEOUT,
_ => StatusCode::BAD_REQUEST,
},
),
@ -150,7 +149,7 @@ impl Error {
_ => (Unknown, StatusCode::INTERNAL_SERVER_ERROR),
};
warn!(%status_code, error = %message, "Responding with an error");
info!("Returning an error: {}: {}", status_code, message);
Ra(UiaaResponse::MatrixError(RumaError {
body: ErrorBody::Standard {

Some files were not shown because too many files have changed in this diff Show more