aboutsummaryrefslogtreecommitdiffstats
path: root/adenosine/src/mst.rs
blob: 0d66d5a64a5583dba7379ce5ce27065f44038749 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
/// This is a simple immutable implemenation of of a Merkle Search Tree (MST) for atproto.
///
/// Mutations on the data structure are not implemented; instead the entire tree is read into a
/// BTreeMap, that is mutated, and then the entire tree is regenerated. This makes implementation
/// much simpler, at the obvious cost of performance in some situations.
///
/// The MST is basically a sorted key/value store where the key is a string and the value is a CID.
/// Tree nodes are stored as DAG-CBOG IPLD blocks, with references as CIDs.
///
/// In the atproto MST implementation, SHA-256 is the hashing algorithm, and "leading zeros" are
/// counted in blocks of 4 bits (so a leading zero byte counts as two zeros). This happens to match
/// simple hex encoding of the SHA-256 hash.
use anyhow::{anyhow, Context, Result};
use ipfs_sqlite_block_store::BlockStore;
use libipld::cbor::DagCborCodec;
use libipld::multihash::Code;
use libipld::prelude::Codec;
use libipld::store::DefaultParams;
use libipld::Block;
use libipld::{Cid, DagCbor};
use log::{debug, error, info};
use std::collections::BTreeMap;
use std::path::PathBuf;

#[derive(Debug, DagCbor, PartialEq, Eq)]
pub struct CommitNode {
    pub root: Cid,
    pub sig: Box<[u8]>,
}

#[derive(Debug, DagCbor, PartialEq, Eq)]
pub struct RootNode {
    pub auth_token: Option<String>,
    pub prev: Option<Cid>,
    // TODO: not 'metadata'?
    pub meta: Cid,
    pub data: Cid,
}

#[derive(Debug, DagCbor, PartialEq, Eq)]
pub struct MetadataNode {
    pub datastore: String, // "mst"
    pub did: String,
    pub version: u8, // 1
}

#[derive(Debug, DagCbor, PartialEq, Eq)]
struct MstEntry {
    p: u32,
    k: String,
    v: Cid,
    t: Option<Cid>,
}

#[derive(Debug, DagCbor, PartialEq)]
struct MstNode {
    l: Option<Cid>,
    e: Vec<MstEntry>,
}

struct WipEntry {
    height: u8,
    key: String,
    val: Cid,
    right: Option<Box<WipNode>>,
}

struct WipNode {
    height: u8,
    left: Option<Box<WipNode>>,
    entries: Vec<WipEntry>,
}

fn get_mst_node(db: &mut BlockStore<libipld::DefaultParams>, cid: &Cid) -> Result<MstNode> {
    let block = &db
        .get_block(cid)?
        .ok_or(anyhow!("reading MST node from blockstore"))?;
    //println!("{:?}", block);
    let mst_node: MstNode = DagCborCodec
        .decode(block)
        .context("parsing MST DAG-CBOR IPLD node from blockstore")?;
    Ok(mst_node)
}

pub fn print_mst_keys(db: &mut BlockStore<libipld::DefaultParams>, cid: &Cid) -> Result<()> {
    let node = get_mst_node(db, cid)?;
    if let Some(ref left) = node.l {
        print_mst_keys(db, left)?;
    }
    let mut key: String = "".to_string();
    for entry in node.e.iter() {
        key = format!("{}{}", &key[0..entry.p as usize], entry.k);
        println!("\t{}\t-> {}", key, entry.v);
        if let Some(ref right) = entry.t {
            print_mst_keys(db, right)?;
        }
    }
    Ok(())
}

pub fn dump_mst_keys(db_path: &PathBuf) -> Result<()> {
    let mut db: BlockStore<libipld::DefaultParams> = BlockStore::open(db_path, Default::default())?;

    let all_aliases: Vec<(Vec<u8>, Cid)> = db.aliases()?;
    if all_aliases.is_empty() {
        error!("expected at least one alias in block store");
        std::process::exit(-1);
    }

    // print all the aliases
    for (alias, commit_cid) in all_aliases.iter() {
        let did = String::from_utf8_lossy(alias);
        println!("{did} -> {commit_cid}");
    }

    let (did, commit_cid) = all_aliases[0].clone();
    let did = String::from_utf8_lossy(&did);
    info!("starting from {} [{}]", commit_cid, did);

    // NOTE: the faster way to develop would have been to decode to libipld::ipld::Ipld first? meh

    debug!(
        "raw commit: {:?}",
        &db.get_block(&commit_cid)?
            .ok_or(anyhow!("expected commit block in store"))?
    );
    let commit: CommitNode = DagCborCodec.decode(
        &db.get_block(&commit_cid)?
            .ok_or(anyhow!("expected commit block in store"))?,
    )?;
    debug!("Commit: {:?}", commit);
    let root: RootNode = DagCborCodec.decode(
        &db.get_block(&commit.root)?
            .ok_or(anyhow!("expected root block in store"))?,
    )?;
    debug!("Root: {:?}", root);
    let metadata: MetadataNode = DagCborCodec.decode(
        &db.get_block(&root.meta)?
            .ok_or(anyhow!("expected metadata block in store"))?,
    )?;
    debug!("Metadata: {:?}", metadata);
    let mst_node: MstNode = DagCborCodec.decode(
        &db.get_block(&root.data)?
            .ok_or(anyhow!("expected block in store"))?,
    )?;
    debug!("MST root node: {:?}", mst_node);
    debug!("============");

    println!("{did}");
    print_mst_keys(&mut db, &root.data)?;
    Ok(())
}

