Skip to content

Commit 873e4e8

Browse files
committed
feat: Add type hinting throughout the project
1 parent e6cead9 commit 873e4e8

File tree

17 files changed

+287
-198
lines changed

17 files changed

+287
-198
lines changed

.github/workflows/tests_and_publish.yml

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,23 @@ jobs:
8484
run: |
8585
pydocstyle --count django_opensearch_dsl/ tests/
8686
87+
mypy:
88+
runs-on: ubuntu-latest
89+
steps:
90+
- uses: actions/checkout@v4
91+
92+
- name: Setup Python
93+
uses: actions/setup-python@master
94+
with:
95+
python-version: '3.13'
96+
97+
- name: Install packages
98+
run: pip install -r requirements_dev.txt
99+
100+
- name: Mypy
101+
run: |
102+
mypy django_opensearch_dsl --disallow-untyped-def
103+
87104
bandit:
88105
runs-on: ubuntu-latest
89106
steps:
@@ -130,7 +147,6 @@ jobs:
130147
with:
131148
python-version: ${{ matrix.python-version }}
132149

133-
134150
- name: Run Opensearch in docker
135151
run: |
136152
docker compose up -d opensearch_test_${{ matrix.opensearch-version }}

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ __pycache__
44
coverage.xml
55
venv/
66

7+
# Mac
8+
.DS_Store
9+
710
# C extensions
811
*.so
912

bin/pre_commit.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,21 @@ fi
5555
echo ""
5656

5757

58+
################################################################################
59+
# MYPY #
60+
################################################################################
61+
echo -n "${Cyan}Running mypy... $Color_Off"
62+
out=$(mypy django_opensearch_dsl --disallow-untyped-def)
63+
if [ "$?" -ne 0 ] ; then
64+
echo "${Red}Error !$Color_Off"
65+
echo -e "$out"
66+
EXIT_CODE=1
67+
else
68+
echo "${Green}Ok ✅ $Color_Off"
69+
fi
70+
echo ""
71+
72+
5873
################################################################################
5974
# BANDIT #
6075
################################################################################

django_opensearch_dsl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@
66
__version__ = "0.7.0"
77

88

9-
def autodiscover():
9+
def autodiscover() -> None:
1010
"""Force the import of the `documents` modules of each `INSTALLED_APPS`."""
1111
autodiscover_modules("documents")

django_opensearch_dsl/apps.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1+
import json
2+
from typing import TYPE_CHECKING, Any
3+
14
from django.apps import AppConfig
25
from django.conf import settings
36
from django.utils.module_loading import import_string
47
from opensearchpy.connection.connections import connections
58

9+
if TYPE_CHECKING:
10+
from .signals import BaseSignalProcessor
11+
612

713
class DODConfig(AppConfig):
814
"""Django Opensearch DSL Appconfig."""
@@ -11,7 +17,7 @@ class DODConfig(AppConfig):
1117
verbose_name = "django-opensearch-dsl"
1218
signal_processor = None
1319

14-
def ready(self):
20+
def ready(self) -> None:
1521
"""Autodiscover documents and register signals."""
1622
self.module.autodiscover()
1723
connections.configure(**settings.OPENSEARCH_DSL)
@@ -21,40 +27,40 @@ def ready(self):
2127
self.signal_processor = self.signal_processor_class()(connections)
2228

2329
@classmethod
24-
def autosync_enabled(cls):
30+
def autosync_enabled(cls) -> bool:
2531
"""Return whether auto sync is enabled."""
2632
return getattr(settings, "OPENSEARCH_DSL_AUTOSYNC", True)
2733

2834
@classmethod
29-
def default_index_settings(cls):
35+
def default_index_settings(cls) -> dict[str, Any]:
3036
"""Return `OPENSEARCH_DSL_INDEX_SETTINGS`."""
3137
return getattr(settings, "OPENSEARCH_DSL_INDEX_SETTINGS", {})
3238

3339
@classmethod
34-
def auto_refresh_enabled(cls):
40+
def auto_refresh_enabled(cls) -> bool:
3541
"""Return whether auto refresh is enabled."""
3642
return getattr(settings, "OPENSEARCH_DSL_AUTO_REFRESH", False)
3743

