diff --git a/src/main/java/net/staticstudios/audit/StaticAudit.java b/src/main/java/net/staticstudios/audit/StaticAudit.java index 60328ad..12be777 100644 --- a/src/main/java/net/staticstudios/audit/StaticAudit.java +++ b/src/main/java/net/staticstudios/audit/StaticAudit.java @@ -10,10 +10,7 @@ import java.sql.*; import java.time.Instant; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.UUID; +import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.function.Supplier; @@ -242,13 +239,13 @@ public StaticAudit log(@NotNull UUID userId, @Nullable UUID sessionId, @NotNull * @param sessionId The session ID to filter by, or null for any session. * @param from The start timestamp for filtering, or null for no lower bound. * @param to The end timestamp for filtering, or null for no upper bound. - * @param actionId The action ID to filter by, or null for any action. + * @param actionIds The action IDs to filter by, or null for any action. * @param limit The maximum number of entries to retrieve. * @return A CompletableFuture containing the list of matching audit log entries. */ - public CompletableFuture>> retrieveAsync(@NotNull UUID userId, @Nullable UUID sessionId, @Nullable Instant from, @Nullable Instant to, @Nullable String actionId, int limit) { + public CompletableFuture>> retrieveAsync(@NotNull UUID userId, @Nullable UUID sessionId, @Nullable Instant from, @Nullable Instant to, int limit, String... actionIds) { CompletableFuture>> future = new CompletableFuture<>(); - async(() -> future.complete(retrieve(userId, sessionId, from, to, actionId, limit))); + async(() -> future.complete(retrieve(userId, sessionId, from, to, limit, actionIds))); return future; } @@ -259,11 +256,11 @@ public CompletableFuture>> retrieveAsync(@NotNull UUID use * @param sessionId The session ID to filter by, or null for any session. * @param from The start timestamp for filtering, or null for no lower bound. * @param to The end timestamp for filtering, or null for no upper bound. - * @param actionId The action ID to filter by, or null for any action. + * @param actionIds The action IDs to filter by, or null for any action. * @param limit The maximum number of entries to retrieve. * @return The list of matching audit log entries. */ - public List> retrieve(@NotNull UUID userId, @Nullable UUID sessionId, @Nullable Instant from, @Nullable Instant to, @Nullable String actionId, int limit) { + public List> retrieve(@NotNull UUID userId, @Nullable UUID sessionId, @Nullable Instant from, @Nullable Instant to, int limit, String... actionIds) { Preconditions.checkNotNull(userId, "User ID cannot be null"); Preconditions.checkArgument(limit > 0, "Limit must be greater than 0"); List> entries = new ArrayList<>(); @@ -279,9 +276,16 @@ public List> retrieve(@NotNull UUID userId, @Nullable UUID sess if (to != null) { sqlBuilder.append(" AND timestamp <= ?"); } - if (actionId != null) { - sqlBuilder.append(" AND action_id = ?"); + if (actionIds.length > 0) { + sqlBuilder.append(" AND action_id IN (?"); + + if (actionIds.length > 1 ) { + sqlBuilder.repeat(", ?", actionIds.length - 1); + } + + sqlBuilder.append(")"); } + sqlBuilder.append(" ORDER BY timestamp DESC LIMIT ?"); String sql = sqlBuilder.toString().formatted(schemaName, tableName); try (PreparedStatement statement = connection.prepareStatement(sql)) { @@ -296,7 +300,7 @@ public List> retrieve(@NotNull UUID userId, @Nullable UUID sess if (to != null) { statement.setObject(index++, Timestamp.from(to)); } - if (actionId != null) { + for (String actionId : actionIds) { statement.setString(index++, actionId); } statement.setInt(index, limit); diff --git a/src/test/java/net/staticstudios/audit/LoggingTest.java b/src/test/java/net/staticstudios/audit/LoggingTest.java index f7dfa6d..ca49e7e 100644 --- a/src/test/java/net/staticstudios/audit/LoggingTest.java +++ b/src/test/java/net/staticstudios/audit/LoggingTest.java @@ -13,15 +13,16 @@ import java.util.List; import java.util.UUID; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; public class LoggingTest extends AuditTest { private static final Instant NOW = Instant.ofEpochMilli(0); private StaticAudit audit; private UUID userId; private UUID sessionId; - private Action.SimpleAction action; + private Action.SimpleAction action1; + private Action.SimpleAction action2; + private Action.SimpleAction action3; @BeforeEach public void setUp() { @@ -35,8 +36,12 @@ public void setUp() { userId = UUID.randomUUID(); sessionId = UUID.randomUUID(); - action = (Action.SimpleAction) Action.simple("test_action", SimpleActionData.class); - audit.registerAction(action); + action1 = (Action.SimpleAction) Action.simple("test_action", SimpleActionData.class); + action2 = (Action.SimpleAction) Action.simple("test_action_2", SimpleActionData.class); + action3 = (Action.SimpleAction) Action.simple("test_action_3", SimpleActionData.class); + audit.registerAction(action1); + audit.registerAction(action2); + audit.registerAction(action3); } @AfterEach @@ -49,7 +54,7 @@ public void tearDown() throws SQLException { @Test public void testLogging() throws SQLException { SimpleActionData data = new SimpleActionData("test"); - audit.log(userId, sessionId, action, data); + audit.log(userId, sessionId, action1, data); Connection connection = getConnection(); @Language("SQL") String sql = "SELECT * FROM %s.%s WHERE user_id = ?"; @@ -59,8 +64,8 @@ public void testLogging() throws SQLException { assertTrue(rs.next()); assertEquals(userId, rs.getObject("user_id")); assertEquals(sessionId, rs.getObject("session_id")); - assertEquals(action.getActionId(), rs.getString("action_id")); - assertEquals(data, action.fromJson(rs.getString("action_data"))); + assertEquals(action1.getActionId(), rs.getString("action_id")); + assertEquals(data, action1.fromJson(rs.getString("action_data"))); } @Test @@ -68,16 +73,41 @@ public void testRetrieving() { logMultiple(50); List> entries; - entries = audit.retrieve(userId, null, null, null, null, 100); + entries = audit.retrieve(userId, null, null, null, 100); assertEquals(50, entries.size()); - entries = audit.retrieve(userId, null, null, null, null, 10); + entries = audit.retrieve(userId, null, null, null, 10); assertEquals(10, entries.size()); + for (int i = 0; i < 10; i++) { assertEquals("test" + (49 - i), ((SimpleActionData) entries.get(i).getData()).data()); } } + @Test + public void testRetrievingWithFilter() { + logMultiple(action1, 50); + logMultiple(action2, 30); + logMultiple(action3, 60); + + List> entries; + + entries = audit.retrieve(userId, null, null, null, 500); + assertEquals(140, entries.size()); + entries = audit.retrieve(userId, null, null, null, 100, action1.getActionId(), action3.getActionId()); + assertEquals(100, entries.size()); + assertFalse(entries.stream().anyMatch(entry -> entry.getAction().getActionId().equals(action2.getActionId()))); + entries = audit.retrieve(userId, null, null, null, 10, action1.getActionId()); + assertEquals(10, entries.size()); + assertTrue(entries.stream().allMatch(entry -> entry.getAction().getActionId().equals(action1.getActionId()))); + assertFalse(entries.stream().anyMatch(entry -> entry.getAction().getActionId().equals(action2.getActionId()))); + assertFalse(entries.stream().anyMatch(entry -> entry.getAction().getActionId().equals(action3.getActionId()))); + } + private void logMultiple(int count) { + logMultiple(action1, count); + } + + private void logMultiple(Action action, int count) { for (int i = 0; i < count; i++) { SimpleActionData data = new SimpleActionData("test" + i); Instant timestamp = NOW.plusSeconds(i);