1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements. See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License. You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
24 from fabric.contrib import files
25 from fabric import context_managers
27 from aria.modeling import models
28 from aria.orchestrator import events
29 from aria.orchestrator import workflow
30 from aria.orchestrator.workflows import api
31 from aria.orchestrator.workflows.executor import process
32 from aria.orchestrator.workflows.core import engine, graph_compiler
33 from aria.orchestrator.workflows.exceptions import ExecutorException
34 from aria.orchestrator.exceptions import TaskAbortException, TaskRetryException
35 from aria.orchestrator.execution_plugin import operations
36 from aria.orchestrator.execution_plugin import constants
37 from aria.orchestrator.execution_plugin.exceptions import ProcessException, TaskException
38 from aria.orchestrator.execution_plugin.ssh import operations as ssh_operations
40 from tests import mock, storage, resources
41 from tests.orchestrator.workflows.helpers import events_collector
43 _CUSTOM_BASE_DIR = '/tmp/new-aria-ctx'
46 KEY_FILENAME = os.path.join(tests.ROOT_DIR, 'tests/resources/keys/test')
49 'disable_known_hosts': True,
51 'key_filename': KEY_FILENAME
56 @pytest.fixture(scope='session')
58 with mockssh.Server({'test': KEY_FILENAME}) as s:
62 #@pytest.mark.skipif(not os.environ.get('TRAVIS'), reason='actual ssh server required')
63 class TestWithActualSSHServer(object):
65 def test_run_script_basic(self):
66 expected_attribute_value = 'some_value'
67 props = self._execute(env={'test_value': expected_attribute_value})
68 assert props['test_value'].value == expected_attribute_value
70 @pytest.mark.skip(reason='sudo privileges are required')
71 def test_run_script_as_sudo(self):
72 self._execute(use_sudo=True)
74 assert files.exists('/opt/test_dir')
75 fabric.api.sudo('rm -rf /opt/test_dir')
77 def test_run_script_default_base_dir(self):
78 props = self._execute()
79 assert props['work_dir'].value == '{0}/work'.format(constants.DEFAULT_BASE_DIR)
81 @pytest.mark.skip(reason='Re-enable once output from process executor can be captured')
82 @pytest.mark.parametrize('hide_groups', [[], ['everything']])
83 def test_run_script_with_hide(self, hide_groups):
84 self._execute(hide_output=hide_groups)
86 expected_log_message = ('[localhost] run: source {0}/scripts/'
87 .format(constants.DEFAULT_BASE_DIR))
89 assert expected_log_message not in output
91 assert expected_log_message in output
93 def test_run_script_process_config(self):
94 expected_env_value = 'test_value_env'
95 expected_arg1_value = 'test_value_arg1'
96 expected_arg2_value = 'test_value_arg2'
98 expected_base_dir = _CUSTOM_BASE_DIR
99 props = self._execute(
100 env={'test_value_env': expected_env_value},
102 'args': [expected_arg1_value, expected_arg2_value],
104 'base_dir': expected_base_dir
106 assert props['env_value'].value == expected_env_value
107 assert len(props['bash_version'].value) > 0
108 assert props['arg1_value'].value == expected_arg1_value
109 assert props['arg2_value'].value == expected_arg2_value
110 assert props['cwd'].value == expected_cwd
111 assert props['ctx_path'].value == '{0}/ctx'.format(expected_base_dir)
113 def test_run_script_command_prefix(self):
114 props = self._execute(process={'command_prefix': 'bash -i'})
115 assert 'i' in props['dollar_dash'].value
117 def test_run_script_reuse_existing_ctx(self):
118 expected_test_value_1 = 'test_value_1'
119 expected_test_value_2 = 'test_value_2'
120 props = self._execute(
121 test_operations=['{0}_1'.format(self.test_name),
122 '{0}_2'.format(self.test_name)],
123 env={'test_value1': expected_test_value_1,
124 'test_value2': expected_test_value_2})
125 assert props['test_value1'].value == expected_test_value_1
126 assert props['test_value2'].value == expected_test_value_2
128 def test_run_script_download_resource_plain(self, tmpdir):
129 resource = tmpdir.join('resource')
130 resource.write('content')
131 self._upload(str(resource), 'test_resource')
132 props = self._execute()
133 assert props['test_value'].value == 'content'
135 def test_run_script_download_resource_and_render(self, tmpdir):
136 resource = tmpdir.join('resource')
137 resource.write('{{ctx.service.name}}')
138 self._upload(str(resource), 'test_resource')
139 props = self._execute()
140 assert props['test_value'].value == self._workflow_context.service.name
142 @pytest.mark.parametrize('value', ['string-value', [1, 2, 3], {'key': 'value'}])
143 def test_run_script_inputs_as_env_variables_no_override(self, value):
144 props = self._execute(custom_input=value)
145 return_value = props['test_value'].value
146 expected = return_value if isinstance(value, basestring) else json.loads(return_value)
147 assert value == expected
149 @pytest.mark.parametrize('value', ['string-value', [1, 2, 3], {'key': 'value'}])
150 def test_run_script_inputs_as_env_variables_process_env_override(self, value):
151 props = self._execute(custom_input='custom-input-value',
152 env={'custom_env_var': value})
153 return_value = props['test_value'].value
154 expected = return_value if isinstance(value, basestring) else json.loads(return_value)
155 assert value == expected
157 def test_run_script_error_in_script(self):
158 exception = self._execute_and_get_task_exception()
159 assert isinstance(exception, TaskException)
161 def test_run_script_abort_immediate(self):
162 exception = self._execute_and_get_task_exception()
163 assert isinstance(exception, TaskAbortException)
164 assert exception.message == 'abort-message'
166 def test_run_script_retry(self):
167 exception = self._execute_and_get_task_exception()
168 assert isinstance(exception, TaskRetryException)
169 assert exception.message == 'retry-message'
171 def test_run_script_abort_error_ignored_by_script(self):
172 exception = self._execute_and_get_task_exception()
173 assert isinstance(exception, TaskAbortException)
174 assert exception.message == 'abort-message'
176 def test_run_commands(self):
177 temp_file_path = '/tmp/very_temporary_file'
178 with self._ssh_env():
179 if files.exists(temp_file_path):
180 fabric.api.run('rm {0}'.format(temp_file_path))
181 self._execute(commands=['touch {0}'.format(temp_file_path)])
182 with self._ssh_env():
183 assert files.exists(temp_file_path)
184 fabric.api.run('rm {0}'.format(temp_file_path))
186 @pytest.fixture(autouse=True)
187 def _setup(self, request, workflow_context, executor, capfd, server):
188 print 'HI!!!!!!!!!!', server.port
189 self._workflow_context = workflow_context
190 self._executor = executor
192 self.test_name = request.node.originalname or request.node.name
193 with self._ssh_env(server):
194 for directory in [constants.DEFAULT_BASE_DIR, _CUSTOM_BASE_DIR]:
195 if files.exists(directory):
196 fabric.api.run('rm -rf {0}'.format(directory))
198 @contextlib.contextmanager
199 def _ssh_env(self, server):
200 with self._capfd.disabled():
201 with context_managers.settings(fabric.api.hide('everything'),
202 host_string='localhost:{0}'.format(server.port),
212 test_operations=None,
214 process = process or {}
216 process.setdefault('env', {}).update(env)
218 test_operations = test_operations or [self.test_name]
220 local_script_path = os.path.join(resources.DIR, 'scripts', 'test_ssh.sh')
221 script_path = os.path.basename(local_script_path)
222 self._upload(local_script_path, script_path)
225 operation = operations.run_commands_with_ssh
227 operation = operations.run_script_with_ssh
229 node = self._workflow_context.model.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME)
231 'script_path': script_path,
232 'fabric_env': _FABRIC_ENV,
234 'use_sudo': use_sudo,
235 'custom_env_var': custom_input,
236 'test_operation': '',
239 arguments['hide_output'] = hide_output
241 arguments['commands'] = commands
242 interface = mock.models.create_interface(
246 operation_kwargs=dict(
247 function='{0}.{1}'.format(
252 node.interfaces[interface.name] = interface
255 def mock_workflow(ctx, graph):
257 for test_operation in test_operations:
258 op_arguments = arguments.copy()
259 op_arguments['test_operation'] = test_operation
260 ops.append(api.task.OperationTask(
262 interface_name='test',
264 arguments=op_arguments))
268 tasks_graph = mock_workflow(ctx=self._workflow_context) # pylint: disable=no-value-for-parameter
269 graph_compiler.GraphCompiler(
270 self._workflow_context, self._executor.__class__).compile(tasks_graph)
271 eng = engine.Engine({self._executor.__class__: self._executor})
272 eng.execute(self._workflow_context)
273 return self._workflow_context.model.node.get_by_name(
274 mock.models.DEPENDENCY_NODE_NAME).attributes
276 def _execute_and_get_task_exception(self, *args, **kwargs):
277 signal = events.on_failure_task_signal
278 with events_collector(signal) as collected:
279 with pytest.raises(ExecutorException):
280 self._execute(*args, **kwargs)
281 return collected[signal][0]['kwargs']['exception']
283 def _upload(self, source, path):
284 self._workflow_context.resource.service.upload(
285 entry_id=str(self._workflow_context.service.id),
291 result = process.ProcessExecutor()
298 def workflow_context(self, tmpdir):
299 workflow_context = mock.context.simple(str(tmpdir))
300 workflow_context.states = []
301 workflow_context.exception = None
302 yield workflow_context
303 storage.release_sqlite_storage(workflow_context.model)
306 class TestFabricEnvHideGroupsAndRunCommands(object):
308 def test_fabric_env_default_override(self):
309 # first sanity for no override
311 assert self.mock.settings_merged['timeout'] == constants.FABRIC_ENV_DEFAULTS['timeout']
313 invocation_fabric_env = self.default_fabric_env.copy()
315 invocation_fabric_env['timeout'] = timeout
316 self._run(fabric_env=invocation_fabric_env)
317 assert self.mock.settings_merged['timeout'] == timeout
319 def test_implicit_host_string(self, mocker):
320 expected_host_address = '1.1.1.1'
321 mocker.patch.object(self._Ctx.task.actor, 'host')
322 mocker.patch.object(self._Ctx.task.actor.host, 'host_address', expected_host_address)
323 fabric_env = self.default_fabric_env.copy()
324 del fabric_env['host_string']
325 self._run(fabric_env=fabric_env)
326 assert self.mock.settings_merged['host_string'] == expected_host_address
328 def test_explicit_host_string(self):
329 fabric_env = self.default_fabric_env.copy()
330 host_string = 'explicit_host_string'
331 fabric_env['host_string'] = host_string
332 self._run(fabric_env=fabric_env)
333 assert self.mock.settings_merged['host_string'] == host_string
335 def test_override_warn_only(self):
336 fabric_env = self.default_fabric_env.copy()
337 self._run(fabric_env=fabric_env)
338 assert self.mock.settings_merged['warn_only'] is True
339 fabric_env = self.default_fabric_env.copy()
340 fabric_env['warn_only'] = False
341 self._run(fabric_env=fabric_env)
342 assert self.mock.settings_merged['warn_only'] is False
344 def test_missing_host_string(self):
345 with pytest.raises(TaskAbortException) as exc_ctx:
346 fabric_env = self.default_fabric_env.copy()
347 del fabric_env['host_string']
348 self._run(fabric_env=fabric_env)
349 assert '`host_string` not supplied' in str(exc_ctx.value)
351 def test_missing_user(self):
352 with pytest.raises(TaskAbortException) as exc_ctx:
353 fabric_env = self.default_fabric_env.copy()
354 del fabric_env['user']
355 self._run(fabric_env=fabric_env)
356 assert '`user` not supplied' in str(exc_ctx.value)
358 def test_missing_key_or_password(self):
359 with pytest.raises(TaskAbortException) as exc_ctx:
360 fabric_env = self.default_fabric_env.copy()
361 del fabric_env['key_filename']
362 self._run(fabric_env=fabric_env)
363 assert 'Access credentials not supplied' in str(exc_ctx.value)
365 def test_hide_in_settings_and_non_viable_groups(self):
366 groups = ('running', 'stdout')
367 self._run(hide_output=groups)
368 assert set(self.mock.settings_merged['hide_output']) == set(groups)
369 with pytest.raises(TaskAbortException) as exc_ctx:
370 self._run(hide_output=('running', 'bla'))
371 assert '`hide_output` must be a subset of' in str(exc_ctx.value)
373 def test_run_commands(self):
375 commands = ['command1', 'command2']
379 assert all(item in self.mock.settings_merged.items() for
380 item in self.default_fabric_env.items())
381 assert self.mock.settings_merged['warn_only'] is True
382 assert self.mock.settings_merged['use_sudo'] == use_sudo
383 assert self.mock.commands == commands
384 self.mock.settings_merged = {}
385 self.mock.commands = []
389 def test_failed_command(self):
390 with pytest.raises(ProcessException) as exc_ctx:
391 self._run(commands=['fail'])
392 exception = exc_ctx.value
393 assert exception.stdout == self.MockCommandResult.stdout
394 assert exception.stderr == self.MockCommandResult.stderr
395 assert exception.command == self.MockCommandResult.command
396 assert exception.exit_code == self.MockCommandResult.return_code
398 class MockCommandResult(object):
399 stdout = 'mock_stdout'
400 stderr = 'mock_stderr'
401 command = 'mock_command'
404 def __init__(self, failed):
407 class MockFabricApi(object):
411 self.settings_merged = {}
413 @contextlib.contextmanager
414 def settings(self, *args, **kwargs):
415 self.settings_merged.update(kwargs)
418 self.settings_merged.update({'hide_output': groups})
421 def run(self, command):
422 self.commands.append(command)
423 self.settings_merged['use_sudo'] = False
424 return TestFabricEnvHideGroupsAndRunCommands.MockCommandResult(command == 'fail')
426 def sudo(self, command):
427 self.commands.append(command)
428 self.settings_merged['use_sudo'] = True
429 return TestFabricEnvHideGroupsAndRunCommands.MockCommandResult(command == 'fail')
431 def hide(self, *groups):
434 def exists(self, *args, **kwargs):
438 INSTRUMENTATION_FIELDS = ()
442 def abort(message=None):
443 models.Task.abort(message)
450 @contextlib.contextmanager
451 def instrument(self, *args, **kwargs):
456 logger = logging.getLogger()
459 @contextlib.contextmanager
460 def _mock_self_logging(*args, **kwargs):
462 _Ctx.logging_handlers = _mock_self_logging
464 @pytest.fixture(autouse=True)
465 def _setup(self, mocker):
466 self.default_fabric_env = {
467 'host_string': 'test',
469 'key_filename': 'test',
471 self.mock = self.MockFabricApi()
472 mocker.patch('fabric.api', self.mock)
480 operations.run_commands_with_ssh(
484 fabric_env=fabric_env or self.default_fabric_env,
486 hide_output=hide_output)
489 class TestUtilityFunctions(object):
491 def test_paths(self):
493 local_script_path = '/local/script/path.py'
494 paths = ssh_operations._Paths(base_dir=base_dir,
495 local_script_path=local_script_path)
496 assert paths.local_script_path == local_script_path
497 assert paths.remote_ctx_dir == base_dir
498 assert paths.base_script_path == 'path.py'
499 assert paths.remote_ctx_path == '/path/ctx'
500 assert paths.remote_scripts_dir == '/path/scripts'
501 assert paths.remote_work_dir == '/path/work'
502 assert paths.remote_env_script_path.startswith('/path/scripts/env-path.py-')
503 assert paths.remote_script_path.startswith('/path/scripts/path.py-')
505 def test_write_environment_script_file(self):
507 local_script_path = '/local/script/path.py'
508 paths = ssh_operations._Paths(base_dir=base_dir,
509 local_script_path=local_script_path)
511 local_socket_url = 'local_socket_url'
512 remote_socket_url = 'remote_socket_url'
513 env_script_lines = set([l for l in ssh_operations._write_environment_script_file(
514 process={'env': env},
516 local_socket_url=local_socket_url,
517 remote_socket_url=remote_socket_url
518 ).getvalue().split('\n') if l])
519 expected_env_script_lines = set([
520 'export PATH=/path:$PATH',
521 'export PYTHONPATH=/path:$PYTHONPATH',
522 'chmod +x /path/ctx',
523 'chmod +x {0}'.format(paths.remote_script_path),
524 'export CTX_SOCKET_URL={0}'.format(remote_socket_url),
525 'export LOCAL_CTX_SOCKET_URL={0}'.format(local_socket_url),
528 assert env_script_lines == expected_env_script_lines