From 13220ca46bbc9fd0001c1c942c3b7238e0f596ee Mon Sep 17 00:00:00 2001 From: Bryan Newbold Date: Wed, 26 Aug 2020 18:21:52 -0700 Subject: refactor errors; fix header names; fmt --- src/lib.rs | 72 ++++++++++++++++++++++++++++++++++++++++--------------------- src/main.rs | 24 ++++++++++----------- 2 files changed, 60 insertions(+), 36 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 652e6be..9093b24 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -use hyper::{Body, Method, Request, Uri}; +use hyper::{Body, Method, Request, StatusCode, Uri}; use serde::Deserialize; use serde_json::json; @@ -36,21 +36,42 @@ impl ProxyConfig { #[derive(Debug)] pub enum ProxyError { - Malformed(String), + HttpError(String), ParseError(String), - NotAllowed(String), + UnknownIndex(String), NotSupported(String), - NotFound(String), } impl ProxyError { - pub fn to_json(&self) -> serde_json::Value { + pub fn http_status_code(&self) -> StatusCode { + match self { + ProxyError::HttpError(_) => StatusCode::BAD_REQUEST, + ProxyError::ParseError(_) => StatusCode::BAD_REQUEST, + ProxyError::UnknownIndex(_) => StatusCode::NOT_FOUND, + ProxyError::NotSupported(_) => StatusCode::FORBIDDEN, + } + } + + pub fn to_json_value(&self) -> serde_json::Value { + let (type_slug, reason) = match self { + ProxyError::HttpError(s) => ("http-error", s.clone()), + ProxyError::ParseError(s) => ("parse-error", s.clone()), + ProxyError::UnknownIndex(index) => ( + "unknown-index", + format!( + "index does not exists, or public access not allowed: {}", + index + ), + ), + ProxyError::NotSupported(s) => ("not-supported", s.clone()), + }; + json!({ "error": { - "reason": format!("{:?}", self), - "type": "unknown", + "reason": reason, + "type": type_slug, }, - "status": 500, + "status": self.http_status_code().as_u16(), }) } } @@ -79,11 +100,13 @@ pub async fn filter_request( // this is sort of like a router let body = match (&parts.method, path_chunks.as_slice()) { (&Method::GET, [""]) | (&Method::HEAD, [""]) | (&Method::OPTIONS, [""]) => Body::empty(), - (&Method::HEAD, ["_search", "scroll"]) | (&Method::OPTIONS, ["_search", "scroll"]) => Body::empty(), + (&Method::HEAD, ["_search", "scroll"]) | (&Method::OPTIONS, ["_search", "scroll"]) => { + Body::empty() + } (&Method::POST, ["_search", "scroll"]) | (&Method::DELETE, ["_search", "scroll"]) => { let whole_body = hyper::body::to_bytes(body) .await - .map_err(|e| ProxyError::Malformed(e.to_string()))?; + .map_err(|e| ProxyError::HttpError(e.to_string()))?; filter_scroll_request(¶ms, &whole_body, config)? } (&Method::HEAD, [index, "_search"]) | (&Method::OPTIONS, [index, "_search"]) => { @@ -92,7 +115,7 @@ pub async fn filter_request( (&Method::GET, [index, "_search"]) | (&Method::POST, [index, "_search"]) => { let whole_body = hyper::body::to_bytes(body) .await - .map_err(|e| ProxyError::Malformed(e.to_string()))?; + .map_err(|e| ProxyError::HttpError(e.to_string()))?; filter_search_request(index, ¶ms, &whole_body, config)? } (&Method::HEAD, [index, "_count"]) | (&Method::OPTIONS, [index, "_count"]) => { @@ -101,19 +124,26 @@ pub async fn filter_request( (&Method::GET, [index, "_count"]) | (&Method::POST, [index, "_count"]) => { let whole_body = hyper::body::to_bytes(body) .await - .map_err(|e| ProxyError::Malformed(e.to_string()))?; + .map_err(|e| ProxyError::HttpError(e.to_string()))?; filter_search_request(index, ¶ms, &whole_body, config)? } - (&Method::GET, [index, "_doc", _key]) | (&Method::GET, [index, "_source", _key]) | (&Method::HEAD, [index, "_doc", _key]) | (&Method::OPTIONS, [index, "_source", _key]) => { + (&Method::GET, [index, "_doc", _key]) + | (&Method::GET, [index, "_source", _key]) + | (&Method::HEAD, [index, "_doc", _key]) + | (&Method::OPTIONS, [index, "_source", _key]) => { filter_read_request(index, path_chunks[1], ¶ms, config)? } - (&Method::GET, [index, ""]) | (&Method::HEAD, [index, ""]) | (&Method::OPTIONS, [index, ""]) => { + (&Method::GET, [index, ""]) + | (&Method::HEAD, [index, ""]) + | (&Method::OPTIONS, [index, ""]) => { filter_read_request(index, path_chunks[1], ¶ms, config)? } - (&Method::GET, [index, "_mapping"]) | (&Method::HEAD, [index, "_mapping"]) | (&Method::OPTIONS, [index, "_mapping"]) => { + (&Method::GET, [index, "_mapping"]) + | (&Method::HEAD, [index, "_mapping"]) + | (&Method::OPTIONS, [index, "_mapping"]) => { filter_read_request(index, path_chunks[1], ¶ms, config)? } - _ => Err(ProxyError::NotSupported("unknown endpoint".to_string()))?, + _ => Err(ProxyError::NotSupported("unknown elasticsearch API endpoint".to_string()))?, }; let upstream_query = serde_urlencoded::to_string(params).expect("re-encoding URL parameters"); @@ -186,10 +216,7 @@ pub fn filter_read_request( config: &ProxyConfig, ) -> Result { if !config.allow_index(index) { - return Err(ProxyError::NotAllowed(format!( - "index doesn't exist or isn't proxied: {}", - index - ))); + return Err(ProxyError::UnknownIndex(index.to_string())); } Ok(Body::empty()) } @@ -201,10 +228,7 @@ pub fn filter_search_request( config: &ProxyConfig, ) -> Result { if !config.allow_index(index) { - return Err(ProxyError::NotAllowed(format!( - "index doesn't exist or isn't proxied: {}", - index - ))); + return Err(ProxyError::UnknownIndex(index.to_string())); } // XXX: more checks if body.len() > 0 { diff --git a/src/main.rs b/src/main.rs index 5f6b574..b55999b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Client, Request, Response, Server, StatusCode, header::HeaderValue}; +use hyper::{header::HeaderValue, Body, Client, Request, Response, Server}; use std::env; use std::net::SocketAddr; use toml; @@ -23,22 +23,22 @@ async fn upstream_req( Client::new().request(upstream_req).await? } Err(other) => Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) + .status(other.http_status_code()) .header("Content-Type", "application/json; charset=UTF-8") - .body(serde_json::to_string(&other.to_json()).unwrap().into()) + .body( + serde_json::to_string(&other.to_json_value()) + .unwrap() + .into(), + ) .unwrap(), }; - resp.headers_mut().insert( - "Via", - HeaderValue::from_static("es-public-proxy"), - ); + resp.headers_mut() + .insert("Via", HeaderValue::from_static("1.1 es-public-proxy")); if config.enable_cors == Some(true) { + resp.headers_mut() + .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*")); resp.headers_mut().insert( - "Access-Control-Allow-Origin", - HeaderValue::from_static("*"), - ); - resp.headers_mut().insert( - "Access-Control-Allow-Origin", + "Access-Control-Allow-Methods", HeaderValue::from_static("GET, POST, DELETE, HEAD, OPTIONS"), ); resp.headers_mut().insert( -- cgit v1.2.3