diff --git a/server/service/service_campaigns.go b/server/service/service_campaigns.go index 9ecf559e9..ff5761d0e 100644 --- a/server/service/service_campaigns.go +++ b/server/service/service_campaigns.go @@ -70,6 +70,17 @@ type targetTotals struct { MissingInAction uint `json:"missing_in_action"` } +const ( + campaignStatusPending = "pending" + campaignStatusFinished = "finished" +) + +type campaignStatus struct { + ExpectedResults uint `json:"expected_results"` + ActualResults uint `json:"actual_results"` + Status string `json:"status"` +} + func (svc service) StreamCampaignResults(ctx context.Context, conn *websocket.Conn, campaignID uint) { // Find the campaign and ensure it is active campaign, err := svc.ds.DistributedQueryCampaign(campaignID) @@ -107,6 +118,12 @@ func (svc service) StreamCampaignResults(ctx context.Context, conn *websocket.Co return } + status := campaignStatus{ + Status: campaignStatusPending, + } + + lastStatus := status.Status + // Loop, pushing updates to results and expected totals for { select { @@ -118,6 +135,7 @@ func (svc service) StreamCampaignResults(ctx context.Context, conn *websocket.Co if err != nil { fmt.Println("error writing to channel") } + status.ActualResults++ } case <-time.After(1 * time.Second): @@ -142,10 +160,22 @@ func (svc service) StreamCampaignResults(ctx context.Context, conn *websocket.Co Offline: metrics.OfflineHosts, MissingInAction: metrics.MissingInActionHosts, } - if err = conn.WriteJSONMessage("totals", totals); err != nil { return } + + status.ExpectedResults = totals.Online + if status.ActualResults >= status.ExpectedResults { + status.Status = campaignStatusFinished + } + // only write status message if status has changed + if lastStatus != status.Status { + lastStatus = status.Status + if err = conn.WriteJSONMessage("status", status); err != nil { + return + } + } + } }