aboutsummaryrefslogtreecommitdiffstats
path: root/adenosine/src/mst.rs
diff options
context:
space:
mode:
Diffstat (limited to 'adenosine/src/mst.rs')
-rw-r--r--adenosine/src/mst.rs399
1 files changed, 399 insertions, 0 deletions
diff --git a/adenosine/src/mst.rs b/adenosine/src/mst.rs
new file mode 100644
index 0000000..1d75da1
--- /dev/null
+++ b/adenosine/src/mst.rs
@@ -0,0 +1,399 @@
+/// This is a simple immutable implemenation of of a Merkle Search Tree (MST) for atproto.
+///
+/// Mutations on the data structure are not implemented; instead the entire tree is read into a
+/// BTreeMap, that is mutated, and then the entire tree is regenerated. This makes implementation
+/// much simpler, at the obvious cost of performance in some situations.
+///
+/// The MST is basically a sorted key/value store where the key is a string and the value is a CID.
+/// Tree nodes are stored as DAG-CBOG IPLD blocks, with references as CIDs.
+///
+/// In the atproto MST implementation, SHA-256 is the hashing algorithm, and "leading zeros" are
+/// counted in blocks of 4 bits (so a leading zero byte counts as two zeros). This happens to match
+/// simple hex encoding of the SHA-256 hash.
+use anyhow::{anyhow, Context, Result};
+use ipfs_sqlite_block_store::BlockStore;
+use libipld::cbor::DagCborCodec;
+use libipld::multihash::Code;
+use libipld::prelude::Codec;
+use libipld::store::DefaultParams;
+use libipld::Block;
+use libipld::{Cid, DagCbor};
+use log::{debug, error, info};
+use std::collections::BTreeMap;
+use std::path::PathBuf;
+
+#[derive(Debug, DagCbor, PartialEq, Eq)]
+pub struct CommitNode {
+ pub root: Cid,
+ pub sig: Box<[u8]>,
+}
+
+#[derive(Debug, DagCbor, PartialEq, Eq)]
+pub struct RootNode {
+ pub auth_token: Option<String>,
+ pub prev: Option<Cid>,
+ // TODO: not 'metadata'?
+ pub meta: Cid,
+ pub data: Cid,
+}
+
+#[derive(Debug, DagCbor, PartialEq, Eq)]
+pub struct MetadataNode {
+ pub datastore: String, // "mst"
+ pub did: String,
+ pub version: u8, // 1
+}
+
+#[derive(Debug, DagCbor, PartialEq, Eq)]
+struct MstEntry {
+ p: u32,
+ k: String,
+ v: Cid,
+ t: Option<Cid>,
+}
+
+#[derive(Debug, DagCbor, PartialEq)]
+struct MstNode {
+ l: Option<Cid>,
+ e: Vec<MstEntry>,
+}
+
+struct WipEntry {
+ height: u8,
+ key: String,
+ val: Cid,
+ right: Option<Box<WipNode>>,
+}
+
+struct WipNode {
+ height: u8,
+ left: Option<Box<WipNode>>,
+ entries: Vec<WipEntry>,
+}
+
+fn get_mst_node(db: &mut BlockStore<libipld::DefaultParams>, cid: &Cid) -> Result<MstNode> {
+ let block = &db
+ .get_block(cid)?
+ .ok_or(anyhow!("reading MST node from blockstore"))?;
+ //println!("{:?}", block);
+ let mst_node: MstNode = DagCborCodec
+ .decode(block)
+ .context("parsing MST DAG-CBOR IPLD node from blockstore")?;
+ Ok(mst_node)
+}
+
+pub fn print_mst_keys(db: &mut BlockStore<libipld::DefaultParams>, cid: &Cid) -> Result<()> {
+ let node = get_mst_node(db, cid)?;
+ if let Some(ref left) = node.l {
+ print_mst_keys(db, left)?;
+ }
+ let mut key: String = "".to_string();
+ for entry in node.e.iter() {
+ key = format!("{}{}", &key[0..entry.p as usize], entry.k);
+ println!("\t{}\t-> {}", key, entry.v);
+ if let Some(ref right) = entry.t {
+ print_mst_keys(db, right)?;
+ }
+ }
+ Ok(())
+}
+
+pub fn dump_mst_keys(db_path: &PathBuf) -> Result<()> {
+ let mut db: BlockStore<libipld::DefaultParams> = BlockStore::open(db_path, Default::default())?;
+
+ let all_aliases: Vec<(Vec<u8>, Cid)> = db.aliases()?;
+ if all_aliases.is_empty() {
+ error!("expected at least one alias in block store");
+ std::process::exit(-1);
+ }
+
+ // print all the aliases
+ for (alias, commit_cid) in all_aliases.iter() {
+ let did = String::from_utf8_lossy(alias);
+ println!("{} -> {}", did, commit_cid);
+ }
+
+ let (did, commit_cid) = all_aliases[0].clone();
+ let did = String::from_utf8_lossy(&did);
+ info!("starting from {} [{}]", commit_cid, did);
+
+ // NOTE: the faster way to develop would have been to decode to libipld::ipld::Ipld first? meh
+
+ debug!(
+ "raw commit: {:?}",
+ &db.get_block(&commit_cid)?
+ .ok_or(anyhow!("expected commit block in store"))?
+ );
+ let commit: CommitNode = DagCborCodec.decode(
+ &db.get_block(&commit_cid)?
+ .ok_or(anyhow!("expected commit block in store"))?,
+ )?;
+ debug!("Commit: {:?}", commit);
+ let root: RootNode = DagCborCodec.decode(
+ &db.get_block(&commit.root)?
+ .ok_or(anyhow!("expected root block in store"))?,
+ )?;
+ debug!("Root: {:?}", root);
+ let metadata: MetadataNode = DagCborCodec.decode(
+ &db.get_block(&root.meta)?
+ .ok_or(anyhow!("expected metadata block in store"))?,
+ )?;
+ debug!("Metadata: {:?}", metadata);
+ let mst_node: MstNode = DagCborCodec.decode(
+ &db.get_block(&root.data)?
+ .ok_or(anyhow!("expected block in store"))?,
+ )?;
+ debug!("MST root node: {:?}", mst_node);
+ debug!("============");
+
+ println!("{}", did);
+ print_mst_keys(&mut db, &root.data)?;
+ Ok(())
+}
+
+pub fn collect_mst_keys(
+ db: &mut BlockStore<libipld::DefaultParams>,
+ cid: &Cid,
+ map: &mut BTreeMap<String, Cid>,
+) -> Result<()> {
+ let node = get_mst_node(db, cid)?;
+ if let Some(ref left) = node.l {
+ collect_mst_keys(db, left, map)?;
+ }
+ let mut key: String = "".to_string();
+ for entry in node.e.iter() {
+ key = format!("{}{}", &key[0..entry.p as usize], entry.k);
+ map.insert(key.clone(), entry.v);
+ if let Some(ref right) = entry.t {
+ collect_mst_keys(db, right, map)?;
+ }
+ }
+ Ok(())
+}
+
+fn leading_zeros(key: &str) -> u8 {
+ let digest = sha256::digest(key);
+ let digest = digest.as_bytes();
+ for (i, c) in digest.iter().enumerate() {
+ if *c != b'0' {
+ return i as u8;
+ }
+ }
+ digest.len() as u8
+}
+
+// # python code to generate test cases
+// import hashlib
+// seed = b"asdf"
+// while True:
+// out = hashlib.sha256(seed).hexdigest()
+// if out.startswith("00"):
+// print(f"{seed} -> {out}")
+// seed = b"app.bsky.feed.post/" + out.encode('utf8')[:12]
+
+#[test]
+fn test_leading_zeros() {
+ assert_eq!(leading_zeros(""), 0);
+ assert_eq!(leading_zeros("asdf"), 0);
+ assert_eq!(leading_zeros("2653ae71"), 0);
+ assert_eq!(leading_zeros("88bfafc7"), 1);
+ assert_eq!(leading_zeros("2a92d355"), 2);
+ assert_eq!(leading_zeros("884976f5"), 3);
+ assert_eq!(leading_zeros("app.bsky.feed.post/454397e440ec"), 2);
+ assert_eq!(leading_zeros("app.bsky.feed.post/9adeb165882c"), 4);
+}
+
+pub fn generate_mst(
+ db: &mut BlockStore<libipld::DefaultParams>,
+ map: &BTreeMap<String, Cid>,
+) -> Result<Cid> {
+ // construct a "WIP" tree
+ let mut root: Option<WipNode> = None;
+ for (key, val) in map {
+ let height = leading_zeros(key);
+ let entry = WipEntry {
+ height,
+ key: key.clone(),
+ val: *val,
+ right: None,
+ };
+ if let Some(node) = root {
+ root = Some(insert_entry(node, entry));
+ } else {
+ root = Some(WipNode {
+ height: entry.height,
+ left: None,
+ entries: vec![entry],
+ });
+ }
+ }
+ let empty_node = WipNode {
+ height: 0,
+ left: None,
+ entries: vec![],
+ };
+ serialize_wip_tree(db, root.unwrap_or(empty_node))
+}
+
+// this routine assumes that entries are added in sorted key order. AKA, the `entry` being added is
+// "further right" in the tree than any existing entries
+fn insert_entry(mut node: WipNode, entry: WipEntry) -> WipNode {
+ // if we are higher on tree than existing node, replace it (recursively) with new layers first
+ while entry.height > node.height {
+ node = WipNode {
+ height: node.height + 1,
+ left: Some(Box::new(node)),
+ entries: vec![],
+ }
+ }
+ // if we are lower on tree, then need to descend first
+ if entry.height < node.height {
+ // if no entries at this node, then we should insert down "left" (which is just "down", not
+ // "before" any entries)
+ if node.entries.is_empty() {
+ if let Some(left) = node.left {
+ node.left = Some(Box::new(insert_entry(*left, entry)));
+ return node;
+ } else {
+ panic!("hit existing totally empty MST node");
+ }
+ }
+ let mut last = node.entries.pop().expect("hit empty existing entry list");
+ assert!(entry.key > last.key);
+ if last.right.is_some() {
+ last.right = Some(Box::new(insert_entry(*last.right.unwrap(), entry)));
+ } else {
+ let mut new_node = WipNode {
+ height: entry.height,
+ left: None,
+ entries: vec![entry],
+ };
+ // may need to (recursively) insert multiple filler layers
+ while new_node.height + 1 < node.height {
+ new_node = WipNode {
+ height: new_node.height + 1,
+ left: Some(Box::new(new_node)),
+ entries: vec![],
+ }
+ }
+ last.right = Some(Box::new(new_node));
+ }
+ node.entries.push(last);
+ return node;
+ }
+ // same height, simply append to end (but verify first)
+ assert!(node.height == entry.height);
+ if !node.entries.is_empty() {
+ let last = &node.entries.last().unwrap();
+ assert!(entry.key > last.key);
+ }
+ node.entries.push(entry);
+ node
+}
+
+/// returns the length of common characters between the two strings. Strings must be simple ASCII,
+/// which should hold for current ATP MST keys (collection plus TID)
+fn common_prefix_len(a: &str, b: &str) -> usize {
+ let a = a.as_bytes();
+ let b = b.as_bytes();
+ for i in 0..std::cmp::min(a.len(), b.len()) {
+ if a[i] != b[i] {
+ return i;
+ }
+ }
+ // strings are the same, up to common length
+ std::cmp::min(a.len(), b.len())
+}
+
+#[test]
+fn test_common_prefix_len() {
+ assert_eq!(common_prefix_len("abc", "abc"), 3);
+ assert_eq!(common_prefix_len("", "abc"), 0);
+ assert_eq!(common_prefix_len("abc", ""), 0);
+ assert_eq!(common_prefix_len("ab", "abc"), 2);
+ assert_eq!(common_prefix_len("abc", "ab"), 2);
+ assert_eq!(common_prefix_len("abcde", "abc"), 3);
+ assert_eq!(common_prefix_len("abc", "abcde"), 3);
+ assert_eq!(common_prefix_len("abcde", "abc1"), 3);
+ assert_eq!(common_prefix_len("abcde", "abb"), 2);
+ assert_eq!(common_prefix_len("abcde", "qbb"), 0);
+ assert_eq!(common_prefix_len("abc", "abc\x00"), 3);
+ assert_eq!(common_prefix_len("abc\x00", "abc"), 3);
+}
+
+#[test]
+fn test_common_prefix_len_wide() {
+ // TODO: these are not cross-language consistent!
+ assert_eq!("jalapeño".len(), 9); // 8 in javascript
+ assert_eq!("💩".len(), 4); // 2 in javascript
+ assert_eq!("👩‍👧‍👧".len(), 18); // 8 in javascript
+
+ // many of the below are different in JS; in Rust we *must* cast down to bytes to count
+ assert_eq!(common_prefix_len("jalapeño", "jalapeno"), 6);
+ assert_eq!(common_prefix_len("jalapeñoA", "jalapeñoB"), 9);
+ assert_eq!(common_prefix_len("coöperative", "coüperative"), 3);
+ assert_eq!(common_prefix_len("abc💩abc", "abcabc"), 3);
+ assert_eq!(common_prefix_len("💩abc", "💩ab"), 6);
+ assert_eq!(common_prefix_len("abc👩‍👦‍👦de", "abc👩‍👧‍👧de"), 13);
+}
+
+fn serialize_wip_tree(
+ db: &mut BlockStore<libipld::DefaultParams>,
+ wip_node: WipNode,
+) -> Result<Cid> {
+ let left: Option<Cid> = if let Some(left) = wip_node.left {
+ Some(serialize_wip_tree(db, *left)?)
+ } else {
+ None
+ };
+
+ let mut entries: Vec<MstEntry> = vec![];
+ let mut last_key = "".to_string();
+ for wip_entry in wip_node.entries {
+ let right: Option<Cid> = if let Some(right) = wip_entry.right {
+ Some(serialize_wip_tree(db, *right)?)
+ } else {
+ None
+ };
+ let prefix_len = common_prefix_len(&last_key, &wip_entry.key);
+ entries.push(MstEntry {
+ k: wip_entry.key[prefix_len..].to_string(),
+ p: prefix_len as u32,
+ v: wip_entry.val,
+ t: right,
+ });
+ last_key = wip_entry.key;
+ }
+ let mst_node = MstNode {
+ l: left,
+ e: entries,
+ };
+ let block = Block::<DefaultParams>::encode(DagCborCodec, Code::Sha2_256, &mst_node)?;
+ let cid = *block.cid();
+ db.put_block(block, None)?;
+ Ok(cid)
+}
+
+#[test]
+fn test_mst_node_cbor() {
+ use std::str::FromStr;
+ let cid1 =
+ Cid::from_str("bafyreie5cvv4h45feadgeuwhbcutmh6t2ceseocckahdoe6uat64zmz454").unwrap();
+ let node = MstNode {
+ l: None,
+ e: vec![MstEntry {
+ k: "asdf".to_string(),
+ p: 0,
+ v: cid1,
+ t: None,
+ }],
+ };
+ let block = Block::<DefaultParams>::encode(DagCborCodec, Code::Sha2_256, &node).unwrap();
+ println!("{:?}", block);
+ //assert_eq!(1, 2);
+ let cid = *block.cid();
+ assert_eq!(
+ cid.to_string(),
+ "bafyreidaftbr35xhh4lzmv5jcoeufqjh75ohzmz6u56v7n2ippbtxdgqqe"
+ );
+}