diff --git a/redash/cli/data_sources.py b/redash/cli/data_sources.py index 6196f1e1..b7d77e37 100644 --- a/redash/cli/data_sources.py +++ b/redash/cli/data_sources.py @@ -2,6 +2,8 @@ from sys import exit import json import click +from flask.cli import with_appcontext +from sqlalchemy.orm.exc import NoResultFound from redash import models from redash.query_runner import query_runners @@ -12,6 +14,7 @@ manager = click.Group(help="Data sources management commands.") @manager.command() +@with_appcontext @click.option('--org', 'organization', default=None, help="The organization the user belongs to (leave blank for " "all organizations).") @@ -19,9 +22,10 @@ def list(organization=None): """List currently configured data sources.""" if organization: org = models.Organization.get_by_slug(organization) - data_sources = models.DataSource.select().where(models.DataSource.org==org.id) + data_sources = models.DataSource.query.filter( + models.DataSource.org == org) else: - data_sources = models.DataSource.select() + data_sources = models.DataSource.query for i, ds in enumerate(data_sources): if i > 0: print "-" * 20 @@ -34,10 +38,12 @@ def validate_data_source_type(type): if type not in query_runners.keys(): print ("Error: the type \"{}\" is not supported (supported types: {})." .format(type, ", ".join(query_runners.keys()))) - exit() + print "OJNK" + exit(1) @manager.command() +@with_appcontext @click.argument('name') @click.option('--org', 'organization', default='default', help="The organization the user belongs to " @@ -46,10 +52,9 @@ def test(name, organization='default'): """Test connection to data source by issuing a trivial query.""" try: org = models.Organization.get_by_slug(organization) - data_source = models.DataSource.get( + data_source = models.DataSource.query.filter( models.DataSource.name == name, - models.DataSource.org == org, - ) + models.DataSource.org == org).one() print "Testing connection to data source: {} (id={})".format( name, data_source.id) try: @@ -59,12 +64,13 @@ def test(name, organization='default'): exit(1) else: print "Success" - except models.DataSource.DoesNotExist: + except NoResultFound: print "Couldn't find data source named: {}".format(name) exit(1) @manager.command() +@with_appcontext @click.argument('name', default=None, required=False) @click.option('--type', default=None, help="new type for the data source") @@ -140,6 +146,7 @@ def new(name=None, type=None, options=None, organization='default'): @manager.command() +@with_appcontext @click.argument('name') @click.option('--org', 'organization', default='default', help="The organization the user belongs to (leave blank for " @@ -148,13 +155,12 @@ def delete(name, organization='default'): """Delete data source by name.""" try: org = models.Organization.get_by_slug(organization) - data_source = models.DataSource.get( - models.DataSource.name==name, - models.DataSource.org==org, - ) + data_source = models.DataSource.query.filter( + models.DataSource.name == name, + models.DataSource.org == org).one() print "Deleting data source: {} (id={})".format(name, data_source.id) - data_source.delete_instance(recursive=True) - except models.DataSource.DoesNotExist: + models.db.session.delete(data_source) + except NoResultFound: print "Couldn't find data source named: {}".format(name) exit(1) @@ -167,6 +173,7 @@ def update_attr(obj, attr, new_value): @manager.command() +@with_appcontext @click.argument('name') @click.option('--name', 'new_name', default=None, help="new name for the data source") @@ -183,10 +190,9 @@ def edit(name, new_name=None, options=None, type=None, organization='default'): if type is not None: validate_data_source_type(type) org = models.Organization.get_by_slug(organization) - data_source = models.DataSource.get( - models.DataSource.name==name, - models.DataSource.org==org, - ) + data_source = models.DataSource.query.filter( + models.DataSource.name == name, + models.DataSource.org == org).one() update_attr(data_source, "name", new_name) update_attr(data_source, "type", type) @@ -197,7 +203,8 @@ def edit(name, new_name=None, options=None, type=None, organization='default'): data_source.options.set_schema(schema) data_source.options.update(options) - data_source.save() + models.db.session.add(data_source) + models.db.session.commit() - except models.DataSource.DoesNotExist: + except NoResultFound: print "Couldn't find data source named: {}".format(name) diff --git a/redash/cli/database.py b/redash/cli/database.py index 0de7e257..57fb2341 100644 --- a/redash/cli/database.py +++ b/redash/cli/database.py @@ -1,9 +1,11 @@ from click import Group +from flask.cli import with_appcontext manager = Group(help="Manage the database (create/drop tables).") @manager.command() +@with_appcontext def create_tables(): """Create the database tables.""" from redash.models import db, create_db, init_db @@ -13,6 +15,7 @@ def create_tables(): @manager.command() +@with_appcontext def drop_tables(): """Drop the database tables.""" from redash.models import create_db diff --git a/redash/cli/groups.py b/redash/cli/groups.py index 7aae322d..92b264f1 100644 --- a/redash/cli/groups.py +++ b/redash/cli/groups.py @@ -1,12 +1,16 @@ from sys import exit +from sqlalchemy.orm.exc import NoResultFound +from flask.cli import with_appcontext from click import Group, argument, option + from redash import models manager = Group(help="Groups management commands.") @manager.command() +@with_appcontext @argument('name') @option('--org', 'organization', default='default', help="The organization the user belongs to (leave blank for " @@ -27,13 +31,17 @@ def create(name, permissions=None, organization='default'): print "permissions: [%s]" % ",".join(permissions) try: - models.Group.create(name=name, org=org, permissions=permissions) + models.db.session.add(models.Group( + name=name, org=org, + permissions=permissions)) + models.db.session.commit() except Exception, e: print "Failed create group: %s" % e.message exit(1) @manager.command() +@with_appcontext @argument('group_id') @option('--permissions', default=None, help="Comma separated list of permissions ('create_dashboard'," @@ -45,8 +53,8 @@ def change_permissions(group_id, permissions=None): print "Change permissions of group %s ..." % group_id try: - group = models.Group.get_by_id(group_id) - except models.Group.DoesNotExist: + group = models.Group.query.get(group_id) + except NoResultFound: print "User [%s] not found." % group_id exit(1) @@ -57,7 +65,8 @@ def change_permissions(group_id, permissions=None): group.permissions = permissions try: - group.save() + models.db.session.add(group) + models.db.session.commit() except Exception, e: print "Failed change permission: %s" % e.message exit(1) @@ -73,15 +82,16 @@ def extract_permissions_string(permissions): @manager.command() +@with_appcontext @option('--org', 'organization', default=None, help="The organization to limit to (leave blank for all).") def list(organization=None): """List all groups""" if organization: org = models.Organization.get_by_slug(organization) - groups = models.Group.select().where(models.Group.org == org) + groups = models.Group.query.filter(models.Group.org == org) else: - groups = models.Group.select() + groups = models.Group.query for i, group in enumerate(groups): if i > 0: diff --git a/redash/cli/organization.py b/redash/cli/organization.py index 04e08687..73daef06 100644 --- a/redash/cli/organization.py +++ b/redash/cli/organization.py @@ -1,35 +1,40 @@ from click import Group, argument +from flask.cli import with_appcontext + from redash import models manager = Group(help="Organization management commands.") @manager.command() +@with_appcontext @argument('domains') def set_google_apps_domains(domains): """ Sets the allowable domains to the comma separated list DOMAINS. """ - organization = models.Organization.select().first() + organization = models.Organization.query.first() k = models.Organization.SETTING_GOOGLE_APPS_DOMAINS organization.settings[k] = domains.split(',') - organization.save() - + models.db.session.add(organization) + models.db.session.commit() print "Updated list of allowed domains to: {}".format( organization.google_apps_domains) @manager.command() +@with_appcontext def show_google_apps_domains(): - organization = models.Organization.select().first() + organization = models.Organization.query.first() print "Current list of Google Apps domains: {}".format( ', '.join(organization.google_apps_domains)) @manager.command() +@with_appcontext def list(): """List all organizations""" - orgs = models.Organization.select() + orgs = models.Organization.query for i, org in enumerate(orgs): if i > 0: print "-" * 20 diff --git a/redash/cli/users.py b/redash/cli/users.py index 4f2ffdad..11049c73 100644 --- a/redash/cli/users.py +++ b/redash/cli/users.py @@ -1,6 +1,8 @@ from sys import exit from click import BOOL, Group, argument, option, prompt +from flask.cli import with_appcontext +from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.exc import IntegrityError from redash import models @@ -25,6 +27,7 @@ def build_groups(org, groups, is_admin): @manager.command() +@with_appcontext @argument('email') @option('--org', 'organization', default='default', help="the organization the user belongs to, (leave blank for " @@ -38,18 +41,19 @@ def grant_admin(email, organization='default'): admin_group = org.admin_group user = models.User.get_by_email_and_org(email, org) - if admin_group.id in user.groups: + if admin_group.id in user.group_ids: print "User is already an admin." else: - user.groups.append(org.admin_group.id) - user.save() - + user.group_ids = user.group_ids + [org.admin_group.id] + models.db.session.add(user) + models.db.session.commit() print "User updated." - except models.User.DoesNotExist: + except NoResultFound: print "User [%s] not found." % email @manager.command() +@with_appcontext @argument('email') @argument('name') @option('--org', 'organization', default='default', @@ -78,7 +82,7 @@ def create(email, name, groups, is_admin=False, google_auth=False, org = models.Organization.get_by_slug(organization) groups = build_groups(org, groups, is_admin) - user = models.User(org=org, email=email, name=name, groups=groups) + user = models.User(org=org, email=email, name=name, group_ids=groups) if not password and not google_auth: password = prompt("Password", hide_input=True, confirmation_prompt=True) @@ -86,13 +90,15 @@ def create(email, name, groups, is_admin=False, google_auth=False, user.hash_password(password) try: - user.save() + models.db.session.add(user) + models.db.session.commit() except Exception, e: print "Failed creating user: %s" % e.message exit(1) @manager.command() +@with_appcontext @argument('email') @option('--org', 'organization', default=None, help="The organization the user belongs to (leave blank for all" @@ -103,16 +109,17 @@ def delete(email, organization=None): """ if organization: org = models.Organization.get_by_slug(organization) - deleted_count = models.User.delete().where( + deleted_count = models.User.query.filter( models.User.email == email, models.User.org == org.id, - ).execute() + ).delete() else: - deleted_count = models.User.delete().where(models.User.email == email).execute() + deleted_count = models.User.query.filter(models.User.email == email).delete() print "Deleted %d users." % deleted_count @manager.command() +@with_appcontext @argument('email') @argument('password') @option('--org', 'organization', default=None, @@ -124,16 +131,17 @@ def password(email, password, organization=None): """ if organization: org = models.Organization.get_by_slug(organization) - user = models.User.select().where( + user = models.User.query.filter( models.User.email == email, - models.User.org == org.id, + models.User.org == org, ).first() else: - user = models.User.select().where(models.User.email == email).first() + user = models.User.query.filter(models.User.email == email).first() if user is not None: user.hash_password(password) - user.save() + models.db.session.add(user) + models.db.session.commit() print "User updated." else: print "User [%s] not found." % email @@ -141,6 +149,7 @@ def password(email, password, organization=None): @manager.command() +@with_appcontext @argument('email') @argument('name') @argument('inviter_email') @@ -159,22 +168,23 @@ def invite(email, name, inviter_email, groups, is_admin=False, groups = build_groups(org, groups, is_admin) try: user_from = models.User.get_by_email_and_org(inviter_email, org) - user = models.User(org=org, name=name, email=email, groups=groups) - + user = models.User(org=org, name=name, email=email, group_ids=groups) + models.db.session.add(user) try: - user.save() - invite_url = invite_user(org, user_from, user) + models.db.session.commit() + invite_user(org, user_from, user) print "An invitation was sent to [%s] at [%s]." % (name, email) except IntegrityError as e: if "email" in e.message: print "Cannot invite. User already exists [%s]" % email else: print e - except models.User.DoesNotExist: + except NoResultFound: print "The inviter [%s] was not found." % inviter_email @manager.command() +@with_appcontext @option('--org', 'organization', default=None, help="The organization the user belongs to (leave blank for all" " organizations)") @@ -182,9 +192,9 @@ def list(organization=None): """List all users""" if organization: org = models.Organization.get_by_slug(organization) - users = models.User.select().where(models.User.org==org.id) + users = models.User.query.filter(models.User.org == org) else: - users = models.User.select() + users = models.User.query for i, user in enumerate(users): if i > 0: print "-" * 20 diff --git a/redash/models.py b/redash/models.py index e13978ad..789907ed 100644 --- a/redash/models.py +++ b/redash/models.py @@ -14,6 +14,7 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.event import listens_for from sqlalchemy.inspection import inspect from sqlalchemy.types import TypeDecorator +from sqlalchemy.ext.mutable import Mutable from sqlalchemy.orm import object_session # noinspection PyUnresolvedReferences from sqlalchemy.orm.exc import NoResultFound @@ -70,6 +71,8 @@ class GFKBase(object): # return peewee.Expression(self, '::', peewee.SQL(as_type)) +# XXX replace PseudoJSON and MutableDict with real JSON field + class PseudoJSON(TypeDecorator): impl = db.Text @@ -82,6 +85,33 @@ class PseudoJSON(TypeDecorator): return json.loads(value) +class MutableDict(Mutable, dict): + @classmethod + def coerce(cls, key, value): + "Convert plain dictionaries to MutableDict." + + if not isinstance(value, MutableDict): + if isinstance(value, dict): + return MutableDict(value) + + # this call will raise ValueError + return Mutable.coerce(key, value) + else: + return value + + def __setitem__(self, key, value): + "Detect dictionary set events and emit change events." + + dict.__setitem__(self, key, value) + self.changed() + + def __delitem__(self, key): + "Detect dictionary del events and emit change events." + + dict.__delitem__(self, key) + self.changed() + + class TimestampMixin(object): updated_at = Column(db.DateTime(True), default=db.func.now(), onupdate=db.func.now(), nullable=False) @@ -186,7 +216,7 @@ class Organization(TimestampMixin, db.Model): id = Column(db.Integer, primary_key=True) name = Column(db.String(255)) slug = Column(db.String(255), unique=True) - settings = Column(PseudoJSON) + settings = Column(MutableDict.as_mutable(PseudoJSON)) groups = db.relationship("Group", lazy="dynamic") __tablename__ = 'organizations' @@ -415,7 +445,9 @@ class DataSource(BelongsToOrgMixin, db.Model): @classmethod def create_with_group(cls, *args, **kwargs): data_source = cls(*args, **kwargs) - data_source_group = DataSourceGroup(data_source=data_source, group=data_source.org.default_group) + data_source_group = DataSourceGroup( + data_source=data_source, + group=data_source.org.default_group) db.session.add_all([data_source, data_source_group]) return data_source @@ -634,7 +666,7 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): is_draft = Column(db.Boolean, default=True, index=True) schedule = Column(db.String(10), nullable=True) visualizations = db.relationship("Visualization", cascade="all, delete-orphan") - options = Column(PseudoJSON, default={}) + options = Column(MutableDict.as_mutable(PseudoJSON), default={}) __tablename__ = 'queries' __mapper_args__ = { @@ -973,7 +1005,7 @@ class Alert(TimestampMixin, db.Model): query_rel = db.relationship(Query, backref='alerts', cascade="all") user_id = Column(db.Integer, db.ForeignKey("users.id")) user = db.relationship(User, backref='alerts') - options = Column(PseudoJSON) + options = Column(MutableDict.as_mutable(PseudoJSON)) state = Column(db.String(255), default=UNKNOWN_STATE) subscriptions = db.relationship("AlertSubscription", cascade="all, delete-orphan") last_triggered_at = Column(db.DateTime(True), nullable=True) diff --git a/tests/test_cli.py b/tests/test_cli.py index 7215a150..d292433f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -14,7 +14,7 @@ from redash.cli.organization import (list as list_org, set_google_apps_domains, show_google_apps_domains) from redash.cli.users import (create as create_user, delete as delete_user, grant_admin, invite, list as list_user, password) -from redash.models import DataSource, Group, Organization, User +from redash.models import DataSource, Group, Organization, User, db class DataSourceCommandTests(BaseTestCase): @@ -26,8 +26,8 @@ class DataSourceCommandTests(BaseTestCase): input="test\n%s\n\n\nexample.com\n\ntestdb\n" % (pg_i,)) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(DataSource.select().count(), 1) - ds = DataSource.select().first() + self.assertEqual(DataSource.query.count(), 1) + ds = DataSource.query.first() self.assertEqual(ds.name, 'test') self.assertEqual(ds.type, 'pg') self.assertEqual(ds.options['dbname'], 'testdb') @@ -40,8 +40,8 @@ class DataSourceCommandTests(BaseTestCase): '--type', 'pg']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(DataSource.select().count(), 1) - ds = DataSource.select().first() + self.assertEqual(DataSource.query.count(), 1) + ds = DataSource.query.first() self.assertEqual(ds.name, 'test') self.assertEqual(ds.type, 'pg') self.assertEqual(ds.options['host'], 'example.com') @@ -54,7 +54,7 @@ class DataSourceCommandTests(BaseTestCase): self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('not supported', result.output) - self.assertEqual(DataSource.select().count(), 0) + self.assertEqual(DataSource.query.count(), 0) def test_bad_options_new(self): runner = CliRunner() @@ -65,7 +65,7 @@ class DataSourceCommandTests(BaseTestCase): self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('invalid configuration', result.output) - self.assertEqual(DataSource.select().count(), 0) + self.assertEqual(DataSource.query.count(), 0) def test_list(self): self.factory.create_data_source( @@ -122,7 +122,7 @@ class DataSourceCommandTests(BaseTestCase): self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) self.assertIn('Deleting', result.output) - self.assertEqual(DataSource.select().count(), 0) + self.assertEqual(DataSource.query.count(), 0) def test_connection_bad_delete(self): self.factory.create_data_source( @@ -133,7 +133,7 @@ class DataSourceCommandTests(BaseTestCase): self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn("Couldn't find", result.output) - self.assertEqual(DataSource.select().count(), 1) + self.assertEqual(DataSource.query.count(), 1) def test_options_edit(self): self.factory.create_data_source( @@ -147,8 +147,8 @@ class DataSourceCommandTests(BaseTestCase): '--type', 'pg']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(DataSource.select().count(), 1) - ds = DataSource.select().first() + self.assertEqual(DataSource.query.count(), 1) + ds = DataSource.query.first() self.assertEqual(ds.name, 'test2') self.assertEqual(ds.type, 'pg') self.assertEqual(ds.options['host'], 'example.com') @@ -164,7 +164,7 @@ class DataSourceCommandTests(BaseTestCase): self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('not supported', result.output) - ds = DataSource.select().first() + ds = DataSource.query.first() self.assertEqual(ds.type, 'sqlite') def test_bad_options_edit(self): @@ -179,7 +179,7 @@ class DataSourceCommandTests(BaseTestCase): self.assertTrue(result.exception) self.assertEqual(result.exit_code, 1) self.assertIn('invalid configuration', result.output) - ds = DataSource.select().first() + ds = DataSource.query.first() self.assertEqual(ds.type, 'sqlite') self.assertEqual(ds.options._config, {"dbpath": "/tmp/test.db"}) @@ -187,21 +187,21 @@ class DataSourceCommandTests(BaseTestCase): class GroupCommandTests(BaseTestCase): def test_create(self): - gcount = Group.select().count() + gcount = Group.query.count() perms = ['create_query', 'edit_query', 'view_query'] runner = CliRunner() result = runner.invoke( create_group, ['test', '--permissions', ','.join(perms)]) - print result.output self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(Group.select().count(), gcount + 1) - g = Group.select().order_by(Group.id.desc()).first() + self.assertEqual(Group.query.count(), gcount + 1) + g = Group.query.order_by(Group.id.desc()).first() self.assertEqual(g.org, self.factory.org) self.assertEqual(g.permissions, perms) def test_change_permissions(self): g = self.factory.create_group(permissions=['list_dashboards']) + db.session.flush() g_id = g.id perms = ['create_query', 'edit_query', 'view_query'] runner = CliRunner() @@ -209,7 +209,7 @@ class GroupCommandTests(BaseTestCase): change_permissions, [str(g_id), '--permissions', ','.join(perms)]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - g = Group.select().where(Group.id == g_id).first() + g = Group.query.filter(Group.id == g_id).first() self.assertEqual(g.permissions, perms) def test_list(self): @@ -245,14 +245,15 @@ class OrganizationCommandTests(BaseTestCase): result = runner.invoke(set_google_apps_domains, [','.join(domains)]) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - o = Organization.select().where( - Organization.id == self.factory.org.id).first() - self.assertEqual(o.google_apps_domains, domains) + #db.session. + db.session.refresh(self.factory.org) + self.assertEqual(self.factory.org.google_apps_domains, domains) def test_show_google_apps_domains(self): self.factory.org.settings[Organization.SETTING_GOOGLE_APPS_DOMAINS] = [ 'example.org', 'example.com'] - self.factory.org.save() + db.session.add(self.factory.org) + db.session.commit() runner = CliRunner() result = runner.invoke(show_google_apps_domains, []) self.assertFalse(result.exception) @@ -290,10 +291,10 @@ class UserCommandTests(BaseTestCase): input="password1\npassword1\n") self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - u = User.select().where(User.email == "foobar@example.com").first() + u = User.query.filter(User.email == "foobar@example.com").first() self.assertEqual(u.name, "Fred Foobar") self.assertTrue(u.verify_password('password1')) - self.assertEqual(u.groups, [self.factory.default_group.id]) + self.assertEqual(u.group_ids, [self.factory.default_group.id]) def test_create_admin(self): runner = CliRunner() @@ -302,10 +303,10 @@ class UserCommandTests(BaseTestCase): '--password', 'password1', '--admin']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - u = User.select().where(User.email == "foobar@example.com").first() + u = User.query.filter(User.email == "foobar@example.com").first() self.assertEqual(u.name, "Fred Foobar") self.assertTrue(u.verify_password('password1')) - self.assertEqual(u.groups, [self.factory.default_group.id, + self.assertEqual(u.group_ids, [self.factory.default_group.id, self.factory.admin_group.id]) def test_create_googleauth(self): @@ -314,10 +315,10 @@ class UserCommandTests(BaseTestCase): create_user, ['foobar@example.com', 'Fred Foobar', '--google']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - u = User.select().where(User.email == "foobar@example.com").first() + u = User.query.filter(User.email == "foobar@example.com").first() self.assertEqual(u.name, "Fred Foobar") self.assertIsNone(u.password_hash) - self.assertEqual(u.groups, [self.factory.default_group.id]) + self.assertEqual(u.group_ids, [self.factory.default_group.id]) def test_create_bad(self): self.factory.create_user(email='foobar@example.com') @@ -331,23 +332,23 @@ class UserCommandTests(BaseTestCase): def test_delete(self): self.factory.create_user(email='foobar@example.com') - ucount = User.select().count() + ucount = User.query.count() runner = CliRunner() result = runner.invoke( delete_user, ['foobar@example.com']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - self.assertEqual(User.select().where(User.email == + self.assertEqual(User.query.filter(User.email == "foobar@example.com").count(), 0) - self.assertEqual(User.select().count(), ucount - 1) + self.assertEqual(User.query.count(), ucount - 1) def test_delete_bad(self): - ucount = User.select().count() + ucount = User.query.count() runner = CliRunner() result = runner.invoke( delete_user, ['foobar@example.com']) self.assertIn('Deleted 0 users', result.output) - self.assertEqual(User.select().count(), ucount) + self.assertEqual(User.query.count(), ucount) def test_password(self): self.factory.create_user(email='foobar@example.com') @@ -356,7 +357,7 @@ class UserCommandTests(BaseTestCase): password, ['foobar@example.com', 'xyzzy']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - u = User.select().where(User.email == "foobar@example.com").first() + u = User.query.filter(User.email == "foobar@example.com").first() self.assertTrue(u.verify_password('xyzzy')) def test_password_bad(self): @@ -394,7 +395,7 @@ class UserCommandTests(BaseTestCase): def test_list(self): self.factory.create_user(name='Fred Foobar', email='foobar@example.com', - organization=self.factory.org) + org=self.factory.org) runner = CliRunner() result = runner.invoke(list_user, []) self.assertFalse(result.exception) @@ -409,15 +410,15 @@ class UserCommandTests(BaseTestCase): textwrap.dedent(output).lstrip()) def test_grant_admin(self): - self.factory.create_user(name='Fred Foobar', + u = self.factory.create_user(name='Fred Foobar', email='foobar@example.com', org=self.factory.org, - groups=[self.factory.default_group.id]) + group_ids=[self.factory.default_group.id]) runner = CliRunner() result = runner.invoke( grant_admin, ['foobar@example.com']) self.assertFalse(result.exception) self.assertEqual(result.exit_code, 0) - u = User.select().order_by(User.id.desc()).first() - self.assertEqual(u.groups, [self.factory.default_group.id, - self.factory.admin_group.id]) + db.session.refresh(u) + self.assertEqual(u.group_ids, [self.factory.default_group.id, + self.factory.admin_group.id])