From c8f5bb5b575f1119b9aea297a5a8821b4e3eab10 Mon Sep 17 00:00:00 2001 From: Tristan Tarrant Date: Tue, 21 Feb 2023 15:17:33 +0100 Subject: [PATCH] SASL authentication fixes * Reuse SaslClient across all steps of an authentication * Allow passing a server name through AuthDescriptor * Allow passing SASL client properties through AuthDescriptor --- .../net/spy/memcached/OperationFactory.java | 8 ++-- .../spy/memcached/auth/AuthDescriptor.java | 30 +++++++++++++- .../net/spy/memcached/auth/AuthThread.java | 39 +++++++++++++------ .../protocol/ascii/AsciiOperationFactory.java | 11 +++--- .../binary/BinaryOperationFactory.java | 20 +++++----- .../binary/SASLAuthOperationImpl.java | 17 +++----- .../binary/SASLBaseOperationImpl.java | 36 ++++++----------- .../binary/SASLStepOperationImpl.java | 16 +++----- .../protocol/binary/BinaryToStringTest.java | 4 +- 9 files changed, 99 insertions(+), 82 deletions(-) diff --git a/src/main/java/net/spy/memcached/OperationFactory.java b/src/main/java/net/spy/memcached/OperationFactory.java index 7428b8858..12abf40ef 100644 --- a/src/main/java/net/spy/memcached/OperationFactory.java +++ b/src/main/java/net/spy/memcached/OperationFactory.java @@ -64,6 +64,7 @@ import net.spy.memcached.tapmessage.TapOpcode; import javax.security.auth.callback.CallbackHandler; +import javax.security.sasl.SaslClient; import java.util.Collection; import java.util.Map; @@ -334,15 +335,12 @@ CASOperation cas(StoreType t, String key, long casId, int flags, int exp, /** * Create a new sasl auth operation. */ - SASLAuthOperation saslAuth(String[] mech, String serverName, - Map props, CallbackHandler cbh, OperationCallback cb); + SASLAuthOperation saslAuth(SaslClient sc, OperationCallback cb); /** * Create a new sasl step operation. */ - SASLStepOperation saslStep(String[] mech, byte[] challenge, - String serverName, Map props, CallbackHandler cbh, - OperationCallback cb); + SASLStepOperation saslStep(SaslClient sc, byte[] ch, OperationCallback cb); /** * Clone an operation. diff --git a/src/main/java/net/spy/memcached/auth/AuthDescriptor.java b/src/main/java/net/spy/memcached/auth/AuthDescriptor.java index 07ebc8645..bb91b649e 100644 --- a/src/main/java/net/spy/memcached/auth/AuthDescriptor.java +++ b/src/main/java/net/spy/memcached/auth/AuthDescriptor.java @@ -24,6 +24,8 @@ package net.spy.memcached.auth; import javax.security.auth.callback.CallbackHandler; +import java.util.Map; +import java.util.Properties; /** * Information required to specify authentication mechanisms and callbacks. @@ -32,6 +34,8 @@ public class AuthDescriptor { private final String[] mechs; private final CallbackHandler cbh; + private final Map props; + private final String serverName; private int authAttempts; private int allowedAuthAttempts; @@ -49,12 +53,19 @@ public class AuthDescriptor { * should be used instead, passing in new String[] {"PLAIN"} will force * the client to use PLAIN.

* + *

It is possible to specify an optional server name to be used + * in certain digest mechanisms for validation. If unspecified, the server's + * socket address is used.

+ * * @param m list of mechanisms * @param h the callback handler for grabbing credentials and stuff + * @param s the server name. Can be null */ - public AuthDescriptor(String[] m, CallbackHandler h) { + public AuthDescriptor(String[] m, CallbackHandler h, String s, Map p) { mechs = m; cbh = h; + props = p; + serverName = s; authAttempts = 0; String authThreshhold = System.getProperty("net.spy.memcached.auth.AuthThreshold"); @@ -65,6 +76,15 @@ public AuthDescriptor(String[] m, CallbackHandler h) { } } + /** + * + * @param m list of mechanisms + * @param h the callback handler for grabbing credentials and stuff + */ + public AuthDescriptor(String[] m, CallbackHandler h) { + this(m, h, null, null); + } + /** * Get a typical auth descriptor for CRAM-MD5 or PLAIN auth with the given * username and password. @@ -97,4 +117,12 @@ public String[] getMechs() { public CallbackHandler getCallback() { return cbh; } + + public Map getProperties() { + return props; + } + + public String getServerName() { + return serverName; + } } diff --git a/src/main/java/net/spy/memcached/auth/AuthThread.java b/src/main/java/net/spy/memcached/auth/AuthThread.java index 2a635b64a..7e7cbee0a 100644 --- a/src/main/java/net/spy/memcached/auth/AuthThread.java +++ b/src/main/java/net/spy/memcached/auth/AuthThread.java @@ -38,6 +38,10 @@ import net.spy.memcached.ops.OperationCallback; import net.spy.memcached.ops.OperationStatus; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; + /** * A thread that does SASL authentication. */ @@ -149,6 +153,17 @@ public void run() { throw new IllegalStateException("Got empty SASL auth mech list."); } + String serverName = authDescriptor.getServerName(); + if (serverName == null) { + serverName = node.getSocketAddress().toString(); + } + SaslClient saslClient; + try { + saslClient = Sasl.createSaslClient(supportedMechs, null, "memcached", serverName, authDescriptor.getProperties(), authDescriptor.getCallback()); + } catch (SaslException e) { + throw new RuntimeException("Error initializing SASL client", e); + } + OperationStatus priorStatus = null; while (!done.get()) { long stepStart = System.nanoTime(); @@ -177,7 +192,7 @@ public void complete() { }; // Get the prior status to create the correct operation. - final Operation op = buildOperation(priorStatus, cb, supportedMechs); + final Operation op = buildOperation(priorStatus, cb, saslClient); conn.insertOperation(node, op); try { @@ -200,7 +215,7 @@ public void complete() { long stepDiff = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - stepStart); msg = String.format("SASL Step took %dms on %s", - stepDiff, node.toString()); + stepDiff, node); level = mechsDiff >= AUTH_ROUNDTRIP_THRESHOLD ? Level.WARN : Level.DEBUG; getLogger().log(level, msg); @@ -216,24 +231,24 @@ public void complete() { } } + try { + saslClient.dispose(); + } catch (SaslException e) { + throw new RuntimeException("Error while disposing SASL", e); + } + long totalDiff = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - totalStart); - msg = String.format("SASL Auth took %dms on %s", - totalDiff, node.toString()); + msg = String.format("SASL Auth took %dms on %s", totalDiff, node); level = mechsDiff >= AUTH_TOTAL_THRESHOLD ? Level.WARN : Level.DEBUG; getLogger().log(level, msg); } - private Operation buildOperation(OperationStatus st, OperationCallback cb, - final String [] supportedMechs) { + private Operation buildOperation(OperationStatus st, OperationCallback cb, SaslClient sc) { if (st == null) { - return opFact.saslAuth(supportedMechs, - node.getSocketAddress().toString(), null, - authDescriptor.getCallback(), cb); + return opFact.saslAuth(sc, cb); } else { - return opFact.saslStep(supportedMechs, KeyUtil.getKeyBytes( - st.getMessage()), node.getSocketAddress().toString(), null, - authDescriptor.getCallback(), cb); + return opFact.saslStep(sc, KeyUtil.getKeyBytes(st.getMessage()), cb); } } } diff --git a/src/main/java/net/spy/memcached/protocol/ascii/AsciiOperationFactory.java b/src/main/java/net/spy/memcached/protocol/ascii/AsciiOperationFactory.java index a89f75d4e..5f078e0b1 100644 --- a/src/main/java/net/spy/memcached/protocol/ascii/AsciiOperationFactory.java +++ b/src/main/java/net/spy/memcached/protocol/ascii/AsciiOperationFactory.java @@ -67,6 +67,7 @@ import net.spy.memcached.tapmessage.TapOpcode; import javax.security.auth.callback.CallbackHandler; +import javax.security.sasl.SaslClient; import java.util.ArrayList; import java.util.Collection; import java.util.Map; @@ -192,20 +193,20 @@ protected Collection cloneGet(KeyedOperation op) { return rv; } + @Override public SASLMechsOperation saslMechs(OperationCallback cb) { throw new UnsupportedOperationException("SASL is not supported for " + "ASCII protocol"); } - public SASLStepOperation saslStep(String[] mech, byte[] challenge, - String serverName, Map props, CallbackHandler cbh, - OperationCallback cb) { + @Override + public SASLStepOperation saslStep(SaslClient sc, byte[] challenge, OperationCallback cb) { throw new UnsupportedOperationException("SASL is not supported for " + "ASCII protocol"); } - public SASLAuthOperation saslAuth(String[] mech, String serverName, - Map props, CallbackHandler cbh, OperationCallback cb) { + @Override + public SASLAuthOperation saslAuth(SaslClient sc, OperationCallback cb) { throw new UnsupportedOperationException("SASL is not supported for " + "ASCII protocol"); } diff --git a/src/main/java/net/spy/memcached/protocol/binary/BinaryOperationFactory.java b/src/main/java/net/spy/memcached/protocol/binary/BinaryOperationFactory.java index 8f6f67d45..72d90f45a 100644 --- a/src/main/java/net/spy/memcached/protocol/binary/BinaryOperationFactory.java +++ b/src/main/java/net/spy/memcached/protocol/binary/BinaryOperationFactory.java @@ -33,13 +33,13 @@ import net.spy.memcached.ops.ConcatenationOperation; import net.spy.memcached.ops.ConcatenationType; import net.spy.memcached.ops.ConfigurationType; +import net.spy.memcached.ops.DeleteConfigOperation; import net.spy.memcached.ops.DeleteOperation; import net.spy.memcached.ops.FlushOperation; import net.spy.memcached.ops.GetAndTouchOperation; import net.spy.memcached.ops.GetConfigOperation; import net.spy.memcached.ops.GetOperation; import net.spy.memcached.ops.GetOperation.Callback; -import net.spy.memcached.ops.DeleteConfigOperation; import net.spy.memcached.ops.GetlOperation; import net.spy.memcached.ops.GetsOperation; import net.spy.memcached.ops.KeyedOperation; @@ -68,10 +68,9 @@ import net.spy.memcached.tapmessage.RequestMessage; import net.spy.memcached.tapmessage.TapOpcode; -import javax.security.auth.callback.CallbackHandler; +import javax.security.sasl.SaslClient; import java.util.ArrayList; import java.util.Collection; -import java.util.Map; /** * Factory for binary operations. @@ -217,20 +216,19 @@ protected Collection cloneGet(KeyedOperation op) { return rv; } - public SASLAuthOperation saslAuth(String[] mech, String serverName, - Map props, CallbackHandler cbh, OperationCallback cb) { - return new SASLAuthOperationImpl(mech, serverName, props, cbh, cb); + @Override + public SASLAuthOperation saslAuth(SaslClient sc, OperationCallback cb) { + return new SASLAuthOperationImpl(sc, cb); } + @Override public SASLMechsOperation saslMechs(OperationCallback cb) { return new SASLMechsOperationImpl(cb); } - public SASLStepOperation saslStep(String[] mech, byte[] challenge, - String serverName, Map props, CallbackHandler cbh, - OperationCallback cb) { - return new SASLStepOperationImpl(mech, challenge, serverName, props, cbh, - cb); + @Override + public SASLStepOperation saslStep(SaslClient sc, byte[] ch, OperationCallback cb) { + return new SASLStepOperationImpl(sc, ch, cb); } public TapOperation tapBackfill(String id, long date, OperationCallback cb) { diff --git a/src/main/java/net/spy/memcached/protocol/binary/SASLAuthOperationImpl.java b/src/main/java/net/spy/memcached/protocol/binary/SASLAuthOperationImpl.java index c2801110a..e9823e437 100644 --- a/src/main/java/net/spy/memcached/protocol/binary/SASLAuthOperationImpl.java +++ b/src/main/java/net/spy/memcached/protocol/binary/SASLAuthOperationImpl.java @@ -23,15 +23,12 @@ package net.spy.memcached.protocol.binary; -import java.util.Map; +import net.spy.memcached.ops.OperationCallback; +import net.spy.memcached.ops.SASLAuthOperation; -import javax.security.auth.callback.CallbackHandler; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; -import net.spy.memcached.ops.OperationCallback; -import net.spy.memcached.ops.SASLAuthOperation; - /** * SASL authenticator. */ @@ -40,15 +37,13 @@ public class SASLAuthOperationImpl extends SASLBaseOperationImpl implements private static final byte CMD = 0x21; - public SASLAuthOperationImpl(String[] m, String s, Map p, - CallbackHandler h, OperationCallback c) { - super(CMD, m, EMPTY_BYTES, s, p, h, c); + public SASLAuthOperationImpl(SaslClient sc, OperationCallback c) { + super(CMD, sc, EMPTY_BYTES, c); } @Override - protected byte[] buildResponse(SaslClient sc) throws SaslException { - return sc.hasInitialResponse() ? sc.evaluateChallenge(challenge) - : EMPTY_BYTES; + protected byte[] buildResponse() throws SaslException { + return sc.hasInitialResponse() ? sc.evaluateChallenge(ch) : ch; } @Override diff --git a/src/main/java/net/spy/memcached/protocol/binary/SASLBaseOperationImpl.java b/src/main/java/net/spy/memcached/protocol/binary/SASLBaseOperationImpl.java index 4b7fa8c39..c5d633160 100644 --- a/src/main/java/net/spy/memcached/protocol/binary/SASLBaseOperationImpl.java +++ b/src/main/java/net/spy/memcached/protocol/binary/SASLBaseOperationImpl.java @@ -23,19 +23,15 @@ package net.spy.memcached.protocol.binary; -import java.io.IOException; -import java.util.Map; - -import javax.security.auth.callback.CallbackHandler; -import javax.security.sasl.Sasl; -import javax.security.sasl.SaslClient; -import javax.security.sasl.SaslException; - import net.spy.memcached.ops.OperationCallback; import net.spy.memcached.ops.OperationState; import net.spy.memcached.ops.OperationStatus; import net.spy.memcached.ops.StatusCode; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import java.io.IOException; + /** * SASL authenticator. */ @@ -43,29 +39,19 @@ public abstract class SASLBaseOperationImpl extends OperationImpl { private static final byte SASL_CONTINUE = 0x21; - protected final String[] mech; - protected final byte[] challenge; - protected final String serverName; - protected final Map props; - protected final CallbackHandler cbh; + protected final SaslClient sc; + protected final byte[] ch; - public SASLBaseOperationImpl(byte c, String[] m, byte[] ch, String s, - Map p, CallbackHandler h, OperationCallback cb) { + public SASLBaseOperationImpl(byte c, SaslClient sasl, byte[] challenge, OperationCallback cb) { super(c, generateOpaque(), cb); - mech = m; - challenge = ch; - serverName = s; - props = p; - cbh = h; + sc = sasl; + ch = challenge; } @Override public void initialize() { try { - SaslClient sc = Sasl.createSaslClient(mech, null, "memcached", - serverName, props, cbh); - - byte[] response = buildResponse(sc); + byte[] response = buildResponse(); String mechanism = sc.getMechanismName(); getLogger().debug("Using SASL auth mechanism: " + mechanism); @@ -77,7 +63,7 @@ public void initialize() { } } - protected abstract byte[] buildResponse(SaslClient sc) throws SaslException; + protected abstract byte[] buildResponse() throws SaslException; @Override protected void decodePayload(byte[] pl) { diff --git a/src/main/java/net/spy/memcached/protocol/binary/SASLStepOperationImpl.java b/src/main/java/net/spy/memcached/protocol/binary/SASLStepOperationImpl.java index 9e36c5f39..ff111fd37 100644 --- a/src/main/java/net/spy/memcached/protocol/binary/SASLStepOperationImpl.java +++ b/src/main/java/net/spy/memcached/protocol/binary/SASLStepOperationImpl.java @@ -23,15 +23,12 @@ package net.spy.memcached.protocol.binary; -import java.util.Map; +import net.spy.memcached.ops.OperationCallback; +import net.spy.memcached.ops.SASLStepOperation; -import javax.security.auth.callback.CallbackHandler; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; -import net.spy.memcached.ops.OperationCallback; -import net.spy.memcached.ops.SASLStepOperation; - /** * A SASLStepOperationImpl. */ @@ -40,14 +37,13 @@ public class SASLStepOperationImpl extends SASLBaseOperationImpl implements private static final byte CMD = 0x22; - public SASLStepOperationImpl(String[] m, byte[] ch, String s, - Map p, CallbackHandler h, OperationCallback c) { - super(CMD, m, ch, s, p, h, c); + public SASLStepOperationImpl(SaslClient sc, byte[] challenge, OperationCallback c) { + super(CMD, sc, challenge, c); } @Override - protected byte[] buildResponse(SaslClient sc) throws SaslException { - return sc.evaluateChallenge(challenge); + protected byte[] buildResponse() throws SaslException { + return sc.evaluateChallenge(ch); } @Override diff --git a/src/test/java/net/spy/memcached/protocol/binary/BinaryToStringTest.java b/src/test/java/net/spy/memcached/protocol/binary/BinaryToStringTest.java index 9a64a9f54..3102b605d 100644 --- a/src/test/java/net/spy/memcached/protocol/binary/BinaryToStringTest.java +++ b/src/test/java/net/spy/memcached/protocol/binary/BinaryToStringTest.java @@ -88,7 +88,7 @@ public void testOptimiedSet() { } public void testSASLAuth() { - (new SASLAuthOperationImpl(null, null, null, null, null)).toString(); + (new SASLAuthOperationImpl(null, null)).toString(); } public void testSASLMechs() { @@ -96,7 +96,7 @@ public void testSASLMechs() { } public void testSASLStep() { - (new SASLStepOperationImpl(null, null, null, null, null, null)).toString(); + (new SASLStepOperationImpl(null, null, null)).toString(); } public void testStats() {