Skip to content

Commit

Permalink
feat: refactor credentials fetching
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr committed Mar 22, 2023
1 parent f905408 commit ed9355e
Show file tree
Hide file tree
Showing 18 changed files with 162 additions and 178 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ require (
github.com/go-swagger/go-swagger v0.30.3
github.com/gobuffalo/fizz v1.14.4
github.com/gobuffalo/httptest v1.5.2
github.com/gobuffalo/pop/v6 v6.1.2-0.20230124165254-ec9229dbf7d7
github.com/gobuffalo/pop/v6 v6.1.2-0.20230318123913-c85387acc9a0
github.com/gofrs/uuid v4.3.1+incompatible
github.com/golang-jwt/jwt/v4 v4.1.0
github.com/golang/gddo v0.0.0-20190904175337-72a348e765d2
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,8 @@ github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/V
github.com/gobuffalo/plush/v4 v4.1.16/go.mod h1:6t7swVsarJ8qSLw1qyAH/KbrcSTwdun2ASEQkOznakg=
github.com/gobuffalo/plush/v4 v4.1.18 h1:bnPjdMTEUQHqj9TNX2Ck3mxEXYZa+0nrFMNM07kpX9g=
github.com/gobuffalo/plush/v4 v4.1.18/go.mod h1:xi2tJIhFI4UdzIL8sxZtzGYOd2xbBpcFbLZlIPGGZhU=
github.com/gobuffalo/pop/v6 v6.1.2-0.20230124165254-ec9229dbf7d7 h1:lwf/5cRw46IrLrhZnCg8J9NKgskkwMPuVvEOc2Wy72I=
github.com/gobuffalo/pop/v6 v6.1.2-0.20230124165254-ec9229dbf7d7/go.mod h1:1n7jAmI1i7fxuXPZjZb0VBPQDbksRtCoFnrDV5IsvaI=
github.com/gobuffalo/pop/v6 v6.1.2-0.20230318123913-c85387acc9a0 h1:+LF3Enal3HZ+rFmaLZfBRNHKqtnoA0d8jk0Iio8InZM=
github.com/gobuffalo/pop/v6 v6.1.2-0.20230318123913-c85387acc9a0/go.mod h1:1n7jAmI1i7fxuXPZjZb0VBPQDbksRtCoFnrDV5IsvaI=
github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw=
github.com/gobuffalo/tags/v3 v3.1.4 h1:X/ydLLPhgXV4h04Hp2xlbI2oc5MDaa7eub6zw8oHjsM=
github.com/gobuffalo/tags/v3 v3.1.4/go.mod h1:ArRNo3ErlHO8BtdA0REaZxijuWnWzF6PUXngmMXd2I0=
Expand Down
42 changes: 3 additions & 39 deletions identity/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,9 @@ import (
"reflect"
"time"

"github.com/gobuffalo/pop/v6"

"github.com/ory/kratos/ui/node"

"github.com/gofrs/uuid"

"github.com/ory/kratos/ui/node"
"github.com/ory/x/sqlxx"
)

Expand Down Expand Up @@ -87,7 +84,8 @@ type Credentials struct {
ID uuid.UUID `json:"-" db:"id"`

// Type discriminates between different types of credentials.
Type CredentialsType `json:"type" db:"-"`
Type CredentialsType `json:"type" db:"-"`
IdentityCredentialTypeID uuid.UUID `json:"-" db:"identity_credential_type_id"`

// Identifiers represents a list of unique identifiers this credential type matches.
Identifiers []string `json:"identifiers" db:"-"`
Expand All @@ -107,26 +105,6 @@ type Credentials struct {
// UpdatedAt is a helper struct field for gobuffalo.pop.
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
NID uuid.UUID `json:"-" faker:"-" db:"nid"`

IdentityCredentialTypeID uuid.UUID `json:"-" db:"identity_credential_type_id"`
IdentityCredentialType CredentialsTypeTable `json:"-" faker:"-" belongs_to:"identity_credential_types"`
CredentialIdentifiers CredentialIdentifierCollection `json:"-" faker:"-" has_many:"identity_credential_identifiers" fk_id:"identity_credential_id" order_by:"id asc"`
}

func (c *Credentials) AfterEagerFind(tx *pop.Connection) error {
return c.setCredentials()
}

func (c *Credentials) setCredentials() error {
c.Type = c.IdentityCredentialType.Name
c.Identifiers = make([]string, 0, len(c.CredentialIdentifiers))
for _, id := range c.CredentialIdentifiers {
if c.NID != id.NID {
continue
}
c.Identifiers = append(c.Identifiers, id.Identifier)
}
return nil
}

func (c Credentials) TableName(ctx context.Context) string {
Expand Down Expand Up @@ -155,12 +133,6 @@ type (
Name CredentialsType `json:"-" db:"name"`
}

// swagger:ignore
CredentialsCollection []Credentials

// swagger:ignore
CredentialIdentifierCollection []CredentialIdentifier

// swagger:ignore
ActiveCredentialsCounter interface {
ID() CredentialsType
Expand All @@ -178,14 +150,6 @@ func (c CredentialsTypeTable) TableName(ctx context.Context) string {
return "identity_credential_types"
}

func (c CredentialsCollection) TableName(ctx context.Context) string {
return "identity_credentials"
}

func (c CredentialIdentifierCollection) TableName(ctx context.Context) string {
return "identity_credential_identifiers"
}

func (c CredentialIdentifier) TableName(ctx context.Context) string {
return "identity_credential_identifiers"
}
Expand Down
12 changes: 3 additions & 9 deletions identity/expandables.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ type Expandable = sqlxx.Expandable
type Expandables = sqlxx.Expandables

const (
ExpandFieldVerifiableAddresses Expandable = "VerifiableAddresses"
ExpandFieldRecoveryAddresses Expandable = "RecoveryAddresses"
ExpandFieldCredentials Expandable = "InternalCredentials"
ExpandFieldCredentialType Expandable = "InternalCredentials.IdentityCredentialType"
ExpandFieldCredentialIdentifiers Expandable = "InternalCredentials.CredentialIdentifiers"
ExpandFieldVerifiableAddresses Expandable = "VerifiableAddresses"
ExpandFieldRecoveryAddresses Expandable = "RecoveryAddresses"
ExpandFieldCredentials Expandable = "Credentials"
)

// ExpandNothing expands nothing
Expand All @@ -31,15 +29,11 @@ var ExpandDefault = Expandables{
// ExpandCredentials expands the identity's credentials.
var ExpandCredentials = Expandables{
ExpandFieldCredentials,
ExpandFieldCredentialType,
ExpandFieldCredentialIdentifiers,
}

// ExpandEverything expands all the fields of an identity.
var ExpandEverything = Expandables{
ExpandFieldVerifiableAddresses,
ExpandFieldRecoveryAddresses,
ExpandFieldCredentials,
ExpandFieldCredentialType,
ExpandFieldCredentialIdentifiers,
}
35 changes: 0 additions & 35 deletions identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ import (

"github.com/samber/lo"

"github.com/gobuffalo/pop/v6"

"github.com/tidwall/sjson"

"github.com/tidwall/gjson"
Expand Down Expand Up @@ -121,9 +119,6 @@ type Identity struct {
// Store metadata about the user which is only accessible through admin APIs such as `GET /admin/identities/<id>`.
MetadataAdmin sqlxx.NullJSONRawMessage `json:"metadata_admin,omitempty" faker:"-" db:"metadata_admin"`

// InternalCredentials is an internal representation of the credentials.
InternalCredentials CredentialsCollection `json:"-" faker:"-" has_many:"identity_credentials" fk_id:"identity_id" order_by:"id asc"`

// CreatedAt is a helper struct field for gobuffalo.pop.
CreatedAt time.Time `json:"created_at" db:"created_at"`

Expand All @@ -132,36 +127,6 @@ type Identity struct {
NID uuid.UUID `json:"-" faker:"-" db:"nid"`
}

func (i *Identity) AfterEagerFind(tx *pop.Connection) error {
if err := i.setCredentials(tx); err != nil {
return err
}

if err := i.Validate(); err != nil {
return err
}

return UpgradeCredentials(i)
}

func (i *Identity) setCredentials(tx *pop.Connection) error {
creds := i.InternalCredentials
i.Credentials = make(map[CredentialsType]Credentials, len(creds))
for k := range creds {
cred := &creds[k]
if cred.NID != i.NID {
continue
}
if err := cred.AfterEagerFind(tx); err != nil {
return err

}
i.Credentials[cred.Type] = *cred
}

return nil
}

// Traits represent an identity's traits. The identity is able to create, modify, and delete traits
// in a self-service manner. The input will always be validated against the JSON Schema defined
// in `schema_url`.
Expand Down
4 changes: 0 additions & 4 deletions identity/test/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister,
assert.Empty(t, actual.RecoveryAddresses)
assert.Empty(t, actual.VerifiableAddresses)
assert.Empty(t, actual.Credentials)
assert.Empty(t, actual.InternalCredentials)
})
})

Expand All @@ -142,7 +141,6 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister,
t.Run("expand=recovery address", func(t *testing.T) {
runner(t, sqlxx.Expandables{identity.ExpandFieldRecoveryAddresses}, func(t *testing.T, actual *identity.Identity) {
assert.Empty(t, actual.Credentials)
assert.Empty(t, actual.InternalCredentials)
assert.Empty(t, actual.VerifiableAddresses)

require.Len(t, actual.RecoveryAddresses, 1)
Expand All @@ -153,7 +151,6 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister,
t.Run("expand=verification address", func(t *testing.T) {
runner(t, sqlxx.Expandables{identity.ExpandFieldVerifiableAddresses}, func(t *testing.T, actual *identity.Identity) {
assert.Empty(t, actual.Credentials)
assert.Empty(t, actual.InternalCredentials)
assert.Empty(t, actual.RecoveryAddresses)

require.Len(t, actual.VerifiableAddresses, 1)
Expand All @@ -165,7 +162,6 @@ func TestPool(ctx context.Context, conf *config.Config, p persistence.Persister,
runner(t, identity.ExpandDefault, func(t *testing.T, actual *identity.Identity) {

assert.Empty(t, actual.Credentials)
assert.Empty(t, actual.InternalCredentials)

require.Len(t, actual.RecoveryAddresses, 1)
assertx.EqualAsJSONExcept(t, expected.RecoveryAddresses, actual.RecoveryAddresses, []string{"0.updated_at", "0.created_at"})
Expand Down
111 changes: 96 additions & 15 deletions persistence/sql/identity/persister_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *
var (
con = p.GetConnection(ctx)
nid = p.NetworkID(ctx)
credentials []identity.Credentials
credentials map[identity.CredentialsType](identity.Credentials)
verifiableAddresses []identity.VerifiableAddress
recoveryAddresses []identity.RecoveryAddress
)
Expand Down Expand Up @@ -455,18 +455,17 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *
}

if expand.Has(identity.ExpandFieldCredentials) {
eg.Go(func() error {
eg.Go(func() (err error) {
// We use WithContext to get a copy of the connection struct, which solves the race detector
// from complaining incorrectly.
//
// https://github.com/gobuffalo/pop/issues/723
if err := con.WithContext(ctx).
EagerPreload("IdentityCredentialType", "CredentialIdentifiers").
Where("identity_id = ? AND nid = ?", i.ID, nid).
All(&credentials); err != nil {
return sqlcon.HandleError(err)
}
return nil
con := con.WithContext(ctx)
credentials, err = QueryForCredentials(con, Where{
"(identity_credentials.identity_id = ? AND identity_credentials.nid = ?)",
[]interface{}{i.ID, nid},
})
return
})
}

Expand All @@ -476,15 +475,93 @@ func (p *IdentityPersister) HydrateIdentityAssociations(ctx context.Context, i *

i.VerifiableAddresses = verifiableAddresses
i.RecoveryAddresses = recoveryAddresses
i.InternalCredentials = credentials
i.Credentials = credentials

if err := i.Validate(); err != nil {
return err
}

if err := i.AfterEagerFind(con); err != nil {
if err := identity.UpgradeCredentials(i); err != nil {
return err
}

return p.InjectTraitsSchemaURL(ctx, i)
}

type QueryCredentials struct {
ID uuid.UUID `db:"cred_id"`
IdentityID uuid.UUID `db:"identity_id"`
NID uuid.UUID `db:"nid"`
Type identity.CredentialsType `db:"cred_type"`
TypeID uuid.UUID `db:"cred_type_id"`
Identifier string `db:"cred_identifier"`
Config sqlxx.JSONRawMessage `db:"cred_config"`
Version int `db:"cred_version"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}

func (QueryCredentials) TableName() string {
return "identity_credentials"
}

type Where struct {
Condition string
Args []interface{}
}

func QueryForCredentials(con *pop.Connection, where ...Where) (map[identity.CredentialsType]identity.Credentials, error) {
q := con.Select(
"identity_credentials.id cred_id",
"identity_credentials.identity_id identity_id",
"identity_credentials.nid nid",
"ict.name cred_type",
"ict.id cred_type_id",
"COALESCE(ici.identifier, '') cred_identifier",
"identity_credentials.config cred_config",
"identity_credentials.version cred_version",
"identity_credentials.created_at created_at",
"identity_credentials.updated_at updated_at",
).InnerJoin(
"identity_credential_types ict",
"(identity_credentials.identity_credential_type_id = ict.id)",
).LeftJoin(
"identity_credential_identifiers ici",
"(ici.identity_credential_id = identity_credentials.id AND ici.nid = identity_credentials.nid)",
)
for _, w := range where {
q = q.Where(w.Condition, w.Args...)
}
var creds []QueryCredentials
if err := q.All(&creds); err != nil {
return nil, sqlcon.HandleError(err)
}
credentials := map[identity.CredentialsType]identity.Credentials{}
for _, cred := range creds {
identifiers := credentials[cred.Type].Identifiers
if cred.Identifier != "" {
identifiers = append(identifiers, cred.Identifier)
}
if identifiers == nil {
identifiers = make([]string, 0)
}
c := identity.Credentials{
ID: cred.ID,
IdentityID: cred.IdentityID,
NID: cred.NID,
Type: cred.Type,
IdentityCredentialTypeID: cred.TypeID,
Identifiers: identifiers,
Config: cred.Config,
Version: cred.Version,
CreatedAt: cred.CreatedAt,
UpdatedAt: cred.UpdatedAt,
}
credentials[cred.Type] = c
}
return credentials, nil
}

func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.ListIdentityParameters) (res []identity.Identity, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.ListIdentities")
defer otelx.End(span, &err)
Expand All @@ -503,10 +580,6 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.
nid := p.NetworkID(ctx)
query := con.Where("identities.nid = ?", nid).Order("identities.id DESC")

if len(params.Expand) > 0 {
query = query.EagerPreload(params.Expand.ToEager()...)
}

if match := params.CredentialsIdentifier; len(match) > 0 {
// When filtering by credentials identifier, we most likely are looking for a username or email. It is therefore
// important to normalize the identifier before querying the database.
Expand All @@ -526,6 +599,14 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity.
return nil, err
}

if len(params.Expand) > 0 {
for i := range is {
if err := p.HydrateIdentityAssociations(ctx, &is[i], params.Expand); err != nil {
return nil, err
}
}
}

schemaCache := map[string]string{}
for k := range is {
i := &is[k]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"password": {
"type": "password",
"identifiers": [
"foo-dupe@ory.sh",
"[email protected]"
"[email protected]",
"foo-dupe@ory.sh"
],
"config": {
"hashed_password": "$argon2id$v=19$m=131072,t=2,p=1$lQFPaKxXqPL56/mU7vRi4w$6aldHyBnURt8sP8+xu41Ng"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"password": {
"type": "password",
"identifiers": [
"foo-dupe@ory.sh",
"foobar@ory.sh"
"foobar@ory.sh",
"foo-dupe@ory.sh"
],
"config": {
"hashed_password": "$argon2id$v=19$m=131072,t=2,p=1$lQFPaKxXqPL56/mU7vRi4w$6aldHyBnURt8sP8+xu41Ng"
Expand Down
Loading

0 comments on commit ed9355e

Please sign in to comment.