diff options
Diffstat (limited to 'adenosine-pds/src/lib.rs')
-rw-r--r-- | adenosine-pds/src/lib.rs | 170 |
1 files changed, 154 insertions, 16 deletions
diff --git a/adenosine-pds/src/lib.rs b/adenosine-pds/src/lib.rs index 0d73881..913b089 100644 --- a/adenosine-pds/src/lib.rs +++ b/adenosine-pds/src/lib.rs @@ -1,10 +1,15 @@ +use adenosine_cli::{AtUri, Did, Nsid, Tid, TidLord}; +use anyhow::Context; use anyhow::{anyhow, Result}; +use libipld::Cid; use libipld::Ipld; use log::{debug, error, info, warn}; use rouille::{router, Request, Response}; use serde_json::{json, Value}; +use std::collections::BTreeMap; use std::fmt; use std::path::PathBuf; +use std::str::FromStr; use std::sync::Mutex; mod car; @@ -20,7 +25,7 @@ pub use car::{load_car_to_blockstore, load_car_to_sqlite}; pub use crypto::{KeyPair, PubKey}; pub use db::AtpDatabase; pub use models::*; -pub use repo::{RepoCommit, RepoStore}; +pub use repo::{Mutation, RepoCommit, RepoStore}; pub use ucan_p256::P256KeyMaterial; struct AtpService { @@ -28,6 +33,7 @@ struct AtpService { pub atp_db: AtpDatabase, pub pds_keypair: KeyPair, pub pds_public_url: String, + pub tid_gen: TidLord, } #[derive(Debug)] @@ -78,6 +84,7 @@ pub fn run_server( atp_db: AtpDatabase::open(atp_db_path)?, pds_keypair: keypair, pds_public_url: format!("http://localhost:{}", port).to_string(), + tid_gen: TidLord::new(), }); let log_ok = |req: &Request, _resp: &Response, elap: std::time::Duration| { @@ -130,6 +137,31 @@ fn ipld_into_json_value(val: Ipld) -> Value { } } +/// Crude reverse generation +/// +/// Does not handle base64 to bytes, and the link generation is pretty simple (object elements with +/// key "car"). Numbers always come through as f64 (float). +fn json_value_into_ipld(val: Value) -> Ipld { + match val { + Value::Null => Ipld::Null, + Value::Bool(b) => Ipld::Bool(b), + Value::String(s) => Ipld::String(s), + // TODO: handle numbers better? + Value::Number(v) => Ipld::Float(v.as_f64().unwrap()), + Value::Array(l) => Ipld::List(l.into_iter().map(|v| json_value_into_ipld(v)).collect()), + Value::Object(m) => { + let map: BTreeMap<String, Ipld> = BTreeMap::from_iter(m.into_iter().map(|(k, v)| { + if k == "car" && v.is_string() { + (k, Ipld::Link(Cid::from_str(v.as_str().unwrap()).unwrap())) + } else { + (k, json_value_into_ipld(v)) + } + })); + Ipld::Map(map) + } + } +} + fn xrpc_required_param(request: &Request, key: &str) -> Result<String> { Ok(request.get_param(key).ok_or(XrpcError::BadRequest(format!( "require '{}' query parameter", @@ -138,7 +170,11 @@ fn xrpc_required_param(request: &Request, key: &str) -> Result<String> { } /// Returns DID of validated user -fn xrpc_check_auth_header(srv: &mut AtpService, request: &Request) -> Result<String> { +fn xrpc_check_auth_header( + srv: &mut AtpService, + request: &Request, + req_did: Option<&Did>, +) -> Result<Did> { let header = request .header("Authorization") .ok_or(XrpcError::Forbidden(format!("require auth header")))?; @@ -146,10 +182,17 @@ fn xrpc_check_auth_header(srv: &mut AtpService, request: &Request) -> Result<Str Err(XrpcError::Forbidden(format!("require bearer token")))?; } let jwt = header.split(" ").nth(1).unwrap(); - match srv.atp_db.check_auth_token(&jwt)? { - Some(did) => Ok(did), + let did = match srv.atp_db.check_auth_token(&jwt)? { + Some(did) => did, None => Err(XrpcError::Forbidden(format!("session token not found")))?, + }; + let did = Did::from_str(&did)?; + if req_did.is_some() && Some(&did) != req_did { + Err(XrpcError::Forbidden(format!( + "can only modify your own repo" + )))?; } + Ok(did) } fn xrpc_get_handler( @@ -161,8 +204,8 @@ fn xrpc_get_handler( "com.atproto.getAccountsConfig" => { Ok(json!({"availableUserDomains": ["test"], "inviteCodeRequired": false})) } - "com.atproto.getRecord" => { - let did = xrpc_required_param(request, "did")?; + "com.atproto.repoGetRecord" => { + let did = Did::from_str(&xrpc_required_param(request, "user")?)?; let collection = xrpc_required_param(request, "collection")?; let rkey = xrpc_required_param(request, "rkey")?; let mut srv = srv.lock().expect("service mutex"); @@ -178,7 +221,7 @@ fn xrpc_get_handler( } } "com.atproto.syncGetRoot" => { - let did = xrpc_required_param(request, "did")?; + let did = Did::from_str(&xrpc_required_param(request, "did")?)?; let mut srv = srv.lock().expect("service mutex"); srv.repo .lookup_commit(&did)? @@ -188,15 +231,17 @@ fn xrpc_get_handler( "com.atproto.repoListRecords" => { // TODO: limit, before, after, tid, reverse // TODO: handle non-DID 'user' - // TODO: validate 'collection' as an NSID // TODO: limit result set size - let did = xrpc_required_param(request, "user")?; - let collection = xrpc_required_param(request, "collection")?; + let did = Did::from_str(&xrpc_required_param(request, "user")?)?; + let collection = Nsid::from_str(&xrpc_required_param(request, "collection")?)?; let mut record_list: Vec<Value> = vec![]; let mut srv = srv.lock().expect("service mutex"); - let full_map = srv.repo.mst_to_map(&did)?; + let commit_cid = &srv.repo.lookup_commit(&did)?.unwrap(); + let last_commit = srv.repo.get_commit(&commit_cid)?; + let full_map = srv.repo.mst_to_map(&last_commit.mst_cid)?; let prefix = format!("/{}/", collection); for (mst_key, cid) in full_map.iter() { + debug!("{}", mst_key); if mst_key.starts_with(&prefix) { let record = srv.repo.get_ipld(cid)?; record_list.push(json!({ @@ -209,15 +254,15 @@ fn xrpc_get_handler( Ok(json!({ "records": record_list })) } "com.atproto.repoDescribe" => { - let did = xrpc_required_param(request, "user")?; + let did = Did::from_str(&xrpc_required_param(request, "user")?)?; // TODO: resolve username? - let username = did.clone(); + let username = did.to_string(); let mut srv = srv.lock().expect("service mutex"); let did_doc = srv.atp_db.get_did_doc(&did)?; let collections: Vec<String> = srv.repo.collections(&did)?; let desc = RepoDescribe { name: username, - did: did, + did: did.to_string(), didDoc: did_doc, collections: collections, nameIsCorrect: true, @@ -292,7 +337,7 @@ fn xrpc_post_handler( Ok(json!(sess)) } "com.atproto.createSession" => { - let req: AccountRequest = rouille::input::json_input(request) + let req: SessionRequest = rouille::input::json_input(request) .map_err(|e| XrpcError::BadRequest(format!("failed to parse JSON body: {}", e)))?; let mut srv = srv.lock().unwrap(); let keypair = srv.pds_keypair.clone(); @@ -304,7 +349,7 @@ fn xrpc_post_handler( } "com.atproto.deleteSession" => { let mut srv = srv.lock().unwrap(); - let _did = xrpc_check_auth_header(&mut srv, request)?; + let _did = xrpc_check_auth_header(&mut srv, request, None)?; let header = request .header("Authorization") .ok_or(XrpcError::Forbidden(format!("require auth header")))?; @@ -319,6 +364,99 @@ fn xrpc_post_handler( }; Ok(json!({})) } + "com.atproto.repoBatchWrite" => { + let batch: RepoBatchWriteBody = rouille::input::json_input(request)?; + // TODO: validate edits against schemas + let did = Did::from_str(&xrpc_required_param(request, "did")?)?; + let mut srv = srv.lock().unwrap(); + let _auth_did = &xrpc_check_auth_header(&mut srv, request, Some(&did))?; + let commit_cid = &srv.repo.lookup_commit(&did)?.unwrap(); + let last_commit = srv.repo.get_commit(&commit_cid)?; + let mut mutations: Vec<Mutation> = Default::default(); + for w in batch.writes.iter() { + let m = match w.op_type.as_str() { + "create" => Mutation::Create( + Nsid::from_str(&w.collection)?, + // TODO: user input unwrap here + w.rkey + .as_ref() + .map(|t| Tid::from_str(&t).unwrap()) + .unwrap_or_else(|| srv.tid_gen.next()), + json_value_into_ipld(w.value.clone()), + ), + "update" => Mutation::Update( + Nsid::from_str(&w.collection)?, + Tid::from_str(w.rkey.as_ref().unwrap())?, + json_value_into_ipld(w.value.clone()), + ), + "delete" => Mutation::Delete( + Nsid::from_str(&w.collection)?, + Tid::from_str(w.rkey.as_ref().unwrap())?, + ), + _ => Err(anyhow!("unhandled operation type: {}", w.op_type))?, + }; + mutations.push(m); + } + let new_mst_cid = srv.repo.update_mst(&last_commit.mst_cid, &mutations)?; + let new_root_cid = srv.repo.write_root( + &last_commit.meta_cid, + Some(&last_commit.commit_cid), + &new_mst_cid, + )?; + srv.repo.write_commit(&did, &new_root_cid, "dummy-sig")?; + // TODO: next handle updates to database + Ok(json!({})) + } + "com.atproto.repoCreateRecord" => { + // TODO: validate edits against schemas + let did = Did::from_str(&xrpc_required_param(request, "did")?)?; + let collection = Nsid::from_str(&xrpc_required_param(request, "collection")?)?; + let record: Value = rouille::input::json_input(request)?; + let mut srv = srv.lock().unwrap(); + let _auth_did = &xrpc_check_auth_header(&mut srv, request, Some(&did))?; + debug!("reading commit"); + let commit_cid = &srv.repo.lookup_commit(&did)?.unwrap(); + let last_commit = srv.repo.get_commit(&commit_cid)?; + let mutations: Vec<Mutation> = vec![Mutation::Create( + collection, + srv.tid_gen.next(), + json_value_into_ipld(record), + )]; + debug!("mutating tree"); + let new_mst_cid = srv + .repo + .update_mst(&last_commit.mst_cid, &mutations) + .context("updating MST in repo")?; + debug!("writing new root"); + let new_root_cid = srv.repo.write_root( + &last_commit.meta_cid, + Some(&last_commit.commit_cid), + &new_mst_cid, + )?; + debug!("writing new commit"); + srv.repo.write_commit(&did, &new_root_cid, "dummy-sig")?; + // TODO: next handle updates to database + Ok(json!({})) + } + "com.atproto.repoDeleteRecord" => { + let did = Did::from_str(&xrpc_required_param(request, "did")?)?; + let collection = Nsid::from_str(&xrpc_required_param(request, "collection")?)?; + let tid = Tid::from_str(&xrpc_required_param(request, "rkey")?)?; + let mut srv = srv.lock().unwrap(); + let _auth_did = &xrpc_check_auth_header(&mut srv, request, Some(&did))?; + let commit_cid = &srv.repo.lookup_commit(&did)?.unwrap(); + let last_commit = srv.repo.get_commit(&commit_cid)?; + let mutations: Vec<Mutation> = vec![Mutation::Delete(collection, tid)]; + let new_mst_cid = srv.repo.update_mst(&last_commit.mst_cid, &mutations)?; + let new_root_cid = srv.repo.write_root( + &last_commit.meta_cid, + Some(&last_commit.commit_cid), + &new_mst_cid, + )?; + srv.repo.write_commit(&did, &new_root_cid, "dummy-sig")?; + // TODO: next handle updates to database + Ok(json!({})) + } _ => Err(anyhow!(XrpcError::NotFound(format!( "XRPC endpoint handler not found: {}", method |