Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions tests/test_fastapi_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def t1(request: Request):
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t1")
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)

def test_single_decorator_with_headers(self, build_fastapi_app):
app, limiter = build_fastapi_app(key_func=get_ipaddr, headers_enabled=True)
Expand All @@ -33,7 +33,7 @@ async def t1(request: Request):
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t1")
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)
assert (
response.headers.get("X-RateLimit-Limit") is not None if i < 5 else True
)
Expand All @@ -50,7 +50,7 @@ async def t1(request: Request, response: Response):
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t1")
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)

def test_single_decorator_not_response_with_headers(self, build_fastapi_app):
app, limiter = build_fastapi_app(key_func=get_ipaddr, headers_enabled=True)
Expand All @@ -63,7 +63,7 @@ async def t1(request: Request, response: Response):
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t1")
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)
assert (
response.headers.get("X-RateLimit-Limit") is not None if i < 5 else True
)
Expand All @@ -84,7 +84,7 @@ async def t1(request: Request):
cli = TestClient(app)
for i in range(0, 100):
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
assert response.status_code == 200 if i < 50 else 429
assert response.status_code == (200 if i < 50 else 429)
for i in range(50):
assert cli.get("/t1").status_code == 200

Expand All @@ -109,7 +109,7 @@ async def t1(request: Request, response: Response):
cli = TestClient(app)
for i in range(0, 100):
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
assert response.status_code == 200 if i < 50 else 429
assert response.status_code == (200 if i < 50 else 429)
for i in range(50):
assert cli.get("/t1").status_code == 200

Expand All @@ -134,7 +134,7 @@ async def t1(request: Request, response: Response):
cli = TestClient(app)
for i in range(0, 100):
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
assert response.status_code == 200 if i < 50 else 429
assert response.status_code == (200 if i < 50 else 429)
for i in range(50):
assert cli.get("/t1").status_code == 200

Expand Down Expand Up @@ -253,11 +253,11 @@ async def t1(request: Request, response: Response):
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t1")
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)

for i in range(0, 20):
response = client.get("/t1", headers={"TOKEN": "secret"})
assert response.status_code == 200 if i < 10 else 429
assert response.status_code == (200 if i < 10 else 429)

def test_disabled_limiter(self, build_fastapi_app):
"""
Expand Down Expand Up @@ -308,10 +308,10 @@ async def t2(request: Request):
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t1")
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)

response = client.get("/t2")
assert response.status_code == 200 if i < 3 else 429
assert response.status_code == (200 if i < 3 else 429)

def test_callable_cost(self, build_fastapi_app):
app, limiter = build_fastapi_app(key_func=get_ipaddr)
Expand All @@ -331,10 +331,10 @@ async def t2(request: Request):
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t1", headers={"foo": "10"})
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)

response = client.get("/t2", headers={"foo": "5"})
assert response.status_code == 200 if i < 6 else 429
assert response.status_code == (200 if i < 6 else 429)

@pytest.mark.parametrize(
"key_style",
Expand Down
24 changes: 12 additions & 12 deletions tests/test_starlette_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def t1(request: Request):
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t1")
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)
if i < 5:
assert response.text == "test"

Expand All @@ -39,7 +39,7 @@ def t1(request: Request):
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t1")
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)
if i < 5:
assert response.text == "test"

Expand Down Expand Up @@ -83,7 +83,7 @@ def always_dynamic(request: Request):
# Test always false hitting the limit after one hit
for i in range(0, 2):
response = client.get("/false")
assert response.status_code == 200 if i < 1 else 429
assert response.status_code == (200 if i < 1 else 429)
if i < 1:
assert response.text == "test"
# Test dynamic not exempting with the correct header
Expand All @@ -94,7 +94,7 @@ def always_dynamic(request: Request):
# Test dynamic exempting with the incorrect header
for i in range(0, 2):
response = client.get("/dynamic")
assert response.status_code == 200 if i < 1 else 429
assert response.status_code == (200 if i < 1 else 429)
if i < 1:
assert response.text == "test"

Expand All @@ -117,7 +117,7 @@ def t2(request: Request):
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t1")
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)
# the shared limit has already been hit via t1
assert client.get("/t2").status_code == 429

Expand All @@ -135,7 +135,7 @@ async def t1(request: Request):
cli = TestClient(app)
for i in range(0, 10):
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)
for i in range(5):
assert cli.get("/t1").status_code == 200

Expand All @@ -159,7 +159,7 @@ async def t1(request: Request):
cli = TestClient(app)
for i in range(0, 10):
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)
assert response.headers.get("Retry-After") if i < 5 else True
for i in range(5):
assert cli.get("/t1").status_code == 200
Expand Down Expand Up @@ -304,7 +304,7 @@ async def t1(request: Request):
cli = TestClient(app)
for i in range(0, 10):
response = cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"})
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)
for i in range(5):
assert cli.get("/t1").status_code == 200

Expand Down Expand Up @@ -332,14 +332,14 @@ async def t2(request: Request):
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t1")
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)
if i < 5:
assert response.text == "test"
else:
assert "error" in response.json()

response = client.get("/t2")
assert response.status_code == 200 if i < 3 else 429
assert response.status_code == (200 if i < 3 else 429)
if i < 3:
assert response.text == "test"
else:
Expand All @@ -365,14 +365,14 @@ async def t2(request: Request):
client = TestClient(app)
for i in range(0, 10):
response = client.get("/t1", headers={"foo": "10"})
assert response.status_code == 200 if i < 5 else 429
assert response.status_code == (200 if i < 5 else 429)
if i < 5:
assert response.text == "test"
else:
assert "error" in response.json()

response = client.get("/t2", headers={"foo": "5"})
assert response.status_code == 200 if i < 6 else 429
assert response.status_code == (200 if i < 6 else 429)
if i < 6:
assert response.text == "test"
else:
Expand Down