From 16125327b190e6a739b4431c567836bd2b4fc89a Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Tue, 13 May 2014 18:16:30 +0300 Subject: [PATCH] Class for SQL metadata logic (tables, ddl, dml statements) --- redash/utils.py | 64 +++++++++++++++++++++++++++++-------- tests/test_sql_meta_data.py | 28 ++++++++++++++++ 2 files changed, 79 insertions(+), 13 deletions(-) create mode 100644 tests/test_sql_meta_data.py diff --git a/redash/utils.py b/redash/utils.py index 7a9f0d44..9786331d 100644 --- a/redash/utils.py +++ b/redash/utils.py @@ -10,19 +10,57 @@ import sqlparse COMMENTS_REGEX = re.compile("/\*.*?\*/") -def extract_table_names(tokens): - tables = set() - tokens = [t for t in tokens if t.ttype not in (sqlparse.tokens.Whitespace, sqlparse.tokens.Newline)] - - for i in range(len(tokens)): - if tokens[i].is_group(): - tables.update(extract_table_names(tokens[i].tokens)) - else: - if tokens[i].ttype == sqlparse.tokens.Keyword \ - and tokens[i].normalized in ['FROM', 'JOIN', 'LEFT JOIN', 'FULL JOIN', 'RIGHT JOIN', 'CROSS JOIN', 'INNER JOIN', 'OUTER JOIN', 'LEFT OUTER JOIN', 'RIGHT OUTER JOIN', 'FULL OUTER JOIN'] \ - and isinstance(tokens[i + 1], sqlparse.sql.Identifier): - tables.add(tokens[i + 1].value) - return tables + +class SQLMetaData(object): + TABLE_SELECTION_KEYWORDS = ('FROM', 'JOIN', 'LEFT JOIN', 'FULL JOIN', 'RIGHT JOIN', 'CROSS JOIN', 'INNER JOIN', + 'OUTER JOIN', 'LEFT OUTER JOIN', 'RIGHT OUTER JOIN', 'FULL OUTER JOIN') + + def __init__(self, sql): + self.sql = sql + self.parsed_sql = sqlparse.parse(self.sql) + + self.has_ddl_statements = self._find_ddl_statements() + self.has_non_select_dml_statements = self._find_dml_statements() + self.used_tables = self._find_tables() + + def _find_ddl_statements(self): + for statement in self.parsed_sql: + if len([x for x in statement.flatten() if x.ttype == sqlparse.tokens.DDL]): + return True + + return False + + def _find_tables(self): + tables = set() + for statement in self.parsed_sql: + tables.update(self.extract_table_names(statement.tokens)) + + return tables + + def extract_table_names(self, tokens): + tables = set() + tokens = [t for t in tokens if t.ttype not in (sqlparse.tokens.Whitespace, sqlparse.tokens.Newline)] + + for i in range(len(tokens)): + if tokens[i].is_group(): + tables.update(self.extract_table_names(tokens[i].tokens)) + else: + if tokens[i].ttype == sqlparse.tokens.Keyword and tokens[i].normalized in self.TABLE_SELECTION_KEYWORDS: + if isinstance(tokens[i + 1], sqlparse.sql.Identifier): + tables.add(tokens[i + 1].value) + + if isinstance(tokens[i + 1], sqlparse.sql.IdentifierList): + tables.update(set([t.value for t in tokens[i+1].get_identifiers()])) + return tables + + def _find_dml_statements(self): + for statement in self.parsed_sql: + for token in statement.flatten(): + if token.ttype == sqlparse.tokens.DML and token.normalized != 'SELECT': + return True + + return False + def gen_query_hash(sql): """Returns hash of the given query after stripping all comments, line breaks and multiple diff --git a/tests/test_sql_meta_data.py b/tests/test_sql_meta_data.py new file mode 100644 index 00000000..68dc40da --- /dev/null +++ b/tests/test_sql_meta_data.py @@ -0,0 +1,28 @@ +from redash.utils import SQLMetaData +from unittest import TestCase + + +class TestSQLMetaData(TestCase): + def test_simple_select(self): + metadata = SQLMetaData("SELECT t FROM test") + self.assertEquals(metadata.used_tables, set(("test",))) + self.assertFalse(metadata.has_ddl_statements) + self.assertFalse(metadata.has_non_select_dml_statements) + + def test_multiple_select(self): + metadata = SQLMetaData("SELECT t FROM test, test2 WHERE t > 1; SELECT a, b, c FROM testing as tbl") + self.assertEquals(metadata.used_tables, set(("test", "test2", "testing"))) + self.assertFalse(metadata.has_ddl_statements) + self.assertFalse(metadata.has_non_select_dml_statements) + + def test_detects_ddl(self): + metadata = SQLMetaData("SELECT t FROM test; DROP TABLE test") + self.assertEquals(metadata.used_tables, set(("test",))) + self.assertTrue(metadata.has_ddl_statements) + self.assertFalse(metadata.has_non_select_dml_statements) + + def test_detects_dml(self): + metadata = SQLMetaData("SELECT t FROM test; DELETE * FROM test") + self.assertEquals(metadata.used_tables, set(("test",))) + self.assertFalse(metadata.has_ddl_statements) + self.assertTrue(metadata.has_non_select_dml_statements)