1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements. See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License. You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
17 Allows JSON-serializable collections to be used as SQLAlchemy column types.
21 from collections import namedtuple
23 from sqlalchemy import (
28 from sqlalchemy.ext import mutable
29 from ruamel import yaml
31 from . import exceptions
34 class _MutableType(TypeDecorator):
36 Dict representation of type.
39 def python_type(self):
40 raise NotImplementedError
42 def process_literal_param(self, value, dialect):
47 def process_bind_param(self, value, dialect):
49 value = json.dumps(value)
52 def process_result_value(self, value, dialect):
54 value = json.loads(value)
58 class Dict(_MutableType):
60 JSON-serializable dict type for SQLAlchemy columns.
63 def python_type(self):
67 class List(_MutableType):
69 JSON-serializable list type for SQLAlchemy columns.
72 def python_type(self):
76 class _StrictDictMixin(object):
79 def coerce(cls, key, value):
81 Convert plain dictionaries to MutableDict.
84 if not isinstance(value, cls):
85 if isinstance(value, dict):
86 for k, v in value.items():
87 cls._assert_strict_key(k)
88 cls._assert_strict_value(v)
90 return mutable.MutableDict.coerce(key, value)
93 except ValueError as e:
94 raise exceptions.ValueFormatException('could not coerce to MutableDict', cause=e)
96 def __setitem__(self, key, value):
97 self._assert_strict_key(key)
98 self._assert_strict_value(value)
99 super(_StrictDictMixin, self).__setitem__(key, value)
101 def setdefault(self, key, value):
102 self._assert_strict_key(key)
103 self._assert_strict_value(value)
104 super(_StrictDictMixin, self).setdefault(key, value)
106 def update(self, *args, **kwargs):
107 for k, v in kwargs.items():
108 self._assert_strict_key(k)
109 self._assert_strict_value(v)
110 super(_StrictDictMixin, self).update(*args, **kwargs)
113 def _assert_strict_key(cls, key):
114 if cls._key_cls is not None and not isinstance(key, cls._key_cls):
115 raise exceptions.ValueFormatException('key type was set strictly to {0}, but was {1}'
116 .format(cls._key_cls, type(key)))
119 def _assert_strict_value(cls, value):
120 if cls._value_cls is not None and not isinstance(value, cls._value_cls):
121 raise exceptions.ValueFormatException('value type was set strictly to {0}, but was {1}'
122 .format(cls._value_cls, type(value)))
125 class _MutableDict(mutable.MutableDict):
127 Enables tracking for dict values.
131 def coerce(cls, key, value):
133 Convert plain dictionaries to MutableDict.
136 return mutable.MutableDict.coerce(key, value)
137 except ValueError as e:
138 raise exceptions.ValueFormatException('could not coerce value', cause=e)
141 class _StrictListMixin(object):
144 def coerce(cls, key, value):
145 "Convert plain dictionaries to MutableDict."
147 if not isinstance(value, cls):
148 if isinstance(value, list):
150 cls._assert_item(item)
152 return mutable.MutableList.coerce(key, value)
155 except ValueError as e:
156 raise exceptions.ValueFormatException('could not coerce to MutableDict', cause=e)
158 def __setitem__(self, index, value):
160 Detect list set events and emit change events.
162 self._assert_item(value)
163 super(_StrictListMixin, self).__setitem__(index, value)
165 def append(self, item):
166 self._assert_item(item)
167 super(_StrictListMixin, self).append(item)
169 def extend(self, item):
170 self._assert_item(item)
171 super(_StrictListMixin, self).extend(item)
173 def insert(self, index, item):
174 self._assert_item(item)
175 super(_StrictListMixin, self).insert(index, item)
178 def _assert_item(cls, item):
179 if cls._item_cls is not None and not isinstance(item, cls._item_cls):
180 raise exceptions.ValueFormatException('key type was set strictly to {0}, but was {1}'
181 .format(cls._item_cls, type(item)))
184 class _MutableList(mutable.MutableList):
187 def coerce(cls, key, value):
189 Convert plain dictionaries to MutableDict.
192 return mutable.MutableList.coerce(key, value)
193 except ValueError as e:
194 raise exceptions.ValueFormatException('could not coerce to MutableDict', cause=e)
197 _StrictDictID = namedtuple('_StrictDictID', 'key_cls, value_cls')
198 _StrictValue = namedtuple('_StrictValue', 'type_cls, listener_cls')
200 class _StrictDict(object):
202 This entire class functions as a factory for strict dicts and their listeners. No type class,
203 and no listener type class is created more than once. If a relevant type class exists it is
208 def __call__(self, key_cls=None, value_cls=None):
209 strict_dict_map_key = _StrictDictID(key_cls=key_cls, value_cls=value_cls)
210 if strict_dict_map_key not in self._strict_map:
211 key_cls_name = getattr(key_cls, '__name__', str(key_cls))
212 value_cls_name = getattr(value_cls, '__name__', str(value_cls))
213 # Creating the type class itself. this class would be returned (used by the SQLAlchemy
215 strict_dict_cls = type(
216 'StrictDict_{0}_{1}'.format(key_cls_name, value_cls_name),
220 # Creating the type listening class.
221 # The new class inherits from both the _MutableDict class and the _StrictDictMixin,
222 # while setting the necessary _key_cls and _value_cls as class attributes.
224 'StrictMutableDict_{0}_{1}'.format(key_cls_name, value_cls_name),
225 (_StrictDictMixin, _MutableDict),
226 {'_key_cls': key_cls, '_value_cls': value_cls}
228 yaml.representer.RoundTripRepresenter.add_representer(
229 listener_cls, yaml.representer.RoundTripRepresenter.represent_list)
230 self._strict_map[strict_dict_map_key] = _StrictValue(type_cls=strict_dict_cls,
231 listener_cls=listener_cls)
233 return self._strict_map[strict_dict_map_key].type_cls
236 StrictDict = _StrictDict()
238 JSON-serializable strict dict type for SQLAlchemy columns.
245 class _StrictList(object):
247 This entire class functions as a factory for strict lists and their listeners. No type class,
248 and no listener type class is created more than once. If a relevant type class exists it is
253 def __call__(self, item_cls=None):
255 if item_cls not in self._strict_map:
256 item_cls_name = getattr(item_cls, '__name__', str(item_cls))
257 # Creating the type class itself. this class would be returned (used by the SQLAlchemy
259 strict_list_cls = type(
260 'StrictList_{0}'.format(item_cls_name),
264 # Creating the type listening class.
265 # The new class inherits from both the _MutableList class and the _StrictListMixin,
266 # while setting the necessary _item_cls as class attribute.
268 'StrictMutableList_{0}'.format(item_cls_name),
269 (_StrictListMixin, _MutableList),
270 {'_item_cls': item_cls}
272 yaml.representer.RoundTripRepresenter.add_representer(
273 listener_cls, yaml.representer.RoundTripRepresenter.represent_list)
274 self._strict_map[item_cls] = _StrictValue(type_cls=strict_list_cls,
275 listener_cls=listener_cls)
277 return self._strict_map[item_cls].type_cls
280 StrictList = _StrictList()
282 JSON-serializable strict list type for SQLAlchemy columns.
288 def _mutable_association_listener(mapper, cls):
289 strict_dict_type_to_listener = \
290 dict((v.type_cls, v.listener_cls) for v in _StrictDict._strict_map.itervalues())
292 strict_list_type_to_listener = \
293 dict((v.type_cls, v.listener_cls) for v in _StrictList._strict_map.itervalues())
295 for prop in mapper.column_attrs:
296 column_type = prop.columns[0].type
298 if type(column_type) in strict_dict_type_to_listener: # pylint: disable=unidiomatic-typecheck
299 strict_dict_type_to_listener[type(column_type)].associate_with_attribute(
300 getattr(cls, prop.key))
301 elif isinstance(column_type, Dict):
302 _MutableDict.associate_with_attribute(getattr(cls, prop.key))
305 if type(column_type) in strict_list_type_to_listener: # pylint: disable=unidiomatic-typecheck
306 strict_list_type_to_listener[type(column_type)].associate_with_attribute(
307 getattr(cls, prop.key))
308 elif isinstance(column_type, List):
309 _MutableList.associate_with_attribute(getattr(cls, prop.key))
312 _LISTENER_ARGS = (mutable.mapper, 'mapper_configured', _mutable_association_listener)
315 def _register_mutable_association_listener():
316 event.listen(*_LISTENER_ARGS)
318 _register_mutable_association_listener()