Skip to content

Commit b2fe0ff

Browse files
authored
feat(client): overwrite client context instead of setting new one (#20356)
1 parent 7ae23e2 commit b2fe0ff

File tree

2 files changed

+31
-8
lines changed

2 files changed

+31
-8
lines changed

client/cmd.go

+8-5
Original file line numberDiff line numberDiff line change
@@ -359,14 +359,17 @@ func GetClientContextFromCmd(cmd *cobra.Command) Context {
359359
// SetCmdClientContext sets a command's Context value to the provided argument.
360360
// If the context has not been set, set the given context as the default.
361361
func SetCmdClientContext(cmd *cobra.Command, clientCtx Context) error {
362-
var cmdCtx context.Context
363-
364-
if cmd.Context() == nil {
362+
cmdCtx := cmd.Context()
363+
if cmdCtx == nil {
365364
cmdCtx = context.Background()
365+
}
366+
367+
v := cmd.Context().Value(ClientContextKey)
368+
if clientCtxPtr, ok := v.(*Context); ok {
369+
*clientCtxPtr = clientCtx
366370
} else {
367-
cmdCtx = cmd.Context()
371+
cmd.SetContext(context.WithValue(cmdCtx, ClientContextKey, &clientCtx))
368372
}
369373

370-
cmd.SetContext(context.WithValue(cmdCtx, ClientContextKey, &clientCtx))
371374
return nil
372375
}

client/cmd_test.go

+23-3
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,21 @@ func TestSetCmdClientContextHandler(t *testing.T) {
7979
name string
8080
expectedContext client.Context
8181
args []string
82+
ctx context.Context
8283
}{
8384
{
8485
"no flags set",
8586
initClientCtx,
8687
[]string{},
88+
context.WithValue(context.Background(), client.ClientContextKey, &client.Context{}),
8789
},
8890
{
8991
"flags set",
9092
initClientCtx.WithChainID("new-chain-id"),
9193
[]string{
9294
fmt.Sprintf("--%s=new-chain-id", flags.FlagChainID),
9395
},
96+
context.WithValue(context.Background(), client.ClientContextKey, &client.Context{}),
9497
},
9598
{
9699
"flags set with space",
@@ -99,20 +102,37 @@ func TestSetCmdClientContextHandler(t *testing.T) {
99102
fmt.Sprintf("--%s", flags.FlagHome),
100103
"/tmp/dir",
101104
},
105+
context.Background(),
106+
},
107+
{
108+
"no context provided",
109+
initClientCtx.WithHomeDir("/tmp/noctx"),
110+
[]string{
111+
fmt.Sprintf("--%s", flags.FlagHome),
112+
"/tmp/noctx",
113+
},
114+
nil,
115+
},
116+
{
117+
"with invalid client value in the context",
118+
initClientCtx.WithHomeDir("/tmp/invalid"),
119+
[]string{
120+
fmt.Sprintf("--%s", flags.FlagHome),
121+
"/tmp/invalid",
122+
},
123+
context.WithValue(context.Background(), client.ClientContextKey, "invalid"),
102124
},
103125
}
104126

105127
for _, tc := range testCases {
106128
tc := tc
107129

108130
t.Run(tc.name, func(t *testing.T) {
109-
ctx := context.WithValue(context.Background(), client.ClientContextKey, &client.Context{})
110-
111131
cmd := newCmd()
112132
_ = testutil.ApplyMockIODiscardOutErr(cmd)
113133
cmd.SetArgs(tc.args)
114134

115-
require.NoError(t, cmd.ExecuteContext(ctx))
135+
require.NoError(t, cmd.ExecuteContext(tc.ctx))
116136

117137
clientCtx := client.GetClientContextFromCmd(cmd)
118138
require.Equal(t, tc.expectedContext, clientCtx)

0 commit comments

Comments
 (0)