diff options
-rw-r--r-- | fuzzycat/matching.py | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/fuzzycat/matching.py b/fuzzycat/matching.py index 2b5f53a..3391e12 100644 --- a/fuzzycat/matching.py +++ b/fuzzycat/matching.py @@ -1,7 +1,7 @@ import os import re import sys -from typing import List, Type, Union +from typing import List, Type, Union, Optional import elasticsearch import elasticsearch_dsl @@ -12,7 +12,12 @@ from fatcat_openapi_client import ContainerEntity, DefaultApi, ReleaseEntity from fuzzycat.entities import entity_from_dict, entity_from_json -def match_release_fuzzy(release: ReleaseEntity, size=5, es=None) -> List[ReleaseEntity]: +def match_release_fuzzy( + release: ReleaseEntity, + size: int = 5, + es: Optional[Union[str, Type[elasticsearch.client.Elasticsearch]]] = None, + api: DefaultApi = None, +) -> List[ReleaseEntity]: """ Given a release entity, return a number similar release entities from fatcat using Elasticsearch. @@ -49,10 +54,9 @@ def match_release_fuzzy(release: ReleaseEntity, size=5, es=None) -> List[Release index="fatcat_release").query("term", **{ es_field: value }).extra(size=size)) - print(s) resp = s.execute() if len(resp) > 0: - return response_to_entity_list(resp, entity_type=ReleaseEntity) + return response_to_entity_list(resp, entity_type=ReleaseEntity, api=api) body = { "query": { @@ -67,7 +71,7 @@ def match_release_fuzzy(release: ReleaseEntity, size=5, es=None) -> List[Release } resp = es.search(body=body, index="fatcat_release") if resp["hits"]["total"] > 0: - return response_to_entity_list(resp, entity_type=ReleaseEntity) + return response_to_entity_list(resp, entity_type=ReleaseEntity, api=api) # Get fuzzy. # https://www.elastic.co/guide/en/elasticsearch/reference/current/common-options.html#fuzziness @@ -85,7 +89,7 @@ def match_release_fuzzy(release: ReleaseEntity, size=5, es=None) -> List[Release } resp = es.search(body=body, index="fatcat_release") if resp["hits"]["total"] > 0: - return response_to_entity_list(resp, entity_type=ReleaseEntity) + return response_to_entity_list(resp, entity_type=ReleaseEntity, api=api) # TODO: perform more queries on other fields. return [] @@ -147,7 +151,7 @@ def retrieve_entity_list( return result -def response_to_entity_list(response, size=5, entity_type=ReleaseEntity): +def response_to_entity_list(response, size=5, entity_type=ReleaseEntity, api: DefaultApi = None): """ Convert an elasticsearch result to a list of entities. Accepts both a dictionary and an elasticsearch_dsl.response.Response. @@ -156,10 +160,10 @@ def response_to_entity_list(response, size=5, entity_type=ReleaseEntity): """ if isinstance(response, dict): ids = [hit["_source"]["ident"] for hit in response["hits"]["hits"]][:size] - return retrieve_entity_list(ids, entity_type=entity_type) + return retrieve_entity_list(ids, entity_type=entity_type, api=api) elif isinstance(response, elasticsearch_dsl.response.Response): ids = [hit.to_dict().get("ident") for hit in response] - return retrieve_entity_list(ids, entity_type=entity_type) + return retrieve_entity_list(ids, entity_type=entity_type, api=api) else: raise ValueError("cannot convert {}".format(response)) |