From ae6aafc3d936e7e5211cf117ee581298ed74c8de Mon Sep 17 00:00:00 2001 From: Bryan Newbold Date: Mon, 24 Aug 2020 21:31:58 -0700 Subject: validation progress --- src/lib.rs | 96 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- src/main.rs | 30 ++++++++++--------- 2 files changed, 112 insertions(+), 14 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 8ea77d4..9f5b8a2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,10 +1,14 @@ +use hyper::{Request, Body, Method, Uri}; +use http::request; use serde::{Serialize, Deserialize}; use std::collections::HashMap; + #[derive(Default, Deserialize, Debug, Clone)] pub struct ProxyConfig { - pub bind_addr: Option, + pub bind_addr: Option, // 127.0.0.1:9292 + pub upstream_addr: Option, // 127.0.0.1:9200 pub index: Vec } @@ -13,6 +17,18 @@ pub struct IndexConfig { pub name: String, } +impl ProxyConfig { + + pub fn allow_index(&self, name: &str) -> bool { + for index in &self.index { + if index.name == name { + return true + } + } + false + } +} + #[derive(Serialize, Deserialize, Debug)] pub struct ApiRequest { pub method: String, @@ -251,3 +267,81 @@ pub struct InnerHits { sort: Option>, name: Option, } + +#[derive(Debug)] +pub enum ParsedRequest { + Malformed(String), + ParseError(String), + NotAllowed(String), + NotSupported(String), + NotFound(String), + Allowed(Request), +} + +pub fn parse_request(req: Request, config: &ProxyConfig) -> ParsedRequest { + let (parts, body) = req.into_parts(); + + // split path into at most 3 chunks + let mut req_path = parts.uri.path(); + if req_path.starts_with("/") { + req_path = &req_path[1..]; + } + let path_chunks: Vec<&str> = req_path.split("/").collect(); + if path_chunks.len() > 3 { + return ParsedRequest::NotSupported("only request paths with up to three segments allowed".to_string()) + } + + println!("{:?}", path_chunks); + + // this is sort of like a router + match (&parts.method, path_chunks.as_slice()) { + (&Method::GET, [""]) | (&Method::HEAD, [""]) => { + parse_request_basic("", &parts, config) + }, + (&Method::POST, ["_search", "scroll"]) | (&Method::DELETE, ["_search", "scroll"]) => { + parse_request_scroll(None, &parts, body, config) + }, + (&Method::POST, ["_search", "scroll", key]) | (&Method::DELETE, ["_search", "scroll", key]) => { + parse_request_scroll(Some(key), &parts, body, config) + }, + (&Method::GET, [index, "_search"]) | (&Method::POST, [index, "_search"]) => { + parse_request_search(index, &parts, body, config) + }, + //(Method::GET, [index, "_count"]) => { + // parse_request_count(index, "_count", None, &parts, body, config) + //}, + (&Method::GET, [index, "_doc", key]) | (&Method::GET, [index, "_source", key]) => { + parse_request_read(index, path_chunks[1], key, &parts, body, config) + }, + _ => ParsedRequest::NotSupported("unknown endpoint".to_string()), + } +} + +pub fn parse_request_basic(endpoint: &str, parts: &request::Parts, config: &ProxyConfig) -> ParsedRequest { + // XXX: partial + let upstream_uri = Uri::builder() + .scheme("http") + .authority(config.upstream_addr.as_ref().unwrap_or(&"localhost:9200".to_string()).as_str()) + .path_and_query(format!("/{}", endpoint).as_str()) + .build() + .unwrap(); + println!("{:?}", upstream_uri); + let upstream_req = Request::builder() + .uri(upstream_uri) + .method(&parts.method) + .body(Body::empty()) + .unwrap(); + ParsedRequest::Allowed(upstream_req) +} + +pub fn parse_request_scroll(key: Option<&str>, parts: &request::Parts, body: Body, config: &ProxyConfig) -> ParsedRequest { + ParsedRequest::NotSupported("not yet implemented".to_string()) +} + +pub fn parse_request_search(index: &str, parts: &request::Parts, body: Body, config: &ProxyConfig) -> ParsedRequest { + ParsedRequest::NotSupported("not yet implemented".to_string()) +} + +pub fn parse_request_read(index: &str, endpoint: &str, key: &str, parts: &request::Parts, body: Body, config: &ProxyConfig) -> ParsedRequest { + ParsedRequest::NotSupported("not yet implemented".to_string()) +} diff --git a/src/main.rs b/src/main.rs index 62cfd37..017b8c8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,25 +1,29 @@ use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Client, Request, Response, Server, Uri}; +use hyper::{Body, Client, Request, Response, Server, Uri, StatusCode}; use std::net::SocketAddr; use std::env; use toml; -use es_public_proxy::ProxyConfig; +use es_public_proxy::{ProxyConfig, ParsedRequest, parse_request}; -async fn upstream_req(req: Request, _config: ProxyConfig) -> Result, hyper::Error> { +async fn upstream_req(req: Request, config: ProxyConfig) -> Result, hyper::Error> { println!("hit: {}", req.uri()); - let req_uri = req.uri(); - let upstream_uri = Uri::builder() - .scheme("http") - .authority("localhost:9200") - .path_and_query(req_uri.path_and_query().unwrap().as_str()) - .build() - .unwrap(); - - let res = Client::new().get(upstream_uri).await?; + let parsed = parse_request(req, &config); + let resp = match parsed { + ParsedRequest::Allowed(upstream_req) => { + println!("sending request..."); + Client::new().request(upstream_req).await? + } + other => { + Response::builder() + .status(StatusCode::NOT_FOUND) + .body(format!("oh noooo! {:?}", other).into()) + .unwrap() + }, + }; println!("resp!"); - Ok(res) + Ok(resp) } async fn shutdown_signal() { -- cgit v1.2.3