summaryrefslogtreecommitdiffstats
path: root/fatcat_scholar/hacks.py
blob: fc1656493f0b214fbfd642a4aa764924e7ab6a5c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import typing
import jinja2

from starlette.background import BackgroundTask
from starlette.templating import _TemplateResponse


class Jinja2Templates:
    """
    This is a patched version of starlette.templating.Jinja2Templates that
    supports extensions (list of strings) passed to jinja2.Environment
    """

    def __init__(self, directory: str, extensions: typing.List[str] = []) -> None:
        assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
        self.env = self.get_env(directory, extensions)

    def get_env(
        self, directory: str, extensions: typing.List[str] = []
    ) -> "jinja2.Environment":
        @jinja2.contextfunction
        def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
            request = context["request"]
            return request.url_for(name, **path_params)

        loader = jinja2.FileSystemLoader(directory)
        env = jinja2.Environment(loader=loader, extensions=extensions, autoescape=True)
        env.globals["url_for"] = url_for
        return env

    def get_template(self, name: str) -> "jinja2.Template":
        return self.env.get_template(name)

    def TemplateResponse(
        self,
        name: str,
        context: dict,
        status_code: int = 200,
        headers: dict = None,
        media_type: str = None,
        background: BackgroundTask = None,
    ) -> _TemplateResponse:
        if "request" not in context:
            raise ValueError('context must include a "request" key')
        template = self.get_template(name)
        return _TemplateResponse(
            template,
            context,
            status_code=status_code,
            headers=headers,
            media_type=media_type,
            background=background,
        )


def parse_accept_lang(header: str, options: typing.List[str]) -> typing.Optional[str]:
    """
    Crude HTTP Accept-Language content negotiation.
    Assumes that languages are specified in order of priority, etc.
    """
    if not header:
        return None
    chunks = [v.split(";")[0].split("-")[0] for v in header.split(",")]
    for c in chunks:
        if len(c) == 2 and c in options:
            return c
    return None


def test_parse_accept_lang() -> None:
    assert parse_accept_lang("", []) == None
    assert parse_accept_lang("en,de", []) == None
    assert parse_accept_lang("en,de", ["en"]) == "en"
    assert parse_accept_lang("en-GB,de", ["en"]) == "en"
    assert parse_accept_lang("en,de", ["de"]) == "de"
    assert (
        parse_accept_lang("en-ca,en;q=0.8,en-us;q=0.6,de-de;q=0.4,de;q=0.2", ["de"])
        == "de"
    )