Improve sso auth flow (#6205)

Co-authored-by: Timshel <timshel@users.noreply.github.com>
This commit is contained in:
Timshel
2025-12-06 22:20:04 +01:00
committed by GitHub
parent 2d91a9460b
commit 8f689d8795
17 changed files with 449 additions and 295 deletions

View File

@@ -337,6 +337,46 @@ macro_rules! db_run {
};
}
// Write all ToSql<Text, DB> and FromSql<Text, DB> given a serializable/deserializable type.
#[macro_export]
macro_rules! impl_FromToSqlText {
($name:ty) => {
#[cfg(mysql)]
impl ToSql<Text, diesel::mysql::Mysql> for $name {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, diesel::mysql::Mysql>) -> diesel::serialize::Result {
serde_json::to_writer(out, self).map(|_| diesel::serialize::IsNull::No).map_err(Into::into)
}
}
#[cfg(postgresql)]
impl ToSql<Text, diesel::pg::Pg> for $name {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, diesel::pg::Pg>) -> diesel::serialize::Result {
serde_json::to_writer(out, self).map(|_| diesel::serialize::IsNull::No).map_err(Into::into)
}
}
#[cfg(sqlite)]
impl ToSql<Text, diesel::sqlite::Sqlite> for $name {
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, diesel::sqlite::Sqlite>) -> diesel::serialize::Result {
serde_json::to_string(self).map_err(Into::into).map(|str| {
out.set_value(str);
diesel::serialize::IsNull::No
})
}
}
impl<DB: diesel::backend::Backend> FromSql<Text, DB> for $name
where
String: FromSql<Text, DB>,
{
fn from_sql(bytes: DB::RawValue<'_>) -> diesel::deserialize::Result<Self> {
<String as FromSql<Text, DB>>::from_sql(bytes)
.and_then(|str| serde_json::from_str(&str).map_err(Into::into))
}
}
};
}
pub mod schema;
// Reexport the models, needs to be after the macros are defined so it can access them

View File

@@ -11,7 +11,7 @@ mod group;
mod org_policy;
mod organization;
mod send;
mod sso_nonce;
mod sso_auth;
mod two_factor;
mod two_factor_duo_context;
mod two_factor_incomplete;
@@ -36,7 +36,7 @@ pub use self::send::{
id::{SendFileId, SendId},
Send, SendType,
};
pub use self::sso_nonce::SsoNonce;
pub use self::sso_auth::{OIDCAuthenticatedUser, OIDCCodeWrapper, SsoAuth};
pub use self::two_factor::{TwoFactor, TwoFactorType};
pub use self::two_factor_duo_context::TwoFactorDuoContext;
pub use self::two_factor_incomplete::TwoFactorIncomplete;

134
src/db/models/sso_auth.rs Normal file
View File

@@ -0,0 +1,134 @@
use chrono::{NaiveDateTime, Utc};
use std::time::Duration;
use crate::api::EmptyResult;
use crate::db::schema::sso_auth;
use crate::db::{DbConn, DbPool};
use crate::error::MapResult;
use crate::sso::{OIDCCode, OIDCCodeChallenge, OIDCIdentifier, OIDCState, SSO_AUTH_EXPIRATION};
use diesel::deserialize::FromSql;
use diesel::expression::AsExpression;
use diesel::prelude::*;
use diesel::serialize::{Output, ToSql};
use diesel::sql_types::Text;
#[derive(AsExpression, Clone, Debug, Serialize, Deserialize, FromSqlRow)]
#[diesel(sql_type = Text)]
pub enum OIDCCodeWrapper {
Ok {
code: OIDCCode,
},
Error {
error: String,
error_description: Option<String>,
},
}
impl_FromToSqlText!(OIDCCodeWrapper);
#[derive(AsExpression, Clone, Debug, Serialize, Deserialize, FromSqlRow)]
#[diesel(sql_type = Text)]
pub struct OIDCAuthenticatedUser {
pub refresh_token: Option<String>,
pub access_token: String,
pub expires_in: Option<Duration>,
pub identifier: OIDCIdentifier,
pub email: String,
pub email_verified: Option<bool>,
pub user_name: Option<String>,
}
impl_FromToSqlText!(OIDCAuthenticatedUser);
#[derive(Identifiable, Queryable, Insertable, AsChangeset, Selectable)]
#[diesel(table_name = sso_auth)]
#[diesel(treat_none_as_null = true)]
#[diesel(primary_key(state))]
pub struct SsoAuth {
pub state: OIDCState,
pub client_challenge: OIDCCodeChallenge,
pub nonce: String,
pub redirect_uri: String,
pub code_response: Option<OIDCCodeWrapper>,
pub auth_response: Option<OIDCAuthenticatedUser>,
pub created_at: NaiveDateTime,
pub updated_at: NaiveDateTime,
}
/// Local methods
impl SsoAuth {
pub fn new(state: OIDCState, client_challenge: OIDCCodeChallenge, nonce: String, redirect_uri: String) -> Self {
let now = Utc::now().naive_utc();
SsoAuth {
state,
client_challenge,
nonce,
redirect_uri,
created_at: now,
updated_at: now,
code_response: None,
auth_response: None,
}
}
}
/// Database methods
impl SsoAuth {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn:
mysql {
diesel::insert_into(sso_auth::table)
.values(self)
.on_conflict(diesel::dsl::DuplicatedKeys)
.do_update()
.set(self)
.execute(conn)
.map_res("Error saving SSO auth")
}
postgresql, sqlite {
diesel::insert_into(sso_auth::table)
.values(self)
.on_conflict(sso_auth::state)
.do_update()
.set(self)
.execute(conn)
.map_res("Error saving SSO auth")
}
}
}
pub async fn find(state: &OIDCState, conn: &DbConn) -> Option<Self> {
let oldest = Utc::now().naive_utc() - *SSO_AUTH_EXPIRATION;
db_run! { conn: {
sso_auth::table
.filter(sso_auth::state.eq(state))
.filter(sso_auth::created_at.ge(oldest))
.first::<Self>(conn)
.ok()
}}
}
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
db_run! {conn: {
diesel::delete(sso_auth::table.filter(sso_auth::state.eq(self.state)))
.execute(conn)
.map_res("Error deleting sso_auth")
}}
}
pub async fn delete_expired(pool: DbPool) -> EmptyResult {
debug!("Purging expired sso_auth");
if let Ok(conn) = pool.get().await {
let oldest = Utc::now().naive_utc() - *SSO_AUTH_EXPIRATION;
db_run! { conn: {
diesel::delete(sso_auth::table.filter(sso_auth::created_at.lt(oldest)))
.execute(conn)
.map_res("Error deleting expired SSO nonce")
}}
} else {
err!("Failed to get DB connection while purging expired sso_auth")
}
}
}

