Skip to content

Commit 6c9d083

Browse files
authored
Merge pull request #315 from djarecka/fix/ext_template
[mnt, bug] refactoring the template formatting (closes #314)
2 parents 5005173 + d94b3b3 commit 6c9d083

File tree

7 files changed

+481
-75
lines changed

7 files changed

+481
-75
lines changed

pydra/engine/core.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,9 +423,7 @@ def _collect_outputs(self):
423423
self.output_spec = output_from_inputfields(self.output_spec, self.inputs)
424424
output_klass = make_klass(self.output_spec)
425425
output = output_klass(**{f.name: None for f in attr.fields(output_klass)})
426-
other_output = output.collect_additional_outputs(
427-
self.input_spec, self.inputs, self.output_dir
428-
)
426+
other_output = output.collect_additional_outputs(self.inputs, self.output_dir)
429427
return attr.evolve(output, **run_output, **other_output)
430428

431429
def split(self, splitter, overwrite=False, **kwargs):

pydra/engine/helpers.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from hashlib import sha256
1111
import subprocess as sp
1212
import getpass
13-
import uuid
13+
import re
1414
from time import strftime
1515
from traceback import format_exception
1616

@@ -636,3 +636,36 @@ def position_adjustment(pos_args):
636636
cmd_args += el[1]
637637

638638
return cmd_args
639+
640+
641+
def argstr_formatting(argstr, inputs, value_updates=None):
642+
""" formatting argstr that have form {field_name},
643+
using values from inputs and updating with value_update if provided
644+
"""
645+
inputs_dict = attr.asdict(inputs)
646+
# if there is a value that has to be updated (e.g. single value from a list)
647+
if value_updates:
648+
inputs_dict.update(value_updates)
649+
# getting all fields that should be formatted, i.e. {field_name}, ...
650+
inp_fields = re.findall("{\w+}", argstr)
651+
val_dict = {}
652+
for fld in inp_fields:
653+
fld_name = fld[1:-1] # extracting the name form {field_name}
654+
fld_value = inputs_dict[fld_name]
655+
if fld_value is attr.NOTHING:
656+
# if value is NOTHING, nothing should be added to the command
657+
val_dict[fld_name] = ""
658+
else:
659+
val_dict[fld_name] = fld_value
660+
661+
# formatting string based on the val_dict
662+
argstr_formatted = argstr.format(**val_dict)
663+
# removing extra commas and spaces after removing the field that have NOTHING
664+
argstr_formatted = (
665+
argstr_formatted.replace("[ ", "[")
666+
.replace(" ]", "]")
667+
.replace("[,", "[")
668+
.replace(",]", "]")
669+
.strip()
670+
)
671+
return argstr_formatted

pydra/engine/helpers_file.py

Lines changed: 83 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -521,43 +521,93 @@ def template_update(inputs, map_copyfiles=None):
521521
f"fields with output_file_template"
522522
"has to be a string or Union[str, bool]"
523523
)
524-
inp_val_set = getattr(inputs, fld.name)
524+
dict_[fld.name] = template_update_single(field=fld, inputs_dict=dict_)
525+
# using is and == so it covers list and numpy arrays
526+
updated_templ_dict = {
527+
k: v
528+
for k, v in dict_.items()
529+
if not (getattr(inputs, k) is v or getattr(inputs, k) == v)
530+
}
531+
return updated_templ_dict
532+
533+
534+
def template_update_single(field, inputs_dict, spec_type="input"):
535+
"""Update a single template from the input_spec or output_spec
536+
based on the value from inputs_dict
537+
(checking the types of the fields, that have "output_file_template)"
538+
"""
539+
from .specs import File
540+
541+
if spec_type == "input":
542+
if field.type not in [str, ty.Union[str, bool]]:
543+
raise Exception(
544+
f"fields with output_file_template"
545+
"has to be a string or Union[str, bool]"
546+
)
547+
inp_val_set = inputs_dict[field.name]
525548
if inp_val_set is not attr.NOTHING and not isinstance(inp_val_set, (str, bool)):
526-
raise Exception(f"{fld.name} has to be str or bool, but {inp_val_set} set")
527-
if isinstance(inp_val_set, bool) and fld.type is str:
528549
raise Exception(
529-
f"type of {fld.name} is str, consider using Union[str, bool]"
550+
f"{field.name} has to be str or bool, but {inp_val_set} set"
530551
)
552+
if isinstance(inp_val_set, bool) and field.type is str:
553+
raise Exception(
554+
f"type of {field.name} is str, consider using Union[str, bool]"
555+
)
556+
elif spec_type == "output":
557+
if field.type is not File:
558+
raise Exception(
559+
f"output {field.name} should be a File, but {field.type} set as the type"
560+
)
561+
else:
562+
raise Exception(f"spec_type can be input or output, but {spec_type} provided")
563+
if spec_type == "input" and isinstance(inputs_dict[field.name], str):
564+
return inputs_dict[field.name]
565+
elif spec_type == "input" and inputs_dict[field.name] is False:
566+
# if input fld is set to False, the fld shouldn't be used (setting NOTHING)
567+
return attr.NOTHING
568+
else: # inputs_dict[field.name] is True or spec_type is output
569+
template = field.metadata["output_file_template"]
570+
# as default, we assume that keep_extension is True
571+
keep_extension = field.metadata.get("keep_extension", True)
572+
value = _template_formatting(
573+
template, inputs_dict, keep_extension=keep_extension
574+
)
575+
return value
531576

