Skip to content

Commit

Permalink
feat(session): add RunTimeConfig & Session tags (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjrusso8 authored Jun 8, 2024
1 parent 87cb35d commit cf25ac9
Show file tree
Hide file tree
Showing 6 changed files with 482 additions and 33 deletions.
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,36 +85,36 @@ The following section outlines some of the larger functionality that are not yet
|------------------|----------|---------------------------------------|
|active |![open] | |
|addArtifact(s) |![open] | |
|addTag |![open] | |
|clearTags |![open] | |
|addTag |![done] | |
|clearTags |![done] | |
|copyFromLocalToFs |![open] | |
|createDataFrame |![partial]|Partial. Only works for `RecordBatch` |
|getActiveSessions |![open] | |
|getTags |![open] | |
|interruptAll |![open] | |
|interruptOperation|![open] | |
|interruptTag |![open] | |
|getTags |![done] | |
|interruptAll |![done] | |
|interruptOperation|![done] | |
|interruptTag |![done] | |
|newSession |![open] | |
|range |![done] | |
|removeTag |![done] | |
|sql |![done] | |
|stop |![open] | |
|table |![done] | |
|catalog |![done] |[Catalog](#catalog) |
|client |X |unstable developer api for testing only|
|conf |![open] |[Conf](#runtimeconfig) |
|read |![done] |[DataFrameReader](#dataframereader) |
|readStream |![done] |[DataStreamReader](#datastreamreader) |
|streams |![open] |[Streams](#streamingquerymanager) |
|udf |![open] |[Udf](#udfregistration) - may not be possible |
|udtf |![open] |[Udtf](#udtfregistration) - may not be possible |
|version |![open] | |
|catalog |![done] |[Catalog](#catalog) |
|client |![done] |unstable developer api for testing only |
|conf |![done] |[Conf](#runtimeconfig) |
|read |![done] |[DataFrameReader](#dataframereader) |
|readStream |![done] |[DataStreamReader](#datastreamreader) |
|streams |![open] |[Streams](#streamingquerymanager) |
|udf |![open] |[Udf](#udfregistration) - may not be possible |
|udtf |![open] |[Udtf](#udtfregistration) - may not be possible |
|version |![done] | |

### SparkSessionBuilder
|SparkSessionBuilder|API |Comment |
|-------------------|----------|---------------------------------------|
|appName |![open] | |
|config |![open] | |
|appName |![done] | |
|config |![done] | |
|master |![open] | |
|remote |![partial]|Validate using [spark connection string](https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md)|

Expand Down
126 changes: 114 additions & 12 deletions core/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ pub struct SparkConnectClient<T> {
builder: ChannelBuilder,
pub handler: ResponseHandler,
pub analyzer: AnalyzeHandler,
pub user_context: Option<spark::UserContext>,
pub tags: Vec<String>,
}

impl<T> SparkConnectClient<T>
Expand All @@ -367,11 +369,19 @@ where
<T::ResponseBody as Body>::Error: Into<StdError> + Send,
{
pub fn new(stub: Arc<RwLock<SparkConnectServiceClient<T>>>, builder: ChannelBuilder) -> Self {
let user_ref = builder.user_id.clone().unwrap_or("".to_string());

SparkConnectClient {
stub,
builder,
handler: ResponseHandler::new(),
analyzer: AnalyzeHandler::new(),
user_context: Some(spark::UserContext {
user_id: user_ref.clone(),
user_name: user_ref,
extensions: vec![],
}),
tags: vec![],
}
}

Expand All @@ -382,27 +392,19 @@ where
fn execute_plan_request_with_metadata(&self) -> spark::ExecutePlanRequest {
spark::ExecutePlanRequest {
session_id: self.session_id(),
user_context: Some(spark::UserContext {
user_id: self.builder.user_id.clone().unwrap_or("n/a".to_string()),
user_name: self.builder.user_id.clone().unwrap_or("n/a".to_string()),
extensions: vec![],
}),
user_context: self.user_context.clone(),
operation_id: None,
plan: None,
client_type: self.builder.user_agent.clone(),
request_options: vec![],
tags: vec![],
tags: self.tags.clone(),
}
}

fn analyze_plan_request_with_metadata(&self) -> spark::AnalyzePlanRequest {
spark::AnalyzePlanRequest {
session_id: self.session_id(),
user_context: Some(spark::UserContext {
user_id: self.builder.user_id.clone().unwrap_or("n/a".to_string()),
user_name: self.builder.user_id.clone().unwrap_or("n/a".to_string()),
extensions: vec![],
}),
user_context: self.user_context.clone(),
client_type: self.builder.user_agent.clone(),
analyze: None,
}
Expand Down Expand Up @@ -455,6 +457,107 @@ where
self.handle_analyze(resp)
}

#[allow(clippy::await_holding_lock)]
pub async fn interrupt_request(
&mut self,
interrupt_type: spark::interrupt_request::InterruptType,
id_or_tag: Option<String>,
) -> Result<spark::InterruptResponse, SparkError> {
let mut req = spark::InterruptRequest {
session_id: self.session_id(),
user_context: self.user_context.clone(),
client_type: self.builder.user_agent.clone(),
interrupt_type: 0,
interrupt: None,
};

match interrupt_type {
spark::interrupt_request::InterruptType::All => {
req.interrupt_type = interrupt_type.into();
}
spark::interrupt_request::InterruptType::Tag => {
let tag = id_or_tag.expect("Tag can not be empty");
let interrupt = spark::interrupt_request::Interrupt::OperationTag(tag);
req.interrupt_type = interrupt_type.into();
req.interrupt = Some(interrupt);
}
spark::interrupt_request::InterruptType::OperationId => {
let op_id = id_or_tag.expect("Operation ID can not be empty");
let interrupt = spark::interrupt_request::Interrupt::OperationId(op_id);
req.interrupt_type = interrupt_type.into();
req.interrupt = Some(interrupt);
}
spark::interrupt_request::InterruptType::Unspecified => {
return Err(SparkError::AnalysisException(
"Interrupt Type was not specified".to_string(),
))
}
};

let mut client = self.stub.write();

let resp = client.interrupt(req).await?.into_inner();
drop(client);

Ok(resp)
}

fn validate_tag(&self, tag: &str) -> Result<(), SparkError> {
if tag.contains(',') {
return Err(SparkError::AnalysisException(
"Spark Connect tag can not contain ',' ".to_string(),
));
};

if tag.is_empty() {
return Err(SparkError::AnalysisException(
"Spark Connect tag can not an empty string ".to_string(),
));
};

Ok(())
}

pub fn add_tag(&mut self, tag: &str) -> Result<(), SparkError> {
self.validate_tag(tag)?;
self.tags.push(tag.to_string());
Ok(())
}

pub fn remove_tag(&mut self, tag: &str) -> Result<(), SparkError> {
self.validate_tag(tag)?;
self.tags.retain(|t| t != tag);
Ok(())
}

pub fn get_tags(&self) -> &Vec<String> {
&self.tags
}

pub fn clear_tags(&mut self) {
self.tags = vec![];
}

#[allow(clippy::await_holding_lock)]
pub async fn config_request(
&mut self,
operation: spark::config_request::Operation,
) -> Result<spark::ConfigResponse, SparkError> {
let operation = spark::ConfigRequest {
session_id: self.session_id(),
user_context: self.user_context.clone(),
client_type: self.builder.user_agent.clone(),
operation: Some(operation),
};

let mut client = self.stub.write();

let resp = client.config(operation).await?.into_inner();
drop(client);

Ok(resp)
}

fn handle_response(&mut self, resp: spark::ExecutePlanResponse) -> Result<(), SparkError> {
self.validate_session(&resp.session_id)?;

Expand All @@ -469,7 +572,6 @@ where
ResponseType::ArrowBatch(res) => {
self.deserialize(res.data.as_slice(), res.row_count)?
}
// TODO! this shouldn't be clones but okay for now
ResponseType::SqlCommandResult(sql_cmd) => {
self.handler.sql_command_result = Some(sql_cmd.clone())
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub struct Column {
pub expression: spark::Expression,
}

/// Trait used to cast columns to specific [DataTypes]
/// Trait used to cast columns to a specific [DataType]
///
/// Either with a String or a [DataType]
pub trait CastToDataType {
Expand Down
139 changes: 139 additions & 0 deletions core/src/conf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
//! Configuration for a Spark application. Used to set various Spark parameters as key-value pairs.

use std::collections::HashMap;

use crate::spark;

use crate::client::{MetadataInterceptor, SparkConnectClient};
use crate::errors::SparkError;

use tonic::service::interceptor::InterceptedService;

#[cfg(not(feature = "wasm"))]
use tonic::transport::Channel;

#[cfg(feature = "wasm")]
use tonic_web_wasm_client::Client;

pub struct RunTimeConfig {
#[cfg(not(feature = "wasm"))]
pub(crate) client: SparkConnectClient<InterceptedService<Channel, MetadataInterceptor>>,

#[cfg(feature = "wasm")]
pub(crate) client: SparkConnectClient<InterceptedService<Client, MetadataInterceptor>>,
}

/// User-facing configuration API, accessible through SparkSession.conf.
///
/// Options set here are automatically propagated to the Hadoop configuration during I/O.
///
/// # Example
/// ```rust
/// spark
/// .conf()
/// .set("spark.sql.shuffle.partitions", "42")
/// .await?;
/// ```
impl RunTimeConfig {
#[allow(dead_code)]
pub(crate) async fn set_configs(
&mut self,
map: &HashMap<String, String>,
) -> Result<(), SparkError> {
for (key, value) in map {
self.set(key.as_str(), value.as_str()).await?
}
Ok(())
}

/// Sets the given Spark runtime configuration property.
pub async fn set(&mut self, key: &str, value: &str) -> Result<(), SparkError> {
let op_type = spark::config_request::operation::OpType::Set(spark::config_request::Set {
pairs: vec![spark::KeyValue {
key: key.into(),
value: Some(value.into()),
}],
});
let operation = spark::config_request::Operation {
op_type: Some(op_type),
};

let _ = self.client.config_request(operation).await?;

Ok(())
}

/// Resets the configuration property for the given key.
pub async fn unset(&mut self, key: &str) -> Result<(), SparkError> {
let op_type =
spark::config_request::operation::OpType::Unset(spark::config_request::Unset {
keys: vec![key.to_string()],
});
let operation = spark::config_request::Operation {
op_type: Some(op_type),
};

let _ = self.client.config_request(operation).await?;

Ok(())
}

/// Indicates whether the configuration property with the given key is modifiable in the current session.
pub async fn get(&mut self, key: &str, default: Option<&str>) -> Result<String, SparkError> {
let operation = match default {
Some(default) => {
let op_type = spark::config_request::operation::OpType::GetWithDefault(
spark::config_request::GetWithDefault {
pairs: vec![spark::KeyValue {
key: key.into(),
value: Some(default.into()),
}],
},
);
spark::config_request::Operation {
op_type: Some(op_type),
}
}
None => {
let op_type =
spark::config_request::operation::OpType::Get(spark::config_request::Get {
keys: vec![key.to_string()],
});
spark::config_request::Operation {
op_type: Some(op_type),
}
}
};

let resp = self.client.config_request(operation).await?;

let val = resp.pairs.first().unwrap().value().to_string();

Ok(val)
}

/// Indicates whether the configuration property with the given key is modifiable in the current session.
#[allow(non_snake_case)]
pub async fn isModifable(&mut self, key: &str) -> Result<bool, SparkError> {
let op_type = spark::config_request::operation::OpType::IsModifiable(
spark::config_request::IsModifiable {
keys: vec![key.to_string()],
},
);
let operation = spark::config_request::Operation {
op_type: Some(op_type),
};

let resp = self.client.config_request(operation).await?;

let val = resp.pairs.first().unwrap().value();

match val {
"true" => Ok(true),
"false" => Ok(false),
_ => Err(SparkError::AnalysisException(
"Unexpected response value for boolean".to_string(),
)),
}
}
}
1 change: 1 addition & 0 deletions core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ pub mod spark {
pub mod catalog;
pub mod client;
pub mod column;
pub mod conf;
pub mod dataframe;
pub mod errors;
pub mod expressions;
Expand Down
Loading

0 comments on commit cf25ac9

Please sign in to comment.