Skip to content
Open
1 change: 1 addition & 0 deletions testproject/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def mfa_method_creator(
is_primary=is_primary,
name=method_name,
is_active=method_args.pop("is_active", True),
code_generated_at=None,
**method_args,
)

Expand Down
6 changes: 3 additions & 3 deletions testproject/tests/test_add_mfa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from rest_framework.status import HTTP_200_OK, HTTP_400_BAD_REQUEST

from tests.utils import TrenchAPIClient
from trench.command.create_otp import create_otp_command
from trench.command.create_otp import create_totp_command
from trench.command.create_secret import create_secret_command


Expand All @@ -24,7 +24,7 @@ def test_add_user_mfa(active_user):
path="/auth/email/activate/",
data={
"secret": secret,
"code": create_otp_command(secret=secret, interval=60).now(),
"code": create_totp_command(secret=secret, interval=60).now(),
"user": getattr(active_user, active_user.USERNAME_FIELD),
},
format="json",
Expand All @@ -43,7 +43,7 @@ def test_should_fail_on_add_user_mfa_with_invalid_source_field(active_user: User
path="/auth/email/activate/",
data={
"secret": secret,
"code": create_otp_command(secret=secret, interval=60).now(),
"code": create_totp_command(secret=secret, interval=60).now(),
"user": getattr(active_user, active_user.USERNAME_FIELD),
},
format="json",
Expand Down
57 changes: 57 additions & 0 deletions testproject/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from trench.backends.application import ApplicationMessageDispatcher
from trench.backends.base import AbstractMessageDispatcher
from trench.backends.basic_mail import SendMailHotpMessageDispatcher
from trench.backends.provider import get_mfa_handler
from trench.models import MFAMethod
from trench.query.get_mfa_config_by_name import get_mfa_config_by_name_query
from trench.utils import UserTokenGenerator


Expand Down Expand Up @@ -46,6 +48,61 @@ def test_validate_code(active_user_with_email_otp):
assert handler.validate_code(code=valid_code) is True


@pytest.mark.django_db
def test_create_code_hotp(active_user_with_email_otp):
email_method = active_user_with_email_otp.mfa_methods.get()
conf = get_mfa_config_by_name_query(name=email_method.name)
handler = SendMailHotpMessageDispatcher(email_method, conf)

email_method.counter = 0
email_method.code_generated_at = None
email_method.save()

handler.create_code()

email_method.refresh_from_db()
assert email_method.counter == 1
assert email_method.code_generated_at is not None

previous_code_genererated_at = email_method.code_generated_at

handler.create_code()

email_method.refresh_from_db()
assert email_method.counter == 2
assert email_method.code_generated_at > previous_code_genererated_at


@pytest.mark.django_db
def test_validate_code_hotp(active_user_with_email_otp):
email_method = active_user_with_email_otp.mfa_methods.get()
conf = get_mfa_config_by_name_query(name=email_method.name)
handler = SendMailHotpMessageDispatcher(email_method, conf)

email_method.counter = 0
email_method.code_generated_at = None
email_method.save()

valid_code = handler.create_code()

assert handler.validate_code(code="123456") is False
email_method.refresh_from_db()
assert email_method.code_generated_at is not None

assert handler.validate_code(code=valid_code) is True
email_method.refresh_from_db()
assert email_method.code_generated_at is None

assert handler.validate_code(code=valid_code) is False

valid_code = handler.create_code()
new_valid_code = handler.create_code()

assert new_valid_code != valid_code
assert handler.validate_code(code=valid_code) is False
assert handler.validate_code(code=new_valid_code) is True


@pytest.mark.django_db
def test_validate_code_yubikey(active_user_with_many_otp_methods):
active_user, _ = active_user_with_many_otp_methods
Expand Down
10 changes: 9 additions & 1 deletion trench/backends/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import boto3
import botocore.exceptions

from trench.backends.base import AbstractMessageDispatcher
from trench.backends.base import (
AbstractMessageDispatcher,
AbstractHotpMessageDispatcher,
)
from trench.responses import (
DispatchResponse,
FailedDispatchResponse,
Expand All @@ -13,6 +16,7 @@
from trench.settings import AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_REGION
from botocore.exceptions import ClientError, EndpointConnectionError


class AWSMessageDispatcher(AbstractMessageDispatcher):
_SMS_BODY = _("Your verification code is: ")
_SUCCESS_DETAILS = _("SMS message with MFA code has been sent.")
Expand All @@ -36,3 +40,7 @@ def dispatch_message(self) -> DispatchResponse:
except EndpointConnectionError as cause:
logging.error(cause, exc_info=True)
return FailedDispatchResponse(details=str(cause))


class AWSHotpMessageDispatcher(AbstractHotpMessageDispatcher, AWSMessageDispatcher):
pass
39 changes: 36 additions & 3 deletions trench/backends/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from django.db.models import Model
from django.utils import timezone

from abc import ABC, abstractmethod
from pyotp import TOTP
from datetime import timedelta
from pyotp import TOTP, HOTP
from typing import Any, Dict, Optional, Tuple

from trench.command.create_otp import create_otp_command
from trench.command.create_otp import create_totp_command, create_hotp_command
from trench.exceptions import MissingConfigurationError
from trench.models import MFAMethod
from trench.responses import DispatchResponse
Expand Down Expand Up @@ -76,11 +78,42 @@ def validate_code(self, code: str) -> bool:
return self._get_otp().verify(otp=code)

def _get_otp(self) -> TOTP:
return create_otp_command(
return create_totp_command(
secret=self._mfa_method.secret, interval=self._get_valid_window()
)

def _get_valid_window(self) -> int:
return self._config.get(
VALIDITY_PERIOD, trench_settings.DEFAULT_VALIDITY_PERIOD
)


class AbstractHotpMessageDispatcher(AbstractMessageDispatcher):
def create_code(self) -> str:
self._mfa_method.counter += 1
self._mfa_method.code_generated_at = timezone.now()
self._mfa_method.save()
return self._get_otp().at(self._mfa_method.counter)

def validate_code(self, code: str) -> bool:
if not self._mfa_method.code_generated_at:
return False

is_valid = self._get_otp().verify(otp=code, counter=self._mfa_method.counter)
if not is_valid:
return False

min_time = self._mfa_method.code_generated_at
max_time = self._mfa_method.code_generated_at + timedelta(
seconds=self._get_valid_window()
)
now = timezone.now()
if now < min_time or now > max_time:
return False

self._mfa_method.code_generated_at = None
self._mfa_method.save()
return True

def _get_otp(self) -> HOTP:
return create_hotp_command(secret=self._mfa_method.secret)
11 changes: 10 additions & 1 deletion trench/backends/basic_mail.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
import logging
from smtplib import SMTPException

from trench.backends.base import AbstractMessageDispatcher
from trench.backends.base import (
AbstractMessageDispatcher,
AbstractHotpMessageDispatcher,
)
from trench.responses import (
DispatchResponse,
FailedDispatchResponse,
Expand Down Expand Up @@ -39,3 +42,9 @@ def dispatch_message(self) -> DispatchResponse:
except ConnectionRefusedError as cause: # pragma: nocover
logging.error(cause, exc_info=True) # pragma: nocover
return FailedDispatchResponse(details=str(cause)) # pragma: nocover


class SendMailHotpMessageDispatcher(
AbstractHotpMessageDispatcher, SendMailMessageDispatcher
):
pass
11 changes: 10 additions & 1 deletion trench/backends/sms_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from smsapi.client import SmsApiPlClient
from smsapi.exception import SmsApiException

from trench.backends.base import AbstractMessageDispatcher
from trench.backends.base import (
AbstractMessageDispatcher,
AbstractHotpMessageDispatcher,
)
from trench.responses import (
DispatchResponse,
FailedDispatchResponse,
Expand All @@ -31,3 +34,9 @@ def dispatch_message(self) -> DispatchResponse:
except SmsApiException as cause:
logging.error(cause, exc_info=True)
return FailedDispatchResponse(details=cause.message)


class SMSAPIHotpMessageDispatcher(
AbstractHotpMessageDispatcher, SMSAPIMessageDispatcher
):
pass
11 changes: 10 additions & 1 deletion trench/backends/twilio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
from twilio.base.exceptions import TwilioRestException
from twilio.rest import Client

from trench.backends.base import AbstractMessageDispatcher
from trench.backends.base import (
AbstractMessageDispatcher,
AbstractHotpMessageDispatcher,
)
from trench.responses import (
DispatchResponse,
FailedDispatchResponse,
Expand All @@ -29,3 +32,9 @@ def dispatch_message(self) -> DispatchResponse:
except TwilioRestException as cause:
logging.error(cause, exc_info=True)
return FailedDispatchResponse(details=cause.msg)


class TwilioHotpMessageDispatcher(
AbstractHotpMessageDispatcher, TwilioMessageDispatcher
):
pass
15 changes: 12 additions & 3 deletions trench/command/create_otp.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from pyotp import TOTP
from pyotp import TOTP, HOTP


class CreateOTPCommand:
class CreateTOTPCommand:
@staticmethod
def execute(secret: str, interval: int) -> TOTP:
return TOTP(secret, interval=interval)


create_otp_command = CreateOTPCommand.execute
create_totp_command = CreateTOTPCommand.execute


class CreateHOTPCommand:
@staticmethod
def execute(secret: str) -> HOTP:
return HOTP(secret)


create_hotp_command = CreateHOTPCommand.execute
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generated by Django 4.1.5 on 2023-03-17 22:56

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("trench", "0005_remove_mfamethod_primary_is_active_and_more"),
]

operations = [
migrations.AddField(
model_name="mfamethod",
name="code_generated_at",
field=models.DateTimeField(
blank=True, null=True, verbose_name="code generated at"
),
),
migrations.AddField(
model_name="mfamethod",
name="counter",
field=models.PositiveIntegerField(default=0, verbose_name="counter"),
),
]
6 changes: 5 additions & 1 deletion trench/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from django.conf import settings
from django.db.models import (
CASCADE,
BooleanField,
CASCADE,
CharField,
CheckConstraint,
DateTimeField,
ForeignKey,
Manager,
Model,
PositiveIntegerField,
Q,
QuerySet,
TextField,
Expand Down Expand Up @@ -70,6 +72,8 @@ class MFAMethod(Model):
)
name = CharField(_("name"), max_length=255)
secret = CharField(_("secret"), max_length=255)
counter = PositiveIntegerField(_("counter"), default=0)
code_generated_at = DateTimeField(_("code generated at"), blank=True, null=True)
is_primary = BooleanField(_("is primary"), default=False)
is_active = BooleanField(_("is active"), default=False)
_backup_codes = TextField(_("backup codes"), blank=True)
Expand Down