From 3d6ca7d5a4bb7069876bdd5972104a625082a2ac Mon Sep 17 00:00:00 2001 From: Zachary Wasserman Date: Wed, 1 Mar 2017 13:14:26 -0800 Subject: [PATCH] Use sockjs to gracefully degrade websockets (#1255) Use the [SockJS Protocol](https://github.com/sockjs/sockjs-protocol) to handle bidirectional communication instead of plain websockets. This allows distributed queries to function in situations in which they previously failed (Load balancers not supporting websockets, issues with Safari and self-signed certs, etc.). Also includes fixes to the JS message handling logic where slightly different message delivery semantics (when using XHR) were exposing bugs. Fixes #1241, #1327. --- CHANGELOG.md | 10 +- docs/third-party/licenses.md | 10 ++ frontend/kolide/base.js | 4 +- frontend/kolide/index.js | 4 +- .../pages/queries/QueryPage/QueryPage.jsx | 22 +-- .../redux/nodes/entities/campaigns/helpers.js | 132 ++++++++------- .../nodes/entities/campaigns/helpers.tests.js | 119 +++++++------ glide.lock | 8 +- glide.yaml | 1 + package.json | 1 + server/service/endpoint_campaigns.go | 48 ++++-- server/service/handler.go | 8 +- server/websocket/websocket.go | 36 +--- server/websocket/websocket_test.go | 159 +++++++++--------- yarn.lock | 63 ++++++- 15 files changed, 360 insertions(+), 265 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ca53d9a3..0c9366606 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,12 @@ -* Fix issue with Distributed Query Pack results full screen feature that broke the browser scrolling abilities +* Fix Distributed Query compatibility with load balancers and Safari. + + Customers running Kolide behind a web balancer lacking support for + websockets were unable to use the distributed query feature. Also, in + certain circumstances, Safari users with a self-signed cert for Kolide + would receive an error. This release add a fallback mechanism from + websockets using SockJS for improved compatibility. + +* Fix issue with Distributed Query Pack results full screen feature that broke the browser scrolling abilities. * Fix bug in which host counts in the sidebar did not match up with displayed hosts. diff --git a/docs/third-party/licenses.md b/docs/third-party/licenses.md index 424847441..5f63580ae 100644 --- a/docs/third-party/licenses.md +++ b/docs/third-party/licenses.md @@ -311,6 +311,7 @@ Third-Party Licenses | [etag](https://www.npmjs.com/package/etag) | [MIT](https://opensource.org/licenses/MIT) | | [event-emitter](https://www.npmjs.com/package/event-emitter) | [MIT](https://opensource.org/licenses/MIT) | | [events](https://www.npmjs.com/package/events) | [MIT](https://opensource.org/licenses/MIT) | +| [eventsource](https://www.npmjs.com/package/eventsource) | [MIT](https://opensource.org/licenses/MIT) | | [evp_bytestokey](https://www.npmjs.com/package/evp_bytestokey) | [MIT](https://opensource.org/licenses/MIT) | | [exit-hook](https://www.npmjs.com/package/exit-hook) | [MIT](https://opensource.org/licenses/MIT) | | [expand-brackets](https://www.npmjs.com/package/expand-brackets) | [MIT](https://opensource.org/licenses/MIT) | @@ -324,6 +325,7 @@ Third-Party Licenses | [extsprintf](https://www.npmjs.com/package/extsprintf) | [MIT](https://opensource.org/licenses/MIT) | | [fast-levenshtein](https://www.npmjs.com/package/fast-levenshtein) | [MIT](https://opensource.org/licenses/MIT) | | [fastparse](https://www.npmjs.com/package/fastparse) | [MIT](https://opensource.org/licenses/MIT) | +| [faye-websocket](https://www.npmjs.com/package/faye-websocket) | [MIT](https://opensource.org/licenses/MIT) | | [fbjs](https://www.npmjs.com/package/fbjs) | [BSD-3-Clause](https://opensource.org/licenses/BSD-3-Clause) | | [figures](https://www.npmjs.com/package/figures) | [MIT](https://opensource.org/licenses/MIT) | | [file-entry-cache](https://www.npmjs.com/package/file-entry-cache) | [MIT](https://opensource.org/licenses/MIT) | @@ -381,6 +383,7 @@ Third-Party Licenses | [github.com/gorilla/mux](https://github.com/gorilla/mux) | [NewBSD](https://opensource.org/licenses/BSD-3-Clause) | | [github.com/gorilla/websocket](https://github.com/gorilla/websocket) | [FreeBSD](https://opensource.org/licenses/BSD-2-Clause) | | [github.com/hashicorp/hcl](https://github.com/hashicorp/hcl) | [MPL-2.0](https://www.mozilla.org/en-US/MPL/2.0/) | +| [github.com/igm/sockjs-go](https://github.com/igm/sockjs-go) | [NewBSD](https://opensource.org/licenses/BSD-3-Clause) | | [github.com/inconshreveable/mousetrap](https://github.com/inconshreveable/mousetrap) | [Apache-2.0](https://www.apache.org/licenses/LICENSE-2.0) | | [github.com/jmoiron/sqlx](https://github.com/jmoiron/sqlx) | [MIT](https://opensource.org/licenses/MIT) | | [github.com/jordan-wright/email](https://github.com/jordan-wright/email) | [MIT](https://opensource.org/licenses/MIT) | @@ -660,6 +663,7 @@ Third-Party Licenses | [onetime](https://www.npmjs.com/package/onetime) | [MIT](https://opensource.org/licenses/MIT) | | [optimist](https://www.npmjs.com/package/optimist) | [MIT/X11](https://opensource.org/licenses/MIT) | | [optionator](https://www.npmjs.com/package/optionator) | [MIT](https://opensource.org/licenses/MIT) | +| [original](https://www.npmjs.com/package/original) | [MIT](https://opensource.org/licenses/MIT) | | [os-browserify](https://www.npmjs.com/package/os-browserify) | [MIT](https://opensource.org/licenses/MIT) | | [os-homedir](https://www.npmjs.com/package/os-homedir) | [MIT](https://opensource.org/licenses/MIT) | | [os-locale](https://www.npmjs.com/package/os-locale) | [MIT](https://opensource.org/licenses/MIT) | @@ -761,6 +765,7 @@ Third-Party Licenses | [query-string](https://www.npmjs.com/package/query-string) | [MIT](https://opensource.org/licenses/MIT) | | [querystring](https://www.npmjs.com/package/querystring) | [MIT](https://opensource.org/licenses/MIT) | | [querystring-es3](https://www.npmjs.com/package/querystring-es3) | [MIT](https://opensource.org/licenses/MIT) | +| [querystringify](https://www.npmjs.com/package/querystringify) | [MIT](https://opensource.org/licenses/MIT) | | [raf](https://www.npmjs.com/package/raf) | [MIT](https://opensource.org/licenses/MIT) | | [randomatic](https://www.npmjs.com/package/randomatic) | [MIT](https://opensource.org/licenses/MIT) | | [randombytes](https://www.npmjs.com/package/randombytes) | [MIT](https://opensource.org/licenses/MIT) | @@ -815,6 +820,7 @@ Third-Party Licenses | [require-hacker](https://www.npmjs.com/package/require-hacker) | [MIT](https://opensource.org/licenses/MIT) | | [require-main-filename](https://www.npmjs.com/package/require-main-filename) | [ISC](https://opensource.org/licenses/ISC) | | [require-uncached](https://www.npmjs.com/package/require-uncached) | [MIT](https://opensource.org/licenses/MIT) | +| [requires-port](https://www.npmjs.com/package/requires-port) | [MIT](https://opensource.org/licenses/MIT) | | [resolve](https://www.npmjs.com/package/resolve) | [MIT](https://opensource.org/licenses/MIT) | | [resolve-from](https://www.npmjs.com/package/resolve-from) | [MIT](https://opensource.org/licenses/MIT) | | [restore-cursor](https://www.npmjs.com/package/restore-cursor) | [MIT](https://opensource.org/licenses/MIT) | @@ -845,6 +851,7 @@ Third-Party Licenses | [slash](https://www.npmjs.com/package/slash) | [MIT](https://opensource.org/licenses/MIT) | | [slice-ansi](https://www.npmjs.com/package/slice-ansi) | [MIT](https://opensource.org/licenses/MIT) | | [sntp](https://www.npmjs.com/package/sntp) | [BSD](https://opensource.org/licenses/BSD-3-Clause) | +| [sockjs-client](https://www.npmjs.com/package/sockjs-client) | [MIT](https://opensource.org/licenses/MIT) | | [sort-keys](https://www.npmjs.com/package/sort-keys) | [MIT](https://opensource.org/licenses/MIT) | | [source-list-map](https://www.npmjs.com/package/source-list-map) | [MIT](https://opensource.org/licenses/MIT) | | [source-map](https://www.npmjs.com/package/source-map) | [BSD-3-Clause](https://opensource.org/licenses/BSD-3-Clause) | @@ -922,6 +929,7 @@ Third-Party Licenses | [upper-case](https://www.npmjs.com/package/upper-case) | [MIT](https://opensource.org/licenses/MIT) | | [url](https://www.npmjs.com/package/url) | [MIT](https://opensource.org/licenses/MIT) | | [url-loader](https://www.npmjs.com/package/url-loader) | [MIT](https://opensource.org/licenses/MIT) | +| [url-parse](https://www.npmjs.com/package/url-parse) | [MIT](https://opensource.org/licenses/MIT) | | [user-home](https://www.npmjs.com/package/user-home) | [MIT](https://opensource.org/licenses/MIT) | | [util](https://www.npmjs.com/package/util) | [MIT](https://opensource.org/licenses/MIT) | | [util-deprecate](https://www.npmjs.com/package/util-deprecate) | [MIT](https://opensource.org/licenses/MIT) | @@ -945,6 +953,8 @@ Third-Party Licenses | [webpack-hot-middleware](https://www.npmjs.com/package/webpack-hot-middleware) | [MIT](https://opensource.org/licenses/MIT) | | [webpack-hot-middleware-example](https://www.npmjs.com/package/webpack-hot-middleware-example) | [MIT](https://opensource.org/licenses/MIT) | | [webpack-sources](https://www.npmjs.com/package/webpack-sources) | [MIT](https://opensource.org/licenses/MIT) | +| [websocket-driver](https://www.npmjs.com/package/websocket-driver) | [MIT](https://opensource.org/licenses/MIT) | +| [websocket-extensions](https://www.npmjs.com/package/websocket-extensions) | [MIT](https://opensource.org/licenses/MIT) | | [whatwg-encoding](https://www.npmjs.com/package/whatwg-encoding) | [WTFPL](http://www.wtfpl.net/txt/copying/) | | [whatwg-fetch](https://www.npmjs.com/package/whatwg-fetch) | [MIT](https://opensource.org/licenses/MIT) | | [whatwg-url](https://www.npmjs.com/package/whatwg-url) | [MIT](https://opensource.org/licenses/MIT) | diff --git a/frontend/kolide/base.js b/frontend/kolide/base.js index 181d6666c..7b3667ee3 100644 --- a/frontend/kolide/base.js +++ b/frontend/kolide/base.js @@ -12,10 +12,9 @@ const REQUEST_METHODS = { class Base { constructor () { - const { host, origin } = global.window.location; + const { origin } = global.window.location; this.baseURL = `${origin}/api`; - this.websocketBaseURL = `wss://${host}/api`; this.bearerToken = local.getItem('auth_token'); } @@ -119,4 +118,3 @@ class Base { } export default Base; - diff --git a/frontend/kolide/index.js b/frontend/kolide/index.js index 2a95bef36..6e9c91903 100644 --- a/frontend/kolide/index.js +++ b/frontend/kolide/index.js @@ -1,4 +1,5 @@ import { get, omit, trim } from 'lodash'; +import SockJS from 'sockjs-client'; import { appendTargetTypeToTargets } from 'redux/nodes/entities/targets/helpers'; import Base from 'kolide/base'; @@ -239,10 +240,11 @@ class Kolide extends Base { queries: { run: (campaignID) => { return new Promise((resolve) => { - const socket = new global.WebSocket(`${this.websocketBaseURL}/v1/kolide/results/${campaignID}`); + const socket = new SockJS(`${this.baseURL}/v1/kolide/results`, undefined, {}); socket.onopen = () => { socket.send(JSON.stringify({ type: 'auth', data: { token: local.getItem('auth_token') } })); + socket.send(JSON.stringify({ type: 'select_campaign', data: { campaign_id: campaignID } })); }; return resolve(socket); diff --git a/frontend/pages/queries/QueryPage/QueryPage.jsx b/frontend/pages/queries/QueryPage/QueryPage.jsx index ba5274bad..d23236489 100644 --- a/frontend/pages/queries/QueryPage/QueryPage.jsx +++ b/frontend/pages/queries/QueryPage/QueryPage.jsx @@ -189,6 +189,7 @@ export class QueryPage extends Component { return Kolide.websockets.queries.run(campaignResponse.id) .then((socket) => { this.setupDistributedQuery(socket); + this.setState({ campaign: campaignResponse, queryIsRunning: true, @@ -199,26 +200,17 @@ export class QueryPage extends Component { const { previousSocketData } = this; if (previousSocketData && isEqual(socketData, previousSocketData)) { - this.previousSocketData = socketData; - return false; } + this.previousSocketData = socketData; - return campaignHelpers.update(this.state.campaign, socketData) - .then((updatedCampaign) => { - const { status } = updatedCampaign; + this.setState(campaignHelpers.updateCampaignState(socketData)); - if (status === 'finished') { - this.teardownDistributedQuery(); + if (socketData.type === 'status' && socketData.data === 'finished') { + return this.teardownDistributedQuery(); + } - return false; - } - - this.previousSocketData = socketData; - this.setState({ campaign: updatedCampaign }); - - return false; - }); + return false; }; }); }) diff --git a/frontend/redux/nodes/entities/campaigns/helpers.js b/frontend/redux/nodes/entities/campaigns/helpers.js index 01cfbd086..3fe3cf1e3 100644 --- a/frontend/redux/nodes/entities/campaigns/helpers.js +++ b/frontend/redux/nodes/entities/campaigns/helpers.js @@ -2,66 +2,76 @@ export const destroyFunc = (campaign) => { return Promise.resolve(campaign); }; -export const update = (campaign, socketData) => { - return new Promise((resolve) => { - const { type, data } = socketData; - - if (type === 'totals') { - return resolve({ - ...campaign, - totals: data, - }); - } - - if (type === 'result') { - const queryResults = campaign.query_results || []; - const hosts = campaign.hosts || []; - const { host, rows } = data; - const { hosts_count: hostsCount } = campaign; - let newHostsCount; - - if (data.error) { - const newFailed = hostsCount.failed + 1; - const newTotal = hostsCount.successful + newFailed; - - newHostsCount = { - successful: hostsCount.successful, - failed: newFailed, - total: newTotal, - }; - } else { - const newSuccessful = hostsCount.successful + 1; - const newTotal = hostsCount.failed + newSuccessful; - - newHostsCount = { - successful: newSuccessful, - failed: hostsCount.failed, - total: newTotal, - }; - } - - return resolve({ - ...campaign, - hosts: [ - ...hosts, - host, - ], - query_results: [ - ...queryResults, - ...rows, - ], - hosts_count: newHostsCount, - }); - } - - if (type === 'status') { - const { status } = data; - - return resolve({ ...campaign, status }); - } - - return resolve(campaign); - }); +const updateCampaignStateFromTotals = (campaign, { data }) => { + return { + campaign: { ...campaign, totals: data }, + }; }; -export default { destroyFunc, update }; +const updateCampaignStateFromResults = (campaign, { data }) => { + const queryResults = campaign.query_results || []; + const hosts = campaign.hosts || []; + const { host, rows } = data; + const { hosts_count: hostsCount } = campaign; + const newHosts = [...hosts, host]; + const newQueryResults = [...queryResults, ...rows]; + let newHostsCount; + + if (data.error) { + const newFailed = hostsCount.failed + 1; + const newTotal = hostsCount.successful + newFailed; + + newHostsCount = { + successful: hostsCount.successful, + failed: newFailed, + total: newTotal, + }; + } else { + const newSuccessful = hostsCount.successful + 1; + const newTotal = hostsCount.failed + newSuccessful; + + newHostsCount = { + successful: newSuccessful, + failed: hostsCount.failed, + total: newTotal, + }; + } + + return { + campaign: { + ...campaign, + hosts: newHosts, + query_results: newQueryResults, + hosts_count: newHostsCount, + }, + }; +}; + +const updateCampaignStateFromStatus = (campaign, { data }) => { + const { status } = data; + const updatedCampaign = { ...campaign, status }; + + return { + campaign: updatedCampaign, + queryIsRunning: data !== 'finished', + }; +}; + +export const updateCampaignState = (socketData) => { + return (prevState) => { + const { campaign } = prevState; + + switch (socketData.type) { + case 'totals': + return updateCampaignStateFromTotals(campaign, socketData); + case 'result': + return updateCampaignStateFromResults(campaign, socketData); + case 'status': + return updateCampaignStateFromStatus(campaign, socketData); + default: + return { campaign }; + } + }; +}; + +export default { destroyFunc, updateCampaignState }; diff --git a/frontend/redux/nodes/entities/campaigns/helpers.tests.js b/frontend/redux/nodes/entities/campaigns/helpers.tests.js index 7c5c95a87..b4106347f 100644 --- a/frontend/redux/nodes/entities/campaigns/helpers.tests.js +++ b/frontend/redux/nodes/entities/campaigns/helpers.tests.js @@ -28,7 +28,7 @@ const campaignWithResults = { online: 2, }, }; -const { destroyFunc, update } = helpers; +const { destroyFunc, updateCampaignState } = helpers; const resultSocketData = { type: 'result', data: { @@ -40,6 +40,10 @@ const resultSocketData = { ], }, }; +const statusSocketData = { + type: 'status', + data: 'finished', +}; const totalsSocketData = { type: 'totals', data: { @@ -60,67 +64,56 @@ describe('campaign entity - helpers', () => { }); }); - describe('#update', () => { - it('appends query results to the campaign when the campaign has query results', (done) => { - update(campaignWithResults, resultSocketData) - .then((response) => { - expect(response.query_results).toEqual([ - ...campaignWithResults.query_results, - { feature: 'product_name', value: 'Intel Core' }, - { feature: 'family', value: '0600' }, - ]); - expect(response.hosts).toInclude(host); - done(); - }) - .catch(done); + describe('#updateCampaignState', () => { + it('appends query results to the campaign when the campaign has query results', () => { + const state = { campaign: campaignWithResults }; + const updatedState = updateCampaignState(resultSocketData)(state, {}); + + expect(updatedState.campaign.query_results).toEqual([ + ...campaignWithResults.query_results, + { feature: 'product_name', value: 'Intel Core' }, + { feature: 'family', value: '0600' }, + ]); + expect(updatedState.campaign.hosts).toInclude(host); }); - it('adds query results to the campaign when the campaign does not have query results', (done) => { - update(campaign, resultSocketData) - .then((response) => { - expect(response.query_results).toEqual([ - { feature: 'product_name', value: 'Intel Core' }, - { feature: 'family', value: '0600' }, - ]); - expect(response.hosts).toInclude(host); - done(); - }) - .catch(done); + it('adds query results to the campaign when the campaign does not have query results', () => { + const state = { campaign }; + const updatedState = updateCampaignState(resultSocketData)(state, {}); + + expect(updatedState.campaign.query_results).toEqual([ + { feature: 'product_name', value: 'Intel Core' }, + { feature: 'family', value: '0600' }, + ]); + expect(updatedState.campaign.hosts).toInclude(host); }); - it('updates totals on the campaign when the campaign has totals', (done) => { - update(campaignWithResults, totalsSocketData) - .then((response) => { - expect(response.totals).toEqual(totalsSocketData.data); - done(); - }) - .catch(done); + it('updates totals on the campaign when the campaign has totals', () => { + const state = { campaign: campaignWithResults }; + const updatedState = updateCampaignState(totalsSocketData)(state, {}); + + expect(updatedState.campaign.totals).toEqual(totalsSocketData.data); }); - it('adds totals to the campaign when the campaign does not have totals', (done) => { - update(campaign, totalsSocketData) - .then((response) => { - expect(response.totals).toEqual(totalsSocketData.data); - done(); - }) - .catch(done); + it('adds totals to the campaign when the campaign does not have totals', () => { + const state = { campaign }; + const updatedState = updateCampaignState(totalsSocketData)(state, {}); + + expect(updatedState.campaign.totals).toEqual(totalsSocketData.data); }); - it('increases the successful hosts count and total when the result has no error', (done) => { - update(campaign, resultSocketData) - .then((response) => { - expect(response.hosts_count).toEqual({ - successful: 1, - failed: 0, - total: 1, - }); + it('increases the successful hosts count and total when the result has no error', () => { + const state = { campaign }; + const updatedState = updateCampaignState(resultSocketData)(state, {}); - done(); - }) - .catch(done); + expect(updatedState.campaign.hosts_count).toEqual({ + successful: 1, + failed: 0, + total: 1, + }); }); - it('increases the failed hosts count and total when the result has an error', (done) => { + it('increases the failed hosts count and total when the result has an error', () => { const resultErrorSocketData = { type: 'result', data: { @@ -129,17 +122,21 @@ describe('campaign entity - helpers', () => { }, }; - update(campaign, resultErrorSocketData) - .then((response) => { - expect(response.hosts_count).toEqual({ - successful: 0, - failed: 1, - total: 1, - }); + const state = { campaign }; + const updatedState = updateCampaignState(resultErrorSocketData)(state, {}); - done(); - }) - .catch(done); + expect(updatedState.campaign.hosts_count).toEqual({ + successful: 0, + failed: 1, + total: 1, + }); + }); + + it('sets the queryIsRunning attribute for status socket data', () => { + const state = { campaign }; + const updatedState = updateCampaignState(statusSocketData)(state, {}); + + expect(updatedState.queryIsRunning).toEqual(false); }); }); }); diff --git a/glide.lock b/glide.lock index 4417eb38c..aa91857d8 100644 --- a/glide.lock +++ b/glide.lock @@ -1,5 +1,5 @@ -hash: 2e60f99104f57234e2a153d088d23ae4db10245052c3335696e9abe7e9e2064e -updated: 2017-02-06T20:10:33.182812801-08:00 +hash: 23a51e1dbed0a504e02209cd5ef0fbb1bcb894ca7861a818154f7e8fab5fac48 +updated: 2017-02-28T11:03:55.538531823-08:00 imports: - name: github.com/alecthomas/template version: a0175ee3bccc567396460bf5acd36800cb10c49c @@ -66,6 +66,10 @@ imports: - json/parser - json/scanner - json/token +- name: github.com/igm/sockjs-go + version: 1f275fbd3bcc9a21ec90217b80f40db44404410b + subpackages: + - sockjs - name: github.com/inconshreveable/mousetrap version: 76626ae9c91c4f2a10f34cad8ce83ea42c93bb75 - name: github.com/jmoiron/sqlx diff --git a/glide.yaml b/glide.yaml index a1d238591..7e6618c26 100644 --- a/glide.yaml +++ b/glide.yaml @@ -66,3 +66,4 @@ import: version: ^0.3.0 - package: github.com/go-yaml/yaml - package: github.com/ryanuber/go-license +- package: github.com/igm/sockjs-go diff --git a/package.json b/package.json index 7d90287f3..3027632ba 100644 --- a/package.json +++ b/package.json @@ -72,6 +72,7 @@ "require-hacker": "^2.1.4", "sass-loader": "^4.0.2", "select": "^1.0.6", + "sockjs-client": "^1.1.2", "sqlite-parser": "^1.0.0", "style-loader": "^0.13.0", "stylus-loader": "1.5.1", diff --git a/server/service/endpoint_campaigns.go b/server/service/endpoint_campaigns.go index fb8f81d2b..392343ada 100644 --- a/server/service/endpoint_campaigns.go +++ b/server/service/endpoint_campaigns.go @@ -1,10 +1,12 @@ package service import ( + "encoding/json" "net/http" "github.com/go-kit/kit/endpoint" kitlog "github.com/go-kit/kit/log" + "github.com/igm/sockjs-go/sockjs" "github.com/kolide/kolide/server/contexts/viewer" "github.com/kolide/kolide/server/kolide" "github.com/kolide/kolide/server/websocket" @@ -45,15 +47,13 @@ func makeCreateDistributedQueryCampaignEndpoint(svc kolide.Service) endpoint.End // Stream Distributed Query Campaign Results and Metadata //////////////////////////////////////////////////////////////////////////////// -func makeStreamDistributedQueryCampaignResultsHandler(svc kolide.Service, jwtKey string, logger kitlog.Logger) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // Upgrade to websocket connection - conn, err := websocket.Upgrade(w, r) - if err != nil { - logger.Log("err", err, "msg", "could not upgrade to websocket") - return - } - defer conn.Close() +func makeStreamDistributedQueryCampaignResultsHandler(svc kolide.Service, jwtKey string, logger kitlog.Logger) http.Handler { + opt := sockjs.DefaultOptions + opt.Websocket = true + return sockjs.NewHandler("/api/v1/kolide/results", opt, func(session sockjs.Session) { + defer session.Close(0, "none") + + conn := &websocket.Conn{Session: session} // Receive the auth bearer token token, err := conn.ReadAuthToken() @@ -72,14 +72,34 @@ func makeStreamDistributedQueryCampaignResultsHandler(svc kolide.Service, jwtKey ctx := viewer.NewContext(context.Background(), *vc) - campaignID, err := idFromRequest(r, "id") + msg, err := conn.ReadJSONMessage() if err != nil { - logger.Log("err", err, "invalid campaign ID") - conn.WriteJSONError("invalid campaign ID") + logger.Log("err", err, "msg", "reading select_campaign JSON") + conn.WriteJSONError("error reading select_campaign") + return + } + if msg.Type != "select_campaign" { + logger.Log("err", "unexpected msg type, expected select_campaign", "msg-type", msg.Type) + conn.WriteJSONError("expected select_campaign") return } - svc.StreamCampaignResults(ctx, conn, campaignID) + var info struct { + CampaignID uint `json:"campaign_id"` + } + err = json.Unmarshal(*(msg.Data.(*json.RawMessage)), &info) + if err != nil { + logger.Log("err", err, "msg", "unmarshaling select_campaign data") + conn.WriteJSONError("error unmarshaling select_campaign data") + return + } + if info.CampaignID == 0 { + logger.Log("err", "campaign ID not set") + conn.WriteJSONError("0 is not a valid campaign ID") + return + } - } + svc.StreamCampaignResults(ctx, conn, info.CampaignID) + + }) } diff --git a/server/service/handler.go b/server/service/handler.go index 2b37d29ae..31b928a48 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -316,12 +316,12 @@ func MakeHandler(ctx context.Context, svc kolide.Service, jwtKey string, logger r := mux.NewRouter() attachKolideAPIRoutes(r, kolideHandlers) - r.HandleFunc("/api/v1/kolide/results/{id}", - makeStreamDistributedQueryCampaignResultsHandler(svc, jwtKey, logger)). - Methods("GET").Name("distributed_query_results") - addMetrics(r) + r.PathPrefix("/api/v1/kolide/results/"). + Handler(makeStreamDistributedQueryCampaignResultsHandler(svc, jwtKey, logger)). + Name("distributed_query_results") + return r } diff --git a/server/websocket/websocket.go b/server/websocket/websocket.go index b25b0d750..6fd266581 100644 --- a/server/websocket/websocket.go +++ b/server/websocket/websocket.go @@ -4,10 +4,10 @@ package websocket import ( "encoding/json" - "net/http" "time" - "github.com/gorilla/websocket" + "github.com/igm/sockjs-go/sockjs" + "github.com/kolide/kolide/server/contexts/token" "github.com/pkg/errors" ) @@ -42,32 +42,20 @@ type JSONMessage struct { // Conn is a wrapper for a standard websocket connection with utility methods // added for interacting with Kolide specific message types. type Conn struct { - *websocket.Conn - Timeout time.Duration + sockjs.Session } -// Upgrade is used to upgrade a normal HTTP request to a websocket connection. -func Upgrade(w http.ResponseWriter, r *http.Request) (*Conn, error) { - var upgrader = websocket.Upgrader{ - HandshakeTimeout: defaultTimeout, - } - - conn, err := upgrader.Upgrade(w, r, nil) +func (c *Conn) WriteJSON(msg JSONMessage) error { + buf, err := json.Marshal(msg) if err != nil { - return nil, errors.Wrap(err, "upgrading connection") + return errors.Wrap(err, "marshalling JSON") } - - conn.SetReadLimit(maxMessageSize) - - return &Conn{conn, defaultTimeout}, nil + return errors.Wrap(c.Send(string(buf)), "sending") } // WriteJSONMessage writes the provided data as JSON (using the Message struct), // returning any error condition from the connection. func (c *Conn) WriteJSONMessage(typ string, data interface{}) error { - c.SetWriteDeadline(time.Now().Add(c.Timeout)) - defer c.SetWriteDeadline(time.Time{}) - return c.WriteJSON(JSONMessage{Type: typ, Data: data}) } @@ -86,20 +74,14 @@ func (c *Conn) WriteJSONError(data interface{}) error { // json.Unmarshal(*(msg.Data.(*json.RawMessage)), &foo) // } func (c *Conn) ReadJSONMessage() (*JSONMessage, error) { - c.SetReadDeadline(time.Now().Add(c.Timeout)) - defer c.SetReadDeadline(time.Time{}) - - mType, data, err := c.ReadMessage() + data, err := c.Recv() if err != nil { return nil, errors.Wrap(err, "reading from websocket") } - if mType != websocket.TextMessage { - return nil, errors.Errorf("unsupported websocket message type: %d", mType) - } msg := &JSONMessage{Data: &json.RawMessage{}} - if err := json.Unmarshal(data, msg); err != nil { + if err := json.Unmarshal([]byte(data), msg); err != nil { return nil, errors.Wrap(err, "parsing msg json") } diff --git a/server/websocket/websocket_test.go b/server/websocket/websocket_test.go index ef851fe98..7afa664a0 100644 --- a/server/websocket/websocket_test.go +++ b/server/websocket/websocket_test.go @@ -2,49 +2,57 @@ package websocket import ( "encoding/json" + "errors" "fmt" - "net/http" "net/http/httptest" "net/url" "testing" "time" "github.com/gorilla/websocket" - "github.com/kolide/kolide/server/contexts/token" - "github.com/pkg/errors" + + "github.com/igm/sockjs-go/sockjs" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestTimeout(t *testing.T) { - completed := make(chan struct{}) - handler := func(w http.ResponseWriter, req *http.Request) { - defer func() { completed <- struct{}{} }() - - conn, err := Upgrade(w, req) - require.Nil(t, err) - defer conn.Close() - - conn.Timeout = 1 * time.Millisecond - - _, err = conn.ReadJSONMessage() - assert.NotNil(t, err, "read should timeout and error") - } - - // Connect to websocket handler server - srv := httptest.NewServer(http.HandlerFunc(handler)) - u, _ := url.Parse(srv.URL) - u.Scheme = "ws" - conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) +// readOpenMessage reads the sockjs open message +func readOpenMessage(t *testing.T, conn *websocket.Conn) { + // Read the open message + mType, data, err := conn.ReadMessage() + require.Equal(t, websocket.TextMessage, mType) require.Nil(t, err) - defer conn.Close() - select { - case <-completed: - // Normal - case <-time.After(1 * time.Second): - t.Error("handler did not complete") - } + require.Equal(t, []byte("o"), data, "expected sockjs open message") +} + +// readJSONMessage reads a sockjs JSON message +func readJSONMessage(t *testing.T, conn *websocket.Conn) string { + mType, data, err := conn.ReadMessage() + require.Nil(t, err) + require.Equal(t, websocket.TextMessage, mType) + + assert.Equal(t, "a", string(data[0]), "expected sockjs data frame") + + // Unwrap from sockjs frame + d := []string{} + err = json.Unmarshal(data[1:], &d) + require.Nil(t, err) + require.Len(t, d, 1) + + return d[0] +} + +func writeJSONMessage(t *testing.T, conn *websocket.Conn, typ string, data interface{}) { + buf, err := json.Marshal(JSONMessage{typ, data}) + require.Nil(t, err) + + // Wrap in sockjs frame + d, err := json.Marshal([]string{string(buf)}) + require.Nil(t, err) + + // Writes from the client to the server do not include the "a" + conn.WriteMessage(websocket.TextMessage, d) } func TestWriteJSONMessage(t *testing.T) { @@ -74,34 +82,33 @@ func TestWriteJSONMessage(t *testing.T) { for _, tt := range cases { t.Run("", func(t *testing.T) { - handler := func(w http.ResponseWriter, req *http.Request) { - conn, err := Upgrade(w, req) - require.Nil(t, err) - defer conn.Close() + handler := sockjs.NewHandler("/test", sockjs.DefaultOptions, func(session sockjs.Session) { + defer session.Close(0, "none") + + conn := &Conn{session} require.Nil(t, conn.WriteJSONMessage(tt.typ, tt.data)) - } + }) - // Connect to websocket handler server - srv := httptest.NewServer(http.HandlerFunc(handler)) + srv := httptest.NewServer(handler) u, _ := url.Parse(srv.URL) u.Scheme = "ws" + u.Path += "/test/123/abcdefghijklmnop/websocket" + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) require.Nil(t, err) defer conn.Close() + readOpenMessage(t, conn) dataJSON, err := json.Marshal(tt.data) require.Nil(t, err) // Ensure we read the correct message - mType, data, err := conn.ReadMessage() - require.Nil(t, err) - assert.Equal(t, websocket.TextMessage, mType) + data := readJSONMessage(t, conn) assert.JSONEq(t, fmt.Sprintf(`{"type": "%s", "data": %s}`, tt.typ, dataJSON), - string(data), + data, ) - }) } } @@ -126,34 +133,33 @@ func TestWriteJSONError(t *testing.T) { for _, tt := range cases { t.Run("", func(t *testing.T) { - handler := func(w http.ResponseWriter, req *http.Request) { - conn, err := Upgrade(w, req) - require.Nil(t, err) - defer conn.Close() + handler := sockjs.NewHandler("/test", sockjs.DefaultOptions, func(session sockjs.Session) { + defer session.Close(0, "none") + + conn := &Conn{session} require.Nil(t, conn.WriteJSONError(tt.err)) - } + }) - // Connect to websocket handler server - srv := httptest.NewServer(http.HandlerFunc(handler)) + srv := httptest.NewServer(handler) u, _ := url.Parse(srv.URL) u.Scheme = "ws" + u.Path += "/test/123/abcdefghijklmnop/websocket" + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) require.Nil(t, err) defer conn.Close() + readOpenMessage(t, conn) errJSON, err := json.Marshal(tt.err) require.Nil(t, err) // Ensure we read the correct message - mType, data, err := conn.ReadMessage() - require.Nil(t, err) - assert.Equal(t, websocket.TextMessage, mType) + data := readJSONMessage(t, conn) assert.JSONEq(t, fmt.Sprintf(`{"type": "error", "data": %s}`, errJSON), - string(data), + data, ) - }) } } @@ -194,12 +200,11 @@ func TestReadJSONMessage(t *testing.T) { require.Nil(t, err) completed := make(chan struct{}) - handler := func(w http.ResponseWriter, req *http.Request) { + handler := sockjs.NewHandler("/test", sockjs.DefaultOptions, func(session sockjs.Session) { + defer session.Close(0, "none") defer func() { completed <- struct{}{} }() - conn, err := Upgrade(w, req) - require.Nil(t, err) - defer conn.Close() + conn := &Conn{session} msg, err := conn.ReadJSONMessage() if tt.err == nil { @@ -211,19 +216,21 @@ func TestReadJSONMessage(t *testing.T) { assert.Equal(t, tt.typ, msg.Type) assert.EqualValues(t, &dataJSON, msg.Data) - - } + }) // Connect to websocket handler server - srv := httptest.NewServer(http.HandlerFunc(handler)) + srv := httptest.NewServer(handler) u, _ := url.Parse(srv.URL) u.Scheme = "ws" - wsConn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + u.Path += "/test/123/abcdefghijklmnop/websocket" + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) require.Nil(t, err) - conn := &Conn{wsConn, defaultTimeout} defer conn.Close() - require.Nil(t, conn.WriteJSONMessage(tt.typ, tt.data)) + readOpenMessage(t, conn) + + writeJSONMessage(t, conn, tt.typ, tt.data) select { case <-completed: @@ -239,7 +246,7 @@ func TestReadAuthToken(t *testing.T) { var cases = []struct { typ string data authData - token token.Token + token string err error }{ { @@ -262,12 +269,11 @@ func TestReadAuthToken(t *testing.T) { for _, tt := range cases { t.Run("", func(t *testing.T) { completed := make(chan struct{}) - handler := func(w http.ResponseWriter, req *http.Request) { + handler := sockjs.NewHandler("/test", sockjs.DefaultOptions, func(session sockjs.Session) { + defer session.Close(0, "none") defer func() { completed <- struct{}{} }() - conn, err := Upgrade(w, req) - require.Nil(t, err) - defer conn.Close() + conn := &Conn{session} token, err := conn.ReadAuthToken() if tt.err == nil { @@ -277,19 +283,22 @@ func TestReadAuthToken(t *testing.T) { return } - assert.Equal(t, tt.token, token) - } + assert.EqualValues(t, tt.token, token) + }) // Connect to websocket handler server - srv := httptest.NewServer(http.HandlerFunc(handler)) + srv := httptest.NewServer(handler) u, _ := url.Parse(srv.URL) u.Scheme = "ws" - wsConn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + u.Path += "/test/123/abcdefghijklmnop/websocket" + + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) require.Nil(t, err) - conn := &Conn{wsConn, defaultTimeout} defer conn.Close() - require.Nil(t, conn.WriteJSONMessage(tt.typ, tt.data)) + readOpenMessage(t, conn) + + writeJSONMessage(t, conn, tt.typ, tt.data) select { case <-completed: diff --git a/yarn.lock b/yarn.lock index 55d5848be..3235633a7 100644 --- a/yarn.lock +++ b/yarn.lock @@ -2267,6 +2267,12 @@ events@^1.0.0: version "1.1.1" resolved "https://registry.yarnpkg.com/events/-/events-1.1.1.tgz#9ebdb7635ad099c70dcc4c2a1f5004288e8bd924" +eventsource@0.1.6: + version "0.1.6" + resolved "https://registry.yarnpkg.com/eventsource/-/eventsource-0.1.6.tgz#0acede849ed7dd1ccc32c811bb11b944d4f29232" + dependencies: + original ">=0.0.5" + evp_bytestokey@^1.0.0: version "1.0.0" resolved "https://registry.yarnpkg.com/evp_bytestokey/-/evp_bytestokey-1.0.0.tgz#497b66ad9fef65cd7c08a6180824ba1476b66e53" @@ -2366,6 +2372,12 @@ fastparse@^1.1.1: version "1.1.1" resolved "https://registry.yarnpkg.com/fastparse/-/fastparse-1.1.1.tgz#d1e2643b38a94d7583b479060e6c4affc94071f8" +faye-websocket@~0.11.0: + version "0.11.1" + resolved "https://registry.yarnpkg.com/faye-websocket/-/faye-websocket-0.11.1.tgz#f0efe18c4f56e4f40afc7e06c719fd5ee6188f38" + dependencies: + websocket-driver ">=0.5.1" + fbjs@^0.8.1, fbjs@^0.8.4: version "0.8.9" resolved "https://registry.yarnpkg.com/fbjs/-/fbjs-0.8.9.tgz#180247fbd347dcc9004517b904f865400a0c8f14" @@ -3342,7 +3354,7 @@ json-stringify-safe@^5.0.1, json-stringify-safe@~5.0.1: version "5.0.1" resolved "https://registry.yarnpkg.com/json-stringify-safe/-/json-stringify-safe-5.0.1.tgz#1296a2d58fd45f19a0f6ce01d65701e2c735b6eb" -json3@3.3.2: +json3@3.3.2, json3@^3.3.2: version "3.3.2" resolved "https://registry.yarnpkg.com/json3/-/json3-3.3.2.tgz#3c0434743df93e2f5c42aee7b19bcb483575f4e1" @@ -4237,6 +4249,12 @@ optionator@^0.8.1: type-check "~0.3.2" wordwrap "~1.0.0" +original@>=0.0.5: + version "1.0.0" + resolved "https://registry.yarnpkg.com/original/-/original-1.0.0.tgz#9147f93fa1696d04be61e01bd50baeaca656bd3b" + dependencies: + url-parse "1.0.x" + os-browserify@^0.2.0: version "0.2.1" resolved "https://registry.yarnpkg.com/os-browserify/-/os-browserify-0.2.1.tgz#63fc4ccee5d2d7763d26bbf8601078e6c2e0044f" @@ -4909,6 +4927,10 @@ querystring@0.2.0, querystring@^0.2.0: version "0.2.0" resolved "https://registry.yarnpkg.com/querystring/-/querystring-0.2.0.tgz#b209849203bb25df820da756e747005878521620" +querystringify@0.0.x: + version "0.0.4" + resolved "https://registry.yarnpkg.com/querystringify/-/querystringify-0.0.4.tgz#0cf7f84f9463ff0ae51c4c4b142d95be37724d9c" + raf@^3.1.0: version "3.3.0" resolved "https://registry.yarnpkg.com/raf/-/raf-3.3.0.tgz#93845eeffc773f8129039f677f80a36044eee2c3" @@ -5309,6 +5331,10 @@ require-uncached@^1.0.2: caller-path "^0.1.0" resolve-from "^1.0.0" +requires-port@1.0.x: + version "1.0.0" + resolved "https://registry.yarnpkg.com/requires-port/-/requires-port-1.0.0.tgz#925d2601d39ac485e091cf0da5c6e694dc3dcaff" + resolve-from@^1.0.0: version "1.0.1" resolved "https://registry.yarnpkg.com/resolve-from/-/resolve-from-1.0.1.tgz#26cbfe935d1aeeeabb29bc3fe5aeb01e93d44226" @@ -5495,6 +5521,17 @@ sntp@1.x.x: dependencies: hoek "2.x.x" +sockjs-client@^1.1.2: + version "1.1.2" + resolved "https://registry.yarnpkg.com/sockjs-client/-/sockjs-client-1.1.2.tgz#f0212a8550e4c9468c8cceaeefd2e3493c033ad5" + dependencies: + debug "^2.2.0" + eventsource "0.1.6" + faye-websocket "~0.11.0" + inherits "^2.0.1" + json3 "^3.3.2" + url-parse "^1.1.1" + sort-keys@^1.0.0: version "1.1.2" resolved "https://registry.yarnpkg.com/sort-keys/-/sort-keys-1.1.2.tgz#441b6d4d346798f1b4e49e8920adfba0e543f9ad" @@ -6007,6 +6044,20 @@ url-loader@^0.5.7: loader-utils "0.2.x" mime "1.2.x" +url-parse@1.0.x: + version "1.0.5" + resolved "https://registry.yarnpkg.com/url-parse/-/url-parse-1.0.5.tgz#0854860422afdcfefeb6c965c662d4800169927b" + dependencies: + querystringify "0.0.x" + requires-port "1.0.x" + +url-parse@^1.1.1: + version "1.1.7" + resolved "https://registry.yarnpkg.com/url-parse/-/url-parse-1.1.7.tgz#025cff999653a459ab34232147d89514cc87d74a" + dependencies: + querystringify "0.0.x" + requires-port "1.0.x" + url@^0.11.0: version "0.11.0" resolved "https://registry.yarnpkg.com/url/-/url-0.11.0.tgz#3838e97cfc60521eb73c525a8e55bfdd9e2e28f1" @@ -6182,6 +6233,16 @@ webpack@1.13.1: watchpack "^0.2.1" webpack-core "~0.6.0" +websocket-driver@>=0.5.1: + version "0.6.5" + resolved "https://registry.yarnpkg.com/websocket-driver/-/websocket-driver-0.6.5.tgz#5cb2556ceb85f4373c6d8238aa691c8454e13a36" + dependencies: + websocket-extensions ">=0.1.1" + +websocket-extensions@>=0.1.1: + version "0.1.1" + resolved "https://registry.yarnpkg.com/websocket-extensions/-/websocket-extensions-0.1.1.tgz#76899499c184b6ef754377c2dbb0cd6cb55d29e7" + whatwg-encoding@^1.0.1: version "1.0.1" resolved "https://registry.yarnpkg.com/whatwg-encoding/-/whatwg-encoding-1.0.1.tgz#3c6c451a198ee7aec55b1ec61d0920c67801a5f4"