Merge "vFW and vDNS support added to azure-plugin"
[multicloud/azure.git] / azure / aria / aria-extension-cloudify / src / aria / aria / orchestrator / workflows / executor / celery.py
1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements.  See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License.  You may obtain a copy of the License at
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 """
17 Celery task executor.
18 """
19
20 import threading
21 import Queue
22
23 from aria.orchestrator.workflows.executor import BaseExecutor
24
25
26 class CeleryExecutor(BaseExecutor):
27     """
28     Celery task executor.
29     """
30
31     def __init__(self, app, *args, **kwargs):
32         super(CeleryExecutor, self).__init__(*args, **kwargs)
33         self._app = app
34         self._started_signaled = False
35         self._started_queue = Queue.Queue(maxsize=1)
36         self._tasks = {}
37         self._results = {}
38         self._receiver = None
39         self._stopped = False
40         self._receiver_thread = threading.Thread(target=self._events_receiver)
41         self._receiver_thread.daemon = True
42         self._receiver_thread.start()
43         self._started_queue.get(timeout=30)
44
45     def _execute(self, ctx):
46         self._tasks[ctx.id] = ctx
47         arguments = dict(arg.unwrapped for arg in ctx.task.arguments.itervalues())
48         arguments['ctx'] = ctx.context
49         self._results[ctx.id] = self._app.send_task(
50             ctx.operation_mapping,
51             kwargs=arguments,
52             task_id=ctx.task.id,
53             queue=self._get_queue(ctx))
54
55     def close(self):
56         self._stopped = True
57         if self._receiver:
58             self._receiver.should_stop = True
59         self._receiver_thread.join()
60
61     @staticmethod
62     def _get_queue(task):
63         return None if task else None  # TODO
64
65     def _events_receiver(self):
66         with self._app.connection() as connection:
67             self._receiver = self._app.events.Receiver(connection, handlers={
68                 'task-started': self._celery_task_started,
69                 'task-succeeded': self._celery_task_succeeded,
70                 'task-failed': self._celery_task_failed,
71             })
72             for _ in self._receiver.itercapture(limit=None, timeout=None, wakeup=True):
73                 if not self._started_signaled:
74                     self._started_queue.put(True)
75                     self._started_signaled = True
76                 if self._stopped:
77                     return
78
79     def _celery_task_started(self, event):
80         self._task_started(self._tasks[event['uuid']])
81
82     def _celery_task_succeeded(self, event):
83         task, _ = self._remove_task(event['uuid'])
84         self._task_succeeded(task)
85
86     def _celery_task_failed(self, event):
87         task, async_result = self._remove_task(event['uuid'])
88         try:
89             exception = async_result.result
90         except BaseException as e:
91             exception = RuntimeError(
92                 'Could not de-serialize exception of task {0} --> {1}: {2}'
93                 .format(task.name, type(e).__name__, str(e)))
94         self._task_failed(task, exception=exception)
95
96     def _remove_task(self, task_id):
97         return self._tasks.pop(task_id), self._results.pop(task_id)