Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/baby_serverlist/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from django.db import models

from accounts.models import Account
from commons.cache import get_baby_server_heartbeat
from commons.cache import BABY_SERVER_HEARTBEAT_TTL_SECONDS, get_baby_server_heartbeat

SERVERLIST_TOKEN_SALT = "baby_serverlist.serverlist_token"
LIVE_HEARTBEAT_GRACE_SECONDS = 2


class BabyServer(models.Model):
Expand Down Expand Up @@ -40,7 +41,7 @@ def generate_serverlist_token(self) -> str:
return signing.dumps(payload, salt=SERVERLIST_TOKEN_SALT)

def is_live(self) -> bool:
"""Return True when the server has reported within the last 12 seconds."""
"""Return True when the server has reported within the heartbeat TTL window."""
heartbeat_iso = get_baby_server_heartbeat(str(self.id))
if not heartbeat_iso:
return False
Expand All @@ -50,4 +51,7 @@ def is_live(self) -> bool:
return False
if heartbeat_time.tzinfo is None:
heartbeat_time = heartbeat_time.replace(tzinfo=UTC)
return datetime.now(tz=UTC) - heartbeat_time <= timedelta(seconds=12)

# live if last heartbeat within the cache TTL plus a small grace buffer
ttl_with_grace = BABY_SERVER_HEARTBEAT_TTL_SECONDS + LIVE_HEARTBEAT_GRACE_SECONDS
return datetime.now(tz=UTC) - heartbeat_time <= timedelta(seconds=ttl_with_grace)
7 changes: 5 additions & 2 deletions src/commons/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from django.core.cache import cache

BABY_SERVER_STATUS_TTL_SECONDS = 10
BABY_SERVER_HEARTBEAT_TTL_SECONDS = 10

SERVER_STATUS_KEY_PREFIX = "baby_server_status:"
SERVER_HEARTBEAT_KEY_PREFIX = "baby_server_heartbeat:"

Expand All @@ -19,7 +22,7 @@ def _heartbeat_key(server_id: str) -> str:

def set_baby_server_status(server_id: str, status: dict[str, Any]) -> None:
"""Persist the latest status payload for a server."""
cache.set(_status_key(server_id), status)
cache.set(_status_key(server_id), status, timeout=BABY_SERVER_STATUS_TTL_SECONDS)


def get_baby_server_status(server_id: str) -> dict[str, Any] | None:
Expand All @@ -39,7 +42,7 @@ def get_many_baby_server_statuses(server_ids: Iterable[str]) -> dict[str, dict[s

def set_baby_server_heartbeat(server_id: str, timestamp: str) -> None:
"""Persist the last-reported timestamp for a server."""
cache.set(_heartbeat_key(server_id), timestamp)
cache.set(_heartbeat_key(server_id), timestamp, timeout=BABY_SERVER_HEARTBEAT_TTL_SECONDS)


def get_baby_server_heartbeat(server_id: str) -> str | None:
Expand Down
24 changes: 22 additions & 2 deletions src/tests/baby_serverlist/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from rest_framework.test import APITestCase

from accounts.models import Account
from baby_serverlist.models import BabyServer
from baby_serverlist.models import LIVE_HEARTBEAT_GRACE_SECONDS, BabyServer
from commons.cache import (
BABY_SERVER_HEARTBEAT_TTL_SECONDS,
get_baby_server_heartbeat,
get_baby_server_status,
set_baby_server_heartbeat,
Expand Down Expand Up @@ -154,7 +155,9 @@ def test_list_owned_baby_servers_live_flag(self) -> None:
response = self.client.get(reverse("baby_serverlist:list-owned"))
self.assertTrue(response.json()[0]["live"])

stale_time = datetime.now(tz=UTC) - timedelta(seconds=13)
stale_time = datetime.now(tz=UTC) - timedelta(
seconds=BABY_SERVER_HEARTBEAT_TTL_SECONDS + LIVE_HEARTBEAT_GRACE_SECONDS + 1
)
set_baby_server_heartbeat(str(baby_server.id), stale_time.isoformat())

response = self.client.get(reverse("baby_serverlist:list-owned"))
Expand Down Expand Up @@ -184,3 +187,20 @@ def test_list_baby_servers_ignores_non_whitelisted(self) -> None:

self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(response.json(), {"servers": []})

def test_baby_server_is_live_respects_heartbeat_ttl(self) -> None:
baby_server = BabyServer.objects.create(owner=self.user)

fresh_time = datetime.now(tz=UTC) - timedelta(
seconds=BABY_SERVER_HEARTBEAT_TTL_SECONDS + LIVE_HEARTBEAT_GRACE_SECONDS - 1
)
set_baby_server_heartbeat(str(baby_server.id), fresh_time.isoformat())
self.assertTrue(baby_server.is_live())

stale_time = datetime.now(tz=UTC) - timedelta(
seconds=BABY_SERVER_HEARTBEAT_TTL_SECONDS + LIVE_HEARTBEAT_GRACE_SECONDS + 1
)
set_baby_server_heartbeat(str(baby_server.id), stale_time.isoformat())
stored = get_baby_server_heartbeat(str(baby_server.id))
self.assertEqual(stored, stale_time.isoformat())
self.assertFalse(baby_server.is_live())
1 change: 1 addition & 0 deletions src/tests/commons/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

33 changes: 33 additions & 0 deletions src/tests/commons/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from unittest.mock import patch

from django.test import SimpleTestCase

from commons import cache as cache_module


class CommonsCacheTests(SimpleTestCase):
def test_set_baby_server_status_uses_ephemeral_timeout(self) -> None:
payload = {"ServerName": "test"}
server_id = "server-123"

with patch.object(cache_module, "cache") as fake_cache:
cache_module.set_baby_server_status(server_id, payload)

fake_cache.set.assert_called_once_with(
f"{cache_module.SERVER_STATUS_KEY_PREFIX}{server_id}",
payload,
timeout=cache_module.BABY_SERVER_STATUS_TTL_SECONDS,
)

def test_set_baby_server_heartbeat_uses_ephemeral_timeout(self) -> None:
timestamp = "2024-01-01T00:00:00+00:00"
server_id = "server-456"

with patch.object(cache_module, "cache") as fake_cache:
cache_module.set_baby_server_heartbeat(server_id, timestamp)

fake_cache.set.assert_called_once_with(
f"{cache_module.SERVER_HEARTBEAT_KEY_PREFIX}{server_id}",
timestamp,
timeout=cache_module.BABY_SERVER_HEARTBEAT_TTL_SECONDS,
)
Loading