@@ -69,9 +69,23 @@ def load_module_from_path(module_name, path_to_module):
69
69
basestring = str
70
70
71
71
try :
72
- from Pilot .proxyTools import getVO
72
+ from Pilot .proxyTools import (
73
+ getVO ,
74
+ BaseRequest ,
75
+ TokenBasedRequest ,
76
+ extract_diracx_payload ,
77
+ refreshPilotToken ,
78
+ refreshUserToken
79
+ )
73
80
except ImportError :
74
- from proxyTools import getVO
81
+ from proxyTools import (
82
+ getVO ,
83
+ BaseRequest ,
84
+ TokenBasedRequest ,
85
+ extract_diracx_payload ,
86
+ refreshPilotToken ,
87
+ refreshUserToken
88
+ )
75
89
76
90
try :
77
91
FileNotFoundError # pylint: disable=used-before-assignment
@@ -908,10 +922,14 @@ def __init__(self):
908
922
self .site = ""
909
923
self .setup = ""
910
924
self .configServer = ""
925
+ self .diracXServer = ""
911
926
self .ceName = ""
912
927
self .ceType = ""
913
928
self .queueName = ""
914
929
self .gridCEType = ""
930
+ self .pilotSecret = ""
931
+ self .clientID = ""
932
+ self .jwt = {}
915
933
# maxNumberOfProcessors: the number of
916
934
# processors allocated to the pilot which the pilot can allocate to one payload
917
935
# used to set payloadProcessors unless other limits are reached (like the number of processors on the WN)
@@ -946,6 +964,7 @@ def __init__(self):
946
964
self .pilotCFGFile = "pilot.json"
947
965
self .pilotLogging = False
948
966
self .loggerURL = None
967
+ self .isLegacyPilot = False
949
968
self .loggerTimerInterval = 0
950
969
self .loggerBufsize = 1000
951
970
self .pilotUUID = "unknown"
@@ -996,6 +1015,7 @@ def __init__(self):
996
1015
("y:" , "CEType=" , "CE Type (normally InProcess)" ),
997
1016
("z" , "pilotLogging" , "Activate pilot logging system" ),
998
1017
("C:" , "configurationServer=" , "Configuration servers to use" ),
1018
+ ("" , "diracx_URL=" , "DiracX Server URL to use" ),
999
1019
("D:" , "disk=" , "Require at least <space> MB available" ),
1000
1020
("E:" , "commandExtensions=" , "Python modules with extra commands" ),
1001
1021
("F:" , "pilotCFGFile=" , "Specify pilot CFG file" ),
@@ -1021,6 +1041,8 @@ def __init__(self):
1021
1041
("" , "preinstalledEnvPrefix=" , "preinstalled pilot environment area prefix" ),
1022
1042
("" , "architectureScript=" , "architecture script to use" ),
1023
1043
("" , "CVMFS_locations=" , "comma-separated list of CVMS locations" ),
1044
+ ("" , "pilotSecret=" , "secret that the pilot uses with DiracX" ),
1045
+ ("" , "clientID=" , "client id used by DiracX to revoke a token" ),
1024
1046
)
1025
1047
1026
1048
# Possibly get Setup and JSON URL/filename from command line
@@ -1047,6 +1069,73 @@ def __init__(self):
1047
1069
self .installEnv ["X509_USER_PROXY" ] = self .certsLocation
1048
1070
os .environ ["X509_USER_PROXY" ] = self .certsLocation
1049
1071
1072
+ try :
1073
+ self .__get_diracx_jwt ()
1074
+ except Exception as e :
1075
+ self .log .error ("Error setting DiracX: %s" % e )
1076
+ # Remove all settings to prevent using it.
1077
+ self .diracXServer = None
1078
+ self .pilotSecret = None
1079
+ self .loggerURL = None
1080
+ self .jwt = {}
1081
+ self .log .error ("Won't use DiracX." )
1082
+
1083
+ def __get_diracx_jwt (self ):
1084
+ # Pilot auth: two cases
1085
+ # 1. Has a secret (DiracX Pilot), exchange for a token
1086
+ # 2. Legacy Pilot, has a proxy with a DiracX section in it (extract the jwt from it)
1087
+ if self .pilotUUID and self .pilotSecret and self .diracXServer :
1088
+ self .log .info ("Fetching JWT in DiracX (URL: %s)" % self .diracXServer )
1089
+
1090
+ config = BaseRequest (
1091
+ "%s/api/auth/secret-exchange" % (
1092
+ self .diracXServer
1093
+ ),
1094
+ os .getenv ("X509_CERT_DIR" ),
1095
+ self .pilotUUID
1096
+ )
1097
+
1098
+ try :
1099
+ self .jwt = config .executeRequest ({
1100
+ "pilot_stamp" : self .pilotUUID ,
1101
+ "pilot_secret" : self .pilotSecret
1102
+ })
1103
+ except HTTPError as e :
1104
+ self .log .error ("Request failed: %s" % str (e ))
1105
+ self .log .error ("Could not fetch pilot tokens." )
1106
+ if e .code == 401 :
1107
+ # First test if the error occurred because of "bad pilot_stamp"
1108
+ # If so, this pilot is in the vacuum case
1109
+ # So we redo auth, but this time with the right data for vacuum cases
1110
+ self .log .error ("Retrying with vacuum case data..." )
1111
+ self .jwt = config .executeRequest ({
1112
+ "pilot_stamp" : self .pilotUUID ,
1113
+ "pilot_secret" : self .pilotSecret ,
1114
+ "vo" : self .wnVO ,
1115
+ "grid_type" : self .gridCEType ,
1116
+ "grid_site" : self .site ,
1117
+ "status" : "Running"
1118
+ })
1119
+ else :
1120
+ raise RuntimeError ("Can't be a vacuum case." )
1121
+
1122
+ self .log .info ("Fetched the pilot token with the pilot secret." )
1123
+ self .isLegacyPilot = False
1124
+ elif self .pilotUUID and self .diracXServer :
1125
+ # Try to extract a token for proxy
1126
+ self .log .info ("Trying to extract diracx token from proxy." )
1127
+
1128
+ cert = os .getenv ("X509_USER_PROXY" )
1129
+ if cert :
1130
+ with open (cert , "rb" ) as fp :
1131
+ self .jwt = extract_diracx_payload (fp .read ())
1132
+ self .isLegacyPilot = True
1133
+ self .log .info ("Successfully extracted token from proxy." )
1134
+ else :
1135
+ raise RuntimeError ("Could not locate a proxy via X509_USER_PROXY" )
1136
+ else :
1137
+ self .log .info ("PilotUUID, pilotSecret, and diracXServer are needed to support DiracX." )
1138
+
1050
1139
def __setSecurityDir (self , envName , dirLocation ):
1051
1140
"""Set the environment variable of the `envName`, and add it also to the Pilot Parameters
1052
1141
@@ -1151,6 +1240,8 @@ def __initCommandLine2(self):
1151
1240
self .keepPythonPath = True
1152
1241
elif o in ("-C" , "--configurationServer" ):
1153
1242
self .configServer = v
1243
+ elif o == "--diracx_URL" :
1244
+ self .diracXServer = v
1154
1245
elif o in ("-G" , "--Group" ):
1155
1246
self .userGroup = v
1156
1247
elif o in ("-x" , "--execute" ):
@@ -1224,6 +1315,10 @@ def __initCommandLine2(self):
1224
1315
self .architectureScript = v
1225
1316
elif o == "--CVMFS_locations" :
1226
1317
self .CVMFS_locations = v .split ("," )
1318
+ elif o == "--pilotSecret" :
1319
+ self .pilotSecret = v
1320
+ elif o == "--clientID" :
1321
+ self .clientID = v
1227
1322
1228
1323
def __loadJSON (self ):
1229
1324
"""
0 commit comments