Skip to content

Commit 90db4d2

Browse files
committed
Fix the stripe charge
1 parent c23e3b7 commit 90db4d2

File tree

7 files changed

+52
-20
lines changed

7 files changed

+52
-20
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""add billable column to usage_tracker
2+
3+
Revision ID: 683fc811a969
4+
Revises: 40e4b59f754d
5+
Create Date: 2025-09-05 10:48:09.623668
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
12+
# revision identifiers, used by Alembic.
13+
revision = '683fc811a969'
14+
down_revision = '40e4b59f754d'
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
op.add_column('usage_tracker', sa.Column('billable', sa.Boolean(), nullable=False, server_default='FALSE'))
21+
22+
23+
def downgrade() -> None:
24+
op.drop_column('usage_tracker', 'billable')

app/api/routes/wallet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def get_wallet_balance(
3434
await WalletService.ensure_wallet(db, user.id)
3535
return WalletResponse(balance=Decimal("0"), blocked=False, currency="USD", total_spent=Decimal("0"), total_earned=Decimal("0"))
3636

37-
result = await db.execute(select(func.sum(UsageTracker.cost)).where(UsageTracker.user_id == user.id, UsageTracker.updated_at.is_not(None)))
37+
result = await db.execute(select(func.sum(UsageTracker.cost)).where(UsageTracker.user_id == user.id, UsageTracker.updated_at.is_not(None), UsageTracker.billable))
3838
total_spent = result.scalar_one_or_none() or "0"
3939
result = await db.execute(select(func.sum(StripePayment.amount)).where(StripePayment.user_id == user.id, StripePayment.status == "completed"))
4040
total_earned = result.scalar_one_or_none() or "0"

app/api/schemas/stripe.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pydantic import BaseModel
1+
from pydantic import BaseModel, field_validator
22
from typing import List, Literal
33

44
# https://docs.stripe.com/api/checkout/sessions/create
@@ -21,7 +21,16 @@ class CreateCheckoutSessionRequest(BaseModel):
2121
line_items: List[StripeCheckoutSessionLineItem]
2222
# Only allow payment mode for now
2323
mode: Literal["payment"] = "payment"
24+
# Attach the session_id to the success_url
25+
# https://docs.stripe.com/payments/checkout/custom-success-page?payment-ui=stripe-hosted&utm_source=chatgpt.com#success-url
2426
success_url: str | None = None
2527
return_url: str | None = None
2628
cancel_url: str | None = None
27-
ui_mode: str = "hosted"
29+
ui_mode: str = "hosted"
30+
31+
@field_validator("success_url")
32+
@classmethod
33+
def append_session_id_to_success_url(cls, value: str):
34+
if value is None:
35+
return None
36+
return value.rstrip("/") + "?session_id={CHECKOUT_SESSION_ID}"

app/models/usage_tracker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from datetime import UTC
33
import uuid
44

5-
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, DECIMAL
5+
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, DECIMAL, Boolean
66
from sqlalchemy.orm import relationship
77
from sqlalchemy.dialects.postgresql import UUID
88
from .base import Base
@@ -25,5 +25,6 @@ class UsageTracker(Base):
2525
cost = Column(DECIMAL(12, 8), nullable=True)
2626
currency = Column(String(3), nullable=True)
2727
pricing_source = Column(String(255), nullable=True)
28+
billable = Column(Boolean, nullable=False, default=False)
2829

2930
provider_key = relationship("ProviderKey", back_populates="usage_tracker")

app/services/provider_service.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
InvalidForgeKeyException,
1919
)
2020
from app.models.user import User
21+
from app.models.provider_key import ProviderKey
2122
from app.core.database import get_db_session
2223
from app.services.wallet_service import WalletService
2324

@@ -193,9 +194,6 @@ async def _load_provider_keys(self) -> dict[str, dict[str, Any]]:
193194
f"Loading provider keys from database for user {self.user_id} (sync)"
194195
)
195196

