Skip to content

Commit 9781f7b

Browse files
authored
Merge pull request #150 from common-workflow-lab/select_first_function
Fixes: #71 Fixes: #104 Fixes: #103 Closes: #72 Support WDL.Type.Struct support select_first function support div support mul support ceil support size with input of array support size with array of files support time_minutes with size and ceil function
1 parent 43302f4 commit 9781f7b

20 files changed

+2320
-696
lines changed

wdl2cwl/main.py

Lines changed: 173 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
from ruamel.yaml import scalarstring
1313
from ruamel.yaml.main import YAML
1414

15-
from . import _logger
16-
from .errors import WDLSourceLine
15+
from wdl2cwl import _logger
16+
from wdl2cwl.errors import WDLSourceLine
1717

1818
valid_js_identifier = regex.compile(
1919
r"^(?!(?:do|if|in|for|let|new|try|var|case|else|enum|eval|null|this|true|"
@@ -69,6 +69,32 @@ def get_cwl_type(input_type: WDL.Type.Base) -> str:
6969
return type_of
7070

7171

72+
def get_mem_in_bytes(unit: str) -> str:
73+
"""Determine the value of a memory unit in bytes."""
74+
with WDLSourceLine(unit, ConversionException):
75+
if unit == "KiB" or unit == "Ki":
76+
mem_in_bytes = "1024^1"
77+
elif unit == "MiB" or unit == "Mi":
78+
mem_in_bytes = "1024^2"
79+
elif unit == "GiB" or unit == "Gi":
80+
mem_in_bytes = "1024^3"
81+
elif unit == "TiB" or unit == "Ti":
82+
mem_in_bytes = "1024^4"
83+
elif unit == "B":
84+
mem_in_bytes = "1024^0"
85+
elif unit == "KB" or unit == "K":
86+
mem_in_bytes = "1000^1"
87+
elif unit == "MB" or unit == "M":
88+
mem_in_bytes = "1000^2"
89+
elif unit == "GB" or unit == "G":
90+
mem_in_bytes = "1000^3"
91+
elif unit == "TB" or unit == "T":
92+
mem_in_bytes = "1000^4"
93+
else:
94+
raise ConversionException(f"Invalid memory unit: ${unit}")
95+
return mem_in_bytes
96+
97+
7298
def get_outdir_requirement(outdir: Union[WDL.Expr.Get, WDL.Expr.Apply]) -> int:
7399
"""Produce the memory requirement for the output directory from WDL runtime disks."""
74100
# This is yet to be implemented. After Feature Parity.
@@ -125,7 +151,7 @@ def get_literal_name(
125151
# if the literal expr is used inside WDL.Expr.Apply
126152
# the literal value is what's needed
127153
parent = expr.parent # type: ignore[union-attr]
128-
if isinstance(parent, WDL.Expr.Apply):
154+
if isinstance(parent, (WDL.Expr.Apply, WDL.Expr.IfThenElse)):
129155
return expr.literal.value # type: ignore
130156
raise WDLSourceLine(expr, ConversionException).makeError(
131157
f"The parent expression for {expr} is not WDL.Expr.Apply, but {parent}."
@@ -289,9 +315,6 @@ def load_wdl_task(self, obj: WDL.Tree.Task) -> cwl.CommandLineTool:
289315
_logger.warning("Skipping parameter_meta: %s", obj.parameter_meta)
290316
if obj.meta:
291317
_logger.warning("Skipping meta: %s", obj.meta)
292-
if len(obj.postinputs) > 0:
293-
for a in obj.postinputs:
294-
_logger.warning("Skipping variable: %s", a)
295318
return cwl.CommandLineTool(
296319
id=obj.name,
297320
inputs=cwl_inputs,
@@ -356,7 +379,7 @@ def get_time_minutes_requirement(
356379
self, time_minutes: WDL.Expr.Get
357380
) -> Union[str, int]:
358381
"""Produce the time limit expression from WDL runtime time minutes."""
359-
with WDLSourceLine(time_minutes, WDLSourceLine):
382+
with WDLSourceLine(time_minutes, ConversionException):
360383
if isinstance(time_minutes, (WDL.Expr.Int, WDL.Expr.Float)):
361384
literal = time_minutes.literal.value # type: ignore
362385
return literal * 60 # type: ignore
@@ -379,7 +402,10 @@ def get_memory_literal(self, memory_runtime: WDL.Expr.String) -> float:
379402
if memory_runtime.literal is None:
380403
_, placeholder, unit, _ = memory_runtime.parts
381404
with WDLSourceLine(placeholder, ConversionException):
382-
value_name = self.get_expr_get(placeholder.expr) # type: ignore
405+
if isinstance(placeholder.expr, WDL.Expr.Get): # type: ignore
406+
value_name = self.get_expr_get(placeholder.expr) # type: ignore
407+
else:
408+
value_name = self.get_expr_apply(placeholder.expr) # type: ignore
383409
return self.get_ram_min_js(value_name, unit.strip()) # type: ignore
384410

385411
ram_min = self.get_expr_string(memory_runtime)[1:-1]
@@ -388,27 +414,8 @@ def get_memory_literal(self, memory_runtime: WDL.Expr.String) -> float:
388414
raise ConversionException("Missing Memory units, yet still a string?")
389415
unit = unit_result.group()
390416
value = float(ram_min.split(unit)[0])
391-
392-
if unit == "KiB":
393-
memory = value / 1024
394-
elif unit == "MiB":
395-
memory = value
396-
elif unit == "GiB":
397-
memory = value * 1024
398-
elif unit == "TiB":
399-
memory = value * 1024 * 1024
400-
elif unit == "B":
401-
memory = value / (1024 * 1024)
402-
elif unit == "KB" or unit == "K":
403-
memory = (value * 1000) / (1024 * 1024)
404-
elif unit == "MB" or unit == "M":
405-
memory = (value * (1000 * 1000)) / (1024 * 1024)
406-
elif unit == "GB" or unit == "G":
407-
memory = (value * (1000 * 1000 * 1000)) / (1024 * 1024)
408-
elif unit == "TB" or unit == "T":
409-
memory = (value * (1000 * 1000 * 1000 * 1000)) / (1024 * 1024)
410-
else:
411-
raise ConversionException(f"Invalid memory unit: ${unit}")
417+
byte, power = get_mem_in_bytes(unit).split("^")
418+
memory: float = value * float(byte) ** float(power) / (1024 * 1024)
412419

413420
return memory
414421

@@ -586,9 +593,25 @@ def get_expr_apply(self, wdl_apply_expr: WDL.Expr.Apply) -> str:
586593
return f"{iterable_object_expr}[{index_expr}]"
587594
elif function_name == "_gt":
588595
left_operand, right_operand = arguments
589-
left_operand_expr = self.get_expr_apply(left_operand) # type: ignore
596+
if isinstance(left_operand, WDL.Expr.Get):
597+
left_operand_expr = self.get_expr(left_operand)
598+
else:
599+
left_operand_expr = self.get_expr_apply(left_operand) # type: ignore
590600
right_operand_expr = self.get_expr(right_operand)
591601
return f"{left_operand_expr} > {right_operand_expr}"
602+
elif function_name == "_lt":
603+
left_operand, right_operand = arguments
604+
if isinstance(left_operand, WDL.Expr.Get):
605+
left_operand_expr = self.get_expr(left_operand)
606+
else:
607+
left_operand_expr = self.get_expr_apply(left_operand) # type: ignore
608+
right_operand_expr = self.get_expr(right_operand)
609+
return f"{left_operand_expr} < {right_operand_expr}"
610+
elif function_name == "_lor":
611+
left_operand, right_operand = arguments
612+
left_operand_expr = self.get_expr_apply(left_operand) # type: ignore
613+
right_operand_expr = self.get_expr(right_operand)
614+
return f"{left_operand_expr} || {right_operand_expr}"
592615
elif function_name == "length":
593616
only_arg_expr = self.get_expr_get(arguments[0]) # type: ignore
594617
return f"{only_arg_expr}.length"
@@ -602,27 +625,77 @@ def get_expr_apply(self, wdl_apply_expr: WDL.Expr.Apply) -> str:
602625
elif function_name == "read_string":
603626
only_arg = arguments[0]
604627
return self.get_expr(only_arg)
628+
elif function_name == "read_float":
629+
only_arg = arguments[0]
630+
return self.get_expr(only_arg)
605631
elif function_name == "glob":
606632
only_arg = arguments[0]
607633
return self.get_expr(only_arg)
634+
elif function_name == "select_first":
635+
array_obj = arguments[0]
636+
array_items = [self.get_expr(item) for item in array_obj.items] # type: ignore
637+
items_str = ", ".join(array_items)
638+
return f"[{items_str}].find(element => element !== null) "
639+
elif function_name == "_mul":
640+
left_operand, right_operand = arguments
641+
left_str = self.get_expr(left_operand)
642+
right_str = self.get_expr(right_operand)
643+
return f"{left_str}*{right_str}"
644+
elif function_name == "_eqeq":
645+
left_operand, right_operand = arguments
646+
left_str = self.get_expr(left_operand)
647+
right_str = self.get_expr(right_operand)
648+
return f"{left_str} === {right_str}"
649+
elif function_name == "ceil":
650+
only_arg = self.get_expr(arguments[0]) # type: ignore
651+
return f"Math.ceil({only_arg}) "
652+
elif function_name == "_div":
653+
left_operand, right_operand = arguments
654+
left_str = self.get_expr(left_operand)
655+
right_str = self.get_expr(right_operand)
656+
return f"{left_str}/{right_str}"
657+
elif function_name == "_sub":
658+
left_operand, right_operand = arguments
659+
left_str = self.get_expr(left_operand)
660+
right_str = self.get_expr(right_operand)
661+
return f"{left_str}-{right_str}"
662+
elif function_name == "size":
663+
left_operand, right_operand = arguments
664+
if isinstance(left_operand, WDL.Expr.Array):
665+
array_items = [self.get_expr(item) for item in left_operand.items]
666+
left = ", ".join(array_items)
667+
left_str = f"[{left}]"
668+
else:
669+
left_str = self.get_expr(left_operand)
670+
size_unit = self.get_expr(right_operand)[1:-1]
671+
unit_value = get_mem_in_bytes(size_unit)
672+
return (
673+
"(function(size_of=0)"
674+
+ "{"
675+
+ f"{left_str}.forEach(function(element)"
676+
+ "{ if (element) {"
677+
+ "size_of += element.size"
678+
+ "}})}"
679+
+ f") / {unit_value}"
680+
)
608681

609682
raise WDLSourceLine(wdl_apply_expr, ConversionException).makeError(
610683
f"Function name '{function_name}' not yet handled."
611684
)
612685

613686
def get_expr_get(self, wdl_get_expr: WDL.Expr.Get) -> str:
614687
"""Translate WDL Get Expressions."""
615-
with WDLSourceLine(wdl_get_expr, ConversionException):
616-
member = wdl_get_expr.member
617-
if (
618-
not member
619-
and isinstance(wdl_get_expr.expr, WDL.Expr.Ident)
620-
and wdl_get_expr.expr
621-
):
622-
return self.get_expr_ident(wdl_get_expr.expr)
623-
raise ConversionException(
624-
f"Get expressions with {member} are not yet handled."
625-
)
688+
member = wdl_get_expr.member
689+
690+
if not member:
691+
return self.get_expr_ident(wdl_get_expr.expr) # type: ignore
692+
struct_name = self.get_expr(wdl_get_expr.expr)
693+
member_str = f"{struct_name}.{member}"
694+
return (
695+
member_str
696+
if not isinstance(wdl_get_expr.type, WDL.Type.File)
697+
else f"{member_str}.path"
698+
)
626699

627700
def get_expr_ident(self, wdl_ident_expr: WDL.Expr.Ident) -> str:
628701
"""Translate WDL Ident Expressions."""
@@ -766,22 +839,37 @@ def get_cwl_task_inputs(
766839
input_name = wdl_input.name
767840
self.non_static_values.add(input_name)
768841
input_value = None
769-
type_of: Union[str, cwl.CommandInputArraySchema]
842+
type_of: Union[
843+
str, cwl.CommandInputArraySchema, cwl.CommandInputRecordSchema
844+
]
770845

771846
if hasattr(wdl_input, "value"):
772847
wdl_input = wdl_input.value # type: ignore
773848

774849
if isinstance(wdl_input.type, WDL.Type.Array):
775850
input_type = get_cwl_type(wdl_input.type.item_type)
776851
type_of = cwl.CommandInputArraySchema(items=input_type, type="array")
852+
elif isinstance(wdl_input.type, WDL.Type.StructInstance):
853+
type_of = cwl.CommandInputRecordSchema(
854+
type="record",
855+
name=wdl_input.type.type_name,
856+
fields=self.get_struct_inputs(wdl_input.type.members),
857+
)
777858
else:
778859
type_of = get_cwl_type(wdl_input.type)
779860

780861
if wdl_input.type.optional or isinstance(wdl_input.expr, WDL.Expr.Apply):
781862
final_type_of: Union[
782-
List[Union[str, cwl.CommandInputArraySchema]],
863+
List[
864+
Union[
865+
str,
866+
cwl.CommandInputArraySchema,
867+
cwl.CommandInputRecordSchema,
868+
]
869+
],
783870
str,
784871
cwl.CommandInputArraySchema,
872+
cwl.CommandInputRecordSchema,
785873
] = [type_of, "null"]
786874
if isinstance(wdl_input.expr, WDL.Expr.Apply):
787875
self.optional_cwl_null.add(input_name)
@@ -805,6 +893,24 @@ def get_cwl_task_inputs(
805893

806894
return inputs
807895

896+
def get_struct_inputs(
897+
self, members: Optional[Dict[str, WDL.Type.Base]]
898+
) -> List[cwl.CommandInputRecordField]:
899+
"""Get member items of a WDL struct and return a list of cwl.CommandInputRecordField."""
900+
inputs: List[cwl.CommandInputRecordField] = []
901+
if not members:
902+
return inputs
903+
for member, value in members.items():
904+
input_name = member
905+
if isinstance(value, WDL.Type.Array):
906+
array_items_type = value.item_type
907+
input_type = get_cwl_type(array_items_type)
908+
type_of = cwl.CommandInputArraySchema(items=input_type, type="array")
909+
else:
910+
type_of = get_cwl_type(value) # type: ignore
911+
inputs.append(cwl.CommandInputRecordField(name=input_name, type=type_of))
912+
return inputs
913+
808914
def get_cwl_task_outputs(
809915
self, wdl_outputs: List[WDL.Tree.Decl]
810916
) -> List[cwl.CommandOutputParameter]:
@@ -856,6 +962,30 @@ def get_cwl_task_outputs(
856962
),
857963
)
858964
)
965+
elif (
966+
isinstance(wdl_output.expr, WDL.Expr.Apply)
967+
and wdl_output.expr.function_name == "read_float"
968+
):
969+
glob_expr = self.get_expr(wdl_output)
970+
is_literal = wdl_output.expr.arguments[0].literal
971+
if is_literal:
972+
glob_str = glob_expr[
973+
1:-1
974+
] # remove quotes from the string returned by get_expr_string
975+
else:
976+
glob_str = f"$({glob_expr})"
977+
978+
outputs.append(
979+
cwl.CommandOutputParameter(
980+
id=output_name,
981+
type=type_of,
982+
outputBinding=cwl.CommandOutputBinding(
983+
glob=glob_str,
984+
loadContents=True,
985+
outputEval=r"$(parseFloat(self[0].contents))",
986+
),
987+
)
988+
)
859989
elif (
860990
isinstance(wdl_output.expr, WDL.Expr.Apply)
861991
and wdl_output.expr.function_name == "stdout"

wdl2cwl/tests/cwl_files/CollectQualityYieldMetrics.cwl

Lines changed: 0 additions & 38 deletions
This file was deleted.

0 commit comments

Comments
 (0)