Skip to content

Commit b28da61

Browse files
committed
fix: UT & code format
1 parent 8cc5c91 commit b28da61

File tree

2 files changed

+33
-90
lines changed

2 files changed

+33
-90
lines changed

openchatbi/utils.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -532,75 +532,71 @@ def max_marginal_relevance_search(
532532

533533
# Get initial candidates using BM25 similarity search
534534
candidates = self.similarity_search_with_score(query, k=fetch_k, **kwargs)
535-
535+
536536
if not candidates:
537537
return []
538-
538+
539539
if len(candidates) <= k:
540540
return [doc for doc, _ in candidates]
541-
541+
542542
# Normalize BM25 scores to [0, 1] for proper MMR calculation
543543
scores = [score for _, score in candidates]
544544
min_score = min(scores) if scores else 0
545545
max_score = max(scores) if scores else 1
546546
score_range = max_score - min_score if max_score > min_score else 1
547-
548-
normalized_candidates = [
549-
(doc, (score - min_score) / score_range)
550-
for doc, score in candidates
551-
]
552-
547+
548+
normalized_candidates = [(doc, (score - min_score) / score_range) for doc, score in candidates]
549+
553550
# MMR implementation following standard algorithm
554551
selected = []
555552
remaining = list(range(len(normalized_candidates)))
556-
553+
557554
# Select documents iteratively using MMR formula
558555
while len(selected) < k and remaining:
559-
best_mmr_score = float('-inf')
556+
best_mmr_score = float("-inf")
560557
best_idx = -1
561558
best_remaining_idx = -1
562-
559+
563560
for i, doc_idx in enumerate(remaining):
564561
candidate_doc, relevance_score = normalized_candidates[doc_idx]
565-
562+
566563
# Calculate maximum similarity to already selected documents
567564
max_similarity = 0.0
568565
if selected:
569566
max_similarity = max(
570567
self._calculate_similarity(candidate_doc, normalized_candidates[sel_idx][0])
571568
for sel_idx in selected
572569
)
573-
570+
574571
# Standard MMR formula: λ * Sim(q, d) - (1-λ) * max(Sim(d, s)) for s in selected
575572
mmr_score = lambda_mult * relevance_score - (1 - lambda_mult) * max_similarity
576-
573+
577574
if mmr_score > best_mmr_score:
578575
best_mmr_score = mmr_score
579576
best_idx = doc_idx
580577
best_remaining_idx = i
581-
578+
582579
if best_idx != -1:
583580
selected.append(best_idx)
584581
remaining.pop(best_remaining_idx)
585-
582+
586583
return [normalized_candidates[idx][0] for idx in selected]
587584

588585
def _calculate_similarity(self, doc1: Document, doc2: Document) -> float:
589586
"""Calculate similarity between two documents using Jaccard similarity.
590-
587+
591588
Args:
592589
doc1: First document.
593590
doc2: Second document.
594-
591+
595592
Returns:
596593
Similarity score between 0 and 1 (higher means more similar).
597594
"""
598595
tokens1 = set(self._tokenize(doc1.page_content))
599596
tokens2 = set(self._tokenize(doc2.page_content))
600-
597+
601598
# Calculate Jaccard similarity
602599
intersection = len(tokens1 & tokens2)
603600
union = len(tokens1 | tokens2)
604-
605-
return intersection / union if union > 0 else 0.0
606601

602+
return intersection / union if union > 0 else 0.0

tests/test_simple_store.py

Lines changed: 15 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -208,89 +208,35 @@ def test_chinese_and_mixed_language(self):
208208
"数据科学" in doc.page_content for doc in cn_results
209209
)
210210