532-
if isinstance(inp_val_set, str):
533-
dict_[fld.name] = inp_val_set
534-
elif inp_val_set is False:
535-
# if False, the field should not be used, so setting attr.NOTHING
536-
dict_[fld.name] = attr.NOTHING
537-
else: # True or attr.NOTHING
538-
template = fld.metadata["output_file_template"]
539-
value = template.format(**dict_)
540-
value = removing_nothing(value)
541-
dict_[fld.name] = value
542-
return {k: v for k, v in dict_.items() if getattr(inputs, k) is not v}
543-
544-
545-
def removing_nothing(template_str):
546-
""" removing all fields that had NOTHING"""
547-
if "NOTHING" not in template_str:
548-
return template_str
549-
regex = re.compile(r"[^a-zA-Z_\-]")
550-
fields_str = regex.sub(" ", template_str)
551-
for fld in fields_str.split():
552-
if "NOTHING" in fld:
553-
template_str = template_str.replace(fld, "")
554-
return (
555-
template_str.replace("[ ", "[")
556-
.replace(" ]", "]")
557-
.replace(",]", "]")
558-
.replace("[,", "[")
559-
.strip()
560-
)
577+
578+
def _template_formatting(template, inputs_dict, keep_extension=True):
579+
"""Formatting a single template based on values from inputs_dict.
580+
Taking into account that field values and template could have file extensions
581+
(assuming that if template has extension, the field value extension is removed,
582+
if field has extension, and no template extension, than it is moved to the end),
583+
"""
584+
inp_fields = re.findall("{\w+}", template)
585+
if len(inp_fields) == 0:
586+
return template
587+
elif len(inp_fields) == 1:
588+
fld_name = inp_fields[0][1:-1]
589+
fld_value = inputs_dict[fld_name]
590+
if fld_value is attr.NOTHING:
591+
return attr.NOTHING
592+
fld_value = str(fld_value) # in case it's a path
593+
filename, *ext = fld_value.split(".", maxsplit=1)
594+
# if keep_extension is False, the extensions are removed
595+
if keep_extension is False:
596+
ext = []
597+
if template.endswith(inp_fields[0]):
598+
# if no suffix added in template, the simplest formatting should work
599+
# recreating fld_value with the updated extension
600+
fld_value_upd = ".".join([filename] + ext)
601+
formatted_value = template.format(**{fld_name: fld_value_upd})
602+
elif "." not in template: # the template doesn't have its own extension
603+
# if the fld_value has extension, it will be moved to the end
604+
formatted_value = ".".join([template.format(**{fld_name: filename})] + ext)
605+
else: # template has its own extension
606+
# removing fld_value extension if any
607+
formatted_value = template.format(**{fld_name: filename})
608+
return formatted_value
609+
else:
610+
raise NotImplementedError("should we allow for more args in the template?")
561611

562612

563613
def is_local_file(f):

pydra/engine/specs.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from pathlib import Path
44
import typing as ty
55

6+
from .helpers_file import template_update_single
7+
68

79
def attr_fields(x):
810
return x.__attrs_attrs__
@@ -33,7 +35,7 @@ class SpecInfo:
3335
class BaseSpec:
3436
"""The base dataclass specs for all inputs and outputs."""
3537

36-
def collect_additional_outputs(self, input_spec, inputs, output_dir):
38+
def collect_additional_outputs(self, inputs, output_dir):
3739
"""Get additional outputs."""
3840
return {}
3941

