Class for SQL metadata logic (tables, ddl, dml statements)

This commit is contained in:
Arik Fraimovich 2014-05-13 18:16:30 +03:00
parent d8d666c971
commit 16125327b1
2 changed files with 79 additions and 13 deletions

View File

@ -10,20 +10,58 @@ import sqlparse
COMMENTS_REGEX = re.compile("/\*.*?\*/")
def extract_table_names(tokens):
class SQLMetaData(object):
TABLE_SELECTION_KEYWORDS = ('FROM', 'JOIN', 'LEFT JOIN', 'FULL JOIN', 'RIGHT JOIN', 'CROSS JOIN', 'INNER JOIN',
'OUTER JOIN', 'LEFT OUTER JOIN', 'RIGHT OUTER JOIN', 'FULL OUTER JOIN')
def __init__(self, sql):
self.sql = sql
self.parsed_sql = sqlparse.parse(self.sql)
self.has_ddl_statements = self._find_ddl_statements()
self.has_non_select_dml_statements = self._find_dml_statements()
self.used_tables = self._find_tables()
def _find_ddl_statements(self):
for statement in self.parsed_sql:
if len([x for x in statement.flatten() if x.ttype == sqlparse.tokens.DDL]):
return True
return False
def _find_tables(self):
tables = set()
for statement in self.parsed_sql:
tables.update(self.extract_table_names(statement.tokens))
return tables
def extract_table_names(self, tokens):
tables = set()
tokens = [t for t in tokens if t.ttype not in (sqlparse.tokens.Whitespace, sqlparse.tokens.Newline)]
for i in range(len(tokens)):
if tokens[i].is_group():
tables.update(extract_table_names(tokens[i].tokens))
tables.update(self.extract_table_names(tokens[i].tokens))
else:
if tokens[i].ttype == sqlparse.tokens.Keyword \
and tokens[i].normalized in ['FROM', 'JOIN', 'LEFT JOIN', 'FULL JOIN', 'RIGHT JOIN', 'CROSS JOIN', 'INNER JOIN', 'OUTER JOIN', 'LEFT OUTER JOIN', 'RIGHT OUTER JOIN', 'FULL OUTER JOIN'] \
and isinstance(tokens[i + 1], sqlparse.sql.Identifier):
if tokens[i].ttype == sqlparse.tokens.Keyword and tokens[i].normalized in self.TABLE_SELECTION_KEYWORDS:
if isinstance(tokens[i + 1], sqlparse.sql.Identifier):
tables.add(tokens[i + 1].value)
if isinstance(tokens[i + 1], sqlparse.sql.IdentifierList):
tables.update(set([t.value for t in tokens[i+1].get_identifiers()]))
return tables
def _find_dml_statements(self):
for statement in self.parsed_sql:
for token in statement.flatten():
if token.ttype == sqlparse.tokens.DML and token.normalized != 'SELECT':
return True
return False
def gen_query_hash(sql):
"""Returns hash of the given query after stripping all comments, line breaks and multiple
spaces, and lower casing all text.

View File

@ -0,0 +1,28 @@
from redash.utils import SQLMetaData
from unittest import TestCase
class TestSQLMetaData(TestCase):
def test_simple_select(self):
metadata = SQLMetaData("SELECT t FROM test")
self.assertEquals(metadata.used_tables, set(("test",)))
self.assertFalse(metadata.has_ddl_statements)
self.assertFalse(metadata.has_non_select_dml_statements)
def test_multiple_select(self):
metadata = SQLMetaData("SELECT t FROM test, test2 WHERE t > 1; SELECT a, b, c FROM testing as tbl")
self.assertEquals(metadata.used_tables, set(("test", "test2", "testing")))
self.assertFalse(metadata.has_ddl_statements)
self.assertFalse(metadata.has_non_select_dml_statements)
def test_detects_ddl(self):
metadata = SQLMetaData("SELECT t FROM test; DROP TABLE test")
self.assertEquals(metadata.used_tables, set(("test",)))
self.assertTrue(metadata.has_ddl_statements)
self.assertFalse(metadata.has_non_select_dml_statements)
def test_detects_dml(self):
metadata = SQLMetaData("SELECT t FROM test; DELETE * FROM test")
self.assertEquals(metadata.used_tables, set(("test",)))
self.assertFalse(metadata.has_ddl_statements)
self.assertTrue(metadata.has_non_select_dml_statements)