Fix proxy client

This commit is contained in:
Thomas Sileo 2022-12-02 19:28:59 +01:00
parent 34c7cdb5fb
commit 73dceee0f5

View file

@ -1180,33 +1180,31 @@ async def nodeinfo(
async def _proxy_get( async def _proxy_get(
request: starlette.requests.Request, url: str, stream: bool proxy_client: httpx.AsyncClient,
request: starlette.requests.Request,
url: str,
stream: bool,
) -> httpx.Response: ) -> httpx.Response:
async with httpx.AsyncClient( # Request the URL (and filter request headers)
follow_redirects=True, proxy_req = proxy_client.build_request(
timeout=httpx.Timeout(timeout=10.0), request.method,
transport=httpx.AsyncHTTPTransport(retries=1), url,
) as proxy_client: headers=[
# Request the URL (and filter request headers) (k, v)
proxy_req = proxy_client.build_request( for (k, v) in request.headers.raw
request.method, if k.lower()
url, not in [
headers=[ b"host",
(k, v) b"cookie",
for (k, v) in request.headers.raw b"x-forwarded-for",
if k.lower() b"x-forwarded-proto",
not in [ b"x-real-ip",
b"host", b"user-agent",
b"cookie",
b"x-forwarded-for",
b"x-forwarded-proto",
b"x-real-ip",
b"user-agent",
]
] ]
+ [(b"user-agent", USER_AGENT.encode())], ]
) + [(b"user-agent", USER_AGENT.encode())],
return await proxy_client.send(proxy_req, stream=stream) )
return await proxy_client.send(proxy_req, stream=stream)
def _filter_proxy_resp_headers( def _filter_proxy_resp_headers(
@ -1232,18 +1230,29 @@ async def serve_proxy_media(
exp: int, exp: int,
sig: str, sig: str,
encoded_url: str, encoded_url: str,
background_tasks: fastapi.BackgroundTasks,
) -> StreamingResponse | PlainTextResponse: ) -> StreamingResponse | PlainTextResponse:
# Decode the base64-encoded URL # Decode the base64-encoded URL
url = base64.urlsafe_b64decode(encoded_url).decode() url = base64.urlsafe_b64decode(encoded_url).decode()
check_url(url) check_url(url)
media.verify_proxied_media_sig(exp, url, sig) media.verify_proxied_media_sig(exp, url, sig)
proxy_resp = await _proxy_get(request, url, stream=True) proxy_client = httpx.AsyncClient(
follow_redirects=True,
timeout=httpx.Timeout(timeout=10.0),
transport=httpx.AsyncHTTPTransport(retries=1),
)
async def _close_proxy_client():
await proxy_client.aclose()
background_tasks.add_task(_close_proxy_client)
proxy_resp = await _proxy_get(proxy_client, request, url, stream=True)
if proxy_resp.status_code >= 300: if proxy_resp.status_code >= 300:
logger.info(f"failed to proxy {url}, got {proxy_resp.status_code}") logger.info(f"failed to proxy {url}, got {proxy_resp.status_code}")
await proxy_resp.aclose()
return PlainTextResponse( return PlainTextResponse(
"proxy error",
status_code=proxy_resp.status_code, status_code=proxy_resp.status_code,
) )
@ -1276,6 +1285,7 @@ async def serve_proxy_media_resized(
sig: str, sig: str,
encoded_url: str, encoded_url: str,
size: int, size: int,
background_tasks: fastapi.BackgroundTasks,
) -> PlainTextResponse: ) -> PlainTextResponse:
if size not in {50, 740}: if size not in {50, 740}:
raise ValueError("Unsupported size") raise ValueError("Unsupported size")
@ -1293,9 +1303,20 @@ async def serve_proxy_media_resized(
headers=resp_headers, headers=resp_headers,
) )
proxy_resp = await _proxy_get(request, url, stream=False) proxy_client = httpx.AsyncClient(
follow_redirects=True,
timeout=httpx.Timeout(timeout=10.0),
transport=httpx.AsyncHTTPTransport(retries=1),
)
async def _close_proxy_client():
await proxy_client.aclose()
background_tasks.add_task(_close_proxy_client)
proxy_resp = await _proxy_get(proxy_client, request, url, stream=False)
if proxy_resp.status_code >= 300: if proxy_resp.status_code >= 300:
logger.info(f"failed to proxy {url}, got {proxy_resp.status_code}") logger.info(f"failed to proxy {url}, got {proxy_resp.status_code}")
await proxy_resp.aclose()
return PlainTextResponse( return PlainTextResponse(
"proxy error", "proxy error",
status_code=proxy_resp.status_code, status_code=proxy_resp.status_code,