From cf25ac948a3352373a15cddd9b5e1d194f8f6481 Mon Sep 17 00:00:00 2001 From: Steve Russo <64294847+sjrusso8@users.noreply.github.com> Date: Fri, 7 Jun 2024 22:02:53 -0400 Subject: [PATCH] feat(session): add RunTimeConfig & Session tags (#45) --- README.md | 34 +++---- core/src/client/mod.rs | 126 +++++++++++++++++++++--- core/src/column.rs | 2 +- core/src/conf.rs | 139 +++++++++++++++++++++++++++ core/src/lib.rs | 1 + core/src/session.rs | 213 ++++++++++++++++++++++++++++++++++++++++- 6 files changed, 482 insertions(+), 33 deletions(-) create mode 100644 core/src/conf.rs diff --git a/README.md b/README.md index ee6a368..1a67ddd 100644 --- a/README.md +++ b/README.md @@ -85,36 +85,36 @@ The following section outlines some of the larger functionality that are not yet |------------------|----------|---------------------------------------| |active |![open] | | |addArtifact(s) |![open] | | -|addTag |![open] | | -|clearTags |![open] | | +|addTag |![done] | | +|clearTags |![done] | | |copyFromLocalToFs |![open] | | |createDataFrame |![partial]|Partial. Only works for `RecordBatch` | |getActiveSessions |![open] | | -|getTags |![open] | | -|interruptAll |![open] | | -|interruptOperation|![open] | | -|interruptTag |![open] | | +|getTags |![done] | | +|interruptAll |![done] | | +|interruptOperation|![done] | | +|interruptTag |![done] | | |newSession |![open] | | |range |![done] | | |removeTag |![done] | | |sql |![done] | | |stop |![open] | | |table |![done] | | -|catalog |![done] |[Catalog](#catalog) | -|client |X |unstable developer api for testing only| -|conf |![open] |[Conf](#runtimeconfig) | -|read |![done] |[DataFrameReader](#dataframereader) | -|readStream |![done] |[DataStreamReader](#datastreamreader) | -|streams |![open] |[Streams](#streamingquerymanager) | -|udf |![open] |[Udf](#udfregistration) - may not be possible | -|udtf |![open] |[Udtf](#udtfregistration) - may not be possible | -|version |![open] | | +|catalog |![done] |[Catalog](#catalog) | +|client |![done] |unstable developer api for testing only | +|conf |![done] |[Conf](#runtimeconfig) | +|read |![done] |[DataFrameReader](#dataframereader) | +|readStream |![done] |[DataStreamReader](#datastreamreader) | +|streams |![open] |[Streams](#streamingquerymanager) | +|udf |![open] |[Udf](#udfregistration) - may not be possible | +|udtf |![open] |[Udtf](#udtfregistration) - may not be possible | +|version |![done] | | ### SparkSessionBuilder |SparkSessionBuilder|API |Comment | |-------------------|----------|---------------------------------------| -|appName |![open] | | -|config |![open] | | +|appName |![done] | | +|config |![done] | | |master |![open] | | |remote |![partial]|Validate using [spark connection string](https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md)| diff --git a/core/src/client/mod.rs b/core/src/client/mod.rs index 114d728..ea349a3 100644 --- a/core/src/client/mod.rs +++ b/core/src/client/mod.rs @@ -357,6 +357,8 @@ pub struct SparkConnectClient { builder: ChannelBuilder, pub handler: ResponseHandler, pub analyzer: AnalyzeHandler, + pub user_context: Option, + pub tags: Vec, } impl SparkConnectClient @@ -367,11 +369,19 @@ where ::Error: Into + Send, { pub fn new(stub: Arc>>, builder: ChannelBuilder) -> Self { + let user_ref = builder.user_id.clone().unwrap_or("".to_string()); + SparkConnectClient { stub, builder, handler: ResponseHandler::new(), analyzer: AnalyzeHandler::new(), + user_context: Some(spark::UserContext { + user_id: user_ref.clone(), + user_name: user_ref, + extensions: vec![], + }), + tags: vec![], } } @@ -382,27 +392,19 @@ where fn execute_plan_request_with_metadata(&self) -> spark::ExecutePlanRequest { spark::ExecutePlanRequest { session_id: self.session_id(), - user_context: Some(spark::UserContext { - user_id: self.builder.user_id.clone().unwrap_or("n/a".to_string()), - user_name: self.builder.user_id.clone().unwrap_or("n/a".to_string()), - extensions: vec![], - }), + user_context: self.user_context.clone(), operation_id: None, plan: None, client_type: self.builder.user_agent.clone(), request_options: vec![], - tags: vec![], + tags: self.tags.clone(), } } fn analyze_plan_request_with_metadata(&self) -> spark::AnalyzePlanRequest { spark::AnalyzePlanRequest { session_id: self.session_id(), - user_context: Some(spark::UserContext { - user_id: self.builder.user_id.clone().unwrap_or("n/a".to_string()), - user_name: self.builder.user_id.clone().unwrap_or("n/a".to_string()), - extensions: vec![], - }), + user_context: self.user_context.clone(), client_type: self.builder.user_agent.clone(), analyze: None, } @@ -455,6 +457,107 @@ where self.handle_analyze(resp) } + #[allow(clippy::await_holding_lock)] + pub async fn interrupt_request( + &mut self, + interrupt_type: spark::interrupt_request::InterruptType, + id_or_tag: Option, + ) -> Result { + let mut req = spark::InterruptRequest { + session_id: self.session_id(), + user_context: self.user_context.clone(), + client_type: self.builder.user_agent.clone(), + interrupt_type: 0, + interrupt: None, + }; + + match interrupt_type { + spark::interrupt_request::InterruptType::All => { + req.interrupt_type = interrupt_type.into(); + } + spark::interrupt_request::InterruptType::Tag => { + let tag = id_or_tag.expect("Tag can not be empty"); + let interrupt = spark::interrupt_request::Interrupt::OperationTag(tag); + req.interrupt_type = interrupt_type.into(); + req.interrupt = Some(interrupt); + } + spark::interrupt_request::InterruptType::OperationId => { + let op_id = id_or_tag.expect("Operation ID can not be empty"); + let interrupt = spark::interrupt_request::Interrupt::OperationId(op_id); + req.interrupt_type = interrupt_type.into(); + req.interrupt = Some(interrupt); + } + spark::interrupt_request::InterruptType::Unspecified => { + return Err(SparkError::AnalysisException( + "Interrupt Type was not specified".to_string(), + )) + } + }; + + let mut client = self.stub.write(); + + let resp = client.interrupt(req).await?.into_inner(); + drop(client); + + Ok(resp) + } + + fn validate_tag(&self, tag: &str) -> Result<(), SparkError> { + if tag.contains(',') { + return Err(SparkError::AnalysisException( + "Spark Connect tag can not contain ',' ".to_string(), + )); + }; + + if tag.is_empty() { + return Err(SparkError::AnalysisException( + "Spark Connect tag can not an empty string ".to_string(), + )); + }; + + Ok(()) + } + + pub fn add_tag(&mut self, tag: &str) -> Result<(), SparkError> { + self.validate_tag(tag)?; + self.tags.push(tag.to_string()); + Ok(()) + } + + pub fn remove_tag(&mut self, tag: &str) -> Result<(), SparkError> { + self.validate_tag(tag)?; + self.tags.retain(|t| t != tag); + Ok(()) + } + + pub fn get_tags(&self) -> &Vec { + &self.tags + } + + pub fn clear_tags(&mut self) { + self.tags = vec![]; + } + + #[allow(clippy::await_holding_lock)] + pub async fn config_request( + &mut self, + operation: spark::config_request::Operation, + ) -> Result { + let operation = spark::ConfigRequest { + session_id: self.session_id(), + user_context: self.user_context.clone(), + client_type: self.builder.user_agent.clone(), + operation: Some(operation), + }; + + let mut client = self.stub.write(); + + let resp = client.config(operation).await?.into_inner(); + drop(client); + + Ok(resp) + } + fn handle_response(&mut self, resp: spark::ExecutePlanResponse) -> Result<(), SparkError> { self.validate_session(&resp.session_id)?; @@ -469,7 +572,6 @@ where ResponseType::ArrowBatch(res) => { self.deserialize(res.data.as_slice(), res.row_count)? } - // TODO! this shouldn't be clones but okay for now ResponseType::SqlCommandResult(sql_cmd) => { self.handler.sql_command_result = Some(sql_cmd.clone()) } diff --git a/core/src/column.rs b/core/src/column.rs index 7736b08..0efca12 100644 --- a/core/src/column.rs +++ b/core/src/column.rs @@ -41,7 +41,7 @@ pub struct Column { pub expression: spark::Expression, } -/// Trait used to cast columns to specific [DataTypes] +/// Trait used to cast columns to a specific [DataType] /// /// Either with a String or a [DataType] pub trait CastToDataType { diff --git a/core/src/conf.rs b/core/src/conf.rs new file mode 100644 index 0000000..ce6cd42 --- /dev/null +++ b/core/src/conf.rs @@ -0,0 +1,139 @@ +//! Configuration for a Spark application. Used to set various Spark parameters as key-value pairs. + +use std::collections::HashMap; + +use crate::spark; + +use crate::client::{MetadataInterceptor, SparkConnectClient}; +use crate::errors::SparkError; + +use tonic::service::interceptor::InterceptedService; + +#[cfg(not(feature = "wasm"))] +use tonic::transport::Channel; + +#[cfg(feature = "wasm")] +use tonic_web_wasm_client::Client; + +pub struct RunTimeConfig { + #[cfg(not(feature = "wasm"))] + pub(crate) client: SparkConnectClient>, + + #[cfg(feature = "wasm")] + pub(crate) client: SparkConnectClient>, +} + +/// User-facing configuration API, accessible through SparkSession.conf. +/// +/// Options set here are automatically propagated to the Hadoop configuration during I/O. +/// +/// # Example +/// ```rust +/// spark +/// .conf() +/// .set("spark.sql.shuffle.partitions", "42") +/// .await?; +/// ``` +impl RunTimeConfig { + #[allow(dead_code)] + pub(crate) async fn set_configs( + &mut self, + map: &HashMap, + ) -> Result<(), SparkError> { + for (key, value) in map { + self.set(key.as_str(), value.as_str()).await? + } + Ok(()) + } + + /// Sets the given Spark runtime configuration property. + pub async fn set(&mut self, key: &str, value: &str) -> Result<(), SparkError> { + let op_type = spark::config_request::operation::OpType::Set(spark::config_request::Set { + pairs: vec![spark::KeyValue { + key: key.into(), + value: Some(value.into()), + }], + }); + let operation = spark::config_request::Operation { + op_type: Some(op_type), + }; + + let _ = self.client.config_request(operation).await?; + + Ok(()) + } + + /// Resets the configuration property for the given key. + pub async fn unset(&mut self, key: &str) -> Result<(), SparkError> { + let op_type = + spark::config_request::operation::OpType::Unset(spark::config_request::Unset { + keys: vec![key.to_string()], + }); + let operation = spark::config_request::Operation { + op_type: Some(op_type), + }; + + let _ = self.client.config_request(operation).await?; + + Ok(()) + } + + /// Indicates whether the configuration property with the given key is modifiable in the current session. + pub async fn get(&mut self, key: &str, default: Option<&str>) -> Result { + let operation = match default { + Some(default) => { + let op_type = spark::config_request::operation::OpType::GetWithDefault( + spark::config_request::GetWithDefault { + pairs: vec![spark::KeyValue { + key: key.into(), + value: Some(default.into()), + }], + }, + ); + spark::config_request::Operation { + op_type: Some(op_type), + } + } + None => { + let op_type = + spark::config_request::operation::OpType::Get(spark::config_request::Get { + keys: vec![key.to_string()], + }); + spark::config_request::Operation { + op_type: Some(op_type), + } + } + }; + + let resp = self.client.config_request(operation).await?; + + let val = resp.pairs.first().unwrap().value().to_string(); + + Ok(val) + } + + /// Indicates whether the configuration property with the given key is modifiable in the current session. + #[allow(non_snake_case)] + pub async fn isModifable(&mut self, key: &str) -> Result { + let op_type = spark::config_request::operation::OpType::IsModifiable( + spark::config_request::IsModifiable { + keys: vec![key.to_string()], + }, + ); + let operation = spark::config_request::Operation { + op_type: Some(op_type), + }; + + let resp = self.client.config_request(operation).await?; + + let val = resp.pairs.first().unwrap().value(); + + match val { + "true" => Ok(true), + "false" => Ok(false), + _ => Err(SparkError::AnalysisException( + "Unexpected response value for boolean".to_string(), + )), + } + } +} diff --git a/core/src/lib.rs b/core/src/lib.rs index f466082..cd9915d 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -118,6 +118,7 @@ pub mod spark { pub mod catalog; pub mod client; pub mod column; +pub mod conf; pub mod dataframe; pub mod errors; pub mod expressions; diff --git a/core/src/session.rs b/core/src/session.rs index 19ae656..9bc0518 100644 --- a/core/src/session.rs +++ b/core/src/session.rs @@ -4,6 +4,7 @@ use std::collections::HashMap; use std::sync::Arc; use crate::catalog::Catalog; +use crate::conf::RunTimeConfig; use crate::dataframe::{DataFrame, DataFrameReader}; use crate::plan::LogicalPlanBuilder; use crate::spark; @@ -31,6 +32,7 @@ use tonic::service::interceptor::InterceptedService; #[derive(Clone, Debug)] pub struct SparkSessionBuilder { pub channel_builder: ChannelBuilder, + configs: HashMap, } /// Default connects a Spark cluster running at `sc://127.0.0.1:15002/` @@ -38,7 +40,10 @@ impl Default for SparkSessionBuilder { fn default() -> Self { let channel_builder = ChannelBuilder::default(); - Self { channel_builder } + Self { + channel_builder, + configs: HashMap::new(), + } } } @@ -46,7 +51,10 @@ impl SparkSessionBuilder { fn new(connection: &str) -> Self { let channel_builder = ChannelBuilder::create(connection).unwrap(); - Self { channel_builder } + Self { + channel_builder, + configs: HashMap::new(), + } } /// Validate a connect string for a remote Spark Session @@ -56,6 +64,20 @@ impl SparkSessionBuilder { Self::new(connection) } + /// Sets a config option. + pub fn config(mut self, key: &str, value: &str) -> Self { + self.configs.insert(key.into(), value.into()); + self + } + + /// Sets a name for the application, which will be shown in the Spark web UI. + #[allow(non_snake_case)] + pub fn appName(mut self, name: &str) -> Self { + self.configs + .insert("spark.app.name".to_string(), name.into()); + self + } + #[cfg(not(feature = "wasm"))] async fn create_client(&self) -> Result { let channel = Endpoint::from_shared(self.channel_builder.endpoint()) @@ -77,6 +99,12 @@ impl SparkSessionBuilder { let spark_connnect_client = SparkConnectClient::new(client.clone(), self.channel_builder.clone()); + let mut rt_config = RunTimeConfig { + client: spark_connnect_client.clone(), + }; + + rt_config.set_configs(&self.configs).await?; + Ok(SparkSession::new(spark_connnect_client)) } @@ -228,6 +256,87 @@ impl SparkSession { pub fn client(self) -> SparkConnectClient> { self.client } + + /// Interrupt all operations of this session currently running on the connected server. + #[allow(non_snake_case)] + pub async fn interruptAll(self) -> Result, SparkError> { + let resp = self + .client() + .interrupt_request(spark::interrupt_request::InterruptType::All, None) + .await?; + + Ok(resp.interrupted_ids) + } + + /// Interrupt all operations of this session with the given operation tag. + #[allow(non_snake_case)] + pub async fn interruptTag(self, tag: &str) -> Result, SparkError> { + let resp = self + .client() + .interrupt_request( + spark::interrupt_request::InterruptType::Tag, + Some(tag.to_string()), + ) + .await?; + + Ok(resp.interrupted_ids) + } + + /// Interrupt an operation of this session with the given operationId. + #[allow(non_snake_case)] + pub async fn interruptOperation(self, op_id: &str) -> Result, SparkError> { + let resp = self + .client() + .interrupt_request( + spark::interrupt_request::InterruptType::OperationId, + Some(op_id.to_string()), + ) + .await?; + + Ok(resp.interrupted_ids) + } + + /// Add a tag to be assigned to all the operations started by this thread in this session. + #[allow(non_snake_case)] + pub fn addTag(&mut self, tag: &str) -> Result<(), SparkError> { + self.client.add_tag(tag) + } + + /// Remove a tag previously added to be assigned to all the operations started by this thread in this session. + #[allow(non_snake_case)] + pub fn removeTag(&mut self, tag: &str) -> Result<(), SparkError> { + self.client.remove_tag(tag) + } + + /// Get the tags that are currently set to be assigned to all the operations started by this thread. + #[allow(non_snake_case)] + pub fn getTags(&mut self) -> &Vec { + self.client.get_tags() + } + + /// Clear the current thread’s operation tags. + #[allow(non_snake_case)] + pub fn clearTags(&mut self) { + self.client.clear_tags() + } + + /// The version of Spark on which this application is running. + pub async fn version(self) -> Result { + let version = spark::analyze_plan_request::Analyze::SparkVersion( + spark::analyze_plan_request::SparkVersion {}, + ); + + let mut client = self.client; + + client.analyze(version).await?.spark_version() + } + + /// [RunTimeConfig] configuration interface for Spark. + pub fn conf(&self) -> RunTimeConfig { + RunTimeConfig { + client: self.client.clone(), + } + } } #[cfg(test)] @@ -235,7 +344,7 @@ mod tests { use super::*; #[test] - fn test_spark_session_builder() { + fn test_session_builder() { let connection = "sc://myhost.com:443/;token=ABCDEFG;user_agent=some_agent;user_id=user123"; let ssbuilder = SparkSessionBuilder::remote(connection); @@ -259,4 +368,102 @@ mod tests { assert!(spark.is_ok()); } + + #[tokio::test] + async fn test_session_tags() -> Result<(), SparkError> { + let mut spark = SparkSessionBuilder::default().build().await?; + + spark.addTag("hello-tag")?; + + spark.addTag("hello-tag-2")?; + + let expected = vec!["hello-tag".to_string(), "hello-tag-2".to_string()]; + + let res = spark.getTags(); + + assert_eq!(&expected, res); + + spark.clearTags(); + let res = spark.getTags(); + + let expected: Vec = vec![]; + + assert_eq!(&expected, res); + + Ok(()) + } + + #[tokio::test] + async fn test_session_tags_panic() -> Result<(), SparkError> { + let mut spark = SparkSessionBuilder::default().build().await?; + + assert!(spark.addTag("bad,tag").is_err()); + assert!(spark.addTag("").is_err()); + + assert!(spark.removeTag("bad,tag").is_err()); + assert!(spark.removeTag("").is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_session_version() -> Result<(), SparkError> { + let spark = SparkSessionBuilder::default().build().await?; + + let version = spark.version().await?; + + assert_eq!("3.5.1".to_string(), version); + Ok(()) + } + + #[tokio::test] + async fn test_session_config() -> Result<(), SparkError> { + let value = "rust-test-app"; + + let spark = SparkSessionBuilder::default() + .appName("rust-test-app") + .build() + .await?; + + let name = spark.conf().get("spark.app.name", None).await?; + + assert_eq!(value, &name); + + // validate set + spark + .conf() + .set("spark.sql.shuffle.partitions", "42") + .await?; + + // validate get + let val = spark + .conf() + .get("spark.sql.shuffle.partitions", None) + .await?; + + assert_eq!("42", &val); + + // validate unset + spark.conf().unset("spark.sql.shuffle.partitions").await?; + + let val = spark + .conf() + .get("spark.sql.shuffle.partitions", None) + .await?; + + assert_eq!("200", &val); + + // not a modifable setting + let val = spark.conf().isModifable("spark.executor.instances").await?; + assert!(!val); + + // a modifable setting + let val = spark + .conf() + .isModifable("spark.sql.shuffle.partitions") + .await?; + assert!(val); + + Ok(()) + } }