Skip to content

Commit a4fe144

Browse files
Support function edit sdk
Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
1 parent f01620e commit a4fe144

File tree

6 files changed

+388
-10
lines changed

6 files changed

+388
-10
lines changed

examples/function_edit.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from pymilvus import (
2+
MilvusClient,
3+
Function, DataType, FunctionType,
4+
)
5+
6+
collection_name = "text_embedding"
7+
8+
milvus_client = MilvusClient("http://localhost:19530")
9+
10+
has_collection = milvus_client.has_collection(collection_name, timeout=5)
11+
if has_collection:
12+
milvus_client.drop_collection(collection_name)
13+
14+
schema = milvus_client.create_schema()
15+
schema.add_field("id", DataType.INT64, is_primary=True, auto_id=False)
16+
schema.add_field("document", DataType.VARCHAR, max_length=9000)
17+
schema.add_field("dense", DataType.FLOAT_VECTOR, dim=1536)
18+
19+
text_embedding_function = Function(
20+
name="openai",
21+
function_type=FunctionType.TEXTEMBEDDING,
22+
input_field_names=["document"],
23+
output_field_names="dense",
24+
params={
25+
"provider": "openai",
26+
"model_name": "text-embedding-3-small",
27+
}
28+
)
29+
30+
schema.add_function(text_embedding_function)
31+
32+
index_params = milvus_client.prepare_index_params()
33+
index_params.add_index(
34+
field_name="dense",
35+
index_name="dense_index",
36+
index_type="AUTOINDEX",
37+
metric_type="IP",
38+
)
39+
40+
ret = milvus_client.create_collection(collection_name, schema=schema, index_params=index_params, consistency_level="Strong")
41+
42+
ret = milvus_client.describe_collection(collection_name)
43+
print(ret["functions"][0])
44+
45+
text_embedding_function.params["user"] = "user123"
46+
47+
milvus_client.alter_collection_function(collection_name, "openai", text_embedding_function)
48+
49+
ret = milvus_client.describe_collection(collection_name)
50+
print(ret["functions"][0])
51+
52+
milvus_client.drop_collection_function(collection_name, "openai")
53+
54+
ret = milvus_client.describe_collection(collection_name)
55+
print(ret["functions"])
56+
57+
text_embedding_function.params["user"] = "user1234"
58+
59+
milvus_client.add_collection_function(collection_name, text_embedding_function)
60+
61+
ret = milvus_client.describe_collection(collection_name)
62+
print(ret["functions"][0])

pymilvus/client/async_grpc_handler.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,60 @@ async def add_collection_field(
13061306
)
13071307
check_status(status)
13081308

1309+
@retry_on_rpc_failure()
1310+
async def drop_collection_function(
1311+
self,
1312+
collection_name: str,
1313+
function_name: str,
1314+
timeout: Optional[float] = None,
1315+
**kwargs,
1316+
):
1317+
await self.ensure_channel_ready()
1318+
check_pass_param(collection_name=collection_name, timeout=timeout)
1319+
request = Prepare.drop_collection_function_request(collection_name, function_name)
1320+
1321+
status = await self._async_stub.DropCollectionFunction(
1322+
request, timeout=timeout, metadata=_api_level_md(**kwargs)
1323+
)
1324+
check_status(status)
1325+
1326+
@retry_on_rpc_failure()
1327+
async def add_collection_function(
1328+
self,
1329+
collection_name: str,
1330+
function: Function,
1331+
timeout: Optional[float] = None,
1332+
**kwargs,
1333+
):
1334+
await self.ensure_channel_ready()
1335+
check_pass_param(collection_name=collection_name, timeout=timeout)
1336+
request = Prepare.add_collection_function_request(collection_name, function)
1337+
1338+
status = await self._async_stub.AddCollectionFunction(
1339+
request, timeout=timeout, metadata=_api_level_md(**kwargs)
1340+
)
1341+
check_status(status)
1342+
1343+
@retry_on_rpc_failure()
1344+
async def alter_collection_function(
1345+
self,
1346+
collection_name: str,
1347+
function_name: str,
1348+
function: Function,
1349+
timeout: Optional[float] = None,
1350+
**kwargs,
1351+
):
1352+
await self.ensure_channel_ready()
1353+
check_pass_param(collection_name=collection_name, timeout=timeout)
1354+
request = Prepare.alter_collection_function_request(
1355+
collection_name, function_name, function
1356+
)
1357+
1358+
status = await self._async_stub.AlterCollectionFunction(
1359+
request, timeout=timeout, metadata=_api_level_md(**kwargs)
1360+
)
1361+
check_status(status)
1362+
13091363
@retry_on_rpc_failure()
13101364
async def list_indexes(self, collection_name: str, timeout: Optional[float] = None, **kwargs):
13111365
await self.ensure_channel_ready()

