From ff988d433b54cf88b51f78292f08430628aab391 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Tue, 5 Aug 2025 15:43:05 +0200 Subject: [PATCH 1/3] feat: Add support for DiracX tokens, and possibly tokens in proxies --- .github/workflows/basic.yml | 6 +- Pilot/dirac-pilot.py | 14 ++ Pilot/pilotCommands.py | 16 +- Pilot/pilotTools.py | 99 +++++++++- Pilot/proxyTools.py | 379 ++++++++++++++++++++++++++++++++++-- 5 files changed, 497 insertions(+), 17 deletions(-) diff --git a/.github/workflows/basic.yml b/.github/workflows/basic.yml index d42df061..22e89b04 100644 --- a/.github/workflows/basic.yml +++ b/.github/workflows/basic.yml @@ -23,7 +23,7 @@ jobs: strategy: matrix: python: - - 2.7.18 + # - 2.7.18 - 3.6.15 - 3.9.17 @@ -53,7 +53,7 @@ jobs: strategy: matrix: python: - - 2.7.18 + # - 2.7.18 - 3.6.15 - 3.9.17 @@ -77,7 +77,7 @@ jobs: strategy: matrix: python: - - 2.7.18 + # - 2.7.18 - 3.6.15 - 3.9.17 diff --git a/Pilot/dirac-pilot.py b/Pilot/dirac-pilot.py index 9c434c97..58c35640 100644 --- a/Pilot/dirac-pilot.py +++ b/Pilot/dirac-pilot.py @@ -41,6 +41,7 @@ getCommand, pythonPathCheck, ) + from Pilot.proxyTools import revokePilotToken except ImportError: from pilotTools import ( Logger, @@ -49,6 +50,7 @@ getCommand, pythonPathCheck, ) + from proxyTools import revokePilotToken ############################ if __name__ == "__main__": @@ -124,3 +126,15 @@ if remote: log.buffer.flush() sys.exit(-1) + + log.info("Pilot tasks finished.") + + if pilotParams.jwt: + if not pilotParams.isLegacyPilot: + log.info("Revoking pilot token.") + revokePilotToken( + pilotParams.diracXServer, + pilotParams.pilotUUID, + pilotParams.jwt, + pilotParams.clientID + ) diff --git a/Pilot/pilotCommands.py b/Pilot/pilotCommands.py index 945a6b78..72acce29 100644 --- a/Pilot/pilotCommands.py +++ b/Pilot/pilotCommands.py @@ -549,7 +549,20 @@ def __init__(self, pilotParams): @logFinalizer def execute(self): - """Calls dirac-admin-add-pilot""" + """Calls dirac-admin-add-pilot + + Deprecated in DIRAC V8, new mechanism in V9 and DiracX.""" + + if self.pp.jwt: + if not self.pp.isLegacyPilot: + self.log.warn("Skipping module, normally it is already done via DiracX secret-exchange.") + return + + # If we're here, this is a legacy pilot with a DiracX token embedded in it. + # TODO: See if we do a dirac-admin-add-pilot in DiracX for legacy pilots + else: + # If we're here, this is a DIRAC only pilot without diracX token embedded in it. + pass if not self.pp.pilotReference: self.log.warn("Skipping module, no pilot reference found") @@ -1232,3 +1245,4 @@ def execute(self): """Standard entry point to a pilot command""" self._setNagiosOptions() self._runNagiosProbes() + diff --git a/Pilot/pilotTools.py b/Pilot/pilotTools.py index 8afe0f62..e5c1c12c 100644 --- a/Pilot/pilotTools.py +++ b/Pilot/pilotTools.py @@ -69,9 +69,23 @@ def load_module_from_path(module_name, path_to_module): basestring = str try: - from Pilot.proxyTools import getVO + from Pilot.proxyTools import ( + getVO, + BaseRequest, + TokenBasedRequest, + extract_diracx_payload, + refreshPilotToken, + refreshUserToken + ) except ImportError: - from proxyTools import getVO + from proxyTools import ( + getVO, + BaseRequest, + TokenBasedRequest, + extract_diracx_payload, + refreshPilotToken, + refreshUserToken + ) try: FileNotFoundError # pylint: disable=used-before-assignment @@ -908,10 +922,14 @@ def __init__(self): self.site = "" self.setup = "" self.configServer = "" + self.diracXServer = "" self.ceName = "" self.ceType = "" self.queueName = "" self.gridCEType = "" + self.pilotSecret = "" + self.clientID = "" + self.jwt = {} # maxNumberOfProcessors: the number of # processors allocated to the pilot which the pilot can allocate to one payload # used to set payloadProcessors unless other limits are reached (like the number of processors on the WN) @@ -946,6 +964,7 @@ def __init__(self): self.pilotCFGFile = "pilot.json" self.pilotLogging = False self.loggerURL = None + self.isLegacyPilot = False self.loggerTimerInterval = 0 self.loggerBufsize = 1000 self.pilotUUID = "unknown" @@ -996,6 +1015,7 @@ def __init__(self): ("y:", "CEType=", "CE Type (normally InProcess)"), ("z", "pilotLogging", "Activate pilot logging system"), ("C:", "configurationServer=", "Configuration servers to use"), + ("", "diracx_URL=", "DiracX Server URL to use"), ("D:", "disk=", "Require at least MB available"), ("E:", "commandExtensions=", "Python modules with extra commands"), ("F:", "pilotCFGFile=", "Specify pilot CFG file"), @@ -1021,6 +1041,8 @@ def __init__(self): ("", "preinstalledEnvPrefix=", "preinstalled pilot environment area prefix"), ("", "architectureScript=", "architecture script to use"), ("", "CVMFS_locations=", "comma-separated list of CVMS locations"), + ("", "pilotSecret=", "secret that the pilot uses with DiracX"), + ("", "clientID=", "client id used by DiracX to revoke a token"), ) # Possibly get Setup and JSON URL/filename from command line @@ -1047,6 +1069,73 @@ def __init__(self): self.installEnv["X509_USER_PROXY"] = self.certsLocation os.environ["X509_USER_PROXY"] = self.certsLocation + try: + self.__get_diracx_jwt() + except Exception as e: + self.log.error("Error setting DiracX: %s" % e) + # Remove all settings to prevent using it. + self.diracXServer = None + self.pilotSecret = None + self.loggerURL = None + self.jwt = {} + self.log.error("Won't use DiracX.") + + def __get_diracx_jwt(self): + # Pilot auth: two cases + # 1. Has a secret (DiracX Pilot), exchange for a token + # 2. Legacy Pilot, has a proxy with a DiracX section in it (extract the jwt from it) + if self.pilotUUID and self.pilotSecret and self.diracXServer: + self.log.info("Fetching JWT in DiracX (URL: %s)" % self.diracXServer) + + config = BaseRequest( + "%s/api/auth/secret-exchange" % ( + self.diracXServer + ), + os.getenv("X509_CERT_DIR"), + self.pilotUUID + ) + + try: + self.jwt = config.executeRequest({ + "pilot_stamp": self.pilotUUID, + "pilot_secret": self.pilotSecret + }) + except HTTPError as e: + self.log.error("Request failed: %s" % str(e)) + self.log.error("Could not fetch pilot tokens.") + if e.code == 401: + # First test if the error occurred because of "bad pilot_stamp" + # If so, this pilot is in the vacuum case + # So we redo auth, but this time with the right data for vacuum cases + self.log.error("Retrying with vacuum case data...") + self.jwt = config.executeRequest({ + "pilot_stamp": self.pilotUUID, + "pilot_secret": self.pilotSecret, + "vo": self.wnVO, + "grid_type": self.gridCEType, + "grid_site": self.site, + "status": "Running" + }) + else: + raise RuntimeError("Can't be a vacuum case.") + + self.log.info("Fetched the pilot token with the pilot secret.") + self.isLegacyPilot = False + elif self.pilotUUID and self.diracXServer: + # Try to extract a token for proxy + self.log.info("Trying to extract diracx token from proxy.") + + cert = os.getenv("X509_USER_PROXY") + if cert: + with open(cert, "rb") as fp: + self.jwt = extract_diracx_payload(fp.read()) + self.isLegacyPilot = True + self.log.info("Successfully extracted token from proxy.") + else: + raise RuntimeError("Could not locate a proxy via X509_USER_PROXY") + else: + self.log.info("PilotUUID, pilotSecret, and diracXServer are needed to support DiracX.") + def __setSecurityDir(self, envName, dirLocation): """Set the environment variable of the `envName`, and add it also to the Pilot Parameters @@ -1151,6 +1240,8 @@ def __initCommandLine2(self): self.keepPythonPath = True elif o in ("-C", "--configurationServer"): self.configServer = v + elif o == "--diracx_URL": + self.diracXServer = v elif o in ("-G", "--Group"): self.userGroup = v elif o in ("-x", "--execute"): @@ -1224,6 +1315,10 @@ def __initCommandLine2(self): self.architectureScript = v elif o == "--CVMFS_locations": self.CVMFS_locations = v.split(",") + elif o == "--pilotSecret": + self.pilotSecret = v + elif o == "--clientID": + self.clientID = v def __loadJSON(self): """ diff --git a/Pilot/proxyTools.py b/Pilot/proxyTools.py index a5fa652e..6d22f67f 100644 --- a/Pilot/proxyTools.py +++ b/Pilot/proxyTools.py @@ -1,16 +1,37 @@ -"""few functions for dealing with proxies""" +"""few functions for dealing with proxies and authentication""" from __future__ import absolute_import, division, print_function +import json +import os +import time import re -from base64 import b16decode +import ssl +import sys +from base64 import b16decode, b64decode from subprocess import PIPE, Popen +from random import randint + +try: + IsADirectoryError # pylint: disable=used-before-assignment +except NameError: + IsADirectoryError = IOError + +try: + from urllib.parse import urlencode + from urllib.error import HTTPError + from urllib.request import Request, urlopen +except ImportError: + from urllib import urlencode + from urllib2 import HTTPError, Request, urlopen VOMS_FQANS_OID = b"1.3.6.1.4.1.8005.100.100.4" VOMS_EXTENSION_OID = b"1.3.6.1.4.1.8005.100.100.5" RE_OPENSSL_ANS1_FORMAT = re.compile(br"^\s*\d+:d=(\d+)\s+hl=") +MAX_REQUEST_RETRIES = 10 # If a request failed (503 error), we retry +MAX_TIME_BETWEEN_TRIES = 20 # 20 seconds max between each request def parseASN1(data): cmd = ["openssl", "asn1parse", "-inform", "der"] @@ -30,15 +51,10 @@ def findExtension(oid, lines): def getVO(proxy_data): """Fetches the VO in a chain certificate - Args: - proxy_data (bytes): Bytes for the proxy chain - - Raises: - Exception: Any error related to openssl - NotImplementedError: Not documented error - - Returns: - str: A VO + :param proxy_data: Bytes for the proxy chain + :type proxy_data: bytes + :return: A VO + :rtype: str """ chain = re.findall(br"-----BEGIN CERTIFICATE-----\n.+?\n-----END CERTIFICATE-----", proxy_data, flags=re.DOTALL) @@ -65,3 +81,344 @@ def getVO(proxy_data): if match: return match.groups()[0].decode() raise NotImplementedError("Something went very wrong") + +def extract_diracx_payload(proxy_data): + """Extracts and decodes the DIRACX section from proxy data + + :param proxy_data: The full proxy content (str or bytes) + :return: Parsed DIRACX payload as dict + :rtype: dict + """ + if isinstance(proxy_data, bytes): + proxy_data = proxy_data.decode('utf-8') + + # 1. Extract the DIRACX block + match = re.search(r"-----BEGIN DIRACX-----(.*?)-----END DIRACX-----", proxy_data, re.DOTALL) + if not match: + raise ValueError("DIRACX section not found") + + # 2. Remove whitespaces/newlines and base64-decode the inner content + b64_data = ''.join(match.group(1).strip().splitlines()) + + # 3. Base64 decode + try: + decoded = b64decode(b64_data) + except Exception as e: + raise ValueError("Base64 decoding failed: %s" % str(e)) + + # 4. JSON decode + try: + payload = json.loads(decoded) + except Exception as e: + raise ValueError("JSON decoding failed: %s" % str(e)) + + return payload + +class BaseRequest(object): + """This class helps supporting multiple kinds of requests that require connections""" + + def __init__(self, url, caPath, pilotUUID, name="unknown"): + self.name = name + self.url = url + self.caPath = caPath + self.headers = { + "User-Agent": "Dirac Pilot [Unknown ID]" + } + self.pilotUUID = pilotUUID + # We assume we have only one context, so this variable could be shared to avoid opening n times a cert. + # On the contrary, to avoid race conditions, we do avoid using "self.data" and "self.headers" + self._context = None + + self._prepareRequest() + + def generateUserAgent(self): + """To analyse the traffic, we can send a taylor-made User-Agent""" + self.addHeader("User-Agent", "Dirac Pilot [%s]" % self.pilotUUID) + + def _prepareRequest(self): + """As previously, loads the SSL certificates of the server (to avoid "unknown issuer")""" + # Load the SSL context + self._context = ssl.create_default_context() + self._context.load_verify_locations(capath=self.caPath) + + def addHeader(self, key, value): + """Add a header (key, value) into the request header""" + self.headers[key] = value + + def executeRequest(self, raw_data, insecure=False, content_type="json", json_output=True): + + tries_left = MAX_REQUEST_RETRIES + + while (tries_left > 0): + try: + return self.__execute_raw_request( + raw_data=raw_data, + insecure=insecure, + content_type=content_type, + json_output=json_output + ) + except HTTPError as e: + if e.code >= 500 and e.code < 600: + # If we have an 5XX error (server overloaded), we retry + # To avoid DOS-ing the server, we retry few seconds later + time.sleep(randint(1, MAX_TIME_BETWEEN_TRIES)) + else: + raise e + + tries_left -= 1 + + raise RuntimeError("Too much tries. Server down.") + + def __execute_raw_request(self, raw_data, insecure=False, content_type="json", json_output=True): + """Execute a HTTP request with the data, headers, and the pre-defined data (SSL + auth) + + :param raw_data: Data to send + :type raw_data: dict + :param insecure: Deactivate proxy verification WARNING Debug ONLY + :type insecure: bool + :param content_type: Data format to send, either "json" or "x-www-form-urlencoded" or "query" + :type content_type: str + :param json_output: If we have an output + :type json_output: bool + :return: Parsed JSON response + :rtype: dict + """ + if content_type == "json": + data = json.dumps(raw_data).encode("utf-8") + self.addHeader("Content-Type", "application/json") + self.addHeader("Content-Length", str(len(data))) + else: + + data = urlencode(raw_data) + + if content_type == "x-www-form-urlencoded": + if sys.version_info.major == 3: + data = urlencode(raw_data).encode("utf-8") # encode to bytes ! for python3 + + self.addHeader("Content-Type", "application/x-www-form-urlencoded") + self.addHeader("Content-Length", str(len(data))) + elif content_type == "query": + self.url = self.url + "?" + data + data = None # No body + else: + raise ValueError("Invalid content_type. Use 'json' or 'x-www-form-urlencoded'.") + + + request = Request(self.url, data=data, headers=self.headers, method="POST") + + ctx = self._context # Save in case of an insecure request + + if insecure: + # DEBUG ONLY + # Overrides context + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + + if sys.version_info.major == 3: + # Python 3 code + with urlopen(request, context=ctx) as res: + response_data = res.read().decode("utf-8") # Decode response bytes + else: + # Python 2 code + res = urlopen(request, context=ctx) + try: + response_data = res.read() + finally: + res.close() + + if json_output: + try: + return json.loads(response_data) # Parse JSON response + except ValueError: # In Python 2, json.JSONDecodeError is a subclass of ValueError + raise ValueError("Invalid JSON response: %s" % response_data) + + +class TokenBasedRequest(BaseRequest): + """Connected Request with JWT support""" + + def __init__(self, diracx_URL, endpoint_path, caPath, jwtData, pilotUUID): + + url = diracx_URL + endpoint_path + + super(TokenBasedRequest, self).__init__(url, caPath, pilotUUID, "TokenBasedConnection") + self.jwtData = jwtData + self.diracx_URL = diracx_URL + self.endpoint_path = endpoint_path + self.addJwtToHeader() + + def addJwtToHeader(self): + # Adds the JWT in the HTTP request (in the Bearer field) + self.headers["Authorization"] = "Bearer %s" % self.jwtData["access_token"] + + def executeRequest(self, raw_data, insecure=False, content_type="json", json_output=True, tries_left=1, refresh_callback=None): + + while (tries_left >= 0): + + try: + return super(TokenBasedRequest, self).executeRequest( + raw_data, + insecure=insecure, + content_type=content_type, + json_output=json_output + ) + except HTTPError as e: + if e.code != 401: + raise e + + # If we have an unauthorized error, then refresh and retry + if refresh_callback: + refresh_callback() + + self.addJwtToHeader() + + tries_left -= 1 + + raise RuntimeError("Too much tries. Can't refresh my token.") + +class X509BasedRequest(BaseRequest): + """Connected Request with X509 support""" + + def __init__(self, url, caPath, certEnv, pilotUUID): + super(X509BasedRequest, self).__init__(url, caPath, pilotUUID, "X509BasedConnection") + + self.certEnv = certEnv + self._hasExtraCredentials = False + + # Load X509 once + try: + self._context.load_cert_chain(self.certEnv) + except IsADirectoryError: # assuming it'a dir containing cert and key + self._context.load_cert_chain( + os.path.join(self.certEnv, "hostcert.pem"), os.path.join(self.certEnv, "hostkey.pem") + ) + self._hasExtraCredentials = True + + def executeRequest(self, raw_data, insecure=False, content_type="json", json_output=True): + # Adds a flag if the passed cert is a Directory + if self._hasExtraCredentials: + raw_data["extraCredentials"] = '"hosts"' + return super(X509BasedRequest, self).executeRequest( + raw_data, + insecure=insecure, + content_type=content_type, + json_output=json_output + ) + +def refreshUserToken(url, pilotUUID, jwt, clientID): + """ + Refresh the JWT token (as a user). + + :param str url: Server URL + :param str pilotUUID: Pilot unique ID + :param dict jwt: Shared dict with current JWT; updated in-place + :return: None + """ + + # PRECONDITION: jwt must contain "refresh_token" + if not jwt or "refresh_token" not in jwt: + raise ValueError("To refresh a token, a pilot needs a JWT with refresh_token") + + # Get CA path from environment + caPath = os.getenv("X509_CERT_DIR") + + # Create request object with required configuration + config = BaseRequest( + url=url + "api/auth/token", + caPath=caPath, + pilotUUID=pilotUUID, + ) + + # Perform the request to refresh the token + response = config.executeRequest( + raw_data={ + "refresh_token": jwt["refresh_token"], + "grant_type": "refresh_token", + "client_id": clientID + }, + content_type="x-www-form-urlencoded", + ) + + # Do NOT assign directly, because jwt is a reference, not a copy + jwt["access_token"] = response["access_token"] + jwt["refresh_token"] = response["refresh_token"] + +def refreshPilotToken(url, pilotUUID, jwt, _=None): + """ + Refresh the JWT token (as a pilot). + + :param str url: Server URL + :param str pilotUUID: Pilot unique ID + :param dict jwt: Shared dict with current JWT; updated in-place + :return: None + """ + + # PRECONDITION: jwt must contain "refresh_token" + if not jwt or "refresh_token" not in jwt: + raise ValueError("To refresh a token, a pilot needs a JWT with refresh_token") + + # Get CA path from environment + caPath = os.getenv("X509_CERT_DIR") + + # Create request object with required configuration + config = BaseRequest( + url=url + "api/auth/pilot-token", + caPath=caPath, + pilotUUID=pilotUUID, + ) + + # Perform the request to refresh the token + response = config.executeRequest( + raw_data={ + "refresh_token": jwt["refresh_token"], + "pilot_stamp": pilotUUID + }, + insecure=True, + ) + + # Do NOT assign directly, because jwt is a reference, not a copy + jwt["access_token"] = response["access_token"] + jwt["refresh_token"] = response["refresh_token"] + +def revokePilotToken(url, pilotUUID, jwt, clientID): + """ + Refresh the JWT token in a separate thread. + + :param str url: Server URL + :param str pilotUUID: Pilot unique ID + :param str clientID: ClientID used to revoke tokens + :param dict jwt: Shared dict with current JWT; + :return: None + """ + + # PRECONDITION: jwt must contain "refresh_token" + if not jwt or "refresh_token" not in jwt: + raise ValueError("To refresh a token, a pilot needs a JWT with refresh_token") + + # Get CA path from environment + caPath = os.getenv("X509_CERT_DIR") + + if not url.endswith("/"): + url = url + "/" + + # Create request object with required configuration + config = BaseRequest( + url="%sapi/auth/revoke" % url, + caPath=caPath, + pilotUUID=pilotUUID + ) + + # Prepare refresh token payload + payload = { + "refresh_token": jwt["refresh_token"], + "client_id": clientID + } + + # Perform the request to revoke the token + _response = config.executeRequest( + raw_data=payload, + insecure=True, + content_type="query", + json_output=False + ) From db4a1bfc7a1156e73495b73813b895abe1be8a8b Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Thu, 31 Jul 2025 10:10:53 +0200 Subject: [PATCH 2/3] feat: Add pilot pilot logging (legacy and DiracX) --- Pilot/dirac-pilot.py | 32 +++++--- Pilot/pilotCommands.py | 13 ++-- Pilot/pilotTools.py | 163 +++++++++++++++++++++++++++-------------- 3 files changed, 134 insertions(+), 74 deletions(-) diff --git a/Pilot/dirac-pilot.py b/Pilot/dirac-pilot.py index 58c35640..9b4e6a70 100644 --- a/Pilot/dirac-pilot.py +++ b/Pilot/dirac-pilot.py @@ -65,25 +65,34 @@ # print the buffer, so we have a "classic' logger back in sync. sys.stdout.write(bufContent) # now the remote logger. - remote = pilotParams.pilotLogging and (pilotParams.loggerURL is not None) - if remote: + remote = pilotParams.pilotLogging and pilotParams.diracXServer + if remote and pilotParams.jwt != {}: # In a remote logger enabled Dirac version we would have some classic logger content from a wrapper, # which we passed in: receivedContent = "" if not sys.stdin.isatty(): receivedContent = sys.stdin.read() + log = RemoteLogger( - pilotParams.loggerURL, + pilotParams.diracXServer, "Pilot", bufsize=pilotParams.loggerBufsize, pilotUUID=pilotParams.pilotUUID, debugFlag=pilotParams.debugFlag, - wnVO=pilotParams.wnVO, + jwt=pilotParams.jwt, + legacy_logging=pilotParams.isLegacyLogging, + clientID=pilotParams.clientID ) log.info("Remote logger activated") - log.buffer.write(receivedContent) + log.buffer.write(log.format_to_json( + "INFO", + receivedContent, + )) log.buffer.flush() - log.buffer.write(bufContent) + log.buffer.write(log.format_to_json( + "INFO", + bufContent, + )) else: log = Logger("Pilot", debugFlag=pilotParams.debugFlag) @@ -106,7 +115,7 @@ log.info("Executing commands: %s" % str(pilotParams.commands)) - if remote: + if remote and pilotParams.jwt: # It's safer to cancel the timer here. Each command has got its own logger object with a timer cancelled by the # finaliser. No need for a timer in the "else" code segment below. try: @@ -124,13 +133,16 @@ log.error("Command %s could not be instantiated" % commandName) # send the last message and abandon ship. if remote: - log.buffer.flush() + log.buffer.flush(force=True) sys.exit(-1) - + log.info("Pilot tasks finished.") if pilotParams.jwt: - if not pilotParams.isLegacyPilot: + if remote: + log.buffer.flush(force=True) + + if not pilotParams.isLegacyLogging: log.info("Revoking pilot token.") revokePilotToken( pilotParams.diracXServer, diff --git a/Pilot/pilotCommands.py b/Pilot/pilotCommands.py index 72acce29..4db83f41 100644 --- a/Pilot/pilotCommands.py +++ b/Pilot/pilotCommands.py @@ -28,7 +28,6 @@ def __init__(self, pilotParams): import sys import time import traceback -import subprocess from collections import Counter ############################ @@ -44,7 +43,6 @@ def __init__(self, pilotParams): from shlex import quote except ImportError: from pipes import quote - try: from Pilot.pilotTools import ( CommandBase, @@ -92,16 +90,20 @@ def wrapper(self): self.log.info( "Flushing the remote logger buffer for pilot on sys.exit(): %s (exit code:%s)" % (pRef, str(exCode)) ) - self.log.buffer.flush() # flush the buffer unconditionally (on sys.exit()). + try: - sendMessage(self.log.url, self.log.pilotUUID, self.log.wnVO, "finaliseLogs", {"retCode": str(exCode)}) + self.log.error(str(exCode)) + self.log.error(traceback.format_exc()) + self.log.buffer.flush(force=True) except Exception as exc: self.log.error("Remote logger couldn't be finalised %s " % str(exc)) + raise except Exception as exc: # unexpected exit: document it and bail out. self.log.error(str(exc)) self.log.error(traceback.format_exc()) + self.log.buffer.flush(force=True) raise finally: self.log.buffer.cancelTimer() @@ -132,7 +134,6 @@ def __init__(self, pilotParams): @logFinalizer def execute(self): """Get host and local user info, and other basic checks, e.g. space available""" - self.log.info("Uname = %s" % " ".join(os.uname())) self.log.info("Host Name = %s" % socket.gethostname()) self.log.info("Host FQDN = %s" % socket.getfqdn()) @@ -1126,8 +1127,6 @@ def execute(self): self.__setInnerCEOpts() self.__startJobAgent() - sys.exit(0) - class NagiosProbes(CommandBase): """Run one or more Nagios probe scripts that follow the Nagios Plugin API: diff --git a/Pilot/pilotTools.py b/Pilot/pilotTools.py index e5c1c12c..84e379e4 100644 --- a/Pilot/pilotTools.py +++ b/Pilot/pilotTools.py @@ -540,7 +540,9 @@ def __init__( pilotUUID="unknown", flushInterval=10, bufsize=1000, - wnVO="unknown", + jwt = {}, + legacy_logging = False, + clientID = "" ): """ c'tor @@ -550,10 +552,27 @@ def __init__( super(RemoteLogger, self).__init__(name, debugFlag, pilotOutput) self.url = url self.pilotUUID = pilotUUID - self.wnVO = wnVO self.isPilotLoggerOn = isPilotLoggerOn - sendToURL = partial(sendMessage, url, pilotUUID, wnVO, "sendMessage") - self.buffer = FixedSizeBuffer(sendToURL, bufsize=bufsize, autoflush=flushInterval) + sendToURL = partial(sendMessage, url, pilotUUID, legacy_logging, clientID) + self.buffer = FixedSizeBuffer(sendToURL, bufsize=bufsize, autoflush=flushInterval, jwt=jwt) + + def format_to_json(self, level, message): + + escaped = json.dumps(message)[1:-1] # remove outer quotes + + # Split on escaped newlines + splitted_message = escaped.split("\\n") + + output = [] + for mess in splitted_message: + if mess: + output.append({ + "timestamp": datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "severity": level, + "message": mess, + "scope": self.name + }) + return output def debug(self, msg, header=True, _sendPilotLog=False): # TODO: Send pilot log remotely? @@ -561,25 +580,25 @@ def debug(self, msg, header=True, _sendPilotLog=False): if ( self.isPilotLoggerOn and self.debugFlag ): # the -d flag activates this debug flag in CommandBase via PilotParams - self.sendMessage(self.messageTemplate.format(level="DEBUG", message=msg)) + self.sendMessage(self.format_to_json(level="DEBUG", message=msg)) def error(self, msg, header=True, _sendPilotLog=False): # TODO: Send pilot log remotely? super(RemoteLogger, self).error(msg, header) if self.isPilotLoggerOn: - self.sendMessage(self.messageTemplate.format(level="ERROR", message=msg)) + self.sendMessage(self.format_to_json(level="ERROR", message=msg)) def warn(self, msg, header=True, _sendPilotLog=False): # TODO: Send pilot log remotely? super(RemoteLogger, self).warn(msg, header) if self.isPilotLoggerOn: - self.sendMessage(self.messageTemplate.format(level="WARNING", message=msg)) + self.sendMessage(self.format_to_json(level="WARNING", message=msg)) def info(self, msg, header=True, _sendPilotLog=False): # TODO: Send pilot log remotely? super(RemoteLogger, self).info(msg, header) if self.isPilotLoggerOn: - self.sendMessage(self.messageTemplate.format(level="INFO", message=msg)) + self.sendMessage(self.format_to_json(level="INFO", message=msg)) def sendMessage(self, msg): """ @@ -591,7 +610,7 @@ def sendMessage(self, msg): :rtype: None """ try: - self.buffer.write(msg + "\n") + self.buffer.write(msg) except Exception as err: super(RemoteLogger, self).error("Message not sent") super(RemoteLogger, self).error(str(err)) @@ -618,7 +637,7 @@ class FixedSizeBuffer(object): Once it's full, a message is sent to a remote server and the buffer is renewed. """ - def __init__(self, senderFunc, bufsize=1000, autoflush=10): + def __init__(self, senderFunc, bufsize=250, autoflush=10, jwt={}): """ Constructor. @@ -636,34 +655,34 @@ def __init__(self, senderFunc, bufsize=1000, autoflush=10): self._timer.start() else: self._timer = None - self.output = StringIO() + self.output = [] self.bufsize = bufsize self._nlines = 0 self.senderFunc = senderFunc + self.jwt = jwt + # A fixed buffer used by a remote buffer can be deactivated: + # If there's a 403/401 error, instead of crashing the pilot, + # we will deactivate the log sending, and prefer just running the pilot. + self.activated = True @synchronized - def write(self, text): + def write(self, content_json): """ Write text to a string buffer. Newline characters are counted and number of lines in the buffer is increased accordingly. - :param text: text string to write - :type text: str + :param content_json: Json to send, format following format_to_json + :type content_json: list[dict] :return: None :rtype: None """ - # reopen the buffer in a case we had to flush a partially filled buffer - if self.output.closed: - self.output = StringIO() - self.output.write(text) - self._nlines += max(1, text.count("\n")) + if not self.activated: + pass + + self.output.extend(content_json) + self._nlines += max(1, len(content_json)) self.sendFullBuffer() - @synchronized - def getValue(self): - content = self.output.getvalue() - return content - @synchronized def sendFullBuffer(self): """ @@ -673,22 +692,26 @@ def sendFullBuffer(self): if self._nlines >= self.bufsize: self.flush() - self.output = StringIO() + self.output = [] @synchronized - def flush(self): + def flush(self, force=False): """ Flush the buffer and send log records to a remote server. The buffer is closed as well. :return: None :rtype: None """ - if not self.output.closed and self._nlines > 0: - self.output.flush() - buf = self.getValue() - self.senderFunc(buf) + if not self.activated: + pass + + if force or (self.output and self._nlines > 0): + try: + self.senderFunc(self.jwt, self.output) + except Exception as e: + print("Deactivating fixed size buffer due to", str(e)) + self.activated = False self._nlines = 0 - self.output.close() def cancelTimer(self): """ @@ -701,40 +724,60 @@ def cancelTimer(self): self._timer.cancel() -def sendMessage(url, pilotUUID, wnVO, method, rawMessage): +def sendMessage(diracx_URL, pilotUUID, legacy=False, clientID="", jwt={}, rawMessage = []): """ Invoke a remote method on a Tornado server and pass a JSON message to it. :param str url: Server URL :param str pilotUUID: pilot unique ID - :param str wnVO: VO name, relevant only if not contained in a proxy :param str method: a method to be invoked :param str rawMessage: a message to be sent, in JSON format + :param dict jwt: JWT for the requests :return: None. """ + caPath = os.getenv("X509_CERT_DIR") - cert = os.getenv("X509_USER_PROXY") - - context = ssl.create_default_context() - context.load_verify_locations(capath=caPath) - message = json.dumps((json.dumps(rawMessage), pilotUUID, wnVO)) + raw_data = { + "pilot_stamp": pilotUUID, + "lines": rawMessage + } - try: - context.load_cert_chain(cert) # this is a proxy - raw_data = {"method": method, "args": message} - except IsADirectoryError: # assuming it'a dir containing cert and key - context.load_cert_chain(os.path.join(cert, "hostcert.pem"), os.path.join(cert, "hostkey.pem")) - raw_data = {"method": method, "args": message, "extraCredentials": '"hosts"'} - - if sys.version_info.major == 3: - data = urlencode(raw_data).encode("utf-8") # encode to bytes ! for python3 + if not diracx_URL.endswith("/"): + diracx_URL += "/" + + if legacy: + endpoint_path = "api/pilots/legacy/message" + refresh_callback = partial( + refreshUserToken, + diracx_URL, + pilotUUID, + jwt, + clientID + ) else: - # Python2 - data = urlencode(raw_data) + endpoint_path = "api/pilots/internal/message" + refresh_callback = partial( + refreshPilotToken, + diracx_URL, + pilotUUID, + jwt + ) + + config = TokenBasedRequest( + diracx_URL=diracx_URL, + endpoint_path=endpoint_path, + caPath=caPath, + jwtData=jwt, + pilotUUID=pilotUUID + ) - res = urlopen(url, data, context=context) - res.close() + # Do the request + _res = config.executeRequest( + raw_data=raw_data, + json_output=False, + refresh_callback=refresh_callback + ) class CommandBase(object): @@ -754,7 +797,7 @@ def __init__(self, pilotParams): self.debugFlag = pilotParams.debugFlag loggerURL = pilotParams.loggerURL # URL present and the flag is set: - isPilotLoggerOn = pilotParams.pilotLogging and (loggerURL is not None) + isPilotLoggerOn = pilotParams.pilotLogging and pilotParams.diracXServer interval = pilotParams.loggerTimerInterval bufsize = pilotParams.loggerBufsize @@ -763,13 +806,15 @@ def __init__(self, pilotParams): else: # remote logger self.log = RemoteLogger( - loggerURL, + self.pp.diracXServer, self.__class__.__name__, pilotUUID=pilotParams.pilotUUID, debugFlag=self.debugFlag, flushInterval=interval, bufsize=bufsize, - wnVO=pilotParams.wnVO, + jwt=pilotParams.jwt, + legacy_logging=pilotParams.isLegacyLogging, + clientID=pilotParams.clientID ) self.log.isPilotLoggerOn = isPilotLoggerOn @@ -816,11 +861,15 @@ def executeAndGetOutput(self, cmd, environDict=None): if stream == _p.stderr: sys.stderr.write(outChunk) sys.stderr.flush() + # TODO: See if wee need also to log here else: sys.stdout.write(outChunk) sys.stdout.flush() if hasattr(self.log, "buffer") and self.log.isPilotLoggerOn: - self.log.buffer.write(outChunk) + self.log.buffer.write(self.log.format_to_json( + "COMMAND", + outChunk + )) outData += outChunk # If no data was read on any of the pipes then the process has finished if not dataWasRead: @@ -846,7 +895,7 @@ def exitWithError(self, errorCode): self.log.info("List of child processes of current PID:") retCode, _outData = self.executeAndGetOutput( - "ps --forest -o pid,%%cpu,%%mem,tty,stat,time,cmd -g %d" % os.getpid() + "ps --forest -o pid,%%cpu,%%mem,tty,stat,time,cmd --ppid %d" % os.getpid() ) if retCode: self.log.error("Failed to issue ps [ERROR %d] " % retCode) @@ -966,7 +1015,7 @@ def __init__(self): self.loggerURL = None self.isLegacyPilot = False self.loggerTimerInterval = 0 - self.loggerBufsize = 1000 + self.loggerBufsize = 250 self.pilotUUID = "unknown" self.modules = "" self.userEnvVariables = "" From c7101cd5b8d4d30dc4a5e22d1dd40d2cb3022c16 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Thu, 7 Aug 2025 14:56:04 +0200 Subject: [PATCH 3/3] fix: Small fix to the legacy pilot --- Pilot/dirac-pilot.py | 4 ++-- Pilot/pilotTools.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Pilot/dirac-pilot.py b/Pilot/dirac-pilot.py index 9b4e6a70..04a44e98 100644 --- a/Pilot/dirac-pilot.py +++ b/Pilot/dirac-pilot.py @@ -80,7 +80,7 @@ pilotUUID=pilotParams.pilotUUID, debugFlag=pilotParams.debugFlag, jwt=pilotParams.jwt, - legacy_logging=pilotParams.isLegacyLogging, + legacy_logging=pilotParams.isLegacyPilot, clientID=pilotParams.clientID ) log.info("Remote logger activated") @@ -142,7 +142,7 @@ if remote: log.buffer.flush(force=True) - if not pilotParams.isLegacyLogging: + if not pilotParams.isLegacyPilot: log.info("Revoking pilot token.") revokePilotToken( pilotParams.diracXServer, diff --git a/Pilot/pilotTools.py b/Pilot/pilotTools.py index 84e379e4..6f8d4f9a 100644 --- a/Pilot/pilotTools.py +++ b/Pilot/pilotTools.py @@ -813,7 +813,7 @@ def __init__(self, pilotParams): flushInterval=interval, bufsize=bufsize, jwt=pilotParams.jwt, - legacy_logging=pilotParams.isLegacyLogging, + legacy_logging=pilotParams.isLegacyPilot, clientID=pilotParams.clientID )