Compare commits
28 commits
master
...
refactor-c
Author | SHA1 | Date | |
---|---|---|---|
|
9d853d34a4 | ||
|
6f7a510046 | ||
|
dc28536690 | ||
|
ec2ef37619 | ||
|
419da573d9 | ||
|
342c9bff4f | ||
|
cc1e4be53d | ||
|
62cbb57ebc | ||
|
f33dbb855a | ||
|
769c1ecdd4 | ||
|
0f618de7fd | ||
|
9b08525a18 | ||
|
79e7115b62 | ||
|
d6b32b6ffe | ||
|
c8f921c532 | ||
|
3a09c8e180 | ||
|
f107f8d160 | ||
|
d1138204a6 | ||
|
5108ce52c2 | ||
|
ada1251a52 | ||
|
a10e7e7263 | ||
|
8f4f3314c6 | ||
|
173f8b1b4d | ||
|
b5305ba217 | ||
|
7c166aa468 | ||
|
03b2867a84 | ||
|
3fd7f6efc2 | ||
|
9c71a2cd5e |
149 changed files with 15981 additions and 12862 deletions
|
@ -26,7 +26,7 @@ variables:
|
|||
- if: "$CI_COMMIT_TAG"
|
||||
- if: '($CI_MERGE_REQUEST_APPROVED == "true") || $BUILD_EVERYTHING' # Once MR is approved, test all builds. Or if BUILD_EVERYTHING is set.
|
||||
interruptible: true
|
||||
image: "registry.gitlab.com/jfowl/conduit-containers/rust-with-tools@sha256:69ab327974aef4cc0daf4273579253bf7ae5e379a6c52729b83137e4caa9d093"
|
||||
image: "registry.gitlab.com/jfowl/conduit-containers/rust-with-tools:commit-aa74ef11"
|
||||
tags: ["docker"]
|
||||
services: ["docker:dind"]
|
||||
variables:
|
||||
|
@ -42,8 +42,6 @@ variables:
|
|||
- "cp -r $CARGO_HOME/bin $SHARED_PATH/cargo"
|
||||
- "cp -r $RUSTUP_HOME $SHARED_PATH"
|
||||
- "export CARGO_HOME=$SHARED_PATH/cargo RUSTUP_HOME=$SHARED_PATH/rustup"
|
||||
# If provided, bring in caching through sccache, which uses an external S3 endpoint to store compilation results.
|
||||
- if [ -n "${SCCACHE_ENDPOINT}" ]; then export RUSTC_WRAPPER=/sccache; fi
|
||||
script:
|
||||
# cross-compile conduit for target
|
||||
- 'time cross build --target="$TARGET" --locked --release'
|
||||
|
@ -251,8 +249,6 @@ test:cargo:
|
|||
extends: .test-shared-settings
|
||||
before_script:
|
||||
- rustup component add clippy
|
||||
# If provided, bring in caching through sccache, which uses an external S3 endpoint to store compilation results:
|
||||
- if [ -n "${SCCACHE_ENDPOINT}" ]; then export RUSTC_WRAPPER=/usr/local/cargo/bin/sccache; fi
|
||||
script:
|
||||
- rustc --version && cargo --version # Print version info for debugging
|
||||
- "cargo test --color always --workspace --verbose --locked --no-fail-fast -- -Z unstable-options --format json | gitlab-report -p test > $CI_PROJECT_DIR/report.xml"
|
||||
|
|
6
Cargo.lock
generated
6
Cargo.lock
generated
|
@ -98,9 +98,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.56"
|
||||
version = "0.1.57"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96cf8829f67d2eab0b2dfa42c5d0ef737e0724e4a82b01b3e292456202b19716"
|
||||
checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
@ -408,6 +408,7 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
|
|||
name = "conduit"
|
||||
version = "0.3.0-next"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum",
|
||||
"axum-server",
|
||||
"base64 0.13.0",
|
||||
|
@ -422,6 +423,7 @@ dependencies = [
|
|||
"http",
|
||||
"image",
|
||||
"jsonwebtoken",
|
||||
"lazy_static",
|
||||
"lru-cache",
|
||||
"num_cpus",
|
||||
"opentelemetry",
|
||||
|
|
|
@ -7,7 +7,7 @@ homepage = "https://conduit.rs"
|
|||
repository = "https://gitlab.com/famedly/conduit"
|
||||
readme = "README.md"
|
||||
version = "0.3.0-next"
|
||||
rust-version = "1.56"
|
||||
rust-version = "1.63"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
@ -90,6 +90,8 @@ figment = { version = "0.10.6", features = ["env", "toml"] }
|
|||
|
||||
tikv-jemalloc-ctl = { version = "0.4.2", features = ["use_std"], optional = true }
|
||||
tikv-jemallocator = { version = "0.4.1", features = ["unprefixed_malloc_on_supported_platforms"], optional = true }
|
||||
lazy_static = "1.4.0"
|
||||
async-trait = "0.1.57"
|
||||
|
||||
[features]
|
||||
default = ["conduit_bin", "backend_sqlite", "backend_rocksdb", "jemalloc"]
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
use crate::{utils, Error, Result};
|
||||
use crate::{services, utils, Error, Result};
|
||||
use bytes::BytesMut;
|
||||
use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken};
|
||||
use std::{fmt::Debug, mem, time::Duration};
|
||||
use tracing::warn;
|
||||
|
||||
#[tracing::instrument(skip(globals, request))]
|
||||
#[tracing::instrument(skip(request))]
|
||||
pub(crate) async fn send_request<T: OutgoingRequest>(
|
||||
globals: &crate::database::globals::Globals,
|
||||
registration: serde_yaml::Value,
|
||||
request: T,
|
||||
) -> Result<T::IncomingResponse>
|
||||
|
@ -46,7 +45,11 @@ where
|
|||
*reqwest_request.timeout_mut() = Some(Duration::from_secs(30));
|
||||
|
||||
let url = reqwest_request.url().clone();
|
||||
let mut response = globals.default_client().execute(reqwest_request).await?;
|
||||
let mut response = services()
|
||||
.globals
|
||||
.default_client()
|
||||
.execute(reqwest_request)
|
||||
.await?;
|
||||
|
||||
// reqwest::Response -> http::Response conversion
|
||||
let status = response.status();
|
|
@ -1,11 +1,5 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
|
||||
use crate::{
|
||||
database::{admin::make_user_admin, DatabaseGuard},
|
||||
pdu::PduBuilder,
|
||||
utils, Database, Error, Result, Ruma,
|
||||
};
|
||||
use crate::{api::client_server, services, utils, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
account::{
|
||||
|
@ -15,16 +9,10 @@ use ruma::{
|
|||
error::ErrorKind,
|
||||
uiaa::{AuthFlow, AuthType, UiaaInfo},
|
||||
},
|
||||
events::{
|
||||
room::{
|
||||
member::{MembershipState, RoomMemberEventContent},
|
||||
message::RoomMessageEventContent,
|
||||
},
|
||||
GlobalAccountDataEventType, RoomEventType,
|
||||
},
|
||||
events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType},
|
||||
push, UserId,
|
||||
};
|
||||
use serde_json::value::to_raw_value;
|
||||
|
||||
use tracing::{info, warn};
|
||||
|
||||
use register::RegistrationKind;
|
||||
|
@ -42,23 +30,24 @@ const RANDOM_USER_ID_LENGTH: usize = 10;
|
|||
///
|
||||
/// Note: This will not reserve the username, so the username might become invalid when trying to register
|
||||
pub async fn get_register_available_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_username_availability::v3::IncomingRequest>,
|
||||
) -> Result<get_username_availability::v3::Response> {
|
||||
// Validate user id
|
||||
let user_id =
|
||||
UserId::parse_with_server_name(body.username.to_lowercase(), db.globals.server_name())
|
||||
.ok()
|
||||
.filter(|user_id| {
|
||||
!user_id.is_historical() && user_id.server_name() == db.globals.server_name()
|
||||
})
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidUsername,
|
||||
"Username is invalid.",
|
||||
))?;
|
||||
let user_id = UserId::parse_with_server_name(
|
||||
body.username.to_lowercase(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.ok()
|
||||
.filter(|user_id| {
|
||||
!user_id.is_historical() && user_id.server_name() == services().globals.server_name()
|
||||
})
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidUsername,
|
||||
"Username is invalid.",
|
||||
))?;
|
||||
|
||||
// Check if username is creative enough
|
||||
if db.users.exists(&user_id)? {
|
||||
if services().users.exists(&user_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UserInUse,
|
||||
"Desired user ID is already taken.",
|
||||
|
@ -85,10 +74,9 @@ pub async fn get_register_available_route(
|
|||
/// - Creates a new account and populates it with default account data
|
||||
/// - If `inhibit_login` is false: Creates a device and returns device id and access_token
|
||||
pub async fn register_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<register::v3::IncomingRequest>,
|
||||
) -> Result<register::v3::Response> {
|
||||
if !db.globals.allow_registration() && !body.from_appservice {
|
||||
if !services().globals.allow_registration() && !body.from_appservice {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Registration has been disabled.",
|
||||
|
@ -99,18 +87,20 @@ pub async fn register_route(
|
|||
|
||||
let user_id = match (&body.username, is_guest) {
|
||||
(Some(username), false) => {
|
||||
let proposed_user_id =
|
||||
UserId::parse_with_server_name(username.to_lowercase(), db.globals.server_name())
|
||||
.ok()
|
||||
.filter(|user_id| {
|
||||
!user_id.is_historical()
|
||||
&& user_id.server_name() == db.globals.server_name()
|
||||
})
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidUsername,
|
||||
"Username is invalid.",
|
||||
))?;
|
||||
if db.users.exists(&proposed_user_id)? {
|
||||
let proposed_user_id = UserId::parse_with_server_name(
|
||||
username.to_lowercase(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.ok()
|
||||
.filter(|user_id| {
|
||||
!user_id.is_historical()
|
||||
&& user_id.server_name() == services().globals.server_name()
|
||||
})
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidUsername,
|
||||
"Username is invalid.",
|
||||
))?;
|
||||
if services().users.exists(&proposed_user_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UserInUse,
|
||||
"Desired user ID is already taken.",
|
||||
|
@ -121,10 +111,10 @@ pub async fn register_route(
|
|||
_ => loop {
|
||||
let proposed_user_id = UserId::parse_with_server_name(
|
||||
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
|
||||
db.globals.server_name(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.unwrap();
|
||||
if !db.users.exists(&proposed_user_id)? {
|
||||
if !services().users.exists(&proposed_user_id)? {
|
||||
break proposed_user_id;
|
||||
}
|
||||
},
|
||||
|
@ -143,14 +133,12 @@ pub async fn register_route(
|
|||
|
||||
if !body.from_appservice {
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) = db.uiaa.try_auth(
|
||||
&UserId::parse_with_server_name("", db.globals.server_name())
|
||||
let (worked, uiaainfo) = services().uiaa.try_auth(
|
||||
&UserId::parse_with_server_name("", services().globals.server_name())
|
||||
.expect("we know this is valid"),
|
||||
"".into(),
|
||||
auth,
|
||||
&uiaainfo,
|
||||
&db.users,
|
||||
&db.globals,
|
||||
)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
|
@ -158,8 +146,8 @@ pub async fn register_route(
|
|||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
db.uiaa.create(
|
||||
&UserId::parse_with_server_name("", db.globals.server_name())
|
||||
services().uiaa.create(
|
||||
&UserId::parse_with_server_name("", services().globals.server_name())
|
||||
.expect("we know this is valid"),
|
||||
"".into(),
|
||||
&uiaainfo,
|
||||
|
@ -178,24 +166,25 @@ pub async fn register_route(
|
|||
};
|
||||
|
||||
// Create user
|
||||
db.users.create(&user_id, password)?;
|
||||
services().users.create(&user_id, password)?;
|
||||
|
||||
// Default to pretty displayname
|
||||
let displayname = format!("{} ⚡️", user_id.localpart());
|
||||
db.users
|
||||
services()
|
||||
.users
|
||||
.set_displayname(&user_id, Some(displayname.clone()))?;
|
||||
|
||||
// Initial account data
|
||||
db.account_data.update(
|
||||
services().account_data.update(
|
||||
None,
|
||||
&user_id,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&ruma::events::push_rules::PushRulesEvent {
|
||||
&serde_json::to_value(ruma::events::push_rules::PushRulesEvent {
|
||||
content: ruma::events::push_rules::PushRulesEventContent {
|
||||
global: push::Ruleset::server_default(&user_id),
|
||||
},
|
||||
},
|
||||
&db.globals,
|
||||
})
|
||||
.expect("to json always works"),
|
||||
)?;
|
||||
|
||||
// Inhibit login does not work for guests
|
||||
|
@ -219,7 +208,7 @@ pub async fn register_route(
|
|||
let token = utils::random_string(TOKEN_LENGTH);
|
||||
|
||||
// Create device for this account
|
||||
db.users.create_device(
|
||||
services().users.create_device(
|
||||
&user_id,
|
||||
&device_id,
|
||||
&token,
|
||||
|
@ -227,7 +216,8 @@ pub async fn register_route(
|
|||
)?;
|
||||
|
||||
info!("New user {} registered on this server.", user_id);
|
||||
db.admin
|
||||
services()
|
||||
.admin
|
||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"New user {} registered on this server.",
|
||||
user_id
|
||||
|
@ -235,14 +225,15 @@ pub async fn register_route(
|
|||
|
||||
// If this is the first real user, grant them admin privileges
|
||||
// Note: the server user, @conduit:servername, is generated first
|
||||
if db.users.count()? == 2 {
|
||||
make_user_admin(&db, &user_id, displayname).await?;
|
||||
if services().users.count()? == 2 {
|
||||
services()
|
||||
.admin
|
||||
.make_user_admin(&user_id, displayname)
|
||||
.await?;
|
||||
|
||||
warn!("Granting {} admin privileges as the first user", user_id);
|
||||
}
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(register::v3::Response {
|
||||
access_token: Some(token),
|
||||
user_id,
|
||||
|
@ -265,7 +256,6 @@ pub async fn register_route(
|
|||
/// - Forgets to-device events
|
||||
/// - Triggers device list updates
|
||||
pub async fn change_password_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<change_password::v3::IncomingRequest>,
|
||||
) -> Result<change_password::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -282,46 +272,43 @@ pub async fn change_password_route(
|
|||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) = db.uiaa.try_auth(
|
||||
sender_user,
|
||||
sender_device,
|
||||
auth,
|
||||
&uiaainfo,
|
||||
&db.users,
|
||||
&db.globals,
|
||||
)?;
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
db.uiaa
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
|
||||
db.users
|
||||
services()
|
||||
.users
|
||||
.set_password(sender_user, Some(&body.new_password))?;
|
||||
|
||||
if body.logout_devices {
|
||||
// Logout all devices except the current one
|
||||
for id in db
|
||||
for id in services()
|
||||
.users
|
||||
.all_device_ids(sender_user)
|
||||
.filter_map(|id| id.ok())
|
||||
.filter(|id| id != sender_device)
|
||||
{
|
||||
db.users.remove_device(sender_user, &id)?;
|
||||
services().users.remove_device(sender_user, &id)?;
|
||||
}
|
||||
}
|
||||
|
||||
db.flush()?;
|
||||
|
||||
info!("User {} changed their password.", sender_user);
|
||||
db.admin
|
||||
services()
|
||||
.admin
|
||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"User {} changed their password.",
|
||||
sender_user
|
||||
|
@ -335,17 +322,14 @@ pub async fn change_password_route(
|
|||
/// Get user_id of the sender user.
|
||||
///
|
||||
/// Note: Also works for Application Services
|
||||
pub async fn whoami_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<whoami::v3::Request>,
|
||||
) -> Result<whoami::v3::Response> {
|
||||
pub async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoami::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let device_id = body.sender_device.as_ref().cloned();
|
||||
|
||||
Ok(whoami::v3::Response {
|
||||
user_id: sender_user.clone(),
|
||||
device_id,
|
||||
is_guest: db.users.is_deactivated(&sender_user)?,
|
||||
is_guest: services().users.is_deactivated(&sender_user)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -360,7 +344,6 @@ pub async fn whoami_route(
|
|||
/// - Triggers device list updates
|
||||
/// - Removes ability to log in again
|
||||
pub async fn deactivate_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<deactivate::v3::IncomingRequest>,
|
||||
) -> Result<deactivate::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -377,21 +360,18 @@ pub async fn deactivate_route(
|
|||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) = db.uiaa.try_auth(
|
||||
sender_user,
|
||||
sender_device,
|
||||
auth,
|
||||
&uiaainfo,
|
||||
&db.users,
|
||||
&db.globals,
|
||||
)?;
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
db.uiaa
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
|
@ -399,20 +379,19 @@ pub async fn deactivate_route(
|
|||
}
|
||||
|
||||
// Make the user leave all rooms before deactivation
|
||||
db.rooms.leave_all_rooms(&sender_user, &db).await?;
|
||||
client_server::leave_all_rooms(&sender_user).await?;
|
||||
|
||||
// Remove devices and mark account as deactivated
|
||||
db.users.deactivate_account(sender_user)?;
|
||||
services().users.deactivate_account(sender_user)?;
|
||||
|
||||
info!("User {} deactivated their account.", sender_user);
|
||||
db.admin
|
||||
services()
|
||||
.admin
|
||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"User {} deactivated their account.",
|
||||
sender_user
|
||||
)));
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(deactivate::v3::Response {
|
||||
id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport,
|
||||
})
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, Database, Error, Result, Ruma};
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use regex::Regex;
|
||||
use ruma::{
|
||||
api::{
|
||||
|
@ -16,24 +16,28 @@ use ruma::{
|
|||
///
|
||||
/// Creates a new room alias on this server.
|
||||
pub async fn create_alias_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<create_alias::v3::IncomingRequest>,
|
||||
) -> Result<create_alias::v3::Response> {
|
||||
if body.room_alias.server_name() != db.globals.server_name() {
|
||||
if body.room_alias.server_name() != services().globals.server_name() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Alias is from another server.",
|
||||
));
|
||||
}
|
||||
|
||||
if db.rooms.id_from_alias(&body.room_alias)?.is_some() {
|
||||
if services()
|
||||
.rooms
|
||||
.alias
|
||||
.resolve_local_alias(&body.room_alias)?
|
||||
.is_some()
|
||||
{
|
||||
return Err(Error::Conflict("Alias already exists."));
|
||||
}
|
||||
|
||||
db.rooms
|
||||
.set_alias(&body.room_alias, Some(&body.room_id), &db.globals)?;
|
||||
|
||||
db.flush()?;
|
||||
services()
|
||||
.rooms
|
||||
.alias
|
||||
.set_alias(&body.room_alias, &body.room_id)?;
|
||||
|
||||
Ok(create_alias::v3::Response::new())
|
||||
}
|
||||
|
@ -45,22 +49,19 @@ pub async fn create_alias_route(
|
|||
/// - TODO: additional access control checks
|
||||
/// - TODO: Update canonical alias event
|
||||
pub async fn delete_alias_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<delete_alias::v3::IncomingRequest>,
|
||||
) -> Result<delete_alias::v3::Response> {
|
||||
if body.room_alias.server_name() != db.globals.server_name() {
|
||||
if body.room_alias.server_name() != services().globals.server_name() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Alias is from another server.",
|
||||
));
|
||||
}
|
||||
|
||||
db.rooms.set_alias(&body.room_alias, None, &db.globals)?;
|
||||
services().rooms.alias.remove_alias(&body.room_alias)?;
|
||||
|
||||
// TODO: update alt_aliases?
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(delete_alias::v3::Response::new())
|
||||
}
|
||||
|
||||
|
@ -70,21 +71,16 @@ pub async fn delete_alias_route(
|
|||
///
|
||||
/// - TODO: Suggest more servers to join via
|
||||
pub async fn get_alias_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_alias::v3::IncomingRequest>,
|
||||
) -> Result<get_alias::v3::Response> {
|
||||
get_alias_helper(&db, &body.room_alias).await
|
||||
get_alias_helper(&body.room_alias).await
|
||||
}
|
||||
|
||||
pub(crate) async fn get_alias_helper(
|
||||
db: &Database,
|
||||
room_alias: &RoomAliasId,
|
||||
) -> Result<get_alias::v3::Response> {
|
||||
if room_alias.server_name() != db.globals.server_name() {
|
||||
let response = db
|
||||
pub(crate) async fn get_alias_helper(room_alias: &RoomAliasId) -> Result<get_alias::v3::Response> {
|
||||
if room_alias.server_name() != services().globals.server_name() {
|
||||
let response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
&db.globals,
|
||||
room_alias.server_name(),
|
||||
federation::query::get_room_information::v1::Request { room_alias },
|
||||
)
|
||||
|
@ -97,10 +93,10 @@ pub(crate) async fn get_alias_helper(
|
|||
}
|
||||
|
||||
let mut room_id = None;
|
||||
match db.rooms.id_from_alias(room_alias)? {
|
||||
match services().rooms.alias.resolve_local_alias(room_alias)? {
|
||||
Some(r) => room_id = Some(r),
|
||||
None => {
|
||||
for (_id, registration) in db.appservice.all()? {
|
||||
for (_id, registration) in services().appservice.all()? {
|
||||
let aliases = registration
|
||||
.get("namespaces")
|
||||
.and_then(|ns| ns.get("aliases"))
|
||||
|
@ -115,19 +111,24 @@ pub(crate) async fn get_alias_helper(
|
|||
if aliases
|
||||
.iter()
|
||||
.any(|aliases| aliases.is_match(room_alias.as_str()))
|
||||
&& db
|
||||
&& services()
|
||||
.sending
|
||||
.send_appservice_request(
|
||||
&db.globals,
|
||||
registration,
|
||||
appservice::query::query_room_alias::v1::Request { room_alias },
|
||||
)
|
||||
.await
|
||||
.is_ok()
|
||||
{
|
||||
room_id = Some(db.rooms.id_from_alias(room_alias)?.ok_or_else(|| {
|
||||
Error::bad_config("Appservice lied to us. Room does not exist.")
|
||||
})?);
|
||||
room_id = Some(
|
||||
services()
|
||||
.rooms
|
||||
.alias
|
||||
.resolve_local_alias(room_alias)?
|
||||
.ok_or_else(|| {
|
||||
Error::bad_config("Appservice lied to us. Room does not exist.")
|
||||
})?,
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
@ -146,6 +147,6 @@ pub(crate) async fn get_alias_helper(
|
|||
|
||||
Ok(get_alias::v3::Response::new(
|
||||
room_id,
|
||||
vec![db.globals.server_name().to_owned()],
|
||||
vec![services().globals.server_name().to_owned()],
|
||||
))
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, Error, Result, Ruma};
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::api::client::{
|
||||
backup::{
|
||||
add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session,
|
||||
|
@ -14,15 +14,12 @@ use ruma::api::client::{
|
|||
///
|
||||
/// Creates a new backup.
|
||||
pub async fn create_backup_version_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<create_backup_version::v3::Request>,
|
||||
) -> Result<create_backup_version::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let version = db
|
||||
let version = services()
|
||||
.key_backups
|
||||
.create_backup(sender_user, &body.algorithm, &db.globals)?;
|
||||
|
||||
db.flush()?;
|
||||
.create_backup(sender_user, &body.algorithm)?;
|
||||
|
||||
Ok(create_backup_version::v3::Response { version })
|
||||
}
|
||||
|
@ -31,14 +28,12 @@ pub async fn create_backup_version_route(
|
|||
///
|
||||
/// Update information about an existing backup. Only `auth_data` can be modified.
|
||||
pub async fn update_backup_version_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<update_backup_version::v3::IncomingRequest>,
|
||||
) -> Result<update_backup_version::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
db.key_backups
|
||||
.update_backup(sender_user, &body.version, &body.algorithm, &db.globals)?;
|
||||
|
||||
db.flush()?;
|
||||
services()
|
||||
.key_backups
|
||||
.update_backup(sender_user, &body.version, &body.algorithm)?;
|
||||
|
||||
Ok(update_backup_version::v3::Response {})
|
||||
}
|
||||
|
@ -47,23 +42,22 @@ pub async fn update_backup_version_route(
|
|||
///
|
||||
/// Get information about the latest backup version.
|
||||
pub async fn get_latest_backup_info_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_latest_backup_info::v3::Request>,
|
||||
) -> Result<get_latest_backup_info::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let (version, algorithm) =
|
||||
db.key_backups
|
||||
.get_latest_backup(sender_user)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Key backup does not exist.",
|
||||
))?;
|
||||
let (version, algorithm) = services()
|
||||
.key_backups
|
||||
.get_latest_backup(sender_user)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Key backup does not exist.",
|
||||
))?;
|
||||
|
||||
Ok(get_latest_backup_info::v3::Response {
|
||||
algorithm,
|
||||
count: (db.key_backups.count_keys(sender_user, &version)? as u32).into(),
|
||||
etag: db.key_backups.get_etag(sender_user, &version)?,
|
||||
count: (services().key_backups.count_keys(sender_user, &version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &version)?,
|
||||
version,
|
||||
})
|
||||
}
|
||||
|
@ -72,11 +66,10 @@ pub async fn get_latest_backup_info_route(
|
|||
///
|
||||
/// Get information about an existing backup.
|
||||
pub async fn get_backup_info_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_backup_info::v3::IncomingRequest>,
|
||||
) -> Result<get_backup_info::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let algorithm = db
|
||||
let algorithm = services()
|
||||
.key_backups
|
||||
.get_backup(sender_user, &body.version)?
|
||||
.ok_or(Error::BadRequest(
|
||||
|
@ -86,8 +79,13 @@ pub async fn get_backup_info_route(
|
|||
|
||||
Ok(get_backup_info::v3::Response {
|
||||
algorithm,
|
||||
count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: db.key_backups.get_etag(sender_user, &body.version)?,
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
version: body.version.to_owned(),
|
||||
})
|
||||
}
|
||||
|
@ -98,14 +96,13 @@ pub async fn get_backup_info_route(
|
|||
///
|
||||
/// - Deletes both information about the backup, as well as all key data related to the backup
|
||||
pub async fn delete_backup_version_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<delete_backup_version::v3::IncomingRequest>,
|
||||
) -> Result<delete_backup_version::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
db.key_backups.delete_backup(sender_user, &body.version)?;
|
||||
|
||||
db.flush()?;
|
||||
services()
|
||||
.key_backups
|
||||
.delete_backup(sender_user, &body.version)?;
|
||||
|
||||
Ok(delete_backup_version::v3::Response {})
|
||||
}
|
||||
|
@ -118,13 +115,12 @@ pub async fn delete_backup_version_route(
|
|||
/// - Adds the keys to the backup
|
||||
/// - Returns the new number of keys in this backup and the etag
|
||||
pub async fn add_backup_keys_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<add_backup_keys::v3::IncomingRequest>,
|
||||
) -> Result<add_backup_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if Some(&body.version)
|
||||
!= db
|
||||
!= services()
|
||||
.key_backups
|
||||
.get_latest_backup_version(sender_user)?
|
||||
.as_ref()
|
||||
|
@ -137,22 +133,24 @@ pub async fn add_backup_keys_route(
|
|||
|
||||
for (room_id, room) in &body.rooms {
|
||||
for (session_id, key_data) in &room.sessions {
|
||||
db.key_backups.add_key(
|
||||
services().key_backups.add_key(
|
||||
sender_user,
|
||||
&body.version,
|
||||
room_id,
|
||||
session_id,
|
||||
key_data,
|
||||
&db.globals,
|
||||
)?
|
||||
}
|
||||
}
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(add_backup_keys::v3::Response {
|
||||
count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: db.key_backups.get_etag(sender_user, &body.version)?,
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -164,13 +162,12 @@ pub async fn add_backup_keys_route(
|
|||
/// - Adds the keys to the backup
|
||||
/// - Returns the new number of keys in this backup and the etag
|
||||
pub async fn add_backup_keys_for_room_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<add_backup_keys_for_room::v3::IncomingRequest>,
|
||||
) -> Result<add_backup_keys_for_room::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if Some(&body.version)
|
||||
!= db
|
||||
!= services()
|
||||
.key_backups
|
||||
.get_latest_backup_version(sender_user)?
|
||||
.as_ref()
|
||||
|
@ -182,21 +179,23 @@ pub async fn add_backup_keys_for_room_route(
|
|||
}
|
||||
|
||||
for (session_id, key_data) in &body.sessions {
|
||||
db.key_backups.add_key(
|
||||
services().key_backups.add_key(
|
||||
sender_user,
|
||||
&body.version,
|
||||
&body.room_id,
|
||||
session_id,
|
||||
key_data,
|
||||
&db.globals,
|
||||
)?
|
||||
}
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(add_backup_keys_for_room::v3::Response {
|
||||
count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: db.key_backups.get_etag(sender_user, &body.version)?,
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -208,13 +207,12 @@ pub async fn add_backup_keys_for_room_route(
|
|||
/// - Adds the keys to the backup
|
||||
/// - Returns the new number of keys in this backup and the etag
|
||||
pub async fn add_backup_keys_for_session_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<add_backup_keys_for_session::v3::IncomingRequest>,
|
||||
) -> Result<add_backup_keys_for_session::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if Some(&body.version)
|
||||
!= db
|
||||
!= services()
|
||||
.key_backups
|
||||
.get_latest_backup_version(sender_user)?
|
||||
.as_ref()
|
||||
|
@ -225,20 +223,22 @@ pub async fn add_backup_keys_for_session_route(
|
|||
));
|
||||
}
|
||||
|
||||
db.key_backups.add_key(
|
||||
services().key_backups.add_key(
|
||||
sender_user,
|
||||
&body.version,
|
||||
&body.room_id,
|
||||
&body.session_id,
|
||||
&body.session_data,
|
||||
&db.globals,
|
||||
)?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(add_backup_keys_for_session::v3::Response {
|
||||
count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: db.key_backups.get_etag(sender_user, &body.version)?,
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -246,12 +246,11 @@ pub async fn add_backup_keys_for_session_route(
|
|||
///
|
||||
/// Retrieves all keys from the backup.
|
||||
pub async fn get_backup_keys_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_backup_keys::v3::IncomingRequest>,
|
||||
) -> Result<get_backup_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let rooms = db.key_backups.get_all(sender_user, &body.version)?;
|
||||
let rooms = services().key_backups.get_all(sender_user, &body.version)?;
|
||||
|
||||
Ok(get_backup_keys::v3::Response { rooms })
|
||||
}
|
||||
|
@ -260,12 +259,11 @@ pub async fn get_backup_keys_route(
|
|||
///
|
||||
/// Retrieves all keys from the backup for a given room.
|
||||
pub async fn get_backup_keys_for_room_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_backup_keys_for_room::v3::IncomingRequest>,
|
||||
) -> Result<get_backup_keys_for_room::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let sessions = db
|
||||
let sessions = services()
|
||||
.key_backups
|
||||
.get_room(sender_user, &body.version, &body.room_id)?;
|
||||
|
||||
|
@ -276,12 +274,11 @@ pub async fn get_backup_keys_for_room_route(
|
|||
///
|
||||
/// Retrieves a key from the backup.
|
||||
pub async fn get_backup_keys_for_session_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_backup_keys_for_session::v3::IncomingRequest>,
|
||||
) -> Result<get_backup_keys_for_session::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let key_data = db
|
||||
let key_data = services()
|
||||
.key_backups
|
||||
.get_session(sender_user, &body.version, &body.room_id, &body.session_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
|
@ -296,18 +293,22 @@ pub async fn get_backup_keys_for_session_route(
|
|||
///
|
||||
/// Delete the keys from the backup.
|
||||
pub async fn delete_backup_keys_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<delete_backup_keys::v3::IncomingRequest>,
|
||||
) -> Result<delete_backup_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
db.key_backups.delete_all_keys(sender_user, &body.version)?;
|
||||
|
||||
db.flush()?;
|
||||
services()
|
||||
.key_backups
|
||||
.delete_all_keys(sender_user, &body.version)?;
|
||||
|
||||
Ok(delete_backup_keys::v3::Response {
|
||||
count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: db.key_backups.get_etag(sender_user, &body.version)?,
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -315,19 +316,22 @@ pub async fn delete_backup_keys_route(
|
|||
///
|
||||
/// Delete the keys from the backup for a given room.
|
||||
pub async fn delete_backup_keys_for_room_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<delete_backup_keys_for_room::v3::IncomingRequest>,
|
||||
) -> Result<delete_backup_keys_for_room::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
db.key_backups
|
||||
services()
|
||||
.key_backups
|
||||
.delete_room_keys(sender_user, &body.version, &body.room_id)?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(delete_backup_keys_for_room::v3::Response {
|
||||
count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: db.key_backups.get_etag(sender_user, &body.version)?,
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -335,18 +339,24 @@ pub async fn delete_backup_keys_for_room_route(
|
|||
///
|
||||
/// Delete a key from the backup.
|
||||
pub async fn delete_backup_keys_for_session_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<delete_backup_keys_for_session::v3::IncomingRequest>,
|
||||
) -> Result<delete_backup_keys_for_session::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
db.key_backups
|
||||
.delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?;
|
||||
|
||||
db.flush()?;
|
||||
services().key_backups.delete_room_key(
|
||||
sender_user,
|
||||
&body.version,
|
||||
&body.room_id,
|
||||
&body.session_id,
|
||||
)?;
|
||||
|
||||
Ok(delete_backup_keys_for_session::v3::Response {
|
||||
count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: db.key_backups.get_etag(sender_user, &body.version)?,
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, Result, Ruma};
|
||||
use crate::{services, Result, Ruma};
|
||||
use ruma::api::client::discovery::get_capabilities::{
|
||||
self, Capabilities, RoomVersionStability, RoomVersionsCapability,
|
||||
};
|
||||
|
@ -8,26 +8,25 @@ use std::collections::BTreeMap;
|
|||
///
|
||||
/// Get information on the supported feature set and other relevent capabilities of this server.
|
||||
pub async fn get_capabilities_route(
|
||||
db: DatabaseGuard,
|
||||
_body: Ruma<get_capabilities::v3::IncomingRequest>,
|
||||
) -> Result<get_capabilities::v3::Response> {
|
||||
let mut available = BTreeMap::new();
|
||||
if db.globals.allow_unstable_room_versions() {
|
||||
for room_version in &db.globals.unstable_room_versions {
|
||||
if services().globals.allow_unstable_room_versions() {
|
||||
for room_version in &services().globals.unstable_room_versions {
|
||||
available.insert(room_version.clone(), RoomVersionStability::Stable);
|
||||
}
|
||||
} else {
|
||||
for room_version in &db.globals.unstable_room_versions {
|
||||
for room_version in &services().globals.unstable_room_versions {
|
||||
available.insert(room_version.clone(), RoomVersionStability::Unstable);
|
||||
}
|
||||
}
|
||||
for room_version in &db.globals.stable_room_versions {
|
||||
for room_version in &services().globals.stable_room_versions {
|
||||
available.insert(room_version.clone(), RoomVersionStability::Stable);
|
||||
}
|
||||
|
||||
let mut capabilities = Capabilities::new();
|
||||
capabilities.room_versions = RoomVersionsCapability {
|
||||
default: db.globals.default_room_version(),
|
||||
default: services().globals.default_room_version(),
|
||||
available,
|
||||
};
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, Error, Result, Ruma};
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
config::{
|
||||
|
@ -17,7 +17,6 @@ use serde_json::{json, value::RawValue as RawJsonValue};
|
|||
///
|
||||
/// Sets some account data for the sender user.
|
||||
pub async fn set_global_account_data_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<set_global_account_data::v3::IncomingRequest>,
|
||||
) -> Result<set_global_account_data::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -27,7 +26,7 @@ pub async fn set_global_account_data_route(
|
|||
|
||||
let event_type = body.event_type.to_string();
|
||||
|
||||
db.account_data.update(
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
event_type.clone().into(),
|
||||
|
@ -35,11 +34,8 @@ pub async fn set_global_account_data_route(
|
|||
"type": event_type,
|
||||
"content": data,
|
||||
}),
|
||||
&db.globals,
|
||||
)?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(set_global_account_data::v3::Response {})
|
||||
}
|
||||
|
||||
|
@ -47,7 +43,6 @@ pub async fn set_global_account_data_route(
|
|||
///
|
||||
/// Sets some room account data for the sender user.
|
||||
pub async fn set_room_account_data_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<set_room_account_data::v3::IncomingRequest>,
|
||||
) -> Result<set_room_account_data::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -57,7 +52,7 @@ pub async fn set_room_account_data_route(
|
|||
|
||||
let event_type = body.event_type.to_string();
|
||||
|
||||
db.account_data.update(
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
event_type.clone().into(),
|
||||
|
@ -65,11 +60,8 @@ pub async fn set_room_account_data_route(
|
|||
"type": event_type,
|
||||
"content": data,
|
||||
}),
|
||||
&db.globals,
|
||||
)?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(set_room_account_data::v3::Response {})
|
||||
}
|
||||
|
||||
|
@ -77,12 +69,11 @@ pub async fn set_room_account_data_route(
|
|||
///
|
||||
/// Gets some account data for the sender user.
|
||||
pub async fn get_global_account_data_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_global_account_data::v3::IncomingRequest>,
|
||||
) -> Result<get_global_account_data::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event: Box<RawJsonValue> = db
|
||||
let event: Box<RawJsonValue> = services()
|
||||
.account_data
|
||||
.get(None, sender_user, body.event_type.clone().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
|
||||
|
@ -98,12 +89,11 @@ pub async fn get_global_account_data_route(
|
|||
///
|
||||
/// Gets some room account data for the sender user.
|
||||
pub async fn get_room_account_data_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_room_account_data::v3::IncomingRequest>,
|
||||
) -> Result<get_room_account_data::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event: Box<RawJsonValue> = db
|
||||
let event: Box<RawJsonValue> = services()
|
||||
.account_data
|
||||
.get(
|
||||
Some(&body.room_id),
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, Error, Result, Ruma};
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions},
|
||||
events::StateEventType,
|
||||
|
@ -13,7 +13,6 @@ use tracing::error;
|
|||
/// - Only works if the user is joined (TODO: always allow, but only show events if the user was
|
||||
/// joined, depending on history_visibility)
|
||||
pub async fn get_context_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_context::v3::IncomingRequest>,
|
||||
) -> Result<get_context::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -28,18 +27,20 @@ pub async fn get_context_route(
|
|||
|
||||
let mut lazy_loaded = HashSet::new();
|
||||
|
||||
let base_pdu_id = db
|
||||
let base_pdu_id = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_id(&body.event_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Base event id not found.",
|
||||
))?;
|
||||
|
||||
let base_token = db.rooms.pdu_count(&base_pdu_id)?;
|
||||
let base_token = services().rooms.timeline.pdu_count(&base_pdu_id)?;
|
||||
|
||||
let base_event = db
|
||||
let base_event = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_from_id(&base_pdu_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
|
@ -48,14 +49,18 @@ pub async fn get_context_route(
|
|||
|
||||
let room_id = base_event.room_id.clone();
|
||||
|
||||
if !db.rooms.is_joined(sender_user, &room_id)? {
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_joined(sender_user, &room_id)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view this room.",
|
||||
));
|
||||
}
|
||||
|
||||
if !db.rooms.lazy_load_was_sent_before(
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&room_id,
|
||||
|
@ -67,8 +72,9 @@ pub async fn get_context_route(
|
|||
|
||||
let base_event = base_event.to_room_event();
|
||||
|
||||
let events_before: Vec<_> = db
|
||||
let events_before: Vec<_> = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdus_until(sender_user, &room_id, base_token)?
|
||||
.take(
|
||||
u32::try_from(body.limit).map_err(|_| {
|
||||
|
@ -80,7 +86,7 @@ pub async fn get_context_route(
|
|||
.collect();
|
||||
|
||||
for (_, event) in &events_before {
|
||||
if !db.rooms.lazy_load_was_sent_before(
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&room_id,
|
||||
|
@ -93,7 +99,7 @@ pub async fn get_context_route(
|
|||
|
||||
let start_token = events_before
|
||||
.last()
|
||||
.and_then(|(pdu_id, _)| db.rooms.pdu_count(pdu_id).ok())
|
||||
.and_then(|(pdu_id, _)| services().rooms.timeline.pdu_count(pdu_id).ok())
|
||||
.map(|count| count.to_string());
|
||||
|
||||
let events_before: Vec<_> = events_before
|
||||
|
@ -101,8 +107,9 @@ pub async fn get_context_route(
|
|||
.map(|(_, pdu)| pdu.to_room_event())
|
||||
.collect();
|
||||
|
||||
let events_after: Vec<_> = db
|
||||
let events_after: Vec<_> = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdus_after(sender_user, &room_id, base_token)?
|
||||
.take(
|
||||
u32::try_from(body.limit).map_err(|_| {
|
||||
|
@ -114,7 +121,7 @@ pub async fn get_context_route(
|
|||
.collect();
|
||||
|
||||
for (_, event) in &events_after {
|
||||
if !db.rooms.lazy_load_was_sent_before(
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&room_id,
|
||||
|
@ -125,23 +132,28 @@ pub async fn get_context_route(
|
|||
}
|
||||
}
|
||||
|
||||
let shortstatehash = match db.rooms.pdu_shortstatehash(
|
||||
let shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash(
|
||||
events_after
|
||||
.last()
|
||||
.map_or(&*body.event_id, |(_, e)| &*e.event_id),
|
||||
)? {
|
||||
Some(s) => s,
|
||||
None => db
|
||||
None => services()
|
||||
.rooms
|
||||
.current_shortstatehash(&room_id)?
|
||||
.state
|
||||
.get_room_shortstatehash(&room_id)?
|
||||
.expect("All rooms have state"),
|
||||
};
|
||||
|
||||
let state_ids = db.rooms.state_full_ids(shortstatehash).await?;
|
||||
let state_ids = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.state_full_ids(shortstatehash)
|
||||
.await?;
|
||||
|
||||
let end_token = events_after
|
||||
.last()
|
||||
.and_then(|(pdu_id, _)| db.rooms.pdu_count(pdu_id).ok())
|
||||
.and_then(|(pdu_id, _)| services().rooms.timeline.pdu_count(pdu_id).ok())
|
||||
.map(|count| count.to_string());
|
||||
|
||||
let events_after: Vec<_> = events_after
|
||||
|
@ -152,10 +164,13 @@ pub async fn get_context_route(
|
|||
let mut state = Vec::new();
|
||||
|
||||
for (shortstatekey, id) in state_ids {
|
||||
let (event_type, state_key) = db.rooms.get_statekey_from_short(shortstatekey)?;
|
||||
let (event_type, state_key) = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_statekey_from_short(shortstatekey)?;
|
||||
|
||||
if event_type != StateEventType::RoomMember {
|
||||
let pdu = match db.rooms.get_pdu(&id)? {
|
||||
let pdu = match services().rooms.timeline.get_pdu(&id)? {
|
||||
Some(pdu) => pdu,
|
||||
None => {
|
||||
error!("Pdu in state not found: {}", id);
|
||||
|
@ -164,7 +179,7 @@ pub async fn get_context_route(
|
|||
};
|
||||
state.push(pdu.to_state_event());
|
||||
} else if !lazy_load_enabled || lazy_loaded.contains(&state_key) {
|
||||
let pdu = match db.rooms.get_pdu(&id)? {
|
||||
let pdu = match services().rooms.timeline.get_pdu(&id)? {
|
||||
Some(pdu) => pdu,
|
||||
None => {
|
||||
error!("Pdu in state not found: {}", id);
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, utils, Error, Result, Ruma};
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
use ruma::api::client::{
|
||||
device::{self, delete_device, delete_devices, get_device, get_devices, update_device},
|
||||
error::ErrorKind,
|
||||
|
@ -11,12 +11,11 @@ use super::SESSION_ID_LENGTH;
|
|||
///
|
||||
/// Get metadata on all devices of the sender user.
|
||||
pub async fn get_devices_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_devices::v3::Request>,
|
||||
) -> Result<get_devices::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let devices: Vec<device::Device> = db
|
||||
let devices: Vec<device::Device> = services()
|
||||
.users
|
||||
.all_devices_metadata(sender_user)
|
||||
.filter_map(|r| r.ok()) // Filter out buggy devices
|
||||
|
@ -29,12 +28,11 @@ pub async fn get_devices_route(
|
|||
///
|
||||
/// Get metadata on a single device of the sender user.
|
||||
pub async fn get_device_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_device::v3::IncomingRequest>,
|
||||
) -> Result<get_device::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let device = db
|
||||
let device = services()
|
||||
.users
|
||||
.get_device_metadata(sender_user, &body.body.device_id)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
|
||||
|
@ -46,23 +44,21 @@ pub async fn get_device_route(
|
|||
///
|
||||
/// Updates the metadata on a given device of the sender user.
|
||||
pub async fn update_device_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<update_device::v3::IncomingRequest>,
|
||||
) -> Result<update_device::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let mut device = db
|
||||
let mut device = services()
|
||||
.users
|
||||
.get_device_metadata(sender_user, &body.device_id)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
|
||||
|
||||
device.display_name = body.display_name.clone();
|
||||
|
||||
db.users
|
||||
services()
|
||||
.users
|
||||
.update_device_metadata(sender_user, &body.device_id, &device)?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(update_device::v3::Response {})
|
||||
}
|
||||
|
||||
|
@ -76,7 +72,6 @@ pub async fn update_device_route(
|
|||
/// - Forgets to-device events
|
||||
/// - Triggers device list updates
|
||||
pub async fn delete_device_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<delete_device::v3::IncomingRequest>,
|
||||
) -> Result<delete_device::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -94,30 +89,27 @@ pub async fn delete_device_route(
|
|||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) = db.uiaa.try_auth(
|
||||
sender_user,
|
||||
sender_device,
|
||||
auth,
|
||||
&uiaainfo,
|
||||
&db.users,
|
||||
&db.globals,
|
||||
)?;
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
db.uiaa
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
|
||||
db.users.remove_device(sender_user, &body.device_id)?;
|
||||
|
||||
db.flush()?;
|
||||
services()
|
||||
.users
|
||||
.remove_device(sender_user, &body.device_id)?;
|
||||
|
||||
Ok(delete_device::v3::Response {})
|
||||
}
|
||||
|
@ -134,7 +126,6 @@ pub async fn delete_device_route(
|
|||
/// - Forgets to-device events
|
||||
/// - Triggers device list updates
|
||||
pub async fn delete_devices_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<delete_devices::v3::IncomingRequest>,
|
||||
) -> Result<delete_devices::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -152,21 +143,18 @@ pub async fn delete_devices_route(
|
|||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) = db.uiaa.try_auth(
|
||||
sender_user,
|
||||
sender_device,
|
||||
auth,
|
||||
&uiaainfo,
|
||||
&db.users,
|
||||
&db.globals,
|
||||
)?;
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
db.uiaa
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
|
@ -174,10 +162,8 @@ pub async fn delete_devices_route(
|
|||
}
|
||||
|
||||
for device_id in &body.devices {
|
||||
db.users.remove_device(sender_user, device_id)?
|
||||
services().users.remove_device(sender_user, device_id)?
|
||||
}
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(delete_devices::v3::Response {})
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, Database, Error, Result, Ruma};
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::{
|
||||
|
@ -29,7 +29,7 @@ use ruma::{
|
|||
},
|
||||
ServerName, UInt,
|
||||
};
|
||||
use tracing::{info, warn};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
/// # `POST /_matrix/client/r0/publicRooms`
|
||||
///
|
||||
|
@ -37,11 +37,9 @@ use tracing::{info, warn};
|
|||
///
|
||||
/// - Rooms are ordered by the number of joined members
|
||||
pub async fn get_public_rooms_filtered_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_public_rooms_filtered::v3::IncomingRequest>,
|
||||
) -> Result<get_public_rooms_filtered::v3::Response> {
|
||||
get_public_rooms_filtered_helper(
|
||||
&db,
|
||||
body.server.as_deref(),
|
||||
body.limit,
|
||||
body.since.as_deref(),
|
||||
|
@ -57,11 +55,9 @@ pub async fn get_public_rooms_filtered_route(
|
|||
///
|
||||
/// - Rooms are ordered by the number of joined members
|
||||
pub async fn get_public_rooms_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_public_rooms::v3::IncomingRequest>,
|
||||
) -> Result<get_public_rooms::v3::Response> {
|
||||
let response = get_public_rooms_filtered_helper(
|
||||
&db,
|
||||
body.server.as_deref(),
|
||||
body.limit,
|
||||
body.since.as_deref(),
|
||||
|
@ -84,17 +80,16 @@ pub async fn get_public_rooms_route(
|
|||
///
|
||||
/// - TODO: Access control checks
|
||||
pub async fn set_room_visibility_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<set_room_visibility::v3::IncomingRequest>,
|
||||
) -> Result<set_room_visibility::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
match &body.visibility {
|
||||
room::Visibility::Public => {
|
||||
db.rooms.set_public(&body.room_id, true)?;
|
||||
services().rooms.directory.set_public(&body.room_id)?;
|
||||
info!("{} made {} public", sender_user, body.room_id);
|
||||
}
|
||||
room::Visibility::Private => db.rooms.set_public(&body.room_id, false)?,
|
||||
room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
|
@ -103,8 +98,6 @@ pub async fn set_room_visibility_route(
|
|||
}
|
||||
}
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(set_room_visibility::v3::Response {})
|
||||
}
|
||||
|
||||
|
@ -112,11 +105,10 @@ pub async fn set_room_visibility_route(
|
|||
///
|
||||
/// Gets the visibility of a given room in the room directory.
|
||||
pub async fn get_room_visibility_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_room_visibility::v3::IncomingRequest>,
|
||||
) -> Result<get_room_visibility::v3::Response> {
|
||||
Ok(get_room_visibility::v3::Response {
|
||||
visibility: if db.rooms.is_public_room(&body.room_id)? {
|
||||
visibility: if services().rooms.directory.is_public_room(&body.room_id)? {
|
||||
room::Visibility::Public
|
||||
} else {
|
||||
room::Visibility::Private
|
||||
|
@ -125,19 +117,18 @@ pub async fn get_room_visibility_route(
|
|||
}
|
||||
|
||||
pub(crate) async fn get_public_rooms_filtered_helper(
|
||||
db: &Database,
|
||||
server: Option<&ServerName>,
|
||||
limit: Option<UInt>,
|
||||
since: Option<&str>,
|
||||
filter: &IncomingFilter,
|
||||
_network: &IncomingRoomNetwork,
|
||||
) -> Result<get_public_rooms_filtered::v3::Response> {
|
||||
if let Some(other_server) = server.filter(|server| *server != db.globals.server_name().as_str())
|
||||
if let Some(other_server) =
|
||||
server.filter(|server| *server != services().globals.server_name().as_str())
|
||||
{
|
||||
let response = db
|
||||
let response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
&db.globals,
|
||||
other_server,
|
||||
federation::directory::get_public_rooms_filtered::v1::Request {
|
||||
limit,
|
||||
|
@ -184,15 +175,17 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
}
|
||||
}
|
||||
|
||||
let mut all_rooms: Vec<_> = db
|
||||
let mut all_rooms: Vec<_> = services()
|
||||
.rooms
|
||||
.directory
|
||||
.public_rooms()
|
||||
.map(|room_id| {
|
||||
let room_id = room_id?;
|
||||
|
||||
let chunk = PublicRoomsChunk {
|
||||
canonical_alias: db
|
||||
canonical_alias: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")?
|
||||
.map_or(Ok(None), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
|
@ -201,8 +194,9 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
Error::bad_database("Invalid canonical alias event in database.")
|
||||
})
|
||||
})?,
|
||||
name: db
|
||||
name: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomName, "")?
|
||||
.map_or(Ok(None), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
|
@ -211,8 +205,9 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
Error::bad_database("Invalid room name event in database.")
|
||||
})
|
||||
})?,
|
||||
num_joined_members: db
|
||||
num_joined_members: services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_joined_count(&room_id)?
|
||||
.unwrap_or_else(|| {
|
||||
warn!("Room {} has no member count", room_id);
|
||||
|
@ -220,8 +215,9 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
})
|
||||
.try_into()
|
||||
.expect("user count should not be that big"),
|
||||
topic: db
|
||||
topic: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomTopic, "")?
|
||||
.map_or(Ok(None), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
|
@ -230,8 +226,9 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
Error::bad_database("Invalid room topic event in database.")
|
||||
})
|
||||
})?,
|
||||
world_readable: db
|
||||
world_readable: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")?
|
||||
.map_or(Ok(false), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
|
@ -244,8 +241,9 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
)
|
||||
})
|
||||
})?,
|
||||
guest_can_join: db
|
||||
guest_can_join: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")?
|
||||
.map_or(Ok(false), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
|
@ -256,8 +254,9 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
Error::bad_database("Invalid room guest access event in database.")
|
||||
})
|
||||
})?,
|
||||
avatar_url: db
|
||||
avatar_url: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomAvatar, "")?
|
||||
.map(|s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
|
@ -269,8 +268,9 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
.transpose()?
|
||||
// url is now an Option<String> so we must flatten
|
||||
.flatten(),
|
||||
join_rule: db
|
||||
join_rule: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomJoinRules, "")?
|
||||
.map(|s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
|
@ -279,15 +279,14 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
JoinRule::Knock => Some(PublicRoomJoinRule::Knock),
|
||||
_ => None,
|
||||
})
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid room join rule event in database.")
|
||||
.map_err(|e| {
|
||||
error!("Invalid room join rule event in database: {}", e);
|
||||
Error::BadDatabase("Invalid room join rule event in database.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.flatten()
|
||||
.ok_or(Error::bad_database(
|
||||
"Invalid room join rule event in database.",
|
||||
))?,
|
||||
.ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?,
|
||||
room_id,
|
||||
};
|
||||
Ok(chunk)
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, Error, Result, Ruma};
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::api::client::{
|
||||
error::ErrorKind,
|
||||
filter::{create_filter, get_filter},
|
||||
|
@ -10,11 +10,10 @@ use ruma::api::client::{
|
|||
///
|
||||
/// - A user can only access their own filters
|
||||
pub async fn get_filter_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_filter::v3::IncomingRequest>,
|
||||
) -> Result<get_filter::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let filter = match db.users.get_filter(sender_user, &body.filter_id)? {
|
||||
let filter = match services().users.get_filter(sender_user, &body.filter_id)? {
|
||||
Some(filter) => filter,
|
||||
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")),
|
||||
};
|
||||
|
@ -26,11 +25,10 @@ pub async fn get_filter_route(
|
|||
///
|
||||
/// Creates a new filter to be used by other endpoints.
|
||||
pub async fn create_filter_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<create_filter::v3::IncomingRequest>,
|
||||
) -> Result<create_filter::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
Ok(create_filter::v3::Response::new(
|
||||
db.users.create_filter(sender_user, &body.filter)?,
|
||||
services().users.create_filter(sender_user, &body.filter)?,
|
||||
))
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
use super::SESSION_ID_LENGTH;
|
||||
use crate::{database::DatabaseGuard, utils, Database, Error, Result, Ruma};
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
use futures_util::{stream::FuturesUnordered, StreamExt};
|
||||
use ruma::{
|
||||
api::{
|
||||
|
@ -26,39 +26,35 @@ use std::collections::{BTreeMap, HashMap, HashSet};
|
|||
/// - Adds one time keys
|
||||
/// - If there are no device keys yet: Adds device keys (TODO: merge with existing keys?)
|
||||
pub async fn upload_keys_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<upload_keys::v3::Request>,
|
||||
) -> Result<upload_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
for (key_key, key_value) in &body.one_time_keys {
|
||||
db.users
|
||||
.add_one_time_key(sender_user, sender_device, key_key, key_value, &db.globals)?;
|
||||
services()
|
||||
.users
|
||||
.add_one_time_key(sender_user, sender_device, key_key, key_value)?;
|
||||
}
|
||||
|
||||
if let Some(device_keys) = &body.device_keys {
|
||||
// TODO: merge this and the existing event?
|
||||
// This check is needed to assure that signatures are kept
|
||||
if db
|
||||
if services()
|
||||
.users
|
||||
.get_device_keys(sender_user, sender_device)?
|
||||
.is_none()
|
||||
{
|
||||
db.users.add_device_keys(
|
||||
sender_user,
|
||||
sender_device,
|
||||
device_keys,
|
||||
&db.rooms,
|
||||
&db.globals,
|
||||
)?;
|
||||
services()
|
||||
.users
|
||||
.add_device_keys(sender_user, sender_device, device_keys)?;
|
||||
}
|
||||
}
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(upload_keys::v3::Response {
|
||||
one_time_key_counts: db.users.count_one_time_keys(sender_user, sender_device)?,
|
||||
one_time_key_counts: services()
|
||||
.users
|
||||
.count_one_time_keys(sender_user, sender_device)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -70,18 +66,12 @@ pub async fn upload_keys_route(
|
|||
/// - Gets master keys, self-signing keys, user signing keys and device keys.
|
||||
/// - The master and self-signing keys contain signatures that the user is allowed to see
|
||||
pub async fn get_keys_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_keys::v3::IncomingRequest>,
|
||||
) -> Result<get_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let response = get_keys_helper(
|
||||
Some(sender_user),
|
||||
&body.device_keys,
|
||||
|u| u == sender_user,
|
||||
&db,
|
||||
)
|
||||
.await?;
|
||||
let response =
|
||||
get_keys_helper(Some(sender_user), &body.device_keys, |u| u == sender_user).await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
@ -90,12 +80,9 @@ pub async fn get_keys_route(
|
|||
///
|
||||
/// Claims one-time keys
|
||||
pub async fn claim_keys_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<claim_keys::v3::Request>,
|
||||
) -> Result<claim_keys::v3::Response> {
|
||||
let response = claim_keys_helper(&body.one_time_keys, &db).await?;
|
||||
|
||||
db.flush()?;
|
||||
let response = claim_keys_helper(&body.one_time_keys).await?;
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
|
@ -106,7 +93,6 @@ pub async fn claim_keys_route(
|
|||
///
|
||||
/// - Requires UIAA to verify password
|
||||
pub async fn upload_signing_keys_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<upload_signing_keys::v3::IncomingRequest>,
|
||||
) -> Result<upload_signing_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -124,21 +110,18 @@ pub async fn upload_signing_keys_route(
|
|||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) = db.uiaa.try_auth(
|
||||
sender_user,
|
||||
sender_device,
|
||||
auth,
|
||||
&uiaainfo,
|
||||
&db.users,
|
||||
&db.globals,
|
||||
)?;
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
db.uiaa
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
|
@ -146,18 +129,14 @@ pub async fn upload_signing_keys_route(
|
|||
}
|
||||
|
||||
if let Some(master_key) = &body.master_key {
|
||||
db.users.add_cross_signing_keys(
|
||||
services().users.add_cross_signing_keys(
|
||||
sender_user,
|
||||
master_key,
|
||||
&body.self_signing_key,
|
||||
&body.user_signing_key,
|
||||
&db.rooms,
|
||||
&db.globals,
|
||||
)?;
|
||||
}
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(upload_signing_keys::v3::Response {})
|
||||
}
|
||||
|
||||
|
@ -165,7 +144,6 @@ pub async fn upload_signing_keys_route(
|
|||
///
|
||||
/// Uploads end-to-end key signatures from the sender user.
|
||||
pub async fn upload_signatures_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<upload_signatures::v3::Request>,
|
||||
) -> Result<upload_signatures::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -205,20 +183,13 @@ pub async fn upload_signatures_route(
|
|||
))?
|
||||
.to_owned(),
|
||||
);
|
||||
db.users.sign_key(
|
||||
user_id,
|
||||
key_id,
|
||||
signature,
|
||||
sender_user,
|
||||
&db.rooms,
|
||||
&db.globals,
|
||||
)?;
|
||||
services()
|
||||
.users
|
||||
.sign_key(user_id, key_id, signature, sender_user)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(upload_signatures::v3::Response {
|
||||
failures: BTreeMap::new(), // TODO: integrate
|
||||
})
|
||||
|
@ -230,7 +201,6 @@ pub async fn upload_signatures_route(
|
|||
///
|
||||
/// - TODO: left users
|
||||
pub async fn get_key_changes_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_key_changes::v3::IncomingRequest>,
|
||||
) -> Result<get_key_changes::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -238,7 +208,8 @@ pub async fn get_key_changes_route(
|
|||
let mut device_list_updates = HashSet::new();
|
||||
|
||||
device_list_updates.extend(
|
||||
db.users
|
||||
services()
|
||||
.users
|
||||
.keys_changed(
|
||||
sender_user.as_str(),
|
||||
body.from
|
||||
|
@ -253,9 +224,15 @@ pub async fn get_key_changes_route(
|
|||
.filter_map(|r| r.ok()),
|
||||
);
|
||||
|
||||
for room_id in db.rooms.rooms_joined(sender_user).filter_map(|r| r.ok()) {
|
||||
for room_id in services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(sender_user)
|
||||
.filter_map(|r| r.ok())
|
||||
{
|
||||
device_list_updates.extend(
|
||||
db.users
|
||||
services()
|
||||
.users
|
||||
.keys_changed(
|
||||
&room_id.to_string(),
|
||||
body.from.parse().map_err(|_| {
|
||||
|
@ -278,7 +255,6 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
|||
sender_user: Option<&UserId>,
|
||||
device_keys_input: &BTreeMap<Box<UserId>, Vec<Box<DeviceId>>>,
|
||||
allowed_signatures: F,
|
||||
db: &Database,
|
||||
) -> Result<get_keys::v3::Response> {
|
||||
let mut master_keys = BTreeMap::new();
|
||||
let mut self_signing_keys = BTreeMap::new();
|
||||
|
@ -290,7 +266,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
|||
for (user_id, device_ids) in device_keys_input {
|
||||
let user_id: &UserId = &**user_id;
|
||||
|
||||
if user_id.server_name() != db.globals.server_name() {
|
||||
if user_id.server_name() != services().globals.server_name() {
|
||||
get_over_federation
|
||||
.entry(user_id.server_name())
|
||||
.or_insert_with(Vec::new)
|
||||
|
@ -300,10 +276,10 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
|||
|
||||
if device_ids.is_empty() {
|
||||
let mut container = BTreeMap::new();
|
||||
for device_id in db.users.all_device_ids(user_id) {
|
||||
for device_id in services().users.all_device_ids(user_id) {
|
||||
let device_id = device_id?;
|
||||
if let Some(mut keys) = db.users.get_device_keys(user_id, &device_id)? {
|
||||
let metadata = db
|
||||
if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? {
|
||||
let metadata = services()
|
||||
.users
|
||||
.get_device_metadata(user_id, &device_id)?
|
||||
.ok_or_else(|| {
|
||||
|
@ -319,13 +295,14 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
|||
} else {
|
||||
for device_id in device_ids {
|
||||
let mut container = BTreeMap::new();
|
||||
if let Some(mut keys) = db.users.get_device_keys(user_id, device_id)? {
|
||||
let metadata = db.users.get_device_metadata(user_id, device_id)?.ok_or(
|
||||
Error::BadRequest(
|
||||
if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? {
|
||||
let metadata = services()
|
||||
.users
|
||||
.get_device_metadata(user_id, device_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Tried to get keys for nonexistent device.",
|
||||
),
|
||||
)?;
|
||||
))?;
|
||||
|
||||
add_unsigned_device_display_name(&mut keys, metadata)
|
||||
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
|
||||
|
@ -335,17 +312,20 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(master_key) = db.users.get_master_key(user_id, &allowed_signatures)? {
|
||||
if let Some(master_key) = services()
|
||||
.users
|
||||
.get_master_key(user_id, &allowed_signatures)?
|
||||
{
|
||||
master_keys.insert(user_id.to_owned(), master_key);
|
||||
}
|
||||
if let Some(self_signing_key) = db
|
||||
if let Some(self_signing_key) = services()
|
||||
.users
|
||||
.get_self_signing_key(user_id, &allowed_signatures)?
|
||||
{
|
||||
self_signing_keys.insert(user_id.to_owned(), self_signing_key);
|
||||
}
|
||||
if Some(user_id) == sender_user {
|
||||
if let Some(user_signing_key) = db.users.get_user_signing_key(user_id)? {
|
||||
if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? {
|
||||
user_signing_keys.insert(user_id.to_owned(), user_signing_key);
|
||||
}
|
||||
}
|
||||
|
@ -362,9 +342,9 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
|||
}
|
||||
(
|
||||
server,
|
||||
db.sending
|
||||
services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
&db.globals,
|
||||
server,
|
||||
federation::keys::get_keys::v1::Request {
|
||||
device_keys: device_keys_input_fed,
|
||||
|
@ -417,14 +397,13 @@ fn add_unsigned_device_display_name(
|
|||
|
||||
pub(crate) async fn claim_keys_helper(
|
||||
one_time_keys_input: &BTreeMap<Box<UserId>, BTreeMap<Box<DeviceId>, DeviceKeyAlgorithm>>,
|
||||
db: &Database,
|
||||
) -> Result<claim_keys::v3::Response> {
|
||||
let mut one_time_keys = BTreeMap::new();
|
||||
|
||||
let mut get_over_federation = BTreeMap::new();
|
||||
|
||||
for (user_id, map) in one_time_keys_input {
|
||||
if user_id.server_name() != db.globals.server_name() {
|
||||
if user_id.server_name() != services().globals.server_name() {
|
||||
get_over_federation
|
||||
.entry(user_id.server_name())
|
||||
.or_insert_with(Vec::new)
|
||||
|
@ -434,8 +413,9 @@ pub(crate) async fn claim_keys_helper(
|
|||
let mut container = BTreeMap::new();
|
||||
for (device_id, key_algorithm) in map {
|
||||
if let Some(one_time_keys) =
|
||||
db.users
|
||||
.take_one_time_key(user_id, device_id, key_algorithm, &db.globals)?
|
||||
services()
|
||||
.users
|
||||
.take_one_time_key(user_id, device_id, key_algorithm)?
|
||||
{
|
||||
let mut c = BTreeMap::new();
|
||||
c.insert(one_time_keys.0, one_time_keys.1);
|
||||
|
@ -453,10 +433,9 @@ pub(crate) async fn claim_keys_helper(
|
|||
one_time_keys_input_fed.insert(user_id.clone(), keys.clone());
|
||||
}
|
||||
// Ignore failures
|
||||
if let Ok(keys) = db
|
||||
if let Ok(keys) = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
&db.globals,
|
||||
server,
|
||||
federation::keys::claim_keys::v1::Request {
|
||||
one_time_keys: one_time_keys_input_fed,
|
|
@ -1,7 +1,4 @@
|
|||
use crate::{
|
||||
database::{media::FileMeta, DatabaseGuard},
|
||||
utils, Error, Result, Ruma,
|
||||
};
|
||||
use crate::{service::media::FileMeta, services, utils, Error, Result, Ruma};
|
||||
use ruma::api::client::{
|
||||
error::ErrorKind,
|
||||
media::{
|
||||
|
@ -16,11 +13,10 @@ const MXC_LENGTH: usize = 32;
|
|||
///
|
||||
/// Returns max upload size.
|
||||
pub async fn get_media_config_route(
|
||||
db: DatabaseGuard,
|
||||
_body: Ruma<get_media_config::v3::Request>,
|
||||
) -> Result<get_media_config::v3::Response> {
|
||||
Ok(get_media_config::v3::Response {
|
||||
upload_size: db.globals.max_request_size().into(),
|
||||
upload_size: services().globals.max_request_size().into(),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -31,31 +27,27 @@ pub async fn get_media_config_route(
|
|||
/// - Some metadata will be saved in the database
|
||||
/// - Media will be saved in the media/ directory
|
||||
pub async fn create_content_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<create_content::v3::IncomingRequest>,
|
||||
) -> Result<create_content::v3::Response> {
|
||||
let mxc = format!(
|
||||
"mxc://{}/{}",
|
||||
db.globals.server_name(),
|
||||
services().globals.server_name(),
|
||||
utils::random_string(MXC_LENGTH)
|
||||
);
|
||||
|
||||
db.media
|
||||
services()
|
||||
.media
|
||||
.create(
|
||||
mxc.clone(),
|
||||
&db.globals,
|
||||
&body
|
||||
.filename
|
||||
body.filename
|
||||
.as_ref()
|
||||
.map(|filename| "inline; filename=".to_owned() + filename)
|
||||
.as_deref(),
|
||||
&body.content_type.as_deref(),
|
||||
body.content_type.as_deref(),
|
||||
&body.file,
|
||||
)
|
||||
.await?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(create_content::v3::Response {
|
||||
content_uri: mxc.try_into().expect("Invalid mxc:// URI"),
|
||||
blurhash: None,
|
||||
|
@ -63,15 +55,13 @@ pub async fn create_content_route(
|
|||
}
|
||||
|
||||
pub async fn get_remote_content(
|
||||
db: &DatabaseGuard,
|
||||
mxc: &str,
|
||||
server_name: &ruma::ServerName,
|
||||
media_id: &str,
|
||||
) -> Result<get_content::v3::Response, Error> {
|
||||
let content_response = db
|
||||
let content_response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
&db.globals,
|
||||
server_name,
|
||||
get_content::v3::Request {
|
||||
allow_remote: false,
|
||||
|
@ -81,12 +71,12 @@ pub async fn get_remote_content(
|
|||
)
|
||||
.await?;
|
||||
|
||||
db.media
|
||||
services()
|
||||
.media
|
||||
.create(
|
||||
mxc.to_string(),
|
||||
&db.globals,
|
||||
&content_response.content_disposition.as_deref(),
|
||||
&content_response.content_type.as_deref(),
|
||||
content_response.content_disposition.as_deref(),
|
||||
content_response.content_type.as_deref(),
|
||||
&content_response.file,
|
||||
)
|
||||
.await?;
|
||||
|
@ -100,7 +90,6 @@ pub async fn get_remote_content(
|
|||
///
|
||||
/// - Only allows federation if `allow_remote` is true
|
||||
pub async fn get_content_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_content::v3::IncomingRequest>,
|
||||
) -> Result<get_content::v3::Response> {
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
|
@ -109,16 +98,16 @@ pub async fn get_content_route(
|
|||
content_disposition,
|
||||
content_type,
|
||||
file,
|
||||
}) = db.media.get(&db.globals, &mxc).await?
|
||||
}) = services().media.get(mxc.clone()).await?
|
||||
{
|
||||
Ok(get_content::v3::Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition,
|
||||
})
|
||||
} else if &*body.server_name != db.globals.server_name() && body.allow_remote {
|
||||
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
|
||||
let remote_content_response =
|
||||
get_remote_content(&db, &mxc, &body.server_name, &body.media_id).await?;
|
||||
get_remote_content(&mxc, &body.server_name, &body.media_id).await?;
|
||||
Ok(remote_content_response)
|
||||
} else {
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
|
||||
|
@ -131,7 +120,6 @@ pub async fn get_content_route(
|
|||
///
|
||||
/// - Only allows federation if `allow_remote` is true
|
||||
pub async fn get_content_as_filename_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_content_as_filename::v3::IncomingRequest>,
|
||||
) -> Result<get_content_as_filename::v3::Response> {
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
|
@ -140,16 +128,16 @@ pub async fn get_content_as_filename_route(
|
|||
content_disposition: _,
|
||||
content_type,
|
||||
file,
|
||||
}) = db.media.get(&db.globals, &mxc).await?
|
||||
}) = services().media.get(mxc.clone()).await?
|
||||
{
|
||||
Ok(get_content_as_filename::v3::Response {
|
||||
file,
|
||||
content_type,
|
||||
content_disposition: Some(format!("inline; filename={}", body.filename)),
|
||||
})
|
||||
} else if &*body.server_name != db.globals.server_name() && body.allow_remote {
|
||||
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
|
||||
let remote_content_response =
|
||||
get_remote_content(&db, &mxc, &body.server_name, &body.media_id).await?;
|
||||
get_remote_content(&mxc, &body.server_name, &body.media_id).await?;
|
||||
|
||||
Ok(get_content_as_filename::v3::Response {
|
||||
content_disposition: Some(format!("inline: filename={}", body.filename)),
|
||||
|
@ -167,18 +155,16 @@ pub async fn get_content_as_filename_route(
|
|||
///
|
||||
/// - Only allows federation if `allow_remote` is true
|
||||
pub async fn get_content_thumbnail_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_content_thumbnail::v3::IncomingRequest>,
|
||||
) -> Result<get_content_thumbnail::v3::Response> {
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
|
||||
if let Some(FileMeta {
|
||||
content_type, file, ..
|
||||
}) = db
|
||||
}) = services()
|
||||
.media
|
||||
.get_thumbnail(
|
||||
&mxc,
|
||||
&db.globals,
|
||||
mxc.clone(),
|
||||
body.width
|
||||
.try_into()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
|
||||
|
@ -189,11 +175,10 @@ pub async fn get_content_thumbnail_route(
|
|||
.await?
|
||||
{
|
||||
Ok(get_content_thumbnail::v3::Response { file, content_type })
|
||||
} else if &*body.server_name != db.globals.server_name() && body.allow_remote {
|
||||
let get_thumbnail_response = db
|
||||
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
|
||||
let get_thumbnail_response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
&db.globals,
|
||||
&body.server_name,
|
||||
get_content_thumbnail::v3::Request {
|
||||
allow_remote: false,
|
||||
|
@ -206,12 +191,12 @@ pub async fn get_content_thumbnail_route(
|
|||
)
|
||||
.await?;
|
||||
|
||||
db.media
|
||||
services()
|
||||
.media
|
||||
.upload_thumbnail(
|
||||
mxc,
|
||||
&db.globals,
|
||||
&None,
|
||||
&get_thumbnail_response.content_type,
|
||||
None,
|
||||
get_thumbnail_response.content_type.as_deref(),
|
||||
body.width.try_into().expect("all UInts are valid u32s"),
|
||||
body.height.try_into().expect("all UInts are valid u32s"),
|
||||
&get_thumbnail_response.file,
|
File diff suppressed because it is too large
Load diff
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, Error, Result, Ruma};
|
||||
use crate::{service::pdu::PduBuilder, services, utils, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
|
@ -19,14 +19,14 @@ use std::{
|
|||
/// - The only requirement for the content is that it has to be valid json
|
||||
/// - Tries to send the event into the room, auth rules will determine if it is allowed
|
||||
pub async fn send_message_event_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<send_message_event::v3::IncomingRequest>,
|
||||
) -> Result<send_message_event::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_deref();
|
||||
|
||||
let mutex_state = Arc::clone(
|
||||
db.globals
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
|
@ -37,7 +37,7 @@ pub async fn send_message_event_route(
|
|||
|
||||
// Forbid m.room.encrypted if encryption is disabled
|
||||
if RoomEventType::RoomEncrypted == body.event_type.to_string().into()
|
||||
&& !db.globals.allow_encryption()
|
||||
&& !services().globals.allow_encryption()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
|
@ -47,7 +47,8 @@ pub async fn send_message_event_route(
|
|||
|
||||
// Check if this is a new transaction id
|
||||
if let Some(response) =
|
||||
db.transaction_ids
|
||||
services()
|
||||
.transaction_ids
|
||||
.existing_txnid(sender_user, sender_device, &body.txn_id)?
|
||||
{
|
||||
// The client might have sent a txnid of the /sendToDevice endpoint
|
||||
|
@ -69,7 +70,7 @@ pub async fn send_message_event_route(
|
|||
let mut unsigned = BTreeMap::new();
|
||||
unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into());
|
||||
|
||||
let event_id = db.rooms.build_and_append_pdu(
|
||||
let event_id = services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: body.event_type.to_string().into(),
|
||||
content: serde_json::from_str(body.body.body.json().get())
|
||||
|
@ -80,11 +81,10 @@ pub async fn send_message_event_route(
|
|||
},
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
|
||||
db.transaction_ids.add_txnid(
|
||||
services().transaction_ids.add_txnid(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.txn_id,
|
||||
|
@ -93,8 +93,6 @@ pub async fn send_message_event_route(
|
|||
|
||||
drop(state_lock);
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(send_message_event::v3::Response::new(
|
||||
(*event_id).to_owned(),
|
||||
))
|
||||
|
@ -107,13 +105,16 @@ pub async fn send_message_event_route(
|
|||
/// - Only works if the user is joined (TODO: always allow, but only show events where the user was
|
||||
/// joined, depending on history_visibility)
|
||||
pub async fn get_message_events_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_message_events::v3::IncomingRequest>,
|
||||
) -> Result<get_message_events::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
if !db.rooms.is_joined(sender_user, &body.room_id)? {
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_joined(sender_user, &body.room_id)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view this room.",
|
||||
|
@ -133,8 +134,12 @@ pub async fn get_message_events_route(
|
|||
|
||||
let to = body.to.as_ref().map(|t| t.parse());
|
||||
|
||||
db.rooms
|
||||
.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)?;
|
||||
services().rooms.lazy_loading.lazy_load_confirm_delivery(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.room_id,
|
||||
from,
|
||||
)?;
|
||||
|
||||
// Use limit or else 10
|
||||
let limit = body.limit.try_into().map_or(10_usize, |l: u32| l as usize);
|
||||
|
@ -147,13 +152,16 @@ pub async fn get_message_events_route(
|
|||
|
||||
match body.dir {
|
||||
get_message_events::v3::Direction::Forward => {
|
||||
let events_after: Vec<_> = db
|
||||
let events_after: Vec<_> = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdus_after(sender_user, &body.room_id, from)?
|
||||
.take(limit)
|
||||
.filter_map(|r| r.ok()) // Filter out buggy events
|
||||
.filter_map(|(pdu_id, pdu)| {
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdu_count(&pdu_id)
|
||||
.map(|pdu_count| (pdu_count, pdu))
|
||||
.ok()
|
||||
|
@ -162,7 +170,7 @@ pub async fn get_message_events_route(
|
|||
.collect();
|
||||
|
||||
for (_, event) in &events_after {
|
||||
if !db.rooms.lazy_load_was_sent_before(
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.room_id,
|
||||
|
@ -184,13 +192,16 @@ pub async fn get_message_events_route(
|
|||
resp.chunk = events_after;
|
||||
}
|
||||
get_message_events::v3::Direction::Backward => {
|
||||
let events_before: Vec<_> = db
|
||||
let events_before: Vec<_> = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdus_until(sender_user, &body.room_id, from)?
|
||||
.take(limit)
|
||||
.filter_map(|r| r.ok()) // Filter out buggy events
|
||||
.filter_map(|(pdu_id, pdu)| {
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdu_count(&pdu_id)
|
||||
.map(|pdu_count| (pdu_count, pdu))
|
||||
.ok()
|
||||
|
@ -199,7 +210,7 @@ pub async fn get_message_events_route(
|
|||
.collect();
|
||||
|
||||
for (_, event) in &events_before {
|
||||
if !db.rooms.lazy_load_was_sent_before(
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.room_id,
|
||||
|
@ -224,16 +235,17 @@ pub async fn get_message_events_route(
|
|||
|
||||
resp.state = Vec::new();
|
||||
for ll_id in &lazy_loaded {
|
||||
if let Some(member_event) =
|
||||
db.rooms
|
||||
.room_state_get(&body.room_id, &StateEventType::RoomMember, ll_id.as_str())?
|
||||
{
|
||||
if let Some(member_event) = services().rooms.state_accessor.room_state_get(
|
||||
&body.room_id,
|
||||
&StateEventType::RoomMember,
|
||||
ll_id.as_str(),
|
||||
)? {
|
||||
resp.state.push(member_event.to_state_event());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(next_token) = next_token {
|
||||
db.rooms.lazy_load_mark_sent(
|
||||
services().rooms.lazy_loading.lazy_load_mark_sent(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.room_id,
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, utils, Result, Ruma};
|
||||
use crate::{services, utils, Result, Ruma};
|
||||
use ruma::api::client::presence::{get_presence, set_presence};
|
||||
use std::time::Duration;
|
||||
|
||||
|
@ -6,22 +6,21 @@ use std::time::Duration;
|
|||
///
|
||||
/// Sets the presence state of the sender user.
|
||||
pub async fn set_presence_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<set_presence::v3::IncomingRequest>,
|
||||
) -> Result<set_presence::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
for room_id in db.rooms.rooms_joined(sender_user) {
|
||||
for room_id in services().rooms.state_cache.rooms_joined(sender_user) {
|
||||
let room_id = room_id?;
|
||||
|
||||
db.rooms.edus.update_presence(
|
||||
services().rooms.edus.presence.update_presence(
|
||||
sender_user,
|
||||
&room_id,
|
||||
ruma::events::presence::PresenceEvent {
|
||||
content: ruma::events::presence::PresenceEventContent {
|
||||
avatar_url: db.users.avatar_url(sender_user)?,
|
||||
avatar_url: services().users.avatar_url(sender_user)?,
|
||||
currently_active: None,
|
||||
displayname: db.users.displayname(sender_user)?,
|
||||
displayname: services().users.displayname(sender_user)?,
|
||||
last_active_ago: Some(
|
||||
utils::millis_since_unix_epoch()
|
||||
.try_into()
|
||||
|
@ -32,12 +31,9 @@ pub async fn set_presence_route(
|
|||
},
|
||||
sender: sender_user.clone(),
|
||||
},
|
||||
&db.globals,
|
||||
)?;
|
||||
}
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(set_presence::v3::Response {})
|
||||
}
|
||||
|
||||
|
@ -47,22 +43,23 @@ pub async fn set_presence_route(
|
|||
///
|
||||
/// - Only works if you share a room with the user
|
||||
pub async fn get_presence_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_presence::v3::IncomingRequest>,
|
||||
) -> Result<get_presence::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let mut presence_event = None;
|
||||
|
||||
for room_id in db
|
||||
for room_id in services()
|
||||
.rooms
|
||||
.user
|
||||
.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])?
|
||||
{
|
||||
let room_id = room_id?;
|
||||
|
||||
if let Some(presence) = db
|
||||
if let Some(presence) = services()
|
||||
.rooms
|
||||
.edus
|
||||
.presence
|
||||
.get_last_presence_event(sender_user, &room_id)?
|
||||
{
|
||||
presence_event = Some(presence);
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, Error, Result, Ruma};
|
||||
use crate::{service::pdu::PduBuilder, services, utils, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::{
|
||||
|
@ -20,17 +20,18 @@ use std::sync::Arc;
|
|||
///
|
||||
/// - Also makes sure other users receive the update using presence EDUs
|
||||
pub async fn set_displayname_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<set_display_name::v3::IncomingRequest>,
|
||||
) -> Result<set_display_name::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
db.users
|
||||
services()
|
||||
.users
|
||||
.set_displayname(sender_user, body.displayname.clone())?;
|
||||
|
||||
// Send a new membership event and presence update into all joined rooms
|
||||
let all_rooms_joined: Vec<_> = db
|
||||
let all_rooms_joined: Vec<_> = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(sender_user)
|
||||
.filter_map(|r| r.ok())
|
||||
.map(|room_id| {
|
||||
|
@ -40,7 +41,9 @@ pub async fn set_displayname_route(
|
|||
content: to_raw_value(&RoomMemberEventContent {
|
||||
displayname: body.displayname.clone(),
|
||||
..serde_json::from_str(
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(
|
||||
&room_id,
|
||||
&StateEventType::RoomMember,
|
||||
|
@ -70,7 +73,8 @@ pub async fn set_displayname_route(
|
|||
|
||||
for (pdu_builder, room_id) in all_rooms_joined {
|
||||
let mutex_state = Arc::clone(
|
||||
db.globals
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
|
@ -79,19 +83,22 @@ pub async fn set_displayname_route(
|
|||
);
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
let _ = db
|
||||
.rooms
|
||||
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock);
|
||||
let _ = services().rooms.timeline.build_and_append_pdu(
|
||||
pdu_builder,
|
||||
sender_user,
|
||||
&room_id,
|
||||
&state_lock,
|
||||
);
|
||||
|
||||
// Presence update
|
||||
db.rooms.edus.update_presence(
|
||||
services().rooms.edus.presence.update_presence(
|
||||
sender_user,
|
||||
&room_id,
|
||||
ruma::events::presence::PresenceEvent {
|
||||
content: ruma::events::presence::PresenceEventContent {
|
||||
avatar_url: db.users.avatar_url(sender_user)?,
|
||||
avatar_url: services().users.avatar_url(sender_user)?,
|
||||
currently_active: None,
|
||||
displayname: db.users.displayname(sender_user)?,
|
||||
displayname: services().users.displayname(sender_user)?,
|
||||
last_active_ago: Some(
|
||||
utils::millis_since_unix_epoch()
|
||||
.try_into()
|
||||
|
@ -102,12 +109,9 @@ pub async fn set_displayname_route(
|
|||
},
|
||||
sender: sender_user.clone(),
|
||||
},
|
||||
&db.globals,
|
||||
)?;
|
||||
}
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(set_display_name::v3::Response {})
|
||||
}
|
||||
|
||||
|
@ -117,14 +121,12 @@ pub async fn set_displayname_route(
|
|||
///
|
||||
/// - If user is on another server: Fetches displayname over federation
|
||||
pub async fn get_displayname_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_display_name::v3::IncomingRequest>,
|
||||
) -> Result<get_display_name::v3::Response> {
|
||||
if body.user_id.server_name() != db.globals.server_name() {
|
||||
let response = db
|
||||
if body.user_id.server_name() != services().globals.server_name() {
|
||||
let response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
&db.globals,
|
||||
body.user_id.server_name(),
|
||||
federation::query::get_profile_information::v1::Request {
|
||||
user_id: &body.user_id,
|
||||
|
@ -139,7 +141,7 @@ pub async fn get_displayname_route(
|
|||
}
|
||||
|
||||
Ok(get_display_name::v3::Response {
|
||||
displayname: db.users.displayname(&body.user_id)?,
|
||||
displayname: services().users.displayname(&body.user_id)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -149,19 +151,22 @@ pub async fn get_displayname_route(
|
|||
///
|
||||
/// - Also makes sure other users receive the update using presence EDUs
|
||||
pub async fn set_avatar_url_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<set_avatar_url::v3::IncomingRequest>,
|
||||
) -> Result<set_avatar_url::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
db.users
|
||||
services()
|
||||
.users
|
||||
.set_avatar_url(sender_user, body.avatar_url.clone())?;
|
||||
|
||||
db.users.set_blurhash(sender_user, body.blurhash.clone())?;
|
||||
services()
|
||||
.users
|
||||
.set_blurhash(sender_user, body.blurhash.clone())?;
|
||||
|
||||
// Send a new membership event and presence update into all joined rooms
|
||||
let all_joined_rooms: Vec<_> = db
|
||||
let all_joined_rooms: Vec<_> = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(sender_user)
|
||||
.filter_map(|r| r.ok())
|
||||
.map(|room_id| {
|
||||
|
@ -171,7 +176,9 @@ pub async fn set_avatar_url_route(
|
|||
content: to_raw_value(&RoomMemberEventContent {
|
||||
avatar_url: body.avatar_url.clone(),
|
||||
..serde_json::from_str(
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(
|
||||
&room_id,
|
||||
&StateEventType::RoomMember,
|
||||
|
@ -201,7 +208,8 @@ pub async fn set_avatar_url_route(
|
|||
|
||||
for (pdu_builder, room_id) in all_joined_rooms {
|
||||
let mutex_state = Arc::clone(
|
||||
db.globals
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
|
@ -210,19 +218,22 @@ pub async fn set_avatar_url_route(
|
|||
);
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
let _ = db
|
||||
.rooms
|
||||
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock);
|
||||
let _ = services().rooms.timeline.build_and_append_pdu(
|
||||
pdu_builder,
|
||||
sender_user,
|
||||
&room_id,
|
||||
&state_lock,
|
||||
);
|
||||
|
||||
// Presence update
|
||||
db.rooms.edus.update_presence(
|
||||
services().rooms.edus.presence.update_presence(
|
||||
sender_user,
|
||||
&room_id,
|
||||
ruma::events::presence::PresenceEvent {
|
||||
content: ruma::events::presence::PresenceEventContent {
|
||||
avatar_url: db.users.avatar_url(sender_user)?,
|
||||
avatar_url: services().users.avatar_url(sender_user)?,
|
||||
currently_active: None,
|
||||
displayname: db.users.displayname(sender_user)?,
|
||||
displayname: services().users.displayname(sender_user)?,
|
||||
last_active_ago: Some(
|
||||
utils::millis_since_unix_epoch()
|
||||
.try_into()
|
||||
|
@ -233,12 +244,9 @@ pub async fn set_avatar_url_route(
|
|||
},
|
||||
sender: sender_user.clone(),
|
||||
},
|
||||
&db.globals,
|
||||
)?;
|
||||
}
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(set_avatar_url::v3::Response {})
|
||||
}
|
||||
|
||||
|
@ -248,14 +256,12 @@ pub async fn set_avatar_url_route(
|
|||
///
|
||||
/// - If user is on another server: Fetches avatar_url and blurhash over federation
|
||||
pub async fn get_avatar_url_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_avatar_url::v3::IncomingRequest>,
|
||||
) -> Result<get_avatar_url::v3::Response> {
|
||||
if body.user_id.server_name() != db.globals.server_name() {
|
||||
let response = db
|
||||
if body.user_id.server_name() != services().globals.server_name() {
|
||||
let response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
&db.globals,
|
||||
body.user_id.server_name(),
|
||||
federation::query::get_profile_information::v1::Request {
|
||||
user_id: &body.user_id,
|
||||
|
@ -271,8 +277,8 @@ pub async fn get_avatar_url_route(
|
|||
}
|
||||
|
||||
Ok(get_avatar_url::v3::Response {
|
||||
avatar_url: db.users.avatar_url(&body.user_id)?,
|
||||
blurhash: db.users.blurhash(&body.user_id)?,
|
||||
avatar_url: services().users.avatar_url(&body.user_id)?,
|
||||
blurhash: services().users.blurhash(&body.user_id)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -282,14 +288,12 @@ pub async fn get_avatar_url_route(
|
|||
///
|
||||
/// - If user is on another server: Fetches profile over federation
|
||||
pub async fn get_profile_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_profile::v3::IncomingRequest>,
|
||||
) -> Result<get_profile::v3::Response> {
|
||||
if body.user_id.server_name() != db.globals.server_name() {
|
||||
let response = db
|
||||
if body.user_id.server_name() != services().globals.server_name() {
|
||||
let response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
&db.globals,
|
||||
body.user_id.server_name(),
|
||||
federation::query::get_profile_information::v1::Request {
|
||||
user_id: &body.user_id,
|
||||
|
@ -305,7 +309,7 @@ pub async fn get_profile_route(
|
|||
});
|
||||
}
|
||||
|
||||
if !db.users.exists(&body.user_id)? {
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
// Return 404 if this user doesn't exist
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
|
@ -314,8 +318,8 @@ pub async fn get_profile_route(
|
|||
}
|
||||
|
||||
Ok(get_profile::v3::Response {
|
||||
avatar_url: db.users.avatar_url(&body.user_id)?,
|
||||
blurhash: db.users.blurhash(&body.user_id)?,
|
||||
displayname: db.users.displayname(&body.user_id)?,
|
||||
avatar_url: services().users.avatar_url(&body.user_id)?,
|
||||
blurhash: services().users.blurhash(&body.user_id)?,
|
||||
displayname: services().users.displayname(&body.user_id)?,
|
||||
})
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, Error, Result, Ruma};
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
|
@ -16,12 +16,11 @@ use ruma::{
|
|||
///
|
||||
/// Retrieves the push rules event for this user.
|
||||
pub async fn get_pushrules_all_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_pushrules_all::v3::Request>,
|
||||
) -> Result<get_pushrules_all::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event: PushRulesEvent = db
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
|
@ -33,8 +32,12 @@ pub async fn get_pushrules_all_route(
|
|||
"PushRules event not found.",
|
||||
))?;
|
||||
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
|
||||
Ok(get_pushrules_all::v3::Response {
|
||||
global: event.content.global,
|
||||
global: account_data.global,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -42,12 +45,11 @@ pub async fn get_pushrules_all_route(
|
|||
///
|
||||
/// Retrieves a single specified push rule for this user.
|
||||
pub async fn get_pushrule_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_pushrule::v3::IncomingRequest>,
|
||||
) -> Result<get_pushrule::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event: PushRulesEvent = db
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
|
@ -59,7 +61,11 @@ pub async fn get_pushrule_route(
|
|||
"PushRules event not found.",
|
||||
))?;
|
||||
|
||||
let global = event.content.global;
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
|
||||
let global = account_data.global;
|
||||
let rule = match body.kind {
|
||||
RuleKind::Override => global
|
||||
.override_
|
||||
|
@ -98,7 +104,6 @@ pub async fn get_pushrule_route(
|
|||
///
|
||||
/// Creates a single specified push rule for this user.
|
||||
pub async fn set_pushrule_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<set_pushrule::v3::IncomingRequest>,
|
||||
) -> Result<set_pushrule::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -111,7 +116,7 @@ pub async fn set_pushrule_route(
|
|||
));
|
||||
}
|
||||
|
||||
let mut event: PushRulesEvent = db
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
|
@ -123,7 +128,10 @@ pub async fn set_pushrule_route(
|
|||
"PushRules event not found.",
|
||||
))?;
|
||||
|
||||
let global = &mut event.content.global;
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
||||
let global = &mut account_data.content.global;
|
||||
match body.kind {
|
||||
RuleKind::Override => {
|
||||
global.override_.replace(
|
||||
|
@ -186,16 +194,13 @@ pub async fn set_pushrule_route(
|
|||
_ => {}
|
||||
}
|
||||
|
||||
db.account_data.update(
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&event,
|
||||
&db.globals,
|
||||
&serde_json::to_value(account_data).expect("to json value always works"),
|
||||
)?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(set_pushrule::v3::Response {})
|
||||
}
|
||||
|
||||
|
@ -203,7 +208,6 @@ pub async fn set_pushrule_route(
|
|||
///
|
||||
/// Gets the actions of a single specified push rule for this user.
|
||||
pub async fn get_pushrule_actions_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_pushrule_actions::v3::IncomingRequest>,
|
||||
) -> Result<get_pushrule_actions::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -215,7 +219,7 @@ pub async fn get_pushrule_actions_route(
|
|||
));
|
||||
}
|
||||
|
||||
let mut event: PushRulesEvent = db
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
|
@ -227,7 +231,11 @@ pub async fn get_pushrule_actions_route(
|
|||
"PushRules event not found.",
|
||||
))?;
|
||||
|
||||
let global = &mut event.content.global;
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
|
||||
let global = account_data.global;
|
||||
let actions = match body.kind {
|
||||
RuleKind::Override => global
|
||||
.override_
|
||||
|
@ -252,8 +260,6 @@ pub async fn get_pushrule_actions_route(
|
|||
_ => None,
|
||||
};
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(get_pushrule_actions::v3::Response {
|
||||
actions: actions.unwrap_or_default(),
|
||||
})
|
||||
|
@ -263,7 +269,6 @@ pub async fn get_pushrule_actions_route(
|
|||
///
|
||||
/// Sets the actions of a single specified push rule for this user.
|
||||
pub async fn set_pushrule_actions_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<set_pushrule_actions::v3::IncomingRequest>,
|
||||
) -> Result<set_pushrule_actions::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -275,7 +280,7 @@ pub async fn set_pushrule_actions_route(
|
|||
));
|
||||
}
|
||||
|
||||
let mut event: PushRulesEvent = db
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
|
@ -287,7 +292,10 @@ pub async fn set_pushrule_actions_route(
|
|||
"PushRules event not found.",
|
||||
))?;
|
||||
|
||||
let global = &mut event.content.global;
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
||||
let global = &mut account_data.content.global;
|
||||
match body.kind {
|
||||
RuleKind::Override => {
|
||||
if let Some(mut rule) = global.override_.get(body.rule_id.as_str()).cloned() {
|
||||
|
@ -322,16 +330,13 @@ pub async fn set_pushrule_actions_route(
|
|||
_ => {}
|
||||
};
|
||||
|
||||
db.account_data.update(
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&event,
|
||||
&db.globals,
|
||||
&serde_json::to_value(account_data).expect("to json value always works"),
|
||||
)?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(set_pushrule_actions::v3::Response {})
|
||||
}
|
||||
|
||||
|
@ -339,7 +344,6 @@ pub async fn set_pushrule_actions_route(
|
|||
///
|
||||
/// Gets the enabled status of a single specified push rule for this user.
|
||||
pub async fn get_pushrule_enabled_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_pushrule_enabled::v3::IncomingRequest>,
|
||||
) -> Result<get_pushrule_enabled::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -351,7 +355,7 @@ pub async fn get_pushrule_enabled_route(
|
|||
));
|
||||
}
|
||||
|
||||
let mut event: PushRulesEvent = db
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
|
@ -363,7 +367,10 @@ pub async fn get_pushrule_enabled_route(
|
|||
"PushRules event not found.",
|
||||
))?;
|
||||
|
||||
let global = &mut event.content.global;
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
||||
let global = account_data.content.global;
|
||||
let enabled = match body.kind {
|
||||
RuleKind::Override => global
|
||||
.override_
|
||||
|
@ -393,8 +400,6 @@ pub async fn get_pushrule_enabled_route(
|
|||
_ => false,
|
||||
};
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(get_pushrule_enabled::v3::Response { enabled })
|
||||
}
|
||||
|
||||
|
@ -402,7 +407,6 @@ pub async fn get_pushrule_enabled_route(
|
|||
///
|
||||
/// Sets the enabled status of a single specified push rule for this user.
|
||||
pub async fn set_pushrule_enabled_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<set_pushrule_enabled::v3::IncomingRequest>,
|
||||
) -> Result<set_pushrule_enabled::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -414,7 +418,7 @@ pub async fn set_pushrule_enabled_route(
|
|||
));
|
||||
}
|
||||
|
||||
let mut event: PushRulesEvent = db
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
|
@ -426,7 +430,10 @@ pub async fn set_pushrule_enabled_route(
|
|||
"PushRules event not found.",
|
||||
))?;
|
||||
|
||||
let global = &mut event.content.global;
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
||||
let global = &mut account_data.content.global;
|
||||
match body.kind {
|
||||
RuleKind::Override => {
|
||||
if let Some(mut rule) = global.override_.get(body.rule_id.as_str()).cloned() {
|
||||
|
@ -466,16 +473,13 @@ pub async fn set_pushrule_enabled_route(
|
|||
_ => {}
|
||||
}
|
||||
|
||||
db.account_data.update(
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&event,
|
||||
&db.globals,
|
||||
&serde_json::to_value(account_data).expect("to json value always works"),
|
||||
)?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(set_pushrule_enabled::v3::Response {})
|
||||
}
|
||||
|
||||
|
@ -483,7 +487,6 @@ pub async fn set_pushrule_enabled_route(
|
|||
///
|
||||
/// Deletes a single specified push rule for this user.
|
||||
pub async fn delete_pushrule_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<delete_pushrule::v3::IncomingRequest>,
|
||||
) -> Result<delete_pushrule::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -495,7 +498,7 @@ pub async fn delete_pushrule_route(
|
|||
));
|
||||
}
|
||||
|
||||
let mut event: PushRulesEvent = db
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
|
@ -507,7 +510,10 @@ pub async fn delete_pushrule_route(
|
|||
"PushRules event not found.",
|
||||
))?;
|
||||
|
||||
let global = &mut event.content.global;
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
||||
let global = &mut account_data.content.global;
|
||||
match body.kind {
|
||||
RuleKind::Override => {
|
||||
if let Some(rule) = global.override_.get(body.rule_id.as_str()).cloned() {
|
||||
|
@ -537,16 +543,13 @@ pub async fn delete_pushrule_route(
|
|||
_ => {}
|
||||
}
|
||||
|
||||
db.account_data.update(
|
||||
services().account_data.update(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&event,
|
||||
&db.globals,
|
||||
&serde_json::to_value(account_data).expect("to json value always works"),
|
||||
)?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(delete_pushrule::v3::Response {})
|
||||
}
|
||||
|
||||
|
@ -554,13 +557,12 @@ pub async fn delete_pushrule_route(
|
|||
///
|
||||
/// Gets all currently active pushers for the sender user.
|
||||
pub async fn get_pushers_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_pushers::v3::Request>,
|
||||
) -> Result<get_pushers::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
Ok(get_pushers::v3::Response {
|
||||
pushers: db.pusher.get_pushers(sender_user)?,
|
||||
pushers: services().pusher.get_pushers(sender_user)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -570,15 +572,12 @@ pub async fn get_pushers_route(
|
|||
///
|
||||
/// - TODO: Handle `append`
|
||||
pub async fn set_pushers_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<set_pusher::v3::Request>,
|
||||
) -> Result<set_pusher::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let pusher = body.pusher.clone();
|
||||
|
||||
db.pusher.set_pusher(sender_user, pusher)?;
|
||||
|
||||
db.flush()?;
|
||||
services().pusher.set_pusher(sender_user, pusher)?;
|
||||
|
||||
Ok(set_pusher::v3::Response::default())
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, Error, Result, Ruma};
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt},
|
||||
events::RoomAccountDataEventType,
|
||||
|
@ -14,7 +14,6 @@ use std::collections::BTreeMap;
|
|||
/// - Updates fully-read account data event to `fully_read`
|
||||
/// - If `read_receipt` is set: Update private marker and public read receipt EDU
|
||||
pub async fn set_read_marker_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<set_read_marker::v3::IncomingRequest>,
|
||||
) -> Result<set_read_marker::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -24,25 +23,29 @@ pub async fn set_read_marker_route(
|
|||
event_id: body.fully_read.clone(),
|
||||
},
|
||||
};
|
||||
db.account_data.update(
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::FullyRead,
|
||||
&fully_read_event,
|
||||
&db.globals,
|
||||
&serde_json::to_value(fully_read_event).expect("to json value always works"),
|
||||
)?;
|
||||
|
||||
if let Some(event) = &body.read_receipt {
|
||||
db.rooms.edus.private_read_set(
|
||||
services().rooms.edus.read_receipt.private_read_set(
|
||||
&body.room_id,
|
||||
sender_user,
|
||||
db.rooms.get_pdu_count(event)?.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Event does not exist.",
|
||||
))?,
|
||||
&db.globals,
|
||||
services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_count(event)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Event does not exist.",
|
||||
))?,
|
||||
)?;
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.user
|
||||
.reset_notification_counts(sender_user, &body.room_id)?;
|
||||
|
||||
let mut user_receipts = BTreeMap::new();
|
||||
|
@ -59,19 +62,16 @@ pub async fn set_read_marker_route(
|
|||
let mut receipt_content = BTreeMap::new();
|
||||
receipt_content.insert(event.to_owned(), receipts);
|
||||
|
||||
db.rooms.edus.readreceipt_update(
|
||||
services().rooms.edus.read_receipt.readreceipt_update(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
ruma::events::receipt::ReceiptEvent {
|
||||
content: ruma::events::receipt::ReceiptEventContent(receipt_content),
|
||||
room_id: body.room_id.clone(),
|
||||
},
|
||||
&db.globals,
|
||||
)?;
|
||||
}
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(set_read_marker::v3::Response {})
|
||||
}
|
||||
|
||||
|
@ -79,23 +79,25 @@ pub async fn set_read_marker_route(
|
|||
///
|
||||
/// Sets private read marker and public read receipt EDU.
|
||||
pub async fn create_receipt_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<create_receipt::v3::IncomingRequest>,
|
||||
) -> Result<create_receipt::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
db.rooms.edus.private_read_set(
|
||||
services().rooms.edus.read_receipt.private_read_set(
|
||||
&body.room_id,
|
||||
sender_user,
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_count(&body.event_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Event does not exist.",
|
||||
))?,
|
||||
&db.globals,
|
||||
)?;
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.user
|
||||
.reset_notification_counts(sender_user, &body.room_id)?;
|
||||
|
||||
let mut user_receipts = BTreeMap::new();
|
||||
|
@ -111,17 +113,14 @@ pub async fn create_receipt_route(
|
|||
let mut receipt_content = BTreeMap::new();
|
||||
receipt_content.insert(body.event_id.to_owned(), receipts);
|
||||
|
||||
db.rooms.edus.readreceipt_update(
|
||||
services().rooms.edus.read_receipt.readreceipt_update(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
ruma::events::receipt::ReceiptEvent {
|
||||
content: ruma::events::receipt::ReceiptEventContent(receipt_content),
|
||||
room_id: body.room_id.clone(),
|
||||
},
|
||||
&db.globals,
|
||||
)?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(create_receipt::v3::Response {})
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{database::DatabaseGuard, pdu::PduBuilder, Result, Ruma};
|
||||
use crate::{service::pdu::PduBuilder, services, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::redact::redact_event,
|
||||
events::{room::redaction::RoomRedactionEventContent, RoomEventType},
|
||||
|
@ -14,14 +14,14 @@ use serde_json::value::to_raw_value;
|
|||
///
|
||||
/// - TODO: Handle txn id
|
||||
pub async fn redact_event_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<redact_event::v3::IncomingRequest>,
|
||||
) -> Result<redact_event::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let body = body.body;
|
||||
|
||||
let mutex_state = Arc::clone(
|
||||
db.globals
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
|
@ -30,7 +30,7 @@ pub async fn redact_event_route(
|
|||
);
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
let event_id = db.rooms.build_and_append_pdu(
|
||||
let event_id = services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: RoomEventType::RoomRedaction,
|
||||
content: to_raw_value(&RoomRedactionEventContent {
|
||||
|
@ -43,14 +43,11 @@ pub async fn redact_event_route(
|
|||
},
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
|
||||
drop(state_lock);
|
||||
|
||||
db.flush()?;
|
||||
|
||||
let event_id = (*event_id).to_owned();
|
||||
Ok(redact_event::v3::Response { event_id })
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, utils::HtmlEscape, Error, Result, Ruma};
|
||||
use crate::{services, utils::HtmlEscape, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{error::ErrorKind, room::report_content},
|
||||
events::room::message,
|
||||
|
@ -10,12 +10,11 @@ use ruma::{
|
|||
/// Reports an inappropriate event to homeserver admins
|
||||
///
|
||||
pub async fn report_event_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<report_content::v3::IncomingRequest>,
|
||||
) -> Result<report_content::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let pdu = match db.rooms.get_pdu(&body.event_id)? {
|
||||
let pdu = match services().rooms.timeline.get_pdu(&body.event_id)? {
|
||||
Some(pdu) => pdu,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
|
@ -39,7 +38,7 @@ pub async fn report_event_route(
|
|||
));
|
||||
};
|
||||
|
||||
db.admin
|
||||
services().admin
|
||||
.send_message(message::RoomMessageEventContent::text_html(
|
||||
format!(
|
||||
"Report received from: {}\n\n\
|
||||
|
@ -66,7 +65,5 @@ pub async fn report_event_route(
|
|||
),
|
||||
));
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(report_content::v3::Response {})
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{
|
||||
client_server::invite_helper, database::DatabaseGuard, pdu::PduBuilder, Error, Result, Ruma,
|
||||
api::client_server::invite_helper, service::pdu::PduBuilder, services, Error, Result, Ruma,
|
||||
};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
|
@ -46,19 +46,19 @@ use tracing::{info, warn};
|
|||
/// - Send events implied by `name` and `topic`
|
||||
/// - Send invite events
|
||||
pub async fn create_room_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<create_room::v3::IncomingRequest>,
|
||||
) -> Result<create_room::v3::Response> {
|
||||
use create_room::v3::RoomPreset;
|
||||
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let room_id = RoomId::new(db.globals.server_name());
|
||||
let room_id = RoomId::new(services().globals.server_name());
|
||||
|
||||
db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?;
|
||||
services().rooms.short.get_or_create_shortroomid(&room_id)?;
|
||||
|
||||
let mutex_state = Arc::clone(
|
||||
db.globals
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
|
@ -67,9 +67,9 @@ pub async fn create_room_route(
|
|||
);
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
if !db.globals.allow_room_creation()
|
||||
if !services().globals.allow_room_creation()
|
||||
&& !body.from_appservice
|
||||
&& !db.users.is_admin(sender_user, &db.rooms, &db.globals)?
|
||||
&& !services().users.is_admin(sender_user)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
|
@ -82,13 +82,19 @@ pub async fn create_room_route(
|
|||
.as_ref()
|
||||
.map_or(Ok(None), |localpart| {
|
||||
// TODO: Check for invalid characters and maximum length
|
||||
let alias =
|
||||
RoomAliasId::parse(format!("#{}:{}", localpart, db.globals.server_name()))
|
||||
.map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias.")
|
||||
})?;
|
||||
let alias = RoomAliasId::parse(format!(
|
||||
"#{}:{}",
|
||||
localpart,
|
||||
services().globals.server_name()
|
||||
))
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?;
|
||||
|
||||
if db.rooms.id_from_alias(&alias)?.is_some() {
|
||||
if services()
|
||||
.rooms
|
||||
.alias
|
||||
.resolve_local_alias(&alias)?
|
||||
.is_some()
|
||||
{
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::RoomInUse,
|
||||
"Room alias already exists.",
|
||||
|
@ -100,7 +106,11 @@ pub async fn create_room_route(
|
|||
|
||||
let room_version = match body.room_version.clone() {
|
||||
Some(room_version) => {
|
||||
if db.rooms.is_supported_version(&db, &room_version) {
|
||||
if services()
|
||||
.globals
|
||||
.supported_room_versions()
|
||||
.contains(&room_version)
|
||||
{
|
||||
room_version
|
||||
} else {
|
||||
return Err(Error::BadRequest(
|
||||
|
@ -109,7 +119,7 @@ pub async fn create_room_route(
|
|||
));
|
||||
}
|
||||
}
|
||||
None => db.globals.default_room_version(),
|
||||
None => services().globals.default_room_version(),
|
||||
};
|
||||
|
||||
let content = match &body.creation_content {
|
||||
|
@ -163,7 +173,7 @@ pub async fn create_room_route(
|
|||
}
|
||||
|
||||
// 1. The room create event
|
||||
db.rooms.build_and_append_pdu(
|
||||
services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: RoomEventType::RoomCreate,
|
||||
content: to_raw_value(&content).expect("event is valid, we just created it"),
|
||||
|
@ -173,21 +183,20 @@ pub async fn create_room_route(
|
|||
},
|
||||
sender_user,
|
||||
&room_id,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
|
||||
// 2. Let the room creator join
|
||||
db.rooms.build_and_append_pdu(
|
||||
services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: RoomEventType::RoomMember,
|
||||
content: to_raw_value(&RoomMemberEventContent {
|
||||
membership: MembershipState::Join,
|
||||
displayname: db.users.displayname(sender_user)?,
|
||||
avatar_url: db.users.avatar_url(sender_user)?,
|
||||
displayname: services().users.displayname(sender_user)?,
|
||||
avatar_url: services().users.avatar_url(sender_user)?,
|
||||
is_direct: Some(body.is_direct),
|
||||
third_party_invite: None,
|
||||
blurhash: db.users.blurhash(sender_user)?,
|
||||
blurhash: services().users.blurhash(sender_user)?,
|
||||
reason: None,
|
||||
join_authorized_via_users_server: None,
|
||||
})
|
||||
|
@ -198,7 +207,6 @@ pub async fn create_room_route(
|
|||
},
|
||||
sender_user,
|
||||
&room_id,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
|
||||
|
@ -240,7 +248,7 @@ pub async fn create_room_route(
|
|||
}
|
||||
}
|
||||
|
||||
db.rooms.build_and_append_pdu(
|
||||
services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: RoomEventType::RoomPowerLevels,
|
||||
content: to_raw_value(&power_levels_content)
|
||||
|
@ -251,13 +259,12 @@ pub async fn create_room_route(
|
|||
},
|
||||
sender_user,
|
||||
&room_id,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
|
||||
// 4. Canonical room alias
|
||||
if let Some(room_alias_id) = &alias {
|
||||
db.rooms.build_and_append_pdu(
|
||||
services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: RoomEventType::RoomCanonicalAlias,
|
||||
content: to_raw_value(&RoomCanonicalAliasEventContent {
|
||||
|
@ -271,7 +278,6 @@ pub async fn create_room_route(
|
|||
},
|
||||
sender_user,
|
||||
&room_id,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
}
|
||||
|
@ -279,7 +285,7 @@ pub async fn create_room_route(
|
|||
// 5. Events set by preset
|
||||
|
||||
// 5.1 Join Rules
|
||||
db.rooms.build_and_append_pdu(
|
||||
services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: RoomEventType::RoomJoinRules,
|
||||
content: to_raw_value(&RoomJoinRulesEventContent::new(match preset {
|
||||
|
@ -294,12 +300,11 @@ pub async fn create_room_route(
|
|||
},
|
||||
sender_user,
|
||||
&room_id,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
|
||||
// 5.2 History Visibility
|
||||
db.rooms.build_and_append_pdu(
|
||||
services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: RoomEventType::RoomHistoryVisibility,
|
||||
content: to_raw_value(&RoomHistoryVisibilityEventContent::new(
|
||||
|
@ -312,12 +317,11 @@ pub async fn create_room_route(
|
|||
},
|
||||
sender_user,
|
||||
&room_id,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
|
||||
// 5.3 Guest Access
|
||||
db.rooms.build_and_append_pdu(
|
||||
services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: RoomEventType::RoomGuestAccess,
|
||||
content: to_raw_value(&RoomGuestAccessEventContent::new(match preset {
|
||||
|
@ -331,7 +335,6 @@ pub async fn create_room_route(
|
|||
},
|
||||
sender_user,
|
||||
&room_id,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
|
||||
|
@ -346,18 +349,23 @@ pub async fn create_room_route(
|
|||
pdu_builder.state_key.get_or_insert_with(|| "".to_owned());
|
||||
|
||||
// Silently skip encryption events if they are not allowed
|
||||
if pdu_builder.event_type == RoomEventType::RoomEncryption && !db.globals.allow_encryption()
|
||||
if pdu_builder.event_type == RoomEventType::RoomEncryption
|
||||
&& !services().globals.allow_encryption()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
db.rooms
|
||||
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock)?;
|
||||
services().rooms.timeline.build_and_append_pdu(
|
||||
pdu_builder,
|
||||
sender_user,
|
||||
&room_id,
|
||||
&state_lock,
|
||||
)?;
|
||||
}
|
||||
|
||||
// 7. Events implied by name and topic
|
||||
if let Some(name) = &body.name {
|
||||
db.rooms.build_and_append_pdu(
|
||||
services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: RoomEventType::RoomName,
|
||||
content: to_raw_value(&RoomNameEventContent::new(Some(name.clone())))
|
||||
|
@ -368,13 +376,12 @@ pub async fn create_room_route(
|
|||
},
|
||||
sender_user,
|
||||
&room_id,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(topic) = &body.topic {
|
||||
db.rooms.build_and_append_pdu(
|
||||
services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: RoomEventType::RoomTopic,
|
||||
content: to_raw_value(&RoomTopicEventContent {
|
||||
|
@ -387,7 +394,6 @@ pub async fn create_room_route(
|
|||
},
|
||||
sender_user,
|
||||
&room_id,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
}
|
||||
|
@ -395,22 +401,20 @@ pub async fn create_room_route(
|
|||
// 8. Events implied by invite (and TODO: invite_3pid)
|
||||
drop(state_lock);
|
||||
for user_id in &body.invite {
|
||||
let _ = invite_helper(sender_user, user_id, &room_id, &db, body.is_direct).await;
|
||||
let _ = invite_helper(sender_user, user_id, &room_id, body.is_direct).await;
|
||||
}
|
||||
|
||||
// Homeserver specific stuff
|
||||
if let Some(alias) = alias {
|
||||
db.rooms.set_alias(&alias, Some(&room_id), &db.globals)?;
|
||||
services().rooms.alias.set_alias(&alias, &room_id)?;
|
||||
}
|
||||
|
||||
if body.visibility == room::Visibility::Public {
|
||||
db.rooms.set_public(&room_id, true)?;
|
||||
services().rooms.directory.set_public(&room_id)?;
|
||||
}
|
||||
|
||||
info!("{} created a room", sender_user);
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(create_room::v3::Response::new(room_id))
|
||||
}
|
||||
|
||||
|
@ -420,12 +424,15 @@ pub async fn create_room_route(
|
|||
///
|
||||
/// - You have to currently be joined to the room (TODO: Respect history visibility)
|
||||
pub async fn get_room_event_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_room_event::v3::IncomingRequest>,
|
||||
) -> Result<get_room_event::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !db.rooms.is_joined(sender_user, &body.room_id)? {
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_joined(sender_user, &body.room_id)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view this room.",
|
||||
|
@ -433,8 +440,9 @@ pub async fn get_room_event_route(
|
|||
}
|
||||
|
||||
Ok(get_room_event::v3::Response {
|
||||
event: db
|
||||
event: services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu(&body.event_id)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?
|
||||
.to_room_event(),
|
||||
|
@ -447,12 +455,15 @@ pub async fn get_room_event_route(
|
|||
///
|
||||
/// - Only users joined to the room are allowed to call this TODO: Allow any user to call it if history_visibility is world readable
|
||||
pub async fn get_room_aliases_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<aliases::v3::IncomingRequest>,
|
||||
) -> Result<aliases::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !db.rooms.is_joined(sender_user, &body.room_id)? {
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_joined(sender_user, &body.room_id)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view this room.",
|
||||
|
@ -460,9 +471,10 @@ pub async fn get_room_aliases_route(
|
|||
}
|
||||
|
||||
Ok(aliases::v3::Response {
|
||||
aliases: db
|
||||
aliases: services()
|
||||
.rooms
|
||||
.room_aliases(&body.room_id)
|
||||
.alias
|
||||
.local_aliases_for_room(&body.room_id)
|
||||
.filter_map(|a| a.ok())
|
||||
.collect(),
|
||||
})
|
||||
|
@ -479,12 +491,15 @@ pub async fn get_room_aliases_route(
|
|||
/// - Moves local aliases
|
||||
/// - Modifies old room power levels to prevent users from speaking
|
||||
pub async fn upgrade_room_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<upgrade_room::v3::IncomingRequest>,
|
||||
) -> Result<upgrade_room::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !db.rooms.is_supported_version(&db, &body.new_version) {
|
||||
if !services()
|
||||
.globals
|
||||
.supported_room_versions()
|
||||
.contains(&body.new_version)
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UnsupportedRoomVersion,
|
||||
"This server does not support that room version.",
|
||||
|
@ -492,12 +507,15 @@ pub async fn upgrade_room_route(
|
|||
}
|
||||
|
||||
// Create a replacement room
|
||||
let replacement_room = RoomId::new(db.globals.server_name());
|
||||
db.rooms
|
||||
.get_or_create_shortroomid(&replacement_room, &db.globals)?;
|
||||
let replacement_room = RoomId::new(services().globals.server_name());
|
||||
services()
|
||||
.rooms
|
||||
.short
|
||||
.get_or_create_shortroomid(&replacement_room)?;
|
||||
|
||||
let mutex_state = Arc::clone(
|
||||
db.globals
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
|
@ -508,7 +526,7 @@ pub async fn upgrade_room_route(
|
|||
|
||||
// Send a m.room.tombstone event to the old room to indicate that it is not intended to be used any further
|
||||
// Fail if the sender does not have the required permissions
|
||||
let tombstone_event_id = db.rooms.build_and_append_pdu(
|
||||
let tombstone_event_id = services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: RoomEventType::RoomTombstone,
|
||||
content: to_raw_value(&RoomTombstoneEventContent {
|
||||
|
@ -522,14 +540,14 @@ pub async fn upgrade_room_route(
|
|||
},
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
|
||||
// Change lock to replacement room
|
||||
drop(state_lock);
|
||||
let mutex_state = Arc::clone(
|
||||
db.globals
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
|
@ -540,7 +558,9 @@ pub async fn upgrade_room_route(
|
|||
|
||||
// Get the old room creation event
|
||||
let mut create_event_content = serde_json::from_str::<CanonicalJsonObject>(
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&body.room_id, &StateEventType::RoomCreate, "")?
|
||||
.ok_or_else(|| Error::bad_database("Found room without m.room.create event."))?
|
||||
.content
|
||||
|
@ -588,7 +608,7 @@ pub async fn upgrade_room_route(
|
|||
));
|
||||
}
|
||||
|
||||
db.rooms.build_and_append_pdu(
|
||||
services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: RoomEventType::RoomCreate,
|
||||
content: to_raw_value(&create_event_content)
|
||||
|
@ -599,21 +619,20 @@ pub async fn upgrade_room_route(
|
|||
},
|
||||
sender_user,
|
||||
&replacement_room,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
|
||||
// Join the new room
|
||||
db.rooms.build_and_append_pdu(
|
||||
services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: RoomEventType::RoomMember,
|
||||
content: to_raw_value(&RoomMemberEventContent {
|
||||
membership: MembershipState::Join,
|
||||
displayname: db.users.displayname(sender_user)?,
|
||||
avatar_url: db.users.avatar_url(sender_user)?,
|
||||
displayname: services().users.displayname(sender_user)?,
|
||||
avatar_url: services().users.avatar_url(sender_user)?,
|
||||
is_direct: None,
|
||||
third_party_invite: None,
|
||||
blurhash: db.users.blurhash(sender_user)?,
|
||||
blurhash: services().users.blurhash(sender_user)?,
|
||||
reason: None,
|
||||
join_authorized_via_users_server: None,
|
||||
})
|
||||
|
@ -624,7 +643,6 @@ pub async fn upgrade_room_route(
|
|||
},
|
||||
sender_user,
|
||||
&replacement_room,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
|
||||
|
@ -643,12 +661,17 @@ pub async fn upgrade_room_route(
|
|||
|
||||
// Replicate transferable state events to the new room
|
||||
for event_type in transferable_state_events {
|
||||
let event_content = match db.rooms.room_state_get(&body.room_id, &event_type, "")? {
|
||||
Some(v) => v.content.clone(),
|
||||
None => continue, // Skipping missing events.
|
||||
};
|
||||
let event_content =
|
||||
match services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&body.room_id, &event_type, "")?
|
||||
{
|
||||
Some(v) => v.content.clone(),
|
||||
None => continue, // Skipping missing events.
|
||||
};
|
||||
|
||||
db.rooms.build_and_append_pdu(
|
||||
services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: event_type.to_string().into(),
|
||||
content: event_content,
|
||||
|
@ -658,20 +681,28 @@ pub async fn upgrade_room_route(
|
|||
},
|
||||
sender_user,
|
||||
&replacement_room,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
}
|
||||
|
||||
// Moves any local aliases to the new room
|
||||
for alias in db.rooms.room_aliases(&body.room_id).filter_map(|r| r.ok()) {
|
||||
db.rooms
|
||||
.set_alias(&alias, Some(&replacement_room), &db.globals)?;
|
||||
for alias in services()
|
||||
.rooms
|
||||
.alias
|
||||
.local_aliases_for_room(&body.room_id)
|
||||
.filter_map(|r| r.ok())
|
||||
{
|
||||
services()
|
||||
.rooms
|
||||
.alias
|
||||
.set_alias(&alias, &replacement_room)?;
|
||||
}
|
||||
|
||||
// Get the old room power levels
|
||||
let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str(
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")?
|
||||
.ok_or_else(|| Error::bad_database("Found room without m.room.create event."))?
|
||||
.content
|
||||
|
@ -685,7 +716,7 @@ pub async fn upgrade_room_route(
|
|||
power_levels_event_content.invite = new_level;
|
||||
|
||||
// Modify the power levels in the old room to prevent sending of events and inviting new users
|
||||
let _ = db.rooms.build_and_append_pdu(
|
||||
let _ = services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: RoomEventType::RoomPowerLevels,
|
||||
content: to_raw_value(&power_levels_event_content)
|
||||
|
@ -696,14 +727,11 @@ pub async fn upgrade_room_route(
|
|||
},
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&db,
|
||||
&state_lock,
|
||||
)?;
|
||||
|
||||
drop(state_lock);
|
||||
|
||||
db.flush()?;
|
||||
|
||||
// Return the replacement room id
|
||||
Ok(upgrade_room::v3::Response { replacement_room })
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, Error, Result, Ruma};
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::api::client::{
|
||||
error::ErrorKind,
|
||||
search::search_events::{
|
||||
|
@ -15,7 +15,6 @@ use std::collections::BTreeMap;
|
|||
///
|
||||
/// - Only works if the user is currently joined to the room (TODO: Respect history visibility)
|
||||
pub async fn search_events_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<search_events::v3::IncomingRequest>,
|
||||
) -> Result<search_events::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -24,7 +23,9 @@ pub async fn search_events_route(
|
|||
let filter = &search_criteria.filter;
|
||||
|
||||
let room_ids = filter.rooms.clone().unwrap_or_else(|| {
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(sender_user)
|
||||
.filter_map(|r| r.ok())
|
||||
.collect()
|
||||
|
@ -35,15 +36,20 @@ pub async fn search_events_route(
|
|||
let mut searches = Vec::new();
|
||||
|
||||
for room_id in room_ids {
|
||||
if !db.rooms.is_joined(sender_user, &room_id)? {
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_joined(sender_user, &room_id)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view this room.",
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(search) = db
|
||||
if let Some(search) = services()
|
||||
.rooms
|
||||
.search
|
||||
.search_pdus(&room_id, &search_criteria.search_term)?
|
||||
{
|
||||
searches.push(search.0.peekable());
|
||||
|
@ -85,8 +91,9 @@ pub async fn search_events_route(
|
|||
start: None,
|
||||
},
|
||||
rank: None,
|
||||
result: db
|
||||
result: services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_from_id(result)?
|
||||
.map(|pdu| pdu.to_room_event()),
|
||||
})
|
|
@ -1,5 +1,5 @@
|
|||
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
|
||||
use crate::{database::DatabaseGuard, utils, Error, Result, Ruma};
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
|
@ -40,10 +40,7 @@ pub async fn get_login_types_route(
|
|||
///
|
||||
/// Note: You can use [`GET /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see
|
||||
/// supported login types.
|
||||
pub async fn login_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<login::v3::IncomingRequest>,
|
||||
) -> Result<login::v3::Response> {
|
||||
pub async fn login_route(body: Ruma<login::v3::IncomingRequest>) -> Result<login::v3::Response> {
|
||||
// Validate login method
|
||||
// TODO: Other login methods
|
||||
let user_id = match &body.login_info {
|
||||
|
@ -56,15 +53,18 @@ pub async fn login_route(
|
|||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
|
||||
};
|
||||
let user_id =
|
||||
UserId::parse_with_server_name(username.to_owned(), db.globals.server_name())
|
||||
.map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
||||
})?;
|
||||
let hash = db.users.password_hash(&user_id)?.ok_or(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Wrong username or password.",
|
||||
))?;
|
||||
let user_id = UserId::parse_with_server_name(
|
||||
username.to_owned(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
|
||||
let hash = services()
|
||||
.users
|
||||
.password_hash(&user_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Wrong username or password.",
|
||||
))?;
|
||||
|
||||
if hash.is_empty() {
|
||||
return Err(Error::BadRequest(
|
||||
|
@ -85,7 +85,7 @@ pub async fn login_route(
|
|||
user_id
|
||||
}
|
||||
login::v3::IncomingLoginInfo::Token(login::v3::IncomingToken { token }) => {
|
||||
if let Some(jwt_decoding_key) = db.globals.jwt_decoding_key() {
|
||||
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() {
|
||||
let token = jsonwebtoken::decode::<Claims>(
|
||||
token,
|
||||
jwt_decoding_key,
|
||||
|
@ -93,7 +93,7 @@ pub async fn login_route(
|
|||
)
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?;
|
||||
let username = token.claims.sub;
|
||||
UserId::parse_with_server_name(username, db.globals.server_name()).map_err(
|
||||
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(
|
||||
|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."),
|
||||
)?
|
||||
} else {
|
||||
|
@ -122,15 +122,16 @@ pub async fn login_route(
|
|||
|
||||
// Determine if device_id was provided and exists in the db for this user
|
||||
let device_exists = body.device_id.as_ref().map_or(false, |device_id| {
|
||||
db.users
|
||||
services()
|
||||
.users
|
||||
.all_device_ids(&user_id)
|
||||
.any(|x| x.as_ref().map_or(false, |v| v == device_id))
|
||||
});
|
||||
|
||||
if device_exists {
|
||||
db.users.set_token(&user_id, &device_id, &token)?;
|
||||
services().users.set_token(&user_id, &device_id, &token)?;
|
||||
} else {
|
||||
db.users.create_device(
|
||||
services().users.create_device(
|
||||
&user_id,
|
||||
&device_id,
|
||||
&token,
|
||||
|
@ -140,12 +141,10 @@ pub async fn login_route(
|
|||
|
||||
info!("{} logged in", user_id);
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(login::v3::Response {
|
||||
user_id,
|
||||
access_token: token,
|
||||
home_server: Some(db.globals.server_name().to_owned()),
|
||||
home_server: Some(services().globals.server_name().to_owned()),
|
||||
device_id,
|
||||
well_known: None,
|
||||
})
|
||||
|
@ -159,16 +158,11 @@ pub async fn login_route(
|
|||
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
|
||||
/// - Forgets to-device events
|
||||
/// - Triggers device list updates
|
||||
pub async fn logout_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<logout::v3::Request>,
|
||||
) -> Result<logout::v3::Response> {
|
||||
pub async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
db.users.remove_device(sender_user, sender_device)?;
|
||||
|
||||
db.flush()?;
|
||||
services().users.remove_device(sender_user, sender_device)?;
|
||||
|
||||
Ok(logout::v3::Response::new())
|
||||
}
|
||||
|
@ -185,16 +179,13 @@ pub async fn logout_route(
|
|||
/// Note: This is equivalent to calling [`GET /_matrix/client/r0/logout`](fn.logout_route.html)
|
||||
/// from each device of this user.
|
||||
pub async fn logout_all_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<logout_all::v3::Request>,
|
||||
) -> Result<logout_all::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
for device_id in db.users.all_device_ids(sender_user).flatten() {
|
||||
db.users.remove_device(sender_user, &device_id)?;
|
||||
for device_id in services().users.all_device_ids(sender_user).flatten() {
|
||||
services().users.remove_device(sender_user, &device_id)?;
|
||||
}
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(logout_all::v3::Response::new())
|
||||
}
|
|
@ -1,8 +1,6 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
database::DatabaseGuard, pdu::PduBuilder, Database, Error, Result, Ruma, RumaResponse,
|
||||
};
|
||||
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
|
@ -27,13 +25,11 @@ use ruma::{
|
|||
/// - Tries to send the event into the room, auth rules will determine if it is allowed
|
||||
/// - If event is new canonical_alias: Rejects if alias is incorrect
|
||||
pub async fn send_state_event_for_key_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<send_state_event::v3::IncomingRequest>,
|
||||
) -> Result<send_state_event::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event_id = send_state_event_for_key_helper(
|
||||
&db,
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_type,
|
||||
|
@ -42,8 +38,6 @@ pub async fn send_state_event_for_key_route(
|
|||
)
|
||||
.await?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
let event_id = (*event_id).to_owned();
|
||||
Ok(send_state_event::v3::Response { event_id })
|
||||
}
|
||||
|
@ -56,13 +50,12 @@ pub async fn send_state_event_for_key_route(
|
|||
/// - Tries to send the event into the room, auth rules will determine if it is allowed
|
||||
/// - If event is new canonical_alias: Rejects if alias is incorrect
|
||||
pub async fn send_state_event_for_empty_key_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<send_state_event::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<send_state_event::v3::Response>> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
// Forbid m.room.encryption if encryption is disabled
|
||||
if body.event_type == StateEventType::RoomEncryption && !db.globals.allow_encryption() {
|
||||
if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Encryption has been disabled",
|
||||
|
@ -70,7 +63,6 @@ pub async fn send_state_event_for_empty_key_route(
|
|||
}
|
||||
|
||||
let event_id = send_state_event_for_key_helper(
|
||||
&db,
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_type.to_string().into(),
|
||||
|
@ -79,8 +71,6 @@ pub async fn send_state_event_for_empty_key_route(
|
|||
)
|
||||
.await?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
let event_id = (*event_id).to_owned();
|
||||
Ok(send_state_event::v3::Response { event_id }.into())
|
||||
}
|
||||
|
@ -91,7 +81,6 @@ pub async fn send_state_event_for_empty_key_route(
|
|||
///
|
||||
/// - If not joined: Only works if current room history visibility is world readable
|
||||
pub async fn get_state_events_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_state_events::v3::IncomingRequest>,
|
||||
) -> Result<get_state_events::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -99,9 +88,14 @@ pub async fn get_state_events_route(
|
|||
#[allow(clippy::blocks_in_if_conditions)]
|
||||
// Users not in the room should not be able to access the state unless history_visibility is
|
||||
// WorldReadable
|
||||
if !db.rooms.is_joined(sender_user, &body.room_id)?
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_joined(sender_user, &body.room_id)?
|
||||
&& !matches!(
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")?
|
||||
.map(|event| {
|
||||
serde_json::from_str(event.content.get())
|
||||
|
@ -122,8 +116,9 @@ pub async fn get_state_events_route(
|
|||
}
|
||||
|
||||
Ok(get_state_events::v3::Response {
|
||||
room_state: db
|
||||
room_state: services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_full(&body.room_id)
|
||||
.await?
|
||||
.values()
|
||||
|
@ -138,7 +133,6 @@ pub async fn get_state_events_route(
|
|||
///
|
||||
/// - If not joined: Only works if current room history visibility is world readable
|
||||
pub async fn get_state_events_for_key_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_state_events_for_key::v3::IncomingRequest>,
|
||||
) -> Result<get_state_events_for_key::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -146,9 +140,14 @@ pub async fn get_state_events_for_key_route(
|
|||
#[allow(clippy::blocks_in_if_conditions)]
|
||||
// Users not in the room should not be able to access the state unless history_visibility is
|
||||
// WorldReadable
|
||||
if !db.rooms.is_joined(sender_user, &body.room_id)?
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_joined(sender_user, &body.room_id)?
|
||||
&& !matches!(
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")?
|
||||
.map(|event| {
|
||||
serde_json::from_str(event.content.get())
|
||||
|
@ -168,8 +167,9 @@ pub async fn get_state_events_for_key_route(
|
|||
));
|
||||
}
|
||||
|
||||
let event = db
|
||||
let event = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&body.room_id, &body.event_type, &body.state_key)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
|
@ -188,7 +188,6 @@ pub async fn get_state_events_for_key_route(
|
|||
///
|
||||
/// - If not joined: Only works if current room history visibility is world readable
|
||||
pub async fn get_state_events_for_empty_key_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_state_events_for_key::v3::IncomingRequest>,
|
||||
) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -196,9 +195,14 @@ pub async fn get_state_events_for_empty_key_route(
|
|||
#[allow(clippy::blocks_in_if_conditions)]
|
||||
// Users not in the room should not be able to access the state unless history_visibility is
|
||||
// WorldReadable
|
||||
if !db.rooms.is_joined(sender_user, &body.room_id)?
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_joined(sender_user, &body.room_id)?
|
||||
&& !matches!(
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")?
|
||||
.map(|event| {
|
||||
serde_json::from_str(event.content.get())
|
||||
|
@ -218,8 +222,9 @@ pub async fn get_state_events_for_empty_key_route(
|
|||
));
|
||||
}
|
||||
|
||||
let event = db
|
||||
let event = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&body.room_id, &body.event_type, "")?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
|
@ -234,7 +239,6 @@ pub async fn get_state_events_for_empty_key_route(
|
|||
}
|
||||
|
||||
async fn send_state_event_for_key_helper(
|
||||
db: &Database,
|
||||
sender: &UserId,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
|
@ -255,10 +259,11 @@ async fn send_state_event_for_key_helper(
|
|||
}
|
||||
|
||||
for alias in aliases {
|
||||
if alias.server_name() != db.globals.server_name()
|
||||
|| db
|
||||
if alias.server_name() != services().globals.server_name()
|
||||
|| services()
|
||||
.rooms
|
||||
.id_from_alias(&alias)?
|
||||
.alias
|
||||
.resolve_local_alias(&alias)?
|
||||
.filter(|room| room == room_id) // Make sure it's the right room
|
||||
.is_none()
|
||||
{
|
||||
|
@ -272,7 +277,8 @@ async fn send_state_event_for_key_helper(
|
|||
}
|
||||
|
||||
let mutex_state = Arc::clone(
|
||||
db.globals
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
|
@ -281,7 +287,7 @@ async fn send_state_event_for_key_helper(
|
|||
);
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
let event_id = db.rooms.build_and_append_pdu(
|
||||
let event_id = services().rooms.timeline.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: event_type.to_string().into(),
|
||||
content: serde_json::from_str(json.json().get()).expect("content is valid json"),
|
||||
|
@ -291,7 +297,6 @@ async fn send_state_event_for_key_helper(
|
|||
},
|
||||
sender_user,
|
||||
room_id,
|
||||
db,
|
||||
&state_lock,
|
||||
)?;
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, Database, Error, Result, Ruma, RumaResponse};
|
||||
use crate::{services, Error, Result, Ruma, RumaResponse};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
filter::{IncomingFilterDefinition, LazyLoadOptions},
|
||||
|
@ -55,16 +55,13 @@ use tracing::error;
|
|||
/// - Sync is handled in an async task, multiple requests from the same device with the same
|
||||
/// `since` will be cached
|
||||
pub async fn sync_events_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<sync_events::v3::IncomingRequest>,
|
||||
) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> {
|
||||
let sender_user = body.sender_user.expect("user is authenticated");
|
||||
let sender_device = body.sender_device.expect("user is authenticated");
|
||||
let body = body.body;
|
||||
|
||||
let arc_db = Arc::new(db);
|
||||
|
||||
let mut rx = match arc_db
|
||||
let mut rx = match services()
|
||||
.globals
|
||||
.sync_receivers
|
||||
.write()
|
||||
|
@ -77,7 +74,6 @@ pub async fn sync_events_route(
|
|||
v.insert((body.since.to_owned(), rx.clone()));
|
||||
|
||||
tokio::spawn(sync_helper_wrapper(
|
||||
Arc::clone(&arc_db),
|
||||
sender_user.clone(),
|
||||
sender_device.clone(),
|
||||
body,
|
||||
|
@ -93,7 +89,6 @@ pub async fn sync_events_route(
|
|||
o.insert((body.since.clone(), rx.clone()));
|
||||
|
||||
tokio::spawn(sync_helper_wrapper(
|
||||
Arc::clone(&arc_db),
|
||||
sender_user.clone(),
|
||||
sender_device.clone(),
|
||||
body,
|
||||
|
@ -127,7 +122,6 @@ pub async fn sync_events_route(
|
|||
}
|
||||
|
||||
async fn sync_helper_wrapper(
|
||||
db: Arc<DatabaseGuard>,
|
||||
sender_user: Box<UserId>,
|
||||
sender_device: Box<DeviceId>,
|
||||
body: sync_events::v3::IncomingRequest,
|
||||
|
@ -135,17 +129,11 @@ async fn sync_helper_wrapper(
|
|||
) {
|
||||
let since = body.since.clone();
|
||||
|
||||
let r = sync_helper(
|
||||
Arc::clone(&db),
|
||||
sender_user.clone(),
|
||||
sender_device.clone(),
|
||||
body,
|
||||
)
|
||||
.await;
|
||||
let r = sync_helper(sender_user.clone(), sender_device.clone(), body).await;
|
||||
|
||||
if let Ok((_, caching_allowed)) = r {
|
||||
if !caching_allowed {
|
||||
match db
|
||||
match services()
|
||||
.globals
|
||||
.sync_receivers
|
||||
.write()
|
||||
|
@ -163,13 +151,10 @@ async fn sync_helper_wrapper(
|
|||
}
|
||||
}
|
||||
|
||||
drop(db);
|
||||
|
||||
let _ = tx.send(Some(r.map(|(r, _)| r)));
|
||||
}
|
||||
|
||||
async fn sync_helper(
|
||||
db: Arc<DatabaseGuard>,
|
||||
sender_user: Box<UserId>,
|
||||
sender_device: Box<DeviceId>,
|
||||
body: sync_events::v3::IncomingRequest,
|
||||
|
@ -182,19 +167,19 @@ async fn sync_helper(
|
|||
};
|
||||
|
||||
// TODO: match body.set_presence {
|
||||
db.rooms.edus.ping_presence(&sender_user)?;
|
||||
services().rooms.edus.presence.ping_presence(&sender_user)?;
|
||||
|
||||
// Setup watchers, so if there's no response, we can wait for them
|
||||
let watcher = db.watch(&sender_user, &sender_device);
|
||||
let watcher = services().globals.watch(&sender_user, &sender_device);
|
||||
|
||||
let next_batch = db.globals.current_count()?;
|
||||
let next_batch = services().globals.current_count()?;
|
||||
let next_batch_string = next_batch.to_string();
|
||||
|
||||
// Load filter
|
||||
let filter = match body.filter {
|
||||
None => IncomingFilterDefinition::default(),
|
||||
Some(IncomingFilter::FilterDefinition(filter)) => filter,
|
||||
Some(IncomingFilter::FilterId(filter_id)) => db
|
||||
Some(IncomingFilter::FilterId(filter_id)) => services()
|
||||
.users
|
||||
.get_filter(&sender_user, &filter_id)?
|
||||
.unwrap_or_default(),
|
||||
|
@ -221,12 +206,17 @@ async fn sync_helper(
|
|||
|
||||
// Look for device list updates of this account
|
||||
device_list_updates.extend(
|
||||
db.users
|
||||
services()
|
||||
.users
|
||||
.keys_changed(&sender_user.to_string(), since, None)
|
||||
.filter_map(|r| r.ok()),
|
||||
);
|
||||
|
||||
let all_joined_rooms = db.rooms.rooms_joined(&sender_user).collect::<Vec<_>>();
|
||||
let all_joined_rooms = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(&sender_user)
|
||||
.collect::<Vec<_>>();
|
||||
for room_id in all_joined_rooms {
|
||||
let room_id = room_id?;
|
||||
|
||||
|
@ -234,7 +224,8 @@ async fn sync_helper(
|
|||
// Get and drop the lock to wait for remaining operations to finish
|
||||
// This will make sure the we have all events until next_batch
|
||||
let mutex_insert = Arc::clone(
|
||||
db.globals
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_insert
|
||||
.write()
|
||||
.unwrap()
|
||||
|
@ -247,9 +238,15 @@ async fn sync_helper(
|
|||
|
||||
let timeline_pdus;
|
||||
let limited;
|
||||
if db.rooms.last_timeline_count(&sender_user, &room_id)? > since {
|
||||
let mut non_timeline_pdus = db
|
||||
if services()
|
||||
.rooms
|
||||
.timeline
|
||||
.last_timeline_count(&sender_user, &room_id)?
|
||||
> since
|
||||
{
|
||||
let mut non_timeline_pdus = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdus_until(&sender_user, &room_id, u64::MAX)?
|
||||
.filter_map(|r| {
|
||||
// Filter out buggy events
|
||||
|
@ -259,7 +256,9 @@ async fn sync_helper(
|
|||
r.ok()
|
||||
})
|
||||
.take_while(|(pduid, _)| {
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.timeline
|
||||
.pdu_count(pduid)
|
||||
.map_or(false, |count| count > since)
|
||||
});
|
||||
|
@ -282,9 +281,10 @@ async fn sync_helper(
|
|||
}
|
||||
|
||||
let send_notification_counts = !timeline_pdus.is_empty()
|
||||
|| db
|
||||
|| services()
|
||||
.rooms
|
||||
.edus
|
||||
.read_receipt
|
||||
.last_privateread_update(&sender_user, &room_id)?
|
||||
> since;
|
||||
|
||||
|
@ -293,24 +293,40 @@ async fn sync_helper(
|
|||
timeline_users.insert(event.sender.as_str().to_owned());
|
||||
}
|
||||
|
||||
db.rooms
|
||||
.lazy_load_confirm_delivery(&sender_user, &sender_device, &room_id, since)?;
|
||||
services().rooms.lazy_loading.lazy_load_confirm_delivery(
|
||||
&sender_user,
|
||||
&sender_device,
|
||||
&room_id,
|
||||
since,
|
||||
)?;
|
||||
|
||||
// Database queries:
|
||||
|
||||
let current_shortstatehash = if let Some(s) = db.rooms.current_shortstatehash(&room_id)? {
|
||||
s
|
||||
} else {
|
||||
error!("Room {} has no state", room_id);
|
||||
continue;
|
||||
};
|
||||
let current_shortstatehash =
|
||||
if let Some(s) = services().rooms.state.get_room_shortstatehash(&room_id)? {
|
||||
s
|
||||
} else {
|
||||
error!("Room {} has no state", room_id);
|
||||
continue;
|
||||
};
|
||||
|
||||
let since_shortstatehash = db.rooms.get_token_shortstatehash(&room_id, since)?;
|
||||
let since_shortstatehash = services()
|
||||
.rooms
|
||||
.user
|
||||
.get_token_shortstatehash(&room_id, since)?;
|
||||
|
||||
// Calculates joined_member_count, invited_member_count and heroes
|
||||
let calculate_counts = || {
|
||||
let joined_member_count = db.rooms.room_joined_count(&room_id)?.unwrap_or(0);
|
||||
let invited_member_count = db.rooms.room_invited_count(&room_id)?.unwrap_or(0);
|
||||
let joined_member_count = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_joined_count(&room_id)?
|
||||
.unwrap_or(0);
|
||||
let invited_member_count = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_invited_count(&room_id)?
|
||||
.unwrap_or(0);
|
||||
|
||||
// Recalculate heroes (first 5 members)
|
||||
let mut heroes = Vec::new();
|
||||
|
@ -319,8 +335,9 @@ async fn sync_helper(
|
|||
// Go through all PDUs and for each member event, check if the user is still joined or
|
||||
// invited until we have 5 or we reach the end
|
||||
|
||||
for hero in db
|
||||
for hero in services()
|
||||
.rooms
|
||||
.timeline
|
||||
.all_pdus(&sender_user, &room_id)?
|
||||
.filter_map(|pdu| pdu.ok()) // Ignore all broken pdus
|
||||
.filter(|(_, pdu)| pdu.kind == RoomEventType::RoomMember)
|
||||
|
@ -339,8 +356,11 @@ async fn sync_helper(
|
|||
if matches!(
|
||||
content.membership,
|
||||
MembershipState::Join | MembershipState::Invite
|
||||
) && (db.rooms.is_joined(&user_id, &room_id)?
|
||||
|| db.rooms.is_invited(&user_id, &room_id)?)
|
||||
) && (services().rooms.state_cache.is_joined(&user_id, &room_id)?
|
||||
|| services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_invited(&user_id, &room_id)?)
|
||||
{
|
||||
Ok::<_, Error>(Some(state_key.clone()))
|
||||
} else {
|
||||
|
@ -381,17 +401,24 @@ async fn sync_helper(
|
|||
|
||||
let (joined_member_count, invited_member_count, heroes) = calculate_counts()?;
|
||||
|
||||
let current_state_ids = db.rooms.state_full_ids(current_shortstatehash).await?;
|
||||
let current_state_ids = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.state_full_ids(current_shortstatehash)
|
||||
.await?;
|
||||
|
||||
let mut state_events = Vec::new();
|
||||
let mut lazy_loaded = HashSet::new();
|
||||
|
||||
let mut i = 0;
|
||||
for (shortstatekey, id) in current_state_ids {
|
||||
let (event_type, state_key) = db.rooms.get_statekey_from_short(shortstatekey)?;
|
||||
let (event_type, state_key) = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_statekey_from_short(shortstatekey)?;
|
||||
|
||||
if event_type != StateEventType::RoomMember {
|
||||
let pdu = match db.rooms.get_pdu(&id)? {
|
||||
let pdu = match services().rooms.timeline.get_pdu(&id)? {
|
||||
Some(pdu) => pdu,
|
||||
None => {
|
||||
error!("Pdu in state not found: {}", id);
|
||||
|
@ -408,7 +435,7 @@ async fn sync_helper(
|
|||
|| body.full_state
|
||||
|| timeline_users.contains(&state_key)
|
||||
{
|
||||
let pdu = match db.rooms.get_pdu(&id)? {
|
||||
let pdu = match services().rooms.timeline.get_pdu(&id)? {
|
||||
Some(pdu) => pdu,
|
||||
None => {
|
||||
error!("Pdu in state not found: {}", id);
|
||||
|
@ -430,12 +457,15 @@ async fn sync_helper(
|
|||
}
|
||||
|
||||
// Reset lazy loading because this is an initial sync
|
||||
db.rooms
|
||||
.lazy_load_reset(&sender_user, &sender_device, &room_id)?;
|
||||
services().rooms.lazy_loading.lazy_load_reset(
|
||||
&sender_user,
|
||||
&sender_device,
|
||||
&room_id,
|
||||
)?;
|
||||
|
||||
// The state_events above should contain all timeline_users, let's mark them as lazy
|
||||
// loaded.
|
||||
db.rooms.lazy_load_mark_sent(
|
||||
services().rooms.lazy_loading.lazy_load_mark_sent(
|
||||
&sender_user,
|
||||
&sender_device,
|
||||
&room_id,
|
||||
|
@ -457,8 +487,9 @@ async fn sync_helper(
|
|||
// Incremental /sync
|
||||
let since_shortstatehash = since_shortstatehash.unwrap();
|
||||
|
||||
let since_sender_member: Option<RoomMemberEventContent> = db
|
||||
let since_sender_member: Option<RoomMemberEventContent> = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.state_get(
|
||||
since_shortstatehash,
|
||||
&StateEventType::RoomMember,
|
||||
|
@ -477,12 +508,20 @@ async fn sync_helper(
|
|||
let mut lazy_loaded = HashSet::new();
|
||||
|
||||
if since_shortstatehash != current_shortstatehash {
|
||||
let current_state_ids = db.rooms.state_full_ids(current_shortstatehash).await?;
|
||||
let since_state_ids = db.rooms.state_full_ids(since_shortstatehash).await?;
|
||||
let current_state_ids = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.state_full_ids(current_shortstatehash)
|
||||
.await?;
|
||||
let since_state_ids = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.state_full_ids(since_shortstatehash)
|
||||
.await?;
|
||||
|
||||
for (key, id) in current_state_ids {
|
||||
if body.full_state || since_state_ids.get(&key) != Some(&id) {
|
||||
let pdu = match db.rooms.get_pdu(&id)? {
|
||||
let pdu = match services().rooms.timeline.get_pdu(&id)? {
|
||||
Some(pdu) => pdu,
|
||||
None => {
|
||||
error!("Pdu in state not found: {}", id);
|
||||
|
@ -515,14 +554,14 @@ async fn sync_helper(
|
|||
continue;
|
||||
}
|
||||
|
||||
if !db.rooms.lazy_load_was_sent_before(
|
||||
if !services().rooms.lazy_loading.lazy_load_was_sent_before(
|
||||
&sender_user,
|
||||
&sender_device,
|
||||
&room_id,
|
||||
&event.sender,
|
||||
)? || lazy_load_send_redundant
|
||||
{
|
||||
if let Some(member_event) = db.rooms.room_state_get(
|
||||
if let Some(member_event) = services().rooms.state_accessor.room_state_get(
|
||||
&room_id,
|
||||
&StateEventType::RoomMember,
|
||||
event.sender.as_str(),
|
||||
|
@ -533,7 +572,7 @@ async fn sync_helper(
|
|||
}
|
||||
}
|
||||
|
||||
db.rooms.lazy_load_mark_sent(
|
||||
services().rooms.lazy_loading.lazy_load_mark_sent(
|
||||
&sender_user,
|
||||
&sender_device,
|
||||
&room_id,
|
||||
|
@ -541,14 +580,17 @@ async fn sync_helper(
|
|||
next_batch,
|
||||
);
|
||||
|
||||
let encrypted_room = db
|
||||
let encrypted_room = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")?
|
||||
.is_some();
|
||||
|
||||
let since_encryption =
|
||||
db.rooms
|
||||
.state_get(since_shortstatehash, &StateEventType::RoomEncryption, "")?;
|
||||
let since_encryption = services().rooms.state_accessor.state_get(
|
||||
since_shortstatehash,
|
||||
&StateEventType::RoomEncryption,
|
||||
"",
|
||||
)?;
|
||||
|
||||
// Calculations:
|
||||
let new_encrypted_room = encrypted_room && since_encryption.is_none();
|
||||
|
@ -580,7 +622,7 @@ async fn sync_helper(
|
|||
match new_membership {
|
||||
MembershipState::Join => {
|
||||
// A new user joined an encrypted room
|
||||
if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? {
|
||||
if !share_encrypted_room(&sender_user, &user_id, &room_id)? {
|
||||
device_list_updates.insert(user_id);
|
||||
}
|
||||
}
|
||||
|
@ -597,7 +639,9 @@ async fn sync_helper(
|
|||
if joined_since_last_sync && encrypted_room || new_encrypted_room {
|
||||
// If the user is in a new encrypted room, give them all joined users
|
||||
device_list_updates.extend(
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_members(&room_id)
|
||||
.flatten()
|
||||
.filter(|user_id| {
|
||||
|
@ -606,8 +650,7 @@ async fn sync_helper(
|
|||
})
|
||||
.filter(|user_id| {
|
||||
// Only send keys if the sender doesn't share an encrypted room with the target already
|
||||
!share_encrypted_room(&db, &sender_user, user_id, &room_id)
|
||||
.unwrap_or(false)
|
||||
!share_encrypted_room(&sender_user, user_id, &room_id).unwrap_or(false)
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
@ -629,14 +672,17 @@ async fn sync_helper(
|
|||
|
||||
// Look for device list updates in this room
|
||||
device_list_updates.extend(
|
||||
db.users
|
||||
services()
|
||||
.users
|
||||
.keys_changed(&room_id.to_string(), since, None)
|
||||
.filter_map(|r| r.ok()),
|
||||
);
|
||||
|
||||
let notification_count = if send_notification_counts {
|
||||
Some(
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.user
|
||||
.notification_count(&sender_user, &room_id)?
|
||||
.try_into()
|
||||
.expect("notification count can't go that high"),
|
||||
|
@ -647,7 +693,9 @@ async fn sync_helper(
|
|||
|
||||
let highlight_count = if send_notification_counts {
|
||||
Some(
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.user
|
||||
.highlight_count(&sender_user, &room_id)?
|
||||
.try_into()
|
||||
.expect("highlight count can't go that high"),
|
||||
|
@ -659,7 +707,9 @@ async fn sync_helper(
|
|||
let prev_batch = timeline_pdus
|
||||
.first()
|
||||
.map_or(Ok::<_, Error>(None), |(pdu_id, _)| {
|
||||
Ok(Some(db.rooms.pdu_count(pdu_id)?.to_string()))
|
||||
Ok(Some(
|
||||
services().rooms.timeline.pdu_count(pdu_id)?.to_string(),
|
||||
))
|
||||
})?;
|
||||
|
||||
let room_events: Vec<_> = timeline_pdus
|
||||
|
@ -667,18 +717,19 @@ async fn sync_helper(
|
|||
.map(|(_, pdu)| pdu.to_sync_room_event())
|
||||
.collect();
|
||||
|
||||
let mut edus: Vec<_> = db
|
||||
let mut edus: Vec<_> = services()
|
||||
.rooms
|
||||
.edus
|
||||
.read_receipt
|
||||
.readreceipts_since(&room_id, since)
|
||||
.filter_map(|r| r.ok()) // Filter out buggy events
|
||||
.map(|(_, _, v)| v)
|
||||
.collect();
|
||||
|
||||
if db.rooms.edus.last_typing_update(&room_id, &db.globals)? > since {
|
||||
if services().rooms.edus.typing.last_typing_update(&room_id)? > since {
|
||||
edus.push(
|
||||
serde_json::from_str(
|
||||
&serde_json::to_string(&db.rooms.edus.typings_all(&room_id)?)
|
||||
&serde_json::to_string(&services().rooms.edus.typing.typings_all(&room_id)?)
|
||||
.expect("event is valid, we just created it"),
|
||||
)
|
||||
.expect("event is valid, we just created it"),
|
||||
|
@ -686,12 +737,15 @@ async fn sync_helper(
|
|||
}
|
||||
|
||||
// Save the state after this sync so we can send the correct state diff next sync
|
||||
db.rooms
|
||||
.associate_token_shortstatehash(&room_id, next_batch, current_shortstatehash)?;
|
||||
services().rooms.user.associate_token_shortstatehash(
|
||||
&room_id,
|
||||
next_batch,
|
||||
current_shortstatehash,
|
||||
)?;
|
||||
|
||||
let joined_room = JoinedRoom {
|
||||
account_data: RoomAccountData {
|
||||
events: db
|
||||
events: services()
|
||||
.account_data
|
||||
.changes_since(Some(&room_id), &sender_user, since)?
|
||||
.into_iter()
|
||||
|
@ -730,10 +784,11 @@ async fn sync_helper(
|
|||
}
|
||||
|
||||
// Take presence updates from this room
|
||||
for (user_id, presence) in
|
||||
db.rooms
|
||||
.edus
|
||||
.presence_since(&room_id, since, &db.rooms, &db.globals)?
|
||||
for (user_id, presence) in services()
|
||||
.rooms
|
||||
.edus
|
||||
.presence
|
||||
.presence_since(&room_id, since)?
|
||||
{
|
||||
match presence_updates.entry(user_id) {
|
||||
Entry::Vacant(v) => {
|
||||
|
@ -765,14 +820,19 @@ async fn sync_helper(
|
|||
}
|
||||
|
||||
let mut left_rooms = BTreeMap::new();
|
||||
let all_left_rooms: Vec<_> = db.rooms.rooms_left(&sender_user).collect();
|
||||
let all_left_rooms: Vec<_> = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_left(&sender_user)
|
||||
.collect();
|
||||
for result in all_left_rooms {
|
||||
let (room_id, left_state_events) = result?;
|
||||
|
||||
{
|
||||
// Get and drop the lock to wait for remaining operations to finish
|
||||
let mutex_insert = Arc::clone(
|
||||
db.globals
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_insert
|
||||
.write()
|
||||
.unwrap()
|
||||
|
@ -783,7 +843,10 @@ async fn sync_helper(
|
|||
drop(insert_lock);
|
||||
}
|
||||
|
||||
let left_count = db.rooms.get_left_count(&room_id, &sender_user)?;
|
||||
let left_count = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.get_left_count(&room_id, &sender_user)?;
|
||||
|
||||
// Left before last sync
|
||||
if Some(since) >= left_count {
|
||||
|
@ -807,14 +870,19 @@ async fn sync_helper(
|
|||
}
|
||||
|
||||
let mut invited_rooms = BTreeMap::new();
|
||||
let all_invited_rooms: Vec<_> = db.rooms.rooms_invited(&sender_user).collect();
|
||||
let all_invited_rooms: Vec<_> = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_invited(&sender_user)
|
||||
.collect();
|
||||
for result in all_invited_rooms {
|
||||
let (room_id, invite_state_events) = result?;
|
||||
|
||||
{
|
||||
// Get and drop the lock to wait for remaining operations to finish
|
||||
let mutex_insert = Arc::clone(
|
||||
db.globals
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_insert
|
||||
.write()
|
||||
.unwrap()
|
||||
|
@ -825,7 +893,10 @@ async fn sync_helper(
|
|||
drop(insert_lock);
|
||||
}
|
||||
|
||||
let invite_count = db.rooms.get_invite_count(&room_id, &sender_user)?;
|
||||
let invite_count = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.get_invite_count(&room_id, &sender_user)?;
|
||||
|
||||
// Invited before last sync
|
||||
if Some(since) >= invite_count {
|
||||
|
@ -843,13 +914,16 @@ async fn sync_helper(
|
|||
}
|
||||
|
||||
for user_id in left_encrypted_users {
|
||||
let still_share_encrypted_room = db
|
||||
let still_share_encrypted_room = services()
|
||||
.rooms
|
||||
.user
|
||||
.get_shared_rooms(vec![sender_user.clone(), user_id.clone()])?
|
||||
.filter_map(|r| r.ok())
|
||||
.filter_map(|other_room_id| {
|
||||
Some(
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&other_room_id, &StateEventType::RoomEncryption, "")
|
||||
.ok()?
|
||||
.is_some(),
|
||||
|
@ -864,7 +938,8 @@ async fn sync_helper(
|
|||
}
|
||||
|
||||
// Remove all to-device events the device received *last time*
|
||||
db.users
|
||||
services()
|
||||
.users
|
||||
.remove_to_device_events(&sender_user, &sender_device, since)?;
|
||||
|
||||
let response = sync_events::v3::Response {
|
||||
|
@ -882,7 +957,7 @@ async fn sync_helper(
|
|||
.collect(),
|
||||
},
|
||||
account_data: GlobalAccountData {
|
||||
events: db
|
||||
events: services()
|
||||
.account_data
|
||||
.changes_since(None, &sender_user, since)?
|
||||
.into_iter()
|
||||
|
@ -897,9 +972,11 @@ async fn sync_helper(
|
|||
changed: device_list_updates.into_iter().collect(),
|
||||
left: device_list_left.into_iter().collect(),
|
||||
},
|
||||
device_one_time_keys_count: db.users.count_one_time_keys(&sender_user, &sender_device)?,
|
||||
device_one_time_keys_count: services()
|
||||
.users
|
||||
.count_one_time_keys(&sender_user, &sender_device)?,
|
||||
to_device: ToDevice {
|
||||
events: db
|
||||
events: services()
|
||||
.users
|
||||
.get_to_device_events(&sender_user, &sender_device)?,
|
||||
},
|
||||
|
@ -928,21 +1005,22 @@ async fn sync_helper(
|
|||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(db))]
|
||||
fn share_encrypted_room(
|
||||
db: &Database,
|
||||
sender_user: &UserId,
|
||||
user_id: &UserId,
|
||||
ignore_room: &RoomId,
|
||||
) -> Result<bool> {
|
||||
Ok(db
|
||||
Ok(services()
|
||||
.rooms
|
||||
.user
|
||||
.get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])?
|
||||
.filter_map(|r| r.ok())
|
||||
.filter(|room_id| room_id != ignore_room)
|
||||
.filter_map(|other_room_id| {
|
||||
Some(
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&other_room_id, &StateEventType::RoomEncryption, "")
|
||||
.ok()?
|
||||
.is_some(),
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, Result, Ruma};
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::tag::{create_tag, delete_tag, get_tags},
|
||||
events::{
|
||||
|
@ -14,38 +14,41 @@ use std::collections::BTreeMap;
|
|||
///
|
||||
/// - Inserts the tag into the tag event of the room account data.
|
||||
pub async fn update_tag_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<create_tag::v3::IncomingRequest>,
|
||||
) -> Result<create_tag::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let mut tags_event = db
|
||||
.account_data
|
||||
.get(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
)?
|
||||
.unwrap_or_else(|| TagEvent {
|
||||
content: TagEventContent {
|
||||
tags: BTreeMap::new(),
|
||||
},
|
||||
});
|
||||
let event = services().account_data.get(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
)?;
|
||||
|
||||
let mut tags_event = event
|
||||
.map(|e| {
|
||||
serde_json::from_str(e.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
Ok(TagEvent {
|
||||
content: TagEventContent {
|
||||
tags: BTreeMap::new(),
|
||||
},
|
||||
})
|
||||
})?;
|
||||
|
||||
tags_event
|
||||
.content
|
||||
.tags
|
||||
.insert(body.tag.clone().into(), body.tag_info.clone());
|
||||
|
||||
db.account_data.update(
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
&tags_event,
|
||||
&db.globals,
|
||||
&serde_json::to_value(tags_event).expect("to json value always works"),
|
||||
)?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(create_tag::v3::Response {})
|
||||
}
|
||||
|
||||
|
@ -55,34 +58,37 @@ pub async fn update_tag_route(
|
|||
///
|
||||
/// - Removes the tag from the tag event of the room account data.
|
||||
pub async fn delete_tag_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<delete_tag::v3::IncomingRequest>,
|
||||
) -> Result<delete_tag::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let mut tags_event = db
|
||||
.account_data
|
||||
.get(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
)?
|
||||
.unwrap_or_else(|| TagEvent {
|
||||
content: TagEventContent {
|
||||
tags: BTreeMap::new(),
|
||||
},
|
||||
});
|
||||
tags_event.content.tags.remove(&body.tag.clone().into());
|
||||
|
||||
db.account_data.update(
|
||||
let event = services().account_data.get(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
&tags_event,
|
||||
&db.globals,
|
||||
)?;
|
||||
|
||||
db.flush()?;
|
||||
let mut tags_event = event
|
||||
.map(|e| {
|
||||
serde_json::from_str(e.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
Ok(TagEvent {
|
||||
content: TagEventContent {
|
||||
tags: BTreeMap::new(),
|
||||
},
|
||||
})
|
||||
})?;
|
||||
|
||||
tags_event.content.tags.remove(&body.tag.clone().into());
|
||||
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
&serde_json::to_value(tags_event).expect("to json value always works"),
|
||||
)?;
|
||||
|
||||
Ok(delete_tag::v3::Response {})
|
||||
}
|
||||
|
@ -93,25 +99,30 @@ pub async fn delete_tag_route(
|
|||
///
|
||||
/// - Gets the tag event of the room account data.
|
||||
pub async fn get_tags_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_tags::v3::IncomingRequest>,
|
||||
) -> Result<get_tags::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
Ok(get_tags::v3::Response {
|
||||
tags: db
|
||||
.account_data
|
||||
.get(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
)?
|
||||
.unwrap_or_else(|| TagEvent {
|
||||
let event = services().account_data.get(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
)?;
|
||||
|
||||
let tags_event = event
|
||||
.map(|e| {
|
||||
serde_json::from_str(e.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
Ok(TagEvent {
|
||||
content: TagEventContent {
|
||||
tags: BTreeMap::new(),
|
||||
},
|
||||
})
|
||||
.content
|
||||
.tags,
|
||||
})?;
|
||||
|
||||
Ok(get_tags::v3::Response {
|
||||
tags: tags_event.content.tags,
|
||||
})
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
use ruma::events::ToDeviceEventType;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::{database::DatabaseGuard, Error, Result, Ruma};
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::{error::ErrorKind, to_device::send_event_to_device},
|
||||
|
@ -14,14 +14,13 @@ use ruma::{
|
|||
///
|
||||
/// Send a to-device event to a set of client devices.
|
||||
pub async fn send_event_to_device_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<send_event_to_device::v3::IncomingRequest>,
|
||||
) -> Result<send_event_to_device::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_deref();
|
||||
|
||||
// Check if this is a new transaction id
|
||||
if db
|
||||
if services()
|
||||
.transaction_ids
|
||||
.existing_txnid(sender_user, sender_device, &body.txn_id)?
|
||||
.is_some()
|
||||
|
@ -31,13 +30,13 @@ pub async fn send_event_to_device_route(
|
|||
|
||||
for (target_user_id, map) in &body.messages {
|
||||
for (target_device_id_maybe, event) in map {
|
||||
if target_user_id.server_name() != db.globals.server_name() {
|
||||
if target_user_id.server_name() != services().globals.server_name() {
|
||||
let mut map = BTreeMap::new();
|
||||
map.insert(target_device_id_maybe.clone(), event.clone());
|
||||
let mut messages = BTreeMap::new();
|
||||
messages.insert(target_user_id.clone(), map);
|
||||
|
||||
db.sending.send_reliable_edu(
|
||||
services().sending.send_reliable_edu(
|
||||
target_user_id.server_name(),
|
||||
serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(
|
||||
DirectDeviceContent {
|
||||
|
@ -48,27 +47,28 @@ pub async fn send_event_to_device_route(
|
|||
},
|
||||
))
|
||||
.expect("DirectToDevice EDU can be serialized"),
|
||||
db.globals.next_count()?,
|
||||
services().globals.next_count()?,
|
||||
)?;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
match target_device_id_maybe {
|
||||
DeviceIdOrAllDevices::DeviceId(target_device_id) => db.users.add_to_device_event(
|
||||
sender_user,
|
||||
target_user_id,
|
||||
&target_device_id,
|
||||
&body.event_type,
|
||||
event.deserialize_as().map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
|
||||
})?,
|
||||
&db.globals,
|
||||
)?,
|
||||
DeviceIdOrAllDevices::DeviceId(target_device_id) => {
|
||||
services().users.add_to_device_event(
|
||||
sender_user,
|
||||
target_user_id,
|
||||
&target_device_id,
|
||||
&body.event_type,
|
||||
event.deserialize_as().map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
|
||||
})?,
|
||||
)?
|
||||
}
|
||||
|
||||
DeviceIdOrAllDevices::AllDevices => {
|
||||
for target_device_id in db.users.all_device_ids(target_user_id) {
|
||||
db.users.add_to_device_event(
|
||||
for target_device_id in services().users.all_device_ids(target_user_id) {
|
||||
services().users.add_to_device_event(
|
||||
sender_user,
|
||||
target_user_id,
|
||||
&target_device_id?,
|
||||
|
@ -76,7 +76,6 @@ pub async fn send_event_to_device_route(
|
|||
event.deserialize_as().map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
|
||||
})?,
|
||||
&db.globals,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
@ -85,10 +84,9 @@ pub async fn send_event_to_device_route(
|
|||
}
|
||||
|
||||
// Save transaction id with empty data
|
||||
db.transaction_ids
|
||||
services()
|
||||
.transaction_ids
|
||||
.add_txnid(sender_user, sender_device, &body.txn_id, &[])?;
|
||||
|
||||
db.flush()?;
|
||||
|
||||
Ok(send_event_to_device::v3::Response {})
|
||||
}
|
|
@ -1,18 +1,21 @@
|
|||
use crate::{database::DatabaseGuard, utils, Error, Result, Ruma};
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
use ruma::api::client::{error::ErrorKind, typing::create_typing_event};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}`
|
||||
///
|
||||
/// Sets the typing state of the sender user.
|
||||
pub async fn create_typing_event_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<create_typing_event::v3::IncomingRequest>,
|
||||
) -> Result<create_typing_event::v3::Response> {
|
||||
use create_typing_event::v3::Typing;
|
||||
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !db.rooms.is_joined(sender_user, &body.room_id)? {
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_joined(sender_user, &body.room_id)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You are not in this room.",
|
||||
|
@ -20,16 +23,17 @@ pub async fn create_typing_event_route(
|
|||
}
|
||||
|
||||
if let Typing::Yes(duration) = body.state {
|
||||
db.rooms.edus.typing_add(
|
||||
services().rooms.edus.typing.typing_add(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
duration.as_millis() as u64 + utils::millis_since_unix_epoch(),
|
||||
&db.globals,
|
||||
)?;
|
||||
} else {
|
||||
db.rooms
|
||||
services()
|
||||
.rooms
|
||||
.edus
|
||||
.typing_remove(sender_user, &body.room_id, &db.globals)?;
|
||||
.typing
|
||||
.typing_remove(sender_user, &body.room_id)?;
|
||||
}
|
||||
|
||||
Ok(create_typing_event::v3::Response {})
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, Result, Ruma};
|
||||
use crate::{services, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::user_directory::search_users,
|
||||
events::{
|
||||
|
@ -14,20 +14,19 @@ use ruma::{
|
|||
/// - Hides any local users that aren't in any public rooms (i.e. those that have the join rule set to public)
|
||||
/// and don't share a room with the sender
|
||||
pub async fn search_users_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<search_users::v3::IncomingRequest>,
|
||||
) -> Result<search_users::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let limit = u64::from(body.limit) as usize;
|
||||
|
||||
let mut users = db.users.iter().filter_map(|user_id| {
|
||||
let mut users = services().users.iter().filter_map(|user_id| {
|
||||
// Filter out buggy users (they should not exist, but you never know...)
|
||||
let user_id = user_id.ok()?;
|
||||
|
||||
let user = search_users::v3::User {
|
||||
user_id: user_id.clone(),
|
||||
display_name: db.users.displayname(&user_id).ok()?,
|
||||
avatar_url: db.users.avatar_url(&user_id).ok()?,
|
||||
display_name: services().users.displayname(&user_id).ok()?,
|
||||
avatar_url: services().users.avatar_url(&user_id).ok()?,
|
||||
};
|
||||
|
||||
let user_id_matches = user
|
||||
|
@ -49,29 +48,33 @@ pub async fn search_users_route(
|
|||
return None;
|
||||
}
|
||||
|
||||
let user_is_in_public_rooms =
|
||||
db.rooms
|
||||
.rooms_joined(&user_id)
|
||||
.filter_map(|r| r.ok())
|
||||
.any(|room| {
|
||||
db.rooms
|
||||
.room_state_get(&room, &StateEventType::RoomJoinRules, "")
|
||||
.map_or(false, |event| {
|
||||
event.map_or(false, |event| {
|
||||
serde_json::from_str(event.content.get())
|
||||
.map_or(false, |r: RoomJoinRulesEventContent| {
|
||||
r.join_rule == JoinRule::Public
|
||||
})
|
||||
})
|
||||
let user_is_in_public_rooms = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(&user_id)
|
||||
.filter_map(|r| r.ok())
|
||||
.any(|room| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room, &StateEventType::RoomJoinRules, "")
|
||||
.map_or(false, |event| {
|
||||
event.map_or(false, |event| {
|
||||
serde_json::from_str(event.content.get())
|
||||
.map_or(false, |r: RoomJoinRulesEventContent| {
|
||||
r.join_rule == JoinRule::Public
|
||||
})
|
||||
})
|
||||
});
|
||||
})
|
||||
});
|
||||
|
||||
if user_is_in_public_rooms {
|
||||
return Some(user);
|
||||
}
|
||||
|
||||
let user_is_in_shared_rooms = db
|
||||
let user_is_in_shared_rooms = services()
|
||||
.rooms
|
||||
.user
|
||||
.get_shared_rooms(vec![sender_user.clone(), user_id.clone()])
|
||||
.ok()?
|
||||
.next()
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{database::DatabaseGuard, Result, Ruma};
|
||||
use crate::{services, Result, Ruma};
|
||||
use hmac::{Hmac, Mac, NewMac};
|
||||
use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch};
|
||||
use sha1::Sha1;
|
||||
|
@ -10,16 +10,15 @@ type HmacSha1 = Hmac<Sha1>;
|
|||
///
|
||||
/// TODO: Returns information about the recommended turn server.
|
||||
pub async fn turn_server_route(
|
||||
db: DatabaseGuard,
|
||||
body: Ruma<get_turn_server_info::v3::IncomingRequest>,
|
||||
) -> Result<get_turn_server_info::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let turn_secret = db.globals.turn_secret();
|
||||
let turn_secret = services().globals.turn_secret().clone();
|
||||
|
||||
let (username, password) = if !turn_secret.is_empty() {
|
||||
let expiry = SecondsSinceUnixEpoch::from_system_time(
|
||||
SystemTime::now() + Duration::from_secs(db.globals.turn_ttl()),
|
||||
SystemTime::now() + Duration::from_secs(services().globals.turn_ttl()),
|
||||
)
|
||||
.expect("time is valid");
|
||||
|
||||
|
@ -34,15 +33,15 @@ pub async fn turn_server_route(
|
|||
(username, password)
|
||||
} else {
|
||||
(
|
||||
db.globals.turn_username().clone(),
|
||||
db.globals.turn_password().clone(),
|
||||
services().globals.turn_username().clone(),
|
||||
services().globals.turn_password().clone(),
|
||||
)
|
||||
};
|
||||
|
||||
Ok(get_turn_server_info::v3::Response {
|
||||
username,
|
||||
password,
|
||||
uris: db.globals.turn_uris().to_vec(),
|
||||
ttl: Duration::from_secs(db.globals.turn_ttl()),
|
||||
uris: services().globals.turn_uris().to_vec(),
|
||||
ttl: Duration::from_secs(services().globals.turn_ttl()),
|
||||
})
|
||||
}
|
4
src/api/mod.rs
Normal file
4
src/api/mod.rs
Normal file
|
@ -0,0 +1,4 @@
|
|||
pub mod appservice_server;
|
||||
pub mod client_server;
|
||||
pub mod ruma_wrapper;
|
||||
pub mod server_server;
|
|
@ -24,7 +24,7 @@ use serde::Deserialize;
|
|||
use tracing::{debug, error, warn};
|
||||
|
||||
use super::{Ruma, RumaResponse};
|
||||
use crate::{database::DatabaseGuard, server_server, Error, Result};
|
||||
use crate::{services, Error, Result};
|
||||
|
||||
#[async_trait]
|
||||
impl<T, B> FromRequest<B> for Ruma<T>
|
||||
|
@ -44,7 +44,6 @@ where
|
|||
}
|
||||
|
||||
let metadata = T::METADATA;
|
||||
let db = DatabaseGuard::from_request(req).await?;
|
||||
let auth_header = Option::<TypedHeader<Authorization<Bearer>>>::from_request(req).await?;
|
||||
let path_params = Path::<Vec<String>>::from_request(req).await?;
|
||||
|
||||
|
@ -71,7 +70,7 @@ where
|
|||
|
||||
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
||||
|
||||
let appservices = db.appservice.all().unwrap();
|
||||
let appservices = services().appservice.all().unwrap();
|
||||
let appservice_registration = appservices.iter().find(|(_id, registration)| {
|
||||
registration
|
||||
.get("as_token")
|
||||
|
@ -91,14 +90,14 @@ where
|
|||
.unwrap()
|
||||
.as_str()
|
||||
.unwrap(),
|
||||
db.globals.server_name(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.unwrap()
|
||||
},
|
||||
|s| UserId::parse(s).unwrap(),
|
||||
);
|
||||
|
||||
if !db.users.exists(&user_id).unwrap() {
|
||||
if !services().users.exists(&user_id).unwrap() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"User does not exist.",
|
||||
|
@ -124,7 +123,7 @@ where
|
|||
}
|
||||
};
|
||||
|
||||
match db.users.find_from_token(token).unwrap() {
|
||||
match services().users.find_from_token(token).unwrap() {
|
||||
None => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UnknownToken { soft_logout: false },
|
||||
|
@ -185,7 +184,7 @@ where
|
|||
(
|
||||
"destination".to_owned(),
|
||||
CanonicalJsonValue::String(
|
||||
db.globals.server_name().as_str().to_owned(),
|
||||
services().globals.server_name().as_str().to_owned(),
|
||||
),
|
||||
),
|
||||
(
|
||||
|
@ -198,12 +197,11 @@ where
|
|||
request_map.insert("content".to_owned(), json_body.clone());
|
||||
};
|
||||
|
||||
let keys_result = server_server::fetch_signing_keys(
|
||||
&db,
|
||||
&x_matrix.origin,
|
||||
vec![x_matrix.key.to_owned()],
|
||||
)
|
||||
.await;
|
||||
let keys_result = services()
|
||||
.rooms
|
||||
.event_handler
|
||||
.fetch_signing_keys(&x_matrix.origin, vec![x_matrix.key.to_owned()])
|
||||
.await;
|
||||
|
||||
let keys = match keys_result {
|
||||
Ok(b) => b,
|
||||
|
@ -251,7 +249,7 @@ where
|
|||
|
||||
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
|
||||
let user_id = sender_user.clone().unwrap_or_else(|| {
|
||||
UserId::parse_with_server_name("", db.globals.server_name())
|
||||
UserId::parse_with_server_name("", services().globals.server_name())
|
||||
.expect("we know this is valid")
|
||||
});
|
||||
|
||||
|
@ -261,7 +259,7 @@ where
|
|||
.and_then(|auth| auth.get("session"))
|
||||
.and_then(|session| session.as_str())
|
||||
.and_then(|session| {
|
||||
db.uiaa.get_uiaa_request(
|
||||
services().uiaa.get_uiaa_request(
|
||||
&user_id,
|
||||
&sender_device.clone().unwrap_or_else(|| "".into()),
|
||||
session,
|
1815
src/api/server_server.rs
Normal file
1815
src/api/server_server.rs
Normal file
File diff suppressed because it is too large
Load diff
1017
src/database.rs
1017
src/database.rs
File diff suppressed because it is too large
Load diff
|
@ -26,11 +26,11 @@ pub mod persy;
|
|||
))]
|
||||
pub mod watchers;
|
||||
|
||||
pub trait DatabaseEngine: Send + Sync {
|
||||
pub trait KeyValueDatabaseEngine: Send + Sync {
|
||||
fn open(config: &Config) -> Result<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn Tree>>;
|
||||
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>>;
|
||||
fn flush(&self) -> Result<()>;
|
||||
fn cleanup(&self) -> Result<()> {
|
||||
Ok(())
|
||||
|
@ -40,7 +40,7 @@ pub trait DatabaseEngine: Send + Sync {
|
|||
}
|
||||
}
|
||||
|
||||
pub trait Tree: Send + Sync {
|
||||
pub trait KvTree: Send + Sync {
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
|
||||
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{super::Config, watchers::Watchers, DatabaseEngine, Tree};
|
||||
use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||
use crate::{utils, Result};
|
||||
use std::{
|
||||
future::Future,
|
||||
|
@ -51,7 +51,7 @@ fn db_options(max_open_files: i32, rocksdb_cache: &rocksdb::Cache) -> rocksdb::O
|
|||
db_opts
|
||||
}
|
||||
|
||||
impl DatabaseEngine for Arc<Engine> {
|
||||
impl KeyValueDatabaseEngine for Arc<Engine> {
|
||||
fn open(config: &Config) -> Result<Self> {
|
||||
let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize;
|
||||
let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes).unwrap();
|
||||
|
@ -83,7 +83,7 @@ impl DatabaseEngine for Arc<Engine> {
|
|||
}))
|
||||
}
|
||||
|
||||
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn Tree>> {
|
||||
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>> {
|
||||
if !self.old_cfs.contains(&name.to_owned()) {
|
||||
// Create if it didn't exist
|
||||
let _ = self
|
||||
|
@ -129,7 +129,7 @@ impl RocksDbEngineTree<'_> {
|
|||
}
|
||||
}
|
||||
|
||||
impl Tree for RocksDbEngineTree<'_> {
|
||||
impl KvTree for RocksDbEngineTree<'_> {
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
Ok(self.db.rocks.get_cf(&self.cf(), key)?)
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{watchers::Watchers, DatabaseEngine, Tree};
|
||||
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||
use crate::{database::Config, Result};
|
||||
use parking_lot::{Mutex, MutexGuard};
|
||||
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
|
||||
|
@ -80,7 +80,7 @@ impl Engine {
|
|||
}
|
||||
}
|
||||
|
||||
impl DatabaseEngine for Arc<Engine> {
|
||||
impl KeyValueDatabaseEngine for Arc<Engine> {
|
||||
fn open(config: &Config) -> Result<Self> {
|
||||
let path = Path::new(&config.database_path).join("conduit.db");
|
||||
|
||||
|
@ -105,7 +105,7 @@ impl DatabaseEngine for Arc<Engine> {
|
|||
Ok(arc)
|
||||
}
|
||||
|
||||
fn open_tree(&self, name: &str) -> Result<Arc<dyn Tree>> {
|
||||
fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> {
|
||||
self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name), [])?;
|
||||
|
||||
Ok(Arc::new(SqliteTable {
|
||||
|
@ -189,7 +189,7 @@ impl SqliteTable {
|
|||
}
|
||||
}
|
||||
|
||||
impl Tree for SqliteTable {
|
||||
impl KvTree for SqliteTable {
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
self.get_with_guard(self.engine.read_lock(), key)
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,30 +1,23 @@
|
|||
use crate::{utils, Error, Result};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use ruma::{
|
||||
api::client::error::ErrorKind,
|
||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
pub struct AccountData {
|
||||
pub(super) roomuserdataid_accountdata: Arc<dyn Tree>, // RoomUserDataId = Room + User + Count + Type
|
||||
pub(super) roomusertype_roomuserdataid: Arc<dyn Tree>, // RoomUserType = Room + User + Type
|
||||
}
|
||||
|
||||
impl AccountData {
|
||||
impl service::account_data::Data for KeyValueDatabase {
|
||||
/// Places one event in the account data of the user and removes the previous entry.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type, data, globals))]
|
||||
pub fn update<T: Serialize>(
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
|
||||
fn update(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
data: &T,
|
||||
globals: &super::globals::Globals,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id
|
||||
.map(|r| r.to_string())
|
||||
|
@ -36,15 +29,14 @@ impl AccountData {
|
|||
prefix.push(0xff);
|
||||
|
||||
let mut roomuserdataid = prefix.clone();
|
||||
roomuserdataid.extend_from_slice(&globals.next_count()?.to_be_bytes());
|
||||
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
roomuserdataid.push(0xff);
|
||||
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
|
||||
|
||||
let mut key = prefix;
|
||||
key.extend_from_slice(event_type.to_string().as_bytes());
|
||||
|
||||
let json = serde_json::to_value(data).expect("all types here can be serialized"); // TODO: maybe add error handling
|
||||
if json.get("type").is_none() || json.get("content").is_none() {
|
||||
if data.get("type").is_none() || data.get("content").is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Account data doesn't have all required fields.",
|
||||
|
@ -53,7 +45,7 @@ impl AccountData {
|
|||
|
||||
self.roomuserdataid_accountdata.insert(
|
||||
&roomuserdataid,
|
||||
&serde_json::to_vec(&json).expect("to_vec always works on json values"),
|
||||
&serde_json::to_vec(&data).expect("to_vec always works on json values"),
|
||||
)?;
|
||||
|
||||
let prev = self.roomusertype_roomuserdataid.get(&key)?;
|
||||
|
@ -71,12 +63,12 @@ impl AccountData {
|
|||
|
||||
/// Searches the account data for a specific kind.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, kind))]
|
||||
pub fn get<T: DeserializeOwned>(
|
||||
fn get(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
kind: RoomAccountDataEventType,
|
||||
) -> Result<Option<T>> {
|
||||
) -> Result<Option<Box<serde_json::value::RawValue>>> {
|
||||
let mut key = room_id
|
||||
.map(|r| r.to_string())
|
||||
.unwrap_or_default()
|
||||
|
@ -104,7 +96,7 @@ impl AccountData {
|
|||
|
||||
/// Returns all changes to the account data that happened after `since`.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
||||
pub fn changes_since(
|
||||
fn changes_since(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
|
@ -1,20 +1,8 @@
|
|||
use crate::{utils, Error, Result};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, RwLock},
|
||||
};
|
||||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub struct Appservice {
|
||||
pub(super) cached_registrations: Arc<RwLock<HashMap<String, serde_yaml::Value>>>,
|
||||
pub(super) id_appserviceregistrations: Arc<dyn Tree>,
|
||||
}
|
||||
|
||||
impl Appservice {
|
||||
impl service::appservice::Data for KeyValueDatabase {
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
///
|
||||
pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<String> {
|
||||
fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<String> {
|
||||
// TODO: Rumaify
|
||||
let id = yaml.get("id").unwrap().as_str().unwrap();
|
||||
self.id_appserviceregistrations.insert(
|
||||
|
@ -34,7 +22,7 @@ impl Appservice {
|
|||
/// # Arguments
|
||||
///
|
||||
/// * `service_name` - the name you send to register the service previously
|
||||
pub fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
||||
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
||||
self.id_appserviceregistrations
|
||||
.remove(service_name.as_bytes())?;
|
||||
self.cached_registrations
|
||||
|
@ -44,7 +32,7 @@ impl Appservice {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_registration(&self, id: &str) -> Result<Option<serde_yaml::Value>> {
|
||||
fn get_registration(&self, id: &str) -> Result<Option<serde_yaml::Value>> {
|
||||
self.cached_registrations
|
||||
.read()
|
||||
.unwrap()
|
||||
|
@ -66,14 +54,17 @@ impl Appservice {
|
|||
)
|
||||
}
|
||||
|
||||
pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> {
|
||||
Ok(self.id_appserviceregistrations.iter().map(|(id, _)| {
|
||||
utils::string_from_bytes(&id)
|
||||
.map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations."))
|
||||
}))
|
||||
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
|
||||
Ok(Box::new(self.id_appserviceregistrations.iter().map(
|
||||
|(id, _)| {
|
||||
utils::string_from_bytes(&id).map_err(|_| {
|
||||
Error::bad_database("Invalid id bytes in id_appserviceregistrations.")
|
||||
})
|
||||
},
|
||||
)))
|
||||
}
|
||||
|
||||
pub fn all(&self) -> Result<Vec<(String, serde_yaml::Value)>> {
|
||||
fn all(&self) -> Result<Vec<(String, serde_yaml::Value)>> {
|
||||
self.iter_ids()?
|
||||
.filter_map(|id| id.ok())
|
||||
.map(move |id| {
|
235
src/database/key_value/globals.rs
Normal file
235
src/database/key_value/globals.rs
Normal file
|
@ -0,0 +1,235 @@
|
|||
use std::collections::BTreeMap;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{stream::FuturesUnordered, StreamExt};
|
||||
use ruma::{
|
||||
api::federation::discovery::{ServerSigningKeys, VerifyKey},
|
||||
signatures::Ed25519KeyPair,
|
||||
DeviceId, MilliSecondsSinceUnixEpoch, ServerName, ServerSigningKeyId, UserId,
|
||||
};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
pub const COUNTER: &[u8] = b"c";
|
||||
|
||||
#[async_trait]
|
||||
impl service::globals::Data for KeyValueDatabase {
|
||||
fn next_count(&self) -> Result<u64> {
|
||||
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
}
|
||||
|
||||
fn current_count(&self) -> Result<u64> {
|
||||
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
})
|
||||
}
|
||||
|
||||
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||
let userid_bytes = user_id.as_bytes().to_vec();
|
||||
let mut userid_prefix = userid_bytes.clone();
|
||||
userid_prefix.push(0xff);
|
||||
|
||||
let mut userdeviceid_prefix = userid_prefix.clone();
|
||||
userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
|
||||
userdeviceid_prefix.push(0xff);
|
||||
|
||||
let mut futures = FuturesUnordered::new();
|
||||
|
||||
// Return when *any* user changed his key
|
||||
// TODO: only send for user they share a room with
|
||||
futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix));
|
||||
|
||||
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
|
||||
futures.push(
|
||||
self.userroomid_notificationcount
|
||||
.watch_prefix(&userid_prefix),
|
||||
);
|
||||
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
|
||||
|
||||
// Events for rooms we are in
|
||||
for room_id in services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(user_id)
|
||||
.filter_map(|r| r.ok())
|
||||
{
|
||||
let short_roomid = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(&room_id)
|
||||
.ok()
|
||||
.flatten()
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
|
||||
let roomid_bytes = room_id.as_bytes().to_vec();
|
||||
let mut roomid_prefix = roomid_bytes.clone();
|
||||
roomid_prefix.push(0xff);
|
||||
|
||||
// PDUs
|
||||
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
|
||||
|
||||
// EDUs
|
||||
futures.push(self.roomid_lasttypingupdate.watch_prefix(&roomid_bytes));
|
||||
|
||||
futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix));
|
||||
|
||||
// Key changes
|
||||
futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix));
|
||||
|
||||
// Room account data
|
||||
let mut roomuser_prefix = roomid_prefix.clone();
|
||||
roomuser_prefix.extend_from_slice(&userid_prefix);
|
||||
|
||||
futures.push(
|
||||
self.roomusertype_roomuserdataid
|
||||
.watch_prefix(&roomuser_prefix),
|
||||
);
|
||||
}
|
||||
|
||||
let mut globaluserdata_prefix = vec![0xff];
|
||||
globaluserdata_prefix.extend_from_slice(&userid_prefix);
|
||||
|
||||
futures.push(
|
||||
self.roomusertype_roomuserdataid
|
||||
.watch_prefix(&globaluserdata_prefix),
|
||||
);
|
||||
|
||||
// More key changes (used when user is not joined to any rooms)
|
||||
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
|
||||
|
||||
// One time keys
|
||||
futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes));
|
||||
|
||||
futures.push(Box::pin(services().globals.rotate.watch()));
|
||||
|
||||
// Wait until one of them finds something
|
||||
futures.next().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> Result<()> {
|
||||
self._db.cleanup()
|
||||
}
|
||||
|
||||
fn memory_usage(&self) -> Result<String> {
|
||||
self._db.memory_usage()
|
||||
}
|
||||
|
||||
fn load_keypair(&self) -> Result<Ed25519KeyPair> {
|
||||
let keypair_bytes = self.global.get(b"keypair")?.map_or_else(
|
||||
|| {
|
||||
let keypair = utils::generate_keypair();
|
||||
self.global.insert(b"keypair", &keypair)?;
|
||||
Ok::<_, Error>(keypair)
|
||||
},
|
||||
|s| Ok(s.to_vec()),
|
||||
)?;
|
||||
|
||||
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff);
|
||||
|
||||
let keypair = utils::string_from_bytes(
|
||||
// 1. version
|
||||
parts
|
||||
.next()
|
||||
.expect("splitn always returns at least one element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
|
||||
.and_then(|version| {
|
||||
// 2. key
|
||||
parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid keypair format in database."))
|
||||
.map(|key| (version, key))
|
||||
})
|
||||
.and_then(|(version, key)| {
|
||||
Ed25519KeyPair::from_der(key, version)
|
||||
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
|
||||
});
|
||||
|
||||
keypair
|
||||
}
|
||||
fn remove_keypair(&self) -> Result<()> {
|
||||
self.global.remove(b"keypair")
|
||||
}
|
||||
|
||||
fn add_signing_key(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<Box<ServerSigningKeyId>, VerifyKey>> {
|
||||
// Not atomic, but this is not critical
|
||||
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
|
||||
|
||||
let mut keys = signingkeys
|
||||
.and_then(|keys| serde_json::from_slice(&keys).ok())
|
||||
.unwrap_or_else(|| {
|
||||
// Just insert "now", it doesn't matter
|
||||
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
|
||||
});
|
||||
|
||||
let ServerSigningKeys {
|
||||
verify_keys,
|
||||
old_verify_keys,
|
||||
..
|
||||
} = new_keys;
|
||||
|
||||
keys.verify_keys.extend(verify_keys.into_iter());
|
||||
keys.old_verify_keys.extend(old_verify_keys.into_iter());
|
||||
|
||||
self.server_signingkeys.insert(
|
||||
origin.as_bytes(),
|
||||
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
|
||||
)?;
|
||||
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(
|
||||
keys.old_verify_keys
|
||||
.into_iter()
|
||||
.map(|old| (old.0, VerifyKey::new(old.1.key))),
|
||||
);
|
||||
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
|
||||
fn signing_keys_for(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
) -> Result<BTreeMap<Box<ServerSigningKeyId>, VerifyKey>> {
|
||||
let signingkeys = self
|
||||
.server_signingkeys
|
||||
.get(origin.as_bytes())?
|
||||
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
|
||||
.map(|keys: ServerSigningKeys| {
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(
|
||||
keys.old_verify_keys
|
||||
.into_iter()
|
||||
.map(|old| (old.0, VerifyKey::new(old.1.key))),
|
||||
);
|
||||
tree
|
||||
})
|
||||
.unwrap_or_else(BTreeMap::new);
|
||||
|
||||
Ok(signingkeys)
|
||||
}
|
||||
|
||||
fn database_version(&self) -> Result<u64> {
|
||||
self.global.get(b"version")?.map_or(Ok(0), |version| {
|
||||
utils::u64_from_bytes(&version)
|
||||
.map_err(|_| Error::bad_database("Database version id is invalid."))
|
||||
})
|
||||
}
|
||||
|
||||
fn bump_database_version(&self, new_version: u64) -> Result<()> {
|
||||
self.global.insert(b"version", &new_version.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
use crate::{utils, Error, Result};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use ruma::{
|
||||
api::client::{
|
||||
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||
|
@ -7,24 +8,16 @@ use ruma::{
|
|||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
use std::{collections::BTreeMap, sync::Arc};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
pub struct KeyBackups {
|
||||
pub(super) backupid_algorithm: Arc<dyn Tree>, // BackupId = UserId + Version(Count)
|
||||
pub(super) backupid_etag: Arc<dyn Tree>, // BackupId = UserId + Version(Count)
|
||||
pub(super) backupkeyid_backup: Arc<dyn Tree>, // BackupKeyId = UserId + Version + RoomId + SessionId
|
||||
}
|
||||
|
||||
impl KeyBackups {
|
||||
pub fn create_backup(
|
||||
impl service::key_backups::Data for KeyValueDatabase {
|
||||
fn create_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<String> {
|
||||
let version = globals.next_count()?.to_string();
|
||||
let version = services().globals.next_count()?.to_string();
|
||||
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
|
@ -35,11 +28,11 @@ impl KeyBackups {
|
|||
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
|
||||
)?;
|
||||
self.backupid_etag
|
||||
.insert(&key, &globals.next_count()?.to_be_bytes())?;
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
Ok(version)
|
||||
}
|
||||
|
||||
pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
@ -56,12 +49,11 @@ impl KeyBackups {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn update_backup(
|
||||
fn update_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<String> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
|
@ -77,11 +69,11 @@ impl KeyBackups {
|
|||
self.backupid_algorithm
|
||||
.insert(&key, backup_metadata.json().get().as_bytes())?;
|
||||
self.backupid_etag
|
||||
.insert(&key, &globals.next_count()?.to_be_bytes())?;
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
Ok(version.to_owned())
|
||||
}
|
||||
|
||||
pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let mut last_possible_key = prefix.clone();
|
||||
|
@ -102,7 +94,7 @@ impl KeyBackups {
|
|||
.transpose()
|
||||
}
|
||||
|
||||
pub fn get_latest_backup(
|
||||
fn get_latest_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||
|
@ -133,11 +125,7 @@ impl KeyBackups {
|
|||
.transpose()
|
||||
}
|
||||
|
||||
pub fn get_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
@ -150,14 +138,13 @@ impl KeyBackups {
|
|||
})
|
||||
}
|
||||
|
||||
pub fn add_key(
|
||||
fn add_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
key_data: &Raw<KeyBackupData>,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
|
@ -171,7 +158,7 @@ impl KeyBackups {
|
|||
}
|
||||
|
||||
self.backupid_etag
|
||||
.insert(&key, &globals.next_count()?.to_be_bytes())?;
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
|
@ -184,7 +171,7 @@ impl KeyBackups {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
|
||||
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
|
@ -192,7 +179,7 @@ impl KeyBackups {
|
|||
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
|
||||
}
|
||||
|
||||
pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
||||
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
@ -207,7 +194,7 @@ impl KeyBackups {
|
|||
.to_string())
|
||||
}
|
||||
|
||||
pub fn get_all(
|
||||
fn get_all(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
|
@ -263,7 +250,7 @@ impl KeyBackups {
|
|||
Ok(rooms)
|
||||
}
|
||||
|
||||
pub fn get_room(
|
||||
fn get_room(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
|
@ -300,7 +287,7 @@ impl KeyBackups {
|
|||
.collect())
|
||||
}
|
||||
|
||||
pub fn get_session(
|
||||
fn get_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
|
@ -325,7 +312,7 @@ impl KeyBackups {
|
|||
.transpose()
|
||||
}
|
||||
|
||||
pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
@ -338,12 +325,7 @@ impl KeyBackups {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn delete_room_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
) -> Result<()> {
|
||||
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
@ -358,7 +340,7 @@ impl KeyBackups {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn delete_room_key(
|
||||
fn delete_room_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
82
src/database/key_value/media.rs
Normal file
82
src/database/key_value/media.rs
Normal file
|
@ -0,0 +1,82 @@
|
|||
use ruma::api::client::error::ErrorKind;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
impl service::media::Data for KeyValueDatabase {
|
||||
fn create_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>> {
|
||||
let mut key = mxc.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(&width.to_be_bytes());
|
||||
key.extend_from_slice(&height.to_be_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_disposition
|
||||
.as_ref()
|
||||
.map(|f| f.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_type
|
||||
.as_ref()
|
||||
.map(|c| c.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
|
||||
self.mediaid_file.insert(&key, &[])?;
|
||||
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
fn search_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(&width.to_be_bytes());
|
||||
prefix.extend_from_slice(&height.to_be_bytes());
|
||||
prefix.push(0xff);
|
||||
|
||||
let (key, _) = self
|
||||
.mediaid_file
|
||||
.scan_prefix(prefix)
|
||||
.next()
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
|
||||
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
|
||||
let content_type = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes).map_err(|_| {
|
||||
Error::bad_database("Content type in mediaid_file is invalid unicode.")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let content_disposition_bytes = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
||||
|
||||
let content_disposition = if content_disposition_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
utils::string_from_bytes(content_disposition_bytes).map_err(|_| {
|
||||
Error::bad_database("Content Disposition in mediaid_file is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
};
|
||||
Ok((content_disposition, content_type, key))
|
||||
}
|
||||
}
|
13
src/database/key_value/mod.rs
Normal file
13
src/database/key_value/mod.rs
Normal file
|
@ -0,0 +1,13 @@
|
|||
mod account_data;
|
||||
//mod admin;
|
||||
mod appservice;
|
||||
mod globals;
|
||||
mod key_backups;
|
||||
mod media;
|
||||
//mod pdu;
|
||||
mod pusher;
|
||||
mod rooms;
|
||||
mod sending;
|
||||
mod transaction_ids;
|
||||
mod uiaa;
|
||||
mod users;
|
81
src/database/key_value/pusher.rs
Normal file
81
src/database/key_value/pusher.rs
Normal file
|
@ -0,0 +1,81 @@
|
|||
use ruma::{
|
||||
api::client::push::{get_pushers, set_pusher},
|
||||
UserId,
|
||||
};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
impl service::pusher::Data for KeyValueDatabase {
|
||||
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> {
|
||||
let mut key = sender.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(pusher.pushkey.as_bytes());
|
||||
|
||||
// There are 2 kinds of pushers but the spec says: null deletes the pusher.
|
||||
if pusher.kind.is_none() {
|
||||
return self
|
||||
.senderkey_pusher
|
||||
.remove(&key)
|
||||
.map(|_| ())
|
||||
.map_err(Into::into);
|
||||
}
|
||||
|
||||
self.senderkey_pusher.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_pusher(
|
||||
&self,
|
||||
sender: &UserId,
|
||||
pushkey: &str,
|
||||
) -> Result<Option<get_pushers::v3::Pusher>> {
|
||||
let mut senderkey = sender.as_bytes().to_vec();
|
||||
senderkey.push(0xff);
|
||||
senderkey.extend_from_slice(pushkey.as_bytes());
|
||||
|
||||
self.senderkey_pusher
|
||||
.get(&senderkey)?
|
||||
.map(|push| {
|
||||
serde_json::from_slice(&*push)
|
||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_pushers(&self, sender: &UserId) -> Result<Vec<get_pushers::v3::Pusher>> {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
self.senderkey_pusher
|
||||
.scan_prefix(prefix)
|
||||
.map(|(_, push)| {
|
||||
serde_json::from_slice(&*push)
|
||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn get_pushkeys<'a>(
|
||||
&'a self,
|
||||
sender: &UserId,
|
||||
) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
|
||||
let mut parts = k.splitn(2, |&b| b == 0xff);
|
||||
let _senderkey = parts.next();
|
||||
let push_key = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
|
||||
let push_key_string = utils::string_from_bytes(push_key)
|
||||
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
|
||||
|
||||
Ok(push_key_string)
|
||||
}))
|
||||
}
|
||||
}
|
60
src/database/key_value/rooms/alias.rs
Normal file
60
src/database/key_value/rooms/alias.rs
Normal file
|
@ -0,0 +1,60 @@
|
|||
use ruma::{api::client::error::ErrorKind, RoomAliasId, RoomId};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::alias::Data for KeyValueDatabase {
|
||||
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
|
||||
self.alias_roomid
|
||||
.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
|
||||
let mut aliasid = room_id.as_bytes().to_vec();
|
||||
aliasid.push(0xff);
|
||||
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
self.aliasid_alias.insert(&aliasid, &*alias.as_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
|
||||
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
|
||||
let mut prefix = room_id.to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
|
||||
self.aliasid_alias.remove(&key)?;
|
||||
}
|
||||
self.alias_roomid.remove(alias.alias().as_bytes())?;
|
||||
} else {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Alias does not exist.",
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<Box<RoomId>>> {
|
||||
self.alias_roomid
|
||||
.get(alias.alias().as_bytes())?
|
||||
.map(|bytes| {
|
||||
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Room ID in alias_roomid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn local_aliases_for_room<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<Box<RoomAliasId>>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))
|
||||
}))
|
||||
}
|
||||
}
|
61
src/database/key_value/rooms/auth_chain.rs
Normal file
61
src/database/key_value/rooms/auth_chain.rs
Normal file
|
@ -0,0 +1,61 @@
|
|||
use std::{collections::HashSet, mem::size_of, sync::Arc};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, utils, Result};
|
||||
|
||||
impl service::rooms::auth_chain::Data for KeyValueDatabase {
|
||||
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> {
|
||||
// Check RAM cache
|
||||
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
|
||||
return Ok(Some(Arc::clone(result)));
|
||||
}
|
||||
|
||||
// We only save auth chains for single events in the db
|
||||
if key.len() == 1 {
|
||||
// Check DB cache
|
||||
let chain = self
|
||||
.shorteventid_authchain
|
||||
.get(&key[0].to_be_bytes())?
|
||||
.map(|chain| {
|
||||
chain
|
||||
.chunks_exact(size_of::<u64>())
|
||||
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
|
||||
.collect()
|
||||
});
|
||||
|
||||
if let Some(chain) = chain {
|
||||
let chain = Arc::new(chain);
|
||||
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(vec![key[0]], Arc::clone(&chain));
|
||||
|
||||
return Ok(Some(chain));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> {
|
||||
// Only persist single events in db
|
||||
if key.len() == 1 {
|
||||
self.shorteventid_authchain.insert(
|
||||
&key[0].to_be_bytes(),
|
||||
&auth_chain
|
||||
.iter()
|
||||
.flat_map(|s| s.to_be_bytes().to_vec())
|
||||
.collect::<Vec<u8>>(),
|
||||
)?;
|
||||
}
|
||||
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(key, auth_chain);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
28
src/database/key_value/rooms/directory.rs
Normal file
28
src/database/key_value/rooms/directory.rs
Normal file
|
@ -0,0 +1,28 @@
|
|||
use ruma::RoomId;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
impl service::rooms::directory::Data for KeyValueDatabase {
|
||||
fn set_public(&self, room_id: &RoomId) -> Result<()> {
|
||||
self.publicroomids.insert(room_id.as_bytes(), &[])
|
||||
}
|
||||
|
||||
fn set_not_public(&self, room_id: &RoomId) -> Result<()> {
|
||||
self.publicroomids.remove(room_id.as_bytes())
|
||||
}
|
||||
|
||||
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
|
||||
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
|
||||
}
|
||||
|
||||
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<Box<RoomId>>> + 'a> {
|
||||
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Room ID in publicroomids is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
|
||||
}))
|
||||
}
|
||||
}
|
7
src/database/key_value/rooms/edus/mod.rs
Normal file
7
src/database/key_value/rooms/edus/mod.rs
Normal file
|
@ -0,0 +1,7 @@
|
|||
mod presence;
|
||||
mod read_receipt;
|
||||
mod typing;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service};
|
||||
|
||||
impl service::rooms::edus::Data for KeyValueDatabase {}
|
150
src/database/key_value/rooms/edus/presence.rs
Normal file
150
src/database/key_value/rooms/edus/presence.rs
Normal file
|
@ -0,0 +1,150 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use ruma::{events::presence::PresenceEvent, presence::PresenceState, RoomId, UInt, UserId};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::edus::presence::Data for KeyValueDatabase {
|
||||
fn update_presence(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
presence: PresenceEvent,
|
||||
) -> Result<()> {
|
||||
// TODO: Remove old entry? Or maybe just wipe completely from time to time?
|
||||
|
||||
let count = services().globals.next_count()?.to_be_bytes();
|
||||
|
||||
let mut presence_id = room_id.as_bytes().to_vec();
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(&count);
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(presence.sender.as_bytes());
|
||||
|
||||
self.presenceid_presence.insert(
|
||||
&presence_id,
|
||||
&serde_json::to_vec(&presence).expect("PresenceEvent can be serialized"),
|
||||
)?;
|
||||
|
||||
self.userid_lastpresenceupdate.insert(
|
||||
user_id.as_bytes(),
|
||||
&utils::millis_since_unix_epoch().to_be_bytes(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ping_presence(&self, user_id: &UserId) -> Result<()> {
|
||||
self.userid_lastpresenceupdate.insert(
|
||||
user_id.as_bytes(),
|
||||
&utils::millis_since_unix_epoch().to_be_bytes(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn last_presence_update(&self, user_id: &UserId) -> Result<Option<u64>> {
|
||||
self.userid_lastpresenceupdate
|
||||
.get(user_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid timestamp in userid_lastpresenceupdate.")
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_presence_event(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
count: u64,
|
||||
) -> Result<Option<PresenceEvent>> {
|
||||
let mut presence_id = room_id.as_bytes().to_vec();
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(&count.to_be_bytes());
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.presenceid_presence
|
||||
.get(&presence_id)?
|
||||
.map(|value| parse_presence_event(&value))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn presence_since(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
) -> Result<HashMap<Box<UserId>, PresenceEvent>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let mut first_possible_edu = prefix.clone();
|
||||
first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
|
||||
let mut hashmap = HashMap::new();
|
||||
|
||||
for (key, value) in self
|
||||
.presenceid_presence
|
||||
.iter_from(&*first_possible_edu, false)
|
||||
.take_while(|(key, _)| key.starts_with(&prefix))
|
||||
{
|
||||
let user_id = UserId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid UserId bytes in presenceid_presence."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid UserId in presenceid_presence."))?;
|
||||
|
||||
let presence = parse_presence_event(&value)?;
|
||||
|
||||
hashmap.insert(user_id, presence);
|
||||
}
|
||||
|
||||
Ok(hashmap)
|
||||
}
|
||||
|
||||
/*
|
||||
fn presence_maintain(&self, db: Arc<TokioRwLock<Database>>) {
|
||||
// TODO @M0dEx: move this to a timed tasks module
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
select! {
|
||||
Some(user_id) = self.presence_timers.next() {
|
||||
// TODO @M0dEx: would it be better to acquire the lock outside the loop?
|
||||
let guard = db.read().await;
|
||||
|
||||
// TODO @M0dEx: add self.presence_timers
|
||||
// TODO @M0dEx: maintain presence
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
fn parse_presence_event(bytes: &[u8]) -> Result<PresenceEvent> {
|
||||
let mut presence: PresenceEvent = serde_json::from_slice(bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid presence event in db."))?;
|
||||
|
||||
let current_timestamp: UInt = utils::millis_since_unix_epoch()
|
||||
.try_into()
|
||||
.expect("time is valid");
|
||||
|
||||
if presence.content.presence == PresenceState::Online {
|
||||
// Don't set last_active_ago when the user is online
|
||||
presence.content.last_active_ago = None;
|
||||
} else {
|
||||
// Convert from timestamp to duration
|
||||
presence.content.last_active_ago = presence
|
||||
.content
|
||||
.last_active_ago
|
||||
.map(|timestamp| current_timestamp - timestamp);
|
||||
}
|
||||
|
||||
Ok(presence)
|
||||
}
|
150
src/database/key_value/rooms/edus/read_receipt.rs
Normal file
150
src/database/key_value/rooms/edus/read_receipt.rs
Normal file
|
@ -0,0 +1,150 @@
|
|||
use std::mem;
|
||||
|
||||
use ruma::{
|
||||
events::receipt::ReceiptEvent, serde::Raw, signatures::CanonicalJsonObject, RoomId, UserId,
|
||||
};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
|
||||
fn readreceipt_update(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
event: ReceiptEvent,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
// Remove old entry
|
||||
if let Some((old, _)) = self
|
||||
.readreceiptid_readreceipt
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(|(key, _)| key.starts_with(&prefix))
|
||||
.find(|(key, _)| {
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element")
|
||||
== user_id.as_bytes()
|
||||
})
|
||||
{
|
||||
// This is the old room_latest
|
||||
self.readreceiptid_readreceipt.remove(&old)?;
|
||||
}
|
||||
|
||||
let mut room_latest_id = prefix;
|
||||
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
room_latest_id.push(0xff);
|
||||
room_latest_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.readreceiptid_readreceipt.insert(
|
||||
&room_latest_id,
|
||||
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn readreceipts_since<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
) -> Box<
|
||||
dyn Iterator<
|
||||
Item = Result<(
|
||||
Box<UserId>,
|
||||
u64,
|
||||
Raw<ruma::events::AnySyncEphemeralRoomEvent>,
|
||||
)>,
|
||||
> + 'a,
|
||||
> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let prefix2 = prefix.clone();
|
||||
|
||||
let mut first_possible_edu = prefix.clone();
|
||||
first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
|
||||
|
||||
Box::new(
|
||||
self.readreceiptid_readreceipt
|
||||
.iter_from(&first_possible_edu, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||
.map(move |(k, v)| {
|
||||
let count = utils::u64_from_bytes(
|
||||
&k[prefix.len()..prefix.len() + mem::size_of::<u64>()],
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
|
||||
let user_id = UserId::parse(
|
||||
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid readreceiptid userid bytes in db.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
|
||||
|
||||
let mut json =
|
||||
serde_json::from_slice::<CanonicalJsonObject>(&v).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Read receipt in roomlatestid_roomlatest is invalid json.",
|
||||
)
|
||||
})?;
|
||||
json.remove("room_id");
|
||||
|
||||
Ok((
|
||||
user_id,
|
||||
count,
|
||||
Raw::from_json(
|
||||
serde_json::value::to_raw_value(&json)
|
||||
.expect("json is valid raw value"),
|
||||
),
|
||||
))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_privateread
|
||||
.insert(&key, &count.to_be_bytes())?;
|
||||
|
||||
self.roomuserid_lastprivatereadupdate
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())
|
||||
}
|
||||
|
||||
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_privateread
|
||||
.get(&key)?
|
||||
.map_or(Ok(None), |v| {
|
||||
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
|
||||
Error::bad_database("Invalid private read marker bytes")
|
||||
})?))
|
||||
})
|
||||
}
|
||||
|
||||
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
Ok(self
|
||||
.roomuserid_lastprivatereadupdate
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
}
|
86
src/database/key_value/rooms/edus/typing.rs
Normal file
86
src/database/key_value/rooms/edus/typing.rs
Normal file
|
@ -0,0 +1,86 @@
|
|||
use std::collections::HashSet;
|
||||
|
||||
use ruma::{RoomId, UserId};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
||||
fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let count = services().globals.next_count()?.to_be_bytes();
|
||||
|
||||
let mut room_typing_id = prefix;
|
||||
room_typing_id.extend_from_slice(&timeout.to_be_bytes());
|
||||
room_typing_id.push(0xff);
|
||||
room_typing_id.extend_from_slice(&count);
|
||||
|
||||
self.typingid_userid
|
||||
.insert(&room_typing_id, &*user_id.as_bytes())?;
|
||||
|
||||
self.roomid_lasttypingupdate
|
||||
.insert(room_id.as_bytes(), &count)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let user_id = user_id.to_string();
|
||||
|
||||
let mut found_outdated = false;
|
||||
|
||||
// Maybe there are multiple ones from calling roomtyping_add multiple times
|
||||
for outdated_edu in self
|
||||
.typingid_userid
|
||||
.scan_prefix(prefix)
|
||||
.filter(|(_, v)| &**v == user_id.as_bytes())
|
||||
{
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate.insert(
|
||||
room_id.as_bytes(),
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> {
|
||||
Ok(self
|
||||
.roomid_lasttypingupdate
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
|
||||
fn typings_all(&self, room_id: &RoomId) -> Result<HashSet<Box<UserId>>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let mut user_ids = HashSet::new();
|
||||
|
||||
for (_, user_id) in self.typingid_userid.scan_prefix(prefix) {
|
||||
let user_id = UserId::parse(utils::string_from_bytes(&user_id).map_err(|_| {
|
||||
Error::bad_database("User ID in typingid_userid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?;
|
||||
|
||||
user_ids.insert(user_id);
|
||||
}
|
||||
|
||||
Ok(user_ids)
|
||||
}
|
||||
}
|
65
src/database/key_value/rooms/lazy_load.rs
Normal file
65
src/database/key_value/rooms/lazy_load.rs
Normal file
|
@ -0,0 +1,65 @@
|
|||
use ruma::{DeviceId, RoomId, UserId};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, Result};
|
||||
|
||||
impl service::rooms::lazy_loading::Data for KeyValueDatabase {
|
||||
fn lazy_load_was_sent_before(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
ll_user: &UserId,
|
||||
) -> Result<bool> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(device_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(ll_user.as_bytes());
|
||||
Ok(self.lazyloadedids.get(&key)?.is_some())
|
||||
}
|
||||
|
||||
fn lazy_load_confirm_delivery(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
|
||||
) -> Result<()> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
|
||||
for ll_id in confirmed_user_ids {
|
||||
let mut key = prefix.clone();
|
||||
key.extend_from_slice(ll_id.as_bytes());
|
||||
self.lazyloadedids.insert(&key, &[])?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn lazy_load_reset(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
) -> Result<()> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
|
||||
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
|
||||
self.lazyloadedids.remove(&key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
45
src/database/key_value/rooms/metadata.rs
Normal file
45
src/database/key_value/rooms/metadata.rs
Normal file
|
@ -0,0 +1,45 @@
|
|||
use ruma::RoomId;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::metadata::Data for KeyValueDatabase {
|
||||
fn exists(&self, room_id: &RoomId) -> Result<bool> {
|
||||
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
|
||||
Some(b) => b.to_be_bytes().to_vec(),
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
// Look for PDUs in that room.
|
||||
Ok(self
|
||||
.pduid_pdu
|
||||
.iter_from(&prefix, false)
|
||||
.next()
|
||||
.filter(|(k, _)| k.starts_with(&prefix))
|
||||
.is_some())
|
||||
}
|
||||
|
||||
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<Box<RoomId>>> + 'a> {
|
||||
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Room ID in publicroomids is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
|
||||
}))
|
||||
}
|
||||
|
||||
fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
|
||||
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
|
||||
}
|
||||
|
||||
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
|
||||
if disabled {
|
||||
self.disabledroomids.insert(room_id.as_bytes(), &[])?;
|
||||
} else {
|
||||
self.disabledroomids.remove(room_id.as_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
20
src/database/key_value/rooms/mod.rs
Normal file
20
src/database/key_value/rooms/mod.rs
Normal file
|
@ -0,0 +1,20 @@
|
|||
mod alias;
|
||||
mod auth_chain;
|
||||
mod directory;
|
||||
mod edus;
|
||||
mod lazy_load;
|
||||
mod metadata;
|
||||
mod outlier;
|
||||
mod pdu_metadata;
|
||||
mod search;
|
||||
mod short;
|
||||
mod state;
|
||||
mod state_accessor;
|
||||
mod state_cache;
|
||||
mod state_compressor;
|
||||
mod timeline;
|
||||
mod user;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service};
|
||||
|
||||
impl service::rooms::Data for KeyValueDatabase {}
|
28
src/database/key_value/rooms/outlier.rs
Normal file
28
src/database/key_value/rooms/outlier.rs
Normal file
|
@ -0,0 +1,28 @@
|
|||
use ruma::{signatures::CanonicalJsonObject, EventId};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result};
|
||||
|
||||
impl service::rooms::outlier::Data for KeyValueDatabase {
|
||||
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or(Ok(None), |pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
}
|
||||
|
||||
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or(Ok(None), |pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
}
|
||||
|
||||
fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
|
||||
self.eventid_outlierpdu.insert(
|
||||
event_id.as_bytes(),
|
||||
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),
|
||||
)
|
||||
}
|
||||
}
|
33
src/database/key_value/rooms/pdu_metadata.rs
Normal file
33
src/database/key_value/rooms/pdu_metadata.rs
Normal file
|
@ -0,0 +1,33 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use ruma::{EventId, RoomId};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, Result};
|
||||
|
||||
impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
|
||||
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
|
||||
for prev in event_ids {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.extend_from_slice(prev.as_bytes());
|
||||
self.referencedevents.insert(&key, &[])?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.extend_from_slice(event_id.as_bytes());
|
||||
Ok(self.referencedevents.get(&key)?.is_some())
|
||||
}
|
||||
|
||||
fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
|
||||
self.softfailedeventids.insert(event_id.as_bytes(), &[])
|
||||
}
|
||||
|
||||
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
|
||||
self.softfailedeventids
|
||||
.get(event_id.as_bytes())
|
||||
.map(|o| o.is_some())
|
||||
}
|
||||
}
|
75
src/database/key_value/rooms/search.rs
Normal file
75
src/database/key_value/rooms/search.rs
Normal file
|
@ -0,0 +1,75 @@
|
|||
use std::mem::size_of;
|
||||
|
||||
use ruma::RoomId;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Result};
|
||||
|
||||
impl service::rooms::search::Data for KeyValueDatabase {
|
||||
fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
|
||||
let mut batch = message_body
|
||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty())
|
||||
.filter(|word| word.len() <= 50)
|
||||
.map(str::to_lowercase)
|
||||
.map(|word| {
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(word.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(&pdu_id);
|
||||
(key, Vec::new())
|
||||
});
|
||||
|
||||
self.tokenids.insert_batch(&mut batch)
|
||||
}
|
||||
|
||||
fn search_pdus<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
search_string: &str,
|
||||
) -> Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>> {
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
let prefix_clone = prefix.clone();
|
||||
|
||||
let words: Vec<_> = search_string
|
||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(str::to_lowercase)
|
||||
.collect();
|
||||
|
||||
let iterators = words.clone().into_iter().map(move |word| {
|
||||
let mut prefix2 = prefix.clone();
|
||||
prefix2.extend_from_slice(word.as_bytes());
|
||||
prefix2.push(0xff);
|
||||
|
||||
let mut last_possible_id = prefix2.clone();
|
||||
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
self.tokenids
|
||||
.iter_from(&last_possible_id, true) // Newest pdus first
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||
.map(|(key, _)| key[key.len() - size_of::<u64>()..].to_vec())
|
||||
});
|
||||
|
||||
let common_elements = match utils::common_elements(iterators, |a, b| {
|
||||
// We compare b with a because we reversed the iterator earlier
|
||||
b.cmp(a)
|
||||
}) {
|
||||
Some(it) => it,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let mapped = common_elements.map(move |id| {
|
||||
let mut pduid = prefix_clone.clone();
|
||||
pduid.extend_from_slice(&id);
|
||||
pduid
|
||||
});
|
||||
|
||||
Ok(Some((Box::new(mapped), words)))
|
||||
}
|
||||
}
|
218
src/database/key_value/rooms/short.rs
Normal file
218
src/database/key_value/rooms/short.rs
Normal file
|
@ -0,0 +1,218 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use ruma::{events::StateEventType, EventId, RoomId};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::short::Data for KeyValueDatabase {
|
||||
fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
|
||||
if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) {
|
||||
return Ok(*short);
|
||||
}
|
||||
|
||||
let short = match self.eventid_shorteventid.get(event_id.as_bytes())? {
|
||||
Some(shorteventid) => utils::u64_from_bytes(&shorteventid)
|
||||
.map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
|
||||
None => {
|
||||
let shorteventid = services().globals.next_count()?;
|
||||
self.eventid_shorteventid
|
||||
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
|
||||
self.shorteventid_eventid
|
||||
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
|
||||
shorteventid
|
||||
}
|
||||
};
|
||||
|
||||
self.eventidshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(event_id.to_owned(), short);
|
||||
|
||||
Ok(short)
|
||||
}
|
||||
|
||||
fn get_shortstatekey(
|
||||
&self,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<u64>> {
|
||||
if let Some(short) = self
|
||||
.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||
{
|
||||
return Ok(Some(*short));
|
||||
}
|
||||
|
||||
let mut statekey = event_type.to_string().as_bytes().to_vec();
|
||||
statekey.push(0xff);
|
||||
statekey.extend_from_slice(state_key.as_bytes());
|
||||
|
||||
let short = self
|
||||
.statekey_shortstatekey
|
||||
.get(&statekey)?
|
||||
.map(|shortstatekey| {
|
||||
utils::u64_from_bytes(&shortstatekey)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
if let Some(s) = short {
|
||||
self.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert((event_type.clone(), state_key.to_owned()), s);
|
||||
}
|
||||
|
||||
Ok(short)
|
||||
}
|
||||
|
||||
fn get_or_create_shortstatekey(
|
||||
&self,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<u64> {
|
||||
if let Some(short) = self
|
||||
.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||
{
|
||||
return Ok(*short);
|
||||
}
|
||||
|
||||
let mut statekey = event_type.to_string().as_bytes().to_vec();
|
||||
statekey.push(0xff);
|
||||
statekey.extend_from_slice(state_key.as_bytes());
|
||||
|
||||
let short = match self.statekey_shortstatekey.get(&statekey)? {
|
||||
Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
|
||||
None => {
|
||||
let shortstatekey = services().globals.next_count()?;
|
||||
self.statekey_shortstatekey
|
||||
.insert(&statekey, &shortstatekey.to_be_bytes())?;
|
||||
self.shortstatekey_statekey
|
||||
.insert(&shortstatekey.to_be_bytes(), &statekey)?;
|
||||
shortstatekey
|
||||
}
|
||||
};
|
||||
|
||||
self.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert((event_type.clone(), state_key.to_owned()), short);
|
||||
|
||||
Ok(short)
|
||||
}
|
||||
|
||||
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
|
||||
if let Some(id) = self
|
||||
.shorteventid_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&shorteventid)
|
||||
{
|
||||
return Ok(Arc::clone(id));
|
||||
}
|
||||
|
||||
let bytes = self
|
||||
.shorteventid_eventid
|
||||
.get(&shorteventid.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
|
||||
|
||||
let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("EventID in shorteventid_eventid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
|
||||
|
||||
self.shorteventid_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(shorteventid, Arc::clone(&event_id));
|
||||
|
||||
Ok(event_id)
|
||||
}
|
||||
|
||||
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
|
||||
if let Some(id) = self
|
||||
.shortstatekey_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&shortstatekey)
|
||||
{
|
||||
return Ok(id.clone());
|
||||
}
|
||||
|
||||
let bytes = self
|
||||
.shortstatekey_statekey
|
||||
.get(&shortstatekey.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
|
||||
|
||||
let mut parts = bytes.splitn(2, |&b| b == 0xff);
|
||||
let eventtype_bytes = parts.next().expect("split always returns one entry");
|
||||
let statekey_bytes = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
|
||||
|
||||
let event_type =
|
||||
StateEventType::try_from(utils::string_from_bytes(eventtype_bytes).map_err(|_| {
|
||||
Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("Event type in shortstatekey_statekey is invalid."))?;
|
||||
|
||||
let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| {
|
||||
Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.")
|
||||
})?;
|
||||
|
||||
let result = (event_type, state_key);
|
||||
|
||||
self.shortstatekey_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(shortstatekey, result.clone());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Returns (shortstatehash, already_existed)
|
||||
fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
|
||||
Ok(match self.statehash_shortstatehash.get(state_hash)? {
|
||||
Some(shortstatehash) => (
|
||||
utils::u64_from_bytes(&shortstatehash)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
|
||||
true,
|
||||
),
|
||||
None => {
|
||||
let shortstatehash = services().globals.next_count()?;
|
||||
self.statehash_shortstatehash
|
||||
.insert(state_hash, &shortstatehash.to_be_bytes())?;
|
||||
(shortstatehash, false)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_shortroomid
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
|
||||
Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
|
||||
Some(short) => utils::u64_from_bytes(&short)
|
||||
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))?,
|
||||
None => {
|
||||
let short = services().globals.next_count()?;
|
||||
self.roomid_shortroomid
|
||||
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
|
||||
short
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
73
src/database/key_value/rooms/state.rs
Normal file
73
src/database/key_value/rooms/state.rs
Normal file
|
@ -0,0 +1,73 @@
|
|||
use ruma::{EventId, RoomId};
|
||||
use std::collections::HashSet;
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::MutexGuard;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
impl service::rooms::state::Data for KeyValueDatabase {
|
||||
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_shortstatehash
|
||||
.get(room_id.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
|
||||
})?))
|
||||
})
|
||||
}
|
||||
|
||||
fn set_room_state(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
new_shortstatehash: u64,
|
||||
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
||||
) -> Result<()> {
|
||||
self.roomid_shortstatehash
|
||||
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
|
||||
self.shorteventid_shortstatehash
|
||||
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
self.roomid_pduleaves
|
||||
.scan_prefix(prefix)
|
||||
.map(|(_, bytes)| {
|
||||
EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn set_forward_extremities<'a>(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_ids: Vec<Box<EventId>>,
|
||||
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
|
||||
self.roomid_pduleaves.remove(&key)?;
|
||||
}
|
||||
|
||||
for event_id in event_ids {
|
||||
let mut key = prefix.to_owned();
|
||||
key.extend_from_slice(event_id.as_bytes());
|
||||
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
189
src/database/key_value/rooms/state_accessor.rs
Normal file
189
src/database/key_value/rooms/state_accessor.rs
Normal file
|
@ -0,0 +1,189 @@
|
|||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
||||
use async_trait::async_trait;
|
||||
use ruma::{events::StateEventType, EventId, RoomId};
|
||||
|
||||
#[async_trait]
|
||||
impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
||||
async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> {
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
let mut result = BTreeMap::new();
|
||||
let mut i = 0;
|
||||
for compressed in full_state.into_iter() {
|
||||
let parsed = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.parse_compressed_state_event(compressed)?;
|
||||
result.insert(parsed.0, parsed.1);
|
||||
|
||||
i += 1;
|
||||
if i % 100 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn state_full(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
|
||||
let mut result = HashMap::new();
|
||||
let mut i = 0;
|
||||
for compressed in full_state {
|
||||
let (_, eventid) = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.parse_compressed_state_event(compressed)?;
|
||||
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
|
||||
result.insert(
|
||||
(
|
||||
pdu.kind.to_string().into(),
|
||||
pdu.state_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::bad_database("State event has no state key."))?
|
||||
.clone(),
|
||||
),
|
||||
pdu,
|
||||
);
|
||||
}
|
||||
|
||||
i += 1;
|
||||
if i % 100 == 0 {
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
fn state_get_id(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<Arc<EventId>>> {
|
||||
let shortstatekey = match services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortstatekey(event_type, state_key)?
|
||||
{
|
||||
Some(s) => s,
|
||||
None => return Ok(None),
|
||||
};
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(shortstatehash)?
|
||||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
Ok(full_state
|
||||
.into_iter()
|
||||
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
|
||||
.and_then(|compressed| {
|
||||
services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.parse_compressed_state_event(compressed)
|
||||
.ok()
|
||||
.map(|(_, id)| id)
|
||||
}))
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
fn state_get(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<Arc<PduEvent>>> {
|
||||
self.state_get_id(shortstatehash, event_type, state_key)?
|
||||
.map_or(Ok(None), |event_id| {
|
||||
services().rooms.timeline.get_pdu(&event_id)
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the state hash for this pdu.
|
||||
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
|
||||
self.eventid_shorteventid
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or(Ok(None), |shorteventid| {
|
||||
self.shorteventid_shortstatehash
|
||||
.get(&shorteventid)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Invalid shortstatehash bytes in shorteventid_shortstatehash",
|
||||
)
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the full room state.
|
||||
async fn room_state_full(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
if let Some(current_shortstatehash) =
|
||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
||||
{
|
||||
self.state_full(current_shortstatehash).await
|
||||
} else {
|
||||
Ok(HashMap::new())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
fn room_state_get_id(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<Arc<EventId>>> {
|
||||
if let Some(current_shortstatehash) =
|
||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
||||
{
|
||||
self.state_get_id(current_shortstatehash, event_type, state_key)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
fn room_state_get(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<Arc<PduEvent>>> {
|
||||
if let Some(current_shortstatehash) =
|
||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
||||
{
|
||||
self.state_get(current_shortstatehash, event_type, state_key)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
593
src/database/key_value/rooms/state_cache.rs
Normal file
593
src/database/key_value/rooms/state_cache.rs
Normal file
|
@ -0,0 +1,593 @@
|
|||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use regex::Regex;
|
||||
use ruma::{
|
||||
events::{AnyStrippedStateEvent, AnySyncStateEvent},
|
||||
serde::Raw,
|
||||
RoomId, ServerName, UserId,
|
||||
};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
self.roomuseroncejoinedids.insert(&userroom_id, &[])
|
||||
}
|
||||
|
||||
fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||
roomuser_id.push(0xff);
|
||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_joined.insert(&userroom_id, &[])?;
|
||||
self.roomuserid_joined.insert(&roomuser_id, &[])?;
|
||||
self.userroomid_invitestate.remove(&userroom_id)?;
|
||||
self.roomuserid_invitecount.remove(&roomuser_id)?;
|
||||
self.userroomid_leftstate.remove(&userroom_id)?;
|
||||
self.roomuserid_leftcount.remove(&roomuser_id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn mark_as_invited(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
|
||||
) -> Result<()> {
|
||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||
roomuser_id.push(0xff);
|
||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_invitestate.insert(
|
||||
&userroom_id,
|
||||
&serde_json::to_vec(&last_state.unwrap_or_default())
|
||||
.expect("state to bytes always works"),
|
||||
)?;
|
||||
self.roomuserid_invitecount.insert(
|
||||
&roomuser_id,
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
self.userroomid_joined.remove(&userroom_id)?;
|
||||
self.roomuserid_joined.remove(&roomuser_id)?;
|
||||
self.userroomid_leftstate.remove(&userroom_id)?;
|
||||
self.roomuserid_leftcount.remove(&roomuser_id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||
roomuser_id.push(0xff);
|
||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_leftstate.insert(
|
||||
&userroom_id,
|
||||
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(),
|
||||
)?; // TODO
|
||||
self.roomuserid_leftcount.insert(
|
||||
&roomuser_id,
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
self.userroomid_joined.remove(&userroom_id)?;
|
||||
self.roomuserid_joined.remove(&roomuser_id)?;
|
||||
self.userroomid_invitestate.remove(&userroom_id)?;
|
||||
self.roomuserid_invitecount.remove(&roomuser_id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn update_joined_count(&self, room_id: &RoomId) -> Result<()> {
|
||||
let mut joinedcount = 0_u64;
|
||||
let mut invitedcount = 0_u64;
|
||||
let mut joined_servers = HashSet::new();
|
||||
let mut real_users = HashSet::new();
|
||||
|
||||
for joined in self.room_members(room_id).filter_map(|r| r.ok()) {
|
||||
joined_servers.insert(joined.server_name().to_owned());
|
||||
if joined.server_name() == services().globals.server_name()
|
||||
&& !services().users.is_deactivated(&joined).unwrap_or(true)
|
||||
{
|
||||
real_users.insert(joined);
|
||||
}
|
||||
joinedcount += 1;
|
||||
}
|
||||
|
||||
for invited in self.room_members_invited(room_id).filter_map(|r| r.ok()) {
|
||||
joined_servers.insert(invited.server_name().to_owned());
|
||||
invitedcount += 1;
|
||||
}
|
||||
|
||||
self.roomid_joinedcount
|
||||
.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?;
|
||||
|
||||
self.roomid_invitedcount
|
||||
.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?;
|
||||
|
||||
self.our_real_users_cache
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(room_id.to_owned(), Arc::new(real_users));
|
||||
|
||||
self.appservice_in_room_cache
|
||||
.write()
|
||||
.unwrap()
|
||||
.remove(room_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, room_id))]
|
||||
fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<Box<UserId>>>> {
|
||||
let maybe = self
|
||||
.our_real_users_cache
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(room_id)
|
||||
.cloned();
|
||||
if let Some(users) = maybe {
|
||||
Ok(users)
|
||||
} else {
|
||||
self.update_joined_count(room_id)?;
|
||||
Ok(Arc::clone(
|
||||
self.our_real_users_cache
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(room_id)
|
||||
.unwrap(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, room_id, appservice))]
|
||||
fn appservice_in_room(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
appservice: &(String, serde_yaml::Value),
|
||||
) -> Result<bool> {
|
||||
let maybe = self
|
||||
.appservice_in_room_cache
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(room_id)
|
||||
.and_then(|map| map.get(&appservice.0))
|
||||
.copied();
|
||||
|
||||
if let Some(b) = maybe {
|
||||
Ok(b)
|
||||
} else if let Some(namespaces) = appservice.1.get("namespaces") {
|
||||
let users = namespaces
|
||||
.get("users")
|
||||
.and_then(|users| users.as_sequence())
|
||||
.map_or_else(Vec::new, |users| {
|
||||
users
|
||||
.iter()
|
||||
.filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok())
|
||||
.collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
let bridge_user_id = appservice
|
||||
.1
|
||||
.get("sender_localpart")
|
||||
.and_then(|string| string.as_str())
|
||||
.and_then(|string| {
|
||||
UserId::parse_with_server_name(string, services().globals.server_name()).ok()
|
||||
});
|
||||
|
||||
let in_room = bridge_user_id
|
||||
.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false))
|
||||
|| self.room_members(room_id).any(|userid| {
|
||||
userid.map_or(false, |userid| {
|
||||
users.iter().any(|r| r.is_match(userid.as_str()))
|
||||
})
|
||||
});
|
||||
|
||||
self.appservice_in_room_cache
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(room_id.to_owned())
|
||||
.or_default()
|
||||
.insert(appservice.0.clone(), in_room);
|
||||
|
||||
Ok(in_room)
|
||||
} else {
|
||||
Ok(false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Makes a user forget a room.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||
roomuser_id.push(0xff);
|
||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.userroomid_leftstate.remove(&userroom_id)?;
|
||||
self.roomuserid_leftcount.remove(&roomuser_id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns an iterator of all servers participating in this room.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn room_servers<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<Box<ServerName>>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| {
|
||||
ServerName::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Server name in roomserverids is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid."))
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn server_in_room<'a>(&'a self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
|
||||
let mut key = server.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.serverroomids.get(&key).map(|o| o.is_some())
|
||||
}
|
||||
|
||||
/// Returns an iterator of all rooms a server participates in (as far as we know).
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn server_rooms<'a>(
|
||||
&'a self,
|
||||
server: &ServerName,
|
||||
) -> Box<dyn Iterator<Item = Result<Box<RoomId>>> + 'a> {
|
||||
let mut prefix = server.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid."))
|
||||
}))
|
||||
}
|
||||
|
||||
/// Returns an iterator over all joined members of a room.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn room_members<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<Box<UserId>>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| {
|
||||
UserId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("User ID in roomuserid_joined is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_joinedcount
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|b| {
|
||||
utils::u64_from_bytes(&b)
|
||||
.map_err(|_| Error::bad_database("Invalid joinedcount in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_invitedcount
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|b| {
|
||||
utils::u64_from_bytes(&b)
|
||||
.map_err(|_| Error::bad_database("Invalid joinedcount in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns an iterator over all User IDs who ever joined a room.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn room_useroncejoined<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<Box<UserId>>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
Box::new(
|
||||
self.roomuseroncejoinedids
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
UserId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database(
|
||||
"User ID in room_useroncejoined is invalid unicode.",
|
||||
)
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid."))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns an iterator over all invited members of a room.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn room_members_invited<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<Box<UserId>>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
Box::new(
|
||||
self.roomuserid_invitecount
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
UserId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("User ID in roomuserid_invited is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_invitecount
|
||||
.get(&key)?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid invitecount in db.")
|
||||
})?))
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_leftcount
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid leftcount in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns an iterator over all rooms this user joined.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn rooms_joined<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
) -> Box<dyn Iterator<Item = Result<Box<RoomId>>> + 'a> {
|
||||
Box::new(
|
||||
self.userroomid_joined
|
||||
.scan_prefix(user_id.as_bytes().to_vec())
|
||||
.map(|(key, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Room ID in userroomid_joined is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns an iterator over all rooms a user was invited to.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn rooms_invited<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
) -> Box<dyn Iterator<Item = Result<(Box<RoomId>, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
Box::new(
|
||||
self.userroomid_invitestate
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, state)| {
|
||||
let room_id = RoomId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Room ID in userroomid_invited is invalid.")
|
||||
})?;
|
||||
|
||||
let state = serde_json::from_slice(&state).map_err(|_| {
|
||||
Error::bad_database("Invalid state in userroomid_invitestate.")
|
||||
})?;
|
||||
|
||||
Ok((room_id, state))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn invite_state(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_invitestate
|
||||
.get(&key)?
|
||||
.map(|state| {
|
||||
let state = serde_json::from_slice(&state)
|
||||
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
|
||||
|
||||
Ok(state)
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn left_state(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_leftstate
|
||||
.get(&key)?
|
||||
.map(|state| {
|
||||
let state = serde_json::from_slice(&state)
|
||||
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
|
||||
|
||||
Ok(state)
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns an iterator over all rooms a user left.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn rooms_left<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
) -> Box<dyn Iterator<Item = Result<(Box<RoomId>, Vec<Raw<AnySyncStateEvent>>)>> + 'a> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
Box::new(
|
||||
self.userroomid_leftstate
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, state)| {
|
||||
let room_id = RoomId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Room ID in userroomid_invited is invalid.")
|
||||
})?;
|
||||
|
||||
let state = serde_json::from_slice(&state).map_err(|_| {
|
||||
Error::bad_database("Invalid state in userroomid_leftstate.")
|
||||
})?;
|
||||
|
||||
Ok((room_id, state))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
Ok(self.userroomid_joined.get(&userroom_id)?.is_some())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
|
||||
}
|
||||
}
|
61
src/database/key_value/rooms/state_compressor.rs
Normal file
61
src/database/key_value/rooms/state_compressor.rs
Normal file
|
@ -0,0 +1,61 @@
|
|||
use std::{collections::HashSet, mem::size_of};
|
||||
|
||||
use crate::{
|
||||
database::KeyValueDatabase,
|
||||
service::{self, rooms::state_compressor::data::StateDiff},
|
||||
utils, Error, Result,
|
||||
};
|
||||
|
||||
impl service::rooms::state_compressor::Data for KeyValueDatabase {
|
||||
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
|
||||
let value = self
|
||||
.shortstatehash_statediff
|
||||
.get(&shortstatehash.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
|
||||
let parent =
|
||||
utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
|
||||
let parent = if parent != 0 { Some(parent) } else { None };
|
||||
|
||||
let mut add_mode = true;
|
||||
let mut added = HashSet::new();
|
||||
let mut removed = HashSet::new();
|
||||
|
||||
let mut i = size_of::<u64>();
|
||||
while let Some(v) = value.get(i..i + 2 * size_of::<u64>()) {
|
||||
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
|
||||
add_mode = false;
|
||||
i += size_of::<u64>();
|
||||
continue;
|
||||
}
|
||||
if add_mode {
|
||||
added.insert(v.try_into().expect("we checked the size above"));
|
||||
} else {
|
||||
removed.insert(v.try_into().expect("we checked the size above"));
|
||||
}
|
||||
i += 2 * size_of::<u64>();
|
||||
}
|
||||
|
||||
Ok(StateDiff {
|
||||
parent,
|
||||
added,
|
||||
removed,
|
||||
})
|
||||
}
|
||||
|
||||
fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> {
|
||||
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
|
||||
for new in &diff.added {
|
||||
value.extend_from_slice(&new[..]);
|
||||
}
|
||||
|
||||
if !diff.removed.is_empty() {
|
||||
value.extend_from_slice(&0_u64.to_be_bytes());
|
||||
for removed in &diff.removed {
|
||||
value.extend_from_slice(&removed[..]);
|
||||
}
|
||||
}
|
||||
|
||||
self.shortstatehash_statediff
|
||||
.insert(&shortstatehash.to_be_bytes(), &value)
|
||||
}
|
||||
}
|
371
src/database/key_value/rooms/timeline.rs
Normal file
371
src/database/key_value/rooms/timeline.rs
Normal file
|
@ -0,0 +1,371 @@
|
|||
use std::{collections::hash_map, mem::size_of, sync::Arc};
|
||||
|
||||
use ruma::{
|
||||
api::client::error::ErrorKind, signatures::CanonicalJsonObject, EventId, RoomId, UserId,
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
||||
|
||||
impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||
fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> {
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
|
||||
// Look for PDUs in that room.
|
||||
self.pduid_pdu
|
||||
.iter_from(&prefix, false)
|
||||
.filter(|(k, _)| k.starts_with(&prefix))
|
||||
.map(|(_, pdu)| {
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid first PDU in db."))
|
||||
.map(Arc::new)
|
||||
})
|
||||
.next()
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
match self
|
||||
.lasttimelinecount_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.entry(room_id.to_owned())
|
||||
{
|
||||
hash_map::Entry::Vacant(v) => {
|
||||
if let Some(last_count) = self
|
||||
.pdus_until(&sender_user, &room_id, u64::MAX)?
|
||||
.filter_map(|r| {
|
||||
// Filter out buggy events
|
||||
if r.is_err() {
|
||||
error!("Bad pdu in pdus_since: {:?}", r);
|
||||
}
|
||||
r.ok()
|
||||
})
|
||||
.map(|(pduid, _)| self.pdu_count(&pduid))
|
||||
.next()
|
||||
{
|
||||
Ok(*v.insert(last_count?))
|
||||
} else {
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
hash_map::Entry::Occupied(o) => Ok(*o.get()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the `count` of this pdu's id.
|
||||
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<u64>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pdu_id| self.pdu_count(&pdu_id))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns the json of a pdu.
|
||||
fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or_else(
|
||||
|| self.eventid_outlierpdu.get(event_id.as_bytes()),
|
||||
|pduid| {
|
||||
Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| {
|
||||
Error::bad_database("Invalid pduid in eventid_pduid.")
|
||||
})?))
|
||||
},
|
||||
)?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns the json of a pdu.
|
||||
fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pduid| {
|
||||
self.pduid_pdu
|
||||
.get(&pduid)?
|
||||
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||
})
|
||||
.transpose()?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns the pdu's id.
|
||||
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> {
|
||||
self.eventid_pduid.get(event_id.as_bytes())
|
||||
}
|
||||
|
||||
/// Returns the pdu.
|
||||
///
|
||||
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
|
||||
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pduid| {
|
||||
self.pduid_pdu
|
||||
.get(&pduid)?
|
||||
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||
})
|
||||
.transpose()?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns the pdu.
|
||||
///
|
||||
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
|
||||
fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
|
||||
if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) {
|
||||
return Ok(Some(Arc::clone(p)));
|
||||
}
|
||||
|
||||
if let Some(pdu) = self
|
||||
.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or_else(
|
||||
|| self.eventid_outlierpdu.get(event_id.as_bytes()),
|
||||
|pduid| {
|
||||
Ok(Some(self.pduid_pdu.get(&pduid)?.ok_or_else(|| {
|
||||
Error::bad_database("Invalid pduid in eventid_pduid.")
|
||||
})?))
|
||||
},
|
||||
)?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
.map(Arc::new)
|
||||
})
|
||||
.transpose()?
|
||||
{
|
||||
self.pdu_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(event_id.to_owned(), Arc::clone(&pdu));
|
||||
Ok(Some(pdu))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the pdu.
|
||||
///
|
||||
/// This does __NOT__ check the outliers `Tree`.
|
||||
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
|
||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||
Ok(Some(
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
|
||||
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||
Ok(Some(
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the `count` of this pdu's id.
|
||||
fn pdu_count(&self, pdu_id: &[u8]) -> Result<u64> {
|
||||
utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
|
||||
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))
|
||||
}
|
||||
|
||||
fn append_pdu(
|
||||
&self,
|
||||
pdu_id: &[u8],
|
||||
pdu: &PduEvent,
|
||||
json: &CanonicalJsonObject,
|
||||
count: u64,
|
||||
) -> Result<()> {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
|
||||
self.lasttimelinecount_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(pdu.room_id.clone(), count);
|
||||
|
||||
self.eventid_pduid
|
||||
.insert(pdu.event_id.as_bytes(), &pdu_id)?;
|
||||
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes a pdu and creates a new one with the same id.
|
||||
fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> {
|
||||
if self.pduid_pdu.get(pdu_id)?.is_some() {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(pdu).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PDU does not exist.",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator over all events in a room that happened after the event with id `since`
|
||||
/// in chronological order.
|
||||
fn pdus_since<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>> {
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
|
||||
// Skip the first pdu if it's exactly at since, because we sent that last time
|
||||
let mut first_pdu_id = prefix.clone();
|
||||
first_pdu_id.extend_from_slice(&(since + 1).to_be_bytes());
|
||||
|
||||
let user_id = user_id.to_owned();
|
||||
|
||||
Ok(Box::new(
|
||||
self.pduid_pdu
|
||||
.iter_from(&first_pdu_id, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(pdu_id, v)| {
|
||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
Ok((pdu_id, pdu))
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
/// Returns an iterator over all events and their tokens in a room that happened before the
|
||||
/// event with id `until` in reverse-chronological order.
|
||||
fn pdus_until<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
until: u64,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>> {
|
||||
// Create the first part of the full pdu id
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
|
||||
let mut current = prefix.clone();
|
||||
current.extend_from_slice(&(until.saturating_sub(1)).to_be_bytes()); // -1 because we don't want event at `until`
|
||||
|
||||
let current: &[u8] = ¤t;
|
||||
|
||||
let user_id = user_id.to_owned();
|
||||
|
||||
Ok(Box::new(
|
||||
self.pduid_pdu
|
||||
.iter_from(current, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(pdu_id, v)| {
|
||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
Ok((pdu_id, pdu))
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
fn pdus_after<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
from: u64,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(Vec<u8>, PduEvent)>> + 'a>> {
|
||||
// Create the first part of the full pdu id
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
|
||||
let mut current = prefix.clone();
|
||||
current.extend_from_slice(&(from + 1).to_be_bytes()); // +1 so we don't send the base event
|
||||
|
||||
let current: &[u8] = ¤t;
|
||||
|
||||
let user_id = user_id.to_owned();
|
||||
|
||||
Ok(Box::new(
|
||||
self.pduid_pdu
|
||||
.iter_from(current, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(pdu_id, v)| {
|
||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
Ok((pdu_id, pdu))
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
fn increment_notification_counts(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
notifies: Vec<Box<UserId>>,
|
||||
highlights: Vec<Box<UserId>>,
|
||||
) -> Result<()> {
|
||||
let mut notifies_batch = Vec::new();
|
||||
let mut highlights_batch = Vec::new();
|
||||
for user in notifies {
|
||||
let mut userroom_id = user.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
notifies_batch.push(userroom_id);
|
||||
}
|
||||
for user in highlights {
|
||||
let mut userroom_id = user.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
highlights_batch.push(userroom_id);
|
||||
}
|
||||
|
||||
self.userroomid_notificationcount
|
||||
.increment_batch(&mut notifies_batch.into_iter())?;
|
||||
self.userroomid_highlightcount
|
||||
.increment_batch(&mut highlights_batch.into_iter())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
124
src/database/key_value/rooms/user.rs
Normal file
124
src/database/key_value/rooms/user.rs
Normal file
|
@ -0,0 +1,124 @@
|
|||
use ruma::{RoomId, UserId};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::user::Data for KeyValueDatabase {
|
||||
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_notificationcount
|
||||
.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||
self.userroomid_highlightcount
|
||||
.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_notificationcount
|
||||
.get(&userroom_id)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid notification count in db."))
|
||||
})
|
||||
.unwrap_or(Ok(0))
|
||||
}
|
||||
|
||||
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_highlightcount
|
||||
.get(&userroom_id)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid highlight count in db."))
|
||||
})
|
||||
.unwrap_or(Ok(0))
|
||||
}
|
||||
|
||||
fn associate_token_shortstatehash(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
token: u64,
|
||||
shortstatehash: u64,
|
||||
) -> Result<()> {
|
||||
let shortroomid = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists");
|
||||
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&token.to_be_bytes());
|
||||
|
||||
self.roomsynctoken_shortstatehash
|
||||
.insert(&key, &shortstatehash.to_be_bytes())
|
||||
}
|
||||
|
||||
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
|
||||
let shortroomid = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists");
|
||||
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&token.to_be_bytes());
|
||||
|
||||
self.roomsynctoken_shortstatehash
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_shared_rooms<'a>(
|
||||
&'a self,
|
||||
users: Vec<Box<UserId>>,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<Box<RoomId>>> + 'a>> {
|
||||
let iterators = users.into_iter().map(move |user_id| {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
self.userroomid_joined
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
let roomid_index = key
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, &b)| b == 0xff)
|
||||
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
|
||||
.0
|
||||
+ 1; // +1 because the room id starts AFTER the separator
|
||||
|
||||
let room_id = key[roomid_index..].to_vec();
|
||||
|
||||
Ok::<_, Error>(room_id)
|
||||
})
|
||||
.filter_map(|r| r.ok())
|
||||
});
|
||||
|
||||
// We use the default compare function because keys are sorted correctly (not reversed)
|
||||
Ok(Box::new(Box::new(
|
||||
utils::common_elements(iterators, Ord::cmp)
|
||||
.expect("users is not empty")
|
||||
.map(|bytes| {
|
||||
RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid RoomId bytes in userroomid_joined")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
|
||||
}),
|
||||
)))
|
||||
}
|
||||
}
|
203
src/database/key_value/sending.rs
Normal file
203
src/database/key_value/sending.rs
Normal file
|
@ -0,0 +1,203 @@
|
|||
use ruma::{ServerName, UserId};
|
||||
|
||||
use crate::{
|
||||
database::KeyValueDatabase,
|
||||
service::{
|
||||
self,
|
||||
sending::{OutgoingKind, SendingEventType},
|
||||
},
|
||||
utils, Error, Result,
|
||||
};
|
||||
|
||||
impl service::sending::Data for KeyValueDatabase {
|
||||
fn active_requests<'a>(
|
||||
&'a self,
|
||||
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, OutgoingKind, SendingEventType)>> + 'a> {
|
||||
Box::new(
|
||||
self.servercurrentevent_data
|
||||
.iter()
|
||||
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))),
|
||||
)
|
||||
}
|
||||
|
||||
fn active_requests_for<'a>(
|
||||
&'a self,
|
||||
outgoing_kind: &OutgoingKind,
|
||||
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
Box::new(
|
||||
self.servercurrentevent_data
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))),
|
||||
)
|
||||
}
|
||||
|
||||
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> {
|
||||
self.servercurrentevent_data.remove(&key)
|
||||
}
|
||||
|
||||
fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
|
||||
self.servercurrentevent_data.remove(&key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) {
|
||||
self.servercurrentevent_data.remove(&key).unwrap();
|
||||
}
|
||||
|
||||
for (key, _) in self.servernameevent_data.scan_prefix(prefix.clone()) {
|
||||
self.servernameevent_data.remove(&key).unwrap();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn queue_requests(
|
||||
&self,
|
||||
requests: &[(&OutgoingKind, SendingEventType)],
|
||||
) -> Result<Vec<Vec<u8>>> {
|
||||
let mut batch = Vec::new();
|
||||
let mut keys = Vec::new();
|
||||
for (outgoing_kind, event) in requests {
|
||||
let mut key = outgoing_kind.get_prefix();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(if let SendingEventType::Pdu(value) = &event {
|
||||
&**value
|
||||
} else {
|
||||
&[]
|
||||
});
|
||||
let value = if let SendingEventType::Edu(value) = &event {
|
||||
&**value
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
batch.push((key.clone(), value.to_owned()));
|
||||
keys.push(key);
|
||||
}
|
||||
self.servernameevent_data
|
||||
.insert_batch(&mut batch.into_iter())?;
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
fn queued_requests<'a>(
|
||||
&'a self,
|
||||
outgoing_kind: &OutgoingKind,
|
||||
) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
return Box::new(
|
||||
self.servernameevent_data
|
||||
.scan_prefix(prefix.clone())
|
||||
.map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))),
|
||||
);
|
||||
}
|
||||
|
||||
fn mark_as_active(&self, events: &[(SendingEventType, Vec<u8>)]) -> Result<()> {
|
||||
for (e, key) in events {
|
||||
let value = if let SendingEventType::Edu(value) = &e {
|
||||
&**value
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
self.servercurrentevent_data.insert(key, value)?;
|
||||
self.servernameevent_data.remove(key)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
|
||||
self.servername_educount
|
||||
.insert(server_name.as_bytes(), &last_count.to_be_bytes())
|
||||
}
|
||||
|
||||
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
|
||||
self.servername_educount
|
||||
.get(server_name.as_bytes())?
|
||||
.map_or(Ok(0), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(key))]
|
||||
fn parse_servercurrentevent(
|
||||
key: &[u8],
|
||||
value: Vec<u8>,
|
||||
) -> Result<(OutgoingKind, SendingEventType)> {
|
||||
// Appservices start with a plus
|
||||
Ok::<_, Error>(if key.starts_with(b"+") {
|
||||
let mut parts = key[1..].splitn(2, |&b| b == 0xff);
|
||||
|
||||
let server = parts.next().expect("splitn always returns one element");
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let server = utils::string_from_bytes(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server bytes in server_currenttransaction")
|
||||
})?;
|
||||
|
||||
(
|
||||
OutgoingKind::Appservice(server),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
} else if key.starts_with(b"$") {
|
||||
let mut parts = key[1..].splitn(3, |&b| b == 0xff);
|
||||
|
||||
let user = parts.next().expect("splitn always returns one element");
|
||||
let user_string = utils::string_from_bytes(&user)
|
||||
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
|
||||
let user_id = UserId::parse(user_string)
|
||||
.map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
|
||||
|
||||
let pushkey = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let pushkey_string = utils::string_from_bytes(pushkey)
|
||||
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
|
||||
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
(
|
||||
OutgoingKind::Push(user_id, pushkey_string),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
// I'm pretty sure this should never be called
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
} else {
|
||||
let mut parts = key.splitn(2, |&b| b == 0xff);
|
||||
|
||||
let server = parts.next().expect("splitn always returns one element");
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let server = utils::string_from_bytes(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server bytes in server_currenttransaction")
|
||||
})?;
|
||||
|
||||
(
|
||||
OutgoingKind::Normal(ServerName::parse(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server string in server_currenttransaction")
|
||||
})?),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
SendingEventType::Edu(value)
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
|
@ -1,16 +1,9 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{DeviceId, TransactionId, UserId};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
use crate::{database::KeyValueDatabase, service, Result};
|
||||
|
||||
pub struct TransactionIds {
|
||||
pub(super) userdevicetxnid_response: Arc<dyn Tree>, // Response can be empty (/sendToDevice) or the event id (/send)
|
||||
}
|
||||
|
||||
impl TransactionIds {
|
||||
pub fn add_txnid(
|
||||
impl service::transaction_ids::Data for KeyValueDatabase {
|
||||
fn add_txnid(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: Option<&DeviceId>,
|
||||
|
@ -28,7 +21,7 @@ impl TransactionIds {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn existing_txnid(
|
||||
fn existing_txnid(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: Option<&DeviceId>,
|
90
src/database/key_value/uiaa.rs
Normal file
90
src/database/key_value/uiaa.rs
Normal file
|
@ -0,0 +1,90 @@
|
|||
use ruma::{
|
||||
api::client::{error::ErrorKind, uiaa::UiaaInfo},
|
||||
signatures::CanonicalJsonValue,
|
||||
DeviceId, UserId,
|
||||
};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, Error, Result};
|
||||
|
||||
impl service::uiaa::Data for KeyValueDatabase {
|
||||
fn set_uiaa_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
request: &CanonicalJsonValue,
|
||||
) -> Result<()> {
|
||||
self.userdevicesessionid_uiaarequest
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(
|
||||
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
|
||||
request.to_owned(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_uiaa_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
) -> Option<CanonicalJsonValue> {
|
||||
self.userdevicesessionid_uiaarequest
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned()))
|
||||
.map(|j| j.to_owned())
|
||||
}
|
||||
|
||||
fn update_uiaa_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
uiaainfo: Option<&UiaaInfo>,
|
||||
) -> Result<()> {
|
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||
|
||||
if let Some(uiaainfo) = uiaainfo {
|
||||
self.userdevicesessionid_uiaainfo.insert(
|
||||
&userdevicesessionid,
|
||||
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
|
||||
)?;
|
||||
} else {
|
||||
self.userdevicesessionid_uiaainfo
|
||||
.remove(&userdevicesessionid)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_uiaa_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
) -> Result<UiaaInfo> {
|
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||
|
||||
serde_json::from_slice(
|
||||
&self
|
||||
.userdevicesessionid_uiaainfo
|
||||
.get(&userdevicesessionid)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"UIAA session does not exist.",
|
||||
))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
|
||||
}
|
||||
}
|
|
@ -1,50 +1,28 @@
|
|||
use crate::{utils, Error, Result};
|
||||
use std::{collections::BTreeMap, mem::size_of};
|
||||
|
||||
use ruma::{
|
||||
api::client::{device::Device, error::ErrorKind, filter::IncomingFilterDefinition},
|
||||
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
|
||||
events::{AnyToDeviceEvent, StateEventType},
|
||||
serde::Raw,
|
||||
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, MxcUri, RoomAliasId,
|
||||
UInt, UserId,
|
||||
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, MxcUri, UInt, UserId,
|
||||
};
|
||||
use std::{collections::BTreeMap, mem, sync::Arc};
|
||||
use tracing::warn;
|
||||
|
||||
use super::abstraction::Tree;
|
||||
use crate::{
|
||||
database::KeyValueDatabase,
|
||||
service::{self, users::clean_signatures},
|
||||
services, utils, Error, Result,
|
||||
};
|
||||
|
||||
pub struct Users {
|
||||
pub(super) userid_password: Arc<dyn Tree>,
|
||||
pub(super) userid_displayname: Arc<dyn Tree>,
|
||||
pub(super) userid_avatarurl: Arc<dyn Tree>,
|
||||
pub(super) userid_blurhash: Arc<dyn Tree>,
|
||||
pub(super) userdeviceid_token: Arc<dyn Tree>,
|
||||
pub(super) userdeviceid_metadata: Arc<dyn Tree>, // This is also used to check if a device exists
|
||||
pub(super) userid_devicelistversion: Arc<dyn Tree>, // DevicelistVersion = u64
|
||||
pub(super) token_userdeviceid: Arc<dyn Tree>,
|
||||
|
||||
pub(super) onetimekeyid_onetimekeys: Arc<dyn Tree>, // OneTimeKeyId = UserId + DeviceKeyId
|
||||
pub(super) userid_lastonetimekeyupdate: Arc<dyn Tree>, // LastOneTimeKeyUpdate = Count
|
||||
pub(super) keychangeid_userid: Arc<dyn Tree>, // KeyChangeId = UserId/RoomId + Count
|
||||
pub(super) keyid_key: Arc<dyn Tree>, // KeyId = UserId + KeyId (depends on key type)
|
||||
pub(super) userid_masterkeyid: Arc<dyn Tree>,
|
||||
pub(super) userid_selfsigningkeyid: Arc<dyn Tree>,
|
||||
pub(super) userid_usersigningkeyid: Arc<dyn Tree>,
|
||||
|
||||
pub(super) userfilterid_filter: Arc<dyn Tree>, // UserFilterId = UserId + FilterId
|
||||
|
||||
pub(super) todeviceid_events: Arc<dyn Tree>, // ToDeviceId = UserId + DeviceId + Count
|
||||
}
|
||||
|
||||
impl Users {
|
||||
impl service::users::Data for KeyValueDatabase {
|
||||
/// Check if a user has an account on this homeserver.
|
||||
#[tracing::instrument(skip(self, user_id))]
|
||||
pub fn exists(&self, user_id: &UserId) -> Result<bool> {
|
||||
fn exists(&self, user_id: &UserId) -> Result<bool> {
|
||||
Ok(self.userid_password.get(user_id.as_bytes())?.is_some())
|
||||
}
|
||||
|
||||
/// Check if account is deactivated
|
||||
#[tracing::instrument(skip(self, user_id))]
|
||||
pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
|
||||
fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
|
||||
Ok(self
|
||||
.userid_password
|
||||
.get(user_id.as_bytes())?
|
||||
|
@ -55,37 +33,13 @@ impl Users {
|
|||
.is_empty())
|
||||
}
|
||||
|
||||
/// Check if a user is an admin
|
||||
#[tracing::instrument(skip(self, user_id, rooms, globals))]
|
||||
pub fn is_admin(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
rooms: &super::rooms::Rooms,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<bool> {
|
||||
let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", globals.server_name()))
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?;
|
||||
let admin_room_id = rooms.id_from_alias(&admin_room_alias_id)?.unwrap();
|
||||
|
||||
rooms.is_joined(user_id, &admin_room_id)
|
||||
}
|
||||
|
||||
/// Create a new user account on this homeserver.
|
||||
#[tracing::instrument(skip(self, user_id, password))]
|
||||
pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
|
||||
self.set_password(user_id, password)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the number of users registered on this server.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn count(&self) -> Result<usize> {
|
||||
fn count(&self) -> Result<usize> {
|
||||
Ok(self.userid_password.iter().count())
|
||||
}
|
||||
|
||||
/// Find out which user an access token belongs to.
|
||||
#[tracing::instrument(skip(self, token))]
|
||||
pub fn find_from_token(&self, token: &str) -> Result<Option<(Box<UserId>, String)>> {
|
||||
fn find_from_token(&self, token: &str) -> Result<Option<(Box<UserId>, String)>> {
|
||||
self.token_userdeviceid
|
||||
.get(token.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
|
@ -112,55 +66,29 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Returns an iterator over all users on this homeserver.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn iter(&self) -> impl Iterator<Item = Result<Box<UserId>>> + '_ {
|
||||
self.userid_password.iter().map(|(bytes, _)| {
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<Box<UserId>>> + 'a> {
|
||||
Box::new(self.userid_password.iter().map(|(bytes, _)| {
|
||||
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("User ID in userid_password is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("User ID in userid_password is invalid."))
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
/// Returns a list of local users as list of usernames.
|
||||
///
|
||||
/// A user account is considered `local` if the length of it's password is greater then zero.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn list_local_users(&self) -> Result<Vec<String>> {
|
||||
fn list_local_users(&self) -> Result<Vec<String>> {
|
||||
let users: Vec<String> = self
|
||||
.userid_password
|
||||
.iter()
|
||||
.filter_map(|(username, pw)| self.get_username_with_valid_password(&username, &pw))
|
||||
.filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw))
|
||||
.collect();
|
||||
Ok(users)
|
||||
}
|
||||
|
||||
/// Will only return with Some(username) if the password was not empty and the
|
||||
/// username could be successfully parsed.
|
||||
/// If utils::string_from_bytes(...) returns an error that username will be skipped
|
||||
/// and the error will be logged.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn get_username_with_valid_password(&self, username: &[u8], password: &[u8]) -> Option<String> {
|
||||
// A valid password is not empty
|
||||
if password.is_empty() {
|
||||
None
|
||||
} else {
|
||||
match utils::string_from_bytes(username) {
|
||||
Ok(u) => Some(u),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to parse username while calling get_local_users(): {}",
|
||||
e.to_string()
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the password hash for the given user.
|
||||
#[tracing::instrument(skip(self, user_id))]
|
||||
pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
self.userid_password
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
|
@ -171,10 +99,9 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Hash and set the user's password to the Argon2 hash
|
||||
#[tracing::instrument(skip(self, user_id, password))]
|
||||
pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
|
||||
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
|
||||
if let Some(password) = password {
|
||||
if let Ok(hash) = utils::calculate_hash(password) {
|
||||
if let Ok(hash) = utils::calculate_password_hash(password) {
|
||||
self.userid_password
|
||||
.insert(user_id.as_bytes(), hash.as_bytes())?;
|
||||
Ok(())
|
||||
|
@ -191,8 +118,7 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Returns the displayname of a user on this homeserver.
|
||||
#[tracing::instrument(skip(self, user_id))]
|
||||
pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
self.userid_displayname
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
|
@ -203,8 +129,7 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change.
|
||||
#[tracing::instrument(skip(self, user_id, displayname))]
|
||||
pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
|
||||
fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
|
||||
if let Some(displayname) = displayname {
|
||||
self.userid_displayname
|
||||
.insert(user_id.as_bytes(), displayname.as_bytes())?;
|
||||
|
@ -216,8 +141,7 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Get the avatar_url of a user.
|
||||
#[tracing::instrument(skip(self, user_id))]
|
||||
pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<Box<MxcUri>>> {
|
||||
fn avatar_url(&self, user_id: &UserId) -> Result<Option<Box<MxcUri>>> {
|
||||
self.userid_avatarurl
|
||||
.get(user_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
|
@ -230,8 +154,7 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Sets a new avatar_url or removes it if avatar_url is None.
|
||||
#[tracing::instrument(skip(self, user_id, avatar_url))]
|
||||
pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<Box<MxcUri>>) -> Result<()> {
|
||||
fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<Box<MxcUri>>) -> Result<()> {
|
||||
if let Some(avatar_url) = avatar_url {
|
||||
self.userid_avatarurl
|
||||
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
|
||||
|
@ -243,8 +166,7 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Get the blurhash of a user.
|
||||
#[tracing::instrument(skip(self, user_id))]
|
||||
pub fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
self.userid_blurhash
|
||||
.get(user_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
|
@ -257,8 +179,7 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Sets a new avatar_url or removes it if avatar_url is None.
|
||||
#[tracing::instrument(skip(self, user_id, blurhash))]
|
||||
pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> {
|
||||
fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> {
|
||||
if let Some(blurhash) = blurhash {
|
||||
self.userid_blurhash
|
||||
.insert(user_id.as_bytes(), blurhash.as_bytes())?;
|
||||
|
@ -270,8 +191,7 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Adds a new device to a user.
|
||||
#[tracing::instrument(skip(self, user_id, device_id, token, initial_device_display_name))]
|
||||
pub fn create_device(
|
||||
fn create_device(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
|
@ -305,8 +225,7 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Removes a device from a user.
|
||||
#[tracing::instrument(skip(self, user_id, device_id))]
|
||||
pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||
fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xff);
|
||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
|
@ -336,31 +255,32 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Returns an iterator over all device ids of this user.
|
||||
#[tracing::instrument(skip(self, user_id))]
|
||||
pub fn all_device_ids<'a>(
|
||||
fn all_device_ids<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
) -> impl Iterator<Item = Result<Box<DeviceId>>> + 'a {
|
||||
) -> Box<dyn Iterator<Item = Result<Box<DeviceId>>> + 'a> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
// All devices have metadata
|
||||
self.userdeviceid_metadata
|
||||
.scan_prefix(prefix)
|
||||
.map(|(bytes, _)| {
|
||||
Ok(utils::string_from_bytes(
|
||||
bytes
|
||||
.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))?
|
||||
.into())
|
||||
})
|
||||
Box::new(
|
||||
self.userdeviceid_metadata
|
||||
.scan_prefix(prefix)
|
||||
.map(|(bytes, _)| {
|
||||
Ok(utils::string_from_bytes(
|
||||
bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| {
|
||||
Error::bad_database("UserDevice ID in db is invalid.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Device ID in userdeviceid_metadata is invalid.")
|
||||
})?
|
||||
.into())
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Replaces the access token of one device.
|
||||
#[tracing::instrument(skip(self, user_id, device_id, token))]
|
||||
pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
|
||||
fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xff);
|
||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
|
@ -383,21 +303,12 @@ impl Users {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(
|
||||
self,
|
||||
user_id,
|
||||
device_id,
|
||||
one_time_key_key,
|
||||
one_time_key_value,
|
||||
globals
|
||||
))]
|
||||
pub fn add_one_time_key(
|
||||
fn add_one_time_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
one_time_key_key: &DeviceKeyId,
|
||||
one_time_key_value: &Raw<OneTimeKey>,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
|
@ -421,14 +332,15 @@ impl Users {
|
|||
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
|
||||
)?;
|
||||
|
||||
self.userid_lastonetimekeyupdate
|
||||
.insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?;
|
||||
self.userid_lastonetimekeyupdate.insert(
|
||||
user_id.as_bytes(),
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_id))]
|
||||
pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> {
|
||||
fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> {
|
||||
self.userid_lastonetimekeyupdate
|
||||
.get(user_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
|
@ -439,13 +351,11 @@ impl Users {
|
|||
.unwrap_or(Ok(0))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_id, device_id, key_algorithm, globals))]
|
||||
pub fn take_one_time_key(
|
||||
fn take_one_time_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
key_algorithm: &DeviceKeyAlgorithm,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<Option<(Box<DeviceKeyId>, Raw<OneTimeKey>)>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
@ -455,8 +365,10 @@ impl Users {
|
|||
prefix.extend_from_slice(key_algorithm.as_ref().as_bytes());
|
||||
prefix.push(b':');
|
||||
|
||||
self.userid_lastonetimekeyupdate
|
||||
.insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?;
|
||||
self.userid_lastonetimekeyupdate.insert(
|
||||
user_id.as_bytes(),
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
|
||||
self.onetimekeyid_onetimekeys
|
||||
.scan_prefix(prefix)
|
||||
|
@ -479,8 +391,7 @@ impl Users {
|
|||
.transpose()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_id, device_id))]
|
||||
pub fn count_one_time_keys(
|
||||
fn count_one_time_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
|
@ -512,14 +423,11 @@ impl Users {
|
|||
Ok(counts)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_id, device_id, device_keys, rooms, globals))]
|
||||
pub fn add_device_keys(
|
||||
fn add_device_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
device_keys: &Raw<DeviceKeys>,
|
||||
rooms: &super::rooms::Rooms,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xff);
|
||||
|
@ -530,27 +438,17 @@ impl Users {
|
|||
&serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"),
|
||||
)?;
|
||||
|
||||
self.mark_device_key_update(user_id, rooms, globals)?;
|
||||
self.mark_device_key_update(user_id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(
|
||||
self,
|
||||
master_key,
|
||||
self_signing_key,
|
||||
user_signing_key,
|
||||
rooms,
|
||||
globals
|
||||
))]
|
||||
pub fn add_cross_signing_keys(
|
||||
fn add_cross_signing_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
master_key: &Raw<CrossSigningKey>,
|
||||
self_signing_key: &Option<Raw<CrossSigningKey>>,
|
||||
user_signing_key: &Option<Raw<CrossSigningKey>>,
|
||||
rooms: &super::rooms::Rooms,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
// TODO: Check signatures
|
||||
|
||||
|
@ -653,20 +551,17 @@ impl Users {
|
|||
.insert(user_id.as_bytes(), &user_signing_key_key)?;
|
||||
}
|
||||
|
||||
self.mark_device_key_update(user_id, rooms, globals)?;
|
||||
self.mark_device_key_update(user_id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, target_id, key_id, signature, sender_id, rooms, globals))]
|
||||
pub fn sign_key(
|
||||
fn sign_key(
|
||||
&self,
|
||||
target_id: &UserId,
|
||||
key_id: &str,
|
||||
signature: (String, String),
|
||||
sender_id: &UserId,
|
||||
rooms: &super::rooms::Rooms,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut key = target_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
|
@ -698,18 +593,17 @@ impl Users {
|
|||
)?;
|
||||
|
||||
// TODO: Should we notify about this change?
|
||||
self.mark_device_key_update(target_id, rooms, globals)?;
|
||||
self.mark_device_key_update(target_id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_or_room_id, from, to))]
|
||||
pub fn keys_changed<'a>(
|
||||
fn keys_changed<'a>(
|
||||
&'a self,
|
||||
user_or_room_id: &str,
|
||||
from: u64,
|
||||
to: Option<u64>,
|
||||
) -> impl Iterator<Item = Result<Box<UserId>>> + 'a {
|
||||
) -> Box<dyn Iterator<Item = Result<Box<UserId>>> + 'a> {
|
||||
let mut prefix = user_or_room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
|
@ -718,41 +612,48 @@ impl Users {
|
|||
|
||||
let to = to.unwrap_or(u64::MAX);
|
||||
|
||||
self.keychangeid_userid
|
||||
.iter_from(&start, false)
|
||||
.take_while(move |(k, _)| {
|
||||
k.starts_with(&prefix)
|
||||
&& if let Some(current) = k.splitn(2, |&b| b == 0xff).nth(1) {
|
||||
if let Ok(c) = utils::u64_from_bytes(current) {
|
||||
c <= to
|
||||
Box::new(
|
||||
self.keychangeid_userid
|
||||
.iter_from(&start, false)
|
||||
.take_while(move |(k, _)| {
|
||||
k.starts_with(&prefix)
|
||||
&& if let Some(current) = k.splitn(2, |&b| b == 0xff).nth(1) {
|
||||
if let Ok(c) = utils::u64_from_bytes(current) {
|
||||
c <= to
|
||||
} else {
|
||||
warn!("BadDatabase: Could not parse keychangeid_userid bytes");
|
||||
false
|
||||
}
|
||||
} else {
|
||||
warn!("BadDatabase: Could not parse keychangeid_userid bytes");
|
||||
warn!("BadDatabase: Could not parse keychangeid_userid");
|
||||
false
|
||||
}
|
||||
} else {
|
||||
warn!("BadDatabase: Could not parse keychangeid_userid");
|
||||
false
|
||||
}
|
||||
})
|
||||
.map(|(_, bytes)| {
|
||||
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid."))
|
||||
})
|
||||
})
|
||||
.map(|(_, bytes)| {
|
||||
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"User ID in devicekeychangeid_userid is invalid unicode.",
|
||||
)
|
||||
})?)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("User ID in devicekeychangeid_userid is invalid.")
|
||||
})
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_id, rooms, globals))]
|
||||
pub fn mark_device_key_update(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
rooms: &super::rooms::Rooms,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let count = globals.next_count()?.to_be_bytes();
|
||||
for room_id in rooms.rooms_joined(user_id).filter_map(|r| r.ok()) {
|
||||
fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> {
|
||||
let count = services().globals.next_count()?.to_be_bytes();
|
||||
for room_id in services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(user_id)
|
||||
.filter_map(|r| r.ok())
|
||||
{
|
||||
// Don't send key updates to unencrypted rooms
|
||||
if rooms
|
||||
if services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?
|
||||
.is_none()
|
||||
{
|
||||
|
@ -774,8 +675,7 @@ impl Users {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_id, device_id))]
|
||||
pub fn get_device_keys(
|
||||
fn get_device_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
|
@ -791,11 +691,10 @@ impl Users {
|
|||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_id, allowed_signatures))]
|
||||
pub fn get_master_key<F: Fn(&UserId) -> bool>(
|
||||
fn get_master_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
allowed_signatures: F,
|
||||
allowed_signatures: &dyn Fn(&UserId) -> bool,
|
||||
) -> Result<Option<Raw<CrossSigningKey>>> {
|
||||
self.userid_masterkeyid
|
||||
.get(user_id.as_bytes())?
|
||||
|
@ -813,11 +712,10 @@ impl Users {
|
|||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_id, allowed_signatures))]
|
||||
pub fn get_self_signing_key<F: Fn(&UserId) -> bool>(
|
||||
fn get_self_signing_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
allowed_signatures: F,
|
||||
allowed_signatures: &dyn Fn(&UserId) -> bool,
|
||||
) -> Result<Option<Raw<CrossSigningKey>>> {
|
||||
self.userid_selfsigningkeyid
|
||||
.get(user_id.as_bytes())?
|
||||
|
@ -835,8 +733,7 @@ impl Users {
|
|||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_id))]
|
||||
pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
|
||||
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
|
||||
self.userid_usersigningkeyid
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |key| {
|
||||
|
@ -848,29 +745,19 @@ impl Users {
|
|||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(
|
||||
self,
|
||||
sender,
|
||||
target_user_id,
|
||||
target_device_id,
|
||||
event_type,
|
||||
content,
|
||||
globals
|
||||
))]
|
||||
pub fn add_to_device_event(
|
||||
fn add_to_device_event(
|
||||
&self,
|
||||
sender: &UserId,
|
||||
target_user_id: &UserId,
|
||||
target_device_id: &DeviceId,
|
||||
event_type: &str,
|
||||
content: serde_json::Value,
|
||||
globals: &super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut key = target_user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(target_device_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(&globals.next_count()?.to_be_bytes());
|
||||
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
|
||||
let mut json = serde_json::Map::new();
|
||||
json.insert("type".to_owned(), event_type.to_owned().into());
|
||||
|
@ -884,8 +771,7 @@ impl Users {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_id, device_id))]
|
||||
pub fn get_to_device_events(
|
||||
fn get_to_device_events(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
|
@ -907,8 +793,7 @@ impl Users {
|
|||
Ok(events)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_id, device_id, until))]
|
||||
pub fn remove_to_device_events(
|
||||
fn remove_to_device_events(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
|
@ -929,7 +814,7 @@ impl Users {
|
|||
.map(|(key, _)| {
|
||||
Ok::<_, Error>((
|
||||
key.clone(),
|
||||
utils::u64_from_bytes(&key[key.len() - mem::size_of::<u64>()..key.len()])
|
||||
utils::u64_from_bytes(&key[key.len() - size_of::<u64>()..key.len()])
|
||||
.map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?,
|
||||
))
|
||||
})
|
||||
|
@ -942,8 +827,7 @@ impl Users {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_id, device_id, device))]
|
||||
pub fn update_device_metadata(
|
||||
fn update_device_metadata(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
|
@ -968,8 +852,7 @@ impl Users {
|
|||
}
|
||||
|
||||
/// Get device metadata.
|
||||
#[tracing::instrument(skip(self, user_id, device_id))]
|
||||
pub fn get_device_metadata(
|
||||
fn get_device_metadata(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
|
@ -987,8 +870,7 @@ impl Users {
|
|||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_id))]
|
||||
pub fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
|
||||
fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
|
||||
self.userid_devicelistversion
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
|
@ -998,46 +880,26 @@ impl Users {
|
|||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user_id))]
|
||||
pub fn all_devices_metadata<'a>(
|
||||
fn all_devices_metadata<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
) -> impl Iterator<Item = Result<Device>> + 'a {
|
||||
) -> Box<dyn Iterator<Item = Result<Device>> + 'a> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
|
||||
self.userdeviceid_metadata
|
||||
.scan_prefix(key)
|
||||
.map(|(_, bytes)| {
|
||||
serde_json::from_slice::<Device>(&bytes)
|
||||
.map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid."))
|
||||
})
|
||||
}
|
||||
|
||||
/// Deactivate account
|
||||
#[tracing::instrument(skip(self, user_id))]
|
||||
pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> {
|
||||
// Remove all associated devices
|
||||
for device_id in self.all_device_ids(user_id) {
|
||||
self.remove_device(user_id, &device_id?)?;
|
||||
}
|
||||
|
||||
// Set the password to "" to indicate a deactivated account. Hashes will never result in an
|
||||
// empty string, so the user will not be able to log in again. Systems like changing the
|
||||
// password without logging in should check if the account is deactivated.
|
||||
self.userid_password.insert(user_id.as_bytes(), &[])?;
|
||||
|
||||
// TODO: Unhook 3PID
|
||||
Ok(())
|
||||
Box::new(
|
||||
self.userdeviceid_metadata
|
||||
.scan_prefix(key)
|
||||
.map(|(_, bytes)| {
|
||||
serde_json::from_slice::<Device>(&bytes).map_err(|_| {
|
||||
Error::bad_database("Device in userdeviceid_metadata is invalid.")
|
||||
})
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Creates a new sync filter. Returns the filter id.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn create_filter(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
filter: &IncomingFilterDefinition,
|
||||
) -> Result<String> {
|
||||
fn create_filter(&self, user_id: &UserId, filter: &IncomingFilterDefinition) -> Result<String> {
|
||||
let filter_id = utils::random_string(4);
|
||||
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
|
@ -1052,8 +914,7 @@ impl Users {
|
|||
Ok(filter_id)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn get_filter(
|
||||
fn get_filter(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
filter_id: &str,
|
||||
|
@ -1073,29 +934,24 @@ impl Users {
|
|||
}
|
||||
}
|
||||
|
||||
/// Ensure that a user only sees signatures from themselves and the target user
|
||||
fn clean_signatures<F: Fn(&UserId) -> bool>(
|
||||
cross_signing_key: &mut serde_json::Value,
|
||||
user_id: &UserId,
|
||||
allowed_signatures: F,
|
||||
) -> Result<(), Error> {
|
||||
if let Some(signatures) = cross_signing_key
|
||||
.get_mut("signatures")
|
||||
.and_then(|v| v.as_object_mut())
|
||||
{
|
||||
// Don't allocate for the full size of the current signatures, but require
|
||||
// at most one resize if nothing is dropped
|
||||
let new_capacity = signatures.len() / 2;
|
||||
for (user, signature) in
|
||||
mem::replace(signatures, serde_json::Map::with_capacity(new_capacity))
|
||||
{
|
||||
let id = <&UserId>::try_from(user.as_str())
|
||||
.map_err(|_| Error::bad_database("Invalid user ID in database."))?;
|
||||
if id == user_id || allowed_signatures(id) {
|
||||
signatures.insert(user, signature);
|
||||
/// Will only return with Some(username) if the password was not empty and the
|
||||
/// username could be successfully parsed.
|
||||
/// If utils::string_from_bytes(...) returns an error that username will be skipped
|
||||
/// and the error will be logged.
|
||||
fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<String> {
|
||||
// A valid password is not empty
|
||||
if password.is_empty() {
|
||||
None
|
||||
} else {
|
||||
match utils::string_from_bytes(username) {
|
||||
Ok(u) => Some(u),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to parse username while calling get_local_users(): {}",
|
||||
e.to_string()
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -1,358 +0,0 @@
|
|||
use crate::database::globals::Globals;
|
||||
use image::{imageops::FilterType, GenericImageView};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
use crate::{utils, Error, Result};
|
||||
use std::{mem, sync::Arc};
|
||||
use tokio::{
|
||||
fs::File,
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
};
|
||||
|
||||
pub struct FileMeta {
|
||||
pub content_disposition: Option<String>,
|
||||
pub content_type: Option<String>,
|
||||
pub file: Vec<u8>,
|
||||
}
|
||||
|
||||
pub struct Media {
|
||||
pub(super) mediaid_file: Arc<dyn Tree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType
|
||||
}
|
||||
|
||||
impl Media {
|
||||
/// Uploads a file.
|
||||
pub async fn create(
|
||||
&self,
|
||||
mxc: String,
|
||||
globals: &Globals,
|
||||
content_disposition: &Option<&str>,
|
||||
content_type: &Option<&str>,
|
||||
file: &[u8],
|
||||
) -> Result<()> {
|
||||
let mut key = mxc.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail
|
||||
key.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_disposition
|
||||
.as_ref()
|
||||
.map(|f| f.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_type
|
||||
.as_ref()
|
||||
.map(|c| c.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
|
||||
let path = globals.get_media_file(&key);
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(file).await?;
|
||||
|
||||
self.mediaid_file.insert(&key, &[])?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Uploads or replaces a file thumbnail.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn upload_thumbnail(
|
||||
&self,
|
||||
mxc: String,
|
||||
globals: &Globals,
|
||||
content_disposition: &Option<String>,
|
||||
content_type: &Option<String>,
|
||||
width: u32,
|
||||
height: u32,
|
||||
file: &[u8],
|
||||
) -> Result<()> {
|
||||
let mut key = mxc.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(&width.to_be_bytes());
|
||||
key.extend_from_slice(&height.to_be_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_disposition
|
||||
.as_ref()
|
||||
.map(|f| f.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_type
|
||||
.as_ref()
|
||||
.map(|c| c.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
|
||||
let path = globals.get_media_file(&key);
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(file).await?;
|
||||
|
||||
self.mediaid_file.insert(&key, &[])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Downloads a file.
|
||||
pub async fn get(&self, globals: &Globals, mxc: &str) -> Result<Option<FileMeta>> {
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail
|
||||
prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
|
||||
prefix.push(0xff);
|
||||
|
||||
let first = self.mediaid_file.scan_prefix(prefix).next();
|
||||
if let Some((key, _)) = first {
|
||||
let path = globals.get_media_file(&key);
|
||||
let mut file = Vec::new();
|
||||
File::open(path).await?.read_to_end(&mut file).await?;
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
|
||||
let content_type = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes).map_err(|_| {
|
||||
Error::bad_database("Content type in mediaid_file is invalid unicode.")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let content_disposition_bytes = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
||||
|
||||
let content_disposition = if content_disposition_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
utils::string_from_bytes(content_disposition_bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Content Disposition in mediaid_file is invalid unicode.",
|
||||
)
|
||||
})?,
|
||||
)
|
||||
};
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns width, height of the thumbnail and whether it should be cropped. Returns None when
|
||||
/// the server should send the original file.
|
||||
pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> {
|
||||
match (width, height) {
|
||||
(0..=32, 0..=32) => Some((32, 32, true)),
|
||||
(0..=96, 0..=96) => Some((96, 96, true)),
|
||||
(0..=320, 0..=240) => Some((320, 240, false)),
|
||||
(0..=640, 0..=480) => Some((640, 480, false)),
|
||||
(0..=800, 0..=600) => Some((800, 600, false)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Downloads a file's thumbnail.
|
||||
///
|
||||
/// Here's an example on how it works:
|
||||
///
|
||||
/// - Client requests an image with width=567, height=567
|
||||
/// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails
|
||||
/// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96)
|
||||
/// - Server creates the thumbnail and sends it to the user
|
||||
///
|
||||
/// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards.
|
||||
pub async fn get_thumbnail(
|
||||
&self,
|
||||
mxc: &str,
|
||||
globals: &Globals,
|
||||
width: u32,
|
||||
height: u32,
|
||||
) -> Result<Option<FileMeta>> {
|
||||
let (width, height, crop) = self
|
||||
.thumbnail_properties(width, height)
|
||||
.unwrap_or((0, 0, false)); // 0, 0 because that's the original file
|
||||
|
||||
let mut main_prefix = mxc.as_bytes().to_vec();
|
||||
main_prefix.push(0xff);
|
||||
|
||||
let mut thumbnail_prefix = main_prefix.clone();
|
||||
thumbnail_prefix.extend_from_slice(&width.to_be_bytes());
|
||||
thumbnail_prefix.extend_from_slice(&height.to_be_bytes());
|
||||
thumbnail_prefix.push(0xff);
|
||||
|
||||
let mut original_prefix = main_prefix;
|
||||
original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Width = 0 if it's not a thumbnail
|
||||
original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail
|
||||
original_prefix.push(0xff);
|
||||
|
||||
let first_thumbnailprefix = self.mediaid_file.scan_prefix(thumbnail_prefix).next();
|
||||
let first_originalprefix = self.mediaid_file.scan_prefix(original_prefix).next();
|
||||
if let Some((key, _)) = first_thumbnailprefix {
|
||||
// Using saved thumbnail
|
||||
let path = globals.get_media_file(&key);
|
||||
let mut file = Vec::new();
|
||||
File::open(path).await?.read_to_end(&mut file).await?;
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
|
||||
let content_type = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes).map_err(|_| {
|
||||
Error::bad_database("Content type in mediaid_file is invalid unicode.")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let content_disposition_bytes = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
||||
|
||||
let content_disposition = if content_disposition_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
utils::string_from_bytes(content_disposition_bytes).map_err(|_| {
|
||||
Error::bad_database("Content Disposition in db is invalid.")
|
||||
})?,
|
||||
)
|
||||
};
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.to_vec(),
|
||||
}))
|
||||
} else if let Some((key, _)) = first_originalprefix {
|
||||
// Generate a thumbnail
|
||||
let path = globals.get_media_file(&key);
|
||||
let mut file = Vec::new();
|
||||
File::open(path).await?.read_to_end(&mut file).await?;
|
||||
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
|
||||
let content_type = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes).map_err(|_| {
|
||||
Error::bad_database("Content type in mediaid_file is invalid unicode.")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let content_disposition_bytes = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
||||
|
||||
let content_disposition = if content_disposition_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
utils::string_from_bytes(content_disposition_bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Content Disposition in mediaid_file is invalid unicode.",
|
||||
)
|
||||
})?,
|
||||
)
|
||||
};
|
||||
|
||||
if let Ok(image) = image::load_from_memory(&file) {
|
||||
let original_width = image.width();
|
||||
let original_height = image.height();
|
||||
if width > original_width || height > original_height {
|
||||
return Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.to_vec(),
|
||||
}));
|
||||
}
|
||||
|
||||
let thumbnail = if crop {
|
||||
image.resize_to_fill(width, height, FilterType::CatmullRom)
|
||||
} else {
|
||||
let (exact_width, exact_height) = {
|
||||
// Copied from image::dynimage::resize_dimensions
|
||||
let ratio = u64::from(original_width) * u64::from(height);
|
||||
let nratio = u64::from(width) * u64::from(original_height);
|
||||
|
||||
let use_width = nratio <= ratio;
|
||||
let intermediate = if use_width {
|
||||
u64::from(original_height) * u64::from(width)
|
||||
/ u64::from(original_width)
|
||||
} else {
|
||||
u64::from(original_width) * u64::from(height)
|
||||
/ u64::from(original_height)
|
||||
};
|
||||
if use_width {
|
||||
if intermediate <= u64::from(::std::u32::MAX) {
|
||||
(width, intermediate as u32)
|
||||
} else {
|
||||
(
|
||||
(u64::from(width) * u64::from(::std::u32::MAX) / intermediate)
|
||||
as u32,
|
||||
::std::u32::MAX,
|
||||
)
|
||||
}
|
||||
} else if intermediate <= u64::from(::std::u32::MAX) {
|
||||
(intermediate as u32, height)
|
||||
} else {
|
||||
(
|
||||
::std::u32::MAX,
|
||||
(u64::from(height) * u64::from(::std::u32::MAX) / intermediate)
|
||||
as u32,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
image.thumbnail_exact(exact_width, exact_height)
|
||||
};
|
||||
|
||||
let mut thumbnail_bytes = Vec::new();
|
||||
thumbnail.write_to(&mut thumbnail_bytes, image::ImageOutputFormat::Png)?;
|
||||
|
||||
// Save thumbnail in database so we don't have to generate it again next time
|
||||
let mut thumbnail_key = key.to_vec();
|
||||
let width_index = thumbnail_key
|
||||
.iter()
|
||||
.position(|&b| b == 0xff)
|
||||
.ok_or_else(|| Error::bad_database("Media in db is invalid."))?
|
||||
+ 1;
|
||||
let mut widthheight = width.to_be_bytes().to_vec();
|
||||
widthheight.extend_from_slice(&height.to_be_bytes());
|
||||
|
||||
thumbnail_key.splice(
|
||||
width_index..width_index + 2 * mem::size_of::<u32>(),
|
||||
widthheight,
|
||||
);
|
||||
|
||||
let path = globals.get_media_file(&thumbnail_key);
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(&thumbnail_bytes).await?;
|
||||
|
||||
self.mediaid_file.insert(&thumbnail_key, &[])?;
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: thumbnail_bytes.to_vec(),
|
||||
}))
|
||||
} else {
|
||||
// Couldn't parse file to generate thumbnail, send original
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.to_vec(),
|
||||
}))
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
937
src/database/mod.rs
Normal file
937
src/database/mod.rs
Normal file
|
@ -0,0 +1,937 @@
|
|||
pub mod abstraction;
|
||||
pub mod key_value;
|
||||
|
||||
use crate::{services, utils, Config, Error, PduEvent, Result, Services, SERVICES};
|
||||
use abstraction::KeyValueDatabaseEngine;
|
||||
use abstraction::KvTree;
|
||||
use directories::ProjectDirs;
|
||||
use lru_cache::LruCache;
|
||||
use ruma::{
|
||||
events::{
|
||||
push_rules::PushRulesEventContent, room::message::RoomMessageEventContent,
|
||||
GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType,
|
||||
},
|
||||
push::Ruleset,
|
||||
signatures::CanonicalJsonValue,
|
||||
DeviceId, EventId, RoomId, UserId,
|
||||
};
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap, HashSet},
|
||||
fs::{self, remove_dir_all},
|
||||
io::Write,
|
||||
mem::size_of,
|
||||
path::Path,
|
||||
sync::{Arc, Mutex, RwLock},
|
||||
};
|
||||
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
pub struct KeyValueDatabase {
|
||||
_db: Arc<dyn KeyValueDatabaseEngine>,
|
||||
|
||||
//pub globals: globals::Globals,
|
||||
pub(super) global: Arc<dyn KvTree>,
|
||||
pub(super) server_signingkeys: Arc<dyn KvTree>,
|
||||
|
||||
//pub users: users::Users,
|
||||
pub(super) userid_password: Arc<dyn KvTree>,
|
||||
pub(super) userid_displayname: Arc<dyn KvTree>,
|
||||
pub(super) userid_avatarurl: Arc<dyn KvTree>,
|
||||
pub(super) userid_blurhash: Arc<dyn KvTree>,
|
||||
pub(super) userdeviceid_token: Arc<dyn KvTree>,
|
||||
pub(super) userdeviceid_metadata: Arc<dyn KvTree>, // This is also used to check if a device exists
|
||||
pub(super) userid_devicelistversion: Arc<dyn KvTree>, // DevicelistVersion = u64
|
||||
pub(super) token_userdeviceid: Arc<dyn KvTree>,
|
||||
|
||||
pub(super) onetimekeyid_onetimekeys: Arc<dyn KvTree>, // OneTimeKeyId = UserId + DeviceKeyId
|
||||
pub(super) userid_lastonetimekeyupdate: Arc<dyn KvTree>, // LastOneTimeKeyUpdate = Count
|
||||
pub(super) keychangeid_userid: Arc<dyn KvTree>, // KeyChangeId = UserId/RoomId + Count
|
||||
pub(super) keyid_key: Arc<dyn KvTree>, // KeyId = UserId + KeyId (depends on key type)
|
||||
pub(super) userid_masterkeyid: Arc<dyn KvTree>,
|
||||
pub(super) userid_selfsigningkeyid: Arc<dyn KvTree>,
|
||||
pub(super) userid_usersigningkeyid: Arc<dyn KvTree>,
|
||||
|
||||
pub(super) userfilterid_filter: Arc<dyn KvTree>, // UserFilterId = UserId + FilterId
|
||||
|
||||
pub(super) todeviceid_events: Arc<dyn KvTree>, // ToDeviceId = UserId + DeviceId + Count
|
||||
|
||||
//pub uiaa: uiaa::Uiaa,
|
||||
pub(super) userdevicesessionid_uiaainfo: Arc<dyn KvTree>, // User-interactive authentication
|
||||
pub(super) userdevicesessionid_uiaarequest:
|
||||
RwLock<BTreeMap<(Box<UserId>, Box<DeviceId>, String), CanonicalJsonValue>>,
|
||||
|
||||
//pub edus: RoomEdus,
|
||||
pub(super) readreceiptid_readreceipt: Arc<dyn KvTree>, // ReadReceiptId = RoomId + Count + UserId
|
||||
pub(super) roomuserid_privateread: Arc<dyn KvTree>, // RoomUserId = Room + User, PrivateRead = Count
|
||||
pub(super) roomuserid_lastprivatereadupdate: Arc<dyn KvTree>, // LastPrivateReadUpdate = Count
|
||||
pub(super) typingid_userid: Arc<dyn KvTree>, // TypingId = RoomId + TimeoutTime + Count
|
||||
pub(super) roomid_lasttypingupdate: Arc<dyn KvTree>, // LastRoomTypingUpdate = Count
|
||||
pub(super) presenceid_presence: Arc<dyn KvTree>, // PresenceId = RoomId + Count + UserId
|
||||
pub(super) userid_lastpresenceupdate: Arc<dyn KvTree>, // LastPresenceUpdate = Count
|
||||
|
||||
//pub rooms: rooms::Rooms,
|
||||
pub(super) pduid_pdu: Arc<dyn KvTree>, // PduId = ShortRoomId + Count
|
||||
pub(super) eventid_pduid: Arc<dyn KvTree>,
|
||||
pub(super) roomid_pduleaves: Arc<dyn KvTree>,
|
||||
pub(super) alias_roomid: Arc<dyn KvTree>,
|
||||
pub(super) aliasid_alias: Arc<dyn KvTree>, // AliasId = RoomId + Count
|
||||
pub(super) publicroomids: Arc<dyn KvTree>,
|
||||
|
||||
pub(super) tokenids: Arc<dyn KvTree>, // TokenId = ShortRoomId + Token + PduIdCount
|
||||
|
||||
/// Participating servers in a room.
|
||||
pub(super) roomserverids: Arc<dyn KvTree>, // RoomServerId = RoomId + ServerName
|
||||
pub(super) serverroomids: Arc<dyn KvTree>, // ServerRoomId = ServerName + RoomId
|
||||
|
||||
pub(super) userroomid_joined: Arc<dyn KvTree>,
|
||||
pub(super) roomuserid_joined: Arc<dyn KvTree>,
|
||||
pub(super) roomid_joinedcount: Arc<dyn KvTree>,
|
||||
pub(super) roomid_invitedcount: Arc<dyn KvTree>,
|
||||
pub(super) roomuseroncejoinedids: Arc<dyn KvTree>,
|
||||
pub(super) userroomid_invitestate: Arc<dyn KvTree>, // InviteState = Vec<Raw<Pdu>>
|
||||
pub(super) roomuserid_invitecount: Arc<dyn KvTree>, // InviteCount = Count
|
||||
pub(super) userroomid_leftstate: Arc<dyn KvTree>,
|
||||
pub(super) roomuserid_leftcount: Arc<dyn KvTree>,
|
||||
|
||||
pub(super) disabledroomids: Arc<dyn KvTree>, // Rooms where incoming federation handling is disabled
|
||||
|
||||
pub(super) lazyloadedids: Arc<dyn KvTree>, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId
|
||||
|
||||
pub(super) userroomid_notificationcount: Arc<dyn KvTree>, // NotifyCount = u64
|
||||
pub(super) userroomid_highlightcount: Arc<dyn KvTree>, // HightlightCount = u64
|
||||
|
||||
/// Remember the current state hash of a room.
|
||||
pub(super) roomid_shortstatehash: Arc<dyn KvTree>,
|
||||
pub(super) roomsynctoken_shortstatehash: Arc<dyn KvTree>,
|
||||
/// Remember the state hash at events in the past.
|
||||
pub(super) shorteventid_shortstatehash: Arc<dyn KvTree>,
|
||||
/// StateKey = EventType + StateKey, ShortStateKey = Count
|
||||
pub(super) statekey_shortstatekey: Arc<dyn KvTree>,
|
||||
pub(super) shortstatekey_statekey: Arc<dyn KvTree>,
|
||||
|
||||
pub(super) roomid_shortroomid: Arc<dyn KvTree>,
|
||||
|
||||
pub(super) shorteventid_eventid: Arc<dyn KvTree>,
|
||||
pub(super) eventid_shorteventid: Arc<dyn KvTree>,
|
||||
|
||||
pub(super) statehash_shortstatehash: Arc<dyn KvTree>,
|
||||
pub(super) shortstatehash_statediff: Arc<dyn KvTree>, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--)
|
||||
|
||||
pub(super) shorteventid_authchain: Arc<dyn KvTree>,
|
||||
|
||||
/// RoomId + EventId -> outlier PDU.
|
||||
/// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn.
|
||||
pub(super) eventid_outlierpdu: Arc<dyn KvTree>,
|
||||
pub(super) softfailedeventids: Arc<dyn KvTree>,
|
||||
|
||||
/// RoomId + EventId -> Parent PDU EventId.
|
||||
pub(super) referencedevents: Arc<dyn KvTree>,
|
||||
|
||||
//pub account_data: account_data::AccountData,
|
||||
pub(super) roomuserdataid_accountdata: Arc<dyn KvTree>, // RoomUserDataId = Room + User + Count + Type
|
||||
pub(super) roomusertype_roomuserdataid: Arc<dyn KvTree>, // RoomUserType = Room + User + Type
|
||||
|
||||
//pub media: media::Media,
|
||||
pub(super) mediaid_file: Arc<dyn KvTree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType
|
||||
//pub key_backups: key_backups::KeyBackups,
|
||||
pub(super) backupid_algorithm: Arc<dyn KvTree>, // BackupId = UserId + Version(Count)
|
||||
pub(super) backupid_etag: Arc<dyn KvTree>, // BackupId = UserId + Version(Count)
|
||||
pub(super) backupkeyid_backup: Arc<dyn KvTree>, // BackupKeyId = UserId + Version + RoomId + SessionId
|
||||
|
||||
//pub transaction_ids: transaction_ids::TransactionIds,
|
||||
pub(super) userdevicetxnid_response: Arc<dyn KvTree>, // Response can be empty (/sendToDevice) or the event id (/send)
|
||||
//pub sending: sending::Sending,
|
||||
pub(super) servername_educount: Arc<dyn KvTree>, // EduCount: Count of last EDU sync
|
||||
pub(super) servernameevent_data: Arc<dyn KvTree>, // ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id (for edus), Data = EDU content
|
||||
pub(super) servercurrentevent_data: Arc<dyn KvTree>, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for edus), Data = EDU content
|
||||
|
||||
//pub appservice: appservice::Appservice,
|
||||
pub(super) id_appserviceregistrations: Arc<dyn KvTree>,
|
||||
|
||||
//pub pusher: pusher::PushData,
|
||||
pub(super) senderkey_pusher: Arc<dyn KvTree>,
|
||||
|
||||
pub(super) cached_registrations: Arc<RwLock<HashMap<String, serde_yaml::Value>>>,
|
||||
pub(super) pdu_cache: Mutex<LruCache<Box<EventId>, Arc<PduEvent>>>,
|
||||
pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>,
|
||||
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>,
|
||||
pub(super) eventidshort_cache: Mutex<LruCache<Box<EventId>, u64>>,
|
||||
pub(super) statekeyshort_cache: Mutex<LruCache<(StateEventType, String), u64>>,
|
||||
pub(super) shortstatekey_cache: Mutex<LruCache<u64, (StateEventType, String)>>,
|
||||
pub(super) our_real_users_cache: RwLock<HashMap<Box<RoomId>, Arc<HashSet<Box<UserId>>>>>,
|
||||
pub(super) appservice_in_room_cache: RwLock<HashMap<Box<RoomId>, HashMap<String, bool>>>,
|
||||
pub(super) lasttimelinecount_cache: Mutex<HashMap<Box<RoomId>, u64>>,
|
||||
}
|
||||
|
||||
impl KeyValueDatabase {
|
||||
/// Tries to remove the old database but ignores all errors.
|
||||
pub fn try_remove(server_name: &str) -> Result<()> {
|
||||
let mut path = ProjectDirs::from("xyz", "koesters", "conduit")
|
||||
.ok_or_else(|| Error::bad_config("The OS didn't return a valid home directory path."))?
|
||||
.data_dir()
|
||||
.to_path_buf();
|
||||
path.push(server_name);
|
||||
let _ = remove_dir_all(path);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_db_setup(config: &Config) -> Result<()> {
|
||||
let path = Path::new(&config.database_path);
|
||||
|
||||
let sled_exists = path.join("db").exists();
|
||||
let sqlite_exists = path.join("conduit.db").exists();
|
||||
let rocksdb_exists = path.join("IDENTITY").exists();
|
||||
|
||||
let mut count = 0;
|
||||
|
||||
if sled_exists {
|
||||
count += 1;
|
||||
}
|
||||
|
||||
if sqlite_exists {
|
||||
count += 1;
|
||||
}
|
||||
|
||||
if rocksdb_exists {
|
||||
count += 1;
|
||||
}
|
||||
|
||||
if count > 1 {
|
||||
warn!("Multiple databases at database_path detected");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if sled_exists && config.database_backend != "sled" {
|
||||
return Err(Error::bad_config(
|
||||
"Found sled at database_path, but is not specified in config.",
|
||||
));
|
||||
}
|
||||
|
||||
if sqlite_exists && config.database_backend != "sqlite" {
|
||||
return Err(Error::bad_config(
|
||||
"Found sqlite at database_path, but is not specified in config.",
|
||||
));
|
||||
}
|
||||
|
||||
if rocksdb_exists && config.database_backend != "rocksdb" {
|
||||
return Err(Error::bad_config(
|
||||
"Found rocksdb at database_path, but is not specified in config.",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load an existing database or create a new one.
|
||||
pub async fn load_or_create(config: Config) -> Result<()> {
|
||||
Self::check_db_setup(&config)?;
|
||||
|
||||
if !Path::new(&config.database_path).exists() {
|
||||
std::fs::create_dir_all(&config.database_path)
|
||||
.map_err(|_| Error::BadConfig("Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please create the database folder yourself."))?;
|
||||
}
|
||||
|
||||
let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config.database_backend {
|
||||
"sqlite" => {
|
||||
#[cfg(not(feature = "sqlite"))]
|
||||
return Err(Error::BadConfig("Database backend not found."));
|
||||
#[cfg(feature = "sqlite")]
|
||||
Arc::new(Arc::<abstraction::sqlite::Engine>::open(&config)?)
|
||||
}
|
||||
"rocksdb" => {
|
||||
#[cfg(not(feature = "rocksdb"))]
|
||||
return Err(Error::BadConfig("Database backend not found."));
|
||||
#[cfg(feature = "rocksdb")]
|
||||
Arc::new(Arc::<abstraction::rocksdb::Engine>::open(&config)?)
|
||||
}
|
||||
"persy" => {
|
||||
#[cfg(not(feature = "persy"))]
|
||||
return Err(Error::BadConfig("Database backend not found."));
|
||||
#[cfg(feature = "persy")]
|
||||
Arc::new(Arc::<abstraction::persy::Engine>::open(&config)?)
|
||||
}
|
||||
_ => {
|
||||
return Err(Error::BadConfig("Database backend not found."));
|
||||
}
|
||||
};
|
||||
|
||||
if config.max_request_size < 1024 {
|
||||
eprintln!("ERROR: Max request size is less than 1KB. Please increase it.");
|
||||
}
|
||||
|
||||
let db_raw = Box::new(Self {
|
||||
_db: builder.clone(),
|
||||
userid_password: builder.open_tree("userid_password")?,
|
||||
userid_displayname: builder.open_tree("userid_displayname")?,
|
||||
userid_avatarurl: builder.open_tree("userid_avatarurl")?,
|
||||
userid_blurhash: builder.open_tree("userid_blurhash")?,
|
||||
userdeviceid_token: builder.open_tree("userdeviceid_token")?,
|
||||
userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?,
|
||||
userid_devicelistversion: builder.open_tree("userid_devicelistversion")?,
|
||||
token_userdeviceid: builder.open_tree("token_userdeviceid")?,
|
||||
onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?,
|
||||
userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?,
|
||||
keychangeid_userid: builder.open_tree("keychangeid_userid")?,
|
||||
keyid_key: builder.open_tree("keyid_key")?,
|
||||
userid_masterkeyid: builder.open_tree("userid_masterkeyid")?,
|
||||
userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?,
|
||||
userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?,
|
||||
userfilterid_filter: builder.open_tree("userfilterid_filter")?,
|
||||
todeviceid_events: builder.open_tree("todeviceid_events")?,
|
||||
|
||||
userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?,
|
||||
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
|
||||
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
|
||||
roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt
|
||||
roomuserid_lastprivatereadupdate: builder
|
||||
.open_tree("roomuserid_lastprivatereadupdate")?,
|
||||
typingid_userid: builder.open_tree("typingid_userid")?,
|
||||
roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?,
|
||||
presenceid_presence: builder.open_tree("presenceid_presence")?,
|
||||
userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?,
|
||||
pduid_pdu: builder.open_tree("pduid_pdu")?,
|
||||
eventid_pduid: builder.open_tree("eventid_pduid")?,
|
||||
roomid_pduleaves: builder.open_tree("roomid_pduleaves")?,
|
||||
|
||||
alias_roomid: builder.open_tree("alias_roomid")?,
|
||||
aliasid_alias: builder.open_tree("aliasid_alias")?,
|
||||
publicroomids: builder.open_tree("publicroomids")?,
|
||||
|
||||
tokenids: builder.open_tree("tokenids")?,
|
||||
|
||||
roomserverids: builder.open_tree("roomserverids")?,
|
||||
serverroomids: builder.open_tree("serverroomids")?,
|
||||
userroomid_joined: builder.open_tree("userroomid_joined")?,
|
||||
roomuserid_joined: builder.open_tree("roomuserid_joined")?,
|
||||
roomid_joinedcount: builder.open_tree("roomid_joinedcount")?,
|
||||
roomid_invitedcount: builder.open_tree("roomid_invitedcount")?,
|
||||
roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?,
|
||||
userroomid_invitestate: builder.open_tree("userroomid_invitestate")?,
|
||||
roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?,
|
||||
userroomid_leftstate: builder.open_tree("userroomid_leftstate")?,
|
||||
roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?,
|
||||
|
||||
disabledroomids: builder.open_tree("disabledroomids")?,
|
||||
|
||||
lazyloadedids: builder.open_tree("lazyloadedids")?,
|
||||
|
||||
userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?,
|
||||
userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?,
|
||||
|
||||
statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?,
|
||||
shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?,
|
||||
|
||||
shorteventid_authchain: builder.open_tree("shorteventid_authchain")?,
|
||||
|
||||
roomid_shortroomid: builder.open_tree("roomid_shortroomid")?,
|
||||
|
||||
shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?,
|
||||
eventid_shorteventid: builder.open_tree("eventid_shorteventid")?,
|
||||
shorteventid_eventid: builder.open_tree("shorteventid_eventid")?,
|
||||
shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?,
|
||||
roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?,
|
||||
roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?,
|
||||
statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?,
|
||||
|
||||
eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?,
|
||||
softfailedeventids: builder.open_tree("softfailedeventids")?,
|
||||
|
||||
referencedevents: builder.open_tree("referencedevents")?,
|
||||
roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?,
|
||||
roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?,
|
||||
mediaid_file: builder.open_tree("mediaid_file")?,
|
||||
backupid_algorithm: builder.open_tree("backupid_algorithm")?,
|
||||
backupid_etag: builder.open_tree("backupid_etag")?,
|
||||
backupkeyid_backup: builder.open_tree("backupkeyid_backup")?,
|
||||
userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?,
|
||||
servername_educount: builder.open_tree("servername_educount")?,
|
||||
servernameevent_data: builder.open_tree("servernameevent_data")?,
|
||||
servercurrentevent_data: builder.open_tree("servercurrentevent_data")?,
|
||||
id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?,
|
||||
senderkey_pusher: builder.open_tree("senderkey_pusher")?,
|
||||
global: builder.open_tree("global")?,
|
||||
server_signingkeys: builder.open_tree("server_signingkeys")?,
|
||||
|
||||
cached_registrations: Arc::new(RwLock::new(HashMap::new())),
|
||||
pdu_cache: Mutex::new(LruCache::new(
|
||||
config
|
||||
.pdu_cache_capacity
|
||||
.try_into()
|
||||
.expect("pdu cache capacity fits into usize"),
|
||||
)),
|
||||
auth_chain_cache: Mutex::new(LruCache::new(
|
||||
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
shorteventid_cache: Mutex::new(LruCache::new(
|
||||
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
eventidshort_cache: Mutex::new(LruCache::new(
|
||||
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
shortstatekey_cache: Mutex::new(LruCache::new(
|
||||
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
statekeyshort_cache: Mutex::new(LruCache::new(
|
||||
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
our_real_users_cache: RwLock::new(HashMap::new()),
|
||||
appservice_in_room_cache: RwLock::new(HashMap::new()),
|
||||
lasttimelinecount_cache: Mutex::new(HashMap::new()),
|
||||
});
|
||||
|
||||
let db = Box::leak(db_raw);
|
||||
|
||||
let services_raw = Box::new(Services::build(db, config)?);
|
||||
|
||||
// This is the first and only time we initialize the SERVICE static
|
||||
*SERVICES.write().unwrap() = Some(Box::leak(services_raw));
|
||||
|
||||
// Matrix resource ownership is based on the server name; changing it
|
||||
// requires recreating the database from scratch.
|
||||
if services().users.count()? > 0 {
|
||||
let conduit_user =
|
||||
UserId::parse_with_server_name("conduit", services().globals.server_name())
|
||||
.expect("@conduit:server_name is valid");
|
||||
|
||||
if !services().users.exists(&conduit_user)? {
|
||||
error!(
|
||||
"The {} server user does not exist, and the database is not new.",
|
||||
conduit_user
|
||||
);
|
||||
return Err(Error::bad_database(
|
||||
"Cannot reuse an existing database after changing the server name, please delete the old one first."
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// If the database has any data, perform data migrations before starting
|
||||
let latest_database_version = 11;
|
||||
|
||||
if services().users.count()? > 0 {
|
||||
// MIGRATIONS
|
||||
if services().globals.database_version()? < 1 {
|
||||
for (roomserverid, _) in db.roomserverids.iter() {
|
||||
let mut parts = roomserverid.split(|&b| b == 0xff);
|
||||
let room_id = parts.next().expect("split always returns one element");
|
||||
let servername = match parts.next() {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
error!("Migration: Invalid roomserverid in db.");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let mut serverroomid = servername.to_vec();
|
||||
serverroomid.push(0xff);
|
||||
serverroomid.extend_from_slice(room_id);
|
||||
|
||||
db.serverroomids.insert(&serverroomid, &[])?;
|
||||
}
|
||||
|
||||
services().globals.bump_database_version(1)?;
|
||||
|
||||
warn!("Migration: 0 -> 1 finished");
|
||||
}
|
||||
|
||||
if services().globals.database_version()? < 2 {
|
||||
// We accidentally inserted hashed versions of "" into the db instead of just ""
|
||||
for (userid, password) in db.userid_password.iter() {
|
||||
let password = utils::string_from_bytes(&password);
|
||||
|
||||
let empty_hashed_password = password.map_or(false, |password| {
|
||||
argon2::verify_encoded(&password, b"").unwrap_or(false)
|
||||
});
|
||||
|
||||
if empty_hashed_password {
|
||||
db.userid_password.insert(&userid, b"")?;
|
||||
}
|
||||
}
|
||||
|
||||
services().globals.bump_database_version(2)?;
|
||||
|
||||
warn!("Migration: 1 -> 2 finished");
|
||||
}
|
||||
|
||||
if services().globals.database_version()? < 3 {
|
||||
// Move media to filesystem
|
||||
for (key, content) in db.mediaid_file.iter() {
|
||||
if content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let path = services().globals.get_media_file(&key);
|
||||
let mut file = fs::File::create(path)?;
|
||||
file.write_all(&content)?;
|
||||
db.mediaid_file.insert(&key, &[])?;
|
||||
}
|
||||
|
||||
services().globals.bump_database_version(3)?;
|
||||
|
||||
warn!("Migration: 2 -> 3 finished");
|
||||
}
|
||||
|
||||
if services().globals.database_version()? < 4 {
|
||||
// Add federated users to services() as deactivated
|
||||
for our_user in services().users.iter() {
|
||||
let our_user = our_user?;
|
||||
if services().users.is_deactivated(&our_user)? {
|
||||
continue;
|
||||
}
|
||||
for room in services().rooms.state_cache.rooms_joined(&our_user) {
|
||||
for user in services().rooms.state_cache.room_members(&room?) {
|
||||
let user = user?;
|
||||
if user.server_name() != services().globals.server_name() {
|
||||
println!("Migration: Creating user {}", user);
|
||||
services().users.create(&user, None)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
services().globals.bump_database_version(4)?;
|
||||
|
||||
warn!("Migration: 3 -> 4 finished");
|
||||
}
|
||||
|
||||
if services().globals.database_version()? < 5 {
|
||||
// Upgrade user data store
|
||||
for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() {
|
||||
let mut parts = roomuserdataid.split(|&b| b == 0xff);
|
||||
let room_id = parts.next().unwrap();
|
||||
let user_id = parts.next().unwrap();
|
||||
let event_type = roomuserdataid.rsplit(|&b| b == 0xff).next().unwrap();
|
||||
|
||||
let mut key = room_id.to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id);
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(event_type);
|
||||
|
||||
db.roomusertype_roomuserdataid
|
||||
.insert(&key, &roomuserdataid)?;
|
||||
}
|
||||
|
||||
services().globals.bump_database_version(5)?;
|
||||
|
||||
warn!("Migration: 4 -> 5 finished");
|
||||
}
|
||||
|
||||
if services().globals.database_version()? < 6 {
|
||||
// Set room member count
|
||||
for (roomid, _) in db.roomid_shortstatehash.iter() {
|
||||
let string = utils::string_from_bytes(&roomid).unwrap();
|
||||
let room_id = <&RoomId>::try_from(string.as_str()).unwrap();
|
||||
services().rooms.state_cache.update_joined_count(room_id)?;
|
||||
}
|
||||
|
||||
services().globals.bump_database_version(6)?;
|
||||
|
||||
warn!("Migration: 5 -> 6 finished");
|
||||
}
|
||||
|
||||
if services().globals.database_version()? < 7 {
|
||||
// Upgrade state store
|
||||
let mut last_roomstates: HashMap<Box<RoomId>, u64> = HashMap::new();
|
||||
let mut current_sstatehash: Option<u64> = None;
|
||||
let mut current_room = None;
|
||||
let mut current_state = HashSet::new();
|
||||
let mut counter = 0;
|
||||
|
||||
let mut handle_state =
|
||||
|current_sstatehash: u64,
|
||||
current_room: &RoomId,
|
||||
current_state: HashSet<_>,
|
||||
last_roomstates: &mut HashMap<_, _>| {
|
||||
counter += 1;
|
||||
println!("counter: {}", counter);
|
||||
let last_roomsstatehash = last_roomstates.get(current_room);
|
||||
|
||||
let states_parents = last_roomsstatehash.map_or_else(
|
||||
|| Ok(Vec::new()),
|
||||
|&last_roomsstatehash| {
|
||||
services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(dbg!(last_roomsstatehash))
|
||||
},
|
||||
)?;
|
||||
|
||||
let (statediffnew, statediffremoved) =
|
||||
if let Some(parent_stateinfo) = states_parents.last() {
|
||||
let statediffnew = current_state
|
||||
.difference(&parent_stateinfo.1)
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
let statediffremoved = parent_stateinfo
|
||||
.1
|
||||
.difference(¤t_state)
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
|
||||
(statediffnew, statediffremoved)
|
||||
} else {
|
||||
(current_state, HashSet::new())
|
||||
};
|
||||
|
||||
services().rooms.state_compressor.save_state_from_diff(
|
||||
dbg!(current_sstatehash),
|
||||
statediffnew,
|
||||
statediffremoved,
|
||||
2, // every state change is 2 event changes on average
|
||||
states_parents,
|
||||
)?;
|
||||
|
||||
/*
|
||||
let mut tmp = services().rooms.load_shortstatehash_info(¤t_sstatehash)?;
|
||||
let state = tmp.pop().unwrap();
|
||||
println!(
|
||||
"{}\t{}{:?}: {:?} + {:?} - {:?}",
|
||||
current_room,
|
||||
" ".repeat(tmp.len()),
|
||||
utils::u64_from_bytes(¤t_sstatehash).unwrap(),
|
||||
tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()),
|
||||
state
|
||||
.2
|
||||
.iter()
|
||||
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
|
||||
.collect::<Vec<_>>(),
|
||||
state
|
||||
.3
|
||||
.iter()
|
||||
.map(|b| utils::u64_from_bytes(&b[size_of::<u64>()..]).unwrap())
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
*/
|
||||
|
||||
Ok::<_, Error>(())
|
||||
};
|
||||
|
||||
for (k, seventid) in db._db.open_tree("stateid_shorteventid")?.iter() {
|
||||
let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()])
|
||||
.expect("number of bytes is correct");
|
||||
let sstatekey = k[size_of::<u64>()..].to_vec();
|
||||
if Some(sstatehash) != current_sstatehash {
|
||||
if let Some(current_sstatehash) = current_sstatehash {
|
||||
handle_state(
|
||||
current_sstatehash,
|
||||
current_room.as_deref().unwrap(),
|
||||
current_state,
|
||||
&mut last_roomstates,
|
||||
)?;
|
||||
last_roomstates
|
||||
.insert(current_room.clone().unwrap(), current_sstatehash);
|
||||
}
|
||||
current_state = HashSet::new();
|
||||
current_sstatehash = Some(sstatehash);
|
||||
|
||||
let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap();
|
||||
let string = utils::string_from_bytes(&event_id).unwrap();
|
||||
let event_id = <&EventId>::try_from(string.as_str()).unwrap();
|
||||
let pdu = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu(event_id)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
|
||||
if Some(&pdu.room_id) != current_room.as_ref() {
|
||||
current_room = Some(pdu.room_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut val = sstatekey;
|
||||
val.extend_from_slice(&seventid);
|
||||
current_state.insert(val.try_into().expect("size is correct"));
|
||||
}
|
||||
|
||||
if let Some(current_sstatehash) = current_sstatehash {
|
||||
handle_state(
|
||||
current_sstatehash,
|
||||
current_room.as_deref().unwrap(),
|
||||
current_state,
|
||||
&mut last_roomstates,
|
||||
)?;
|
||||
}
|
||||
|
||||
services().globals.bump_database_version(7)?;
|
||||
|
||||
warn!("Migration: 6 -> 7 finished");
|
||||
}
|
||||
|
||||
if services().globals.database_version()? < 8 {
|
||||
// Generate short room ids for all rooms
|
||||
for (room_id, _) in db.roomid_shortstatehash.iter() {
|
||||
let shortroomid = services().globals.next_count()?.to_be_bytes();
|
||||
db.roomid_shortroomid.insert(&room_id, &shortroomid)?;
|
||||
info!("Migration: 8");
|
||||
}
|
||||
// Update pduids db layout
|
||||
let mut batch = db.pduid_pdu.iter().filter_map(|(key, v)| {
|
||||
if !key.starts_with(b"!") {
|
||||
return None;
|
||||
}
|
||||
let mut parts = key.splitn(2, |&b| b == 0xff);
|
||||
let room_id = parts.next().unwrap();
|
||||
let count = parts.next().unwrap();
|
||||
|
||||
let short_room_id = db
|
||||
.roomid_shortroomid
|
||||
.get(room_id)
|
||||
.unwrap()
|
||||
.expect("shortroomid should exist");
|
||||
|
||||
let mut new_key = short_room_id;
|
||||
new_key.extend_from_slice(count);
|
||||
|
||||
Some((new_key, v))
|
||||
});
|
||||
|
||||
db.pduid_pdu.insert_batch(&mut batch)?;
|
||||
|
||||
let mut batch2 = db.eventid_pduid.iter().filter_map(|(k, value)| {
|
||||
if !value.starts_with(b"!") {
|
||||
return None;
|
||||
}
|
||||
let mut parts = value.splitn(2, |&b| b == 0xff);
|
||||
let room_id = parts.next().unwrap();
|
||||
let count = parts.next().unwrap();
|
||||
|
||||
let short_room_id = db
|
||||
.roomid_shortroomid
|
||||
.get(room_id)
|
||||
.unwrap()
|
||||
.expect("shortroomid should exist");
|
||||
|
||||
let mut new_value = short_room_id;
|
||||
new_value.extend_from_slice(count);
|
||||
|
||||
Some((k, new_value))
|
||||
});
|
||||
|
||||
db.eventid_pduid.insert_batch(&mut batch2)?;
|
||||
|
||||
services().globals.bump_database_version(8)?;
|
||||
|
||||
warn!("Migration: 7 -> 8 finished");
|
||||
}
|
||||
|
||||
if services().globals.database_version()? < 9 {
|
||||
// Update tokenids db layout
|
||||
let mut iter = db
|
||||
.tokenids
|
||||
.iter()
|
||||
.filter_map(|(key, _)| {
|
||||
if !key.starts_with(b"!") {
|
||||
return None;
|
||||
}
|
||||
let mut parts = key.splitn(4, |&b| b == 0xff);
|
||||
let room_id = parts.next().unwrap();
|
||||
let word = parts.next().unwrap();
|
||||
let _pdu_id_room = parts.next().unwrap();
|
||||
let pdu_id_count = parts.next().unwrap();
|
||||
|
||||
let short_room_id = db
|
||||
.roomid_shortroomid
|
||||
.get(room_id)
|
||||
.unwrap()
|
||||
.expect("shortroomid should exist");
|
||||
let mut new_key = short_room_id;
|
||||
new_key.extend_from_slice(word);
|
||||
new_key.push(0xff);
|
||||
new_key.extend_from_slice(pdu_id_count);
|
||||
println!("old {:?}", key);
|
||||
println!("new {:?}", new_key);
|
||||
Some((new_key, Vec::new()))
|
||||
})
|
||||
.peekable();
|
||||
|
||||
while iter.peek().is_some() {
|
||||
db.tokenids.insert_batch(&mut iter.by_ref().take(1000))?;
|
||||
println!("smaller batch done");
|
||||
}
|
||||
|
||||
info!("Deleting starts");
|
||||
|
||||
let batch2: Vec<_> = db
|
||||
.tokenids
|
||||
.iter()
|
||||
.filter_map(|(key, _)| {
|
||||
if key.starts_with(b"!") {
|
||||
println!("del {:?}", key);
|
||||
Some(key)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
for key in batch2 {
|
||||
println!("del");
|
||||
db.tokenids.remove(&key)?;
|
||||
}
|
||||
|
||||
services().globals.bump_database_version(9)?;
|
||||
|
||||
warn!("Migration: 8 -> 9 finished");
|
||||
}
|
||||
|
||||
if services().globals.database_version()? < 10 {
|
||||
// Add other direction for shortstatekeys
|
||||
for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() {
|
||||
db.shortstatekey_statekey
|
||||
.insert(&shortstatekey, &statekey)?;
|
||||
}
|
||||
|
||||
// Force E2EE device list updates so we can send them over federation
|
||||
for user_id in services().users.iter().filter_map(|r| r.ok()) {
|
||||
services().users.mark_device_key_update(&user_id)?;
|
||||
}
|
||||
|
||||
services().globals.bump_database_version(10)?;
|
||||
|
||||
warn!("Migration: 9 -> 10 finished");
|
||||
}
|
||||
|
||||
if services().globals.database_version()? < 11 {
|
||||
db._db
|
||||
.open_tree("userdevicesessionid_uiaarequest")?
|
||||
.clear()?;
|
||||
services().globals.bump_database_version(11)?;
|
||||
|
||||
warn!("Migration: 10 -> 11 finished");
|
||||
}
|
||||
|
||||
assert_eq!(11, latest_database_version);
|
||||
|
||||
info!(
|
||||
"Loaded {} database with version {}",
|
||||
services().globals.config.database_backend,
|
||||
latest_database_version
|
||||
);
|
||||
} else {
|
||||
services()
|
||||
.globals
|
||||
.bump_database_version(latest_database_version)?;
|
||||
|
||||
// Create the admin room and server user on first run
|
||||
services().admin.create_admin_room().await?;
|
||||
|
||||
warn!(
|
||||
"Created new {} database with version {}",
|
||||
services().globals.config.database_backend,
|
||||
latest_database_version
|
||||
);
|
||||
}
|
||||
|
||||
// This data is probably outdated
|
||||
db.presenceid_presence.clear()?;
|
||||
|
||||
services().admin.start_handler();
|
||||
|
||||
// Set emergency access for the conduit user
|
||||
match set_emergency_access() {
|
||||
Ok(pwd_set) => {
|
||||
if pwd_set {
|
||||
warn!("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!");
|
||||
services().admin.send_message(RoomMessageEventContent::text_plain("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!"));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Could not set the configured emergency password for the conduit user: {}",
|
||||
e
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
services().sending.start_handler();
|
||||
|
||||
Self::start_cleanup_task().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn flush(&self) -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let res = self._db.flush();
|
||||
|
||||
debug!("flush: took {:?}", start.elapsed());
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
#[tracing::instrument]
|
||||
pub async fn start_cleanup_task() {
|
||||
use tokio::time::interval;
|
||||
|
||||
#[cfg(unix)]
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
use tracing::info;
|
||||
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
let timer_interval =
|
||||
Duration::from_secs(services().globals.config.cleanup_second_interval as u64);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut i = interval(timer_interval);
|
||||
#[cfg(unix)]
|
||||
let mut s = signal(SignalKind::hangup()).unwrap();
|
||||
|
||||
loop {
|
||||
#[cfg(unix)]
|
||||
tokio::select! {
|
||||
_ = i.tick() => {
|
||||
info!("cleanup: Timer ticked");
|
||||
}
|
||||
_ = s.recv() => {
|
||||
info!("cleanup: Received SIGHUP");
|
||||
}
|
||||
};
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
i.tick().await;
|
||||
info!("cleanup: Timer ticked")
|
||||
}
|
||||
|
||||
let start = Instant::now();
|
||||
if let Err(e) = services().globals.cleanup() {
|
||||
error!("cleanup: Errored: {}", e);
|
||||
} else {
|
||||
info!("cleanup: Finished in {:?}", start.elapsed());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets the emergency password and push rules for the @conduit account in case emergency password is set
|
||||
fn set_emergency_access() -> Result<bool> {
|
||||
let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
|
||||
.expect("@conduit:server_name is a valid UserId");
|
||||
|
||||
services().users.set_password(
|
||||
&conduit_user,
|
||||
services().globals.emergency_password().as_deref(),
|
||||
)?;
|
||||
|
||||
let (ruleset, res) = match services().globals.emergency_password() {
|
||||
Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)),
|
||||
None => (Ruleset::new(), Ok(false)),
|
||||
};
|
||||
|
||||
services().account_data.update(
|
||||
None,
|
||||
&conduit_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&serde_json::to_value(&GlobalAccountDataEvent {
|
||||
content: PushRulesEventContent { global: ruleset },
|
||||
})
|
||||
.expect("to json value always works"),
|
||||
)?;
|
||||
|
||||
res
|
||||
}
|
|
@ -1,348 +0,0 @@
|
|||
use crate::{Database, Error, PduEvent, Result};
|
||||
use bytes::BytesMut;
|
||||
use ruma::{
|
||||
api::{
|
||||
client::push::{get_pushers, set_pusher, PusherKind},
|
||||
push_gateway::send_event_notification::{
|
||||
self,
|
||||
v1::{Device, Notification, NotificationCounts, NotificationPriority},
|
||||
},
|
||||
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
|
||||
},
|
||||
events::{
|
||||
room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent},
|
||||
AnySyncRoomEvent, RoomEventType, StateEventType,
|
||||
},
|
||||
push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
|
||||
serde::Raw,
|
||||
uint, RoomId, UInt, UserId,
|
||||
};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use std::{fmt::Debug, mem, sync::Arc};
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub struct PushData {
|
||||
/// UserId + pushkey -> Pusher
|
||||
pub(super) senderkey_pusher: Arc<dyn Tree>,
|
||||
}
|
||||
|
||||
impl PushData {
|
||||
#[tracing::instrument(skip(self, sender, pusher))]
|
||||
pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> {
|
||||
let mut key = sender.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(pusher.pushkey.as_bytes());
|
||||
|
||||
// There are 2 kinds of pushers but the spec says: null deletes the pusher.
|
||||
if pusher.kind.is_none() {
|
||||
return self
|
||||
.senderkey_pusher
|
||||
.remove(&key)
|
||||
.map(|_| ())
|
||||
.map_err(Into::into);
|
||||
}
|
||||
|
||||
self.senderkey_pusher.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, senderkey))]
|
||||
pub fn get_pusher(&self, senderkey: &[u8]) -> Result<Option<get_pushers::v3::Pusher>> {
|
||||
self.senderkey_pusher
|
||||
.get(senderkey)?
|
||||
.map(|push| {
|
||||
serde_json::from_slice(&*push)
|
||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, sender))]
|
||||
pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<get_pushers::v3::Pusher>> {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
self.senderkey_pusher
|
||||
.scan_prefix(prefix)
|
||||
.map(|(_, push)| {
|
||||
serde_json::from_slice(&*push)
|
||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, sender))]
|
||||
pub fn get_pusher_senderkeys<'a>(
|
||||
&'a self,
|
||||
sender: &UserId,
|
||||
) -> impl Iterator<Item = Vec<u8>> + 'a {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| k)
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(globals, destination, request))]
|
||||
pub async fn send_request<T: OutgoingRequest>(
|
||||
globals: &crate::database::globals::Globals,
|
||||
destination: &str,
|
||||
request: T,
|
||||
) -> Result<T::IncomingResponse>
|
||||
where
|
||||
T: Debug,
|
||||
{
|
||||
let destination = destination.replace("/_matrix/push/v1/notify", "");
|
||||
|
||||
let http_request = request
|
||||
.try_into_http_request::<BytesMut>(
|
||||
&destination,
|
||||
SendAccessToken::IfRequired(""),
|
||||
&[MatrixVersion::V1_0],
|
||||
)
|
||||
.map_err(|e| {
|
||||
warn!("Failed to find destination {}: {}", destination, e);
|
||||
Error::BadServerResponse("Invalid destination")
|
||||
})?
|
||||
.map(|body| body.freeze());
|
||||
|
||||
let reqwest_request = reqwest::Request::try_from(http_request)
|
||||
.expect("all http requests are valid reqwest requests");
|
||||
|
||||
// TODO: we could keep this very short and let expo backoff do it's thing...
|
||||
//*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
|
||||
|
||||
let url = reqwest_request.url().clone();
|
||||
let response = globals.default_client().execute(reqwest_request).await;
|
||||
|
||||
match response {
|
||||
Ok(mut response) => {
|
||||
// reqwest::Response -> http::Response conversion
|
||||
let status = response.status();
|
||||
let mut http_response_builder = http::Response::builder()
|
||||
.status(status)
|
||||
.version(response.version());
|
||||
mem::swap(
|
||||
response.headers_mut(),
|
||||
http_response_builder
|
||||
.headers_mut()
|
||||
.expect("http::response::Builder is usable"),
|
||||
);
|
||||
|
||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||
warn!("server error {}", e);
|
||||
Vec::new().into()
|
||||
}); // TODO: handle timeout
|
||||
|
||||
if status != 200 {
|
||||
info!(
|
||||
"Push gateway returned bad response {} {}\n{}\n{:?}",
|
||||
destination,
|
||||
status,
|
||||
url,
|
||||
crate::utils::string_from_bytes(&body)
|
||||
);
|
||||
}
|
||||
|
||||
let response = T::IncomingResponse::try_from_http_response(
|
||||
http_response_builder
|
||||
.body(body)
|
||||
.expect("reqwest body is valid http body"),
|
||||
);
|
||||
response.map_err(|_| {
|
||||
info!(
|
||||
"Push gateway returned invalid response bytes {}\n{}",
|
||||
destination, url
|
||||
);
|
||||
Error::BadServerResponse("Push gateway returned bad response.")
|
||||
})
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))]
|
||||
pub async fn send_push_notice(
|
||||
user: &UserId,
|
||||
unread: UInt,
|
||||
pusher: &get_pushers::v3::Pusher,
|
||||
ruleset: Ruleset,
|
||||
pdu: &PduEvent,
|
||||
db: &Database,
|
||||
) -> Result<()> {
|
||||
let mut notify = None;
|
||||
let mut tweaks = Vec::new();
|
||||
|
||||
let power_levels: RoomPowerLevelsEventContent = db
|
||||
.rooms
|
||||
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")?
|
||||
.map(|ev| {
|
||||
serde_json::from_str(ev.content.get())
|
||||
.map_err(|_| Error::bad_database("invalid m.room.power_levels event"))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
|
||||
for action in get_actions(
|
||||
user,
|
||||
&ruleset,
|
||||
&power_levels,
|
||||
&pdu.to_sync_room_event(),
|
||||
&pdu.room_id,
|
||||
db,
|
||||
)? {
|
||||
let n = match action {
|
||||
Action::DontNotify => false,
|
||||
// TODO: Implement proper support for coalesce
|
||||
Action::Notify | Action::Coalesce => true,
|
||||
Action::SetTweak(tweak) => {
|
||||
tweaks.push(tweak.clone());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if notify.is_some() {
|
||||
return Err(Error::bad_database(
|
||||
r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#,
|
||||
));
|
||||
}
|
||||
|
||||
notify = Some(n);
|
||||
}
|
||||
|
||||
if notify == Some(true) {
|
||||
send_notice(unread, pusher, tweaks, pdu, db).await?;
|
||||
}
|
||||
// Else the event triggered no actions
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(user, ruleset, pdu, db))]
|
||||
pub fn get_actions<'a>(
|
||||
user: &UserId,
|
||||
ruleset: &'a Ruleset,
|
||||
power_levels: &RoomPowerLevelsEventContent,
|
||||
pdu: &Raw<AnySyncRoomEvent>,
|
||||
room_id: &RoomId,
|
||||
db: &Database,
|
||||
) -> Result<&'a [Action]> {
|
||||
let ctx = PushConditionRoomCtx {
|
||||
room_id: room_id.to_owned(),
|
||||
member_count: 10_u32.into(), // TODO: get member count efficiently
|
||||
user_display_name: db
|
||||
.users
|
||||
.displayname(user)?
|
||||
.unwrap_or_else(|| user.localpart().to_owned()),
|
||||
users_power_levels: power_levels.users.clone(),
|
||||
default_power_level: power_levels.users_default,
|
||||
notification_power_levels: power_levels.notifications.clone(),
|
||||
};
|
||||
|
||||
Ok(ruleset.get_actions(pdu, &ctx))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(unread, pusher, tweaks, event, db))]
|
||||
async fn send_notice(
|
||||
unread: UInt,
|
||||
pusher: &get_pushers::v3::Pusher,
|
||||
tweaks: Vec<Tweak>,
|
||||
event: &PduEvent,
|
||||
db: &Database,
|
||||
) -> Result<()> {
|
||||
// TODO: email
|
||||
if pusher.kind == PusherKind::Email {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// Two problems with this
|
||||
// 1. if "event_id_only" is the only format kind it seems we should never add more info
|
||||
// 2. can pusher/devices have conflicting formats
|
||||
let event_id_only = pusher.data.format == Some(PushFormat::EventIdOnly);
|
||||
let url = if let Some(url) = &pusher.data.url {
|
||||
url
|
||||
} else {
|
||||
error!("Http Pusher must have URL specified.");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let mut device = Device::new(pusher.app_id.clone(), pusher.pushkey.clone());
|
||||
let mut data_minus_url = pusher.data.clone();
|
||||
// The url must be stripped off according to spec
|
||||
data_minus_url.url = None;
|
||||
device.data = data_minus_url;
|
||||
|
||||
// Tweaks are only added if the format is NOT event_id_only
|
||||
if !event_id_only {
|
||||
device.tweaks = tweaks.clone();
|
||||
}
|
||||
|
||||
let d = &[device];
|
||||
let mut notifi = Notification::new(d);
|
||||
|
||||
notifi.prio = NotificationPriority::Low;
|
||||
notifi.event_id = Some(&event.event_id);
|
||||
notifi.room_id = Some(&event.room_id);
|
||||
// TODO: missed calls
|
||||
notifi.counts = NotificationCounts::new(unread, uint!(0));
|
||||
|
||||
if event.kind == RoomEventType::RoomEncrypted
|
||||
|| tweaks
|
||||
.iter()
|
||||
.any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_)))
|
||||
{
|
||||
notifi.prio = NotificationPriority::High
|
||||
}
|
||||
|
||||
if event_id_only {
|
||||
send_request(
|
||||
&db.globals,
|
||||
url,
|
||||
send_event_notification::v1::Request::new(notifi),
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
notifi.sender = Some(&event.sender);
|
||||
notifi.event_type = Some(&event.kind);
|
||||
let content = serde_json::value::to_raw_value(&event.content).ok();
|
||||
notifi.content = content.as_deref();
|
||||
|
||||
if event.kind == RoomEventType::RoomMember {
|
||||
notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str());
|
||||
}
|
||||
|
||||
let user_name = db.users.displayname(&event.sender)?;
|
||||
notifi.sender_display_name = user_name.as_deref();
|
||||
|
||||
let room_name = if let Some(room_name_pdu) =
|
||||
db.rooms
|
||||
.room_state_get(&event.room_id, &StateEventType::RoomName, "")?
|
||||
{
|
||||
serde_json::from_str::<RoomNameEventContent>(room_name_pdu.content.get())
|
||||
.map_err(|_| Error::bad_database("Invalid room name event in database."))?
|
||||
.name
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
notifi.room_name = room_name.as_deref();
|
||||
|
||||
send_request(
|
||||
&db.globals,
|
||||
url,
|
||||
send_event_notification::v1::Request::new(notifi),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// TODO: email
|
||||
|
||||
Ok(())
|
||||
}
|
File diff suppressed because it is too large
Load diff
|
@ -1,550 +0,0 @@
|
|||
use crate::{database::abstraction::Tree, utils, Error, Result};
|
||||
use ruma::{
|
||||
events::{
|
||||
presence::{PresenceEvent, PresenceEventContent},
|
||||
receipt::ReceiptEvent,
|
||||
SyncEphemeralRoomEvent,
|
||||
},
|
||||
presence::PresenceState,
|
||||
serde::Raw,
|
||||
signatures::CanonicalJsonObject,
|
||||
RoomId, UInt, UserId,
|
||||
};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
mem,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
pub struct RoomEdus {
|
||||
pub(in super::super) readreceiptid_readreceipt: Arc<dyn Tree>, // ReadReceiptId = RoomId + Count + UserId
|
||||
pub(in super::super) roomuserid_privateread: Arc<dyn Tree>, // RoomUserId = Room + User, PrivateRead = Count
|
||||
pub(in super::super) roomuserid_lastprivatereadupdate: Arc<dyn Tree>, // LastPrivateReadUpdate = Count
|
||||
pub(in super::super) typingid_userid: Arc<dyn Tree>, // TypingId = RoomId + TimeoutTime + Count
|
||||
pub(in super::super) roomid_lasttypingupdate: Arc<dyn Tree>, // LastRoomTypingUpdate = Count
|
||||
pub(in super::super) presenceid_presence: Arc<dyn Tree>, // PresenceId = RoomId + Count + UserId
|
||||
pub(in super::super) userid_lastpresenceupdate: Arc<dyn Tree>, // LastPresenceUpdate = Count
|
||||
}
|
||||
|
||||
impl RoomEdus {
|
||||
/// Adds an event which will be saved until a new event replaces it (e.g. read receipt).
|
||||
pub fn readreceipt_update(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
event: ReceiptEvent,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
// Remove old entry
|
||||
if let Some((old, _)) = self
|
||||
.readreceiptid_readreceipt
|
||||
.iter_from(&last_possible_key, true)
|
||||
.take_while(|(key, _)| key.starts_with(&prefix))
|
||||
.find(|(key, _)| {
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element")
|
||||
== user_id.as_bytes()
|
||||
})
|
||||
{
|
||||
// This is the old room_latest
|
||||
self.readreceiptid_readreceipt.remove(&old)?;
|
||||
}
|
||||
|
||||
let mut room_latest_id = prefix;
|
||||
room_latest_id.extend_from_slice(&globals.next_count()?.to_be_bytes());
|
||||
room_latest_id.push(0xff);
|
||||
room_latest_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.readreceiptid_readreceipt.insert(
|
||||
&room_latest_id,
|
||||
&serde_json::to_vec(&event).expect("EduEvent::to_string always works"),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn readreceipts_since<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
) -> impl Iterator<
|
||||
Item = Result<(
|
||||
Box<UserId>,
|
||||
u64,
|
||||
Raw<ruma::events::AnySyncEphemeralRoomEvent>,
|
||||
)>,
|
||||
> + 'a {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
let prefix2 = prefix.clone();
|
||||
|
||||
let mut first_possible_edu = prefix.clone();
|
||||
first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
|
||||
|
||||
self.readreceiptid_readreceipt
|
||||
.iter_from(&first_possible_edu, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||
.map(move |(k, v)| {
|
||||
let count =
|
||||
utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::<u64>()])
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
|
||||
let user_id = UserId::parse(
|
||||
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid readreceiptid userid bytes in db.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
|
||||
|
||||
let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v).map_err(|_| {
|
||||
Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json.")
|
||||
})?;
|
||||
json.remove("room_id");
|
||||
|
||||
Ok((
|
||||
user_id,
|
||||
count,
|
||||
Raw::from_json(
|
||||
serde_json::value::to_raw_value(&json).expect("json is valid raw value"),
|
||||
),
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Sets a private read marker at `count`.
|
||||
#[tracing::instrument(skip(self, globals))]
|
||||
pub fn private_read_set(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
count: u64,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_privateread
|
||||
.insert(&key, &count.to_be_bytes())?;
|
||||
|
||||
self.roomuserid_lastprivatereadupdate
|
||||
.insert(&key, &globals.next_count()?.to_be_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the private read marker.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_privateread
|
||||
.get(&key)?
|
||||
.map_or(Ok(None), |v| {
|
||||
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
|
||||
Error::bad_database("Invalid private read marker bytes")
|
||||
})?))
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the count of the last typing update in this room.
|
||||
pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
Ok(self
|
||||
.roomuserid_lastprivatereadupdate
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
|
||||
/// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is
|
||||
/// called.
|
||||
pub fn typing_add(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
timeout: u64,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let count = globals.next_count()?.to_be_bytes();
|
||||
|
||||
let mut room_typing_id = prefix;
|
||||
room_typing_id.extend_from_slice(&timeout.to_be_bytes());
|
||||
room_typing_id.push(0xff);
|
||||
room_typing_id.extend_from_slice(&count);
|
||||
|
||||
self.typingid_userid
|
||||
.insert(&room_typing_id, &*user_id.as_bytes())?;
|
||||
|
||||
self.roomid_lasttypingupdate
|
||||
.insert(room_id.as_bytes(), &count)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes a user from typing before the timeout is reached.
|
||||
pub fn typing_remove(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let user_id = user_id.to_string();
|
||||
|
||||
let mut found_outdated = false;
|
||||
|
||||
// Maybe there are multiple ones from calling roomtyping_add multiple times
|
||||
for outdated_edu in self
|
||||
.typingid_userid
|
||||
.scan_prefix(prefix)
|
||||
.filter(|(_, v)| &**v == user_id.as_bytes())
|
||||
{
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate
|
||||
.insert(room_id.as_bytes(), &globals.next_count()?.to_be_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Makes sure that typing events with old timestamps get removed.
|
||||
fn typings_maintain(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let current_timestamp = utils::millis_since_unix_epoch();
|
||||
|
||||
let mut found_outdated = false;
|
||||
|
||||
// Find all outdated edus before inserting a new one
|
||||
for outdated_edu in self
|
||||
.typingid_userid
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
Ok::<_, Error>((
|
||||
key.clone(),
|
||||
utils::u64_from_bytes(
|
||||
&key.splitn(2, |&b| b == 0xff).nth(1).ok_or_else(|| {
|
||||
Error::bad_database("RoomTyping has invalid timestamp or delimiters.")
|
||||
})?[0..mem::size_of::<u64>()],
|
||||
)
|
||||
.map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?,
|
||||
))
|
||||
})
|
||||
.filter_map(|r| r.ok())
|
||||
.take_while(|&(_, timestamp)| timestamp < current_timestamp)
|
||||
{
|
||||
// This is an outdated edu (time > timestamp)
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate
|
||||
.insert(room_id.as_bytes(), &globals.next_count()?.to_be_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the count of the last typing update in this room.
|
||||
#[tracing::instrument(skip(self, globals))]
|
||||
pub fn last_typing_update(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<u64> {
|
||||
self.typings_maintain(room_id, globals)?;
|
||||
|
||||
Ok(self
|
||||
.roomid_lasttypingupdate
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
|
||||
pub fn typings_all(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<SyncEphemeralRoomEvent<ruma::events::typing::TypingEventContent>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let mut user_ids = HashSet::new();
|
||||
|
||||
for (_, user_id) in self.typingid_userid.scan_prefix(prefix) {
|
||||
let user_id = UserId::parse(utils::string_from_bytes(&user_id).map_err(|_| {
|
||||
Error::bad_database("User ID in typingid_userid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?;
|
||||
|
||||
user_ids.insert(user_id);
|
||||
}
|
||||
|
||||
Ok(SyncEphemeralRoomEvent {
|
||||
content: ruma::events::typing::TypingEventContent {
|
||||
user_ids: user_ids.into_iter().collect(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
/// Adds a presence event which will be saved until a new event replaces it.
|
||||
///
|
||||
/// Note: This method takes a RoomId because presence updates are always bound to rooms to
|
||||
/// make sure users outside these rooms can't see them.
|
||||
pub fn update_presence(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
presence: PresenceEvent,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
// TODO: Remove old entry? Or maybe just wipe completely from time to time?
|
||||
|
||||
let count = globals.next_count()?.to_be_bytes();
|
||||
|
||||
let mut presence_id = room_id.as_bytes().to_vec();
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(&count);
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(presence.sender.as_bytes());
|
||||
|
||||
self.presenceid_presence.insert(
|
||||
&presence_id,
|
||||
&serde_json::to_vec(&presence).expect("PresenceEvent can be serialized"),
|
||||
)?;
|
||||
|
||||
self.userid_lastpresenceupdate.insert(
|
||||
user_id.as_bytes(),
|
||||
&utils::millis_since_unix_epoch().to_be_bytes(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Resets the presence timeout, so the user will stay in their current presence state.
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn ping_presence(&self, user_id: &UserId) -> Result<()> {
|
||||
self.userid_lastpresenceupdate.insert(
|
||||
user_id.as_bytes(),
|
||||
&utils::millis_since_unix_epoch().to_be_bytes(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the timestamp of the last presence update of this user in millis since the unix epoch.
|
||||
pub fn last_presence_update(&self, user_id: &UserId) -> Result<Option<u64>> {
|
||||
self.userid_lastpresenceupdate
|
||||
.get(user_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid timestamp in userid_lastpresenceupdate.")
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
pub fn get_last_presence_event(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Option<PresenceEvent>> {
|
||||
let last_update = match self.last_presence_update(user_id)? {
|
||||
Some(last) => last,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let mut presence_id = room_id.as_bytes().to_vec();
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(&last_update.to_be_bytes());
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.presenceid_presence
|
||||
.get(&presence_id)?
|
||||
.map(|value| {
|
||||
let mut presence: PresenceEvent = serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("Invalid presence event in db."))?;
|
||||
let current_timestamp: UInt = utils::millis_since_unix_epoch()
|
||||
.try_into()
|
||||
.expect("time is valid");
|
||||
|
||||
if presence.content.presence == PresenceState::Online {
|
||||
// Don't set last_active_ago when the user is online
|
||||
presence.content.last_active_ago = None;
|
||||
} else {
|
||||
// Convert from timestamp to duration
|
||||
presence.content.last_active_ago = presence
|
||||
.content
|
||||
.last_active_ago
|
||||
.map(|timestamp| current_timestamp - timestamp);
|
||||
}
|
||||
|
||||
Ok(presence)
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Sets all users to offline who have been quiet for too long.
|
||||
fn _presence_maintain(
|
||||
&self,
|
||||
rooms: &super::Rooms,
|
||||
globals: &super::super::globals::Globals,
|
||||
) -> Result<()> {
|
||||
let current_timestamp = utils::millis_since_unix_epoch();
|
||||
|
||||
for (user_id_bytes, last_timestamp) in self
|
||||
.userid_lastpresenceupdate
|
||||
.iter()
|
||||
.filter_map(|(k, bytes)| {
|
||||
Some((
|
||||
k,
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid timestamp in userid_lastpresenceupdate.")
|
||||
})
|
||||
.ok()?,
|
||||
))
|
||||
})
|
||||
.take_while(|(_, timestamp)| current_timestamp.saturating_sub(*timestamp) > 5 * 60_000)
|
||||
// 5 Minutes
|
||||
{
|
||||
// Send new presence events to set the user offline
|
||||
let count = globals.next_count()?.to_be_bytes();
|
||||
let user_id: Box<_> = utils::string_from_bytes(&user_id_bytes)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid UserId bytes in userid_lastpresenceupdate.")
|
||||
})?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid UserId in userid_lastpresenceupdate."))?;
|
||||
for room_id in rooms.rooms_joined(&user_id).filter_map(|r| r.ok()) {
|
||||
let mut presence_id = room_id.as_bytes().to_vec();
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(&count);
|
||||
presence_id.push(0xff);
|
||||
presence_id.extend_from_slice(&user_id_bytes);
|
||||
|
||||
self.presenceid_presence.insert(
|
||||
&presence_id,
|
||||
&serde_json::to_vec(&PresenceEvent {
|
||||
content: PresenceEventContent {
|
||||
avatar_url: None,
|
||||
currently_active: None,
|
||||
displayname: None,
|
||||
last_active_ago: Some(
|
||||
last_timestamp.try_into().expect("time is valid"),
|
||||
),
|
||||
presence: PresenceState::Offline,
|
||||
status_msg: None,
|
||||
},
|
||||
sender: user_id.to_owned(),
|
||||
})
|
||||
.expect("PresenceEvent can be serialized"),
|
||||
)?;
|
||||
}
|
||||
|
||||
self.userid_lastpresenceupdate.insert(
|
||||
user_id.as_bytes(),
|
||||
&utils::millis_since_unix_epoch().to_be_bytes(),
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns an iterator over the most recent presence updates that happened after the event with id `since`.
|
||||
#[tracing::instrument(skip(self, since, _rooms, _globals))]
|
||||
pub fn presence_since(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
_rooms: &super::Rooms,
|
||||
_globals: &super::super::globals::Globals,
|
||||
) -> Result<HashMap<Box<UserId>, PresenceEvent>> {
|
||||
//self.presence_maintain(rooms, globals)?;
|
||||
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
|
||||
let mut first_possible_edu = prefix.clone();
|
||||
first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since
|
||||
let mut hashmap = HashMap::new();
|
||||
|
||||
for (key, value) in self
|
||||
.presenceid_presence
|
||||
.iter_from(&*first_possible_edu, false)
|
||||
.take_while(|(key, _)| key.starts_with(&prefix))
|
||||
{
|
||||
let user_id = UserId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid UserId bytes in presenceid_presence."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid UserId in presenceid_presence."))?;
|
||||
|
||||
let mut presence: PresenceEvent = serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("Invalid presence event in db."))?;
|
||||
|
||||
let current_timestamp: UInt = utils::millis_since_unix_epoch()
|
||||
.try_into()
|
||||
.expect("time is valid");
|
||||
|
||||
if presence.content.presence == PresenceState::Online {
|
||||
// Don't set last_active_ago when the user is online
|
||||
presence.content.last_active_ago = None;
|
||||
} else {
|
||||
// Convert from timestamp to duration
|
||||
presence.content.last_active_ago = presence
|
||||
.content
|
||||
.last_active_ago
|
||||
.map(|timestamp| current_timestamp - timestamp);
|
||||
}
|
||||
|
||||
hashmap.insert(user_id, presence);
|
||||
}
|
||||
|
||||
Ok(hashmap)
|
||||
}
|
||||
}
|
26
src/lib.rs
26
src/lib.rs
|
@ -7,19 +7,25 @@
|
|||
#![allow(clippy::suspicious_else_formatting)]
|
||||
#![deny(clippy::dbg_macro)]
|
||||
|
||||
pub mod api;
|
||||
mod config;
|
||||
mod database;
|
||||
mod error;
|
||||
mod pdu;
|
||||
mod ruma_wrapper;
|
||||
mod service;
|
||||
mod utils;
|
||||
|
||||
pub mod appservice_server;
|
||||
pub mod client_server;
|
||||
pub mod server_server;
|
||||
use std::sync::RwLock;
|
||||
|
||||
pub use api::ruma_wrapper::{Ruma, RumaResponse};
|
||||
pub use config::Config;
|
||||
pub use database::Database;
|
||||
pub use error::{Error, Result};
|
||||
pub use pdu::PduEvent;
|
||||
pub use ruma_wrapper::{Ruma, RumaResponse};
|
||||
pub use database::KeyValueDatabase;
|
||||
pub use service::{pdu::PduEvent, Services};
|
||||
pub use utils::error::{Error, Result};
|
||||
|
||||
pub static SERVICES: RwLock<Option<&'static Services>> = RwLock::new(None);
|
||||
|
||||
pub fn services<'a>() -> &'static Services {
|
||||
&SERVICES
|
||||
.read()
|
||||
.unwrap()
|
||||
.expect("SERVICES should be initialized when this is called")
|
||||
}
|
||||
|
|
46
src/main.rs
46
src/main.rs
|
@ -7,7 +7,7 @@
|
|||
#![allow(clippy::suspicious_else_formatting)]
|
||||
#![deny(clippy::dbg_macro)]
|
||||
|
||||
use std::{future::Future, io, net::SocketAddr, sync::Arc, time::Duration};
|
||||
use std::{future::Future, io, net::SocketAddr, time::Duration};
|
||||
|
||||
use axum::{
|
||||
extract::{FromRequest, MatchedPath},
|
||||
|
@ -17,6 +17,7 @@ use axum::{
|
|||
Router,
|
||||
};
|
||||
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
|
||||
use conduit::api::{client_server, server_server};
|
||||
use figment::{
|
||||
providers::{Env, Format, Toml},
|
||||
Figment,
|
||||
|
@ -27,14 +28,14 @@ use http::{
|
|||
};
|
||||
use opentelemetry::trace::{FutureExt, Tracer};
|
||||
use ruma::api::{client::error::ErrorKind, IncomingRequest};
|
||||
use tokio::{signal, sync::RwLock};
|
||||
use tokio::signal;
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::{
|
||||
cors::{self, CorsLayer},
|
||||
trace::TraceLayer,
|
||||
ServiceBuilderExt as _,
|
||||
};
|
||||
use tracing::warn;
|
||||
use tracing::{info, warn};
|
||||
use tracing_subscriber::{prelude::*, EnvFilter};
|
||||
|
||||
pub use conduit::*; // Re-export everything from the library crate
|
||||
|
@ -48,6 +49,7 @@ static GLOBAL: Jemalloc = Jemalloc;
|
|||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// Initialize DB
|
||||
let raw_config =
|
||||
Figment::new()
|
||||
.merge(
|
||||
|
@ -66,21 +68,20 @@ async fn main() {
|
|||
}
|
||||
};
|
||||
|
||||
config.warn_deprecated();
|
||||
|
||||
if let Err(e) = KeyValueDatabase::load_or_create(config).await {
|
||||
eprintln!(
|
||||
"The database couldn't be loaded or created. The following error occured: {}",
|
||||
e
|
||||
);
|
||||
std::process::exit(1);
|
||||
};
|
||||
|
||||
let config = &services().globals.config;
|
||||
|
||||
let start = async {
|
||||
config.warn_deprecated();
|
||||
|
||||
let db = match Database::load_or_create(&config).await {
|
||||
Ok(db) => db,
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"The database couldn't be loaded or created. The following error occured: {}",
|
||||
e
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
|
||||
run_server(&config, db).await.unwrap();
|
||||
run_server().await.unwrap();
|
||||
};
|
||||
|
||||
if config.allow_jaeger {
|
||||
|
@ -120,7 +121,8 @@ async fn main() {
|
|||
}
|
||||
}
|
||||
|
||||
async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> io::Result<()> {
|
||||
async fn run_server() -> io::Result<()> {
|
||||
let config = &services().globals.config;
|
||||
let addr = SocketAddr::from((config.address, config.port));
|
||||
|
||||
let x_requested_with = HeaderName::from_static("x-requested-with");
|
||||
|
@ -157,8 +159,7 @@ async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> io::Result<()
|
|||
header::AUTHORIZATION,
|
||||
])
|
||||
.max_age(Duration::from_secs(86400)),
|
||||
)
|
||||
.add_extension(db.clone());
|
||||
);
|
||||
|
||||
let app = routes().layer(middlewares).into_make_service();
|
||||
let handle = ServerHandle::new();
|
||||
|
@ -175,8 +176,9 @@ async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> io::Result<()
|
|||
}
|
||||
}
|
||||
|
||||
// After serve exits and before exiting, shutdown the DB
|
||||
Database::on_shutdown(db).await;
|
||||
// On shutdown
|
||||
info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers...");
|
||||
services().globals.rotate.fire();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
3644
src/server_server.rs
3644
src/server_server.rs
File diff suppressed because it is too large
Load diff
35
src/service/account_data/data.rs
Normal file
35
src/service/account_data/data.rs
Normal file
|
@ -0,0 +1,35 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{
|
||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
/// Places one event in the account data of the user and removes the previous entry.
|
||||
fn update(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()>;
|
||||
|
||||
/// Searches the account data for a specific kind.
|
||||
fn get(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
kind: RoomAccountDataEventType,
|
||||
) -> Result<Option<Box<serde_json::value::RawValue>>>;
|
||||
|
||||
/// Returns all changes to the account data that happened after `since`.
|
||||
fn changes_since(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
since: u64,
|
||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>>;
|
||||
}
|
53
src/service/account_data/mod.rs
Normal file
53
src/service/account_data/mod.rs
Normal file
|
@ -0,0 +1,53 @@
|
|||
mod data;
|
||||
|
||||
pub use data::Data;
|
||||
|
||||
use ruma::{
|
||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
/// Places one event in the account data of the user and removes the previous entry.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
|
||||
pub fn update(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()> {
|
||||
self.db.update(room_id, user_id, event_type, data)
|
||||
}
|
||||
|
||||
/// Searches the account data for a specific kind.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type))]
|
||||
pub fn get(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
) -> Result<Option<Box<serde_json::value::RawValue>>> {
|
||||
self.db.get(room_id, user_id, event_type)
|
||||
}
|
||||
|
||||
/// Returns all changes to the account data that happened after `since`.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
||||
pub fn changes_since(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
since: u64,
|
||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
|
||||
self.db.changes_since(room_id, user_id, since)
|
||||
}
|
||||
}
|
1161
src/service/admin/mod.rs
Normal file
1161
src/service/admin/mod.rs
Normal file
File diff suppressed because it is too large
Load diff
19
src/service/appservice/data.rs
Normal file
19
src/service/appservice/data.rs
Normal file
|
@ -0,0 +1,19 @@
|
|||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<String>;
|
||||
|
||||
/// Remove an appservice registration
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `service_name` - the name you send to register the service previously
|
||||
fn unregister_appservice(&self, service_name: &str) -> Result<()>;
|
||||
|
||||
fn get_registration(&self, id: &str) -> Result<Option<serde_yaml::Value>>;
|
||||
|
||||
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>>;
|
||||
|
||||
fn all(&self) -> Result<Vec<(String, serde_yaml::Value)>>;
|
||||
}
|
37
src/service/appservice/mod.rs
Normal file
37
src/service/appservice/mod.rs
Normal file
|
@ -0,0 +1,37 @@
|
|||
mod data;
|
||||
|
||||
pub use data::Data;
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<String> {
|
||||
self.db.register_appservice(yaml)
|
||||
}
|
||||
|
||||
/// Remove an appservice registration
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `service_name` - the name you send to register the service previously
|
||||
pub fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
||||
self.db.unregister_appservice(service_name)
|
||||
}
|
||||
|
||||
pub fn get_registration(&self, id: &str) -> Result<Option<serde_yaml::Value>> {
|
||||
self.db.get_registration(id)
|
||||
}
|
||||
|
||||
pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> {
|
||||
self.db.iter_ids()
|
||||
}
|
||||
|
||||
pub fn all(&self) -> Result<Vec<(String, serde_yaml::Value)>> {
|
||||
self.db.all()
|
||||
}
|
||||
}
|
34
src/service/globals/data.rs
Normal file
34
src/service/globals/data.rs
Normal file
|
@ -0,0 +1,34 @@
|
|||
use std::collections::BTreeMap;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use ruma::{
|
||||
api::federation::discovery::{ServerSigningKeys, VerifyKey},
|
||||
signatures::Ed25519KeyPair,
|
||||
DeviceId, ServerName, ServerSigningKeyId, UserId,
|
||||
};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
#[async_trait]
|
||||
pub trait Data: Send + Sync {
|
||||
fn next_count(&self) -> Result<u64>;
|
||||
fn current_count(&self) -> Result<u64>;
|
||||
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>;
|
||||
fn cleanup(&self) -> Result<()>;
|
||||
fn memory_usage(&self) -> Result<String>;
|
||||
fn load_keypair(&self) -> Result<Ed25519KeyPair>;
|
||||
fn remove_keypair(&self) -> Result<()>;
|
||||
fn add_signing_key(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<Box<ServerSigningKeyId>, VerifyKey>>;
|
||||
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
|
||||
fn signing_keys_for(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
) -> Result<BTreeMap<Box<ServerSigningKeyId>, VerifyKey>>;
|
||||
fn database_version(&self) -> Result<u64>;
|
||||
fn bump_database_version(&self, new_version: u64) -> Result<()>;
|
||||
}
|
|
@ -1,11 +1,15 @@
|
|||
use crate::{database::Config, server_server::FedDest, utils, Error, Result};
|
||||
mod data;
|
||||
pub use data::Data;
|
||||
|
||||
use crate::api::server_server::FedDest;
|
||||
|
||||
use crate::{Config, Error, Result};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::sync::sync_events,
|
||||
federation::discovery::{ServerSigningKeys, VerifyKey},
|
||||
},
|
||||
DeviceId, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId, ServerName,
|
||||
ServerSigningKeyId, UserId,
|
||||
DeviceId, EventId, RoomId, RoomVersionId, ServerName, ServerSigningKeyId, UserId,
|
||||
};
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
|
@ -20,10 +24,6 @@ use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore};
|
|||
use tracing::error;
|
||||
use trust_dns_resolver::TokioAsyncResolver;
|
||||
|
||||
use super::abstraction::Tree;
|
||||
|
||||
pub const COUNTER: &[u8] = b"c";
|
||||
|
||||
type WellKnownMap = HashMap<Box<ServerName>, (FedDest, String)>;
|
||||
type TlsNameMap = HashMap<String, (Vec<IpAddr>, u16)>;
|
||||
type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries
|
||||
|
@ -32,10 +32,11 @@ type SyncHandle = (
|
|||
Receiver<Option<Result<sync_events::v3::Response>>>, // rx
|
||||
);
|
||||
|
||||
pub struct Globals {
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
|
||||
pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host
|
||||
pub tls_name_override: Arc<RwLock<TlsNameMap>>,
|
||||
pub(super) globals: Arc<dyn Tree>,
|
||||
pub config: Config,
|
||||
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
|
||||
dns_resolver: TokioAsyncResolver,
|
||||
|
@ -44,7 +45,6 @@ pub struct Globals {
|
|||
default_client: reqwest::Client,
|
||||
pub stable_room_versions: Vec<RoomVersionId>,
|
||||
pub unstable_room_versions: Vec<RoomVersionId>,
|
||||
pub(super) server_signingkeys: Arc<dyn Tree>,
|
||||
pub bad_event_ratelimiter: Arc<RwLock<HashMap<Box<EventId>, RateLimitState>>>,
|
||||
pub bad_signature_ratelimiter: Arc<RwLock<HashMap<Vec<String>, RateLimitState>>>,
|
||||
pub servername_ratelimiter: Arc<RwLock<HashMap<Box<ServerName>, Arc<Semaphore>>>>,
|
||||
|
@ -87,47 +87,15 @@ impl Default for RotationHandler {
|
|||
}
|
||||
}
|
||||
|
||||
impl Globals {
|
||||
pub fn load(
|
||||
globals: Arc<dyn Tree>,
|
||||
server_signingkeys: Arc<dyn Tree>,
|
||||
config: Config,
|
||||
) -> Result<Self> {
|
||||
let keypair_bytes = globals.get(b"keypair")?.map_or_else(
|
||||
|| {
|
||||
let keypair = utils::generate_keypair();
|
||||
globals.insert(b"keypair", &keypair)?;
|
||||
Ok::<_, Error>(keypair)
|
||||
},
|
||||
|s| Ok(s.to_vec()),
|
||||
)?;
|
||||
|
||||
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff);
|
||||
|
||||
let keypair = utils::string_from_bytes(
|
||||
// 1. version
|
||||
parts
|
||||
.next()
|
||||
.expect("splitn always returns at least one element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
|
||||
.and_then(|version| {
|
||||
// 2. key
|
||||
parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid keypair format in database."))
|
||||
.map(|key| (version, key))
|
||||
})
|
||||
.and_then(|(version, key)| {
|
||||
ruma::signatures::Ed25519KeyPair::from_der(key, version)
|
||||
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
|
||||
});
|
||||
impl Service {
|
||||
pub fn load(db: &'static dyn Data, config: Config) -> Result<Self> {
|
||||
let keypair = db.load_keypair();
|
||||
|
||||
let keypair = match keypair {
|
||||
Ok(k) => k,
|
||||
Err(e) => {
|
||||
error!("Keypair invalid. Deleting...");
|
||||
globals.remove(b"keypair")?;
|
||||
db.remove_keypair()?;
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
@ -161,7 +129,7 @@ impl Globals {
|
|||
let unstable_room_versions = vec![RoomVersionId::V3, RoomVersionId::V4, RoomVersionId::V5];
|
||||
|
||||
let mut s = Self {
|
||||
globals,
|
||||
db,
|
||||
config,
|
||||
keypair: Arc::new(keypair),
|
||||
dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|e| {
|
||||
|
@ -175,7 +143,6 @@ impl Globals {
|
|||
tls_name_override,
|
||||
federation_client,
|
||||
default_client,
|
||||
server_signingkeys,
|
||||
jwt_decoding_key,
|
||||
stable_room_versions,
|
||||
unstable_room_versions,
|
||||
|
@ -223,16 +190,24 @@ impl Globals {
|
|||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn next_count(&self) -> Result<u64> {
|
||||
utils::u64_from_bytes(&self.globals.increment(COUNTER)?)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
self.db.next_count()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn current_count(&self) -> Result<u64> {
|
||||
self.globals.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
})
|
||||
self.db.current_count()
|
||||
}
|
||||
|
||||
pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||
self.db.watch(user_id, device_id).await
|
||||
}
|
||||
|
||||
pub fn cleanup(&self) -> Result<()> {
|
||||
self.db.cleanup()
|
||||
}
|
||||
|
||||
pub fn memory_usage(&self) -> Result<String> {
|
||||
self.db.memory_usage()
|
||||
}
|
||||
|
||||
pub fn server_name(&self) -> &ServerName {
|
||||
|
@ -321,38 +296,7 @@ impl Globals {
|
|||
origin: &ServerName,
|
||||
new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<Box<ServerSigningKeyId>, VerifyKey>> {
|
||||
// Not atomic, but this is not critical
|
||||
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
|
||||
|
||||
let mut keys = signingkeys
|
||||
.and_then(|keys| serde_json::from_slice(&keys).ok())
|
||||
.unwrap_or_else(|| {
|
||||
// Just insert "now", it doesn't matter
|
||||
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
|
||||
});
|
||||
|
||||
let ServerSigningKeys {
|
||||
verify_keys,
|
||||
old_verify_keys,
|
||||
..
|
||||
} = new_keys;
|
||||
|
||||
keys.verify_keys.extend(verify_keys.into_iter());
|
||||
keys.old_verify_keys.extend(old_verify_keys.into_iter());
|
||||
|
||||
self.server_signingkeys.insert(
|
||||
origin.as_bytes(),
|
||||
&serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"),
|
||||
)?;
|
||||
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(
|
||||
keys.old_verify_keys
|
||||
.into_iter()
|
||||
.map(|old| (old.0, VerifyKey::new(old.1.key))),
|
||||
);
|
||||
|
||||
Ok(tree)
|
||||
self.db.add_signing_key(origin, new_keys)
|
||||
}
|
||||
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
|
||||
|
@ -360,35 +304,15 @@ impl Globals {
|
|||
&self,
|
||||
origin: &ServerName,
|
||||
) -> Result<BTreeMap<Box<ServerSigningKeyId>, VerifyKey>> {
|
||||
let signingkeys = self
|
||||
.server_signingkeys
|
||||
.get(origin.as_bytes())?
|
||||
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
|
||||
.map(|keys: ServerSigningKeys| {
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(
|
||||
keys.old_verify_keys
|
||||
.into_iter()
|
||||
.map(|old| (old.0, VerifyKey::new(old.1.key))),
|
||||
);
|
||||
tree
|
||||
})
|
||||
.unwrap_or_else(BTreeMap::new);
|
||||
|
||||
Ok(signingkeys)
|
||||
self.db.signing_keys_for(origin)
|
||||
}
|
||||
|
||||
pub fn database_version(&self) -> Result<u64> {
|
||||
self.globals.get(b"version")?.map_or(Ok(0), |version| {
|
||||
utils::u64_from_bytes(&version)
|
||||
.map_err(|_| Error::bad_database("Database version id is invalid."))
|
||||
})
|
||||
self.db.database_version()
|
||||
}
|
||||
|
||||
pub fn bump_database_version(&self, new_version: u64) -> Result<()> {
|
||||
self.globals
|
||||
.insert(b"version", &new_version.to_be_bytes())?;
|
||||
Ok(())
|
||||
self.db.bump_database_version(new_version)
|
||||
}
|
||||
|
||||
pub fn get_media_folder(&self) -> PathBuf {
|
78
src/service/key_backups/data.rs
Normal file
78
src/service/key_backups/data.rs
Normal file
|
@ -0,0 +1,78 @@
|
|||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{
|
||||
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
fn create_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String>;
|
||||
|
||||
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>;
|
||||
|
||||
fn update_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String>;
|
||||
|
||||
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>>;
|
||||
|
||||
fn get_latest_backup(&self, user_id: &UserId)
|
||||
-> Result<Option<(String, Raw<BackupAlgorithm>)>>;
|
||||
|
||||
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>>;
|
||||
|
||||
fn add_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
key_data: &Raw<KeyBackupData>,
|
||||
) -> Result<()>;
|
||||
|
||||
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize>;
|
||||
|
||||
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String>;
|
||||
|
||||
fn get_all(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<BTreeMap<Box<RoomId>, RoomKeyBackup>>;
|
||||
|
||||
fn get_room(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>>;
|
||||
|
||||
fn get_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<Option<Raw<KeyBackupData>>>;
|
||||
|
||||
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>;
|
||||
|
||||
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>;
|
||||
|
||||
fn delete_room_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<()>;
|
||||
}
|
127
src/service/key_backups/mod.rs
Normal file
127
src/service/key_backups/mod.rs
Normal file
|
@ -0,0 +1,127 @@
|
|||
mod data;
|
||||
pub use data::Data;
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{
|
||||
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub fn create_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String> {
|
||||
self.db.create_backup(user_id, backup_metadata)
|
||||
}
|
||||
|
||||
pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
self.db.delete_backup(user_id, version)
|
||||
}
|
||||
|
||||
pub fn update_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String> {
|
||||
self.db.update_backup(user_id, version, backup_metadata)
|
||||
}
|
||||
|
||||
pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
self.db.get_latest_backup_version(user_id)
|
||||
}
|
||||
|
||||
pub fn get_latest_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||
self.db.get_latest_backup(user_id)
|
||||
}
|
||||
|
||||
pub fn get_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||
self.db.get_backup(user_id, version)
|
||||
}
|
||||
|
||||
pub fn add_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
key_data: &Raw<KeyBackupData>,
|
||||
) -> Result<()> {
|
||||
self.db
|
||||
.add_key(user_id, version, room_id, session_id, key_data)
|
||||
}
|
||||
|
||||
pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
|
||||
self.db.count_keys(user_id, version)
|
||||
}
|
||||
|
||||
pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
||||
self.db.get_etag(user_id, version)
|
||||
}
|
||||
|
||||
pub fn get_all(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<BTreeMap<Box<RoomId>, RoomKeyBackup>> {
|
||||
self.db.get_all(user_id, version)
|
||||
}
|
||||
|
||||
pub fn get_room(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
|
||||
self.db.get_room(user_id, version, room_id)
|
||||
}
|
||||
|
||||
pub fn get_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<Option<Raw<KeyBackupData>>> {
|
||||
self.db.get_session(user_id, version, room_id, session_id)
|
||||
}
|
||||
|
||||
pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
self.db.delete_all_keys(user_id, version)
|
||||
}
|
||||
|
||||
pub fn delete_room_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
) -> Result<()> {
|
||||
self.db.delete_room_keys(user_id, version, room_id)
|
||||
}
|
||||
|
||||
pub fn delete_room_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<()> {
|
||||
self.db
|
||||
.delete_room_key(user_id, version, room_id, session_id)
|
||||
}
|
||||
}
|
20
src/service/media/data.rs
Normal file
20
src/service/media/data.rs
Normal file
|
@ -0,0 +1,20 @@
|
|||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
fn create_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>>;
|
||||
|
||||
/// Returns content_disposition, content_type and the metadata key.
|
||||
fn search_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)>;
|
||||
}
|
221
src/service/media/mod.rs
Normal file
221
src/service/media/mod.rs
Normal file
|
@ -0,0 +1,221 @@
|
|||
mod data;
|
||||
pub use data::Data;
|
||||
|
||||
use crate::{services, Result};
|
||||
use image::{imageops::FilterType, GenericImageView};
|
||||
|
||||
use tokio::{
|
||||
fs::File,
|
||||
io::{AsyncReadExt, AsyncWriteExt},
|
||||
};
|
||||
|
||||
pub struct FileMeta {
|
||||
pub content_disposition: Option<String>,
|
||||
pub content_type: Option<String>,
|
||||
pub file: Vec<u8>,
|
||||
}
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
/// Uploads a file.
|
||||
pub async fn create(
|
||||
&self,
|
||||
mxc: String,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
file: &[u8],
|
||||
) -> Result<()> {
|
||||
// Width, Height = 0 if it's not a thumbnail
|
||||
let key = self
|
||||
.db
|
||||
.create_file_metadata(mxc, 0, 0, content_disposition, content_type)?;
|
||||
|
||||
let path = services().globals.get_media_file(&key);
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(file).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Uploads or replaces a file thumbnail.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn upload_thumbnail(
|
||||
&self,
|
||||
mxc: String,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
width: u32,
|
||||
height: u32,
|
||||
file: &[u8],
|
||||
) -> Result<()> {
|
||||
let key =
|
||||
self.db
|
||||
.create_file_metadata(mxc, width, height, content_disposition, content_type)?;
|
||||
|
||||
let path = services().globals.get_media_file(&key);
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(file).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Downloads a file.
|
||||
pub async fn get(&self, mxc: String) -> Result<Option<FileMeta>> {
|
||||
if let Ok((content_disposition, content_type, key)) =
|
||||
self.db.search_file_metadata(mxc, 0, 0)
|
||||
{
|
||||
let path = services().globals.get_media_file(&key);
|
||||
let mut file = Vec::new();
|
||||
File::open(path).await?.read_to_end(&mut file).await?;
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file,
|
||||
}))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns width, height of the thumbnail and whether it should be cropped. Returns None when
|
||||
/// the server should send the original file.
|
||||
pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> {
|
||||
match (width, height) {
|
||||
(0..=32, 0..=32) => Some((32, 32, true)),
|
||||
(0..=96, 0..=96) => Some((96, 96, true)),
|
||||
(0..=320, 0..=240) => Some((320, 240, false)),
|
||||
(0..=640, 0..=480) => Some((640, 480, false)),
|
||||
(0..=800, 0..=600) => Some((800, 600, false)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Downloads a file's thumbnail.
|
||||
///
|
||||
/// Here's an example on how it works:
|
||||
///
|
||||
/// - Client requests an image with width=567, height=567
|
||||
/// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails
|
||||
/// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96)
|
||||
/// - Server creates the thumbnail and sends it to the user
|
||||
///
|
||||
/// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards.
|
||||
pub async fn get_thumbnail(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
) -> Result<Option<FileMeta>> {
|
||||
let (width, height, crop) = self
|
||||
.thumbnail_properties(width, height)
|
||||
.unwrap_or((0, 0, false)); // 0, 0 because that's the original file
|
||||
|
||||
if let Ok((content_disposition, content_type, key)) =
|
||||
self.db.search_file_metadata(mxc.clone(), width, height)
|
||||
{
|
||||
// Using saved thumbnail
|
||||
let path = services().globals.get_media_file(&key);
|
||||
let mut file = Vec::new();
|
||||
File::open(path).await?.read_to_end(&mut file).await?;
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.to_vec(),
|
||||
}))
|
||||
} else if let Ok((content_disposition, content_type, key)) =
|
||||
self.db.search_file_metadata(mxc.clone(), 0, 0)
|
||||
{
|
||||
// Generate a thumbnail
|
||||
let path = services().globals.get_media_file(&key);
|
||||
let mut file = Vec::new();
|
||||
File::open(path).await?.read_to_end(&mut file).await?;
|
||||
|
||||
if let Ok(image) = image::load_from_memory(&file) {
|
||||
let original_width = image.width();
|
||||
let original_height = image.height();
|
||||
if width > original_width || height > original_height {
|
||||
return Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.to_vec(),
|
||||
}));
|
||||
}
|
||||
|
||||
let thumbnail = if crop {
|
||||
image.resize_to_fill(width, height, FilterType::CatmullRom)
|
||||
} else {
|
||||
let (exact_width, exact_height) = {
|
||||
// Copied from image::dynimage::resize_dimensions
|
||||
let ratio = u64::from(original_width) * u64::from(height);
|
||||
let nratio = u64::from(width) * u64::from(original_height);
|
||||
|
||||
let use_width = nratio <= ratio;
|
||||
let intermediate = if use_width {
|
||||
u64::from(original_height) * u64::from(width)
|
||||
/ u64::from(original_width)
|
||||
} else {
|
||||
u64::from(original_width) * u64::from(height)
|
||||
/ u64::from(original_height)
|
||||
};
|
||||
if use_width {
|
||||
if intermediate <= u64::from(::std::u32::MAX) {
|
||||
(width, intermediate as u32)
|
||||
} else {
|
||||
(
|
||||
(u64::from(width) * u64::from(::std::u32::MAX) / intermediate)
|
||||
as u32,
|
||||
::std::u32::MAX,
|
||||
)
|
||||
}
|
||||
} else if intermediate <= u64::from(::std::u32::MAX) {
|
||||
(intermediate as u32, height)
|
||||
} else {
|
||||
(
|
||||
::std::u32::MAX,
|
||||
(u64::from(height) * u64::from(::std::u32::MAX) / intermediate)
|
||||
as u32,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
image.thumbnail_exact(exact_width, exact_height)
|
||||
};
|
||||
|
||||
let mut thumbnail_bytes = Vec::new();
|
||||
thumbnail.write_to(&mut thumbnail_bytes, image::ImageOutputFormat::Png)?;
|
||||
|
||||
// Save thumbnail in database so we don't have to generate it again next time
|
||||
let thumbnail_key = self.db.create_file_metadata(
|
||||
mxc,
|
||||
width,
|
||||
height,
|
||||
content_disposition.as_deref(),
|
||||
content_type.as_deref(),
|
||||
)?;
|
||||
|
||||
let path = services().globals.get_media_file(&thumbnail_key);
|
||||
let mut f = File::create(path).await?;
|
||||
f.write_all(&thumbnail_bytes).await?;
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: thumbnail_bytes.to_vec(),
|
||||
}))
|
||||
} else {
|
||||
// Couldn't parse file to generate thumbnail, send original
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
content_type,
|
||||
file: file.to_vec(),
|
||||
}))
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
106
src/service/mod.rs
Normal file
106
src/service/mod.rs
Normal file
|
@ -0,0 +1,106 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
|
||||
use lru_cache::LruCache;
|
||||
|
||||
use crate::{Config, Result};
|
||||
|
||||
pub mod account_data;
|
||||
pub mod admin;
|
||||
pub mod appservice;
|
||||
pub mod globals;
|
||||
pub mod key_backups;
|
||||
pub mod media;
|
||||
pub mod pdu;
|
||||
pub mod pusher;
|
||||
pub mod rooms;
|
||||
pub mod sending;
|
||||
pub mod transaction_ids;
|
||||
pub mod uiaa;
|
||||
pub mod users;
|
||||
|
||||
pub struct Services {
|
||||
pub appservice: appservice::Service,
|
||||
pub pusher: pusher::Service,
|
||||
pub rooms: rooms::Service,
|
||||
pub transaction_ids: transaction_ids::Service,
|
||||
pub uiaa: uiaa::Service,
|
||||
pub users: users::Service,
|
||||
pub account_data: account_data::Service,
|
||||
pub admin: Arc<admin::Service>,
|
||||
pub globals: globals::Service,
|
||||
pub key_backups: key_backups::Service,
|
||||
pub media: media::Service,
|
||||
pub sending: Arc<sending::Service>,
|
||||
}
|
||||
|
||||
impl Services {
|
||||
pub fn build<
|
||||
D: appservice::Data
|
||||
+ pusher::Data
|
||||
+ rooms::Data
|
||||
+ transaction_ids::Data
|
||||
+ uiaa::Data
|
||||
+ users::Data
|
||||
+ account_data::Data
|
||||
+ globals::Data
|
||||
+ key_backups::Data
|
||||
+ media::Data
|
||||
+ sending::Data
|
||||
+ 'static,
|
||||
>(
|
||||
db: &'static D,
|
||||
config: Config,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
appservice: appservice::Service { db },
|
||||
pusher: pusher::Service { db },
|
||||
rooms: rooms::Service {
|
||||
alias: rooms::alias::Service { db },
|
||||
auth_chain: rooms::auth_chain::Service { db },
|
||||
directory: rooms::directory::Service { db },
|
||||
edus: rooms::edus::Service {
|
||||
presence: rooms::edus::presence::Service { db },
|
||||
read_receipt: rooms::edus::read_receipt::Service { db },
|
||||
typing: rooms::edus::typing::Service { db },
|
||||
},
|
||||
event_handler: rooms::event_handler::Service,
|
||||
lazy_loading: rooms::lazy_loading::Service {
|
||||
db,
|
||||
lazy_load_waiting: Mutex::new(HashMap::new()),
|
||||
},
|
||||
metadata: rooms::metadata::Service { db },
|
||||
outlier: rooms::outlier::Service { db },
|
||||
pdu_metadata: rooms::pdu_metadata::Service { db },
|
||||
search: rooms::search::Service { db },
|
||||
short: rooms::short::Service { db },
|
||||
state: rooms::state::Service { db },
|
||||
state_accessor: rooms::state_accessor::Service { db },
|
||||
state_cache: rooms::state_cache::Service { db },
|
||||
state_compressor: rooms::state_compressor::Service {
|
||||
db,
|
||||
stateinfo_cache: Mutex::new(LruCache::new(
|
||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
},
|
||||
timeline: rooms::timeline::Service {
|
||||
db,
|
||||
lasttimelinecount_cache: Mutex::new(HashMap::new()),
|
||||
},
|
||||
user: rooms::user::Service { db },
|
||||
},
|
||||
transaction_ids: transaction_ids::Service { db },
|
||||
uiaa: uiaa::Service { db },
|
||||
users: users::Service { db },
|
||||
account_data: account_data::Service { db },
|
||||
admin: admin::Service::build(),
|
||||
key_backups: key_backups::Service { db },
|
||||
media: media::Service { db },
|
||||
sending: sending::Service::build(db, &config),
|
||||
|
||||
globals: globals::Service::load(db, config)?,
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{Database, Error};
|
||||
use crate::{services, Error};
|
||||
use ruma::{
|
||||
events::{
|
||||
room::member::RoomMemberEventContent, AnyEphemeralRoomEvent, AnyRoomEvent, AnyStateEvent,
|
||||
|
@ -332,7 +332,6 @@ impl Ord for PduEvent {
|
|||
/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String, CanonicalJsonValue>`.
|
||||
pub(crate) fn gen_event_id_canonical_json(
|
||||
pdu: &RawJsonValue,
|
||||
db: &Database,
|
||||
) -> crate::Result<(Box<EventId>, CanonicalJsonObject)> {
|
||||
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
|
||||
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
|
||||
|
@ -344,7 +343,7 @@ pub(crate) fn gen_event_id_canonical_json(
|
|||
.and_then(|id| RoomId::parse(id.as_str()?).ok())
|
||||
.ok_or_else(|| Error::bad_database("PDU in db has invalid room_id."))?;
|
||||
|
||||
let room_version_id = db.rooms.get_room_version(&room_id);
|
||||
let room_version_id = services().rooms.state.get_room_version(&room_id);
|
||||
|
||||
let event_id = format!(
|
||||
"${}",
|
||||
|
@ -358,7 +357,7 @@ pub(crate) fn gen_event_id_canonical_json(
|
|||
Ok((event_id, value))
|
||||
}
|
||||
|
||||
/// Build the start of a PDU in order to add it to the `Database`.
|
||||
/// Build the start of a PDU in order to add it to the Database.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct PduBuilder {
|
||||
#[serde(rename = "type")]
|
17
src/service/pusher/data.rs
Normal file
17
src/service/pusher/data.rs
Normal file
|
@ -0,0 +1,17 @@
|
|||
use crate::Result;
|
||||
use ruma::{
|
||||
api::client::push::{get_pushers, set_pusher},
|
||||
UserId,
|
||||
};
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()>;
|
||||
|
||||
fn get_pusher(&self, sender: &UserId, pushkey: &str)
|
||||
-> Result<Option<get_pushers::v3::Pusher>>;
|
||||
|
||||
fn get_pushers(&self, sender: &UserId) -> Result<Vec<get_pushers::v3::Pusher>>;
|
||||
|
||||
fn get_pushkeys<'a>(&'a self, sender: &UserId)
|
||||
-> Box<dyn Iterator<Item = Result<String>> + 'a>;
|
||||
}
|
307
src/service/pusher/mod.rs
Normal file
307
src/service/pusher/mod.rs
Normal file
|
@ -0,0 +1,307 @@
|
|||
mod data;
|
||||
pub use data::Data;
|
||||
|
||||
use crate::{services, Error, PduEvent, Result};
|
||||
use bytes::BytesMut;
|
||||
use ruma::api::IncomingResponse;
|
||||
use ruma::{
|
||||
api::{
|
||||
client::push::{get_pushers, set_pusher, PusherKind},
|
||||
push_gateway::send_event_notification::{
|
||||
self,
|
||||
v1::{Device, Notification, NotificationCounts, NotificationPriority},
|
||||
},
|
||||
MatrixVersion, OutgoingRequest, SendAccessToken,
|
||||
},
|
||||
events::{
|
||||
room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent},
|
||||
AnySyncRoomEvent, RoomEventType, StateEventType,
|
||||
},
|
||||
push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
|
||||
serde::Raw,
|
||||
uint, RoomId, UInt, UserId,
|
||||
};
|
||||
|
||||
use std::{fmt::Debug, mem};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> {
|
||||
self.db.set_pusher(sender, pusher)
|
||||
}
|
||||
|
||||
pub fn get_pusher(
|
||||
&self,
|
||||
sender: &UserId,
|
||||
pushkey: &str,
|
||||
) -> Result<Option<get_pushers::v3::Pusher>> {
|
||||
self.db.get_pusher(sender, pushkey)
|
||||
}
|
||||
|
||||
pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<get_pushers::v3::Pusher>> {
|
||||
self.db.get_pushers(sender)
|
||||
}
|
||||
|
||||
pub fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>>> {
|
||||
self.db.get_pushkeys(sender)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, destination, request))]
|
||||
pub async fn send_request<T: OutgoingRequest>(
|
||||
&self,
|
||||
destination: &str,
|
||||
request: T,
|
||||
) -> Result<T::IncomingResponse>
|
||||
where
|
||||
T: Debug,
|
||||
{
|
||||
let destination = destination.replace("/_matrix/push/v1/notify", "");
|
||||
|
||||
let http_request = request
|
||||
.try_into_http_request::<BytesMut>(
|
||||
&destination,
|
||||
SendAccessToken::IfRequired(""),
|
||||
&[MatrixVersion::V1_0],
|
||||
)
|
||||
.map_err(|e| {
|
||||
warn!("Failed to find destination {}: {}", destination, e);
|
||||
Error::BadServerResponse("Invalid destination")
|
||||
})?
|
||||
.map(|body| body.freeze());
|
||||
|
||||
let reqwest_request = reqwest::Request::try_from(http_request)
|
||||
.expect("all http requests are valid reqwest requests");
|
||||
|
||||
// TODO: we could keep this very short and let expo backoff do it's thing...
|
||||
//*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
|
||||
|
||||
let url = reqwest_request.url().clone();
|
||||
let response = services()
|
||||
.globals
|
||||
.default_client()
|
||||
.execute(reqwest_request)
|
||||
.await;
|
||||
|
||||
match response {
|
||||
Ok(mut response) => {
|
||||
// reqwest::Response -> http::Response conversion
|
||||
let status = response.status();
|
||||
let mut http_response_builder = http::Response::builder()
|
||||
.status(status)
|
||||
.version(response.version());
|
||||
mem::swap(
|
||||
response.headers_mut(),
|
||||
http_response_builder
|
||||
.headers_mut()
|
||||
.expect("http::response::Builder is usable"),
|
||||
);
|
||||
|
||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||
warn!("server error {}", e);
|
||||
Vec::new().into()
|
||||
}); // TODO: handle timeout
|
||||
|
||||
if status != 200 {
|
||||
info!(
|
||||
"Push gateway returned bad response {} {}\n{}\n{:?}",
|
||||
destination,
|
||||
status,
|
||||
url,
|
||||
crate::utils::string_from_bytes(&body)
|
||||
);
|
||||
}
|
||||
|
||||
let response = T::IncomingResponse::try_from_http_response(
|
||||
http_response_builder
|
||||
.body(body)
|
||||
.expect("reqwest body is valid http body"),
|
||||
);
|
||||
response.map_err(|_| {
|
||||
info!(
|
||||
"Push gateway returned invalid response bytes {}\n{}",
|
||||
destination, url
|
||||
);
|
||||
Error::BadServerResponse("Push gateway returned bad response.")
|
||||
})
|
||||
}
|
||||
Err(e) => Err(e.into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))]
|
||||
pub async fn send_push_notice(
|
||||
&self,
|
||||
user: &UserId,
|
||||
unread: UInt,
|
||||
pusher: &get_pushers::v3::Pusher,
|
||||
ruleset: Ruleset,
|
||||
pdu: &PduEvent,
|
||||
) -> Result<()> {
|
||||
let mut notify = None;
|
||||
let mut tweaks = Vec::new();
|
||||
|
||||
let power_levels: RoomPowerLevelsEventContent = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")?
|
||||
.map(|ev| {
|
||||
serde_json::from_str(ev.content.get())
|
||||
.map_err(|_| Error::bad_database("invalid m.room.power_levels event"))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
|
||||
for action in self.get_actions(
|
||||
user,
|
||||
&ruleset,
|
||||
&power_levels,
|
||||
&pdu.to_sync_room_event(),
|
||||
&pdu.room_id,
|
||||
)? {
|
||||
let n = match action {
|
||||
Action::DontNotify => false,
|
||||
// TODO: Implement proper support for coalesce
|
||||
Action::Notify | Action::Coalesce => true,
|
||||
Action::SetTweak(tweak) => {
|
||||
tweaks.push(tweak.clone());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if notify.is_some() {
|
||||
return Err(Error::bad_database(
|
||||
r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#,
|
||||
));
|
||||
}
|
||||
|
||||
notify = Some(n);
|
||||
}
|
||||
|
||||
if notify == Some(true) {
|
||||
self.send_notice(unread, pusher, tweaks, pdu).await?;
|
||||
}
|
||||
// Else the event triggered no actions
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user, ruleset, pdu))]
|
||||
pub fn get_actions<'a>(
|
||||
&self,
|
||||
user: &UserId,
|
||||
ruleset: &'a Ruleset,
|
||||
power_levels: &RoomPowerLevelsEventContent,
|
||||
pdu: &Raw<AnySyncRoomEvent>,
|
||||
room_id: &RoomId,
|
||||
) -> Result<&'a [Action]> {
|
||||
let ctx = PushConditionRoomCtx {
|
||||
room_id: room_id.to_owned(),
|
||||
member_count: 10_u32.into(), // TODO: get member count efficiently
|
||||
user_display_name: services()
|
||||
.users
|
||||
.displayname(user)?
|
||||
.unwrap_or_else(|| user.localpart().to_owned()),
|
||||
users_power_levels: power_levels.users.clone(),
|
||||
default_power_level: power_levels.users_default,
|
||||
notification_power_levels: power_levels.notifications.clone(),
|
||||
};
|
||||
|
||||
Ok(ruleset.get_actions(pdu, &ctx))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, unread, pusher, tweaks, event))]
|
||||
async fn send_notice(
|
||||
&self,
|
||||
unread: UInt,
|
||||
pusher: &get_pushers::v3::Pusher,
|
||||
tweaks: Vec<Tweak>,
|
||||
event: &PduEvent,
|
||||
) -> Result<()> {
|
||||
// TODO: email
|
||||
if pusher.kind == PusherKind::Email {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// TODO:
|
||||
// Two problems with this
|
||||
// 1. if "event_id_only" is the only format kind it seems we should never add more info
|
||||
// 2. can pusher/devices have conflicting formats
|
||||
let event_id_only = pusher.data.format == Some(PushFormat::EventIdOnly);
|
||||
let url = if let Some(url) = &pusher.data.url {
|
||||
url
|
||||
} else {
|
||||
error!("Http Pusher must have URL specified.");
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let mut device = Device::new(pusher.app_id.clone(), pusher.pushkey.clone());
|
||||
let mut data_minus_url = pusher.data.clone();
|
||||
// The url must be stripped off according to spec
|
||||
data_minus_url.url = None;
|
||||
device.data = data_minus_url;
|
||||
|
||||
// Tweaks are only added if the format is NOT event_id_only
|
||||
if !event_id_only {
|
||||
device.tweaks = tweaks.clone();
|
||||
}
|
||||
|
||||
let d = &[device];
|
||||
let mut notifi = Notification::new(d);
|
||||
|
||||
notifi.prio = NotificationPriority::Low;
|
||||
notifi.event_id = Some(&event.event_id);
|
||||
notifi.room_id = Some(&event.room_id);
|
||||
// TODO: missed calls
|
||||
notifi.counts = NotificationCounts::new(unread, uint!(0));
|
||||
|
||||
if event.kind == RoomEventType::RoomEncrypted
|
||||
|| tweaks
|
||||
.iter()
|
||||
.any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_)))
|
||||
{
|
||||
notifi.prio = NotificationPriority::High
|
||||
}
|
||||
|
||||
if event_id_only {
|
||||
self.send_request(url, send_event_notification::v1::Request::new(notifi))
|
||||
.await?;
|
||||
} else {
|
||||
notifi.sender = Some(&event.sender);
|
||||
notifi.event_type = Some(&event.kind);
|
||||
let content = serde_json::value::to_raw_value(&event.content).ok();
|
||||
notifi.content = content.as_deref();
|
||||
|
||||
if event.kind == RoomEventType::RoomMember {
|
||||
notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str());
|
||||
}
|
||||
|
||||
let user_name = services().users.displayname(&event.sender)?;
|
||||
notifi.sender_display_name = user_name.as_deref();
|
||||
|
||||
let room_name = if let Some(room_name_pdu) = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&event.room_id, &StateEventType::RoomName, "")?
|
||||
{
|
||||
serde_json::from_str::<RoomNameEventContent>(room_name_pdu.content.get())
|
||||
.map_err(|_| Error::bad_database("Invalid room name event in database."))?
|
||||
.name
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
notifi.room_name = room_name.as_deref();
|
||||
|
||||
self.send_request(url, send_event_notification::v1::Request::new(notifi))
|
||||
.await?;
|
||||
}
|
||||
|
||||
// TODO: email
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue