Moving sessions code into sub-package (#42)

Since the sessions code mostly stands on it's own, I wanted to break the
dependencies apart from it and move it into it's own package.
This commit is contained in:
Mike Arpaia 2016-08-05 10:47:41 -07:00 committed by GitHub
parent cd8057e860
commit fe2bf7eb2b
12 changed files with 844 additions and 735 deletions

41
auth.go
View File

@ -7,8 +7,8 @@ import (
"fmt"
"github.com/Sirupsen/logrus"
"github.com/dgrijalva/jwt-go"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"golang.org/x/crypto/bcrypt"
)
@ -99,32 +99,23 @@ func EmptyVC() *ViewerContext {
// and generate an appropriate ViewerContext given the data in the session.
func VC(c *gin.Context) *ViewerContext {
sm := NewSessionManager(c)
return sm.VC()
session, err := sm.Session()
if err != nil {
return EmptyVC()
}
return VCForID(GetDB(c), session.UserID)
}
////////////////////////////////////////////////////////////////////////////////
// JSON Web Tokens
////////////////////////////////////////////////////////////////////////////////
func VCForID(db *gorm.DB, id uint) *ViewerContext {
// Generating a VC requires a user struct. Attempt to populate one using
// the user id of the current session holder
user := &User{BaseModel: BaseModel{ID: id}}
err := db.Where(user).First(user).Error
if err != nil {
return EmptyVC()
}
// Given a session key create a JWT to be delivered to the client
func GenerateJWT(sessionKey string) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"session_key": sessionKey,
})
return token.SignedString([]byte(config.App.JWTKey))
}
// ParseJWT attempts to parse a JWT token in serialized string form into a
// JWT token in a deserialized jwt.Token struct.
func ParseJWT(token string) (*jwt.Token, error) {
return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
method, ok := t.Method.(*jwt.SigningMethodHMAC)
if !ok || method != jwt.SigningMethodHS256 {
return nil, errors.New("Unexpected signing method")
}
return []byte(config.App.JWTKey), nil
})
return GenerateVC(user)
}
////////////////////////////////////////////////////////////////////////////////
@ -194,7 +185,7 @@ func Login(c *gin.Context) {
}
sm := NewSessionManager(c)
sm.MakeSessionForUser(user)
sm.MakeSessionForUserID(user.ID)
err = sm.Save()
if err != nil {
DatabaseError(c)

View File

@ -5,7 +5,6 @@ import (
"net/http/httptest"
"testing"
"github.com/dgrijalva/jwt-go"
"github.com/gin-gonic/gin"
)
@ -24,23 +23,6 @@ func TestGenerateVC(t *testing.T) {
}
func TestGenerateJWT(t *testing.T) {
tokenString, err := GenerateJWT("4")
token, err := ParseJWT(tokenString)
if err != nil {
t.Fatal(err.Error())
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
t.Fatal("Token is invalid")
}
sessionKey := claims["session_key"].(string)
if sessionKey != "4" {
t.Fatalf("Claims are incorrect. session key is %s", sessionKey)
}
}
func TestVC(t *testing.T) {
db := openTestDB(t)
r := createEmptyTestServer(db)
@ -57,7 +39,7 @@ func TestVC(t *testing.T) {
r.GET("/admin_login", func(c *gin.Context) {
sm := NewSessionManager(c)
sm.MakeSessionForUser(admin)
sm.MakeSessionForUserID(admin.ID)
err := sm.Save()
if err != nil {
t.Fatal(err.Error())
@ -67,7 +49,7 @@ func TestVC(t *testing.T) {
r.GET("/user_login", func(c *gin.Context) {
sm := NewSessionManager(c)
sm.MakeSessionForUser(user)
sm.MakeSessionForUserID(user.ID)
err := sm.Save()
if err != nil {
t.Fatal(err.Error())

View File

@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
_ "github.com/jinzhu/gorm/dialects/mysql"
_ "github.com/jinzhu/gorm/dialects/sqlite"
"github.com/kolide/kolide-ose/sessions"
)
// Get the database connection from the context, or panic
@ -143,7 +144,7 @@ type Decorator struct {
var tables = [...]interface{}{
&User{},
&Session{},
&sessions.Session{},
&ScheduledQuery{},
&Pack{},
&DiscoveryQuery{},

View File

@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/contrib/ginrus"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/kolide/kolide-ose/sessions"
)
// ServerError is a helper which accepts a string error and returns a map in
@ -54,6 +55,17 @@ func DatabaseMiddleware(db *gorm.DB) gin.HandlerFunc {
}
}
// NewSessionManager allows you to get a SessionManager instance for a given
// web request. Unless you're interacting with login, logout, or core auth
// code, this should be abstracted by the ViewerContext pattern.
func NewSessionManager(c *gin.Context) *sessions.SessionManager {
return &sessions.SessionManager{
Request: c.Request,
Backend: GetSessionBackend(c),
Writer: c.Writer,
}
}
// CreateServer creates a gin.Engine HTTP server and configures it to be in a
// state such that it is ready to serve HTTP requests for the kolide application
func CreateServer(db *gorm.DB, w io.Writer) *gin.Engine {
@ -61,6 +73,13 @@ func CreateServer(db *gorm.DB, w io.Writer) *gin.Engine {
server.Use(DatabaseMiddleware(db))
server.Use(SessionBackendMiddleware)
sessions.Configure(&sessions.SessionConfiguration{
CookieName: "KolideSession",
JWTKey: config.App.JWTKey,
SessionKeySize: config.App.SessionKeySize,
Lifespan: config.App.SessionExpirationSeconds,
})
// TODO: The following loggers are not synchronized with each other or
// logrus.StandardLogger() used through the rest of the codebase. As
// such, their output may become intermingled.

View File

@ -1,570 +0,0 @@
package main
import (
"errors"
"net/http"
"time"
"github.com/Sirupsen/logrus"
"github.com/dgrijalva/jwt-go"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
)
var (
// An error returned by SessionBackend.Get() if no session record was found
// in the database
ErrNoActiveSession = errors.New("Active session is not present in the database")
// An error returned by SessionBackend methods when no session object has
// been created yet but the requested action requires one
ErrSessionNotCreated = errors.New("The session has not been created")
// An error returned by SessionBackend.Get() when a session is requested but
// it has expired
ErrSessionExpired = errors.New("The session has expired")
)
const (
// The name of the session cookie
CookieName = "KolideSession"
)
// Session is the model object which represents what an active session is
type Session struct {
BaseModel
UserID uint `gorm:"not null"`
Key string `gorm:"not null;unique_index:idx_session_unique_key"`
AccessedAt time.Time
}
////////////////////////////////////////////////////////////////////////////////
// Managing sessions
////////////////////////////////////////////////////////////////////////////////
// SessionManager is a management object which helps with the administration of
// sessions within the application. Use NewSessionManager to create an instance
type SessionManager struct {
backend SessionBackend
request *http.Request
writer http.ResponseWriter
session *Session
vc *ViewerContext
db *gorm.DB
}
// NewSessionManager allows you to get a SessionManager instance for a given
// web request. Unless you're interacting with login, logout, or core auth
// code, this should be abstracted by the ViewerContext pattern.
func NewSessionManager(c *gin.Context) *SessionManager {
return &SessionManager{
request: c.Request,
backend: GetSessionBackend(c),
writer: c.Writer,
db: GetDB(c),
}
}
// Get the ViewerContext instance for a user represented by the active session
func (sm *SessionManager) VC() *ViewerContext {
if sm.session == nil {
cookie, err := sm.request.Cookie(CookieName)
if err != nil {
switch err {
case http.ErrNoCookie:
// No cookie was set
return EmptyVC()
default:
// Something went wrong and the cookie may or may not be set
logrus.Errorf("Couldn't get cookie: %s", err.Error())
return EmptyVC()
}
}
token, err := ParseJWT(cookie.Value)
if err != nil {
logrus.Errorf("Couldn't parse JWT token string from cookie: %s", err.Error())
return EmptyVC()
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
logrus.Error("Could not parse the claims from the JWT token")
return EmptyVC()
}
sessionKeyClaim, ok := claims["session_key"]
if !ok {
logrus.Warn("JWT did not have session_key claim")
return EmptyVC()
}
sessionKey, ok := sessionKeyClaim.(string)
if !ok {
logrus.Warn("JWT session_key claim was not a string")
return EmptyVC()
}
session, err := sm.backend.FindKey(sessionKey)
if err != nil {
switch err {
case ErrNoActiveSession:
// If the code path got this far, it's likely that the user was logged
// in some time in the past, but their session has been expired since
// their last usage of the application
return EmptyVC()
default:
logrus.Errorf("Couldn't call Get on backend object: %s", err.Error())
return EmptyVC()
}
}
sm.session = session
}
if sm.vc == nil {
// Generating a VC requires a user struct. Attempt to populate one using
// the user id of the current session holder
user := &User{BaseModel: BaseModel{ID: sm.session.UserID}}
err := sm.db.Where(user).First(user).Error
if err != nil {
return EmptyVC()
}
sm.vc = GenerateVC(user)
}
return sm.vc
}
// MakeSessionForUserID creates a session in the database for a given user id.
// You must call Save() after calling this.
func (sm *SessionManager) MakeSessionForUserID(id uint) error {
session, err := sm.backend.Create(id)
if err != nil {
return err
}
sm.session = session
return nil
}
// MakeSessionForUserID creates a session in the database for a given user
// You must call Save() after calling this.
func (sm *SessionManager) MakeSessionForUser(u *User) error {
return sm.MakeSessionForUserID(u.ID)
}
// Save writes the current session to a token and delivers the token as a cookie
// to the user. Save must be called after every write action on this struct
// (MakeSessionForUser, Destroy, etc.)
func (sm *SessionManager) Save() error {
token, err := GenerateJWT(sm.session.Key)
if err != nil {
return err
}
// TODO: set proper flags on cookie for maximum security
http.SetCookie(sm.writer, &http.Cookie{
Name: CookieName,
Value: token,
})
return nil
}
// Destroy deletes the active session from the database and erases the session
// instance from this object's access. You must call Save() after calling this.
func (sm *SessionManager) Destroy() error {
if sm.backend != nil {
err := sm.backend.Destroy(sm.session)
if err != nil {
return err
}
}
return nil
}
////////////////////////////////////////////////////////////////////////////////
// Session Backend API
////////////////////////////////////////////////////////////////////////////////
// SessionBackend is the abstract interface that all session backends must
// conform to. SessionBackend instances are only expected to exist within the
// context of a single request.
type SessionBackend interface {
// Given a session key, find and return a session object or an error if one
// could not be found for the given key
FindKey(key string) (*Session, error)
// Given a session id, find and return a session object or an error if one
// could not be found for the given id
FindID(id uint) (*Session, error)
// Find all of the active sessions for a given user
FindAllForUser(id uint) ([]*Session, error)
// Create a session object tied to the given user ID
Create(userID uint) (*Session, error)
// Destroy the currently tracked session
Destroy(session *Session) error
// Destroy all of the sessions for a given user
DestroyAllForUser(id uint) error
// Mark the currently tracked session as access to extend expiration
MarkAccessed(session *Session) error
}
////////////////////////////////////////////////////////////////////////////////
// Session Backend Plugins
////////////////////////////////////////////////////////////////////////////////
// GormSessionBackend stores sessions using a pre-instantiated gorm database
// object
type GormSessionBackend struct {
db *gorm.DB
}
func (s *GormSessionBackend) validate(session *Session) error {
if time.Since(session.AccessedAt).Seconds() >= config.App.SessionExpirationSeconds {
err := s.db.Delete(session).Error
if err != nil {
return err
}
return ErrSessionExpired
}
err := s.MarkAccessed(session)
if err != nil {
return err
}
return nil
}
func (s *GormSessionBackend) FindID(id uint) (*Session, error) {
session := &Session{
BaseModel: BaseModel{
ID: id,
},
}
err := s.db.Where(session).First(session).Error
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
return nil, ErrNoActiveSession
default:
return nil, err
}
}
err = s.validate(session)
if err != nil {
return nil, err
}
return session, nil
}
func (s *GormSessionBackend) FindKey(key string) (*Session, error) {
session := &Session{
Key: key,
}
err := s.db.Where(session).First(session).Error
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
return nil, ErrNoActiveSession
default:
return nil, err
}
}
err = s.validate(session)
if err != nil {
return nil, err
}
return session, nil
}
func (s *GormSessionBackend) FindAllForUser(id uint) ([]*Session, error) {
var sessions []*Session
err := s.db.Where("user_id = ?", id).Find(&sessions).Error
return sessions, err
}
func (s *GormSessionBackend) Create(userID uint) (*Session, error) {
key, err := generateRandomText(config.App.SessionKeySize)
if err != nil {
return nil, err
}
session := &Session{
UserID: userID,
Key: key,
}
err = s.db.Create(session).Error
if err != nil {
return nil, err
}
err = s.MarkAccessed(session)
if err != nil {
return nil, err
}
return session, nil
}
func (s *GormSessionBackend) Destroy(session *Session) error {
err := s.db.Delete(session).Error
if err != nil {
return err
}
return nil
}
func (s *GormSessionBackend) DestroyAllForUser(id uint) error {
return s.db.Delete(&Session{}, "user_id = ?", id).Error
}
func (s *GormSessionBackend) MarkAccessed(session *Session) error {
session.AccessedAt = time.Now().UTC()
return s.db.Save(session).Error
}
////////////////////////////////////////////////////////////////////////////////
// Session management HTTP endpoints
////////////////////////////////////////////////////////////////////////////////
// Setting the session backend via a middleware
func SessionBackendMiddleware(c *gin.Context) {
db := GetDB(c)
c.Set("SessionBackend", &GormSessionBackend{db})
c.Next()
}
// Get the database connection from the context, or panic
func GetSessionBackend(c *gin.Context) SessionBackend {
return c.MustGet("SessionBackend").(SessionBackend)
}
////////////////////////////////////////////////////////////////////////////////
// Session management HTTP endpoints
////////////////////////////////////////////////////////////////////////////////
type DeleteSessionRequestBody struct {
SessionID uint `json:"session_id" binding:"required"`
}
func DeleteSession(c *gin.Context) {
var body DeleteSessionRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf(err.Error())
return
}
vc := VC(c)
if !vc.CanPerformActions() {
UnauthorizedError(c)
return
}
sb := GetSessionBackend(c)
session, err := sb.FindID(body.SessionID)
if err != nil {
}
db := GetDB(c)
user := &User{
BaseModel: BaseModel{
ID: session.UserID,
},
}
err = db.Where(user).First(user).Error
if err != nil {
DatabaseError(c)
return
}
if !vc.CanPerformWriteActionOnUser(user) {
UnauthorizedError(c)
return
}
err = sb.Destroy(session)
if err != nil {
DatabaseError(c)
return
}
c.JSON(200, nil)
}
type DeleteSessionsForUserRequestBody struct {
ID uint `json:"id"`
Username string `json:"username"`
}
func DeleteSessionsForUser(c *gin.Context) {
var body DeleteSessionsForUserRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf(err.Error())
}
vc := VC(c)
if !vc.CanPerformActions() {
UnauthorizedError(c)
return
}
db := GetDB(c)
var user User
user.ID = body.ID
user.Username = body.Username
err = db.Where(&user).First(&user).Error
if err != nil {
DatabaseError(c)
return
}
if !vc.CanPerformWriteActionOnUser(&user) {
UnauthorizedError(c)
return
}
sb := GetSessionBackend(c)
err = sb.DestroyAllForUser(user.ID)
err = db.Delete(&Session{}, "user_id = ?", user.ID).Error
if err != nil {
DatabaseError(c)
return
}
c.JSON(200, nil)
}
type GetInfoAboutSessionRequestBody struct {
SessionKey string `json:"session_key" binding:"required"`
}
type SessionInfoResponseBody struct {
SessionID uint `json:"session_id"`
UserID uint `json:"user_id"`
CreatedAt time.Time `json:"created_at"`
AccessedAt time.Time `json:"created_at"`
}
func GetInfoAboutSession(c *gin.Context) {
var body GetInfoAboutSessionRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf(err.Error())
return
}
vc := VC(c)
if !vc.CanPerformActions() {
UnauthorizedError(c)
return
}
sb := GetSessionBackend(c)
session, err := sb.FindKey(body.SessionKey)
if err != nil {
DatabaseError(c)
return
}
db := GetDB(c)
var user User
user.ID = session.UserID
err = db.Where(&user).First(&user).Error
if err != nil {
DatabaseError(c)
return
}
if !vc.IsAdmin() && !vc.IsUserID(user.ID) {
UnauthorizedError(c)
return
}
c.JSON(200, &SessionInfoResponseBody{
SessionID: session.ID,
UserID: session.UserID,
CreatedAt: session.CreatedAt,
AccessedAt: session.AccessedAt,
})
}
type GetInfoAboutSessionsForUserRequestBody struct {
ID uint `json:"id"`
Username string `json:"username"`
}
type GetInfoAboutSessionsForUserResponseBody struct {
Sessions []*SessionInfoResponseBody `json:"sessions"`
}
func GetInfoAboutSessionsForUser(c *gin.Context) {
var body GetInfoAboutSessionsForUserRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf(err.Error())
return
}
vc := VC(c)
if !vc.CanPerformActions() {
UnauthorizedError(c)
return
}
db := GetDB(c)
var user User
user.ID = body.ID
user.Username = body.Username
err = db.Where(&user).First(&user).Error
if err != nil {
DatabaseError(c)
return
}
if !vc.IsAdmin() && !vc.IsUserID(user.ID) {
UnauthorizedError(c)
return
}
sb := GetSessionBackend(c)
sessions, err := sb.FindAllForUser(user.ID)
if err != nil {
DatabaseError(c)
return
}
var response []*SessionInfoResponseBody
for _, session := range sessions {
response = append(response, &SessionInfoResponseBody{
SessionID: session.ID,
UserID: session.UserID,
CreatedAt: session.CreatedAt,
AccessedAt: session.AccessedAt,
})
}
c.JSON(200, &GetInfoAboutSessionsForUserResponseBody{
Sessions: response,
})
}

164
sessions/backends.go Normal file
View File

@ -0,0 +1,164 @@
package sessions
import (
"crypto/rand"
"encoding/base64"
"time"
"github.com/jinzhu/gorm"
)
////////////////////////////////////////////////////////////////////////////////
// Session Backend API
////////////////////////////////////////////////////////////////////////////////
// SessionBackend is the abstract interface that all session backends must
// conform to. SessionBackend instances are only expected to exist within the
// context of a single request.
type SessionBackend interface {
// Given a session key, find and return a session object or an error if one
// could not be found for the given key
FindKey(key string) (*Session, error)
// Given a session id, find and return a session object or an error if one
// could not be found for the given id
FindID(id uint) (*Session, error)
// Find all of the active sessions for a given user
FindAllForUser(id uint) ([]*Session, error)
// Create a session object tied to the given user ID
Create(userID uint) (*Session, error)
// Destroy the currently tracked session
Destroy(session *Session) error
// Destroy all of the sessions for a given user
DestroyAllForUser(id uint) error
// Mark the currently tracked session as access to extend expiration
MarkAccessed(session *Session) error
}
////////////////////////////////////////////////////////////////////////////////
// Session Backend Plugins
////////////////////////////////////////////////////////////////////////////////
// GormSessionBackend stores sessions using a pre-instantiated gorm database
// object
type GormSessionBackend struct {
DB *gorm.DB
}
func (s *GormSessionBackend) validate(session *Session) error {
if time.Since(session.AccessedAt).Seconds() >= Lifespan {
err := s.DB.Delete(session).Error
if err != nil {
return err
}
return ErrSessionExpired
}
err := s.MarkAccessed(session)
if err != nil {
return err
}
return nil
}
func (s *GormSessionBackend) FindID(id uint) (*Session, error) {
session := &Session{
ID: id,
}
err := s.DB.Where(session).First(session).Error
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
return nil, ErrNoActiveSession
default:
return nil, err
}
}
err = s.validate(session)
if err != nil {
return nil, err
}
return session, nil
}
func (s *GormSessionBackend) FindKey(key string) (*Session, error) {
session := &Session{
Key: key,
}
err := s.DB.Where(session).First(session).Error
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
return nil, ErrNoActiveSession
default:
return nil, err
}
}
err = s.validate(session)
if err != nil {
return nil, err
}
return session, nil
}
func (s *GormSessionBackend) FindAllForUser(id uint) ([]*Session, error) {
var sessions []*Session
err := s.DB.Where("user_id = ?", id).Find(&sessions).Error
return sessions, err
}
func (s *GormSessionBackend) Create(userID uint) (*Session, error) {
key := make([]byte, SessionKeySize)
_, err := rand.Read(key)
if err != nil {
return nil, err
}
session := &Session{
UserID: userID,
Key: base64.StdEncoding.EncodeToString(key),
}
err = s.DB.Create(session).Error
if err != nil {
return nil, err
}
err = s.MarkAccessed(session)
if err != nil {
return nil, err
}
return session, nil
}
func (s *GormSessionBackend) Destroy(session *Session) error {
err := s.DB.Delete(session).Error
if err != nil {
return err
}
return nil
}
func (s *GormSessionBackend) DestroyAllForUser(id uint) error {
return s.DB.Delete(&Session{}, "user_id = ?", id).Error
}
func (s *GormSessionBackend) MarkAccessed(session *Session) error {
session.AccessedAt = time.Now().UTC()
return s.DB.Save(session).Error
}

