Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/main/java/net/spy/memcached/OperationFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import net.spy.memcached.tapmessage.TapOpcode;

import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.SaslClient;
import java.util.Collection;
import java.util.Map;

Expand Down Expand Up @@ -334,15 +335,12 @@ CASOperation cas(StoreType t, String key, long casId, int flags, int exp,
/**
* Create a new sasl auth operation.
*/
SASLAuthOperation saslAuth(String[] mech, String serverName,
Map<String, ?> props, CallbackHandler cbh, OperationCallback cb);
SASLAuthOperation saslAuth(SaslClient sc, OperationCallback cb);

/**
* Create a new sasl step operation.
*/
SASLStepOperation saslStep(String[] mech, byte[] challenge,
String serverName, Map<String, ?> props, CallbackHandler cbh,
OperationCallback cb);
SASLStepOperation saslStep(SaslClient sc, byte[] ch, OperationCallback cb);

/**
* Clone an operation.
Expand Down
30 changes: 29 additions & 1 deletion src/main/java/net/spy/memcached/auth/AuthDescriptor.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
package net.spy.memcached.auth;

import javax.security.auth.callback.CallbackHandler;
import java.util.Map;
import java.util.Properties;

/**
* Information required to specify authentication mechanisms and callbacks.
Expand All @@ -32,6 +34,8 @@ public class AuthDescriptor {

private final String[] mechs;
private final CallbackHandler cbh;
private final Map<String, ?> props;
private final String serverName;
private int authAttempts;
private int allowedAuthAttempts;

Expand All @@ -49,12 +53,19 @@ public class AuthDescriptor {
* should be used instead, passing in new String[] {"PLAIN"} will force
* the client to use PLAIN.</p>
*
* <p>It is possible to specify an optional server name to be used
* in certain digest mechanisms for validation. If unspecified, the server's
* socket address is used.</p>
*
* @param m list of mechanisms
* @param h the callback handler for grabbing credentials and stuff
* @param s the server name. Can be null
*/
public AuthDescriptor(String[] m, CallbackHandler h) {
public AuthDescriptor(String[] m, CallbackHandler h, String s, Map<String, ?> p) {
mechs = m;
cbh = h;
props = p;
serverName = s;
authAttempts = 0;
String authThreshhold =
System.getProperty("net.spy.memcached.auth.AuthThreshold");
Expand All @@ -65,6 +76,15 @@ public AuthDescriptor(String[] m, CallbackHandler h) {
}
}

/**
*
* @param m list of mechanisms
* @param h the callback handler for grabbing credentials and stuff
*/
public AuthDescriptor(String[] m, CallbackHandler h) {
this(m, h, null, null);
}

/**
* Get a typical auth descriptor for CRAM-MD5 or PLAIN auth with the given
* username and password.
Expand Down Expand Up @@ -97,4 +117,12 @@ public String[] getMechs() {
public CallbackHandler getCallback() {
return cbh;
}

public Map<String, ?> getProperties() {
return props;
}

public String getServerName() {
return serverName;
}
}
39 changes: 27 additions & 12 deletions src/main/java/net/spy/memcached/auth/AuthThread.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
import net.spy.memcached.ops.OperationCallback;
import net.spy.memcached.ops.OperationStatus;

import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;

/**
* A thread that does SASL authentication.
*/
Expand Down Expand Up @@ -149,6 +153,17 @@ public void run() {
throw new IllegalStateException("Got empty SASL auth mech list.");
}

String serverName = authDescriptor.getServerName();
if (serverName == null) {
serverName = node.getSocketAddress().toString();
}
SaslClient saslClient;
try {
saslClient = Sasl.createSaslClient(supportedMechs, null, "memcached", serverName, authDescriptor.getProperties(), authDescriptor.getCallback());
} catch (SaslException e) {
throw new RuntimeException("Error initializing SASL client", e);
}

OperationStatus priorStatus = null;
while (!done.get()) {
long stepStart = System.nanoTime();
Expand Down Expand Up @@ -177,7 +192,7 @@ public void complete() {
};

// Get the prior status to create the correct operation.
final Operation op = buildOperation(priorStatus, cb, supportedMechs);
final Operation op = buildOperation(priorStatus, cb, saslClient);
conn.insertOperation(node, op);

try {
Expand All @@ -200,7 +215,7 @@ public void complete() {
long stepDiff = TimeUnit.NANOSECONDS.toMillis(System.nanoTime()
- stepStart);
msg = String.format("SASL Step took %dms on %s",
stepDiff, node.toString());
stepDiff, node);
level = mechsDiff
>= AUTH_ROUNDTRIP_THRESHOLD ? Level.WARN : Level.DEBUG;
getLogger().log(level, msg);
Expand All @@ -216,24 +231,24 @@ public void complete() {
}
}

try {
saslClient.dispose();
} catch (SaslException e) {
throw new RuntimeException("Error while disposing SASL", e);
}

long totalDiff = TimeUnit.NANOSECONDS.toMillis(System.nanoTime()
- totalStart);
msg = String.format("SASL Auth took %dms on %s",
totalDiff, node.toString());
msg = String.format("SASL Auth took %dms on %s", totalDiff, node);
level = mechsDiff >= AUTH_TOTAL_THRESHOLD ? Level.WARN : Level.DEBUG;
getLogger().log(level, msg);
}

private Operation buildOperation(OperationStatus st, OperationCallback cb,
final String [] supportedMechs) {
private Operation buildOperation(OperationStatus st, OperationCallback cb, SaslClient sc) {
if (st == null) {
return opFact.saslAuth(supportedMechs,
node.getSocketAddress().toString(), null,
authDescriptor.getCallback(), cb);
return opFact.saslAuth(sc, cb);
} else {
return opFact.saslStep(supportedMechs, KeyUtil.getKeyBytes(
st.getMessage()), node.getSocketAddress().toString(), null,
authDescriptor.getCallback(), cb);
return opFact.saslStep(sc, KeyUtil.getKeyBytes(st.getMessage()), cb);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
import net.spy.memcached.tapmessage.TapOpcode;

import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.SaslClient;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;
Expand Down Expand Up @@ -192,20 +193,20 @@ protected Collection<? extends Operation> cloneGet(KeyedOperation op) {
return rv;
}

@Override
public SASLMechsOperation saslMechs(OperationCallback cb) {
throw new UnsupportedOperationException("SASL is not supported for "
+ "ASCII protocol");
}

public SASLStepOperation saslStep(String[] mech, byte[] challenge,
String serverName, Map<String, ?> props, CallbackHandler cbh,
OperationCallback cb) {
@Override
public SASLStepOperation saslStep(SaslClient sc, byte[] challenge, OperationCallback cb) {
throw new UnsupportedOperationException("SASL is not supported for "
+ "ASCII protocol");
}

public SASLAuthOperation saslAuth(String[] mech, String serverName,
Map<String, ?> props, CallbackHandler cbh, OperationCallback cb) {
@Override
public SASLAuthOperation saslAuth(SaslClient sc, OperationCallback cb) {
throw new UnsupportedOperationException("SASL is not supported for "
+ "ASCII protocol");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@
import net.spy.memcached.ops.ConcatenationOperation;
import net.spy.memcached.ops.ConcatenationType;
import net.spy.memcached.ops.ConfigurationType;
import net.spy.memcached.ops.DeleteConfigOperation;
import net.spy.memcached.ops.DeleteOperation;
import net.spy.memcached.ops.FlushOperation;
import net.spy.memcached.ops.GetAndTouchOperation;
import net.spy.memcached.ops.GetConfigOperation;
import net.spy.memcached.ops.GetOperation;
import net.spy.memcached.ops.GetOperation.Callback;
import net.spy.memcached.ops.DeleteConfigOperation;
import net.spy.memcached.ops.GetlOperation;
import net.spy.memcached.ops.GetsOperation;
import net.spy.memcached.ops.KeyedOperation;
Expand Down Expand Up @@ -68,10 +68,9 @@
import net.spy.memcached.tapmessage.RequestMessage;
import net.spy.memcached.tapmessage.TapOpcode;

import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.SaslClient;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Map;

/**
* Factory for binary operations.
Expand Down Expand Up @@ -217,20 +216,19 @@ protected Collection<? extends Operation> cloneGet(KeyedOperation op) {
return rv;
}

public SASLAuthOperation saslAuth(String[] mech, String serverName,
Map<String, ?> props, CallbackHandler cbh, OperationCallback cb) {
return new SASLAuthOperationImpl(mech, serverName, props, cbh, cb);
@Override
public SASLAuthOperation saslAuth(SaslClient sc, OperationCallback cb) {
return new SASLAuthOperationImpl(sc, cb);
}

@Override
public SASLMechsOperation saslMechs(OperationCallback cb) {
return new SASLMechsOperationImpl(cb);
}

public SASLStepOperation saslStep(String[] mech, byte[] challenge,
String serverName, Map<String, ?> props, CallbackHandler cbh,
OperationCallback cb) {
return new SASLStepOperationImpl(mech, challenge, serverName, props, cbh,
cb);
@Override
public SASLStepOperation saslStep(SaslClient sc, byte[] ch, OperationCallback cb) {
return new SASLStepOperationImpl(sc, ch, cb);
}

public TapOperation tapBackfill(String id, long date, OperationCallback cb) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,12 @@

package net.spy.memcached.protocol.binary;

import java.util.Map;
import net.spy.memcached.ops.OperationCallback;
import net.spy.memcached.ops.SASLAuthOperation;

import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;

import net.spy.memcached.ops.OperationCallback;
import net.spy.memcached.ops.SASLAuthOperation;

/**
* SASL authenticator.
*/
Expand All @@ -40,15 +37,13 @@ public class SASLAuthOperationImpl extends SASLBaseOperationImpl implements

private static final byte CMD = 0x21;

public SASLAuthOperationImpl(String[] m, String s, Map<String, ?> p,
CallbackHandler h, OperationCallback c) {
super(CMD, m, EMPTY_BYTES, s, p, h, c);
public SASLAuthOperationImpl(SaslClient sc, OperationCallback c) {
super(CMD, sc, EMPTY_BYTES, c);
}

@Override
protected byte[] buildResponse(SaslClient sc) throws SaslException {
return sc.hasInitialResponse() ? sc.evaluateChallenge(challenge)
: EMPTY_BYTES;
protected byte[] buildResponse() throws SaslException {
return sc.hasInitialResponse() ? sc.evaluateChallenge(ch) : ch;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,49 +23,35 @@

package net.spy.memcached.protocol.binary;

import java.io.IOException;
import java.util.Map;

import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;

import net.spy.memcached.ops.OperationCallback;
import net.spy.memcached.ops.OperationState;
import net.spy.memcached.ops.OperationStatus;
import net.spy.memcached.ops.StatusCode;

import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import java.io.IOException;

/**
* SASL authenticator.
*/
public abstract class SASLBaseOperationImpl extends OperationImpl {

private static final byte SASL_CONTINUE = 0x21;

protected final String[] mech;
protected final byte[] challenge;
protected final String serverName;
protected final Map<String, ?> props;
protected final CallbackHandler cbh;
protected final SaslClient sc;
protected final byte[] ch;

public SASLBaseOperationImpl(byte c, String[] m, byte[] ch, String s,
Map<String, ?> p, CallbackHandler h, OperationCallback cb) {
public SASLBaseOperationImpl(byte c, SaslClient sasl, byte[] challenge, OperationCallback cb) {
super(c, generateOpaque(), cb);
mech = m;
challenge = ch;
serverName = s;
props = p;
cbh = h;
sc = sasl;
ch = challenge;
}

@Override
public void initialize() {
try {
SaslClient sc = Sasl.createSaslClient(mech, null, "memcached",
serverName, props, cbh);

byte[] response = buildResponse(sc);
byte[] response = buildResponse();
String mechanism = sc.getMechanismName();

getLogger().debug("Using SASL auth mechanism: " + mechanism);
Expand All @@ -77,7 +63,7 @@ public void initialize() {
}
}

protected abstract byte[] buildResponse(SaslClient sc) throws SaslException;
protected abstract byte[] buildResponse() throws SaslException;

@Override
protected void decodePayload(byte[] pl) {
Expand Down
Loading