diff options
Diffstat (limited to 'src/lib.rs')
| -rw-r--r-- | src/lib.rs | 95 | 
1 files changed, 64 insertions, 31 deletions
| @@ -1,6 +1,5 @@ - +use hyper::{Body, Method, Request, Uri};  use serde::Deserialize; -use hyper::{Request, Body, Method, Uri};  use serde_json::json;  pub mod parse; @@ -9,10 +8,10 @@ use parse::UrlQueryParams;  #[derive(Default, Deserialize, Debug, Clone)]  pub struct ProxyConfig { -    pub bind_addr: Option<String>,      // 127.0.0.1:9292 -    pub upstream_addr: Option<String>,  // 127.0.0.1:9200 +    pub bind_addr: Option<String>,     // 127.0.0.1:9292 +    pub upstream_addr: Option<String>, // 127.0.0.1:9200      pub allow_all_indices: Option<bool>, -    pub index: Vec<IndexConfig> +    pub index: Vec<IndexConfig>,  }  #[derive(Deserialize, Debug, Clone)] @@ -21,14 +20,13 @@ pub struct IndexConfig {  }  impl ProxyConfig { -      pub fn allow_index(&self, name: &str) -> bool {          if self.allow_all_indices == Some(true) { -            return true +            return true;          }          for index in &self.index {              if index.name == name { -                return true +                return true;              }          }          false @@ -45,7 +43,6 @@ pub enum ProxyError {  }  impl ProxyError { -      pub fn to_json(&self) -> serde_json::Value {          json!({              "error": { @@ -57,7 +54,10 @@ impl ProxyError {      }  } -pub async fn filter_request(req: Request<Body>, config: &ProxyConfig) -> Result<Request<Body>, ProxyError> { +pub async fn filter_request( +    req: Request<Body>, +    config: &ProxyConfig, +) -> Result<Request<Body>, ProxyError> {      let (parts, body) = req.into_parts();      // split path into at most 3 chunks @@ -67,7 +67,9 @@ pub async fn filter_request(req: Request<Body>, config: &ProxyConfig) -> Result<      }      let path_chunks: Vec<&str> = req_path.split("/").collect();      if path_chunks.len() > 3 { -        return Err(ProxyError::NotSupported("only request paths with up to three segments allowed".to_string())) +        return Err(ProxyError::NotSupported( +            "only request paths with up to three segments allowed".to_string(), +        ));      }      let params: UrlQueryParams = serde_urlencoded::from_str(parts.uri.query().unwrap_or("")) @@ -75,30 +77,28 @@ pub async fn filter_request(req: Request<Body>, config: &ProxyConfig) -> Result<      // this is sort of like a router      let body = match (&parts.method, path_chunks.as_slice()) { -        (&Method::GET, [""]) | (&Method::HEAD, [""]) => { -            Body::empty() -        }, +        (&Method::GET, [""]) | (&Method::HEAD, [""]) => Body::empty(),          (&Method::POST, ["_search", "scroll"]) | (&Method::DELETE, ["_search", "scroll"]) => {              let whole_body = hyper::body::to_bytes(body)                  .await                  .map_err(|e| ProxyError::Malformed(e.to_string()))?;              filter_scroll_request(¶ms, &whole_body, config)? -        }, +        }          (&Method::GET, [index, "_search"]) | (&Method::POST, [index, "_search"]) => {              let whole_body = hyper::body::to_bytes(body)                  .await                  .map_err(|e| ProxyError::Malformed(e.to_string()))?;              filter_search_request(index, ¶ms, &whole_body, config)? -        }, +        }          (&Method::GET, [index, "_count"]) | (&Method::POST, [index, "_count"]) => {              let whole_body = hyper::body::to_bytes(body)                  .await                  .map_err(|e| ProxyError::Malformed(e.to_string()))?;              filter_search_request(index, ¶ms, &whole_body, config)? -        }, +        }          (&Method::GET, [index, "_doc", key]) | (&Method::GET, [index, "_source", key]) => {              filter_read_request(index, path_chunks[1], key, ¶ms, config)? -        }, +        }          _ => Err(ProxyError::NotSupported("unknown endpoint".to_string()))?,      }; @@ -110,7 +110,13 @@ pub async fn filter_request(req: Request<Body>, config: &ProxyConfig) -> Result<      };      let upstream_uri = Uri::builder()          .scheme("http") -        .authority(config.upstream_addr.as_ref().unwrap_or(&"localhost:9200".to_string()).as_str()) +        .authority( +            config +                .upstream_addr +                .as_ref() +                .unwrap_or(&"localhost:9200".to_string()) +                .as_str(), +        )          .path_and_query(upstream_query_and_params.as_str())          .build()          .expect("constructing upstream request URI"); @@ -123,21 +129,31 @@ pub async fn filter_request(req: Request<Body>, config: &ProxyConfig) -> Result<      Ok(upstream_req)  } -pub fn filter_scroll_request(_params: &UrlQueryParams, body: &[u8], _config: &ProxyConfig) -> Result<Body, ProxyError> { +pub fn filter_scroll_request( +    _params: &UrlQueryParams, +    body: &[u8], +    _config: &ProxyConfig, +) -> Result<Body, ProxyError> {      if body.len() > 0 { -        let parsed: parse::ScrollBody = serde_json::from_slice(body) -            .map_err(|e| ProxyError::ParseError(e.to_string()))?; +        let parsed: parse::ScrollBody = +            serde_json::from_slice(body).map_err(|e| ProxyError::ParseError(e.to_string()))?;          // check that scroll_id is not "_all" or too short          match &parsed.scroll_id {              parse::StringOrArray::String(single) => {                  if single == "_all" || single.len() < 8 { -                    return Err(ProxyError::NotSupported(format!("short scroll_id: {}", single))); +                    return Err(ProxyError::NotSupported(format!( +                        "short scroll_id: {}", +                        single +                    )));                  } -            }, +            }              parse::StringOrArray::Array(array) => {                  for single in array {                      if single == "_all" || single.len() < 8 { -                        return Err(ProxyError::NotSupported(format!("short scroll_id: {}", single))); +                        return Err(ProxyError::NotSupported(format!( +                            "short scroll_id: {}", +                            single +                        )));                      }                  }              } @@ -148,21 +164,38 @@ pub fn filter_scroll_request(_params: &UrlQueryParams, body: &[u8], _config: &Pr      }  } -pub fn filter_read_request(index: &str, _endpoint: &str, _key: &str, _params: &UrlQueryParams, config: &ProxyConfig) -> Result<Body, ProxyError>{ +pub fn filter_read_request( +    index: &str, +    _endpoint: &str, +    _key: &str, +    _params: &UrlQueryParams, +    config: &ProxyConfig, +) -> Result<Body, ProxyError> {      if !config.allow_index(index) { -        return Err(ProxyError::NotAllowed(format!("index doesn't exist or isn't proxied: {}", index))); +        return Err(ProxyError::NotAllowed(format!( +            "index doesn't exist or isn't proxied: {}", +            index +        )));      }      Ok(Body::empty())  } -pub fn filter_search_request(index: &str, _params: &UrlQueryParams, body: &[u8], config: &ProxyConfig) -> Result<Body, ProxyError> { +pub fn filter_search_request( +    index: &str, +    _params: &UrlQueryParams, +    body: &[u8], +    config: &ProxyConfig, +) -> Result<Body, ProxyError> {      if !config.allow_index(index) { -        return Err(ProxyError::NotAllowed(format!("index doesn't exist or isn't proxied: {}", index))); +        return Err(ProxyError::NotAllowed(format!( +            "index doesn't exist or isn't proxied: {}", +            index +        )));      }      // XXX: more checks      if body.len() > 0 { -        let parsed: parse::SearchBody = serde_json::from_slice(body) -            .map_err(|e| ProxyError::ParseError(e.to_string()))?; +        let parsed: parse::SearchBody = +            serde_json::from_slice(body).map_err(|e| ProxyError::ParseError(e.to_string()))?;          Ok(Body::from(serde_json::to_string(&parsed).unwrap()))      } else {          Ok(Body::empty()) | 
