1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements. See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License. You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
17 from contextlib import contextmanager
20 from aria.modeling import models
21 from aria.orchestrator.context.common import BaseContext
24 class MockContext(object):
26 INSTRUMENTATION_FIELDS = BaseContext.INSTRUMENTATION_FIELDS
28 def __init__(self, storage, task_kwargs=None):
29 self.logger = logging.getLogger('mock_logger')
30 self._task_kwargs = task_kwargs or {}
31 self._storage = storage
32 self.task = MockTask(storage, **task_kwargs)
37 def serialization_dict(self):
39 'context_cls': self.__class__,
41 'storage_kwargs': self._storage.serialization_dict,
42 'task_kwargs': self._task_kwargs
46 def __getattr__(self, item):
57 def instantiate_from_dict(cls, storage_kwargs=None, task_kwargs=None):
58 return cls(storage=aria.application_model_storage(**(storage_kwargs or {})),
59 task_kwargs=(task_kwargs or {}))
63 def persist_changes(self):
67 class MockActor(object):
69 self.name = 'actor_name'
72 class MockTask(object):
74 INFINITE_RETRIES = models.Task.INFINITE_RETRIES
76 def __init__(self, model, function, arguments=None, plugin_fk=None):
77 self.function = self.name = function
78 self.plugin_fk = plugin_fk
79 self.arguments = arguments or {}
82 self.id = str(uuid.uuid4())
83 self.logger = logging.getLogger()
84 self.attempts_count = 1
86 self.ignore_failure = False
87 self.interface_name = 'interface_name'
88 self.operation_name = 'operation_name'
89 self.actor = MockActor()
90 self.node = self.actor
93 for state in models.Task.STATES:
94 setattr(self, state.upper(), state)
98 return self.model.plugin.get(self.plugin_fk) if self.plugin_fk else None