diff --git a/testproject/tests/conftest.py b/testproject/tests/conftest.py index 81d16ce3..a16e94b0 100644 --- a/testproject/tests/conftest.py +++ b/testproject/tests/conftest.py @@ -43,6 +43,7 @@ def mfa_method_creator( is_primary=is_primary, name=method_name, is_active=method_args.pop("is_active", True), + code_generated_at=None, **method_args, ) diff --git a/testproject/tests/test_add_mfa.py b/testproject/tests/test_add_mfa.py index d8ecb574..1dc554cb 100644 --- a/testproject/tests/test_add_mfa.py +++ b/testproject/tests/test_add_mfa.py @@ -8,7 +8,7 @@ from rest_framework.status import HTTP_200_OK, HTTP_400_BAD_REQUEST from tests.utils import TrenchAPIClient -from trench.command.create_otp import create_otp_command +from trench.command.create_otp import create_totp_command from trench.command.create_secret import create_secret_command @@ -24,7 +24,7 @@ def test_add_user_mfa(active_user): path="/auth/email/activate/", data={ "secret": secret, - "code": create_otp_command(secret=secret, interval=60).now(), + "code": create_totp_command(secret=secret, interval=60).now(), "user": getattr(active_user, active_user.USERNAME_FIELD), }, format="json", @@ -43,7 +43,7 @@ def test_should_fail_on_add_user_mfa_with_invalid_source_field(active_user: User path="/auth/email/activate/", data={ "secret": secret, - "code": create_otp_command(secret=secret, interval=60).now(), + "code": create_totp_command(secret=secret, interval=60).now(), "user": getattr(active_user, active_user.USERNAME_FIELD), }, format="json", diff --git a/testproject/tests/test_utils.py b/testproject/tests/test_utils.py index f5807a5d..6c1a2d77 100644 --- a/testproject/tests/test_utils.py +++ b/testproject/tests/test_utils.py @@ -2,8 +2,10 @@ from trench.backends.application import ApplicationMessageDispatcher from trench.backends.base import AbstractMessageDispatcher +from trench.backends.basic_mail import SendMailHotpMessageDispatcher from trench.backends.provider import get_mfa_handler from trench.models import MFAMethod +from trench.query.get_mfa_config_by_name import get_mfa_config_by_name_query from trench.utils import UserTokenGenerator @@ -46,6 +48,61 @@ def test_validate_code(active_user_with_email_otp): assert handler.validate_code(code=valid_code) is True +@pytest.mark.django_db +def test_create_code_hotp(active_user_with_email_otp): + email_method = active_user_with_email_otp.mfa_methods.get() + conf = get_mfa_config_by_name_query(name=email_method.name) + handler = SendMailHotpMessageDispatcher(email_method, conf) + + email_method.counter = 0 + email_method.code_generated_at = None + email_method.save() + + handler.create_code() + + email_method.refresh_from_db() + assert email_method.counter == 1 + assert email_method.code_generated_at is not None + + previous_code_genererated_at = email_method.code_generated_at + + handler.create_code() + + email_method.refresh_from_db() + assert email_method.counter == 2 + assert email_method.code_generated_at > previous_code_genererated_at + + +@pytest.mark.django_db +def test_validate_code_hotp(active_user_with_email_otp): + email_method = active_user_with_email_otp.mfa_methods.get() + conf = get_mfa_config_by_name_query(name=email_method.name) + handler = SendMailHotpMessageDispatcher(email_method, conf) + + email_method.counter = 0 + email_method.code_generated_at = None + email_method.save() + + valid_code = handler.create_code() + + assert handler.validate_code(code="123456") is False + email_method.refresh_from_db() + assert email_method.code_generated_at is not None + + assert handler.validate_code(code=valid_code) is True + email_method.refresh_from_db() + assert email_method.code_generated_at is None + + assert handler.validate_code(code=valid_code) is False + + valid_code = handler.create_code() + new_valid_code = handler.create_code() + + assert new_valid_code != valid_code + assert handler.validate_code(code=valid_code) is False + assert handler.validate_code(code=new_valid_code) is True + + @pytest.mark.django_db def test_validate_code_yubikey(active_user_with_many_otp_methods): active_user, _ = active_user_with_many_otp_methods diff --git a/trench/backends/aws.py b/trench/backends/aws.py index 7d8cfaed..68a9a43c 100644 --- a/trench/backends/aws.py +++ b/trench/backends/aws.py @@ -4,7 +4,10 @@ import boto3 import botocore.exceptions -from trench.backends.base import AbstractMessageDispatcher +from trench.backends.base import ( + AbstractMessageDispatcher, + AbstractHotpMessageDispatcher, +) from trench.responses import ( DispatchResponse, FailedDispatchResponse, @@ -13,6 +16,7 @@ from trench.settings import AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_REGION from botocore.exceptions import ClientError, EndpointConnectionError + class AWSMessageDispatcher(AbstractMessageDispatcher): _SMS_BODY = _("Your verification code is: ") _SUCCESS_DETAILS = _("SMS message with MFA code has been sent.") @@ -36,3 +40,7 @@ def dispatch_message(self) -> DispatchResponse: except EndpointConnectionError as cause: logging.error(cause, exc_info=True) return FailedDispatchResponse(details=str(cause)) + + +class AWSHotpMessageDispatcher(AbstractHotpMessageDispatcher, AWSMessageDispatcher): + pass diff --git a/trench/backends/base.py b/trench/backends/base.py index c32fe1a7..640b63f9 100644 --- a/trench/backends/base.py +++ b/trench/backends/base.py @@ -1,10 +1,12 @@ from django.db.models import Model +from django.utils import timezone from abc import ABC, abstractmethod -from pyotp import TOTP +from datetime import timedelta +from pyotp import TOTP, HOTP from typing import Any, Dict, Optional, Tuple -from trench.command.create_otp import create_otp_command +from trench.command.create_otp import create_totp_command, create_hotp_command from trench.exceptions import MissingConfigurationError from trench.models import MFAMethod from trench.responses import DispatchResponse @@ -76,7 +78,7 @@ def validate_code(self, code: str) -> bool: return self._get_otp().verify(otp=code) def _get_otp(self) -> TOTP: - return create_otp_command( + return create_totp_command( secret=self._mfa_method.secret, interval=self._get_valid_window() ) @@ -84,3 +86,34 @@ def _get_valid_window(self) -> int: return self._config.get( VALIDITY_PERIOD, trench_settings.DEFAULT_VALIDITY_PERIOD ) + + +class AbstractHotpMessageDispatcher(AbstractMessageDispatcher): + def create_code(self) -> str: + self._mfa_method.counter += 1 + self._mfa_method.code_generated_at = timezone.now() + self._mfa_method.save() + return self._get_otp().at(self._mfa_method.counter) + + def validate_code(self, code: str) -> bool: + if not self._mfa_method.code_generated_at: + return False + + is_valid = self._get_otp().verify(otp=code, counter=self._mfa_method.counter) + if not is_valid: + return False + + min_time = self._mfa_method.code_generated_at + max_time = self._mfa_method.code_generated_at + timedelta( + seconds=self._get_valid_window() + ) + now = timezone.now() + if now < min_time or now > max_time: + return False + + self._mfa_method.code_generated_at = None + self._mfa_method.save() + return True + + def _get_otp(self) -> HOTP: + return create_hotp_command(secret=self._mfa_method.secret) diff --git a/trench/backends/basic_mail.py b/trench/backends/basic_mail.py index 1f02e3eb..8dcc87ef 100644 --- a/trench/backends/basic_mail.py +++ b/trench/backends/basic_mail.py @@ -6,7 +6,10 @@ import logging from smtplib import SMTPException -from trench.backends.base import AbstractMessageDispatcher +from trench.backends.base import ( + AbstractMessageDispatcher, + AbstractHotpMessageDispatcher, +) from trench.responses import ( DispatchResponse, FailedDispatchResponse, @@ -39,3 +42,9 @@ def dispatch_message(self) -> DispatchResponse: except ConnectionRefusedError as cause: # pragma: nocover logging.error(cause, exc_info=True) # pragma: nocover return FailedDispatchResponse(details=str(cause)) # pragma: nocover + + +class SendMailHotpMessageDispatcher( + AbstractHotpMessageDispatcher, SendMailMessageDispatcher +): + pass diff --git a/trench/backends/sms_api.py b/trench/backends/sms_api.py index eba393f5..64382c2d 100644 --- a/trench/backends/sms_api.py +++ b/trench/backends/sms_api.py @@ -4,7 +4,10 @@ from smsapi.client import SmsApiPlClient from smsapi.exception import SmsApiException -from trench.backends.base import AbstractMessageDispatcher +from trench.backends.base import ( + AbstractMessageDispatcher, + AbstractHotpMessageDispatcher, +) from trench.responses import ( DispatchResponse, FailedDispatchResponse, @@ -31,3 +34,9 @@ def dispatch_message(self) -> DispatchResponse: except SmsApiException as cause: logging.error(cause, exc_info=True) return FailedDispatchResponse(details=cause.message) + + +class SMSAPIHotpMessageDispatcher( + AbstractHotpMessageDispatcher, SMSAPIMessageDispatcher +): + pass diff --git a/trench/backends/twilio.py b/trench/backends/twilio.py index 897c4332..5b1f98ef 100644 --- a/trench/backends/twilio.py +++ b/trench/backends/twilio.py @@ -4,7 +4,10 @@ from twilio.base.exceptions import TwilioRestException from twilio.rest import Client -from trench.backends.base import AbstractMessageDispatcher +from trench.backends.base import ( + AbstractMessageDispatcher, + AbstractHotpMessageDispatcher, +) from trench.responses import ( DispatchResponse, FailedDispatchResponse, @@ -29,3 +32,9 @@ def dispatch_message(self) -> DispatchResponse: except TwilioRestException as cause: logging.error(cause, exc_info=True) return FailedDispatchResponse(details=cause.msg) + + +class TwilioHotpMessageDispatcher( + AbstractHotpMessageDispatcher, TwilioMessageDispatcher +): + pass diff --git a/trench/command/create_otp.py b/trench/command/create_otp.py index 93ca4b7e..a3a1017d 100644 --- a/trench/command/create_otp.py +++ b/trench/command/create_otp.py @@ -1,10 +1,19 @@ -from pyotp import TOTP +from pyotp import TOTP, HOTP -class CreateOTPCommand: +class CreateTOTPCommand: @staticmethod def execute(secret: str, interval: int) -> TOTP: return TOTP(secret, interval=interval) -create_otp_command = CreateOTPCommand.execute +create_totp_command = CreateTOTPCommand.execute + + +class CreateHOTPCommand: + @staticmethod + def execute(secret: str) -> HOTP: + return HOTP(secret) + + +create_hotp_command = CreateHOTPCommand.execute diff --git a/trench/migrations/0006_mfamethod_code_generated_at_mfamethod_counter.py b/trench/migrations/0006_mfamethod_code_generated_at_mfamethod_counter.py new file mode 100644 index 00000000..cf23c721 --- /dev/null +++ b/trench/migrations/0006_mfamethod_code_generated_at_mfamethod_counter.py @@ -0,0 +1,24 @@ +# Generated by Django 4.1.5 on 2023-03-17 22:56 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("trench", "0005_remove_mfamethod_primary_is_active_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="mfamethod", + name="code_generated_at", + field=models.DateTimeField( + blank=True, null=True, verbose_name="code generated at" + ), + ), + migrations.AddField( + model_name="mfamethod", + name="counter", + field=models.PositiveIntegerField(default=0, verbose_name="counter"), + ), + ] diff --git a/trench/models.py b/trench/models.py index 6b1e19ca..9c7e30f7 100644 --- a/trench/models.py +++ b/trench/models.py @@ -1,12 +1,14 @@ from django.conf import settings from django.db.models import ( - CASCADE, BooleanField, + CASCADE, CharField, CheckConstraint, + DateTimeField, ForeignKey, Manager, Model, + PositiveIntegerField, Q, QuerySet, TextField, @@ -70,6 +72,8 @@ class MFAMethod(Model): ) name = CharField(_("name"), max_length=255) secret = CharField(_("secret"), max_length=255) + counter = PositiveIntegerField(_("counter"), default=0) + code_generated_at = DateTimeField(_("code generated at"), blank=True, null=True) is_primary = BooleanField(_("is primary"), default=False) is_active = BooleanField(_("is active"), default=False) _backup_codes = TextField(_("backup codes"), blank=True)