mirror of
https://github.com/valitydev/redash.git
synced 2024-11-06 17:15:17 +00:00
Remove tree validations and introduce ParameterizedQuery (#3230)
This commit is contained in:
parent
823e4ccdd6
commit
121a44ef15
@ -1,66 +1,14 @@
|
||||
from __future__ import absolute_import
|
||||
import logging
|
||||
import time
|
||||
|
||||
from flask import request
|
||||
|
||||
from .authentication import current_org
|
||||
from flask_login import current_user, login_required
|
||||
from flask_restful import abort
|
||||
from redash import models, utils
|
||||
from redash import models
|
||||
from redash.handlers import routes
|
||||
from redash.handlers.base import (get_object_or_404, org_scoped_rule,
|
||||
record_event)
|
||||
from redash.utils import find_missing_params
|
||||
from redash.handlers.static import render_index
|
||||
from redash.utils import gen_query_hash, mustache_render
|
||||
|
||||
|
||||
#
|
||||
# Run a parameterized query synchronously and return the result
|
||||
# DISCLAIMER: Temporary solution to support parameters in queries. Should be
|
||||
# removed once we refactor the query results API endpoints and handling
|
||||
# on the client side. Please don't reuse in other API handlers.
|
||||
#
|
||||
def run_query_sync(data_source, parameter_values, query_text, max_age=0):
|
||||
missing_params = find_missing_params(query_text, parameter_values)
|
||||
if missing_params:
|
||||
raise Exception('Missing parameter value for: {}'.format(", ".join(missing_params)))
|
||||
|
||||
query_text = mustache_render(query_text, parameter_values)
|
||||
|
||||
if max_age <= 0:
|
||||
query_result = None
|
||||
else:
|
||||
query_result = models.QueryResult.get_latest(data_source, query_text, max_age)
|
||||
|
||||
query_hash = gen_query_hash(query_text)
|
||||
|
||||
if query_result:
|
||||
logging.info("Returning cached result for query %s" % query_hash)
|
||||
return query_result.data
|
||||
|
||||
try:
|
||||
started_at = time.time()
|
||||
data, error = data_source.query_runner.run_query(query_text, current_user)
|
||||
|
||||
if error:
|
||||
return None
|
||||
# update cache
|
||||
if max_age > 0:
|
||||
run_time = time.time() - started_at
|
||||
query_result, updated_query_ids = models.QueryResult.store_result(data_source.org_id, data_source.id,
|
||||
query_hash, query_text, data,
|
||||
run_time, utils.utcnow())
|
||||
|
||||
models.db.session.commit()
|
||||
return data
|
||||
except Exception:
|
||||
if max_age > 0:
|
||||
abort(404, message="Unable to get result from the database, and no cached query result found.")
|
||||
else:
|
||||
abort(503, message="Unable to get result from the database.")
|
||||
return None
|
||||
|
||||
|
||||
@routes.route(org_scoped_rule('/embed/query/<query_id>/visualization/<visualization_id>'), methods=['GET'])
|
||||
|
@ -8,39 +8,16 @@ from redash import models, settings
|
||||
from redash.handlers.base import BaseResource, get_object_or_404
|
||||
from redash.permissions import (has_access, not_view_only, require_access,
|
||||
require_permission, view_only)
|
||||
from redash.tasks import QueryTask, record_event
|
||||
from redash.tasks import QueryTask
|
||||
from redash.tasks.queries import enqueue_query
|
||||
from redash.utils import (collect_parameters_from_request, find_missing_params, gen_query_hash, json_dumps, utcnow)
|
||||
from redash.utils.sql_query import SQLInjectionError, SQLQuery
|
||||
from redash.utils import (collect_parameters_from_request, gen_query_hash, json_dumps, utcnow)
|
||||
from redash.utils.parameterized_query import ParameterizedQuery
|
||||
|
||||
|
||||
def error_response(message):
|
||||
return {'job': {'status': 4, 'error': message}}, 400
|
||||
|
||||
|
||||
def apply_parameters(template, parameters, data_source):
|
||||
query = SQLQuery(template).apply(parameters)
|
||||
|
||||
# for now we only log `SQLInjectionError` to detect false positives
|
||||
try:
|
||||
text = query.text
|
||||
except SQLInjectionError:
|
||||
record_event({
|
||||
'action': 'sql_injection',
|
||||
'object_type': 'query',
|
||||
'query': template,
|
||||
'parameters': parameters,
|
||||
'timestamp': time.time(),
|
||||
'org_id': data_source.org_id
|
||||
})
|
||||
except Exception as e:
|
||||
logging.info(u"Failed applying parameters for query %s: %s", gen_query_hash(query.query), e.message)
|
||||
finally:
|
||||
text = query.query
|
||||
|
||||
return text
|
||||
|
||||
|
||||
#
|
||||
# Run a parameterized query synchronously and return the result
|
||||
# DISCLAIMER: Temporary solution to support parameters in queries. Should be
|
||||
@ -48,18 +25,17 @@ def apply_parameters(template, parameters, data_source):
|
||||
# on the client side. Please don't reuse in other API handlers.
|
||||
#
|
||||
def run_query_sync(data_source, parameter_values, query_text, max_age=0):
|
||||
missing_params = find_missing_params(query_text, parameter_values)
|
||||
if missing_params:
|
||||
raise Exception('Missing parameter value for: {}'.format(", ".join(missing_params)))
|
||||
query = ParameterizedQuery(query_text).apply(parameter_values)
|
||||
|
||||
query_text = apply_parameters(query_text, parameter_values, data_source)
|
||||
if query.missing_params:
|
||||
raise Exception('Missing parameter value for: {}'.format(", ".join(query.missing_params)))
|
||||
|
||||
if max_age <= 0:
|
||||
query_result = None
|
||||
else:
|
||||
query_result = models.QueryResult.get_latest(data_source, query_text, max_age)
|
||||
query_result = models.QueryResult.get_latest(data_source, query.text, max_age)
|
||||
|
||||
query_hash = gen_query_hash(query_text)
|
||||
query_hash = gen_query_hash(query.text)
|
||||
|
||||
if query_result:
|
||||
logging.info("Returning cached result for query %s" % query_hash)
|
||||
@ -67,7 +43,7 @@ def run_query_sync(data_source, parameter_values, query_text, max_age=0):
|
||||
|
||||
try:
|
||||
started_at = time.time()
|
||||
data, error = data_source.query_runner.run_query(query_text, current_user)
|
||||
data, error = data_source.query_runner.run_query(query.text, current_user)
|
||||
|
||||
if error:
|
||||
logging.info('got bak error')
|
||||
@ -76,9 +52,8 @@ def run_query_sync(data_source, parameter_values, query_text, max_age=0):
|
||||
|
||||
run_time = time.time() - started_at
|
||||
query_result, updated_query_ids = models.QueryResult.store_result(data_source.org_id, data_source,
|
||||
query_hash, query_text, data,
|
||||
query_hash, query.text, data,
|
||||
run_time, utcnow())
|
||||
|
||||
models.db.session.commit()
|
||||
return query_result
|
||||
except Exception as e:
|
||||
@ -90,10 +65,6 @@ def run_query_sync(data_source, parameter_values, query_text, max_age=0):
|
||||
|
||||
|
||||
def run_query(data_source, parameter_values, query_text, query_id, max_age=0):
|
||||
missing_params = find_missing_params(query_text, parameter_values)
|
||||
if missing_params:
|
||||
return error_response(u'Missing parameter value for: {}'.format(u", ".join(missing_params)))
|
||||
|
||||
if data_source.paused:
|
||||
if data_source.pause_reason:
|
||||
message = '{} is paused ({}). Please try later.'.format(data_source.name, data_source.pause_reason)
|
||||
@ -102,17 +73,20 @@ def run_query(data_source, parameter_values, query_text, query_id, max_age=0):
|
||||
|
||||
return error_response(message)
|
||||
|
||||
query_text = apply_parameters(query_text, parameter_values, data_source)
|
||||
query = ParameterizedQuery(query_text).apply(parameter_values)
|
||||
|
||||
if query.missing_params:
|
||||
return error_response(u'Missing parameter value for: {}'.format(u", ".join(query.missing_params)))
|
||||
|
||||
if max_age == 0:
|
||||
query_result = None
|
||||
else:
|
||||
query_result = models.QueryResult.get_latest(data_source, query_text, max_age)
|
||||
query_result = models.QueryResult.get_latest(data_source, query.text, max_age)
|
||||
|
||||
if query_result:
|
||||
return {'query_result': query_result.to_dict()}
|
||||
else:
|
||||
job = enqueue_query(query_text, data_source, current_user.id, metadata={"Username": current_user.email, "Query ID": query_id})
|
||||
job = enqueue_query(query.text, data_source, current_user.id, metadata={"Username": current_user.email, "Query ID": query_id})
|
||||
return {'job': job.to_dict()}
|
||||
|
||||
|
||||
|
@ -124,6 +124,9 @@ class TimeoutSession(Session):
|
||||
|
||||
|
||||
class GoogleSpreadsheet(BaseQueryRunner):
|
||||
def __init__(self, configuration):
|
||||
super(GoogleSpreadsheet, self).__init__(configuration)
|
||||
self.syntax = 'custom'
|
||||
|
||||
@classmethod
|
||||
def annotate_query(cls):
|
||||
|
@ -55,6 +55,7 @@ class Graphite(BaseQueryRunner):
|
||||
|
||||
def __init__(self, configuration):
|
||||
super(Graphite, self).__init__(configuration)
|
||||
self.syntax = 'custom'
|
||||
|
||||
if "username" in self.configuration and self.configuration["username"]:
|
||||
self.auth = (self.configuration["username"], self.configuration["password"])
|
||||
|
@ -15,7 +15,7 @@ from six import string_types
|
||||
import pystache
|
||||
import pytz
|
||||
import simplejson
|
||||
from funcy import distinct, select_values
|
||||
from funcy import select_values
|
||||
from redash import settings
|
||||
from sqlalchemy.orm.query import Query
|
||||
|
||||
@ -167,41 +167,6 @@ class UnicodeWriter:
|
||||
self.writerow(row)
|
||||
|
||||
|
||||
def _collect_key_names(nodes):
|
||||
keys = []
|
||||
for node in nodes._parse_tree:
|
||||
if isinstance(node, pystache.parser._EscapeNode):
|
||||
keys.append(node.key)
|
||||
elif isinstance(node, pystache.parser._SectionNode):
|
||||
keys.append(node.key)
|
||||
keys.extend(_collect_key_names(node.parsed))
|
||||
|
||||
return distinct(keys)
|
||||
|
||||
|
||||
def collect_query_parameters(query):
|
||||
nodes = pystache.parse(query)
|
||||
keys = _collect_key_names(nodes)
|
||||
return keys
|
||||
|
||||
|
||||
def parameter_names(parameter_values):
|
||||
names = []
|
||||
for key, value in parameter_values.iteritems():
|
||||
if isinstance(value, dict):
|
||||
for inner_key in value.keys():
|
||||
names.append(u'{}.{}'.format(key, inner_key))
|
||||
else:
|
||||
names.append(key)
|
||||
|
||||
return names
|
||||
|
||||
|
||||
def find_missing_params(query_text, parameter_values):
|
||||
query_parameters = set(collect_query_parameters(query_text))
|
||||
return set(query_parameters) - set(parameter_names(parameter_values))
|
||||
|
||||
|
||||
def collect_parameters_from_request(args):
|
||||
parameters = {}
|
||||
|
||||
|
54
redash/utils/parameterized_query.py
Normal file
54
redash/utils/parameterized_query.py
Normal file
@ -0,0 +1,54 @@
|
||||
import pystache
|
||||
from redash.utils import mustache_render
|
||||
from funcy import distinct
|
||||
|
||||
|
||||
def _collect_key_names(nodes):
|
||||
keys = []
|
||||
for node in nodes._parse_tree:
|
||||
if isinstance(node, pystache.parser._EscapeNode):
|
||||
keys.append(node.key)
|
||||
elif isinstance(node, pystache.parser._SectionNode):
|
||||
keys.append(node.key)
|
||||
keys.extend(_collect_key_names(node.parsed))
|
||||
|
||||
return distinct(keys)
|
||||
|
||||
|
||||
def _collect_query_parameters(query):
|
||||
nodes = pystache.parse(query)
|
||||
keys = _collect_key_names(nodes)
|
||||
return keys
|
||||
|
||||
|
||||
def _parameter_names(parameter_values):
|
||||
names = []
|
||||
for key, value in parameter_values.iteritems():
|
||||
if isinstance(value, dict):
|
||||
for inner_key in value.keys():
|
||||
names.append(u'{}.{}'.format(key, inner_key))
|
||||
else:
|
||||
names.append(key)
|
||||
|
||||
return names
|
||||
|
||||
|
||||
class ParameterizedQuery(object):
|
||||
def __init__(self, template):
|
||||
self.template = template
|
||||
self.query = template
|
||||
self.parameters = {}
|
||||
|
||||
def apply(self, parameters):
|
||||
self.parameters.update(parameters)
|
||||
self.query = mustache_render(self.template, self.parameters)
|
||||
return self
|
||||
|
||||
@property
|
||||
def missing_params(self):
|
||||
query_parameters = set(_collect_query_parameters(self.template))
|
||||
return set(query_parameters) - set(_parameter_names(self.parameters))
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
return self.query
|
@ -1,71 +0,0 @@
|
||||
import re
|
||||
|
||||
import sqlparse
|
||||
from redash.utils import mustache_render
|
||||
|
||||
|
||||
def _replace_params(template):
|
||||
return re.sub('-?{{.+?}}', 'param', template)
|
||||
|
||||
|
||||
def _inside_a_where_clause(a):
|
||||
if a is None:
|
||||
return False
|
||||
else:
|
||||
return type(a.parent) is sqlparse.sql.Where or _inside_a_where_clause(a.parent)
|
||||
|
||||
|
||||
def _populating_an_in_operator(a, b):
|
||||
if type(a) is sqlparse.sql.Identifier and \
|
||||
type(b) is sqlparse.sql.IdentifierList and \
|
||||
_inside_a_where_clause(a):
|
||||
return True
|
||||
|
||||
|
||||
def _equivalent_leaves(a, b):
|
||||
return type(a) == type(b) or \
|
||||
(type(a) is sqlparse.sql.Identifier and type(b) is sqlparse.sql.Token)
|
||||
|
||||
|
||||
def _filter_noise(tokens):
|
||||
skippable_tokens = [sqlparse.tokens.Error, sqlparse.tokens.Whitespace]
|
||||
return [t for t in tokens if t.ttype not in skippable_tokens]
|
||||
|
||||
|
||||
def _same_type(a, b):
|
||||
if _populating_an_in_operator(a, b):
|
||||
return True
|
||||
elif type(a) in (list, tuple):
|
||||
children_are_same = [_same_type(child_a, child_b) for (child_a, child_b) in zip(a, b)]
|
||||
return len(a) == len(b) and all(children_are_same)
|
||||
elif (hasattr(a, 'tokens') and hasattr(b, 'tokens')):
|
||||
return _same_type(_filter_noise(a.tokens), _filter_noise(b.tokens))
|
||||
else:
|
||||
return _equivalent_leaves(a, b)
|
||||
|
||||
|
||||
class SQLQuery(object):
|
||||
def __init__(self, template):
|
||||
self.template = template
|
||||
self.query = template
|
||||
|
||||
def apply(self, parameters):
|
||||
self.query = mustache_render(self.template, parameters)
|
||||
return self
|
||||
|
||||
def is_safe(self):
|
||||
template_tree = sqlparse.parse(_replace_params(self.template))
|
||||
query_tree = sqlparse.parse(self.query)
|
||||
|
||||
return _same_type(template_tree, query_tree)
|
||||
|
||||
@property
|
||||
def text(self):
|
||||
if not self.is_safe():
|
||||
raise SQLInjectionError()
|
||||
else:
|
||||
return self.query
|
||||
|
||||
|
||||
class SQLInjectionError(Exception):
|
||||
pass
|
@ -2,8 +2,7 @@ from collections import namedtuple
|
||||
from unittest import TestCase
|
||||
|
||||
from redash.utils import (build_url, collect_parameters_from_request,
|
||||
find_missing_params, filter_none,
|
||||
json_dumps, generate_token)
|
||||
filter_none, json_dumps, generate_token)
|
||||
|
||||
DummyRequest = namedtuple('DummyRequest', ['host', 'scheme'])
|
||||
|
||||
@ -25,33 +24,6 @@ class TestBuildUrl(TestCase):
|
||||
self.assertEqual("http://example.com:443/test", build_url(DummyRequest("example.com:443", "http"), "example.com", "/test"))
|
||||
|
||||
|
||||
class TestFindMissingParams(TestCase):
|
||||
def test_returns_empty_list_for_regular_query(self):
|
||||
query = u"SELECT 1"
|
||||
self.assertEqual(set([]), find_missing_params(query, {}))
|
||||
|
||||
def test_finds_all_params_when_missing(self):
|
||||
query = u"SELECT {{param}} FROM {{table}}"
|
||||
self.assertEqual(set(['param', 'table']), find_missing_params(query, {}))
|
||||
|
||||
def test_finds_all_params(self):
|
||||
query = u"SELECT {{param}} FROM {{table}}"
|
||||
self.assertEqual(set([]), find_missing_params(query, {'param': 'value', 'table': 'value'}))
|
||||
|
||||
def test_deduplicates_params(self):
|
||||
query = u"SELECT {{param}}, {{param}} FROM {{table}}"
|
||||
self.assertEqual(set([]), find_missing_params(query, {'param': 'value', 'table': 'value'}))
|
||||
|
||||
def test_handles_nested_params(self):
|
||||
query = u"SELECT {{param}}, {{param}} FROM {{table}} -- {{#test}} {{nested_param}} {{/test}}"
|
||||
self.assertEqual(set(['test', 'nested_param']),
|
||||
find_missing_params(query, {'param': 'value', 'table': 'value'}))
|
||||
|
||||
def test_handles_objects(self):
|
||||
query = u"SELECT * FROM USERS WHERE created_at between '{{ created_at.start }}' and '{{ created_at.end }}'"
|
||||
self.assertEqual(set([]), find_missing_params(query, {'created_at': {'start': 1, 'end': 2}}))
|
||||
|
||||
|
||||
class TestCollectParametersFromRequest(TestCase):
|
||||
def test_ignores_non_prefixed_values(self):
|
||||
self.assertEqual({}, collect_parameters_from_request({'test': 1}))
|
||||
|
43
tests/utils/test_parameterized_query.py
Normal file
43
tests/utils/test_parameterized_query.py
Normal file
@ -0,0 +1,43 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from redash.utils.parameterized_query import ParameterizedQuery
|
||||
|
||||
|
||||
class TestParameterizedQuery(TestCase):
|
||||
def test_returns_empty_list_for_regular_query(self):
|
||||
query = ParameterizedQuery(u"SELECT 1")
|
||||
self.assertEqual(set([]), query.missing_params)
|
||||
|
||||
def test_finds_all_params_when_missing(self):
|
||||
query = ParameterizedQuery(u"SELECT {{param}} FROM {{table}}")
|
||||
self.assertEqual(set(['param', 'table']), query.missing_params)
|
||||
|
||||
def test_finds_all_params(self):
|
||||
query = ParameterizedQuery(u"SELECT {{param}} FROM {{table}}").apply({
|
||||
'param': 'value',
|
||||
'table': 'value'
|
||||
})
|
||||
self.assertEqual(set([]), query.missing_params)
|
||||
|
||||
def test_deduplicates_params(self):
|
||||
query = ParameterizedQuery(u"SELECT {{param}}, {{param}} FROM {{table}}").apply({
|
||||
'param': 'value',
|
||||
'table': 'value'
|
||||
})
|
||||
self.assertEqual(set([]), query.missing_params)
|
||||
|
||||
def test_handles_nested_params(self):
|
||||
query = ParameterizedQuery(u"SELECT {{param}}, {{param}} FROM {{table}} -- {{#test}} {{nested_param}} {{/test}}").apply({
|
||||
'param': 'value',
|
||||
'table': 'value'
|
||||
})
|
||||
self.assertEqual(set(['test', 'nested_param']), query.missing_params)
|
||||
|
||||
def test_handles_objects(self):
|
||||
query = ParameterizedQuery(u"SELECT * FROM USERS WHERE created_at between '{{ created_at.start }}' and '{{ created_at.end }}'").apply({
|
||||
'created_at': {
|
||||
'start': 1,
|
||||
'end': 2
|
||||
}
|
||||
})
|
||||
self.assertEqual(set([]), query.missing_params)
|
@ -1,94 +0,0 @@
|
||||
from unittest import TestCase
|
||||
|
||||
from redash.utils.sql_query import SQLInjectionError, SQLQuery
|
||||
|
||||
|
||||
class TestSQLQuery(TestCase):
|
||||
def test_serializes(self):
|
||||
query = SQLQuery("SELECT * FROM users WHERE userid='{{userid}}'").apply({
|
||||
"userid": 22
|
||||
})
|
||||
|
||||
self.assertEqual(query.text, "SELECT * FROM users WHERE userid='22'")
|
||||
|
||||
def test_raises_when_serializing_unsafe_queries(self):
|
||||
query = SQLQuery("SELECT * FROM users WHERE userid={{userid}}").apply({
|
||||
"userid": "22 OR 1==1"
|
||||
})
|
||||
|
||||
self.assertRaises(SQLInjectionError, getattr, query, 'text')
|
||||
|
||||
def test_marks_queries_without_params_as_safe(self):
|
||||
query = SQLQuery("SELECT * FROM users")
|
||||
|
||||
self.assertTrue(query.is_safe())
|
||||
|
||||
def test_marks_simple_queries_with_where_params_as_safe(self):
|
||||
query = SQLQuery("SELECT * FROM users WHERE userid='{{userid}}'").apply({
|
||||
"userid": 22
|
||||
})
|
||||
|
||||
self.assertTrue(query.is_safe())
|
||||
|
||||
def test_marks_simple_queries_with_column_params_as_safe(self):
|
||||
query = SQLQuery("SELECT {{this_column}} FROM users").apply({
|
||||
"this_column": "username"
|
||||
})
|
||||
|
||||
self.assertTrue(query.is_safe())
|
||||
|
||||
def test_marks_multiple_simple_queries_as_safe(self):
|
||||
query = SQLQuery("SELECT * FROM users WHERE userid='{{userid}}' ; SELECT * FROM profiles").apply({
|
||||
"userid": 22
|
||||
})
|
||||
|
||||
self.assertTrue(query.is_safe())
|
||||
|
||||
def test_marks_tautologies_as_not_safe(self):
|
||||
query = SQLQuery("SELECT * FROM users WHERE userid={{userid}}").apply({
|
||||
"userid": "22 OR 1==1"
|
||||
})
|
||||
|
||||
self.assertFalse(query.is_safe())
|
||||
|
||||
def test_marks_union_queries_as_not_safe(self):
|
||||
query = SQLQuery("SELECT * FROM users WHERE userid={{userid}}").apply({
|
||||
"userid": "22 UNION SELECT body, results, 1 FROM reports"
|
||||
})
|
||||
|
||||
self.assertFalse(query.is_safe())
|
||||
|
||||
def test_marks_comment_attacks_as_not_safe(self):
|
||||
query = SQLQuery("SELECT * FROM users WHERE username='{{username}}' AND password='{{password}}'").apply({
|
||||
"username": "admin' --"
|
||||
})
|
||||
|
||||
self.assertFalse(query.is_safe())
|
||||
|
||||
def test_marks_additional_columns_as_not_safe(self):
|
||||
query = SQLQuery("SELECT {{this_column}} FROM users").apply({
|
||||
"this_column": "username, password"
|
||||
})
|
||||
|
||||
self.assertFalse(query.is_safe())
|
||||
|
||||
def test_marks_query_additions_as_not_safe(self):
|
||||
query = SQLQuery("SELECT * FROM users ORDER BY {{this_column}}").apply({
|
||||
"this_column": "id ; DROP TABLE midgets"
|
||||
})
|
||||
|
||||
self.assertFalse(query.is_safe())
|
||||
|
||||
def test_marks_multiple_word_params_as_safe(self):
|
||||
query = SQLQuery("SELECT {{why would you do this}} FROM users").apply({
|
||||
"why would you do this": "shrug"
|
||||
})
|
||||
|
||||
self.assertTrue(query.is_safe())
|
||||
|
||||
def test_marks_param_negations_as_safe(self):
|
||||
query = SQLQuery("SELECT date_add(some_column, INTERVAL -{{days}} DAY) FROM events").apply({
|
||||
"days": 7
|
||||
})
|
||||
|
||||
self.assertTrue(query.is_safe())
|
Loading…
Reference in New Issue
Block a user