Add basic math extension functions (#2123)

This commit is contained in:
Teddy Reed 2016-05-28 17:46:24 -07:00
parent d08e55545a
commit 37544138b5
4 changed files with 229 additions and 36 deletions

View File

@ -4,6 +4,7 @@ ADD_OSQUERY_LIBRARY(TRUE osquery_sql
ADD_OSQUERY_LIBRARY(FALSE osquery_sql_internal
sqlite_util.cpp
sqlite_math.cpp
virtual_table.cpp
)

205
osquery/sql/sqlite_math.cpp Normal file
View File

@ -0,0 +1,205 @@
/*
* Copyright (c) 2014-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree. An additional grant
* of patent rights can be found in the PATENTS file in the same directory.
*
*/
#include <assert.h>
#include <errno.h>
#include <math.h>
#include <string.h>
#include <functional>
#include <sqlite3.h>
namespace osquery {
using DoubleDoubleFunction = std::function<double(double)>;
/**
* @brief Call a math function that takes a double and returns a double.
*/
static void callDoubleFunc(sqlite3_context *context,
int argc,
sqlite3_value **argv,
DoubleDoubleFunction f) {
double rVal = 0.0, val;
assert(argc == 1);
switch (sqlite3_value_type(argv[0])) {
case SQLITE_NULL:
sqlite3_result_null(context);
break;
default:
rVal = sqlite3_value_double(argv[0]);
errno = 0;
val = f(rVal);
if (errno == 0) {
sqlite3_result_double(context, val);
} else {
sqlite3_result_error(context, strerror(errno), errno);
}
break;
}
}
static void sinFunc(sqlite3_context *context, int argc, sqlite3_value **argv) {
callDoubleFunc(context, argc, argv, sin);
}
static void cosFunc(sqlite3_context *context, int argc, sqlite3_value **argv) {
callDoubleFunc(context, argc, argv, cos);
}
static void tanFunc(sqlite3_context *context, int argc, sqlite3_value **argv) {
callDoubleFunc(context, argc, argv, tan);
}
static void asinFunc(sqlite3_context *context, int argc, sqlite3_value **argv) {
callDoubleFunc(context, argc, argv, asin);
}
static void acosFunc(sqlite3_context *context, int argc, sqlite3_value **argv) {
callDoubleFunc(context, argc, argv, acos);
}
static void atanFunc(sqlite3_context *context, int argc, sqlite3_value **argv) {
callDoubleFunc(context, argc, argv, atan);
}
static double cot(double x) { return 1.0 / tan(x); }
static void cotFunc(sqlite3_context *context, int argc, sqlite3_value **argv) {
callDoubleFunc(context, argc, argv, cot);
}
static void logFunc(sqlite3_context *context, int argc, sqlite3_value **argv) {
callDoubleFunc(context, argc, argv, log);
}
static void log10Func(sqlite3_context *context,
int argc,
sqlite3_value **argv) {
callDoubleFunc(context, argc, argv, log10);
}
static void sqrtFunc(sqlite3_context *context, int argc, sqlite3_value **argv) {
callDoubleFunc(context, argc, argv, sqrt);
}
static void expFunc(sqlite3_context *context, int argc, sqlite3_value **argv) {
callDoubleFunc(context, argc, argv, exp);
}
static void powerFunc(sqlite3_context *context,
int argc,
sqlite3_value **argv) {
double r1 = 0.0;
double r2 = 0.0;
double val;
assert(argc == 2);
if (sqlite3_value_type(argv[0]) == SQLITE_NULL ||
sqlite3_value_type(argv[1]) == SQLITE_NULL) {
sqlite3_result_null(context);
} else {
r1 = sqlite3_value_double(argv[0]);
r2 = sqlite3_value_double(argv[1]);
errno = 0;
val = pow(r1, r2);
if (errno == 0) {
sqlite3_result_double(context, val);
} else {
sqlite3_result_error(context, strerror(errno), errno);
}
}
}
static void callCastedDoubleFunc(sqlite3_context *context,
int argc,
sqlite3_value **argv,
DoubleDoubleFunction f) {
double rVal = 0.0;
assert(argc == 1);
switch (sqlite3_value_type(argv[0])) {
case SQLITE_INTEGER: {
int64_t iVal = sqlite3_value_int64(argv[0]);
sqlite3_result_int64(context, iVal);
break;
}
case SQLITE_NULL:
sqlite3_result_null(context);
break;
default:
rVal = sqlite3_value_double(argv[0]);
sqlite3_result_int64(context, (int64_t)f(rVal));
break;
}
}
static void ceilFunc(sqlite3_context *context, int argc, sqlite3_value **argv) {
callCastedDoubleFunc(context, argc, argv, ceil);
}
static void floorFunc(sqlite3_context *context,
int argc,
sqlite3_value **argv) {
callCastedDoubleFunc(context, argc, argv, floor);
}
/** Convert Degrees into Radians */
static double deg2rad(double x) { return x * M_PI / 180.0; }
/** Convert Radians into Degrees */
static double rad2deg(double x) { return 180.0 * x / M_PI; }
static void rad2degFunc(sqlite3_context *context,
int argc,
sqlite3_value **argv) {
callDoubleFunc(context, argc, argv, rad2deg);
}
static void deg2radFunc(sqlite3_context *context,
int argc,
sqlite3_value **argv) {
callDoubleFunc(context, argc, argv, deg2rad);
}
static void piFunc(sqlite3_context *context, int argc, sqlite3_value **argv) {
sqlite3_result_double(context, M_PI);
}
struct FuncDef {
const char *zFunctionName;
int nArg;
void (*xFunc)(sqlite3_context *, int, sqlite3_value **);
};
void registerMathExtensions(sqlite3 *db) {
// This approach to adding non-standard Math functions was inspired by the
// somewhat deprecated/legacy work by Liam Healy from 2010 in the extension
// functions contribution.
static const struct FuncDef aFuncs[] = {
{"sqrt", 1, sqrtFunc}, {"acos", 1, acosFunc},
{"asin", 1, asinFunc}, {"atan", 1, atanFunc},
{"cos", 1, cosFunc}, {"sin", 1, sinFunc},
{"tan", 1, tanFunc}, {"cot", 1, cotFunc},
{"exp", 1, expFunc}, {"log", 1, logFunc},
{"log10", 1, log10Func}, {"power", 2, powerFunc},
{"ceil", 1, ceilFunc}, {"floor", 1, floorFunc},
{"degrees", 1, rad2degFunc}, {"radians", 1, deg2radFunc},
{"pi", 0, piFunc},
};
for (size_t i = 0; i < sizeof(aFuncs) / sizeof(struct FuncDef); i++) {
sqlite3_create_function(db, aFuncs[i].zFunctionName, aFuncs[i].nArg,
SQLITE_UTF8, nullptr, aFuncs[i].xFunc, nullptr,
nullptr);
}
}
}

View File

@ -35,49 +35,28 @@ using SQLiteDBInstanceRef = std::shared_ptr<SQLiteDBInstance>;
*
* Details of this map are defined at: http://www.sqlite.org/c3ref/c_abort.html
*/
// clang-format off
const std::map<int, std::string> kSQLiteReturnCodes = {
{0, "SQLITE_OK"},
{1, "SQLITE_ERROR"},
{2, "SQLITE_INTERNAL"},
{3, "SQLITE_PERM"},
{4, "SQLITE_ABORT"},
{5, "SQLITE_BUSY"},
{6, "SQLITE_LOCKED"},
{7, "SQLITE_NOMEM"},
{8, "SQLITE_READONLY"},
{9, "SQLITE_INTERRUPT"},
{10, "SQLITE_IOERR"},
{11, "SQLITE_CORRUPT"},
{12, "SQLITE_NOTFOUND"},
{13, "SQLITE_FULL"},
{14, "SQLITE_CANTOPEN"},
{15, "SQLITE_PROTOCOL"},
{16, "SQLITE_EMPTY"},
{17, "SQLITE_SCHEMA"},
{18, "SQLITE_TOOBIG"},
{19, "SQLITE_CONSTRAINT"},
{20, "SQLITE_MISMATCH"},
{21, "SQLITE_MISUSE"},
{22, "SQLITE_NOLFS"},
{23, "SQLITE_AUTH"},
{24, "SQLITE_FORMAT"},
{25, "SQLITE_RANGE"},
{26, "SQLITE_NOTADB"},
{27, "SQLITE_NOTICE"},
{28, "SQLITE_WARNING"},
{100, "SQLITE_ROW"},
{0, "SQLITE_OK"}, {1, "SQLITE_ERROR"}, {2, "SQLITE_INTERNAL"},
{3, "SQLITE_PERM"}, {4, "SQLITE_ABORT"}, {5, "SQLITE_BUSY"},
{6, "SQLITE_LOCKED"}, {7, "SQLITE_NOMEM"}, {8, "SQLITE_READONLY"},
{9, "SQLITE_INTERRUPT"}, {10, "SQLITE_IOERR"}, {11, "SQLITE_CORRUPT"},
{12, "SQLITE_NOTFOUND"}, {13, "SQLITE_FULL"}, {14, "SQLITE_CANTOPEN"},
{15, "SQLITE_PROTOCOL"}, {16, "SQLITE_EMPTY"}, {17, "SQLITE_SCHEMA"},
{18, "SQLITE_TOOBIG"}, {19, "SQLITE_CONSTRAINT"}, {20, "SQLITE_MISMATCH"},
{21, "SQLITE_MISUSE"}, {22, "SQLITE_NOLFS"}, {23, "SQLITE_AUTH"},
{24, "SQLITE_FORMAT"}, {25, "SQLITE_RANGE"}, {26, "SQLITE_NOTADB"},
{27, "SQLITE_NOTICE"}, {28, "SQLITE_WARNING"}, {100, "SQLITE_ROW"},
{101, "SQLITE_DONE"},
};
const std::map<std::string, std::string> kMemoryDBSettings = {
{"synchronous", "OFF"},
{"count_changes", "OFF"},
{"default_temp_store", "0"},
{"auto_vacuum", "FULL"},
{"journal_mode", "OFF"},
{"cache_size", "0"},
{"synchronous", "OFF"}, {"count_changes", "OFF"},
{"default_temp_store", "0"}, {"auto_vacuum", "FULL"},
{"journal_mode", "OFF"}, {"cache_size", "0"},
{"page_count", "0"},
};
// clang-format on
#define OpComparator(x) \
{ x, QueryPlanner::Opcode(OpReg::P2, INTEGER_TYPE) }
@ -200,6 +179,9 @@ static inline void openOptimized(sqlite3*& db) {
settings += "PRAGMA " + setting.first + "=" + setting.second + "; ";
}
sqlite3_exec(db, settings.c_str(), nullptr, nullptr, nullptr);
// Register function extensions.
registerMathExtensions(db);
}
void SQLiteDBInstance::init() {

View File

@ -317,4 +317,9 @@ std::string getStringForSQLiteReturnCode(int code);
* should be a non-const reference to a std::vector<Row>.
*/
int queryDataCallback(void* argument, int argc, char* argv[], char* column[]);
/**
* @brief Register math-related 'custom' functions.
*/
void registerMathExtensions(sqlite3* db);
}