Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add request argument to TemplateResponse #2191

Merged
merged 11 commits into from
Jul 13, 2023
2 changes: 1 addition & 1 deletion docs/templates.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ from starlette.staticfiles import StaticFiles
templates = Jinja2Templates(directory='templates')

async def homepage(request):
return templates.TemplateResponse('index.html', {'request': request})
return templates.TemplateResponse(request, 'index.html')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the only thing in the documentation that uses TemplateResponse?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


routes = [
Route('/', endpoint=homepage),
Expand Down
76 changes: 72 additions & 4 deletions starlette/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,87 @@ def url_for(context: dict, __name: str, **path_params: typing.Any) -> URL:
def get_template(self, name: str) -> "jinja2.Template":
return self.env.get_template(name)

@typing.overload
def TemplateResponse(
self,
request: Request,
name: str,
context: dict,
context: typing.Optional[dict] = None,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
) -> _TemplateResponse:
...

@typing.overload
def TemplateResponse(
self,
name: str,
context: typing.Optional[dict] = None,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
) -> _TemplateResponse:
if "request" not in context:
raise ValueError('context must include a "request" key')
# Deprecated usage
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, we can't use @typing_extensions.deprecated without making typing_extensions mandatory for Starlette.

Also, the @deprecated PEP was still not accepted.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, so we have a comment here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah yeah, just explaining for future reference 👀

...

request = typing.cast(Request, context["request"])
def TemplateResponse(
self, *args: typing.Any, **kwargs: typing.Any
) -> _TemplateResponse:
if args:
if isinstance(
args[0], str
): # the first argument is template name (old style)
warnings.warn(
"The `name` is not the first parameter anymore. "
"The first parameter should be the `Request` instance.\n"
'Replace `TemplateResponse(name, {"request": request})` by `TemplateResponse(request, name)`.', # noqa: E501
DeprecationWarning,
)

name = args[0]
context = args[1] if len(args) > 1 else kwargs.get("context", {})
status_code = (
args[2] if len(args) > 2 else kwargs.get("status_code", 200)
)
headers = args[2] if len(args) > 2 else kwargs.get("headers")
media_type = args[3] if len(args) > 3 else kwargs.get("media_type")
background = args[4] if len(args) > 4 else kwargs.get("background")

if "request" not in context:
raise ValueError('context must include a "request" key')
request = context["request"]
else: # the first argument is a request instance (new style)
request = args[0]
name = args[1] if len(args) > 1 else kwargs["name"]
context = args[2] if len(args) > 2 else kwargs.get("context", {})
status_code = (
args[3] if len(args) > 3 else kwargs.get("status_code", 200)
)
headers = args[4] if len(args) > 4 else kwargs.get("headers")
media_type = args[5] if len(args) > 5 else kwargs.get("media_type")
background = args[6] if len(args) > 6 else kwargs.get("background")
else: # all arguments are kwargs
if "request" not in kwargs:
warnings.warn(
"The `TemplateResponse` now requires the `request` argument.\n"
'Replace `TemplateResponse(name, {"context": context})` by `TemplateResponse(request, name)`.', # noqa: E501
DeprecationWarning,
)
if "request" not in kwargs.get("context", {}):
raise ValueError('context must include a "request" key')
Comment on lines +212 to +213
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This error is being raised with the warning above, can we make sure we only have the warning if this doesn't raise?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This preserves current behavior.


context = kwargs.get("context", {})
request = kwargs.get("request", context.get("request"))
name = typing.cast(str, kwargs["name"])
status_code = kwargs.get("status_code", 200)
headers = kwargs.get("headers")
media_type = kwargs.get("media_type")
background = kwargs.get("background")

context.setdefault("request", request)
for context_processor in self.context_processors:
context.update(context_processor(request))

Expand Down
157 changes: 146 additions & 11 deletions tests/test_templates.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
from pathlib import Path
from unittest import mock

import jinja2
import pytest

from starlette.applications import Starlette
from starlette.background import BackgroundTask
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.routing import Route
Expand All @@ -17,7 +19,7 @@ def test_templates(tmpdir, test_client_factory):
file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>")

async def homepage(request):
return templates.TemplateResponse("index.html", {"request": request})
return templates.TemplateResponse(request, "index.html")

app = Starlette(
debug=True,
Expand All @@ -32,18 +34,12 @@ async def homepage(request):
assert set(response.context.keys()) == {"request"}


def test_template_response_requires_request(tmpdir):
templates = Jinja2Templates(str(tmpdir))
with pytest.raises(ValueError):
templates.TemplateResponse("", {})


def test_calls_context_processors(tmp_path, test_client_factory):
path = tmp_path / "index.html"
path.write_text("<html>Hello {{ username }}</html>")

async def homepage(request):
return templates.TemplateResponse("index.html", {"request": request})
return templates.TemplateResponse(request, "index.html")

def hello_world_processor(request):
return {"username": "World"}
Expand Down Expand Up @@ -72,7 +68,7 @@ def test_template_with_middleware(tmpdir, test_client_factory):
file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>")

async def homepage(request):
return templates.TemplateResponse("index.html", {"request": request})
return templates.TemplateResponse(request, "index.html")

class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
Expand All @@ -99,15 +95,15 @@ def test_templates_with_directories(tmp_path: Path, test_client_factory):
template_a.write_text("<html><a href='{{ url_for('page_a') }}'></a> a</html>")

async def page_a(request):
return templates.TemplateResponse("template_a.html", {"request": request})
return templates.TemplateResponse(request, "template_a.html")

dir_b = tmp_path.resolve() / "b"
dir_b.mkdir()
template_b = dir_b / "template_b.html"
template_b.write_text("<html><a href='{{ url_for('page_b') }}'></a> b</html>")

async def page_b(request):
return templates.TemplateResponse("template_b.html", {"request": request})
return templates.TemplateResponse(request, "template_b.html")

app = Starlette(
debug=True,
Expand Down Expand Up @@ -158,3 +154,142 @@ def test_templates_with_environment(tmpdir):
def test_templates_with_environment_options_emit_warning(tmpdir):
with pytest.warns(DeprecationWarning):
Jinja2Templates(str(tmpdir), autoescape=True)


def test_templates_with_kwargs_only(tmpdir, test_client_factory):
# MAINTAINERS: remove after 1.0
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("value: {{ a }}")
templates = Jinja2Templates(directory=str(tmpdir))

spy = mock.MagicMock()

def page(request):
return templates.TemplateResponse(
request=request,
name="index.html",
context={"a": "b"},
status_code=201,
headers={"x-key": "value"},
media_type="text/plain",
background=BackgroundTask(func=spy),
)

app = Starlette(routes=[Route("/", page)])
client = test_client_factory(app)
response = client.get("/")

assert response.text == "value: b" # context was rendered
assert response.status_code == 201
assert response.headers["x-key"] == "value"
assert response.headers["content-type"] == "text/plain; charset=utf-8"
spy.assert_called()


def test_templates_with_kwargs_only_requires_request_in_context(tmpdir):
# MAINTAINERS: remove after 1.0

templates = Jinja2Templates(directory=str(tmpdir))
with pytest.warns(
DeprecationWarning,
match="requires the `request` argument",
):
with pytest.raises(ValueError):
templates.TemplateResponse(name="index.html", context={"a": "b"})


def test_templates_with_kwargs_only_warns_when_no_request_keyword(
tmpdir, test_client_factory
):
# MAINTAINERS: remove after 1.0

path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("Hello")

templates = Jinja2Templates(directory=str(tmpdir))

def page(request):
return templates.TemplateResponse(
name="index.html", context={"request": request}
)

app = Starlette(routes=[Route("/", page)])
client = test_client_factory(app)

with pytest.warns(
DeprecationWarning,
match="requires the `request` argument",
):
client.get("/")


def test_templates_with_requires_request_in_context(tmpdir):
# MAINTAINERS: remove after 1.0
templates = Jinja2Templates(directory=str(tmpdir))
with pytest.warns(DeprecationWarning):
with pytest.raises(ValueError):
templates.TemplateResponse("index.html", context={})


def test_templates_warns_when_first_argument_isnot_request(tmpdir, test_client_factory):
# MAINTAINERS: remove after 1.0
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("value: {{ a }}")
templates = Jinja2Templates(directory=str(tmpdir))

spy = mock.MagicMock()

def page(request):
return templates.TemplateResponse(
"index.html",
{"a": "b", "request": request},
status_code=201,
headers={"x-key": "value"},
media_type="text/plain",
background=BackgroundTask(func=spy),
)

app = Starlette(routes=[Route("/", page)])
client = test_client_factory(app)
with pytest.warns(DeprecationWarning):
response = client.get("/")

assert response.text == "value: b" # context was rendered
assert response.status_code == 201
assert response.headers["x-key"] == "value"
assert response.headers["content-type"] == "text/plain; charset=utf-8"
spy.assert_called()


def test_templates_when_first_argument_is_request(tmpdir, test_client_factory):
# MAINTAINERS: remove after 1.0
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("value: {{ a }}")
templates = Jinja2Templates(directory=str(tmpdir))

spy = mock.MagicMock()

def page(request):
return templates.TemplateResponse(
request,
"index.html",
{"a": "b"},
status_code=201,
headers={"x-key": "value"},
media_type="text/plain",
background=BackgroundTask(func=spy),
)

app = Starlette(routes=[Route("/", page)])
client = test_client_factory(app)
response = client.get("/")

assert response.text == "value: b" # context was rendered
assert response.status_code == 201
assert response.headers["x-key"] == "value"
assert response.headers["content-type"] == "text/plain; charset=utf-8"
spy.assert_called()