diff --git a/tests/test_capability_tokens.py b/tests/test_capability_tokens.py new file mode 100644 index 0000000..efb65bf --- /dev/null +++ b/tests/test_capability_tokens.py @@ -0,0 +1,433 @@ +"""Tests for capability tokens.""" + +from datetime import datetime, timedelta + +import pytest + +from lexecon.capability.tokens import CapabilityToken, CapabilityTokenStore + + +class TestCapabilityToken: + """Tests for CapabilityToken class.""" + + def test_create_token(self): + """Test creating a capability token.""" + token = CapabilityToken.create( + action="search", tool="web_search", policy_version_hash="abc123" + ) + + assert token.token_id.startswith("tok_") + assert len(token.token_id) == 20 # "tok_" + 16 hex chars + assert token.scope["action"] == "search" + assert token.scope["tool"] == "web_search" + assert token.policy_version_hash == "abc123" + assert token.granted_at is not None + assert token.expiry > token.granted_at + + def test_create_token_with_custom_ttl(self): + """Test creating token with custom TTL.""" + token = CapabilityToken.create( + action="write", tool="database", policy_version_hash="xyz789", ttl_minutes=10 + ) + + # Token should expire in 10 minutes + expected_expiry = datetime.utcnow() + timedelta(minutes=10) + time_diff = abs((token.expiry - expected_expiry).total_seconds()) + assert time_diff < 2 # Allow 2 seconds tolerance + + def test_token_is_valid_when_not_expired(self): + """Test that non-expired token is valid.""" + token = CapabilityToken.create( + action="read", tool="file_system", policy_version_hash="hash1", ttl_minutes=5 + ) + + assert token.is_valid() is True + + def test_token_is_invalid_when_expired(self): + """Test that expired token is invalid.""" + # Create token that expires immediately + now = datetime.utcnow() + token = CapabilityToken( + token_id="tok_expired123", + scope={"action": "read", "tool": "file_system"}, + expiry=now - timedelta(minutes=1), # Already expired + policy_version_hash="hash1", + granted_at=now - timedelta(minutes=10), + ) + + assert token.is_valid() is False + + def test_token_is_authorized_for_matching_action_and_tool(self): + """Test authorization check with matching action and tool.""" + token = CapabilityToken.create( + action="search", tool="web_search", policy_version_hash="hash1" + ) + + assert token.is_authorized_for("search", "web_search") is True + + def test_token_not_authorized_for_different_action(self): + """Test authorization check fails with different action.""" + token = CapabilityToken.create( + action="search", tool="web_search", policy_version_hash="hash1" + ) + + assert token.is_authorized_for("write", "web_search") is False + + def test_token_not_authorized_for_different_tool(self): + """Test authorization check fails with different tool.""" + token = CapabilityToken.create( + action="search", tool="web_search", policy_version_hash="hash1" + ) + + assert token.is_authorized_for("search", "database") is False + + def test_expired_token_not_authorized(self): + """Test that expired token is not authorized even with correct scope.""" + now = datetime.utcnow() + token = CapabilityToken( + token_id="tok_expired456", + scope={"action": "read", "tool": "file_system"}, + expiry=now - timedelta(minutes=1), # Expired + policy_version_hash="hash1", + granted_at=now - timedelta(minutes=10), + ) + + assert token.is_authorized_for("read", "file_system") is False + + def test_token_serialization(self): + """Test token serialization to dict.""" + token = CapabilityToken.create( + action="delete", tool="admin_panel", policy_version_hash="hash123" + ) + token.signature = "test_signature" + + data = token.to_dict() + + assert data["token_id"] == token.token_id + assert data["scope"]["action"] == "delete" + assert data["scope"]["tool"] == "admin_panel" + assert data["policy_version_hash"] == "hash123" + assert data["signature"] == "test_signature" + assert "expiry" in data + assert "granted_at" in data + + def test_token_deserialization(self): + """Test token deserialization from dict.""" + now = datetime.utcnow() + expiry = now + timedelta(minutes=5) + + data = { + "token_id": "tok_test123456789", + "scope": {"action": "update", "tool": "config"}, + "expiry": expiry.isoformat(), + "policy_version_hash": "hash999", + "granted_at": now.isoformat(), + "signature": "sig_abc", + } + + token = CapabilityToken.from_dict(data) + + assert token.token_id == "tok_test123456789" + assert token.scope["action"] == "update" + assert token.scope["tool"] == "config" + assert token.policy_version_hash == "hash999" + assert token.signature == "sig_abc" + assert isinstance(token.expiry, datetime) + assert isinstance(token.granted_at, datetime) + + def test_token_serialization_roundtrip(self): + """Test that serialization and deserialization preserve token data.""" + original = CapabilityToken.create( + action="execute", tool="script_runner", policy_version_hash="hashABC" + ) + original.signature = "test_sig_123" + + # Serialize and deserialize + data = original.to_dict() + restored = CapabilityToken.from_dict(data) + + assert restored.token_id == original.token_id + assert restored.scope == original.scope + assert restored.policy_version_hash == original.policy_version_hash + assert restored.signature == original.signature + # Time comparison with tolerance + assert abs((restored.expiry - original.expiry).total_seconds()) < 1 + assert abs((restored.granted_at - original.granted_at).total_seconds()) < 1 + + def test_token_deserialization_without_signature(self): + """Test deserialization when signature is not present.""" + now = datetime.utcnow() + data = { + "token_id": "tok_nosig", + "scope": {"action": "read", "tool": "api"}, + "expiry": (now + timedelta(minutes=5)).isoformat(), + "policy_version_hash": "hash1", + "granted_at": now.isoformat(), + } + + token = CapabilityToken.from_dict(data) + assert token.signature is None + + def test_different_tokens_have_unique_ids(self): + """Test that multiple tokens get unique IDs.""" + tokens = [ + CapabilityToken.create("action", "tool", "hash") for _ in range(100) + ] + + token_ids = [t.token_id for t in tokens] + assert len(token_ids) == len(set(token_ids)) # All unique + + +class TestCapabilityTokenStore: + """Tests for CapabilityTokenStore class.""" + + def test_store_initialization(self): + """Test token store initialization.""" + store = CapabilityTokenStore() + assert len(store.tokens) == 0 + + def test_store_token(self): + """Test storing a token.""" + store = CapabilityTokenStore() + token = CapabilityToken.create("read", "api", "hash1") + + store.store(token) + + assert len(store.tokens) == 1 + assert token.token_id in store.tokens + + def test_get_existing_token(self): + """Test retrieving an existing token.""" + store = CapabilityTokenStore() + token = CapabilityToken.create("write", "database", "hash2") + store.store(token) + + retrieved = store.get(token.token_id) + + assert retrieved is not None + assert retrieved.token_id == token.token_id + assert retrieved.scope == token.scope + + def test_get_nonexistent_token(self): + """Test retrieving a token that doesn't exist.""" + store = CapabilityTokenStore() + + result = store.get("tok_nonexistent") + + assert result is None + + def test_verify_valid_token(self): + """Test verifying a valid token with correct scope.""" + store = CapabilityTokenStore() + token = CapabilityToken.create("search", "web_search", "hash3") + store.store(token) + + is_valid = store.verify(token.token_id, "search", "web_search") + + assert is_valid is True + + def test_verify_token_wrong_action(self): + """Test verification fails with wrong action.""" + store = CapabilityTokenStore() + token = CapabilityToken.create("read", "file", "hash4") + store.store(token) + + is_valid = store.verify(token.token_id, "write", "file") + + assert is_valid is False + + def test_verify_token_wrong_tool(self): + """Test verification fails with wrong tool.""" + store = CapabilityTokenStore() + token = CapabilityToken.create("execute", "script", "hash5") + store.store(token) + + is_valid = store.verify(token.token_id, "execute", "command") + + assert is_valid is False + + def test_verify_nonexistent_token(self): + """Test verification fails for non-existent token.""" + store = CapabilityTokenStore() + + is_valid = store.verify("tok_fake123", "any", "any") + + assert is_valid is False + + def test_verify_expired_token(self): + """Test verification fails for expired token.""" + store = CapabilityTokenStore() + now = datetime.utcnow() + expired_token = CapabilityToken( + token_id="tok_expired999", + scope={"action": "read", "tool": "api"}, + expiry=now - timedelta(minutes=1), + policy_version_hash="hash6", + granted_at=now - timedelta(minutes=10), + ) + store.store(expired_token) + + is_valid = store.verify(expired_token.token_id, "read", "api") + + assert is_valid is False + + def test_cleanup_expired_tokens(self): + """Test cleanup of expired tokens.""" + store = CapabilityTokenStore() + now = datetime.utcnow() + + # Create valid token + valid_token = CapabilityToken.create("read", "api", "hash7") + store.store(valid_token) + + # Create expired tokens + for i in range(3): + expired = CapabilityToken( + token_id=f"tok_exp{i}", + scope={"action": "write", "tool": "db"}, + expiry=now - timedelta(minutes=i + 1), + policy_version_hash="hash8", + granted_at=now - timedelta(minutes=10), + ) + store.store(expired) + + # Should have 4 tokens total + assert len(store.tokens) == 4 + + # Cleanup expired + removed_count = store.cleanup_expired() + + # Should remove 3 expired tokens + assert removed_count == 3 + assert len(store.tokens) == 1 + assert valid_token.token_id in store.tokens + + def test_cleanup_with_no_expired_tokens(self): + """Test cleanup when there are no expired tokens.""" + store = CapabilityTokenStore() + + # Create only valid tokens + for i in range(5): + token = CapabilityToken.create(f"action{i}", "tool", f"hash{i}") + store.store(token) + + removed_count = store.cleanup_expired() + + assert removed_count == 0 + assert len(store.tokens) == 5 + + def test_cleanup_empty_store(self): + """Test cleanup on empty store.""" + store = CapabilityTokenStore() + + removed_count = store.cleanup_expired() + + assert removed_count == 0 + + def test_store_multiple_tokens(self): + """Test storing multiple tokens.""" + store = CapabilityTokenStore() + tokens = [] + + for i in range(10): + token = CapabilityToken.create(f"action{i}", f"tool{i}", f"hash{i}") + tokens.append(token) + store.store(token) + + assert len(store.tokens) == 10 + + # Verify all can be retrieved + for token in tokens: + retrieved = store.get(token.token_id) + assert retrieved is not None + assert retrieved.token_id == token.token_id + + def test_store_overwrites_existing_token(self): + """Test that storing same token ID overwrites.""" + store = CapabilityTokenStore() + token1 = CapabilityToken.create("read", "api", "hash1") + store.store(token1) + + # Create new token with same ID + now = datetime.utcnow() + token2 = CapabilityToken( + token_id=token1.token_id, + scope={"action": "write", "tool": "db"}, + expiry=now + timedelta(minutes=10), + policy_version_hash="hash2", + granted_at=now, + ) + store.store(token2) + + # Should only have 1 token + assert len(store.tokens) == 1 + + # Should have the new token's scope + retrieved = store.get(token1.token_id) + assert retrieved.scope["action"] == "write" + assert retrieved.policy_version_hash == "hash2" + + +class TestTokenEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_token_with_zero_ttl(self): + """Test token creation with zero TTL.""" + token = CapabilityToken.create("action", "tool", "hash", ttl_minutes=0) + + # Should be expired immediately + assert token.is_valid() is False + + def test_token_with_negative_ttl(self): + """Test token creation with negative TTL.""" + token = CapabilityToken.create("action", "tool", "hash", ttl_minutes=-5) + + # Should be expired + assert token.is_valid() is False + + def test_token_scope_with_additional_fields(self): + """Test token scope can contain additional fields.""" + now = datetime.utcnow() + token = CapabilityToken( + token_id="tok_extra", + scope={ + "action": "read", + "tool": "api", + "resource": "/users/123", + "method": "GET", + }, + expiry=now + timedelta(minutes=5), + policy_version_hash="hash1", + granted_at=now, + ) + + # Basic authorization should still work + assert token.is_authorized_for("read", "api") is True + # Extra fields preserved + assert token.scope["resource"] == "/users/123" + assert token.scope["method"] == "GET" + + def test_token_with_empty_scope(self): + """Test token with minimal scope.""" + now = datetime.utcnow() + token = CapabilityToken( + token_id="tok_empty", + scope={}, + expiry=now + timedelta(minutes=5), + policy_version_hash="hash1", + granted_at=now, + ) + + # Should not authorize anything + assert token.is_authorized_for("action", "tool") is False + + def test_very_long_ttl(self): + """Test token with very long TTL.""" + token = CapabilityToken.create("action", "tool", "hash", ttl_minutes=525600) # 1 year + + assert token.is_valid() is True + + # Expiry should be about 1 year from now + expected = datetime.utcnow() + timedelta(days=365) + time_diff = abs((token.expiry - expected).total_seconds()) + assert time_diff < 60 # Within 1 minute tolerance diff --git a/tests/test_decision_service.py b/tests/test_decision_service.py new file mode 100644 index 0000000..11e0049 --- /dev/null +++ b/tests/test_decision_service.py @@ -0,0 +1,685 @@ +"""Tests for decision service.""" + +import uuid +from datetime import datetime + +import pytest + +from lexecon.decision.service import DecisionRequest, DecisionResponse, DecisionService +from lexecon.identity.signing import NodeIdentity +from lexecon.ledger.chain import LedgerChain +from lexecon.policy.engine import PolicyEngine, PolicyMode +from lexecon.policy.relations import PolicyRelation +from lexecon.policy.terms import PolicyTerm + + +class TestDecisionRequest: + """Tests for DecisionRequest class.""" + + def test_create_decision_request(self): + """Test creating a decision request.""" + request = DecisionRequest( + actor="model", + proposed_action="search", + tool="web_search", + user_intent="Research AI governance", + ) + + assert request.actor == "model" + assert request.proposed_action == "search" + assert request.tool == "web_search" + assert request.user_intent == "Research AI governance" + assert request.risk_level == 1 # Default + assert request.policy_mode == "strict" # Default + assert isinstance(request.request_id, str) + assert len(request.request_id) > 0 + + def test_request_with_custom_id(self): + """Test creating request with custom ID.""" + custom_id = "req_123456" + request = DecisionRequest( + request_id=custom_id, + actor="user", + proposed_action="read", + tool="file_system", + user_intent="Read file", + ) + + assert request.request_id == custom_id + + def test_request_generates_uuid_if_no_id(self): + """Test that request generates UUID if no ID provided.""" + request1 = DecisionRequest( + actor="model", proposed_action="action1", tool="tool1", user_intent="intent1" + ) + request2 = DecisionRequest( + actor="model", proposed_action="action2", tool="tool2", user_intent="intent2" + ) + + # Should be different UUIDs + assert request1.request_id != request2.request_id + + # Should be valid UUIDs + uuid.UUID(request1.request_id) + uuid.UUID(request2.request_id) + + def test_request_with_data_classes(self): + """Test request with data classes.""" + request = DecisionRequest( + actor="model", + proposed_action="process", + tool="analytics", + user_intent="Analyze data", + data_classes=["pii", "financial"], + ) + + assert request.data_classes == ["pii", "financial"] + + def test_request_with_high_risk_level(self): + """Test request with high risk level.""" + request = DecisionRequest( + actor="model", + proposed_action="delete", + tool="database", + user_intent="Clean up", + risk_level=5, + ) + + assert request.risk_level == 5 + + def test_request_with_context(self): + """Test request with additional context.""" + context = {"session_id": "abc123", "ip": "192.168.1.1"} + request = DecisionRequest( + actor="user", + proposed_action="login", + tool="auth", + user_intent="Login", + context=context, + ) + + assert request.context == context + + def test_request_serialization(self): + """Test request serialization to dict.""" + request = DecisionRequest( + actor="model", + proposed_action="search", + tool="web", + user_intent="Search", + data_classes=["public"], + risk_level=2, + context={"query": "test"}, + ) + + data = request.to_dict() + + assert data["actor"] == "model" + assert data["proposed_action"] == "search" + assert data["tool"] == "web" + assert data["user_intent"] == "Search" + assert data["data_classes"] == ["public"] + assert data["risk_level"] == 2 + assert data["context"]["query"] == "test" + assert "timestamp" in data + assert "request_id" in data + + def test_request_has_timestamp(self): + """Test that request includes timestamp.""" + request = DecisionRequest( + actor="model", proposed_action="action", tool="tool", user_intent="intent" + ) + + assert hasattr(request, "timestamp") + # Should be ISO format + datetime.fromisoformat(request.timestamp) + + +class TestDecisionResponse: + """Tests for DecisionResponse class.""" + + def test_create_decision_response(self): + """Test creating a decision response.""" + response = DecisionResponse( + request_id="req_123", + decision="permit", + reasoning="Action is allowed by policy", + policy_version_hash="abc123", + ) + + assert response.request_id == "req_123" + assert response.decision == "permit" + assert response.reasoning == "Action is allowed by policy" + assert response.policy_version_hash == "abc123" + + def test_response_with_capability_token(self): + """Test response including capability token.""" + token = { + "token_id": "tok_abc123", + "scope": {"action": "read", "tool": "api"}, + "expiry": "2025-01-02T12:00:00", + } + + response = DecisionResponse( + request_id="req_456", + decision="permit", + reasoning="Permitted", + policy_version_hash="hash1", + capability_token=token, + ) + + assert response.capability_token == token + + def test_response_with_ledger_entry(self): + """Test response including ledger entry hash.""" + response = DecisionResponse( + request_id="req_789", + decision="deny", + reasoning="Forbidden by policy", + policy_version_hash="hash2", + ledger_entry_hash="ledger_hash_abc", + ) + + assert response.ledger_entry_hash == "ledger_hash_abc" + + def test_response_with_signature(self): + """Test response including signature.""" + response = DecisionResponse( + request_id="req_sig", + decision="permit", + reasoning="OK", + policy_version_hash="hash3", + signature="sig_base64_encoded", + ) + + assert response.signature == "sig_base64_encoded" + + def test_decision_hash_generation(self): + """Test decision hash generation.""" + response = DecisionResponse( + request_id="req_hash", + decision="permit", + reasoning="Test", + policy_version_hash="hash4", + ) + + decision_hash = response.decision_hash + + assert isinstance(decision_hash, str) + assert len(decision_hash) == 64 # SHA256 hex + + def test_decision_hash_is_deterministic(self): + """Test that decision hash is deterministic.""" + timestamp = "2025-01-01T00:00:00" + response1 = DecisionResponse( + request_id="req_det", + decision="permit", + reasoning="Test", + policy_version_hash="hash5", + timestamp=timestamp, + ) + response2 = DecisionResponse( + request_id="req_det", + decision="permit", + reasoning="Different reasoning", # Hash doesn't include reasoning + policy_version_hash="hash5", + timestamp=timestamp, + ) + + # Same request_id, decision, policy hash, timestamp -> same hash + assert response1.decision_hash == response2.decision_hash + + def test_response_serialization(self): + """Test response serialization to dict.""" + response = DecisionResponse( + request_id="req_ser", + decision="deny", + reasoning="Not authorized", + policy_version_hash="hash6", + ledger_entry_hash="ledger123", + signature="sig123", + ) + + data = response.to_dict() + + assert data["request_id"] == "req_ser" + assert data["decision"] == "deny" + assert data["reasoning"] == "Not authorized" + assert data["reason"] == "Not authorized" # Backwards compatibility + assert data["allowed"] is False # deny -> allowed=False + assert data["policy_version_hash"] == "hash6" + assert data["ledger_entry_hash"] == "ledger123" + assert data["signature"] == "sig123" + assert "timestamp" in data + + def test_response_allowed_field_for_permit(self): + """Test that 'allowed' field is True for permit.""" + response = DecisionResponse( + request_id="req_permit", + decision="permit", + reasoning="OK", + policy_version_hash="hash", + ) + + data = response.to_dict() + assert data["allowed"] is True + + def test_response_allowed_field_for_deny(self): + """Test that 'allowed' field is False for deny.""" + response = DecisionResponse( + request_id="req_deny", decision="deny", reasoning="No", policy_version_hash="hash" + ) + + data = response.to_dict() + assert data["allowed"] is False + + def test_response_signature_defaults_to_empty(self): + """Test that signature defaults to empty string in serialization.""" + response = DecisionResponse( + request_id="req_nosig", decision="permit", reasoning="OK", policy_version_hash="hash" + ) + + data = response.to_dict() + assert data["signature"] == "" + + +class TestDecisionService: + """Tests for DecisionService class.""" + + @pytest.fixture + def policy_engine(self): + """Create a policy engine with basic rules.""" + engine = PolicyEngine(mode=PolicyMode.STRICT) + engine.add_term(PolicyTerm.create_actor("model", "AI Model")) + engine.add_term(PolicyTerm.create_action("search", "Search")) + engine.add_relation(PolicyRelation.permits("actor:model", "action:search")) + return engine + + @pytest.fixture + def service(self, policy_engine): + """Create basic decision service.""" + return DecisionService(policy_engine) + + @pytest.fixture + def full_service(self, policy_engine): + """Create decision service with ledger and identity.""" + ledger = LedgerChain() + identity = NodeIdentity("test-node") + return DecisionService(policy_engine, ledger, identity) + + def test_service_initialization(self, policy_engine): + """Test creating decision service.""" + service = DecisionService(policy_engine) + + assert service.policy_engine is not None + assert service.ledger is None + assert service.identity is None + + def test_service_with_ledger_and_identity(self, full_service): + """Test service with all components.""" + assert full_service.policy_engine is not None + assert full_service.ledger is not None + assert full_service.identity is not None + + def test_evaluate_request_permit(self, service): + """Test evaluating a request that should be permitted.""" + request = DecisionRequest( + actor="model", + proposed_action="search", + tool="web_search", + user_intent="Search for information", + ) + + response = service.evaluate_request(request) + + assert response.decision == "permit" + assert response.request_id == request.request_id + assert response.capability_token is not None + assert "token_id" in response.capability_token + + def test_evaluate_request_deny(self, service): + """Test evaluating a request that should be denied.""" + request = DecisionRequest( + actor="model", + proposed_action="delete", # Not permitted + tool="database", + user_intent="Delete records", + ) + + response = service.evaluate_request(request) + + assert response.decision == "deny" + assert response.capability_token is None + + def test_capability_token_generation(self, service): + """Test that capability token is generated for permitted actions.""" + request = DecisionRequest( + actor="model", proposed_action="search", tool="web_search", user_intent="Search" + ) + + response = service.evaluate_request(request) + + token = response.capability_token + assert token is not None + assert token["scope"]["action"] == "search" + assert token["scope"]["tool"] == "web_search" + assert "expiry" in token + assert "granted_at" in token + + def test_ledger_entry_creation(self, full_service): + """Test that ledger entry is created.""" + request = DecisionRequest( + actor="model", proposed_action="search", tool="web", user_intent="Test" + ) + + response = full_service.evaluate_request(request) + + assert response.ledger_entry_hash is not None + assert len(response.ledger_entry_hash) == 64 # SHA256 hex + + # Verify entry exists in ledger by hash + entry = full_service.ledger.get_entry(response.ledger_entry_hash) + assert entry is not None + assert entry.event_type == "decision" + + def test_signature_generation(self, full_service): + """Test that decision is signed.""" + request = DecisionRequest( + actor="model", proposed_action="search", tool="web", user_intent="Test" + ) + + response = full_service.evaluate_request(request) + + assert response.signature is not None + assert len(response.signature) > 0 + + # Verify it's valid base64 + import base64 + + decoded = base64.b64decode(response.signature) + assert len(decoded) > 0 + + def test_evaluate_convenience_method_simple(self, policy_engine): + """Test simple evaluate method.""" + service = DecisionService(policy_engine) + + # Simple policy evaluation + result = service.evaluate(actor="model", action="search") + + # Should return PolicyDecision object + assert hasattr(result, "allowed") + assert hasattr(result, "reason") + + def test_evaluate_convenience_method_full(self, full_service): + """Test evaluate method with full parameters.""" + # Full decision request + result = full_service.evaluate( + actor="model", + proposed_action="search", + tool="web_search", + user_intent="Testing", + risk_level=1, + ) + + # Should return DecisionResponse + assert hasattr(result, "decision") + assert hasattr(result, "capability_token") + assert hasattr(result, "ledger_entry_hash") + + def test_multiple_decisions_logged(self, full_service): + """Test that multiple decisions are logged to ledger.""" + initial_count = len(full_service.ledger.entries) + + for i in range(5): + request = DecisionRequest( + actor="model", + proposed_action="search", + tool="web", + user_intent=f"Search {i}", + ) + full_service.evaluate_request(request) + + final_count = len(full_service.ledger.entries) + assert final_count == initial_count + 5 + + def test_decision_with_data_classes(self, policy_engine): + """Test decision evaluation with data classes.""" + service = DecisionService(policy_engine) + request = DecisionRequest( + actor="model", + proposed_action="search", + tool="web", + user_intent="Search", + data_classes=["public"], + ) + + response = service.evaluate_request(request) + + # Should still work + assert response.decision in ["permit", "deny"] + + def test_decision_with_high_risk_level(self, service): + """Test decision with high risk level.""" + request = DecisionRequest( + actor="model", + proposed_action="search", + tool="web", + user_intent="Search", + risk_level=5, + ) + + response = service.evaluate_request(request) + + # Should still evaluate + assert response.decision in ["permit", "deny"] + + def test_policy_version_hash_in_response(self, service): + """Test that response includes policy version hash.""" + request = DecisionRequest( + actor="model", proposed_action="search", tool="web", user_intent="Test" + ) + + response = service.evaluate_request(request) + + assert response.policy_version_hash is not None + assert len(response.policy_version_hash) == 64 # SHA256 hex + + def test_reasoning_in_response(self, service): + """Test that response includes reasoning.""" + request = DecisionRequest( + actor="model", proposed_action="search", tool="web", user_intent="Test" + ) + + response = service.evaluate_request(request) + + assert response.reasoning is not None + assert len(response.reasoning) > 0 + + +class TestDecisionWorkflow: + """Integration tests for complete decision workflow.""" + + def test_complete_workflow(self): + """Test complete decision workflow from request to signed response.""" + # Setup + engine = PolicyEngine(mode=PolicyMode.STRICT) + engine.add_term(PolicyTerm.create_actor("user", "User")) + engine.add_term(PolicyTerm.create_action("read", "Read")) + engine.add_relation(PolicyRelation.permits("actor:user", "action:read")) + + ledger = LedgerChain() + identity = NodeIdentity("governance-node") + service = DecisionService(engine, ledger, identity) + + # Make decision + request = DecisionRequest( + actor="user", proposed_action="read", tool="file_system", user_intent="Read file" + ) + + response = service.evaluate_request(request) + + # Verify all components + assert response.decision == "permit" + assert response.capability_token is not None + assert response.ledger_entry_hash is not None + assert response.signature is not None + + # Verify ledger integrity + assert ledger.verify_integrity()["valid"] is True + + # Verify signature + is_valid = identity.verify_signature(response.decision_hash, response.signature) + assert is_valid is True + + def test_deny_workflow(self): + """Test workflow for denied decision.""" + # Strict mode - deny by default + engine = PolicyEngine(mode=PolicyMode.STRICT) + ledger = LedgerChain() + identity = NodeIdentity("test-node") + service = DecisionService(engine, ledger, identity) + + request = DecisionRequest( + actor="unknown", proposed_action="forbidden", tool="admin", user_intent="Test" + ) + + response = service.evaluate_request(request) + + # Should be denied + assert response.decision == "deny" + assert response.capability_token is None # No token for deny + assert response.ledger_entry_hash is not None # Still logged + assert response.signature is not None # Still signed + + def test_decision_audit_trail(self): + """Test that decisions create proper audit trail.""" + engine = PolicyEngine(mode=PolicyMode.STRICT) + engine.add_term(PolicyTerm.create_actor("bot", "Bot")) + engine.add_term(PolicyTerm.create_action("execute", "Execute")) + engine.add_relation(PolicyRelation.permits("actor:bot", "action:execute")) + + ledger = LedgerChain() + identity = NodeIdentity("audit-node") + service = DecisionService(engine, ledger, identity) + + # Make multiple decisions + requests = [ + DecisionRequest( + actor="bot", proposed_action="execute", tool=f"tool{i}", user_intent=f"Task {i}" + ) + for i in range(3) + ] + + responses = [service.evaluate_request(req) for req in requests] + + # Check audit report + report = ledger.generate_audit_report() + + assert report["total_entries"] >= 4 # Genesis + 3 decisions + assert "decision" in report["event_type_counts"] + assert report["event_type_counts"]["decision"] == 3 + + # Verify all entries + decision_entries = ledger.get_entries_by_type("decision") + assert len(decision_entries) == 3 + + def test_concurrent_decisions(self): + """Test handling multiple concurrent decisions.""" + engine = PolicyEngine(mode=PolicyMode.PERMISSIVE) + ledger = LedgerChain() + identity = NodeIdentity("concurrent-node") + service = DecisionService(engine, ledger, identity) + + # Simulate concurrent requests + requests = [ + DecisionRequest( + actor=f"actor{i}", proposed_action="action", tool="tool", user_intent="Intent" + ) + for i in range(10) + ] + + responses = [service.evaluate_request(req) for req in requests] + + # All should succeed + assert len(responses) == 10 + assert all(r.ledger_entry_hash is not None for r in responses) + + # All hashes should be unique + hashes = [r.ledger_entry_hash for r in responses] + assert len(hashes) == len(set(hashes)) + + # Ledger should still be valid + assert ledger.verify_integrity()["valid"] is True + + +class TestEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_empty_user_intent(self): + """Test request with empty user intent.""" + engine = PolicyEngine() + service = DecisionService(engine) + + request = DecisionRequest( + actor="model", proposed_action="action", tool="tool", user_intent="" + ) + + response = service.evaluate_request(request) + assert response is not None + + def test_very_long_user_intent(self): + """Test request with very long user intent.""" + engine = PolicyEngine() + service = DecisionService(engine) + + long_intent = "A" * 10000 + request = DecisionRequest( + actor="model", proposed_action="action", tool="tool", user_intent=long_intent + ) + + response = service.evaluate_request(request) + assert response is not None + + def test_special_characters_in_fields(self): + """Test request with special characters.""" + engine = PolicyEngine() + service = DecisionService(engine) + + request = DecisionRequest( + actor="model-v2.0", + proposed_action="read/write", + tool="tool_@#$", + user_intent="Test with 特殊文字 and émojis 🔒", + ) + + response = service.evaluate_request(request) + assert response is not None + + def test_decision_without_ledger(self): + """Test that service works without ledger.""" + engine = PolicyEngine() + service = DecisionService(engine) # No ledger + + request = DecisionRequest( + actor="model", proposed_action="action", tool="tool", user_intent="Test" + ) + + response = service.evaluate_request(request) + + assert response.ledger_entry_hash is None + assert response.decision in ["permit", "deny"] + + def test_decision_without_identity(self): + """Test that service works without identity.""" + engine = PolicyEngine() + ledger = LedgerChain() + service = DecisionService(engine, ledger) # No identity + + request = DecisionRequest( + actor="model", proposed_action="action", tool="tool", user_intent="Test" + ) + + response = service.evaluate_request(request) + + assert response.signature is None + assert response.ledger_entry_hash is not None # Ledger still works diff --git a/tests/test_health.py b/tests/test_health.py new file mode 100644 index 0000000..0b02954 --- /dev/null +++ b/tests/test_health.py @@ -0,0 +1,513 @@ +"""Tests for health check and observability functionality.""" + +import time + +import pytest + +from lexecon.observability.health import ( + HealthCheck, + HealthStatus, + check_identity, + check_ledger, + check_policy_engine, +) + + +class TestHealthStatus: + """Tests for HealthStatus enum.""" + + def test_health_status_values(self): + """Test that health status enum has expected values.""" + assert HealthStatus.HEALTHY == "healthy" + assert HealthStatus.DEGRADED == "degraded" + assert HealthStatus.UNHEALTHY == "unhealthy" + + def test_health_status_is_string(self): + """Test that health status values are strings.""" + assert isinstance(HealthStatus.HEALTHY, str) + assert isinstance(HealthStatus.DEGRADED, str) + assert isinstance(HealthStatus.UNHEALTHY, str) + + +class TestHealthCheck: + """Tests for HealthCheck class.""" + + def test_initialization(self): + """Test health check initialization.""" + hc = HealthCheck() + + assert hc.checks is not None + assert hc.start_time > 0 + assert isinstance(hc.checks, dict) + + def test_liveness_probe(self): + """Test liveness probe returns healthy status.""" + hc = HealthCheck() + + result = hc.liveness() + + assert result["status"] == HealthStatus.HEALTHY + assert "timestamp" in result + assert "uptime_seconds" in result + assert result["uptime_seconds"] >= 0 + + def test_liveness_uptime_increases(self): + """Test that uptime increases over time.""" + hc = HealthCheck() + + result1 = hc.liveness() + time.sleep(0.1) + result2 = hc.liveness() + + assert result2["uptime_seconds"] > result1["uptime_seconds"] + + def test_readiness_probe_no_checks(self): + """Test readiness probe with no health checks registered.""" + hc = HealthCheck() + hc.checks = {} # Clear default checks + + result = hc.readiness() + + assert result["status"] == HealthStatus.HEALTHY + assert "timestamp" in result + assert "checks" in result + assert len(result["checks"]) == 0 + + def test_readiness_probe_all_healthy(self): + """Test readiness probe when all checks are healthy.""" + hc = HealthCheck() + hc.checks = {} # Clear default checks + + # Add healthy checks + hc.add_check("check1", lambda: (HealthStatus.HEALTHY, {"detail": "ok"})) + hc.add_check("check2", lambda: (HealthStatus.HEALTHY, {"detail": "ok"})) + + result = hc.readiness() + + assert result["status"] == HealthStatus.HEALTHY + assert len(result["checks"]) == 2 + assert all(c["status"] == HealthStatus.HEALTHY for c in result["checks"]) + + def test_readiness_probe_one_degraded(self): + """Test readiness probe with one degraded check.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("healthy", lambda: (HealthStatus.HEALTHY, {})) + hc.add_check("degraded", lambda: (HealthStatus.DEGRADED, {"reason": "slow"})) + + result = hc.readiness() + + assert result["status"] == HealthStatus.DEGRADED + assert len(result["checks"]) == 2 + + def test_readiness_probe_one_unhealthy(self): + """Test readiness probe with one unhealthy check.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("healthy", lambda: (HealthStatus.HEALTHY, {})) + hc.add_check("unhealthy", lambda: (HealthStatus.UNHEALTHY, {"error": "failed"})) + + result = hc.readiness() + + assert result["status"] == HealthStatus.UNHEALTHY + assert len(result["checks"]) == 2 + + # Find unhealthy check + unhealthy_check = next(c for c in result["checks"] if c["name"] == "unhealthy") + assert unhealthy_check["details"]["error"] == "failed" + + def test_readiness_unhealthy_takes_precedence(self): + """Test that unhealthy status takes precedence over degraded.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("degraded", lambda: (HealthStatus.DEGRADED, {})) + hc.add_check("unhealthy", lambda: (HealthStatus.UNHEALTHY, {})) + hc.add_check("healthy", lambda: (HealthStatus.HEALTHY, {})) + + result = hc.readiness() + + assert result["status"] == HealthStatus.UNHEALTHY + + def test_readiness_handles_check_exception(self): + """Test readiness probe handles exceptions in health checks.""" + hc = HealthCheck() + hc.checks = {} + + def failing_check(): + raise RuntimeError("Check failed") + + hc.add_check("healthy", lambda: (HealthStatus.HEALTHY, {})) + hc.add_check("failing", failing_check) + + result = hc.readiness() + + # Overall status should be unhealthy + assert result["status"] == HealthStatus.UNHEALTHY + + # Failing check should show error + failing_result = next(c for c in result["checks"] if c["name"] == "failing") + assert failing_result["status"] == HealthStatus.UNHEALTHY + assert "error" in failing_result["details"] + assert "Check failed" in failing_result["details"]["error"] + + def test_add_check(self): + """Test adding custom health check.""" + hc = HealthCheck() + hc.checks = {} + + def custom_check(): + return HealthStatus.HEALTHY, {"custom": "data"} + + hc.add_check("custom", custom_check) + + assert "custom" in hc.checks + assert hc.checks["custom"] == custom_check + + def test_add_multiple_checks(self): + """Test adding multiple health checks.""" + hc = HealthCheck() + hc.checks = {} + + for i in range(5): + hc.add_check(f"check{i}", lambda: (HealthStatus.HEALTHY, {})) + + assert len(hc.checks) == 5 + + def test_startup_probe(self): + """Test startup probe.""" + hc = HealthCheck() + + result = hc.startup() + + assert result["status"] == HealthStatus.HEALTHY + assert "timestamp" in result + assert "message" in result + assert result["message"] == "Service initialized" + + def test_readiness_check_details(self): + """Test that readiness probe includes check details.""" + hc = HealthCheck() + hc.checks = {} + + details = {"version": "1.0", "connections": 5} + hc.add_check("detailed", lambda: (HealthStatus.HEALTHY, details)) + + result = hc.readiness() + + check_result = result["checks"][0] + assert check_result["name"] == "detailed" + assert check_result["details"] == details + + def test_multiple_unhealthy_checks(self): + """Test handling multiple unhealthy checks.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("fail1", lambda: (HealthStatus.UNHEALTHY, {"error": "error1"})) + hc.add_check("fail2", lambda: (HealthStatus.UNHEALTHY, {"error": "error2"})) + + result = hc.readiness() + + assert result["status"] == HealthStatus.UNHEALTHY + assert len(result["checks"]) == 2 + assert all(c["status"] == HealthStatus.UNHEALTHY for c in result["checks"]) + + +class TestDefaultHealthChecks: + """Tests for default health check functions.""" + + def test_check_policy_engine(self): + """Test policy engine health check.""" + status, details = check_policy_engine() + + assert status == HealthStatus.HEALTHY + assert "policies_loaded" in details + assert isinstance(details["policies_loaded"], int) + + def test_check_ledger(self): + """Test ledger health check.""" + status, details = check_ledger() + + assert status == HealthStatus.HEALTHY + assert "entries" in details + assert "last_verified" in details + assert isinstance(details["entries"], int) + assert isinstance(details["last_verified"], float) + + def test_check_identity(self): + """Test identity health check.""" + status, details = check_identity() + + assert status == HealthStatus.HEALTHY + assert "key_loaded" in details + assert isinstance(details["key_loaded"], bool) + + +class TestHealthCheckIntegration: + """Integration tests for health check system.""" + + def test_default_health_check_instance(self): + """Test that default health check instance has checks registered.""" + from lexecon.observability.health import health_check + + # Should have default checks registered + assert len(health_check.checks) > 0 + assert "policy_engine" in health_check.checks + assert "ledger" in health_check.checks + assert "identity" in health_check.checks + + def test_full_readiness_check(self): + """Test full readiness check with default checks.""" + from lexecon.observability.health import health_check + + result = health_check.readiness() + + assert "status" in result + assert "checks" in result + assert len(result["checks"]) >= 3 # At least the 3 default checks + + def test_liveness_timestamp_is_recent(self): + """Test that liveness timestamp is recent.""" + hc = HealthCheck() + + result = hc.liveness() + + # Timestamp should be within last second + now = time.time() + assert abs(now - result["timestamp"]) < 1.0 + + def test_readiness_timestamp_is_recent(self): + """Test that readiness timestamp is recent.""" + hc = HealthCheck() + + result = hc.readiness() + + now = time.time() + assert abs(now - result["timestamp"]) < 1.0 + + +class TestHealthCheckEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_check_returning_none(self): + """Test handling check that returns None.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("none_check", lambda: None) + + # Should handle by catching exception in readiness check + result = hc.readiness() + # Should mark as unhealthy due to exception + assert result["status"] == HealthStatus.UNHEALTHY + + def test_check_returning_invalid_status(self): + """Test handling check with invalid status.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("invalid", lambda: ("invalid_status", {})) + + result = hc.readiness() + + # Should still run, but status might not be recognized + assert result is not None + + def test_check_with_empty_details(self): + """Test check with empty details.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("empty", lambda: (HealthStatus.HEALTHY, {})) + + result = hc.readiness() + + assert result["status"] == HealthStatus.HEALTHY + check = result["checks"][0] + assert check["details"] == {} + + def test_check_with_complex_details(self): + """Test check with complex nested details.""" + hc = HealthCheck() + hc.checks = {} + + complex_details = { + "metrics": {"cpu": 45.2, "memory": 1024}, + "connections": [{"id": 1, "status": "active"}, {"id": 2, "status": "idle"}], + "metadata": {"version": "1.0", "uptime": 3600}, + } + + hc.add_check("complex", lambda: (HealthStatus.HEALTHY, complex_details)) + + result = hc.readiness() + + check = result["checks"][0] + assert check["details"] == complex_details + + def test_overwrite_check(self): + """Test that adding check with same name overwrites.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("check", lambda: (HealthStatus.HEALTHY, {"version": 1})) + hc.add_check("check", lambda: (HealthStatus.DEGRADED, {"version": 2})) + + result = hc.readiness() + + # Should only have 1 check + assert len(result["checks"]) == 1 + + # Should be the second one + assert result["checks"][0]["status"] == HealthStatus.DEGRADED + assert result["checks"][0]["details"]["version"] == 2 + + def test_check_with_very_long_execution_time(self): + """Test check that takes long to execute.""" + hc = HealthCheck() + hc.checks = {} + + def slow_check(): + time.sleep(0.2) + return HealthStatus.HEALTHY, {"slow": True} + + hc.add_check("slow", slow_check) + + start = time.time() + result = hc.readiness() + duration = time.time() - start + + # Should still complete + assert result["status"] == HealthStatus.HEALTHY + assert duration >= 0.2 + + def test_many_checks(self): + """Test with many health checks.""" + hc = HealthCheck() + hc.checks = {} + + # Add 100 checks + for i in range(100): + hc.add_check(f"check{i}", lambda: (HealthStatus.HEALTHY, {})) + + result = hc.readiness() + + assert len(result["checks"]) == 100 + assert result["status"] == HealthStatus.HEALTHY + + def test_mixed_status_priority(self): + """Test status priority with all three statuses.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("healthy1", lambda: (HealthStatus.HEALTHY, {})) + hc.add_check("healthy2", lambda: (HealthStatus.HEALTHY, {})) + hc.add_check("degraded1", lambda: (HealthStatus.DEGRADED, {})) + hc.add_check("unhealthy1", lambda: (HealthStatus.UNHEALTHY, {})) + + result = hc.readiness() + + # Unhealthy should take precedence + assert result["status"] == HealthStatus.UNHEALTHY + + def test_only_degraded_checks(self): + """Test with only degraded checks.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("deg1", lambda: (HealthStatus.DEGRADED, {})) + hc.add_check("deg2", lambda: (HealthStatus.DEGRADED, {})) + + result = hc.readiness() + + assert result["status"] == HealthStatus.DEGRADED + + +class TestHealthCheckConcurrency: + """Tests for concurrent health check scenarios.""" + + def test_multiple_liveness_calls(self): + """Test multiple simultaneous liveness calls.""" + hc = HealthCheck() + + results = [hc.liveness() for _ in range(10)] + + # All should succeed + assert len(results) == 10 + assert all(r["status"] == HealthStatus.HEALTHY for r in results) + + def test_multiple_readiness_calls(self): + """Test multiple simultaneous readiness calls.""" + hc = HealthCheck() + hc.checks = {} + hc.add_check("test", lambda: (HealthStatus.HEALTHY, {})) + + results = [hc.readiness() for _ in range(10)] + + # All should succeed + assert len(results) == 10 + assert all(r["status"] == HealthStatus.HEALTHY for r in results) + + def test_concurrent_check_modifications(self): + """Test that checks can be added during operation.""" + hc = HealthCheck() + hc.checks = {} + + hc.add_check("initial", lambda: (HealthStatus.HEALTHY, {})) + + result1 = hc.readiness() + assert len(result1["checks"]) == 1 + + hc.add_check("added", lambda: (HealthStatus.HEALTHY, {})) + + result2 = hc.readiness() + assert len(result2["checks"]) == 2 + + +class TestHealthCheckSerialization: + """Tests for health check result serialization.""" + + def test_liveness_result_is_dict(self): + """Test that liveness result is JSON-serializable.""" + hc = HealthCheck() + result = hc.liveness() + + import json + + # Should be serializable to JSON + json_str = json.dumps(result) + assert len(json_str) > 0 + + # Should be deserializable + parsed = json.loads(json_str) + assert parsed["status"] == HealthStatus.HEALTHY + + def test_readiness_result_is_dict(self): + """Test that readiness result is JSON-serializable.""" + hc = HealthCheck() + hc.checks = {} + hc.add_check("test", lambda: (HealthStatus.HEALTHY, {"count": 5})) + + result = hc.readiness() + + import json + + json_str = json.dumps(result) + parsed = json.loads(json_str) + + assert parsed["status"] == HealthStatus.HEALTHY + assert len(parsed["checks"]) == 1 + + def test_startup_result_is_dict(self): + """Test that startup result is JSON-serializable.""" + hc = HealthCheck() + result = hc.startup() + + import json + + json_str = json.dumps(result) + parsed = json.loads(json_str) + + assert parsed["status"] == HealthStatus.HEALTHY diff --git a/tests/test_identity.py b/tests/test_identity.py new file mode 100644 index 0000000..8af89d9 --- /dev/null +++ b/tests/test_identity.py @@ -0,0 +1,548 @@ +"""Tests for identity and signing functionality.""" + +import json +import tempfile +from pathlib import Path + +import pytest + +from lexecon.identity.signing import KeyManager, NodeIdentity + + +class TestKeyManager: + """Tests for KeyManager class.""" + + def test_generate_key_pair(self): + """Test generating a new Ed25519 key pair.""" + km = KeyManager.generate() + + assert km.private_key is not None + assert km.public_key is not None + + def test_different_key_pairs_are_unique(self): + """Test that multiple generated key pairs are different.""" + km1 = KeyManager.generate() + km2 = KeyManager.generate() + + # Get fingerprints to compare + fp1 = km1.get_public_key_fingerprint() + fp2 = km2.get_public_key_fingerprint() + + assert fp1 != fp2 + + def test_save_keys_to_disk(self): + """Test saving keys to disk in PEM format.""" + km = KeyManager.generate() + + with tempfile.TemporaryDirectory() as tmpdir: + private_path = Path(tmpdir) / "test.key" + public_path = Path(tmpdir) / "test.pub" + + km.save_keys(private_path, public_path) + + # Check files exist + assert private_path.exists() + assert public_path.exists() + + # Check files have content + assert len(private_path.read_bytes()) > 0 + assert len(public_path.read_bytes()) > 0 + + # Check PEM format + private_content = private_path.read_text() + public_content = public_path.read_text() + + assert "BEGIN PRIVATE KEY" in private_content + assert "END PRIVATE KEY" in private_content + assert "BEGIN PUBLIC KEY" in public_content + assert "END PUBLIC KEY" in public_content + + def test_load_keys_from_disk(self): + """Test loading private key from disk.""" + km_original = KeyManager.generate() + + with tempfile.TemporaryDirectory() as tmpdir: + private_path = Path(tmpdir) / "test.key" + public_path = Path(tmpdir) / "test.pub" + + # Save keys + km_original.save_keys(private_path, public_path) + + # Load keys + km_loaded = KeyManager.load_keys(private_path) + + assert km_loaded.private_key is not None + assert km_loaded.public_key is not None + + # Fingerprints should match + assert ( + km_loaded.get_public_key_fingerprint() + == km_original.get_public_key_fingerprint() + ) + + def test_load_public_key_from_disk(self): + """Test loading public key separately.""" + km = KeyManager.generate() + + with tempfile.TemporaryDirectory() as tmpdir: + private_path = Path(tmpdir) / "test.key" + public_path = Path(tmpdir) / "test.pub" + + km.save_keys(private_path, public_path) + + # Load public key + public_key = KeyManager.load_public_key(public_path) + + assert public_key is not None + + def test_sign_data(self): + """Test signing data with private key.""" + km = KeyManager.generate() + data = {"message": "test", "value": 42} + + signature = km.sign(data) + + # Signature should be base64 string + assert isinstance(signature, str) + assert len(signature) > 0 + + # Should be valid base64 + import base64 + + decoded = base64.b64decode(signature) + assert len(decoded) > 0 + + def test_sign_creates_deterministic_signature(self): + """Test that signing same data twice gives same signature.""" + km = KeyManager.generate() + data = {"key": "value", "number": 123} + + sig1 = km.sign(data) + sig2 = km.sign(data) + + assert sig1 == sig2 + + def test_sign_different_data_gives_different_signatures(self): + """Test that different data produces different signatures.""" + km = KeyManager.generate() + data1 = {"message": "first"} + data2 = {"message": "second"} + + sig1 = km.sign(data1) + sig2 = km.sign(data2) + + assert sig1 != sig2 + + def test_verify_valid_signature(self): + """Test verifying a valid signature.""" + km = KeyManager.generate() + data = {"test": "data", "count": 5} + + signature = km.sign(data) + is_valid = KeyManager.verify(data, signature, km.public_key) + + assert is_valid is True + + def test_verify_invalid_signature(self): + """Test that invalid signature fails verification.""" + km = KeyManager.generate() + data = {"test": "data"} + + signature = km.sign(data) + + # Tamper with signature + import base64 + + sig_bytes = base64.b64decode(signature) + tampered = sig_bytes[:-1] + bytes([sig_bytes[-1] ^ 0xFF]) + tampered_sig = base64.b64encode(tampered).decode() + + is_valid = KeyManager.verify(data, tampered_sig, km.public_key) + + assert is_valid is False + + def test_verify_signature_with_wrong_key(self): + """Test that signature verification fails with wrong public key.""" + km1 = KeyManager.generate() + km2 = KeyManager.generate() + data = {"test": "data"} + + # Sign with km1 + signature = km1.sign(data) + + # Verify with km2's public key + is_valid = KeyManager.verify(data, signature, km2.public_key) + + assert is_valid is False + + def test_verify_signature_with_tampered_data(self): + """Test that verification fails when data is tampered.""" + km = KeyManager.generate() + data = {"value": 100} + + signature = km.sign(data) + + # Tamper with data + tampered_data = {"value": 999} + + is_valid = KeyManager.verify(tampered_data, signature, km.public_key) + + assert is_valid is False + + def test_sign_without_private_key_raises_error(self): + """Test that signing without private key raises error.""" + km = KeyManager() # No key + + with pytest.raises(ValueError, match="No private key"): + km.sign({"test": "data"}) + + def test_save_keys_without_private_key_raises_error(self): + """Test that saving without private key raises error.""" + km = KeyManager() # No key + + with tempfile.TemporaryDirectory() as tmpdir: + private_path = Path(tmpdir) / "test.key" + public_path = Path(tmpdir) / "test.pub" + + with pytest.raises(ValueError, match="No private key"): + km.save_keys(private_path, public_path) + + def test_get_public_key_fingerprint(self): + """Test getting public key fingerprint.""" + km = KeyManager.generate() + + fingerprint = km.get_public_key_fingerprint() + + # Should be 16 character hex string (first 16 chars of SHA256) + assert isinstance(fingerprint, str) + assert len(fingerprint) == 16 + # Should be valid hex + int(fingerprint, 16) + + def test_fingerprint_is_deterministic(self): + """Test that fingerprint is deterministic for same key.""" + km = KeyManager.generate() + + fp1 = km.get_public_key_fingerprint() + fp2 = km.get_public_key_fingerprint() + + assert fp1 == fp2 + + def test_get_fingerprint_without_public_key_raises_error(self): + """Test that getting fingerprint without key raises error.""" + km = KeyManager() + + with pytest.raises(ValueError, match="No public key"): + km.get_public_key_fingerprint() + + def test_sign_canonical_json(self): + """Test that signing uses canonical JSON representation.""" + km = KeyManager.generate() + + # These should produce the same signature due to canonical JSON + data1 = {"b": 2, "a": 1} + data2 = {"a": 1, "b": 2} + + sig1 = km.sign(data1) + sig2 = km.sign(data2) + + assert sig1 == sig2 + + def test_sign_handles_nested_data(self): + """Test signing complex nested data structures.""" + km = KeyManager.generate() + data = { + "user": {"name": "Alice", "id": 123}, + "permissions": ["read", "write"], + "metadata": {"created": "2025-01-01", "version": 2}, + } + + signature = km.sign(data) + is_valid = KeyManager.verify(data, signature, km.public_key) + + assert is_valid is True + + def test_key_persistence_roundtrip(self): + """Test full roundtrip: generate, save, load, verify.""" + km_original = KeyManager.generate() + data = {"test": "roundtrip"} + original_signature = km_original.sign(data) + + with tempfile.TemporaryDirectory() as tmpdir: + private_path = Path(tmpdir) / "key.pem" + public_path = Path(tmpdir) / "key.pub" + + # Save + km_original.save_keys(private_path, public_path) + + # Load + km_loaded = KeyManager.load_keys(private_path) + + # Should be able to sign with loaded key + loaded_signature = km_loaded.sign(data) + + # Signatures should match + assert loaded_signature == original_signature + + # Verification should work + assert KeyManager.verify(data, loaded_signature, km_loaded.public_key) is True + + +class TestNodeIdentity: + """Tests for NodeIdentity class.""" + + def test_create_node_identity(self): + """Test creating a node identity.""" + node = NodeIdentity("test-node-1") + + assert node.node_id == "test-node-1" + assert node.key_manager is not None + assert node.key_manager.private_key is not None + + def test_create_with_existing_key_manager(self): + """Test creating node with existing key manager.""" + km = KeyManager.generate() + node = NodeIdentity("test-node-2", key_manager=km) + + assert node.node_id == "test-node-2" + assert node.key_manager is km + + def test_get_node_id(self): + """Test getting node ID.""" + node = NodeIdentity("my-node") + + assert node.get_node_id() == "my-node" + + def test_sign_data(self): + """Test signing data through node identity.""" + node = NodeIdentity("signer-node") + data = {"message": "test", "timestamp": "2025-01-01"} + + signature = node.sign(data) + + assert isinstance(signature, str) + assert len(signature) > 0 + + def test_get_public_key_fingerprint(self): + """Test getting public key fingerprint.""" + node = NodeIdentity("fp-node") + + fingerprint = node.get_public_key_fingerprint() + + assert isinstance(fingerprint, str) + assert len(fingerprint) == 16 + + def test_verify_signature_with_string_data(self): + """Test verifying signature on string data (like hashes).""" + node = NodeIdentity("verify-node") + + # Simulate signing a hash string + hash_string = "abc123def456" + + # Sign it manually through key manager + import base64 + message = hash_string.encode() + signature_bytes = node.key_manager.private_key.sign(message) + signature = base64.b64encode(signature_bytes).decode() + + # Verify using node identity + is_valid = node.verify_signature(hash_string, signature) + + assert is_valid is True + + def test_verify_signature_fails_with_wrong_data(self): + """Test that verification fails with different data.""" + node = NodeIdentity("verify-node") + + original_data = "original_hash" + tampered_data = "tampered_hash" + + # Sign original + import base64 + message = original_data.encode() + signature_bytes = node.key_manager.private_key.sign(message) + signature = base64.b64encode(signature_bytes).decode() + + # Try to verify with tampered data + is_valid = node.verify_signature(tampered_data, signature) + + assert is_valid is False + + def test_verify_signature_fails_with_wrong_signature(self): + """Test that verification fails with invalid signature.""" + node = NodeIdentity("verify-node") + data = "test_data" + + # Create invalid signature + fake_signature = "aW52YWxpZF9zaWduYXR1cmU=" # base64 of "invalid_signature" + + is_valid = node.verify_signature(data, fake_signature) + + assert is_valid is False + + def test_verify_signature_without_public_key(self): + """Test verification fails without public key.""" + node = NodeIdentity("no-key-node") + # Remove public key + node.key_manager.public_key = None + + is_valid = node.verify_signature("data", "signature") + + assert is_valid is False + + def test_different_nodes_have_different_fingerprints(self): + """Test that different nodes have unique fingerprints.""" + node1 = NodeIdentity("node-1") + node2 = NodeIdentity("node-2") + + fp1 = node1.get_public_key_fingerprint() + fp2 = node2.get_public_key_fingerprint() + + assert fp1 != fp2 + + def test_node_can_verify_own_signature(self): + """Test that node can verify its own signatures.""" + node = NodeIdentity("self-verify-node") + data = {"action": "test", "value": 42} + + # Sign with node + signature = node.sign(data) + + # Convert dict to canonical JSON for verification + import json + canonical = json.dumps(data, sort_keys=True, separators=(",", ":")) + + # This should work with the node's verify_signature method + # but it expects string data, so we need to verify differently + # Let's use the key manager's verify method + is_valid = KeyManager.verify(data, signature, node.key_manager.public_key) + + assert is_valid is True + + +class TestCrossNodeVerification: + """Tests for cross-node signature verification.""" + + def test_node_cannot_verify_other_node_signature(self): + """Test that one node cannot forge another's signature.""" + node1 = NodeIdentity("node-1") + node2 = NodeIdentity("node-2") + + data = {"message": "test"} + + # Node 1 signs + signature = node1.sign(data) + + # Node 2 tries to verify with its own key + is_valid = KeyManager.verify(data, signature, node2.key_manager.public_key) + + assert is_valid is False + + def test_public_key_distribution(self): + """Test that public keys can be shared for verification.""" + node1 = NodeIdentity("alice") + node2 = NodeIdentity("bob") + + data = {"transfer": "100", "to": "bob"} + + # Alice signs + signature = node1.sign(data) + + # Bob can verify using Alice's public key + is_valid = KeyManager.verify(data, signature, node1.key_manager.public_key) + + assert is_valid is True + + def test_signature_persistence_across_nodes(self): + """Test signature verification works after key export/import.""" + # Node 1 creates and signs + node1 = NodeIdentity("node-1") + data = {"test": "data"} + signature = node1.sign(data) + + with tempfile.TemporaryDirectory() as tmpdir: + private_path = Path(tmpdir) / "node1.key" + public_path = Path(tmpdir) / "node1.pub" + + # Export keys + node1.key_manager.save_keys(private_path, public_path) + + # Load into new key manager (simulating different node) + loaded_km = KeyManager.load_keys(private_path) + + # Should be able to verify with loaded keys + is_valid = KeyManager.verify(data, signature, loaded_km.public_key) + assert is_valid is True + + +class TestEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_sign_empty_dict(self): + """Test signing empty dictionary.""" + km = KeyManager.generate() + data = {} + + signature = km.sign(data) + is_valid = KeyManager.verify(data, signature, km.public_key) + + assert is_valid is True + + def test_sign_large_data(self): + """Test signing large data structure.""" + km = KeyManager.generate() + data = {"items": [{"id": i, "value": f"item_{i}"} for i in range(1000)]} + + signature = km.sign(data) + is_valid = KeyManager.verify(data, signature, km.public_key) + + assert is_valid is True + + def test_sign_with_unicode(self): + """Test signing data with unicode characters.""" + km = KeyManager.generate() + data = {"message": "Hello 世界 🌍", "emoji": "🔐"} + + signature = km.sign(data) + is_valid = KeyManager.verify(data, signature, km.public_key) + + assert is_valid is True + + def test_node_id_with_special_characters(self): + """Test node identity with special characters in ID.""" + node = NodeIdentity("node-123_test.example.com") + + assert node.get_node_id() == "node-123_test.example.com" + + def test_verify_with_malformed_signature(self): + """Test verification with malformed base64 signature.""" + node = NodeIdentity("test-node") + + # Invalid base64 + invalid_sig = "not-valid-base64!!!" + + is_valid = node.verify_signature("data", invalid_sig) + + assert is_valid is False + + def test_load_nonexistent_key_file(self): + """Test loading from non-existent file.""" + with pytest.raises(FileNotFoundError): + KeyManager.load_keys(Path("/nonexistent/key.pem")) + + def test_save_to_readonly_directory(self): + """Test error handling when saving to readonly location.""" + import os + + # Skip if running as root (has permission to write everywhere) + if os.getuid() == 0: + pytest.skip("Running as root, cannot test readonly directory") + + km = KeyManager.generate() + + # Try to save to root (should fail on Unix systems) + readonly_path = Path("/readonly.key") + readonly_pub = Path("/readonly.pub") + + with pytest.raises((PermissionError, OSError)): + km.save_keys(readonly_path, readonly_pub)