diff options
Diffstat (limited to 'adenosine-pds/src/lib.rs')
-rw-r--r-- | adenosine-pds/src/lib.rs | 120 |
1 files changed, 101 insertions, 19 deletions
diff --git a/adenosine-pds/src/lib.rs b/adenosine-pds/src/lib.rs index 0ef081d..e52cb69 100644 --- a/adenosine-pds/src/lib.rs +++ b/adenosine-pds/src/lib.rs @@ -1,6 +1,8 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use log::{error, info}; -use rouille::{router, Request, Response}; +use rouille::{router, try_or_400, Request, Response}; +use serde_json::json; +use std::fmt; use std::path::PathBuf; use std::sync::Mutex; @@ -17,13 +19,61 @@ pub use db::AtpDatabase; pub use models::*; pub use repo::{RepoCommit, RepoStore}; +#[derive(Debug, serde::Deserialize, serde::Serialize, PartialEq, Eq)] +struct AccountRequest { + email: String, + username: String, + password: String, +} + +struct AtpService { + pub repo: RepoStore, + pub atp_db: AtpDatabase, +} + +#[derive(Debug)] +enum XrpcError { + BadRequest(String), + NotFound(String), + Forbidden(String), +} + +impl std::error::Error for XrpcError {} + +impl fmt::Display for XrpcError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::BadRequest(msg) | Self::NotFound(msg) | Self::Forbidden(msg) => { + write!(f, "{}", msg) + } + } + } +} + +/// Helper to take an XRPC result (always a JSON object), and transform it to a rouille response +fn xrpc_wrap<S: serde::Serialize>(resp: Result<S>) -> Response { + match resp { + Ok(val) => Response::json(&val), + Err(e) => { + let msg = e.to_string(); + let code = match e.downcast_ref::<XrpcError>() { + Some(XrpcError::BadRequest(_)) => 400, + Some(XrpcError::NotFound(_)) => 404, + Some(XrpcError::Forbidden(_)) => 403, + None => 500, + }; + Response::json(&json!({ "message": msg })).with_status_code(code) + } + } +} + pub fn run_server(port: u16, blockstore_db_path: &PathBuf, atp_db_path: &PathBuf) -> Result<()> { // TODO: some static files? https://github.com/tomaka/rouille/blob/master/examples/static-files.rs - // TODO: could just open connection on every request? - let db = Mutex::new(AtpDatabase::open(atp_db_path)?); - let mut _blockstore: BlockStore<libipld::DefaultParams> = - BlockStore::open(blockstore_db_path, Default::default())?; + let srv = Mutex::new(AtpService { + repo: RepoStore::open(blockstore_db_path)?, + atp_db: AtpDatabase::open(atp_db_path)?, + }); let log_ok = |req: &Request, _resp: &Response, elap: std::time::Duration| { info!("{} {} ({:?})", req.method(), req.raw_url(), elap); @@ -42,23 +92,55 @@ pub fn run_server(port: u16, blockstore_db_path: &PathBuf, atp_db_path: &PathBuf (GET) ["/"] => { Response::text("Not much to see here yet!") }, - (GET) ["/xrpc/some.method"] => { - Response::text("didn't get a thing") - // TODO: reply with query params as a JSON body + (POST) ["/xrpc/com.atproto.createAccount"] => { + let req: AccountRequest = try_or_400!(rouille::input::json_input(request)); + let mut srv = srv.lock().unwrap(); + xrpc_wrap(srv.atp_db.create_account(&req.username, &req.password, &req.email)) }, - (POST) ["/xrpc/other.method"] => { - Response::text("didn't get other thing") - // TODO: parse and echo back JSON body - }, - - (GET) ["/xrpc/com.atproto.getRecord"] => { - // TODO: JSON response - // TODO: handle error - let mut db = db.lock().unwrap().new_connection().unwrap(); - Response::text(db.get_record("asdf", "123", "blah").unwrap().to_string()) + (GET) ["/xrpc/com.atproto.{endpoint}", endpoint: String] => { + xrpc_wrap(xrpc_get_atproto(&srv, &endpoint, request)) }, _ => rouille::Response::empty_404() ) }) }); } + +fn xrpc_get_atproto( + srv: &Mutex<AtpService>, + method: &str, + request: &Request, +) -> Result<serde_json::Value> { + match method { + "getRecord" => { + let did = request.get_param("user").unwrap(); + let collection = request.get_param("collection").unwrap(); + let rkey = request.get_param("rkey").unwrap(); + let repo_key = format!("/{}/{}", collection, rkey); + let mut srv = srv.lock().expect("service mutex"); + let commit_cid = srv.repo.lookup_commit(&did)?.unwrap(); + let key = format!("/{}/{}", collection, rkey); + match srv.repo.get_record_by_key(&commit_cid, &key) { + // TODO: format as JSON, not text debug + Ok(Some(ipld)) => Ok(json!({ "thing": format!("{:?}", ipld) })), + Ok(None) => Err(anyhow!(XrpcError::NotFound(format!( + "could not find record: {}", + key + )))), + Err(e) => Err(e), + } + } + "syncGetRoot" => { + let did = request.get_param("did").unwrap(); + let mut srv = srv.lock().expect("service mutex"); + srv.repo + .lookup_commit(&did)? + .map(|v| json!({ "root": v })) + .ok_or(anyhow!("XXX: missing")) + } + _ => Err(anyhow!(XrpcError::NotFound(format!( + "XRPC endpoint handler not found: com.atproto.{}", + method + )))), + } +} |