Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/baby_serverlist/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from django.db import models

from accounts.models import Account
from commons.cache import BABY_SERVER_HEARTBEAT_TTL_SECONDS, get_baby_server_heartbeat
from central_command.settings import BABY_SERVER_STATUS_TTL_SECONDS
from commons.cache import get_baby_server_heartbeat

SERVERLIST_TOKEN_SALT = "baby_serverlist.serverlist_token"
LIVE_HEARTBEAT_GRACE_SECONDS = 2


class BabyServer(models.Model):
Expand Down Expand Up @@ -52,6 +52,4 @@ def is_live(self) -> bool:
if heartbeat_time.tzinfo is None:
heartbeat_time = heartbeat_time.replace(tzinfo=UTC)

# live if last heartbeat within the cache TTL plus a small grace buffer
ttl_with_grace = BABY_SERVER_HEARTBEAT_TTL_SECONDS + LIVE_HEARTBEAT_GRACE_SECONDS
return datetime.now(tz=UTC) - heartbeat_time <= timedelta(seconds=ttl_with_grace)
return datetime.now(tz=UTC) - heartbeat_time <= timedelta(seconds=BABY_SERVER_STATUS_TTL_SECONDS)
1 change: 1 addition & 0 deletions src/central_command/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

_csrf_origins = os.environ.get("CSRF_TRUSTED_ORIGINS", "")
CSRF_TRUSTED_ORIGINS = [origin.strip() for origin in _csrf_origins.split(",") if origin.strip()]
BABY_SERVER_STATUS_TTL_SECONDS = 20

INSTALLED_APPS = [
"django.contrib.admin",
Expand Down
5 changes: 2 additions & 3 deletions src/commons/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from django.core.cache import cache

BABY_SERVER_STATUS_TTL_SECONDS = 10
BABY_SERVER_HEARTBEAT_TTL_SECONDS = 10
from central_command.settings import BABY_SERVER_STATUS_TTL_SECONDS

SERVER_STATUS_KEY_PREFIX = "baby_server_status:"
SERVER_HEARTBEAT_KEY_PREFIX = "baby_server_heartbeat:"
Expand Down Expand Up @@ -42,7 +41,7 @@ def get_many_baby_server_statuses(server_ids: Iterable[str]) -> dict[str, dict[s

def set_baby_server_heartbeat(server_id: str, timestamp: str) -> None:
"""Persist the last-reported timestamp for a server."""
cache.set(_heartbeat_key(server_id), timestamp, timeout=BABY_SERVER_HEARTBEAT_TTL_SECONDS)
cache.set(_heartbeat_key(server_id), timestamp, timeout=BABY_SERVER_STATUS_TTL_SECONDS)


def get_baby_server_heartbeat(server_id: str) -> str | None:
Expand Down
16 changes: 5 additions & 11 deletions src/tests/baby_serverlist/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from rest_framework.test import APITestCase

from accounts.models import Account
from baby_serverlist.models import LIVE_HEARTBEAT_GRACE_SECONDS, BabyServer
from baby_serverlist.models import BabyServer
from central_command.settings import BABY_SERVER_STATUS_TTL_SECONDS
from commons.cache import (
BABY_SERVER_HEARTBEAT_TTL_SECONDS,
get_baby_server_heartbeat,
get_baby_server_status,
set_baby_server_heartbeat,
Expand Down Expand Up @@ -155,9 +155,7 @@ def test_list_owned_baby_servers_live_flag(self) -> None:
response = self.client.get(reverse("baby_serverlist:list-owned"))
self.assertTrue(response.json()[0]["live"])

stale_time = datetime.now(tz=UTC) - timedelta(
seconds=BABY_SERVER_HEARTBEAT_TTL_SECONDS + LIVE_HEARTBEAT_GRACE_SECONDS + 1
)
stale_time = datetime.now(tz=UTC) - timedelta(seconds=BABY_SERVER_STATUS_TTL_SECONDS + 1)
set_baby_server_heartbeat(str(baby_server.id), stale_time.isoformat())

response = self.client.get(reverse("baby_serverlist:list-owned"))
Expand Down Expand Up @@ -191,15 +189,11 @@ def test_list_baby_servers_ignores_non_whitelisted(self) -> None:
def test_baby_server_is_live_respects_heartbeat_ttl(self) -> None:
baby_server = BabyServer.objects.create(owner=self.user)

fresh_time = datetime.now(tz=UTC) - timedelta(
seconds=BABY_SERVER_HEARTBEAT_TTL_SECONDS + LIVE_HEARTBEAT_GRACE_SECONDS - 1
)
fresh_time = datetime.now(tz=UTC) - timedelta(seconds=BABY_SERVER_STATUS_TTL_SECONDS - 1)
set_baby_server_heartbeat(str(baby_server.id), fresh_time.isoformat())
self.assertTrue(baby_server.is_live())

stale_time = datetime.now(tz=UTC) - timedelta(
seconds=BABY_SERVER_HEARTBEAT_TTL_SECONDS + LIVE_HEARTBEAT_GRACE_SECONDS + 1
)
stale_time = datetime.now(tz=UTC) - timedelta(seconds=BABY_SERVER_STATUS_TTL_SECONDS + 1)
set_baby_server_heartbeat(str(baby_server.id), stale_time.isoformat())
stored = get_baby_server_heartbeat(str(baby_server.id))
self.assertEqual(stored, stale_time.isoformat())
Expand Down
9 changes: 6 additions & 3 deletions src/tests/commons/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from django.test import SimpleTestCase

from central_command.settings import BABY_SERVER_STATUS_TTL_SECONDS
from commons import cache as cache_module


class CommonsCacheTests(SimpleTestCase):
def test_set_baby_server_status_uses_ephemeral_timeout(self) -> None:
@staticmethod
def test_set_baby_server_status_uses_ephemeral_timeout() -> None:
payload = {"ServerName": "test"}
server_id = "server-123"

Expand All @@ -19,7 +21,8 @@ def test_set_baby_server_status_uses_ephemeral_timeout(self) -> None:
timeout=cache_module.BABY_SERVER_STATUS_TTL_SECONDS,
)

def test_set_baby_server_heartbeat_uses_ephemeral_timeout(self) -> None:
@staticmethod
def test_set_baby_server_heartbeat_uses_ephemeral_timeout() -> None:
timestamp = "2024-01-01T00:00:00+00:00"
server_id = "server-456"

Expand All @@ -29,5 +32,5 @@ def test_set_baby_server_heartbeat_uses_ephemeral_timeout(self) -> None:
fake_cache.set.assert_called_once_with(
f"{cache_module.SERVER_HEARTBEAT_KEY_PREFIX}{server_id}",
timestamp,
timeout=cache_module.BABY_SERVER_HEARTBEAT_TTL_SECONDS,
timeout=BABY_SERVER_STATUS_TTL_SECONDS,
)
Loading