mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 08:55:24 +00:00
Moving sessions code into sub-package (#42)
Since the sessions code mostly stands on it's own, I wanted to break the dependencies apart from it and move it into it's own package.
This commit is contained in:
parent
cd8057e860
commit
fe2bf7eb2b
41
auth.go
41
auth.go
@ -7,8 +7,8 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
@ -99,32 +99,23 @@ func EmptyVC() *ViewerContext {
|
||||
// and generate an appropriate ViewerContext given the data in the session.
|
||||
func VC(c *gin.Context) *ViewerContext {
|
||||
sm := NewSessionManager(c)
|
||||
return sm.VC()
|
||||
session, err := sm.Session()
|
||||
if err != nil {
|
||||
return EmptyVC()
|
||||
}
|
||||
return VCForID(GetDB(c), session.UserID)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// JSON Web Tokens
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
func VCForID(db *gorm.DB, id uint) *ViewerContext {
|
||||
// Generating a VC requires a user struct. Attempt to populate one using
|
||||
// the user id of the current session holder
|
||||
user := &User{BaseModel: BaseModel{ID: id}}
|
||||
err := db.Where(user).First(user).Error
|
||||
if err != nil {
|
||||
return EmptyVC()
|
||||
}
|
||||
|
||||
// Given a session key create a JWT to be delivered to the client
|
||||
func GenerateJWT(sessionKey string) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"session_key": sessionKey,
|
||||
})
|
||||
|
||||
return token.SignedString([]byte(config.App.JWTKey))
|
||||
}
|
||||
|
||||
// ParseJWT attempts to parse a JWT token in serialized string form into a
|
||||
// JWT token in a deserialized jwt.Token struct.
|
||||
func ParseJWT(token string) (*jwt.Token, error) {
|
||||
return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
|
||||
method, ok := t.Method.(*jwt.SigningMethodHMAC)
|
||||
if !ok || method != jwt.SigningMethodHS256 {
|
||||
return nil, errors.New("Unexpected signing method")
|
||||
}
|
||||
return []byte(config.App.JWTKey), nil
|
||||
})
|
||||
return GenerateVC(user)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
@ -194,7 +185,7 @@ func Login(c *gin.Context) {
|
||||
}
|
||||
|
||||
sm := NewSessionManager(c)
|
||||
sm.MakeSessionForUser(user)
|
||||
sm.MakeSessionForUserID(user.ID)
|
||||
err = sm.Save()
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
|
22
auth_test.go
22
auth_test.go
@ -5,7 +5,6 @@ import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@ -24,23 +23,6 @@ func TestGenerateVC(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestGenerateJWT(t *testing.T) {
|
||||
tokenString, err := GenerateJWT("4")
|
||||
token, err := ParseJWT(tokenString)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
t.Fatal("Token is invalid")
|
||||
}
|
||||
|
||||
sessionKey := claims["session_key"].(string)
|
||||
if sessionKey != "4" {
|
||||
t.Fatalf("Claims are incorrect. session key is %s", sessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVC(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
r := createEmptyTestServer(db)
|
||||
@ -57,7 +39,7 @@ func TestVC(t *testing.T) {
|
||||
|
||||
r.GET("/admin_login", func(c *gin.Context) {
|
||||
sm := NewSessionManager(c)
|
||||
sm.MakeSessionForUser(admin)
|
||||
sm.MakeSessionForUserID(admin.ID)
|
||||
err := sm.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
@ -67,7 +49,7 @@ func TestVC(t *testing.T) {
|
||||
|
||||
r.GET("/user_login", func(c *gin.Context) {
|
||||
sm := NewSessionManager(c)
|
||||
sm.MakeSessionForUser(user)
|
||||
sm.MakeSessionForUserID(user.ID)
|
||||
err := sm.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||
_ "github.com/jinzhu/gorm/dialects/sqlite"
|
||||
"github.com/kolide/kolide-ose/sessions"
|
||||
)
|
||||
|
||||
// Get the database connection from the context, or panic
|
||||
@ -143,7 +144,7 @@ type Decorator struct {
|
||||
|
||||
var tables = [...]interface{}{
|
||||
&User{},
|
||||
&Session{},
|
||||
&sessions.Session{},
|
||||
&ScheduledQuery{},
|
||||
&Pack{},
|
||||
&DiscoveryQuery{},
|
||||
|
19
server.go
19
server.go
@ -9,6 +9,7 @@ import (
|
||||
"github.com/gin-gonic/contrib/ginrus"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/kolide/kolide-ose/sessions"
|
||||
)
|
||||
|
||||
// ServerError is a helper which accepts a string error and returns a map in
|
||||
@ -54,6 +55,17 @@ func DatabaseMiddleware(db *gorm.DB) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// NewSessionManager allows you to get a SessionManager instance for a given
|
||||
// web request. Unless you're interacting with login, logout, or core auth
|
||||
// code, this should be abstracted by the ViewerContext pattern.
|
||||
func NewSessionManager(c *gin.Context) *sessions.SessionManager {
|
||||
return &sessions.SessionManager{
|
||||
Request: c.Request,
|
||||
Backend: GetSessionBackend(c),
|
||||
Writer: c.Writer,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateServer creates a gin.Engine HTTP server and configures it to be in a
|
||||
// state such that it is ready to serve HTTP requests for the kolide application
|
||||
func CreateServer(db *gorm.DB, w io.Writer) *gin.Engine {
|
||||
@ -61,6 +73,13 @@ func CreateServer(db *gorm.DB, w io.Writer) *gin.Engine {
|
||||
server.Use(DatabaseMiddleware(db))
|
||||
server.Use(SessionBackendMiddleware)
|
||||
|
||||
sessions.Configure(&sessions.SessionConfiguration{
|
||||
CookieName: "KolideSession",
|
||||
JWTKey: config.App.JWTKey,
|
||||
SessionKeySize: config.App.SessionKeySize,
|
||||
Lifespan: config.App.SessionExpirationSeconds,
|
||||
})
|
||||
|
||||
// TODO: The following loggers are not synchronized with each other or
|
||||
// logrus.StandardLogger() used through the rest of the codebase. As
|
||||
// such, their output may become intermingled.
|
||||
|
570
sessions.go
570
sessions.go
@ -1,570 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
// An error returned by SessionBackend.Get() if no session record was found
|
||||
// in the database
|
||||
ErrNoActiveSession = errors.New("Active session is not present in the database")
|
||||
|
||||
// An error returned by SessionBackend methods when no session object has
|
||||
// been created yet but the requested action requires one
|
||||
ErrSessionNotCreated = errors.New("The session has not been created")
|
||||
|
||||
// An error returned by SessionBackend.Get() when a session is requested but
|
||||
// it has expired
|
||||
ErrSessionExpired = errors.New("The session has expired")
|
||||
)
|
||||
|
||||
const (
|
||||
// The name of the session cookie
|
||||
CookieName = "KolideSession"
|
||||
)
|
||||
|
||||
// Session is the model object which represents what an active session is
|
||||
type Session struct {
|
||||
BaseModel
|
||||
UserID uint `gorm:"not null"`
|
||||
Key string `gorm:"not null;unique_index:idx_session_unique_key"`
|
||||
AccessedAt time.Time
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Managing sessions
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// SessionManager is a management object which helps with the administration of
|
||||
// sessions within the application. Use NewSessionManager to create an instance
|
||||
type SessionManager struct {
|
||||
backend SessionBackend
|
||||
request *http.Request
|
||||
writer http.ResponseWriter
|
||||
session *Session
|
||||
vc *ViewerContext
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewSessionManager allows you to get a SessionManager instance for a given
|
||||
// web request. Unless you're interacting with login, logout, or core auth
|
||||
// code, this should be abstracted by the ViewerContext pattern.
|
||||
func NewSessionManager(c *gin.Context) *SessionManager {
|
||||
return &SessionManager{
|
||||
request: c.Request,
|
||||
backend: GetSessionBackend(c),
|
||||
writer: c.Writer,
|
||||
db: GetDB(c),
|
||||
}
|
||||
}
|
||||
|
||||
// Get the ViewerContext instance for a user represented by the active session
|
||||
func (sm *SessionManager) VC() *ViewerContext {
|
||||
if sm.session == nil {
|
||||
cookie, err := sm.request.Cookie(CookieName)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case http.ErrNoCookie:
|
||||
// No cookie was set
|
||||
return EmptyVC()
|
||||
default:
|
||||
// Something went wrong and the cookie may or may not be set
|
||||
logrus.Errorf("Couldn't get cookie: %s", err.Error())
|
||||
return EmptyVC()
|
||||
}
|
||||
}
|
||||
|
||||
token, err := ParseJWT(cookie.Value)
|
||||
if err != nil {
|
||||
logrus.Errorf("Couldn't parse JWT token string from cookie: %s", err.Error())
|
||||
return EmptyVC()
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
logrus.Error("Could not parse the claims from the JWT token")
|
||||
return EmptyVC()
|
||||
}
|
||||
|
||||
sessionKeyClaim, ok := claims["session_key"]
|
||||
if !ok {
|
||||
logrus.Warn("JWT did not have session_key claim")
|
||||
return EmptyVC()
|
||||
}
|
||||
|
||||
sessionKey, ok := sessionKeyClaim.(string)
|
||||
if !ok {
|
||||
logrus.Warn("JWT session_key claim was not a string")
|
||||
return EmptyVC()
|
||||
}
|
||||
|
||||
session, err := sm.backend.FindKey(sessionKey)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case ErrNoActiveSession:
|
||||
// If the code path got this far, it's likely that the user was logged
|
||||
// in some time in the past, but their session has been expired since
|
||||
// their last usage of the application
|
||||
return EmptyVC()
|
||||
default:
|
||||
logrus.Errorf("Couldn't call Get on backend object: %s", err.Error())
|
||||
return EmptyVC()
|
||||
}
|
||||
}
|
||||
sm.session = session
|
||||
}
|
||||
|
||||
if sm.vc == nil {
|
||||
// Generating a VC requires a user struct. Attempt to populate one using
|
||||
// the user id of the current session holder
|
||||
user := &User{BaseModel: BaseModel{ID: sm.session.UserID}}
|
||||
err := sm.db.Where(user).First(user).Error
|
||||
if err != nil {
|
||||
return EmptyVC()
|
||||
}
|
||||
|
||||
sm.vc = GenerateVC(user)
|
||||
}
|
||||
|
||||
return sm.vc
|
||||
}
|
||||
|
||||
// MakeSessionForUserID creates a session in the database for a given user id.
|
||||
// You must call Save() after calling this.
|
||||
func (sm *SessionManager) MakeSessionForUserID(id uint) error {
|
||||
session, err := sm.backend.Create(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sm.session = session
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeSessionForUserID creates a session in the database for a given user
|
||||
// You must call Save() after calling this.
|
||||
func (sm *SessionManager) MakeSessionForUser(u *User) error {
|
||||
return sm.MakeSessionForUserID(u.ID)
|
||||
}
|
||||
|
||||
// Save writes the current session to a token and delivers the token as a cookie
|
||||
// to the user. Save must be called after every write action on this struct
|
||||
// (MakeSessionForUser, Destroy, etc.)
|
||||
func (sm *SessionManager) Save() error {
|
||||
token, err := GenerateJWT(sm.session.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: set proper flags on cookie for maximum security
|
||||
http.SetCookie(sm.writer, &http.Cookie{
|
||||
Name: CookieName,
|
||||
Value: token,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Destroy deletes the active session from the database and erases the session
|
||||
// instance from this object's access. You must call Save() after calling this.
|
||||
func (sm *SessionManager) Destroy() error {
|
||||
if sm.backend != nil {
|
||||
err := sm.backend.Destroy(sm.session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Session Backend API
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// SessionBackend is the abstract interface that all session backends must
|
||||
// conform to. SessionBackend instances are only expected to exist within the
|
||||
// context of a single request.
|
||||
type SessionBackend interface {
|
||||
// Given a session key, find and return a session object or an error if one
|
||||
// could not be found for the given key
|
||||
FindKey(key string) (*Session, error)
|
||||
|
||||
// Given a session id, find and return a session object or an error if one
|
||||
// could not be found for the given id
|
||||
FindID(id uint) (*Session, error)
|
||||
|
||||
// Find all of the active sessions for a given user
|
||||
FindAllForUser(id uint) ([]*Session, error)
|
||||
|
||||
// Create a session object tied to the given user ID
|
||||
Create(userID uint) (*Session, error)
|
||||
|
||||
// Destroy the currently tracked session
|
||||
Destroy(session *Session) error
|
||||
|
||||
// Destroy all of the sessions for a given user
|
||||
DestroyAllForUser(id uint) error
|
||||
|
||||
// Mark the currently tracked session as access to extend expiration
|
||||
MarkAccessed(session *Session) error
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Session Backend Plugins
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GormSessionBackend stores sessions using a pre-instantiated gorm database
|
||||
// object
|
||||
type GormSessionBackend struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) validate(session *Session) error {
|
||||
if time.Since(session.AccessedAt).Seconds() >= config.App.SessionExpirationSeconds {
|
||||
err := s.db.Delete(session).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ErrSessionExpired
|
||||
}
|
||||
|
||||
err := s.MarkAccessed(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) FindID(id uint) (*Session, error) {
|
||||
session := &Session{
|
||||
BaseModel: BaseModel{
|
||||
ID: id,
|
||||
},
|
||||
}
|
||||
|
||||
err := s.db.Where(session).First(session).Error
|
||||
if err != nil {
|
||||
switch err {
|
||||
case gorm.ErrRecordNotFound:
|
||||
return nil, ErrNoActiveSession
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = s.validate(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) FindKey(key string) (*Session, error) {
|
||||
session := &Session{
|
||||
Key: key,
|
||||
}
|
||||
|
||||
err := s.db.Where(session).First(session).Error
|
||||
if err != nil {
|
||||
switch err {
|
||||
case gorm.ErrRecordNotFound:
|
||||
return nil, ErrNoActiveSession
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = s.validate(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) FindAllForUser(id uint) ([]*Session, error) {
|
||||
var sessions []*Session
|
||||
err := s.db.Where("user_id = ?", id).Find(&sessions).Error
|
||||
return sessions, err
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) Create(userID uint) (*Session, error) {
|
||||
key, err := generateRandomText(config.App.SessionKeySize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session := &Session{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
}
|
||||
|
||||
err = s.db.Create(session).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.MarkAccessed(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) Destroy(session *Session) error {
|
||||
err := s.db.Delete(session).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) DestroyAllForUser(id uint) error {
|
||||
return s.db.Delete(&Session{}, "user_id = ?", id).Error
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) MarkAccessed(session *Session) error {
|
||||
session.AccessedAt = time.Now().UTC()
|
||||
return s.db.Save(session).Error
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Session management HTTP endpoints
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Setting the session backend via a middleware
|
||||
func SessionBackendMiddleware(c *gin.Context) {
|
||||
db := GetDB(c)
|
||||
c.Set("SessionBackend", &GormSessionBackend{db})
|
||||
c.Next()
|
||||
}
|
||||
|
||||
// Get the database connection from the context, or panic
|
||||
func GetSessionBackend(c *gin.Context) SessionBackend {
|
||||
return c.MustGet("SessionBackend").(SessionBackend)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Session management HTTP endpoints
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type DeleteSessionRequestBody struct {
|
||||
SessionID uint `json:"session_id" binding:"required"`
|
||||
}
|
||||
|
||||
func DeleteSession(c *gin.Context) {
|
||||
var body DeleteSessionRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
vc := VC(c)
|
||||
if !vc.CanPerformActions() {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
sb := GetSessionBackend(c)
|
||||
|
||||
session, err := sb.FindID(body.SessionID)
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
|
||||
db := GetDB(c)
|
||||
user := &User{
|
||||
BaseModel: BaseModel{
|
||||
ID: session.UserID,
|
||||
},
|
||||
}
|
||||
err = db.Where(user).First(user).Error
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.CanPerformWriteActionOnUser(user) {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
err = sb.Destroy(session)
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, nil)
|
||||
}
|
||||
|
||||
type DeleteSessionsForUserRequestBody struct {
|
||||
ID uint `json:"id"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
func DeleteSessionsForUser(c *gin.Context) {
|
||||
var body DeleteSessionsForUserRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf(err.Error())
|
||||
}
|
||||
|
||||
vc := VC(c)
|
||||
if !vc.CanPerformActions() {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
db := GetDB(c)
|
||||
var user User
|
||||
user.ID = body.ID
|
||||
user.Username = body.Username
|
||||
err = db.Where(&user).First(&user).Error
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.CanPerformWriteActionOnUser(&user) {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
sb := GetSessionBackend(c)
|
||||
err = sb.DestroyAllForUser(user.ID)
|
||||
err = db.Delete(&Session{}, "user_id = ?", user.ID).Error
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, nil)
|
||||
|
||||
}
|
||||
|
||||
type GetInfoAboutSessionRequestBody struct {
|
||||
SessionKey string `json:"session_key" binding:"required"`
|
||||
}
|
||||
|
||||
type SessionInfoResponseBody struct {
|
||||
SessionID uint `json:"session_id"`
|
||||
UserID uint `json:"user_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
AccessedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
func GetInfoAboutSession(c *gin.Context) {
|
||||
var body GetInfoAboutSessionRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
vc := VC(c)
|
||||
if !vc.CanPerformActions() {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
sb := GetSessionBackend(c)
|
||||
session, err := sb.FindKey(body.SessionKey)
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
db := GetDB(c)
|
||||
var user User
|
||||
user.ID = session.UserID
|
||||
err = db.Where(&user).First(&user).Error
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.IsAdmin() && !vc.IsUserID(user.ID) {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, &SessionInfoResponseBody{
|
||||
SessionID: session.ID,
|
||||
UserID: session.UserID,
|
||||
CreatedAt: session.CreatedAt,
|
||||
AccessedAt: session.AccessedAt,
|
||||
})
|
||||
}
|
||||
|
||||
type GetInfoAboutSessionsForUserRequestBody struct {
|
||||
ID uint `json:"id"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type GetInfoAboutSessionsForUserResponseBody struct {
|
||||
Sessions []*SessionInfoResponseBody `json:"sessions"`
|
||||
}
|
||||
|
||||
func GetInfoAboutSessionsForUser(c *gin.Context) {
|
||||
var body GetInfoAboutSessionsForUserRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
vc := VC(c)
|
||||
if !vc.CanPerformActions() {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
db := GetDB(c)
|
||||
var user User
|
||||
user.ID = body.ID
|
||||
user.Username = body.Username
|
||||
err = db.Where(&user).First(&user).Error
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.IsAdmin() && !vc.IsUserID(user.ID) {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
sb := GetSessionBackend(c)
|
||||
sessions, err := sb.FindAllForUser(user.ID)
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
var response []*SessionInfoResponseBody
|
||||
for _, session := range sessions {
|
||||
response = append(response, &SessionInfoResponseBody{
|
||||
SessionID: session.ID,
|
||||
UserID: session.UserID,
|
||||
CreatedAt: session.CreatedAt,
|
||||
AccessedAt: session.AccessedAt,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(200, &GetInfoAboutSessionsForUserResponseBody{
|
||||
Sessions: response,
|
||||
})
|
||||
}
|
164
sessions/backends.go
Normal file
164
sessions/backends.go
Normal file
@ -0,0 +1,164 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Session Backend API
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// SessionBackend is the abstract interface that all session backends must
|
||||
// conform to. SessionBackend instances are only expected to exist within the
|
||||
// context of a single request.
|
||||
type SessionBackend interface {
|
||||
// Given a session key, find and return a session object or an error if one
|
||||
// could not be found for the given key
|
||||
FindKey(key string) (*Session, error)
|
||||
|
||||
// Given a session id, find and return a session object or an error if one
|
||||
// could not be found for the given id
|
||||
FindID(id uint) (*Session, error)
|
||||
|
||||
// Find all of the active sessions for a given user
|
||||
FindAllForUser(id uint) ([]*Session, error)
|
||||
|
||||
// Create a session object tied to the given user ID
|
||||
Create(userID uint) (*Session, error)
|
||||
|
||||
// Destroy the currently tracked session
|
||||
Destroy(session *Session) error
|
||||
|
||||
// Destroy all of the sessions for a given user
|
||||
DestroyAllForUser(id uint) error
|
||||
|
||||
// Mark the currently tracked session as access to extend expiration
|
||||
MarkAccessed(session *Session) error
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Session Backend Plugins
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// GormSessionBackend stores sessions using a pre-instantiated gorm database
|
||||
// object
|
||||
type GormSessionBackend struct {
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) validate(session *Session) error {
|
||||
if time.Since(session.AccessedAt).Seconds() >= Lifespan {
|
||||
err := s.DB.Delete(session).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ErrSessionExpired
|
||||
}
|
||||
|
||||
err := s.MarkAccessed(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) FindID(id uint) (*Session, error) {
|
||||
session := &Session{
|
||||
ID: id,
|
||||
}
|
||||
|
||||
err := s.DB.Where(session).First(session).Error
|
||||
if err != nil {
|
||||
switch err {
|
||||
case gorm.ErrRecordNotFound:
|
||||
return nil, ErrNoActiveSession
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = s.validate(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) FindKey(key string) (*Session, error) {
|
||||
session := &Session{
|
||||
Key: key,
|
||||
}
|
||||
|
||||
err := s.DB.Where(session).First(session).Error
|
||||
if err != nil {
|
||||
switch err {
|
||||
case gorm.ErrRecordNotFound:
|
||||
return nil, ErrNoActiveSession
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = s.validate(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) FindAllForUser(id uint) ([]*Session, error) {
|
||||
var sessions []*Session
|
||||
err := s.DB.Where("user_id = ?", id).Find(&sessions).Error
|
||||
return sessions, err
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) Create(userID uint) (*Session, error) {
|
||||
key := make([]byte, SessionKeySize)
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session := &Session{
|
||||
UserID: userID,
|
||||
Key: base64.StdEncoding.EncodeToString(key),
|
||||
}
|
||||
|
||||
err = s.DB.Create(session).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.MarkAccessed(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) Destroy(session *Session) error {
|
||||
err := s.DB.Delete(session).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) DestroyAllForUser(id uint) error {
|
||||
return s.DB.Delete(&Session{}, "user_id = ?", id).Error
|
||||
}
|
||||
|
||||
func (s *GormSessionBackend) MarkAccessed(session *Session) error {
|
||||
session.AccessedAt = time.Now().UTC()
|
||||
return s.DB.Save(session).Error
|
||||
}
|
129
sessions/backends_test.go
Normal file
129
sessions/backends_test.go
Normal file
@ -0,0 +1,129 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type mockSessionBackend struct {
|
||||
sessions []*Session
|
||||
id uint
|
||||
}
|
||||
|
||||
func newMockSessionBackend() *mockSessionBackend {
|
||||
return &mockSessionBackend{
|
||||
sessions: []*Session{},
|
||||
id: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *mockSessionBackend) FindID(id uint) (*Session, error) {
|
||||
for _, each := range s.sessions {
|
||||
if each.ID == id {
|
||||
return each, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrNoActiveSession
|
||||
}
|
||||
|
||||
func (s *mockSessionBackend) FindKey(key string) (*Session, error) {
|
||||
for _, each := range s.sessions {
|
||||
if each.Key == key {
|
||||
return each, nil
|
||||
}
|
||||
}
|
||||
return nil, ErrNoActiveSession
|
||||
}
|
||||
|
||||
func (s *mockSessionBackend) FindAllForUser(id uint) ([]*Session, error) {
|
||||
var sessions []*Session
|
||||
for _, each := range sessions {
|
||||
if each.UserID == id {
|
||||
sessions = append(sessions, each)
|
||||
}
|
||||
}
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (s *mockSessionBackend) nextID() uint {
|
||||
s.id = s.id + 1
|
||||
return s.id
|
||||
}
|
||||
|
||||
func (s *mockSessionBackend) Create(userID uint) (*Session, error) {
|
||||
key := make([]byte, SessionKeySize)
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session := &Session{
|
||||
ID: s.nextID(),
|
||||
UserID: userID,
|
||||
Key: base64.StdEncoding.EncodeToString(key),
|
||||
}
|
||||
|
||||
err = s.MarkAccessed(session)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.sessions = append(s.sessions, session)
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *mockSessionBackend) Destroy(session *Session) error {
|
||||
var sessions []*Session
|
||||
for _, each := range s.sessions {
|
||||
if each.ID != session.ID {
|
||||
sessions = append(sessions, each)
|
||||
}
|
||||
}
|
||||
s.sessions = sessions
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *mockSessionBackend) DestroyAllForUser(id uint) error {
|
||||
var sessions []*Session
|
||||
for _, each := range s.sessions {
|
||||
if each.UserID != id {
|
||||
sessions = append(sessions, each)
|
||||
}
|
||||
}
|
||||
s.sessions = sessions
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *mockSessionBackend) MarkAccessed(session *Session) error {
|
||||
session.AccessedAt = time.Now().UTC()
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockResponseWriter struct {
|
||||
headers map[string][]string
|
||||
}
|
||||
|
||||
func newMocResponseWriter() *mockResponseWriter {
|
||||
return &mockResponseWriter{
|
||||
headers: map[string][]string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (w *mockResponseWriter) Header() http.Header {
|
||||
return w.headers
|
||||
}
|
||||
|
||||
func (w *mockResponseWriter) Write([]byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (w *mockResponseWriter) WriteHeader(int) {
|
||||
}
|
||||
|
||||
func TestFindID(t *testing.T) {
|
||||
|
||||
}
|
212
sessions/sessions.go
Normal file
212
sessions/sessions.go
Normal file
@ -0,0 +1,212 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
var (
|
||||
// An error returned by SessionBackend.Get() if no session record was found
|
||||
// in the database
|
||||
ErrNoActiveSession = errors.New("Active session is not present in the database")
|
||||
|
||||
// An error returned by SessionBackend methods when no session object has
|
||||
// been created yet but the requested action requires one
|
||||
ErrSessionNotCreated = errors.New("The session has not been created")
|
||||
|
||||
// An error returned by SessionBackend.Get() when a session is requested but
|
||||
// it has expired
|
||||
ErrSessionExpired = errors.New("The session has expired")
|
||||
|
||||
// An error returned by SessionBackend which indicates that the token
|
||||
// or it's content were malformed
|
||||
ErrSessionMalformed = errors.New("The session token was malformed")
|
||||
)
|
||||
|
||||
var (
|
||||
// The name of the session cookie
|
||||
CookieName = "Session"
|
||||
|
||||
// The key to be used to sign and verify JWTs
|
||||
jwtKey = ""
|
||||
|
||||
// The amount of random data, in bytes, which will be used to create each
|
||||
// session key
|
||||
SessionKeySize = 64
|
||||
|
||||
// The amount of seconds that will pass before an inactive user is logged out
|
||||
Lifespan = float64(60 * 60 * 24 * 90)
|
||||
)
|
||||
|
||||
// Session is the model object which represents what an active session is
|
||||
type Session struct {
|
||||
ID uint `gorm:"primary_key"`
|
||||
CreatedAt time.Time
|
||||
AccessedAt time.Time
|
||||
UserID uint `gorm:"not null"`
|
||||
Key string `gorm:"not null;unique_index:idx_session_unique_key"`
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Configuring the library
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type SessionConfiguration struct {
|
||||
CookieName string
|
||||
JWTKey string
|
||||
SessionKeySize int
|
||||
Lifespan float64
|
||||
}
|
||||
|
||||
func Configure(s *SessionConfiguration) {
|
||||
CookieName = s.CookieName
|
||||
jwtKey = s.JWTKey
|
||||
SessionKeySize = s.SessionKeySize
|
||||
Lifespan = s.Lifespan
|
||||
}
|
||||
|
||||
// Set the name of the cookie
|
||||
func SetCookieName(name string) {
|
||||
CookieName = name
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Managing sessions
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// SessionManager is a management object which helps with the administration of
|
||||
// sessions within the application. Use NewSessionManager to create an instance
|
||||
type SessionManager struct {
|
||||
Backend SessionBackend
|
||||
Request *http.Request
|
||||
Writer http.ResponseWriter
|
||||
session *Session
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Session() (*Session, error) {
|
||||
if sm.session == nil {
|
||||
cookie, err := sm.Request.Cookie(CookieName)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case http.ErrNoCookie:
|
||||
// No cookie was set
|
||||
return nil, err
|
||||
default:
|
||||
// Something went wrong and the cookie may or may not be set
|
||||
logrus.Errorf("Couldn't get cookie: %s", err.Error())
|
||||
return nil, ErrSessionMalformed
|
||||
}
|
||||
}
|
||||
|
||||
token, err := ParseJWT(cookie.Value)
|
||||
if err != nil {
|
||||
logrus.Errorf("Couldn't parse JWT token string from cookie: %s", err.Error())
|
||||
return nil, ErrSessionMalformed
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
logrus.Error("Could not parse the claims from the JWT token")
|
||||
return nil, ErrSessionMalformed
|
||||
}
|
||||
|
||||
sessionKeyClaim, ok := claims["session_key"]
|
||||
if !ok {
|
||||
logrus.Warn("JWT did not have session_key claim")
|
||||
return nil, ErrSessionMalformed
|
||||
}
|
||||
|
||||
sessionKey, ok := sessionKeyClaim.(string)
|
||||
if !ok {
|
||||
logrus.Warn("JWT session_key claim was not a string")
|
||||
return nil, ErrSessionMalformed
|
||||
}
|
||||
|
||||
session, err := sm.Backend.FindKey(sessionKey)
|
||||
if err != nil {
|
||||
switch err {
|
||||
case ErrNoActiveSession:
|
||||
// If the code path got this far, it's likely that the user was logged
|
||||
// in some time in the past, but their session has been expired since
|
||||
// their last usage of the application
|
||||
return nil, err
|
||||
default:
|
||||
logrus.Errorf("Couldn't call Get on backend object: %s", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
sm.session = session
|
||||
}
|
||||
return sm.session, nil
|
||||
|
||||
}
|
||||
|
||||
// MakeSessionForUserID creates a session in the database for a given user id.
|
||||
// You must call Save() after calling this.
|
||||
func (sm *SessionManager) MakeSessionForUserID(id uint) error {
|
||||
session, err := sm.Backend.Create(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sm.session = session
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save writes the current session to a token and delivers the token as a cookie
|
||||
// to the user. Save must be called after every write action on this struct
|
||||
// (MakeSessionForUser, Destroy, etc.)
|
||||
func (sm *SessionManager) Save() error {
|
||||
token, err := GenerateJWT(sm.session.Key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: set proper flags on cookie for maximum security
|
||||
http.SetCookie(sm.Writer, &http.Cookie{
|
||||
Name: CookieName,
|
||||
Value: token,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Destroy deletes the active session from the database and erases the session
|
||||
// instance from this object's access. You must call Save() after calling this.
|
||||
func (sm *SessionManager) Destroy() error {
|
||||
if sm.Backend != nil {
|
||||
err := sm.Backend.Destroy(sm.session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// JSON Web Tokens
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Given a session key create a JWT to be delivered to the client
|
||||
func GenerateJWT(sessionKey string) (string, error) {
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
|
||||
"session_key": sessionKey,
|
||||
})
|
||||
|
||||
return token.SignedString([]byte(jwtKey))
|
||||
}
|
||||
|
||||
// ParseJWT attempts to parse a JWT token in serialized string form into a
|
||||
// JWT token in a deserialized jwt.Token struct.
|
||||
func ParseJWT(token string) (*jwt.Token, error) {
|
||||
return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
|
||||
method, ok := t.Method.(*jwt.SigningMethodHMAC)
|
||||
if !ok || method != jwt.SigningMethodHS256 {
|
||||
return nil, errors.New("Unexpected signing method")
|
||||
}
|
||||
return []byte(jwtKey), nil
|
||||
})
|
||||
}
|
66
sessions/sessions_test.go
Normal file
66
sessions/sessions_test.go
Normal file
@ -0,0 +1,66 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
func TestGenerateJWT(t *testing.T) {
|
||||
jwtKey = "very secure"
|
||||
tokenString, err := GenerateJWT("4")
|
||||
token, err := ParseJWT(tokenString)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
t.Fatal("Token is invalid")
|
||||
}
|
||||
|
||||
sessionKey := claims["session_key"].(string)
|
||||
if sessionKey != "4" {
|
||||
t.Fatalf("Claims are incorrect. session key is %s", sessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionManager(t *testing.T) {
|
||||
r, _ := http.NewRequest("GET", "/", nil)
|
||||
w := newMocResponseWriter()
|
||||
sb := newMockSessionBackend()
|
||||
|
||||
sm := &SessionManager{
|
||||
Backend: sb,
|
||||
Request: r,
|
||||
Writer: w,
|
||||
}
|
||||
|
||||
err := sm.MakeSessionForUserID(1)
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
err = sm.Save()
|
||||
if err != nil {
|
||||
t.Fatalf(err.Error())
|
||||
}
|
||||
|
||||
header := w.Header().Get("Set-Cookie")
|
||||
tokenString := strings.Split(header, "=")[1]
|
||||
token, err := ParseJWT(tokenString)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
session_key := token.Claims.(jwt.MapClaims)["session_key"].(string)
|
||||
session, err := sb.FindKey(session_key)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if session.UserID != 1 {
|
||||
t.Fatal("User ID doesn't match. Got: %s", session.UserID)
|
||||
}
|
||||
|
||||
}
|
117
sessions_test.go
117
sessions_test.go
@ -1,117 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type MockResponseWriter struct {
|
||||
}
|
||||
|
||||
func (w *MockResponseWriter) Header() http.Header {
|
||||
return map[string][]string{}
|
||||
}
|
||||
|
||||
func (w *MockResponseWriter) Write([]byte) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (w *MockResponseWriter) WriteHeader(int) {
|
||||
}
|
||||
|
||||
func TestSessionManagerVC(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
admin, err := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
backend := &GormSessionBackend{db}
|
||||
session, err := backend.Create(admin.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
if session.UserID != admin.ID {
|
||||
t.Fatal("IDs do not match")
|
||||
}
|
||||
|
||||
token, err := GenerateJWT(session.Key)
|
||||
|
||||
cookie := &http.Cookie{
|
||||
Name: CookieName,
|
||||
Value: token,
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
req.AddCookie(cookie)
|
||||
|
||||
writer := &MockResponseWriter{}
|
||||
|
||||
sm := &SessionManager{
|
||||
request: req,
|
||||
writer: writer,
|
||||
backend: backend,
|
||||
db: db,
|
||||
}
|
||||
vc := sm.VC()
|
||||
|
||||
if !vc.IsAdmin() {
|
||||
t.Fatal("User should be admin")
|
||||
}
|
||||
|
||||
vcID, _ := vc.UserID()
|
||||
if vcID != admin.ID {
|
||||
t.Fatal("IDs don't match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionCreation(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
r := createEmptyTestServer(db)
|
||||
admin, _ := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false)
|
||||
|
||||
r.GET("/login", func(c *gin.Context) {
|
||||
sm := NewSessionManager(c)
|
||||
sm.MakeSessionForUser(admin)
|
||||
err := sm.Save()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
c.JSON(200, nil)
|
||||
})
|
||||
|
||||
r.GET("/resource", func(c *gin.Context) {
|
||||
sm := NewSessionManager(c)
|
||||
vc := sm.VC()
|
||||
if !vc.IsAdmin() {
|
||||
t.Fatal("Request is not admin")
|
||||
}
|
||||
c.JSON(200, nil)
|
||||
})
|
||||
|
||||
r.GET("/nope", func(c *gin.Context) {
|
||||
sm := NewSessionManager(c)
|
||||
vc := sm.VC()
|
||||
if !vc.IsAdmin() {
|
||||
t.Fatal("Request is not admin")
|
||||
}
|
||||
c.JSON(200, nil)
|
||||
})
|
||||
|
||||
res1 := httptest.NewRecorder()
|
||||
req1, _ := http.NewRequest("GET", "/login", nil)
|
||||
r.ServeHTTP(res1, req1)
|
||||
|
||||
res2 := httptest.NewRecorder()
|
||||
req2, _ := http.NewRequest("GET", "/resource", nil)
|
||||
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
|
||||
r.ServeHTTP(res2, req2)
|
||||
}
|
@ -6,6 +6,7 @@ import (
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/kolide/kolide-ose/sessions"
|
||||
)
|
||||
|
||||
func TestUserAndAccountManagement(t *testing.T) {
|
||||
@ -59,7 +60,7 @@ func TestUserAndAccountManagement(t *testing.T) {
|
||||
}
|
||||
|
||||
// Pull the token out of the JWT token and get the session info via that
|
||||
token, err := ParseJWT(strings.Split(adminSession, "=")[1])
|
||||
token, err := sessions.ParseJWT(strings.Split(adminSession, "=")[1])
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
@ -75,7 +76,7 @@ func TestUserAndAccountManagement(t *testing.T) {
|
||||
req.DeleteSession(adminSessionInfo.Sessions[0].SessionID, adminSession)
|
||||
|
||||
// Verify the session was deleted
|
||||
sessionVerify := &Session{
|
||||
sessionVerify := &sessions.Session{
|
||||
Key: sessionKey,
|
||||
}
|
||||
err = req.db.Where(sessionVerify).First(sessionVerify).Error
|
||||
|
231
users.go
231
users.go
@ -2,10 +2,12 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Sirupsen/logrus"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/jinzhu/gorm"
|
||||
"github.com/kolide/kolide-ose/sessions"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
@ -452,3 +454,232 @@ func SetUserEnabledState(c *gin.Context) {
|
||||
NeedsPasswordReset: user.NeedsPasswordReset,
|
||||
})
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
// Session management HTTP endpoints
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Setting the session backend via a middleware
|
||||
func SessionBackendMiddleware(c *gin.Context) {
|
||||
db := GetDB(c)
|
||||
c.Set("SessionBackend", &sessions.GormSessionBackend{db})
|
||||
c.Next()
|
||||
}
|
||||
|
||||
// Get the database connection from the context, or panic
|
||||
func GetSessionBackend(c *gin.Context) sessions.SessionBackend {
|
||||
return c.MustGet("SessionBackend").(sessions.SessionBackend)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Session management HTTP endpoints
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
type DeleteSessionRequestBody struct {
|
||||
SessionID uint `json:"session_id" binding:"required"`
|
||||
}
|
||||
|
||||
func DeleteSession(c *gin.Context) {
|
||||
var body DeleteSessionRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
vc := VC(c)
|
||||
if !vc.CanPerformActions() {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
sb := GetSessionBackend(c)
|
||||
|
||||
session, err := sb.FindID(body.SessionID)
|
||||
if err != nil {
|
||||
|
||||
}
|
||||
|
||||
db := GetDB(c)
|
||||
user := &User{
|
||||
BaseModel: BaseModel{
|
||||
ID: session.UserID,
|
||||
},
|
||||
}
|
||||
err = db.Where(user).First(user).Error
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.CanPerformWriteActionOnUser(user) {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
err = sb.Destroy(session)
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, nil)
|
||||
}
|
||||
|
||||
type DeleteSessionsForUserRequestBody struct {
|
||||
ID uint `json:"id"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
func DeleteSessionsForUser(c *gin.Context) {
|
||||
var body DeleteSessionsForUserRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf(err.Error())
|
||||
}
|
||||
|
||||
vc := VC(c)
|
||||
if !vc.CanPerformActions() {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
db := GetDB(c)
|
||||
var user User
|
||||
user.ID = body.ID
|
||||
user.Username = body.Username
|
||||
err = db.Where(&user).First(&user).Error
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.CanPerformWriteActionOnUser(&user) {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
sb := GetSessionBackend(c)
|
||||
err = sb.DestroyAllForUser(user.ID)
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, nil)
|
||||
|
||||
}
|
||||
|
||||
type GetInfoAboutSessionRequestBody struct {
|
||||
SessionKey string `json:"session_key" binding:"required"`
|
||||
}
|
||||
|
||||
type SessionInfoResponseBody struct {
|
||||
SessionID uint `json:"session_id"`
|
||||
UserID uint `json:"user_id"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
AccessedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
func GetInfoAboutSession(c *gin.Context) {
|
||||
var body GetInfoAboutSessionRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
vc := VC(c)
|
||||
if !vc.CanPerformActions() {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
sb := GetSessionBackend(c)
|
||||
session, err := sb.FindKey(body.SessionKey)
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
db := GetDB(c)
|
||||
var user User
|
||||
user.ID = session.UserID
|
||||
err = db.Where(&user).First(&user).Error
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.IsAdmin() && !vc.IsUserID(user.ID) {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(200, &SessionInfoResponseBody{
|
||||
SessionID: session.ID,
|
||||
UserID: session.UserID,
|
||||
CreatedAt: session.CreatedAt,
|
||||
AccessedAt: session.AccessedAt,
|
||||
})
|
||||
}
|
||||
|
||||
type GetInfoAboutSessionsForUserRequestBody struct {
|
||||
ID uint `json:"id"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
type GetInfoAboutSessionsForUserResponseBody struct {
|
||||
Sessions []*SessionInfoResponseBody `json:"sessions"`
|
||||
}
|
||||
|
||||
func GetInfoAboutSessionsForUser(c *gin.Context) {
|
||||
var body GetInfoAboutSessionsForUserRequestBody
|
||||
err := c.BindJSON(&body)
|
||||
if err != nil {
|
||||
logrus.Errorf(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
vc := VC(c)
|
||||
if !vc.CanPerformActions() {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
db := GetDB(c)
|
||||
var user User
|
||||
user.ID = body.ID
|
||||
user.Username = body.Username
|
||||
err = db.Where(&user).First(&user).Error
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
if !vc.IsAdmin() && !vc.IsUserID(user.ID) {
|
||||
UnauthorizedError(c)
|
||||
return
|
||||
}
|
||||
|
||||
sb := GetSessionBackend(c)
|
||||
sessions, err := sb.FindAllForUser(user.ID)
|
||||
if err != nil {
|
||||
DatabaseError(c)
|
||||
return
|
||||
}
|
||||
|
||||
var response []*SessionInfoResponseBody
|
||||
for _, session := range sessions {
|
||||
response = append(response, &SessionInfoResponseBody{
|
||||
SessionID: session.ID,
|
||||
UserID: session.UserID,
|
||||
CreatedAt: session.CreatedAt,
|
||||
AccessedAt: session.AccessedAt,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(200, &GetInfoAboutSessionsForUserResponseBody{
|
||||
Sessions: response,
|
||||
})
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user