Skip to content

Commit e68e2bb

Browse files
committed
use task_factory instead of decorator
1 parent 4111917 commit e68e2bb

File tree

2 files changed

+76
-83
lines changed

2 files changed

+76
-83
lines changed

backend/apps/ifc_validation/tasks/configs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ class TaskConfig:
1313
blocks: Optional[List[str]]
1414
execution_stage: str = "parallel"
1515
process_results: Callable | None = None
16+
17+
@property
18+
def celery_task_name(self) -> str:
19+
return f"apps.ifc_validation.tasks.{self.type.name.lower()}_subtask"
1620

1721
# create blueprint
1822
def make_task(*, type, increment, field=None, stage="parallel"):

backend/apps/ifc_validation/tasks/task_runner.py

Lines changed: 72 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -91,71 +91,71 @@ def on_workflow_failed(self, *args, **kwargs):
9191
send_failure_admin_email_task.delay(id=id, file_name=request.file_name)
9292

9393

94-
def validation_task_runner(task_type):
95-
def decorator(func):
96-
@shared_task(bind=True)
97-
@log_execution
98-
@requires_django_user_context
99-
@functools.wraps(func)
100-
def wrapper(self, *args, **kwargs):
101-
id = kwargs.get('id')
102-
103-
request = ValidationRequest.objects.get(pk=id)
104-
file_path = get_absolute_file_path(request.file.name)
105-
106-
# Always create the task record, even if it will be skipped due to blocking conditions,
107-
# so it is logged and its status can be marked as 'skipped'
108-
task = ValidationTask.objects.create(request=request, type=task_type)
94+
def task_factory(task_type):
95+
config = task_registry[task_type]
96+
97+
@shared_task(bind=True, name=config.celery_task_name)
98+
@log_execution
99+
@requires_django_user_context
100+
def validation_subtask_runner(self, *args, **kwargs):
101+
102+
id = kwargs.get('id')
103+
104+
request = ValidationRequest.objects.get(pk=id)
105+
file_path = get_absolute_file_path(request.file.name)
106+
107+
# Always create the task record, even if it will be skipped due to blocking conditions,
108+
# so it is logged and its status can be marked as 'skipped'
109+
task = ValidationTask.objects.create(request=request, type=task_type)
110+
111+
if model := request.model:
112+
invalid_blockers = list(filter(
113+
lambda b: getattr(model, task_registry[b].status_field.name) == Model.Status.INVALID,
114+
task_registry.get_blockers_of(task_type)
115+
))
116+
else: # for testing, we're not instantiating a model
117+
invalid_blockers = []
118+
119+
# update progress
120+
increment = config.increment
121+
request.progress = min(request.progress + increment, 100)
122+
request.save()
123+
124+
# run or skip
125+
if not invalid_blockers:
126+
task.mark_as_initiated()
109127

