diff --git a/CHANGES.rst b/CHANGES.rst index e078cefef..3c8a6bfe8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,7 +7,10 @@ Changes New features: * 'region' argument for :ref:`splash-png` and :ref:`splash-jpeg` methods - allow to take screenshots of parts of pages. + allow to take screenshots of parts of pages; +* :ref:`save_args ` and :ref:`load_args ` + parameters allow to save network traffic by caching large request arguments + inside Splash server. Bug fixes: diff --git a/docs/api.rst b/docs/api.rst index e0bb18702..0cff1142c 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -171,6 +171,47 @@ http_method : string : optional HTTP method of outgoing Splash request. Default method is GET. Splash also supports POST. +.. _arg-save-args: + +save_args : JSON array or a comma-separated string : optional + A list of argument names to put in cache. Splash will store each + argument value in an internal cache and return ``X-Splash-Saved-Arguments`` + HTTP header with a list of SHA1 hashes for each argument + (a semicolon-separated list of name=hash pairs):: + + name1=9a6747fc6259aa374ab4e1bb03074b6ec672cf99;name2=ba001160ef96fe2a3f938fea9e6762e204a562b3 + + Client can then use :ref:`load_args ` parameter + to pass these hashes instead of argument values. This is most useful + when argument value is large and doesn't change often + (:ref:`js_source ` or :ref:`lua_source ` + are often good candidates). + +.. _arg-load-args: + +load_args : JSON object or a string : optional + Parameter values to load from cache. + ``load_args`` should be either ``{"name": "", ...}`` + JSON object or a raw ``X-Splash-Saved-Arguments`` header value + (a semicolon-separated list of name=hash pairs). + + For each parameter in ``load_args`` Splash tries to fetch the + value from the internal cache using a provided SHA1 hash as a key. + If all values are in cache then Splash uses them as argument values + and then handles the request as usual. + + If at least on argument can't be found Splash returns **HTTP 498** status + code. In this case client should repeat the request, but + use :ref:`save_args ` and send full argument values. + + :ref:`load_args ` and :ref:`save_args ` + allow to save network traffic by not sending large arguments with each + request (:ref:`js_source ` and + :ref:`lua_source ` are often good candidates). + + Splash uses LRU cache to store values; the number of entries is limited, + and cache is cleared after each Splash restart. In other words, storage + is not persistent; client should be ready to re-send the arguments. Examples ~~~~~~~~ @@ -522,9 +563,10 @@ execute Execute a custom rendering script and return a result. -:ref:`render.html`, :ref:`render.png`, :ref:`render.jpeg`, :ref:`render.har` and :ref:`render.json` -endpoints cover many common use cases, but sometimes they are not enough. -This endpoint allows to write custom :ref:`Splash Scripts `. +:ref:`render.html`, :ref:`render.png`, :ref:`render.jpeg`, :ref:`render.har` +and :ref:`render.json` endpoints cover many common use cases, but sometimes +they are not enough. This endpoint allows to write custom +:ref:`Splash Scripts `. Arguments: @@ -547,6 +589,16 @@ proxy : string : optional filters : string : optional Same as :ref:`'filters' ` argument for `render.html`_. +save_args : JSON array or a comma-separated string : optional + Same as :ref:`'save_args' ` argument for `render.html`_. + Note that you can save not only default Splash arguments, + but any other parameters as well. + +load_args : JSON object or a string : optional + Same as :ref:`'load_args' ` argument for `render.html`_. + Note that you can load not only default Splash arguments, + but any other parameters as well. + .. _execute javascript: Executing custom Javascript code within page context diff --git a/splash/argument_cache.py b/splash/argument_cache.py new file mode 100644 index 000000000..2d2db857e --- /dev/null +++ b/splash/argument_cache.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import +import json +import hashlib +from collections import OrderedDict + + +class ArgumentCache(object): + """ + >>> cache = ArgumentCache() + >>> "foo" in cache + False + >>> cache['foo'] + Traceback (most recent call last): + ... + KeyError: 'foo' + >>> len(cache) + 0 + >>> key = cache.add("Hello, world!") + >>> key + 'bea2c9d7fd040292e0424938af39f7d6334e8d8a' + >>> cache[key] + 'Hello, world!' + >>> key in cache + True + >>> len(cache) + 1 + >>> cache.get_missing([ + ... ('bar', key), + ... ('baz', '1111111111111111111111111111111111111111'), + ... ]) + ['baz'] + >>> cache.add_many(['value1', 'value2']) + ['daf626c4ebd6bdd697e043111454304e5fb1459e', '849988af22dbd04d3e353caf77f9d81241ca9ee2'] + >>> cache['daf626c4ebd6bdd697e043111454304e5fb1459e'] + 'value1' + >>> cache['849988af22dbd04d3e353caf77f9d81241ca9ee2'] + 'value2' + >>> cache[key] + 'Hello, world!' + >>> len(cache) + 3 + >>> cache.clear() + >>> len(cache) + 0 + + Size of ArgumentCache can be limited: + + >>> cache = ArgumentCache(0) + Traceback (most recent call last): + ... + ValueError: maxsize must be greater than 0 + >>> cache = ArgumentCache(2) # limit it to 2 elements + >>> cache.add_many(['value1', 'value2']) + ['daf626c4ebd6bdd697e043111454304e5fb1459e', '849988af22dbd04d3e353caf77f9d81241ca9ee2'] + >>> len(cache) + 2 + >>> cache.add("Hello, world!") + 'bea2c9d7fd040292e0424938af39f7d6334e8d8a' + >>> len(cache) + 2 + >>> cache["bea2c9d7fd040292e0424938af39f7d6334e8d8a"] + 'Hello, world!' + >>> cache['849988af22dbd04d3e353caf77f9d81241ca9ee2'] + 'value2' + >>> cache['daf626c4ebd6bdd697e043111454304e5fb1459e'] + Traceback (most recent call last): + ... + KeyError: 'daf626c4ebd6bdd697e043111454304e5fb1459e' + >>> cache.add("foo") + 'd465e627f9946f2fa0d2dc0fc04e5385bc6cd46d' + >>> len(cache) + 2 + >>> 'bea2c9d7fd040292e0424938af39f7d6334e8d8a' in cache + False + """ + def __init__(self, maxsize=None): + if maxsize is None: + maxsize = float("+inf") + if maxsize <= 0: + raise ValueError("maxsize must be greater than 0") + self.maxsize = maxsize + self._values = OrderedDict() + + def add(self, value): + key = self.get_key(value) + if key in self._values: + del self._values[key] + else: + while len(self._values) >= self.maxsize: + self._values.popitem(last=False) + self._values[key] = value + return key + + def __getitem__(self, key): + self._values.move_to_end(key) + return self._values[key] + + def __contains__(self, key): + return key in self._values + + def __len__(self): + return len(self._values) + + def clear(self): + self._values.clear() + + def get_missing(self, items): + return [name for name, key in items if key not in self] + + def add_many(self, values): + """ + Add all values from ``values`` list to cache. Return a list of keys. + """ + return [self.add(value) for value in values] + + @classmethod + def get_key(cls, value): + value_json = json.dumps(value, sort_keys=True, ensure_ascii=False) + return hashlib.sha1(value_json.encode('utf8')).hexdigest() diff --git a/splash/defaults.py b/splash/defaults.py index 86d8a2496..656119ebc 100644 --- a/splash/defaults.py +++ b/splash/defaults.py @@ -72,6 +72,9 @@ # pool options SLOTS = 50 +# argument cache option +ARGUMENT_CACHE_MAX_ENTRIES = 500 + # security options ALLOWED_SCHEMES = ['http', 'https', 'data', 'ftp', 'sftp', 'ws', 'wss'] JS_CROSS_DOMAIN_ENABLED = False diff --git a/splash/exceptions.py b/splash/exceptions.py index fe04d5c65..aa9bec12a 100644 --- a/splash/exceptions.py +++ b/splash/exceptions.py @@ -27,6 +27,11 @@ class UnsupportedContentType(Exception): pass +class ExpiredArguments(Exception): + """ Arguments stored with ``save_args`` are expired """ + pass + + class ScriptError(BadOption): """ Error happened while executing Lua script """ LUA_INIT_ERROR = 'LUA_INIT_ERROR' # error happened before coroutine starts diff --git a/splash/render_options.py b/splash/render_options.py index 0f6a72aba..cac9f699f 100644 --- a/splash/render_options.py +++ b/splash/render_options.py @@ -65,6 +65,27 @@ def fromrequest(cls, request, max_timeout): data['uid'] = id(request) return cls(data, max_timeout) + def get_expired_args(self, cache): + """ + Return a list of argument names from load_args which can't be loaded + """ + return cache.get_missing(self.get_load_args().items()) + + def save_args_to_cache(self, cache): + """ + Process save_args and put all values to cache. + Return a list of (name, key) pairs. + """ + save_args = self.get_save_args() + save_values = [self.data.get(name) for name in save_args] + keys = cache.add_many(save_values) + return list(zip(save_args, keys)) + + def load_cached_args(self, cache): + load_args = self.get_load_args() + for name, key in (load_args or {}).items(): + self.data[name] = cache[key] + def get(self, name, default=_REQUIRED, type=six.text_type, range=None): value = self.data.get(name) if value is not None: @@ -216,6 +237,57 @@ def get_headers(self): return headers + def get_save_args(self): + save_args = self.get("save_args", default=None, type=None) + if save_args is None: + return [] + + if isinstance(save_args, six.text_type): + # comma-separated string + save_args = save_args.split(',') + + if not isinstance(save_args, list): + self.raise_error( + argument="save_args", + description="'save_args' should be either a comma-separated " + "string or a JSON array with argument names", + ) + + # JSON array + if not all(isinstance(a, six.text_type) for a in save_args): + self.raise_error( + argument="save_args", + description="'save_args' should be a list of strings", + ) + return save_args + + def get_load_args(self): + load_args = self.get("load_args", default=None, type=None) + if load_args is None: + return {} + + if isinstance(load_args, six.text_type): + try: + load_args = dict( + kv.split("=", 1) for kv in load_args.split(';') + ) + except ValueError: + self.raise_error( + argument="load_args", + description="'load_args' string value is not a " + "semicolon-separated list of name=hash pairs" + ) + + if not isinstance(load_args, dict): + self.raise_error( + argument="load_args", + description="'load_args' should be either a JSON object with " + "argument hashes or a semicolon-separated list " + "of name=hash pairs" + ) + + return load_args + def get_viewport(self, wait=None): viewport = self.get("viewport", defaults.VIEWPORT_SIZE) diff --git a/splash/resources.py b/splash/resources.py index 08fd95b30..9153ba3d3 100644 --- a/splash/resources.py +++ b/splash/resources.py @@ -17,6 +17,7 @@ import six import splash +from splash.argument_cache import ArgumentCache from splash.qtrender import ( HtmlRender, PngRender, JsonRender, HarRender, JpegRender ) @@ -34,6 +35,7 @@ from splash.exceptions import ( BadOption, RenderError, InternalError, GlobalTimeoutError, UnsupportedContentType, + ExpiredArguments, ) if lua_is_supported(): @@ -83,17 +85,33 @@ class BaseRenderResource(_ValidatingResource): isLeaf = True content_type = "text/html; charset=utf-8" - def __init__(self, pool, max_timeout): + def __init__(self, pool, max_timeout, argument_cache): Resource.__init__(self) self.pool = pool self.js_profiles_path = self.pool.js_profiles_path self.max_timeout = max_timeout + self.argument_cache = argument_cache def render_GET(self, request): #log.msg("%s %s %s %s" % (id(request), request.method, request.path, request.args)) request.starttime = time.time() render_options = RenderOptions.fromrequest(request, self.max_timeout) + # process argument cache + original_options = render_options.data.copy() + expired_args = render_options.get_expired_args(self.argument_cache) + if expired_args: + error = self._write_expired_args(request, expired_args) + self._log_stats(request, original_options, error) + return b"\n" + + saved_args = render_options.save_args_to_cache(self.argument_cache) + if saved_args: + value = ';'.join("{}={}".format(name, value) + for name, value in saved_args) + request.setHeader(b'X-Splash-Saved-Arguments', value.encode('utf8')) + render_options.load_cached_args(self.argument_cache) + # check arguments before starting the render render_options.get_filters(self.pool) @@ -110,7 +128,8 @@ def render_GET(self, request): pool_d.addErrback(self._on_render_error, request) pool_d.addErrback(self._on_bad_request, request) pool_d.addErrback(self._on_internal_error, request) - pool_d.addBoth(self._finish_request, request, options=render_options.data) + pool_d.addBoth(self._finish_request, request, + options=original_options) return NOT_DONE_YET def render_POST(self, request): @@ -175,6 +194,10 @@ def _write_output(self, data, request, content_type=None): request.write(data) + def _write_expired_args(self, request, expired_args): + ex = ExpiredArguments({'expired': expired_args}) + return self._write_error(request, 498, ex) + def _log_stats(self, request, options, error=None): # def args_to_unicode(args): @@ -259,8 +282,10 @@ class ExecuteLuaScriptResource(BaseRenderResource): def __init__(self, pool, sandboxed, lua_package_path, lua_sandbox_allowed_modules, - max_timeout): - BaseRenderResource.__init__(self, pool, max_timeout) + max_timeout, + argument_cache, + ): + BaseRenderResource.__init__(self, pool, max_timeout, argument_cache) self.sandboxed = sandboxed self.lua_package_path = lua_package_path self.lua_sandbox_allowed_modules = lua_sandbox_allowed_modules @@ -316,8 +341,9 @@ def _get_render(self, request, options): class DebugResource(Resource): isLeaf = True - def __init__(self, pool, warn=False): + def __init__(self, pool, argument_cache, warn=False): Resource.__init__(self) + self.argument_cache = argument_cache self.pool = pool self.warn = warn @@ -329,13 +355,14 @@ def render_GET(self, request): "qsize": len(self.pool.queue.pending), "maxrss": resource.getrusage(resource.RUSAGE_SELF).ru_maxrss, "fds": get_num_fds(), + "argcache": len(self.argument_cache) } if self.warn: info['WARNING'] = "/debug endpoint is deprecated. " \ "Please use /_debug instead." # info['leaks'] = get_leaks() - return (json.dumps(info)).encode('utf-8') + return (json.dumps(info, sort_keys=True)).encode('utf-8') def get_repr(self, render): if hasattr(render, 'url'): @@ -347,13 +374,20 @@ class ClearCachesResource(Resource): isLeaf = True content_type = "application/json" + def __init__(self, argument_cache): + Resource.__init__(self) + self.argument_cache = argument_cache + def render_POST(self, request): + argcache_size = len(self.argument_cache) + self.argument_cache.clear() clear_caches() unreachable = gc.collect() return json.dumps({ "status": "ok", - "pyobjects_collected": unreachable - }).encode('utf-8') + "pyobjects_collected": unreachable, + "cached_args_removed": argcache_size, + }, sort_keys=True).encode('utf-8') class PingResource(Resource): @@ -364,7 +398,7 @@ def render_GET(self, request): return (json.dumps({ "status": "ok", "maxrss": get_ru_maxrss(), - })).encode('utf-8') + }, sort_keys=True)).encode('utf-8') @@ -407,6 +441,8 @@ def _validate_params(self, request): options.get_filters(self.pool) # check params = options.get_common_params(self.pool.js_profiles_path) params.update({ + 'save_args': options.get_save_args(), + 'load_args': options.get_load_args(), 'timeout': options.get_timeout(), 'har': 1, 'png': 1, @@ -527,22 +563,27 @@ class Root(Resource): def __init__(self, pool, ui_enabled, lua_enabled, lua_sandbox_enabled, lua_package_path, lua_sandbox_allowed_modules, - max_timeout): + max_timeout, + argument_cache_max_entries, + ): Resource.__init__(self) + self.argument_cache = ArgumentCache(argument_cache_max_entries) self.ui_enabled = ui_enabled self.lua_enabled = lua_enabled - self.putChild(b"render.html", RenderHtmlResource(pool, max_timeout)) - self.putChild(b"render.png", RenderPngResource(pool, max_timeout)) - self.putChild(b"render.jpeg", RenderJpegResource(pool, max_timeout)) - self.putChild(b"render.json", RenderJsonResource(pool, max_timeout)) - self.putChild(b"render.har", RenderHarResource(pool, max_timeout)) - - self.putChild(b"_debug", DebugResource(pool)) - self.putChild(b"_gc", ClearCachesResource()) + + _args = pool, max_timeout, self.argument_cache + self.putChild(b"render.html", RenderHtmlResource(*_args)) + self.putChild(b"render.png", RenderPngResource(*_args)) + self.putChild(b"render.jpeg", RenderJpegResource(*_args)) + self.putChild(b"render.json", RenderJsonResource(*_args)) + self.putChild(b"render.har", RenderHarResource(*_args)) + + self.putChild(b"_debug", DebugResource(pool, self.argument_cache)) + self.putChild(b"_gc", ClearCachesResource(self.argument_cache)) self.putChild(b"_ping", PingResource()) # backwards compatibility - self.putChild(b"debug", DebugResource(pool, warn=True)) + self.putChild(b"debug", DebugResource(pool, self.argument_cache, warn=True)) if self.lua_enabled and ExecuteLuaScriptResource is not None: self.putChild(b"execute", ExecuteLuaScriptResource( @@ -550,7 +591,8 @@ def __init__(self, pool, ui_enabled, lua_enabled, lua_sandbox_enabled, sandboxed=lua_sandbox_enabled, lua_package_path=lua_package_path, lua_sandbox_allowed_modules=lua_sandbox_allowed_modules, - max_timeout=max_timeout + max_timeout=max_timeout, + argument_cache=self.argument_cache, )) if self.ui_enabled: diff --git a/splash/server.py b/splash/server.py index 44664d9bb..dcf51d887 100644 --- a/splash/server.py +++ b/splash/server.py @@ -74,6 +74,9 @@ def parse_opts(jupyter=False, argv=sys.argv): help="disable web UI") op.add_option("--disable-lua", action="store_true", default=False, help="disable Lua scripting") + op.add_option("--argument-cache-max-entries", type="int", + default=defaults.ARGUMENT_CACHE_MAX_ENTRIES, + help="maximum number of entries in arguments cache (default: %default)") opts, args = op.parse_args(argv) @@ -84,11 +87,11 @@ def parse_opts(jupyter=False, argv=sys.argv): opts.port = None opts.slots = None opts.max_timeout = None + opts.argument_cache_max_entries = None return opts, args - def start_logging(opts): import twisted from twisted.python import log @@ -168,6 +171,7 @@ def splash_server(portnum, slots, network_manager_factory, max_timeout, lua_sandbox_enabled=True, lua_package_path="", lua_sandbox_allowed_modules=(), + argument_cache_max_entries=None, verbosity=None): from twisted.internet import reactor from twisted.web.server import Site @@ -182,6 +186,9 @@ def splash_server(portnum, slots, network_manager_factory, max_timeout, slots = defaults.SLOTS if slots is None else slots log.msg("slots=%s" % slots) + if argument_cache_max_entries: + log.msg("argument_cache_max_entries=%s" % argument_cache_max_entries) + pool = RenderPool( slots=slots, network_manager_factory=network_manager_factory, @@ -211,7 +218,8 @@ def splash_server(portnum, slots, network_manager_factory, max_timeout, lua_sandbox_enabled=lua_sandbox_enabled, lua_package_path=lua_package_path, lua_sandbox_allowed_modules=lua_sandbox_allowed_modules, - max_timeout=max_timeout + max_timeout=max_timeout, + argument_cache_max_entries=argument_cache_max_entries, ) factory = Site(root) reactor.listenTCP(portnum, factory) @@ -255,6 +263,7 @@ def default_splash_server(portnum, max_timeout, slots=None, lua_sandbox_enabled=True, lua_package_path="", lua_sandbox_allowed_modules=(), + argument_cache_max_entries=None, verbosity=None, server_factory=splash_server): from splash import network_manager @@ -278,7 +287,8 @@ def default_splash_server(portnum, max_timeout, slots=None, lua_package_path=lua_package_path, lua_sandbox_allowed_modules=lua_sandbox_allowed_modules, verbosity=verbosity, - max_timeout=max_timeout + max_timeout=max_timeout, + argument_cache_max_entries=argument_cache_max_entries, ) @@ -358,6 +368,7 @@ def main(jupyter=False, argv=sys.argv, server_factory=splash_server): lua_sandbox_allowed_modules=opts.lua_sandbox_allowed_modules.split(";"), verbosity=opts.verbosity, max_timeout=opts.max_timeout, + argument_cache_max_entries=opts.argument_cache_max_entries, server_factory=server_factory, ) signal.signal(signal.SIGUSR1, lambda s, f: traceback.print_stack(f)) diff --git a/splash/tests/test_argument_caching.py b/splash/tests/test_argument_caching.py new file mode 100644 index 000000000..3b2e37f7e --- /dev/null +++ b/splash/tests/test_argument_caching.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import +import hashlib + +from .test_render import BaseRenderTest +from .test_execute import BaseLuaRenderTest +from .test_jsonpost import JsonPostRequestHandler + + +class RenderHtmlArgumentCachingTest(BaseRenderTest): + endpoint = 'render.html' + + def test_cache_url(self): + # make a save_args request + resp = self.request({ + "url": self.mockurl('jsrender'), + "wait": 0.5, + "save_args": "url,wait", + }) + self.assertStatusCode(resp, 200) + self.assertIn("After", resp.text) + + # use load_args to avoid sending parameter values + header = resp.headers['X-Splash-Saved-Arguments'] + resp2 = self.request({"load_args": header}) + self.assertStatusCode(resp2, 200) + assert resp2.text == resp.text + + # clear cache + resp3 = self.post({}, endpoint="_gc") + self.assertStatusCode(resp3, 200) + data = resp3.json() + assert data['cached_args_removed'] >= 2 + assert data['pyobjects_collected'] > 0 + assert data['status'] == 'ok' + + # check that argument cache is cleared + resp4 = self.request({"load_args": header}) + data = self.assertJsonError(resp4, 498, 'ExpiredArguments') + assert set(data['info']['expired']) == {'wait', 'url'} + + +class ArgumentCachingTest(BaseLuaRenderTest): + request_handler = JsonPostRequestHandler + + def test_cache_args(self): + resp = self.request_lua(""" + function main(splash) + return {foo=splash.args.foo, baz=splash.args.baz} + end + """, { + "save_args": ["lua_source", "foo", "bar"], + "foo": "hello", + "baz": "world", + }) + self.assertStatusCode(resp, 200) + self.assertEqual(resp.json(), {"foo": "hello", "baz": "world"}) + + hashes = dict( + h.split("=", 1) for h in + resp.headers['X-Splash-Saved-Arguments'].split(";") + ) + resp2 = self.request({"load_args": hashes, "baz": "!"}) + self.assertStatusCode(resp2, 200) + self.assertEqual(resp2.json(), {"foo": "hello", "baz": "!"}) + + hashes["foo"] = hashlib.sha1(b"invalid").hexdigest() + resp3 = self.request({"load_args": hashes, "baz": "!"}) + data = self.assertJsonError(resp3, 498, "ExpiredArguments") + self.assertEqual(data['info'], {'expired': ['foo']}) + + def test_bad_save_args(self): + resp = self.request_lua("function main(splash) return 'hi' end", { + "save_args": {"lua_source": "yes"}, + }) + self.assertBadArgument(resp, "save_args") + + resp = self.request_lua("function main(splash) return 'hi' end", { + "save_args": ["foo", 324], + }) + self.assertBadArgument(resp, "save_args") + + def test_bad_load_args(self): + resp = self.request({"load_args": "foo"}) + self.assertBadArgument(resp, "load_args") + + resp = self.request({"load_args": [("foo", "bar")]}) + self.assertBadArgument(resp, "load_args") diff --git a/splash/tests/test_render.py b/splash/tests/test_render.py index 481a80ffa..6eeedc47b 100644 --- a/splash/tests/test_render.py +++ b/splash/tests/test_render.py @@ -97,6 +97,10 @@ def assertJsonError(self, response, code, error_type=None): self.assertEqual(data['type'], error_type) return data + def assertBadArgument(self, response, argname): + data = self.assertJsonError(response, 400, "BadOption") + self.assertEqual(data['info']['argument'], argname) + def assertPng(self, response, width=None, height=None): self.assertStatusCode(response, 200) self.assertEqual(response.headers["content-type"], "image/png")