diff --git a/src/main/java/net/spy/memcached/MemcachedClient.java b/src/main/java/net/spy/memcached/MemcachedClient.java index 367e5e094..f8907c2b9 100644 --- a/src/main/java/net/spy/memcached/MemcachedClient.java +++ b/src/main/java/net/spy/memcached/MemcachedClient.java @@ -42,8 +42,6 @@ import java.util.concurrent.TimeoutException; import java.util.stream.Collectors; -import net.spy.memcached.auth.AuthDescriptor; -import net.spy.memcached.auth.AuthThreadMonitor; import net.spy.memcached.compat.SpyThread; import net.spy.memcached.internal.BroadcastFuture; import net.spy.memcached.internal.BulkFuture; @@ -120,7 +118,7 @@ * } */ public class MemcachedClient extends SpyThread - implements MemcachedClientIF, ConnectionObserver { + implements MemcachedClientIF { private volatile boolean running = true; private volatile boolean shuttingDown = false; @@ -132,16 +130,12 @@ public class MemcachedClient extends SpyThread protected final Transcoder transcoder; - private final AuthDescriptor authDescriptor; - private final byte delimiter; private static final String DEFAULT_MEMCACHED_CLIENT_NAME = "MemcachedClient"; private static final int GET_BULK_CHUNK_SIZE = 200; - private final AuthThreadMonitor authMonitor = new AuthThreadMonitor(); - /** * Get a memcached client operating on the specified memcached locations. * @@ -219,10 +213,6 @@ public MemcachedClient(ConnectionFactory cf, String name, List connObservers = new ConcurrentLinkedQueue<>(); private final Set nodesNeedVersionOp = new HashSet<>(); @@ -137,6 +142,7 @@ public MemcachedConnection(String name, ConnectionFactory f, FailureMode fm, OperationFactory opfactory) throws IOException { this.connFactory = f; + authDescriptor = f.getAuthDescriptor(); connName = name; connObservers.addAll(obs); addedQueue = new ConcurrentLinkedQueue<>(); @@ -636,6 +642,69 @@ private MemcachedNode makeMemcachedNode(String name, return qa; } + private void prepareAuthentication(final MemcachedNode node) { + if (authDescriptor == null) { + return; + } + + final SaslClient sc; + try { + sc = Sasl.createSaslClient(authDescriptor.getMechs(), null, + "memcached", node.getSocketAddress().toString(), null, authDescriptor.getCallback()); + } catch (Exception e) { + throw new IllegalStateException("Can't create SaslClient", e); + } + if (sc == null) { + throw new IllegalStateException("SaslClient is null"); + } + + final OperationCallback cb = new OperationCallback() { + private boolean authDone = false; + private boolean mechDone = false; + private OperationStatus priorStatus = null; + + @Override + public void receivedStatus(OperationStatus val) { + String msg = val.getMessage(); + // If the status we found was SASL_OK or NOT_SUPPORTED, we're authDone. + if ("SASL_OK".equals(msg) || "NOT_SUPPORTED".equals(msg)) { + authDone = true; + node.authComplete(true); + getLogger().info("Authenticated to " + node.getSocketAddress()); + } else if (!val.isSuccess()) { + authDone = true; + node.authComplete(false); + getLogger().error("Authentication failed to " + node.getSocketAddress() + ": " + msg); + } else if (!mechDone) { + mechDone = true; + } else { + // Get the prior status to create the correct operation. + priorStatus = val; + } + } + + @Override + public void complete() { + if (authDone) { + return; + } + + // NOTE: `this` keyword below is the OperationCallback object itself. + final Operation op; + if (priorStatus == null) { + op = opFactory.saslAuth(sc, this); + } else { + op = opFactory.saslStep(sc, KeyUtil.getKeyBytes(priorStatus.getMessage()), this); + } + + insertOperation(node, op); + } + }; + + final Operation mechOp = opFactory.saslMechs(true, cb); + insertOperation(node, mechOp); + } + private void prepareVersionInfo(final MemcachedNode node) { Operation op = opFactory.version(new OperationCallback() { @Override @@ -867,6 +936,7 @@ private void connected(MemcachedNode qa) { for (ConnectionObserver observer : connObservers) { observer.connectionEstablished(qa, rt); } + prepareAuthentication(qa); prepareVersionInfo(qa); } @@ -1422,7 +1492,7 @@ public void addOperation(final String key, final Operation o) { addOperation(findNodeByKey(key, o), o); } - public void insertOperation(final MemcachedNode node, final Operation o) { + private void insertOperation(final MemcachedNode node, final Operation o) { if (!node.isConnected() && failureMode == FailureMode.Cancel) { o.setHandlingNode(node); o.cancel("inactive node"); diff --git a/src/main/java/net/spy/memcached/auth/AuthThread.java b/src/main/java/net/spy/memcached/auth/AuthThread.java deleted file mode 100644 index 5dfc09bb9..000000000 --- a/src/main/java/net/spy/memcached/auth/AuthThread.java +++ /dev/null @@ -1,109 +0,0 @@ -package net.spy.memcached.auth; - -import java.util.concurrent.CountDownLatch; - -import javax.security.sasl.Sasl; -import javax.security.sasl.SaslClient; - -import net.spy.memcached.KeyUtil; -import net.spy.memcached.MemcachedConnection; -import net.spy.memcached.MemcachedNode; -import net.spy.memcached.OperationFactory; -import net.spy.memcached.compat.SpyThread; -import net.spy.memcached.ops.Operation; -import net.spy.memcached.ops.OperationCallback; -import net.spy.memcached.ops.OperationStatus; -import net.spy.memcached.ops.StatusCode; - -public class AuthThread extends SpyThread { - - private final MemcachedConnection conn; - private final AuthDescriptor authDescriptor; - private final OperationFactory opFact; - private final MemcachedNode node; - private final SaslClient sc; - - private boolean mechDone = false; - private boolean authDone = false; - private OperationStatus priorStatus = null; - - public AuthThread(MemcachedConnection c, OperationFactory o, - AuthDescriptor a, MemcachedNode n) { - conn = c; - opFact = o; - authDescriptor = a; - node = n; - try { - sc = Sasl.createSaslClient(authDescriptor.getMechs(), null, - "memcached", node.getSocketAddress().toString(), null, authDescriptor.getCallback()); - } catch (Exception e) { - throw new IllegalStateException("Can't create SaslClient", e); - } - - if (sc == null) { - throw new IllegalStateException("SaslClient is null"); - } - } - - @Override - public void run() { - while (!authDone) { - final CountDownLatch latch = new CountDownLatch(1); - final OperationCallback cb = new OperationCallback() { - @Override - public void receivedStatus(OperationStatus val) { - String msg = val.getMessage(); - // If the status we found was SASL_OK or NOT_SUPPORTED, we're authDone. - if ("SASL_OK".equals(msg) || "NOT_SUPPORTED".equals(msg)) { - authDone = true; - node.authComplete(true); - getLogger().info("Authenticated to " + node.getSocketAddress()); - } else if (val.getStatusCode() == StatusCode.CANCELLED && - AuthThread.this == Thread.currentThread()) { - // Don't call authComplete() if this callback is called by auth thread - // through calling op.cancel() after the InterruptedException . - authDone = true; - getLogger().error("Authentication canceled to " + node.getSocketAddress() + ": " + msg); - } else if (!val.isSuccess()) { - authDone = true; - node.authComplete(false); - getLogger().error("Authentication failed to " + node.getSocketAddress() + ": " + msg); - } else if (!mechDone) { - mechDone = true; - } else { - // Get the prior status to create the correct operation. - priorStatus = val; - } - } - - @Override - public void complete() { - latch.countDown(); - } - }; - - final Operation op; - if (!mechDone) { - op = opFact.saslMechs(true, cb); - } else if (priorStatus == null) { - op = opFact.saslAuth(sc, cb); - } else { - op = opFact.saslStep(sc, KeyUtil.getKeyBytes(priorStatus.getMessage()), cb); - } - conn.insertOperation(node, op); - - try { - latch.await(); - } catch (InterruptedException e) { - // we can be interrupted if we were in the - // process of auth'ing and the connection is - // lost or dropped due to bad auth - Thread.currentThread().interrupt(); - if (op != null) { - op.cancel("interruption to authentication: " + e); - } - authDone = true; // If we were interrupted, tear down - } - } - } -} diff --git a/src/main/java/net/spy/memcached/auth/AuthThreadMonitor.java b/src/main/java/net/spy/memcached/auth/AuthThreadMonitor.java deleted file mode 100644 index b2df2cea6..000000000 --- a/src/main/java/net/spy/memcached/auth/AuthThreadMonitor.java +++ /dev/null @@ -1,58 +0,0 @@ -package net.spy.memcached.auth; - -import java.util.HashMap; -import java.util.Map; - -import net.spy.memcached.MemcachedConnection; -import net.spy.memcached.MemcachedNode; -import net.spy.memcached.OperationFactory; -import net.spy.memcached.compat.SpyObject; - -/** - * This will ensure no more than one AuthThread will exist for a given - * MemcachedNode. - */ -public class AuthThreadMonitor extends SpyObject { - - private Map nodeMap; - - public AuthThreadMonitor() { - nodeMap = new HashMap<>(); - } - - /** - * Authenticate a new connection. This is typically used by a - * MemcachedNode in order to authenticate a connection right after it - * has been established. - * - * If an old, but not yet completed authentication exists this will - * stop it in order to create a new authentication attempt. - * - * @param conn - * @param opFact - * @param authDescriptor - * @param node - */ - public synchronized void authConnection(MemcachedConnection conn, - OperationFactory opFact, AuthDescriptor authDescriptor, - MemcachedNode node) { - interruptOldAuth(node); - AuthThread newSASLAuthenticator = new AuthThread(conn, opFact, - authDescriptor, node); - newSASLAuthenticator.start(); - nodeMap.put(node, newSASLAuthenticator); - } - - private void interruptOldAuth(MemcachedNode nodeToStop) { - AuthThread toStop = nodeMap.get(nodeToStop); - if (toStop != null) { - if (toStop.isAlive()) { - getLogger().warn("Incomplete authentication interrupted for node " + - nodeToStop); - toStop.interrupt(); - } - - nodeMap.remove(nodeToStop); - } - } -} diff --git a/src/main/java/net/spy/memcached/protocol/TCPMemcachedNodeImpl.java b/src/main/java/net/spy/memcached/protocol/TCPMemcachedNodeImpl.java index ba4432ca1..64c03b3de 100644 --- a/src/main/java/net/spy/memcached/protocol/TCPMemcachedNodeImpl.java +++ b/src/main/java/net/spy/memcached/protocol/TCPMemcachedNodeImpl.java @@ -336,11 +336,10 @@ public final boolean addOpToWriteQ(Operation op) { public final void insertOp(Operation op) { op.setHandlingNode(this); op.initialize(); - ArrayList tmp = new ArrayList<>( - inputQueue.size() + 1); - tmp.add(op); - inputQueue.drainTo(tmp); - inputQueue.addAll(tmp); + if (!writeQ.offer(op)) { + op.cancel("write queue overflow"); + return; + } addOpCount.incrementAndGet(); }