129
sessions/backends_test.go Normal file
View File

@ -0,0 +1,129 @@
package sessions
import (
"crypto/rand"
"encoding/base64"
"net/http"
"testing"
"time"
)
type mockSessionBackend struct {
sessions []*Session
id uint
}
func newMockSessionBackend() *mockSessionBackend {
return &mockSessionBackend{
sessions: []*Session{},
id: 0,
}
}
func (s *mockSessionBackend) FindID(id uint) (*Session, error) {
for _, each := range s.sessions {
if each.ID == id {
return each, nil
}
}
return nil, ErrNoActiveSession
}
func (s *mockSessionBackend) FindKey(key string) (*Session, error) {
for _, each := range s.sessions {
if each.Key == key {
return each, nil
}
}
return nil, ErrNoActiveSession
}
func (s *mockSessionBackend) FindAllForUser(id uint) ([]*Session, error) {
var sessions []*Session
for _, each := range sessions {
if each.UserID == id {
sessions = append(sessions, each)
}
}
return sessions, nil
}
func (s *mockSessionBackend) nextID() uint {
s.id = s.id + 1
return s.id
}
func (s *mockSessionBackend) Create(userID uint) (*Session, error) {
key := make([]byte, SessionKeySize)
_, err := rand.Read(key)
if err != nil {
return nil, err
}
session := &Session{
ID: s.nextID(),
UserID: userID,
Key: base64.StdEncoding.EncodeToString(key),
}
err = s.MarkAccessed(session)
if err != nil {
return nil, err
}
s.sessions = append(s.sessions, session)
return session, nil
}
func (s *mockSessionBackend) Destroy(session *Session) error {
var sessions []*Session
for _, each := range s.sessions {
if each.ID != session.ID {
sessions = append(sessions, each)
}
}
s.sessions = sessions
return nil
}
func (s *mockSessionBackend) DestroyAllForUser(id uint) error {
var sessions []*Session
for _, each := range s.sessions {
if each.UserID != id {
sessions = append(sessions, each)
}
}
s.sessions = sessions
return nil
}
func (s *mockSessionBackend) MarkAccessed(session *Session) error {
session.AccessedAt = time.Now().UTC()
return nil
}
type mockResponseWriter struct {
headers map[string][]string
}
func newMocResponseWriter() *mockResponseWriter {
return &mockResponseWriter{
headers: map[string][]string{},
}
}
func (w *mockResponseWriter) Header() http.Header {
return w.headers
}
func (w *mockResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (w *mockResponseWriter) WriteHeader(int) {
}
func TestFindID(t *testing.T) {
}

212
sessions/sessions.go Normal file
View File

@ -0,0 +1,212 @@
package sessions
import (
"errors"
"net/http"
"time"
"github.com/Sirupsen/logrus"
"github.com/dgrijalva/jwt-go"
)
var (
// An error returned by SessionBackend.Get() if no session record was found
// in the database
ErrNoActiveSession = errors.New("Active session is not present in the database")
// An error returned by SessionBackend methods when no session object has
// been created yet but the requested action requires one
ErrSessionNotCreated = errors.New("The session has not been created")
// An error returned by SessionBackend.Get() when a session is requested but
// it has expired
ErrSessionExpired = errors.New("The session has expired")
// An error returned by SessionBackend which indicates that the token
// or it's content were malformed
ErrSessionMalformed = errors.New("The session token was malformed")
)
var (
// The name of the session cookie
CookieName = "Session"
// The key to be used to sign and verify JWTs
jwtKey = ""
// The amount of random data, in bytes, which will be used to create each
// session key
SessionKeySize = 64
// The amount of seconds that will pass before an inactive user is logged out
Lifespan = float64(60 * 60 * 24 * 90)
)
// Session is the model object which represents what an active session is
type Session struct {
ID uint `gorm:"primary_key"`
CreatedAt time.Time
AccessedAt time.Time
UserID uint `gorm:"not null"`
Key string `gorm:"not null;unique_index:idx_session_unique_key"`
}
////////////////////////////////////////////////////////////////////////////////
// Configuring the library
////////////////////////////////////////////////////////////////////////////////
type SessionConfiguration struct {
CookieName string
JWTKey string
SessionKeySize int
Lifespan float64
}
func Configure(s *SessionConfiguration) {
CookieName = s.CookieName
jwtKey = s.JWTKey
SessionKeySize = s.SessionKeySize
Lifespan = s.Lifespan
}
// Set the name of the cookie
func SetCookieName(name string) {
CookieName = name
}
////////////////////////////////////////////////////////////////////////////////
// Managing sessions
////////////////////////////////////////////////////////////////////////////////
// SessionManager is a management object which helps with the administration of
// sessions within the application. Use NewSessionManager to create an instance
type SessionManager struct {
Backend SessionBackend
Request *http.Request
Writer http.ResponseWriter
session *Session
}
func (sm *SessionManager) Session() (*Session, error) {
if sm.session == nil {
cookie, err := sm.Request.Cookie(CookieName)
if err != nil {
switch err {
case http.ErrNoCookie:
// No cookie was set
return nil, err
default:
// Something went wrong and the cookie may or may not be set
logrus.Errorf("Couldn't get cookie: %s", err.Error())
return nil, ErrSessionMalformed
}
}
token, err := ParseJWT(cookie.Value)
if err != nil {
logrus.Errorf("Couldn't parse JWT token string from cookie: %s", err.Error())
return nil, ErrSessionMalformed
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
logrus.Error("Could not parse the claims from the JWT token")
return nil, ErrSessionMalformed
}
sessionKeyClaim, ok := claims["session_key"]
if !ok {
logrus.Warn("JWT did not have session_key claim")
return nil, ErrSessionMalformed
}
sessionKey, ok := sessionKeyClaim.(string)
if !ok {
logrus.Warn("JWT session_key claim was not a string")
return nil, ErrSessionMalformed
}
session, err := sm.Backend.FindKey(sessionKey)
if err != nil {
switch err {
case ErrNoActiveSession:
// If the code path got this far, it's likely that the user was logged
// in some time in the past, but their session has been expired since
// their last usage of the application
return nil, err
default:
logrus.Errorf("Couldn't call Get on backend object: %s", err.Error())
return nil, err
}
}
sm.session = session
}
return sm.session, nil
}
// MakeSessionForUserID creates a session in the database for a given user id.
// You must call Save() after calling this.
func (sm *SessionManager) MakeSessionForUserID(id uint) error {
session, err := sm.Backend.Create(id)
if err != nil {
return err
}
sm.session = session
return nil
}
// Save writes the current session to a token and delivers the token as a cookie
// to the user. Save must be called after every write action on this struct
// (MakeSessionForUser, Destroy, etc.)
func (sm *SessionManager) Save() error {
token, err := GenerateJWT(sm.session.Key)
if err != nil {
return err
}
// TODO: set proper flags on cookie for maximum security
http.SetCookie(sm.Writer, &http.Cookie{
Name: CookieName,
Value: token,
})
return nil
}
// Destroy deletes the active session from the database and erases the session
// instance from this object's access. You must call Save() after calling this.
func (sm *SessionManager) Destroy() error {
if sm.Backend != nil {
err := sm.Backend.Destroy(sm.session)
if err != nil {
return err
}
}
return nil
}
////////////////////////////////////////////////////////////////////////////////
// JSON Web Tokens
////////////////////////////////////////////////////////////////////////////////
// Given a session key create a JWT to be delivered to the client
func GenerateJWT(sessionKey string) (string, error) {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"session_key": sessionKey,
})
return token.SignedString([]byte(jwtKey))
}
// ParseJWT attempts to parse a JWT token in serialized string form into a
// JWT token in a deserialized jwt.Token struct.
func ParseJWT(token string) (*jwt.Token, error) {
return jwt.Parse(token, func(t *jwt.Token) (interface{}, error) {
method, ok := t.Method.(*jwt.SigningMethodHMAC)
if !ok || method != jwt.SigningMethodHS256 {
return nil, errors.New("Unexpected signing method")
}
return []byte(jwtKey), nil
})
}

