Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions testproject/tests/test_second_step_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
HTTP_204_NO_CONTENT,
HTTP_400_BAD_REQUEST,
HTTP_401_UNAUTHORIZED,
HTTP_422_UNPROCESSABLE_ENTITY,
)
from rest_framework.test import APIClient
from time import sleep
Expand Down Expand Up @@ -518,6 +519,22 @@ def test_request_code_for_not_inactive_mfa_method(active_user_with_email_otp):
assert response.status_code == HTTP_400_BAD_REQUEST
assert response.data.get("error") == "Requested MFA method does not exist."

@flaky
@pytest.mark.django_db
def test_request_code_for_application_mfa_method(active_user_with_application_otp):
client = TrenchAPIClient()
mfa_method = active_user_with_application_otp.mfa_methods.first()
client.authenticate_multi_factor(
mfa_method=mfa_method, user=active_user_with_application_otp
)
response = client.post(
path="/auth/code/request/",
data={"method": "app"},
format="json",
)
assert response.status_code == HTTP_422_UNPROCESSABLE_ENTITY
assert response.data.get("details") == "Get code from OTP application."


@flaky
@pytest.mark.django_db
Expand Down
7 changes: 6 additions & 1 deletion trench/backends/application.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Optional

from django.contrib.auth import get_user_model
from django.contrib.auth.models import AbstractUser

import logging

