Skip to content

Commit e43c672

Browse files
committed
Retry authentication with all remaining auth methods after partial success
Signed-off-by: Jeroen van Erp <[email protected]>
1 parent d628c47 commit e43c672

File tree

4 files changed

+38
-11
lines changed

4 files changed

+38
-11
lines changed

src/main/java/net/schmizz/sshj/SSHClient.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import net.schmizz.sshj.transport.verification.FingerprintVerifier;
4141
import net.schmizz.sshj.transport.verification.HostKeyVerifier;
4242
import net.schmizz.sshj.transport.verification.OpenSSHKnownHosts;
43+
import net.schmizz.sshj.userauth.AuthResult;
4344
import net.schmizz.sshj.userauth.UserAuth;
4445
import net.schmizz.sshj.userauth.UserAuthException;
4546
import net.schmizz.sshj.userauth.UserAuthImpl;
@@ -218,13 +219,30 @@ public void auth(String username, Iterable<AuthMethod> methods)
218219
throws UserAuthException, TransportException {
219220
checkConnected();
220221
final Deque<UserAuthException> savedEx = new LinkedList<UserAuthException>();
221-
for (AuthMethod method: methods) {
222+
final List<AuthMethod> tried = new LinkedList<AuthMethod>();
223+
224+
for (Iterator<AuthMethod> it = methods.iterator(); it.hasNext();) {
225+
AuthMethod method = it.next();
222226
method.setLoggerFactory(loggerFactory);
227+
223228
try {
224-
if (auth.authenticate(username, (Service) conn, method, trans.getTimeoutMs()))
229+
AuthResult result = auth.authenticate(username, (Service) conn, method, trans.getTimeoutMs());
230+
231+
if (result == AuthResult.SUCCESS) {
225232
return;
233+
} else if (result == AuthResult.PARTIAL) {
234+
// Put all remaining methods in the tried list, so that we can try them for the second round of authentication
235+
while (it.hasNext()) {
236+
tried.add(it.next());
237+
}
238+
239+
auth(username, tried);
240+
return;
241+
}
242+
tried.add(method);
226243
} catch (UserAuthException e) {
227244
savedEx.push(e);
245+
tried.add(method);
228246
}
229247
}
230248
throw new UserAuthException("Exhausted available authentication methods", savedEx.peek());
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package net.schmizz.sshj.userauth;
2+
3+
public enum AuthResult {
4+
SUCCESS,
5+
FAILURE,
6+
PARTIAL
7+
}

src/main/java/net/schmizz/sshj/userauth/UserAuth.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ public interface UserAuth {
3737
* @param nextService the service to set on successful authentication
3838
* @param methods the {@link AuthMethod}'s to try
3939
*
40-
* @return whether authentication was successful
40+
* @return whether authentication was successful, failed, or partially successful
4141
*
4242
* @throws UserAuthException in case of authentication failure
4343
* @throws TransportException if there was a transport-layer error
4444
*/
45-
boolean authenticate(String username, Service nextService, AuthMethod methods, int timeoutMs)
45+
AuthResult authenticate(String username, Service nextService, AuthMethod methods, int timeoutMs)
4646
throws UserAuthException, TransportException;
4747

4848
/**

src/main/java/net/schmizz/sshj/userauth/UserAuthImpl.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public class UserAuthImpl
4040
extends AbstractService
4141
implements UserAuth {
4242

43-
private final Promise<Boolean, UserAuthException> authenticated;
43+
private final Promise<AuthResult, UserAuthException> authenticated;
4444

4545
// Externally available
4646
private volatile String banner = "";
@@ -53,13 +53,13 @@ public class UserAuthImpl
5353

5454
public UserAuthImpl(Transport trans) {
5555
super("ssh-userauth", trans);
56-
authenticated = new Promise<Boolean, UserAuthException>("authenticated", UserAuthException.chainer, trans.getConfig().getLoggerFactory());
56+
authenticated = new Promise<AuthResult, UserAuthException>("authenticated", UserAuthException.chainer, trans.getConfig().getLoggerFactory());
5757
}
5858

5959
@Override
60-
public boolean authenticate(String username, Service nextService, AuthMethod method, int timeoutMs)
60+
public AuthResult authenticate(String username, Service nextService, AuthMethod method, int timeoutMs)
6161
throws UserAuthException, TransportException {
62-
final boolean outcome;
62+
final AuthResult outcome;
6363

6464
authenticated.lock();
6565
try {
@@ -73,8 +73,10 @@ public boolean authenticate(String username, Service nextService, AuthMethod met
7373
currentMethod.request();
7474
outcome = authenticated.retrieve(timeoutMs, TimeUnit.MILLISECONDS);
7575

76-
if (outcome) {
76+
if (outcome == AuthResult.SUCCESS) {
7777
log.debug("`{}` auth successful", method.getName());
78+
} else if (outcome == AuthResult.PARTIAL) {
79+
log.debug("`{}` auth partially successful", method.getName());
7880
} else {
7981
log.debug("`{}` auth failed", method.getName());
8082
}
@@ -124,7 +126,7 @@ public void handle(Message msg, SSHPacket buf)
124126
// Should fix https://github.com/hierynomus/sshj/issues/237
125127
trans.setAuthenticated(); // So it can put delayed compression into force if applicable
126128
trans.setService(nextService); // We aren't in charge anymore, next service is
127-
authenticated.deliver(true);
129+
authenticated.deliver(AuthResult.SUCCESS);
128130
break;
129131

130132
case USERAUTH_FAILURE:
@@ -133,7 +135,7 @@ public void handle(Message msg, SSHPacket buf)
133135
if (allowedMethods.contains(currentMethod.getName()) && currentMethod.shouldRetry()) {
134136
currentMethod.request();
135137
} else {
136-
authenticated.deliver(false);
138+
authenticated.deliver(partialSuccess ? AuthResult.PARTIAL : AuthResult.FAILURE);
137139
}
138140
break;
139141

0 commit comments

Comments
 (0)