66
sessions/sessions_test.go Normal file
View File

@ -0,0 +1,66 @@
package sessions
import (
"net/http"
"strings"
"testing"
jwt "github.com/dgrijalva/jwt-go"
)
func TestGenerateJWT(t *testing.T) {
jwtKey = "very secure"
tokenString, err := GenerateJWT("4")
token, err := ParseJWT(tokenString)
if err != nil {
t.Fatal(err.Error())
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
t.Fatal("Token is invalid")
}
sessionKey := claims["session_key"].(string)
if sessionKey != "4" {
t.Fatalf("Claims are incorrect. session key is %s", sessionKey)
}
}
func TestSessionManager(t *testing.T) {
r, _ := http.NewRequest("GET", "/", nil)
w := newMocResponseWriter()
sb := newMockSessionBackend()
sm := &SessionManager{
Backend: sb,
Request: r,
Writer: w,
}
err := sm.MakeSessionForUserID(1)
if err != nil {
t.Fatalf(err.Error())
}
err = sm.Save()
if err != nil {
t.Fatalf(err.Error())
}
header := w.Header().Get("Set-Cookie")
tokenString := strings.Split(header, "=")[1]
token, err := ParseJWT(tokenString)
if err != nil {
t.Fatal(err.Error())
}
session_key := token.Claims.(jwt.MapClaims)["session_key"].(string)
session, err := sb.FindKey(session_key)
if err != nil {
t.Fatal(err.Error())
}
if session.UserID != 1 {
t.Fatal("User ID doesn't match. Got: %s", session.UserID)
}
}

