#!/usr/bin/python
# -*- coding: utf-8 -*-
from webhelpers import paginate as _paginate
import functools
from datetime import datetime, date
import logging
from yarhp.utils import get_composite_key
log = logging.getLogger(__name__)
class ValidationError(Exception):
pass
class ResourceNotFound(Exception):
pass
def get_pagination_links(request):
page_url = functools.partial(_paginate.make_page_url, request.path_url,
request.params)
calls = zip(['first', 'previous', 'next', 'last'],
[request.page.first_page, request.page.previous_page,
request.page.next_page, request.page.last_page])
return dict([(name, page_url(lnk)) for (name, lnk) in calls if lnk])
# Decorators
[docs]class wrap_me(object):
"""Base class for decorators used to add before and after calls.
The callables are appended to the ``before`` or ``after`` lists,
which are in turn injected into the method object being decorated.
Method is returned without any side effects.
"""
def __init__(self, before=None, after=None):
self.before = (before if type(before)
is list else ([before] if before else []))
self.after = (after if type(after)
is list else ([after] if after else []))
def __call__(self, meth):
if not hasattr(meth, '_before_calls'):
meth._before_calls = []
if not hasattr(meth, '_after_calls'):
meth._after_calls = []
meth._before_calls += self.before
meth._after_calls += self.after
return meth
[docs]class paginate(wrap_me):
"Decorator used to paginate the results. Its an after-call"
def __init__(self, **kwargs):
wrap_me.__init__(self, after=pager(**kwargs))
[docs]class validator(wrap_me):
"""Decorator that validates the type and required fields in request params against the supplied kwargs
::
class MyView():
@validator(first_name={'type':int, 'required':True})
def index(self):
return response
"""
def __init__(self, **kwargs):
wrap_me.__init__(self, before=[validate_types(**kwargs),
validate_required(**kwargs)])
[docs]class check_ancestors(wrap_me):
"Decorator which adds a wrapper to check the ancestors."
def __init__(self):
wrap_me.__init__(self, before=[validate_ancestors()])
[docs]class callable_base(object):
"""Base class for all before and after calls.
``__eq__`` method is overloaded in order to prevent duplicate callables of the same type.
For example, you could have a before call ``pager`` which is called in the base class and
also decorate the action with ``paginate``. ``__eq__`` declares all same type callables to be the same.
"""
def __init__(self, **kwargs):
self.kwargs = kwargs
def __eq__(self, other):
"we only allow one instance of the same type of callable."
return type(self) == type(other)
# Before calls
[docs]class validate_base(callable_base):
"""Base class for validation callables.
"""
def __call__(self, **kwargs):
self.request = kwargs['request']
self.params = self.request.params.copy()
self.params.pop('_method', None) # tunneling internal param, no need to check.
[docs]class validate_types(validate_base):
"""
Validates the field types in ``request.params`` match the types declared in ``kwargs``.
Raises ValidationError if there is mismatch.
"""
def __call__(self, **kwargs):
validate_base.__call__(self, **kwargs)
# checking the types
for name, value in self.params.items():
if value == 'None': # fix this properly.
continue
_type = self.kwargs.get(name, {}).get('type')
try:
if _type == datetime:
value = datetime.strptime(value, '%Y-%m-%dT%H:%M:%S') # must be in iso format
elif _type == date:
value = datetime.strptime(value, '%Y-%m-%d') # must be in iso format
elif _type == None:
log.debug('Incorrect or unsupported type for %s(%s)'
% (name, value))
continue
elif type(_type) is type:
_type(value)
else:
raise ValueError
except ValueError, e:
raise ValidationError('Bad type %s for %s=%s. Suppose to be %s'
% (type(value), name, value, _type))
[docs]class validate_required(validate_base):
"""Validates that fields in ``request.params`` are present
according to ``kwargs`` argument passed to ``__call__.
Raises ValidationError in case of the mismatch
"""
def __call__(self, **kwargs):
validate_base.__call__(self, **kwargs)
#get parent resources' ids from matchdict, so there is no need to pass in the request.params
self.params.update(self.request.matchdict)
self.kwargs.pop('id', None)
required_fields = set([n for n in self.kwargs.keys()
if self.kwargs[n].get('required', False)])
if not required_fields.issubset(set(self.params.keys())):
raise ValidationError('Required fields: %s. Received: %s'
% (list(required_fields),
self.params.keys()))
[docs]class validate_ancestors(validate_base):
"""Validates that all ancestor resources of the given resource exist"""
def __call__(self, **kwargs):
request = kwargs['request']
resource = request.context
for anc in resource.ancestors:
# if anc is application then its id is application_id. find it and put it back in args as id
kw = request.matchdict.copy()
# this could have been pop instead of get.
# but some models might not have "id" primary keys and have only fks as primary composite.
# so leaving the "member_name"+_id there would make sure that case is covered.
kw['id'] = kw.get(anc.member_name + '_id', None)
key = get_composite_key(anc.model, kw)
if not anc.model.get(key):
raise ResourceNotFound('Resource %s(%s) does not exist' % (anc.model,
key))
# After calls.
[docs]def obj2dict(**kwargs):
'''converts objects in `result` into dicts'''
def to_dict(each):
try:
return each.to_dict()
except AttributeError:
return each
result = kwargs['result']
if type(result) is list:
for ix, each in enumerate(result):
result[ix] = to_dict(each)
else:
result = to_dict(result)
return result
[docs]def wrap_in_dict(**kwargs):
'''if result is a list then wrap it in the dict'''
resource = kwargs['resource']
result = kwargs['result']
if type(result) is list:
result = {resource.collection_name: result}
return result
[docs]def update_links(result, new_links):
""" if result has "links", update it with new_links
if result has "link", change it to "links" and add values from new_links if not empty
if result has neither, create "link" or "links" depending len(new_links)
"""
if type(result) is not dict or len(new_links) == 0:
return result
links = result.pop('link', {})
links.update(result.pop('links', {}))
links.update(new_links)
if len(links) > 1:
result['links'] = links
else:
result['link'] = links
return result
def add_pagination_links(**kwargs):
request = kwargs['request']
result = kwargs['result']
try:
result['total'] = request.page.item_count
result = update_links(result, get_pagination_links(request))
except AttributeError:
# request does not have page attribtue.
pass
return result
[docs]def add_self_links(**kwargs):
"Add links to the result dict to the resouce itself."
request = kwargs['request']
result = kwargs['result']
for each in (result if type(result) is list else [result]):
each = update_links(each, {'self': request.url, 'edit': request.url})
return result
def add_parent_links(**kwargs):
request = kwargs['request']
result = kwargs['result']
for each in (result if type(result) is list else [result]):
if type(each) is not dict:
continue
for name_id, value in each.items():
if name_id.endswith('_id'):
name = name_id.rpartition('_id')[0]
try:
url = request.route_url(name, id=value)
except KeyError:
# there is no route defined for this name
continue
#FIX. Dirty hack. Must not remove any data from the result.
each.pop(name_id) # i.e. remove account_id
each[name] = dict(id=value, link=url) #i.e. put account with dict(id=1, link='http://../')
return result