211-
def test_similarity_search_by_vector(self, simple_store):
212-
"""Test similarity_search_by_vector method."""
213-
# Test with dummy embedding vector
214-
dummy_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
215-
results = simple_store.similarity_search_by_vector(dummy_embedding, k=2)
216-
217-
assert len(results) == 2
218-
assert all(hasattr(doc, "page_content") for doc in results)
219-
220-
# Test k parameter bounds
221-
results = simple_store.similarity_search_by_vector(dummy_embedding, k=10)
222-
assert len(results) == 4 # Should return all documents
223-
224-
# Test empty store
225-
empty_store = SimpleStore([])
226-
results = empty_store.similarity_search_by_vector(dummy_embedding, k=5)
227-
assert results == []
228-
229-
def test_max_marginal_relevance_search_by_vector(self, simple_store):
230-
"""Test max_marginal_relevance_search_by_vector method."""
231-
dummy_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]
232-
233-
# Test basic functionality
234-
results = simple_store.max_marginal_relevance_search_by_vector(
235-
dummy_embedding, k=2, fetch_k=4, lambda_mult=0.5
236-
)
237-
assert len(results) == 2
238-
assert all(hasattr(doc, "page_content") for doc in results)
239-
240-
# Test with k >= fetch_k
241-
results = simple_store.max_marginal_relevance_search_by_vector(
242-
dummy_embedding, k=4, fetch_k=3
243-
)
244-
assert len(results) == 3 # Should return fetch_k documents
245-
246-
# Test diversity (lambda_mult = 0 should prioritize diversity)
247-
results_diverse = simple_store.max_marginal_relevance_search_by_vector(
248-
dummy_embedding, k=2, fetch_k=4, lambda_mult=0.0
249-
)
250-
assert len(results_diverse) == 2
251-
252-
# Test empty store
253-
empty_store = SimpleStore([])
254-
results = empty_store.max_marginal_relevance_search_by_vector(dummy_embedding, k=2)
255-
assert results == []
256-
257211
def test_max_marginal_relevance_search(self, simple_store):
258212
"""Test max_marginal_relevance_search method."""
259213
query = "programming language"
260-
214+
261215
# Test basic MMR search
262-
results = simple_store.max_marginal_relevance_search(
263-
query, k=2, fetch_k=4, lambda_mult=0.5
264-
)
216+
results = simple_store.max_marginal_relevance_search(query, k=2, fetch_k=4, lambda_mult=0.5)
265217
assert len(results) == 2
266218
assert all(hasattr(doc, "page_content") for doc in results)
267-
219+
268220
# Test relevance-focused search (lambda_mult = 1.0)
269-
results_relevant = simple_store.max_marginal_relevance_search(
270-
query, k=3, fetch_k=4, lambda_mult=1.0
271-
)
221+
results_relevant = simple_store.max_marginal_relevance_search(query, k=3, fetch_k=4, lambda_mult=1.0)
272222
assert len(results_relevant) == 3
273-
223+
274224
# Test diversity-focused search (lambda_mult = 0.0)
275-
results_diverse = simple_store.max_marginal_relevance_search(
276-
query, k=3, fetch_k=4, lambda_mult=0.0
277-
)
225+
results_diverse = simple_store.max_marginal_relevance_search(query, k=3, fetch_k=4, lambda_mult=0.0)
278226
assert len(results_diverse) == 3
279-
227+
280228
# Verify different lambda values produce different results
281229
# (unless there are ties in scoring)
282230
assert len(results_relevant) == len(results_diverse)
283-
231+
284232
# Test with k >= fetch_k
285-
results = simple_store.max_marginal_relevance_search(
286-
query, k=5, fetch_k=3, lambda_mult=0.5
287-
)
233+
results = simple_store.max_marginal_relevance_search(query, k=5, fetch_k=3, lambda_mult=0.5)
288234
assert len(results) == 3 # Should return fetch_k documents
289-
235+
290236
# Test empty query
291237
results = simple_store.max_marginal_relevance_search("", k=2)
292238
assert len(results) <= 2
293-
239+
294240
# Test empty store
295241
empty_store = SimpleStore([])
296242
results = empty_store.max_marginal_relevance_search(query, k=2)
@@ -302,17 +248,18 @@ def test_calculate_similarity(self, simple_store):
302248
doc1 = simple_store.documents[0] # "Python is a programming language"
303249
doc2 = simple_store.documents[1] # "Machine learning is a subset of AI"
304250
doc3 = simple_store.documents[0] # Same as doc1
305-
251+
306252
# Test similarity between different documents
307253
similarity_diff = simple_store._calculate_similarity(doc1, doc2)
308254
assert 0.0 <= similarity_diff <= 1.0
309-
255+
310256
# Test similarity between identical documents
311257
similarity_same = simple_store._calculate_similarity(doc1, doc3)
312258
assert similarity_same == 1.0
313-
259+
314260
# Test with empty documents
315261
from langchain_core.documents import Document
262+
316263
empty_doc1 = Document(page_content="", metadata={})
317264
empty_doc2 = Document(page_content="", metadata={})
318265
similarity_empty = simple_store._calculate_similarity(empty_doc1, empty_doc2)

0 commit comments

Comments
 (0)