From d375ebf1ee7a390ad926a9db40528794e27dba18 Mon Sep 17 00:00:00 2001 From: Craig Loftus Date: Thu, 5 Jan 2017 21:01:06 +0000 Subject: [PATCH] Allow passing of json kwargs to get, set, push and update --- pyrebase/pyrebase.py | 16 ++++++++-------- tests/test_database.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/pyrebase/pyrebase.py b/pyrebase/pyrebase.py index 90b4e5c..2be9034 100644 --- a/pyrebase/pyrebase.py +++ b/pyrebase/pyrebase.py @@ -253,7 +253,7 @@ def build_headers(self, token=None): headers['Authorization'] = 'Bearer ' + access_token return headers - def get(self, token=None): + def get(self, token=None, json_kwargs={}): build_query = self.build_query query_key = self.path.split("/")[-1] request_ref = self.build_request_url(token) @@ -262,7 +262,7 @@ def get(self, token=None): # do request request_object = self.requests.get(request_ref, headers=headers) raise_detailed_error(request_object) - request_dict = request_object.json() + request_dict = request_object.json(**json_kwargs) # if primitive or simple query return if isinstance(request_dict, list): @@ -285,27 +285,27 @@ def get(self, token=None): sorted_response = sorted(request_dict.items(), key=lambda item: item[1][build_query["orderBy"]]) return PyreResponse(convert_to_pyre(sorted_response), query_key) - def push(self, data, token=None): + def push(self, data, token=None, json_kwargs={}): request_ref = self.check_token(self.database_url, self.path, token) self.path = "" headers = self.build_headers(token) - request_object = self.requests.post(request_ref, headers=headers, data=json.dumps(data).encode("utf-8")) + request_object = self.requests.post(request_ref, headers=headers, data=json.dumps(data, **json_kwargs).encode("utf-8")) raise_detailed_error(request_object) return request_object.json() - def set(self, data, token=None): + def set(self, data, token=None, json_kwargs={}): request_ref = self.check_token(self.database_url, self.path, token) self.path = "" headers = self.build_headers(token) - request_object = self.requests.put(request_ref, headers=headers, data=json.dumps(data).encode("utf-8")) + request_object = self.requests.put(request_ref, headers=headers, data=json.dumps(data, **json_kwargs).encode("utf-8")) raise_detailed_error(request_object) return request_object.json() - def update(self, data, token=None): + def update(self, data, token=None, json_kwargs={}): request_ref = self.check_token(self.database_url, self.path, token) self.path = "" headers = self.build_headers(token) - request_object = self.requests.patch(request_ref, headers=headers, data=json.dumps(data).encode("utf-8")) + request_object = self.requests.patch(request_ref, headers=headers, data=json.dumps(data, **json_kwargs).encode("utf-8")) raise_detailed_error(request_object) return request_object.json() diff --git a/tests/test_database.py b/tests/test_database.py index dc40aea..08ea0ed 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,3 +1,4 @@ +import datetime import random import time from contextlib import contextmanager @@ -59,6 +60,36 @@ def test_put_deeper_dictionnary(self, db_sa): assert db_sa().get().val() == v +class TestJsonKwargs: + + def encoder(self, obj): + if isinstance(obj, datetime.datetime): + return { + '__type__': obj.__class__.__name__, + 'value': obj.timestamp(), + } + return obj + + def decoder(self, obj): + if '__type__' in obj and obj['__type__'] == datetime.datetime.__name__: + return datetime.datetime.utcfromtimestamp(obj['value']) + return obj + + def test_put_fail(self, db_sa): + v = {'some_datetime': datetime.datetime.now()} + with pytest.raises(TypeError): + db_sa().set(v) + + def test_put_succeed(self, db_sa): + v = {'some_datetime': datetime.datetime.now()} + assert db_sa().set(v, json_kwargs={'default': str}) + + def test_put_then_get_succeed(self, db_sa): + v = {'another_datetime': datetime.datetime.now()} + db_sa().set(v, json_kwargs={'default': self.encoder}) + assert db_sa().get(json_kwargs={'object_hook': self.decoder}).val() == v + + class TestChildNavigation: def test_get_child_none(self, db_sa): assert db_sa().child('lorem').get().val() is None