diff --git a/Cargo.lock b/Cargo.lock index 6cfeaff9..6968c6e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1206,6 +1206,7 @@ dependencies = [ "metrics-runtime", "rand", "regex", + "reqwest", "rust_decimal", "rustc-hex", "secp256k1", diff --git a/crates/erc20_payment_lib_common/Cargo.toml b/crates/erc20_payment_lib_common/Cargo.toml index ecda14c9..0ef0fe30 100644 --- a/crates/erc20_payment_lib_common/Cargo.toml +++ b/crates/erc20_payment_lib_common/Cargo.toml @@ -29,6 +29,7 @@ rustc-hex = { workspace = true } secp256k1 = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } +reqwest = { workspace = true, features = ["json"] } sha3 = { workspace = true } sqlx = { workspace = true } structopt = { workspace = true } diff --git a/crates/erc20_payment_lib_common/src/dns_over_https_resolver.rs b/crates/erc20_payment_lib_common/src/dns_over_https_resolver.rs new file mode 100644 index 00000000..8e867efd --- /dev/null +++ b/crates/erc20_payment_lib_common/src/dns_over_https_resolver.rs @@ -0,0 +1,73 @@ +use reqwest; +use serde; +use std::env; +use std::io::ErrorKind as IoErrorKind; + +#[derive(serde::Deserialize)] +struct DnsOverHttpsResponse { + #[serde(rename(deserialize = "Answer"))] + pub answer: Vec, +} + +#[derive(serde::Deserialize)] +struct DnsOverHttpsAnswer { + pub data: String, +} + +fn strip_quotes(data: &String) -> String { + let v1 = match data.strip_prefix('"') { + None => data.as_str(), + Some(x) => x, + }; + String::from(match v1.strip_suffix('"') { + None => v1, + Some(x) => x, + }) +} + +pub enum DnsOverHttpsServer { + Google, + Cloudflare, +} + +impl DnsOverHttpsServer { + pub fn get_dns_url(self: &DnsOverHttpsServer) -> &str { + match self { + DnsOverHttpsServer::Google => "https://dns.google/resolve", + DnsOverHttpsServer::Cloudflare => "https://cloudflare-dns.com/dns-query", + } + } +} + +pub async fn resolve_dns_record_https( + record: &str, + record_type: &str, + dns_server: DnsOverHttpsServer, +) -> std::io::Result> { + let result = reqwest::Client::new() + .get(dns_server.get_dns_url()) + .query(&[("name", record), ("type", record_type)]) + .header(reqwest::header::ACCEPT, "application/dns-json") + .send() + .await + .map_err(|_| std::io::Error::new(IoErrorKind::Other, "Couldn't fetch DNS record."))? + .json::() + .await + .map_err(|_| std::io::Error::new(IoErrorKind::Other, "Couldn't fetch DNS record."))? + .answer + .iter() + .map(|a| strip_quotes(&a.data)) + .collect(); + Ok(result) +} + +pub async fn resolve_txt_record_to_string_array_https( + record: &str, + dns_server: DnsOverHttpsServer, +) -> std::io::Result> { + resolve_dns_record_https(record, "TXT", dns_server).await +} + +pub fn should_use_dns_over_https() -> bool { + matches!(env::var("YA_USE_HTTPS_DNS_RESOLVER"), Ok(value) if value == "1") +} diff --git a/crates/erc20_payment_lib_common/src/lib.rs b/crates/erc20_payment_lib_common/src/lib.rs index 206f6739..5340003f 100644 --- a/crates/erc20_payment_lib_common/src/lib.rs +++ b/crates/erc20_payment_lib_common/src/lib.rs @@ -1,4 +1,5 @@ mod db; +pub mod dns_over_https_resolver; pub mod error; mod events; mod metrics; diff --git a/crates/erc20_rpc_pool/src/rpc_pool/pool.rs b/crates/erc20_rpc_pool/src/rpc_pool/pool.rs index eeabc890..cd74ec93 100644 --- a/crates/erc20_rpc_pool/src/rpc_pool/pool.rs +++ b/crates/erc20_rpc_pool/src/rpc_pool/pool.rs @@ -20,6 +20,10 @@ use uuid::Uuid; use web3::transports::Http; use web3::Web3; +use erc20_payment_lib_common::dns_over_https_resolver::{ + resolve_txt_record_to_string_array_https, should_use_dns_over_https, DnsOverHttpsServer, +}; + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Web3ExternalEndpointList { @@ -122,7 +126,7 @@ pub struct Web3RpcPool { pub endpoint_verifier: Arc, } -pub async fn resolve_txt_record_to_string_array(record: &str) -> std::io::Result> { +pub async fn resolve_txt_record_to_string_array_dns(record: &str) -> std::io::Result> { let resolver: TokioAsyncResolver = TokioAsyncResolver::tokio(ResolverConfig::google(), ResolverOpts::default()); @@ -136,6 +140,14 @@ pub async fn resolve_txt_record_to_string_array(record: &str) -> std::io::Result .collect::>()) } +pub async fn resolve_txt_record_to_string_array(record: &str) -> std::io::Result> { + if should_use_dns_over_https() { + resolve_txt_record_to_string_array_https(record, DnsOverHttpsServer::Google).await + } else { + resolve_txt_record_to_string_array_dns(record).await + } +} + pub struct ChooseBestEndpointsResult { pub allowed_endpoints: Vec, pub is_resolving: bool,