11import time
2+ import toml
23from typing import List , Optional
34import psycopg2
45from engine .base_client import BaseUploader
@@ -13,28 +14,45 @@ class PGVectorUploader(BaseUploader):
1314 vector_count : int = None
1415
1516 @classmethod
16- def init_client (cls , host , distance , vector_count , connection_params , upload_params ,
17- extra_columns_name : list , extra_columns_type : list ):
18- database , host , port , user , password = process_connection_params (connection_params , host )
19- cls .conn = psycopg2 .connect (database = database , user = user , password = password , host = host , port = port )
17+ def init_client (
18+ cls ,
19+ host ,
20+ distance ,
21+ vector_count ,
22+ connection_params ,
23+ upload_params ,
24+ extra_columns_name : list ,
25+ extra_columns_type : list ,
26+ ):
27+ database , host , port , user , password = process_connection_params (
28+ connection_params , host
29+ )
30+ cls .conn = psycopg2 .connect (
31+ database = database , user = user , password = password , host = host , port = port
32+ )
2033 cls .host = host
2134 cls .upload_params = upload_params
2235 cls .engine_type = upload_params .get ("engine_type" , "c" )
23- cls .distance = DISTANCE_MAPPING_CREATE [distance ] if cls .engine_type == "c" else DISTANCE_MAPPING_CREATE_RUST [
24- distance ]
36+ cls .distance = (
37+ DISTANCE_MAPPING_CREATE [distance ]
38+ if cls .engine_type == "c"
39+ else DISTANCE_MAPPING_CREATE_RUST [distance ]
40+ )
2541 cls .vector_count = vector_count
2642
2743 @classmethod
28- def upload_batch (cls , ids : List [int ], vectors : List [list ], metadata : List [Optional [dict ]]):
44+ def upload_batch (
45+ cls , ids : List [int ], vectors : List [list ], metadata : List [Optional [dict ]]
46+ ):
2947 if len (ids ) != len (vectors ):
3048 raise RuntimeError ("PGVector batch upload unhealthy" )
3149 # Getting the names of structured data columns based on the first meta information.
32- col_name_tuple = ('id' , ' vector' )
33- col_type_tuple = ('%s' , ' %s::real[]' )
50+ col_name_tuple = ("id" , " vector" )
51+ col_type_tuple = ("%s" , " %s::real[]" )
3452 if metadata [0 ] is not None :
3553 for col_name in list (metadata [0 ].keys ()):
3654 col_name_tuple += (col_name ,)
37- col_type_tuple += ('%s' ,)
55+ col_type_tuple += ("%s" ,)
3856
3957 insert_data = []
4058 for i in range (0 , len (ids )):
@@ -43,7 +61,9 @@ def upload_batch(cls, ids: List[int], vectors: List[list], metadata: List[Option
4361 for col_name in list (metadata [i ].keys ()):
4462 value = metadata [i ][col_name ]
4563 # Determining if the data is a dictionary type of latitude and longitude.
46- if isinstance (value , dict ) and ('lon' and 'lat' ) in list (value .keys ()):
64+ if isinstance (value , dict ) and ("lon" and "lat" ) in list (
65+ value .keys ()
66+ ):
4767 raise RuntimeError ("Postgres doesn't support geo datasets" )
4868 else :
4969 temp_tuple += (value ,)
@@ -63,21 +83,22 @@ def upload_batch(cls, ids: List[int], vectors: List[list], metadata: List[Option
6383
6484 @classmethod
6585 def post_upload (cls , distance ):
66- index_options_c = ""
67- index_options_rust = ""
68- for key in cls .upload_params .get ("index_params" , {}).keys ():
69- index_options_c += ( "{}={}" if index_options_c == "" else ", {}={}" ). format (
70- key , cls . upload_params . get ( 'index_params' , {})[ key ])
71- index_options_rust += ( "{}={}" if index_options_rust == "" else " \n {}={}" ). format (
72- key , cls .upload_params . get ( 'index_params' , {})[ key ])
73- create_index_command = f"CREATE INDEX ON { PGVECTOR_INDEX } USING hnsw (vector { cls .distance } ) WITH ( { index_options_c } );"
74- if cls . engine_type == "rust" :
86+ if cls . engine_type == "c" :
87+ index_options_c = ""
88+ for key in cls .upload_params .get ("index_params" , {}).keys ():
89+ index_options_c += (
90+ "{}={}" if index_options_c == "" else " , {}={}"
91+ ). format ( key , cls . upload_params . get ( "index_params" , {})[ key ])
92+ create_index_command = f"CREATE INDEX ON { PGVECTOR_INDEX } USING hnsw (vector { cls .distance } ) WITH ( { index_options_c } );"
93+ elif cls .engine_type == "rust" :
94+ index_options_rust = toml . dumps ( cls . upload_params . get ( "index_params" , {}))
7595 create_index_command = f"""
7696CREATE INDEX ON { PGVECTOR_INDEX } USING vectors (vector { cls .distance } ) WITH (options=$$
77- [indexing.hnsw]
7897{ index_options_rust }
7998$$);
8099"""
100+ else :
101+ raise ValueError ("PGVector engine type must be c or rust" )
81102
82103 # create index (blocking)
83104 with cls .conn .cursor () as cur :
@@ -86,5 +107,7 @@ def post_upload(cls, distance):
86107 cls .conn .commit ()
87108 # wait index finished
88109 with cls .conn .cursor () as cur :
89- cur .execute ("SELECT phase, tuples_done, tuples_total FROM pg_stat_progress_create_index;" )
110+ cur .execute (
111+ "SELECT phase, tuples_done, tuples_total FROM pg_stat_progress_create_index;"
112+ )
90113 cls .conn .commit ()
0 commit comments