Skip to content

Commit 83747c5

Browse files
mergify[bot]islishudejulienrbrt
authored
feat(client): overwrite client context instead of setting new one (backport #20356) (#20383)
Co-authored-by: Shude Li <[email protected]> Co-authored-by: Julien Robert <[email protected]>
1 parent 08fdfec commit 83747c5

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ Ref: https://keepachangelog.com/en/1.0.0/
4242

4343
* (debug) [#20328](https://github.com/cosmos/cosmos-sdk/pull/20328) Add consensus address for debug cmd.
4444
* (runtime) [#20264](https://github.com/cosmos/cosmos-sdk/pull/20264) Expose grpc query router via depinject.
45+
* (client) [#20356](https://github.com/cosmos/cosmos-sdk/pull/20356) Overwrite client context when available in `SetCmdClientContext`.
4546

4647
### Bug Fixes
4748

client/cmd.go

+8-5
Original file line numberDiff line numberDiff line change
@@ -359,14 +359,17 @@ func GetClientContextFromCmd(cmd *cobra.Command) Context {
359359
// SetCmdClientContext sets a command's Context value to the provided argument.
360360
// If the context has not been set, set the given context as the default.
361361
func SetCmdClientContext(cmd *cobra.Command, clientCtx Context) error {
362-
var cmdCtx context.Context
363-
364-
if cmd.Context() == nil {
362+
cmdCtx := cmd.Context()
363+
if cmdCtx == nil {
365364
cmdCtx = context.Background()
365+
}
366+
367+
v := cmd.Context().Value(ClientContextKey)
368+
if clientCtxPtr, ok := v.(*Context); ok {
369+
*clientCtxPtr = clientCtx
366370
} else {
367-
cmdCtx = cmd.Context()
371+
cmd.SetContext(context.WithValue(cmdCtx, ClientContextKey, &clientCtx))
368372
}
369373

370-
cmd.SetContext(context.WithValue(cmdCtx, ClientContextKey, &clientCtx))
371374
return nil
372375
}

client/cmd_test.go

+23-3
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,21 @@ func TestSetCmdClientContextHandler(t *testing.T) {
7979
name string
8080
expectedContext client.Context
8181
args []string
82+
ctx context.Context
8283
}{
8384
{
8485
"no flags set",
8586
initClientCtx,
8687
[]string{},
88+
context.WithValue(context.Background(), client.ClientContextKey, &client.Context{}),
8789
},
8890
{
8991
"flags set",
9092
initClientCtx.WithChainID("new-chain-id"),
9193
[]string{
9294
fmt.Sprintf("--%s=new-chain-id", flags.FlagChainID),
9395
},
96+
context.WithValue(context.Background(), client.ClientContextKey, &client.Context{}),
9497
},
9598
{
9699
"flags set with space",
@@ -99,20 +102,37 @@ func TestSetCmdClientContextHandler(t *testing.T) {
99102
fmt.Sprintf("--%s", flags.FlagHome),
100103
"/tmp/dir",
101104
},
105+
context.Background(),
106+
},
107+
{
108+
"no context provided",
109+
initClientCtx.WithHomeDir("/tmp/noctx"),
110+
[]string{
111+
fmt.Sprintf("--%s", flags.FlagHome),
112+
"/tmp/noctx",
113+
},
114+
nil,
115+
},
116+
{
117+
"with invalid client value in the context",
118+
initClientCtx.WithHomeDir("/tmp/invalid"),
119+
[]string{
120+
fmt.Sprintf("--%s", flags.FlagHome),
121+
"/tmp/invalid",
122+
},
123+
context.WithValue(context.Background(), client.ClientContextKey, "invalid"),
102124
},
103125
}
104126

105127
for _, tc := range testCases {
106128
tc := tc
107129

108130
t.Run(tc.name, func(t *testing.T) {
109-
ctx := context.WithValue(context.Background(), client.ClientContextKey, &client.Context{})
110-
111131
cmd := newCmd()
112132
_ = testutil.ApplyMockIODiscardOutErr(cmd)
113133
cmd.SetArgs(tc.args)
114134

115-
require.NoError(t, cmd.ExecuteContext(ctx))
135+
require.NoError(t, cmd.ExecuteContext(tc.ctx))
116136

117137
clientCtx := client.GetClientContextFromCmd(cmd)
118138
require.Equal(t, tc.expectedContext, clientCtx)

0 commit comments

Comments
 (0)