Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions dinov2/layers/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
import logging

logger = logging.getLogger("dinov2")


def _xformers_is_available(layer):

XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
xformers = None
try:
if XFORMERS_ENABLED:
import xformers

logger.info(f"xFormers is available ({layer})")
else:
logger.warning(f"xFormers is disabled ({layer})")
raise ImportError
except ImportError:
logger.warning(f"xFormers is not available ({layer})")

return xformers is not None
20 changes: 5 additions & 15 deletions dinov2/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,19 @@
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py

import logging
import os
import warnings

from torch import Tensor
from torch import nn

from ._utils import _xformers_is_available

logger = logging.getLogger("dinov2")

logger = logging.getLogger("dinov2")

XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
if XFORMERS_ENABLED:
from xformers.ops import memory_efficient_attention, unbind
XFORMERS_AVAILABLE = _xformers_is_available("Attention")

XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (Attention)")
else:
warnings.warn("xFormers is disabled (Attention)")
raise ImportError
except ImportError:
XFORMERS_AVAILABLE = False
warnings.warn("xFormers is not available (Attention)")
if XFORMERS_AVAILABLE:
from xformers.ops import memory_efficient_attention, unbind


class Attention(nn.Module):
Expand Down
20 changes: 4 additions & 16 deletions dinov2/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py

import logging
import os
from typing import Callable, List, Any, Tuple, Dict
import warnings

import torch
from torch import nn, Tensor
Expand All @@ -19,25 +17,15 @@
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp
from ._utils import _xformers_is_available


logger = logging.getLogger("dinov2")

XFORMERS_AVAILABLE = _xformers_is_available("Block")

XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
if XFORMERS_ENABLED:
from xformers.ops import fmha, scaled_index_add, index_select_cat

XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (Block)")
else:
warnings.warn("xFormers is disabled (Block)")
raise ImportError
except ImportError:
XFORMERS_AVAILABLE = False

warnings.warn("xFormers is not available (Block)")
if XFORMERS_AVAILABLE:
from xformers.ops import fmha, scaled_index_add, index_select_cat


class Block(nn.Module):
Expand Down
24 changes: 8 additions & 16 deletions dinov2/layers/swiglu_ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import os
from typing import Callable, Optional
import warnings

from torch import Tensor, nn
import torch.nn.functional as F

from ._utils import _xformers_is_available


XFORMERS_AVAILABLE = _xformers_is_available("SwiGLU")


class SwiGLUFFN(nn.Module):
def __init__(
Expand All @@ -34,21 +37,10 @@ def forward(self, x: Tensor) -> Tensor:
return self.w3(hidden)


XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
if XFORMERS_ENABLED:
from xformers.ops import SwiGLU

XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (SwiGLU)")
else:
warnings.warn("xFormers is disabled (SwiGLU)")
raise ImportError
except ImportError:
if XFORMERS_AVAILABLE:
from xformers.ops import SwiGLU
else:
SwiGLU = SwiGLUFFN
XFORMERS_AVAILABLE = False

warnings.warn("xFormers is not available (SwiGLU)")


class SwiGLUFFNFused(SwiGLU):
Expand Down