diff options
-rw-r--r-- | fatcat/api.py | 5 | ||||
-rw-r--r-- | tests/test_backend.py | 32 |
2 files changed, 32 insertions, 5 deletions
diff --git a/fatcat/api.py b/fatcat/api.py index e5b473bb..9f2ed29a 100644 --- a/fatcat/api.py +++ b/fatcat/api.py @@ -71,6 +71,8 @@ def api_release_create(): edit_group = get_or_create_edit_group(params.get('editgroup')) creators = params.get('creators', []) creators = [CreatorIdent.query.filter(CreatorIdent.id==c).first_or_404() for c in creators] + targets = [ref['target'] for ref in params.get('refs', []) if ref.get('target') != None] + targets = [ReleaseIdent.query.filter(ReleaseIdent.id==t).first_or_404() for t in targets] work = params.get('work') if work: work = WorkIdent.query.filter(WorkIdent.id==work).first_or_404() @@ -87,6 +89,9 @@ def api_release_create(): contribs = [ReleaseContrib(release=rev, creator=c) for c in creators] rev.creators = contribs db.session.add_all(contribs) + refs = [ReleaseRef(release=rev, target=t) for t in targets] + rev.refs = refs + db.session.add_all(refs) ident = ReleaseIdent(is_live=False, rev=rev) edit = ReleaseEdit(edit_group=edit_group, ident=ident, rev=rev) if params.get('extra', None): diff --git a/tests/test_backend.py b/tests/test_backend.py index 25421560..e6e15f3a 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -236,6 +236,24 @@ class APITestCase(FatcatTestCase): obj = json.loads(rv.data.decode('utf-8')) work_id = obj['id'] + # this stub work will be referenced + rv = self.app.post('/v0/release', + data=json.dumps(dict( + title="derivative work", + work_type="journal-article", + work=work_id, + creators=[creator_id], + doi="10.1234/58", + editgroup=editgroup_id, + refs=[ + dict(stub="some other journal article"), + ], + extra=dict(f=7, b="zing"))), + headers={"content-type": "application/json"}) + assert rv.status_code == 200 + obj = json.loads(rv.data.decode('utf-8')) + stub_release_id = obj['id'] + rv = self.app.post('/v0/release', data=json.dumps(dict( title="dummy work", @@ -246,9 +264,9 @@ class APITestCase(FatcatTestCase): doi="10.1234/5678", editgroup=editgroup_id, refs=[ - dict(stub="some other journal article"), + dict(stub="some book", target=stub_release_id), ], - extra=dict(f=7, b="zing"))), + extra=dict(f=7, b="loopy"))), headers={"content-type": "application/json"}) assert rv.status_code == 200 obj = json.loads(rv.data.decode('utf-8')) @@ -269,9 +287,10 @@ class APITestCase(FatcatTestCase): for cls in (WorkIdent, WorkRev, WorkEdit, ContainerIdent, ContainerRev, ContainerEdit, CreatorIdent, CreatorRev, CreatorEdit, - ReleaseIdent, ReleaseRev, ReleaseEdit, FileIdent, FileRev, FileEdit): assert cls.query.count() == 1 + for cls in (ReleaseIdent, ReleaseRev, ReleaseEdit): + assert cls.query.count() == 2 for cls in (WorkIdent, ContainerIdent, @@ -280,6 +299,7 @@ class APITestCase(FatcatTestCase): FileIdent): assert cls.query.filter(cls.is_live==True).count() == 0 + assert ChangelogEntry.query.count() == 0 rv = self.app.post('/v0/editgroup/{}/accept'.format(editgroup_id), headers={"content-type": "application/json"}) assert rv.status_code == 200 @@ -288,16 +308,17 @@ class APITestCase(FatcatTestCase): for cls in (WorkIdent, WorkRev, WorkEdit, ContainerIdent, ContainerRev, ContainerEdit, CreatorIdent, CreatorRev, CreatorEdit, - ReleaseIdent, ReleaseRev, ReleaseEdit, FileIdent, FileRev, FileEdit): assert cls.query.count() == 1 + for cls in (ReleaseIdent, ReleaseRev, ReleaseEdit): + assert cls.query.count() == 2 for cls in (WorkIdent, ContainerIdent, CreatorIdent, - ReleaseIdent, FileIdent): assert cls.query.filter(cls.is_live==True).count() == 1 + assert ReleaseIdent.query.filter(ReleaseIdent.is_live==True).count() == 2 # Test that foreign key relations worked release_rv = json.loads(self.app.get('/v0/release/{}'.format(release_id)).data.decode('utf-8')) @@ -305,6 +326,7 @@ class APITestCase(FatcatTestCase): assert(release_rv['creators'][0]['creator'] == creator_id) assert(release_rv['container']['id'] == container_id) assert(release_rv['work']['id'] == work_id) + assert(release_rv['refs'][0]['target'] == stub_release_id) file_rv = json.loads(self.app.get('/v0/file/{}'.format(file_id)).data.decode('utf-8')) print(file_rv) |