Add OAuth refresh token support
This commit is contained in:
parent
3fb36d6119
commit
ed214cf0e7
4 changed files with 94 additions and 17 deletions
|
@ -0,0 +1,36 @@
|
||||||
|
"""Add OAuth refresh token support
|
||||||
|
|
||||||
|
Revision ID: a209f0333f5a
|
||||||
|
Revises: 4ab54becec04
|
||||||
|
Create Date: 2022-12-18 11:26:31.976348+00:00
|
||||||
|
|
||||||
|
"""
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'a209f0333f5a'
|
||||||
|
down_revision = '4ab54becec04'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('indieauth_access_token', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('refresh_token', sa.String(), nullable=True))
|
||||||
|
batch_op.add_column(sa.Column('was_refreshed', sa.Boolean(), server_default='0', nullable=False))
|
||||||
|
batch_op.create_index(batch_op.f('ix_indieauth_access_token_refresh_token'), ['refresh_token'], unique=True)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('indieauth_access_token', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index(batch_op.f('ix_indieauth_access_token_refresh_token'))
|
||||||
|
batch_op.drop_column('was_refreshed')
|
||||||
|
batch_op.drop_column('refresh_token')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
|
@ -270,17 +270,17 @@ async def indieauth_token_endpoint(
|
||||||
form_data = await request.form()
|
form_data = await request.form()
|
||||||
logger.info(f"{form_data=}")
|
logger.info(f"{form_data=}")
|
||||||
grant_type = form_data.get("grant_type", "authorization_code")
|
grant_type = form_data.get("grant_type", "authorization_code")
|
||||||
if grant_type != "authorization_code":
|
if grant_type not in ["authorization_code", "refresh_token"]:
|
||||||
raise ValueError(f"Invalid grant_type {grant_type}")
|
raise ValueError(f"Invalid grant_type {grant_type}")
|
||||||
|
|
||||||
code = form_data["code"]
|
|
||||||
|
|
||||||
# These must match the params from the first request
|
# These must match the params from the first request
|
||||||
client_id = form_data["client_id"]
|
client_id = form_data["client_id"]
|
||||||
redirect_uri = form_data["redirect_uri"]
|
|
||||||
# code_verifier is optional for backward compat
|
|
||||||
code_verifier = form_data.get("code_verifier")
|
code_verifier = form_data.get("code_verifier")
|
||||||
|
|
||||||
|
if grant_type == "authorization_code":
|
||||||
|
code = form_data["code"]
|
||||||
|
redirect_uri = form_data["redirect_uri"]
|
||||||
|
# code_verifier is optional for backward compat
|
||||||
is_code_valid, auth_code_request = await _check_auth_code(
|
is_code_valid, auth_code_request = await _check_auth_code(
|
||||||
db_session,
|
db_session,
|
||||||
code=code,
|
code=code,
|
||||||
|
@ -294,12 +294,38 @@ async def indieauth_token_endpoint(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif grant_type == "refresh_token":
|
||||||
|
refresh_token = form_data["refresh_token"]
|
||||||
|
access_token = (
|
||||||
|
await db_session.scalars(
|
||||||
|
select(models.IndieAuthAccessToken)
|
||||||
|
.where(
|
||||||
|
models.IndieAuthAccessToken.refresh_token == refresh_token,
|
||||||
|
models.IndieAuthAccessToken.was_refreshed.is_(False),
|
||||||
|
)
|
||||||
|
.options(
|
||||||
|
joinedload(
|
||||||
|
models.IndieAuthAccessToken.indieauth_authorization_request
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).one_or_none()
|
||||||
|
if not access_token:
|
||||||
|
raise ValueError("invalid refresh token")
|
||||||
|
|
||||||
|
if access_token.indieauth_authorization_request.client_id != client_id:
|
||||||
|
raise ValueError("invalid client ID")
|
||||||
|
|
||||||
|
auth_code_request = access_token.indieauth_authorization_request
|
||||||
|
access_token.was_refreshed = True
|
||||||
|
|
||||||
if not auth_code_request:
|
if not auth_code_request:
|
||||||
raise ValueError("Should never happen")
|
raise ValueError("Should never happen")
|
||||||
|
|
||||||
access_token = models.IndieAuthAccessToken(
|
access_token = models.IndieAuthAccessToken(
|
||||||
indieauth_authorization_request_id=auth_code_request.id,
|
indieauth_authorization_request_id=auth_code_request.id,
|
||||||
access_token=secrets.token_urlsafe(32),
|
access_token=secrets.token_urlsafe(32),
|
||||||
|
refresh_token=secrets.token_urlsafe(32),
|
||||||
expires_in=3600,
|
expires_in=3600,
|
||||||
scope=auth_code_request.scope,
|
scope=auth_code_request.scope,
|
||||||
)
|
)
|
||||||
|
@ -309,6 +335,7 @@ async def indieauth_token_endpoint(
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={
|
content={
|
||||||
"access_token": access_token.access_token,
|
"access_token": access_token.access_token,
|
||||||
|
"refresh_token": access_token.refresh_token,
|
||||||
"token_type": "Bearer",
|
"token_type": "Bearer",
|
||||||
"scope": auth_code_request.scope,
|
"scope": auth_code_request.scope,
|
||||||
"me": config.ID + "/",
|
"me": config.ID + "/",
|
||||||
|
|
14
app/main.py
14
app/main.py
|
@ -631,6 +631,19 @@ async def outbox(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/outbox")
|
||||||
|
async def post_inbox(
|
||||||
|
request: Request,
|
||||||
|
db_session: AsyncSession = Depends(get_db_session),
|
||||||
|
access_token_info: indieauth.AccessTokenInfo = Depends(
|
||||||
|
indieauth.enforce_access_token
|
||||||
|
),
|
||||||
|
) -> ActivityPubResponse:
|
||||||
|
payload = await request.json()
|
||||||
|
logger.info(f"{payload=}")
|
||||||
|
raise ValueError("TODO")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/featured")
|
@app.get("/featured")
|
||||||
async def featured(
|
async def featured(
|
||||||
db_session: AsyncSession = Depends(get_db_session),
|
db_session: AsyncSession = Depends(get_db_session),
|
||||||
|
@ -1055,7 +1068,6 @@ async def get_inbox(
|
||||||
page: bool | None = None,
|
page: bool | None = None,
|
||||||
next_cursor: str | None = None,
|
next_cursor: str | None = None,
|
||||||
) -> ActivityPubResponse:
|
) -> ActivityPubResponse:
|
||||||
logger.info(f"{page=}/{next_cursor=}")
|
|
||||||
where = [
|
where = [
|
||||||
models.InboxObject.ap_type.in_(
|
models.InboxObject.ap_type.in_(
|
||||||
["Create", "Follow", "Like", "Announce", "Undo", "Update"]
|
["Create", "Follow", "Like", "Announce", "Undo", "Update"]
|
||||||
|
|
|
@ -471,9 +471,11 @@ class IndieAuthAccessToken(Base):
|
||||||
)
|
)
|
||||||
|
|
||||||
access_token = Column(String, nullable=False, unique=True, index=True)
|
access_token = Column(String, nullable=False, unique=True, index=True)
|
||||||
|
refresh_token = Column(String, nullable=True, unique=True, index=True)
|
||||||
expires_in = Column(Integer, nullable=False)
|
expires_in = Column(Integer, nullable=False)
|
||||||
scope = Column(String, nullable=False)
|
scope = Column(String, nullable=False)
|
||||||
is_revoked = Column(Boolean, nullable=False, default=False)
|
is_revoked = Column(Boolean, nullable=False, default=False)
|
||||||
|
was_refreshed = Column(Boolean, nullable=False, default=False, server_default="0")
|
||||||
|
|
||||||
|
|
||||||
class OAuthClient(Base):
|
class OAuthClient(Base):
|
||||||
|
|
Loading…
Reference in a new issue