aboutsummaryrefslogtreecommitdiffstats
path: root/adenosine-pds/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'adenosine-pds/src/lib.rs')
-rw-r--r--adenosine-pds/src/lib.rs170
1 files changed, 154 insertions, 16 deletions
diff --git a/adenosine-pds/src/lib.rs b/adenosine-pds/src/lib.rs
index 0d73881..913b089 100644
--- a/adenosine-pds/src/lib.rs
+++ b/adenosine-pds/src/lib.rs
@@ -1,10 +1,15 @@
+use adenosine_cli::{AtUri, Did, Nsid, Tid, TidLord};
+use anyhow::Context;
use anyhow::{anyhow, Result};
+use libipld::Cid;
use libipld::Ipld;
use log::{debug, error, info, warn};
use rouille::{router, Request, Response};
use serde_json::{json, Value};
+use std::collections::BTreeMap;
use std::fmt;
use std::path::PathBuf;
+use std::str::FromStr;
use std::sync::Mutex;
mod car;
@@ -20,7 +25,7 @@ pub use car::{load_car_to_blockstore, load_car_to_sqlite};
pub use crypto::{KeyPair, PubKey};
pub use db::AtpDatabase;
pub use models::*;
-pub use repo::{RepoCommit, RepoStore};
+pub use repo::{Mutation, RepoCommit, RepoStore};
pub use ucan_p256::P256KeyMaterial;
struct AtpService {
@@ -28,6 +33,7 @@ struct AtpService {
pub atp_db: AtpDatabase,
pub pds_keypair: KeyPair,
pub pds_public_url: String,
+ pub tid_gen: TidLord,
}
#[derive(Debug)]
@@ -78,6 +84,7 @@ pub fn run_server(
atp_db: AtpDatabase::open(atp_db_path)?,
pds_keypair: keypair,
pds_public_url: format!("http://localhost:{}", port).to_string(),
+ tid_gen: TidLord::new(),
});
let log_ok = |req: &Request, _resp: &Response, elap: std::time::Duration| {
@@ -130,6 +137,31 @@ fn ipld_into_json_value(val: Ipld) -> Value {
}
}
+/// Crude reverse generation
+///
+/// Does not handle base64 to bytes, and the link generation is pretty simple (object elements with
+/// key "car"). Numbers always come through as f64 (float).
+fn json_value_into_ipld(val: Value) -> Ipld {
+ match val {
+ Value::Null => Ipld::Null,
+ Value::Bool(b) => Ipld::Bool(b),
+ Value::String(s) => Ipld::String(s),
+ // TODO: handle numbers better?
+ Value::Number(v) => Ipld::Float(v.as_f64().unwrap()),
+ Value::Array(l) => Ipld::List(l.into_iter().map(|v| json_value_into_ipld(v)).collect()),
+ Value::Object(m) => {
+ let map: BTreeMap<String, Ipld> = BTreeMap::from_iter(m.into_iter().map(|(k, v)| {
+ if k == "car" && v.is_string() {
+ (k, Ipld::Link(Cid::from_str(v.as_str().unwrap()).unwrap()))
+ } else {
+ (k, json_value_into_ipld(v))
+ }
+ }));
+ Ipld::Map(map)
+ }
+ }
+}
+
fn xrpc_required_param(request: &Request, key: &str) -> Result<String> {
Ok(request.get_param(key).ok_or(XrpcError::BadRequest(format!(
"require '{}' query parameter",
@@ -138,7 +170,11 @@ fn xrpc_required_param(request: &Request, key: &str) -> Result<String> {
}
/// Returns DID of validated user
-fn xrpc_check_auth_header(srv: &mut AtpService, request: &Request) -> Result<String> {
+fn xrpc_check_auth_header(
+ srv: &mut AtpService,
+ request: &Request,
+ req_did: Option<&Did>,
+) -> Result<Did> {
let header = request
.header("Authorization")
.ok_or(XrpcError::Forbidden(format!("require auth header")))?;
@@ -146,10 +182,17 @@ fn xrpc_check_auth_header(srv: &mut AtpService, request: &Request) -> Result<Str
Err(XrpcError::Forbidden(format!("require bearer token")))?;
}
let jwt = header.split(" ").nth(1).unwrap();
- match srv.atp_db.check_auth_token(&jwt)? {
- Some(did) => Ok(did),
+ let did = match srv.atp_db.check_auth_token(&jwt)? {
+ Some(did) => did,
None => Err(XrpcError::Forbidden(format!("session token not found")))?,
+ };
+ let did = Did::from_str(&did)?;
+ if req_did.is_some() && Some(&did) != req_did {
+ Err(XrpcError::Forbidden(format!(
+ "can only modify your own repo"
+ )))?;
}
+ Ok(did)
}
fn xrpc_get_handler(
@@ -161,8 +204,8 @@ fn xrpc_get_handler(
"com.atproto.getAccountsConfig" => {
Ok(json!({"availableUserDomains": ["test"], "inviteCodeRequired": false}))
}
- "com.atproto.getRecord" => {
- let did = xrpc_required_param(request, "did")?;
+ "com.atproto.repoGetRecord" => {
+ let did = Did::from_str(&xrpc_required_param(request, "user")?)?;
let collection = xrpc_required_param(request, "collection")?;
let rkey = xrpc_required_param(request, "rkey")?;
let mut srv = srv.lock().expect("service mutex");
@@ -178,7 +221,7 @@ fn xrpc_get_handler(
}
}
"com.atproto.syncGetRoot" => {
- let did = xrpc_required_param(request, "did")?;
+ let did = Did::from_str(&xrpc_required_param(request, "did")?)?;
let mut srv = srv.lock().expect("service mutex");
srv.repo
.lookup_commit(&did)?
@@ -188,15 +231,17 @@ fn xrpc_get_handler(
"com.atproto.repoListRecords" => {
// TODO: limit, before, after, tid, reverse
// TODO: handle non-DID 'user'
- // TODO: validate 'collection' as an NSID
// TODO: limit result set size
- let did = xrpc_required_param(request, "user")?;
- let collection = xrpc_required_param(request, "collection")?;
+ let did = Did::from_str(&xrpc_required_param(request, "user")?)?;
+ let collection = Nsid::from_str(&xrpc_required_param(request, "collection")?)?;
let mut record_list: Vec<Value> = vec![];
let mut srv = srv.lock().expect("service mutex");
- let full_map = srv.repo.mst_to_map(&did)?;
+ let commit_cid = &srv.repo.lookup_commit(&did)?.unwrap();
+ let last_commit = srv.repo.get_commit(&commit_cid)?;
+ let full_map = srv.repo.mst_to_map(&last_commit.mst_cid)?;
let prefix = format!("/{}/", collection);
for (mst_key, cid) in full_map.iter() {
+ debug!("{}", mst_key);
if mst_key.starts_with(&prefix) {
let record = srv.repo.get_ipld(cid)?;
record_list.push(json!({
@@ -209,15 +254,15 @@ fn xrpc_get_handler(
Ok(json!({ "records": record_list }))
}
"com.atproto.repoDescribe" => {
- let did = xrpc_required_param(request, "user")?;
+ let did = Did::from_str(&xrpc_required_param(request, "user")?)?;
// TODO: resolve username?
- let username = did.clone();
+ let username = did.to_string();
let mut srv = srv.lock().expect("service mutex");
let did_doc = srv.atp_db.get_did_doc(&did)?;
let collections: Vec<String> = srv.repo.collections(&did)?;
let desc = RepoDescribe {
name: username,
- did: did,
+ did: did.to_string(),
didDoc: did_doc,
collections: collections,
nameIsCorrect: true,
@@ -292,7 +337,7 @@ fn xrpc_post_handler(
Ok(json!(sess))
}
"com.atproto.createSession" => {
- let req: AccountRequest = rouille::input::json_input(request)
+ let req: SessionRequest = rouille::input::json_input(request)
.map_err(|e| XrpcError::BadRequest(format!("failed to parse JSON body: {}", e)))?;
let mut srv = srv.lock().unwrap();
let keypair = srv.pds_keypair.clone();
@@ -304,7 +349,7 @@ fn xrpc_post_handler(
}
"com.atproto.deleteSession" => {
let mut srv = srv.lock().unwrap();
- let _did = xrpc_check_auth_header(&mut srv, request)?;
+ let _did = xrpc_check_auth_header(&mut srv, request, None)?;
let header = request
.header("Authorization")
.ok_or(XrpcError::Forbidden(format!("require auth header")))?;
@@ -319,6 +364,99 @@ fn xrpc_post_handler(
};
Ok(json!({}))
}
+ "com.atproto.repoBatchWrite" => {
+ let batch: RepoBatchWriteBody = rouille::input::json_input(request)?;
+ // TODO: validate edits against schemas
+ let did = Did::from_str(&xrpc_required_param(request, "did")?)?;
+ let mut srv = srv.lock().unwrap();
+ let _auth_did = &xrpc_check_auth_header(&mut srv, request, Some(&did))?;
+ let commit_cid = &srv.repo.lookup_commit(&did)?.unwrap();
+ let last_commit = srv.repo.get_commit(&commit_cid)?;
+ let mut mutations: Vec<Mutation> = Default::default();
+ for w in batch.writes.iter() {
+ let m = match w.op_type.as_str() {
+ "create" => Mutation::Create(
+ Nsid::from_str(&w.collection)?,
+ // TODO: user input unwrap here
+ w.rkey
+ .as_ref()
+ .map(|t| Tid::from_str(&t).unwrap())
+ .unwrap_or_else(|| srv.tid_gen.next()),
+ json_value_into_ipld(w.value.clone()),
+ ),
+ "update" => Mutation::Update(
+ Nsid::from_str(&w.collection)?,
+ Tid::from_str(w.rkey.as_ref().unwrap())?,
+ json_value_into_ipld(w.value.clone()),
+ ),
+ "delete" => Mutation::Delete(
+ Nsid::from_str(&w.collection)?,
+ Tid::from_str(w.rkey.as_ref().unwrap())?,
+ ),
+ _ => Err(anyhow!("unhandled operation type: {}", w.op_type))?,
+ };
+ mutations.push(m);
+ }
+ let new_mst_cid = srv.repo.update_mst(&last_commit.mst_cid, &mutations)?;
+ let new_root_cid = srv.repo.write_root(
+ &last_commit.meta_cid,
+ Some(&last_commit.commit_cid),
+ &new_mst_cid,
+ )?;
+ srv.repo.write_commit(&did, &new_root_cid, "dummy-sig")?;
+ // TODO: next handle updates to database
+ Ok(json!({}))
+ }
+ "com.atproto.repoCreateRecord" => {
+ // TODO: validate edits against schemas
+ let did = Did::from_str(&xrpc_required_param(request, "did")?)?;
+ let collection = Nsid::from_str(&xrpc_required_param(request, "collection")?)?;
+ let record: Value = rouille::input::json_input(request)?;
+ let mut srv = srv.lock().unwrap();
+ let _auth_did = &xrpc_check_auth_header(&mut srv, request, Some(&did))?;
+ debug!("reading commit");
+ let commit_cid = &srv.repo.lookup_commit(&did)?.unwrap();
+ let last_commit = srv.repo.get_commit(&commit_cid)?;
+ let mutations: Vec<Mutation> = vec![Mutation::Create(
+ collection,
+ srv.tid_gen.next(),
+ json_value_into_ipld(record),
+ )];
+ debug!("mutating tree");
+ let new_mst_cid = srv
+ .repo
+ .update_mst(&last_commit.mst_cid, &mutations)
+ .context("updating MST in repo")?;
+ debug!("writing new root");
+ let new_root_cid = srv.repo.write_root(
+ &last_commit.meta_cid,
+ Some(&last_commit.commit_cid),
+ &new_mst_cid,
+ )?;
+ debug!("writing new commit");
+ srv.repo.write_commit(&did, &new_root_cid, "dummy-sig")?;
+ // TODO: next handle updates to database
+ Ok(json!({}))
+ }
+ "com.atproto.repoDeleteRecord" => {
+ let did = Did::from_str(&xrpc_required_param(request, "did")?)?;
+ let collection = Nsid::from_str(&xrpc_required_param(request, "collection")?)?;
+ let tid = Tid::from_str(&xrpc_required_param(request, "rkey")?)?;
+ let mut srv = srv.lock().unwrap();
+ let _auth_did = &xrpc_check_auth_header(&mut srv, request, Some(&did))?;
+ let commit_cid = &srv.repo.lookup_commit(&did)?.unwrap();
+ let last_commit = srv.repo.get_commit(&commit_cid)?;
+ let mutations: Vec<Mutation> = vec![Mutation::Delete(collection, tid)];
+ let new_mst_cid = srv.repo.update_mst(&last_commit.mst_cid, &mutations)?;
+ let new_root_cid = srv.repo.write_root(
+ &last_commit.meta_cid,
+ Some(&last_commit.commit_cid),
+ &new_mst_cid,
+ )?;
+ srv.repo.write_commit(&did, &new_root_cid, "dummy-sig")?;
+ // TODO: next handle updates to database
+ Ok(json!({}))
+ }
_ => Err(anyhow!(XrpcError::NotFound(format!(
"XRPC endpoint handler not found: {}",
method