use hyper::{Body, Method, Request, Uri}; use serde::Deserialize; use serde_json::json; pub mod parse; use parse::UrlQueryParams; #[derive(Default, Deserialize, Debug, Clone)] pub struct ProxyConfig { pub bind_addr: Option, // 127.0.0.1:9292 pub upstream_addr: Option, // 127.0.0.1:9200 pub unsafe_all_indices: Option, pub enable_cors: Option, pub index: Vec, } #[derive(Deserialize, Debug, Clone)] pub struct IndexConfig { pub name: String, } impl ProxyConfig { pub fn allow_index(&self, name: &str) -> bool { if self.unsafe_all_indices == Some(true) { return true; } for index in &self.index { if index.name == name { return true; } } false } } #[derive(Debug)] pub enum ProxyError { Malformed(String), ParseError(String), NotAllowed(String), NotSupported(String), NotFound(String), } impl ProxyError { pub fn to_json(&self) -> serde_json::Value { json!({ "error": { "reason": format!("{:?}", self), "type": "unknown", }, "status": 500, }) } } pub async fn filter_request( req: Request, config: &ProxyConfig, ) -> Result, ProxyError> { let (parts, body) = req.into_parts(); // split path into at most 3 chunks let mut req_path = parts.uri.path(); if req_path.starts_with("/") { req_path = &req_path[1..]; } let path_chunks: Vec<&str> = req_path.split("/").collect(); if path_chunks.len() > 3 { return Err(ProxyError::NotSupported( "only request paths with up to three segments allowed".to_string(), )); } let params: UrlQueryParams = serde_urlencoded::from_str(parts.uri.query().unwrap_or("")) .map_err(|e| ProxyError::ParseError(e.to_string()))?; // this is sort of like a router let body = match (&parts.method, path_chunks.as_slice()) { (&Method::GET, [""]) | (&Method::HEAD, [""]) | (&Method::OPTIONS, [""]) => Body::empty(), (&Method::HEAD, ["_search", "scroll"]) | (&Method::OPTIONS, ["_search", "scroll"]) => Body::empty(), (&Method::POST, ["_search", "scroll"]) | (&Method::DELETE, ["_search", "scroll"]) => { let whole_body = hyper::body::to_bytes(body) .await .map_err(|e| ProxyError::Malformed(e.to_string()))?; filter_scroll_request(¶ms, &whole_body, config)? } (&Method::HEAD, [index, "_search"]) | (&Method::OPTIONS, [index, "_search"]) => { filter_search_request(index, ¶ms, &[], config)? } (&Method::GET, [index, "_search"]) | (&Method::POST, [index, "_search"]) => { let whole_body = hyper::body::to_bytes(body) .await .map_err(|e| ProxyError::Malformed(e.to_string()))?; filter_search_request(index, ¶ms, &whole_body, config)? } (&Method::HEAD, [index, "_count"]) | (&Method::OPTIONS, [index, "_count"]) => { filter_search_request(index, ¶ms, &[], config)? } (&Method::GET, [index, "_count"]) | (&Method::POST, [index, "_count"]) => { let whole_body = hyper::body::to_bytes(body) .await .map_err(|e| ProxyError::Malformed(e.to_string()))?; filter_search_request(index, ¶ms, &whole_body, config)? } (&Method::GET, [index, "_doc", _key]) | (&Method::GET, [index, "_source", _key]) | (&Method::HEAD, [index, "_doc", _key]) | (&Method::OPTIONS, [index, "_source", _key]) => { filter_read_request(index, path_chunks[1], ¶ms, config)? } (&Method::GET, [index, ""]) | (&Method::HEAD, [index, ""]) | (&Method::OPTIONS, [index, ""]) => { filter_read_request(index, path_chunks[1], ¶ms, config)? } (&Method::GET, [index, "_mapping"]) | (&Method::HEAD, [index, "_mapping"]) | (&Method::OPTIONS, [index, "_mapping"]) => { filter_read_request(index, path_chunks[1], ¶ms, config)? } _ => Err(ProxyError::NotSupported("unknown endpoint".to_string()))?, }; let upstream_query = serde_urlencoded::to_string(params).expect("re-encoding URL parameters"); let upstream_query_and_params = if upstream_query.len() > 0 { format!("{}?{}", req_path, upstream_query) } else { req_path.to_string() }; let upstream_uri = Uri::builder() .scheme("http") .authority( config .upstream_addr .as_ref() .unwrap_or(&"localhost:9200".to_string()) .as_str(), ) .path_and_query(upstream_query_and_params.as_str()) .build() .expect("constructing upstream request URI"); let upstream_req = Request::builder() .uri(upstream_uri) .method(&parts.method) .header("Content-Type", "application/json; charset=UTF-8") .body(body) .expect("constructing upstream request"); Ok(upstream_req) } pub fn filter_scroll_request( _params: &UrlQueryParams, body: &[u8], _config: &ProxyConfig, ) -> Result { if body.len() > 0 { let parsed: parse::ScrollBody = serde_json::from_slice(body).map_err(|e| ProxyError::ParseError(e.to_string()))?; // check that scroll_id is not "_all" or too short match &parsed.scroll_id { parse::StringOrArray::String(single) => { if single == "_all" || single.len() < 8 { return Err(ProxyError::NotSupported(format!( "short scroll_id: {}", single ))); } } parse::StringOrArray::Array(array) => { for single in array { if single == "_all" || single.len() < 8 { return Err(ProxyError::NotSupported(format!( "short scroll_id: {}", single ))); } } } } Ok(Body::from(serde_json::to_string(&parsed).unwrap())) } else { Ok(Body::empty()) } } pub fn filter_read_request( index: &str, _endpoint: &str, _params: &UrlQueryParams, config: &ProxyConfig, ) -> Result { if !config.allow_index(index) { return Err(ProxyError::NotAllowed(format!( "index doesn't exist or isn't proxied: {}", index ))); } Ok(Body::empty()) } pub fn filter_search_request( index: &str, _params: &UrlQueryParams, body: &[u8], config: &ProxyConfig, ) -> Result { if !config.allow_index(index) { return Err(ProxyError::NotAllowed(format!( "index doesn't exist or isn't proxied: {}", index ))); } // XXX: more checks if body.len() > 0 { let parsed: parse::SearchBody = serde_json::from_slice(body).map_err(|e| ProxyError::ParseError(e.to_string()))?; Ok(Body::from(serde_json::to_string(&parsed).unwrap())) } else { Ok(Body::empty()) } }