Fix OAuth introspection endpoint

This commit is contained in:
Thomas Sileo 2023-02-03 08:32:50 +01:00
parent 2bd6c98538
commit 625f399309

View file

@ -10,6 +10,8 @@ from fastapi import Form
from fastapi import HTTPException
from fastapi import Request
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBasic
from fastapi.security import HTTPBasicCredentials
from loguru import logger
from pydantic import BaseModel
from sqlalchemy import select
@ -26,6 +28,8 @@ from app.redirect import redirect
from app.utils import indieauth
from app.utils.datetime import now
basic_auth = HTTPBasic()
router = APIRouter()
@ -496,19 +500,49 @@ async def indieauth_revocation_endpoint(
@router.post("/token_introspection")
async def oauth_introspection_endpoint(
request: Request,
access_token_info: AccessTokenInfo = Depends(enforce_access_token),
credentials: HTTPBasicCredentials = Depends(basic_auth),
db_session: AsyncSession = Depends(get_db_session),
token: str = Form(),
) -> JSONResponse:
# Ensure the requested token is the same as bearer token
if token != access_token_info.access_token:
raise HTTPException(status_code=401, detail="access token required")
registered_client = (
await db_session.scalars(
select(models.OAuthClient).where(
models.OAuthClient.client_id == credentials.username,
models.OAuthClient.client_secret == credentials.password,
)
)
).one_or_none()
if not registered_client:
raise HTTPException(status_code=401, detail="unauthenticated")
access_token = (
await db_session.scalars(
select(models.IndieAuthAccessToken)
.where(models.IndieAuthAccessToken.access_token == token)
.join(
models.IndieAuthAuthorizationRequest,
models.IndieAuthAccessToken.indieauth_authorization_request_id
== models.IndieAuthAuthorizationRequest.id,
)
.where(
models.IndieAuthAuthorizationRequest.client_id == credentials.username
)
)
).one_or_none()
if not access_token:
return JSONResponse(content={"active": False})
return JSONResponse(
content={
"active": True,
"client_id": access_token_info.client_id,
"scope": " ".join(access_token_info.scopes),
"exp": access_token_info.exp,
"client_id": credentials.username,
"scope": access_token.scope,
"exp": int(
(
access_token.created_at.replace(tzinfo=timezone.utc)
+ timedelta(seconds=access_token.expires_in)
).timestamp()
),
},
status_code=200,
)