diff --git a/Plugins/RedisDriverPlugin/RedisCommandParser.swift b/Plugins/RedisDriverPlugin/RedisCommandParser.swift index 040106d8..ea595f53 100644 --- a/Plugins/RedisDriverPlugin/RedisCommandParser.swift +++ b/Plugins/RedisDriverPlugin/RedisCommandParser.swift @@ -44,8 +44,8 @@ enum RedisOperation { case scard(key: String) // Sorted set - case zrange(key: String, start: Int, stop: Int, withScores: Bool) - case zadd(key: String, scoreMembers: [(Double, String)]) + case zrange(key: String, start: String, stop: String, flags: [String]) + case zadd(key: String, flags: [String], scoreMembers: [(Double, String)]) case zrem(key: String, members: [String]) case zcard(key: String) @@ -73,6 +73,8 @@ enum RedisOperation { struct RedisSetOptions { var ex: Int? var px: Int? + var exat: Int? + var pxat: Int? var nx: Bool = false var xx: Bool = false } @@ -112,26 +114,41 @@ struct RedisCommandParser { switch command { case "GET", "SET", "DEL", "KEYS", "SCAN", "TYPE", "TTL", "PTTL", - "EXPIRE", "PERSIST", "RENAME", "EXISTS": - return try parseKeyCommand(command, args: args) - - case "HGET", "HSET", "HGETALL", "HDEL": - return try parseHashCommand(command, args: args) - - case "LRANGE", "LPUSH", "RPUSH", "LLEN": - return try parseListCommand(command, args: args) - - case "SMEMBERS", "SADD", "SREM", "SCARD": - return try parseSetCommand(command, args: args) - - case "ZRANGE", "ZADD", "ZREM", "ZCARD": - return try parseSortedSetCommand(command, args: args) - - case "XRANGE", "XLEN": - return try parseStreamCommand(command, args: args) - - case "PING", "INFO", "DBSIZE", "FLUSHDB", "SELECT", "CONFIG", - "MULTI", "EXEC", "DISCARD": + "EXPIRE", "PEXPIRE", "EXPIREAT", "PEXPIREAT", + "PERSIST", "RENAME", "EXISTS", + "GETSET", "GETDEL", "GETEX", + "MGET", "MSET", + "INCR", "DECR", "INCRBY", "DECRBY", "INCRBYFLOAT", + "APPEND": + return try parseKeyCommand(command, args: args, tokens: tokens) + + case "HGET", "HSET", "HGETALL", "HDEL", "HSCAN": + return try parseHashCommand(command, args: args, tokens: tokens) + + case "LRANGE", "LPUSH", "RPUSH", "LLEN", + "LPOP", "RPOP", "LSET", "LINSERT", "LREM", "LPOS", "LMOVE": + return try parseListCommand(command, args: args, tokens: tokens) + + case "SMEMBERS", "SADD", "SREM", "SCARD", + "SPOP", "SRANDMEMBER", "SMOVE", + "SUNION", "SINTER", "SDIFF", + "SUNIONSTORE", "SINTERSTORE", "SDIFFSTORE", + "SSCAN": + return try parseSetCommand(command, args: args, tokens: tokens) + + case "ZRANGE", "ZADD", "ZREM", "ZCARD", + "ZSCORE", "ZRANGEBYSCORE", "ZREVRANGE", "ZREVRANGEBYSCORE", + "ZINCRBY", "ZCOUNT", "ZRANK", "ZREVRANK", + "ZPOPMIN", "ZPOPMAX", + "ZSCAN": + return try parseSortedSetCommand(command, args: args, tokens: tokens) + + case "XRANGE", "XLEN", "XADD", "XREAD", "XREVRANGE", "XDEL", + "XTRIM", "XINFO", "XGROUP", "XACK": + return try parseStreamCommand(command, args: args, tokens: tokens) + + case "PING", "INFO", "DBSIZE", "FLUSHDB", "FLUSHALL", "SELECT", "CONFIG", + "MULTI", "EXEC", "DISCARD", "AUTH", "OBJECT": return try parseServerCommand(command, args: args, tokens: tokens) default: @@ -141,7 +158,9 @@ struct RedisCommandParser { // MARK: - Key Commands - private static func parseKeyCommand(_ command: String, args: [String]) throws -> RedisOperation { + private static func parseKeyCommand( + _ command: String, args: [String], tokens: [String] + ) throws -> RedisOperation { switch command { case "GET": guard args.count >= 1 else { throw RedisParseError.missingArgument("GET requires a key") } @@ -149,7 +168,7 @@ struct RedisCommandParser { case "SET": guard args.count >= 2 else { throw RedisParseError.missingArgument("SET requires key and value") } - let options = parseSetOptions(Array(args.dropFirst(2))) + let options = try parseSetOptions(Array(args.dropFirst(2))) return .set(key: args[0], value: args[1], options: options) case "DEL": @@ -164,7 +183,7 @@ struct RedisCommandParser { guard args.count >= 1, let cursor = Int(args[0]) else { throw RedisParseError.missingArgument("SCAN requires a cursor (integer)") } - let (pattern, count) = parseScanOptions(Array(args.dropFirst())) + let (pattern, count) = try parseScanOptions(Array(args.dropFirst())) return .scan(cursor: cursor, pattern: pattern, count: count) case "TYPE": @@ -184,8 +203,39 @@ struct RedisCommandParser { guard let seconds = Int(args[1]) else { throw RedisParseError.invalidArgument("EXPIRE seconds must be an integer") } + // Redis 7.0+ supports optional NX|XX|GT|LT flags; pass through as raw command + if args.count > 2 { + return .command(args: tokens) + } return .expire(key: args[0], seconds: seconds) + case "PEXPIRE": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("PEXPIRE requires key and milliseconds") + } + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("PEXPIRE milliseconds must be an integer") + } + return .command(args: tokens) + + case "EXPIREAT": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("EXPIREAT requires key and timestamp") + } + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("EXPIREAT timestamp must be an integer") + } + return .command(args: tokens) + + case "PEXPIREAT": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("PEXPIREAT requires key and milliseconds-timestamp") + } + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("PEXPIREAT milliseconds-timestamp must be an integer") + } + return .command(args: tokens) + case "PERSIST": guard args.count >= 1 else { throw RedisParseError.missingArgument("PERSIST requires a key") } return .persist(key: args[0]) @@ -198,14 +248,73 @@ struct RedisCommandParser { guard !args.isEmpty else { throw RedisParseError.missingArgument("EXISTS requires at least one key") } return .exists(keys: args) + case "GETSET": + guard args.count >= 2 else { throw RedisParseError.missingArgument("GETSET requires key and value") } + return .command(args: tokens) + + case "GETDEL": + guard args.count >= 1 else { throw RedisParseError.missingArgument("GETDEL requires a key") } + return .command(args: tokens) + + case "GETEX": + guard args.count >= 1 else { throw RedisParseError.missingArgument("GETEX requires a key") } + return .command(args: tokens) + + case "MGET": + guard !args.isEmpty else { throw RedisParseError.missingArgument("MGET requires at least one key") } + return .command(args: tokens) + + case "MSET": + guard args.count >= 2, args.count % 2 == 0 else { + throw RedisParseError.missingArgument("MSET requires key value pairs") + } + return .command(args: tokens) + + case "INCR": + guard args.count >= 1 else { throw RedisParseError.missingArgument("INCR requires a key") } + return .command(args: tokens) + + case "DECR": + guard args.count >= 1 else { throw RedisParseError.missingArgument("DECR requires a key") } + return .command(args: tokens) + + case "INCRBY": + guard args.count >= 2 else { throw RedisParseError.missingArgument("INCRBY requires key and increment") } + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("INCRBY increment must be an integer") + } + return .command(args: tokens) + + case "DECRBY": + guard args.count >= 2 else { throw RedisParseError.missingArgument("DECRBY requires key and decrement") } + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("DECRBY decrement must be an integer") + } + return .command(args: tokens) + + case "INCRBYFLOAT": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("INCRBYFLOAT requires key and increment") + } + guard Double(args[1]) != nil else { + throw RedisParseError.invalidArgument("INCRBYFLOAT increment must be a number") + } + return .command(args: tokens) + + case "APPEND": + guard args.count >= 2 else { throw RedisParseError.missingArgument("APPEND requires key and value") } + return .command(args: tokens) + default: - throw RedisParseError.invalidArgument("Unknown key command: \(command)") + return .command(args: tokens) } } // MARK: - Hash Commands - private static func parseHashCommand(_ command: String, args: [String]) throws -> RedisOperation { + private static func parseHashCommand( + _ command: String, args: [String], tokens: [String] + ) throws -> RedisOperation { switch command { case "HGET": guard args.count >= 2 else { throw RedisParseError.missingArgument("HGET requires key and field") } @@ -228,113 +337,383 @@ struct RedisCommandParser { return .hgetall(key: args[0]) case "HDEL": - guard args.count >= 2 else { throw RedisParseError.missingArgument("HDEL requires key and at least one field") } + guard args.count >= 2 else { + throw RedisParseError.missingArgument("HDEL requires key and at least one field") + } return .hdel(key: args[0], fields: Array(args.dropFirst())) + case "HSCAN": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("HSCAN requires key and cursor") + } + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("HSCAN cursor must be an integer") + } + return .command(args: tokens) + default: - throw RedisParseError.invalidArgument("Unknown hash command: \(command)") + return .command(args: tokens) } } // MARK: - List Commands - private static func parseListCommand(_ command: String, args: [String]) throws -> RedisOperation { + private static func parseListCommand( + _ command: String, args: [String], tokens: [String] + ) throws -> RedisOperation { switch command { case "LRANGE": - guard args.count >= 3 else { throw RedisParseError.missingArgument("LRANGE requires key, start, and stop") } + guard args.count >= 3 else { + throw RedisParseError.missingArgument("LRANGE requires key, start, and stop") + } guard let start = Int(args[1]), let stop = Int(args[2]) else { throw RedisParseError.invalidArgument("LRANGE start and stop must be integers") } return .lrange(key: args[0], start: start, stop: stop) case "LPUSH": - guard args.count >= 2 else { throw RedisParseError.missingArgument("LPUSH requires key and at least one value") } + guard args.count >= 2 else { + throw RedisParseError.missingArgument("LPUSH requires key and at least one value") + } return .lpush(key: args[0], values: Array(args.dropFirst())) case "RPUSH": - guard args.count >= 2 else { throw RedisParseError.missingArgument("RPUSH requires key and at least one value") } + guard args.count >= 2 else { + throw RedisParseError.missingArgument("RPUSH requires key and at least one value") + } return .rpush(key: args[0], values: Array(args.dropFirst())) case "LLEN": guard args.count >= 1 else { throw RedisParseError.missingArgument("LLEN requires a key") } return .llen(key: args[0]) + case "LPOP": + guard args.count >= 1 else { throw RedisParseError.missingArgument("LPOP requires a key") } + if args.count >= 2 { + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("LPOP count must be an integer") + } + } + return .command(args: tokens) + + case "RPOP": + guard args.count >= 1 else { throw RedisParseError.missingArgument("RPOP requires a key") } + if args.count >= 2 { + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("RPOP count must be an integer") + } + } + return .command(args: tokens) + + case "LSET": + guard args.count >= 3 else { + throw RedisParseError.missingArgument("LSET requires key, index, and element") + } + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("LSET index must be an integer") + } + return .command(args: tokens) + + case "LINSERT": + guard args.count >= 4 else { + throw RedisParseError.missingArgument("LINSERT requires key, BEFORE|AFTER, pivot, and element") + } + let position = args[1].uppercased() + guard position == "BEFORE" || position == "AFTER" else { + throw RedisParseError.invalidArgument("LINSERT position must be BEFORE or AFTER") + } + return .command(args: tokens) + + case "LREM": + guard args.count >= 3 else { + throw RedisParseError.missingArgument("LREM requires key, count, and element") + } + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("LREM count must be an integer") + } + return .command(args: tokens) + + case "LPOS": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("LPOS requires key and element") + } + return .command(args: tokens) + + case "LMOVE": + guard args.count >= 4 else { + throw RedisParseError.missingArgument("LMOVE requires source, destination, LEFT|RIGHT, LEFT|RIGHT") + } + let dir1 = args[2].uppercased() + let dir2 = args[3].uppercased() + guard (dir1 == "LEFT" || dir1 == "RIGHT") && (dir2 == "LEFT" || dir2 == "RIGHT") else { + throw RedisParseError.invalidArgument("LMOVE directions must be LEFT or RIGHT") + } + return .command(args: tokens) + default: - throw RedisParseError.invalidArgument("Unknown list command: \(command)") + return .command(args: tokens) } } // MARK: - Set Commands - private static func parseSetCommand(_ command: String, args: [String]) throws -> RedisOperation { + private static func parseSetCommand( + _ command: String, args: [String], tokens: [String] + ) throws -> RedisOperation { switch command { case "SMEMBERS": guard args.count >= 1 else { throw RedisParseError.missingArgument("SMEMBERS requires a key") } return .smembers(key: args[0]) case "SADD": - guard args.count >= 2 else { throw RedisParseError.missingArgument("SADD requires key and at least one member") } + guard args.count >= 2 else { + throw RedisParseError.missingArgument("SADD requires key and at least one member") + } return .sadd(key: args[0], members: Array(args.dropFirst())) case "SREM": - guard args.count >= 2 else { throw RedisParseError.missingArgument("SREM requires key and at least one member") } + guard args.count >= 2 else { + throw RedisParseError.missingArgument("SREM requires key and at least one member") + } return .srem(key: args[0], members: Array(args.dropFirst())) case "SCARD": guard args.count >= 1 else { throw RedisParseError.missingArgument("SCARD requires a key") } return .scard(key: args[0]) + case "SPOP": + guard args.count >= 1 else { throw RedisParseError.missingArgument("SPOP requires a key") } + if args.count >= 2 { + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("SPOP count must be an integer") + } + } + return .command(args: tokens) + + case "SRANDMEMBER": + guard args.count >= 1 else { throw RedisParseError.missingArgument("SRANDMEMBER requires a key") } + if args.count >= 2 { + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("SRANDMEMBER count must be an integer") + } + } + return .command(args: tokens) + + case "SMOVE": + guard args.count >= 3 else { + throw RedisParseError.missingArgument("SMOVE requires source, destination, and member") + } + return .command(args: tokens) + + case "SUNION": + guard !args.isEmpty else { throw RedisParseError.missingArgument("SUNION requires at least one key") } + return .command(args: tokens) + + case "SINTER": + guard !args.isEmpty else { throw RedisParseError.missingArgument("SINTER requires at least one key") } + return .command(args: tokens) + + case "SDIFF": + guard !args.isEmpty else { throw RedisParseError.missingArgument("SDIFF requires at least one key") } + return .command(args: tokens) + + case "SUNIONSTORE": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("SUNIONSTORE requires destination and at least one key") + } + return .command(args: tokens) + + case "SINTERSTORE": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("SINTERSTORE requires destination and at least one key") + } + return .command(args: tokens) + + case "SDIFFSTORE": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("SDIFFSTORE requires destination and at least one key") + } + return .command(args: tokens) + + case "SSCAN": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("SSCAN requires key and cursor") + } + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("SSCAN cursor must be an integer") + } + return .command(args: tokens) + default: - throw RedisParseError.invalidArgument("Unknown set command: \(command)") + return .command(args: tokens) } } // MARK: - Sorted Set Commands - private static func parseSortedSetCommand(_ command: String, args: [String]) throws -> RedisOperation { + private static func parseSortedSetCommand( + _ command: String, args: [String], tokens: [String] + ) throws -> RedisOperation { switch command { case "ZRANGE": guard args.count >= 3 else { throw RedisParseError.missingArgument("ZRANGE requires key, start, and stop") } - guard let start = Int(args[1]), let stop = Int(args[2]) else { - throw RedisParseError.invalidArgument("ZRANGE start and stop must be integers") + let start = args[1] + let stop = args[2] + // Parse optional trailing flags: BYSCORE, BYLEX, REV, WITHSCORES, LIMIT offset count + let knownFlags: Set = ["BYSCORE", "BYLEX", "REV", "WITHSCORES", "LIMIT"] + var flags: [String] = [] + var i = 3 + while i < args.count { + let upper = args[i].uppercased() + if knownFlags.contains(upper) { + flags.append(upper) + if upper == "LIMIT" { + // LIMIT requires offset and count + guard i + 2 < args.count else { + throw RedisParseError.missingArgument("LIMIT requires offset and count") + } + flags.append(args[i + 1]) + flags.append(args[i + 2]) + i += 2 + } + } + i += 1 } - let withScores = args.count > 3 && args[3].uppercased() == "WITHSCORES" - return .zrange(key: args[0], start: start, stop: stop, withScores: withScores) + return .zrange(key: args[0], start: start, stop: stop, flags: flags) case "ZADD": - guard args.count >= 3, (args.count - 1) % 2 == 0 else { + guard args.count >= 2 else { throw RedisParseError.missingArgument("ZADD requires key followed by score member pairs") } - var scoreMembers: [(Double, String)] = [] + // Skip known flags after key: NX, XX, GT, LT, CH, INCR (case-insensitive) + let zaddFlags: Set = ["NX", "XX", "GT", "LT", "CH", "INCR"] + var collectedFlags: [String] = [] var i = 1 - while i + 1 < args.count { - guard let score = Double(args[i]) else { - throw RedisParseError.invalidArgument("ZADD score must be a number: \(args[i])") + while i < args.count, zaddFlags.contains(args[i].uppercased()) { + collectedFlags.append(args[i].uppercased()) + i += 1 + } + let remaining = Array(args[i...]) + guard !remaining.isEmpty, remaining.count % 2 == 0 else { + throw RedisParseError.missingArgument("ZADD requires score member pairs after flags") + } + var scoreMembers: [(Double, String)] = [] + var j = 0 + while j + 1 < remaining.count { + guard let score = Double(remaining[j]) else { + throw RedisParseError.invalidArgument("ZADD score must be a number: \(remaining[j])") } - scoreMembers.append((score, args[i + 1])) - i += 2 + scoreMembers.append((score, remaining[j + 1])) + j += 2 } - return .zadd(key: args[0], scoreMembers: scoreMembers) + return .zadd(key: args[0], flags: collectedFlags, scoreMembers: scoreMembers) case "ZREM": - guard args.count >= 2 else { throw RedisParseError.missingArgument("ZREM requires key and at least one member") } + guard args.count >= 2 else { + throw RedisParseError.missingArgument("ZREM requires key and at least one member") + } return .zrem(key: args[0], members: Array(args.dropFirst())) case "ZCARD": guard args.count >= 1 else { throw RedisParseError.missingArgument("ZCARD requires a key") } return .zcard(key: args[0]) + case "ZSCORE": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("ZSCORE requires key and member") + } + return .command(args: tokens) + + case "ZRANGEBYSCORE": + guard args.count >= 3 else { + throw RedisParseError.missingArgument("ZRANGEBYSCORE requires key, min, and max") + } + return .command(args: tokens) + + case "ZREVRANGE": + guard args.count >= 3 else { + throw RedisParseError.missingArgument("ZREVRANGE requires key, start, and stop") + } + guard Int(args[1]) != nil, Int(args[2]) != nil else { + throw RedisParseError.invalidArgument("ZREVRANGE start and stop must be integers") + } + return .command(args: tokens) + + case "ZREVRANGEBYSCORE": + guard args.count >= 3 else { + throw RedisParseError.missingArgument("ZREVRANGEBYSCORE requires key, max, and min") + } + return .command(args: tokens) + + case "ZINCRBY": + guard args.count >= 3 else { + throw RedisParseError.missingArgument("ZINCRBY requires key, increment, and member") + } + guard Double(args[1]) != nil else { + throw RedisParseError.invalidArgument("ZINCRBY increment must be a number") + } + return .command(args: tokens) + + case "ZCOUNT": + guard args.count >= 3 else { + throw RedisParseError.missingArgument("ZCOUNT requires key, min, and max") + } + return .command(args: tokens) + + case "ZRANK": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("ZRANK requires key and member") + } + return .command(args: tokens) + + case "ZREVRANK": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("ZREVRANK requires key and member") + } + return .command(args: tokens) + + case "ZPOPMIN": + guard args.count >= 1 else { throw RedisParseError.missingArgument("ZPOPMIN requires a key") } + if args.count >= 2 { + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("ZPOPMIN count must be an integer") + } + } + return .command(args: tokens) + + case "ZPOPMAX": + guard args.count >= 1 else { throw RedisParseError.missingArgument("ZPOPMAX requires a key") } + if args.count >= 2 { + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("ZPOPMAX count must be an integer") + } + } + return .command(args: tokens) + + case "ZSCAN": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("ZSCAN requires key and cursor") + } + guard Int(args[1]) != nil else { + throw RedisParseError.invalidArgument("ZSCAN cursor must be an integer") + } + return .command(args: tokens) + default: - throw RedisParseError.invalidArgument("Unknown sorted set command: \(command)") + return .command(args: tokens) } } // MARK: - Stream Commands - private static func parseStreamCommand(_ command: String, args: [String]) throws -> RedisOperation { + private static func parseStreamCommand( + _ command: String, args: [String], tokens: [String] + ) throws -> RedisOperation { switch command { case "XRANGE": - guard args.count >= 3 else { throw RedisParseError.missingArgument("XRANGE requires key, start, and end") } + guard args.count >= 3 else { + throw RedisParseError.missingArgument("XRANGE requires key, start, and end") + } var count: Int? if args.count >= 5, args[3].uppercased() == "COUNT" { count = Int(args[4]) @@ -345,8 +724,74 @@ struct RedisCommandParser { guard args.count >= 1 else { throw RedisParseError.missingArgument("XLEN requires a key") } return .xlen(key: args[0]) + case "XADD": + // XADD key [NOMKSTREAM] [MAXLEN|MINID [=|~] threshold] *|ID field value [field value ...] + guard args.count >= 4 else { + throw RedisParseError.missingArgument("XADD requires key, ID, and at least one field-value pair") + } + return .command(args: tokens) + + case "XREAD": + // XREAD [COUNT count] [BLOCK ms] STREAMS key [key ...] ID [ID ...] + guard args.count >= 3 else { + throw RedisParseError.missingArgument("XREAD requires STREAMS keyword, at least one key, and an ID") + } + let hasStreams = args.contains { $0.uppercased() == "STREAMS" } + guard hasStreams else { + throw RedisParseError.missingArgument("XREAD requires the STREAMS keyword") + } + return .command(args: tokens) + + case "XREVRANGE": + guard args.count >= 3 else { + throw RedisParseError.missingArgument("XREVRANGE requires key, end, and start") + } + return .command(args: tokens) + + case "XDEL": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("XDEL requires key and at least one ID") + } + return .command(args: tokens) + + case "XTRIM": + guard args.count >= 3 else { + throw RedisParseError.missingArgument("XTRIM requires key, MAXLEN|MINID, and threshold") + } + return .command(args: tokens) + + case "XINFO": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("XINFO requires a subcommand and key") + } + let sub = args[0].uppercased() + guard sub == "STREAM" || sub == "GROUPS" || sub == "CONSUMERS" || sub == "HELP" else { + throw RedisParseError.invalidArgument( + "XINFO subcommand must be STREAM, GROUPS, CONSUMERS, or HELP" + ) + } + return .command(args: tokens) + + case "XGROUP": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("XGROUP requires a subcommand and key") + } + let sub = args[0].uppercased() + guard sub == "CREATE" || sub == "SETID" || sub == "DELCONSUMER" || sub == "DESTROY" else { + throw RedisParseError.invalidArgument( + "XGROUP subcommand must be CREATE, SETID, DELCONSUMER, or DESTROY" + ) + } + return .command(args: tokens) + + case "XACK": + guard args.count >= 3 else { + throw RedisParseError.missingArgument("XACK requires key, group, and at least one ID") + } + return .command(args: tokens) + default: - throw RedisParseError.invalidArgument("Unknown stream command: \(command)") + return .command(args: tokens) } } @@ -368,6 +813,15 @@ struct RedisCommandParser { case "FLUSHDB": return .flushdb + case "FLUSHALL": + // Optional ASYNC|SYNC flag + if let flag = args.first?.uppercased() { + guard flag == "ASYNC" || flag == "SYNC" else { + throw RedisParseError.invalidArgument("FLUSHALL flag must be ASYNC or SYNC") + } + } + return .command(args: tokens) + case "SELECT": guard args.count >= 1, let db = Int(args[0]) else { throw RedisParseError.missingArgument("SELECT requires a database index (integer)") @@ -400,30 +854,72 @@ struct RedisCommandParser { case "DISCARD": return .discard + case "AUTH": + guard !args.isEmpty else { + throw RedisParseError.missingArgument("AUTH requires a password (and optionally a username)") + } + return .command(args: tokens) + + case "OBJECT": + guard args.count >= 2 else { + throw RedisParseError.missingArgument("OBJECT requires a subcommand and key") + } + let sub = args[0].uppercased() + guard sub == "ENCODING" || sub == "REFCOUNT" || sub == "IDLETIME" + || sub == "HELP" || sub == "FREQ" else { + throw RedisParseError.invalidArgument( + "OBJECT subcommand must be ENCODING, REFCOUNT, IDLETIME, FREQ, or HELP" + ) + } + return .command(args: tokens) + default: - throw RedisParseError.invalidArgument("Unknown server command: \(command)") + return .command(args: tokens) } } // MARK: - Tokenizer - /// Split input by whitespace, respecting quoted strings (single and double quotes) + /// Split input by whitespace, respecting quoted strings (single and double quotes). + /// Escape sequences (\n, \t, \r, \\, \", \') are only decoded inside quoted strings. + /// Outside quotes, backslash is treated as a literal character (matching Redis CLI behavior). private static func tokenize(_ input: String) -> [String] { var tokens: [String] = [] var current = "" var inQuote = false var quoteChar: Character = "\"" var escapeNext = false + var escapedInsideQuote = false + var hadQuote = false for char in input { if escapeNext { - current.append(char) escapeNext = false + if escapedInsideQuote { + // Decode known escape sequences inside quoted strings + switch char { + case "n": current.append("\n") + case "t": current.append("\t") + case "r": current.append("\r") + case "\\": current.append("\\") + case "\"": current.append("\"") + case "'": current.append("'") + default: + // Unknown escape: preserve both characters + current.append("\\") + current.append(char) + } + } else { + // Outside quotes: backslash is literal + current.append("\\") + current.append(char) + } continue } if char == "\\" { escapeNext = true + escapedInsideQuote = inQuote continue } @@ -438,14 +934,16 @@ struct RedisCommandParser { if char == "\"" || char == "'" { inQuote = true + hadQuote = true quoteChar = char continue } if char.isWhitespace { - if !current.isEmpty { + if !current.isEmpty || hadQuote { tokens.append(current) current = "" + hadQuote = false } continue } @@ -453,7 +951,12 @@ struct RedisCommandParser { current.append(char) } - if !current.isEmpty { + // Handle trailing backslash + if escapeNext { + current.append("\\") + } + + if !current.isEmpty || hadQuote { tokens.append(current) } @@ -462,8 +965,8 @@ struct RedisCommandParser { // MARK: - Option Parsers - /// Parse SET command options: EX, PX, NX, XX - private static func parseSetOptions(_ args: [String]) -> RedisSetOptions? { + /// Parse SET command options: EX, PX, EXAT, PXAT, NX, XX + private static func parseSetOptions(_ args: [String]) throws -> RedisSetOptions? { guard !args.isEmpty else { return nil } var options = RedisSetOptions() @@ -474,17 +977,45 @@ struct RedisCommandParser { let arg = args[i].uppercased() switch arg { case "EX": - if i + 1 < args.count, let seconds = Int(args[i + 1]) { - options.ex = seconds - hasOption = true - i += 1 + guard i + 1 < args.count else { + throw RedisParseError.missingArgument("EX requires a value") } + guard let seconds = Int(args[i + 1]), seconds > 0 else { + throw RedisParseError.invalidArgument("EX value must be a positive integer") + } + options.ex = seconds + hasOption = true + i += 1 case "PX": - if i + 1 < args.count, let millis = Int(args[i + 1]) { - options.px = millis - hasOption = true - i += 1 + guard i + 1 < args.count else { + throw RedisParseError.missingArgument("PX requires a value") + } + guard let millis = Int(args[i + 1]), millis > 0 else { + throw RedisParseError.invalidArgument("PX value must be a positive integer") + } + options.px = millis + hasOption = true + i += 1 + case "EXAT": + guard i + 1 < args.count else { + throw RedisParseError.missingArgument("EXAT requires a value") } + guard let timestamp = Int(args[i + 1]) else { + throw RedisParseError.invalidArgument("EXAT value must be a positive integer") + } + options.exat = timestamp + hasOption = true + i += 1 + case "PXAT": + guard i + 1 < args.count else { + throw RedisParseError.missingArgument("PXAT requires a value") + } + guard let timestamp = Int(args[i + 1]) else { + throw RedisParseError.invalidArgument("PXAT value must be a positive integer") + } + options.pxat = timestamp + hasOption = true + i += 1 case "NX": options.nx = true hasOption = true @@ -501,7 +1032,7 @@ struct RedisCommandParser { } /// Parse SCAN options: MATCH pattern, COUNT count - private static func parseScanOptions(_ args: [String]) -> (pattern: String?, count: Int?) { + private static func parseScanOptions(_ args: [String]) throws -> (pattern: String?, count: Int?) { var pattern: String? var count: Int? var i = 0 @@ -515,10 +1046,14 @@ struct RedisCommandParser { i += 1 } case "COUNT": - if i + 1 < args.count { - count = Int(args[i + 1]) - i += 1 + guard i + 1 < args.count else { + throw RedisParseError.missingArgument("COUNT requires a value") + } + guard let countVal = Int(args[i + 1]) else { + throw RedisParseError.invalidArgument("COUNT must be a positive integer") } + count = countVal + i += 1 default: break } diff --git a/Plugins/RedisDriverPlugin/RedisPlugin.swift b/Plugins/RedisDriverPlugin/RedisPlugin.swift index 007dea01..b7cafa33 100644 --- a/Plugins/RedisDriverPlugin/RedisPlugin.swift +++ b/Plugins/RedisDriverPlugin/RedisPlugin.swift @@ -69,44 +69,112 @@ final class RedisPlugin: NSObject, TableProPlugin, DriverPlugin { static var statementCompletions: [CompletionEntry] { [ + // Key commands CompletionEntry(label: "GET", insertText: "GET"), CompletionEntry(label: "SET", insertText: "SET"), CompletionEntry(label: "DEL", insertText: "DEL"), CompletionEntry(label: "EXISTS", insertText: "EXISTS"), CompletionEntry(label: "KEYS", insertText: "KEYS"), + CompletionEntry(label: "GETSET", insertText: "GETSET"), + CompletionEntry(label: "GETDEL", insertText: "GETDEL"), + CompletionEntry(label: "GETEX", insertText: "GETEX"), + CompletionEntry(label: "MGET", insertText: "MGET"), + CompletionEntry(label: "MSET", insertText: "MSET"), + CompletionEntry(label: "INCR", insertText: "INCR"), + CompletionEntry(label: "DECR", insertText: "DECR"), + CompletionEntry(label: "INCRBY", insertText: "INCRBY"), + CompletionEntry(label: "DECRBY", insertText: "DECRBY"), + CompletionEntry(label: "INCRBYFLOAT", insertText: "INCRBYFLOAT"), + CompletionEntry(label: "APPEND", insertText: "APPEND"), + CompletionEntry(label: "EXPIRE", insertText: "EXPIRE"), + CompletionEntry(label: "PEXPIRE", insertText: "PEXPIRE"), + CompletionEntry(label: "EXPIREAT", insertText: "EXPIREAT"), + CompletionEntry(label: "PEXPIREAT", insertText: "PEXPIREAT"), + CompletionEntry(label: "TTL", insertText: "TTL"), + CompletionEntry(label: "PTTL", insertText: "PTTL"), + CompletionEntry(label: "PERSIST", insertText: "PERSIST"), + CompletionEntry(label: "TYPE", insertText: "TYPE"), + CompletionEntry(label: "RENAME", insertText: "RENAME"), + CompletionEntry(label: "SCAN", insertText: "SCAN"), + + // Hash commands CompletionEntry(label: "HGET", insertText: "HGET"), CompletionEntry(label: "HSET", insertText: "HSET"), CompletionEntry(label: "HGETALL", insertText: "HGETALL"), CompletionEntry(label: "HDEL", insertText: "HDEL"), + CompletionEntry(label: "HSCAN", insertText: "HSCAN"), + + // List commands CompletionEntry(label: "LPUSH", insertText: "LPUSH"), CompletionEntry(label: "RPUSH", insertText: "RPUSH"), CompletionEntry(label: "LRANGE", insertText: "LRANGE"), CompletionEntry(label: "LLEN", insertText: "LLEN"), + CompletionEntry(label: "LPOP", insertText: "LPOP"), + CompletionEntry(label: "RPOP", insertText: "RPOP"), + CompletionEntry(label: "LSET", insertText: "LSET"), + CompletionEntry(label: "LINSERT", insertText: "LINSERT"), + CompletionEntry(label: "LREM", insertText: "LREM"), + CompletionEntry(label: "LPOS", insertText: "LPOS"), + CompletionEntry(label: "LMOVE", insertText: "LMOVE"), + + // Set commands CompletionEntry(label: "SADD", insertText: "SADD"), CompletionEntry(label: "SMEMBERS", insertText: "SMEMBERS"), CompletionEntry(label: "SREM", insertText: "SREM"), CompletionEntry(label: "SCARD", insertText: "SCARD"), + CompletionEntry(label: "SPOP", insertText: "SPOP"), + CompletionEntry(label: "SRANDMEMBER", insertText: "SRANDMEMBER"), + CompletionEntry(label: "SMOVE", insertText: "SMOVE"), + CompletionEntry(label: "SUNION", insertText: "SUNION"), + CompletionEntry(label: "SINTER", insertText: "SINTER"), + CompletionEntry(label: "SDIFF", insertText: "SDIFF"), + CompletionEntry(label: "SUNIONSTORE", insertText: "SUNIONSTORE"), + CompletionEntry(label: "SINTERSTORE", insertText: "SINTERSTORE"), + CompletionEntry(label: "SDIFFSTORE", insertText: "SDIFFSTORE"), + CompletionEntry(label: "SSCAN", insertText: "SSCAN"), + + // Sorted set commands CompletionEntry(label: "ZADD", insertText: "ZADD"), CompletionEntry(label: "ZRANGE", insertText: "ZRANGE"), CompletionEntry(label: "ZREM", insertText: "ZREM"), + CompletionEntry(label: "ZCARD", insertText: "ZCARD"), CompletionEntry(label: "ZSCORE", insertText: "ZSCORE"), - CompletionEntry(label: "EXPIRE", insertText: "EXPIRE"), - CompletionEntry(label: "TTL", insertText: "TTL"), - CompletionEntry(label: "PERSIST", insertText: "PERSIST"), - CompletionEntry(label: "TYPE", insertText: "TYPE"), - CompletionEntry(label: "SCAN", insertText: "SCAN"), - CompletionEntry(label: "HSCAN", insertText: "HSCAN"), - CompletionEntry(label: "SSCAN", insertText: "SSCAN"), + CompletionEntry(label: "ZRANGEBYSCORE", insertText: "ZRANGEBYSCORE"), + CompletionEntry(label: "ZREVRANGE", insertText: "ZREVRANGE"), + CompletionEntry(label: "ZREVRANGEBYSCORE", insertText: "ZREVRANGEBYSCORE"), + CompletionEntry(label: "ZINCRBY", insertText: "ZINCRBY"), + CompletionEntry(label: "ZCOUNT", insertText: "ZCOUNT"), + CompletionEntry(label: "ZRANK", insertText: "ZRANK"), + CompletionEntry(label: "ZREVRANK", insertText: "ZREVRANK"), + CompletionEntry(label: "ZPOPMIN", insertText: "ZPOPMIN"), + CompletionEntry(label: "ZPOPMAX", insertText: "ZPOPMAX"), CompletionEntry(label: "ZSCAN", insertText: "ZSCAN"), + + // Stream commands + CompletionEntry(label: "XRANGE", insertText: "XRANGE"), + CompletionEntry(label: "XREVRANGE", insertText: "XREVRANGE"), + CompletionEntry(label: "XLEN", insertText: "XLEN"), + CompletionEntry(label: "XADD", insertText: "XADD"), + CompletionEntry(label: "XREAD", insertText: "XREAD"), + CompletionEntry(label: "XDEL", insertText: "XDEL"), + CompletionEntry(label: "XTRIM", insertText: "XTRIM"), + CompletionEntry(label: "XINFO", insertText: "XINFO"), + CompletionEntry(label: "XGROUP", insertText: "XGROUP"), + CompletionEntry(label: "XACK", insertText: "XACK"), + + // Server commands + CompletionEntry(label: "PING", insertText: "PING"), CompletionEntry(label: "INFO", insertText: "INFO"), CompletionEntry(label: "DBSIZE", insertText: "DBSIZE"), CompletionEntry(label: "FLUSHDB", insertText: "FLUSHDB"), + CompletionEntry(label: "FLUSHALL", insertText: "FLUSHALL"), CompletionEntry(label: "SELECT", insertText: "SELECT"), - CompletionEntry(label: "INCR", insertText: "INCR"), - CompletionEntry(label: "DECR", insertText: "DECR"), - CompletionEntry(label: "APPEND", insertText: "APPEND"), - CompletionEntry(label: "MGET", insertText: "MGET"), - CompletionEntry(label: "MSET", insertText: "MSET") + CompletionEntry(label: "CONFIG", insertText: "CONFIG"), + CompletionEntry(label: "AUTH", insertText: "AUTH"), + CompletionEntry(label: "OBJECT", insertText: "OBJECT"), + CompletionEntry(label: "MULTI", insertText: "MULTI"), + CompletionEntry(label: "EXEC", insertText: "EXEC"), + CompletionEntry(label: "DISCARD", insertText: "DISCARD"), ] } diff --git a/Plugins/RedisDriverPlugin/RedisPluginConnection.swift b/Plugins/RedisDriverPlugin/RedisPluginConnection.swift index 2d77a488..e36e5c1f 100644 --- a/Plugins/RedisDriverPlugin/RedisPluginConnection.swift +++ b/Plugins/RedisDriverPlugin/RedisPluginConnection.swift @@ -23,6 +23,7 @@ struct RedisSSLConfig { var caCertificatePath: String = "" var clientCertificatePath: String = "" var clientKeyPath: String = "" + var verifyPeer: Bool = true init() {} @@ -32,6 +33,7 @@ struct RedisSSLConfig { self.caCertificatePath = additionalFields["sslCaCertPath"] ?? "" self.clientCertificatePath = additionalFields["sslClientCertPath"] ?? "" self.clientKeyPath = additionalFields["sslClientKeyPath"] ?? "" + self.verifyPeer = (additionalFields["sslVerifyPeer"] ?? "true") == "true" } } @@ -99,7 +101,10 @@ final class RedisPluginConnection: @unchecked Sendable { #if canImport(CRedis) private static let initOnce: Void = { - redisInitOpenSSL() + let result = redisInitOpenSSL() + if result != REDIS_OK { + logger.warning("redisInitOpenSSL failed with code \(result)") + } }() private var context: UnsafeMutablePointer? @@ -109,6 +114,7 @@ final class RedisPluginConnection: @unchecked Sendable { private let queue = DispatchQueue(label: "com.TablePro.redis.plugin", qos: .userInitiated) private let host: String private let port: Int + private let username: String? private let password: String? private let database: Int private let sslConfig: RedisSSLConfig @@ -144,12 +150,14 @@ final class RedisPluginConnection: @unchecked Sendable { init( host: String, port: Int, + username: String? = nil, password: String?, database: Int = 0, sslConfig: RedisSSLConfig = RedisSSLConfig() ) { self.host = host self.port = port + self.username = username self.password = password self.database = database self.sslConfig = sslConfig @@ -165,15 +173,12 @@ final class RedisPluginConnection: @unchecked Sendable { sslContext = nil stateLock.unlock() - let cleanupQueue = queue + // Dispatch cleanup to the serial queue to ensure in-flight commands complete first if handle != nil || ssl != nil { + let cleanupQueue = queue cleanupQueue.async { - if let handle = handle { - redisFree(handle) - } - if let ssl = ssl { - redisFreeSSLContext(ssl) - } + if let handle { redisFree(handle) } + if let ssl { redisFreeSSLContext(ssl) } } } #endif @@ -187,7 +192,8 @@ final class RedisPluginConnection: @unchecked Sendable { try await pluginDispatchAsync(on: queue) { [self] in logger.debug("Connecting to Redis at \(self.host):\(self.port)") - guard let ctx = redisConnect(host, Int32(port)) else { + let connectTimeout = timeval(tv_sec: 10, tv_usec: 0) + guard let ctx = redisConnectWithTimeout(host, Int32(port), connectTimeout) else { logger.error("Failed to create Redis context") throw RedisPluginError.connectionFailed } @@ -202,6 +208,10 @@ final class RedisPluginConnection: @unchecked Sendable { throw RedisPluginError(code: errCode, message: errMsg) } + let commandTimeout = timeval(tv_sec: 30, tv_usec: 0) + redisSetTimeout(ctx, commandTimeout) + redisEnableKeepAliveWithInterval(ctx, 60) + self.context = ctx if sslConfig.isEnabled { @@ -216,7 +226,13 @@ final class RedisPluginConnection: @unchecked Sendable { if let password = password, !password.isEmpty { do { - let reply = try executeCommandSync(["AUTH", password]) + let authArgs: [String] + if let username = username, !username.isEmpty { + authArgs = ["AUTH", username, password] + } else { + authArgs = ["AUTH", password] + } + let reply = try executeCommandSync(authArgs) if case .error(let msg) = reply { redisFree(ctx) self.context = nil @@ -321,7 +337,7 @@ final class RedisPluginConnection: @unchecked Sendable { } } - private func resetCancellation() { + func resetCancellation() { stateLock.lock() _isCancelled = false stateLock.unlock() @@ -345,11 +361,16 @@ final class RedisPluginConnection: @unchecked Sendable { func executeCommand(_ args: [String]) async throws -> RedisReply { #if canImport(CRedis) - resetCancellation() return try await pluginDispatchAsync(on: queue) { [self] in - guard !isShuttingDown, context != nil else { + guard !isShuttingDown else { + throw RedisPluginError.notConnected + } + stateLock.lock() + guard context != nil else { + stateLock.unlock() throw RedisPluginError.notConnected } + stateLock.unlock() try checkCancelled() let result = try executeCommandSync(args) try checkCancelled() @@ -362,11 +383,16 @@ final class RedisPluginConnection: @unchecked Sendable { func executePipeline(_ commands: [[String]]) async throws -> [RedisReply] { #if canImport(CRedis) - resetCancellation() return try await pluginDispatchAsync(on: queue) { [self] in - guard !isShuttingDown, context != nil else { + guard !isShuttingDown else { throw RedisPluginError.notConnected } + stateLock.lock() + guard context != nil else { + stateLock.unlock() + throw RedisPluginError.notConnected + } + stateLock.unlock() try checkCancelled() let results = try executePipelineSync(commands) try checkCancelled() @@ -381,11 +407,16 @@ final class RedisPluginConnection: @unchecked Sendable { func selectDatabase(_ index: Int) async throws { #if canImport(CRedis) - resetCancellation() try await pluginDispatchAsync(on: queue) { [self] in - guard !isShuttingDown, context != nil else { + guard !isShuttingDown else { throw RedisPluginError.notConnected } + stateLock.lock() + guard context != nil else { + stateLock.unlock() + throw RedisPluginError.notConnected + } + stateLock.unlock() try checkCancelled() let reply = try executeCommandSync(["SELECT", String(index)]) if case .error(let msg) = reply { @@ -417,27 +448,44 @@ private extension RedisPluginConnection { let clientKey: UnsafePointer? = sslConfig.clientKeyPath.isEmpty ? nil : (sslConfig.clientKeyPath as NSString).utf8String + let sniHostname: UnsafePointer? = (host as NSString).utf8String - guard let ssl = redisCreateSSLContext(caCert, nil, clientCert, clientKey, nil, &sslError) else { + var options = redisSSLOptions() + options.cacert_filename = caCert + options.capath = nil + options.cert_filename = clientCert + options.private_key_filename = clientKey + options.server_name = sniHostname + options.verify_mode = sslConfig.verifyPeer ? REDIS_SSL_VERIFY_PEER : REDIS_SSL_VERIFY_NONE + + guard let ssl = redisCreateSSLContextWithOptions(&options, &sslError) else { let errCode = Int(sslError.rawValue) - throw RedisPluginError(code: errCode, message: "Failed to create SSL context (error \(errCode))") + throw RedisPluginError( + code: errCode, + message: "Failed to create SSL context (error \(errCode))" + ) } - self.sslContext = ssl - let result = redisInitiateSSLWithContext(ctx, ssl) if result != REDIS_OK { + redisFreeSSLContext(ssl) let errMsg = withUnsafePointer(to: &ctx.pointee.errstr) { ptr in ptr.withMemoryRebound(to: CChar.self, capacity: 128) { String(cString: $0) } } throw RedisPluginError(code: Int(result), message: "SSL handshake failed: \(errMsg)") } + self.sslContext = ssl logger.debug("SSL connection established") } func executeCommandSync(_ args: [String]) throws -> RedisReply { - guard let ctx = context else { throw RedisPluginError.notConnected } + stateLock.lock() + guard let ctx = context else { + stateLock.unlock() + throw RedisPluginError.notConnected + } + stateLock.unlock() let argc = Int32(args.count) let lengths = args.map { $0.utf8.count } @@ -461,7 +509,12 @@ private extension RedisPluginConnection { } func executePipelineSync(_ commands: [[String]]) throws -> [RedisReply] { - guard let ctx = context else { throw RedisPluginError.notConnected } + stateLock.lock() + guard let ctx = context else { + stateLock.unlock() + throw RedisPluginError.notConnected + } + stateLock.unlock() guard !commands.isEmpty else { return [] } var appendedCount = 0 @@ -479,6 +532,7 @@ private extension RedisPluginConnection { let errMsg = withUnsafePointer(to: &ctx.pointee.errstr) { ptr in ptr.withMemoryRebound(to: CChar.self, capacity: 128) { String(cString: $0) } } + markDisconnected() throw RedisPluginError(code: Int(ctx.pointee.err), message: errMsg) } } @@ -500,6 +554,7 @@ private extension RedisPluginConnection { freeReplyObject(d) } } + markDisconnected() throw RedisPluginError(code: Int(ctx.pointee.err), message: errMsg) } let replyPtr = reply.assumingMemoryBound(to: redisReply.self) @@ -510,6 +565,17 @@ private extension RedisPluginConnection { return replies } + func markDisconnected() { + stateLock.lock() + let handle = context + context = nil + _isConnected = false + stateLock.unlock() + #if canImport(CRedis) + if let handle { redisFree(handle) } + #endif + } + func withArgvPointers( args: [String], lengths: [Int], @@ -517,8 +583,18 @@ private extension RedisPluginConnection { ) rethrows -> T { let count = args.count - let cStrings = args.map { strdup($0) } - defer { cStrings.forEach { free($0) } } + let cStrings: [UnsafeMutablePointer] = args.map { arg in + let utf8 = Array(arg.utf8) + let ptr = UnsafeMutablePointer.allocate(capacity: utf8.count + 1) + if let base = utf8.withUnsafeBufferPointer({ $0.baseAddress }) { + base.withMemoryRebound(to: CChar.self, capacity: utf8.count) { src in + ptr.initialize(from: src, count: utf8.count) + } + } + ptr[utf8.count] = 0 + return ptr + } + defer { cStrings.forEach { $0.deallocate() } } let argv = UnsafeMutablePointer?>.allocate(capacity: count) let argvlen = UnsafeMutablePointer.allocate(capacity: count) @@ -651,7 +727,12 @@ private extension RedisPluginConnection { } func fetchServerVersionSync() -> String? { - guard context != nil else { return nil } + stateLock.lock() + guard context != nil else { + stateLock.unlock() + return nil + } + stateLock.unlock() do { let reply = try executeCommandSync(["INFO", "server"]) if case .string(let info) = reply { @@ -664,7 +745,7 @@ private extension RedisPluginConnection { } func parseVersionFromInfo(_ info: String) -> String? { - for line in info.components(separatedBy: "\r\n") { + for line in info.components(separatedBy: .newlines) { let trimmed = line.trimmingCharacters(in: .whitespaces) if trimmed.hasPrefix("redis_version:") { let value = trimmed.dropFirst("redis_version:".count) diff --git a/Plugins/RedisDriverPlugin/RedisPluginDriver.swift b/Plugins/RedisDriverPlugin/RedisPluginDriver.swift index 5a2a9ab1..3c49a897 100644 --- a/Plugins/RedisDriverPlugin/RedisPluginDriver.swift +++ b/Plugins/RedisDriverPlugin/RedisPluginDriver.swift @@ -19,6 +19,9 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { private static let maxScanKeys = PluginRowLimits.defaultMax + private var cachedScanPattern: String? + private var cachedScanKeys: [String]? + var serverVersion: String? { redisConnection?.serverVersion() } @@ -42,6 +45,7 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { let conn = RedisPluginConnection( host: config.host, port: config.port, + username: config.username.isEmpty ? nil : config.username, password: config.password.isEmpty ? nil : config.password, database: redisDb, sslConfig: sslConfig @@ -54,6 +58,8 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { func disconnect() { redisConnection?.disconnect() redisConnection = nil + cachedScanPattern = nil + cachedScanKeys = nil } func ping() async throws { @@ -70,28 +76,15 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { func execute(query: String) async throws -> PluginQueryResult { let startTime = Date() + cachedScanPattern = nil + cachedScanKeys = nil + redisConnection?.resetCancellation() guard let conn = redisConnection else { throw RedisPluginError.notConnected } - var trimmed = query.trimmingCharacters(in: .whitespacesAndNewlines) - - if trimmed.caseInsensitiveCompare("SELECT") == .orderedSame { - trimmed = "SELECT 0" - } - - // Health monitor sends "SELECT 1" as a ping — intercept and remap to PING. - if trimmed.lowercased() == "select 1" { - _ = try await conn.executeCommand(["PING"]) - return PluginQueryResult( - columns: ["ok"], - columnTypeNames: ["Int32"], - rows: [["1"]], - rowsAffected: 0, - executionTime: Date().timeIntervalSince(startTime) - ) - } + let trimmed = query.trimmingCharacters(in: .whitespacesAndNewlines) let operation = try RedisCommandParser.parse(trimmed) return try await executeOperation(operation, connection: conn, startTime: startTime) @@ -129,6 +122,7 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { func fetchRows(query: String, offset: Int, limit: Int) async throws -> PluginQueryResult { let startTime = Date() + redisConnection?.resetCancellation() guard let conn = redisConnection else { throw RedisPluginError.notConnected @@ -139,7 +133,18 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { switch operation { case .scan(_, let pattern, _): - let allKeys = try await scanAllKeys(connection: conn, pattern: pattern, maxKeys: Self.maxScanKeys) + let dbIndex = conn.currentDatabase() + let cacheKey = "\(dbIndex):\(pattern ?? "*")" + let allKeys: [String] + if cachedScanPattern == cacheKey, let cached = cachedScanKeys { + allKeys = cached + } else { + allKeys = try await scanAllKeys( + connection: conn, pattern: pattern, maxKeys: Self.maxScanKeys + ) + cachedScanPattern = cacheKey + cachedScanKeys = allKeys + } let pageEnd = min(offset + limit, allKeys.count) guard offset < allKeys.count else { return buildEmptyKeyResult(startTime: startTime) @@ -165,37 +170,46 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { // MARK: - Schema Operations func fetchTables(schema: String?) async throws -> [PluginTableInfo] { + redisConnection?.resetCancellation() guard let conn = redisConnection else { throw RedisPluginError.notConnected } + // Parse key counts from INFO keyspace let result = try await conn.executeCommand(["INFO", "keyspace"]) - guard let info = result.stringValue else { return [] } + var keyCounts: [String: Int] = [:] + if let info = result.stringValue { + for line in info.components(separatedBy: .newlines) { + let trimmed = line.trimmingCharacters(in: .whitespacesAndNewlines) + guard trimmed.hasPrefix("db"), + let colonIndex = trimmed.firstIndex(of: ":") else { continue } - var databases: [PluginTableInfo] = [] - for line in info.components(separatedBy: .newlines) { - let trimmed = line.trimmingCharacters(in: .whitespacesAndNewlines) - guard trimmed.hasPrefix("db"), - let colonIndex = trimmed.firstIndex(of: ":") else { continue } - - let dbName = String(trimmed[trimmed.startIndex ..< colonIndex]) - let statsStr = String(trimmed[trimmed.index(after: colonIndex)...]) - - var keyCount = 0 - for stat in statsStr.components(separatedBy: ",") { - let parts = stat.components(separatedBy: "=") - if parts.count == 2, parts[0] == "keys", let count = Int(parts[1]) { - keyCount = count - break + let dbName = String(trimmed[trimmed.startIndex ..< colonIndex]) + let statsStr = String(trimmed[trimmed.index(after: colonIndex)...]) + + for stat in statsStr.components(separatedBy: ",") { + let parts = stat.components(separatedBy: "=") + if parts.count == 2, parts[0] == "keys", let count = Int(parts[1]) { + keyCounts[dbName] = count + break + } } } + } - if keyCount > 0 { - databases.append(PluginTableInfo(name: dbName, type: "TABLE", rowCount: keyCount)) - } + // Get total database count from CONFIG GET databases + let configResult = try await conn.executeCommand(["CONFIG", "GET", "databases"]) + var maxDatabases = 16 + if let array = configResult.stringArrayValue, array.count >= 2, let count = Int(array[1]) { + maxDatabases = count } - return databases + // Return all databases (including empty ones) so users can navigate to them + return (0 ..< maxDatabases).map { index in + let dbName = "db\(index)" + let keyCount = keyCounts[dbName] ?? 0 + return PluginTableInfo(name: dbName, type: "TABLE", rowCount: keyCount) + } } func fetchColumns(table: String, schema: String?) async throws -> [PluginColumnInfo] { @@ -291,7 +305,15 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { } func fetchDatabases() async throws -> [String] { - [] + guard let conn = redisConnection else { + throw RedisPluginError.notConnected + } + let result = try await conn.executeCommand(["CONFIG", "GET", "databases"]) + var maxDatabases = 16 + if let array = result.stringArrayValue, array.count >= 2, let count = Int(array[1]) { + maxDatabases = count + } + return (0 ..< maxDatabases).map { "db\($0)" } } func fetchDatabaseMetadata(_ database: String) async throws -> PluginDatabaseMetadata { @@ -358,13 +380,31 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { // MARK: - Database Switching func switchDatabase(to database: String) async throws { + redisConnection?.resetCancellation() guard let conn = redisConnection else { throw RedisPluginError.notConnected } - guard let dbIndex = Int(database) ?? Int(database.dropFirst(2)) else { + let dbIndex: Int + if let idx = Int(database) { + dbIndex = idx + } else if database.lowercased().hasPrefix("db"), let idx = Int(database.dropFirst(2)) { + dbIndex = idx + } else { throw RedisPluginError(code: 0, message: "Invalid database index: \(database)") } try await conn.selectDatabase(dbIndex) } + // MARK: - Table Operations + + func truncateTableStatements(table: String, schema: String?, cascade: Bool) -> [String]? { + ["FLUSHDB"] + } + + func dropObjectStatement(name: String, objectType: String, schema: String?, cascade: Bool) -> String? { + // Redis databases are pre-allocated and cannot be dropped. + // Return empty string to prevent adapter from synthesizing SQL DROP. + "" + } + // MARK: - EXPLAIN func buildExplainQuery(_ sql: String) -> String? { @@ -390,7 +430,7 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { return k case .sadd(let k, _), .srem(let k, _): return k - case .zadd(let k, _), .zrem(let k, _): + case .zadd(let k, _, _), .zrem(let k, _): return k case .del(let keys) where keys.count == 1: return keys[0] @@ -400,7 +440,8 @@ final class RedisPluginDriver: PluginDatabaseDriver, @unchecked Sendable { }() guard let key else { return nil } - return "DEBUG OBJECT \(key)" + let quoted = key.contains(" ") || key.contains("\"") ? "\"\(key.replacingOccurrences(of: "\"", with: "\\\""))\"" : key + return "DEBUG OBJECT \(quoted)" } // MARK: - View Templates @@ -564,6 +605,8 @@ private extension RedisPluginDriver { if let opts = options { if let ex = opts.ex { args += ["EX", String(ex)] } if let px = opts.px { args += ["PX", String(px)] } + if let exat = opts.exat { args += ["EXAT", String(exat)] } + if let pxat = opts.pxat { args += ["PXAT", String(pxat)] } if opts.nx { args.append("NX") } if opts.xx { args.append("XX") } } @@ -644,7 +687,10 @@ private extension RedisPluginDriver { return buildStatusResult(success ? "OK" : "Key not found or no TTL", startTime: startTime) case .rename(let key, let newKey): - _ = try await conn.executeCommand(["RENAME", key, newKey]) + let reply = try await conn.executeCommand(["RENAME", key, newKey]) + if case .error(let msg) = reply { + throw RedisPluginError(code: 0, message: "RENAME failed: \(msg)") + } return buildStatusResult("OK", startTime: startTime) case .exists(let keys): @@ -729,7 +775,7 @@ private extension RedisPluginDriver { switch operation { case .lrange(let key, let start, let stop): let result = try await conn.executeCommand(["LRANGE", key, String(start), String(stop)]) - return buildListResult(result, startTime: startTime) + return buildListResult(result, startOffset: start, startTime: startTime) case .lpush(let key, let values): let args = ["LPUSH", key] + values @@ -831,24 +877,38 @@ private extension RedisPluginDriver { startTime: Date ) async throws -> PluginQueryResult { switch operation { - case .zrange(let key, let start, let stop, let withScores): - var args = ["ZRANGE", key, String(start), String(stop)] - if withScores { args.append("WITHSCORES") } + case .zrange(let key, let start, let stop, let flags): + var args = ["ZRANGE", key, start, stop] + args += flags + let withScores = flags.contains("WITHSCORES") let result = try await conn.executeCommand(args) return buildSortedSetResult(result, withScores: withScores, startTime: startTime) - case .zadd(let key, let scoreMembers): + case .zadd(let key, let flags, let scoreMembers): var args = ["ZADD", key] + args += flags for (score, member) in scoreMembers { args += [String(score), member] } let result = try await conn.executeCommand(args) - let added = result.intValue ?? 0 + if flags.contains("INCR") { + // INCR mode returns the new score (or nil for NX miss) + let scoreStr = result.stringValue ?? "nil" + return PluginQueryResult( + columns: ["score"], + columnTypeNames: ["String"], + rows: [[scoreStr]], + rowsAffected: 0, + executionTime: Date().timeIntervalSince(startTime) + ) + } + let count = result.intValue ?? 0 + let columnName = flags.contains("CH") ? "changed" : "added" return PluginQueryResult( - columns: ["added"], + columns: [columnName], columnTypeNames: ["Int64"], - rows: [[String(added)]], - rowsAffected: added, + rows: [[String(count)]], + rowsAffected: count, executionTime: Date().timeIntervalSince(startTime) ) @@ -957,7 +1017,9 @@ private extension RedisPluginDriver { return buildStatusResult("OK", startTime: startTime) case .select(let database): - _ = try await conn.executeCommand(["SELECT", String(database)]) + try await conn.selectDatabase(database) + cachedScanPattern = nil + cachedScanKeys = nil return buildStatusResult("OK", startTime: startTime) case .configGet(let parameter): @@ -1085,22 +1147,59 @@ private extension RedisPluginDriver { return buildEmptyKeyResult(startTime: startTime) } - var commands: [[String]] = [] - commands.reserveCapacity(keys.count * 2) + var typeAndTtlCommands: [[String]] = [] + typeAndTtlCommands.reserveCapacity(keys.count * 2) for key in keys { - commands.append(["TYPE", key]) - commands.append(["TTL", key]) + typeAndTtlCommands.append(["TYPE", key]) + typeAndTtlCommands.append(["TTL", key]) + } + let typeAndTtlReplies = try await conn.executePipeline(typeAndTtlCommands) + + var typeNames: [String] = [] + typeNames.reserveCapacity(keys.count) + var ttlValues: [Int] = [] + ttlValues.reserveCapacity(keys.count) + for i in 0 ..< keys.count { + let typeName = (typeAndTtlReplies[i * 2].stringValue ?? "unknown").uppercased() + let ttl = typeAndTtlReplies[i * 2 + 1].intValue ?? -1 + typeNames.append(typeName) + ttlValues.append(ttl) } - let replies = try await conn.executePipeline(commands) - var rows: [[String?]] = [] + var previewCommands: [[String]] = [] + previewCommands.reserveCapacity(keys.count) + var previewCommandIndices: [Int] = [] + previewCommandIndices.reserveCapacity(keys.count) + for (i, key) in keys.enumerated() { - let typeName = (replies[i * 2].stringValue ?? "unknown").uppercased() - let ttl = replies[i * 2 + 1].intValue ?? -1 - let ttlStr = String(ttl) + let command: [String]? = previewCommandForType(typeNames[i], key: key) + if let command { + previewCommandIndices.append(previewCommands.count) + previewCommands.append(command) + } else { + previewCommandIndices.append(-1) + } + } - let value = try await fetchValuePreview(key: key, type: typeName, connection: conn) - rows.append([key, typeName, ttlStr, value]) + var previewReplies: [RedisReply] = [] + if !previewCommands.isEmpty { + previewReplies = try await conn.executePipeline(previewCommands) + } + + var rows: [[String?]] = [] + rows.reserveCapacity(keys.count) + for (i, key) in keys.enumerated() { + let ttlStr = String(ttlValues[i]) + let pipelineIndex = previewCommandIndices[i] + let preview: String? + if pipelineIndex >= 0, pipelineIndex < previewReplies.count { + preview = formatPreviewReply( + previewReplies[pipelineIndex], type: typeNames[i] + ) + } else { + preview = nil + } + rows.append([key, typeNames[i], ttlStr, preview]) } return PluginQueryResult( @@ -1113,47 +1212,64 @@ private extension RedisPluginDriver { ) } - func fetchValuePreview(key: String, type: String, connection conn: RedisPluginConnection) async throws -> String? { + func previewCommandForType(_ type: String, key: String) -> [String]? { switch type.lowercased() { case "string": - let result = try await conn.executeCommand(["GET", key]) - return truncatePreview(result.stringValue) + return ["GET", key] + case "hash": + return ["HSCAN", key, "0", "COUNT", String(Self.previewLimit)] + case "list": + return ["LRANGE", key, "0", String(Self.previewLimit - 1)] + case "set": + return ["SSCAN", key, "0", "COUNT", String(Self.previewLimit)] + case "zset": + return ["ZRANGE", key, "0", String(Self.previewLimit - 1), "WITHSCORES"] + case "stream": + return ["XREVRANGE", key, "+", "-", "COUNT", "5"] + default: + return nil + } + } + + func formatPreviewReply(_ reply: RedisReply, type: String) -> String? { + switch type.lowercased() { + case "string": + return truncatePreview(reply.stringValue) case "hash": - let result = try await conn.executeCommand(["HSCAN", key, "0", "COUNT", String(Self.previewLimit)]) let array: [String] - if case .array(let scanResult) = result, + if case .array(let scanResult) = reply, scanResult.count == 2, let items = scanResult[1].stringArrayValue { array = items - } else if let items = result.stringArrayValue, !items.isEmpty { + } else if let items = reply.stringArrayValue, !items.isEmpty { array = items } else { return "{}" } guard !array.isEmpty else { return "{}" } var pairs: [String] = [] - var i = 0 - while i + 1 < array.count { - pairs.append("\"\(escapeJsonString(array[i]))\":\"\(escapeJsonString(array[i + 1]))\"") - i += 2 + var idx = 0 + while idx + 1 < array.count { + pairs.append( + "\"\(escapeJsonString(array[idx]))\":\"\(escapeJsonString(array[idx + 1]))\"" + ) + idx += 2 } return truncatePreview("{\(pairs.joined(separator: ","))}") case "list": - let result = try await conn.executeCommand(["LRANGE", key, "0", String(Self.previewLimit - 1)]) - guard let items = result.stringArrayValue else { return "[]" } + guard let items = reply.stringArrayValue else { return "[]" } let quoted = items.map { "\"\(escapeJsonString($0))\"" } return truncatePreview("[\(quoted.joined(separator: ", "))]") case "set": - let result = try await conn.executeCommand(["SSCAN", key, "0", "COUNT", String(Self.previewLimit)]) let members: [String] - if case .array(let scanResult) = result, + if case .array(let scanResult) = reply, scanResult.count == 2, let items = scanResult[1].stringArrayValue { members = items - } else if let items = result.stringArrayValue { + } else if let items = reply.stringArrayValue { members = items } else { return "[]" @@ -1162,15 +1278,37 @@ private extension RedisPluginDriver { return truncatePreview("[\(quoted.joined(separator: ", "))]") case "zset": - let result = try await conn.executeCommand(["ZRANGE", key, "0", String(Self.previewLimit - 1)]) - guard let members = result.stringArrayValue else { return "[]" } - let quoted = members.map { "\"\(escapeJsonString($0))\"" } - return truncatePreview("[\(quoted.joined(separator: ", "))]") + // Parse WITHSCORES result: alternating member, score pairs + guard let items = reply.stringArrayValue, !items.isEmpty else { return "[]" } + var pairs: [String] = [] + var i = 0 + while i + 1 < items.count { + pairs.append("\(items[i]):\(items[i + 1])") + i += 2 + } + return truncatePreview(pairs.joined(separator: ", ")) case "stream": - let lenResult = try await conn.executeCommand(["XLEN", key]) - let len = lenResult.intValue ?? 0 - return "(\(len) entries)" + // Parse XREVRANGE result: array of [id, [field, value, ...]] entries + guard let entries = reply.arrayValue, !entries.isEmpty else { + return "(0 entries)" + } + var entryStrings: [String] = [] + for entry in entries { + guard let parts = entry.arrayValue, parts.count >= 2, + let entryId = parts[0].stringValue, + let fields = parts[1].stringArrayValue else { + continue + } + var fieldPairs: [String] = [] + var j = 0 + while j + 1 < fields.count { + fieldPairs.append("\(fields[j])=\(fields[j + 1])") + j += 2 + } + entryStrings.append("\(entryId): \(fieldPairs.joined(separator: ", "))") + } + return truncatePreview(entryStrings.joined(separator: "; ")) default: return nil @@ -1179,22 +1317,28 @@ private extension RedisPluginDriver { func truncatePreview(_ value: String?) -> String? { guard let value else { return nil } - if value.count > Self.previewMaxChars { - return String(value.prefix(Self.previewMaxChars)) + "..." + let nsValue = value as NSString + if nsValue.length > Self.previewMaxChars { + return nsValue.substring(to: Self.previewMaxChars) + "..." } return value } func escapeJsonString(_ str: String) -> String { var result = "" - for char in str { - switch char { + for scalar in str.unicodeScalars { + switch scalar { case "\\": result += "\\\\" case "\"": result += "\\\"" case "\n": result += "\\n" case "\r": result += "\\r" case "\t": result += "\\t" - default: result.append(char) + default: + if scalar.value < 0x20 { + result += String(format: "\\u%04X", scalar.value) + } else { + result += String(scalar) + } } } return result @@ -1317,7 +1461,7 @@ private extension RedisPluginDriver { ) } - func buildListResult(_ result: RedisReply, startTime: Date) -> PluginQueryResult { + func buildListResult(_ result: RedisReply, startOffset: Int = 0, startTime: Date) -> PluginQueryResult { guard let array = result.stringArrayValue else { return PluginQueryResult( columns: ["Index", "Value"], @@ -1329,7 +1473,7 @@ private extension RedisPluginDriver { } let rows = array.enumerated().map { index, value -> [String?] in - [String(index), value] + [String(startOffset + index), value] } return PluginQueryResult( diff --git a/Plugins/RedisDriverPlugin/RedisQueryBuilder.swift b/Plugins/RedisDriverPlugin/RedisQueryBuilder.swift index d336a955..78371993 100644 --- a/Plugins/RedisDriverPlugin/RedisQueryBuilder.swift +++ b/Plugins/RedisDriverPlugin/RedisQueryBuilder.swift @@ -59,12 +59,15 @@ struct RedisQueryBuilder { return "SCAN 0 MATCH \"\(pattern)\" COUNT \(limit)" } - /// Build a count command for a namespace + /// Build a count command for a namespace. + /// When a namespace filter is active, DBSIZE would overcount because it + /// returns the total key count for the entire database. We use a SCAN-based + /// approach instead; note the returned count is approximate since SCAN may + /// return duplicates across iterations and new keys may appear mid-scan. func buildCountQuery(namespace: String) -> String { if namespace.isEmpty { return "DBSIZE" } - // For a specific namespace, we use SCAN to count matching keys return "SCAN 0 MATCH \"\(namespace)*\" COUNT 10000" } diff --git a/Plugins/RedisDriverPlugin/RedisStatementGenerator.swift b/Plugins/RedisDriverPlugin/RedisStatementGenerator.swift index e3619193..21593b16 100644 --- a/Plugins/RedisDriverPlugin/RedisStatementGenerator.swift +++ b/Plugins/RedisDriverPlugin/RedisStatementGenerator.swift @@ -26,6 +26,11 @@ struct RedisStatementGenerator { columns.firstIndex(of: "Value") } + /// Index of the "Type" column + private var typeColumnIndex: Int? { + columns.firstIndex(of: "Type") + } + /// Index of the "TTL" column private var ttlColumnIndex: Int? { columns.firstIndex(of: "TTL") @@ -80,22 +85,27 @@ struct RedisStatementGenerator { var key: String? var value: String? + var type: String? var ttl: Int? if let values = insertedRowData[change.rowIndex] { if let ki = keyColumnIndex, ki < values.count { key = values[ki] } + if let ti = typeColumnIndex, ti < values.count { + type = values[ti] + } if let vi = valueColumnIndex, vi < values.count { value = values[vi] } - if let ti = ttlColumnIndex, ti < values.count, let ttlStr = values[ti] { + if let ttli = ttlColumnIndex, ttli < values.count, let ttlStr = values[ttli] { ttl = Int(ttlStr) } } else { for cellChange in change.cellChanges { switch cellChange.columnName { case "Key": key = cellChange.newValue + case "Type": type = cellChange.newValue case "Value": value = cellChange.newValue case "TTL": if let ttlStr = cellChange.newValue { ttl = Int(ttlStr) } @@ -110,7 +120,7 @@ struct RedisStatementGenerator { } let v = value ?? "" - let cmd = "SET \(escapeArgument(k)) \(escapeArgument(v))" + let cmd = generateInsertCommand(key: k, value: v, type: type?.lowercased()) statements.append((statement: cmd, parameters: [])) if let ttlSeconds = ttl, ttlSeconds > 0 { @@ -121,6 +131,31 @@ struct RedisStatementGenerator { return statements } + /// Generate the appropriate Redis command based on the data type + private func generateInsertCommand(key: String, value: String, type: String?) -> String { + switch type { + case "hash": + // Try to parse value as JSON object for HSET key field1 val1 ... + if let data = value.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] { + var args = "HSET \(escapeArgument(key))" + for (field, val) in json { + args += " \(escapeArgument(field)) \(escapeArgument(String(describing: val)))" + } + return args + } + return "HSET \(escapeArgument(key)) value \(escapeArgument(value))" + case "list": + return "RPUSH \(escapeArgument(key)) \(escapeArgument(value))" + case "set": + return "SADD \(escapeArgument(key)) \(escapeArgument(value))" + case "zset": + return "ZADD \(escapeArgument(key)) 0 \(escapeArgument(value))" + default: + return "SET \(escapeArgument(key)) \(escapeArgument(value))" + } + } + // MARK: - UPDATE private func generateUpdate(for change: PluginRowChange) -> [(statement: String, parameters: [String?])] { @@ -148,12 +183,30 @@ struct RedisStatementGenerator { return key }() + // Determine the Redis type from the original row data + let redisType: String? = { + guard let ti = typeColumnIndex, + let originalRow = change.originalRow, + ti < originalRow.count else { + return nil + } + return originalRow[ti] + }() + for cellChange in change.cellChanges { switch cellChange.columnName { case "Key": continue // Already handled above case "Value": if let newValue = cellChange.newValue { + let typeLower = redisType?.lowercased() ?? "string" + if typeLower != "string" { + // Non-string types show a preview; blindly SET would destroy the data structure + Self.logger.warning( + "Skipping Value update for \(typeLower) key '\(effectiveKey)' - use query editor" + ) + continue + } let cmd = "SET \(escapeArgument(effectiveKey)) \(escapeArgument(newValue))" statements.append((statement: cmd, parameters: [])) } @@ -187,14 +240,18 @@ struct RedisStatementGenerator { /// Escape a Redis argument for safe embedding in a command string. /// Wraps in double quotes if the value contains whitespace or special characters. + /// Ensures special characters round-trip correctly through the tokenizer. private func escapeArgument(_ value: String) -> String { - let needsQuoting = value.isEmpty || value.contains(where: { $0.isWhitespace || $0 == "\"" || $0 == "'" }) + let needsQuoting = value.isEmpty || value.contains(where: { + $0.isWhitespace || $0 == "\"" || $0 == "'" || $0 == "\\" || $0 == "\n" || $0 == "\r" || $0 == "\t" + }) if needsQuoting { let escaped = value .replacingOccurrences(of: "\\", with: "\\\\") .replacingOccurrences(of: "\"", with: "\\\"") .replacingOccurrences(of: "\n", with: "\\n") .replacingOccurrences(of: "\r", with: "\\r") + .replacingOccurrences(of: "\t", with: "\\t") return "\"\(escaped)\"" } return value diff --git a/TablePro/Views/Main/Extensions/MainContentCoordinator+Navigation.swift b/TablePro/Views/Main/Extensions/MainContentCoordinator+Navigation.swift index adf5c18b..f89ce488 100644 --- a/TablePro/Views/Main/Extensions/MainContentCoordinator+Navigation.swift +++ b/TablePro/Views/Main/Extensions/MainContentCoordinator+Navigation.swift @@ -431,18 +431,26 @@ extension MainContentCoordinator { /// Select a Redis database index and then run the query. /// Redis sidebar clicks go through openTableTab (sync), so we need a Task /// to call the async selectDatabase before executing the query. + /// Cancels any previous in-flight switch to prevent race conditions + /// from rapid sidebar clicks. private func selectRedisDatabaseAndQuery(_ dbIndex: Int) { + cancelRedisDatabaseSwitchTask() + let connId = connectionId let database = String(dbIndex) - Task { @MainActor in + redisDatabaseSwitchTask = Task { @MainActor [weak self] in + guard let self else { return } do { if let adapter = DatabaseManager.shared.driver(for: connId) as? PluginDriverAdapter { try await adapter.switchDatabase(to: String(dbIndex)) } } catch { - navigationLogger.error("Failed to SELECT Redis db\(dbIndex): \(error.localizedDescription, privacy: .public)") + if !Task.isCancelled { + navigationLogger.error("Failed to SELECT Redis db\(dbIndex): \(error.localizedDescription, privacy: .public)") + } return } + guard !Task.isCancelled else { return } DatabaseManager.shared.updateSession(connId) { session in session.currentDatabase = database } diff --git a/TablePro/Views/Main/Extensions/MainContentCoordinator+QueryAnalysis.swift b/TablePro/Views/Main/Extensions/MainContentCoordinator+QueryAnalysis.swift index 9b230528..ab934f95 100644 --- a/TablePro/Views/Main/Extensions/MainContentCoordinator+QueryAnalysis.swift +++ b/TablePro/Views/Main/Extensions/MainContentCoordinator+QueryAnalysis.swift @@ -17,9 +17,36 @@ extension MainContentCoordinator { "RENAME ", "GRANT ", "REVOKE ", ] + /// Redis commands that modify data + private static let redisWriteCommands: Set = [ + "SET", "DEL", "HSET", "HDEL", "HMSET", "LPUSH", "RPUSH", "LPOP", "RPOP", + "SADD", "SREM", "ZADD", "ZREM", "EXPIRE", "PERSIST", "RENAME", + "FLUSHDB", "FLUSHALL", "MSET", "APPEND", "INCR", "DECR", "INCRBY", + "DECRBY", "SETEX", "PSETEX", "SETNX", "GETSET", "GETDEL", + "XADD", "XTRIM", "XDEL", + ] + + /// Redis commands that are destructive + private static let redisDangerousCommands: Set = [ + "FLUSHDB", "FLUSHALL", "DEBUG", "SHUTDOWN", + ] + /// Check if a SQL statement is a write operation (modifies data or schema) func isWriteQuery(_ sql: String) -> Bool { - let uppercased = sql.uppercased().trimmingCharacters(in: .whitespacesAndNewlines) + let trimmed = sql.trimmingCharacters(in: .whitespacesAndNewlines) + + // Redis: check the first token against known write commands + if connection.type == .redis { + let firstToken = trimmed.prefix(while: { !$0.isWhitespace }).uppercased() + // CONFIG SET is a write; plain CONFIG GET is not + if firstToken == "CONFIG" { + let rest = trimmed.dropFirst(firstToken.count).trimmingCharacters(in: .whitespaces) + return rest.uppercased().hasPrefix("SET") + } + return Self.redisWriteCommands.contains(firstToken) + } + + let uppercased = trimmed.uppercased() return Self.writeQueryPrefixes.contains { uppercased.hasPrefix($0) } } @@ -30,7 +57,20 @@ extension MainContentCoordinator { /// Check if a query is potentially dangerous (DROP, TRUNCATE, DELETE without WHERE) func isDangerousQuery(_ sql: String) -> Bool { - let uppercased = sql.uppercased().trimmingCharacters(in: .whitespacesAndNewlines) + let trimmed = sql.trimmingCharacters(in: .whitespacesAndNewlines) + + // Redis: check for destructive commands + if connection.type == .redis { + let firstToken = trimmed.prefix(while: { !$0.isWhitespace }).uppercased() + // CONFIG SET is dangerous + if firstToken == "CONFIG" { + let rest = trimmed.dropFirst(firstToken.count).trimmingCharacters(in: .whitespaces) + return rest.uppercased().hasPrefix("SET") + } + return Self.redisDangerousCommands.contains(firstToken) + } + + let uppercased = trimmed.uppercased() // Check for DROP if uppercased.hasPrefix("DROP ") { diff --git a/TablePro/Views/Main/Extensions/MainContentCoordinator+QueryHelpers.swift b/TablePro/Views/Main/Extensions/MainContentCoordinator+QueryHelpers.swift index 7c7b91d0..85c80263 100644 --- a/TablePro/Views/Main/Extensions/MainContentCoordinator+QueryHelpers.swift +++ b/TablePro/Views/Main/Extensions/MainContentCoordinator+QueryHelpers.swift @@ -196,30 +196,46 @@ extension MainContentCoordinator { connectionType: DatabaseType, schemaResult: SchemaResult? ) { + let isNonSQL = PluginManager.shared.editorLanguage(for: connectionType) != .sql + + // Phase 2a: Exact row count + // Redis/non-SQL drivers don't support SELECT COUNT(*); use approximate count instead. Task { [weak self] in guard let self else { return } try? await Task.sleep(nanoseconds: 200_000_000) guard !self.isTearingDown else { return } guard let mainDriver = DatabaseManager.shared.driver(for: connectionId) else { return } - let quotedTable = mainDriver.quoteIdentifier(tableName) - let countResult = try? await mainDriver.execute( - query: "SELECT COUNT(*) FROM \(quotedTable)" - ) - if let firstRow = countResult?.rows.first, - let countStr = firstRow.first ?? nil, - let count = Int(countStr) { + + let count: Int? + if isNonSQL { + count = try? await mainDriver.fetchApproximateRowCount(table: tableName) + } else { + let quotedTable = mainDriver.quoteIdentifier(tableName) + let countResult = try? await mainDriver.execute( + query: "SELECT COUNT(*) FROM \(quotedTable)" + ) + if let firstRow = countResult?.rows.first, + let countStr = firstRow.first.flatMap({ $0 }) { + count = Int(countStr) + } else { + count = nil + } + } + + if let count { await MainActor.run { [weak self] in guard let self else { return } guard capturedGeneration == queryGeneration else { return } if let idx = tabManager.tabs.firstIndex(where: { $0.id == tabId }) { tabManager.tabs[idx].pagination.totalRowCount = count - tabManager.tabs[idx].pagination.isApproximateRowCount = false + tabManager.tabs[idx].pagination.isApproximateRowCount = isNonSQL } } } } - // Phase 2b: Fetch enum/set values + // Phase 2b: Fetch enum/set values (not applicable for non-SQL databases) + guard !isNonSQL else { return } guard let enumDriver = DatabaseManager.shared.driver(for: connectionId) else { return } Task { [weak self] in guard let self else { return } @@ -270,21 +286,34 @@ extension MainContentCoordinator { capturedGeneration: Int, connectionType: DatabaseType ) { + let isNonSQL = PluginManager.shared.editorLanguage(for: connectionType) != .sql + Task { [weak self] in guard let self else { return } guard let mainDriver = DatabaseManager.shared.driver(for: connectionId) else { return } - let quotedTable = mainDriver.quoteIdentifier(tableName) - let countResult = try? await mainDriver.execute( - query: "SELECT COUNT(*) FROM \(quotedTable)" - ) - if let firstRow = countResult?.rows.first, - let countStr = firstRow.first ?? nil, - let count = Int(countStr) { + + let count: Int? + if isNonSQL { + count = try? await mainDriver.fetchApproximateRowCount(table: tableName) + } else { + let quotedTable = mainDriver.quoteIdentifier(tableName) + let countResult = try? await mainDriver.execute( + query: "SELECT COUNT(*) FROM \(quotedTable)" + ) + if let firstRow = countResult?.rows.first, + let countStr = firstRow.first.flatMap({ $0 }) { + count = Int(countStr) + } else { + count = nil + } + } + + if let count { await MainActor.run { [weak self] in guard let self else { return } if let idx = tabManager.tabs.firstIndex(where: { $0.id == tabId }) { tabManager.tabs[idx].pagination.totalRowCount = count - tabManager.tabs[idx].pagination.isApproximateRowCount = false + tabManager.tabs[idx].pagination.isApproximateRowCount = isNonSQL } } } diff --git a/TablePro/Views/Main/Extensions/MainContentCoordinator+Redis.swift b/TablePro/Views/Main/Extensions/MainContentCoordinator+Redis.swift index 82a820ff..782b5186 100644 --- a/TablePro/Views/Main/Extensions/MainContentCoordinator+Redis.swift +++ b/TablePro/Views/Main/Extensions/MainContentCoordinator+Redis.swift @@ -8,4 +8,10 @@ import Foundation extension MainContentCoordinator { + /// Cancel any in-flight Redis database switch task to prevent race conditions + /// from rapid sidebar clicks. + func cancelRedisDatabaseSwitchTask() { + redisDatabaseSwitchTask?.cancel() + redisDatabaseSwitchTask = nil + } } diff --git a/TablePro/Views/Main/Extensions/MainContentCoordinator+SaveChanges.swift b/TablePro/Views/Main/Extensions/MainContentCoordinator+SaveChanges.swift index 0d983571..d194aa90 100644 --- a/TablePro/Views/Main/Extensions/MainContentCoordinator+SaveChanges.swift +++ b/TablePro/Views/Main/Extensions/MainContentCoordinator+SaveChanges.swift @@ -187,7 +187,13 @@ extension MainContentCoordinator { throw DatabaseError.notConnected } - try await driver.beginTransaction() + // Redis MULTI/EXEC is not a true transaction (no rollback on failure), + // so execute statements individually without wrapping. + let useTransaction = dbType != .redis + + if useTransaction { + try await driver.beginTransaction() + } do { for statement in validStatements { @@ -212,9 +218,13 @@ extension MainContentCoordinator { ) } - try await driver.commitTransaction() + if useTransaction { + try await driver.commitTransaction() + } } catch { - try? await driver.rollbackTransaction() + if useTransaction { + try? await driver.rollbackTransaction() + } throw error } diff --git a/TablePro/Views/Main/MainContentCoordinator.swift b/TablePro/Views/Main/MainContentCoordinator.swift index af45d32f..34408a42 100644 --- a/TablePro/Views/Main/MainContentCoordinator.swift +++ b/TablePro/Views/Main/MainContentCoordinator.swift @@ -98,6 +98,7 @@ final class MainContentCoordinator { @ObservationIgnored internal var queryGeneration: Int = 0 @ObservationIgnored internal var currentQueryTask: Task? + @ObservationIgnored internal var redisDatabaseSwitchTask: Task? @ObservationIgnored private var changeManagerUpdateTask: Task? @ObservationIgnored private var activeSortTasks: [UUID: Task] = [:] @ObservationIgnored private var terminationObserver: NSObjectProtocol?