1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements. See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License. You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
17 Utilities for instrumenting collections of models in storage.
20 from . import exceptions
23 class _InstrumentedCollection(object):
35 self._field_name = field_name
36 self._is_top_level = is_top_level
37 self._field_cls = field_cls
38 self._load(seq, **kwargs)
42 raise NotImplementedError
44 def _load(self, seq, **kwargs):
46 Instantiates the object from existing seq.
48 :param seq: the original sequence to load from
50 raise NotImplementedError
52 def _set(self, key, value):
54 Sets the changes for the current object (not in the database).
59 raise NotImplementedError
61 def _del(self, collection, key):
62 raise NotImplementedError
64 def _instrument(self, key, value):
66 Instruments any collection to track changes (and ease of access).
71 if isinstance(value, _InstrumentedCollection):
73 elif isinstance(value, dict):
74 instrumentation_cls = _InstrumentedDict
75 elif isinstance(value, list):
76 instrumentation_cls = _InstrumentedList
80 return instrumentation_cls(self._mapi, self, key, self._field_cls, value, False)
82 def _raw_value(self, value):
88 if isinstance(value, self._field_cls):
92 def _encapsulate_value(self, key, value):
94 Creates a new item class if needed.
99 if isinstance(value, self._field_cls):
101 # If it is not wrapped
102 return self._field_cls.wrap(key, value)
104 def __setitem__(self, key, value):
106 Updates the values in both the local and the database locations.
111 self._set(key, value)
112 if self._is_top_level:
113 # We are at the top level
114 field = getattr(self._parent, self._field_name)
116 field, key, value if key in field else self._encapsulate_value(key, value))
117 self._mapi.update(self._parent)
119 # We are not at the top level
120 self._set_field(self._parent, self._field_name, self)
122 def _set_field(self, collection, key, value):
124 Enables updating the current change in the ancestors.
126 :param collection: collection to change
127 :param key: key for the specific field
128 :param value: new value
130 if isinstance(value, _InstrumentedCollection):
132 if key in collection and isinstance(collection[key], self._field_cls):
133 if isinstance(collection[key], _InstrumentedCollection):
134 self._del(collection, key)
135 collection[key].value = value
137 collection[key] = value
138 return collection[key]
140 def __deepcopy__(self, *args, **kwargs):
144 class _InstrumentedDict(_InstrumentedCollection, dict):
146 def _load(self, dict_=None, **kwargs):
149 tuple((key, self._raw_value(value)) for key, value in (dict_ or {}).iteritems()),
152 def update(self, dict_=None, **kwargs):
154 for key, value in dict_.iteritems():
156 for key, value in kwargs.iteritems():
159 def __getitem__(self, key):
160 return self._instrument(key, dict.__getitem__(self, key))
162 def _set(self, key, value):
163 dict.__setitem__(self, key, self._raw_value(value))
169 def _del(self, collection, key):
173 class _InstrumentedList(_InstrumentedCollection, list):
175 def _load(self, list_=None, **kwargs):
176 list.__init__(self, list(item for item in list_ or []))
178 def append(self, value):
179 self.insert(len(self), value)
181 def insert(self, index, value):
182 list.insert(self, index, self._raw_value(value))
183 if self._is_top_level:
184 field = getattr(self._parent, self._field_name)
185 field.insert(index, self._encapsulate_value(index, value))
187 self._parent[self._field_name] = self
189 def __getitem__(self, key):
190 return self._instrument(key, list.__getitem__(self, key))
192 def _set(self, key, value):
193 list.__setitem__(self, key, value)
195 def _del(self, collection, key):
203 class _WrappedBase(object):
205 def __init__(self, wrapped, instrumentation, instrumentation_kwargs=None):
207 :param wrapped: model to be instrumented
208 :param instrumentation: instrumentation dict
209 :param instrumentation_kwargs: arguments for instrumentation class
211 self._wrapped = wrapped
212 self._instrumentation = instrumentation
213 self._instrumentation_kwargs = instrumentation_kwargs or {}
215 def _wrap(self, value):
216 if value.__class__ in set(class_.class_ for class_ in self._instrumentation):
217 return _create_instrumented_model(
218 value, instrumentation=self._instrumentation, **self._instrumentation_kwargs)
219 # Check that the value is a SQLAlchemy model (it should have metadata) or a collection
220 elif hasattr(value, 'metadata') or isinstance(value, (dict, list)):
221 return _create_wrapped_model(
222 value, instrumentation=self._instrumentation, **self._instrumentation_kwargs)
225 def __getattr__(self, item):
226 if hasattr(self, '_wrapped'):
227 return self._wrap(getattr(self._wrapped, item))
229 super(_WrappedBase, self).__getattribute__(item)
232 class _InstrumentedModel(_WrappedBase):
234 def __init__(self, mapi, *args, **kwargs):
238 :param mapi: MAPI for the wrapped model
239 :param wrapped: model to be instrumented
240 :param instrumentation: instrumentation dict
241 :param instrumentation_kwargs: arguments for instrumentation class
243 super(_InstrumentedModel, self).__init__(instrumentation_kwargs=dict(mapi=mapi),
246 self._apply_instrumentation()
248 def _apply_instrumentation(self):
249 for field in self._instrumentation:
250 if not issubclass(type(self._wrapped), field.parent.class_):
251 # Do not apply if this field is not for our class
254 field_name = field.key
255 field_cls = field.mapper.class_
257 field = getattr(self._wrapped, field_name)
259 # Preserve the original field, e.g. original "attributes" would be located under
261 setattr(self, '_{0}'.format(field_name), field)
263 # Set instrumented value
264 if isinstance(field, dict):
265 instrumentation_cls = _InstrumentedDict
266 elif isinstance(field, list):
267 instrumentation_cls = _InstrumentedList
269 # TODO: raise proper error
270 raise exceptions.StorageError(
271 "ARIA supports instrumentation for dict and list. Field {field} of the "
272 "class `{model}` is of type `{type}`.".format(
277 instrumented_class = instrumentation_cls(seq=field,
278 parent=self._wrapped,
280 field_name=field_name,
282 setattr(self, field_name, instrumented_class)
285 class _WrappedModel(_WrappedBase):
287 def __getitem__(self, item):
288 return self._wrap(self._wrapped[item])
291 for item in self._wrapped.__iter__():
292 yield self._wrap(item)
295 def _create_instrumented_model(original_model, mapi, instrumentation):
296 return type('Instrumented{0}'.format(original_model.__class__.__name__),
297 (_InstrumentedModel,),
298 {})(wrapped=original_model, instrumentation=instrumentation, mapi=mapi)
301 def _create_wrapped_model(original_model, mapi, instrumentation):
302 return type('Wrapped{0}'.format(original_model.__class__.__name__),
304 {})(wrapped=original_model,
305 instrumentation=instrumentation,
306 instrumentation_kwargs=dict(mapi=mapi))
309 def instrument(instrumentation, original_model, mapi):
310 for instrumented_field in instrumentation:
311 if isinstance(original_model, instrumented_field.class_):
312 return _create_instrumented_model(original_model, mapi, instrumentation)
314 return _create_wrapped_model(original_model, mapi, instrumentation)