@@ -213,6 +215,7 @@ def check_metadata(self):
213215
"position",
214216
"requires",
215217
"separate_ext",
218+
"keep_extension",
216219
"xor",
217220
"sep",
218221
}
@@ -322,7 +325,7 @@ class ShellOutSpec(BaseSpec):
322325
stderr: ty.Union[File, str]
323326
"""The process' standard input."""
324327

325-
def collect_additional_outputs(self, input_spec, inputs, output_dir):
328+
def collect_additional_outputs(self, inputs, output_dir):
326329
"""Collect additional outputs from shelltask output_spec."""
327330
additional_out = {}
328331
for fld in attr_fields(self):
@@ -379,24 +382,13 @@ def _field_metadata(self, fld, inputs, output_dir):
379382
"""Collect output file if metadata specified."""
380383
if "value" in fld.metadata:
381384
return output_dir / fld.metadata["value"]
385+
# this block is only run if "output_file_template" is provided in output_spec
386+
# if the field is set in input_spec with output_file_template,
387+
# than the field already should have value
382388
elif "output_file_template" in fld.metadata:
383-
sfx_tmpl = (output_dir / fld.metadata["output_file_template"]).suffixes
384-
if sfx_tmpl:
385-
# removing suffix from input field if template has it's own suffix
386-
inputs_templ = {
387-
k: v.split(".")[0]
388-
for k, v in inputs.__dict__.items()
389-
if isinstance(v, str)
390-
}
391-
else:
392-
inputs_templ = {
393-
k: v for k, v in inputs.__dict__.items() if isinstance(v, str)
394-
}
395-
out_path = output_dir / fld.metadata["output_file_template"].format(
396-
**inputs_templ
397-
)
398-
return out_path
399-
389+
inputs_templ = attr.asdict(inputs)
390+
value = template_update_single(fld, inputs_templ, spec_type="output")
391+
return output_dir / value
400392
elif "callable" in fld.metadata:
401393
return fld.metadata["callable"](fld.name, output_dir)
402394
else:

pydra/engine/task.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@
5757
SingularitySpec,
5858
attr_fields,
5959
)
60-
from .helpers import ensure_list, execute, position_adjustment
61-
from .helpers_file import template_update, is_local_file, removing_nothing
60+
from .helpers import ensure_list, execute, position_adjustment, argstr_formatting
61+
from .helpers_file import template_update, is_local_file
6262

6363

6464
class FunctionTask(TaskBase):
@@ -431,38 +431,40 @@ def _command_pos_args(self, field, state_ind, ind):
431431

432432
cmd_add = []
433433
if field.type is bool:
434+
# if value is simply True the original argstr is used,
435+
# if False, nothing is added to the command
434436
if value is True:
435437
cmd_add.append(argstr)
436438
else:
437439
sep = field.metadata.get("sep", " ")
438440
if argstr.endswith("...") and isinstance(value, list):
439441
argstr = argstr.replace("...", "")
442+
# if argstr has a more complex form, with "{input_field}"
440443
if "{" in argstr and "}" in argstr:
441444
argstr_formatted_l = []
442445
for val in value:
443-
argstr_f = argstr.format(**{field.name: val}).format(
444-
**attr.asdict(self.inputs)
446+
argstr_f = argstr_formatting(
447+
argstr, self.inputs, value_updates={field.name: val}
445448
)
446-
argstr_formatted_l.append(removing_nothing(argstr_f))
447-
449+
argstr_formatted_l.append(argstr_f)
448450
cmd_el_str = sep.join(argstr_formatted_l)
449-
else:
451+
else: # argstr has a simple form, e.g. "-f", or "--f"
450452
cmd_el_str = sep.join([f" {argstr} {val}" for val in value])
451453
else:
452454
# in case there are ... when input is not a list
453455
argstr = argstr.replace("...", "")
454456
if isinstance(value, list):
455457
cmd_el_str = sep.join([str(val) for val in value])
456458
value = cmd_el_str
457-
459+
# if argstr has a more complex form, with "{input_field}"
458460
if "{" in argstr and "}" in argstr:
459-
argstr_f = argstr.format(**attr.asdict(self.inputs))
460-
cmd_el_str = removing_nothing(argstr_f)
461-
else:
461+
cmd_el_str = argstr_formatting(argstr, self.inputs)
462+
else: # argstr has a simple form, e.g. "-f", or "--f"
462463
if value:
463464
cmd_el_str = f"{argstr} {value}"
464465
else:
465466
cmd_el_str = ""
467+
# removing double spacing
466468
cmd_el_str = cmd_el_str.strip().replace(" ", " ")
467469
if cmd_el_str:
468470
cmd_add += cmd_el_str.split(" ")

0 commit comments

Comments
 (0)