diff --git a/src/rest_framework_api_key/crypto.py b/src/rest_framework_api_key/crypto.py index 6a1696e..8e21457 100644 --- a/src/rest_framework_api_key/crypto.py +++ b/src/rest_framework_api_key/crypto.py @@ -1,7 +1,9 @@ +import hashlib import typing +from django.conf import settings from django.contrib.auth.hashers import check_password, make_password -from django.utils.crypto import get_random_string +from django.utils.crypto import constant_time_compare, get_random_string def concatenate(left: str, right: str) -> str: @@ -13,6 +15,21 @@ def split(concatenated: str) -> typing.Tuple[str, str]: return left, right +def hash_key(algo: str, key: str, salt: str) -> str: + hasher = getattr(hashlib, algo) + hash_value = hasher(key.encode() + salt.encode()).hexdigest() + return f"plain_{algo}$${hash_value}" + + +def check_hash(key: str, hashed_key: str, salt: str) -> bool: + algo, _, hash_value = hashed_key.partition("$$") + algo = algo.replace("plain_", "") + hasher = getattr(hashlib, algo) + return constant_time_compare( + hasher(key.encode() + salt.encode()).hexdigest(), hash_value + ) + + class KeyGenerator: def __init__(self, prefix_length: int = 8, secret_key_length: int = 32): self.prefix_length = prefix_length @@ -24,15 +41,24 @@ def get_prefix(self) -> str: def get_secret_key(self) -> str: return get_random_string(self.secret_key_length) - def hash(self, value: str) -> str: + def hash(self, value: str, salt: str) -> str: + hash_algo = getattr(settings, "DRF_API_KEY_HASHING_ALGORITHM", None) + if hash_algo: + # the hash is salted with the prefix to prevent rainbow table attacks + # (even though the key should be random enough to prevent that) + return hash_key(hash_algo, value, salt) return make_password(value) def generate(self) -> typing.Tuple[str, str, str]: prefix = self.get_prefix() secret_key = self.get_secret_key() key = concatenate(prefix, secret_key) - hashed_key = self.hash(key) + hashed_key = self.hash(key, prefix) return key, prefix, hashed_key - def verify(self, key: str, hashed_key: str) -> bool: + def verify(self, key: str, hashed_key: str, prefix: str) -> bool: + if hashed_key.startswith("plain_"): + # this is a plain key + return check_hash(key, hashed_key, prefix) + return check_password(key, hashed_key) diff --git a/src/rest_framework_api_key/models.py b/src/rest_framework_api_key/models.py index 3854f68..5b128a1 100644 --- a/src/rest_framework_api_key/models.py +++ b/src/rest_framework_api_key/models.py @@ -1,5 +1,6 @@ import typing +from django.conf import settings from django.core.exceptions import ValidationError from django.db import models from django.utils import timezone @@ -128,7 +129,17 @@ def _has_expired(self) -> bool: has_expired = property(_has_expired) def is_valid(self, key: str) -> bool: - return type(self).objects.key_generator.verify(key, self.hashed_key) + ok = type(self).objects.key_generator.verify(key, self.hashed_key, self.prefix) + if ok and getattr(settings, "DRF_API_KEY_HASH_AUTOUPDATE", False): + # by generating a new hash and comparing it with the stored one, + # we can detect not only if the hash algorithm has changed, but also + # if some internal parameters have changed (e.g. the number of iterations) + # at the cost of one more hash generation, which is negligible + new_hash = type(self).objects.key_generator.hash(key, self.prefix) + if new_hash != self.hashed_key: + self.hashed_key = new_hash + self.save() + return ok def clean(self) -> None: self._validate_revoked() diff --git a/tests/test_plain_key_hashing.py b/tests/test_plain_key_hashing.py new file mode 100644 index 0000000..d293f09 --- /dev/null +++ b/tests/test_plain_key_hashing.py @@ -0,0 +1,44 @@ +from typing import Callable + +import pytest +from django.conf import LazySettings + +from rest_framework_api_key.crypto import KeyGenerator +from rest_framework_api_key.models import APIKey + + +@pytest.mark.parametrize("algorithm", ["sha256", "sha512", "blake2b"]) +def test_hashing_algorithm_honors_setting( + settings: LazySettings, algorithm: str +) -> None: + settings.DRF_API_KEY_HASHING_ALGORITHM = algorithm + _key, _prefix, hashed_key = KeyGenerator().generate() + assert hashed_key.startswith(f"plain_{algorithm}$$") + + +@pytest.mark.parametrize("algorithm", ["sha256", "sha512", "blake2b"]) +def test_hash_verify(settings: LazySettings, algorithm: str) -> None: + settings.DRF_API_KEY_HASHING_ALGORITHM = algorithm + key, prefix, hashed_key = KeyGenerator().generate() + assert KeyGenerator().verify(key, hashed_key, prefix) is True + + +@pytest.mark.parametrize("update_algo", [True, False]) +@pytest.mark.django_db +def test_hash_verify_with_update( + settings: LazySettings, update_algo: bool, django_assert_num_queries: Callable +) -> None: + api_key, generated_key = APIKey.objects.create_key(name="test") + assert not api_key.hashed_key.startswith("plain_") + assert api_key.is_valid(generated_key) is True + + settings.DRF_API_KEY_HASHING_ALGORITHM = "blake2b" + settings.DRF_API_KEY_HASH_AUTOUPDATE = update_algo + + assert api_key.is_valid(generated_key) is True + assert api_key.hashed_key.startswith("plain_blake2b$$") is update_algo + with django_assert_num_queries(0): + # no queries should be made to update the key if it is already updated + assert ( + api_key.is_valid(generated_key) is True + ), "check still works after potential update"