Skip to content
Open
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ This JSON file specifies the plugin and Slurm configuration parameters.
"LogLevel": "STRING",
"LogFileName": "STRING",
"SlurmBinPath": "STRING",
"NodeNameStartsFrom1": BOOLEAN,
"NodeNameStartsWithNodeGroupName": BOOLEAN,
"SlurmConf": {
"PrivateData": "STRING",
"ResumeProgram": "STRING",
Expand All @@ -60,6 +62,8 @@ This JSON file specifies the plugin and Slurm configuration parameters.
* `LogLevel`: Logging level. Possible values are `CRITICAL`, `ERROR`, `WARNING`, `INFO`, `DEBUG`. Default is `DEBUG`.
* `LogFileName`: Full path to the log file location. Default is `PLUGIN_PATH\aws_plugin.log`.
* `SlurmBinPath`: Full path to the folder that contains Slurm binaries like `scontrol` or `sinfo`. Example: `/slurm/bin`.
* `NodeNameStartsFrom1`: Optional. By default node number starts from 0, like aws-c5_24xlarge-0. This flag changes this behavior to start from 0.
* `NodeNameStartsWithNodeGroupName`: Optional. By default node name starts with partition name followed by "-". This flag changes this behavior to allow node names be node group name followed by a number like: c5_24xlarge001.
* `SlurmConf`: These attributes are used by `generate_conf.py` to generate the content that must be appended to the Slurm configuration file. You must specify at least the following attributes:
* `PrivateData`: Must be equal to `CLOUD` such that EC2 compute nodes that are idle are returned by Slurm command outputs such as `sinfo`.
* `ResumeProgram`: Full path to the location of `resume.py`. Example: `/slurm/etc/aws/resume.py`.
Expand Down
84 changes: 74 additions & 10 deletions common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@

import boto3

# set to use proxy and AWS credentials
# os.environ["https_proxy"] = "https://user:password@proxy-server:port"
# os.environ["AWS_ACCESS_KEY_ID"] = ""
# os.environ["AWS_SECRET_ACCESS_KEY"] = ""


dir_path = os.path.dirname(os.path.realpath(__file__)) # Folder where resides the Python files

logger = None # Global variable for the logging.Logger object
config = None # Global variable for the config parameters
partitions = None # Global variable that stores partitions details


# Create and return a logging.Logger object
# - scriptname: name of the module
# - levelname: log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
Expand Down Expand Up @@ -142,6 +146,7 @@ def get_common(scriptname):

# Create a logger
logger = get_logger(scriptname, config['LogLevel'], config['LogFileName'])
# comment to reduce output
logger.debug('Config: %s' %json.dumps(config, indent=4))

# Validate the structure of config.json
Expand Down Expand Up @@ -171,6 +176,7 @@ def get_common(scriptname):
sys.exit(1)
finally:
partitions = partitions_json['Partitions']
# comment to reduce output
logger.debug('Partitions: %s' %json.dumps(partitions_json, indent=4))

return logger, config, partitions
Expand All @@ -191,11 +197,17 @@ def get_node_name(partition, nodegroup, node_id=''):
nodegroup_name = nodegroup['NodeGroupName']
else:
nodegroup_name = nodegroup

if node_id == '':
return '%s-%s' %(partition_name, nodegroup_name)

if config['NodeNameStartsWithNodeGroupName']:
if node_id == '':
return '%s' %(nodegroup_name)
else:
return '%s%s' %(nodegroup_name, node_id)
else:
return '%s-%s-%s' %(partition_name, nodegroup_name, node_id)
if node_id == '':
return '%s-%s' %(partition_name, nodegroup_name)
else:
return '%s-%s-%s' %(partition_name, nodegroup_name, node_id)


# Return the name of a node [partition_name]-[nodegroup_name][id]
Expand All @@ -208,7 +220,18 @@ def get_node_range(partition, nodegroup, nb_nodes=None):
nb_nodes = nodegroup['MaxNodes']

if nb_nodes > 1:
return '%s-[0-%s]' %(get_node_name(partition, nodegroup), nb_nodes-1)
if config['NodeNameStartsFrom1']:
if config['NodeNameStartsWithNodeGroupName']:
digits=len(str(nb_nodes))
return '%s[%s-%s]' %(get_node_name(partition, nodegroup),str(1).zfill(digits),nb_nodes)
else:
return '%s-[1-%s]' %(get_node_name(partition, nodegroup), nb_nodes)
else:
if config['NodeNameStartsWithNodeGroupName']:
digits=len(str(nb_nodes))
return '%s[%s-%s]' %(get_node_name(partition, nodegroup),str(0).zfill(digits),nb_nodes-1)
else:
return '%s-[0-%s]' %(get_node_name(partition, nodegroup), nb_nodes-1)
else:
return '%s-0' %(get_node_name(partition, nodegroup))

Expand Down Expand Up @@ -241,14 +264,23 @@ def expand_hostlist(hostlist):
# Take a list of node names in input and return a dict with result[partition_name][nodegroup_name] = list of node ids
def parse_node_names(node_names):
result = {}

if config['NodeNameStartsWithNodeGroupName']:
groups = 2
pattern = '^([-a-zA-Z0-9]+[-a-zA-z])([0-9]+)$'
else:
pattern = '^([a-zA-Z0-9]+)-([a-zA-Z0-9]+)-([0-9]+)$'

for node_name in node_names:

# For each node: extract partition name, node group name and node id
pattern = '^([a-zA-Z0-9]+)-([a-zA-Z0-9]+)-([0-9]+)$'
match = re.match(pattern, node_name)
if match:
partition_name, nodegroup_name, node_id = match.groups()

if groups == 2:
nodegroup_name, node_id = match.groups()
partition_name = get_partition_name(nodegroup_name)
else:
partition_name, nodegroup_name, node_id = match.groups()

# Add to result
if not partition_name in result:
result[partition_name] = {}
Expand All @@ -271,6 +303,16 @@ def get_partition_nodegroup(partition_name, nodegroup_name):
# Return None if it does not exist
return None

# Return partition name based on node group name
def get_partition_name(nodegroup_name):

for partition in partitions:
for nodegroup in partition['NodeGroups']:
if nodegroup['NodeGroupName'] == nodegroup_name:
return partition['PartitionName']
# ReturnNone if it doesn't exist
return None


# Use 'scontrol update node' to update nodes
def update_node(node_name, parameters):
Expand Down Expand Up @@ -303,3 +345,25 @@ def get_ec2_client(nodegroup):
sys.exit(1)
else:
return boto3.client('ec2', region_name=nodegroup['Region'])


def check_ha():
args = ['show', 'config']
out = run_scommand('scontrol', args)
ctld_hosts = 0
for line in out:
if 'SlurmctldHost' in line:
ctld_hosts += 1
if 'ClusterName' in line:
cluster_name = line.split('= ')[-1] # get last list element
if ctld_hosts > 1:
args = ['show', 'cluster', cluster_name]
out = run_scommand('sacctmgr', args)
for line in out:
if cluster_name in line:
primary_ip = re.sub(r'([ ])(\1+)', r'\1', line).split(' ')[2]
hostname = socket.gethostname()
host_ip = socket.gethostbyname(hostname)
if primary_ip != host_ip:
logger.info('This host is not the primary slurmctld. Exiting...')
sys.exit(1)