diff --git a/Cargo.lock b/Cargo.lock index a2c7f42a02..395f903570 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7630,6 +7630,7 @@ dependencies = [ name = "spin-factor-variables" version = "2.7.0-pre0" dependencies = [ + "dotenvy", "serde 1.0.197", "spin-expressions", "spin-factors", @@ -7637,6 +7638,7 @@ dependencies = [ "spin-world", "tokio", "toml 0.8.14", + "tracing", ] [[package]] diff --git a/crates/factor-variables/Cargo.toml b/crates/factor-variables/Cargo.toml index 30465c2096..ac7370d709 100644 --- a/crates/factor-variables/Cargo.toml +++ b/crates/factor-variables/Cargo.toml @@ -5,11 +5,14 @@ authors = { workspace = true } edition = { workspace = true } [dependencies] +dotenvy = "0.15" serde = { version = "1.0", features = ["rc"] } spin-expressions = { path = "../expressions" } spin-factors = { path = "../factors" } spin-world = { path = "../world" } +tokio = { version = "1", features = ["rt-multi-thread"] } toml = "0.8" +tracing = { workspace = true } [dev-dependencies] spin-factors-test = { path = "../factors-test" } diff --git a/crates/factor-variables/src/lib.rs b/crates/factor-variables/src/lib.rs index 596dc6e747..22c26022b0 100644 --- a/crates/factor-variables/src/lib.rs +++ b/crates/factor-variables/src/lib.rs @@ -1,8 +1,7 @@ -mod provider; +pub mod provider; use std::{collections::HashMap, sync::Arc}; -use provider::{provider_from_toml_fn, ProviderFromToml}; use serde::Deserialize; use spin_expressions::ProviderResolver; use spin_factors::{ @@ -16,7 +15,7 @@ pub use provider::{MakeVariablesProvider, StaticVariables}; #[derive(Default)] pub struct VariablesFactor { - provider_types: HashMap<&'static str, ProviderFromToml>, + provider_types: HashMap<&'static str, provider::ProviderFromToml>, } impl VariablesFactor { @@ -26,7 +25,10 @@ impl VariablesFactor { ) -> anyhow::Result<()> { if self .provider_types - .insert(T::RUNTIME_CONFIG_TYPE, provider_from_toml_fn(provider_type)) + .insert( + T::RUNTIME_CONFIG_TYPE, + provider::provider_from_toml_fn(provider_type), + ) .is_some() { bail!("duplicate provider type {:?}", T::RUNTIME_CONFIG_TYPE); diff --git a/crates/factor-variables/src/provider.rs b/crates/factor-variables/src/provider.rs index f3ee948118..f34945f1cc 100644 --- a/crates/factor-variables/src/provider.rs +++ b/crates/factor-variables/src/provider.rs @@ -1,7 +1,11 @@ -use std::{collections::HashMap, sync::Arc}; +mod env; +mod statik; -use serde::{de::DeserializeOwned, Deserialize}; -use spin_expressions::{async_trait::async_trait, Key, Provider}; +pub use env::EnvVariables; +pub use statik::StaticVariables; + +use serde::de::DeserializeOwned; +use spin_expressions::Provider; use spin_factors::anyhow; pub trait MakeVariablesProvider: 'static { @@ -24,28 +28,3 @@ pub(crate) fn provider_from_toml_fn( Ok(Box::new(provider)) }) } - -pub struct StaticVariables; - -impl MakeVariablesProvider for StaticVariables { - const RUNTIME_CONFIG_TYPE: &'static str = "static"; - - type RuntimeConfig = StaticVariablesProvider; - type Provider = StaticVariablesProvider; - - fn make_provider(&self, runtime_config: Self::RuntimeConfig) -> anyhow::Result { - Ok(runtime_config) - } -} - -#[derive(Debug, Deserialize)] -pub struct StaticVariablesProvider { - values: Arc>, -} - -#[async_trait] -impl Provider for StaticVariablesProvider { - async fn get(&self, key: &Key) -> anyhow::Result> { - Ok(self.values.get(key.as_str()).cloned()) - } -} diff --git a/crates/factor-variables/src/provider/env.rs b/crates/factor-variables/src/provider/env.rs new file mode 100644 index 0000000000..23b87c022c --- /dev/null +++ b/crates/factor-variables/src/provider/env.rs @@ -0,0 +1,215 @@ +use std::{ + collections::HashMap, + env::VarError, + path::{Path, PathBuf}, + sync::OnceLock, +}; + +use serde::Deserialize; +use spin_expressions::{Key, Provider}; +use spin_factors::anyhow::{self, Context as _}; +use spin_world::async_trait; +use tracing::{instrument, Level}; + +use crate::MakeVariablesProvider; + +/// Creator of a environment variables provider. +pub struct EnvVariables; + +impl MakeVariablesProvider for EnvVariables { + const RUNTIME_CONFIG_TYPE: &'static str = "static"; + + type RuntimeConfig = EnvVariablesConfig; + type Provider = EnvVariablesProvider; + + fn make_provider(&self, runtime_config: Self::RuntimeConfig) -> anyhow::Result { + Ok(EnvVariablesProvider::new( + runtime_config.prefix, + |key| std::env::var(key), + runtime_config.dotenv_path, + )) + } +} + +/// Configuration for the environment variables provider. +#[derive(Debug, Default, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct EnvVariablesConfig { + /// A prefix to add to variable names when resolving from the environment. + /// + /// Unless empty, joined to the variable name with an underscore. + #[serde(default)] + pub prefix: Option, + /// Optional path to a 'dotenv' file which will be merged into the environment. + #[serde(default)] + pub dotenv_path: Option, +} + +const DEFAULT_ENV_PREFIX: &str = "SPIN_VARIABLE"; + +/// A config Provider that uses environment variables. +pub struct EnvVariablesProvider { + prefix: Option, + env_fetcher: Box Result + Send + Sync>, + dotenv_path: Option, + dotenv_cache: OnceLock>, +} + +impl EnvVariablesProvider { + /// Creates a new EnvProvider. + /// + /// * `prefix` - The string prefix to use to distinguish an environment variable that should be used. + /// If not set, the default prefix is used. + /// * `env_fetcher` - The function to use to fetch an environment variable. + /// * `dotenv_path` - The path to the .env file to load environment variables from. If not set, + /// no .env file is loaded. + pub fn new( + prefix: Option>, + env_fetcher: impl Fn(&str) -> Result + Send + Sync + 'static, + dotenv_path: Option, + ) -> Self { + Self { + prefix: prefix.map(Into::into), + dotenv_path, + env_fetcher: Box::new(env_fetcher), + dotenv_cache: Default::default(), + } + } + + /// Gets the value of a variable from the environment. + fn get_sync(&self, key: &Key) -> anyhow::Result> { + let prefix = self + .prefix + .clone() + .unwrap_or(DEFAULT_ENV_PREFIX.to_string()); + + let upper_key = key.as_ref().to_ascii_uppercase(); + let env_key = format!("{prefix}_{upper_key}"); + + self.query_env(&env_key) + } + + /// Queries the environment for a variable defaulting to dotenv. + fn query_env(&self, env_key: &str) -> anyhow::Result> { + match (self.env_fetcher)(env_key) { + Err(std::env::VarError::NotPresent) => self.get_dotenv(env_key), + other => other + .map(Some) + .with_context(|| format!("failed to resolve env var {env_key}")), + } + } + + fn get_dotenv(&self, key: &str) -> anyhow::Result> { + let Some(dotenv_path) = self.dotenv_path.as_deref() else { + return Ok(None); + }; + let cache = match self.dotenv_cache.get() { + Some(cache) => cache, + None => { + let cache = load_dotenv(dotenv_path)?; + let _ = self.dotenv_cache.set(cache); + // Safe to unwrap because we just set the cache. + // Ensures we always get the first value set. + self.dotenv_cache.get().unwrap() + } + }; + Ok(cache.get(key).cloned()) + } +} + +impl std::fmt::Debug for EnvVariablesProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EnvProvider") + .field("prefix", &self.prefix) + .field("dotenv_path", &self.dotenv_path) + .finish() + } +} + +fn load_dotenv(dotenv_path: &Path) -> anyhow::Result> { + Ok(dotenvy::from_path_iter(dotenv_path) + .into_iter() + .flatten() + .collect::, _>>()?) +} + +#[async_trait] +impl Provider for EnvVariablesProvider { + #[instrument(name = "spin_variables.get_from_env", skip(self), err(level = Level::INFO))] + async fn get(&self, key: &Key) -> anyhow::Result> { + tokio::task::block_in_place(|| self.get_sync(key)) + } +} + +#[cfg(test)] +mod test { + use std::env::temp_dir; + + use super::*; + + struct TestEnv { + map: HashMap, + } + + impl TestEnv { + fn new() -> Self { + Self { + map: Default::default(), + } + } + + fn insert(&mut self, key: &str, value: &str) { + self.map.insert(key.to_string(), value.to_string()); + } + + fn get(&self, key: &str) -> Result { + self.map.get(key).cloned().ok_or(VarError::NotPresent) + } + } + + #[test] + fn provider_get() { + let mut env = TestEnv::new(); + env.insert("TESTING_SPIN_ENV_KEY1", "val"); + let key1 = Key::new("env_key1").unwrap(); + assert_eq!( + EnvVariablesProvider::new(Some("TESTING_SPIN"), move |key| env.get(key), None) + .get_sync(&key1) + .unwrap(), + Some("val".to_string()) + ); + } + + #[test] + fn provider_get_dotenv() { + let dotenv_path = temp_dir().join("spin-env-provider-test"); + std::fs::write(&dotenv_path, b"TESTING_SPIN_ENV_KEY2=dotenv_val").unwrap(); + + let key = Key::new("env_key2").unwrap(); + assert_eq!( + EnvVariablesProvider::new( + Some("TESTING_SPIN"), + |_| Err(VarError::NotPresent), + Some(dotenv_path) + ) + .get_sync(&key) + .unwrap(), + Some("dotenv_val".to_string()) + ); + } + + #[test] + fn provider_get_missing() { + let key = Key::new("definitely_not_set").unwrap(); + assert_eq!( + EnvVariablesProvider::new( + Some("TESTING_SPIN"), + |_| Err(VarError::NotPresent), + Default::default() + ) + .get_sync(&key) + .unwrap(), + None + ); + } +} diff --git a/crates/factor-variables/src/provider/statik.rs b/crates/factor-variables/src/provider/statik.rs new file mode 100644 index 0000000000..222c7168e1 --- /dev/null +++ b/crates/factor-variables/src/provider/statik.rs @@ -0,0 +1,34 @@ +use std::{collections::HashMap, sync::Arc}; + +use serde::Deserialize; +use spin_expressions::{async_trait::async_trait, Key, Provider}; +use spin_factors::anyhow; + +use crate::MakeVariablesProvider; + +/// Creator of a static variables provider. +pub struct StaticVariables; + +impl MakeVariablesProvider for StaticVariables { + const RUNTIME_CONFIG_TYPE: &'static str = "static"; + + type RuntimeConfig = StaticVariablesProvider; + type Provider = StaticVariablesProvider; + + fn make_provider(&self, runtime_config: Self::RuntimeConfig) -> anyhow::Result { + Ok(runtime_config) + } +} + +/// A variables provider that reads variables from an static map. +#[derive(Debug, Deserialize)] +pub struct StaticVariablesProvider { + values: Arc>, +} + +#[async_trait] +impl Provider for StaticVariablesProvider { + async fn get(&self, key: &Key) -> anyhow::Result> { + Ok(self.values.get(key.as_str()).cloned()) + } +}