3844
@classmethod
39-
def parallel_enabled(cls):
45+
def parallel_enabled(cls) -> bool:
4046
"""Return whether parallel operation is enabled."""
4147
return getattr(settings, "OPENSEARCH_DSL_PARALLEL", False)
4248

4349
@classmethod
44-
def default_queryset_pagination(cls):
50+
def default_queryset_pagination(cls) -> int:
4551
"""Return `OPENSEARCH_DSL_QUERYSET_PAGINATION`."""
4652
return getattr(settings, "OPENSEARCH_DSL_QUERYSET_PAGINATION", 4096)
4753

4854
@classmethod
49-
def signal_processor_class(cls):
55+
def signal_processor_class(cls) -> type["BaseSignalProcessor"]:
5056
"""Import and return the target of `OPENSEARCH_SIGNAL_PROCESSOR_CLASS`."""
5157
path = getattr(
5258
settings, "OPENSEARCH_DSL_SIGNAL_PROCESSOR", "django_opensearch_dsl.signals.RealTimeSignalProcessor"
5359
)
5460
return import_string(path)
5561

5662
@classmethod
57-
def signal_processor_serializer_class(cls):
63+
def signal_processor_serializer_class(cls) -> type[json.JSONEncoder]:
5864
"""Import and return the target of `OPENSEARCH_DSL_SIGNAL_PROCESSOR_SERIALIZER_CLASS`."""
5965
path = getattr(
6066
settings,
@@ -64,7 +70,7 @@ def signal_processor_serializer_class(cls):
6470
return import_string(path)
6571

6672
@classmethod
67-
def signal_processor_deserializer_class(cls):
73+
def signal_processor_deserializer_class(cls) -> type[json.JSONEncoder]:
6874
"""Import and return the target of `OPENSEARCH_DSL_SIGNAL_PROCESSOR_SERIALIZER_CLASS`."""
6975
path = getattr(
7076
settings,

django_opensearch_dsl/documents.py

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
1-
import io
21
import sys
32
import time
43
from collections import deque
54
from functools import partial
6-
from typing import Iterable, Optional
5+
from typing import Any, Callable, Iterable, Optional, TextIO, Union
76

7+
import opensearchpy
88
from django.db import models
99
from django.db.models import Q, QuerySet
1010
from opensearchpy.helpers import bulk, parallel_bulk
1111
from opensearchpy.helpers.document import Document as DSLDocument
1212

1313
from . import fields
1414
from .apps import DODConfig
15+
from .enums import BulkAction, CommandAction
1516
from .exceptions import ModelFieldNotMappedError
16-
from .management.enums import OpensearchAction
1717
from .search import Search
1818
from .signals import post_index
1919

20-
model_field_class_to_field_class = {
20+
model_field_class_to_field_class: dict[type[models.Field], type[fields.DODField]] = {
2121
models.AutoField: fields.IntegerField,
2222
models.BigAutoField: fields.LongField,
2323
models.BigIntegerField: fields.LongField,
@@ -48,31 +48,26 @@
4848
class Document(DSLDocument):
4949
"""Allow the definition of Opensearch' index using Django `Model`."""
5050

51-
_prepared_fields = []
51+
_prepared_fields: list[tuple[str, fields.DODField, Callable[[models.Model], Any]]] = []
5252

53-
def __init__(self, related_instance_to_ignore=None, **kwargs):
53+
def __init__(self, related_instance_to_ignore: Any = None, **kwargs: Any) -> None:
5454
super(Document, self).__init__(**kwargs)
5555
# related instances to ignore is required to remove the instance
5656
# from related models on deletion.
5757
self._related_instance_to_ignore = related_instance_to_ignore
5858
self._prepared_fields = self.init_prepare()
5959

6060
@classmethod
61-
def search(cls, using=None, index=None):
62-
"""Return a `Search` object parametrized with the index' information."""
61+
def search(cls, using: str = None, index: str = None) -> opensearchpy.Search:
62+
"""Return a `Search` object parametrized with the index information."""
6363
return Search(
6464
using=cls._get_using(using),
6565
index=cls._default_index(index),
6666
doc_type=[cls],
6767
model=cls.django.model,
6868
)
6969

70-
def get_queryset(
71-
self,
72-
filter_: Optional[Q] = None,
73-
exclude: Optional[Q] = None,
74-
count: int = None,
75-
) -> QuerySet:
70+
def get_queryset(self, filter_: Optional[Q] = None, exclude: Optional[Q] = None, count: int = None) -> QuerySet:
7671
"""Return the queryset that should be indexed by this doc type."""
7772
qs = self.django.model.objects.all()
7873

@@ -85,7 +80,7 @@ def get_queryset(
8580

8681
return qs
8782

88-
def _eta(self, start, done, total): # pragma: no cover
83+
def _eta(self, start: float, done: int, total: int) -> str: # pragma: no cover
8984
if done == 0:
9085
return "~"
9186
eta = round((time.time() - start) / done * (total - done))
@@ -101,43 +96,44 @@ def get_indexing_queryset(
10196
filter_: Optional[Q] = None,
10297
exclude: Optional[Q] = None,
10398
count: int = None,
104-
action: OpensearchAction = OpensearchAction.INDEX,
105-
stdout: io.FileIO = sys.stdout,
99+
action: CommandAction = CommandAction.INDEX,
100+
stdout: TextIO = sys.stdout,
106101
) -> Iterable:
107102
"""Divide the queryset into chunks."""
108103
chunk_size = self.django.queryset_pagination
109104
qs = self.get_queryset(filter_=filter_, exclude=exclude, count=count)
110105
qs = qs.order_by("pk") if not qs.query.is_sliced else qs
111-
count = qs.count()
106+
total = qs.count()
112107
model = self.django.model.__name__
113108
action = action.present_participle.title()
114109

115110
i = 0
116111
done = 0
117112
start = time.time()
118113
if verbose:
119-
stdout.write(f"{action} {model}: 0% ({self._eta(start, done, count)})\r")
120-
while done < count:
114+
stdout.write(f"{action} {model}: 0% ({self._eta(start, done, total)})\r")
115+
while done < total:
121116
if verbose:
122-
stdout.write(f"{action} {model}: {round(i / count * 100)}% ({self._eta(start, done, count)})\r")
117+
stdout.write(f"{action} {model}: {round(i / total * 100)}% ({self._eta(start, done, total)})\r")
123118

124119
for obj in qs[i : i + chunk_size]:
125120
done += 1
126121
yield obj
127122

128-
i = min(i + chunk_size, count)
123+
i = min(i + chunk_size, total)
129124

130125
if verbose:
131-
stdout.write(f"{action} {count} {model}: OK \n")
126+
stdout.write(f"{action} {total} {model}: OK \n")
132127

133-
def init_prepare(self):
128+
def init_prepare(self) -> list[tuple[str, fields.DODField, Callable[[models.Model], Any]]]:
134129
"""Initialise the data model preparers once here.
135130
136131
Extracts the preparers from the model and generate a list of callables
137132
to avoid doing that work on every object instance over.
138133
"""
139-
index_fields = getattr(self, "_fields", {})
134+
index_fields: dict[str, fields.DODField] = getattr(self, "_fields", {})
140135
preparers = []
136+
fn: Callable[[models.Model], Any]
141137
for name, field in iter(index_fields.items()):
142138
if not isinstance(field, fields.DODField): # pragma: no cover
143139
continue
@@ -162,13 +158,13 @@ def init_prepare(self):
162158

163159
return preparers
164160

165-
def prepare(self, instance):
161+
def prepare(self, instance: models.Model) -> dict[str, Any]:
166162
"""Generate the opensearch's document from `instance` based on defined fields."""
167163
data = {name: prep_func(instance) for name, field, prep_func in self._prepared_fields}
168164
return data
169165

170166
@classmethod
171-
def to_field(cls, field_name, model_field):
167+
def to_field(cls, field_name: str, model_field: models.Field) -> fields.DODField:
172168
"""Return the opensearch field instance mapped to the model field class.
173169
174170
This is a good place to hook into if you have more complex
@@ -179,14 +175,16 @@ def to_field(cls, field_name, model_field):
179175
except KeyError: # pragma: no cover
180176
raise ModelFieldNotMappedError(f"Cannot convert model field {field_name} to an Opensearch field!")
181177

182-
def bulk(self, actions, using=None, **kwargs):
178+
def bulk(
179+
self, actions: Iterable[dict[str, Any]], using: str = None, **kwargs: Any
180+
) -> Union[tuple[int, int], tuple[int, list]]:
183181
"""Execute given actions in bulk."""
184182
response = bulk(client=self._get_connection(using), actions=actions, **kwargs)
185183
# send post index signal
186184
post_index.send(sender=self.__class__, instance=self, actions=actions, response=response)
187185
return response
188186

189-
def parallel_bulk(self, actions, using=None, **kwargs):
187+
def parallel_bulk(self, actions: Iterable[dict[str, Any]], using: str = None, **kwargs: Any) -> tuple[int, list]:
190188
"""Parallel version of `bulk`."""
191189
kwargs.setdefault("chunk_size", self.django.queryset_pagination)
192190
bulk_actions = parallel_bulk(client=self._get_connection(using), actions=actions, **kwargs)
@@ -199,7 +197,7 @@ def parallel_bulk(self, actions, using=None, **kwargs):
199197
return 1, []
200198

201199
@classmethod
202-
def generate_id(cls, object_instance):
200+
def generate_id(cls, object_instance: models.Model) -> Any:
203201
"""Generate the opensearch's _id from a Django `Model` instance.
204202
205203
The default behavior is to use the Django object's pk (id) as the
@@ -208,47 +206,52 @@ def generate_id(cls, object_instance):
208206
"""
209207
return object_instance.pk
210208

211-
def _prepare_action(self, object_instance, action):
209+
def _prepare_action(self, object_instance: models.Model, action: BulkAction) -> dict[str, Any]:
212210
return {
213211
"_op_type": action,
214212
"_index": self._index._name, # noqa
215213
"_id": self.generate_id(object_instance),
216214
"_source" if action != "update" else "doc": (self.prepare(object_instance) if action != "delete" else None),
217215
}
218216

219-
def _get_actions(self, object_list, action):
217+
def _get_actions(self, object_list: Iterable[models.Model], action: BulkAction) -> Iterable[dict[str, Any]]:
220218
for object_instance in object_list:
221219
if action == "delete" or self.should_index_object(object_instance):
222220
yield self._prepare_action(object_instance, action)
223221

224-
def _bulk(self, *args, parallel=False, using=None, **kwargs):
222+
def _bulk(
223+
self, actions: Iterable[dict[str, Any]], parallel: bool = False, using: str = None, **kwargs: Any
224+
) -> Union[tuple[int, int], tuple[int, list]]:
225225
"""Allow switching between normal and parallel bulk operation."""
226226
if parallel:
227-
return self.parallel_bulk(*args, using=using, **kwargs)
228-
return self.bulk(*args, using=using, **kwargs)
227+
return self.parallel_bulk(actions, using=using, **kwargs)
228+
return self.bulk(actions, using=using, **kwargs)
229229

230-
def should_index_object(self, obj):
230+
def should_index_object(self, object_instance: models.Model) -> bool:
231231
"""Whether given object should be indexed.
232232
233233
Overwriting this method and returning a boolean value should determine
234234
whether the object should be indexed.
235235
"""
236236
return True
237237

238-
def update(self, thing, action, *args, refresh=None, using=None, **kwargs): # noqa
238+
def update( # type: ignore[override] # noqa
239+
self,
240+
thing: Union[models.Model, Iterable[models.Model]],
241+
action: BulkAction,
242+
refresh: bool = None,
243+
parallel: bool = None,
244+
using: str = None,
245+
**kwargs: Any,
246+
) -> Union[tuple[int, int], tuple[int, list]]:
239247
"""Update document in OS for a model, iterable of models or queryset."""
240248
if refresh is None:
241249
refresh = getattr(self.Index, "auto_refresh", DODConfig.auto_refresh_enabled())
250+
if parallel is None:
251+
parallel = DODConfig.parallel_enabled()
242252

243-
if isinstance(thing, models.Model):
244-
object_list = [thing]
245-
else:
246-
object_list = thing
253+
object_list = [thing] if isinstance(thing, models.Model) else thing
247254

248255
return self._bulk(
249-
self._get_actions(object_list, action),
250-
*args,
251-
refresh=refresh,
252-
using=using,
253-
**kwargs,
256+
self._get_actions(object_list, action), parallel=parallel, refresh=refresh, using=using, **kwargs
254257
)

0 commit comments

Comments
 (0)