Endpoint integration tests (#28)

* Quick fix where JWTRenewalMiddleware wasn't saving the update session to the client

* integration tests for all user/account management HTTP endpoints

close #15

* Combine checks in CheckUser

* Moving t.Fatals into utility functions

* Simplifying get user by id or username flow

* Fixing incorrect error log message

* Simplifying checkUser compare
This commit is contained in:
Mike Arpaia 2016-08-02 15:39:20 -07:00 committed by GitHub
parent fc1b8eaa05
commit 24638413c4
6 changed files with 627 additions and 175 deletions

25
auth.go
View File

@ -47,7 +47,7 @@ func (vc *ViewerContext) UserID() (uint, error) {
return 0, errors.New("No user set") return 0, errors.New("No user set")
} }
func (vc *ViewerContext) CanPerformActions(db *gorm.DB) bool { func (vc *ViewerContext) CanPerformActions() bool {
if vc.user == nil { if vc.user == nil {
return false return false
} }
@ -70,12 +70,12 @@ func (vc *ViewerContext) IsUserID(id uint) bool {
return false return false
} }
func (vc *ViewerContext) CanPerformWriteActionOnUser(db *gorm.DB, u *User) bool { func (vc *ViewerContext) CanPerformWriteActionOnUser(u *User) bool {
return vc.CanPerformActions(db) && (vc.IsUserID(u.ID) || vc.IsAdmin()) return vc.CanPerformActions() && (vc.IsUserID(u.ID) || vc.IsAdmin())
} }
func (vc *ViewerContext) CanPerformReadActionOnUser(db *gorm.DB, u *User) bool { func (vc *ViewerContext) CanPerformReadActionOnUser(u *User) bool {
return vc.CanPerformActions(db) && (vc.IsUserID(u.ID) || vc.IsAdmin()) return vc.CanPerformActions()
} }
// GenerateJWT generates a JWT token in serialized string form given a // GenerateJWT generates a JWT token in serialized string form given a
@ -147,6 +147,7 @@ func JWTRenewalMiddleware(c *gin.Context) {
} }
session.Set("jwt", jwt) session.Set("jwt", jwt)
session.Save()
c.Next() c.Next()
} }
@ -246,12 +247,14 @@ func Login(c *gin.Context) {
session.Set("jwt", token) session.Set("jwt", token)
session.Save() session.Save()
c.JSON(200, map[string]interface{}{ c.JSON(200, GetUserResponseBody{
"id": user.ID, ID: user.ID,
"username": user.Username, Username: user.Username,
"email": user.Email, Name: user.Name,
"name": user.Name, Email: user.Email,
"admin": user.Admin, Admin: user.Admin,
Enabled: user.Enabled,
NeedsPasswordReset: user.NeedsPasswordReset,
}) })
} }

View File