View File

@@ -1,87 +0,0 @@
use chrono::{NaiveDateTime, Utc};
use crate::api::EmptyResult;
use crate::db::schema::sso_nonce;
use crate::db::{DbConn, DbPool};
use crate::error::MapResult;
use crate::sso::{OIDCState, NONCE_EXPIRATION};
use diesel::prelude::*;
#[derive(Identifiable, Queryable, Insertable)]
#[diesel(table_name = sso_nonce)]
#[diesel(primary_key(state))]
pub struct SsoNonce {
pub state: OIDCState,
pub nonce: String,
pub verifier: Option<String>,
pub redirect_uri: String,
pub created_at: NaiveDateTime,
}
/// Local methods
impl SsoNonce {
pub fn new(state: OIDCState, nonce: String, verifier: Option<String>, redirect_uri: String) -> Self {
let now = Utc::now().naive_utc();
SsoNonce {
state,
nonce,
verifier,
redirect_uri,
created_at: now,
}
}
}
/// Database methods
impl SsoNonce {
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
db_run! { conn:
sqlite, mysql {
diesel::replace_into(sso_nonce::table)
.values(self)
.execute(conn)
.map_res("Error saving SSO nonce")
}
postgresql {
diesel::insert_into(sso_nonce::table)
.values(self)
.execute(conn)
.map_res("Error saving SSO nonce")
}
}
}
pub async fn delete(state: &OIDCState, conn: &DbConn) -> EmptyResult {
db_run! { conn: {
diesel::delete(sso_nonce::table.filter(sso_nonce::state.eq(state)))
.execute(conn)
.map_res("Error deleting SSO nonce")
}}
}
pub async fn find(state: &OIDCState, conn: &DbConn) -> Option<Self> {
let oldest = Utc::now().naive_utc() - *NONCE_EXPIRATION;
db_run! { conn: {
sso_nonce::table
.filter(sso_nonce::state.eq(state))
.filter(sso_nonce::created_at.ge(oldest))
.first::<Self>(conn)
.ok()
}}
}
pub async fn delete_expired(pool: DbPool) -> EmptyResult {
debug!("Purging expired sso_nonce");
if let Ok(conn) = pool.get().await {
let oldest = Utc::now().naive_utc() - *NONCE_EXPIRATION;
db_run! { conn: {
diesel::delete(sso_nonce::table.filter(sso_nonce::created_at.lt(oldest)))
.execute(conn)
.map_res("Error deleting expired SSO nonce")
}}
} else {
err!("Failed to get DB connection while purging expired sso_nonce")
}
}
}

View File

@@ -256,12 +256,15 @@ table! {
}
table! {
sso_nonce (state) {
sso_auth (state) {
state -> Text,
client_challenge -> Text,
nonce -> Text,
verifier -> Nullable<Text>,
redirect_uri -> Text,
code_response -> Nullable<Text>,
auth_response -> Nullable<Text>,
created_at -> Timestamp,
updated_at -> Timestamp,
}
}