1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements. See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License. You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
21 from datetime import datetime
23 from aria import logger
24 from aria.modeling import models
25 from aria.orchestrator import events
26 from aria.orchestrator.context import operation
28 from .. import exceptions
29 from ..executor.base import StubTaskExecutor
30 # Import required so all signals are registered
31 from . import events_handler # pylint: disable=unused-import
34 class Engine(logger.LoggerMixin):
39 def __init__(self, executors, **kwargs):
40 super(Engine, self).__init__(**kwargs)
41 self._executors = executors.copy()
42 self._executors.setdefault(StubTaskExecutor, StubTaskExecutor())
44 def execute(self, ctx, resuming=False, retry_failed=False):
46 Executes the workflow.
49 events.on_resume_workflow_signal.send(ctx, retry_failed=retry_failed)
51 tasks_tracker = _TasksTracker(ctx)
54 events.start_workflow_signal.send(ctx)
56 cancel = self._is_cancel(ctx)
59 for task in tasks_tracker.ended_tasks:
60 self._handle_ended_tasks(task)
61 tasks_tracker.finished(task)
62 for task in tasks_tracker.executable_tasks:
63 tasks_tracker.executing(task)
64 self._handle_executable_task(ctx, task)
65 if tasks_tracker.all_tasks_consumed:
70 self._terminate_tasks(tasks_tracker.executing_tasks)
71 events.on_cancelled_workflow_signal.send(ctx)
73 events.on_success_workflow_signal.send(ctx)
74 except BaseException as e:
75 # Cleanup any remaining tasks
76 self._terminate_tasks(tasks_tracker.executing_tasks)
77 events.on_failure_workflow_signal.send(ctx, exception=e)
80 def _terminate_tasks(self, tasks):
83 self._executors[task._executor].terminate(task.id)
88 def cancel_execution(ctx):
90 Send a cancel request to the engine. If execution already started, execution status
91 will be modified to ``cancelling`` status. If execution is in pending mode, execution status
92 will be modified to ``cancelled`` directly.
94 events.on_cancelling_workflow_signal.send(ctx)
98 execution = ctx.model.execution.refresh(ctx.execution)
99 return execution.status in (models.Execution.CANCELLING, models.Execution.CANCELLED)
101 def _handle_executable_task(self, ctx, task):
102 task_executor = self._executors[task._executor]
104 # If the task is a stub, a default context is provided, else it should hold the context cls
105 context_cls = operation.BaseOperationContext if task._stub_type else task._context_cls
106 op_ctx = context_cls(
107 model_storage=ctx.model,
108 resource_storage=ctx.resource,
109 workdir=ctx._workdir,
111 actor_id=task.actor.id if task.actor else None,
112 service_id=task.execution.service.id,
113 execution_id=task.execution.id,
117 if not task._stub_type:
118 events.sent_task_signal.send(op_ctx)
119 task_executor.execute(op_ctx)
122 def _handle_ended_tasks(task):
123 if task.status == models.Task.FAILED and not task.ignore_failure:
124 raise exceptions.ExecutorException('Workflow failed')
127 class _TasksTracker(object):
129 def __init__(self, ctx):
132 self._tasks = ctx.execution.tasks
133 self._executed_tasks = [task for task in self._tasks if task.has_ended()]
134 self._executable_tasks = list(set(self._tasks) - set(self._executed_tasks))
135 self._executing_tasks = []
138 def all_tasks_consumed(self):
139 return len(self._executed_tasks) == len(self._tasks) and len(self._executing_tasks) == 0
141 def executing(self, task):
142 # Task executing could be retrying (thus removed and added earlier)
143 if task not in self._executing_tasks:
144 self._executable_tasks.remove(task)
145 self._executing_tasks.append(task)
147 def finished(self, task):
148 self._executing_tasks.remove(task)
149 self._executed_tasks.append(task)
152 def ended_tasks(self):
153 for task in self.executing_tasks:
158 def executable_tasks(self):
159 now = datetime.utcnow()
160 # we need both lists since retrying task are in the executing task list.
161 for task in self._update_tasks(set(self._executing_tasks + self._executable_tasks)):
162 if all([task.is_waiting(),
164 all(dependency in self._executed_tasks for dependency in task.dependencies)
169 def executing_tasks(self):
170 for task in self._update_tasks(self._executing_tasks):
174 def executed_tasks(self):
175 for task in self._update_tasks(self._executed_tasks):
180 for task in self._update_tasks(self._tasks):
183 def _update_tasks(self, tasks):
185 yield self._ctx.model.task.refresh(task)