diff --git a/README.md b/README.md index f054714..d7abe20 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,8 @@ This JSON file specifies the plugin and Slurm configuration parameters. "LogLevel": "STRING", "LogFileName": "STRING", "SlurmBinPath": "STRING", + "NodeNameStartsFrom1": BOOLEAN, + "NodeNameStartsWithNodeGroupName": BOOLEAN, "SlurmConf": { "PrivateData": "STRING", "ResumeProgram": "STRING", @@ -60,6 +62,8 @@ This JSON file specifies the plugin and Slurm configuration parameters. * `LogLevel`: Logging level. Possible values are `CRITICAL`, `ERROR`, `WARNING`, `INFO`, `DEBUG`. Default is `DEBUG`. * `LogFileName`: Full path to the log file location. Default is `PLUGIN_PATH\aws_plugin.log`. * `SlurmBinPath`: Full path to the folder that contains Slurm binaries like `scontrol` or `sinfo`. Example: `/slurm/bin`. +* `NodeNameStartsFrom1`: Optional. By default node number starts from 0, like aws-c5_24xlarge-0. This flag changes this behavior to start from 0. +* `NodeNameStartsWithNodeGroupName`: Optional. By default node name starts with partition name followed by "-". This flag changes this behavior to allow node names be node group name followed by a number like: c5_24xlarge001. * `SlurmConf`: These attributes are used by `generate_conf.py` to generate the content that must be appended to the Slurm configuration file. You must specify at least the following attributes: * `PrivateData`: Must be equal to `CLOUD` such that EC2 compute nodes that are idle are returned by Slurm command outputs such as `sinfo`. * `ResumeProgram`: Full path to the location of `resume.py`. Example: `/slurm/etc/aws/resume.py`. diff --git a/common.py b/common.py index 58aa9a7..acf2914 100644 --- a/common.py +++ b/common.py @@ -7,6 +7,11 @@ import boto3 +# set to use proxy and AWS credentials +# os.environ["https_proxy"] = "https://user:password@proxy-server:port" +# os.environ["AWS_ACCESS_KEY_ID"] = "" +# os.environ["AWS_SECRET_ACCESS_KEY"] = "" + dir_path = os.path.dirname(os.path.realpath(__file__)) # Folder where resides the Python files @@ -14,7 +19,6 @@ config = None # Global variable for the config parameters partitions = None # Global variable that stores partitions details - # Create and return a logging.Logger object # - scriptname: name of the module # - levelname: log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) @@ -142,6 +146,7 @@ def get_common(scriptname): # Create a logger logger = get_logger(scriptname, config['LogLevel'], config['LogFileName']) + # comment to reduce output logger.debug('Config: %s' %json.dumps(config, indent=4)) # Validate the structure of config.json @@ -171,6 +176,7 @@ def get_common(scriptname): sys.exit(1) finally: partitions = partitions_json['Partitions'] + # comment to reduce output logger.debug('Partitions: %s' %json.dumps(partitions_json, indent=4)) return logger, config, partitions @@ -191,11 +197,17 @@ def get_node_name(partition, nodegroup, node_id=''): nodegroup_name = nodegroup['NodeGroupName'] else: nodegroup_name = nodegroup - - if node_id == '': - return '%s-%s' %(partition_name, nodegroup_name) + + if config['NodeNameStartsWithNodeGroupName']: + if node_id == '': + return '%s' %(nodegroup_name) + else: + return '%s%s' %(nodegroup_name, node_id) else: - return '%s-%s-%s' %(partition_name, nodegroup_name, node_id) + if node_id == '': + return '%s-%s' %(partition_name, nodegroup_name) + else: + return '%s-%s-%s' %(partition_name, nodegroup_name, node_id) # Return the name of a node [partition_name]-[nodegroup_name][id] @@ -208,7 +220,18 @@ def get_node_range(partition, nodegroup, nb_nodes=None): nb_nodes = nodegroup['MaxNodes'] if nb_nodes > 1: - return '%s-[0-%s]' %(get_node_name(partition, nodegroup), nb_nodes-1) + if config['NodeNameStartsFrom1']: + if config['NodeNameStartsWithNodeGroupName']: + digits=len(str(nb_nodes)) + return '%s[%s-%s]' %(get_node_name(partition, nodegroup),str(1).zfill(digits),nb_nodes) + else: + return '%s-[1-%s]' %(get_node_name(partition, nodegroup), nb_nodes) + else: + if config['NodeNameStartsWithNodeGroupName']: + digits=len(str(nb_nodes)) + return '%s[%s-%s]' %(get_node_name(partition, nodegroup),str(0).zfill(digits),nb_nodes-1) + else: + return '%s-[0-%s]' %(get_node_name(partition, nodegroup), nb_nodes-1) else: return '%s-0' %(get_node_name(partition, nodegroup)) @@ -241,14 +264,23 @@ def expand_hostlist(hostlist): # Take a list of node names in input and return a dict with result[partition_name][nodegroup_name] = list of node ids def parse_node_names(node_names): result = {} + + if config['NodeNameStartsWithNodeGroupName']: + groups = 2 + pattern = '^([-a-zA-Z0-9]+[-a-zA-z])([0-9]+)$' + else: + pattern = '^([a-zA-Z0-9]+)-([a-zA-Z0-9]+)-([0-9]+)$' + for node_name in node_names: - # For each node: extract partition name, node group name and node id - pattern = '^([a-zA-Z0-9]+)-([a-zA-Z0-9]+)-([0-9]+)$' match = re.match(pattern, node_name) if match: - partition_name, nodegroup_name, node_id = match.groups() - + if groups == 2: + nodegroup_name, node_id = match.groups() + partition_name = get_partition_name(nodegroup_name) + else: + partition_name, nodegroup_name, node_id = match.groups() + # Add to result if not partition_name in result: result[partition_name] = {} @@ -271,6 +303,16 @@ def get_partition_nodegroup(partition_name, nodegroup_name): # Return None if it does not exist return None +# Return partition name based on node group name +def get_partition_name(nodegroup_name): + + for partition in partitions: + for nodegroup in partition['NodeGroups']: + if nodegroup['NodeGroupName'] == nodegroup_name: + return partition['PartitionName'] + # ReturnNone if it doesn't exist + return None + # Use 'scontrol update node' to update nodes def update_node(node_name, parameters): @@ -303,3 +345,25 @@ def get_ec2_client(nodegroup): sys.exit(1) else: return boto3.client('ec2', region_name=nodegroup['Region']) + + +def check_ha(): + args = ['show', 'config'] + out = run_scommand('scontrol', args) + ctld_hosts = 0 + for line in out: + if 'SlurmctldHost' in line: + ctld_hosts += 1 + if 'ClusterName' in line: + cluster_name = line.split('= ')[-1] # get last list element + if ctld_hosts > 1: + args = ['show', 'cluster', cluster_name] + out = run_scommand('sacctmgr', args) + for line in out: + if cluster_name in line: + primary_ip = re.sub(r'([ ])(\1+)', r'\1', line).split(' ')[2] + hostname = socket.gethostname() + host_ip = socket.gethostbyname(hostname) + if primary_ip != host_ip: + logger.info('This host is not the primary slurmctld. Exiting...') + sys.exit(1) \ No newline at end of file