127 errors left

This commit is contained in:
Timo Kösters 2022-10-05 18:36:12 +02:00
parent 0f618de7fd
commit 769c1ecdd4
No known key found for this signature in database
GPG key ID: 24DA7517711A2BA4
65 changed files with 809 additions and 556 deletions

View file

@ -654,7 +654,7 @@ async fn join_room_by_id_helper(
// We set the room state after inserting the pdu, so that we never have a moment in time
// where events in the current room state do not exist
services().rooms.state.set_room_state(room_id, shortstatehash)?;
services().rooms.state.set_room_state(room_id, shortstatehash, &state_lock)?;
let statehashid = services().rooms.state.append_to_state(&parsed_pdu)?;
} else {

View file

@ -857,131 +857,6 @@ pub async fn send_transaction_message_route(
Ok(send_transaction_message::v1::Response { pdus: resolved_map.into_iter().map(|(e, r)| (e, r.map_err(|e| e.to_string()))).collect() })
}
#[tracing::instrument(skip(starting_events))]
pub(crate) async fn get_auth_chain<'a>(
room_id: &RoomId,
starting_events: Vec<Arc<EventId>>,
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
const NUM_BUCKETS: usize = 50;
let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS];
let mut i = 0;
for id in starting_events {
let short = services().rooms.short.get_or_create_shorteventid(&id)?;
let bucket_id = (short % NUM_BUCKETS as u64) as usize;
buckets[bucket_id].insert((short, id.clone()));
i += 1;
if i % 100 == 0 {
tokio::task::yield_now().await;
}
}
let mut full_auth_chain = HashSet::new();
let mut hits = 0;
let mut misses = 0;
for chunk in buckets {
if chunk.is_empty() {
continue;
}
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? {
hits += 1;
full_auth_chain.extend(cached.iter().copied());
continue;
}
misses += 1;
let mut chunk_cache = HashSet::new();
let mut hits2 = 0;
let mut misses2 = 0;
let mut i = 0;
for (sevent_id, event_id) in chunk {
if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? {
hits2 += 1;
chunk_cache.extend(cached.iter().copied());
} else {
misses2 += 1;
let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id)?);
services().rooms
.auth_chain
.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
println!(
"cache missed event {} with auth chain len {}",
event_id,
auth_chain.len()
);
chunk_cache.extend(auth_chain.iter());
i += 1;
if i % 100 == 0 {
tokio::task::yield_now().await;
}
};
}
println!(
"chunk missed with len {}, event hits2: {}, misses2: {}",
chunk_cache.len(),
hits2,
misses2
);
let chunk_cache = Arc::new(chunk_cache);
services().rooms
.auth_chain.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
full_auth_chain.extend(chunk_cache.iter());
}
println!(
"total: {}, chunk hits: {}, misses: {}",
full_auth_chain.len(),
hits,
misses
);
Ok(full_auth_chain
.into_iter()
.filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
}
#[tracing::instrument(skip(event_id))]
fn get_auth_chain_inner(
room_id: &RoomId,
event_id: &EventId,
) -> Result<HashSet<u64>> {
let mut todo = vec![Arc::from(event_id)];
let mut found = HashSet::new();
while let Some(event_id) = todo.pop() {
match services().rooms.timeline.get_pdu(&event_id) {
Ok(Some(pdu)) => {
if pdu.room_id != room_id {
return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db"));
}
for auth_event in &pdu.auth_events {
let sauthevent = services()
.rooms.short
.get_or_create_shorteventid(auth_event)?;
if !found.contains(&sauthevent) {
found.insert(sauthevent);
todo.push(auth_event.clone());
}
}
}
Ok(None) => {
warn!("Could not find pdu mentioned in auth events: {}", event_id);
}
Err(e) => {
warn!("Could not load event in auth chain: {} {}", event_id, e);
}
}
}
Ok(found)
}
/// # `GET /_matrix/federation/v1/event/{eventId}`
///
/// Retrieves a single event from the server.
@ -1135,7 +1010,7 @@ pub async fn get_event_authorization_route(
let room_id = <&RoomId>::try_from(room_id_str)
.map_err(|_| Error::bad_database("Invalid room id field in event in database"))?;
let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)]).await?;
let auth_chain_ids = services().rooms.auth_chain.get_auth_chain(room_id, vec![Arc::from(&*body.event_id)]).await?;
Ok(get_event_authorization::v1::Response {
auth_chain: auth_chain_ids
@ -1190,7 +1065,7 @@ pub async fn get_room_state_route(
.collect();
let auth_chain_ids =
get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?;
services().rooms.auth_chain.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?;
Ok(get_room_state::v1::Response {
auth_chain: auth_chain_ids
@ -1246,7 +1121,7 @@ pub async fn get_room_state_ids_route(
.collect();
let auth_chain_ids =
get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?;
services().rooms.auth_chain.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?;
Ok(get_room_state_ids::v1::Response {
auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(),
@ -1449,7 +1324,7 @@ async fn create_join_event(
drop(mutex_lock);
let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?;
let auth_chain_ids = get_auth_chain(
let auth_chain_ids = services().rooms.auth_chain.get_auth_chain(
room_id,
state_ids.iter().map(|(_, id)| id.clone()).collect(),
)

View file

@ -1,11 +1,11 @@
use std::{collections::HashMap, sync::Arc};
use std::collections::HashMap;
use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}, events::{RoomAccountDataEventType, AnyEphemeralRoomEvent}, serde::Raw, RoomId};
use serde::{Serialize, de::DeserializeOwned};
use crate::{Result, database::KeyValueDatabase, service, Error, utils, services};
impl service::account_data::Data for Arc<KeyValueDatabase> {
impl service::account_data::Data for KeyValueDatabase {
/// Places one event in the account data of the user and removes the previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
fn update(

View file

@ -1,5 +1,3 @@
use std::sync::Arc;
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::appservice::Data for KeyValueDatabase {

View file

@ -1,4 +1,4 @@
use std::{collections::BTreeMap, sync::Arc};
use std::collections::BTreeMap;
use async_trait::async_trait;
use futures_util::{stream::FuturesUnordered, StreamExt};
@ -9,7 +9,7 @@ use crate::{Result, service, database::KeyValueDatabase, Error, utils, services}
pub const COUNTER: &[u8] = b"c";
#[async_trait]
impl service::globals::Data for Arc<KeyValueDatabase> {
impl service::globals::Data for KeyValueDatabase {
fn next_count(&self) -> Result<u64> {
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
.map_err(|_| Error::bad_database("Count has invalid bytes."))

View file

@ -1,10 +1,10 @@
use std::{collections::BTreeMap, sync::Arc};
use std::collections::BTreeMap;
use ruma::{UserId, serde::Raw, api::client::{backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, error::ErrorKind}, RoomId};
use crate::{Result, service, database::KeyValueDatabase, services, Error, utils};
impl service::key_backups::Data for Arc<KeyValueDatabase> {
impl service::key_backups::Data for KeyValueDatabase {
fn create_backup(
&self,
user_id: &UserId,

View file

@ -1,10 +1,8 @@
use std::sync::Arc;
use ruma::api::client::error::ErrorKind;
use crate::{database::KeyValueDatabase, service, Error, utils, Result};
impl service::media::Data for Arc<KeyValueDatabase> {
impl service::media::Data for KeyValueDatabase {
fn create_file_metadata(&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>) -> Result<Vec<u8>> {
let mut key = mxc.as_bytes().to_vec();
key.push(0xff);

View file

@ -1,10 +1,8 @@
use std::sync::Arc;
use ruma::{UserId, api::client::push::{set_pusher, get_pushers}};
use crate::{service, database::KeyValueDatabase, Error, Result};
impl service::pusher::Data for Arc<KeyValueDatabase> {
impl service::pusher::Data for KeyValueDatabase {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> {
let mut key = sender.as_bytes().to_vec();
key.push(0xff);

View file

@ -1,10 +1,8 @@
use std::sync::Arc;
use ruma::{RoomId, RoomAliasId, api::client::error::ErrorKind};
use crate::{service, database::KeyValueDatabase, utils, Error, services, Result};
impl service::rooms::alias::Data for Arc<KeyValueDatabase> {
impl service::rooms::alias::Data for KeyValueDatabase {
fn set_alias(
&self,
alias: &RoomAliasId,

View file

@ -2,7 +2,7 @@ use std::{collections::HashSet, mem::size_of, sync::Arc};
use crate::{service, database::KeyValueDatabase, Result, utils};
impl service::rooms::auth_chain::Data for Arc<KeyValueDatabase> {
impl service::rooms::auth_chain::Data for KeyValueDatabase {
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> {
// Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {

View file

@ -1,10 +1,8 @@
use std::sync::Arc;
use ruma::RoomId;
use crate::{service, database::KeyValueDatabase, utils, Error, Result};
impl service::rooms::directory::Data for Arc<KeyValueDatabase> {
impl service::rooms::directory::Data for KeyValueDatabase {
fn set_public(&self, room_id: &RoomId) -> Result<()> {
self.publicroomids.insert(room_id.as_bytes(), &[])
}

View file

@ -2,8 +2,6 @@ mod presence;
mod typing;
mod read_receipt;
use std::sync::Arc;
use crate::{service, database::KeyValueDatabase};
impl service::rooms::edus::Data for Arc<KeyValueDatabase> {}
impl service::rooms::edus::Data for KeyValueDatabase {}

View file

@ -1,10 +1,10 @@
use std::{collections::HashMap, sync::Arc};
use std::collections::HashMap;
use ruma::{UserId, RoomId, events::presence::PresenceEvent, presence::PresenceState, UInt};
use crate::{service, database::KeyValueDatabase, utils, Error, services, Result};
impl service::rooms::edus::presence::Data for Arc<KeyValueDatabase> {
impl service::rooms::edus::presence::Data for KeyValueDatabase {
fn update_presence(
&self,
user_id: &UserId,

View file

@ -1,10 +1,10 @@
use std::{mem, sync::Arc};
use std::mem;
use ruma::{UserId, RoomId, events::receipt::ReceiptEvent, serde::Raw, signatures::CanonicalJsonObject};
use crate::{database::KeyValueDatabase, service, utils, Error, services, Result};
impl service::rooms::edus::read_receipt::Data for Arc<KeyValueDatabase> {
impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
fn readreceipt_update(
&self,
user_id: &UserId,

View file

@ -1,10 +1,10 @@
use std::{collections::HashSet, sync::Arc};
use std::collections::HashSet;
use ruma::{UserId, RoomId};
use crate::{database::KeyValueDatabase, service, utils, Error, services, Result};
impl service::rooms::edus::typing::Data for Arc<KeyValueDatabase> {
impl service::rooms::edus::typing::Data for KeyValueDatabase {
fn typing_add(
&self,
user_id: &UserId,

View file

@ -1,10 +1,8 @@
use std::sync::Arc;
use ruma::{UserId, DeviceId, RoomId};
use crate::{service, database::KeyValueDatabase, Result};
impl service::rooms::lazy_loading::Data for Arc<KeyValueDatabase> {
impl service::rooms::lazy_loading::Data for KeyValueDatabase {
fn lazy_load_was_sent_before(
&self,
user_id: &UserId,

View file

@ -1,10 +1,8 @@
use std::sync::Arc;
use ruma::RoomId;
use crate::{service, database::KeyValueDatabase, Result, services};
impl service::rooms::metadata::Data for Arc<KeyValueDatabase> {
impl service::rooms::metadata::Data for KeyValueDatabase {
fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(),
@ -19,4 +17,18 @@ impl service::rooms::metadata::Data for Arc<KeyValueDatabase> {
.filter(|(k, _)| k.starts_with(&prefix))
.is_some())
}
fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
}
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
if disabled {
self.disabledroomids.insert(room_id.as_bytes(), &[])?;
} else {
self.disabledroomids.remove(room_id.as_bytes())?;
}
Ok(())
}
}

View file

@ -15,8 +15,6 @@ mod state_compressor;
mod timeline;
mod user;
use std::sync::Arc;
use crate::{database::KeyValueDatabase, service};
impl service::rooms::Data for Arc<KeyValueDatabase> {}
impl service::rooms::Data for KeyValueDatabase {}

View file

@ -1,10 +1,8 @@
use std::sync::Arc;
use ruma::{EventId, signatures::CanonicalJsonObject};
use crate::{service, database::KeyValueDatabase, PduEvent, Error, Result};
impl service::rooms::outlier::Data for Arc<KeyValueDatabase> {
impl service::rooms::outlier::Data for KeyValueDatabase {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?

View file

@ -4,7 +4,7 @@ use ruma::{RoomId, EventId};
use crate::{service, database::KeyValueDatabase, Result};
impl service::rooms::pdu_metadata::Data for Arc<KeyValueDatabase> {
impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
for prev in event_ids {
let mut key = room_id.as_bytes().to_vec();

View file

@ -1,10 +1,10 @@
use std::{mem::size_of, sync::Arc};
use std::mem::size_of;
use ruma::RoomId;
use crate::{service, database::KeyValueDatabase, utils, Result, services};
impl service::rooms::search::Data for Arc<KeyValueDatabase> {
impl service::rooms::search::Data for KeyValueDatabase {
fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()> {
let mut batch = message_body
.split_terminator(|c: char| !c.is_alphanumeric())

View file

@ -1,6 +1,227 @@
use std::sync::Arc;
use crate::{database::KeyValueDatabase, service};
use ruma::{EventId, events::StateEventType, RoomId};
impl service::rooms::short::Data for Arc<KeyValueDatabase> {
use crate::{Result, database::KeyValueDatabase, service, utils, Error, services};
impl service::rooms::short::Data for KeyValueDatabase {
fn get_or_create_shorteventid(
&self,
event_id: &EventId,
) -> Result<u64> {
if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) {
return Ok(*short);
}
let short = match self.eventid_shorteventid.get(event_id.as_bytes())? {
Some(shorteventid) => utils::u64_from_bytes(&shorteventid)
.map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
None => {
let shorteventid = services().globals.next_count()?;
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
shorteventid
}
};
self.eventidshort_cache
.lock()
.unwrap()
.insert(event_id.to_owned(), short);
Ok(short)
}
fn get_shortstatekey(
&self,
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<u64>> {
if let Some(short) = self
.statekeyshort_cache
.lock()
.unwrap()
.get_mut(&(event_type.clone(), state_key.to_owned()))
{
return Ok(Some(*short));
}
let mut statekey = event_type.to_string().as_bytes().to_vec();
statekey.push(0xff);
statekey.extend_from_slice(state_key.as_bytes());
let short = self
.statekey_shortstatekey
.get(&statekey)?
.map(|shortstatekey| {
utils::u64_from_bytes(&shortstatekey)
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
})
.transpose()?;
if let Some(s) = short {
self.statekeyshort_cache
.lock()
.unwrap()
.insert((event_type.clone(), state_key.to_owned()), s);
}
Ok(short)
}
fn get_or_create_shortstatekey(
&self,
event_type: &StateEventType,
state_key: &str,
) -> Result<u64> {
if let Some(short) = self
.statekeyshort_cache
.lock()
.unwrap()
.get_mut(&(event_type.clone(), state_key.to_owned()))
{
return Ok(*short);
}
let mut statekey = event_type.to_string().as_bytes().to_vec();
statekey.push(0xff);
statekey.extend_from_slice(state_key.as_bytes());
let short = match self.statekey_shortstatekey.get(&statekey)? {
Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey)
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
None => {
let shortstatekey = services().globals.next_count()?;
self.statekey_shortstatekey
.insert(&statekey, &shortstatekey.to_be_bytes())?;
self.shortstatekey_statekey
.insert(&shortstatekey.to_be_bytes(), &statekey)?;
shortstatekey
}
};
self.statekeyshort_cache
.lock()
.unwrap()
.insert((event_type.clone(), state_key.to_owned()), short);
Ok(short)
}
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
if let Some(id) = self
.shorteventid_cache
.lock()
.unwrap()
.get_mut(&shorteventid)
{
return Ok(Arc::clone(id));
}
let bytes = self
.shorteventid_eventid
.get(&shorteventid.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("EventID in shorteventid_eventid is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
self.shorteventid_cache
.lock()
.unwrap()
.insert(shorteventid, Arc::clone(&event_id));
Ok(event_id)
}
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
if let Some(id) = self
.shortstatekey_cache
.lock()
.unwrap()
.get_mut(&shortstatekey)
{
return Ok(id.clone());
}
let bytes = self
.shortstatekey_statekey
.get(&shortstatekey.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
let mut parts = bytes.splitn(2, |&b| b == 0xff);
let eventtype_bytes = parts.next().expect("split always returns one entry");
let statekey_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
let event_type =
StateEventType::try_from(utils::string_from_bytes(eventtype_bytes).map_err(|_| {
Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("Event type in shortstatekey_statekey is invalid."))?;
let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| {
Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.")
})?;
let result = (event_type, state_key);
self.shortstatekey_cache
.lock()
.unwrap()
.insert(shortstatekey, result.clone());
Ok(result)
}
/// Returns (shortstatehash, already_existed)
fn get_or_create_shortstatehash(
&self,
state_hash: &[u8],
) -> Result<(u64, bool)> {
Ok(match self.statehash_shortstatehash.get(state_hash)? {
Some(shortstatehash) => (
utils::u64_from_bytes(&shortstatehash)
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
true,
),
None => {
let shortstatehash = services().globals.next_count()?;
self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false)
}
})
}
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortroomid
.get(room_id.as_bytes())?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))
})
.transpose()
}
fn get_or_create_shortroomid(
&self,
room_id: &RoomId,
) -> Result<u64> {
Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
Some(short) => utils::u64_from_bytes(&short)
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))?,
None => {
let short = services().globals.next_count()?;
self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
short
}
})
}
}

View file

@ -6,7 +6,7 @@ use std::fmt::Debug;
use crate::{service, database::KeyValueDatabase, utils, Error, Result};
impl service::rooms::state::Data for Arc<KeyValueDatabase> {
impl service::rooms::state::Data for KeyValueDatabase {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash
.get(room_id.as_bytes())?

View file

@ -5,7 +5,7 @@ use async_trait::async_trait;
use ruma::{EventId, events::StateEventType, RoomId};
#[async_trait]
impl service::rooms::state_accessor::Data for Arc<KeyValueDatabase> {
impl service::rooms::state_accessor::Data for KeyValueDatabase {
async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> {
let full_state = services().rooms.state_compressor
.load_shortstatehash_info(shortstatehash)?

View file

@ -1,10 +1,8 @@
use std::sync::Arc;
use ruma::{UserId, RoomId, events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw};
use crate::{service, database::KeyValueDatabase, services, Result};
impl service::rooms::state_cache::Data for Arc<KeyValueDatabase> {
impl service::rooms::state_cache::Data for KeyValueDatabase {
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);

View file

@ -1,8 +1,8 @@
use std::{collections::HashSet, mem::size_of, sync::Arc};
use std::{collections::HashSet, mem::size_of};
use crate::{service::{self, rooms::state_compressor::data::StateDiff}, database::KeyValueDatabase, Error, utils, Result};
impl service::rooms::state_compressor::Data for Arc<KeyValueDatabase> {
impl service::rooms::state_compressor::Data for KeyValueDatabase {
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
let value = self
.shortstatehash_statediff

View file

@ -5,7 +5,27 @@ use tracing::error;
use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent, Result, services};
impl service::rooms::timeline::Data for Arc<KeyValueDatabase> {
impl service::rooms::timeline::Data for KeyValueDatabase {
fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> {
let prefix = services().rooms.short
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
// Look for PDUs in that room.
self.pduid_pdu
.iter_from(&prefix, false)
.filter(|(k, _)| k.starts_with(&prefix))
.map(|(_, pdu)| {
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid first PDU in db."))
.map(Arc::new)
})
.next()
.transpose()
}
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> {
match self
.lasttimelinecount_cache

View file

@ -1,10 +1,8 @@
use std::sync::Arc;
use ruma::{UserId, RoomId};
use crate::{service, database::KeyValueDatabase, utils, Error, Result, services};
impl service::rooms::user::Data for Arc<KeyValueDatabase> {
impl service::rooms::user::Data for KeyValueDatabase {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
@ -104,13 +102,13 @@ impl service::rooms::user::Data for Arc<KeyValueDatabase> {
});
// We use the default compare function because keys are sorted correctly (not reversed)
Ok(utils::common_elements(iterators, Ord::cmp)
Ok(Box::new(Box::new(utils::common_elements(iterators, Ord::cmp)
.expect("users is not empty")
.map(|bytes| {
RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| {
Error::bad_database("Invalid RoomId bytes in userroomid_joined")
})?)
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
}))
}))))
}
}

View file

@ -1,10 +1,8 @@
use std::sync::Arc;
use ruma::{UserId, DeviceId, TransactionId};
use crate::{service, database::KeyValueDatabase, Result};
impl service::transaction_ids::Data for Arc<KeyValueDatabase> {
impl service::transaction_ids::Data for KeyValueDatabase {
fn add_txnid(
&self,
user_id: &UserId,

View file

@ -1,10 +1,8 @@
use std::sync::Arc;
use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}};
use crate::{database::KeyValueDatabase, service, Error, Result};
impl service::uiaa::Data for Arc<KeyValueDatabase> {
impl service::uiaa::Data for KeyValueDatabase {
fn set_uiaa_request(
&self,
user_id: &UserId,

View file

@ -1,11 +1,11 @@
use std::{mem::size_of, collections::BTreeMap, sync::Arc};
use std::{mem::size_of, collections::BTreeMap};
use ruma::{api::client::{filter::IncomingFilterDefinition, error::ErrorKind, device::Device}, UserId, RoomAliasId, MxcUri, DeviceId, MilliSecondsSinceUnixEpoch, DeviceKeyId, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, serde::Raw, events::{AnyToDeviceEvent, StateEventType}, DeviceKeyAlgorithm, UInt};
use tracing::warn;
use crate::{service::{self, users::clean_signatures}, database::KeyValueDatabase, Error, utils, services, Result};
impl service::users::Data for Arc<KeyValueDatabase> {
impl service::users::Data for KeyValueDatabase {
/// Check if a user has an account on this homeserver.
fn exists(&self, user_id: &UserId) -> Result<bool> {
Ok(self.userid_password.get(user_id.as_bytes())?.is_some())
@ -113,7 +113,7 @@ impl service::users::Data for Arc<KeyValueDatabase> {
/// Hash and set the user's password to the Argon2 hash
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
if let Some(password) = password {
if let Ok(hash) = utils::calculate_hash(password) {
if let Ok(hash) = utils::calculate_password_hash(password) {
self.userid_password
.insert(user_id.as_bytes(), hash.as_bytes())?;
Ok(())

View file

@ -238,8 +238,8 @@ impl KeyValueDatabase {
}
/// Load an existing database or create a new one.
pub async fn load_or_create(config: &Config) -> Result<()> {
Self::check_db_setup(config)?;
pub async fn load_or_create(config: Config) -> Result<()> {
Self::check_db_setup(&config)?;
if !Path::new(&config.database_path).exists() {
std::fs::create_dir_all(&config.database_path)
@ -251,19 +251,19 @@ impl KeyValueDatabase {
#[cfg(not(feature = "sqlite"))]
return Err(Error::BadConfig("Database backend not found."));
#[cfg(feature = "sqlite")]
Arc::new(Arc::<abstraction::sqlite::Engine>::open(config)?)
Arc::new(Arc::<abstraction::sqlite::Engine>::open(&config)?)
}
"rocksdb" => {
#[cfg(not(feature = "rocksdb"))]
return Err(Error::BadConfig("Database backend not found."));
#[cfg(feature = "rocksdb")]
Arc::new(Arc::<abstraction::rocksdb::Engine>::open(config)?)
Arc::new(Arc::<abstraction::rocksdb::Engine>::open(&config)?)
}
"persy" => {
#[cfg(not(feature = "persy"))]
return Err(Error::BadConfig("Database backend not found."));
#[cfg(feature = "persy")]
Arc::new(Arc::<abstraction::persy::Engine>::open(config)?)
Arc::new(Arc::<abstraction::persy::Engine>::open(&config)?)
}
_ => {
return Err(Error::BadConfig("Database backend not found."));
@ -402,7 +402,7 @@ impl KeyValueDatabase {
});
let services_raw = Box::new(Services::build(Arc::clone(&db)));
let services_raw = Box::new(Services::build(Arc::clone(&db), config)?);
// This is the first and only time we initialize the SERVICE static
*SERVICES.write().unwrap() = Some(Box::leak(services_raw));
@ -825,7 +825,7 @@ impl KeyValueDatabase {
info!(
"Loaded {} database with version {}",
config.database_backend, latest_database_version
services().globals.config.database_backend, latest_database_version
);
} else {
services()
@ -837,7 +837,7 @@ impl KeyValueDatabase {
warn!(
"Created new {} database with version {}",
config.database_backend, latest_database_version
services().globals.config.database_backend, latest_database_version
);
}
@ -866,7 +866,7 @@ impl KeyValueDatabase {
.sending
.start_handler(sending_receiver);
Self::start_cleanup_task(config).await;
Self::start_cleanup_task().await;
Ok(())
}
@ -888,8 +888,8 @@ impl KeyValueDatabase {
res
}
#[tracing::instrument(skip(config))]
pub async fn start_cleanup_task(config: &Config) {
#[tracing::instrument]
pub async fn start_cleanup_task() {
use tokio::time::interval;
#[cfg(unix)]
@ -898,7 +898,7 @@ impl KeyValueDatabase {
use std::time::{Duration, Instant};
let timer_interval = Duration::from_secs(config.cleanup_second_interval as u64);
let timer_interval = Duration::from_secs(services().globals.config.cleanup_second_interval as u64);
tokio::spawn(async move {
let mut i = interval(timer_interval);

View file

@ -18,7 +18,7 @@ use tracing::error;
use crate::{service::*, services, utils, Error, Result};
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -426,7 +426,7 @@ impl Service {
Error::bad_database("Invalid room id field in event in database")
})?;
let start = Instant::now();
let count = server_server::get_auth_chain(room_id, vec![event_id])
let count = services().rooms.auth_chain.get_auth_chain(room_id, vec![event_id])
.await?
.count();
let elapsed = start.elapsed();
@ -615,14 +615,12 @@ impl Service {
))
}
AdminCommand::DisableRoom { room_id } => {
todo!();
//services().rooms.disabledroomids.insert(room_id.as_bytes(), &[])?;
//RoomMessageEventContent::text_plain("Room disabled.")
services().rooms.metadata.disable_room(&room_id, true);
RoomMessageEventContent::text_plain("Room disabled.")
}
AdminCommand::EnableRoom { room_id } => {
todo!();
//services().rooms.disabledroomids.remove(room_id.as_bytes())?;
//RoomMessageEventContent::text_plain("Room enabled.")
services().rooms.metadata.disable_room(&room_id, false);
RoomMessageEventContent::text_plain("Room enabled.")
}
AdminCommand::DeactivateUser {
leave_rooms,

View file

@ -35,7 +35,7 @@ type SyncHandle = (
);
pub struct Service {
pub db: Box<dyn Data>,
pub db: Arc<dyn Data>,
pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub tls_name_override: Arc<RwLock<TlsNameMap>>,
@ -92,7 +92,7 @@ impl Default for RotationHandler {
impl Service {
pub fn load(
db: Box<dyn Data>,
db: Arc<dyn Data>,
config: Config,
) -> Result<Self> {
let keypair = db.load_keypair();

View file

@ -13,7 +13,7 @@ use ruma::{
use std::{collections::BTreeMap, sync::Arc};
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -16,7 +16,7 @@ pub struct FileMeta {
}
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -1,4 +1,9 @@
use std::sync::Arc;
use std::{
collections::{BTreeMap, HashMap},
sync::{Arc, Mutex},
};
use crate::{Result, Config};
pub mod account_data;
pub mod admin;
@ -30,20 +35,73 @@ pub struct Services {
}
impl Services {
pub fn build<D: appservice::Data + pusher::Data + rooms::Data + transaction_ids::Data + uiaa::Data + users::Data + account_data::Data + globals::Data + key_backups::Data + media::Data>(db: Arc<D>) -> Self {
Self {
pub fn build<
D: appservice::Data
+ pusher::Data
+ rooms::Data
+ transaction_ids::Data
+ uiaa::Data
+ users::Data
+ account_data::Data
+ globals::Data
+ key_backups::Data
+ media::Data,
>(
db: Arc<D>, config: Config
) -> Result<Self> {
Ok(Self {
appservice: appservice::Service { db: db.clone() },
pusher: pusher::Service { db: db.clone() },
rooms: rooms::Service { db: Arc::clone(&db) },
transaction_ids: transaction_ids::Service { db: Arc::clone(&db) },
uiaa: uiaa::Service { db: Arc::clone(&db) },
users: users::Service { db: Arc::clone(&db) },
account_data: account_data::Service { db: Arc::clone(&db) },
admin: admin::Service { db: Arc::clone(&db) },
globals: globals::Service { db: Arc::clone(&db) },
key_backups: key_backups::Service { db: Arc::clone(&db) },
media: media::Service { db: Arc::clone(&db) },
sending: sending::Service { db: Arc::clone(&db) },
}
rooms: rooms::Service {
alias: rooms::alias::Service { db: db.clone() },
auth_chain: rooms::auth_chain::Service { db: db.clone() },
directory: rooms::directory::Service { db: db.clone() },
edus: rooms::edus::Service {
presence: rooms::edus::presence::Service { db: db.clone() },
read_receipt: rooms::edus::read_receipt::Service { db: db.clone() },
typing: rooms::edus::typing::Service { db: db.clone() },
},
event_handler: rooms::event_handler::Service,
lazy_loading: rooms::lazy_loading::Service {
db: db.clone(),
lazy_load_waiting: Mutex::new(HashMap::new()),
},
metadata: rooms::metadata::Service { db: db.clone() },
outlier: rooms::outlier::Service { db: db.clone() },
pdu_metadata: rooms::pdu_metadata::Service { db: db.clone() },
search: rooms::search::Service { db: db.clone() },
short: rooms::short::Service { db: db.clone() },
state: rooms::state::Service { db: db.clone() },
state_accessor: rooms::state_accessor::Service { db: db.clone() },
state_cache: rooms::state_cache::Service { db: db.clone() },
state_compressor: rooms::state_compressor::Service { db: db.clone() },
timeline: rooms::timeline::Service { db: db.clone() },
user: rooms::user::Service { db: db.clone() },
},
transaction_ids: transaction_ids::Service {
db: db.clone()
},
uiaa: uiaa::Service {
db: db.clone()
},
users: users::Service {
db: db.clone()
},
account_data: account_data::Service {
db: db.clone()
},
admin: admin::Service { sender: todo!() },
globals: globals::Service::load(db.clone(), config)?,
key_backups: key_backups::Service {
db: db.clone()
},
media: media::Service {
db: db.clone()
},
sending: sending::Service {
maximum_requests: todo!(),
sender: todo!(),
},
})
}
}

View file

@ -1,11 +1,13 @@
mod data;
use std::sync::Arc;
pub use data::Data;
use ruma::{RoomAliasId, RoomId};
use crate::Result;
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -1,12 +1,14 @@
mod data;
use std::{sync::Arc, collections::HashSet};
use std::{sync::Arc, collections::{HashSet, BTreeSet}};
pub use data::Data;
use ruma::{RoomId, EventId, api::client::error::ErrorKind};
use tracing::log::warn;
use crate::Result;
use crate::{Result, services, Error};
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {
@ -22,4 +24,131 @@ impl Service {
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> {
self.db.cache_auth_chain(key, auth_chain)
}
#[tracing::instrument(skip(self, starting_events))]
pub async fn get_auth_chain<'a>(
&self,
room_id: &RoomId,
starting_events: Vec<Arc<EventId>>,
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
const NUM_BUCKETS: usize = 50;
let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS];
let mut i = 0;
for id in starting_events {
let short = services().rooms.short.get_or_create_shorteventid(&id)?;
let bucket_id = (short % NUM_BUCKETS as u64) as usize;
buckets[bucket_id].insert((short, id.clone()));
i += 1;
if i % 100 == 0 {
tokio::task::yield_now().await;
}
}
let mut full_auth_chain = HashSet::new();
let mut hits = 0;
let mut misses = 0;
for chunk in buckets {
if chunk.is_empty() {
continue;
}
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? {
hits += 1;
full_auth_chain.extend(cached.iter().copied());
continue;
}
misses += 1;
let mut chunk_cache = HashSet::new();
let mut hits2 = 0;
let mut misses2 = 0;
let mut i = 0;
for (sevent_id, event_id) in chunk {
if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? {
hits2 += 1;
chunk_cache.extend(cached.iter().copied());
} else {
misses2 += 1;
let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?);
services().rooms
.auth_chain
.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
println!(
"cache missed event {} with auth chain len {}",
event_id,
auth_chain.len()
);
chunk_cache.extend(auth_chain.iter());
i += 1;
if i % 100 == 0 {
tokio::task::yield_now().await;
}
};
}
println!(
"chunk missed with len {}, event hits2: {}, misses2: {}",
chunk_cache.len(),
hits2,
misses2
);
let chunk_cache = Arc::new(chunk_cache);
services().rooms
.auth_chain.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
full_auth_chain.extend(chunk_cache.iter());
}
println!(
"total: {}, chunk hits: {}, misses: {}",
full_auth_chain.len(),
hits,
misses
);
Ok(full_auth_chain
.into_iter()
.filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
}
#[tracing::instrument(skip(self, event_id))]
fn get_auth_chain_inner(
&self,
room_id: &RoomId,
event_id: &EventId,
) -> Result<HashSet<u64>> {
let mut todo = vec![Arc::from(event_id)];
let mut found = HashSet::new();
while let Some(event_id) = todo.pop() {
match services().rooms.timeline.get_pdu(&event_id) {
Ok(Some(pdu)) => {
if pdu.room_id != room_id {
return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db"));
}
for auth_event in &pdu.auth_events {
let sauthevent = services()
.rooms.short
.get_or_create_shorteventid(auth_event)?;
if !found.contains(&sauthevent) {
found.insert(sauthevent);
todo.push(auth_event.clone());
}
}
}
Ok(None) => {
warn!("Could not find pdu mentioned in auth events: {}", event_id);
}
Err(e) => {
warn!("Could not load event in auth chain: {} {}", event_id, e);
}
}
}
Ok(found)
}
}

View file

@ -1,11 +1,13 @@
mod data;
use std::sync::Arc;
pub use data::Data;
use ruma::RoomId;
use crate::Result;
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -1,5 +1,5 @@
mod data;
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
pub use data::Data;
use ruma::{RoomId, UserId, events::presence::PresenceEvent};
@ -7,7 +7,7 @@ use ruma::{RoomId, UserId, events::presence::PresenceEvent};
use crate::Result;
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -1,11 +1,13 @@
mod data;
use std::sync::Arc;
pub use data::Data;
use ruma::{RoomId, UserId, events::receipt::ReceiptEvent, serde::Raw};
use crate::Result;
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -1,11 +1,13 @@
mod data;
use std::sync::Arc;
pub use data::Data;
use ruma::{UserId, RoomId, events::SyncEphemeralRoomEvent};
use crate::Result;
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -72,13 +72,15 @@ impl Service {
));
}
services()
if services()
.rooms
.is_disabled(room_id)?
.ok_or(Error::BadRequest(
.metadata
.is_disabled(room_id)? {
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Federation of this room is currently disabled on this server.",
))?;
));
}
// 1. Skip the PDU if we already have it as a timeline event
if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(event_id)? {
@ -111,7 +113,7 @@ impl Service {
}
// 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events
let (sorted_prev_events, eventid_info) = self.fetch_unknown_prev_events(
let (sorted_prev_events, mut eventid_info) = self.fetch_unknown_prev_events(
origin,
&create_event,
room_id,
@ -122,14 +124,15 @@ impl Service {
let mut errors = 0;
for prev_id in dbg!(sorted_prev_events) {
// Check for disabled again because it might have changed
services()
if services()
.rooms
.is_disabled(room_id)?
.ok_or(Error::BadRequest(
.metadata
.is_disabled(room_id)? {
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Federation of
this room is currently disabled on this server.",
))?;
"Federation of this room is currently disabled on this server.",
));
}
if let Some((time, tries)) = services()
.globals
@ -279,14 +282,14 @@ impl Service {
Err(e) => {
// Drop
warn!("Dropping bad event {}: {}", event_id, e);
return Err("Signature verification failed".to_owned());
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Signature verification failed"));
}
Ok(ruma::signatures::Verified::Signatures) => {
// Redact
warn!("Calculated hash does not match: {}", event_id);
match ruma::signatures::redact(&value, room_version_id) {
Ok(obj) => obj,
Err(_) => return Err("Redaction failed".to_owned()),
Err(_) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Redaction failed")),
}
}
Ok(ruma::signatures::Verified::All) => value,
@ -480,7 +483,7 @@ impl Service {
let mut okay = true;
for prev_eventid in &incoming_pdu.prev_events {
let prev_event = if let Ok(Some(pdu)) = services().rooms.get_pdu(prev_eventid) {
let prev_event = if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(prev_eventid) {
pdu
} else {
okay = false;
@ -488,7 +491,7 @@ impl Service {
};
let sstatehash =
if let Ok(Some(s)) = services().rooms.pdu_shortstatehash(prev_eventid) {
if let Ok(Some(s)) = services().rooms.state_accessor.pdu_shortstatehash(prev_eventid) {
s
} else {
okay = false;
@ -525,7 +528,7 @@ impl Service {
let mut starting_events = Vec::with_capacity(leaf_state.len());
for (k, id) in leaf_state {
if let Ok((ty, st_key)) = services().rooms.get_statekey_from_short(k) {
if let Ok((ty, st_key)) = services().rooms.short.get_statekey_from_short(k) {
// FIXME: Undo .to_string().into() when StateMap
// is updated to use StateEventType
state.insert((ty.to_string().into(), st_key), id.clone());
@ -539,7 +542,7 @@ impl Service {
services()
.rooms
.auth_chain
.get_auth_chain(room_id, starting_events, services())
.get_auth_chain(room_id, starting_events)
.await?
.collect(),
);
@ -551,7 +554,7 @@ impl Service {
let result =
state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| {
let res = services().rooms.get_pdu(id);
let res = services().rooms.timeline.get_pdu(id);
if let Err(e) = &res {
error!("LOOK AT ME Failed to fetch event: {}", e);
}
@ -677,7 +680,7 @@ impl Service {
.and_then(|event_id| services().rooms.timeline.get_pdu(event_id).ok().flatten())
},
)
.map_err(|_e| "Auth check failed.".to_owned())?;
.map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed."))?;
if !check_result {
return Err(Error::bad_database("Event has failed auth check with state at the event."));
@ -714,7 +717,7 @@ impl Service {
// Only keep those extremities were not referenced yet
extremities
.retain(|id| !matches!(services().rooms.is_event_referenced(room_id, id), Ok(true)));
.retain(|id| !matches!(services().rooms.pdu_metadata.is_event_referenced(room_id, id), Ok(true)));
info!("Compressing state at event");
let state_ids_compressed = state_at_incoming_event
@ -722,7 +725,8 @@ impl Service {
.map(|(shortstatekey, id)| {
services()
.rooms
.compress_state_event(*shortstatekey, id)?
.state_compressor
.compress_state_event(*shortstatekey, id)
})
.collect::<Result<_>>()?;
@ -731,6 +735,7 @@ impl Service {
let auth_events = services()
.rooms
.state
.get_auth_events(
room_id,
&incoming_pdu.kind,
@ -744,10 +749,10 @@ impl Service {
&incoming_pdu,
None::<PduEvent>,
|k, s| auth_events.get(&(k.clone(), s.to_owned())),
)?;
).map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed."))?;
if soft_fail {
self.append_incoming_pdu(
services().rooms.timeline.append_incoming_pdu(
&incoming_pdu,
val,
extremities.iter().map(std::ops::Deref::deref),
@ -760,8 +765,9 @@ impl Service {
warn!("Event was soft failed: {:?}", incoming_pdu);
services()
.rooms
.pdu_metadata
.mark_event_soft_failed(&incoming_pdu.event_id)?;
return Err("Event has been soft failed".into());
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed"));
}
if incoming_pdu.state_key.is_some() {
@ -798,14 +804,14 @@ impl Service {
"Found extremity pdu with no statehash in db: {:?}",
leaf_pdu
);
"Found pdu with no statehash in db.".to_owned()
Error::bad_database("Found pdu with no statehash in db.")
})?,
leaf_pdu,
);
}
_ => {
error!("Missing state snapshot for {:?}", id);
return Err("Missing state snapshot.".to_owned());
return Err(Error::BadDatabase("Missing state snapshot."));
}
}
}
@ -835,7 +841,7 @@ impl Service {
let mut update_state = false;
// 14. Use state resolution to find new room state
let new_room_state = if fork_states.is_empty() {
return Err("State is empty.".to_owned());
panic!("State is empty");
} else if fork_states.iter().skip(1).all(|f| &fork_states[0] == f) {
info!("State resolution trivial");
// There was only one state, so it has to be the room's current state (because that is
@ -845,7 +851,8 @@ impl Service {
.map(|(k, id)| {
services()
.rooms
.compress_state_event(*k, id)?
.state_compressor
.compress_state_event(*k, id)
})
.collect::<Result<_>>()?
} else {
@ -877,9 +884,8 @@ impl Service {
.filter_map(|(k, id)| {
services()
.rooms
.get_statekey_from_short(k)?
// FIXME: Undo .to_string().into() when StateMap
// is updated to use StateEventType
.short
.get_statekey_from_short(k)
.map(|(ty, st_key)| ((ty.to_string().into(), st_key), id))
.ok()
})
@ -895,7 +901,7 @@ impl Service {
&fork_states,
auth_chain_sets,
|id| {
let res = services().rooms.get_pdu(id);
let res = services().rooms.timeline.get_pdu(id);
if let Err(e) = &res {
error!("LOOK AT ME Failed to fetch event: {}", e);
}
@ -904,7 +910,7 @@ impl Service {
) {
Ok(new_state) => new_state,
Err(_) => {
return Err("State resolution failed, either an event could not be found or deserialization".into());
return Err(Error::bad_database("State resolution failed, either an event could not be found or deserialization"));
}
};
@ -921,6 +927,7 @@ impl Service {
.get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?;
services()
.rooms
.state_compressor
.compress_state_event(shortstatekey, &event_id)
})
.collect::<Result<_>>()?
@ -929,9 +936,11 @@ impl Service {
// Set the new room state to the resolved state
if update_state {
info!("Forcing new room state");
let (sstatehash, _, _) = services().rooms.state_compressor.save_state(room_id, new_room_state)?;
services()
.rooms
.force_state(room_id, new_room_state)?;
.state
.set_room_state(room_id, sstatehash, &state_lock)?;
}
}
@ -942,7 +951,7 @@ impl Service {
// We use the `state_at_event` instead of `state_after` so we accurately
// represent the state for this event.
let pdu_id = self
let pdu_id = services().rooms.timeline
.append_incoming_pdu(
&incoming_pdu,
val,
@ -1017,7 +1026,7 @@ impl Service {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
if let Ok(Some(local_pdu)) = services().rooms.get_pdu(id) {
if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) {
trace!("Found {} in db", id);
pdus.push((local_pdu, None));
continue;
@ -1040,7 +1049,7 @@ impl Service {
tokio::task::yield_now().await;
}
if let Ok(Some(_)) = services().rooms.get_pdu(&next_id) {
if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) {
trace!("Found {} in db", id);
continue;
}
@ -1140,6 +1149,7 @@ impl Service {
let first_pdu_in_room = services()
.rooms
.timeline
.first_pdu_in_room(room_id)?
.ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?;

View file

@ -1,5 +1,5 @@
mod data;
use std::{collections::{HashSet, HashMap}, sync::Mutex};
use std::{collections::{HashSet, HashMap}, sync::{Mutex, Arc}};
pub use data::Data;
use ruma::{DeviceId, UserId, RoomId};
@ -7,7 +7,7 @@ use ruma::{DeviceId, UserId, RoomId};
use crate::Result;
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
lazy_load_waiting: Mutex<HashMap<(Box<UserId>, Box<DeviceId>, Box<RoomId>, u64), HashSet<Box<UserId>>>>,
}

View file

@ -3,4 +3,6 @@ use crate::Result;
pub trait Data: Send + Sync {
fn exists(&self, room_id: &RoomId) -> Result<bool>;
fn is_disabled(&self, room_id: &RoomId) -> Result<bool>;
fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()>;
}

View file

@ -1,11 +1,13 @@
mod data;
use std::sync::Arc;
pub use data::Data;
use ruma::RoomId;
use crate::Result;
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {
@ -14,4 +16,12 @@ impl Service {
pub fn exists(&self, room_id: &RoomId) -> Result<bool> {
self.db.exists(room_id)
}
pub fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
self.db.is_disabled(room_id)
}
pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
self.db.disable_room(room_id, disabled)
}
}

View file

@ -1,11 +1,13 @@
mod data;
use std::sync::Arc;
pub use data::Data;
use ruma::{EventId, signatures::CanonicalJsonObject};
use crate::{Result, PduEvent};
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -7,7 +7,7 @@ use ruma::{RoomId, EventId};
use crate::Result;
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -1,11 +1,13 @@
mod data;
use std::sync::Arc;
pub use data::Data;
use crate::Result;
use ruma::RoomId;
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -1,2 +1,40 @@
use std::sync::Arc;
use ruma::{EventId, events::StateEventType, RoomId};
use crate::Result;
pub trait Data: Send + Sync {
fn get_or_create_shorteventid(
&self,
event_id: &EventId,
) -> Result<u64>;
fn get_shortstatekey(
&self,
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<u64>>;
fn get_or_create_shortstatekey(
&self,
event_type: &StateEventType,
state_key: &str,
) -> Result<u64>;
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>>;
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>;
/// Returns (shortstatehash, already_existed)
fn get_or_create_shortstatehash(
&self,
state_hash: &[u8],
) -> Result<(u64, bool)>;
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>>;
fn get_or_create_shortroomid(
&self,
room_id: &RoomId,
) -> Result<u64>;
}

View file

@ -7,7 +7,7 @@ use ruma::{EventId, events::StateEventType, RoomId};
use crate::{Result, Error, utils, services};
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {
@ -15,29 +15,7 @@ impl Service {
&self,
event_id: &EventId,
) -> Result<u64> {
if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) {
return Ok(*short);
}
let short = match self.eventid_shorteventid.get(event_id.as_bytes())? {
Some(shorteventid) => utils::u64_from_bytes(&shorteventid)
.map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
None => {
let shorteventid = services().globals.next_count()?;
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
shorteventid
}
};
self.eventidshort_cache
.lock()
.unwrap()
.insert(event_id.to_owned(), short);
Ok(short)
self.db.get_or_create_shorteventid(event_id)
}
pub fn get_shortstatekey(
@ -45,36 +23,7 @@ impl Service {
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<u64>> {
if let Some(short) = self
.statekeyshort_cache
.lock()
.unwrap()
.get_mut(&(event_type.clone(), state_key.to_owned()))
{
return Ok(Some(*short));
}
let mut statekey = event_type.to_string().as_bytes().to_vec();
statekey.push(0xff);
statekey.extend_from_slice(state_key.as_bytes());
let short = self
.statekey_shortstatekey
.get(&statekey)?
.map(|shortstatekey| {
utils::u64_from_bytes(&shortstatekey)
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
})
.transpose()?;
if let Some(s) = short {
self.statekeyshort_cache
.lock()
.unwrap()
.insert((event_type.clone(), state_key.to_owned()), s);
}
Ok(short)
self.db.get_shortstatekey(event_type, state_key)
}
pub fn get_or_create_shortstatekey(
@ -82,152 +31,33 @@ impl Service {
event_type: &StateEventType,
state_key: &str,
) -> Result<u64> {
if let Some(short) = self
.statekeyshort_cache
.lock()
.unwrap()
.get_mut(&(event_type.clone(), state_key.to_owned()))
{
return Ok(*short);
}
let mut statekey = event_type.to_string().as_bytes().to_vec();
statekey.push(0xff);
statekey.extend_from_slice(state_key.as_bytes());
let short = match self.statekey_shortstatekey.get(&statekey)? {
Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey)
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
None => {
let shortstatekey = services().globals.next_count()?;
self.statekey_shortstatekey
.insert(&statekey, &shortstatekey.to_be_bytes())?;
self.shortstatekey_statekey
.insert(&shortstatekey.to_be_bytes(), &statekey)?;
shortstatekey
}
};
self.statekeyshort_cache
.lock()
.unwrap()
.insert((event_type.clone(), state_key.to_owned()), short);
Ok(short)
self.db.get_or_create_shortstatekey(event_type, state_key)
}
pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
if let Some(id) = self
.shorteventid_cache
.lock()
.unwrap()
.get_mut(&shorteventid)
{
return Ok(Arc::clone(id));
}
let bytes = self
.shorteventid_eventid
.get(&shorteventid.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("EventID in shorteventid_eventid is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
self.shorteventid_cache
.lock()
.unwrap()
.insert(shorteventid, Arc::clone(&event_id));
Ok(event_id)
self.db.get_eventid_from_short(shorteventid)
}
pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
if let Some(id) = self
.shortstatekey_cache
.lock()
.unwrap()
.get_mut(&shortstatekey)
{
return Ok(id.clone());
}
let bytes = self
.shortstatekey_statekey
.get(&shortstatekey.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
let mut parts = bytes.splitn(2, |&b| b == 0xff);
let eventtype_bytes = parts.next().expect("split always returns one entry");
let statekey_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
let event_type =
StateEventType::try_from(utils::string_from_bytes(eventtype_bytes).map_err(|_| {
Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.")
})?)
.map_err(|_| Error::bad_database("Event type in shortstatekey_statekey is invalid."))?;
let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| {
Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.")
})?;
let result = (event_type, state_key);
self.shortstatekey_cache
.lock()
.unwrap()
.insert(shortstatekey, result.clone());
Ok(result)
self.db.get_statekey_from_short(shortstatekey)
}
/// Returns (shortstatehash, already_existed)
fn get_or_create_shortstatehash(
pub fn get_or_create_shortstatehash(
&self,
state_hash: &[u8],
) -> Result<(u64, bool)> {
Ok(match self.statehash_shortstatehash.get(state_hash)? {
Some(shortstatehash) => (
utils::u64_from_bytes(&shortstatehash)
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
true,
),
None => {
let shortstatehash = services().globals.next_count()?;
self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false)
}
})
self.db.get_or_create_shortstatehash(state_hash)
}
pub fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortroomid
.get(room_id.as_bytes())?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))
})
.transpose()
self.db.get_shortroomid(room_id)
}
pub fn get_or_create_shortroomid(
&self,
room_id: &RoomId,
) -> Result<u64> {
Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
Some(short) => utils::u64_from_bytes(&short)
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))?,
None => {
let short = services().globals.next_count()?;
self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
short
}
})
self.db.get_or_create_shortroomid(room_id)
}
}

View file

@ -1,9 +1,10 @@
mod data;
use std::{collections::HashSet, sync::Arc};
use std::{collections::{HashSet, HashMap}, sync::Arc};
pub use data::Data;
use ruma::{RoomId, events::{room::{member::MembershipState, create::RoomCreateEventContent}, AnyStrippedStateEvent, StateEventType}, UserId, EventId, serde::Raw, RoomVersionId};
use ruma::{RoomId, events::{room::{member::MembershipState, create::RoomCreateEventContent}, AnyStrippedStateEvent, StateEventType, RoomEventType}, UserId, EventId, serde::Raw, RoomVersionId, state_res::{StateMap, self}};
use serde::Deserialize;
use tokio::sync::MutexGuard;
use tracing::warn;
use crate::{Result, services, PduEvent, Error, utils::calculate_hash};
@ -11,7 +12,7 @@ use crate::{Result, services, PduEvent, Error, utils::calculate_hash};
use super::state_compressor::CompressedStateEvent;
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {
@ -97,7 +98,7 @@ impl Service {
room_id: &RoomId,
state_ids_compressed: HashSet<CompressedStateEvent>,
) -> Result<u64> {
let shorteventid = services().short.get_or_create_shorteventid(event_id)?;
let shorteventid = services().rooms.short.get_or_create_shorteventid(event_id)?;
let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?;
@ -109,11 +110,11 @@ impl Service {
);
let (shortstatehash, already_existed) =
services().short.get_or_create_shortstatehash(&state_hash)?;
services().rooms.short.get_or_create_shortstatehash(&state_hash)?;
if !already_existed {
let states_parents = previous_shortstatehash
.map_or_else(|| Ok(Vec::new()), |p| services().room.state_compressor.load_shortstatehash_info(p))?;
.map_or_else(|| Ok(Vec::new()), |p| services().rooms.state_compressor.load_shortstatehash_info(p))?;
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
@ -132,7 +133,7 @@ impl Service {
} else {
(state_ids_compressed, HashSet::new())
};
services().room.state_compressor.save_state_from_diff(
services().rooms.state_compressor.save_state_from_diff(
shortstatehash,
statediffnew,
statediffremoved,
@ -141,7 +142,7 @@ impl Service {
)?;
}
self.db.set_event_state(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
self.db.set_event_state(shorteventid, shortstatehash)?;
Ok(shortstatehash)
}
@ -155,25 +156,24 @@ impl Service {
&self,
new_pdu: &PduEvent,
) -> Result<u64> {
let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id)?;
let shorteventid = services().rooms.short.get_or_create_shorteventid(&new_pdu.event_id)?;
let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?;
if let Some(p) = previous_shortstatehash {
self.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &p.to_be_bytes())?;
self.db.set_event_state(shorteventid, p)?;
}
if let Some(state_key) = &new_pdu.state_key {
let states_parents = previous_shortstatehash
.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?;
.map_or_else(|| Ok(Vec::new()), |p| services().rooms.state_compressor.load_shortstatehash_info(p))?;
let shortstatekey = self.get_or_create_shortstatekey(
let shortstatekey = services().rooms.short.get_or_create_shortstatekey(
&new_pdu.kind.to_string().into(),
state_key,
)?;
let new = self.compress_state_event(shortstatekey, &new_pdu.event_id)?;
let new = services().rooms.state_compressor.compress_state_event(shortstatekey, &new_pdu.event_id)?;
let replaces = states_parents
.last()
@ -199,7 +199,7 @@ impl Service {
statediffremoved.insert(*replaces);
}
self.save_state_from_diff(
services().rooms.state_compressor.save_state_from_diff(
shortstatehash,
statediffnew,
statediffremoved,
@ -221,16 +221,16 @@ impl Service {
let mut state = Vec::new();
// Add recommended events
if let Some(e) =
self.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")?
services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")?
{
state.push(e.to_stripped_state_event());
}
if let Some(e) =
self.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")?
services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")?
{
state.push(e.to_stripped_state_event());
}
if let Some(e) = self.room_state_get(
if let Some(e) = services().rooms.state_accessor.room_state_get(
&invite_event.room_id,
&StateEventType::RoomCanonicalAlias,
"",
@ -238,16 +238,16 @@ impl Service {
state.push(e.to_stripped_state_event());
}
if let Some(e) =
self.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")?
services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")?
{
state.push(e.to_stripped_state_event());
}
if let Some(e) =
self.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")?
services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")?
{
state.push(e.to_stripped_state_event());
}
if let Some(e) = self.room_state_get(
if let Some(e) = services().rooms.state_accessor.room_state_get(
&invite_event.room_id,
&StateEventType::RoomMember,
invite_event.sender.as_str(),
@ -260,17 +260,16 @@ impl Service {
}
#[tracing::instrument(skip(self))]
pub fn set_room_state(&self, room_id: &RoomId, shortstatehash: u64) -> Result<()> {
self.roomid_shortstatehash
.insert(room_id.as_bytes(), &shortstatehash.to_be_bytes())?;
Ok(())
pub fn set_room_state(&self, room_id: &RoomId, shortstatehash: u64,
mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
self.db.set_room_state(room_id, shortstatehash, mutex_lock)
}
/// Returns the room's version.
#[tracing::instrument(skip(self))]
pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> {
let create_event = self.room_state_get(room_id, &StateEventType::RoomCreate, "")?;
let create_event = services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomCreate, "")?;
let create_event_content: Option<RoomCreateEventContent> = create_event
.as_ref()
@ -294,4 +293,50 @@ impl Service {
pub fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
self.db.get_forward_extremities(room_id)
}
/// This fetches auth events from the current state.
#[tracing::instrument(skip(self))]
pub fn get_auth_events(
&self,
room_id: &RoomId,
kind: &RoomEventType,
sender: &UserId,
state_key: Option<&str>,
content: &serde_json::value::RawValue,
) -> Result<StateMap<Arc<PduEvent>>> {
let shortstatehash =
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
current_shortstatehash
} else {
return Ok(HashMap::new());
};
let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content)
.expect("content is a valid JSON object");
let mut sauthevents = auth_events
.into_iter()
.filter_map(|(event_type, state_key)| {
services().rooms.short.get_shortstatekey(&event_type.to_string().into(), &state_key)
.ok()
.flatten()
.map(|s| (s, (event_type, state_key)))
})
.collect::<HashMap<_, _>>();
let full_state = services().rooms.state_compressor
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
.into_iter()
.filter_map(|compressed| services().rooms.state_compressor.parse_compressed_state_event(compressed).ok())
.filter_map(|(shortstatekey, event_id)| {
sauthevents.remove(&shortstatekey).map(|k| (k, event_id))
})
.filter_map(|(k, event_id)| services().rooms.timeline.get_pdu(&event_id).ok().flatten().map(|pdu| (k, pdu)))
.collect())
}
}

View file

@ -7,7 +7,7 @@ use ruma::{events::StateEventType, RoomId, EventId};
use crate::{Result, PduEvent};
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {
@ -45,7 +45,7 @@ impl Service {
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
self.db.pdu_state_get(shortstatehash, event_type, state_key)
self.db.state_get(shortstatehash, event_type, state_key)
}
/// Returns the state hash for this pdu.

View file

@ -3,12 +3,23 @@ use std::{collections::HashSet, sync::Arc};
pub use data::Data;
use regex::Regex;
use ruma::{RoomId, UserId, events::{room::{member::MembershipState, create::RoomCreateEventContent}, AnyStrippedStateEvent, StateEventType, tag::TagEvent, RoomAccountDataEventType, GlobalAccountDataEventType, direct::DirectEvent, ignored_user_list::IgnoredUserListEvent, AnySyncStateEvent}, serde::Raw, ServerName};
use ruma::{
events::{
direct::{DirectEvent, DirectEventContent},
ignored_user_list::IgnoredUserListEvent,
room::{create::RoomCreateEventContent, member::MembershipState},
tag::{TagEvent, TagEventContent},
AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
RoomAccountDataEventType, StateEventType, RoomAccountDataEvent, RoomAccountDataEventContent,
},
serde::Raw,
RoomId, ServerName, UserId,
};
use crate::{Result, services, utils, Error};
use crate::{services, utils, Error, Result};
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {
@ -45,7 +56,9 @@ impl Service {
self.db.mark_as_once_joined(user_id, room_id)?;
// Check if the room has a predecessor
if let Some(predecessor) = self
if let Some(predecessor) = services()
.rooms
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")?
.and_then(|create| serde_json::from_str(create.content.get()).ok())
.and_then(|content: RoomCreateEventContent| content.predecessor)
@ -76,27 +89,41 @@ impl Service {
// .ok();
// Copy old tags to new room
if let Some(tag_event) = services().account_data.get::<TagEvent>(
Some(&predecessor.room_id),
user_id,
RoomAccountDataEventType::Tag,
)? {
services().account_data
if let Some(tag_event) = services()
.account_data
.get(
Some(&predecessor.room_id),
user_id,
RoomAccountDataEventType::Tag,
)?
.map(|event| {
serde_json::from_str(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))
})
{
services()
.account_data
.update(
Some(room_id),
user_id,
RoomAccountDataEventType::Tag,
&tag_event,
&tag_event?,
)
.ok();
};
// Copy direct chat flag
if let Some(mut direct_event) = services().account_data.get::<DirectEvent>(
if let Some(mut direct_event) = services().account_data.get(
None,
user_id,
GlobalAccountDataEventType::Direct.to_string().into(),
)? {
)?
.map(|event| {
serde_json::from_str::<DirectEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))
})
{
let direct_event = direct_event?;
let mut room_ids_updated = false;
for room_ids in direct_event.content.0.values_mut() {
@ -111,7 +138,7 @@ impl Service {
None,
user_id,
GlobalAccountDataEventType::Direct.to_string().into(),
&direct_event,
&serde_json::to_value(&direct_event).expect("to json always works"),
)?;
}
};
@ -124,13 +151,17 @@ impl Service {
// We want to know if the sender is ignored by the receiver
let is_ignored = services()
.account_data
.get::<IgnoredUserListEvent>(
.get(
None, // Ignored users are in global account data
user_id, // Receiver
GlobalAccountDataEventType::IgnoredUserList
.to_string()
.into(),
)?
.map(|event| {
serde_json::from_str::<IgnoredUserListEvent>(event.get())
.map_err(|_| Error::bad_database("Invalid account data event in db."))
}).transpose()?
.map_or(false, |ignored| {
ignored
.content
@ -200,10 +231,7 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id))]
pub fn get_our_real_users(
&self,
room_id: &RoomId,
) -> Result<Arc<HashSet<Box<UserId>>>> {
pub fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<Box<UserId>>>> {
let maybe = self
.our_real_users_cache
.read()

View file

@ -9,7 +9,7 @@ use crate::{Result, utils, services};
use self::data::StateDiff;
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()];
@ -67,7 +67,7 @@ impl Service {
) -> Result<CompressedStateEvent> {
let mut v = shortstatekey.to_be_bytes().to_vec();
v.extend_from_slice(
&self
&services().rooms.short
.get_or_create_shorteventid(event_id)?
.to_be_bytes(),
);
@ -218,7 +218,7 @@ impl Service {
HashSet<CompressedStateEvent>, // added
HashSet<CompressedStateEvent>)> // removed
{
let previous_shortstatehash = self.db.current_shortstatehash(room_id)?;
let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?;
let state_hash = utils::calculate_hash(
&new_state_ids_compressed

View file

@ -5,6 +5,7 @@ use ruma::{signatures::CanonicalJsonObject, EventId, UserId, RoomId};
use crate::{Result, PduEvent};
pub trait Data: Send + Sync {
fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>>;
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64>;
/// Returns the `count` of this pdu's id.

View file

@ -21,33 +21,14 @@ use crate::{services, Result, service::pdu::{PduBuilder, EventHash}, Error, PduE
use super::state_compressor::CompressedStateEvent;
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {
/*
/// Checks if a room exists.
#[tracing::instrument(skip(self))]
pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> {
let prefix = self
.get_shortroomid(room_id)?
.expect("room exists")
.to_be_bytes()
.to_vec();
// Look for PDUs in that room.
self.pduid_pdu
.iter_from(&prefix, false)
.filter(|(k, _)| k.starts_with(&prefix))
.map(|(_, pdu)| {
serde_json::from_slice(&pdu)
.map_err(|_| Error::bad_database("Invalid first PDU in db."))
.map(Arc::new)
})
.next()
.transpose()
self.db.first_pdu_in_room(room_id)
}
*/
#[tracing::instrument(skip(self))]
pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> {
@ -681,7 +662,8 @@ impl Service {
/// Append the incoming event setting the state snapshot to the state from the
/// server that sent the event.
#[tracing::instrument(skip_all)]
fn append_incoming_pdu<'a>(
pub fn append_incoming_pdu<'a>(
&self,
pdu: &PduEvent,
pdu_json: CanonicalJsonObject,
new_room_leaves: impl IntoIterator<Item = &'a EventId> + Clone + Debug,

View file

@ -1,11 +1,13 @@
mod data;
use std::sync::Arc;
pub use data::Data;
use ruma::{RoomId, UserId};
use crate::Result;
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -448,14 +448,6 @@ impl Service {
Ok(())
}
#[tracing::instrument(skip(keys))]
fn calculate_hash(keys: &[&[u8]]) -> Vec<u8> {
// We only hash the pdu's event ids, not the whole pdu
let bytes = keys.join(&0xff);
let hash = digest::digest(&digest::SHA256, &bytes);
hash.as_ref().to_owned()
}
/// Cleanup event data
/// Used for instance after we remove an appservice registration
///

View file

@ -1,11 +1,13 @@
mod data;
use std::sync::Arc;
pub use data::Data;
use ruma::{UserId, DeviceId, TransactionId};
use crate::Result;
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -1,4 +1,6 @@
mod data;
use std::sync::Arc;
pub use data::Data;
use ruma::{api::client::{uiaa::{UiaaInfo, IncomingAuthData, IncomingPassword, AuthType, IncomingUserIdentifier}, error::ErrorKind}, DeviceId, UserId, signatures::CanonicalJsonValue};
@ -7,7 +9,7 @@ use tracing::error;
use crate::{Result, utils, Error, services, api::client_server::SESSION_ID_LENGTH};
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -1,5 +1,5 @@
mod data;
use std::{collections::BTreeMap, mem};
use std::{collections::BTreeMap, mem, sync::Arc};
pub use data::Data;
use ruma::{UserId, MxcUri, DeviceId, DeviceKeyId, serde::Raw, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, DeviceKeyAlgorithm, UInt, events::AnyToDeviceEvent, api::client::{device::Device, filter::IncomingFilterDefinition, error::ErrorKind}, RoomAliasId};
@ -7,7 +7,7 @@ use ruma::{UserId, MxcUri, DeviceId, DeviceKeyId, serde::Raw, encryption::{OneTi
use crate::{Result, Error, services};
pub struct Service {
db: Box<dyn Data>,
db: Arc<dyn Data>,
}
impl Service {

View file

@ -3,6 +3,7 @@ pub mod error;
use argon2::{Config, Variant};
use cmp::Ordering;
use rand::prelude::*;
use ring::digest;
use ruma::serde::{try_from_json_map, CanonicalJsonError, CanonicalJsonObject};
use std::{
cmp, fmt,
@ -59,7 +60,7 @@ pub fn random_string(length: usize) -> String {
}
/// Calculate a new hash for the given password
pub fn calculate_hash(password: &str) -> Result<String, argon2::Error> {
pub fn calculate_password_hash(password: &str) -> Result<String, argon2::Error> {
let hashing_config = Config {
variant: Variant::Argon2id,
..Default::default()
@ -69,6 +70,15 @@ pub fn calculate_hash(password: &str) -> Result<String, argon2::Error> {
argon2::hash_encoded(password.as_bytes(), salt.as_bytes(), &hashing_config)
}
#[tracing::instrument(skip(keys))]
pub fn calculate_hash(keys: &[&[u8]]) -> Vec<u8> {
// We only hash the pdu's event ids, not the whole pdu
let bytes = keys.join(&0xff);
let hash = digest::digest(&digest::SHA256, &bytes);
hash.as_ref().to_owned()
}
pub fn common_elements(
mut iterators: impl Iterator<Item = impl Iterator<Item = Vec<u8>>>,
check_order: impl Fn(&[u8], &[u8]) -> Ordering,