mirror of
https://github.com/valitydev/redash.git
synced 2024-11-07 01:25:16 +00:00
Class for SQL metadata logic (tables, ddl, dml statements)
This commit is contained in:
parent
d8d666c971
commit
16125327b1
@ -10,19 +10,57 @@ import sqlparse
|
||||
|
||||
COMMENTS_REGEX = re.compile("/\*.*?\*/")
|
||||
|
||||
def extract_table_names(tokens):
|
||||
tables = set()
|
||||
tokens = [t for t in tokens if t.ttype not in (sqlparse.tokens.Whitespace, sqlparse.tokens.Newline)]
|
||||
|
||||
for i in range(len(tokens)):
|
||||
if tokens[i].is_group():
|
||||
tables.update(extract_table_names(tokens[i].tokens))
|
||||
else:
|
||||
if tokens[i].ttype == sqlparse.tokens.Keyword \
|
||||
and tokens[i].normalized in ['FROM', 'JOIN', 'LEFT JOIN', 'FULL JOIN', 'RIGHT JOIN', 'CROSS JOIN', 'INNER JOIN', 'OUTER JOIN', 'LEFT OUTER JOIN', 'RIGHT OUTER JOIN', 'FULL OUTER JOIN'] \
|
||||
and isinstance(tokens[i + 1], sqlparse.sql.Identifier):
|
||||
tables.add(tokens[i + 1].value)
|
||||
return tables
|
||||
|
||||
class SQLMetaData(object):
|
||||
TABLE_SELECTION_KEYWORDS = ('FROM', 'JOIN', 'LEFT JOIN', 'FULL JOIN', 'RIGHT JOIN', 'CROSS JOIN', 'INNER JOIN',
|
||||
'OUTER JOIN', 'LEFT OUTER JOIN', 'RIGHT OUTER JOIN', 'FULL OUTER JOIN')
|
||||
|
||||
def __init__(self, sql):
|
||||
self.sql = sql
|
||||
self.parsed_sql = sqlparse.parse(self.sql)
|
||||
|
||||
self.has_ddl_statements = self._find_ddl_statements()
|
||||
self.has_non_select_dml_statements = self._find_dml_statements()
|
||||
self.used_tables = self._find_tables()
|
||||
|
||||
def _find_ddl_statements(self):
|
||||
for statement in self.parsed_sql:
|
||||
if len([x for x in statement.flatten() if x.ttype == sqlparse.tokens.DDL]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _find_tables(self):
|
||||
tables = set()
|
||||
for statement in self.parsed_sql:
|
||||
tables.update(self.extract_table_names(statement.tokens))
|
||||
|
||||
return tables
|
||||
|
||||
def extract_table_names(self, tokens):
|
||||
tables = set()
|
||||
tokens = [t for t in tokens if t.ttype not in (sqlparse.tokens.Whitespace, sqlparse.tokens.Newline)]
|
||||
|
||||
for i in range(len(tokens)):
|
||||
if tokens[i].is_group():
|
||||
tables.update(self.extract_table_names(tokens[i].tokens))
|
||||
else:
|
||||
if tokens[i].ttype == sqlparse.tokens.Keyword and tokens[i].normalized in self.TABLE_SELECTION_KEYWORDS:
|
||||
if isinstance(tokens[i + 1], sqlparse.sql.Identifier):
|
||||
tables.add(tokens[i + 1].value)
|
||||
|
||||
if isinstance(tokens[i + 1], sqlparse.sql.IdentifierList):
|
||||
tables.update(set([t.value for t in tokens[i+1].get_identifiers()]))
|
||||
return tables
|
||||
|
||||
def _find_dml_statements(self):
|
||||
for statement in self.parsed_sql:
|
||||
for token in statement.flatten():
|
||||
if token.ttype == sqlparse.tokens.DML and token.normalized != 'SELECT':
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def gen_query_hash(sql):
|
||||
"""Returns hash of the given query after stripping all comments, line breaks and multiple
|
||||
|
28
tests/test_sql_meta_data.py
Normal file
28
tests/test_sql_meta_data.py
Normal file
@ -0,0 +1,28 @@
|
||||
from redash.utils import SQLMetaData
|
||||
from unittest import TestCase
|
||||
|
||||
|
||||
class TestSQLMetaData(TestCase):
|
||||
def test_simple_select(self):
|
||||
metadata = SQLMetaData("SELECT t FROM test")
|
||||
self.assertEquals(metadata.used_tables, set(("test",)))
|
||||
self.assertFalse(metadata.has_ddl_statements)
|
||||
self.assertFalse(metadata.has_non_select_dml_statements)
|
||||
|
||||
def test_multiple_select(self):
|
||||
metadata = SQLMetaData("SELECT t FROM test, test2 WHERE t > 1; SELECT a, b, c FROM testing as tbl")
|
||||
self.assertEquals(metadata.used_tables, set(("test", "test2", "testing")))
|
||||
self.assertFalse(metadata.has_ddl_statements)
|
||||
self.assertFalse(metadata.has_non_select_dml_statements)
|
||||
|
||||
def test_detects_ddl(self):
|
||||
metadata = SQLMetaData("SELECT t FROM test; DROP TABLE test")
|
||||
self.assertEquals(metadata.used_tables, set(("test",)))
|
||||
self.assertTrue(metadata.has_ddl_statements)
|
||||
self.assertFalse(metadata.has_non_select_dml_statements)
|
||||
|
||||
def test_detects_dml(self):
|
||||
metadata = SQLMetaData("SELECT t FROM test; DELETE * FROM test")
|
||||
self.assertEquals(metadata.used_tables, set(("test",)))
|
||||
self.assertFalse(metadata.has_ddl_statements)
|
||||
self.assertTrue(metadata.has_non_select_dml_statements)
|
Loading…
Reference in New Issue
Block a user