From 44dff83046fd0f0a3666cbb262ea7b7ac8eeefb8 Mon Sep 17 00:00:00 2001 From: Jannis Leidel Date: Mon, 7 Jan 2019 09:30:42 +0100 Subject: [PATCH] Add "Active at" column to user list. (#3026) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add last_active_at to users page * Use our JSON encoder as the SQLAlchemy JSON serializer. * Fixed some inconsistencies in the user query class methods. * Minor cosmetic fixes. * Add some make tasks for easier development. * Add user detail sync system based on Redis backend. There is a periodic Celery task that updates a new “details” JSONB column in the “user” table with the data from Redis. Currently this is only used for tracking the date of last activity of a user but can be extended with other user information later. Updates a few dependencies. * Normalize a few Flask extension API names. * Reduce implementation complexity of JSONEncoder. * Use request_started signal to make sure we have a request context. Otherwise loading the user based on the request won’t work. * Fix test that checks if disabled users can login. This correctly uses a URL path that includes the current organization and checks for the error message. The previous test seems to have been a red herring. * Minor cosmetic fixes. * Remove needs_sync in favor of just deleting things. * Misc review fixes. * Ignore line length. * Split redash.models import several modules. * Move walrus UTC DateTimeField into redash.models.types. * Restore distinctly loading dashboards. * Simplify default values for user details. * Define __repr__ methods generically. * Consistently have underscore methods at the top of model methods. * Fix tests. * Split redash.models import several modules. * Update to latest walrus and redis-py. * Update kombu to 4.2.2 for redis-py 3.x compatibility. * Remove redis-cli container after running Make task. * Move buffer condition after datetime/time conditions. * Update walrus to 0.7.1. * Refactor some query APIs. This uses the flask-sqlalchemy helpers consistently and makes more use of mixins. * Post rebase fixes. * Use correct kombu version * Fix migration down revision --- Makefile | 8 +- client/app/pages/users/list.html | 7 + ...f8a917aa8e_add_user_details_json_column.py | 24 + redash/__init__.py | 21 +- redash/authentication/__init__.py | 4 +- redash/handlers/organization.py | 2 +- redash/handlers/users.py | 2 + redash/{models.py => models/__init__.py} | 864 +++--------------- redash/models/base.py | 85 ++ redash/models/changes.py | 86 ++ redash/models/mixins.py | 28 + redash/models/organizations.py | 82 ++ redash/models/types.py | 106 +++ redash/models/users.py | 448 +++++++++ redash/monitor.py | 6 +- redash/tasks/__init__.py | 4 +- redash/tasks/general.py | 13 + redash/tasks/queries.py | 8 +- redash/utils/__init__.py | 28 +- redash/worker.py | 4 + requirements.txt | 10 +- setup.cfg | 5 + tests/handlers/test_users.py | 9 +- tests/handlers/test_visualizations.py | 2 +- tests/models/test_query_results.py | 4 +- tests/tasks/test_queries.py | 2 +- tests/test_models.py | 112 ++- 27 files changed, 1201 insertions(+), 773 deletions(-) create mode 100644 migrations/versions/e7f8a917aa8e_add_user_details_json_column.py rename redash/{models.py => models/__init__.py} (61%) create mode 100644 redash/models/base.py create mode 100644 redash/models/changes.py create mode 100644 redash/models/mixins.py create mode 100644 redash/models/organizations.py create mode 100644 redash/models/types.py create mode 100644 redash/models/users.py diff --git a/Makefile b/Makefile index d13ecacf..19a146fa 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: compose_build up test_db create_database clean down bundle tests lint backend-unit-tests frontend-unit-tests test build watch start +.PHONY: compose_build up test_db create_database clean down bundle tests lint backend-unit-tests frontend-unit-tests test build watch start redis-cli bash compose_build: docker-compose build @@ -49,3 +49,9 @@ watch: bundle start: bundle npm run start + +redis-cli: + docker-compose run --rm redis redis-cli -h redis + +bash: + docker-compose run --rm server bash diff --git a/client/app/pages/users/list.html b/client/app/pages/users/list.html index bef079d7..dcf28a7d 100644 --- a/client/app/pages/users/list.html +++ b/client/app/pages/users/list.html @@ -46,6 +46,10 @@ Joined + + Last Active At + + @@ -62,6 +66,9 @@ + + +
diff --git a/migrations/versions/e7f8a917aa8e_add_user_details_json_column.py b/migrations/versions/e7f8a917aa8e_add_user_details_json_column.py new file mode 100644 index 00000000..cd47bb77 --- /dev/null +++ b/migrations/versions/e7f8a917aa8e_add_user_details_json_column.py @@ -0,0 +1,24 @@ +"""Add user details JSON column. + +Revision ID: e7f8a917aa8e +Revises: 71477dadd6ef +Create Date: 2018-11-08 16:12:17.023569 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'e7f8a917aa8e' +down_revision = '640888ce445d' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('users', sa.Column('details', postgresql.JSON(astext_type=sa.Text()), server_default='{}', nullable=True)) + + +def downgrade(): + op.drop_column('users', 'details') diff --git a/redash/__init__.py b/redash/__init__.py index 94255090..dd2f33c6 100644 --- a/redash/__init__.py +++ b/redash/__init__.py @@ -2,7 +2,8 @@ import sys import logging import urlparse import urllib -import redis + +import walrus from flask import Flask, current_app from flask_sslify import SSLify from werkzeug.contrib.fixers import ProxyFix @@ -47,7 +48,7 @@ def create_redis_connection(): else: db = 0 - r = redis.StrictRedis(unix_socket_path=redis_url.path, db=db) + client = walrus.Database(unix_socket_path=redis_url.path, db=db) else: if redis_url.path: redis_db = redis_url.path[1] @@ -55,13 +56,14 @@ def create_redis_connection(): redis_db = 0 # Redis passwords might be quoted with special characters redis_password = redis_url.password and urllib.unquote(redis_url.password) - r = redis.StrictRedis(host=redis_url.hostname, port=redis_url.port, db=redis_db, password=redis_password) + client = walrus.Database(host=redis_url.hostname, port=redis_url.port, db=redis_db, password=redis_password) - return r + return client setup_logging() redis_connection = create_redis_connection() + mail = Mail() migrate = Migrate() mail.init_mail(settings.all_settings()) @@ -90,12 +92,10 @@ class SlugConverter(BaseConverter): def create_app(load_admin=True): - from redash import extensions, handlers + from redash import admin, authentication, extensions, handlers from redash.handlers.webpack import configure_webpack from redash.handlers import chrome_logger - from redash.admin import init_admin - from redash.models import db - from redash.authentication import setup_authentication + from redash.models import db, users from redash.metrics.request import provision_app app = Flask(__name__, @@ -131,14 +131,15 @@ def create_app(load_admin=True): db.init_app(app) migrate.init_app(app, db) if load_admin: - init_admin(app) + admin.init_admin(app) mail.init_app(app) - setup_authentication(app) + authentication.init_app(app) limiter.init_app(app) handlers.init_app(app) configure_webpack(app) extensions.init_extensions(app) chrome_logger.init_app(app) + users.init_app(app) return app diff --git a/redash/authentication/__init__.py b/redash/authentication/__init__.py index c507cdf6..24bc50c0 100644 --- a/redash/authentication/__init__.py +++ b/redash/authentication/__init__.py @@ -105,7 +105,7 @@ def hmac_load_user_from_request(request): return user if query_id: - query = models.db.session.query(models.Query).filter(models.Query.id == query_id).one() + query = models.Query.query.filter(models.Query.id == query_id).one() calculated_signature = sign(query.api_key, request.path, expires) if query.api_key and signature == calculated_signature: @@ -238,7 +238,7 @@ def logout_and_redirect_to_index(): return redirect(index_url) -def setup_authentication(app): +def init_app(app): from redash.authentication import google_oauth, saml_auth, remote_user_auth, ldap_auth login_manager.init_app(app) diff --git a/redash/handlers/organization.py b/redash/handlers/organization.py index eb81c773..384f4018 100644 --- a/redash/handlers/organization.py +++ b/redash/handlers/organization.py @@ -10,7 +10,7 @@ from redash.authentication import current_org @login_required def organization_status(org_slug=None): counters = { - 'users': models.User.all_not_disabled(current_org).count(), + 'users': models.User.all(current_org).count(), 'alerts': models.Alert.all(group_ids=current_user.group_ids).count(), 'data_sources': models.DataSource.all(current_org, group_ids=current_user.group_ids).count(), 'queries': models.Query.all_queries(current_user.group_ids, current_user.id, drafts=True).count(), diff --git a/redash/handlers/users.py b/redash/handlers/users.py index 1ab92dcf..ae3b9b88 100644 --- a/redash/handlers/users.py +++ b/redash/handlers/users.py @@ -20,6 +20,8 @@ from redash.authentication.account import invite_link_for_user, send_invite_emai order_map = { 'name': 'name', '-name': '-name', + 'active_at': 'active_at', + '-active_at': '-active_at', 'created_at': 'created_at', '-created_at': '-created_at', 'groups': 'group_ids', diff --git a/redash/models.py b/redash/models/__init__.py similarity index 61% rename from redash/models.py rename to redash/models/__init__.py index 0f10190f..3ca5a32c 100644 --- a/redash/models.py +++ b/redash/models/__init__.py @@ -10,14 +10,20 @@ import time import pytz from functools import reduce -from six import python_2_unicode_compatible, string_types, text_type import xlsxwriter -from flask import current_app as app, url_for -from flask_login import AnonymousUserMixin, UserMixin -from flask_sqlalchemy import SQLAlchemy, BaseQuery -from passlib.apps import custom_app_context as pwd_context +from six import python_2_unicode_compatible, text_type +from sqlalchemy import distinct, or_, and_, UniqueConstraint +from sqlalchemy.dialects import postgresql +from sqlalchemy.event import listens_for +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import backref, contains_eager, joinedload, subqueryload, load_only +from sqlalchemy.orm.exc import NoResultFound # noqa: F401 +from sqlalchemy import func +from sqlalchemy_utils import generic_relationship +from sqlalchemy_utils.types import TSVectorType +from sqlalchemy_utils.models import generic_repr -from redash import settings, redis_connection, utils +from redash import redis_connection, utils from redash.destinations import (get_configuration_schema_for_destination_type, get_destination) from redash.metrics import database # noqa: F401 @@ -25,53 +31,16 @@ from redash.query_runner import (get_configuration_schema_for_query_runner_type, get_query_runner) from redash.utils import generate_token, json_dumps, json_loads from redash.utils.configuration import ConfigurationContainer -from redash.settings.organization import settings as org_settings -from sqlalchemy import distinct, or_, and_, UniqueConstraint -from sqlalchemy.dialects import postgresql -from sqlalchemy.event import listens_for -from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.ext.mutable import Mutable -from sqlalchemy.inspection import inspect -from sqlalchemy.orm import backref, contains_eager, joinedload, object_session, load_only -from sqlalchemy.orm.exc import NoResultFound # noqa: F401 -from sqlalchemy.types import TypeDecorator -from sqlalchemy.orm.attributes import flag_modified -from sqlalchemy import func -from sqlalchemy_searchable import SearchQueryMixin, make_searchable, vectorizer -from sqlalchemy_utils import generic_relationship, EmailType -from sqlalchemy_utils.types import TSVectorType +from .base import db, gfk_type, Column, GFKBase, SearchBaseQuery +from .changes import ChangeTrackingMixin, Change # noqa +from .mixins import BelongsToOrgMixin, TimestampMixin +from .organizations import Organization +from .types import Configuration, MutableDict, MutableList, PseudoJSON +from .users import (AccessPermission, AnonymousUser, ApiUser, Group, User, + UserDetail) # noqa - -class SQLAlchemyExt(SQLAlchemy): - def apply_pool_defaults(self, app, options): - if settings.SQLALCHEMY_DISABLE_POOL: - from sqlalchemy.pool import NullPool - options['poolclass'] = NullPool - else: - return super(SQLAlchemyExt, self).apply_pool_defaults(app, options) - - -db = SQLAlchemyExt(session_options={ - 'expire_on_commit': False -}) -# Make sure the SQLAlchemy mappers are all properly configured first. -# This is required by SQLAlchemy-Searchable as it adds DDL listeners -# on the configuration phase of models. -db.configure_mappers() - -# listen to a few database events to set up functions, trigger updates -# and indexes for the full text search -make_searchable(options={'regconfig': 'pg_catalog.simple'}) - - -class SearchBaseQuery(BaseQuery, SearchQueryMixin): - """ - The SQA query class to use when full text search is wanted. - """ - - -Column = functools.partial(db.Column, nullable=False) +logger = logging.getLogger(__name__) class ScheduledQueriesExecutions(object): @@ -98,485 +67,9 @@ class ScheduledQueriesExecutions(object): scheduled_queries_executions = ScheduledQueriesExecutions() -# AccessPermission and Change use a 'generic foreign key' approach to refer to -# either queries or dashboards. -# TODO replace this with association tables. -_gfk_types = {} - - -class GFKBase(object): - """ - Compatibility with 'generic foreign key' approach Peewee used. - """ - # XXX Replace this with table-per-association. - object_type = Column(db.String(255)) - object_id = Column(db.Integer) - - _object = None - - @property - def object(self): - session = object_session(self) - if self._object or not session: - return self._object - else: - object_class = _gfk_types[self.object_type] - self._object = session.query(object_class).filter( - object_class.id == self.object_id).first() - return self._object - - @object.setter - def object(self, value): - self._object = value - self.object_type = value.__class__.__tablename__ - self.object_id = value.id - - -# XXX replace PseudoJSON and MutableDict with real JSON field -class PseudoJSON(TypeDecorator): - impl = db.Text - - def process_bind_param(self, value, dialect): - return json_dumps(value) - - def process_result_value(self, value, dialect): - if not value: - return value - return json_loads(value) - - -class MutableDict(Mutable, dict): - @classmethod - def coerce(cls, key, value): - "Convert plain dictionaries to MutableDict." - - if not isinstance(value, MutableDict): - if isinstance(value, dict): - return MutableDict(value) - - # this call will raise ValueError - return Mutable.coerce(key, value) - else: - return value - - def __setitem__(self, key, value): - "Detect dictionary set events and emit change events." - - dict.__setitem__(self, key, value) - self.changed() - - def __delitem__(self, key): - "Detect dictionary del events and emit change events." - - dict.__delitem__(self, key) - self.changed() - - -class MutableList(Mutable, list): - def append(self, value): - list.append(self, value) - self.changed() - - def remove(self, value): - list.remove(self, value) - self.changed() - - @classmethod - def coerce(cls, key, value): - if not isinstance(value, MutableList): - if isinstance(value, list): - return MutableList(value) - return Mutable.coerce(key, value) - else: - return value - - -class TimestampMixin(object): - updated_at = Column(db.DateTime(True), default=db.func.now(), nullable=False) - created_at = Column(db.DateTime(True), default=db.func.now(), nullable=False) - - -@listens_for(TimestampMixin, 'before_update', propagate=True) -def timestamp_before_update(mapper, connection, target): - # Check if we really want to update the updated_at value - if hasattr(target, 'skip_updated_at'): - return - - target.updated_at = db.func.now() - - -class ChangeTrackingMixin(object): - skipped_fields = ('id', 'created_at', 'updated_at', 'version') - _clean_values = None - - def __init__(self, *a, **kw): - super(ChangeTrackingMixin, self).__init__(*a, **kw) - self.record_changes(self.user) - - def prep_cleanvalues(self): - self.__dict__['_clean_values'] = {} - for attr in inspect(self.__class__).column_attrs: - col, = attr.columns - # 'query' is col name but not attr name - self._clean_values[col.name] = None - - def __setattr__(self, key, value): - if self._clean_values is None: - self.prep_cleanvalues() - for attr in inspect(self.__class__).column_attrs: - col, = attr.columns - previous = getattr(self, attr.key, None) - self._clean_values[col.name] = previous - - super(ChangeTrackingMixin, self).__setattr__(key, value) - - def record_changes(self, changed_by): - db.session.add(self) - db.session.flush() - changes = {} - for attr in inspect(self.__class__).column_attrs: - col, = attr.columns - if attr.key not in self.skipped_fields: - changes[col.name] = {'previous': self._clean_values[col.name], - 'current': getattr(self, attr.key)} - - db.session.add(Change(object=self, - object_version=self.version, - user=changed_by, - change=changes)) - - -class BelongsToOrgMixin(object): - @classmethod - def get_by_id_and_org(cls, object_id, org): - return db.session.query(cls).filter(cls.id == object_id, cls.org == org).one() - - -class PermissionsCheckMixin(object): - def has_permission(self, permission): - return self.has_permissions((permission,)) - - def has_permissions(self, permissions): - has_permissions = reduce(lambda a, b: a and b, - map(lambda permission: permission in self.permissions, - permissions), - True) - - return has_permissions - - -class AnonymousUser(AnonymousUserMixin, PermissionsCheckMixin): - @property - def permissions(self): - return [] - - def is_api_user(self): - return False - - -class ApiUser(UserMixin, PermissionsCheckMixin): - def __init__(self, api_key, org, groups, name=None): - self.object = None - if isinstance(api_key, string_types): - self.id = api_key - self.name = name - else: - self.id = api_key.api_key - self.name = "ApiKey: {}".format(api_key.id) - self.object = api_key.object - self.group_ids = groups - self.org = org - - def __repr__(self): - return u"<{}>".format(self.name) - - def is_api_user(self): - return True - - @property - def org_id(self): - if not self.org: - return None - return self.org.id - - @property - def permissions(self): - return ['view_query'] - - def has_access(self, obj, access_type): - return False - - -@python_2_unicode_compatible -class Organization(TimestampMixin, db.Model): - SETTING_GOOGLE_APPS_DOMAINS = 'google_apps_domains' - SETTING_IS_PUBLIC = "is_public" - - id = Column(db.Integer, primary_key=True) - name = Column(db.String(255)) - slug = Column(db.String(255), unique=True) - settings = Column(MutableDict.as_mutable(PseudoJSON)) - groups = db.relationship("Group", lazy="dynamic") - events = db.relationship("Event", lazy="dynamic", order_by="desc(Event.created_at)",) - - - __tablename__ = 'organizations' - - def __repr__(self): - return u"".format(self.id, self.name) - - def __str__(self): - return u'%s (%s)' % (self.name, self.id) - - @classmethod - def get_by_slug(cls, slug): - return cls.query.filter(cls.slug == slug).first() - - @property - def default_group(self): - return self.groups.filter(Group.name == 'default', Group.type == Group.BUILTIN_GROUP).first() - - @property - def google_apps_domains(self): - return self.settings.get(self.SETTING_GOOGLE_APPS_DOMAINS, []) - - @property - def is_public(self): - return self.settings.get(self.SETTING_IS_PUBLIC, False) - - @property - def is_disabled(self): - return self.settings.get('is_disabled', False) - - def disable(self): - self.settings['is_disabled'] = True - - def enable(self): - self.settings['is_disabled'] = False - - def set_setting(self, key, value): - if key not in org_settings: - raise KeyError(key) - - self.settings.setdefault('settings', {}) - self.settings['settings'][key] = value - flag_modified(self, 'settings') - - def get_setting(self, key, raise_on_missing=True): - if key in self.settings.get('settings', {}): - return self.settings['settings'][key] - - if key in org_settings: - return org_settings[key] - - if raise_on_missing: - raise KeyError(key) - - return None - - @property - def admin_group(self): - return self.groups.filter(Group.name == 'admin', Group.type == Group.BUILTIN_GROUP).first() - - def has_user(self, email): - return self.users.filter(User.email == email).count() == 1 - - -@python_2_unicode_compatible -class Group(db.Model, BelongsToOrgMixin): - DEFAULT_PERMISSIONS = ['create_dashboard', 'create_query', 'edit_dashboard', 'edit_query', - 'view_query', 'view_source', 'execute_query', 'list_users', 'schedule_query', - 'list_dashboards', 'list_alerts', 'list_data_sources'] - - BUILTIN_GROUP = 'builtin' - REGULAR_GROUP = 'regular' - - id = Column(db.Integer, primary_key=True) - data_sources = db.relationship("DataSourceGroup", back_populates="group", - cascade="all") - org_id = Column(db.Integer, db.ForeignKey('organizations.id')) - org = db.relationship(Organization, back_populates="groups") - type = Column(db.String(255), default=REGULAR_GROUP) - name = Column(db.String(100)) - permissions = Column(postgresql.ARRAY(db.String(255)), - default=DEFAULT_PERMISSIONS) - created_at = Column(db.DateTime(True), default=db.func.now()) - - __tablename__ = 'groups' - - def to_dict(self): - return { - 'id': self.id, - 'name': self.name, - 'permissions': self.permissions, - 'type': self.type, - 'created_at': self.created_at - } - - @classmethod - def all(cls, org): - return cls.query.filter(cls.org == org) - - @classmethod - def members(cls, group_id): - return User.query.filter(User.group_ids.any(group_id)) - - @classmethod - def find_by_name(cls, org, group_names): - result = cls.query.filter(cls.org == org, cls.name.in_(group_names)) - return list(result) - - def __str__(self): - return text_type(self.id) - - -@python_2_unicode_compatible -class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin): - id = Column(db.Integer, primary_key=True) - org_id = Column(db.Integer, db.ForeignKey('organizations.id')) - org = db.relationship(Organization, backref=db.backref("users", lazy="dynamic")) - name = Column(db.String(320)) - email = Column(EmailType) - _profile_image_url = Column('profile_image_url', db.String(320), nullable=True) - password_hash = Column(db.String(128), nullable=True) - # XXX replace with association table - group_ids = Column('groups', MutableList.as_mutable(postgresql.ARRAY(db.Integer)), nullable=True) - api_key = Column(db.String(40), - default=lambda: generate_token(40), - unique=True) - - disabled_at = Column(db.DateTime(True), default=None, nullable=True) - - __tablename__ = 'users' - __table_args__ = (db.Index('users_org_id_email', 'org_id', 'email', unique=True),) - - @property - def is_disabled(self): - return self.disabled_at is not None - - def disable(self): - self.disabled_at = db.func.now() - - def enable(self): - self.disabled_at = None - - def __init__(self, *args, **kwargs): - if kwargs.get('email') is not None: - kwargs['email'] = kwargs['email'].lower() - super(User, self).__init__(*args, **kwargs) - - def to_dict(self, with_api_key=False): - profile_image_url = self.profile_image_url - if self.is_disabled: - assets = app.extensions['webpack']['assets'] or {} - path = 'images/avatar.svg' - profile_image_url = url_for('static', filename=assets.get(path, path)) - - d = { - 'id': self.id, - 'name': self.name, - 'email': self.email, - 'profile_image_url': profile_image_url, - 'groups': self.group_ids, - 'updated_at': self.updated_at, - 'created_at': self.created_at, - 'disabled_at': self.disabled_at, - 'is_disabled': self.is_disabled, - } - - if self.password_hash is None: - d['auth_type'] = 'external' - else: - d['auth_type'] = 'password' - - if with_api_key: - d['api_key'] = self.api_key - - return d - - def is_api_user(self): - return False - - @property - def profile_image_url(self): - if self._profile_image_url is not None: - return self._profile_image_url - - email_md5 = hashlib.md5(self.email.lower()).hexdigest() - return "https://www.gravatar.com/avatar/{}?s=40&d=identicon".format(email_md5) - - @property - def permissions(self): - # TODO: this should be cached. - return list(itertools.chain(*[g.permissions for g in - Group.query.filter(Group.id.in_(self.group_ids))])) - - @classmethod - def get_by_email_and_org(cls, email, org): - return cls.query.filter(cls.email == email, cls.org == org).one() - - @classmethod - def get_by_api_key_and_org(cls, api_key, org): - return cls.query.filter(cls.api_key == api_key, cls.org == org).one() - - @classmethod - def all(cls, org): - return cls.query.filter(cls.org == org).filter(cls.disabled_at == None) - - @classmethod - def search(cls, base_query, term): - term = u'%{}%'.format(term) - search_filter = or_(cls.name.ilike(term), cls.email.like(term)) - - return base_query.filter(search_filter) - - @classmethod - def all_disabled(cls, org): - return cls.query.filter(cls.org == org).filter(cls.disabled_at != None) - - @classmethod - def all_not_disabled(cls, org): - return cls.all(org).filter(cls.disabled_at == None) - - @classmethod - def find_by_email(cls, email): - return cls.query.filter(cls.email == email) - - def __str__(self): - return u'%s (%s)' % (self.name, self.email) - - def hash_password(self, password): - self.password_hash = pwd_context.encrypt(password) - - def verify_password(self, password): - return self.password_hash and pwd_context.verify(password, self.password_hash) - - def update_group_assignments(self, group_names): - groups = Group.find_by_name(self.org, group_names) - groups.append(self.org.default_group) - self.group_ids = [g.id for g in groups] - db.session.add(self) - db.session.commit() - - def has_access(self, obj, access_type): - return AccessPermission.exists(obj, access_type, grantee=self) - - def get_id(self): - identity = hashlib.md5("{0},{1}".format(self.email, self.password_hash)).hexdigest() - return u"{0}-{1}".format(self.id, identity) - - -class Configuration(TypeDecorator): - impl = db.Text - - def process_bind_param(self, value, dialect): - return value.to_json() - - def process_result_value(self, value, dialect): - return ConfigurationContainer.from_json(value) - @python_2_unicode_compatible +@generic_repr('id', 'name', 'type', 'org_id', 'created_at') class DataSource(BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey('organizations.id')) @@ -684,7 +177,7 @@ class DataSource(BelongsToOrgMixin, db.Model): return redis_connection.get(self._pause_key()) def pause(self, reason=None): - redis_connection.set(self._pause_key(), reason) + redis_connection.set(self._pause_key(), reason or '') def resume(self): redis_connection.delete(self._pause_key()) @@ -695,9 +188,10 @@ class DataSource(BelongsToOrgMixin, db.Model): return dsg def remove_group(self, group): - db.session.query(DataSourceGroup).filter( + DataSourceGroup.query.filter( DataSourceGroup.group == group, - DataSourceGroup.data_source == self).delete() + DataSourceGroup.data_source == self + ).delete() db.session.commit() def update_group_permission(self, group, view_only): @@ -719,11 +213,13 @@ class DataSource(BelongsToOrgMixin, db.Model): # XXX examine call sites to see if a regular SQLA collection would work better @property def groups(self): - groups = db.session.query(DataSourceGroup).filter( - DataSourceGroup.data_source == self) + groups = DataSourceGroup.query.filter( + DataSourceGroup.data_source == self + ) return dict(map(lambda g: (g.group_id, g.view_only), groups)) +@generic_repr('id', 'data_source_id', 'group_id', 'view_only') class DataSourceGroup(db.Model): # XXX drop id, use datasource/group as PK id = Column(db.Integer, primary_key=True) @@ -737,6 +233,7 @@ class DataSourceGroup(db.Model): @python_2_unicode_compatible +@generic_repr('id', 'org_id', 'data_source_id', 'query_hash', 'runtime', 'retrieved_at') class QueryResult(db.Model, BelongsToOrgMixin): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey('organizations.id')) @@ -751,6 +248,9 @@ class QueryResult(db.Model, BelongsToOrgMixin): __tablename__ = 'query_results' + def __str__(self): + return u"%d | %s | %s" % (self.id, self.query_hash, self.retrieved_at) + def to_dict(self): return { 'id': self.id, @@ -765,32 +265,35 @@ class QueryResult(db.Model, BelongsToOrgMixin): @classmethod def unused(cls, days=7): age_threshold = datetime.datetime.now() - datetime.timedelta(days=days) - - unused_results = (db.session.query(QueryResult.id).filter( - Query.id == None, QueryResult.retrieved_at < age_threshold) - .outerjoin(Query)) - - return unused_results + return ( + cls.query.filter( + Query.id.is_(None), + cls.retrieved_at < age_threshold + ) + .outerjoin(Query) + ).options(load_only('id')) @classmethod def get_latest(cls, data_source, query, max_age=0): query_hash = utils.gen_query_hash(query) if max_age == -1: - q = db.session.query(QueryResult).filter( + query = cls.query.filter( cls.query_hash == query_hash, - cls.data_source == data_source).order_by( - QueryResult.retrieved_at.desc()) + cls.data_source == data_source + ) else: - q = db.session.query(QueryResult).filter( - QueryResult.query_hash == query_hash, - QueryResult.data_source == data_source, - db.func.timezone('utc', QueryResult.retrieved_at) + - datetime.timedelta(seconds=max_age) >= - db.func.timezone('utc', db.func.now()) - ).order_by(QueryResult.retrieved_at.desc()) + query = cls.query.filter( + cls.query_hash == query_hash, + cls.data_source == data_source, + ( + db.func.timezone('utc', cls.retrieved_at) + + datetime.timedelta(seconds=max_age) >= + db.func.timezone('utc', db.func.now()) + ) + ) - return q.first() + return query.order_by(cls.retrieved_at.desc()).first() @classmethod def store_result(cls, org, data_source, query_hash, query, data, run_time, retrieved_at): @@ -804,11 +307,13 @@ class QueryResult(db.Model, BelongsToOrgMixin): db.session.add(query_result) logging.info("Inserted query (%s) data; id=%s", query_hash, query_result.id) # TODO: Investigate how big an impact this select-before-update makes. - queries = db.session.query(Query).filter( + queries = Query.query.filter( Query.query_hash == query_hash, - Query.data_source == data_source) + Query.data_source == data_source + ) for q in queries: q.latest_query_data = query_result + # don't auto-update the updated_at timestamp q.skip_updated_at = True db.session.add(q) query_ids = [q.id for q in queries] @@ -816,9 +321,6 @@ class QueryResult(db.Model, BelongsToOrgMixin): return query_result, query_ids - def __str__(self): - return u"%d | %s | %s" % (self.id, self.query_hash, self.retrieved_at) - @property def groups(self): return self.data_source.groups @@ -892,6 +394,10 @@ def should_schedule_next(previous_iteration, now, interval, time=None, day_of_we @python_2_unicode_compatible +@gfk_type +@generic_repr('id', 'name', 'query_hash', 'version', 'user_id', 'org_id', + 'data_source_id', 'query_hash', 'last_modified_by_id', + 'is_archived', 'is_draft', 'schedule', 'schedule_failures') class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) version = Column(db.Integer, default=1) @@ -932,6 +438,9 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): 'version_id_generator': False } + def __str__(self): + return text_type(self.id) + def archive(self, user=None): db.session.add(self) self.is_archived = True @@ -1006,9 +515,15 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): @classmethod def favorites(cls, user, base_query=None): - if base_query == None: + if base_query is None: base_query = cls.all_queries(user.group_ids, user.id, drafts=True) - return base_query.join((Favorite, and_(Favorite.object_type==u'Query', Favorite.object_id==Query.id))).filter(Favorite.user_id==user.id) + return base_query.join(( + Favorite, + and_( + Favorite.object_type == u'Query', + Favorite.object_id == Query.id + ) + )).filter(Favorite.user_id == user.id) @classmethod def all_tags(cls, user, include_drafts=False): @@ -1146,17 +661,6 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): "The SQLAlchemy expression for the property above." return func.lower(cls.name) - def __str__(self): - return text_type(self.id) - - def __repr__(self): - return '' % (self.id, self.name or 'untitled') - - -@vectorizer(db.Integer) -def integer_vectorizer(column): - return db.func.cast(column, db.Text) - @listens_for(Query.query_text, 'set') def gen_query_hash(target, val, oldval, initiator): @@ -1169,9 +673,8 @@ def query_last_modified_by(target, val, oldval, initiator): target.last_modified_by_id = val +@generic_repr('id', 'object_type', 'object_id', 'user_id', 'org_id') class Favorite(TimestampMixin, db.Model): - __tablename__ = "favorites" - id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey("organizations.id")) @@ -1182,7 +685,10 @@ class Favorite(TimestampMixin, db.Model): user_id = Column(db.Integer, db.ForeignKey("users.id")) user = db.relationship(User, backref='favorites') - __table_args__ = (UniqueConstraint("object_type", "object_id", "user_id", name="unique_favorite"),) + __tablename__ = "favorites" + __table_args__ = ( + UniqueConstraint("object_type", "object_id", "user_id", name="unique_favorite"), + ) @classmethod def is_favorite(cls, user, object): @@ -1198,114 +704,8 @@ class Favorite(TimestampMixin, db.Model): return map(lambda fav: fav.object_id, cls.query.filter(cls.object_id.in_(map(lambda o: o.id, objects)), cls.object_type == object_type, cls.user_id == user)) -class AccessPermission(GFKBase, db.Model): - id = Column(db.Integer, primary_key=True) - # 'object' defined in GFKBase - access_type = Column(db.String(255)) - grantor_id = Column(db.Integer, db.ForeignKey("users.id")) - grantor = db.relationship(User, backref='grantor', foreign_keys=[grantor_id]) - grantee_id = Column(db.Integer, db.ForeignKey("users.id")) - grantee = db.relationship(User, backref='grantee', foreign_keys=[grantee_id]) - - __tablename__ = 'access_permissions' - - @classmethod - def grant(cls, obj, access_type, grantee, grantor): - grant = cls.query.filter(cls.object_type == obj.__tablename__, - cls.object_id == obj.id, - cls.access_type == access_type, - cls.grantee == grantee, - cls.grantor == grantor).one_or_none() - - if not grant: - grant = cls(object_type=obj.__tablename__, - object_id=obj.id, - access_type=access_type, - grantee=grantee, - grantor=grantor) - db.session.add(grant) - - return grant - - @classmethod - def revoke(cls, obj, grantee, access_type=None): - permissions = cls._query(obj, access_type, grantee) - return permissions.delete() - - @classmethod - def find(cls, obj, access_type=None, grantee=None, grantor=None): - return cls._query(obj, access_type, grantee, grantor) - - @classmethod - def exists(cls, obj, access_type, grantee): - return cls.find(obj, access_type, grantee).count() > 0 - - @classmethod - def _query(cls, obj, access_type=None, grantee=None, grantor=None): - q = cls.query.filter(cls.object_id == obj.id, - cls.object_type == obj.__tablename__) - - if access_type: - q = q.filter(AccessPermission.access_type == access_type) - - if grantee: - q = q.filter(AccessPermission.grantee == grantee) - - if grantor: - q = q.filter(AccessPermission.grantor == grantor) - - return q - - def to_dict(self): - d = { - 'id': self.id, - 'object_id': self.object_id, - 'object_type': self.object_type, - 'access_type': self.access_type, - 'grantor': self.grantor_id, - 'grantee': self.grantee_id - } - return d - - -class Change(GFKBase, db.Model): - id = Column(db.Integer, primary_key=True) - # 'object' defined in GFKBase - object_version = Column(db.Integer, default=0) - user_id = Column(db.Integer, db.ForeignKey("users.id")) - user = db.relationship(User, backref='changes') - change = Column(PseudoJSON) - created_at = Column(db.DateTime(True), default=db.func.now()) - - __tablename__ = 'changes' - - def to_dict(self, full=True): - d = { - 'id': self.id, - 'object_id': self.object_id, - 'object_type': self.object_type, - 'change_type': self.change_type, - 'object_version': self.object_version, - 'change': self.change, - 'created_at': self.created_at - } - - if full: - d['user'] = self.user.to_dict() - else: - d['user_id'] = self.user_id - - return d - - @classmethod - def last_change(cls, obj): - return db.session.query(cls).filter( - cls.object_id == obj.id, - cls.object_type == obj.__class__.__tablename__).order_by( - cls.object_version.desc()).first() - - -class Alert(TimestampMixin, db.Model): +@generic_repr('id', 'name', 'query_id', 'user_id', 'state', 'last_triggered_at', 'rearm') +class Alert(TimestampMixin, BelongsToOrgMixin, db.Model): UNKNOWN_STATE = 'unknown' OK_STATE = 'ok' TRIGGERED_STATE = 'triggered' @@ -1326,15 +726,23 @@ class Alert(TimestampMixin, db.Model): @classmethod def all(cls, group_ids): - return db.session.query(Alert)\ - .options(joinedload(Alert.user), joinedload(Alert.query_rel))\ - .join(Query)\ - .join(DataSourceGroup, DataSourceGroup.data_source_id == Query.data_source_id)\ + return ( + cls.query + .options( + joinedload(Alert.user), + joinedload(Alert.query_rel), + ) + .join(Query) + .join( + DataSourceGroup, + DataSourceGroup.data_source_id == Query.data_source_id + ) .filter(DataSourceGroup.group_id.in_(group_ids)) + ) @classmethod - def get_by_id_and_org(cls, id, org): - return db.session.query(Alert).join(Query).filter(Alert.id == id, Query.org == org).one() + def get_by_id_and_org(cls, object_id, org): + return super(Alert, cls).get_by_id_and_org(object_id, org, Query) def evaluate(self): data = json_loads(self.query_rel.latest_query_data.data) @@ -1374,6 +782,8 @@ def generate_slug(ctx): @python_2_unicode_compatible +@gfk_type +@generic_repr('id', 'name', 'slug', 'user_id', 'org_id', 'version', 'is_archived', 'is_draft') class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) version = Column(db.Integer) @@ -1396,11 +806,16 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model "version_id_col": version } + def __str__(self): + return u"%s=%s" % (self.id, self.name) + @classmethod def all(cls, org, group_ids, user_id): query = ( Dashboard.query - .options(joinedload(Dashboard.user)) + .options( + subqueryload(Dashboard.user).load_only('_profile_image_url', 'name'), + ) .outerjoin(Widget) .outerjoin(Visualization) .outerjoin(Query) @@ -1442,7 +857,15 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model def favorites(cls, user, base_query=None): if base_query is None: base_query = cls.all(user.org, user.group_ids, user.id) - return base_query.join((Favorite, and_(Favorite.object_type==u'Dashboard', Favorite.object_id==Dashboard.id))).filter(Favorite.user_id==user.id) + return base_query.join( + ( + Favorite, + and_( + Favorite.object_type == u'Dashboard', + Favorite.object_id == Dashboard.id + ) + ) + ).filter(Favorite.user_id == user.id) @classmethod def get_by_slug_and_org(cls, slug, org): @@ -1458,12 +881,10 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model "The SQLAlchemy expression for the property above." return func.lower(cls.name) - def __str__(self): - return u"%s=%s" % (self.id, self.name) - @python_2_unicode_compatible -class Visualization(TimestampMixin, db.Model): +@generic_repr('id', 'name', 'type', 'query_id') +class Visualization(TimestampMixin, BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) type = Column(db.String(100)) query_id = Column(db.Integer, db.ForeignKey("queries.id")) @@ -1475,15 +896,13 @@ class Visualization(TimestampMixin, db.Model): __tablename__ = 'visualizations' - @classmethod - def get_by_id_and_org(cls, visualization_id, org): - return db.session.query(Visualization).join(Query).filter( - cls.id == visualization_id, - Query.org == org).one() - def __str__(self): return u"%s %s" % (self.id, self.type) + @classmethod + def get_by_id_and_org(cls, object_id, org): + return super(Visualization, cls).get_by_id_and_org(object_id, org, Query) + def copy(self): return { 'type': self.type, @@ -1494,7 +913,8 @@ class Visualization(TimestampMixin, db.Model): @python_2_unicode_compatible -class Widget(TimestampMixin, db.Model): +@generic_repr('id', 'visualization_id', 'dashboard_id') +class Widget(TimestampMixin, BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) visualization_id = Column(db.Integer, db.ForeignKey('visualizations.id'), nullable=True) visualization = db.relationship(Visualization, backref='widgets') @@ -1509,11 +929,12 @@ class Widget(TimestampMixin, db.Model): return u"%s" % self.id @classmethod - def get_by_id_and_org(cls, widget_id, org): - return db.session.query(cls).join(Dashboard).filter(cls.id == widget_id, Dashboard.org == org).one() + def get_by_id_and_org(cls, object_id, org): + return super(Widget, cls).get_by_id_and_org(object_id, org, Dashboard) @python_2_unicode_compatible +@generic_repr('id', 'object_type', 'object_id', 'action', 'user_id', 'org_id', 'created_at') class Event(db.Model): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey("organizations.id")) @@ -1560,6 +981,7 @@ class Event(db.Model): return event +@generic_repr('id', 'created_by_id', 'org_id', 'active') class ApiKey(TimestampMixin, GFKBase, db.Model): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey("organizations.id")) @@ -1571,7 +993,9 @@ class ApiKey(TimestampMixin, GFKBase, db.Model): created_by = db.relationship(User) __tablename__ = 'api_keys' - __table_args__ = (db.Index('api_keys_object_type_object_id', 'object_type', 'object_id'),) + __table_args__ = ( + db.Index('api_keys_object_type_object_id', 'object_type', 'object_id'), + ) @classmethod def get_by_api_key(cls, api_key): @@ -1579,7 +1003,11 @@ class ApiKey(TimestampMixin, GFKBase, db.Model): @classmethod def get_by_object(cls, object): - return cls.query.filter(cls.object_type == object.__class__.__tablename__, cls.object_id == object.id, cls.active == True).first() + return cls.query.filter( + cls.object_type == object.__class__.__tablename__, + cls.object_id == object.id, + cls.active == True + ).first() @classmethod def create_for_object(cls, object, user): @@ -1589,6 +1017,7 @@ class ApiKey(TimestampMixin, GFKBase, db.Model): @python_2_unicode_compatible +@generic_repr('id', 'name', 'type', 'user_id', 'org_id', 'created_at') class NotificationDestination(BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey("organizations.id")) @@ -1601,8 +1030,14 @@ class NotificationDestination(BelongsToOrgMixin, db.Model): created_at = Column(db.DateTime(True), default=db.func.now()) __tablename__ = 'notification_destinations' - __table_args__ = (db.Index('notification_destinations_org_id_name', 'org_id', - 'name', unique=True),) + __table_args__ = ( + db.Index( + 'notification_destinations_org_id_name', 'org_id', 'name', unique=True + ), + ) + + def __str__(self): + return text_type(self.name) def to_dict(self, all=False): d = { @@ -1619,9 +1054,6 @@ class NotificationDestination(BelongsToOrgMixin, db.Model): return d - def __str__(self): - return text_type(self.name) - @property def destination(self): return get_destination(self.type, self.options) @@ -1639,20 +1071,25 @@ class NotificationDestination(BelongsToOrgMixin, db.Model): app, host, self.options) +@generic_repr('id', 'user_id', 'destination_id', 'alert_id') class AlertSubscription(TimestampMixin, db.Model): id = Column(db.Integer, primary_key=True) user_id = Column(db.Integer, db.ForeignKey("users.id")) user = db.relationship(User) destination_id = Column(db.Integer, - db.ForeignKey("notification_destinations.id"), - nullable=True) + db.ForeignKey("notification_destinations.id"), + nullable=True) destination = db.relationship(NotificationDestination) alert_id = Column(db.Integer, db.ForeignKey("alerts.id")) alert = db.relationship(Alert, back_populates="subscriptions") __tablename__ = 'alert_subscriptions' - __table_args__ = (db.Index('alert_subscriptions_destination_id_alert_id', - 'destination_id', 'alert_id', unique=True),) + __table_args__ = ( + db.Index( + 'alert_subscriptions_destination_id_alert_id', + 'destination_id', 'alert_id', unique=True + ), + ) def to_dict(self): d = { @@ -1683,6 +1120,7 @@ class AlertSubscription(TimestampMixin, db.Model): return destination.notify(alert, query, user, new_state, app, host, options) +@generic_repr('id', 'trigger', 'user_id', 'org_id') class QuerySnippet(TimestampMixin, db.Model, BelongsToOrgMixin): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey("organizations.id")) @@ -1692,6 +1130,7 @@ class QuerySnippet(TimestampMixin, db.Model, BelongsToOrgMixin): user_id = Column(db.Integer, db.ForeignKey("users.id")) user = db.relationship(User, backref="query_snippets") snippet = Column(db.Text) + __tablename__ = 'query_snippets' @classmethod @@ -1712,9 +1151,6 @@ class QuerySnippet(TimestampMixin, db.Model, BelongsToOrgMixin): return d -_gfk_types = {'queries': Query, 'dashboards': Dashboard} - - def init_db(): default_org = Organization(name="Default", slug='default', settings={}) admin_group = Group(name='admin', permissions=['admin', 'super_admin'], org=default_org, type=Group.BUILTIN_GROUP) diff --git a/redash/models/base.py b/redash/models/base.py new file mode 100644 index 00000000..ec19ebce --- /dev/null +++ b/redash/models/base.py @@ -0,0 +1,85 @@ +import functools + +from flask_sqlalchemy import BaseQuery, SQLAlchemy +from sqlalchemy.orm import object_session +from sqlalchemy.pool import NullPool +from sqlalchemy_searchable import make_searchable, vectorizer, SearchQueryMixin + +from redash import settings +from redash.utils import json_dumps + + +class RedashSQLAlchemy(SQLAlchemy): + def apply_driver_hacks(self, app, info, options): + options.update(json_serializer=json_dumps) + super(RedashSQLAlchemy, self).apply_driver_hacks(app, info, options) + + def apply_pool_defaults(self, app, options): + super(RedashSQLAlchemy, self).apply_pool_defaults(app, options) + if settings.SQLALCHEMY_DISABLE_POOL: + options['poolclass'] = NullPool + + +db = RedashSQLAlchemy(session_options={ + 'expire_on_commit': False +}) +# Make sure the SQLAlchemy mappers are all properly configured first. +# This is required by SQLAlchemy-Searchable as it adds DDL listeners +# on the configuration phase of models. +db.configure_mappers() + +# listen to a few database events to set up functions, trigger updates +# and indexes for the full text search +make_searchable(options={'regconfig': 'pg_catalog.simple'}) + + +class SearchBaseQuery(BaseQuery, SearchQueryMixin): + """ + The SQA query class to use when full text search is wanted. + """ + + +@vectorizer(db.Integer) +def integer_vectorizer(column): + return db.func.cast(column, db.Text) + + +Column = functools.partial(db.Column, nullable=False) + +# AccessPermission and Change use a 'generic foreign key' approach to refer to +# either queries or dashboards. +# TODO replace this with association tables. +_gfk_types = {} + + +def gfk_type(cls): + _gfk_types[cls.__tablename__] = cls + return cls + + +class GFKBase(object): + """ + Compatibility with 'generic foreign key' approach Peewee used. + """ + # XXX Replace this with table-per-association. + object_type = Column(db.String(255)) + object_id = Column(db.Integer) + + _object = None + + @property + def object(self): + session = object_session(self) + if self._object or not session: + return self._object + else: + object_class = _gfk_types[self.object_type] + self._object = session.query(object_class).filter( + object_class.id == self.object_id).first() + return self._object + + @object.setter + def object(self, value): + self._object = value + self.object_type = value.__class__.__tablename__ + self.object_id = value.id diff --git a/redash/models/changes.py b/redash/models/changes.py new file mode 100644 index 00000000..1529c4b3 --- /dev/null +++ b/redash/models/changes.py @@ -0,0 +1,86 @@ +from sqlalchemy.inspection import inspect +from sqlalchemy_utils.models import generic_repr + +from .base import GFKBase, db, Column +from .types import PseudoJSON + + +@generic_repr('id', 'object_type', 'object_id', 'created_at') +class Change(GFKBase, db.Model): + id = Column(db.Integer, primary_key=True) + # 'object' defined in GFKBase + object_version = Column(db.Integer, default=0) + user_id = Column(db.Integer, db.ForeignKey("users.id")) + user = db.relationship("User", backref='changes') + change = Column(PseudoJSON) + created_at = Column(db.DateTime(True), default=db.func.now()) + + __tablename__ = 'changes' + + def to_dict(self, full=True): + d = { + 'id': self.id, + 'object_id': self.object_id, + 'object_type': self.object_type, + 'change_type': self.change_type, + 'object_version': self.object_version, + 'change': self.change, + 'created_at': self.created_at + } + + if full: + d['user'] = self.user.to_dict() + else: + d['user_id'] = self.user_id + + return d + + @classmethod + def last_change(cls, obj): + return cls.query.filter( + cls.object_id == obj.id, + cls.object_type == obj.__class__.__tablename__ + ).order_by( + cls.object_version.desc() + ).first() + + +class ChangeTrackingMixin(object): + skipped_fields = ('id', 'created_at', 'updated_at', 'version') + _clean_values = None + + def __init__(self, *a, **kw): + super(ChangeTrackingMixin, self).__init__(*a, **kw) + self.record_changes(self.user) + + def prep_cleanvalues(self): + self.__dict__['_clean_values'] = {} + for attr in inspect(self.__class__).column_attrs: + col, = attr.columns + # 'query' is col name but not attr name + self._clean_values[col.name] = None + + def __setattr__(self, key, value): + if self._clean_values is None: + self.prep_cleanvalues() + for attr in inspect(self.__class__).column_attrs: + col, = attr.columns + previous = getattr(self, attr.key, None) + self._clean_values[col.name] = previous + + super(ChangeTrackingMixin, self).__setattr__(key, value) + + def record_changes(self, changed_by): + db.session.add(self) + db.session.flush() + changes = {} + for attr in inspect(self.__class__).column_attrs: + col, = attr.columns + if attr.key not in self.skipped_fields: + changes[col.name] = {'previous': self._clean_values[col.name], + 'current': getattr(self, attr.key)} + + db.session.add(Change(object=self, + object_version=self.version, + user=changed_by, + change=changes)) diff --git a/redash/models/mixins.py b/redash/models/mixins.py new file mode 100644 index 00000000..a19f4717 --- /dev/null +++ b/redash/models/mixins.py @@ -0,0 +1,28 @@ +from sqlalchemy.event import listens_for + +from .base import db, Column + + +class TimestampMixin(object): + updated_at = Column(db.DateTime(True), default=db.func.now(), nullable=False) + created_at = Column(db.DateTime(True), default=db.func.now(), nullable=False) + + +@listens_for(TimestampMixin, 'before_update', propagate=True) +def timestamp_before_update(mapper, connection, target): + # Check if we really want to update the updated_at value + if hasattr(target, 'skip_updated_at'): + return + + target.updated_at = db.func.now() + + +class BelongsToOrgMixin(object): + @classmethod + def get_by_id_and_org(cls, object_id, org, org_cls=None): + query = cls.query.filter(cls.id == object_id) + if org_cls is None: + query = query.filter(cls.org == org) + else: + query = query.join(org_cls).filter(org_cls.org == org) + return query.one() diff --git a/redash/models/organizations.py b/redash/models/organizations.py new file mode 100644 index 00000000..fcf8cc8c --- /dev/null +++ b/redash/models/organizations.py @@ -0,0 +1,82 @@ +from six import python_2_unicode_compatible +from sqlalchemy.orm.attributes import flag_modified +from sqlalchemy_utils.models import generic_repr + +from redash.settings.organization import settings as org_settings + +from .base import db, Column +from .mixins import TimestampMixin +from .types import MutableDict, PseudoJSON +from .users import User, Group + + +@python_2_unicode_compatible +@generic_repr('id', 'name', 'slug') +class Organization(TimestampMixin, db.Model): + SETTING_GOOGLE_APPS_DOMAINS = 'google_apps_domains' + SETTING_IS_PUBLIC = "is_public" + + id = Column(db.Integer, primary_key=True) + name = Column(db.String(255)) + slug = Column(db.String(255), unique=True) + settings = Column(MutableDict.as_mutable(PseudoJSON)) + groups = db.relationship("Group", lazy="dynamic") + events = db.relationship("Event", lazy="dynamic", order_by="desc(Event.created_at)",) + + __tablename__ = 'organizations' + + def __str__(self): + return u'%s (%s)' % (self.name, self.id) + + @classmethod + def get_by_slug(cls, slug): + return cls.query.filter(cls.slug == slug).first() + + @property + def default_group(self): + return self.groups.filter(Group.name == 'default', Group.type == Group.BUILTIN_GROUP).first() + + @property + def google_apps_domains(self): + return self.settings.get(self.SETTING_GOOGLE_APPS_DOMAINS, []) + + @property + def is_public(self): + return self.settings.get(self.SETTING_IS_PUBLIC, False) + + @property + def is_disabled(self): + return self.settings.get('is_disabled', False) + + def disable(self): + self.settings['is_disabled'] = True + + def enable(self): + self.settings['is_disabled'] = False + + def set_setting(self, key, value): + if key not in org_settings: + raise KeyError(key) + + self.settings.setdefault('settings', {}) + self.settings['settings'][key] = value + flag_modified(self, 'settings') + + def get_setting(self, key, raise_on_missing=True): + if key in self.settings.get('settings', {}): + return self.settings['settings'][key] + + if key in org_settings: + return org_settings[key] + + if raise_on_missing: + raise KeyError(key) + + return None + + @property + def admin_group(self): + return self.groups.filter(Group.name == 'admin', Group.type == Group.BUILTIN_GROUP).first() + + def has_user(self, email): + return self.users.filter(User.email == email).count() == 1 diff --git a/redash/models/types.py b/redash/models/types.py new file mode 100644 index 00000000..3220b9db --- /dev/null +++ b/redash/models/types.py @@ -0,0 +1,106 @@ +import pytz +import walrus +from sqlalchemy.types import TypeDecorator +from sqlalchemy.ext.indexable import index_property +from sqlalchemy.ext.mutable import Mutable + +from redash.utils import json_dumps, json_loads +from redash.utils.configuration import ConfigurationContainer + +from .base import db + + +class Configuration(TypeDecorator): + impl = db.Text + + def process_bind_param(self, value, dialect): + return value.to_json() + + def process_result_value(self, value, dialect): + return ConfigurationContainer.from_json(value) + + +# XXX replace PseudoJSON and MutableDict with real JSON field +class PseudoJSON(TypeDecorator): + impl = db.Text + + def process_bind_param(self, value, dialect): + return json_dumps(value) + + def process_result_value(self, value, dialect): + if not value: + return value + return json_loads(value) + + +class MutableDict(Mutable, dict): + @classmethod + def coerce(cls, key, value): + "Convert plain dictionaries to MutableDict." + + if not isinstance(value, MutableDict): + if isinstance(value, dict): + return MutableDict(value) + + # this call will raise ValueError + return Mutable.coerce(key, value) + else: + return value + + def __setitem__(self, key, value): + "Detect dictionary set events and emit change events." + + dict.__setitem__(self, key, value) + self.changed() + + def __delitem__(self, key): + "Detect dictionary del events and emit change events." + + dict.__delitem__(self, key) + self.changed() + + +class MutableList(Mutable, list): + def append(self, value): + list.append(self, value) + self.changed() + + def remove(self, value): + list.remove(self, value) + self.changed() + + @classmethod + def coerce(cls, key, value): + if not isinstance(value, MutableList): + if isinstance(value, list): + return MutableList(value) + return Mutable.coerce(key, value) + else: + return value + + +class json_cast_property(index_property): + """ + A SQLAlchemy index property that is able to cast the + entity attribute as the specified cast type. Useful + for JSON and JSONB colums for easier querying/filtering. + """ + def __init__(self, cast_type, *args, **kwargs): + super(json_cast_property, self).__init__(*args, **kwargs) + self.cast_type = cast_type + + def expr(self, model): + expr = super(json_cast_property, self).expr(model) + return expr.astext.cast(self.cast_type) + + +class UTCDateTimeField(walrus.DateTimeField): + """ + A walrus DateTimeField that makes the value timezone aware + using the pytz.utc timezone on return. + """ + def python_value(self, value): + value = super(UTCDateTimeField, self).python_value(value) + if value.tzinfo is None: + value = value.replace(tzinfo=pytz.utc) + return value diff --git a/redash/models/users.py b/redash/models/users.py new file mode 100644 index 00000000..036543d7 --- /dev/null +++ b/redash/models/users.py @@ -0,0 +1,448 @@ +import hashlib +import itertools +import logging +from functools import reduce +from operator import or_ + +import walrus +from flask import current_app as app, url_for, request_started +from flask_login import current_user, AnonymousUserMixin, UserMixin +from passlib.apps import custom_app_context as pwd_context +from six import python_2_unicode_compatible, string_types, text_type +from sqlalchemy.exc import DBAPIError +from sqlalchemy.dialects import postgresql + +from sqlalchemy_utils import EmailType +from sqlalchemy_utils.models import generic_repr + +from redash import redis_connection +from redash.utils import generate_token, utcnow + +from .base import db, Column, GFKBase +from .mixins import TimestampMixin, BelongsToOrgMixin +from .types import json_cast_property, MutableDict, MutableList, UTCDateTimeField + +logger = logging.getLogger(__name__) + + +class UserDetail(walrus.Model): + """ + A walrus data model to store some user data to Redis to be + synced to Postgres asynchronously. + """ + __database__ = redis_connection + __namespace__ = 'redash.user.details' + + user_id = walrus.IntegerField(index=True) + updated_at = UTCDateTimeField(index=True, default=utcnow) + + @classmethod + def update(cls, user_id): + """ + Update the user details hash using the given redis + pipeline, user id, optional redis id and optional user + details. + + The fields uid, rid and updated (timestamp) are + enforced and can't be overwritten. + """ + # try getting the user detail with the given user ID + # or create one if it doesn't exist yet (e.g. when key was purged) + try: + user_detail = cls.get(cls.user_id == user_id) + # update the timestamp with the current time + user_detail.updated_at = utcnow() + # save to Redis + user_detail.save() + except ValueError: + user_detail = cls.create( + user_id=user_id, + updated_at=utcnow(), + ) + return user_detail + + @classmethod + def sync(cls, chunksize=1000): + """ + Syncs user details to Postgres (to the JSON field User.details). + """ + to_sync = {} + try: + for user_detail in cls.all(): + to_sync[user_detail.user_id] = user_detail + + user_ids = list(to_sync.keys()) + if not user_ids: + return + logger.info( + 'syncing users: %s', + ', '.join([str(uid) for uid in user_ids]) + ) + # get all SQLA users that need to be updated + users = User.query.filter(User.id.in_(user_ids)) + for i, user in enumerate(users): + update = to_sync[user.id] + user.active_at = update.updated_at + # flush changes to the database after a certain + # number of items and extend the list of keys to + # stop sync in case of exceptions + if i % chunksize == 0: + db.session.flush() + db.session.commit() + except DBAPIError: + # reset list of keys to stop sync + pass + finally: + user_ids = [str(user_id) for user_id in to_sync.keys()] + if user_ids: + logger.info( + 'Deleting temporary user details for users %s', + ', '.join(user_ids) + ) + delete_query = [ + UserDetail.user_id == str(user_id) + for user_id in user_ids + ] + UserDetail.query_delete(reduce(or_, delete_query)) + + +def update_user_detail(sender, *args, **kwargs): + """ + Used as a Flask request_started signal callback that adds + the current user's details to Redis + """ + if ( + current_user.get_id() and + current_user.is_authenticated and + not current_user.is_api_user() + ): + UserDetail.update(current_user.id) + + +def init_app(app): + """ + A Flask extension to keep user details updates in Redis and + sync it periodically to the database (User.details). + """ + request_started.connect(update_user_detail, app) + + +class PermissionsCheckMixin(object): + def has_permission(self, permission): + return self.has_permissions((permission,)) + + def has_permissions(self, permissions): + has_permissions = reduce(lambda a, b: a and b, + map(lambda permission: permission in self.permissions, + permissions), + True) + + return has_permissions + + +@python_2_unicode_compatible +@generic_repr('id', 'name', 'email') +class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCheckMixin): + id = Column(db.Integer, primary_key=True) + org_id = Column(db.Integer, db.ForeignKey('organizations.id')) + org = db.relationship("Organization", backref=db.backref("users", lazy="dynamic")) + name = Column(db.String(320)) + email = Column(EmailType) + _profile_image_url = Column('profile_image_url', db.String(320), nullable=True) + password_hash = Column(db.String(128), nullable=True) + # XXX replace with association table + group_ids = Column('groups', MutableList.as_mutable(postgresql.ARRAY(db.Integer)), nullable=True) + api_key = Column(db.String(40), + default=lambda: generate_token(40), + unique=True) + + disabled_at = Column(db.DateTime(True), default=None, nullable=True) + details = Column(MutableDict.as_mutable(postgresql.JSON), nullable=True, + server_default='{}', default={}) + active_at = json_cast_property(db.DateTime(True), 'details', 'active_at', + default=None) + + __tablename__ = 'users' + __table_args__ = ( + db.Index('users_org_id_email', 'org_id', 'email', unique=True), + ) + + def __str__(self): + return u'%s (%s)' % (self.name, self.email) + + def __init__(self, *args, **kwargs): + if kwargs.get('email') is not None: + kwargs['email'] = kwargs['email'].lower() + super(User, self).__init__(*args, **kwargs) + + @property + def is_disabled(self): + return self.disabled_at is not None + + def disable(self): + self.disabled_at = db.func.now() + + def enable(self): + self.disabled_at = None + + def to_dict(self, with_api_key=False): + profile_image_url = self.profile_image_url + if self.is_disabled: + assets = app.extensions['webpack']['assets'] or {} + path = 'images/avatar.svg' + profile_image_url = url_for('static', filename=assets.get(path, path)) + + d = { + 'id': self.id, + 'name': self.name, + 'email': self.email, + 'profile_image_url': profile_image_url, + 'groups': self.group_ids, + 'updated_at': self.updated_at, + 'created_at': self.created_at, + 'disabled_at': self.disabled_at, + 'is_disabled': self.is_disabled, + 'active_at': self.active_at, + } + + if self.password_hash is None: + d['auth_type'] = 'external' + else: + d['auth_type'] = 'password' + + if with_api_key: + d['api_key'] = self.api_key + + return d + + def is_api_user(self): + return False + + @property + def profile_image_url(self): + if self._profile_image_url is not None: + return self._profile_image_url + + email_md5 = hashlib.md5(self.email.lower()).hexdigest() + return "https://www.gravatar.com/avatar/{}?s=40&d=identicon".format(email_md5) + + @property + def permissions(self): + # TODO: this should be cached. + return list(itertools.chain(*[g.permissions for g in + Group.query.filter(Group.id.in_(self.group_ids))])) + + @classmethod + def get_by_org(cls, org): + return cls.query.filter(cls.org == org) + + @classmethod + def get_by_email_and_org(cls, email, org): + return cls.get_by_org(org).filter(cls.email == email).one() + + @classmethod + def get_by_api_key_and_org(cls, api_key, org): + return cls.get_by_org(org).filter(cls.api_key == api_key).one() + + @classmethod + def all(cls, org): + return cls.get_by_org(org).filter(cls.disabled_at.is_(None)) + + @classmethod + def search(cls, base_query, term): + term = u'%{}%'.format(term) + search_filter = or_(cls.name.ilike(term), cls.email.like(term)) + + return base_query.filter(search_filter) + + @classmethod + def all_disabled(cls, org): + return cls.get_by_org(org).filter(cls.disabled_at.isnot(None)) + + @classmethod + def find_by_email(cls, email): + return cls.query.filter(cls.email == email) + + def hash_password(self, password): + self.password_hash = pwd_context.encrypt(password) + + def verify_password(self, password): + return self.password_hash and pwd_context.verify(password, self.password_hash) + + def update_group_assignments(self, group_names): + groups = Group.find_by_name(self.org, group_names) + groups.append(self.org.default_group) + self.group_ids = [g.id for g in groups] + db.session.add(self) + db.session.commit() + + def has_access(self, obj, access_type): + return AccessPermission.exists(obj, access_type, grantee=self) + + def get_id(self): + identity = hashlib.md5( + "{},{}".format(self.email, self.password_hash) + ).hexdigest() + return u"{0}-{1}".format(self.id, identity) + + +@python_2_unicode_compatible +@generic_repr('id', 'name', 'type', 'org_id') +class Group(db.Model, BelongsToOrgMixin): + DEFAULT_PERMISSIONS = ['create_dashboard', 'create_query', 'edit_dashboard', 'edit_query', + 'view_query', 'view_source', 'execute_query', 'list_users', 'schedule_query', + 'list_dashboards', 'list_alerts', 'list_data_sources'] + + BUILTIN_GROUP = 'builtin' + REGULAR_GROUP = 'regular' + + id = Column(db.Integer, primary_key=True) + data_sources = db.relationship("DataSourceGroup", back_populates="group", + cascade="all") + org_id = Column(db.Integer, db.ForeignKey('organizations.id')) + org = db.relationship("Organization", back_populates="groups") + type = Column(db.String(255), default=REGULAR_GROUP) + name = Column(db.String(100)) + permissions = Column(postgresql.ARRAY(db.String(255)), + default=DEFAULT_PERMISSIONS) + created_at = Column(db.DateTime(True), default=db.func.now()) + + __tablename__ = 'groups' + + def __str__(self): + return text_type(self.id) + + def to_dict(self): + return { + 'id': self.id, + 'name': self.name, + 'permissions': self.permissions, + 'type': self.type, + 'created_at': self.created_at + } + + @classmethod + def all(cls, org): + return cls.query.filter(cls.org == org) + + @classmethod + def members(cls, group_id): + return User.query.filter(User.group_ids.any(group_id)) + + @classmethod + def find_by_name(cls, org, group_names): + result = cls.query.filter(cls.org == org, cls.name.in_(group_names)) + return list(result) + + +@generic_repr('id', 'object_type', 'object_id', 'access_type', 'grantor_id', 'grantee_id') +class AccessPermission(GFKBase, db.Model): + id = Column(db.Integer, primary_key=True) + # 'object' defined in GFKBase + access_type = Column(db.String(255)) + grantor_id = Column(db.Integer, db.ForeignKey("users.id")) + grantor = db.relationship(User, backref='grantor', foreign_keys=[grantor_id]) + grantee_id = Column(db.Integer, db.ForeignKey("users.id")) + grantee = db.relationship(User, backref='grantee', foreign_keys=[grantee_id]) + + __tablename__ = 'access_permissions' + + @classmethod + def grant(cls, obj, access_type, grantee, grantor): + grant = cls.query.filter(cls.object_type == obj.__tablename__, + cls.object_id == obj.id, + cls.access_type == access_type, + cls.grantee == grantee, + cls.grantor == grantor).one_or_none() + + if not grant: + grant = cls(object_type=obj.__tablename__, + object_id=obj.id, + access_type=access_type, + grantee=grantee, + grantor=grantor) + db.session.add(grant) + + return grant + + @classmethod + def revoke(cls, obj, grantee, access_type=None): + permissions = cls._query(obj, access_type, grantee) + return permissions.delete() + + @classmethod + def find(cls, obj, access_type=None, grantee=None, grantor=None): + return cls._query(obj, access_type, grantee, grantor) + + @classmethod + def exists(cls, obj, access_type, grantee): + return cls.find(obj, access_type, grantee).count() > 0 + + @classmethod + def _query(cls, obj, access_type=None, grantee=None, grantor=None): + q = cls.query.filter(cls.object_id == obj.id, + cls.object_type == obj.__tablename__) + + if access_type: + q = q.filter(AccessPermission.access_type == access_type) + + if grantee: + q = q.filter(AccessPermission.grantee == grantee) + + if grantor: + q = q.filter(AccessPermission.grantor == grantor) + + return q + + def to_dict(self): + d = { + 'id': self.id, + 'object_id': self.object_id, + 'object_type': self.object_type, + 'access_type': self.access_type, + 'grantor': self.grantor_id, + 'grantee': self.grantee_id + } + return d + + +class AnonymousUser(AnonymousUserMixin, PermissionsCheckMixin): + @property + def permissions(self): + return [] + + def is_api_user(self): + return False + + +class ApiUser(UserMixin, PermissionsCheckMixin): + def __init__(self, api_key, org, groups, name=None): + self.object = None + if isinstance(api_key, string_types): + self.id = api_key + self.name = name + else: + self.id = api_key.api_key + self.name = "ApiKey: {}".format(api_key.id) + self.object = api_key.object + self.group_ids = groups + self.org = org + + def __repr__(self): + return u"<{}>".format(self.name) + + def is_api_user(self): + return True + + @property + def org_id(self): + if not self.org: + return None + return self.org.id + + @property + def permissions(self): + return ['view_query'] + + def has_access(self, obj, access_type): + return False \ No newline at end of file diff --git a/redash/monitor.py b/redash/monitor.py index f1f241eb..6b41f221 100644 --- a/redash/monitor.py +++ b/redash/monitor.py @@ -8,9 +8,9 @@ def get_redis_status(): def get_object_counts(): status = {} - status['queries_count'] = models.db.session.query(models.Query).count() + status['queries_count'] = models.Query.query.count() if settings.FEATURE_SHOW_QUERY_RESULTS_COUNT: - status['query_results_count'] = models.db.session.query(models.QueryResult).count() + status['query_results_count'] = models.QueryResult.query.count() status['unused_query_results_count'] = models.QueryResult.unused().count() status['dashboards_count'] = models.Dashboard.query.count() status['widgets_count'] = models.Widget.query.count() @@ -35,7 +35,7 @@ def get_queues_status(): 'data_sources': ', '.join(sources), 'size': redis_connection.llen(queue) } - + queues['celery'] = { 'size': redis_connection.llen('celery'), 'data_sources': '' diff --git a/redash/tasks/__init__.py b/redash/tasks/__init__.py index f242e4c5..418ed92c 100644 --- a/redash/tasks/__init__.py +++ b/redash/tasks/__init__.py @@ -1,3 +1,3 @@ -from .general import record_event, version_check, send_mail +from .general import record_event, version_check, send_mail, sync_user_details from .queries import QueryTask, refresh_queries, refresh_schemas, cleanup_tasks, cleanup_query_results, execute_query -from .alerts import check_alerts_for_query \ No newline at end of file +from .alerts import check_alerts_for_query diff --git a/redash/tasks/general.py b/redash/tasks/general.py index 1257b82e..8be76d27 100644 --- a/redash/tasks/general.py +++ b/redash/tasks/general.py @@ -57,3 +57,16 @@ def send_mail(to, subject, html, text): mail.send(message) except Exception: logger.exception('Failed sending message: %s', message.subject) + + +@celery.task( + name="redash.tasks.sync_user_details", + ignore_result=True, + soft_time_limit=60, + time_limit=120, + # let the task expire after 45 seconds since there will be + # another task 15 seconds later anyway + expires=45, +) +def sync_user_details(): + models.UserDetail.sync() diff --git a/redash/tasks/queries.py b/redash/tasks/queries.py index a178496d..1fc2c50e 100644 --- a/redash/tasks/queries.py +++ b/redash/tasks/queries.py @@ -59,11 +59,11 @@ class QueryTaskTracker(object): self.data['updated_at'] = time.time() key_name = self._key_name(self.data['task_id']) connection.set(key_name, json_dumps(self.data)) - connection.zadd(self._get_list(), time.time(), key_name) + connection.zadd(self._get_list(), {key_name: time.time()}) - for l in self.ALL_LISTS: - if l != self._get_list(): - connection.zrem(l, key_name) + for _list in self.ALL_LISTS: + if _list != self._get_list(): + connection.zrem(_list, key_name) # TOOD: this is not thread/concurrency safe. In current code this is not an issue, but better to fix this. def update(self, **kwargs): diff --git a/redash/utils/__init__.py b/redash/utils/__init__.py index 5a14805f..b2b8766a 100644 --- a/redash/utils/__init__.py +++ b/redash/utils/__init__.py @@ -76,17 +76,31 @@ class JSONEncoder(simplejson.JSONEncoder): def default(self, o): # Some SQLAlchemy collections are lazy. if isinstance(o, Query): - return list(o) + result = list(o) elif isinstance(o, decimal.Decimal): - return float(o) + result = float(o) elif isinstance(o, (datetime.timedelta, uuid.UUID)): - return str(o) - elif isinstance(o, (datetime.date, datetime.time)): - return o.isoformat() + result = str(o) + # See "Date Time String Format" in the ECMA-262 specification. + if isinstance(o, datetime.datetime): + result = o.isoformat() + if o.microsecond: + result = result[:23] + result[26:] + if result.endswith('+00:00'): + result = result[:-6] + 'Z' + elif isinstance(o, datetime.date): + result = o.isoformat() + elif isinstance(o, datetime.time): + if o.utcoffset() is not None: + raise ValueError("JSON can't represent timezone-aware times.") + result = o.isoformat() + if o.microsecond: + result = result[:12] elif isinstance(o, buffer): - return binascii.hexlify(o) + result = binascii.hexlify(o) else: - return super(JSONEncoder, self).default(o) + result = super(JSONEncoder, self).default(o) + return result def json_loads(data, *args, **kwargs): diff --git a/redash/worker.py b/redash/worker.py index 0a3224a0..dfd504a0 100644 --- a/redash/worker.py +++ b/redash/worker.py @@ -27,6 +27,10 @@ celery_schedule = { 'refresh_schemas': { 'task': 'redash.tasks.refresh_schemas', 'schedule': timedelta(minutes=settings.SCHEMAS_REFRESH_SCHEDULE) + }, + 'sync_user_details': { + 'task': 'redash.tasks.sync_user_details', + 'schedule': timedelta(minutes=1), } } diff --git a/requirements.txt b/requirements.txt index 6f51c79f..191497de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ Flask-Admin==1.5.2 Flask-RESTful==0.3.5 Flask-Login==0.4.0 Flask-OAuthLib==0.9.3 -Flask-SQLAlchemy==2.1 +Flask-SQLAlchemy==2.3.2 Flask-Migrate==2.0.1 flask-mail==0.9.1 flask-sslify==0.1.5 @@ -19,19 +19,21 @@ passlib==1.6.2 aniso8601==1.1.0 blinker==1.3 psycopg2==2.7.3.2 -python-dateutil==2.4.2 +python-dateutil==2.7.3 pytz==2016.7 PyYAML==3.12 -redis==2.10.5 +redis==3.0.1 +walrus==0.7.1 requests==2.21.0 six==1.11.0 -SQLAlchemy==1.2.7 +SQLAlchemy==1.2.12 SQLAlchemy-Searchable==0.10.6 SQLAlchemy-Utils>=0.29.0 sqlparse==0.2.4 statsd==2.1.2 gunicorn==19.7.1 celery==4.2.1 +kombu==4.2.2.post1 jsonschema==2.4.0 RestrictedPython==3.6.0 pysaml2==4.5.0 diff --git a/setup.cfg b/setup.cfg index 68859ad0..9fa9f5c9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,7 @@ [pep8] max-line-length = 120 + +[flake8] +ignore = E501 +exclude = .git +max-complexity = 10 diff --git a/tests/handlers/test_users.py b/tests/handlers/test_users.py index 3939ab5d..8cf7d913 100644 --- a/tests/handlers/test_users.py +++ b/tests/handlers/test_users.py @@ -28,7 +28,6 @@ class TestUserListResourcePost(BaseTestCase): rv = self.make_request('post', '/api/users', data=test_user, user=admin) self.assertEqual(rv.status_code, 400) - def test_creates_user(self): admin = self.factory.create_admin() @@ -274,12 +273,12 @@ class TestUserDisable(BaseTestCase): self.db.session.commit() with patch('redash.handlers.authentication.login_user') as login_user_mock: - rv = self.client.post('/login', data={'email': user.email, 'password': 'password'}) + rv = self.post_request('/login', data={'email': user.email, 'password': 'password'}, org=self.factory.org) # login handler should not be called login_user_mock.assert_not_called() - # check for redirect back to login page - self.assertEquals(rv.status_code, 301) - self.assertIn('/login', rv.headers.get('Location', None)) + # check if error is raised + self.assertEquals(rv.status_code, 200) + self.assertIn('Wrong email or password', rv.data) def test_disabled_user_should_not_access_api(self): # Note: some API does not require user, so check the one which requires diff --git a/tests/handlers/test_visualizations.py b/tests/handlers/test_visualizations.py index 85db9ffe..cc266c74 100644 --- a/tests/handlers/test_visualizations.py +++ b/tests/handlers/test_visualizations.py @@ -27,7 +27,7 @@ class VisualizationResourceTest(BaseTestCase): rv = self.make_request('delete', '/api/visualizations/{}'.format(visualization.id)) self.assertEquals(rv.status_code, 200) - self.assertEquals(models.db.session.query(models.Visualization).count(), 0) + self.assertEquals(models.Visualization.query.count(), 0) def test_update_visualization(self): visualization = self.factory.create_visualization() diff --git a/tests/models/test_query_results.py b/tests/models/test_query_results.py index 94a160ba..a3c805c8 100644 --- a/tests/models/test_query_results.py +++ b/tests/models/test_query_results.py @@ -43,7 +43,7 @@ class QueryResultTest(BaseTestCase): def test_get_latest_returns_the_most_recent_result(self): yesterday = utcnow() - datetime.timedelta(seconds=30) - old_qr = self.factory.create_query_result(retrieved_at=yesterday) + self.factory.create_query_result(retrieved_at=yesterday) qr = self.factory.create_query_result() found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query_text, 60) @@ -52,7 +52,7 @@ class QueryResultTest(BaseTestCase): def test_get_latest_returns_the_last_cached_result_for_negative_ttl(self): yesterday = utcnow() + datetime.timedelta(days=-100) - very_old = self.factory.create_query_result(retrieved_at=yesterday) + self.factory.create_query_result(retrieved_at=yesterday) yesterday = utcnow() + datetime.timedelta(days=-1) qr = self.factory.create_query_result(retrieved_at=yesterday) diff --git a/tests/tasks/test_queries.py b/tests/tasks/test_queries.py index 17680a00..14f80e07 100644 --- a/tests/tasks/test_queries.py +++ b/tests/tasks/test_queries.py @@ -19,7 +19,7 @@ class TestPrune(TestCase): for score in range(0, 100): key = 'k:{}'.format(score) self.keys.append(key) - redis_connection.zadd(self.list, score, key) + redis_connection.zadd(self.list, {key: score}) redis_connection.set(key, 1) def test_does_nothing_when_below_threshold(self): diff --git a/tests/test_models.py b/tests/test_models.py index 31404069..4c83cb40 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,15 +1,15 @@ #encoding: utf8 import calendar import datetime -import json from unittest import TestCase -import mock +import pytz +import walrus from dateutil.parser import parse as date_parse -from tests import BaseTestCase +from tests import BaseTestCase, authenticated_user -from redash import models -from redash.models import db +from redash import models, redis_connection +from redash.models import db, types from redash.utils import gen_query_hash, utcnow @@ -282,9 +282,6 @@ class QueryOutdatedQueriesTest(BaseTestCase): class QueryArchiveTest(BaseTestCase): - def setUp(self): - super(QueryArchiveTest, self).setUp() - def test_archive_query_sets_flag(self): query = self.factory.create_query() db.session.flush() @@ -315,7 +312,7 @@ class QueryArchiveTest(BaseTestCase): db.session.commit() query.archive() db.session.flush() - self.assertEqual(db.session.query(models.Widget).get(widget.id), None) + self.assertEqual(models.Widget.query.get(widget.id), None) def test_removes_scheduling(self): query = self.factory.create_query(schedule={'interval':'1', 'until':None, 'time': None, 'day_of_week':None}) @@ -330,19 +327,19 @@ class QueryArchiveTest(BaseTestCase): db.session.commit() query.archive() db.session.flush() - self.assertEqual(db.session.query(models.Alert).get(subscription.alert.id), None) - self.assertEqual(db.session.query(models.AlertSubscription).get(subscription.id), None) + self.assertEqual(models.Alert.query.get(subscription.alert.id), None) + self.assertEqual(models.AlertSubscription.query.get(subscription.id), None) class TestUnusedQueryResults(BaseTestCase): def test_returns_only_unused_query_results(self): two_weeks_ago = utcnow() - datetime.timedelta(days=14) qr = self.factory.create_query_result() - query = self.factory.create_query(latest_query_data=qr) + self.factory.create_query(latest_query_data=qr) db.session.flush() unused_qr = self.factory.create_query_result(retrieved_at=two_weeks_ago) - self.assertIn((unused_qr.id,), models.QueryResult.unused()) - self.assertNotIn((qr.id,), list(models.QueryResult.unused())) + self.assertIn(unused_qr, list(models.QueryResult.unused())) + self.assertNotIn(qr, list(models.QueryResult.unused())) def test_returns_only_over_a_week_old_results(self): two_weeks_ago = utcnow() - datetime.timedelta(days=14) @@ -350,8 +347,8 @@ class TestUnusedQueryResults(BaseTestCase): db.session.flush() new_unused_qr = self.factory.create_query_result() - self.assertIn((unused_qr.id,), models.QueryResult.unused()) - self.assertNotIn((new_unused_qr.id,), models.QueryResult.unused()) + self.assertIn(unused_qr, list(models.QueryResult.unused())) + self.assertNotIn(new_unused_qr, list(models.QueryResult.unused())) class TestQueryAll(BaseTestCase): @@ -546,6 +543,7 @@ def _set_up_dashboard_test(d): d.w2.dashboard.is_draft = False d.w4.dashboard.is_draft = False + class TestDashboardAll(BaseTestCase): def setUp(self): super(TestDashboardAll, self).setUp() @@ -596,3 +594,85 @@ class TestDashboardAll(BaseTestCase): self.assertIn(w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.group_ids, None)) self.assertNotIn(w1.dashboard, models.Dashboard.all(user.org, user.group_ids, None)) + + +class Timestamp(walrus.Model): + __database__ = redis_connection + __namespace__ = 'redash.tests.timestamp' + + created_at = types.UTCDateTimeField(index=True, default=utcnow) + + +class TestUserDetail(BaseTestCase): + + def setUp(self): + super(TestUserDetail, self).setUp() + redis_connection.flushdb() + + def test_walrus_utcdatetimefield(self): + timestamp = Timestamp() + timestamp.save() + + timestamps = list(Timestamp.all()) + self.assertEqual(len(timestamps), 1) + self.assertIsInstance(timestamps[0].created_at, datetime.datetime) + self.assertEqual(timestamps[0].created_at.tzinfo, pytz.utc) + + def test_userdetail_db_default(self): + with authenticated_user(self.client) as user: + self.assertEqual(user.details, {}) + self.assertIsNone(user.active_at) + + def test_userdetail_db_default_save(self): + with authenticated_user(self.client) as user: + user.details['test'] = 1 + models.db.session.commit() + + user_reloaded = models.User.query.filter_by(id=user.id).first() + self.assertEqual(user.details['test'], 1) + self.assertEqual( + user_reloaded, + models.User.query.filter( + models.User.details['test'].astext.cast(models.db.Integer) == 1 + ).first() + ) + + def test_userdetail_create(self): + self.assertEqual(len(list(models.UserDetail.all())), 0) + user_detail = models.UserDetail.create(user_id=1) + user_detail.save() + self.assertEqual( + models.UserDetail.get(models.UserDetail.user_id == 1)._id, + user_detail._id, + ) + + def test_userdetail_update(self): + self.assertEqual(len(list(models.UserDetail.all())), 0) + # first try to create a user with a user id that we haven't used before + # and see if the creation was successful + models.UserDetail.update(user_id=1000) # non-existent user + all_user_details = list(models.UserDetail.all()) + self.assertEqual(len(all_user_details), 1) + created_user_detail = all_user_details[0] + + # then see if we can update the same user detail again + updated_user_detail = models.UserDetail.update( + user_id=created_user_detail.user_id + ) + self.assertGreater( + updated_user_detail.updated_at, + created_user_detail.updated_at + ) + + def test_sync(self): + with authenticated_user(self.client) as user: + user_detail = models.UserDetail.update(user_id=user.id) + self.assertEqual(user.details, {}) + + self.assertEqual(len(list(models.UserDetail.all())), 1) + models.UserDetail.sync() + self.assertEqual(len(list(models.UserDetail.all())), 0) + + user_reloaded = models.User.query.filter_by(id=user.id).first() + self.assertIn('active_at', user_reloaded.details) + self.assertEqual(user_reloaded.active_at, user_detail.updated_at)