pymilvus/client/grpc_handler.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,57 @@ def add_collection_field(
408408
)
409409
check_status(status)
410410

411+
@retry_on_rpc_failure()
412+
def drop_collection_function(
413+
self,
414+
collection_name: str,
415+
function_name: str,
416+
timeout: Optional[float] = None,
417+
**kwargs,
418+
):
419+
check_pass_param(collection_name=collection_name, timeout=timeout)
420+
request = Prepare.drop_collection_function_request(collection_name, function_name)
421+
422+
status = self._stub.DropCollectionFunction(
423+
request, timeout=timeout, metadata=_api_level_md(**kwargs)
424+
)
425+
check_status(status)
426+
427+
@retry_on_rpc_failure()
428+
def add_collection_function(
429+
self,
430+
collection_name: str,
431+
function: Function,
432+
timeout: Optional[float] = None,
433+
**kwargs,
434+
):
435+
check_pass_param(collection_name=collection_name, timeout=timeout)
436+
request = Prepare.add_collection_function_request(collection_name, function)
437+
438+
status = self._stub.AddCollectionFunction(
439+
request, timeout=timeout, metadata=_api_level_md(**kwargs)
440+
)
441+
check_status(status)
442+
443+
@retry_on_rpc_failure()
444+
def alter_collection_function(
445+
self,
446+
collection_name: str,
447+
function_name: str,
448+
function: Function,
449+
timeout: Optional[float] = None,
450+
**kwargs,
451+
):
452+
check_pass_param(collection_name=collection_name, timeout=timeout)
453+
request = Prepare.alter_collection_function_request(
454+
collection_name, function_name, function
455+
)
456+
457+
status = self._stub.AlterCollectionFunction(
458+
request, timeout=timeout, metadata=_api_level_md(**kwargs)
459+
)
460+
check_status(status)
461+
411462
@retry_on_rpc_failure()
412463
def alter_collection_properties(
413464
self, collection_name: str, properties: List, timeout: Optional[float] = None, **kwargs

pymilvus/client/prepare.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -242,16 +242,7 @@ def get_schema_from_collection_schema(
242242
schema.struct_array_fields.append(struct_schema)
243243

244244
for f in fields.functions:
245-
function_schema = schema_types.FunctionSchema(
246-
name=f.name,
247-
description=f.description,
248-
type=f.type,
249-
input_field_names=f.input_field_names,
250-
output_field_names=f.output_field_names,
251-
)
252-
for k, v in f.params.items():
253-
kv_pair = common_types.KeyValuePair(key=str(k), value=str(v))
254-
function_schema.params.append(kv_pair)
245+
function_schema = cls.convert_function_to_function_schema(f)
255246
schema.functions.append(function_schema)
256247

257248
return schema
@@ -364,6 +355,34 @@ def get_schema(
364355
def drop_collection_request(cls, collection_name: str) -> milvus_types.DropCollectionRequest:
365356
return milvus_types.DropCollectionRequest(collection_name=collection_name)
366357

358+
@classmethod
359+
def drop_collection_function_request(
360+
cls, collection_name: str, function_name: str
361+
) -> milvus_types.DropCollectionFunctionRequest:
362+
return milvus_types.DropCollectionFunctionRequest(
363+
collection_name=collection_name, function_name=function_name
364+
)
365+
366+
@classmethod
367+
def add_collection_function_request(
368+
cls, collection_name: str, f: Function
369+
) -> milvus_types.AddCollectionFunctionRequest:
370+
function_schema = cls.convert_function_to_function_schema(f)
371+
return milvus_types.AddCollectionFunctionRequest(
372+
collection_name=collection_name, functionSchema=function_schema
373+
)
374+
375+
@classmethod
376+
def alter_collection_function_request(
377+
cls, collection_name: str, function_name: str, f: Function
378+
) -> milvus_types.AlterCollectionFunctionRequest:
379+
function_schema = cls.convert_function_to_function_schema(f)
380+
return milvus_types.AlterCollectionFunctionRequest(
381+
collection_name=collection_name,
382+
function_name=function_name,
383+
functionSchema=function_schema,
384+
)
385+
367386
@classmethod
368387
def add_collection_field_request(
369388
cls,
@@ -2450,3 +2469,17 @@ def update_replicate_configuration_request(
24502469
return milvus_types.UpdateReplicateConfigurationRequest(
24512470
replicate_configuration=replicate_configuration
24522471
)
2472+
2473+
@staticmethod
2474+
def convert_function_to_function_schema(f: Function) -> schema_types.FunctionSchema:
2475+
function_schema = schema_types.FunctionSchema(
2476+
name=f.name,
2477+
description=f.description,
2478+
type=f.type,
2479+
input_field_names=f.input_field_names,
2480+
output_field_names=f.output_field_names,
2481+
)
2482+
for k, v in f.params.items():
2483+
kv_pair = common_types.KeyValuePair(key=str(k), value=str(v))
2484+
function_schema.params.append(kv_pair)
2485+
return function_schema

pymilvus/milvus_client/async_milvus_client.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,105 @@ async def add_collection_field(
702702
**kwargs,
703703
)
704704

705+
706+
async def add_collection_function(
707+
self, collection_name: str, function: Function, timeout: Optional[float] = None, **kwargs
708+
):
709+
"""Add a new function to the collection.
710+
711+
Args:
712+
collection_name(``string``): The name of collection.
713+
function(``Function``): The function schema.
714+
timeout (``float``, optional): A duration of time in seconds to allow for the RPC.
715+
If timeout is set to None, the client keeps waiting until the server
716+
responds or an error occurs.
717+
**kwargs (``dict``): Optional field params
718+
719+
Raises:
720+
MilvusException: If anything goes wrong
721+
"""
722+
conn = self._get_connection()
723+
await conn.add_collection_function(
724+
collection_name,
725+
function,
726+
timeout=timeout,
727+
**kwargs,
728+
)
729+
730+
async def alter_collection_function(
731+
self,
732+
collection_name: str,
733+
function_name: str,
734+
function: Function,
735+
timeout: Optional[float] = None,
736+
**kwargs,
737+
):
738+
"""Alter a function in the collection.
739+
740+
Args:
741+
collection_name(``string``): The name of collection.
742+
function_name(``string``): The function name that needs to be modified
743+
function(``Function``): The function schema.
744+
timeout (``float``, optional): A duration of time in seconds to allow for the RPC.
745+
If timeout is set to None, the client keeps waiting until the server
746+
responds or an error occurs.
747+
**kwargs (``dict``): Optional field params
748+
749+
Raises:
750+
MilvusException: If anything goes wrong
751+
"""
752+
conn = self._get_connection()
753+
await conn.alter_collection_function(
754+
collection_name,
755+
function_name,
756+
function,
757+
timeout=timeout,
758+
**kwargs,
759+
)
760+
761+
async def drop_collection_function(
762+
self, collection_name: str, function_name: str, timeout: Optional[float] = None, **kwargs
763+
):
764+
"""Drop a function from the collection.
765+
766+
Args:
767+
collection_name(``string``): The name of collection.
768+
function_name(``string``): The function name that needs to be dropped
769+
timeout (``float``, optional): A duration of time in seconds to allow for the RPC.
770+
If timeout is set to None, the client keeps waiting until the server
771+
responds or an error occurs.
772+
**kwargs (``dict``): Optional field params
773+
774+
Raises:
775+
MilvusException: If anything goes wrong
776+
"""
777+
conn = self._get_connection()
778+
await conn.drop_collection_function(
779+
collection_name,
780+
function_name,
781+
timeout=timeout,
782+
**kwargs,
783+
)
784+
785+
@classmethod
786+
def create_schema(cls, **kwargs):
787+
kwargs["check_fields"] = False # do not check fields for now
788+
return CollectionSchema([], **kwargs)
789+
790+
@classmethod
791+
def create_field_schema(
792+
cls, name: str, data_type: DataType, desc: str = "", **kwargs
793+
) -> FieldSchema:
794+
return FieldSchema(name, data_type, desc, **kwargs)
795+
796+
@classmethod
797+
def prepare_index_params(cls, field_name: str = "", **kwargs) -> IndexParams:
798+
index_params = IndexParams()
799+
if field_name:
800+
validate_param("field_name", field_name, str)
801+
index_params.add_index(field_name, **kwargs)
802+
return index_params
803+
705804
async def close(self):
706805
await connections.async_remove_connection(self._using)
707806

0 commit comments

Comments
 (0)