From 01e5348c1c0ca9fbf2826e4e35d71a743ba28741 Mon Sep 17 00:00:00 2001 From: Bryan Newbold Date: Mon, 24 Aug 2020 23:08:03 -0700 Subject: more progress on parsing/validating --- src/lib.rs | 87 +++++++++++++++++++++++++++++++++++++++++-------------------- src/main.rs | 18 ++++++++----- 2 files changed, 71 insertions(+), 34 deletions(-) (limited to 'src') diff --git a/src/lib.rs b/src/lib.rs index 5fd54c9..8c60068 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,9 @@ +use std::collections::HashMap; use serde::{Serialize, Deserialize}; use hyper::{Request, Body, Method, Uri}; use http::request; +use url; pub mod parse; @@ -9,6 +11,7 @@ pub mod parse; pub struct ProxyConfig { pub bind_addr: Option, // 127.0.0.1:9292 pub upstream_addr: Option, // 127.0.0.1:9200 + pub allow_all_indices: Option, pub index: Vec } @@ -20,6 +23,9 @@ pub struct IndexConfig { impl ProxyConfig { pub fn allow_index(&self, name: &str) -> bool { + if self.allow_all_indices == Some(true) { + return true + } for index in &self.index { if index.name == name { return true @@ -30,16 +36,15 @@ impl ProxyConfig { } #[derive(Debug)] -pub enum ParsedRequest { +pub enum ProxyError { Malformed(String), ParseError(String), NotAllowed(String), NotSupported(String), NotFound(String), - Allowed(Request), } -pub fn parse_request(req: Request, config: &ProxyConfig) -> ParsedRequest { +pub async fn parse_request(req: Request, config: &ProxyConfig) -> Result, ProxyError> { let (parts, body) = req.into_parts(); // split path into at most 3 chunks @@ -49,60 +54,86 @@ pub fn parse_request(req: Request, config: &ProxyConfig) -> ParsedRequest } let path_chunks: Vec<&str> = req_path.split("/").collect(); if path_chunks.len() > 3 { - return ParsedRequest::NotSupported("only request paths with up to three segments allowed".to_string()) + return Err(ProxyError::NotSupported("only request paths with up to three segments allowed".to_string())) } - println!("{:?}", path_chunks); + let raw_params: HashMap = parts.uri.query() + .map(|v| { + url::form_urlencoded::parse(v.as_bytes()) + .into_owned() + .collect() + }) + .unwrap_or_else(HashMap::new); // this is sort of like a router - match (&parts.method, path_chunks.as_slice()) { + let body = match (&parts.method, path_chunks.as_slice()) { (&Method::GET, [""]) | (&Method::HEAD, [""]) => { - parse_request_basic("", &parts, config) + Body::empty() }, (&Method::POST, ["_search", "scroll"]) | (&Method::DELETE, ["_search", "scroll"]) => { - parse_request_scroll(None, &parts, body, config) + let whole_body = hyper::body::to_bytes(body).await.unwrap(); + parse_request_scroll(None, &parts, &whole_body, config)? }, (&Method::POST, ["_search", "scroll", key]) | (&Method::DELETE, ["_search", "scroll", key]) => { - parse_request_scroll(Some(key), &parts, body, config) + let whole_body = hyper::body::to_bytes(body).await.unwrap(); + parse_request_scroll(Some(key), &parts, &whole_body, config)? }, (&Method::GET, [index, "_search"]) | (&Method::POST, [index, "_search"]) => { - parse_request_search(index, &parts, body, config) + let whole_body = hyper::body::to_bytes(body).await.unwrap(); + parse_request_search(index, &parts, &whole_body, config)? + }, + (&Method::GET, [index, "_count"]) | (&Method::POST, [index, "_count"]) => { + let whole_body = hyper::body::to_bytes(body).await.unwrap(); + parse_request_search(index, &parts, &whole_body, config)? }, //(Method::GET, [index, "_count"]) => { - // parse_request_count(index, "_count", None, &parts, body, config) + // parse_request_count(index, "_count", None, &parts, body, config)? //}, (&Method::GET, [index, "_doc", key]) | (&Method::GET, [index, "_source", key]) => { - parse_request_read(index, path_chunks[1], key, &parts, body, config) + parse_request_read(index, path_chunks[1], key, &parts, config)? }, - _ => ParsedRequest::NotSupported("unknown endpoint".to_string()), - } -} + _ => Err(ProxyError::NotSupported("unknown endpoint".to_string()))?, + }; -pub fn parse_request_basic(endpoint: &str, parts: &request::Parts, config: &ProxyConfig) -> ParsedRequest { - // XXX: partial + // TODO: pass-through query parameters let upstream_uri = Uri::builder() .scheme("http") .authority(config.upstream_addr.as_ref().unwrap_or(&"localhost:9200".to_string()).as_str()) - .path_and_query(format!("/{}", endpoint).as_str()) + .path_and_query(format!("{}", req_path).as_str()) .build() .unwrap(); - println!("{:?}", upstream_uri); + let upstream_req = Request::builder() .uri(upstream_uri) .method(&parts.method) - .body(Body::empty()) + .body(body) .unwrap(); - ParsedRequest::Allowed(upstream_req) -} -pub fn parse_request_scroll(key: Option<&str>, parts: &request::Parts, body: Body, config: &ProxyConfig) -> ParsedRequest { - ParsedRequest::NotSupported("not yet implemented".to_string()) + Ok(upstream_req) +} +pub fn parse_request_scroll(key: Option<&str>, parts: &request::Parts, body: &[u8], config: &ProxyConfig) -> Result { + // XXX + //let _parsed: ScrollBody = serde_json::from_str(&body).unwrap(); + Err(ProxyError::NotSupported("not yet implemented".to_string())) } -pub fn parse_request_search(index: &str, parts: &request::Parts, body: Body, config: &ProxyConfig) -> ParsedRequest { - ParsedRequest::NotSupported("not yet implemented".to_string()) +pub fn parse_request_read(index: &str, endpoint: &str, key: &str, parts: &request::Parts, config: &ProxyConfig) -> Result{ + if !config.allow_index(index) { + return Err(ProxyError::NotAllowed(format!("index doesn't exist or isn't proxied: {}", index))); + } + // XXX: no body needed? + Ok(Body::empty()) } -pub fn parse_request_read(index: &str, endpoint: &str, key: &str, parts: &request::Parts, body: Body, config: &ProxyConfig) -> ParsedRequest { - ParsedRequest::NotSupported("not yet implemented".to_string()) +pub fn parse_request_search(index: &str, parts: &request::Parts, body: &[u8], config: &ProxyConfig) -> Result { + if !config.allow_index(index) { + return Err(ProxyError::NotAllowed(format!("index doesn't exist or isn't proxied: {}", index))); + } + // XXX: more checks + if body.len() > 0 { + let parsed: parse::ScrollBody = serde_json::from_slice(body).unwrap(); + Ok(Body::from(serde_json::to_string(&parsed).unwrap())) + } else { + Ok(Body::empty()) + } } diff --git a/src/main.rs b/src/main.rs index 017b8c8..632c159 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,21 +1,21 @@ use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Client, Request, Response, Server, Uri, StatusCode}; +use hyper::{Body, Client, Request, Response, Server, StatusCode}; use std::net::SocketAddr; use std::env; use toml; -use es_public_proxy::{ProxyConfig, ParsedRequest, parse_request}; +use es_public_proxy::{ProxyConfig, ProxyError, parse_request}; async fn upstream_req(req: Request, config: ProxyConfig) -> Result, hyper::Error> { println!("hit: {}", req.uri()); - let parsed = parse_request(req, &config); + let parsed = parse_request(req, &config).await; let resp = match parsed { - ParsedRequest::Allowed(upstream_req) => { + Ok(upstream_req) => { println!("sending request..."); Client::new().request(upstream_req).await? } - other => { + Err(other) => { Response::builder() .status(StatusCode::NOT_FOUND) .body(format!("oh noooo! {:?}", other).into()) @@ -68,6 +68,7 @@ fn load_config() -> ProxyConfig { let args: Vec = env::args().collect(); let args: Vec<&str> = args.iter().map(|x| x.as_str()).collect(); let mut config_path: Option = None; + let mut allow_all_indices = false; // first parse CLI arg match args.as_slice() { @@ -77,6 +78,7 @@ fn load_config() -> ProxyConfig { std::process::exit(0); }, [_, "--config", p] => { config_path = Some(p.to_string()) }, + [_, "--allow-all-indices"] => { allow_all_indices = true }, _ => { eprintln!("{}", usage()); eprintln!("couldn't parse arguments"); @@ -90,13 +92,17 @@ fn load_config() -> ProxyConfig { } // then either load config file (TOML), or use default config - if let Some(config_path) = config_path { + let mut config = if let Some(config_path) = config_path { let config_toml = std::fs::read_to_string(config_path).unwrap(); let config: ProxyConfig = toml::from_str(&config_toml).unwrap(); config } else { ProxyConfig::default() + }; + if allow_all_indices { + config.allow_all_indices = Some(true); } + config } #[tokio::main] -- cgit v1.2.3