Skip to content

Commit 0f746c8

Browse files
feat: Add legacy and pilot token refresh (choice between them with partial)
1 parent 71b16f0 commit 0f746c8

File tree

3 files changed

+105
-28
lines changed

3 files changed

+105
-28
lines changed

Pilot/dirac-pilot.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@
8484
pilotUUID=pilotParams.pilotUUID,
8585
debugFlag=pilotParams.debugFlag,
8686
jwt=pilotParams.jwt,
87-
legacy_logging=pilotParams.isLegacyLogging
87+
legacy_logging=pilotParams.isLegacyLogging,
88+
clientID=pilotParams.clientID
8889
)
8990
log.info("Remote logger activated")
9091
log.buffer.write(log.format_to_json(

Pilot/pilotTools.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,23 @@ def load_module_from_path(module_name, path_to_module):
6969
basestring = str
7070

7171
try:
72-
from Pilot.proxyTools import getVO, BaseRequest, TokenBasedRequest, extract_diracx_payload
72+
from Pilot.proxyTools import (
73+
getVO,
74+
BaseRequest,
75+
TokenBasedRequest,
76+
extract_diracx_payload,
77+
refreshPilotToken,
78+
refreshUserToken
79+
)
7380
except ImportError:
74-
from proxyTools import getVO, BaseRequest, TokenBasedRequest, extract_diracx_payload
81+
from proxyTools import (
82+
getVO,
83+
BaseRequest,
84+
TokenBasedRequest,
85+
extract_diracx_payload,
86+
refreshPilotToken,
87+
refreshUserToken
88+
)
7589

7690
try:
7791
FileNotFoundError # pylint: disable=used-before-assignment
@@ -525,9 +539,10 @@ def __init__(
525539
isPilotLoggerOn=True,
526540
pilotUUID="unknown",
527541
flushInterval=10,
528-
bufsize=1000,
542+
bufsize=250,
529543
jwt = {},
530-
legacy_logging = False
544+
legacy_logging = False,
545+
clientID = ""
531546
):
532547
"""
533548
c'tor
@@ -538,7 +553,7 @@ def __init__(
538553
self.url = url
539554
self.pilotUUID = pilotUUID
540555
self.isPilotLoggerOn = isPilotLoggerOn
541-
sendToURL = partial(sendMessage, url, pilotUUID, legacy_logging)
556+
sendToURL = partial(sendMessage, url, pilotUUID, legacy_logging, clientID)
542557
self.buffer = FixedSizeBuffer(sendToURL, bufsize=bufsize, autoflush=flushInterval, jwt=jwt)
543558

544559
def format_to_json(self, level, message):
@@ -622,7 +637,7 @@ class FixedSizeBuffer(object):
622637
Once it's full, a message is sent to a remote server and the buffer is renewed.
623638
"""
624639

625-
def __init__(self, senderFunc, bufsize=1000, autoflush=10, jwt={}):
640+
def __init__(self, senderFunc, bufsize=250, autoflush=10, jwt={}):
626641
"""
627642
Constructor.
628643
@@ -645,6 +660,10 @@ def __init__(self, senderFunc, bufsize=1000, autoflush=10, jwt={}):
645660
self._nlines = 0
646661
self.senderFunc = senderFunc
647662
self.jwt = jwt
663+
# A fixed buffer used by a remote buffer can be deactivated:
664+
# If there's a 403/401 error, instead of crashing the pilot,
665+
# we will deactivate the log sending, and prefer just running the pilot.
666+
self.activated = True
648667

649668
@synchronized
650669
def write(self, content_json):
@@ -657,13 +676,11 @@ def write(self, content_json):
657676
:return: None
658677
:rtype: None
659678
"""
679+
if not self.activated:
680+
pass
660681

661682
self.output.extend(content_json)
662-
663-
try:
664-
self._nlines += max(1, len(content_json))
665-
except Exception:
666-
raise ValueError(content_json)
683+
self._nlines += max(1, len(content_json))
667684
self.sendFullBuffer()
668685

669686
@synchronized
@@ -674,7 +691,11 @@ def sendFullBuffer(self):
674691
"""
675692

676693
if self._nlines >= self.bufsize:
677-
self.flush()
694+
try:
695+
self.flush()
696+
except Exception as e:
697+
print("Deactivating fixed size buffer due to", str(e))
698+
self.activated = False
678699
self.output = []
679700

680701
@synchronized
@@ -685,6 +706,9 @@ def flush(self, force=False):
685706
:return: None
686707
:rtype: None
687708
"""
709+
if not self.activated:
710+
pass
711+
688712
if force or (self.output and self._nlines > 0):
689713
self.senderFunc(self.jwt, self.output)
690714
self._nlines = 0
@@ -700,7 +724,7 @@ def cancelTimer(self):
700724
self._timer.cancel()
701725

702726

703-
def sendMessage(diracx_URL, pilotUUID, legacy=False, jwt={}, rawMessage = []):
727+
def sendMessage(diracx_URL, pilotUUID, legacy=False, clientID="", jwt={}, rawMessage = []):
704728
"""
705729
Invoke a remote method on a Tornado server and pass a JSON message to it.
706730
@@ -724,8 +748,21 @@ def sendMessage(diracx_URL, pilotUUID, legacy=False, jwt={}, rawMessage = []):
724748

725749
if legacy:
726750
endpoint_path = "api/pilots/legacy/message"
751+
refresh_callback = partial(
752+
refreshUserToken,
753+
diracx_URL,
754+
pilotUUID,
755+
jwt,
756+
clientID
757+
)
727758
else:
728759
endpoint_path = "api/pilots/internal/message"
760+
refresh_callback = partial(
761+
refreshPilotToken,
762+
diracx_URL,
763+
pilotUUID,
764+
jwt
765+
)
729766

730767
config = TokenBasedRequest(
731768
diracx_URL=diracx_URL,
@@ -738,7 +775,8 @@ def sendMessage(diracx_URL, pilotUUID, legacy=False, jwt={}, rawMessage = []):
738775
# Do the request
739776
_res = config.executeRequest(
740777
raw_data=raw_data,
741-
json_output=False
778+
json_output=False,
779+
refresh_callback=refresh_callback
742780
)
743781

744782

@@ -775,7 +813,8 @@ def __init__(self, pilotParams):
775813
flushInterval=interval,
776814
bufsize=bufsize,
777815
jwt=pilotParams.jwt,
778-
legacy_logging=pilotParams.isLegacyLogging
816+
legacy_logging=pilotParams.isLegacyLogging,
817+
clientID=pilotParams.clientID
779818
)
780819

781820
self.log.isPilotLoggerOn = isPilotLoggerOn
@@ -976,7 +1015,7 @@ def __init__(self):
9761015
self.loggerURL = None
9771016
self.isLegacyLogging = False
9781017
self.loggerTimerInterval = 0
979-
self.loggerBufsize = 1000
1018+
self.loggerBufsize = 250
9801019
self.pilotUUID = "unknown"
9811020
self.modules = ""
9821021
self.userEnvVariables = ""

Pilot/proxyTools.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -252,9 +252,9 @@ def addJwtToHeader(self):
252252
# Adds the JWT in the HTTP request (in the Bearer field)
253253
self.headers["Authorization"] = "Bearer %s" % self.jwtData["access_token"]
254254

255-
def executeRequest(self, raw_data, insecure=False, content_type="json", json_output=True, tries_left=1):
255+
def executeRequest(self, raw_data, insecure=False, content_type="json", json_output=True, tries_left=1, refresh_callback=None):
256256

257-
while (tries_left > 0):
257+
while (tries_left >= 0):
258258

259259
try:
260260
return super(TokenBasedRequest, self).executeRequest(
@@ -268,11 +268,8 @@ def executeRequest(self, raw_data, insecure=False, content_type="json", json_out
268268
raise e
269269

270270
# If we have an unauthorized error, then refresh and retry
271-
refreshPilotToken(
272-
self.diracx_URL,
273-
self.pilotUUID,
274-
self.jwtData
275-
)
271+
if refresh_callback:
272+
refresh_callback()
276273

277274
self.addJwtToHeader()
278275

@@ -309,10 +306,47 @@ def executeRequest(self, raw_data, insecure=False, content_type="json", json_out
309306
json_output=json_output
310307
)
311308

309+
def refreshUserToken(url, pilotUUID, jwt, clientID):
310+
"""
311+
Refresh the JWT token (as a user).
312312
313-
def refreshPilotToken(url, pilotUUID, jwt):
313+
:param str url: Server URL
314+
:param str pilotUUID: Pilot unique ID
315+
:param dict jwt: Shared dict with current JWT; updated in-place
316+
:return: None
314317
"""
315-
Refresh the JWT token in a separate thread.
318+
319+
# PRECONDITION: jwt must contain "refresh_token"
320+
if not jwt or "refresh_token" not in jwt:
321+
raise ValueError("To refresh a token, a pilot needs a JWT with refresh_token")
322+
323+
# Get CA path from environment
324+
caPath = os.getenv("X509_CERT_DIR")
325+
326+
# Create request object with required configuration
327+
config = BaseRequest(
328+
url=url + "api/auth/token",
329+
caPath=caPath,
330+
pilotUUID=pilotUUID,
331+
)
332+
333+
# Perform the request to refresh the token
334+
response = config.executeRequest(
335+
raw_data={
336+
"refresh_token": jwt["refresh_token"],
337+
"grant_type": "refresh_token",
338+
"client_id": clientID
339+
},
340+
content_type="x-www-form-urlencoded",
341+
)
342+
343+
# Do NOT assign directly, because jwt is a reference, not a copy
344+
jwt["access_token"] = response["access_token"]
345+
jwt["refresh_token"] = response["refresh_token"]
346+
347+
def refreshPilotToken(url, pilotUUID, jwt, _=None):
348+
"""
349+
Refresh the JWT token (as a pilot).
316350
317351
:param str url: Server URL
318352
:param str pilotUUID: Pilot unique ID
@@ -329,7 +363,7 @@ def refreshPilotToken(url, pilotUUID, jwt):
329363

330364
# Create request object with required configuration
331365
config = BaseRequest(
332-
url=url + "/api/auth/pilot-token",
366+
url=url + "api/auth/pilot-token",
333367
caPath=caPath,
334368
pilotUUID=pilotUUID,
335369
)
@@ -365,9 +399,12 @@ def revokePilotToken(url, pilotUUID, jwt, clientID):
365399
# Get CA path from environment
366400
caPath = os.getenv("X509_CERT_DIR")
367401

402+
if not url.endswith("/"):
403+
url = url + "/"
404+
368405
# Create request object with required configuration
369406
config = BaseRequest(
370-
url="%s/api/auth/revoke" % url,
407+
url="%sapi/auth/revoke" % url,
371408
caPath=caPath,
372409
pilotUUID=pilotUUID
373410
)

0 commit comments

Comments
 (0)