pub fn collect_mst_keys(
    db: &mut BlockStore<libipld::DefaultParams>,
    cid: &Cid,
    map: &mut BTreeMap<String, Cid>,
) -> Result<()> {
    let node = get_mst_node(db, cid)?;
    if let Some(ref left) = node.l {
        collect_mst_keys(db, left, map)?;
    }
    let mut key: String = "".to_string();
    for entry in node.e.iter() {
        key = format!("{}{}", &key[0..entry.p as usize], entry.k);
        map.insert(key.clone(), entry.v);
        if let Some(ref right) = entry.t {
            collect_mst_keys(db, right, map)?;
        }
    }
    Ok(())
}

fn leading_zeros(key: &str) -> u8 {
    let digest = sha256::digest(key);
    let digest = digest.as_bytes();
    for (i, c) in digest.iter().enumerate() {
        if *c != b'0' {
            return i as u8;
        }
    }
    digest.len() as u8
}

// # python code to generate test cases
// import hashlib
// seed = b"asdf"
// while True:
//     out = hashlib.sha256(seed).hexdigest()
//     if out.startswith("00"):
//         print(f"{seed} -> {out}")
//     seed = b"app.bsky.feed.post/" + out.encode('utf8')[:12]

#[test]
fn test_leading_zeros() {
    assert_eq!(leading_zeros(""), 0);
    assert_eq!(leading_zeros("asdf"), 0);
    assert_eq!(leading_zeros("2653ae71"), 0);
    assert_eq!(leading_zeros("88bfafc7"), 1);
    assert_eq!(leading_zeros("2a92d355"), 2);
    assert_eq!(leading_zeros("884976f5"), 3);
    assert_eq!(leading_zeros("app.bsky.feed.post/454397e440ec"), 2);
    assert_eq!(leading_zeros("app.bsky.feed.post/9adeb165882c"), 4);
}

pub fn generate_mst(
    db: &mut BlockStore<libipld::DefaultParams>,
    map: &BTreeMap<String, Cid>,
) -> Result<Cid> {
    // construct a "WIP" tree
    let mut root: Option<WipNode> = None;
    for (key, val) in map {
        let height = leading_zeros(key);
        let entry = WipEntry {
            height,
            key: key.clone(),
            val: *val,
            right: None,
        };
        if let Some(node) = root {
            root = Some(insert_entry(node, entry));
        } else {
            root = Some(WipNode {
                height: entry.height,
                left: None,
                entries: vec![entry],
            });
        }
    }
    let empty_node = WipNode {
        height: 0,
        left: None,
        entries: vec![],
    };
    serialize_wip_tree(db, root.unwrap_or(empty_node))
}

// this routine assumes that entries are added in sorted key order. AKA, the `entry` being added is
// "further right" in the tree than any existing entries
fn insert_entry(mut node: WipNode, entry: WipEntry) -> WipNode {
    // if we are higher on tree than existing node, replace it (recursively) with new layers first
    while entry.height > node.height {
        node = WipNode {
            height: node.height + 1,
            left: Some(Box::new(node)),
            entries: vec![],
        }
    }
    // if we are lower on tree, then need to descend first
    if entry.height < node.height {
        // if no entries at this node, then we should insert down "left" (which is just "down", not
        // "before" any entries)
        if node.entries.is_empty() {
            if let Some(left) = node.left {
                node.left = Some(Box::new(insert_entry(*left, entry)));
                return node;
            } else {
                panic!("hit existing totally empty MST node");
            }
        }
        let mut last = node.entries.pop().expect("hit empty existing entry list");
        assert!(entry.key > last.key);
        if last.right.is_some() {
            last.right = Some(Box::new(insert_entry(*last.right.unwrap(), entry)));
        } else {
            let mut new_node = WipNode {
                height: entry.height,
                left: None,
                entries: vec![entry],
            };
            // may need to (recursively) insert multiple filler layers
            while new_node.height + 1 < node.height {
                new_node = WipNode {
                    height: new_node.height + 1,
                    left: Some(Box::new(new_node)),
                    entries: vec![],
                }
            }
            last.right = Some(Box::new(new_node));
        }
        node.entries.push(last);
        return node;
    }
    // same height, simply append to end (but verify first)
    assert!(node.height == entry.height);
    if !node.entries.is_empty() {
        let last = &node.entries.last().unwrap();
        assert!(entry.key > last.key);
    }
    node.entries.push(entry);
    node
}

