From 117a497a67a76d74e161e4f228b44f4fdbcbecec Mon Sep 17 00:00:00 2001 From: John Guo Date: Thu, 5 Dec 2024 22:34:49 +0800 Subject: [PATCH] feat(os/gsession): add RegenerateId/MustRegenerateId support --- os/gsession/gsession_session.go | 88 ++++++++++++++++++----- os/gsession/gsession_z_unit_test.go | 108 ++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+), 17 deletions(-) diff --git a/os/gsession/gsession_session.go b/os/gsession/gsession_session.go index c89b4239c43..1844b51de2b 100644 --- a/os/gsession/gsession_session.go +++ b/os/gsession/gsession_session.go @@ -45,7 +45,7 @@ func (s *Session) init() error { // Retrieve stored session data from storage. if s.manager.storage != nil { s.data, err = s.manager.storage.GetSession(s.ctx, s.id, s.manager.GetTTL()) - if err != nil && err != ErrorDisabled { + if err != nil && !gerror.Is(err, ErrorDisabled) { intlog.Errorf(s.ctx, `session restoring failed for id "%s": %+v`, s.id, err) return err } @@ -59,7 +59,7 @@ func (s *Session) init() error { } else { // Use default session id creating function of storage. s.id, err = s.manager.storage.New(s.ctx, s.manager.ttl) - if err != nil && err != ErrorDisabled { + if err != nil && !gerror.Is(err, ErrorDisabled) { intlog.Errorf(s.ctx, "create session id failed: %+v", err) return err } @@ -89,12 +89,12 @@ func (s *Session) Close() error { size := s.data.Size() if s.dirty { err := s.manager.storage.SetSession(s.ctx, s.id, s.data, s.manager.ttl) - if err != nil && err != ErrorDisabled { + if err != nil && !gerror.Is(err, ErrorDisabled) { return err } } else if size > 0 { err := s.manager.storage.UpdateTTL(s.ctx, s.id, s.manager.ttl) - if err != nil && err != ErrorDisabled { + if err != nil && !gerror.Is(err, ErrorDisabled) { return err } } @@ -108,11 +108,10 @@ func (s *Session) Set(key string, value interface{}) (err error) { return err } if err = s.manager.storage.Set(s.ctx, s.id, key, value, s.manager.ttl); err != nil { - if err == ErrorDisabled { - s.data.Set(key, value) - } else { + if !gerror.Is(err, ErrorDisabled) { return err } + s.data.Set(key, value) } s.dirty = true return nil @@ -124,11 +123,10 @@ func (s *Session) SetMap(data map[string]interface{}) (err error) { return err } if err = s.manager.storage.SetMap(s.ctx, s.id, data, s.manager.ttl); err != nil { - if err == ErrorDisabled { - s.data.Sets(data) - } else { + if !gerror.Is(err, ErrorDisabled) { return err } + s.data.Sets(data) } s.dirty = true return nil @@ -144,11 +142,10 @@ func (s *Session) Remove(keys ...string) (err error) { } for _, key := range keys { if err = s.manager.storage.Remove(s.ctx, s.id, key); err != nil { - if err == ErrorDisabled { - s.data.Remove(key) - } else { + if !gerror.Is(err, ErrorDisabled) { return err } + s.data.Remove(key) } } s.dirty = true @@ -164,7 +161,7 @@ func (s *Session) RemoveAll() (err error) { return err } if err = s.manager.storage.RemoveAll(s.ctx, s.id); err != nil { - if err != ErrorDisabled { + if !gerror.Is(err, ErrorDisabled) { return err } } @@ -215,7 +212,7 @@ func (s *Session) Data() (sessionData map[string]interface{}, err error) { return nil, err } sessionData, err = s.manager.storage.Data(s.ctx, s.id) - if err != nil && err != ErrorDisabled { + if err != nil && !gerror.Is(err, ErrorDisabled) { intlog.Errorf(s.ctx, `%+v`, err) } if sessionData != nil { @@ -233,7 +230,7 @@ func (s *Session) Size() (size int, err error) { return 0, err } size, err = s.manager.storage.GetSize(s.ctx, s.id) - if err != nil && err != ErrorDisabled { + if err != nil && !gerror.Is(err, ErrorDisabled) { intlog.Errorf(s.ctx, `%+v`, err) } if size > 0 { @@ -273,7 +270,7 @@ func (s *Session) Get(key string, def ...interface{}) (value *gvar.Var, err erro return nil, err } v, err := s.manager.storage.Get(s.ctx, s.id, key) - if err != nil && err != ErrorDisabled { + if err != nil && !gerror.Is(err, ErrorDisabled) { intlog.Errorf(s.ctx, `%+v`, err) return nil, err } @@ -357,3 +354,60 @@ func (s *Session) MustRemove(keys ...string) { panic(err) } } + +// RegenerateId regenerates a new session id for current session. +// It keeps the session data and updates the session id with a new one. +// This is commonly used to prevent session fixation attacks and increase security. +// +// The parameter `deleteOld` specifies whether to delete the old session data: +// - If true: the old session data will be deleted immediately +// - If false: the old session data will be kept and expire according to its TTL +func (s *Session) RegenerateId(deleteOld bool) (newId string, err error) { + if err = s.init(); err != nil { + return "", err + } + + // Generate new session id + if s.idFunc != nil { + newId = s.idFunc(s.manager.ttl) + } else { + newId, err = s.manager.storage.New(s.ctx, s.manager.ttl) + if err != nil && !gerror.Is(err, ErrorDisabled) { + return "", err + } + if newId == "" { + newId = NewSessionId() + } + } + + // If using storage, need to copy data to new id + if s.manager.storage != nil { + if err = s.manager.storage.SetSession(s.ctx, newId, s.data, s.manager.ttl); err != nil { + if !gerror.Is(err, ErrorDisabled) { + return "", err + } + } + // Delete old session data if requested + if deleteOld { + if err = s.manager.storage.RemoveAll(s.ctx, s.id); err != nil { + if !gerror.Is(err, ErrorDisabled) { + return "", err + } + } + } + } + + // Update session id + s.id = newId + s.dirty = true + return newId, nil +} + +// MustRegenerateId performs as function RegenerateId, but it panics if any error occurs. +func (s *Session) MustRegenerateId(deleteOld bool) string { + newId, err := s.RegenerateId(deleteOld) + if err != nil { + panic(err) + } + return newId +} diff --git a/os/gsession/gsession_z_unit_test.go b/os/gsession/gsession_z_unit_test.go index 54940183cc0..613c00dc465 100644 --- a/os/gsession/gsession_z_unit_test.go +++ b/os/gsession/gsession_z_unit_test.go @@ -7,11 +7,15 @@ package gsession import ( + "context" "testing" + "time" "github.com/gogf/gf/v2/test/gtest" ) +var ctx = context.TODO() + func Test_NewSessionId(t *testing.T) { gtest.C(t, func(t *gtest.T) { id1 := NewSessionId() @@ -20,3 +24,107 @@ func Test_NewSessionId(t *testing.T) { t.Assert(len(id1), 32) }) } + +func Test_Session_RegenerateId(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + // 1. Test with memory storage + storage := NewStorageMemory() + manager := New(time.Hour, storage) + session := manager.New(ctx) + + // Store some data + err := session.Set("key1", "value1") + t.AssertNil(err) + err = session.Set("key2", "value2") + t.AssertNil(err) + + // Get original session id + oldId := session.MustId() + + // Test regenerate with deleteOld = true + newId1, err := session.RegenerateId(true) + t.AssertNil(err) + t.AssertNE(oldId, newId1) + + // Verify data is preserved + v1 := session.MustGet("key1") + t.Assert(v1.String(), "value1") + v2 := session.MustGet("key2") + t.Assert(v2.String(), "value2") + + // Verify old session is deleted + oldSession := manager.New(ctx) + err = oldSession.SetId(oldId) + t.AssertNil(err) + v3 := oldSession.MustGet("key1") + t.Assert(v3.IsNil(), true) + + // Test regenerate with deleteOld = false + currentId := newId1 + newId2, err := session.RegenerateId(false) + t.AssertNil(err) + t.AssertNE(currentId, newId2) + + // Verify data is preserved in new session + v4 := session.MustGet("key1") + t.Assert(v4.String(), "value1") + + // Create another session instance with the previous id + prevSession := manager.New(ctx) + err = prevSession.SetId(currentId) + t.AssertNil(err) + // Data should still be accessible in previous session + v5 := prevSession.MustGet("key1") + t.Assert(v5.String(), "value1") + }) + + gtest.C(t, func(t *gtest.T) { + // 2. Test with custom id function + storage := NewStorageMemory() + manager := New(time.Hour, storage) + session := manager.New(ctx) + + customId := "custom_session_id" + err := session.SetIdFunc(func(ttl time.Duration) string { + return customId + }) + t.AssertNil(err) + + newId, err := session.RegenerateId(true) + t.AssertNil(err) + t.Assert(newId, customId) + }) + + gtest.C(t, func(t *gtest.T) { + // 3. Test with disabled storage + storage := &StorageBase{} // implements Storage interface but all methods return ErrorDisabled + manager := New(time.Hour, storage) + session := manager.New(ctx) + + // Should still work even with disabled storage + newId, err := session.RegenerateId(true) + t.AssertNil(err) + t.Assert(len(newId), 32) + }) +} + +// Test MustRegenerateId +func Test_Session_MustRegenerateId(t *testing.T) { + gtest.C(t, func(t *gtest.T) { + storage := NewStorageMemory() + manager := New(time.Hour, storage) + session := manager.New(ctx) + + // Normal case should not panic + t.AssertNil(session.Set("key", "value")) + newId := session.MustRegenerateId(true) + t.Assert(len(newId), 32) + + // Test with disabled storage (should not panic) + storage2 := &StorageBase{} + manager2 := New(time.Hour, storage2) + session2 := manager2.New(ctx) + newId2 := session2.MustRegenerateId(true) + t.Assert(len(newId2), 32) + }) +}