diff --git a/docs/source/conf.py b/docs/source/conf.py index c5f81bdd5..508bf4426 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -214,7 +214,7 @@ def _get_commit_sha() -> str: 'torch': ('https://pytorch.org/docs/stable/', None), 'torchmetrics': ('https://torchmetrics.readthedocs.io/en/latest/', None), 'torchvision': ('https://pytorch.org/vision/stable/', None), - 'transformers': ('https://huggingface.co/docs/transformers/master/en/', None), + 'transformers': ('https://huggingface.co/docs/transformers/main/en/', None), } nitpicky = False # warn on broken links @@ -423,7 +423,7 @@ def _generate_rst_files_for_modules() -> None: # avoid duplicate entries in docs. We add torch's _LRScheduler to # types, so we get a ``WARNING: duplicate object description`` if we # don't exclude it - exclude_members = [torch.optim.lr_scheduler._LRScheduler] + exclude_members = [torch.optim.lr_scheduler.LRScheduler] if module is not streaming: exclude_members += streaming_imported_types diff --git a/setup.py b/setup.py index ff8420800..07d01b937 100644 --- a/setup.py +++ b/setup.py @@ -49,10 +49,10 @@ 'numpy>=1.21.5,<2.2.0', 'paramiko>=2.11.0,<5', 'python-snappy>=0.6.1,<1', - 'torch>=1.10,<3', - 'torchvision>=0.10', + 'torch>=2.1,<3', + 'torchvision>=0.16', 'tqdm>=4.64.0,<5', - 'transformers>=4.21.3,<5', + 'transformers>=5.0.0,<6', 'xxhash>=3.0.0,<4', 'zstd>=1.5.2.5,<2', 'oci>=2.88,<3', diff --git a/streaming/base/dataloader.py b/streaming/base/dataloader.py index 9487c5e2b..2465c1e5d 100644 --- a/streaming/base/dataloader.py +++ b/streaming/base/dataloader.py @@ -7,8 +7,7 @@ from torch import Tensor from torch.utils.data import DataLoader -from transformers.feature_extraction_utils import BatchFeature -from transformers.tokenization_utils_base import BatchEncoding +from transformers import BatchEncoding, BatchFeature from streaming.base.dataset import StreamingDataset from streaming.base.world import World diff --git a/streaming/text/c4.py b/streaming/text/c4.py index bfde0c787..5ca94e8ae 100644 --- a/streaming/text/c4.py +++ b/streaming/text/c4.py @@ -9,7 +9,7 @@ from typing import Any, Optional -from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers import AutoTokenizer from streaming.base import StreamingDataset diff --git a/streaming/text/pile.py b/streaming/text/pile.py index 59379b602..1d27c520b 100644 --- a/streaming/text/pile.py +++ b/streaming/text/pile.py @@ -9,7 +9,7 @@ from typing import Any, Optional -from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers import AutoTokenizer from streaming.base import StreamingDataset