@@ -64,42 +64,55 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector,
6464 if c := C .faiss_SearchParameters_new (& rv .sp , sel ); c != 0 {
6565 return nil , fmt .Errorf ("failed to create faiss search params" )
6666 }
67+
68+ if len (params ) == 0 && sel == nil {
69+ return rv , nil
70+ }
71+
72+ var nlist , nprobe , nvecs , maxCodes int
73+ var ivfParams searchParamsIVF
74+
75+ rv .sp = C .faiss_SearchParametersIVF_cast (rv .sp )
76+
6777 // check if the index is IVF and set the search params
6878 if ivfIdx := C .faiss_IndexIVF_cast (idx .cPtr ()); ivfIdx != nil {
69- rv .sp = C .faiss_SearchParametersIVF_cast (rv .sp )
70- if len (params ) == 0 && sel == nil {
71- return rv , nil
72- }
73- var nlist , nprobe , nvecs , maxCodes int
7479 nlist = int (C .faiss_IndexIVF_nlist (ivfIdx ))
7580 nprobe = int (C .faiss_IndexIVF_nprobe (ivfIdx ))
7681 nvecs = int (C .faiss_Index_ntotal (idx .cPtr ()))
77- if defaultParams != nil {
78- if defaultParams .Nlist > 0 {
79- nlist = defaultParams .Nlist
80- }
81- if defaultParams .Nprobe > 0 {
82- nprobe = defaultParams .Nprobe
83- }
82+ } else if bivfIdx := C .faiss_IndexBinaryIVF_cast (idx .cPtrBinary ()); bivfIdx != nil {
83+ nlist = int (C .faiss_IndexBinaryIVF_nlist (bivfIdx ))
84+ nprobe = int (C .faiss_IndexBinaryIVF_nprobe (bivfIdx ))
85+ nvecs = int (C .faiss_IndexBinary_ntotal (idx .cPtrBinary ()))
86+ }
87+
88+ if defaultParams != nil {
89+ if defaultParams .Nlist > 0 {
90+ nlist = defaultParams .Nlist
8491 }
85- var ivfParams searchParamsIVF
86- if len (params ) > 0 {
87- if err := json .Unmarshal (params , & ivfParams ); err != nil {
88- rv .Delete ()
89- return nil , fmt .Errorf ("failed to unmarshal IVF search params, " +
90- "err:%v" , err )
91- }
92- if err := ivfParams .Validate (); err != nil {
93- rv .Delete ()
94- return nil , err
95- }
92+ if defaultParams .Nprobe > 0 {
93+ nprobe = defaultParams .Nprobe
9694 }
97- if ivfParams .NprobePct > 0 {
98- nprobe = max (int (float32 (nlist )* (ivfParams .NprobePct / 100 )), 1 )
95+ }
96+
97+ if len (params ) > 0 {
98+ if err := json .Unmarshal (params , & ivfParams ); err != nil {
99+ rv .Delete ()
100+ return nil , fmt .Errorf ("failed to unmarshal IVF search params, " +
101+ "err:%v" , err )
99102 }
100- if ivfParams .MaxCodesPct > 0 {
101- maxCodes = int (float32 (nvecs ) * (ivfParams .MaxCodesPct / 100 ))
102- } // else, maxCodes will be set to the default value of 0, which means no limit
103+ if err := ivfParams .Validate (); err != nil {
104+ rv .Delete ()
105+ return nil , err
106+ }
107+ }
108+ if ivfParams .NprobePct > 0 {
109+ nprobe = max (int (float32 (nlist )* (ivfParams .NprobePct / 100 )), 1 )
110+ }
111+ if ivfParams .MaxCodesPct > 0 {
112+ maxCodes = int (float32 (nvecs ) * (ivfParams .MaxCodesPct / 100 ))
113+ } // else, maxCodes will be set to the default value of 0, which means no limit
114+
115+ if ivfIdx := C .faiss_IndexIVF_cast (idx .cPtr ()); ivfIdx != nil {
103116 if c := C .faiss_SearchParametersIVF_new_with (
104117 & rv .sp ,
105118 sel ,
@@ -110,48 +123,14 @@ func NewSearchParams(idx Index, params json.RawMessage, sel *C.FaissIDSelector,
110123 return nil , fmt .Errorf ("failed to create faiss IVF search params" )
111124 }
112125 } else if bivfIdx := C .faiss_IndexBinaryIVF_cast (idx .cPtrBinary ()); bivfIdx != nil {
113- rv .sp = C .faiss_SearchParametersIVF_cast (rv .sp )
114- if len (params ) == 0 && sel == nil {
115- return rv , nil
116- }
117- var nlist , nprobe , nvecs , maxCodes int
118- nlist = int (C .faiss_IndexBinaryIVF_nlist (bivfIdx ))
119- nprobe = int (C .faiss_IndexBinaryIVF_nprobe (bivfIdx ))
120- nvecs = int (C .faiss_IndexBinaryIVF_ntotal (bivfIdx ))
121- if defaultParams != nil {
122- if defaultParams .Nlist > 0 {
123- nlist = defaultParams .Nlist
124- }
125- if defaultParams .Nprobe > 0 {
126- nprobe = defaultParams .Nprobe
127- }
128- }
129- var ivfParams searchParamsIVF
130- if len (params ) > 0 {
131- if err := json .Unmarshal (params , & ivfParams ); err != nil {
132- rv .Delete ()
133- return nil , fmt .Errorf ("failed to unmarshal IVF search params, " +
134- "err:%v" , err )
135- }
136- if err := ivfParams .Validate (); err != nil {
137- rv .Delete ()
138- return nil , err
139- }
140- }
141- if ivfParams .NprobePct > 0 {
142- nprobe = max (int (float32 (nlist )* (ivfParams .NprobePct / 100 )), 1 )
143- }
144- if ivfParams .MaxCodesPct > 0 {
145- maxCodes = int (float32 (nvecs ) * (ivfParams .MaxCodesPct / 100 ))
146- } // else, maxCodes will be set to the default value of 0, which means no limit
147126 if c := C .faiss_SearchParametersIVF_new_with (
148127 & rv .sp ,
149128 sel ,
150129 C .size_t (nprobe ),
151130 C .size_t (maxCodes ),
152131 ); c != 0 {
153132 rv .Delete ()
154- return nil , fmt .Errorf ("failed to create faiss IVF search params" )
133+ return nil , fmt .Errorf ("failed to create faiss BIVF search params" )
155134 }
156135 }
157136 return rv , nil
0 commit comments