Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions src/main/java/net/staticstudios/audit/StaticAudit.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@

import java.sql.*;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;
Expand Down Expand Up @@ -242,13 +239,13 @@ public StaticAudit log(@NotNull UUID userId, @Nullable UUID sessionId, @NotNull
* @param sessionId The session ID to filter by, or null for any session.
* @param from The start timestamp for filtering, or null for no lower bound.
* @param to The end timestamp for filtering, or null for no upper bound.
* @param actionId The action ID to filter by, or null for any action.
* @param actionIds The action IDs to filter by, or null for any action.
* @param limit The maximum number of entries to retrieve.
* @return A CompletableFuture containing the list of matching audit log entries.
*/
public CompletableFuture<List<AuditLogEntry<?>>> retrieveAsync(@NotNull UUID userId, @Nullable UUID sessionId, @Nullable Instant from, @Nullable Instant to, @Nullable String actionId, int limit) {
public CompletableFuture<List<AuditLogEntry<?>>> retrieveAsync(@NotNull UUID userId, @Nullable UUID sessionId, @Nullable Instant from, @Nullable Instant to, int limit, String... actionIds) {
CompletableFuture<List<AuditLogEntry<?>>> future = new CompletableFuture<>();
async(() -> future.complete(retrieve(userId, sessionId, from, to, actionId, limit)));
async(() -> future.complete(retrieve(userId, sessionId, from, to, limit, actionIds)));
return future;
}

Expand All @@ -259,11 +256,11 @@ public CompletableFuture<List<AuditLogEntry<?>>> retrieveAsync(@NotNull UUID use
* @param sessionId The session ID to filter by, or null for any session.
* @param from The start timestamp for filtering, or null for no lower bound.
* @param to The end timestamp for filtering, or null for no upper bound.
* @param actionId The action ID to filter by, or null for any action.
* @param actionIds The action IDs to filter by, or null for any action.
* @param limit The maximum number of entries to retrieve.
* @return The list of matching audit log entries.
*/
public List<AuditLogEntry<?>> retrieve(@NotNull UUID userId, @Nullable UUID sessionId, @Nullable Instant from, @Nullable Instant to, @Nullable String actionId, int limit) {
public List<AuditLogEntry<?>> retrieve(@NotNull UUID userId, @Nullable UUID sessionId, @Nullable Instant from, @Nullable Instant to, int limit, String... actionIds) {
Preconditions.checkNotNull(userId, "User ID cannot be null");
Preconditions.checkArgument(limit > 0, "Limit must be greater than 0");
List<AuditLogEntry<?>> entries = new ArrayList<>();
Expand All @@ -279,9 +276,16 @@ public List<AuditLogEntry<?>> retrieve(@NotNull UUID userId, @Nullable UUID sess
if (to != null) {
sqlBuilder.append(" AND timestamp <= ?");
}
if (actionId != null) {
sqlBuilder.append(" AND action_id = ?");
if (actionIds.length > 0) {
sqlBuilder.append(" AND action_id IN (?");

if (actionIds.length > 1 ) {
sqlBuilder.repeat(", ?", actionIds.length - 1);
}

sqlBuilder.append(")");
}

sqlBuilder.append(" ORDER BY timestamp DESC LIMIT ?");
String sql = sqlBuilder.toString().formatted(schemaName, tableName);
try (PreparedStatement statement = connection.prepareStatement(sql)) {
Expand All @@ -296,7 +300,7 @@ public List<AuditLogEntry<?>> retrieve(@NotNull UUID userId, @Nullable UUID sess
if (to != null) {
statement.setObject(index++, Timestamp.from(to));
}
if (actionId != null) {
for (String actionId : actionIds) {
statement.setString(index++, actionId);
}
statement.setInt(index, limit);
Expand Down
50 changes: 40 additions & 10 deletions src/test/java/net/staticstudios/audit/LoggingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
import java.util.List;
import java.util.UUID;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.*;

public class LoggingTest extends AuditTest {
private static final Instant NOW = Instant.ofEpochMilli(0);
private StaticAudit audit;
private UUID userId;
private UUID sessionId;
private Action.SimpleAction<SimpleActionData> action;
private Action.SimpleAction<SimpleActionData> action1;
private Action.SimpleAction<SimpleActionData> action2;
private Action.SimpleAction<SimpleActionData> action3;

@BeforeEach
public void setUp() {
Expand All @@ -35,8 +36,12 @@ public void setUp() {

userId = UUID.randomUUID();
sessionId = UUID.randomUUID();
action = (Action.SimpleAction<SimpleActionData>) Action.simple("test_action", SimpleActionData.class);
audit.registerAction(action);
action1 = (Action.SimpleAction<SimpleActionData>) Action.simple("test_action", SimpleActionData.class);
action2 = (Action.SimpleAction<SimpleActionData>) Action.simple("test_action_2", SimpleActionData.class);
action3 = (Action.SimpleAction<SimpleActionData>) Action.simple("test_action_3", SimpleActionData.class);
audit.registerAction(action1);
audit.registerAction(action2);
audit.registerAction(action3);
}

@AfterEach
Expand All @@ -49,7 +54,7 @@ public void tearDown() throws SQLException {
@Test
public void testLogging() throws SQLException {
SimpleActionData data = new SimpleActionData("test");
audit.log(userId, sessionId, action, data);
audit.log(userId, sessionId, action1, data);

Connection connection = getConnection();
@Language("SQL") String sql = "SELECT * FROM %s.%s WHERE user_id = ?";
Expand All @@ -59,25 +64,50 @@ public void testLogging() throws SQLException {
assertTrue(rs.next());
assertEquals(userId, rs.getObject("user_id"));
assertEquals(sessionId, rs.getObject("session_id"));
assertEquals(action.getActionId(), rs.getString("action_id"));
assertEquals(data, action.fromJson(rs.getString("action_data")));
assertEquals(action1.getActionId(), rs.getString("action_id"));
assertEquals(data, action1.fromJson(rs.getString("action_data")));
}

@Test
public void testRetrieving() {
logMultiple(50);

List<AuditLogEntry<?>> entries;
entries = audit.retrieve(userId, null, null, null, null, 100);
entries = audit.retrieve(userId, null, null, null, 100);
assertEquals(50, entries.size());
entries = audit.retrieve(userId, null, null, null, null, 10);
entries = audit.retrieve(userId, null, null, null, 10);
assertEquals(10, entries.size());

for (int i = 0; i < 10; i++) {
assertEquals("test" + (49 - i), ((SimpleActionData) entries.get(i).getData()).data());
}
}

@Test
public void testRetrievingWithFilter() {
logMultiple(action1, 50);
logMultiple(action2, 30);
logMultiple(action3, 60);

List<AuditLogEntry<?>> entries;

entries = audit.retrieve(userId, null, null, null, 500);
assertEquals(140, entries.size());
entries = audit.retrieve(userId, null, null, null, 100, action1.getActionId(), action3.getActionId());
assertEquals(100, entries.size());
assertFalse(entries.stream().anyMatch(entry -> entry.getAction().getActionId().equals(action2.getActionId())));
entries = audit.retrieve(userId, null, null, null, 10, action1.getActionId());
assertEquals(10, entries.size());
assertTrue(entries.stream().allMatch(entry -> entry.getAction().getActionId().equals(action1.getActionId())));
assertFalse(entries.stream().anyMatch(entry -> entry.getAction().getActionId().equals(action2.getActionId())));
assertFalse(entries.stream().anyMatch(entry -> entry.getAction().getActionId().equals(action3.getActionId())));
}

private void logMultiple(int count) {
logMultiple(action1, count);
}

private void logMultiple(Action action, int count) {
for (int i = 0; i < count; i++) {
SimpleActionData data = new SimpleActionData("test" + i);
Instant timestamp = NOW.plusSeconds(i);
Expand Down