Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.ai.model.chat.memory.repository.mongo.autoconfigure;

import java.lang.reflect.Method;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -27,11 +29,13 @@
import org.springframework.data.domain.Sort;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.index.Index;
import org.springframework.data.mongodb.core.index.IndexDefinition;
import org.springframework.data.mongodb.core.index.IndexOperations;

/**
* Class responsible for creating proper MongoDB indices for the ChatMemory. Creates a
* main index on the conversationId and timestamp fields, and a TTL index on the timestamp
* field if the TTL is set in properties.
* main index on the conversationId and timestamp fields, and a TTL index on the
* timestamp field if the TTL is set in properties.
*
* @author Łukasz Jernaś
* @see MongoChatMemoryProperties
Expand All @@ -41,42 +45,98 @@
@ConditionalOnProperty(value = "spring.ai.chat.memory.repository.mongo.create-indices", havingValue = "true")
public class MongoChatMemoryIndexCreatorAutoConfiguration {

private static final Logger logger = LoggerFactory.getLogger(MongoChatMemoryIndexCreatorAutoConfiguration.class);
private static final Logger logger = LoggerFactory
.getLogger(MongoChatMemoryIndexCreatorAutoConfiguration.class);

private final MongoTemplate mongoTemplate;

private final MongoChatMemoryProperties mongoChatMemoryProperties;

public MongoChatMemoryIndexCreatorAutoConfiguration(MongoTemplate mongoTemplate,
MongoChatMemoryProperties mongoChatMemoryProperties) {
public MongoChatMemoryIndexCreatorAutoConfiguration(final MongoTemplate mongoTemplate,
final MongoChatMemoryProperties mongoChatMemoryProperties) {
this.mongoTemplate = mongoTemplate;
this.mongoChatMemoryProperties = mongoChatMemoryProperties;
}

/**
* Initializes MongoDB indices after application context refresh.
*/
@EventListener(ContextRefreshedEvent.class)
public void initIndicesAfterStartup() {
logger.info("Creating MongoDB indices for ChatMemory");
// Create a main index
this.mongoTemplate.indexOps(Conversation.class)
.createIndex(new Index().on("conversationId", Sort.Direction.ASC).on("timestamp", Sort.Direction.DESC));

createMainIndex();
createOrUpdateTtlIndex();
}

private void createMainIndex() {
var indexOps = this.mongoTemplate.indexOps(Conversation.class);
var index = new Index().on("conversationId", Sort.Direction.ASC)
.on("timestamp", Sort.Direction.DESC);

// Use reflection to handle API differences across Spring Data MongoDB versions
createIndexSafely(indexOps, index);
}

private void createOrUpdateTtlIndex() {
if (!this.mongoChatMemoryProperties.getTtl().isZero()) {
var indexOps = this.mongoTemplate.indexOps(Conversation.class);
// Check for existing TTL index
this.mongoTemplate.indexOps(Conversation.class).getIndexInfo().forEach(idx -> {
indexOps.getIndexInfo().forEach(idx -> {
if (idx.getExpireAfter().isPresent()
&& !idx.getExpireAfter().get().equals(this.mongoChatMemoryProperties.getTtl())) {
&& !idx.getExpireAfter().get()
.equals(this.mongoChatMemoryProperties.getTtl())) {
logger.warn("Dropping existing TTL index, because TTL is different");
this.mongoTemplate.indexOps(Conversation.class).dropIndex(idx.getName());
indexOps.dropIndex(idx.getName());
}
});
this.mongoTemplate.indexOps(Conversation.class)
.createIndex(new Index().on("timestamp", Sort.Direction.ASC)
// Use reflection to handle API differences across Spring Data MongoDB
// versions
createIndexSafely(indexOps, new Index().on("timestamp", Sort.Direction.ASC)
.expire(this.mongoChatMemoryProperties.getTtl()));
}
}

/**
* Creates an index using reflection to handle API changes across different Spring
* Data MongoDB versions:
* <ul>
* <li>Spring Data MongoDB 4.2.x - 4.4.x: only {@code ensureIndex(IndexDefinition)}
* is available.</li>
* <li>Spring Data MongoDB 4.5.x+: {@code createIndex(IndexDefinition)} is the new
* API, {@code ensureIndex} is deprecated.</li>
* </ul>
* @param indexOps the IndexOperations instance
* @param index the index definition
* @throws IllegalStateException if neither method is available or invocation fails
*/
private void createIndexSafely(final IndexOperations indexOps, final IndexDefinition index) {
try {
// Try new API (Spring Data MongoDB 4.5.x+)
Method method = IndexOperations.class.getMethod("createIndex", IndexDefinition.class);
method.invoke(indexOps, index);
logger.debug("Created index using createIndex() method");
}
catch (NoSuchMethodException createIndexNotFound) {
// Fall back to old API (Spring Data MongoDB 4.2.x - 4.4.x)
try {
Method method = IndexOperations.class.getMethod("ensureIndex", IndexDefinition.class);
method.invoke(indexOps, index);
logger.debug("Created index using ensureIndex() method");
}
catch (NoSuchMethodException ensureIndexNotFound) {
throw new IllegalStateException(
"Neither createIndex() nor ensureIndex() method found on IndexOperations. "
+ "This may indicate an unsupported Spring Data MongoDB version.",
ensureIndexNotFound);
}
catch (ReflectiveOperationException ex) {
throw new IllegalStateException("Failed to invoke ensureIndex() method", ex);
}
}
catch (ReflectiveOperationException ex) {
throw new IllegalStateException("Failed to invoke createIndex() method", ex);
}
}

}
Loading