From c8f5bb5b575f1119b9aea297a5a8821b4e3eab10 Mon Sep 17 00:00:00 2001
From: Tristan Tarrant
Date: Tue, 21 Feb 2023 15:17:33 +0100
Subject: [PATCH] SASL authentication fixes
* Reuse SaslClient across all steps of an authentication
* Allow passing a server name through AuthDescriptor
* Allow passing SASL client properties through AuthDescriptor
---
.../net/spy/memcached/OperationFactory.java | 8 ++--
.../spy/memcached/auth/AuthDescriptor.java | 30 +++++++++++++-
.../net/spy/memcached/auth/AuthThread.java | 39 +++++++++++++------
.../protocol/ascii/AsciiOperationFactory.java | 11 +++---
.../binary/BinaryOperationFactory.java | 20 +++++-----
.../binary/SASLAuthOperationImpl.java | 17 +++-----
.../binary/SASLBaseOperationImpl.java | 36 ++++++-----------
.../binary/SASLStepOperationImpl.java | 16 +++-----
.../protocol/binary/BinaryToStringTest.java | 4 +-
9 files changed, 99 insertions(+), 82 deletions(-)
diff --git a/src/main/java/net/spy/memcached/OperationFactory.java b/src/main/java/net/spy/memcached/OperationFactory.java
index 7428b8858..12abf40ef 100644
--- a/src/main/java/net/spy/memcached/OperationFactory.java
+++ b/src/main/java/net/spy/memcached/OperationFactory.java
@@ -64,6 +64,7 @@
import net.spy.memcached.tapmessage.TapOpcode;
import javax.security.auth.callback.CallbackHandler;
+import javax.security.sasl.SaslClient;
import java.util.Collection;
import java.util.Map;
@@ -334,15 +335,12 @@ CASOperation cas(StoreType t, String key, long casId, int flags, int exp,
/**
* Create a new sasl auth operation.
*/
- SASLAuthOperation saslAuth(String[] mech, String serverName,
- Map props, CallbackHandler cbh, OperationCallback cb);
+ SASLAuthOperation saslAuth(SaslClient sc, OperationCallback cb);
/**
* Create a new sasl step operation.
*/
- SASLStepOperation saslStep(String[] mech, byte[] challenge,
- String serverName, Map props, CallbackHandler cbh,
- OperationCallback cb);
+ SASLStepOperation saslStep(SaslClient sc, byte[] ch, OperationCallback cb);
/**
* Clone an operation.
diff --git a/src/main/java/net/spy/memcached/auth/AuthDescriptor.java b/src/main/java/net/spy/memcached/auth/AuthDescriptor.java
index 07ebc8645..bb91b649e 100644
--- a/src/main/java/net/spy/memcached/auth/AuthDescriptor.java
+++ b/src/main/java/net/spy/memcached/auth/AuthDescriptor.java
@@ -24,6 +24,8 @@
package net.spy.memcached.auth;
import javax.security.auth.callback.CallbackHandler;
+import java.util.Map;
+import java.util.Properties;
/**
* Information required to specify authentication mechanisms and callbacks.
@@ -32,6 +34,8 @@ public class AuthDescriptor {
private final String[] mechs;
private final CallbackHandler cbh;
+ private final Map props;
+ private final String serverName;
private int authAttempts;
private int allowedAuthAttempts;
@@ -49,12 +53,19 @@ public class AuthDescriptor {
* should be used instead, passing in new String[] {"PLAIN"} will force
* the client to use PLAIN.
*
+ * It is possible to specify an optional server name to be used
+ * in certain digest mechanisms for validation. If unspecified, the server's
+ * socket address is used.
+ *
* @param m list of mechanisms
* @param h the callback handler for grabbing credentials and stuff
+ * @param s the server name. Can be null
*/
- public AuthDescriptor(String[] m, CallbackHandler h) {
+ public AuthDescriptor(String[] m, CallbackHandler h, String s, Map p) {
mechs = m;
cbh = h;
+ props = p;
+ serverName = s;
authAttempts = 0;
String authThreshhold =
System.getProperty("net.spy.memcached.auth.AuthThreshold");
@@ -65,6 +76,15 @@ public AuthDescriptor(String[] m, CallbackHandler h) {
}
}
+ /**
+ *
+ * @param m list of mechanisms
+ * @param h the callback handler for grabbing credentials and stuff
+ */
+ public AuthDescriptor(String[] m, CallbackHandler h) {
+ this(m, h, null, null);
+ }
+
/**
* Get a typical auth descriptor for CRAM-MD5 or PLAIN auth with the given
* username and password.
@@ -97,4 +117,12 @@ public String[] getMechs() {
public CallbackHandler getCallback() {
return cbh;
}
+
+ public Map getProperties() {
+ return props;
+ }
+
+ public String getServerName() {
+ return serverName;
+ }
}
diff --git a/src/main/java/net/spy/memcached/auth/AuthThread.java b/src/main/java/net/spy/memcached/auth/AuthThread.java
index 2a635b64a..7e7cbee0a 100644
--- a/src/main/java/net/spy/memcached/auth/AuthThread.java
+++ b/src/main/java/net/spy/memcached/auth/AuthThread.java
@@ -38,6 +38,10 @@
import net.spy.memcached.ops.OperationCallback;
import net.spy.memcached.ops.OperationStatus;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+
/**
* A thread that does SASL authentication.
*/
@@ -149,6 +153,17 @@ public void run() {
throw new IllegalStateException("Got empty SASL auth mech list.");
}
+ String serverName = authDescriptor.getServerName();
+ if (serverName == null) {
+ serverName = node.getSocketAddress().toString();
+ }
+ SaslClient saslClient;
+ try {
+ saslClient = Sasl.createSaslClient(supportedMechs, null, "memcached", serverName, authDescriptor.getProperties(), authDescriptor.getCallback());
+ } catch (SaslException e) {
+ throw new RuntimeException("Error initializing SASL client", e);
+ }
+
OperationStatus priorStatus = null;
while (!done.get()) {
long stepStart = System.nanoTime();
@@ -177,7 +192,7 @@ public void complete() {
};
// Get the prior status to create the correct operation.
- final Operation op = buildOperation(priorStatus, cb, supportedMechs);
+ final Operation op = buildOperation(priorStatus, cb, saslClient);
conn.insertOperation(node, op);
try {
@@ -200,7 +215,7 @@ public void complete() {
long stepDiff = TimeUnit.NANOSECONDS.toMillis(System.nanoTime()
- stepStart);
msg = String.format("SASL Step took %dms on %s",
- stepDiff, node.toString());
+ stepDiff, node);
level = mechsDiff
>= AUTH_ROUNDTRIP_THRESHOLD ? Level.WARN : Level.DEBUG;
getLogger().log(level, msg);
@@ -216,24 +231,24 @@ public void complete() {
}
}
+ try {
+ saslClient.dispose();
+ } catch (SaslException e) {
+ throw new RuntimeException("Error while disposing SASL", e);
+ }
+
long totalDiff = TimeUnit.NANOSECONDS.toMillis(System.nanoTime()
- totalStart);
- msg = String.format("SASL Auth took %dms on %s",
- totalDiff, node.toString());
+ msg = String.format("SASL Auth took %dms on %s", totalDiff, node);
level = mechsDiff >= AUTH_TOTAL_THRESHOLD ? Level.WARN : Level.DEBUG;
getLogger().log(level, msg);
}
- private Operation buildOperation(OperationStatus st, OperationCallback cb,
- final String [] supportedMechs) {
+ private Operation buildOperation(OperationStatus st, OperationCallback cb, SaslClient sc) {
if (st == null) {
- return opFact.saslAuth(supportedMechs,
- node.getSocketAddress().toString(), null,
- authDescriptor.getCallback(), cb);
+ return opFact.saslAuth(sc, cb);
} else {
- return opFact.saslStep(supportedMechs, KeyUtil.getKeyBytes(
- st.getMessage()), node.getSocketAddress().toString(), null,
- authDescriptor.getCallback(), cb);
+ return opFact.saslStep(sc, KeyUtil.getKeyBytes(st.getMessage()), cb);
}
}
}
diff --git a/src/main/java/net/spy/memcached/protocol/ascii/AsciiOperationFactory.java b/src/main/java/net/spy/memcached/protocol/ascii/AsciiOperationFactory.java
index a89f75d4e..5f078e0b1 100644
--- a/src/main/java/net/spy/memcached/protocol/ascii/AsciiOperationFactory.java
+++ b/src/main/java/net/spy/memcached/protocol/ascii/AsciiOperationFactory.java
@@ -67,6 +67,7 @@
import net.spy.memcached.tapmessage.TapOpcode;
import javax.security.auth.callback.CallbackHandler;
+import javax.security.sasl.SaslClient;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;
@@ -192,20 +193,20 @@ protected Collection extends Operation> cloneGet(KeyedOperation op) {
return rv;
}
+ @Override
public SASLMechsOperation saslMechs(OperationCallback cb) {
throw new UnsupportedOperationException("SASL is not supported for "
+ "ASCII protocol");
}
- public SASLStepOperation saslStep(String[] mech, byte[] challenge,
- String serverName, Map props, CallbackHandler cbh,
- OperationCallback cb) {
+ @Override
+ public SASLStepOperation saslStep(SaslClient sc, byte[] challenge, OperationCallback cb) {
throw new UnsupportedOperationException("SASL is not supported for "
+ "ASCII protocol");
}
- public SASLAuthOperation saslAuth(String[] mech, String serverName,
- Map props, CallbackHandler cbh, OperationCallback cb) {
+ @Override
+ public SASLAuthOperation saslAuth(SaslClient sc, OperationCallback cb) {
throw new UnsupportedOperationException("SASL is not supported for "
+ "ASCII protocol");
}
diff --git a/src/main/java/net/spy/memcached/protocol/binary/BinaryOperationFactory.java b/src/main/java/net/spy/memcached/protocol/binary/BinaryOperationFactory.java
index 8f6f67d45..72d90f45a 100644
--- a/src/main/java/net/spy/memcached/protocol/binary/BinaryOperationFactory.java
+++ b/src/main/java/net/spy/memcached/protocol/binary/BinaryOperationFactory.java
@@ -33,13 +33,13 @@
import net.spy.memcached.ops.ConcatenationOperation;
import net.spy.memcached.ops.ConcatenationType;
import net.spy.memcached.ops.ConfigurationType;
+import net.spy.memcached.ops.DeleteConfigOperation;
import net.spy.memcached.ops.DeleteOperation;
import net.spy.memcached.ops.FlushOperation;
import net.spy.memcached.ops.GetAndTouchOperation;
import net.spy.memcached.ops.GetConfigOperation;
import net.spy.memcached.ops.GetOperation;
import net.spy.memcached.ops.GetOperation.Callback;
-import net.spy.memcached.ops.DeleteConfigOperation;
import net.spy.memcached.ops.GetlOperation;
import net.spy.memcached.ops.GetsOperation;
import net.spy.memcached.ops.KeyedOperation;
@@ -68,10 +68,9 @@
import net.spy.memcached.tapmessage.RequestMessage;
import net.spy.memcached.tapmessage.TapOpcode;
-import javax.security.auth.callback.CallbackHandler;
+import javax.security.sasl.SaslClient;
import java.util.ArrayList;
import java.util.Collection;
-import java.util.Map;
/**
* Factory for binary operations.
@@ -217,20 +216,19 @@ protected Collection extends Operation> cloneGet(KeyedOperation op) {
return rv;
}
- public SASLAuthOperation saslAuth(String[] mech, String serverName,
- Map props, CallbackHandler cbh, OperationCallback cb) {
- return new SASLAuthOperationImpl(mech, serverName, props, cbh, cb);
+ @Override
+ public SASLAuthOperation saslAuth(SaslClient sc, OperationCallback cb) {
+ return new SASLAuthOperationImpl(sc, cb);
}
+ @Override
public SASLMechsOperation saslMechs(OperationCallback cb) {
return new SASLMechsOperationImpl(cb);
}
- public SASLStepOperation saslStep(String[] mech, byte[] challenge,
- String serverName, Map props, CallbackHandler cbh,
- OperationCallback cb) {
- return new SASLStepOperationImpl(mech, challenge, serverName, props, cbh,
- cb);
+ @Override
+ public SASLStepOperation saslStep(SaslClient sc, byte[] ch, OperationCallback cb) {
+ return new SASLStepOperationImpl(sc, ch, cb);
}
public TapOperation tapBackfill(String id, long date, OperationCallback cb) {
diff --git a/src/main/java/net/spy/memcached/protocol/binary/SASLAuthOperationImpl.java b/src/main/java/net/spy/memcached/protocol/binary/SASLAuthOperationImpl.java
index c2801110a..e9823e437 100644
--- a/src/main/java/net/spy/memcached/protocol/binary/SASLAuthOperationImpl.java
+++ b/src/main/java/net/spy/memcached/protocol/binary/SASLAuthOperationImpl.java
@@ -23,15 +23,12 @@
package net.spy.memcached.protocol.binary;
-import java.util.Map;
+import net.spy.memcached.ops.OperationCallback;
+import net.spy.memcached.ops.SASLAuthOperation;
-import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
-import net.spy.memcached.ops.OperationCallback;
-import net.spy.memcached.ops.SASLAuthOperation;
-
/**
* SASL authenticator.
*/
@@ -40,15 +37,13 @@ public class SASLAuthOperationImpl extends SASLBaseOperationImpl implements
private static final byte CMD = 0x21;
- public SASLAuthOperationImpl(String[] m, String s, Map p,
- CallbackHandler h, OperationCallback c) {
- super(CMD, m, EMPTY_BYTES, s, p, h, c);
+ public SASLAuthOperationImpl(SaslClient sc, OperationCallback c) {
+ super(CMD, sc, EMPTY_BYTES, c);
}
@Override
- protected byte[] buildResponse(SaslClient sc) throws SaslException {
- return sc.hasInitialResponse() ? sc.evaluateChallenge(challenge)
- : EMPTY_BYTES;
+ protected byte[] buildResponse() throws SaslException {
+ return sc.hasInitialResponse() ? sc.evaluateChallenge(ch) : ch;
}
@Override
diff --git a/src/main/java/net/spy/memcached/protocol/binary/SASLBaseOperationImpl.java b/src/main/java/net/spy/memcached/protocol/binary/SASLBaseOperationImpl.java
index 4b7fa8c39..c5d633160 100644
--- a/src/main/java/net/spy/memcached/protocol/binary/SASLBaseOperationImpl.java
+++ b/src/main/java/net/spy/memcached/protocol/binary/SASLBaseOperationImpl.java
@@ -23,19 +23,15 @@
package net.spy.memcached.protocol.binary;
-import java.io.IOException;
-import java.util.Map;
-
-import javax.security.auth.callback.CallbackHandler;
-import javax.security.sasl.Sasl;
-import javax.security.sasl.SaslClient;
-import javax.security.sasl.SaslException;
-
import net.spy.memcached.ops.OperationCallback;
import net.spy.memcached.ops.OperationState;
import net.spy.memcached.ops.OperationStatus;
import net.spy.memcached.ops.StatusCode;
+import javax.security.sasl.SaslClient;
+import javax.security.sasl.SaslException;
+import java.io.IOException;
+
/**
* SASL authenticator.
*/
@@ -43,29 +39,19 @@ public abstract class SASLBaseOperationImpl extends OperationImpl {
private static final byte SASL_CONTINUE = 0x21;
- protected final String[] mech;
- protected final byte[] challenge;
- protected final String serverName;
- protected final Map props;
- protected final CallbackHandler cbh;
+ protected final SaslClient sc;
+ protected final byte[] ch;
- public SASLBaseOperationImpl(byte c, String[] m, byte[] ch, String s,
- Map p, CallbackHandler h, OperationCallback cb) {
+ public SASLBaseOperationImpl(byte c, SaslClient sasl, byte[] challenge, OperationCallback cb) {
super(c, generateOpaque(), cb);
- mech = m;
- challenge = ch;
- serverName = s;
- props = p;
- cbh = h;
+ sc = sasl;
+ ch = challenge;
}
@Override
public void initialize() {
try {
- SaslClient sc = Sasl.createSaslClient(mech, null, "memcached",
- serverName, props, cbh);
-
- byte[] response = buildResponse(sc);
+ byte[] response = buildResponse();
String mechanism = sc.getMechanismName();
getLogger().debug("Using SASL auth mechanism: " + mechanism);
@@ -77,7 +63,7 @@ public void initialize() {
}
}
- protected abstract byte[] buildResponse(SaslClient sc) throws SaslException;
+ protected abstract byte[] buildResponse() throws SaslException;
@Override
protected void decodePayload(byte[] pl) {
diff --git a/src/main/java/net/spy/memcached/protocol/binary/SASLStepOperationImpl.java b/src/main/java/net/spy/memcached/protocol/binary/SASLStepOperationImpl.java
index 9e36c5f39..ff111fd37 100644
--- a/src/main/java/net/spy/memcached/protocol/binary/SASLStepOperationImpl.java
+++ b/src/main/java/net/spy/memcached/protocol/binary/SASLStepOperationImpl.java
@@ -23,15 +23,12 @@
package net.spy.memcached.protocol.binary;
-import java.util.Map;
+import net.spy.memcached.ops.OperationCallback;
+import net.spy.memcached.ops.SASLStepOperation;
-import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
-import net.spy.memcached.ops.OperationCallback;
-import net.spy.memcached.ops.SASLStepOperation;
-
/**
* A SASLStepOperationImpl.
*/
@@ -40,14 +37,13 @@ public class SASLStepOperationImpl extends SASLBaseOperationImpl implements
private static final byte CMD = 0x22;
- public SASLStepOperationImpl(String[] m, byte[] ch, String s,
- Map p, CallbackHandler h, OperationCallback c) {
- super(CMD, m, ch, s, p, h, c);
+ public SASLStepOperationImpl(SaslClient sc, byte[] challenge, OperationCallback c) {
+ super(CMD, sc, challenge, c);
}
@Override
- protected byte[] buildResponse(SaslClient sc) throws SaslException {
- return sc.evaluateChallenge(challenge);
+ protected byte[] buildResponse() throws SaslException {
+ return sc.evaluateChallenge(ch);
}
@Override
diff --git a/src/test/java/net/spy/memcached/protocol/binary/BinaryToStringTest.java b/src/test/java/net/spy/memcached/protocol/binary/BinaryToStringTest.java
index 9a64a9f54..3102b605d 100644
--- a/src/test/java/net/spy/memcached/protocol/binary/BinaryToStringTest.java
+++ b/src/test/java/net/spy/memcached/protocol/binary/BinaryToStringTest.java
@@ -88,7 +88,7 @@ public void testOptimiedSet() {
}
public void testSASLAuth() {
- (new SASLAuthOperationImpl(null, null, null, null, null)).toString();
+ (new SASLAuthOperationImpl(null, null)).toString();
}
public void testSASLMechs() {
@@ -96,7 +96,7 @@ public void testSASLMechs() {
}
public void testSASLStep() {
- (new SASLStepOperationImpl(null, null, null, null, null, null)).toString();
+ (new SASLStepOperationImpl(null, null, null)).toString();
}
public void testStats() {