aboutsummaryrefslogtreecommitdiffstats
path: root/adenosine-pds/src/db.rs
blob: 03f6c6851ca37bec82b6afdb6eccd5c157f9f3e8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
use crate::AtpSession;
/// ATP database (as distinct from blockstore)
use anyhow::{anyhow, Result};
use lazy_static::lazy_static;
use log::debug;
use rusqlite::{params, Connection};
use rusqlite_migration::{Migrations, M};
use serde_json::Value;
use std::path::PathBuf;
use std::str::FromStr;

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn migrations_test() {
        assert!(MIGRATIONS.validate().is_ok());
    }
}

lazy_static! {
    static ref MIGRATIONS: Migrations<'static> =
        Migrations::new(vec![M::up(include_str!("atp_db.sql")),]);
}

#[derive(Debug)]
pub struct AtpDatabase {
    conn: Connection,
}

impl AtpDatabase {
    pub fn open(path: &PathBuf) -> Result<Self> {
        let mut conn = Connection::open(path)?;
        MIGRATIONS.to_latest(&mut conn)?;
        // any pragma would happen here
        Ok(AtpDatabase { conn })
    }

    /// temporary database, eg for tests.
    ///
    /// TODO: should create a tmp file on ramdisk (/var/tmp?) instead of opening an in-memory
    /// database. in-memory database can't be used with multiple connections
    pub fn open_ephemeral() -> Result<Self> {
        let mut conn = Connection::open_in_memory()?;
        MIGRATIONS.to_latest(&mut conn)?;
        // any pragma would happen here
        Ok(AtpDatabase { conn })
    }

    /// Creates an entirely new connection to the same database
    ///
    /// Skips re-running migrations.
    ///
    /// Fails for ephemeral databases.
    pub fn new_connection(&self) -> Result<Self> {
        // TODO: let path = std::path::PathBuf::from(self.conn.path().ok_or(Err(anyhow!("expected real database")))?);
        let path = std::path::PathBuf::from(self.conn.path().expect("expected real database"));
        let conn = Connection::open(path)?;
        Ok(AtpDatabase { conn })
    }

    pub fn get_record(&mut self, did: &str, collection: &str, tid: &str) -> Result<Value> {
        let mut stmt = self.conn.prepare_cached(
            "SELECT record_json FROM record WHERE did = ?1 AND collection = ?2 AND tid = ?3",
        )?;
        Ok(stmt.query_row(params!(did, collection, tid), |row| {
            row.get(0).map(|v: String| Value::from_str(&v))
        })??)
    }

    pub fn get_record_list(&mut self, did: &str, collection: &str) -> Result<Vec<String>> {
        let mut stmt = self
            .conn
            .prepare_cached("SELECT tid FROM record WHERE did = ?1 AND collection = ?2")?;
        let ret = stmt
            .query_and_then(params!(did, collection), |row| {
                let v: String = row.get(0)?;
                Ok(v)
            })?
            .collect();
        ret
    }

    pub fn get_collection_list(&mut self, did: &str) -> Result<Vec<String>> {
        let mut stmt = self
            .conn
            .prepare_cached("SELECT collection FROM record WHERE did = ?1 GROUP BY collection")?;
        let ret = stmt
            .query_and_then(params!(did), |row| {
                let v: String = row.get(0)?;
                Ok(v)
            })?
            .collect();
        ret
    }

    /// Quick check if an account already exists for given username or email
    pub fn account_exists(&mut self, username: &str, email: &str) -> Result<bool> {
        let mut stmt = self
            .conn
            .prepare_cached("SELECT COUNT(*) FROM account WHERE username = $1 OR email = $2")?;
        let count: i32 = stmt.query_row(params!(username, email), |row| row.get(0))?;
        Ok(count > 0)
    }

    pub fn create_account(
        &mut self,
        did: &str,
        username: &str,
        password: &str,
        email: &str,
    ) -> Result<()> {
        debug!("bcrypt hashing password (can be slow)...");
        let password_bcrypt = bcrypt::hash(password, bcrypt::DEFAULT_COST)?;
        let did = "did:TODO";
        let mut stmt = self.conn.prepare_cached(
            "INSERT INTO account (username, password_bcrypt, email, did) VALUES (?1, ?2, ?3, ?4)",
        )?;
        stmt.execute(params!(username, password_bcrypt, email, did))?;
        Ok(())
    }

    /// Returns a JWT session token
    pub fn create_session(&mut self, username: &str, password: &str) -> Result<AtpSession> {
        let mut stmt = self
            .conn
            .prepare_cached("SELECT did, password_bcrypt FROM account WHERE username = ?1")?;
        let (did, password_bcrypt): (String, String) =
            stmt.query_row(params!(username), |row| Ok((row.get(0)?, row.get(1)?)))?;
        if !bcrypt::verify(password, &password_bcrypt)? {
            return Err(anyhow!("password did not match"));
        }
        // TODO: generate JWT
        // TODO: insert session with JWT
        let jwt = "jwt:BOGUS";
        Ok(AtpSession {
            did,
            name: username.to_string(),
            accessJwt: jwt.to_string(),
            refreshJwt: jwt.to_string(),
        })
    }

    /// Returns the DID that a token is valid for
    pub fn check_auth_token(&mut self, jwt: &str) -> Result<String> {
        let mut stmt = self
            .conn
            .prepare_cached("SELECT did FROM session WHERE jwt = $1")?;
        let did = stmt.query_row(params!(jwt), |row| row.get(0))?;
        Ok(did)
    }

    pub fn put_did_doc(&mut self, did: &str, did_doc: &Value) -> Result<()> {
        let mut stmt = self
            .conn
            .prepare_cached("INSERT INTO did_doc (did, doc_json) VALUES (?1, ?2)")?;
        stmt.execute(params!(did, did_doc.to_string()))?;
        Ok(())
    }
}