mirror of
https://github.com/valitydev/redash.git
synced 2024-11-07 01:25:16 +00:00
Merge pull request #1355 from getredash/add_user_id_to_run_query
Change: pass the user object to the run_query method
This commit is contained in:
commit
880627c69c
@ -45,7 +45,7 @@ def run_query(data_source, parameter_values, query_text, query_id, max_age=0):
|
||||
if query_result:
|
||||
return {'query_result': query_result.to_dict()}
|
||||
else:
|
||||
job = enqueue_query(query_text, data_source, metadata={"Username": current_user.name, "Query ID": query_id})
|
||||
job = enqueue_query(query_text, data_source, current_user.id, metadata={"Username": current_user.email, "Query ID": query_id})
|
||||
return {'job': job.to_dict()}
|
||||
|
||||
|
||||
|
@ -45,6 +45,7 @@ class InterruptException(Exception):
|
||||
|
||||
class BaseQueryRunner(object):
|
||||
noop_query = None
|
||||
|
||||
def __init__(self, configuration):
|
||||
self.syntax = 'sql'
|
||||
self.configuration = configuration
|
||||
@ -74,7 +75,7 @@ class BaseQueryRunner(object):
|
||||
raise NotImplementedError()
|
||||
self.run_query(self.noop_query)
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
raise NotImplementedError()
|
||||
|
||||
def fetch_columns(self, columns):
|
||||
|
@ -215,7 +215,7 @@ class BigQuery(BaseQueryRunner):
|
||||
|
||||
return schema
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
logger.debug("BigQuery got query: %s", query)
|
||||
|
||||
bigquery_service = self._get_bigquery_service()
|
||||
|
@ -60,7 +60,7 @@ class Cassandra(BaseQueryRunner):
|
||||
results = self.run_query(query)
|
||||
return results, error
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
from cassandra.cluster import Cluster
|
||||
connection = None
|
||||
try:
|
||||
|
@ -107,7 +107,7 @@ class DynamoDBSQL(BaseSQLQueryRunner):
|
||||
logging.exception(e)
|
||||
raise sys.exc_info()[1], None, sys.exc_info()[2]
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
connection = None
|
||||
try:
|
||||
engine, connection = self._connect()
|
||||
|
@ -262,7 +262,7 @@ class Kibana(BaseElasticSearch):
|
||||
|
||||
return raw_result["hits"]["total"]
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
try:
|
||||
error = None
|
||||
|
||||
@ -335,7 +335,7 @@ class ElasticSearch(BaseElasticSearch):
|
||||
def annotate_query(cls):
|
||||
return False
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
try:
|
||||
error = None
|
||||
|
||||
|
@ -171,7 +171,7 @@ class GoogleSpreadsheet(BaseQueryRunner):
|
||||
def test_connection(self):
|
||||
self._get_spreadsheet_service()
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
logger.debug("Spreadsheet is about to execute query: %s", query)
|
||||
values = query.split("|")
|
||||
key = values[0] #key of the spreadsheet
|
||||
|
@ -69,7 +69,7 @@ class Graphite(BaseQueryRunner):
|
||||
raise Exception("Connection test failed. Return Code: {0}"
|
||||
" Reason: {1}".format(r.status_code, r.text))
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
url = "%s%s" % (self.base_url, "&".join(query.split("\n")))
|
||||
error = None
|
||||
data = None
|
||||
|
@ -92,7 +92,7 @@ class Hive(BaseSQLQueryRunner):
|
||||
raise sys.exc_info()[1], None, sys.exc_info()[2]
|
||||
return schema.values()
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
|
||||
connection = None
|
||||
try:
|
||||
|
@ -95,7 +95,7 @@ class Impala(BaseSQLQueryRunner):
|
||||
|
||||
return schema_dict.values()
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
|
||||
connection = None
|
||||
try:
|
||||
|
@ -78,7 +78,7 @@ class InfluxDB(BaseQueryRunner):
|
||||
def __init__(self, configuration):
|
||||
super(InfluxDB, self).__init__(configuration)
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
client = InfluxDBClusterClient.from_DSN(self.configuration['url'])
|
||||
|
||||
logger.debug("influxdb url: %s", self.configuration['url'])
|
||||
|
@ -101,11 +101,11 @@ class JiraJQL(BaseQueryRunner):
|
||||
super(JiraJQL, self).__init__(configuration)
|
||||
self.syntax = 'json'
|
||||
|
||||
def run_query(self, query_string):
|
||||
def run_query(self, query, user):
|
||||
jql_url = '{}/rest/api/2/search'.format(self.configuration["url"])
|
||||
|
||||
try:
|
||||
query = json.loads(query_string)
|
||||
query = json.loads(query)
|
||||
query_type = query.pop('queryType', 'select')
|
||||
|
||||
if query_type == 'count':
|
||||
|
@ -161,7 +161,7 @@ class MongoDB(BaseQueryRunner):
|
||||
|
||||
return schema.values()
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
db = self._get_db()
|
||||
|
||||
logger.debug("mongodb connection string: %s", self.configuration['connectionString'])
|
||||
|
@ -50,7 +50,7 @@ class MQL(BaseQueryRunner):
|
||||
if not conn.command("connectionStatus")["ok"]:
|
||||
raise Exception("MongoDB connection error")
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
conn = pymongo.MongoClient(self.configuration['uri'])
|
||||
# execute() returns a generator (that wraps a cursor)
|
||||
gen = query_to_plan(query).execute(conn)
|
||||
|
@ -121,7 +121,7 @@ class SqlServer(BaseSQLQueryRunner):
|
||||
|
||||
return schema.values()
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
connection = None
|
||||
|
||||
try:
|
||||
|
@ -121,7 +121,7 @@ class Mysql(BaseSQLQueryRunner):
|
||||
|
||||
return schema.values()
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
import MySQLdb
|
||||
|
||||
connection = None
|
||||
|
@ -138,7 +138,7 @@ class Oracle(BaseSQLQueryRunner):
|
||||
if scale <= 0:
|
||||
return cursor.var(cx_Oracle.STRING, 255, outconverter=Oracle._convert_number, arraysize=cursor.arraysize)
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
connection = cx_Oracle.connect(self.connection_string)
|
||||
connection.outputtypehandler = Oracle.output_handler
|
||||
|
||||
|
@ -115,7 +115,7 @@ class PostgreSQL(BaseSQLQueryRunner):
|
||||
|
||||
return schema.values()
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
connection = psycopg2.connect(self.connection_string, async=True)
|
||||
_wait(connection, timeout=10)
|
||||
|
||||
|
@ -96,7 +96,7 @@ class Presto(BaseQueryRunner):
|
||||
|
||||
return schema.values()
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
connection = presto.connect(
|
||||
host=self.configuration.get('host', ''),
|
||||
port=self.configuration.get('port', 8080),
|
||||
|
@ -182,7 +182,7 @@ class Python(BaseQueryRunner):
|
||||
def test_connection(self):
|
||||
pass
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
try:
|
||||
error = None
|
||||
|
||||
|
@ -37,7 +37,7 @@ class Script(BaseQueryRunner):
|
||||
def test_connection(self):
|
||||
pass
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
try:
|
||||
json_data = None
|
||||
error = None
|
||||
|
@ -59,7 +59,7 @@ class Sqlite(BaseSQLQueryRunner):
|
||||
|
||||
return schema.values()
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
connection = sqlite3.connect(self._dbpath)
|
||||
|
||||
cursor = connection.cursor()
|
||||
|
@ -93,7 +93,7 @@ class TreasureData(BaseQueryRunner):
|
||||
raise Exception("Failed getting schema")
|
||||
return schema.values()
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
connection = tdclient.connect(
|
||||
endpoint=self.configuration.get('endpoint', 'https://api.treasuredata.com'),
|
||||
apikey=self.configuration.get('apikey'),
|
||||
|
@ -23,7 +23,7 @@ class Url(BaseQueryRunner):
|
||||
def annotate_query(cls):
|
||||
return False
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
base_url = self.configuration.get("url", None)
|
||||
|
||||
try:
|
||||
|
@ -93,7 +93,7 @@ class Vertica(BaseSQLQueryRunner):
|
||||
|
||||
return schema.values()
|
||||
|
||||
def run_query(self, query):
|
||||
def run_query(self, query, user):
|
||||
import vertica_python
|
||||
|
||||
if query == "":
|
||||
|
@ -198,7 +198,7 @@ class QueryTask(object):
|
||||
return self._async_result.revoke(terminate=True, signal='SIGINT')
|
||||
|
||||
|
||||
def enqueue_query(query, data_source, scheduled=False, metadata={}):
|
||||
def enqueue_query(query, data_source, user_id, scheduled=False, metadata={}):
|
||||
query_hash = gen_query_hash(query)
|
||||
logging.info("Inserting job for %s with metadata=%s", query_hash, metadata)
|
||||
try_count = 0
|
||||
@ -229,7 +229,7 @@ def enqueue_query(query, data_source, scheduled=False, metadata={}):
|
||||
else:
|
||||
queue_name = data_source.queue_name
|
||||
|
||||
result = execute_query.apply_async(args=(query, data_source.id, metadata), queue=queue_name)
|
||||
result = execute_query.apply_async(args=(query, data_source.id, metadata, user_id), queue=queue_name)
|
||||
job = QueryTask(async_result=result)
|
||||
tracker = QueryTaskTracker.create(result.id, 'created', query_hash, data_source.id, scheduled, metadata)
|
||||
tracker.save(connection=pipe)
|
||||
@ -262,7 +262,7 @@ def refresh_queries():
|
||||
elif query.data_source.paused:
|
||||
logging.info("Skipping refresh of %s because datasource - %s is paused (%s).", query.id, query.data_source.name, query.data_source.pause_reason)
|
||||
else:
|
||||
enqueue_query(query.query, query.data_source,
|
||||
enqueue_query(query.query, query.data_source, query.user_id,
|
||||
scheduled=True,
|
||||
metadata={'Query ID': query.id, 'Username': 'Scheduled'})
|
||||
|
||||
@ -378,12 +378,13 @@ class QueryExecutionError(Exception):
|
||||
# We could have created this as a celery.Task derived class, and act as the task itself. But this might result in weird
|
||||
# issues as the task class created once per process, so decided to have a plain object instead.
|
||||
class QueryExecutor(object):
|
||||
def __init__(self, task, query, data_source_id, metadata):
|
||||
def __init__(self, task, query, data_source_id, user_id, metadata):
|
||||
self.task = task
|
||||
self.query = query
|
||||
self.data_source_id = data_source_id
|
||||
self.metadata = metadata
|
||||
self.data_source = self._load_data_source()
|
||||
self.user = models.User.get_by_id(user_id)
|
||||
self.query_hash = gen_query_hash(self.query)
|
||||
# Load existing tracker or create a new one if the job was created before code update:
|
||||
self.tracker = QueryTaskTracker.get_by_task_id(task.request.id) or QueryTaskTracker.create(task.request.id,
|
||||
@ -403,7 +404,7 @@ class QueryExecutor(object):
|
||||
annotated_query = self._annotate_query(query_runner)
|
||||
|
||||
try:
|
||||
data, error = query_runner.run_query(annotated_query)
|
||||
data, error = query_runner.run_query(annotated_query, self.user)
|
||||
except Exception as e:
|
||||
error = unicode(e)
|
||||
data = None
|
||||
@ -456,6 +457,8 @@ class QueryExecutor(object):
|
||||
return models.DataSource.get_by_id(self.data_source_id)
|
||||
|
||||
|
||||
# user_id is added last as a keyword argument for backward compatability -- to support executing previously submitted
|
||||
# jobs before the upgrade to this version.
|
||||
@celery.task(name="redash.tasks.execute_query", bind=True, base=BaseTask, track_started=True)
|
||||
def execute_query(self, query, data_source_id, metadata):
|
||||
return QueryExecutor(self, query, data_source_id, metadata).run()
|
||||
def execute_query(self, query, data_source_id, metadata, user_id=None):
|
||||
return QueryExecutor(self, query, data_source_id, user_id, metadata).run()
|
||||
|
@ -19,7 +19,7 @@ class TestRefreshQueries(BaseTestCase):
|
||||
|
||||
with patch('redash.tasks.queries.enqueue_query') as add_job_mock:
|
||||
refresh_queries()
|
||||
add_job_mock.assert_called_with(query.query, query.data_source, scheduled=True, metadata=ANY)
|
||||
add_job_mock.assert_called_with(query.query, query.data_source, query.user_id, scheduled=True, metadata=ANY)
|
||||
|
||||
def test_doesnt_enqueue_outdated_queries_for_paused_data_source(self):
|
||||
query = self.factory.create_query(schedule="60")
|
||||
@ -39,7 +39,7 @@ class TestRefreshQueries(BaseTestCase):
|
||||
|
||||
with patch('redash.tasks.queries.enqueue_query') as add_job_mock:
|
||||
refresh_queries()
|
||||
add_job_mock.assert_called_with(query.query, query.data_source, scheduled=True, metadata=ANY)
|
||||
add_job_mock.assert_called_with(query.query, query.data_source, query.user_id, scheduled=True, metadata=ANY)
|
||||
|
||||
def test_skips_fresh_queries(self):
|
||||
query = self.factory.create_query(schedule="1200")
|
||||
@ -74,7 +74,7 @@ class TestRefreshQueries(BaseTestCase):
|
||||
|
||||
with patch('redash.tasks.queries.enqueue_query') as add_job_mock:
|
||||
refresh_queries()
|
||||
add_job_mock.assert_called_once_with(query.query, query.data_source, scheduled=True, metadata=ANY)#{'Query ID': query.id, 'Username': 'Scheduled'})
|
||||
add_job_mock.assert_called_once_with(query.query, query.data_source, query.user_id, scheduled=True, metadata=ANY)#{'Query ID': query.id, 'Username': 'Scheduled'})
|
||||
|
||||
def test_enqueues_query_with_correct_data_source(self):
|
||||
query = self.factory.create_query(schedule="60", data_source=self.factory.create_data_source())
|
||||
@ -89,8 +89,8 @@ class TestRefreshQueries(BaseTestCase):
|
||||
|
||||
with patch('redash.tasks.queries.enqueue_query') as add_job_mock:
|
||||
refresh_queries()
|
||||
add_job_mock.assert_has_calls([call(query2.query, query2.data_source, scheduled=True, metadata=ANY),
|
||||
call(query.query, query.data_source, scheduled=True, metadata=ANY)],
|
||||
add_job_mock.assert_has_calls([call(query2.query, query2.data_source, query2.user_id, scheduled=True, metadata=ANY),
|
||||
call(query.query, query.data_source, query.user_id, scheduled=True, metadata=ANY)],
|
||||
any_order=True)
|
||||
self.assertEquals(2, add_job_mock.call_count)
|
||||
|
||||
@ -108,4 +108,4 @@ class TestRefreshQueries(BaseTestCase):
|
||||
|
||||
with patch('redash.tasks.queries.enqueue_query') as add_job_mock:
|
||||
refresh_queries()
|
||||
add_job_mock.assert_called_once_with(query.query, query.data_source, scheduled=True, metadata=ANY)
|
||||
add_job_mock.assert_called_once_with(query.query, query.data_source, query.user_id, scheduled=True, metadata=ANY)
|
||||
|
Loading…
Reference in New Issue
Block a user