11import abc
22import dataclasses
3- import decimal
43import json
54import logging
65import operator
4544from .encoders import jsonable_encoder
4645from .render_tree import render_tree
4746from .token_escaper import TokenEscaper
47+ from .types import Coordinates , CoordinateType , GeoFilter
4848
4949
5050model_registry = {}
@@ -405,7 +405,6 @@ class RediSearchFieldTypes(Enum):
405405 GEO = "GEO"
406406
407407
408- # TODO: How to handle Geo fields?
409408DEFAULT_PAGE_SIZE = 1000
410409
411410
@@ -535,8 +534,12 @@ def validate_sort_fields(self, sort_fields: List[str]):
535534 def resolve_field_type (field : "FieldInfo" , op : Operators ) -> RediSearchFieldTypes :
536535 field_info : Union [FieldInfo , PydanticFieldInfo ] = field
537536
537+ typ = get_outer_type (field_info )
538+
538539 if getattr (field_info , "primary_key" , None ) is True :
539540 return RediSearchFieldTypes .TAG
541+ elif typ in [CoordinateType , Coordinates ]:
542+ return RediSearchFieldTypes .GEO
540543 elif op is Operators .LIKE :
541544 fts = getattr (field_info , "full_text_search" , None )
542545 if fts is not True : # Could be PydanticUndefined
@@ -552,7 +555,6 @@ def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldType
552555 if not isinstance (field_type , type ):
553556 field_type = field_type .__origin__
554557
555- # TODO: GEO fields
556558 container_type = get_origin (field_type )
557559
558560 if is_supported_container_type (container_type ):
@@ -726,6 +728,15 @@ def resolve_value(
726728 field_name = field_name , expanded_value = expanded_value
727729 )
728730
731+ elif field_type is RediSearchFieldTypes .GEO :
732+ if not isinstance (value , GeoFilter ):
733+ raise QuerySyntaxError (
734+ "You can only use a GeoFilter object with a GEO field."
735+ )
736+
737+ if op is Operators .EQ :
738+ result += f"@{ field_name } :[{ value } ]"
739+
729740 return result
730741
731742 def resolve_redisearch_pagination (self ):
@@ -1804,6 +1815,8 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
18041815 schema = cls .schema_for_type (name , embedded_cls , field_info )
18051816 elif typ is bool :
18061817 schema = f"{ name } TAG"
1818+ elif typ in [CoordinateType , Coordinates ]:
1819+ schema = f"{ name } GEO"
18071820 elif is_numeric_type (typ ):
18081821 vector_options : Optional [VectorFieldOptions ] = getattr (
18091822 field_info , "vector_options" , None
@@ -2107,7 +2120,6 @@ def schema_for_type(
21072120 else typ
21082121 )
21092122
2110- # TODO: GEO field
21112123 if is_vector and vector_options :
21122124 schema = f"{ path } AS { index_field_name } { vector_options .schema } "
21132125 elif parent_is_container_type or parent_is_model_in_container :
@@ -2128,6 +2140,8 @@ def schema_for_type(
21282140 schema += " CASESENSITIVE"
21292141 elif typ is bool :
21302142 schema = f"{ path } AS { index_field_name } TAG"
2143+ elif typ in [CoordinateType , Coordinates ]:
2144+ schema = f"{ path } AS { index_field_name } GEO"
21312145 elif is_numeric_type (typ ):
21322146 schema = f"{ path } AS { index_field_name } NUMERIC"
21332147 elif issubclass (typ , str ):
0 commit comments