Skip to content

Commit

Permalink
fix: only return one result set for credentials_identifier
Browse files Browse the repository at this point in the history
Closes #3105
  • Loading branch information
aeneasr committed Feb 15, 2023
1 parent 3d07161 commit 2659765
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
13 changes: 13 additions & 0 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,11 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
expectedIdentities = append(expectedIdentities, expected)
}

create := identity.NewIdentity("")
create.SetCredentials(identity.CredentialsTypePassword, identity.Credentials{Type: identity.CredentialsTypePassword, Identifiers: []string{"[email protected]"}, Config: sqlxx.JSONRawMessage(`{}`)})
create.SetCredentials(identity.CredentialsTypeWebAuthn, identity.Credentials{Type: identity.CredentialsTypeWebAuthn, Identifiers: []string{"[email protected]"}, Config: sqlxx.JSONRawMessage(`{}`)})
require.NoError(t, p.CreateIdentity(ctx, create))

actual, err := p.ListIdentities(ctx, identity.ListIdentityParameters{
Expand: identity.ExpandEverything,
})
Expand Down Expand Up @@ -643,6 +648,14 @@ func TestPool(ctx context.Context, conf *config.Config, p interface {
assert.Len(t, actual, 0)
})

t.Run("one result set even if multiple matches", func(t *testing.T) {
actual, err := p.ListIdentities(ctx, identity.ListIdentityParameters{
CredentialsIdentifier: "[email protected]",
})
require.NoError(t, err)
assert.Len(t, actual, 1)
})

t.Run("non existing identifier", func(t *testing.T) {
actual, err := p.ListIdentities(ctx, identity.ListIdentityParameters{
CredentialsIdentifier: "[email protected]",
Expand Down
8 changes: 5 additions & 3 deletions persistence/sql/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,7 @@ func (p *Persister) ListIdentities(ctx context.Context, params identity.ListIden

con := p.GetConnection(ctx)
nid := p.NetworkID(ctx)
query := con.Where("identities.nid = ?", nid).Paginate(params.Page, params.PerPage).
Order("identities.id DESC")
query := con.Where("identities.nid = ?", nid).Order("identities.id DESC")

if len(params.Expand) > 0 {
query = query.EagerPreload(params.Expand.ToEager()...)
Expand All @@ -408,7 +407,10 @@ func (p *Persister) ListIdentities(ctx context.Context, params identity.ListIden
InnerJoin("identity_credential_types ict", "ict.id = ic.identity_credential_type_id").
InnerJoin("identity_credential_identifiers ici", "ici.identity_credential_id = ic.id").
Where("(ic.nid = ? AND ici.nid = ? AND ici.identifier = ?)", nid, nid, match).
Where("ict.name IN (?)", identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword)
Where("ict.name IN (?)", identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword).
Limit(1)
} else {
query = query.Paginate(params.Page, params.PerPage)
}

/* #nosec G201 TableName is static */
Expand Down

0 comments on commit 2659765

Please sign in to comment.