@ -179,19 +179,19 @@ func TestCanPerformActionsOnUser(t *testing.T) {
adminVC := GenerateVC(admin) adminVC := GenerateVC(admin)
user1VC := GenerateVC(user1) user1VC := GenerateVC(user1)
if !adminVC.CanPerformWriteActionOnUser(db, user1) || !adminVC.CanPerformWriteActionOnUser(db, user2) { if !adminVC.CanPerformWriteActionOnUser(user1) || !adminVC.CanPerformWriteActionOnUser(user2) {
t.Fatal("Admin should be able to perform writes on users") t.Fatal("Admin should be able to perform writes on users")
} }
if !adminVC.CanPerformReadActionOnUser(db, user1) || !adminVC.CanPerformReadActionOnUser(db, user2) { if !adminVC.CanPerformReadActionOnUser(user1) || !adminVC.CanPerformReadActionOnUser(user2) {
t.Fatal("Admin should be able to perform reads on users") t.Fatal("Admin should be able to perform reads on users")
} }
if user1VC.CanPerformWriteActionOnUser(db, user2) { if user1VC.CanPerformWriteActionOnUser(user2) {
t.Fatal("user1 shouldn't be able to perform writes on user2") t.Fatal("user1 shouldn't be able to perform writes on user2")
} }
if user1VC.CanPerformReadActionOnUser(db, user2) { if !user1VC.CanPerformReadActionOnUser(user2) {
t.Fatal("user1 should be able to perform reads on user2") t.Fatal("user1 should be able to perform reads on user2")
} }

View File

@ -25,6 +25,13 @@ func UnauthorizedError(c *gin.Context) {
c.JSON(401, ServerError("Unauthorized")) c.JSON(401, ServerError("Unauthorized"))
} }
// MalformedRequestError emits a response that is appropriate in the event that
// a request is received by a user which does not have required fields or is in
// some way malformed
func MalformedRequestError(c *gin.Context) {
c.JSON(400, ServerError("Malformed request"))
}
func createTestServer() *gin.Engine { func createTestServer() *gin.Engine {
server := gin.New() server := gin.New()
server.Use(TestingDatabaseMiddleware) server.Use(TestingDatabaseMiddleware)
@ -52,7 +59,7 @@ func CreateServer() *gin.Engine {
kolide.PATCH("/user", ModifyUser) kolide.PATCH("/user", ModifyUser)
kolide.DELETE("/user", DeleteUser) kolide.DELETE("/user", DeleteUser)
kolide.PATCH("/user/password", ResetUserPassword) kolide.PATCH("/user/password", ChangeUserPassword)
kolide.PATCH("/user/admin", SetUserAdminState) kolide.PATCH("/user/admin", SetUserAdminState)
kolide.PATCH("/user/enabled", SetUserEnabledState) kolide.PATCH("/user/enabled", SetUserEnabledState)

469
story_test.go Normal file
View File

@ -0,0 +1,469 @@
package main
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
)
type integrationRequests struct {
r *gin.Engine
db *gorm.DB
t *testing.T
}
func (req *integrationRequests) New(t *testing.T) {
req.t = t
req.r = createTestServer()
req.r.Use(testSessionMiddleware)
req.r.Use(JWTRenewalMiddleware)
req.db, _ = openTestDB()
injectedTestDB = req.db
// Until we have a better solution for first-user onboarding, manually
// create an admin
_, err := NewUser(req.db, "admin", "foobar", "admin@kolide.co", true, false)
if err != nil {
panic(err.Error())
}
req.r.POST("/login", Login)
req.r.GET("/logout", Logout)
req.r.POST("/user", GetUser)
req.r.PUT("/user", CreateUser)
req.r.PATCH("/user", ModifyUser)
req.r.DELETE("/user", DeleteUser)
req.r.PATCH("/user/password", ChangeUserPassword)
req.r.PATCH("/user/admin", SetUserAdminState)
req.r.PATCH("/user/enabled", SetUserEnabledState)
}
func (req *integrationRequests) Login(username, password string, sessionOut *string) {
response := httptest.NewRecorder()
body, err := json.Marshal(LoginRequestBody{
Username: username,
Password: password,
})
if err != nil {
req.t.Fatal(err.Error())
return
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("POST", "/login", buff)
request.Header.Set("Content-Type", "application/json")
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return
}
*sessionOut = response.Header().Get("Set-Cookie")
return
}
func (req *integrationRequests) CreateUser(username, password, email string, admin, reset bool, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(CreateUserRequestBody{
Username: username,
Password: password,
Email: email,
Admin: admin,
NeedsPasswordReset: reset,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PUT", "/user", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
return nil
}
return &responseBody
}
func (req *integrationRequests) GetUser(username string, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(GetUserRequestBody{
Username: username,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("POST", "/user", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
return nil
}
return &responseBody
}
func (req *integrationRequests) ModifyUser(username, name, email string, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(ModifyUserRequestBody{
Username: username,
Name: name,
Email: email,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PATCH", "/user", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
return nil
}
return &responseBody
}
func (req *integrationRequests) DeleteUser(username string, session *string) {
response := httptest.NewRecorder()
body, err := json.Marshal(DeleteUserRequestBody{
Username: username,
})
if err != nil {
req.t.Fatal(err.Error())
return
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("DELETE", "/user", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return
}
*session = response.Header().Get("Set-Cookie")
return
}
func (req *integrationRequests) ChangePassword(username, currentPassword, newPassword string, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(ChangePasswordRequestBody{
Username: username,
CurrentPassword: currentPassword,
NewPassword: newPassword,
NewPasswordConfim: newPassword,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PATCH", "/user/password", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
}
return &responseBody
}
func (req *integrationRequests) SetAdminState(username string, admin bool, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(SetUserAdminStateRequestBody{
Username: username,
Admin: admin,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PATCH", "/user/admin", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
}
return &responseBody
}
func (req *integrationRequests) SetEnabledState(username string, enabled bool, session *string) *GetUserResponseBody {
response := httptest.NewRecorder()
body, err := json.Marshal(SetUserEnabledStateRequestBody{
Username: username,
Enabled: enabled,
})
if err != nil {
req.t.Fatal(err.Error())
return nil
}
buff := new(bytes.Buffer)
buff.Write(body)
request, _ := http.NewRequest("PATCH", "/user/enabled", buff)
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Cookie", *session)
req.r.ServeHTTP(response, request)
if response.Code != 200 {
req.t.Fatalf("Response code: %d", response.Code)
return nil
}
*session = response.Header().Get("Set-Cookie")
var responseBody GetUserResponseBody
err = json.Unmarshal(response.Body.Bytes(), &responseBody)
if err != nil {
req.t.Fatal(err.Error())
return nil
}
return &responseBody
}
func (req *integrationRequests) CheckUser(username, email, name string, admin, reset, enabled bool) {
var user User
err := req.db.Where("username = ?", username).First(&user).Error
if err != nil {
req.t.Fatal(err.Error())
return
}
if user.Email != email {
req.t.Fatalf("user's email was not set in the DB: %s", user.Email)
}
if user.Admin != admin {
req.t.Fatal("user admin settings don't match")
}
if user.NeedsPasswordReset != reset {
req.t.Fatal("user reset settings don't match")
}
if user.Enabled != enabled {
req.t.Fatal("user enabled settings don't match")
}
if user.Name != name {
req.t.Fatalf("user names don't match: %s and %s", user.Name, name)
}
return
}
func (req *integrationRequests) GetAndCheckUser(username string, session *string) {
resp := req.GetUser(username, session)
req.CheckUser(username, resp.Email, resp.Name, resp.Admin, resp.NeedsPasswordReset, resp.Enabled)
}
func (req *integrationRequests) CreateAndCheckUser(username, password, email, name string, admin, reset bool, session *string) {
resp := req.CreateUser(username, password, email, admin, reset, session)
req.CheckUser(username, email, name, admin, reset, resp.Enabled)
}
func (req *integrationRequests) ModifyAndCheckUser(username, email, name string, admin, reset bool, session *string) {
resp := req.ModifyUser(username, name, email, session)
req.CheckUser(username, email, name, admin, reset, resp.Enabled)
}
func (req *integrationRequests) DeleteAndCheckUser(username string, session *string) {
req.DeleteUser(username, session)
var user User
err := req.db.Where("username = ?", username).First(&user).Error
if err == nil {
req.t.Fatal("User should have been deleted.")
}
}
func (req *integrationRequests) SetEnabledStateAndCheckUser(username string, enabled bool, session *string) {
resp := req.SetEnabledState(username, enabled, session)
req.CheckUser(username, resp.Email, resp.Name, resp.Admin, resp.NeedsPasswordReset, enabled)
}
func (req *integrationRequests) SetAdminStateAndCheckUser(username string, admin bool, session *string) {
resp := req.SetAdminState(username, admin, session)
req.CheckUser(username, resp.Email, resp.Name, admin, resp.NeedsPasswordReset, resp.Enabled)
}
func TestUserAndAccountManagement(t *testing.T) {
// Create and configure the webserver which will be used to handle the tests
var req integrationRequests
req.New(t)
// Instantiate the variables that will store the most recent session cookie
// for each user context that will be created
var adminSession string
var admin2Session string
var user1Session string
var user2Session string
// Test logging in with the first admin
req.Login("admin", "foobar", &adminSession)
// Once admin is logged in, create a user using a valid admin session
req.CreateAndCheckUser("user1", "foobar", "user1@kolide.co", "", false, false, &adminSession)
// Once admin is logged in, create another admin account using a valid
// admin session
req.CreateAndCheckUser("admin2", "foobar", "admin2@kolide.co", "", true, false, &adminSession)
// Once admin has created admin2, log in with admin2 to get a session
// context for admin2
req.Login("admin2", "foobar", &admin2Session)
// Use an admin created via the API to create a user via the API
req.CreateAndCheckUser("user2", "foobar", "user2@kolide.co", "", false, false, &admin2Session)
// Once admin has created user1, log in with user1 to get a session context
// for user1
req.Login("user1", "foobar", &user1Session)
// Once admin2 has created user2, log in with user1 to get a session context
// for user2
req.Login("user2", "foobar", &user2Session)
// Get info on user2 as admin2
req.GetAndCheckUser("user2", &admin2Session)
// Get info on admin2 as user2
req.GetAndCheckUser("admin2", &user2Session)
// Modify user1 as admin
req.ModifyAndCheckUser("user1", "user1@kolide.co", "User One", false, false, &adminSession)
// Modify user2 as user2
req.ModifyAndCheckUser("user2", "user2@kolide.co", "User Two", false, false, &user2Session)
// admin resets user1 password
req.ChangePassword("user1", "", "bazz1", &adminSession)
// user1 logs in with new password
req.Login("user1", "bazz1", &user1Session)
// user2 resets user2 password
req.ChangePassword("user2", "foobar", "bazz2", &user2Session)
// user2 logs in with new password
req.Login("user2", "bazz2", &user2Session)
// admin2 promotes user2 to admin
req.SetAdminStateAndCheckUser("user2", true, &admin2Session)
// user2 is admin
resp := req.GetUser("user2", &user2Session)
if !resp.Admin {
t.Fatal("user2 should be an admin")
}
// admin demotes user2 from admin
req.SetAdminStateAndCheckUser("user2", false, &adminSession)
// user2 is no longer an admin
resp = req.GetUser("user2", &user2Session)
if resp.Admin {
t.Fatal("user2 shouldn't be an admin")
}
// admin sets user1 as no longer enabled
req.SetEnabledStateAndCheckUser("user1", false, &adminSession)
// user1 is no longer enabled
resp = req.GetUser("user1", &user2Session)
if resp.Enabled {
t.Fatal("user1 shouldn't be enabled")
}
// admin2 re-enables user1
req.SetEnabledStateAndCheckUser("user1", true, &admin2Session)
// user1 can view user2
req.GetUser("user2", &user2Session)
// Delete admin2 as admin1
req.DeleteAndCheckUser("admin2", &adminSession)
// Delete user2 as admin
req.DeleteAndCheckUser("user2", &adminSession)
}

175
users.go
View File

@ -81,7 +81,18 @@ func (u *User) MakeAdmin(db *gorm.DB) error {
} }
type GetUserRequestBody struct { type GetUserRequestBody struct {
ID uint `json:"id" binding:"required"` ID uint `json:"id"`
Username string `json:"username"`
}
type GetUserResponseBody struct {
ID uint `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Name string `json:"name"`
Admin bool `json:"admin"`
Enabled bool `json:"enabled"`
NeedsPasswordReset bool `json:"needs_password_reset"`
} }
func GetUser(c *gin.Context) { func GetUser(c *gin.Context) {
@ -107,26 +118,27 @@ func GetUser(c *gin.Context) {
} }
var user User var user User
err = db.Where("id = ?", body.ID).First(&user).Error user.ID = body.ID
user.Username = body.Username
err = db.Where(&user).First(&user).Error
if err != nil { if err != nil {
logrus.Errorf("Error finding user in database: %s", err.Error())
DatabaseError(c) DatabaseError(c)
return return
} }
if !vc.CanPerformReadActionOnUser(db, &user) { if !vc.CanPerformReadActionOnUser(&user) {
UnauthorizedError(c) UnauthorizedError(c)
return return
} }
c.JSON(200, map[string]interface{}{ c.JSON(200, GetUserResponseBody{
"id": user.ID, ID: user.ID,
"username": user.Username, Username: user.Username,
"name": user.Name, Name: user.Name,
"email": user.Email, Email: user.Email,
"admin": user.Admin, Admin: user.Admin,
"enabled": user.Enabled, Enabled: user.Enabled,
"needs_password_reset": user.NeedsPasswordReset, NeedsPasswordReset: user.NeedsPasswordReset,
}) })
} }
@ -165,18 +177,26 @@ func CreateUser(c *gin.Context) {
return return
} }
_, err = NewUser(db, body.Username, body.Password, body.Email, body.Admin, body.NeedsPasswordReset) user, err := NewUser(db, body.Username, body.Password, body.Email, body.Admin, body.NeedsPasswordReset)
if err != nil { if err != nil {
logrus.Errorf("Error creating new user: %s", err.Error()) logrus.Errorf("Error creating new user: %s", err.Error())
DatabaseError(c) DatabaseError(c)
return return
} }
c.JSON(200, nil) c.JSON(200, GetUserResponseBody{
ID: user.ID,
Username: user.Username,
Name: user.Name,
Email: user.Email,
Admin: user.Admin,
Enabled: user.Enabled,
NeedsPasswordReset: user.NeedsPasswordReset,
})
} }
type ModifyUserRequestBody struct { type ModifyUserRequestBody struct {
ID uint `json:"id" binding:"required"` ID uint `json:"id"`
Username string `json:"username"` Username string `json:"username"`
Name string `json:"name"` Name string `json:"name"`
Email string `json:"email"` Email string `json:"email"`
@ -205,29 +225,46 @@ func ModifyUser(c *gin.Context) {
} }
var user User var user User
err = db.Where("id = ?", body.ID).First(&user).Error user.ID = body.ID
user.Username = body.Username
err = db.Where(&user).First(&user).Error
if err != nil { if err != nil {
logrus.Errorf("Error finding user in database: %s", err.Error())
DatabaseError(c) DatabaseError(c)
return return
} }
if !vc.CanPerformWriteActionOnUser(db, &user) { if !vc.CanPerformWriteActionOnUser(&user) {
UnauthorizedError(c) UnauthorizedError(c)
return return
} }
if body.Name != "" {
user.Name = body.Name
}
if body.Email != "" {
user.Email = body.Email
}
err = db.Save(&user).Error err = db.Save(&user).Error
if err != nil { if err != nil {
logrus.Errorf("Error updating user in database: %s", err.Error()) logrus.Errorf("Error updating user in database: %s", err.Error())
DatabaseError(c) DatabaseError(c)
return return
} }
c.JSON(200, nil) c.JSON(200, GetUserResponseBody{
ID: user.ID,
Username: user.Username,
Name: user.Name,
Email: user.Email,
Admin: user.Admin,
Enabled: user.Enabled,
NeedsPasswordReset: user.NeedsPasswordReset,
})
} }
type DeleteUserRequestBody struct { type DeleteUserRequestBody struct {
ID uint `json:"id" binding:"required"` ID uint `json:"id"`
Username string `json:"username"`
} }
func DeleteUser(c *gin.Context) { func DeleteUser(c *gin.Context) {
@ -258,9 +295,10 @@ func DeleteUser(c *gin.Context) {
} }
var user User var user User
err = db.Where("id = ?", body.ID).First(&user).Error user.ID = body.ID
user.Username = body.Username
err = db.Where(&user).First(&user).Error
if err != nil { if err != nil {
logrus.Errorf("Error finding user in database: %s", err.Error())
DatabaseError(c) DatabaseError(c)
return return
} }
@ -275,20 +313,29 @@ func DeleteUser(c *gin.Context) {
} }
type ResetPasswordRequestBody struct { type ResetPasswordRequestBody struct {
ID uint `json:"id" binding:"required"` ID uint `json:"id"`
Username string `json:"username"`
Password string `json:"password" binding:"required"` Password string `json:"password" binding:"required"`
PasswordConfim string `json:"password_confirm" binding:"required"` PasswordConfim string `json:"password_confirm" binding:"required"`
} }
func ResetUserPassword(c *gin.Context) { type ChangePasswordRequestBody struct {
var body ResetPasswordRequestBody ID uint `json:"id"`
Username string `json:"username"`
CurrentPassword string `json:"current_password"`
NewPassword string `json:"new_password" binding:"required"`
NewPasswordConfim string `json:"new_password_confirm" binding:"required"`
}
func ChangeUserPassword(c *gin.Context) {
var body ChangePasswordRequestBody
err := c.BindJSON(&body) err := c.BindJSON(&body)
if err != nil { if err != nil {
logrus.Errorf("Error parsing ResetPassword post body: %s", err.Error()) logrus.Errorf("Error parsing ResetPassword post body: %s", err.Error())
return return
} }
if body.Password != body.PasswordConfim { if body.NewPassword != body.NewPasswordConfim {
c.JSON(406, map[string]interface{}{"error": "Passwords do not match"}) c.JSON(406, map[string]interface{}{"error": "Passwords do not match"})
return return
} }
@ -300,26 +347,34 @@ func ResetUserPassword(c *gin.Context) {
return return
} }
var user User
user.ID = body.ID
user.Username = body.Username
err = db.Where(&user).First(&user).Error
if err != nil {
DatabaseError(c)
return
}
vc, err := VC(c, db) vc, err := VC(c, db)
if err != nil { if err != nil {
logrus.Errorf("Could not create VC: %s", err.Error()) logrus.Errorf("Could not create VC: %s", err.Error())
DatabaseError(c) DatabaseError(c)
return return
} }
var user User
err = db.Where("id = ?", body.ID).First(&user).Error
if err != nil {
logrus.Errorf("Error finding user in database: %s", err.Error())
DatabaseError(c)
return
}
if !vc.CanPerformWriteActionOnUser(db, &user) { if !vc.IsAdmin() {
if !vc.IsUserID(user.ID) {
UnauthorizedError(c) UnauthorizedError(c)
return return
} }
if user.ValidatePassword(body.CurrentPassword) != nil {
UnauthorizedError(c)
return
}
}
err = user.SetPassword(db, body.Password) err = user.SetPassword(db, body.NewPassword)
if err != nil { if err != nil {
logrus.Errorf("Error setting user password: %s", err.Error()) logrus.Errorf("Error setting user password: %s", err.Error())
// xxx don't try to write to the db? // xxx don't try to write to the db?
@ -331,12 +386,21 @@ func ResetUserPassword(c *gin.Context) {
DatabaseError(c) DatabaseError(c)
return return
} }
c.JSON(200, nil) c.JSON(200, GetUserResponseBody{
ID: user.ID,
Username: user.Username,
Name: user.Name,
Email: user.Email,
Admin: user.Admin,
Enabled: user.Enabled,
NeedsPasswordReset: user.NeedsPasswordReset,
})
} }
type SetUserAdminStateRequestBody struct { type SetUserAdminStateRequestBody struct {
ID uint `json:"id" binding:"required"` ID uint `json:"id"`
Admin bool `json:"admin" binding:"required"` Username string `json:"username"`
Admin bool `json:"admin"`
} }
func SetUserAdminState(c *gin.Context) { func SetUserAdminState(c *gin.Context) {
@ -367,9 +431,10 @@ func SetUserAdminState(c *gin.Context) {
} }
var user User var user User
err = db.Where("id = ?", body.ID).First(&user).Error user.ID = body.ID
user.Username = body.Username
err = db.Where(&user).First(&user).Error
if err != nil { if err != nil {
logrus.Errorf("Error finding user in database: %s", err.Error())
DatabaseError(c) DatabaseError(c)
return return
} }
@ -381,12 +446,21 @@ func SetUserAdminState(c *gin.Context) {
DatabaseError(c) DatabaseError(c)
return return
} }
c.JSON(200, nil) c.JSON(200, GetUserResponseBody{
ID: user.ID,
Username: user.Username,
Name: user.Name,
Email: user.Email,
Admin: user.Admin,
Enabled: user.Enabled,
NeedsPasswordReset: user.NeedsPasswordReset,
})
} }
type SetUserEnabledStateRequestBody struct { type SetUserEnabledStateRequestBody struct {
ID uint `json:"id" binding:"required"` ID uint `json:"id"`
Enabled bool `json:"enabled" binding:"required"` Username string `json:"username"`
Enabled bool `json:"enabled"`
} }
func SetUserEnabledState(c *gin.Context) { func SetUserEnabledState(c *gin.Context) {
@ -417,9 +491,10 @@ func SetUserEnabledState(c *gin.Context) {
} }
var user User var user User
err = db.Where("id = ?", body.ID).First(&user).Error user.ID = body.ID
user.Username = body.Username
err = db.Where(&user).First(&user).Error
if err != nil { if err != nil {
logrus.Errorf("Error finding user in database: %s", err.Error())
DatabaseError(c) DatabaseError(c)
return return
} }
@ -431,5 +506,13 @@ func SetUserEnabledState(c *gin.Context) {
DatabaseError(c) DatabaseError(c)
return return
} }
c.JSON(200, nil) c.JSON(200, GetUserResponseBody{
ID: user.ID,
Username: user.Username,
Name: user.Name,
Email: user.Email,
Admin: user.Admin,
Enabled: user.Enabled,
NeedsPasswordReset: user.NeedsPasswordReset,
})
} }

View File

@ -1,10 +1,6 @@
package main package main
import ( import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"testing" "testing"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@ -201,109 +197,3 @@ func TestSetPassword(t *testing.T) {
t.Fatal(err.Error()) t.Fatal(err.Error())
} }
} }
func TestUserManagementIntegration(t *testing.T) {
r := createTestServer()
r.Use(testSessionMiddleware)
r.Use(JWTRenewalMiddleware)
db, err := openTestDB()
if err != nil {
t.Fatal(err.Error())
}
injectedTestDB = db
admin, err := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false)
if err != nil {
t.Fatal(err.Error())
}
_ = admin
r.POST("/login", Login)
r.GET("/logout", Logout)
r.GET("/user", GetUser)
r.PUT("/user", CreateUser)
r.PATCH("/user", ModifyUser)
r.DELETE("/user", DeleteUser)
res1 := httptest.NewRecorder()
body1, err := json.Marshal(LoginRequestBody{
Username: "admin",
Password: "foobar",
})
if err != nil {
t.Fatal(err.Error())
}
buff1 := new(bytes.Buffer)
buff1.Write(body1)
req1, _ := http.NewRequest("POST", "/login", buff1)
req1.Header.Set("Content-Type", "application/json")
r.ServeHTTP(res1, req1)
if res1.Code != 200 {
t.Fatalf("Response code: %d", res1.Code)
}
res2 := httptest.NewRecorder()
body2, err := json.Marshal(CreateUserRequestBody{
Username: "marpaia",
Password: "foobar",
Email: "mike@kolide.co",
Admin: false,
NeedsPasswordReset: false,
})
if err != nil {
t.Fatal(err.Error())
}
buff2 := new(bytes.Buffer)
buff2.Write(body2)
req2, _ := http.NewRequest("PUT", "/user", buff2)
req2.Header.Set("Content-Type", "application/json")
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
r.ServeHTTP(res2, req2)
res3 := httptest.NewRecorder()
body3, err := json.Marshal(CreateUserRequestBody{
Username: "admin2",
Password: "foobar",
Email: "admin2@kolide.co",
Admin: true,
NeedsPasswordReset: false,
})
if err != nil {
t.Fatal(err.Error())
}
buff3 := new(bytes.Buffer)
buff3.Write(body3)
req3, _ := http.NewRequest("PUT", "/user", buff3)
req3.Header.Set("Content-Type", "application/json")
req3.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
r.ServeHTTP(res3, req3)
var user User
err = db.Where("username = ?", "marpaia").First(&user).Error
if err != nil {
t.Fatal(err.Error())
}
if user.Email != "mike@kolide.co" {
t.Fatalf("user's email was not set in the DB: %s", user.Email)
}
if user.Admin {
t.Fatal("user shouldn't be admin")
}
var admin2 User
err = db.Where("username = ?", "admin2").First(&admin2).Error
if err != nil {
t.Fatal(err.Error())
}
if admin2.Email != "admin2@kolide.co" {
t.Fatalf("admin2's email was not set in the DB: %s", admin2.Email)
}
if !admin2.Admin {
t.Fatal("admin2 should be admin")
}
}