From 8e1c852b0db2a1d86d2e0c7f34bebd5960a502d5 Mon Sep 17 00:00:00 2001 From: Arik Fraimovich Date: Sun, 18 May 2014 13:57:08 +0300 Subject: [PATCH] Reset the database lock if pid changed --- redash/__init__.py | 23 ++++++++++++++--------- redash/models.py | 12 ++++++++++++ redash/tasks.py | 26 +++++++++++++++++++------- 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/redash/__init__.py b/redash/__init__.py index afbd1d8c..20cdc5f5 100644 --- a/redash/__init__.py +++ b/redash/__init__.py @@ -17,14 +17,19 @@ def setup_logging(): events.setup_logging(settings.EVENTS_LOG_PATH, settings.EVENTS_CONSOLE_OUTPUT) + +def create_redis_connection(): + redis_url = urlparse.urlparse(settings.REDIS_URL) + if redis_url.path: + redis_db = redis_url.path[1] + else: + redis_db = 0 + + r = redis.StrictRedis(host=redis_url.hostname, port=redis_url.port, db=redis_db, password=redis_url.password) + + return r + + setup_logging() - -redis_url = urlparse.urlparse(settings.REDIS_URL) -if redis_url.path: - redis_db = redis_url.path[1] -else: - redis_db = 0 - -# TODO: move this to function that create a connection? -redis_connection = redis.StrictRedis(host=redis_url.hostname, port=redis_url.port, db=redis_db, password=redis_url.password) +redis_connection = create_redis_connection() statsd_client = StatsClient(host=settings.STATSD_HOST, port=settings.STATSD_PORT, prefix=settings.STATSD_PREFIX) \ No newline at end of file diff --git a/redash/models.py b/redash/models.py index 7b792703..423c50a3 100644 --- a/redash/models.py +++ b/redash/models.py @@ -1,6 +1,8 @@ import json import hashlib import logging +import os +import threading import time import datetime import itertools @@ -19,18 +21,28 @@ class Database(object): self.database_name = self.database_config.pop('name') self.database = peewee.PostgresqlDatabase(self.database_name, **self.database_config) self.app = None + self.pid = os.getpid() def init_app(self, app): self.app = app self.register_handlers() def connect_db(self): + self._check_pid() self.database.connect() def close_db(self, exc): + self._check_pid() if not self.database.is_closed(): self.database.close() + def _check_pid(self): + current_pid = os.getpid() + if self.pid != current_pid: + logging.info("New pid detected (%d!=%d); resetting database lock.", self.pid, current_pid) + self.pid = os.getpid() + self.database._conn_lock = threading.Lock() + def register_handlers(self): self.app.before_request(self.connect_db) self.app.teardown_request(self.close_db) diff --git a/redash/tasks.py b/redash/tasks.py index f690fb3c..db416961 100644 --- a/redash/tasks.py +++ b/redash/tasks.py @@ -1,17 +1,29 @@ import time import datetime -from celery.utils.log import get_task_logger import logging -from celery.result import AsyncResult import redis -from redash.data.query_runner import get_query_runner -from redash import models, redis_connection, statsd_client -from redash.worker import celery +from celery import Task +from celery.result import AsyncResult +from celery.utils.log import get_task_logger +from redash import redis_connection, models, statsd_client from redash.utils import gen_query_hash +from redash.worker import celery +from redash.data.query_runner import get_query_runner logger = get_task_logger(__name__) +class BaseTask(Task): + abstract = True + + def after_return(self, *args, **kwargs): + models.db.close_db(None) + + def __call__(self, *args, **kwargs): + models.db.connect_db() + return super(BaseTask, self).__call__(*args, **kwargs) + + class QueryTask(object): MAX_RETRIES = 5 @@ -105,7 +117,7 @@ class QueryTask(object): return self._async_result.revoke(terminate=True) -@celery.task +@celery.task(base=BaseTask) def refresh_queries(): # self.status['last_refresh_at'] = time.time() # self._save_status() @@ -134,7 +146,7 @@ def refresh_queries(): statsd_client.gauge('manager.seconds_since_refresh', now - float(status.get('last_refresh_at', now))) -@celery.task(bind=True, track_started=True) +@celery.task(bind=True, base=BaseTask, track_started=True) def execute_query(self, query, data_source_id): # TODO: maybe this should be a class? start_time = time.time()