from trench.backends.base import AbstractMessageDispatcher
from trench.exceptions import GetCodeFromApplicationException
from trench.responses import (
DispatchResponse,
FailedDispatchResponse,
Expand All @@ -16,8 +19,10 @@


class ApplicationMessageDispatcher(AbstractMessageDispatcher):
def dispatch_message(self) -> DispatchResponse:
def dispatch_message(self, url_name: Optional[str] = None) -> DispatchResponse:
try:
if url_name and url_name == "mfa-request-code":
raise GetCodeFromApplicationException()
qr_link = self._create_qr_link(self._mfa_method.user)
return SuccessfulDispatchResponse(details=qr_link)
except Exception as cause: # pragma: nocover
Expand Down
6 changes: 4 additions & 2 deletions trench/backends/aws.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Optional

from django.utils.translation import gettext_lazy as _

import logging
import boto3
import botocore.exceptions

from trench.backends.base import AbstractMessageDispatcher
from trench.responses import (
Expand All @@ -13,11 +14,12 @@
from trench.settings import AWS_ACCESS_KEY, AWS_SECRET_KEY, AWS_REGION
from botocore.exceptions import ClientError, EndpointConnectionError


class AWSMessageDispatcher(AbstractMessageDispatcher):
_SMS_BODY = _("Your verification code is: ")
_SUCCESS_DETAILS = _("SMS message with MFA code has been sent.")

def dispatch_message(self) -> DispatchResponse:
def dispatch_message(self, url_name: Optional[str] = None) -> DispatchResponse:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't care about the args here maybe this would be better:

Suggested change
def dispatch_message(self, url_name: Optional[str] = None) -> DispatchResponse:
def dispatch_message(self, *args, **kwargs) -> DispatchResponse:

try:
client = boto3.client(
"sns",
Expand Down
2 changes: 1 addition & 1 deletion trench/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _get_innermost_object(obj: Model, dotted_path: Optional[str] = None) -> Mode
return obj # pragma: no cover

@abstractmethod
def dispatch_message(self) -> DispatchResponse:
def dispatch_message(self, url_name: Optional[str] = None) -> DispatchResponse:
raise NotImplementedError # pragma: no cover

def create_code(self) -> str:
Expand Down
4 changes: 3 additions & 1 deletion trench/backends/basic_mail.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from django.conf import settings
from django.core.mail import send_mail
from django.template.loader import get_template
Expand All @@ -19,7 +21,7 @@ class SendMailMessageDispatcher(AbstractMessageDispatcher):
_KEY_MESSAGE = "message"
_SUCCESS_DETAILS = _("Email message with MFA code has been sent.")

def dispatch_message(self) -> DispatchResponse:
def dispatch_message(self, url_name: Optional[str] = None) -> DispatchResponse:
context = {"code": self.create_code()}
email_plain_template = self._config[EMAIL_PLAIN_TEMPLATE]
email_html_template = self._config[EMAIL_HTML_TEMPLATE]
Expand Down
4 changes: 3 additions & 1 deletion trench/backends/sms_api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from django.utils.translation import gettext_lazy as _

import logging
Expand All @@ -17,7 +19,7 @@ class SMSAPIMessageDispatcher(AbstractMessageDispatcher):
_SMS_BODY = _("Your verification code is: ")
_SUCCESS_DETAILS = _("SMS message with MFA code has been sent.")

def dispatch_message(self) -> DispatchResponse:
def dispatch_message(self, url_name: Optional[str] = None) -> DispatchResponse:
try:
client = SmsApiPlClient(access_token=self._config.get(SMSAPI_ACCESS_TOKEN))
from_number = self._config.get(SMSAPI_FROM_NUMBER)
Expand Down
4 changes: 3 additions & 1 deletion trench/backends/twilio.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from django.utils.translation import gettext_lazy as _

import logging
Expand All @@ -17,7 +19,7 @@ class TwilioMessageDispatcher(AbstractMessageDispatcher):
_SMS_BODY = _("Your verification code is: ")
_SUCCESS_DETAILS = _("SMS message with MFA code has been sent.")

def dispatch_message(self) -> DispatchResponse:
def dispatch_message(self, url_name: Optional[str] = None) -> DispatchResponse:
try:
client = Client()
client.messages.create(
Expand Down
4 changes: 3 additions & 1 deletion trench/backends/yubikey.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from django.utils.translation import gettext_lazy as _

import logging
Expand All @@ -11,7 +13,7 @@


class YubiKeyMessageDispatcher(AbstractMessageDispatcher):
def dispatch_message(self) -> DispatchResponse:
def dispatch_message(self, url_name: Optional[str] = None) -> DispatchResponse:
return SuccessfulDispatchResponse(details=_("Generate code using YubiKey"))

def confirm_activation(self, code: str) -> None:
Expand Down
8 changes: 8 additions & 0 deletions trench/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ def __init__(self) -> None:
super().__init__(detail=_("OTP code not provided."), code="otp_code_missing")


class GetCodeFromApplicationException(MFAValidationError):
def __init__(self) -> None:
super().__init__(
detail=_("Get code from OTP application."),
code="get_code_from_application"
)


class MFAMethodDoesNotExistError(MFAValidationError):
def __init__(self) -> None:
super().__init__(
Expand Down
6 changes: 6 additions & 0 deletions trench/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def get_by_name(self, user_id: Any, name: str) -> "MFAMethod":
except self.model.DoesNotExist:
raise MFAMethodDoesNotExistError()

def get_active_by_name(self, user_id: Any, name: str) -> "MFAMethod":
try:
return self.get(user_id=user_id, name=name, is_active=True)
except self.model.DoesNotExist:
raise MFAMethodDoesNotExistError()

def get_primary_active(self, user_id: Any) -> "MFAMethod":
try:
return self.get(user_id=user_id, is_primary=True, is_active=True)
Expand Down
16 changes: 12 additions & 4 deletions trench/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from trench.settings import SOURCE_FIELD, trench_settings
from trench.utils import available_method_choices, get_mfa_model, user_token_generator


User: AbstractUser = get_user_model()


Expand All @@ -70,6 +69,7 @@ def post(self, request: Request) -> Response:
)
except MFAValidationError as cause:
return ErrorResponse(error=cause)

try:
mfa_model = get_mfa_model()
mfa_method = mfa_model.objects.get_primary_active(user_id=user.id)
Expand Down Expand Up @@ -110,7 +110,10 @@ def post(request: Request, method: str) -> Response:
user = request.user
try:
if source_field is not None and not hasattr(user, source_field):
raise MFASourceFieldDoesNotExistError(source_field, user.__class__.__name__)
raise MFASourceFieldDoesNotExistError(
source_field,
user.__class__.__name__
)

mfa = create_mfa_method_command(
user_id=user.id,
Expand Down Expand Up @@ -224,8 +227,13 @@ def post(request: Request) -> Response:
method = mfa_model.objects.get_primary_active_name(
user_id=request.user.id
)
mfa = mfa_model.objects.get_by_name(user_id=request.user.id, name=method)
return get_mfa_handler(mfa_method=mfa).dispatch_message()
mfa = mfa_model.objects.get_active_by_name(
user_id=request.user.id,
name=method
)
return get_mfa_handler(mfa_method=mfa).dispatch_message(
url_name=request.resolver_match.url_name
)
except MFAValidationError as cause:
return ErrorResponse(error=cause)

Expand Down