1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements. See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License. You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
17 Sub-process task executor.
20 # pylint: disable=wrong-import-position
25 # As part of the process executor implementation, subprocess are started with this module as their
26 # entry point. We thus remove this module's directory from the python path if it happens to be
29 from collections import namedtuple
31 script_dir = os.path.dirname(__file__)
32 if script_dir in sys.path:
33 sys.path.remove(script_dir)
49 from aria.orchestrator.workflows.executor import base
50 from aria.extension import process_executor
51 from aria.utils import (
54 process as process_utils
59 _INT_SIZE = struct.calcsize(_INT_FMT)
60 UPDATE_TRACKED_CHANGES_FAILED_STR = \
61 'Some changes failed writing to storage. For more info refer to the log.'
64 _Task = namedtuple('_Task', 'proc, ctx')
67 class ProcessExecutor(base.BaseExecutor):
69 Sub-process task executor.
72 def __init__(self, plugin_manager=None, python_path=None, *args, **kwargs):
73 super(ProcessExecutor, self).__init__(*args, **kwargs)
74 self._plugin_manager = plugin_manager
76 # Optional list of additional directories that should be added to
77 # subprocesses python path
78 self._python_path = python_path or []
80 # Flag that denotes whether this executor has been stopped
83 # Contains reference to all currently running tasks
86 self._request_handlers = {
87 'started': self._handle_task_started_request,
88 'succeeded': self._handle_task_succeeded_request,
89 'failed': self._handle_task_failed_request,
92 # Server socket used to accept task status messages from subprocesses
93 self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
94 self._server_socket.bind(('localhost', 0))
95 self._server_socket.listen(10)
96 self._server_port = self._server_socket.getsockname()[1]
98 # Used to send a "closed" message to the listener when this executor is closed
99 self._messenger = _Messenger(task_id=None, port=self._server_port)
101 # Queue object used by the listener thread to notify this constructed it has started
102 # (see last line of this __init__ method)
103 self._listener_started = Queue.Queue()
105 # Listener thread to handle subprocesses task status messages
106 self._listener_thread = threading.Thread(target=self._listener)
107 self._listener_thread.daemon = True
108 self._listener_thread.start()
110 # Wait for listener thread to actually start before returning
111 self._listener_started.get(timeout=60)
117 # Listener thread may be blocked on "accept" call. This will wake it up with an explicit
119 self._messenger.closed()
120 self._server_socket.close()
121 self._listener_thread.join(timeout=60)
123 # we use set(self._tasks) since tasks may change in the process of closing
124 for task_id in set(self._tasks):
125 self.terminate(task_id)
127 def terminate(self, task_id):
128 task = self._remove_task(task_id)
129 # The process might have managed to finish, thus it would not be in the tasks list
132 parent_process = psutil.Process(task.proc.pid)
133 for child_process in reversed(parent_process.children(recursive=True)):
136 except BaseException:
138 parent_process.kill()
139 except BaseException:
142 def _execute(self, ctx):
145 # Temporary file used to pass arguments to the started subprocess
146 file_descriptor, arguments_json_path = tempfile.mkstemp(prefix='executor-', suffix='.json')
147 os.close(file_descriptor)
148 with open(arguments_json_path, 'wb') as f:
149 f.write(pickle.dumps(self._create_arguments_dict(ctx)))
151 env = self._construct_subprocess_env(task=ctx.task)
152 # Asynchronously start the operation in a subprocess
153 proc = subprocess.Popen(
156 os.path.expanduser(os.path.expandvars(__file__)),
157 os.path.expanduser(os.path.expandvars(arguments_json_path))
161 self._tasks[ctx.task.id] = _Task(ctx=ctx, proc=proc)
163 def _remove_task(self, task_id):
164 return self._tasks.pop(task_id, None)
166 def _check_closed(self):
168 raise RuntimeError('Executor closed')
170 def _create_arguments_dict(self, ctx):
172 'task_id': ctx.task.id,
173 'function': ctx.task.function,
174 'operation_arguments': dict(arg.unwrapped for arg in ctx.task.arguments.itervalues()),
175 'port': self._server_port,
176 'context': ctx.serialization_dict
179 def _construct_subprocess_env(self, task):
180 env = os.environ.copy()
182 if task.plugin_fk and self._plugin_manager:
183 # If this is a plugin operation,
184 # load the plugin on the subprocess env we're constructing
185 self._plugin_manager.load_plugin(task.plugin, env=env)
187 # Add user supplied directories to injected PYTHONPATH
188 if self._python_path:
189 process_utils.append_to_pythonpath(*self._python_path, env=env)
194 # Notify __init__ method this thread has actually started
195 self._listener_started.put(True)
196 while not self._stopped:
198 with self._accept_request() as (request, response):
199 request_type = request['type']
200 if request_type == 'closed':
202 request_handler = self._request_handlers.get(request_type)
203 if not request_handler:
204 raise RuntimeError('Invalid request type: {0}'.format(request_type))
205 task_id = request['task_id']
206 request_handler(task_id=task_id, request=request, response=response)
207 except BaseException as e:
208 self.logger.debug('Error in process executor listener: {0}'.format(e))
210 @contextlib.contextmanager
211 def _accept_request(self):
212 with contextlib.closing(self._server_socket.accept()[0]) as connection:
213 message = _recv_message(connection)
216 yield message, response
217 except BaseException as e:
218 response['exception'] = exceptions.wrap_if_needed(e)
221 _send_message(connection, response)
223 def _handle_task_started_request(self, task_id, **kwargs):
224 self._task_started(self._tasks[task_id].ctx)
226 def _handle_task_succeeded_request(self, task_id, **kwargs):
227 task = self._remove_task(task_id)
229 self._task_succeeded(task.ctx)
231 def _handle_task_failed_request(self, task_id, request, **kwargs):
232 task = self._remove_task(task_id)
235 task.ctx, exception=request['exception'], traceback=request['traceback'])
238 def _send_message(connection, message):
240 # Packing the length of the entire msg using struct.pack.
241 # This enables later reading of the content.
243 return struct.pack(_INT_FMT, len(data))
245 data = jsonpickle.dumps(message)
246 msg_metadata = _pack(data)
247 connection.send(msg_metadata)
248 connection.sendall(data)
251 def _recv_message(connection):
252 # Retrieving the length of the msg to come.
254 return struct.unpack(_INT_FMT, _recv_bytes(conn, _INT_SIZE))[0]
256 msg_metadata_len = _unpack(connection)
257 msg = _recv_bytes(connection, msg_metadata_len)
258 return jsonpickle.loads(msg)
261 def _recv_bytes(connection, count):
262 result = io.BytesIO()
265 return result.getvalue()
266 read = connection.recv(count)
268 return result.getvalue()
273 class _Messenger(object):
275 def __init__(self, task_id, port):
276 self.task_id = task_id
280 """Task started message"""
281 self._send_message(type='started')
284 """Task succeeded message"""
285 self._send_message(type='succeeded')
287 def failed(self, exception):
288 """Task failed message"""
289 self._send_message(type='failed', exception=exception)
292 """Executor closed message"""
293 self._send_message(type='closed')
295 def _send_message(self, type, exception=None):
296 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
297 sock.connect(('localhost', self.port))
299 _send_message(sock, {
301 'task_id': self.task_id,
302 'exception': exceptions.wrap_if_needed(exception),
303 'traceback': exceptions.get_exception_as_string(*sys.exc_info()),
305 response = _recv_message(sock)
306 response_exception = response.get('exception')
307 if response_exception:
308 raise response_exception
314 arguments_json_path = sys.argv[1]
315 with open(arguments_json_path) as f:
316 arguments = pickle.loads(f.read())
318 # arguments_json_path is a temporary file created by the parent process.
319 # so we remove it here
320 os.remove(arguments_json_path)
322 task_id = arguments['task_id']
323 port = arguments['port']
324 messenger = _Messenger(task_id=task_id, port=port)
326 function = arguments['function']
327 operation_arguments = arguments['operation_arguments']
328 context_dict = arguments['context']
331 ctx = context_dict['context_cls'].instantiate_from_dict(**context_dict['context'])
332 except BaseException as e:
338 task_func = imports.load_attribute(function)
339 aria.install_aria_extensions()
340 for decorate in process_executor.decorate():
341 task_func = decorate(task_func)
342 task_func(ctx=ctx, **operation_arguments)
344 messenger.succeeded()
345 except BaseException as e:
349 if __name__ == '__main__':