View File

@ -1,117 +0,0 @@
package main
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
type MockResponseWriter struct {
}
func (w *MockResponseWriter) Header() http.Header {
return map[string][]string{}
}
func (w *MockResponseWriter) Write([]byte) (int, error) {
return 0, nil
}
func (w *MockResponseWriter) WriteHeader(int) {
}
func TestSessionManagerVC(t *testing.T) {
db := openTestDB(t)
admin, err := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false)
if err != nil {
t.Fatal(err.Error())
}
backend := &GormSessionBackend{db}
session, err := backend.Create(admin.ID)
if err != nil {
t.Fatal(err.Error())
}
if session.UserID != admin.ID {
t.Fatal("IDs do not match")
}
token, err := GenerateJWT(session.Key)
cookie := &http.Cookie{
Name: CookieName,
Value: token,
}
req, err := http.NewRequest("GET", "/", nil)
if err != nil {
t.Fatal(err.Error())
}
req.AddCookie(cookie)
writer := &MockResponseWriter{}
sm := &SessionManager{
request: req,
writer: writer,
backend: backend,
db: db,
}
vc := sm.VC()
if !vc.IsAdmin() {
t.Fatal("User should be admin")
}
vcID, _ := vc.UserID()
if vcID != admin.ID {
t.Fatal("IDs don't match")
}
}
func TestSessionCreation(t *testing.T) {
db := openTestDB(t)
r := createEmptyTestServer(db)
admin, _ := NewUser(db, "admin", "foobar", "admin@kolide.co", true, false)
r.GET("/login", func(c *gin.Context) {
sm := NewSessionManager(c)
sm.MakeSessionForUser(admin)
err := sm.Save()
if err != nil {
t.Fatal(err.Error())
}
c.JSON(200, nil)
})
r.GET("/resource", func(c *gin.Context) {
sm := NewSessionManager(c)
vc := sm.VC()
if !vc.IsAdmin() {
t.Fatal("Request is not admin")
}
c.JSON(200, nil)
})
r.GET("/nope", func(c *gin.Context) {
sm := NewSessionManager(c)
vc := sm.VC()
if !vc.IsAdmin() {
t.Fatal("Request is not admin")
}
c.JSON(200, nil)
})
res1 := httptest.NewRecorder()
req1, _ := http.NewRequest("GET", "/login", nil)
r.ServeHTTP(res1, req1)
res2 := httptest.NewRecorder()
req2, _ := http.NewRequest("GET", "/resource", nil)
req2.Header.Set("Cookie", res1.Header().Get("Set-Cookie"))
r.ServeHTTP(res2, req2)
}

