summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--fatcat_scholar/web.py22
-rw-r--r--tests/test_web.py12
2 files changed, 22 insertions, 12 deletions
diff --git a/fatcat_scholar/web.py b/fatcat_scholar/web.py
index 0b82df9..fe2ad90 100644
--- a/fatcat_scholar/web.py
+++ b/fatcat_scholar/web.py
@@ -9,7 +9,7 @@ from typing import Optional, Any, List
from pydantic import BaseModel
import babel.support
-from fastapi import FastAPI, APIRouter, Request, Depends, Response
+from fastapi import FastAPI, APIRouter, Request, Depends, Response, HTTPException
from fastapi.staticfiles import StaticFiles
from fastapi.responses import PlainTextResponse, JSONResponse, FileResponse
import sentry_sdk
@@ -93,16 +93,16 @@ class HitsModel(BaseModel):
@api.get("/search", operation_id="get_search", response_model=HitsModel)
async def search(query: FulltextQuery = Depends(FulltextQuery)) -> FulltextHits:
- if query.q is not None:
- try:
- hits: FulltextHits = do_fulltext_search(query)
- except ValueError as e:
- sentry_sdk.set_level("warning")
- sentry_sdk.capture_exception(e)
- raise HTTPException(status_code=400, detail=f"Query Error: {e}")
- except IOError as e:
- sentry_sdk.capture_exception(e)
- raise HTTPException(status_code=500, detail=f"Backend Error: {e}")
+ hits: Optional[FulltextHits] = None
+ try:
+ hits = do_fulltext_search(query)
+ except ValueError as e:
+ sentry_sdk.set_level("warning")
+ sentry_sdk.capture_exception(e)
+ raise HTTPException(status_code=400, detail=f"Query Error: {e}")
+ except IOError as e:
+ sentry_sdk.capture_exception(e)
+ raise HTTPException(status_code=500, detail=f"Backend Error: {e}")
# remove internal context from hit objects
for doc in hits.results:
diff --git a/tests/test_web.py b/tests/test_web.py
index a5629e0..810f8e3 100644
--- a/tests/test_web.py
+++ b/tests/test_web.py
@@ -30,7 +30,7 @@ def test_main_view(client: Any) -> None:
assert "我们是" in resp.content.decode("utf-8")
-def test_basic_api(client: Any) -> None:
+def test_basic_api(client: Any, mocker: Any) -> None:
"""
Simple check of GET routes with application/json support
"""
@@ -39,6 +39,16 @@ def test_basic_api(client: Any) -> None:
assert resp.status_code == 200
assert resp.json()
+ with open("tests/files/elastic_fulltext_search.json") as f:
+ elastic_resp = json.loads(f.read())
+
+ es_raw = mocker.patch(
+ "elasticsearch.connection.Urllib3HttpConnection.perform_request"
+ )
+ es_raw.side_effect = [
+ (200, {}, json.dumps(elastic_resp)),
+ ]
+
resp = client.get("/search", headers=headers)
assert resp.status_code == 200
assert resp.json()