Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(os/gsession): add RegenerateId/MustRegenerateId support #4012

Merged
merged 1 commit into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 71 additions & 17 deletions os/gsession/gsession_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (s *Session) init() error {
// Retrieve stored session data from storage.
if s.manager.storage != nil {
s.data, err = s.manager.storage.GetSession(s.ctx, s.id, s.manager.GetTTL())
if err != nil && err != ErrorDisabled {
if err != nil && !gerror.Is(err, ErrorDisabled) {
intlog.Errorf(s.ctx, `session restoring failed for id "%s": %+v`, s.id, err)
return err
}
Expand All @@ -59,7 +59,7 @@ func (s *Session) init() error {
} else {
// Use default session id creating function of storage.
s.id, err = s.manager.storage.New(s.ctx, s.manager.ttl)
if err != nil && err != ErrorDisabled {
if err != nil && !gerror.Is(err, ErrorDisabled) {
intlog.Errorf(s.ctx, "create session id failed: %+v", err)
return err
}
Expand Down Expand Up @@ -89,12 +89,12 @@ func (s *Session) Close() error {
size := s.data.Size()
if s.dirty {
err := s.manager.storage.SetSession(s.ctx, s.id, s.data, s.manager.ttl)
if err != nil && err != ErrorDisabled {
if err != nil && !gerror.Is(err, ErrorDisabled) {
return err
}
} else if size > 0 {
err := s.manager.storage.UpdateTTL(s.ctx, s.id, s.manager.ttl)
if err != nil && err != ErrorDisabled {
if err != nil && !gerror.Is(err, ErrorDisabled) {
return err
}
}
Expand All @@ -108,11 +108,10 @@ func (s *Session) Set(key string, value interface{}) (err error) {
return err
}
if err = s.manager.storage.Set(s.ctx, s.id, key, value, s.manager.ttl); err != nil {
if err == ErrorDisabled {
s.data.Set(key, value)
} else {
if !gerror.Is(err, ErrorDisabled) {
return err
}
s.data.Set(key, value)
}
s.dirty = true
return nil
Expand All @@ -124,11 +123,10 @@ func (s *Session) SetMap(data map[string]interface{}) (err error) {
return err
}
if err = s.manager.storage.SetMap(s.ctx, s.id, data, s.manager.ttl); err != nil {
if err == ErrorDisabled {
s.data.Sets(data)
} else {
if !gerror.Is(err, ErrorDisabled) {
return err
}
s.data.Sets(data)
}
s.dirty = true
return nil
Expand All @@ -144,11 +142,10 @@ func (s *Session) Remove(keys ...string) (err error) {
}
for _, key := range keys {
if err = s.manager.storage.Remove(s.ctx, s.id, key); err != nil {
if err == ErrorDisabled {
s.data.Remove(key)
} else {
if !gerror.Is(err, ErrorDisabled) {
return err
}
s.data.Remove(key)
}
}
s.dirty = true
Expand All @@ -164,7 +161,7 @@ func (s *Session) RemoveAll() (err error) {
return err
}
if err = s.manager.storage.RemoveAll(s.ctx, s.id); err != nil {
if err != ErrorDisabled {
if !gerror.Is(err, ErrorDisabled) {
return err
}
}
Expand Down Expand Up @@ -215,7 +212,7 @@ func (s *Session) Data() (sessionData map[string]interface{}, err error) {
return nil, err
}
sessionData, err = s.manager.storage.Data(s.ctx, s.id)
if err != nil && err != ErrorDisabled {
if err != nil && !gerror.Is(err, ErrorDisabled) {
intlog.Errorf(s.ctx, `%+v`, err)
}
if sessionData != nil {
Expand All @@ -233,7 +230,7 @@ func (s *Session) Size() (size int, err error) {
return 0, err
}
size, err = s.manager.storage.GetSize(s.ctx, s.id)
if err != nil && err != ErrorDisabled {
if err != nil && !gerror.Is(err, ErrorDisabled) {
intlog.Errorf(s.ctx, `%+v`, err)
}
if size > 0 {
Expand Down Expand Up @@ -273,7 +270,7 @@ func (s *Session) Get(key string, def ...interface{}) (value *gvar.Var, err erro
return nil, err
}
v, err := s.manager.storage.Get(s.ctx, s.id, key)
if err != nil && err != ErrorDisabled {
if err != nil && !gerror.Is(err, ErrorDisabled) {
intlog.Errorf(s.ctx, `%+v`, err)
return nil, err
}
Expand Down Expand Up @@ -357,3 +354,60 @@ func (s *Session) MustRemove(keys ...string) {
panic(err)
}
}

// RegenerateId regenerates a new session id for current session.
// It keeps the session data and updates the session id with a new one.
// This is commonly used to prevent session fixation attacks and increase security.
//
// The parameter `deleteOld` specifies whether to delete the old session data:
// - If true: the old session data will be deleted immediately
// - If false: the old session data will be kept and expire according to its TTL
func (s *Session) RegenerateId(deleteOld bool) (newId string, err error) {
if err = s.init(); err != nil {
return "", err
}

// Generate new session id
if s.idFunc != nil {
newId = s.idFunc(s.manager.ttl)
} else {
newId, err = s.manager.storage.New(s.ctx, s.manager.ttl)
if err != nil && !gerror.Is(err, ErrorDisabled) {
return "", err
}
if newId == "" {
newId = NewSessionId()
}
}

// If using storage, need to copy data to new id
if s.manager.storage != nil {
if err = s.manager.storage.SetSession(s.ctx, newId, s.data, s.manager.ttl); err != nil {
if !gerror.Is(err, ErrorDisabled) {
return "", err
}
}
// Delete old session data if requested
if deleteOld {
if err = s.manager.storage.RemoveAll(s.ctx, s.id); err != nil {
if !gerror.Is(err, ErrorDisabled) {
return "", err
}
}
}
}

// Update session id
s.id = newId
s.dirty = true
return newId, nil
}

// MustRegenerateId performs as function RegenerateId, but it panics if any error occurs.
func (s *Session) MustRegenerateId(deleteOld bool) string {
newId, err := s.RegenerateId(deleteOld)
if err != nil {
panic(err)
}
return newId
}
108 changes: 108 additions & 0 deletions os/gsession/gsession_z_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
package gsession

import (
"context"
"testing"
"time"

"github.com/gogf/gf/v2/test/gtest"
)

var ctx = context.TODO()

func Test_NewSessionId(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
id1 := NewSessionId()
Expand All @@ -20,3 +24,107 @@ func Test_NewSessionId(t *testing.T) {
t.Assert(len(id1), 32)
})
}

func Test_Session_RegenerateId(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
// 1. Test with memory storage
storage := NewStorageMemory()
manager := New(time.Hour, storage)
session := manager.New(ctx)

// Store some data
err := session.Set("key1", "value1")
t.AssertNil(err)
err = session.Set("key2", "value2")
t.AssertNil(err)

// Get original session id
oldId := session.MustId()

// Test regenerate with deleteOld = true
newId1, err := session.RegenerateId(true)
t.AssertNil(err)
t.AssertNE(oldId, newId1)

// Verify data is preserved
v1 := session.MustGet("key1")
t.Assert(v1.String(), "value1")
v2 := session.MustGet("key2")
t.Assert(v2.String(), "value2")

// Verify old session is deleted
oldSession := manager.New(ctx)
err = oldSession.SetId(oldId)
t.AssertNil(err)
v3 := oldSession.MustGet("key1")
t.Assert(v3.IsNil(), true)

// Test regenerate with deleteOld = false
currentId := newId1
newId2, err := session.RegenerateId(false)
t.AssertNil(err)
t.AssertNE(currentId, newId2)

// Verify data is preserved in new session
v4 := session.MustGet("key1")
t.Assert(v4.String(), "value1")

// Create another session instance with the previous id
prevSession := manager.New(ctx)
err = prevSession.SetId(currentId)
t.AssertNil(err)
// Data should still be accessible in previous session
v5 := prevSession.MustGet("key1")
t.Assert(v5.String(), "value1")
})

gtest.C(t, func(t *gtest.T) {
// 2. Test with custom id function
storage := NewStorageMemory()
manager := New(time.Hour, storage)
session := manager.New(ctx)

customId := "custom_session_id"
err := session.SetIdFunc(func(ttl time.Duration) string {
return customId
})
t.AssertNil(err)

newId, err := session.RegenerateId(true)
t.AssertNil(err)
t.Assert(newId, customId)
})

gtest.C(t, func(t *gtest.T) {
// 3. Test with disabled storage
storage := &StorageBase{} // implements Storage interface but all methods return ErrorDisabled
manager := New(time.Hour, storage)
session := manager.New(ctx)

// Should still work even with disabled storage
newId, err := session.RegenerateId(true)
t.AssertNil(err)
t.Assert(len(newId), 32)
})
}

// Test MustRegenerateId
func Test_Session_MustRegenerateId(t *testing.T) {
gtest.C(t, func(t *gtest.T) {
storage := NewStorageMemory()
manager := New(time.Hour, storage)
session := manager.New(ctx)

// Normal case should not panic
t.AssertNil(session.Set("key", "value"))
newId := session.MustRegenerateId(true)
t.Assert(len(newId), 32)

// Test with disabled storage (should not panic)
storage2 := &StorageBase{}
manager2 := New(time.Hour, storage2)
session2 := manager2.New(ctx)
newId2 := session2.MustRegenerateId(true)
t.Assert(len(newId2), 32)
})
}
Loading