diff --git a/go.mod b/go.mod index 28e218203320..07565e5100df 100644 --- a/go.mod +++ b/go.mod @@ -42,7 +42,7 @@ require ( github.com/go-swagger/go-swagger v0.30.3 github.com/gobuffalo/fizz v1.14.4 github.com/gobuffalo/httptest v1.5.2 - github.com/gobuffalo/pop/v6 v6.1.2-0.20230124165254-ec9229dbf7d7 + github.com/gobuffalo/pop/v6 v6.1.2-0.20230318123913-c85387acc9a0 github.com/gofrs/uuid v4.3.1+incompatible github.com/golang-jwt/jwt/v4 v4.1.0 github.com/golang/gddo v0.0.0-20190904175337-72a348e765d2 diff --git a/go.sum b/go.sum index cf8c85102e95..0e68d3f6d312 100644 --- a/go.sum +++ b/go.sum @@ -501,8 +501,8 @@ github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/V github.com/gobuffalo/plush/v4 v4.1.16/go.mod h1:6t7swVsarJ8qSLw1qyAH/KbrcSTwdun2ASEQkOznakg= github.com/gobuffalo/plush/v4 v4.1.18 h1:bnPjdMTEUQHqj9TNX2Ck3mxEXYZa+0nrFMNM07kpX9g= github.com/gobuffalo/plush/v4 v4.1.18/go.mod h1:xi2tJIhFI4UdzIL8sxZtzGYOd2xbBpcFbLZlIPGGZhU= -github.com/gobuffalo/pop/v6 v6.1.2-0.20230124165254-ec9229dbf7d7 h1:lwf/5cRw46IrLrhZnCg8J9NKgskkwMPuVvEOc2Wy72I= -github.com/gobuffalo/pop/v6 v6.1.2-0.20230124165254-ec9229dbf7d7/go.mod h1:1n7jAmI1i7fxuXPZjZb0VBPQDbksRtCoFnrDV5IsvaI= +github.com/gobuffalo/pop/v6 v6.1.2-0.20230318123913-c85387acc9a0 h1:+LF3Enal3HZ+rFmaLZfBRNHKqtnoA0d8jk0Iio8InZM= +github.com/gobuffalo/pop/v6 v6.1.2-0.20230318123913-c85387acc9a0/go.mod h1:1n7jAmI1i7fxuXPZjZb0VBPQDbksRtCoFnrDV5IsvaI= github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= github.com/gobuffalo/tags/v3 v3.1.4 h1:X/ydLLPhgXV4h04Hp2xlbI2oc5MDaa7eub6zw8oHjsM= github.com/gobuffalo/tags/v3 v3.1.4/go.mod h1:ArRNo3ErlHO8BtdA0REaZxijuWnWzF6PUXngmMXd2I0= diff --git a/identity/credentials.go b/identity/credentials.go index 7b1c485579b6..d44875e9ff99 100644 --- a/identity/credentials.go +++ b/identity/credentials.go @@ -8,12 +8,9 @@ import ( "reflect" "time" - "github.com/gobuffalo/pop/v6" - - "github.com/ory/kratos/ui/node" - "github.com/gofrs/uuid" + "github.com/ory/kratos/ui/node" "github.com/ory/x/sqlxx" ) @@ -87,7 +84,8 @@ type Credentials struct { ID uuid.UUID `json:"-" db:"id"` // Type discriminates between different types of credentials. - Type CredentialsType `json:"type" db:"-"` + Type CredentialsType `json:"type" db:"-"` + IdentityCredentialTypeID uuid.UUID `json:"-" db:"identity_credential_type_id"` // Identifiers represents a list of unique identifiers this credential type matches. Identifiers []string `json:"identifiers" db:"-"` @@ -107,26 +105,6 @@ type Credentials struct { // UpdatedAt is a helper struct field for gobuffalo.pop. UpdatedAt time.Time `json:"updated_at" db:"updated_at"` NID uuid.UUID `json:"-" faker:"-" db:"nid"` - - IdentityCredentialTypeID uuid.UUID `json:"-" db:"identity_credential_type_id"` - IdentityCredentialType CredentialsTypeTable `json:"-" faker:"-" belongs_to:"identity_credential_types"` - CredentialIdentifiers CredentialIdentifierCollection `json:"-" faker:"-" has_many:"identity_credential_identifiers" fk_id:"identity_credential_id" order_by:"id asc"` -} - -func (c *Credentials) AfterEagerFind(tx *pop.Connection) error { - return c.setCredentials() -} - -func (c *Credentials) setCredentials() error { - c.Type = c.IdentityCredentialType.Name - c.Identifiers = make([]string, 0, len(c.CredentialIdentifiers)) - for _, id := range c.CredentialIdentifiers { - if c.NID != id.NID { - continue - } - c.Identifiers = append(c.Identifiers, id.Identifier) - } - return nil } func (c Credentials) TableName(ctx context.Context) string { @@ -155,12 +133,6 @@ type ( Name CredentialsType `json:"-" db:"name"` } - // swagger:ignore - CredentialsCollection []Credentials - - // swagger:ignore - CredentialIdentifierCollection []CredentialIdentifier - // swagger:ignore ActiveCredentialsCounter interface { ID() CredentialsType @@ -178,14 +150,6 @@ func (c CredentialsTypeTable) TableName(ctx context.Context) string { return "identity_credential_types" } -func (c CredentialsCollection) TableName(ctx context.Context) string { - return "identity_credentials" -} - -func (c CredentialIdentifierCollection) TableName(ctx context.Context) string { - return "identity_credential_identifiers" -} - func (c CredentialIdentifier) TableName(ctx context.Context) string { return "identity_credential_identifiers" } diff --git a/identity/expandables.go b/identity/expandables.go index a0abf73ff288..93c752939c19 100644 --- a/identity/expandables.go +++ b/identity/expandables.go @@ -9,11 +9,9 @@ type Expandable = sqlxx.Expandable type Expandables = sqlxx.Expandables const ( - ExpandFieldVerifiableAddresses Expandable = "VerifiableAddresses" - ExpandFieldRecoveryAddresses Expandable = "RecoveryAddresses" - ExpandFieldCredentials Expandable = "InternalCredentials" - ExpandFieldCredentialType Expandable = "InternalCredentials.IdentityCredentialType" - ExpandFieldCredentialIdentifiers Expandable = "InternalCredentials.CredentialIdentifiers" + ExpandFieldVerifiableAddresses Expandable = "VerifiableAddresses" + ExpandFieldRecoveryAddresses Expandable = "RecoveryAddresses" + ExpandFieldCredentials Expandable = "Credentials" ) // ExpandNothing expands nothing @@ -31,8 +29,6 @@ var ExpandDefault = Expandables{ // ExpandCredentials expands the identity's credentials. var ExpandCredentials = Expandables{ ExpandFieldCredentials, - ExpandFieldCredentialType, - ExpandFieldCredentialIdentifiers, } // ExpandEverything expands all the fields of an identity. @@ -40,6 +36,4 @@ var ExpandEverything = Expandables{ ExpandFieldVerifiableAddresses, ExpandFieldRecoveryAddresses, ExpandFieldCredentials, - ExpandFieldCredentialType, - ExpandFieldCredentialIdentifiers, } diff --git a/identity/identity.go b/identity/identity.go index 520d824cf867..1581f7e1ed9b 100644 --- a/identity/identity.go +++ b/identity/identity.go @@ -13,8 +13,6 @@ import ( "github.com/samber/lo" - "github.com/gobuffalo/pop/v6" - "github.com/tidwall/sjson" "github.com/tidwall/gjson" @@ -121,9 +119,6 @@ type Identity struct { // Store metadata about the user which is only accessible through admin APIs such as `GET /admin/identities/`. MetadataAdmin sqlxx.NullJSONRawMessage `json:"metadata_admin,omitempty" faker:"-" db:"metadata_admin"` - // InternalCredentials is an internal representation of the credentials. - InternalCredentials CredentialsCollection `json:"-" faker:"-" has_many:"identity_credentials" fk_id:"identity_id" order_by:"id asc"` - // CreatedAt is a helper struct field for gobuffalo.pop. CreatedAt time.Time `json:"created_at" db:"created_at"` @@ -132,36 +127,6 @@ type Identity struct { NID uuid.UUID `json:"-" faker:"-" db:"nid"` } -func (i *Identity) AfterEagerFind(tx *pop.Connection) error { - if err := i.setCredentials(tx); err != nil { - return err - } - - if err := i.Validate(); err != nil { - return err - } - - return UpgradeCredentials(i) -} - -func (i *Identity) setCredentials(tx *pop.Connection) error { - creds := i.InternalCredentials - i.Credentials = make(map[CredentialsType]Credentials, len(creds)) - for k := range creds { - cred := &creds[k] - if cred.NID != i.NID { - continue - } - if err := cred.AfterEagerFind(tx); err != nil { - return err - - } - i.Credentials[cred.Type] = *cred - } - - return nil -} - // Traits represent an identity's traits. The identity is able to create, modify, and delete traits // in a self-service manner. The input will always be validated against the JSON Schema defined // in `schema_url`. diff --git a/identity/test/pool.go b/identity/test/pool.go index 5f1b2fe32827..1fcde3b22db5 100644 --- a/identity/test/pool.go +++ b/identity/test/pool.go @@ -123,7 +123,6 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, assert.Empty(t, actual.RecoveryAddresses) assert.Empty(t, actual.VerifiableAddresses) assert.Empty(t, actual.Credentials) - assert.Empty(t, actual.InternalCredentials) }) }) @@ -142,7 +141,6 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, t.Run("expand=recovery address", func(t *testing.T) { runner(t, sqlxx.Expandables{identity.ExpandFieldRecoveryAddresses}, func(t *testing.T, actual *identity.Identity) { assert.Empty(t, actual.Credentials) - assert.Empty(t, actual.InternalCredentials) assert.Empty(t, actual.VerifiableAddresses) require.Len(t, actual.RecoveryAddresses, 1) @@ -153,7 +151,6 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, t.Run("expand=verification address", func(t *testing.T) { runner(t, sqlxx.Expandables{identity.ExpandFieldVerifiableAddresses}, func(t *testing.T, actual *identity.Identity) { assert.Empty(t, actual.Credentials) - assert.Empty(t, actual.InternalCredentials) assert.Empty(t, actual.RecoveryAddresses) require.Len(t, actual.VerifiableAddresses, 1) @@ -165,7 +162,6 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister, runner(t, identity.ExpandDefault, func(t *testing.T, actual *identity.Identity) { assert.Empty(t, actual.Credentials) - assert.Empty(t, actual.InternalCredentials) require.Len(t, actual.RecoveryAddresses, 1) assertx.EqualAsJSONExcept(t, expected.RecoveryAddresses, actual.RecoveryAddresses, []string{"0.updated_at", "0.created_at"}) diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 6afa4cc5b7a3..22ddc1522f9d 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -417,7 +417,7 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i * var ( con = p.GetConnection(ctx) nid = p.NetworkID(ctx) - credentials []identity.Credentials + credentials map[identity.CredentialsType](identity.Credentials) verifiableAddresses []identity.VerifiableAddress recoveryAddresses []identity.RecoveryAddress ) @@ -455,18 +455,17 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i * } if expand.Has(identity.ExpandFieldCredentials) { - eg.Go(func() error { + eg.Go(func() (err error) { // We use WithContext to get a copy of the connection struct, which solves the race detector // from complaining incorrectly. // // https://github.com/gobuffalo/pop/issues/723 - if err := con.WithContext(ctx). - EagerPreload("IdentityCredentialType", "CredentialIdentifiers"). - Where("identity_id = ? AND nid = ?", i.ID, nid). - All(&credentials); err != nil { - return sqlcon.HandleError(err) - } - return nil + con := con.WithContext(ctx) + credentials, err = QueryForCredentials(con, Where{ + "(identity_credentials.identity_id = ? AND identity_credentials.nid = ?)", + []interface{}{i.ID, nid}, + }) + return }) } @@ -476,15 +475,93 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i * i.VerifiableAddresses = verifiableAddresses i.RecoveryAddresses = recoveryAddresses - i.InternalCredentials = credentials + i.Credentials = credentials + + if err := i.Validate(); err != nil { + return err + } - if err := i.AfterEagerFind(con); err != nil { + if err := identity.UpgradeCredentials(i); err != nil { return err } return p.InjectTraitsSchemaURL(ctx, i) } +type QueryCredentials struct { + ID uuid.UUID `db:"cred_id"` + IdentityID uuid.UUID `db:"identity_id"` + NID uuid.UUID `db:"nid"` + Type identity.CredentialsType `db:"cred_type"` + TypeID uuid.UUID `db:"cred_type_id"` + Identifier string `db:"cred_identifier"` + Config sqlxx.JSONRawMessage `db:"cred_config"` + Version int `db:"cred_version"` + CreatedAt time.Time `db:"created_at"` + UpdatedAt time.Time `db:"updated_at"` +} + +func (QueryCredentials) TableName() string { + return "identity_credentials" +} + +type Where struct { + Condition string + Args []interface{} +} + +func QueryForCredentials(con *pop.Connection, where ...Where) (map[identity.CredentialsType]identity.Credentials, error) { + q := con.Select( + "identity_credentials.id cred_id", + "identity_credentials.identity_id identity_id", + "identity_credentials.nid nid", + "ict.name cred_type", + "ict.id cred_type_id", + "COALESCE(ici.identifier, '') cred_identifier", + "identity_credentials.config cred_config", + "identity_credentials.version cred_version", + "identity_credentials.created_at created_at", + "identity_credentials.updated_at updated_at", + ).InnerJoin( + "identity_credential_types ict", + "(identity_credentials.identity_credential_type_id = ict.id)", + ).LeftJoin( + "identity_credential_identifiers ici", + "(ici.identity_credential_id = identity_credentials.id AND ici.nid = identity_credentials.nid)", + ) + for _, w := range where { + q = q.Where(w.Condition, w.Args...) + } + var creds []QueryCredentials + if err := q.All(&creds); err != nil { + return nil, sqlcon.HandleError(err) + } + credentials := map[identity.CredentialsType]identity.Credentials{} + for _, cred := range creds { + identifiers := credentials[cred.Type].Identifiers + if cred.Identifier != "" { + identifiers = append(identifiers, cred.Identifier) + } + if identifiers == nil { + identifiers = make([]string, 0) + } + c := identity.Credentials{ + ID: cred.ID, + IdentityID: cred.IdentityID, + NID: cred.NID, + Type: cred.Type, + IdentityCredentialTypeID: cred.TypeID, + Identifiers: identifiers, + Config: cred.Config, + Version: cred.Version, + CreatedAt: cred.CreatedAt, + UpdatedAt: cred.UpdatedAt, + } + credentials[cred.Type] = c + } + return credentials, nil +} + func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.ListIdentityParameters) (res []identity.Identity, err error) { ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListIdentities") defer otelx.End(span, &err) @@ -503,10 +580,6 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. nid := p.NetworkID(ctx) query := con.Where("identities.nid = ?", nid).Order("identities.id DESC") - if len(params.Expand) > 0 { - query = query.EagerPreload(params.Expand.ToEager()...) - } - if match := params.CredentialsIdentifier; len(match) > 0 { // When filtering by credentials identifier, we most likely are looking for a username or email. It is therefore // important to normalize the identifier before querying the database. @@ -526,6 +599,14 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. return nil, err } + if len(params.Expand) > 0 { + for i := range is { + if err := p.HydrateIdentityAssociations(ctx, &is[i], params.Expand); err != nil { + return nil, err + } + } + } + schemaCache := map[string]string{} for k := range is { i := &is[k] diff --git a/persistence/sql/migratest/fixtures/identity/5ff66179-c240-4703-b0d8-494592cefff5.json b/persistence/sql/migratest/fixtures/identity/5ff66179-c240-4703-b0d8-494592cefff5.json index cf61fd2a8bed..d634388cc31e 100644 --- a/persistence/sql/migratest/fixtures/identity/5ff66179-c240-4703-b0d8-494592cefff5.json +++ b/persistence/sql/migratest/fixtures/identity/5ff66179-c240-4703-b0d8-494592cefff5.json @@ -4,8 +4,8 @@ "password": { "type": "password", "identifiers": [ - "foo-dupe@ory.sh", - "foo@ory.sh" + "foo@ory.sh", + "foo-dupe@ory.sh" ], "config": { "hashed_password": "$argon2id$v=19$m=131072,t=2,p=1$lQFPaKxXqPL56/mU7vRi4w$6aldHyBnURt8sP8+xu41Ng" diff --git a/persistence/sql/migratest/fixtures/identity/a251ebc2-880c-4f76-a8f3-38e6940eab0e.json b/persistence/sql/migratest/fixtures/identity/a251ebc2-880c-4f76-a8f3-38e6940eab0e.json index 9c63190a6fa7..3fe737773cef 100644 --- a/persistence/sql/migratest/fixtures/identity/a251ebc2-880c-4f76-a8f3-38e6940eab0e.json +++ b/persistence/sql/migratest/fixtures/identity/a251ebc2-880c-4f76-a8f3-38e6940eab0e.json @@ -4,8 +4,8 @@ "password": { "type": "password", "identifiers": [ - "foo-dupe@ory.sh", - "foobar@ory.sh" + "foobar@ory.sh", + "foo-dupe@ory.sh" ], "config": { "hashed_password": "$argon2id$v=19$m=131072,t=2,p=1$lQFPaKxXqPL56/mU7vRi4w$6aldHyBnURt8sP8+xu41Ng" diff --git a/persistence/sql/migratest/migration_test.go b/persistence/sql/migratest/migration_test.go index 21481c268289..36126f147935 100644 --- a/persistence/sql/migratest/migration_test.go +++ b/persistence/sql/migratest/migration_test.go @@ -66,7 +66,8 @@ func CompareWithFixture(t *testing.T, actual interface{}, prefix string, id stri s := snapshotFor("fixtures", prefix) actualJSON, err := json.MarshalIndent(actual, "", " ") require.NoError(t, err) - assert.NoError(t, s.SnapshotWithName(id, actualJSON)) + err = s.SnapshotWithName(id, actualJSON) + assert.NoErrorf(t, err, "actual = %s", string(actualJSON)) } func TestMigrations_SQLite(t *testing.T) { diff --git a/persistence/sql/persister_test.go b/persistence/sql/persister_test.go index f514b22fb4ff..3a4374f82882 100644 --- a/persistence/sql/persister_test.go +++ b/persistence/sql/persister_test.go @@ -58,12 +58,8 @@ func init() { } func TestMain(m *testing.M) { - atexit := dockertest.NewOnExit() - atexit.Add(func() { - // _ = os.Remove(strings.TrimPrefix(sqlite, "sqlite://")) - dockertest.KillAllTestDatabases() - }) - atexit.Exit(m.Run()) + m.Run() + dockertest.KillAllTestDatabases() } func pl(t *testing.T) func(lvl logging.Level, s string, args ...interface{}) { diff --git a/selfservice/hook/web_hook_integration_test.go b/selfservice/hook/web_hook_integration_test.go index 1525f1410266..94569e80b6da 100644 --- a/selfservice/hook/web_hook_integration_test.go +++ b/selfservice/hook/web_hook_integration_test.go @@ -682,9 +682,8 @@ func TestWebHooks(t *testing.T) { Value: "some@example.org", Via: "email", }}, - MetadataPublic: []byte(`{"public":"data"}`), - MetadataAdmin: []byte(`{"admin":"data"}`), - InternalCredentials: identity.CredentialsCollection{{Type: "password", Identifiers: []string{"test"}, Config: []byte(`{}`)}}, + MetadataPublic: []byte(`{"public":"data"}`), + MetadataAdmin: []byte(`{"admin":"data"}`), } t.Run("case=body is empty", func(t *testing.T) { diff --git a/selfservice/strategy/lookup/strategy_test.go b/selfservice/strategy/lookup/strategy_test.go index 9d189e68ed14..f9674ab2805e 100644 --- a/selfservice/strategy/lookup/strategy_test.go +++ b/selfservice/strategy/lookup/strategy_test.go @@ -28,25 +28,25 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { t.Run("multi factor", func(t *testing.T) { for k, tc := range []struct { - in identity.CredentialsCollection + in map[identity.CredentialsType]identity.Credentials expected int }{ { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Config: []byte{}, }}, expected: 0, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Config: []byte(`{"recovery_codes": []}`), }}, expected: 0, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{"foo"}, Config: []byte(`{"recovery_codes": [{}]}`), @@ -54,24 +54,19 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { expected: 1, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Config: []byte(`{}`), }}, expected: 0, }, { - in: identity.CredentialsCollection{{}, {}}, + in: nil, expected: 0, }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - cc := map[identity.CredentialsType]identity.Credentials{} - for _, c := range tc.in { - cc[c.Type] = c - } - - actual, err := strategy.CountActiveMultiFactorCredentials(cc) + actual, err := strategy.CountActiveMultiFactorCredentials(tc.in) require.NoError(t, err) assert.Equal(t, tc.expected, actual) }) diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index 1505cfe9d2ac..8fda01141e45 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -32,8 +32,6 @@ import ( "github.com/stretchr/testify/require" "github.com/tidwall/gjson" - "github.com/ory/x/sqlxx" - "github.com/ory/x/urlx" "github.com/ory/kratos/driver/config" @@ -671,17 +669,17 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { } for k, tc := range []struct { - in identity.CredentialsCollection + in map[identity.CredentialsType]identity.Credentials expected int }{ { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), - Config: sqlxx.JSONRawMessage{}, + Config: []byte{}, }}, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Config: toJson(identity.CredentialsOIDC{Providers: []identity.CredentialsOIDCProvider{ {Subject: "foo", Provider: "bar"}, @@ -689,7 +687,7 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { }}, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{""}, Config: toJson(identity.CredentialsOIDC{Providers: []identity.CredentialsOIDCProvider{ @@ -698,7 +696,7 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { }}, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{"bar:"}, Config: toJson(identity.CredentialsOIDC{Providers: []identity.CredentialsOIDCProvider{ @@ -707,7 +705,7 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { }}, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{":foo"}, Config: toJson(identity.CredentialsOIDC{Providers: []identity.CredentialsOIDCProvider{ @@ -716,7 +714,7 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { }}, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{"not-bar:foo"}, Config: toJson(identity.CredentialsOIDC{Providers: []identity.CredentialsOIDCProvider{ @@ -725,7 +723,7 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { }}, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{"bar:not-foo"}, Config: toJson(identity.CredentialsOIDC{Providers: []identity.CredentialsOIDCProvider{ @@ -734,7 +732,7 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { }}, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{"bar:foo"}, Config: toJson(identity.CredentialsOIDC{Providers: []identity.CredentialsOIDCProvider{ diff --git a/selfservice/strategy/password/strategy_test.go b/selfservice/strategy/password/strategy_test.go index c1b17ea15319..e27a0950ca28 100644 --- a/selfservice/strategy/password/strategy_test.go +++ b/selfservice/strategy/password/strategy_test.go @@ -29,25 +29,25 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { require.NoError(t, err) for k, tc := range []struct { - in identity.CredentialsCollection + in map[identity.CredentialsType]identity.Credentials expected int }{ { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Config: []byte{}, }}, expected: 0, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Config: []byte(`{"hashed_password": "` + string(h1) + `"}`), }}, expected: 0, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{""}, Config: []byte(`{"hashed_password": "` + string(h1) + `"}`), @@ -55,7 +55,7 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { expected: 0, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{"foo"}, Config: []byte(`{"hashed_password": "` + string(h1) + `"}`), @@ -63,7 +63,7 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { expected: 1, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{"foo"}, Config: []byte(`{"hashed_password": "` + string(h2) + `"}`), @@ -71,36 +71,31 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { expected: 1, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Config: []byte(`{"hashed_password": "asdf"}`), }}, expected: 0, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Config: []byte(`{}`), }}, expected: 0, }, { - in: identity.CredentialsCollection{{}, {}}, + in: nil, expected: 0, }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { - cc := map[identity.CredentialsType]identity.Credentials{} - for _, c := range tc.in { - cc[c.Type] = c - } - - actual, err := strategy.CountActiveFirstFactorCredentials(cc) - require.NoError(t, err) + actual, err := strategy.CountActiveFirstFactorCredentials(tc.in) + assert.NoError(t, err) assert.Equal(t, tc.expected, actual) - actual, err = strategy.CountActiveMultiFactorCredentials(cc) - require.NoError(t, err) + actual, err = strategy.CountActiveMultiFactorCredentials(tc.in) + assert.NoError(t, err) assert.Equal(t, 0, actual) }) } diff --git a/selfservice/strategy/totp/strategy_test.go b/selfservice/strategy/totp/strategy_test.go index 60a75e0f6ff7..17c507c9d144 100644 --- a/selfservice/strategy/totp/strategy_test.go +++ b/selfservice/strategy/totp/strategy_test.go @@ -32,25 +32,25 @@ func TestCountActiveCredentials(t *testing.T) { t.Run("multi factor", func(t *testing.T) { for k, tc := range []struct { - in identity.CredentialsCollection + in map[identity.CredentialsType]identity.Credentials expected int }{ { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Config: []byte{}, }}, expected: 0, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Config: []byte(`{"totp_url": ""}`), }}, expected: 0, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{"foo"}, Config: []byte(`{"totp_url": "` + key.URL() + `"}`), @@ -58,14 +58,14 @@ func TestCountActiveCredentials(t *testing.T) { expected: 1, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Config: []byte(`{}`), }}, expected: 0, }, { - in: identity.CredentialsCollection{{}, {}}, + in: nil, expected: 0, }, } { diff --git a/selfservice/strategy/webauthn/strategy_test.go b/selfservice/strategy/webauthn/strategy_test.go index 6926b9a6e76e..699a36ab6cbb 100644 --- a/selfservice/strategy/webauthn/strategy_test.go +++ b/selfservice/strategy/webauthn/strategy_test.go @@ -40,24 +40,24 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { strategy := webauthn.NewStrategy(reg) for k, tc := range []struct { - in identity.CredentialsCollection + in map[identity.CredentialsType]identity.Credentials expectedFirst int expectedMulti int }{ { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Config: []byte{}, }}, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Config: []byte(`{"credentials": []}`), }}, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{"foo"}, Config: []byte(`{"credentials": [{}]}`), @@ -65,7 +65,7 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { expectedMulti: 1, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{"foo"}, Config: []byte(`{"credentials": [{"is_passwordless": true}]}`), @@ -73,7 +73,7 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { expectedFirst: 1, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{"foo"}, Config: []byte(`{"credentials": [{"is_passwordless": true}, {"is_passwordless": true}]}`), @@ -81,7 +81,7 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { expectedFirst: 2, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Identifiers: []string{"foo"}, Config: []byte(`{"credentials": [{"is_passwordless": true}, {"is_passwordless": false}]}`), @@ -90,13 +90,13 @@ func TestCountActiveFirstFactorCredentials(t *testing.T) { expectedMulti: 1, }, { - in: identity.CredentialsCollection{{ + in: map[identity.CredentialsType]identity.Credentials{strategy.ID(): { Type: strategy.ID(), Config: []byte(`{}`), }}, }, { - in: identity.CredentialsCollection{{}, {}}, + in: nil, }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { diff --git a/x/xsql/sql.go b/x/xsql/sql.go index ba558e84426d..176d3a494dbb 100644 --- a/x/xsql/sql.go +++ b/x/xsql/sql.go @@ -48,8 +48,8 @@ func CleanSQL(t *testing.T, c *pop.Connection) { new(errorx.ErrorContainer).TableName(ctx), - new(identity.CredentialIdentifierCollection).TableName(ctx), - new(identity.CredentialsCollection).TableName(ctx), + new(identity.CredentialIdentifier).TableName(ctx), + new(identity.Credentials).TableName(ctx), new(identity.VerifiableAddress).TableName(ctx), new(identity.RecoveryAddress).TableName(ctx), new(identity.Identity).TableName(ctx),