diff --git a/crates/policy/src/policy/principal.rs b/crates/policy/src/policy/principal.rs index d530c143..8c12ef9f 100644 --- a/crates/policy/src/policy/principal.rs +++ b/crates/policy/src/policy/principal.rs @@ -13,17 +13,67 @@ // limitations under the License. use super::{Validator, utils::wildcard}; -use crate::error::{Error, Result}; -use serde::{Deserialize, Serialize}; +use crate::error::Error; +use serde::Serialize; use std::collections::HashSet; -#[derive(Debug, Clone, Deserialize, Serialize, Default, PartialEq, Eq)] +#[derive(Debug, Clone, Serialize, Default, PartialEq, Eq)] #[serde(rename_all = "PascalCase", default)] pub struct Principal { #[serde(rename = "AWS")] aws: HashSet, } +#[derive(serde::Deserialize)] +#[serde(untagged)] +enum PrincipalFormat { + Wildcard(String), + AwsObject(PrincipalAwsObject), +} + +#[derive(serde::Deserialize)] +struct PrincipalAwsObject { + #[serde(rename = "AWS")] + aws: AwsValues, +} + +#[derive(serde::Deserialize)] +#[serde(untagged)] +enum AwsValues { + Single(String), + Multiple(HashSet), +} + +impl<'de> serde::Deserialize<'de> for Principal { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + let format = PrincipalFormat::deserialize(deserializer)?; + match format { + PrincipalFormat::Wildcard(s) => { + if s == "*" { + Ok(Principal { + aws: HashSet::from(["*".to_string()]), + }) + } else { + Err(serde::de::Error::custom(format!( + "invalid wildcard principal value: expected \"*\", got \"{}\"", + s + ))) + } + } + PrincipalFormat::AwsObject(obj) => { + let aws = match obj.aws { + AwsValues::Single(s) => HashSet::from([s]), + AwsValues::Multiple(set) => set, + }; + Ok(Principal { aws }) + } + } + } +} + impl Principal { pub fn is_match(&self, parincipal: &str) -> bool { for pattern in self.aws.iter() { @@ -37,10 +87,35 @@ impl Principal { impl Validator for Principal { type Error = Error; - fn is_valid(&self) -> Result<()> { + fn is_valid(&self) -> Result<(), Error> { if self.aws.is_empty() { return Err(Error::other("Principal is empty")); } Ok(()) } } + +#[cfg(test)] +mod test { + use super::*; + use serde_json; + use test_case::test_case; + + #[test_case(r#""*""#, true ; "wildcard_string")] + #[test_case(r#"{"AWS": "*"}"#, true ; "aws_object_single_string")] + #[test_case(r#"{"AWS": ["*"]}"#, true ; "aws_object_array")] + #[test_case(r#""invalid""#, false ; "invalid_string")] + #[test_case(r#""""#, false ; "empty_string")] + #[test_case(r#"{"Other": "*"}"#, false ; "wrong_field")] + #[test_case(r#"{}"#, false ; "empty_object")] + fn test_principal_parsing(json: &str, should_succeed: bool) { + let result = match serde_json::from_str::(json) { + Ok(principal) => { + assert!(principal.aws.contains("*")); + should_succeed + } + Err(_) => !should_succeed, + }; + assert!(result); + } +}