use anyhow::Context as _;
use async_graphql::{Context, Description, Enum, InputObject, Object, ID};
use mas_storage::{
job::{DeactivateUserJob, JobRepositoryExt, ProvisionUserJob},
user::UserRepository,
};
use tracing::{info, warn};
use zeroize::Zeroizing;
use crate::graphql::{
model::{NodeType, User},
state::ContextExt,
Requester, UserId,
};
#[derive(Default)]
pub struct UserMutations {
_private: (),
}
#[derive(InputObject)]
struct AddUserInput {
username: String,
skip_homeserver_check: Option<bool>,
}
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum AddUserStatus {
Added,
Exists,
Reserved,
Invalid,
}
#[derive(Description)]
enum AddUserPayload {
Added(mas_data_model::User),
Exists(mas_data_model::User),
Reserved,
Invalid,
}
#[Object(use_type_description)]
impl AddUserPayload {
async fn status(&self) -> AddUserStatus {
match self {
Self::Added(_) => AddUserStatus::Added,
Self::Exists(_) => AddUserStatus::Exists,
Self::Reserved => AddUserStatus::Reserved,
Self::Invalid => AddUserStatus::Invalid,
}
}
async fn user(&self) -> Option<User> {
match self {
Self::Added(user) | Self::Exists(user) => Some(User(user.clone())),
Self::Invalid | Self::Reserved => None,
}
}
}
#[derive(InputObject)]
struct LockUserInput {
user_id: ID,
deactivate: Option<bool>,
}
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum LockUserStatus {
Locked,
NotFound,
}
#[derive(Description)]
enum LockUserPayload {
Locked(mas_data_model::User),
NotFound,
}
#[Object(use_type_description)]
impl LockUserPayload {
async fn status(&self) -> LockUserStatus {
match self {
Self::Locked(_) => LockUserStatus::Locked,
Self::NotFound => LockUserStatus::NotFound,
}
}
async fn user(&self) -> Option<User> {
match self {
Self::Locked(user) => Some(User(user.clone())),
Self::NotFound => None,
}
}
}
#[derive(InputObject)]
struct UnlockUserInput {
user_id: ID,
}
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum UnlockUserStatus {
Unlocked,
NotFound,
}
#[derive(Description)]
enum UnlockUserPayload {
Unlocked(mas_data_model::User),
NotFound,
}
#[Object(use_type_description)]
impl UnlockUserPayload {
async fn status(&self) -> UnlockUserStatus {
match self {
Self::Unlocked(_) => UnlockUserStatus::Unlocked,
Self::NotFound => UnlockUserStatus::NotFound,
}
}
async fn user(&self) -> Option<User> {
match self {
Self::Unlocked(user) => Some(User(user.clone())),
Self::NotFound => None,
}
}
}
#[derive(InputObject)]
struct SetCanRequestAdminInput {
user_id: ID,
can_request_admin: bool,
}
#[derive(Description)]
enum SetCanRequestAdminPayload {
Updated(mas_data_model::User),
NotFound,
}
#[Object(use_type_description)]
impl SetCanRequestAdminPayload {
async fn user(&self) -> Option<User> {
match self {
Self::Updated(user) => Some(User(user.clone())),
Self::NotFound => None,
}
}
}
#[derive(InputObject)]
struct AllowUserCrossSigningResetInput {
user_id: ID,
}
#[derive(Description)]
enum AllowUserCrossSigningResetPayload {
Allowed(mas_data_model::User),
NotFound,
}
#[Object(use_type_description)]
impl AllowUserCrossSigningResetPayload {
async fn user(&self) -> Option<User> {
match self {
Self::Allowed(user) => Some(User(user.clone())),
Self::NotFound => None,
}
}
}
#[derive(InputObject)]
struct SetPasswordInput {
user_id: ID,
current_password: Option<String>,
new_password: String,
}
#[derive(InputObject)]
struct SetPasswordByRecoveryInput {
ticket: String,
new_password: String,
}
#[derive(Description)]
struct SetPasswordPayload {
status: SetPasswordStatus,
}
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
enum SetPasswordStatus {
Allowed,
NotFound,
NoCurrentPassword,
WrongPassword,
InvalidNewPassword,
NotAllowed,
PasswordChangesDisabled,
NoSuchRecoveryTicket,
RecoveryTicketAlreadyUsed,
ExpiredRecoveryTicket,
AccountLocked,
}
#[Object(use_type_description)]
impl SetPasswordPayload {
async fn status(&self) -> SetPasswordStatus {
self.status
}
}
fn valid_username_character(c: char) -> bool {
c.is_ascii_lowercase()
|| c.is_ascii_digit()
|| c == '='
|| c == '_'
|| c == '-'
|| c == '.'
|| c == '/'
|| c == '+'
}
fn username_valid(username: &str) -> bool {
if username.is_empty() || username.len() > 255 {
return false;
}
if username.starts_with('_') {
return false;
}
if !username.chars().all(valid_username_character) {
return false;
}
true
}
#[Object]
impl UserMutations {
async fn add_user(
&self,
ctx: &Context<'_>,
input: AddUserInput,
) -> Result<AddUserPayload, async_graphql::Error> {
let state = ctx.state();
let requester = ctx.requester();
let clock = state.clock();
let mut rng = state.rng();
if !requester.is_admin() {
return Err(async_graphql::Error::new("Unauthorized"));
}
let mut repo = state.repository().await?;
if let Some(user) = repo.user().find_by_username(&input.username).await? {
return Ok(AddUserPayload::Exists(user));
}
if !username_valid(&input.username) {
return Ok(AddUserPayload::Invalid);
}
let homeserver_available = state
.homeserver_connection()
.is_localpart_available(&input.username)
.await?;
if !homeserver_available {
if !input.skip_homeserver_check.unwrap_or(false) {
return Ok(AddUserPayload::Reserved);
}
warn!("Skipped homeserver check for username {}", input.username);
}
let user = repo.user().add(&mut rng, &clock, input.username).await?;
repo.job()
.schedule_job(ProvisionUserJob::new(&user))
.await?;
repo.save().await?;
Ok(AddUserPayload::Added(user))
}
async fn lock_user(
&self,
ctx: &Context<'_>,
input: LockUserInput,
) -> Result<LockUserPayload, async_graphql::Error> {
let state = ctx.state();
let requester = ctx.requester();
if !requester.is_admin() {
return Err(async_graphql::Error::new("Unauthorized"));
}
let mut repo = state.repository().await?;
let user_id = NodeType::User.extract_ulid(&input.user_id)?;
let user = repo.user().lookup(user_id).await?;
let Some(user) = user else {
return Ok(LockUserPayload::NotFound);
};
let deactivate = input.deactivate.unwrap_or(false);
let user = repo.user().lock(&state.clock(), user).await?;
if deactivate {
info!("Scheduling deactivation of user {}", user.id);
repo.job()
.schedule_job(DeactivateUserJob::new(&user, deactivate))
.await?;
}
repo.save().await?;
Ok(LockUserPayload::Locked(user))
}
async fn unlock_user(
&self,
ctx: &Context<'_>,
input: UnlockUserInput,
) -> Result<UnlockUserPayload, async_graphql::Error> {
let state = ctx.state();
let requester = ctx.requester();
let matrix = state.homeserver_connection();
if !requester.is_admin() {
return Err(async_graphql::Error::new("Unauthorized"));
}
let mut repo = state.repository().await?;
let user_id = NodeType::User.extract_ulid(&input.user_id)?;
let user = repo.user().lookup(user_id).await?;
let Some(user) = user else {
return Ok(UnlockUserPayload::NotFound);
};
let mxid = matrix.mxid(&user.username);
matrix.reactivate_user(&mxid).await?;
let user = repo.user().unlock(user).await?;
repo.save().await?;
Ok(UnlockUserPayload::Unlocked(user))
}
async fn set_can_request_admin(
&self,
ctx: &Context<'_>,
input: SetCanRequestAdminInput,
) -> Result<SetCanRequestAdminPayload, async_graphql::Error> {
let state = ctx.state();
let requester = ctx.requester();
if !requester.is_admin() {
return Err(async_graphql::Error::new("Unauthorized"));
}
let mut repo = state.repository().await?;
let user_id = NodeType::User.extract_ulid(&input.user_id)?;
let user = repo.user().lookup(user_id).await?;
let Some(user) = user else {
return Ok(SetCanRequestAdminPayload::NotFound);
};
let user = repo
.user()
.set_can_request_admin(user, input.can_request_admin)
.await?;
repo.save().await?;
Ok(SetCanRequestAdminPayload::Updated(user))
}
async fn allow_user_cross_signing_reset(
&self,
ctx: &Context<'_>,
input: AllowUserCrossSigningResetInput,
) -> Result<AllowUserCrossSigningResetPayload, async_graphql::Error> {
let state = ctx.state();
let user_id = NodeType::User.extract_ulid(&input.user_id)?;
let requester = ctx.requester();
if !requester.is_owner_or_admin(&UserId(user_id)) {
return Err(async_graphql::Error::new("Unauthorized"));
}
let mut repo = state.repository().await?;
let user = repo.user().lookup(user_id).await?;
repo.cancel().await?;
let Some(user) = user else {
return Ok(AllowUserCrossSigningResetPayload::NotFound);
};
let conn = state.homeserver_connection();
let mxid = conn.mxid(&user.username);
conn.allow_cross_signing_reset(&mxid)
.await
.context("Failed to allow cross-signing reset")?;
Ok(AllowUserCrossSigningResetPayload::Allowed(user))
}
async fn set_password(
&self,
ctx: &Context<'_>,
input: SetPasswordInput,
) -> Result<SetPasswordPayload, async_graphql::Error> {
let state = ctx.state();
let user_id = NodeType::User.extract_ulid(&input.user_id)?;
let requester = ctx.requester();
if !requester.is_owner_or_admin(&UserId(user_id)) {
return Err(async_graphql::Error::new("Unauthorized"));
}
if input.new_password.is_empty() {
return Ok(SetPasswordPayload {
status: SetPasswordStatus::InvalidNewPassword,
});
}
let password_manager = state.password_manager();
if !password_manager.is_enabled() {
return Ok(SetPasswordPayload {
status: SetPasswordStatus::PasswordChangesDisabled,
});
}
if !password_manager.is_password_complex_enough(&input.new_password)? {
return Ok(SetPasswordPayload {
status: SetPasswordStatus::InvalidNewPassword,
});
}
let mut repo = state.repository().await?;
let Some(user) = repo.user().lookup(user_id).await? else {
return Ok(SetPasswordPayload {
status: SetPasswordStatus::NotFound,
});
};
if !requester.is_admin() {
if !state.site_config().password_change_allowed {
return Ok(SetPasswordPayload {
status: SetPasswordStatus::PasswordChangesDisabled,
});
}
let Some(active_password) = repo.user_password().active(&user).await? else {
return Ok(SetPasswordPayload {
status: SetPasswordStatus::NoCurrentPassword,
});
};
let Some(current_password_attempt) = input.current_password else {
return Err(async_graphql::Error::new(
"You must supply `currentPassword` to change your own password if you are not an administrator"
));
};
if let Err(_err) = password_manager
.verify(
active_password.version,
Zeroizing::new(current_password_attempt.into_bytes()),
active_password.hashed_password,
)
.await
{
return Ok(SetPasswordPayload {
status: SetPasswordStatus::WrongPassword,
});
}
}
let (new_password_version, new_password_hash) = password_manager
.hash(state.rng(), Zeroizing::new(input.new_password.into_bytes()))
.await?;
repo.user_password()
.add(
&mut state.rng(),
&state.clock(),
&user,
new_password_version,
new_password_hash,
None,
)
.await?;
repo.save().await?;
Ok(SetPasswordPayload {
status: SetPasswordStatus::Allowed,
})
}
async fn set_password_by_recovery(
&self,
ctx: &Context<'_>,
input: SetPasswordByRecoveryInput,
) -> Result<SetPasswordPayload, async_graphql::Error> {
let state = ctx.state();
let requester = ctx.requester();
let clock = state.clock();
if !matches!(requester, Requester::Anonymous) {
return Err(async_graphql::Error::new(
"Account recovery is only for anonymous users.",
));
}
let password_manager = state.password_manager();
if !password_manager.is_enabled() || !state.site_config().account_recovery_allowed {
return Ok(SetPasswordPayload {
status: SetPasswordStatus::PasswordChangesDisabled,
});
}
if !password_manager.is_password_complex_enough(&input.new_password)? {
return Ok(SetPasswordPayload {
status: SetPasswordStatus::InvalidNewPassword,
});
}
let mut repo = state.repository().await?;
let Some(ticket) = repo.user_recovery().find_ticket(&input.ticket).await? else {
return Ok(SetPasswordPayload {
status: SetPasswordStatus::NoSuchRecoveryTicket,
});
};
let session = repo
.user_recovery()
.lookup_session(ticket.user_recovery_session_id)
.await?
.context("Unknown session")?;
if session.consumed_at.is_some() {
return Ok(SetPasswordPayload {
status: SetPasswordStatus::RecoveryTicketAlreadyUsed,
});
}
if !ticket.active(clock.now()) {
return Ok(SetPasswordPayload {
status: SetPasswordStatus::ExpiredRecoveryTicket,
});
}
let user_email = repo
.user_email()
.lookup(ticket.user_email_id)
.await?
.filter(|email| email.confirmed_at.is_some())
.context("Unknown email address")?;
let user = repo
.user()
.lookup(user_email.user_id)
.await?
.context("Invalid user")?;
if !user.is_valid() {
return Ok(SetPasswordPayload {
status: SetPasswordStatus::AccountLocked,
});
}
let (new_password_version, new_password_hash) = password_manager
.hash(state.rng(), Zeroizing::new(input.new_password.into_bytes()))
.await?;
repo.user_password()
.add(
&mut state.rng(),
&state.clock(),
&user,
new_password_version,
new_password_hash,
None,
)
.await?;
repo.user_recovery()
.consume_ticket(&clock, ticket, session)
.await?;
repo.save().await?;
Ok(SetPasswordPayload {
status: SetPasswordStatus::Allowed,
})
}
}