Merge pull request #220 from EverythingMe/celery

Reconnect to database on every task.
This commit is contained in:
Arik Fraimovich 2014-05-18 13:59:42 +03:00
commit b0cc646b5e
3 changed files with 45 additions and 16 deletions

View File

@ -17,14 +17,19 @@ def setup_logging():
events.setup_logging(settings.EVENTS_LOG_PATH, settings.EVENTS_CONSOLE_OUTPUT)
def create_redis_connection():
redis_url = urlparse.urlparse(settings.REDIS_URL)
if redis_url.path:
redis_db = redis_url.path[1]
else:
redis_db = 0
r = redis.StrictRedis(host=redis_url.hostname, port=redis_url.port, db=redis_db, password=redis_url.password)
return r
setup_logging()
redis_url = urlparse.urlparse(settings.REDIS_URL)
if redis_url.path:
redis_db = redis_url.path[1]
else:
redis_db = 0
# TODO: move this to function that create a connection?
redis_connection = redis.StrictRedis(host=redis_url.hostname, port=redis_url.port, db=redis_db, password=redis_url.password)
redis_connection = create_redis_connection()
statsd_client = StatsClient(host=settings.STATSD_HOST, port=settings.STATSD_PORT, prefix=settings.STATSD_PREFIX)

View File

@ -1,6 +1,8 @@
import json
import hashlib
import logging
import os
import threading
import time
import datetime
import itertools
@ -19,18 +21,28 @@ class Database(object):
self.database_name = self.database_config.pop('name')
self.database = peewee.PostgresqlDatabase(self.database_name, **self.database_config)
self.app = None
self.pid = os.getpid()
def init_app(self, app):
self.app = app
self.register_handlers()
def connect_db(self):
self._check_pid()
self.database.connect()
def close_db(self, exc):
self._check_pid()
if not self.database.is_closed():
self.database.close()
def _check_pid(self):
current_pid = os.getpid()
if self.pid != current_pid:
logging.info("New pid detected (%d!=%d); resetting database lock.", self.pid, current_pid)
self.pid = os.getpid()
self.database._conn_lock = threading.Lock()
def register_handlers(self):
self.app.before_request(self.connect_db)
self.app.teardown_request(self.close_db)

View File

@ -1,17 +1,29 @@
import time
import datetime
from celery.utils.log import get_task_logger
import logging
from celery.result import AsyncResult
import redis
from redash.data.query_runner import get_query_runner
from redash import models, redis_connection, statsd_client
from redash.worker import celery
from celery import Task
from celery.result import AsyncResult
from celery.utils.log import get_task_logger
from redash import redis_connection, models, statsd_client
from redash.utils import gen_query_hash
from redash.worker import celery
from redash.data.query_runner import get_query_runner
logger = get_task_logger(__name__)
class BaseTask(Task):
abstract = True
def after_return(self, *args, **kwargs):
models.db.close_db(None)
def __call__(self, *args, **kwargs):
models.db.connect_db()
return super(BaseTask, self).__call__(*args, **kwargs)
class QueryTask(object):
MAX_RETRIES = 5
@ -105,7 +117,7 @@ class QueryTask(object):
return self._async_result.revoke(terminate=True)
@celery.task
@celery.task(base=BaseTask)
def refresh_queries():
# self.status['last_refresh_at'] = time.time()
# self._save_status()
@ -134,7 +146,7 @@ def refresh_queries():
statsd_client.gauge('manager.seconds_since_refresh', now - float(status.get('last_refresh_at', now)))
@celery.task(bind=True, track_started=True)
@celery.task(bind=True, base=BaseTask, track_started=True)
def execute_query(self, query, data_source_id):
# TODO: maybe this should be a class?
start_time = time.time()