diff options
Diffstat (limited to 'adenosine-pds/src/lib.rs')
| -rw-r--r-- | adenosine-pds/src/lib.rs | 120 | 
1 files changed, 101 insertions, 19 deletions
diff --git a/adenosine-pds/src/lib.rs b/adenosine-pds/src/lib.rs index 0ef081d..e52cb69 100644 --- a/adenosine-pds/src/lib.rs +++ b/adenosine-pds/src/lib.rs @@ -1,6 +1,8 @@ -use anyhow::Result; +use anyhow::{anyhow, Result};  use log::{error, info}; -use rouille::{router, Request, Response}; +use rouille::{router, try_or_400, Request, Response}; +use serde_json::json; +use std::fmt;  use std::path::PathBuf;  use std::sync::Mutex; @@ -17,13 +19,61 @@ pub use db::AtpDatabase;  pub use models::*;  pub use repo::{RepoCommit, RepoStore}; +#[derive(Debug, serde::Deserialize, serde::Serialize, PartialEq, Eq)] +struct AccountRequest { +    email: String, +    username: String, +    password: String, +} + +struct AtpService { +    pub repo: RepoStore, +    pub atp_db: AtpDatabase, +} + +#[derive(Debug)] +enum XrpcError { +    BadRequest(String), +    NotFound(String), +    Forbidden(String), +} + +impl std::error::Error for XrpcError {} + +impl fmt::Display for XrpcError { +    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +        match self { +            Self::BadRequest(msg) | Self::NotFound(msg) | Self::Forbidden(msg) => { +                write!(f, "{}", msg) +            } +        } +    } +} + +/// Helper to take an XRPC result (always a JSON object), and transform it to a rouille response +fn xrpc_wrap<S: serde::Serialize>(resp: Result<S>) -> Response { +    match resp { +        Ok(val) => Response::json(&val), +        Err(e) => { +            let msg = e.to_string(); +            let code = match e.downcast_ref::<XrpcError>() { +                Some(XrpcError::BadRequest(_)) => 400, +                Some(XrpcError::NotFound(_)) => 404, +                Some(XrpcError::Forbidden(_)) => 403, +                None => 500, +            }; +            Response::json(&json!({ "message": msg })).with_status_code(code) +        } +    } +} +  pub fn run_server(port: u16, blockstore_db_path: &PathBuf, atp_db_path: &PathBuf) -> Result<()> {      // TODO: some static files? https://github.com/tomaka/rouille/blob/master/examples/static-files.rs -    // TODO: could just open connection on every request? -    let db = Mutex::new(AtpDatabase::open(atp_db_path)?); -    let mut _blockstore: BlockStore<libipld::DefaultParams> = -        BlockStore::open(blockstore_db_path, Default::default())?; +    let srv = Mutex::new(AtpService { +        repo: RepoStore::open(blockstore_db_path)?, +        atp_db: AtpDatabase::open(atp_db_path)?, +    });      let log_ok = |req: &Request, _resp: &Response, elap: std::time::Duration| {          info!("{} {} ({:?})", req.method(), req.raw_url(), elap); @@ -42,23 +92,55 @@ pub fn run_server(port: u16, blockstore_db_path: &PathBuf, atp_db_path: &PathBuf                  (GET) ["/"] => {                      Response::text("Not much to see here yet!")                  }, -                (GET) ["/xrpc/some.method"] => { -                    Response::text("didn't get a thing") -                    // TODO: reply with query params as a JSON body +                (POST) ["/xrpc/com.atproto.createAccount"] => { +                    let req: AccountRequest = try_or_400!(rouille::input::json_input(request)); +                    let mut srv = srv.lock().unwrap(); +                    xrpc_wrap(srv.atp_db.create_account(&req.username, &req.password, &req.email))                  }, -                (POST) ["/xrpc/other.method"] => { -                    Response::text("didn't get other thing") -                    // TODO: parse and echo back JSON body -                }, - -                (GET) ["/xrpc/com.atproto.getRecord"] => { -                    // TODO: JSON response -                    // TODO: handle error -                    let mut db = db.lock().unwrap().new_connection().unwrap(); -                    Response::text(db.get_record("asdf", "123", "blah").unwrap().to_string()) +                (GET) ["/xrpc/com.atproto.{endpoint}", endpoint: String] => { +                    xrpc_wrap(xrpc_get_atproto(&srv, &endpoint, request))                  },                  _ => rouille::Response::empty_404()              )          })      });  } + +fn xrpc_get_atproto( +    srv: &Mutex<AtpService>, +    method: &str, +    request: &Request, +) -> Result<serde_json::Value> { +    match method { +        "getRecord" => { +            let did = request.get_param("user").unwrap(); +            let collection = request.get_param("collection").unwrap(); +            let rkey = request.get_param("rkey").unwrap(); +            let repo_key = format!("/{}/{}", collection, rkey); +            let mut srv = srv.lock().expect("service mutex"); +            let commit_cid = srv.repo.lookup_commit(&did)?.unwrap(); +            let key = format!("/{}/{}", collection, rkey); +            match srv.repo.get_record_by_key(&commit_cid, &key) { +                // TODO: format as JSON, not text debug +                Ok(Some(ipld)) => Ok(json!({ "thing": format!("{:?}", ipld) })), +                Ok(None) => Err(anyhow!(XrpcError::NotFound(format!( +                    "could not find record: {}", +                    key +                )))), +                Err(e) => Err(e), +            } +        } +        "syncGetRoot" => { +            let did = request.get_param("did").unwrap(); +            let mut srv = srv.lock().expect("service mutex"); +            srv.repo +                .lookup_commit(&did)? +                .map(|v| json!({ "root": v })) +                .ok_or(anyhow!("XXX: missing")) +        } +        _ => Err(anyhow!(XrpcError::NotFound(format!( +            "XRPC endpoint handler not found: com.atproto.{}", +            method +        )))), +    } +}  | 
