mirror of
https://github.com/dani-garcia/vaultwarden.git
synced 2026-01-16 20:50:33 +00:00
Improve sso auth flow (#6205)
Co-authored-by: Timshel <timshel@users.noreply.github.com>
This commit is contained in:
@@ -337,6 +337,46 @@ macro_rules! db_run {
|
||||
};
|
||||
}
|
||||
|
||||
// Write all ToSql<Text, DB> and FromSql<Text, DB> given a serializable/deserializable type.
|
||||
#[macro_export]
|
||||
macro_rules! impl_FromToSqlText {
|
||||
($name:ty) => {
|
||||
#[cfg(mysql)]
|
||||
impl ToSql<Text, diesel::mysql::Mysql> for $name {
|
||||
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, diesel::mysql::Mysql>) -> diesel::serialize::Result {
|
||||
serde_json::to_writer(out, self).map(|_| diesel::serialize::IsNull::No).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(postgresql)]
|
||||
impl ToSql<Text, diesel::pg::Pg> for $name {
|
||||
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, diesel::pg::Pg>) -> diesel::serialize::Result {
|
||||
serde_json::to_writer(out, self).map(|_| diesel::serialize::IsNull::No).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(sqlite)]
|
||||
impl ToSql<Text, diesel::sqlite::Sqlite> for $name {
|
||||
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, diesel::sqlite::Sqlite>) -> diesel::serialize::Result {
|
||||
serde_json::to_string(self).map_err(Into::into).map(|str| {
|
||||
out.set_value(str);
|
||||
diesel::serialize::IsNull::No
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<DB: diesel::backend::Backend> FromSql<Text, DB> for $name
|
||||
where
|
||||
String: FromSql<Text, DB>,
|
||||
{
|
||||
fn from_sql(bytes: DB::RawValue<'_>) -> diesel::deserialize::Result<Self> {
|
||||
<String as FromSql<Text, DB>>::from_sql(bytes)
|
||||
.and_then(|str| serde_json::from_str(&str).map_err(Into::into))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub mod schema;
|
||||
|
||||
// Reexport the models, needs to be after the macros are defined so it can access them
|
||||
|
||||
@@ -11,7 +11,7 @@ mod group;
|
||||
mod org_policy;
|
||||
mod organization;
|
||||
mod send;
|
||||
mod sso_nonce;
|
||||
mod sso_auth;
|
||||
mod two_factor;
|
||||
mod two_factor_duo_context;
|
||||
mod two_factor_incomplete;
|
||||
@@ -36,7 +36,7 @@ pub use self::send::{
|
||||
id::{SendFileId, SendId},
|
||||
Send, SendType,
|
||||
};
|
||||
pub use self::sso_nonce::SsoNonce;
|
||||
pub use self::sso_auth::{OIDCAuthenticatedUser, OIDCCodeWrapper, SsoAuth};
|
||||
pub use self::two_factor::{TwoFactor, TwoFactorType};
|
||||
pub use self::two_factor_duo_context::TwoFactorDuoContext;
|
||||
pub use self::two_factor_incomplete::TwoFactorIncomplete;
|
||||
|
||||
134
src/db/models/sso_auth.rs
Normal file
134
src/db/models/sso_auth.rs
Normal file
@@ -0,0 +1,134 @@
|
||||
use chrono::{NaiveDateTime, Utc};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::api::EmptyResult;
|
||||
use crate::db::schema::sso_auth;
|
||||
use crate::db::{DbConn, DbPool};
|
||||
use crate::error::MapResult;
|
||||
use crate::sso::{OIDCCode, OIDCCodeChallenge, OIDCIdentifier, OIDCState, SSO_AUTH_EXPIRATION};
|
||||
|
||||
use diesel::deserialize::FromSql;
|
||||
use diesel::expression::AsExpression;
|
||||
use diesel::prelude::*;
|
||||
use diesel::serialize::{Output, ToSql};
|
||||
use diesel::sql_types::Text;
|
||||
|
||||
#[derive(AsExpression, Clone, Debug, Serialize, Deserialize, FromSqlRow)]
|
||||
#[diesel(sql_type = Text)]
|
||||
pub enum OIDCCodeWrapper {
|
||||
Ok {
|
||||
code: OIDCCode,
|
||||
},
|
||||
Error {
|
||||
error: String,
|
||||
error_description: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl_FromToSqlText!(OIDCCodeWrapper);
|
||||
|
||||
#[derive(AsExpression, Clone, Debug, Serialize, Deserialize, FromSqlRow)]
|
||||
#[diesel(sql_type = Text)]
|
||||
pub struct OIDCAuthenticatedUser {
|
||||
pub refresh_token: Option<String>,
|
||||
pub access_token: String,
|
||||
pub expires_in: Option<Duration>,
|
||||
pub identifier: OIDCIdentifier,
|
||||
pub email: String,
|
||||
pub email_verified: Option<bool>,
|
||||
pub user_name: Option<String>,
|
||||
}
|
||||
|
||||
impl_FromToSqlText!(OIDCAuthenticatedUser);
|
||||
|
||||
#[derive(Identifiable, Queryable, Insertable, AsChangeset, Selectable)]
|
||||
#[diesel(table_name = sso_auth)]
|
||||
#[diesel(treat_none_as_null = true)]
|
||||
#[diesel(primary_key(state))]
|
||||
pub struct SsoAuth {
|
||||
pub state: OIDCState,
|
||||
pub client_challenge: OIDCCodeChallenge,
|
||||
pub nonce: String,
|
||||
pub redirect_uri: String,
|
||||
pub code_response: Option<OIDCCodeWrapper>,
|
||||
pub auth_response: Option<OIDCAuthenticatedUser>,
|
||||
pub created_at: NaiveDateTime,
|
||||
pub updated_at: NaiveDateTime,
|
||||
}
|
||||
|
||||
/// Local methods
|
||||
impl SsoAuth {
|
||||
pub fn new(state: OIDCState, client_challenge: OIDCCodeChallenge, nonce: String, redirect_uri: String) -> Self {
|
||||
let now = Utc::now().naive_utc();
|
||||
|
||||
SsoAuth {
|
||||
state,
|
||||
client_challenge,
|
||||
nonce,
|
||||
redirect_uri,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
code_response: None,
|
||||
auth_response: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Database methods
|
||||
impl SsoAuth {
|
||||
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
|
||||
db_run! { conn:
|
||||
mysql {
|
||||
diesel::insert_into(sso_auth::table)
|
||||
.values(self)
|
||||
.on_conflict(diesel::dsl::DuplicatedKeys)
|
||||
.do_update()
|
||||
.set(self)
|
||||
.execute(conn)
|
||||
.map_res("Error saving SSO auth")
|
||||
}
|
||||
postgresql, sqlite {
|
||||
diesel::insert_into(sso_auth::table)
|
||||
.values(self)
|
||||
.on_conflict(sso_auth::state)
|
||||
.do_update()
|
||||
.set(self)
|
||||
.execute(conn)
|
||||
.map_res("Error saving SSO auth")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn find(state: &OIDCState, conn: &DbConn) -> Option<Self> {
|
||||
let oldest = Utc::now().naive_utc() - *SSO_AUTH_EXPIRATION;
|
||||
db_run! { conn: {
|
||||
sso_auth::table
|
||||
.filter(sso_auth::state.eq(state))
|
||||
.filter(sso_auth::created_at.ge(oldest))
|
||||
.first::<Self>(conn)
|
||||
.ok()
|
||||
}}
|
||||
}
|
||||
|
||||
pub async fn delete(self, conn: &DbConn) -> EmptyResult {
|
||||
db_run! {conn: {
|
||||
diesel::delete(sso_auth::table.filter(sso_auth::state.eq(self.state)))
|
||||
.execute(conn)
|
||||
.map_res("Error deleting sso_auth")
|
||||
}}
|
||||
}
|
||||
|
||||
pub async fn delete_expired(pool: DbPool) -> EmptyResult {
|
||||
debug!("Purging expired sso_auth");
|
||||
if let Ok(conn) = pool.get().await {
|
||||
let oldest = Utc::now().naive_utc() - *SSO_AUTH_EXPIRATION;
|
||||
db_run! { conn: {
|
||||
diesel::delete(sso_auth::table.filter(sso_auth::created_at.lt(oldest)))
|
||||
.execute(conn)
|
||||
.map_res("Error deleting expired SSO nonce")
|
||||
}}
|
||||
} else {
|
||||
err!("Failed to get DB connection while purging expired sso_auth")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,87 +0,0 @@
|
||||
use chrono::{NaiveDateTime, Utc};
|
||||
|
||||
use crate::api::EmptyResult;
|
||||
use crate::db::schema::sso_nonce;
|
||||
use crate::db::{DbConn, DbPool};
|
||||
use crate::error::MapResult;
|
||||
use crate::sso::{OIDCState, NONCE_EXPIRATION};
|
||||
use diesel::prelude::*;
|
||||
|
||||
#[derive(Identifiable, Queryable, Insertable)]
|
||||
#[diesel(table_name = sso_nonce)]
|
||||
#[diesel(primary_key(state))]
|
||||
pub struct SsoNonce {
|
||||
pub state: OIDCState,
|
||||
pub nonce: String,
|
||||
pub verifier: Option<String>,
|
||||
pub redirect_uri: String,
|
||||
pub created_at: NaiveDateTime,
|
||||
}
|
||||
|
||||
/// Local methods
|
||||
impl SsoNonce {
|
||||
pub fn new(state: OIDCState, nonce: String, verifier: Option<String>, redirect_uri: String) -> Self {
|
||||
let now = Utc::now().naive_utc();
|
||||
|
||||
SsoNonce {
|
||||
state,
|
||||
nonce,
|
||||
verifier,
|
||||
redirect_uri,
|
||||
created_at: now,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Database methods
|
||||
impl SsoNonce {
|
||||
pub async fn save(&self, conn: &DbConn) -> EmptyResult {
|
||||
db_run! { conn:
|
||||
sqlite, mysql {
|
||||
diesel::replace_into(sso_nonce::table)
|
||||
.values(self)
|
||||
.execute(conn)
|
||||
.map_res("Error saving SSO nonce")
|
||||
}
|
||||
postgresql {
|
||||
diesel::insert_into(sso_nonce::table)
|
||||
.values(self)
|
||||
.execute(conn)
|
||||
.map_res("Error saving SSO nonce")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete(state: &OIDCState, conn: &DbConn) -> EmptyResult {
|
||||
db_run! { conn: {
|
||||
diesel::delete(sso_nonce::table.filter(sso_nonce::state.eq(state)))
|
||||
.execute(conn)
|
||||
.map_res("Error deleting SSO nonce")
|
||||
}}
|
||||
}
|
||||
|
||||
pub async fn find(state: &OIDCState, conn: &DbConn) -> Option<Self> {
|
||||
let oldest = Utc::now().naive_utc() - *NONCE_EXPIRATION;
|
||||
db_run! { conn: {
|
||||
sso_nonce::table
|
||||
.filter(sso_nonce::state.eq(state))
|
||||
.filter(sso_nonce::created_at.ge(oldest))
|
||||
.first::<Self>(conn)
|
||||
.ok()
|
||||
}}
|
||||
}
|
||||
|
||||
pub async fn delete_expired(pool: DbPool) -> EmptyResult {
|
||||
debug!("Purging expired sso_nonce");
|
||||
if let Ok(conn) = pool.get().await {
|
||||
let oldest = Utc::now().naive_utc() - *NONCE_EXPIRATION;
|
||||
db_run! { conn: {
|
||||
diesel::delete(sso_nonce::table.filter(sso_nonce::created_at.lt(oldest)))
|
||||
.execute(conn)
|
||||
.map_res("Error deleting expired SSO nonce")
|
||||
}}
|
||||
} else {
|
||||
err!("Failed to get DB connection while purging expired sso_nonce")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -256,12 +256,15 @@ table! {
|
||||
}
|
||||
|
||||
table! {
|
||||
sso_nonce (state) {
|
||||
sso_auth (state) {
|
||||
state -> Text,
|
||||
client_challenge -> Text,
|
||||
nonce -> Text,
|
||||
verifier -> Nullable<Text>,
|
||||
redirect_uri -> Text,
|
||||
code_response -> Nullable<Text>,
|
||||
auth_response -> Nullable<Text>,
|
||||
created_at -> Timestamp,
|
||||
updated_at -> Timestamp,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user