196-
# Query ProviderKey directly by user_id
197-
from app.models.provider_key import ProviderKey
198-
199197
result = await self.db.execute(
200198
select(ProviderKey).filter(
201199
ProviderKey.user_id == self.user_id, ProviderKey.deleted_at == None
@@ -253,9 +251,6 @@ async def _load_provider_keys_async(self) -> dict[str, dict[str, Any]]:
253251
f"Loading provider keys from database for user {self.user_id} (async)"
254252
)
255253

256-
# Query ProviderKey directly by user_id
257-
from app.models.provider_key import ProviderKey
258-
259254
result = await self.db.execute(
260255
select(ProviderKey).filter(
261256
ProviderKey.user_id == self.user_id, ProviderKey.deleted_at == None
@@ -597,14 +592,18 @@ async def process_request(
597592
# Process the request through the adapter
598593
usage_tracker_id = None
599594
if self.api_key_id is not None and provider_key_id is not None:
600-
await WalletService.wallet_precheck(self.user_id, self.db, provider_key_id)
595+
result = await self.db.execute(select(ProviderKey.billable).where(ProviderKey.id == provider_key_id))
596+
billable = result.scalar_one_or_none() or False
597+
if billable:
598+
await WalletService.wallet_precheck(self.user_id, self.db)
601599
usage_tracker_id = await UsageTrackerService.start_tracking_usage(
602600
db=self.db,
603601
user_id=self.user_id,
604602
provider_key_id=provider_key_id,
605603
forge_key_id=self.api_key_id,
606604
model=actual_model,
607605
endpoint=endpoint,
606+
billable=billable,
608607
)
609608
else:
610609
# For api like list models, we don't have usage tracking

app/services/providers/usage_tracker_service.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ async def start_tracking_usage(
2424
forge_key_id: int,
2525
model: str,
2626
endpoint: str,
27+
billable: bool = False,
2728
) -> int:
2829
try:
2930
usage_tracker = UsageTracker(
@@ -33,6 +34,7 @@ async def start_tracking_usage(
3334
model=model,
3435
endpoint=endpoint,
3536
created_at=datetime.now(UTC),
37+
billable=billable,
3638
)
3739
db.add(usage_tracker)
3840
await db.commit()
@@ -83,7 +85,7 @@ async def update_usage_tracker(
8385
usage_tracker.pricing_source = price_info['pricing_source']
8486

8587
# Deduct from wallet balance if the provider is not free
86-
if price_info['total_cost'] and price_info['total_cost'] > 0 and usage_tracker.provider_key.billable:
88+
if price_info['total_cost'] and price_info['total_cost'] > 0 and usage_tracker.billable:
8789
try:
8890
result = await WalletService.adjust(
8991
db,

app/services/wallet_service.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from app.core.logger import get_logger
1010
from app.models.wallet import Wallet
11-
from app.models.provider_key import ProviderKey
1211

1312
logger = get_logger(name="wallet_service")
1413

@@ -62,6 +61,10 @@ async def adjust(
6261
currency: str = "USD"
6362
) -> Dict[str, any]:
6463
"""Adjust wallet balance with optimistic locking and retry"""
64+
65+
# enforce delta to be Decimal
66+
delta = Decimal(delta)
67+
6568
for attempt in range(MAX_RETRIES):
6669
try:
6770
# Read current wallet state including version
@@ -191,14 +194,8 @@ async def get(db: AsyncSession, account_id: int) -> Optional[Dict[str, any]]:
191194
# Helper: perform wallet precheck
192195
# -------------------------------------------------------------
193196
@staticmethod
194-
async def wallet_precheck(user_id: int, db: AsyncSession, provider_key_id: int) -> None:
197+
async def wallet_precheck(user_id: int, db: AsyncSession) -> None:
195198
"""Check wallet balance and ensure user can make requests"""
196-
provider_key = await db.execute(select(ProviderKey).filter(ProviderKey.id == provider_key_id, ProviderKey.billable))
197-
provider_key = provider_key.scalar_one_or_none()
198-
# If the provider key is not billable, we don't need to check the wallet
199-
if not provider_key:
200-
return
201-
202199
await WalletService.ensure_wallet(db, user_id)
203200
check_result = await WalletService.precheck(db, user_id)
204201

0 commit comments

Comments
 (0)