/// returns the length of common characters between the two strings. Strings must be simple ASCII,
/// which should hold for current ATP MST keys (collection plus TID)
fn common_prefix_len(a: &str, b: &str) -> usize {
    let a = a.as_bytes();
    let b = b.as_bytes();
    for i in 0..std::cmp::min(a.len(), b.len()) {
        if a[i] != b[i] {
            return i;
        }
    }
    // strings are the same, up to common length
    std::cmp::min(a.len(), b.len())
}

#[test]
fn test_common_prefix_len() {
    assert_eq!(common_prefix_len("abc", "abc"), 3);
    assert_eq!(common_prefix_len("", "abc"), 0);
    assert_eq!(common_prefix_len("abc", ""), 0);
    assert_eq!(common_prefix_len("ab", "abc"), 2);
    assert_eq!(common_prefix_len("abc", "ab"), 2);
    assert_eq!(common_prefix_len("abcde", "abc"), 3);
    assert_eq!(common_prefix_len("abc", "abcde"), 3);
    assert_eq!(common_prefix_len("abcde", "abc1"), 3);
    assert_eq!(common_prefix_len("abcde", "abb"), 2);
    assert_eq!(common_prefix_len("abcde", "qbb"), 0);
    assert_eq!(common_prefix_len("abc", "abc\x00"), 3);
    assert_eq!(common_prefix_len("abc\x00", "abc"), 3);
}

#[test]
fn test_common_prefix_len_wide() {
    // TODO: these are not cross-language consistent!
    assert_eq!("jalapeño".len(), 9); // 8 in javascript
    assert_eq!("💩".len(), 4); // 2 in javascript
    assert_eq!("👩‍👧‍👧".len(), 18); // 8 in javascript

    // many of the below are different in JS; in Rust we *must* cast down to bytes to count
    assert_eq!(common_prefix_len("jalapeño", "jalapeno"), 6);
    assert_eq!(common_prefix_len("jalapeñoA", "jalapeñoB"), 9);
    assert_eq!(common_prefix_len("coöperative", "coüperative"), 3);
    assert_eq!(common_prefix_len("abc💩abc", "abcabc"), 3);
    assert_eq!(common_prefix_len("💩abc", "💩ab"), 6);
    assert_eq!(common_prefix_len("abc👩‍👦‍👦de", "abc👩‍👧‍👧de"), 13);
}

fn serialize_wip_tree(
    db: &mut BlockStore<libipld::DefaultParams>,
    wip_node: WipNode,
) -> Result<Cid> {
    let left: Option<Cid> = if let Some(left) = wip_node.left {
        Some(serialize_wip_tree(db, *left)?)
    } else {
        None
    };

    let mut entries: Vec<MstEntry> = vec![];
    let mut last_key = "".to_string();
    for wip_entry in wip_node.entries {
        let right: Option<Cid> = if let Some(right) = wip_entry.right {
            Some(serialize_wip_tree(db, *right)?)
        } else {
            None
        };
        let prefix_len = common_prefix_len(&last_key, &wip_entry.key);
        entries.push(MstEntry {
            k: wip_entry.key[prefix_len..].to_string(),
            p: prefix_len as u32,
            v: wip_entry.val,
            t: right,
        });
        last_key = wip_entry.key;
    }
    let mst_node = MstNode {
        l: left,
        e: entries,
    };
    let block = Block::<DefaultParams>::encode(DagCborCodec, Code::Sha2_256, &mst_node)?;
    let cid = *block.cid();
    db.put_block(block, None)?;
    Ok(cid)
}

#[test]
fn test_mst_node_cbor() {
    use std::str::FromStr;
    let cid1 =
        Cid::from_str("bafyreie5cvv4h45feadgeuwhbcutmh6t2ceseocckahdoe6uat64zmz454").unwrap();
    let node = MstNode {
        l: None,
        e: vec![MstEntry {
            k: "asdf".to_string(),
            p: 0,
            v: cid1,
            t: None,
        }],
    };
    let block = Block::<DefaultParams>::encode(DagCborCodec, Code::Sha2_256, &node).unwrap();
    println!("{block:?}");
    //assert_eq!(1, 2);
    let cid = *block.cid();
    assert_eq!(
        cid.to_string(),
        "bafyreidaftbr35xhh4lzmv5jcoeufqjh75ohzmz6u56v7n2ippbtxdgqqe"
    );
}