summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/lib.rs84
-rw-r--r--src/main.rs4
-rw-r--r--src/parse.rs8
3 files changed, 66 insertions, 30 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 8c60068..ffc06bd 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,12 +1,11 @@
-use std::collections::HashMap;
-use serde::{Serialize, Deserialize};
+use serde::Deserialize;
use hyper::{Request, Body, Method, Uri};
-use http::request;
-use url;
pub mod parse;
+use parse::UrlQueryParams;
+
#[derive(Default, Deserialize, Debug, Clone)]
pub struct ProxyConfig {
pub bind_addr: Option<String>, // 127.0.0.1:9292
@@ -44,7 +43,7 @@ pub enum ProxyError {
NotFound(String),
}
-pub async fn parse_request(req: Request<Body>, config: &ProxyConfig) -> Result<Request<Body>, ProxyError> {
+pub async fn filter_request(req: Request<Body>, config: &ProxyConfig) -> Result<Request<Body>, ProxyError> {
let (parts, body) = req.into_parts();
// split path into at most 3 chunks
@@ -57,13 +56,7 @@ pub async fn parse_request(req: Request<Body>, config: &ProxyConfig) -> Result<R
return Err(ProxyError::NotSupported("only request paths with up to three segments allowed".to_string()))
}
- let raw_params: HashMap<String, String> = parts.uri.query()
- .map(|v| {
- url::form_urlencoded::parse(v.as_bytes())
- .into_owned()
- .collect()
- })
- .unwrap_or_else(HashMap::new);
+ let params = parse_params(parts.uri.query())?;
// this is sort of like a router
let body = match (&parts.method, path_chunks.as_slice()) {
@@ -72,34 +65,32 @@ pub async fn parse_request(req: Request<Body>, config: &ProxyConfig) -> Result<R
},
(&Method::POST, ["_search", "scroll"]) | (&Method::DELETE, ["_search", "scroll"]) => {
let whole_body = hyper::body::to_bytes(body).await.unwrap();
- parse_request_scroll(None, &parts, &whole_body, config)?
- },
- (&Method::POST, ["_search", "scroll", key]) | (&Method::DELETE, ["_search", "scroll", key]) => {
- let whole_body = hyper::body::to_bytes(body).await.unwrap();
- parse_request_scroll(Some(key), &parts, &whole_body, config)?
+ filter_scroll_request(&params, &whole_body, config)?
},
(&Method::GET, [index, "_search"]) | (&Method::POST, [index, "_search"]) => {
let whole_body = hyper::body::to_bytes(body).await.unwrap();
- parse_request_search(index, &parts, &whole_body, config)?
+ filter_search_request(index, &params, &whole_body, config)?
},
(&Method::GET, [index, "_count"]) | (&Method::POST, [index, "_count"]) => {
let whole_body = hyper::body::to_bytes(body).await.unwrap();
- parse_request_search(index, &parts, &whole_body, config)?
+ filter_search_request(index, &params, &whole_body, config)?
},
- //(Method::GET, [index, "_count"]) => {
- // parse_request_count(index, "_count", None, &parts, body, config)?
- //},
(&Method::GET, [index, "_doc", key]) | (&Method::GET, [index, "_source", key]) => {
- parse_request_read(index, path_chunks[1], key, &parts, config)?
+ filter_read_request(index, path_chunks[1], key, &params, config)?
},
_ => Err(ProxyError::NotSupported("unknown endpoint".to_string()))?,
};
- // TODO: pass-through query parameters
+ let upstream_query = serialize_params(&params);
+ let upstream_query_and_params = if upstream_query.len() > 0 {
+ format!("{}?{}", req_path, upstream_query)
+ } else {
+ req_path.to_string()
+ };
let upstream_uri = Uri::builder()
.scheme("http")
.authority(config.upstream_addr.as_ref().unwrap_or(&"localhost:9200".to_string()).as_str())
- .path_and_query(format!("{}", req_path).as_str())
+ .path_and_query(upstream_query_and_params.as_str())
.build()
.unwrap();
@@ -111,13 +102,14 @@ pub async fn parse_request(req: Request<Body>, config: &ProxyConfig) -> Result<R
Ok(upstream_req)
}
-pub fn parse_request_scroll(key: Option<&str>, parts: &request::Parts, body: &[u8], config: &ProxyConfig) -> Result<Body, ProxyError> {
+pub fn filter_scroll_request(_params: &UrlQueryParams, _body: &[u8], _config: &ProxyConfig) -> Result<Body, ProxyError> {
// XXX
+ // TODO: check that scroll_id is not "_all"
//let _parsed: ScrollBody = serde_json::from_str(&body).unwrap();
Err(ProxyError::NotSupported("not yet implemented".to_string()))
}
-pub fn parse_request_read(index: &str, endpoint: &str, key: &str, parts: &request::Parts, config: &ProxyConfig) -> Result<Body, ProxyError>{
+pub fn filter_read_request(index: &str, _endpoint: &str, _key: &str, _params: &UrlQueryParams, config: &ProxyConfig) -> Result<Body, ProxyError>{
if !config.allow_index(index) {
return Err(ProxyError::NotAllowed(format!("index doesn't exist or isn't proxied: {}", index)));
}
@@ -125,7 +117,7 @@ pub fn parse_request_read(index: &str, endpoint: &str, key: &str, parts: &reques
Ok(Body::empty())
}
-pub fn parse_request_search(index: &str, parts: &request::Parts, body: &[u8], config: &ProxyConfig) -> Result<Body, ProxyError> {
+pub fn filter_search_request(index: &str, _params: &UrlQueryParams, body: &[u8], config: &ProxyConfig) -> Result<Body, ProxyError> {
if !config.allow_index(index) {
return Err(ProxyError::NotAllowed(format!("index doesn't exist or isn't proxied: {}", index)));
}
@@ -137,3 +129,39 @@ pub fn parse_request_search(index: &str, parts: &request::Parts, body: &[u8], co
Ok(Body::empty())
}
}
+
+pub fn parse_params(query: Option<&str>) -> Result<UrlQueryParams, ProxyError> {
+ println!("params: {:?}", query);
+ let raw_params: serde_json::map::Map<String, serde_json::Value> = query
+ .map(|q| {
+ url::form_urlencoded::parse(q.as_bytes())
+ .into_owned()
+ .map(|(k,v)| (k, serde_json::from_str(&v).unwrap()))
+ .collect()
+ })
+ .unwrap_or_else(serde_json::map::Map::new);
+ let parsed: UrlQueryParams = serde_json::from_value(serde_json::Value::Object(raw_params)).unwrap();
+ Ok(parsed)
+}
+
+pub fn serialize_params(params: &UrlQueryParams) -> String {
+
+ let json_value = serde_json::to_value(params).unwrap();
+ let value_map: serde_json::map::Map<String, serde_json::Value> = match json_value {
+ serde_json::Value::Object(val) => val,
+ _ => panic!("expected an object"),
+ };
+
+ let mut builder = url::form_urlencoded::Serializer::new(String::new());
+ // XXX: array and object types should raise an error?
+ for (k, v) in value_map.iter() {
+ match v {
+ serde_json::Value::Null | serde_json::Value::Object(_) | serde_json::Value::Array(_) => (),
+ serde_json::Value::Bool(_) | serde_json::Value::Number(_) | serde_json::Value::String(_) => {
+ let string_val = serde_json::to_string(&v).unwrap();
+ builder.append_pair(k, &string_val);
+ }
+ }
+ }
+ builder.finish()
+}
diff --git a/src/main.rs b/src/main.rs
index 632c159..5e6c20c 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -5,11 +5,11 @@ use std::net::SocketAddr;
use std::env;
use toml;
-use es_public_proxy::{ProxyConfig, ProxyError, parse_request};
+use es_public_proxy::{ProxyConfig, filter_request};
async fn upstream_req(req: Request<Body>, config: ProxyConfig) -> Result<Response<Body>, hyper::Error> {
println!("hit: {}", req.uri());
- let parsed = parse_request(req, &config).await;
+ let parsed = filter_request(req, &config).await;
let resp = match parsed {
Ok(upstream_req) => {
println!("sending request...");
diff --git a/src/parse.rs b/src/parse.rs
index a4cdc3c..0bd1eeb 100644
--- a/src/parse.rs
+++ b/src/parse.rs
@@ -9,6 +9,8 @@ pub struct ApiRequest {
pub body: Option<SearchBody>,
}
+#[derive(Serialize, Deserialize, Debug, Default)]
+#[serde(deny_unknown_fields)]
pub struct UrlQueryParams {
pub allow_no_indices: Option<bool>,
pub allow_partial_search_results: Option<bool>,
@@ -45,6 +47,12 @@ pub struct UrlQueryParams {
pub track_total_hits: Option<bool>, // XXX: bool or integer
pub typed_keys: Option<bool>,
pub version: Option<bool>,
+
+ // additional generic params
+ pub human: Option<bool>,
+ pub pretty: Option<bool>,
+ pub filter_path: Option<String>,
+ pub error_trace: Option<bool>,
}
// https://www.elastic.co/guide/en/elasticsearch/reference/current/search-search.html