From b92c7b38013a94f8310900e8e04b8b201d69de43 Mon Sep 17 00:00:00 2001 From: Bryan Newbold Date: Tue, 25 Aug 2020 12:22:17 -0700 Subject: progress: query param parsing, small renamings --- src/lib.rs | 84 ++++++++++++++++++++++++++++++++++++++++-------------------- src/main.rs | 4 +-- src/parse.rs | 8 ++++++ 3 files changed, 66 insertions(+), 30 deletions(-) (limited to 'src') diff --git a/src/lib.rs b/src/lib.rs index 8c60068..ffc06bd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,11 @@ -use std::collections::HashMap; -use serde::{Serialize, Deserialize}; +use serde::Deserialize; use hyper::{Request, Body, Method, Uri}; -use http::request; -use url; pub mod parse; +use parse::UrlQueryParams; + #[derive(Default, Deserialize, Debug, Clone)] pub struct ProxyConfig { pub bind_addr: Option, // 127.0.0.1:9292 @@ -44,7 +43,7 @@ pub enum ProxyError { NotFound(String), } -pub async fn parse_request(req: Request, config: &ProxyConfig) -> Result, ProxyError> { +pub async fn filter_request(req: Request, config: &ProxyConfig) -> Result, ProxyError> { let (parts, body) = req.into_parts(); // split path into at most 3 chunks @@ -57,13 +56,7 @@ pub async fn parse_request(req: Request, config: &ProxyConfig) -> Result = parts.uri.query() - .map(|v| { - url::form_urlencoded::parse(v.as_bytes()) - .into_owned() - .collect() - }) - .unwrap_or_else(HashMap::new); + let params = parse_params(parts.uri.query())?; // this is sort of like a router let body = match (&parts.method, path_chunks.as_slice()) { @@ -72,34 +65,32 @@ pub async fn parse_request(req: Request, config: &ProxyConfig) -> Result { let whole_body = hyper::body::to_bytes(body).await.unwrap(); - parse_request_scroll(None, &parts, &whole_body, config)? - }, - (&Method::POST, ["_search", "scroll", key]) | (&Method::DELETE, ["_search", "scroll", key]) => { - let whole_body = hyper::body::to_bytes(body).await.unwrap(); - parse_request_scroll(Some(key), &parts, &whole_body, config)? + filter_scroll_request(¶ms, &whole_body, config)? }, (&Method::GET, [index, "_search"]) | (&Method::POST, [index, "_search"]) => { let whole_body = hyper::body::to_bytes(body).await.unwrap(); - parse_request_search(index, &parts, &whole_body, config)? + filter_search_request(index, ¶ms, &whole_body, config)? }, (&Method::GET, [index, "_count"]) | (&Method::POST, [index, "_count"]) => { let whole_body = hyper::body::to_bytes(body).await.unwrap(); - parse_request_search(index, &parts, &whole_body, config)? + filter_search_request(index, ¶ms, &whole_body, config)? }, - //(Method::GET, [index, "_count"]) => { - // parse_request_count(index, "_count", None, &parts, body, config)? - //}, (&Method::GET, [index, "_doc", key]) | (&Method::GET, [index, "_source", key]) => { - parse_request_read(index, path_chunks[1], key, &parts, config)? + filter_read_request(index, path_chunks[1], key, ¶ms, config)? }, _ => Err(ProxyError::NotSupported("unknown endpoint".to_string()))?, }; - // TODO: pass-through query parameters + let upstream_query = serialize_params(¶ms); + let upstream_query_and_params = if upstream_query.len() > 0 { + format!("{}?{}", req_path, upstream_query) + } else { + req_path.to_string() + }; let upstream_uri = Uri::builder() .scheme("http") .authority(config.upstream_addr.as_ref().unwrap_or(&"localhost:9200".to_string()).as_str()) - .path_and_query(format!("{}", req_path).as_str()) + .path_and_query(upstream_query_and_params.as_str()) .build() .unwrap(); @@ -111,13 +102,14 @@ pub async fn parse_request(req: Request, config: &ProxyConfig) -> Result, parts: &request::Parts, body: &[u8], config: &ProxyConfig) -> Result { +pub fn filter_scroll_request(_params: &UrlQueryParams, _body: &[u8], _config: &ProxyConfig) -> Result { // XXX + // TODO: check that scroll_id is not "_all" //let _parsed: ScrollBody = serde_json::from_str(&body).unwrap(); Err(ProxyError::NotSupported("not yet implemented".to_string())) } -pub fn parse_request_read(index: &str, endpoint: &str, key: &str, parts: &request::Parts, config: &ProxyConfig) -> Result{ +pub fn filter_read_request(index: &str, _endpoint: &str, _key: &str, _params: &UrlQueryParams, config: &ProxyConfig) -> Result{ if !config.allow_index(index) { return Err(ProxyError::NotAllowed(format!("index doesn't exist or isn't proxied: {}", index))); } @@ -125,7 +117,7 @@ pub fn parse_request_read(index: &str, endpoint: &str, key: &str, parts: &reques Ok(Body::empty()) } -pub fn parse_request_search(index: &str, parts: &request::Parts, body: &[u8], config: &ProxyConfig) -> Result { +pub fn filter_search_request(index: &str, _params: &UrlQueryParams, body: &[u8], config: &ProxyConfig) -> Result { if !config.allow_index(index) { return Err(ProxyError::NotAllowed(format!("index doesn't exist or isn't proxied: {}", index))); } @@ -137,3 +129,39 @@ pub fn parse_request_search(index: &str, parts: &request::Parts, body: &[u8], co Ok(Body::empty()) } } + +pub fn parse_params(query: Option<&str>) -> Result { + println!("params: {:?}", query); + let raw_params: serde_json::map::Map = query + .map(|q| { + url::form_urlencoded::parse(q.as_bytes()) + .into_owned() + .map(|(k,v)| (k, serde_json::from_str(&v).unwrap())) + .collect() + }) + .unwrap_or_else(serde_json::map::Map::new); + let parsed: UrlQueryParams = serde_json::from_value(serde_json::Value::Object(raw_params)).unwrap(); + Ok(parsed) +} + +pub fn serialize_params(params: &UrlQueryParams) -> String { + + let json_value = serde_json::to_value(params).unwrap(); + let value_map: serde_json::map::Map = match json_value { + serde_json::Value::Object(val) => val, + _ => panic!("expected an object"), + }; + + let mut builder = url::form_urlencoded::Serializer::new(String::new()); + // XXX: array and object types should raise an error? + for (k, v) in value_map.iter() { + match v { + serde_json::Value::Null | serde_json::Value::Object(_) | serde_json::Value::Array(_) => (), + serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::String(_) => { + let string_val = serde_json::to_string(&v).unwrap(); + builder.append_pair(k, &string_val); + } + } + } + builder.finish() +} diff --git a/src/main.rs b/src/main.rs index 632c159..5e6c20c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,11 +5,11 @@ use std::net::SocketAddr; use std::env; use toml; -use es_public_proxy::{ProxyConfig, ProxyError, parse_request}; +use es_public_proxy::{ProxyConfig, filter_request}; async fn upstream_req(req: Request, config: ProxyConfig) -> Result, hyper::Error> { println!("hit: {}", req.uri()); - let parsed = parse_request(req, &config).await; + let parsed = filter_request(req, &config).await; let resp = match parsed { Ok(upstream_req) => { println!("sending request..."); diff --git a/src/parse.rs b/src/parse.rs index a4cdc3c..0bd1eeb 100644 --- a/src/parse.rs +++ b/src/parse.rs @@ -9,6 +9,8 @@ pub struct ApiRequest { pub body: Option, } +#[derive(Serialize, Deserialize, Debug, Default)] +#[serde(deny_unknown_fields)] pub struct UrlQueryParams { pub allow_no_indices: Option, pub allow_partial_search_results: Option, @@ -45,6 +47,12 @@ pub struct UrlQueryParams { pub track_total_hits: Option, // XXX: bool or integer pub typed_keys: Option, pub version: Option, + + // additional generic params + pub human: Option, + pub pretty: Option, + pub filter_path: Option, + pub error_trace: Option, } // https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html -- cgit v1.2.3