1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements. See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License. You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
20 from collections import Iterable
22 from networkx import DiGraph, topological_sort
24 from ....utils.uuid import generate_uuid
25 from . import task as api_task
28 class TaskNotInGraphError(Exception):
30 An error representing a scenario where a given task is not in the graph as expected.
35 def _filter_out_empty_tasks(func=None):
37 return lambda f: _filter_out_empty_tasks(func=f)
39 def _wrapper(task, *tasks, **kwargs):
40 return func(*(t for t in (task,) + tuple(tasks) if t), **kwargs)
44 class TaskGraph(object):
49 def __init__(self, name):
51 self._id = generate_uuid(variant='uuid')
52 self._graph = DiGraph()
55 return '{name}(id={self._id}, name={self.name}, graph={self._graph!r})'.format(
56 name=self.__class__.__name__, self=self)
65 # graph traversal methods
70 Iterator over tasks in the graph.
72 for _, data in self._graph.nodes_iter(data=True):
75 def topological_order(self, reverse=False):
77 Topological sort of the graph.
79 :param reverse: whether to reverse the sort
80 :return: list which represents the topological sort
82 for task_id in topological_sort(self._graph, reverse=reverse):
83 yield self.get_task(task_id)
85 def get_dependencies(self, dependent_task):
87 Iterates over the task's dependencies.
89 :param dependent_task: task whose dependencies are requested
90 :raises ~aria.orchestrator.workflows.api.task_graph.TaskNotInGraphError: if
91 ``dependent_task`` is not in the graph
93 if not self.has_tasks(dependent_task):
94 raise TaskNotInGraphError('Task id: {0}'.format(dependent_task.id))
95 for _, dependency_id in self._graph.out_edges_iter(dependent_task.id):
96 yield self.get_task(dependency_id)
98 def get_dependents(self, dependency_task):
100 Iterates over the task's dependents.
102 :param dependency_task: task whose dependents are requested
103 :raises ~aria.orchestrator.workflows.api.task_graph.TaskNotInGraphError: if
104 ``dependency_task`` is not in the graph
106 if not self.has_tasks(dependency_task):
107 raise TaskNotInGraphError('Task id: {0}'.format(dependency_task.id))
108 for dependent_id, _ in self._graph.in_edges_iter(dependency_task.id):
109 yield self.get_task(dependent_id)
113 def get_task(self, task_id):
115 Get a task instance that's been inserted to the graph by the task's ID.
117 :param basestring task_id: task ID
118 :raises ~aria.orchestrator.workflows.api.task_graph.TaskNotInGraphError: if no task found in
119 the graph with the given ID
121 if not self._graph.has_node(task_id):
122 raise TaskNotInGraphError('Task id: {0}'.format(task_id))
123 data = self._graph.node[task_id]
126 @_filter_out_empty_tasks
127 def add_tasks(self, *tasks):
129 Adds a task to the graph.
132 :return: list of added tasks
135 assert all([isinstance(task, (api_task.BaseTask, Iterable)) for task in tasks])
139 if isinstance(task, Iterable):
140 return_tasks += self.add_tasks(*task)
141 elif not self.has_tasks(task):
142 self._graph.add_node(task.id, task=task)
143 return_tasks.append(task)
147 @_filter_out_empty_tasks
148 def remove_tasks(self, *tasks):
150 Removes the provided task from the graph.
153 :return: list of removed tasks
159 if isinstance(task, Iterable):
160 return_tasks += self.remove_tasks(*task)
161 elif self.has_tasks(task):
162 self._graph.remove_node(task.id)
163 return_tasks.append(task)
167 @_filter_out_empty_tasks
168 def has_tasks(self, *tasks):
170 Checks whether a task is in the graph.
173 :return: ``True`` if all tasks are in the graph, otherwise ``False``
176 assert all(isinstance(t, (api_task.BaseTask, Iterable)) for t in tasks)
180 if isinstance(task, Iterable):
181 return_value &= self.has_tasks(*task)
183 return_value &= self._graph.has_node(task.id)
187 def add_dependency(self, dependent, dependency):
189 Adds a dependency for one item (task, sequence or parallel) on another.
191 The dependent will only be executed after the dependency terminates. If either of the items
192 is either a sequence or a parallel, multiple dependencies may be added.
194 :param dependent: dependent (task, sequence or parallel)
195 :param dependency: dependency (task, sequence or parallel)
196 :return: ``True`` if the dependency between the two hadn't already existed, otherwise
199 :raises ~aria.orchestrator.workflows.api.task_graph.TaskNotInGraphError: if either the
200 dependent or dependency are tasks which are not in the graph
202 if not (self.has_tasks(dependent) and self.has_tasks(dependency)):
203 raise TaskNotInGraphError()
205 if self.has_dependency(dependent, dependency):
208 if isinstance(dependent, Iterable):
209 for dependent_task in dependent:
210 self.add_dependency(dependent_task, dependency)
212 if isinstance(dependency, Iterable):
213 for dependency_task in dependency:
214 self.add_dependency(dependent, dependency_task)
216 self._graph.add_edge(dependent.id, dependency.id)
218 def has_dependency(self, dependent, dependency):
220 Checks whether one item (task, sequence or parallel) depends on another.
222 Note that if either of the items is either a sequence or a parallel, and some of the
223 dependencies exist in the graph but not all of them, this method will return ``False``.
225 :param dependent: dependent (task, sequence or parallel)
226 :param dependency: dependency (task, sequence or parallel)
227 :return: ``True`` if the dependency between the two exists, otherwise ``False``
229 :raises ~aria.orchestrator.workflows.api.task_graph.TaskNotInGraphError: if either the
230 dependent or dependency are tasks which are not in the graph
232 if not (dependent and dependency):
234 elif not (self.has_tasks(dependent) and self.has_tasks(dependency)):
235 raise TaskNotInGraphError()
239 if isinstance(dependent, Iterable):
240 for dependent_task in dependent:
241 return_value &= self.has_dependency(dependent_task, dependency)
243 if isinstance(dependency, Iterable):
244 for dependency_task in dependency:
245 return_value &= self.has_dependency(dependent, dependency_task)
247 return_value &= self._graph.has_edge(dependent.id, dependency.id)
251 def remove_dependency(self, dependent, dependency):
253 Removes a dependency for one item (task, sequence or parallel) on another.
255 Note that if either of the items is either a sequence or a parallel, and some of the
256 dependencies exist in the graph but not all of them, this method will not remove any of the
257 dependencies and return ``False``.
259 :param dependent: dependent (task, sequence or parallel)
260 :param dependency: dependency (task, sequence or parallel)
261 :return: ``False`` if the dependency between the two hadn't existed, otherwise ``True``
263 :raises ~aria.orchestrator.workflows.api.task_graph.TaskNotInGraphError: if either the
264 dependent or dependency are tasks which are not in the graph
266 if not (self.has_tasks(dependent) and self.has_tasks(dependency)):
267 raise TaskNotInGraphError()
269 if not self.has_dependency(dependent, dependency):
272 if isinstance(dependent, Iterable):
273 for dependent_task in dependent:
274 self.remove_dependency(dependent_task, dependency)
275 elif isinstance(dependency, Iterable):
276 for dependency_task in dependency:
277 self.remove_dependency(dependent, dependency_task)
279 self._graph.remove_edge(dependent.id, dependency.id)
281 @_filter_out_empty_tasks
282 def sequence(self, *tasks):
284 Creates and inserts a sequence into the graph, effectively each task i depends on i-1.
286 :param tasks: iterable of dependencies
287 :return: provided tasks
290 self.add_tasks(*tasks)
292 for i in xrange(1, len(tasks)):
293 self.add_dependency(tasks[i], tasks[i-1])