2
2
3
3
import java .io .EOFException ;
4
4
import java .io .IOException ;
5
+ import java .util .Set ;
6
+ import java .util .concurrent .ConcurrentHashMap ;
7
+
8
+ import org .eclipse .jetty .websocket .api .Session ;
5
9
6
10
import info .unterrainer .oauthtokenmanager .OauthTokenManager ;
7
11
import io .javalin .websocket .WsCloseContext ;
14
18
public class WsOauthHandlerBase extends WsHandlerBase {
15
19
16
20
private OauthTokenManager tokenHandler ;
21
+ private Set <WsConnectContext > clientsConnected = ConcurrentHashMap .newKeySet ();
22
+ private Set <WsConnectContext > clientsQuarantined = ConcurrentHashMap .newKeySet ();
17
23
18
24
void setTokenHandler (OauthTokenManager tokenHandler ) {
19
25
this .tokenHandler = tokenHandler ;
20
26
}
21
27
28
+ public void removeClient (Session session ) {
29
+ log .debug ("Removing client: [{}]" , session .getRemoteAddress ());
30
+ clientsConnected .removeIf (client -> client .session .equals (session ));
31
+ clientsQuarantined .removeIf (client -> client .session .equals (session ));
32
+ }
33
+
34
+ public WsConnectContext getClient (Session session ) {
35
+ log .debug ("Getting client: [{}]" , session .getRemoteAddress ());
36
+ return clientsConnected .stream ().filter (client -> client .session .equals (session )).findFirst ().orElse (null );
37
+ }
38
+
39
+ public WsConnectContext getQuarantinedClient (Session session ) {
40
+ log .debug ("Getting quarantined client: [{}]" , session .getRemoteAddress ());
41
+ return clientsQuarantined .stream ().filter (client -> client .session .equals (session )).findFirst ().orElse (null );
42
+ }
43
+
44
+ public boolean isQuarantined (Session session ) {
45
+ log .debug ("Checking if client is quarantined: [{}]" , session .getRemoteAddress ());
46
+ return clientsQuarantined .stream ().anyMatch (client -> client .session .equals (session ));
47
+ }
48
+
49
+ public boolean isConnected (Session session ) {
50
+ log .debug ("Checking if client is connected: [{}]" , session .getRemoteAddress ());
51
+ return clientsConnected .stream ().anyMatch (client -> client .session .equals (session ));
52
+ }
53
+
22
54
@ Override
23
55
public void onConnect (WsConnectContext ctx ) throws Exception {
24
56
log .debug ("New client tries to connect: [{}]" , ctx .session .getRemoteAddress ());
25
57
String token = ctx .header ("Authorization" );
58
+ if (token == null || token .isEmpty ()) {
59
+ log .warn ("No token provided for client: [{}]\n Sending connection into quarantine." ,
60
+ ctx .session .getRemoteAddress ());
61
+ clientsQuarantined .add (ctx );
62
+ return ;
63
+ }
26
64
log .debug ("New client token: [{}]" , token );
27
65
try {
28
66
tokenHandler .checkAccess (token );
67
+ clientsConnected .add (ctx );
29
68
} catch (Exception e ) {
69
+ log .debug ("Token validation failed for client [{}]. Disconnecting." , ctx .session .getRemoteAddress (), e );
30
70
ctx .session .close ();
31
71
return ;
32
72
}
@@ -35,11 +75,33 @@ public void onConnect(WsConnectContext ctx) throws Exception {
35
75
@ Override
36
76
public void onMessage (WsMessageContext ctx ) throws Exception {
37
77
log .debug ("Received from [{}]: [{}]" , ctx .session .getRemoteAddress (), ctx .message ());
78
+ if (isQuarantined (ctx .session )) {
79
+ log .warn ("Client [{}] is quarantined, checking message for standard authorization-bearer-token." ,
80
+ ctx .session .getRemoteAddress ());
81
+ if (ctx .message () == null || !ctx .message ().startsWith ("Bearer " )) {
82
+ log .warn ("Invalid message from quarantined client [{}]. Disconnecting." ,
83
+ ctx .session .getRemoteAddress ());
84
+ removeClient (ctx .session );
85
+ ctx .session .close ();
86
+ return ;
87
+ }
88
+ try {
89
+ tokenHandler .checkAccess (ctx .message ());
90
+ WsConnectContext client = getQuarantinedClient (ctx .session );
91
+ clientsQuarantined .removeIf (c -> c .session .equals (ctx .session ));
92
+ clientsConnected .add (client );
93
+ } catch (Exception e ) {
94
+ ctx .session .close ();
95
+ log .debug ("Token validation failed for client [{}]. Disconnecting." , ctx .session .getRemoteAddress (), e );
96
+ return ;
97
+ }
98
+ }
38
99
}
39
100
40
101
@ Override
41
102
public void onClose (WsCloseContext ctx ) throws Exception {
42
103
log .debug ("Disconnected client: [{}]" , ctx .session .getRemoteAddress ());
104
+ removeClient (ctx .session );
43
105
}
44
106
45
107
@ Override
@@ -50,5 +112,6 @@ public void onError(WsErrorContext ctx) throws Exception {
50
112
} else {
51
113
log .error ("Unexpected error on [{}]." , ctx .session .getRemoteAddress (), t );
52
114
}
115
+ removeClient (ctx .session );
53
116
}
54
117
}
0 commit comments