1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements. See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License. You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
23 from aria.orchestrator.workflows.executor import BaseExecutor
26 class CeleryExecutor(BaseExecutor):
31 def __init__(self, app, *args, **kwargs):
32 super(CeleryExecutor, self).__init__(*args, **kwargs)
34 self._started_signaled = False
35 self._started_queue = Queue.Queue(maxsize=1)
40 self._receiver_thread = threading.Thread(target=self._events_receiver)
41 self._receiver_thread.daemon = True
42 self._receiver_thread.start()
43 self._started_queue.get(timeout=30)
45 def _execute(self, ctx):
46 self._tasks[ctx.id] = ctx
47 arguments = dict(arg.unwrapped for arg in ctx.task.arguments.itervalues())
48 arguments['ctx'] = ctx.context
49 self._results[ctx.id] = self._app.send_task(
50 ctx.operation_mapping,
53 queue=self._get_queue(ctx))
58 self._receiver.should_stop = True
59 self._receiver_thread.join()
63 return None if task else None # TODO
65 def _events_receiver(self):
66 with self._app.connection() as connection:
67 self._receiver = self._app.events.Receiver(connection, handlers={
68 'task-started': self._celery_task_started,
69 'task-succeeded': self._celery_task_succeeded,
70 'task-failed': self._celery_task_failed,
72 for _ in self._receiver.itercapture(limit=None, timeout=None, wakeup=True):
73 if not self._started_signaled:
74 self._started_queue.put(True)
75 self._started_signaled = True
79 def _celery_task_started(self, event):
80 self._task_started(self._tasks[event['uuid']])
82 def _celery_task_succeeded(self, event):
83 task, _ = self._remove_task(event['uuid'])
84 self._task_succeeded(task)
86 def _celery_task_failed(self, event):
87 task, async_result = self._remove_task(event['uuid'])
89 exception = async_result.result
90 except BaseException as e:
91 exception = RuntimeError(
92 'Could not de-serialize exception of task {0} --> {1}: {2}'
93 .format(task.name, type(e).__name__, str(e)))
94 self._task_failed(task, exception=exception)
96 def _remove_task(self, task_id):
97 return self._tasks.pop(task_id), self._results.pop(task_id)