Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@
import org.springframework.util.Assert;

/**
*
* Parse a textual, vector-store agnostic, filter expression language into
* {@link Filter.Expression}.
*
* <p>
* The vector-store agnostic, filter expression language is defined by a formal ANTLR4
* grammar (Filters.g4). The language looks and feels like a subset of the well known SQL
* WHERE filter expressions. For example, you can use the parser like this:
Expand Down Expand Up @@ -161,7 +160,9 @@ public void clearCache() {
this.cache.clear();
}

/** For testing only */
/**
* For testing only
*/
Map<String, Filter.Expression> getCache() {
return this.cache;
}
Expand Down Expand Up @@ -202,7 +203,13 @@ private String removeOuterQuotes(String in) {

@Override
public Filter.Operand visitIntegerConstant(FiltersParser.IntegerConstantContext ctx) {
return new Filter.Value(Integer.valueOf(ctx.getText()));
String text = ctx.getText();
try {
return new Filter.Value(Integer.parseInt(text));
}
catch (NumberFormatException ignored) {
return new Filter.Value(Long.parseLong(text));
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
package org.springframework.ai.vectorstore.filter;

import java.util.List;
import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import org.springframework.ai.vectorstore.filter.Filter.Expression;
import org.springframework.ai.vectorstore.filter.Filter.Group;
Expand All @@ -39,54 +43,55 @@
/**
* @author Christian Tzolov
* @author Sun Yuhan
* @author lance
*/
public class FilterExpressionTextParserTests {
class FilterExpressionTextParserTests {

FilterExpressionTextParser parser = new FilterExpressionTextParser();

@Test
public void testEQ() {
void testEQ() {
// country == "BG"
Expression exp = this.parser.parse("country == 'BG'");
assertThat(exp).isEqualTo(new Expression(EQ, new Key("country"), new Value("BG")));

assertThat(this.parser.getCache().get("WHERE " + "country == 'BG'")).isEqualTo(exp);
assertThat(this.parser.getCache()).containsEntry("WHERE " + "country == 'BG'", exp);
}

@Test
public void tesEqAndGte() {
void tesEqAndGte() {
// genre == "drama" AND year >= 2020
Expression exp = this.parser.parse("genre == 'drama' && year >= 2020");
assertThat(exp).isEqualTo(new Expression(AND, new Expression(EQ, new Key("genre"), new Value("drama")),
new Expression(GTE, new Key("year"), new Value(2020))));

assertThat(this.parser.getCache().get("WHERE " + "genre == 'drama' && year >= 2020")).isEqualTo(exp);
assertThat(this.parser.getCache()).containsEntry("WHERE " + "genre == 'drama' && year >= 2020", exp);
}

@Test
public void tesIn() {
void tesIn() {
// genre in ["comedy", "documentary", "drama"]
Expression exp = this.parser.parse("genre in ['comedy', 'documentary', 'drama']");
assertThat(exp)
.isEqualTo(new Expression(IN, new Key("genre"), new Value(List.of("comedy", "documentary", "drama"))));

assertThat(this.parser.getCache().get("WHERE " + "genre in ['comedy', 'documentary', 'drama']")).isEqualTo(exp);
assertThat(this.parser.getCache()).containsEntry("WHERE " + "genre in ['comedy', 'documentary', 'drama']", exp);
}

@Test
public void testNe() {
void testNe() {
// year >= 2020 OR country == "BG" AND city != "Sofia"
Expression exp = this.parser.parse("year >= 2020 OR country == \"BG\" AND city != \"Sofia\"");
assertThat(exp).isEqualTo(new Expression(OR, new Expression(GTE, new Key("year"), new Value(2020)),
new Expression(AND, new Expression(EQ, new Key("country"), new Value("BG")),
new Expression(NE, new Key("city"), new Value("Sofia")))));

assertThat(this.parser.getCache().get("WHERE " + "year >= 2020 OR country == \"BG\" AND city != \"Sofia\""))
.isEqualTo(exp);
assertThat(this.parser.getCache())
.containsEntry("WHERE " + "year >= 2020 OR country == \"BG\" AND city != \"Sofia\"", exp);
}

@Test
public void testGroup() {
void testGroup() {
// (year >= 2020 OR country == "BG") AND city NIN ["Sofia", "Plovdiv"]
Expression exp = this.parser.parse("(year >= 2020 OR country == \"BG\") AND city NIN [\"Sofia\", \"Plovdiv\"]");

Expand All @@ -95,26 +100,26 @@ public void testGroup() {
new Expression(EQ, new Key("country"), new Value("BG")))),
new Expression(NIN, new Key("city"), new Value(List.of("Sofia", "Plovdiv")))));

assertThat(this.parser.getCache()
.get("WHERE " + "(year >= 2020 OR country == \"BG\") AND city NIN [\"Sofia\", \"Plovdiv\"]"))
.isEqualTo(exp);
assertThat(this.parser.getCache())
.containsEntry("WHERE " + "(year >= 2020 OR country == \"BG\") AND city NIN [\"Sofia\", \"Plovdiv\"]", exp);
}

@Test
public void tesBoolean() {
void tesBoolean() {
// isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"]
Expression exp = this.parser.parse("isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"]");

assertThat(exp).isEqualTo(new Expression(AND,
new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)),
new Expression(GTE, new Key("year"), new Value(2020))),
new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US")))));
assertThat(this.parser.getCache()
.get("WHERE " + "isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"]")).isEqualTo(exp);

assertThat(this.parser.getCache())
.containsEntry("WHERE " + "isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"]", exp);
}

@Test
public void tesNot() {
void tesNot() {
// NOT(isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"])
Expression exp = this.parser
.parse("not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])");
Expand All @@ -126,13 +131,12 @@ public void tesNot() {
new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))),
null));

assertThat(this.parser.getCache()
.get("WHERE " + "not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])"))
.isEqualTo(exp);
assertThat(this.parser.getCache()).containsEntry(
"WHERE " + "not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])", exp);
}

@Test
public void tesNotNin() {
void tesNotNin() {
// NOT(country NOT IN ["BG", "NL", "US"])
Expression exp = this.parser.parse("not(country NOT IN [\"BG\", \"NL\", \"US\"])");

Expand All @@ -141,7 +145,7 @@ public void tesNotNin() {
}

@Test
public void tesNotNin2() {
void tesNotNin2() {
// NOT country NOT IN ["BG", "NL", "US"]
Expression exp = this.parser.parse("NOT country NOT IN [\"BG\", \"NL\", \"US\"]");

Expand All @@ -150,7 +154,7 @@ public void tesNotNin2() {
}

@Test
public void tesNestedNot() {
void tesNestedNot() {
// NOT(isOpen == true AND year >= 2020 AND NOT(country IN ["BG", "NL", "US"]))
Expression exp = this.parser
.parse("not(isOpen == true AND year >= 2020 AND NOT(country IN [\"BG\", \"NL\", \"US\"]))");
Expand All @@ -164,25 +168,24 @@ public void tesNestedNot() {
null))),
null));

assertThat(this.parser.getCache()
.get("WHERE " + "not(isOpen == true AND year >= 2020 AND NOT(country IN [\"BG\", \"NL\", \"US\"]))"))
.isEqualTo(exp);
assertThat(this.parser.getCache()).containsEntry(
"WHERE " + "not(isOpen == true AND year >= 2020 AND NOT(country IN [\"BG\", \"NL\", \"US\"]))", exp);
}

@Test
public void testDecimal() {
void testDecimal() {
// temperature >= -15.6 && temperature <= +20.13
String expText = "temperature >= -15.6 && temperature <= +20.13";
Expression exp = this.parser.parse(expText);

assertThat(exp).isEqualTo(new Expression(AND, new Expression(GTE, new Key("temperature"), new Value(-15.6)),
new Expression(LTE, new Key("temperature"), new Value(20.13))));

assertThat(this.parser.getCache().get("WHERE " + expText)).isEqualTo(exp);
assertThat(this.parser.getCache()).containsEntry("WHERE " + expText, exp);
}

@Test
public void testLong() {
void testLong() {
Expression exp2 = this.parser.parse("biz_id == 3L");
Expression exp3 = this.parser.parse("biz_id == -5L");

Expand All @@ -191,7 +194,7 @@ public void testLong() {
}

@Test
public void testIdentifiers() {
void testIdentifiers() {
Expression exp = this.parser.parse("'country.1' == 'BG'");
assertThat(exp).isEqualTo(new Expression(EQ, new Key("'country.1'"), new Value("BG")));

Expand All @@ -203,9 +206,24 @@ public void testIdentifiers() {
}

@Test
public void testUnescapedIdentifierWithUnderscores() {
void testUnescapedIdentifierWithUnderscores() {
Expression exp = this.parser.parse("file_name == 'medicaid-wa-faqs.pdf'");
assertThat(exp).isEqualTo(new Expression(EQ, new Key("file_name"), new Value("medicaid-wa-faqs.pdf")));
}

@MethodSource("constantConstantProvider")
@ParameterizedTest(name = "{index} => [{0}, expected={1}]")
void testConstants(String expr, Object expectedValue) {
Expression result = this.parser.parse(expr);
assertThat(result).isEqualTo(new Expression(EQ, new Key("id"), new Value(expectedValue)));
}

static Stream<Arguments> constantConstantProvider() {
return Stream.of(Arguments.of("id==" + Integer.MAX_VALUE, Integer.MAX_VALUE),
Arguments.of("id==" + Integer.MIN_VALUE, Integer.MIN_VALUE),
Arguments.of("id==" + Long.MAX_VALUE, Long.MAX_VALUE),
Arguments.of("id==" + Long.MIN_VALUE, Long.MIN_VALUE), Arguments.of("id==" + 0x100, 0x100),
Arguments.of("id==" + 1000000000000L, 1000000000000L), Arguments.of("id==" + Math.PI, Math.PI));
}

}