110-
if model := request.model:
111-
invalid_blockers = list(filter(
112-
lambda b: getattr(model, task_registry[b].status_field.name) == Model.Status.INVALID,
113-
task_registry.get_blockers_of(task_type)
128+
# Execution Layer
129+
try:
130+
context = config.check_program(TaskContext(
131+
config=config,
132+
task=task,
133+
request=request,
134+
file_path=file_path,
114135
))
115-
else: # for testing, we're not instantiating a model
116-
invalid_blockers = []
136+
except Exception as err:
137+
task.mark_as_failed(str(err))
138+
logger.exception(f"Execution failed in task {task_type}: {task}")
139+
return
140+
141+
# Processing Layer / write to DB
142+
try:
143+
reason = config.process_results(context)
144+
task.mark_as_completed(reason)
145+
logger.debug(f"Task {task_type} completed, reason: {reason}")
146+
except Exception as err:
147+
task.mark_as_failed(str(err))
148+
logger.exception(f"Processing failed in task {task_type}: {err}")
149+
return
117150

118-
# get task configuration
119-
config = task_registry[task_type]
151+
# Handle skipped tasks
152+
else:
153+
reason = f"Skipped due to fail in blocking tasks: {', '.join(invalid_blockers)}"
154+
logger.debug(reason)
155+
task.mark_as_skipped(reason)
120156

121-
# update progress
122-
increment = config.increment
123-
request.progress = min(request.progress + increment, 100)
124-
request.save()
125-
126-
if not invalid_blockers:
127-
task.mark_as_initiated()
128-
129-
# Execution Layer
130-
try:
131-
context = config.check_program(TaskContext(
132-
config=config,
133-
task=task,
134-
request=request,
135-
file_path=file_path,
136-
))
137-
except Exception as err:
138-
task.mark_as_failed(str(err))
139-
logger.exception(f"Execution failed in task {task_type}: {task}")
140-
return
141-
142-
# Processing Layer / write to DB
143-
try:
144-
reason = config.process_results(context)
145-
task.mark_as_completed(reason)
146-
logger.debug(f"Task {task_type} completed, reason: {reason}")
147-
except Exception as err:
148-
task.mark_as_failed(str(err))
149-
logger.exception(f"Processing failed in task {task_type}: {err}")
150-
return
151-
152-
# Handle skipped tasks
153-
else:
154-
reason = f"Skipped due to fail in blocking tasks: {', '.join(invalid_blockers)}"
155-
logger.debug(reason)
156-
task.mark_as_skipped(reason)
157-
return wrapper
158-
return decorator
157+
validation_subtask_runner.__doc__ = f"Validation task for {task_type} generated by the task_factory func."
158+
return validation_subtask_runner
159159

160160

161161
@shared_task(bind=True)
@@ -202,35 +202,24 @@ def ifc_file_validation_task(self, id, file_name, *args, **kwargs):
202202
workflow.apply_async()
203203

204204

205-
@validation_task_runner(ValidationTask.Type.INSTANCE_COMPLETION)
206-
def instance_completion_subtask(): pass
205+
instance_completion_subtask = task_factory(ValidationTask.Type.INSTANCE_COMPLETION)
207206

208-
@validation_task_runner(ValidationTask.Type.NORMATIVE_IA)
209-
def normative_rules_ia_validation_subtask(): pass
207+
normative_rules_ia_validation_subtask = task_factory(ValidationTask.Type.NORMATIVE_IA)
210208

211-
@validation_task_runner(ValidationTask.Type.NORMATIVE_IP)
212-
def normative_rules_ip_validation_subtask(): pass
209+
normative_rules_ip_validation_subtask = task_factory(ValidationTask.Type.NORMATIVE_IP)
213210

214-
@validation_task_runner(ValidationTask.Type.PREREQUISITES)
215-
def prerequisites_subtask(): pass
211+
prerequisites_subtask = task_factory(ValidationTask.Type.PREREQUISITES)
216212

217-
@validation_task_runner(ValidationTask.Type.SYNTAX)
218-
def syntax_validation_subtask(): pass
213+
syntax_validation_subtask = task_factory(ValidationTask.Type.SYNTAX)
219214

220-
@validation_task_runner(ValidationTask.Type.HEADER_SYNTAX)
221-
def header_syntax_validation_subtask(): pass
215+
header_syntax_validation_subtask = task_factory(ValidationTask.Type.HEADER_SYNTAX)
222216

223-
@validation_task_runner(ValidationTask.Type.SCHEMA)
224-
def schema_validation_subtask(): pass
217+
schema_validation_subtask = task_factory(ValidationTask.Type.SCHEMA)
225218

226-
@validation_task_runner(ValidationTask.Type.HEADER)
227-
def header_validation_subtask(): pass
219+
header_validation_subtask = task_factory(ValidationTask.Type.HEADER)
228220

229-
@validation_task_runner(ValidationTask.Type.DIGITAL_SIGNATURES)
230-
def digital_signatures_subtask(): pass
221+
digital_signatures_subtask = task_factory(ValidationTask.Type.DIGITAL_SIGNATURES)
231222

232-
@validation_task_runner(ValidationTask.Type.BSDD)
233-
def bsdd_validation_subtask(): pass
223+
bsdd_validation_subtask = task_factory(ValidationTask.Type.BSDD)
234224

235-
@validation_task_runner(ValidationTask.Type.INDUSTRY_PRACTICES)
236-
def industry_practices_subtask(): pass
225+
industry_practices_subtask = task_factory(ValidationTask.Type.INDUSTRY_PRACTICES)

0 commit comments

Comments
 (0)