Update project maturity status
[multicloud/azure.git] / azure / aria / aria-extension-cloudify / src / aria / test_ssh.py
1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements.  See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License.  You may obtain a copy of the License at
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 import contextlib
17 import json
18 import logging
19 import os
20
21 import pytest
22
23 import fabric.api
24 from fabric.contrib import files
25 from fabric import context_managers
26
27 from aria.modeling import models
28 from aria.orchestrator import events
29 from aria.orchestrator import workflow
30 from aria.orchestrator.workflows import api
31 from aria.orchestrator.workflows.executor import process
32 from aria.orchestrator.workflows.core import engine, graph_compiler
33 from aria.orchestrator.workflows.exceptions import ExecutorException
34 from aria.orchestrator.exceptions import TaskAbortException, TaskRetryException
35 from aria.orchestrator.execution_plugin import operations
36 from aria.orchestrator.execution_plugin import constants
37 from aria.orchestrator.execution_plugin.exceptions import ProcessException, TaskException
38 from aria.orchestrator.execution_plugin.ssh import operations as ssh_operations
39
40 from tests import mock, storage, resources
41 from tests.orchestrator.workflows.helpers import events_collector
42
43 _CUSTOM_BASE_DIR = '/tmp/new-aria-ctx'
44
45 import tests
46 KEY_FILENAME = os.path.join(tests.ROOT_DIR, 'tests/resources/keys/test')
47
48 _FABRIC_ENV = {
49     'disable_known_hosts': True,
50     'user': 'test',
51     'key_filename': KEY_FILENAME
52 }
53
54
55 import mockssh
56 @pytest.fixture(scope='session')
57 def server():
58     with mockssh.Server({'test': KEY_FILENAME}) as s:
59         yield s
60
61
62 #@pytest.mark.skipif(not os.environ.get('TRAVIS'), reason='actual ssh server required')
63 class TestWithActualSSHServer(object):
64
65     def test_run_script_basic(self):
66         expected_attribute_value = 'some_value'
67         props = self._execute(env={'test_value': expected_attribute_value})
68         assert props['test_value'].value == expected_attribute_value
69
70     @pytest.mark.skip(reason='sudo privileges are required')
71     def test_run_script_as_sudo(self):
72         self._execute(use_sudo=True)
73         with self._ssh_env():
74             assert files.exists('/opt/test_dir')
75             fabric.api.sudo('rm -rf /opt/test_dir')
76
77     def test_run_script_default_base_dir(self):
78         props = self._execute()
79         assert props['work_dir'].value == '{0}/work'.format(constants.DEFAULT_BASE_DIR)
80
81     @pytest.mark.skip(reason='Re-enable once output from process executor can be captured')
82     @pytest.mark.parametrize('hide_groups', [[], ['everything']])
83     def test_run_script_with_hide(self, hide_groups):
84         self._execute(hide_output=hide_groups)
85         output = 'TODO'
86         expected_log_message = ('[localhost] run: source {0}/scripts/'
87                                 .format(constants.DEFAULT_BASE_DIR))
88         if hide_groups:
89             assert expected_log_message not in output
90         else:
91             assert expected_log_message in output
92
93     def test_run_script_process_config(self):
94         expected_env_value = 'test_value_env'
95         expected_arg1_value = 'test_value_arg1'
96         expected_arg2_value = 'test_value_arg2'
97         expected_cwd = '/tmp'
98         expected_base_dir = _CUSTOM_BASE_DIR
99         props = self._execute(
100             env={'test_value_env': expected_env_value},
101             process={
102                 'args': [expected_arg1_value, expected_arg2_value],
103                 'cwd': expected_cwd,
104                 'base_dir': expected_base_dir
105             })
106         assert props['env_value'].value == expected_env_value
107         assert len(props['bash_version'].value) > 0
108         assert props['arg1_value'].value == expected_arg1_value
109         assert props['arg2_value'].value == expected_arg2_value
110         assert props['cwd'].value == expected_cwd
111         assert props['ctx_path'].value == '{0}/ctx'.format(expected_base_dir)
112
113     def test_run_script_command_prefix(self):
114         props = self._execute(process={'command_prefix': 'bash -i'})
115         assert 'i' in props['dollar_dash'].value
116
117     def test_run_script_reuse_existing_ctx(self):
118         expected_test_value_1 = 'test_value_1'
119         expected_test_value_2 = 'test_value_2'
120         props = self._execute(
121             test_operations=['{0}_1'.format(self.test_name),
122                              '{0}_2'.format(self.test_name)],
123             env={'test_value1': expected_test_value_1,
124                  'test_value2': expected_test_value_2})
125         assert props['test_value1'].value == expected_test_value_1
126         assert props['test_value2'].value == expected_test_value_2
127
128     def test_run_script_download_resource_plain(self, tmpdir):
129         resource = tmpdir.join('resource')
130         resource.write('content')
131         self._upload(str(resource), 'test_resource')
132         props = self._execute()
133         assert props['test_value'].value == 'content'
134
135     def test_run_script_download_resource_and_render(self, tmpdir):
136         resource = tmpdir.join('resource')
137         resource.write('{{ctx.service.name}}')
138         self._upload(str(resource), 'test_resource')
139         props = self._execute()
140         assert props['test_value'].value == self._workflow_context.service.name
141
142     @pytest.mark.parametrize('value', ['string-value', [1, 2, 3], {'key': 'value'}])
143     def test_run_script_inputs_as_env_variables_no_override(self, value):
144         props = self._execute(custom_input=value)
145         return_value = props['test_value'].value
146         expected = return_value if isinstance(value, basestring) else json.loads(return_value)
147         assert value == expected
148
149     @pytest.mark.parametrize('value', ['string-value', [1, 2, 3], {'key': 'value'}])
150     def test_run_script_inputs_as_env_variables_process_env_override(self, value):
151         props = self._execute(custom_input='custom-input-value',
152                               env={'custom_env_var': value})
153         return_value = props['test_value'].value
154         expected = return_value if isinstance(value, basestring) else json.loads(return_value)
155         assert value == expected
156
157     def test_run_script_error_in_script(self):
158         exception = self._execute_and_get_task_exception()
159         assert isinstance(exception, TaskException)
160
161     def test_run_script_abort_immediate(self):
162         exception = self._execute_and_get_task_exception()
163         assert isinstance(exception, TaskAbortException)
164         assert exception.message == 'abort-message'
165
166     def test_run_script_retry(self):
167         exception = self._execute_and_get_task_exception()
168         assert isinstance(exception, TaskRetryException)
169         assert exception.message == 'retry-message'
170
171     def test_run_script_abort_error_ignored_by_script(self):
172         exception = self._execute_and_get_task_exception()
173         assert isinstance(exception, TaskAbortException)
174         assert exception.message == 'abort-message'
175
176     def test_run_commands(self):
177         temp_file_path = '/tmp/very_temporary_file'
178         with self._ssh_env():
179             if files.exists(temp_file_path):
180                 fabric.api.run('rm {0}'.format(temp_file_path))
181         self._execute(commands=['touch {0}'.format(temp_file_path)])
182         with self._ssh_env():
183             assert files.exists(temp_file_path)
184             fabric.api.run('rm {0}'.format(temp_file_path))
185
186     @pytest.fixture(autouse=True)
187     def _setup(self, request, workflow_context, executor, capfd, server):
188         print 'HI!!!!!!!!!!', server.port
189         self._workflow_context = workflow_context
190         self._executor = executor
191         self._capfd = capfd
192         self.test_name = request.node.originalname or request.node.name
193         with self._ssh_env(server):
194             for directory in [constants.DEFAULT_BASE_DIR, _CUSTOM_BASE_DIR]:
195                 if files.exists(directory):
196                     fabric.api.run('rm -rf {0}'.format(directory))
197
198     @contextlib.contextmanager
199     def _ssh_env(self, server):
200         with self._capfd.disabled():
201             with context_managers.settings(fabric.api.hide('everything'),
202                                            host_string='localhost:{0}'.format(server.port),
203                                            **_FABRIC_ENV):
204                 yield
205
206     def _execute(self,
207                  env=None,
208                  use_sudo=False,
209                  hide_output=None,
210                  process=None,
211                  custom_input='',
212                  test_operations=None,
213                  commands=None):
214         process = process or {}
215         if env:
216             process.setdefault('env', {}).update(env)
217
218         test_operations = test_operations or [self.test_name]
219
220         local_script_path = os.path.join(resources.DIR, 'scripts', 'test_ssh.sh')
221         script_path = os.path.basename(local_script_path)
222         self._upload(local_script_path, script_path)
223
224         if commands:
225             operation = operations.run_commands_with_ssh
226         else:
227             operation = operations.run_script_with_ssh
228
229         node = self._workflow_context.model.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME)
230         arguments = {
231             'script_path': script_path,
232             'fabric_env': _FABRIC_ENV,
233             'process': process,
234             'use_sudo': use_sudo,
235             'custom_env_var': custom_input,
236             'test_operation': '',
237         }
238         if hide_output:
239             arguments['hide_output'] = hide_output
240         if commands:
241             arguments['commands'] = commands
242         interface = mock.models.create_interface(
243             node.service,
244             'test',
245             'op',
246             operation_kwargs=dict(
247                 function='{0}.{1}'.format(
248                     operations.__name__,
249                     operation.__name__),
250                 arguments=arguments)
251         )
252         node.interfaces[interface.name] = interface
253
254         @workflow
255         def mock_workflow(ctx, graph):
256             ops = []
257             for test_operation in test_operations:
258                 op_arguments = arguments.copy()
259                 op_arguments['test_operation'] = test_operation
260                 ops.append(api.task.OperationTask(
261                     node,
262                     interface_name='test',
263                     operation_name='op',
264                     arguments=op_arguments))
265
266             graph.sequence(*ops)
267             return graph
268         tasks_graph = mock_workflow(ctx=self._workflow_context)  # pylint: disable=no-value-for-parameter
269         graph_compiler.GraphCompiler(
270             self._workflow_context, self._executor.__class__).compile(tasks_graph)
271         eng = engine.Engine({self._executor.__class__: self._executor})
272         eng.execute(self._workflow_context)
273         return self._workflow_context.model.node.get_by_name(
274             mock.models.DEPENDENCY_NODE_NAME).attributes
275
276     def _execute_and_get_task_exception(self, *args, **kwargs):
277         signal = events.on_failure_task_signal
278         with events_collector(signal) as collected:
279             with pytest.raises(ExecutorException):
280                 self._execute(*args, **kwargs)
281         return collected[signal][0]['kwargs']['exception']
282
283     def _upload(self, source, path):
284         self._workflow_context.resource.service.upload(
285             entry_id=str(self._workflow_context.service.id),
286             source=source,
287             path=path)
288
289     @pytest.fixture
290     def executor(self):
291         result = process.ProcessExecutor()
292         try:
293             yield result
294         finally:
295             result.close()
296
297     @pytest.fixture
298     def workflow_context(self, tmpdir):
299         workflow_context = mock.context.simple(str(tmpdir))
300         workflow_context.states = []
301         workflow_context.exception = None
302         yield workflow_context
303         storage.release_sqlite_storage(workflow_context.model)
304
305
306 class TestFabricEnvHideGroupsAndRunCommands(object):
307
308     def test_fabric_env_default_override(self):
309         # first sanity for no override
310         self._run()
311         assert self.mock.settings_merged['timeout'] == constants.FABRIC_ENV_DEFAULTS['timeout']
312         # now override
313         invocation_fabric_env = self.default_fabric_env.copy()
314         timeout = 1000000
315         invocation_fabric_env['timeout'] = timeout
316         self._run(fabric_env=invocation_fabric_env)
317         assert self.mock.settings_merged['timeout'] == timeout
318
319     def test_implicit_host_string(self, mocker):
320         expected_host_address = '1.1.1.1'
321         mocker.patch.object(self._Ctx.task.actor, 'host')
322         mocker.patch.object(self._Ctx.task.actor.host, 'host_address', expected_host_address)
323         fabric_env = self.default_fabric_env.copy()
324         del fabric_env['host_string']
325         self._run(fabric_env=fabric_env)
326         assert self.mock.settings_merged['host_string'] == expected_host_address
327
328     def test_explicit_host_string(self):
329         fabric_env = self.default_fabric_env.copy()
330         host_string = 'explicit_host_string'
331         fabric_env['host_string'] = host_string
332         self._run(fabric_env=fabric_env)
333         assert self.mock.settings_merged['host_string'] == host_string
334
335     def test_override_warn_only(self):
336         fabric_env = self.default_fabric_env.copy()
337         self._run(fabric_env=fabric_env)
338         assert self.mock.settings_merged['warn_only'] is True
339         fabric_env = self.default_fabric_env.copy()
340         fabric_env['warn_only'] = False
341         self._run(fabric_env=fabric_env)
342         assert self.mock.settings_merged['warn_only'] is False
343
344     def test_missing_host_string(self):
345         with pytest.raises(TaskAbortException) as exc_ctx:
346             fabric_env = self.default_fabric_env.copy()
347             del fabric_env['host_string']
348             self._run(fabric_env=fabric_env)
349         assert '`host_string` not supplied' in str(exc_ctx.value)
350
351     def test_missing_user(self):
352         with pytest.raises(TaskAbortException) as exc_ctx:
353             fabric_env = self.default_fabric_env.copy()
354             del fabric_env['user']
355             self._run(fabric_env=fabric_env)
356         assert '`user` not supplied' in str(exc_ctx.value)
357
358     def test_missing_key_or_password(self):
359         with pytest.raises(TaskAbortException) as exc_ctx:
360             fabric_env = self.default_fabric_env.copy()
361             del fabric_env['key_filename']
362             self._run(fabric_env=fabric_env)
363         assert 'Access credentials not supplied' in str(exc_ctx.value)
364
365     def test_hide_in_settings_and_non_viable_groups(self):
366         groups = ('running', 'stdout')
367         self._run(hide_output=groups)
368         assert set(self.mock.settings_merged['hide_output']) == set(groups)
369         with pytest.raises(TaskAbortException) as exc_ctx:
370             self._run(hide_output=('running', 'bla'))
371         assert '`hide_output` must be a subset of' in str(exc_ctx.value)
372
373     def test_run_commands(self):
374         def test(use_sudo):
375             commands = ['command1', 'command2']
376             self._run(
377                 commands=commands,
378                 use_sudo=use_sudo)
379             assert all(item in self.mock.settings_merged.items() for
380                        item in self.default_fabric_env.items())
381             assert self.mock.settings_merged['warn_only'] is True
382             assert self.mock.settings_merged['use_sudo'] == use_sudo
383             assert self.mock.commands == commands
384             self.mock.settings_merged = {}
385             self.mock.commands = []
386         test(use_sudo=False)
387         test(use_sudo=True)
388
389     def test_failed_command(self):
390         with pytest.raises(ProcessException) as exc_ctx:
391             self._run(commands=['fail'])
392         exception = exc_ctx.value
393         assert exception.stdout == self.MockCommandResult.stdout
394         assert exception.stderr == self.MockCommandResult.stderr
395         assert exception.command == self.MockCommandResult.command
396         assert exception.exit_code == self.MockCommandResult.return_code
397
398     class MockCommandResult(object):
399         stdout = 'mock_stdout'
400         stderr = 'mock_stderr'
401         command = 'mock_command'
402         return_code = 1
403
404         def __init__(self, failed):
405             self.failed = failed
406
407     class MockFabricApi(object):
408
409         def __init__(self):
410             self.commands = []
411             self.settings_merged = {}
412
413         @contextlib.contextmanager
414         def settings(self, *args, **kwargs):
415             self.settings_merged.update(kwargs)
416             if args:
417                 groups = args[0]
418                 self.settings_merged.update({'hide_output': groups})
419             yield
420
421         def run(self, command):
422             self.commands.append(command)
423             self.settings_merged['use_sudo'] = False
424             return TestFabricEnvHideGroupsAndRunCommands.MockCommandResult(command == 'fail')
425
426         def sudo(self, command):
427             self.commands.append(command)
428             self.settings_merged['use_sudo'] = True
429             return TestFabricEnvHideGroupsAndRunCommands.MockCommandResult(command == 'fail')
430
431         def hide(self, *groups):
432             return groups
433
434         def exists(self, *args, **kwargs):
435             raise RuntimeError
436
437     class _Ctx(object):
438         INSTRUMENTATION_FIELDS = ()
439
440         class Task(object):
441             @staticmethod
442             def abort(message=None):
443                 models.Task.abort(message)
444             actor = None
445
446         class Actor(object):
447             host = None
448
449         class Model(object):
450             @contextlib.contextmanager
451             def instrument(self, *args, **kwargs):
452                 yield
453         task = Task
454         task.actor = Actor
455         model = Model()
456         logger = logging.getLogger()
457
458     @staticmethod
459     @contextlib.contextmanager
460     def _mock_self_logging(*args, **kwargs):
461         yield
462     _Ctx.logging_handlers = _mock_self_logging
463
464     @pytest.fixture(autouse=True)
465     def _setup(self, mocker):
466         self.default_fabric_env = {
467             'host_string': 'test',
468             'user': 'test',
469             'key_filename': 'test',
470         }
471         self.mock = self.MockFabricApi()
472         mocker.patch('fabric.api', self.mock)
473
474     def _run(self,
475              commands=(),
476              fabric_env=None,
477              process=None,
478              use_sudo=False,
479              hide_output=None):
480         operations.run_commands_with_ssh(
481             ctx=self._Ctx,
482             commands=commands,
483             process=process,
484             fabric_env=fabric_env or self.default_fabric_env,
485             use_sudo=use_sudo,
486             hide_output=hide_output)
487
488
489 class TestUtilityFunctions(object):
490
491     def test_paths(self):
492         base_dir = '/path'
493         local_script_path = '/local/script/path.py'
494         paths = ssh_operations._Paths(base_dir=base_dir,
495                                       local_script_path=local_script_path)
496         assert paths.local_script_path == local_script_path
497         assert paths.remote_ctx_dir == base_dir
498         assert paths.base_script_path == 'path.py'
499         assert paths.remote_ctx_path == '/path/ctx'
500         assert paths.remote_scripts_dir == '/path/scripts'
501         assert paths.remote_work_dir == '/path/work'
502         assert paths.remote_env_script_path.startswith('/path/scripts/env-path.py-')
503         assert paths.remote_script_path.startswith('/path/scripts/path.py-')
504
505     def test_write_environment_script_file(self):
506         base_dir = '/path'
507         local_script_path = '/local/script/path.py'
508         paths = ssh_operations._Paths(base_dir=base_dir,
509                                       local_script_path=local_script_path)
510         env = {'one': "'1'"}
511         local_socket_url = 'local_socket_url'
512         remote_socket_url = 'remote_socket_url'
513         env_script_lines = set([l for l in ssh_operations._write_environment_script_file(
514             process={'env': env},
515             paths=paths,
516             local_socket_url=local_socket_url,
517             remote_socket_url=remote_socket_url
518         ).getvalue().split('\n') if l])
519         expected_env_script_lines = set([
520             'export PATH=/path:$PATH',
521             'export PYTHONPATH=/path:$PYTHONPATH',
522             'chmod +x /path/ctx',
523             'chmod +x {0}'.format(paths.remote_script_path),
524             'export CTX_SOCKET_URL={0}'.format(remote_socket_url),
525             'export LOCAL_CTX_SOCKET_URL={0}'.format(local_socket_url),
526             'export one=\'1\''
527         ])
528         assert env_script_lines == expected_env_script_lines