View File

@ -6,6 +6,7 @@ import (
jwt "github.com/dgrijalva/jwt-go"
"github.com/jinzhu/gorm"
"github.com/kolide/kolide-ose/sessions"
)
func TestUserAndAccountManagement(t *testing.T) {
@ -59,7 +60,7 @@ func TestUserAndAccountManagement(t *testing.T) {
}
// Pull the token out of the JWT token and get the session info via that
token, err := ParseJWT(strings.Split(adminSession, "=")[1])
token, err := sessions.ParseJWT(strings.Split(adminSession, "=")[1])
if err != nil {
t.Fatal(err.Error())
}
@ -75,7 +76,7 @@ func TestUserAndAccountManagement(t *testing.T) {
req.DeleteSession(adminSessionInfo.Sessions[0].SessionID, adminSession)
// Verify the session was deleted
sessionVerify := &Session{
sessionVerify := &sessions.Session{
Key: sessionKey,
}
err = req.db.Where(sessionVerify).First(sessionVerify).Error

231
users.go
View File

@ -2,10 +2,12 @@ package main
import (
"fmt"
"time"
"github.com/Sirupsen/logrus"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
"github.com/kolide/kolide-ose/sessions"
"golang.org/x/crypto/bcrypt"
)
@ -452,3 +454,232 @@ func SetUserEnabledState(c *gin.Context) {
NeedsPasswordReset: user.NeedsPasswordReset,
})
}
///////////////////////////////////////////////////////////////////////////////
// Session management HTTP endpoints
////////////////////////////////////////////////////////////////////////////////
// Setting the session backend via a middleware
func SessionBackendMiddleware(c *gin.Context) {
db := GetDB(c)
c.Set("SessionBackend", &sessions.GormSessionBackend{db})
c.Next()
}
// Get the database connection from the context, or panic
func GetSessionBackend(c *gin.Context) sessions.SessionBackend {
return c.MustGet("SessionBackend").(sessions.SessionBackend)
}
////////////////////////////////////////////////////////////////////////////////
// Session management HTTP endpoints
////////////////////////////////////////////////////////////////////////////////
type DeleteSessionRequestBody struct {
SessionID uint `json:"session_id" binding:"required"`
}
func DeleteSession(c *gin.Context) {
var body DeleteSessionRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf(err.Error())
return
}
vc := VC(c)
if !vc.CanPerformActions() {
UnauthorizedError(c)
return
}
sb := GetSessionBackend(c)
session, err := sb.FindID(body.SessionID)
if err != nil {
}
db := GetDB(c)
user := &User{
BaseModel: BaseModel{
ID: session.UserID,
},
}
err = db.Where(user).First(user).Error
if err != nil {
DatabaseError(c)
return
}
if !vc.CanPerformWriteActionOnUser(user) {
UnauthorizedError(c)
return
}
err = sb.Destroy(session)
if err != nil {
DatabaseError(c)
return
}
c.JSON(200, nil)
}
type DeleteSessionsForUserRequestBody struct {
ID uint `json:"id"`
Username string `json:"username"`
}
func DeleteSessionsForUser(c *gin.Context) {
var body DeleteSessionsForUserRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf(err.Error())
}
vc := VC(c)
if !vc.CanPerformActions() {
UnauthorizedError(c)
return
}
db := GetDB(c)
var user User
user.ID = body.ID
user.Username = body.Username
err = db.Where(&user).First(&user).Error
if err != nil {
DatabaseError(c)
return
}
if !vc.CanPerformWriteActionOnUser(&user) {
UnauthorizedError(c)
return
}
sb := GetSessionBackend(c)
err = sb.DestroyAllForUser(user.ID)
if err != nil {
DatabaseError(c)
return
}
c.JSON(200, nil)
}
type GetInfoAboutSessionRequestBody struct {
SessionKey string `json:"session_key" binding:"required"`
}
type SessionInfoResponseBody struct {
SessionID uint `json:"session_id"`
UserID uint `json:"user_id"`
CreatedAt time.Time `json:"created_at"`
AccessedAt time.Time `json:"created_at"`
}
func GetInfoAboutSession(c *gin.Context) {
var body GetInfoAboutSessionRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf(err.Error())
return
}
vc := VC(c)
if !vc.CanPerformActions() {
UnauthorizedError(c)
return
}
sb := GetSessionBackend(c)
session, err := sb.FindKey(body.SessionKey)
if err != nil {
DatabaseError(c)
return
}
db := GetDB(c)
var user User
user.ID = session.UserID
err = db.Where(&user).First(&user).Error
if err != nil {
DatabaseError(c)
return
}
if !vc.IsAdmin() && !vc.IsUserID(user.ID) {
UnauthorizedError(c)
return
}
c.JSON(200, &SessionInfoResponseBody{
SessionID: session.ID,
UserID: session.UserID,
CreatedAt: session.CreatedAt,
AccessedAt: session.AccessedAt,
})
}
type GetInfoAboutSessionsForUserRequestBody struct {
ID uint `json:"id"`
Username string `json:"username"`
}
type GetInfoAboutSessionsForUserResponseBody struct {
Sessions []*SessionInfoResponseBody `json:"sessions"`
}
func GetInfoAboutSessionsForUser(c *gin.Context) {
var body GetInfoAboutSessionsForUserRequestBody
err := c.BindJSON(&body)
if err != nil {
logrus.Errorf(err.Error())
return
}
vc := VC(c)
if !vc.CanPerformActions() {
UnauthorizedError(c)
return
}
db := GetDB(c)
var user User
user.ID = body.ID
user.Username = body.Username
err = db.Where(&user).First(&user).Error
if err != nil {
DatabaseError(c)
return
}
if !vc.IsAdmin() && !vc.IsUserID(user.ID) {
UnauthorizedError(c)
return
}
sb := GetSessionBackend(c)
sessions, err := sb.FindAllForUser(user.ID)
if err != nil {
DatabaseError(c)
return
}
var response []*SessionInfoResponseBody
for _, session := range sessions {
response = append(response, &SessionInfoResponseBody{
SessionID: session.ID,
UserID: session.UserID,
CreatedAt: session.CreatedAt,
AccessedAt: session.AccessedAt,
})
}
c.JSON(200, &GetInfoAboutSessionsForUserResponseBody{
Sessions: response,
})
}