diff --git a/context/context.go b/context/context.go index cf66309c4a..db1c95fab1 100644 --- a/context/context.go +++ b/context/context.go @@ -3,29 +3,31 @@ // license that can be found in the LICENSE file. // Package context defines the Context type, which carries deadlines, -// cancelation signals, and other request-scoped values across API boundaries +// cancellation signals, and other request-scoped values across API boundaries // and between processes. // As of Go 1.7 this package is available in the standard library under the -// name context. https://golang.org/pkg/context. +// name [context], and migrating to it can be done automatically with [go fix]. // -// Incoming requests to a server should create a Context, and outgoing calls to -// servers should accept a Context. The chain of function calls between must -// propagate the Context, optionally replacing it with a modified copy created -// using WithDeadline, WithTimeout, WithCancel, or WithValue. +// Incoming requests to a server should create a [Context], and outgoing +// calls to servers should accept a Context. The chain of function +// calls between them must propagate the Context, optionally replacing +// it with a derived Context created using [WithCancel], [WithDeadline], +// [WithTimeout], or [WithValue]. // // Programs that use Contexts should follow these rules to keep interfaces // consistent across packages and enable static analysis tools to check context // propagation: // // Do not store Contexts inside a struct type; instead, pass a Context -// explicitly to each function that needs it. The Context should be the first +// explicitly to each function that needs it. This is discussed further in +// https://go.dev/blog/context-and-structs. The Context should be the first // parameter, typically named ctx: // // func DoSomething(ctx context.Context, arg Arg) error { // // ... use ctx ... // } // -// Do not pass a nil Context, even if a function permits it. Pass context.TODO +// Do not pass a nil [Context], even if a function permits it. Pass [context.TODO] // if you are unsure about which Context to use. // // Use context Values only for request-scoped data that transits processes and @@ -34,9 +36,30 @@ // The same Context may be passed to functions running in different goroutines; // Contexts are safe for simultaneous use by multiple goroutines. // -// See http://blog.golang.org/context for example code for a server that uses +// See https://go.dev/blog/context for example code for a server that uses // Contexts. -package context // import "golang.org/x/net/context" +// +// [go fix]: https://go.dev/cmd/go#hdr-Update_packages_to_use_new_APIs +package context + +import ( + "context" // standard library's context, as of Go 1.7 + "time" +) + +// A Context carries a deadline, a cancellation signal, and other values across +// API boundaries. +// +// Context's methods may be called by multiple goroutines simultaneously. +type Context = context.Context + +// Canceled is the error returned by [Context.Err] when the context is canceled +// for some reason other than its deadline passing. +var Canceled = context.Canceled + +// DeadlineExceeded is the error returned by [Context.Err] when the context is canceled +// due to its deadline passing. +var DeadlineExceeded = context.DeadlineExceeded // Background returns a non-nil, empty Context. It is never canceled, has no // values, and has no deadline. It is typically used by the main function, @@ -49,8 +72,73 @@ func Background() Context { // TODO returns a non-nil, empty Context. Code should use context.TODO when // it's unclear which Context to use or it is not yet available (because the // surrounding function has not yet been extended to accept a Context -// parameter). TODO is recognized by static analysis tools that determine -// whether Contexts are propagated correctly in a program. +// parameter). func TODO() Context { return todo } + +var ( + background = context.Background() + todo = context.TODO() +) + +// A CancelFunc tells an operation to abandon its work. +// A CancelFunc does not wait for the work to stop. +// A CancelFunc may be called by multiple goroutines simultaneously. +// After the first call, subsequent calls to a CancelFunc do nothing. +type CancelFunc = context.CancelFunc + +// WithCancel returns a derived context that points to the parent context +// but has a new Done channel. The returned context's Done channel is closed +// when the returned cancel function is called or when the parent context's +// Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this [Context] complete. +func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + return context.WithCancel(parent) +} + +// WithDeadline returns a derived context that points to the parent context +// but has the deadline adjusted to be no later than d. If the parent's +// deadline is already earlier than d, WithDeadline(parent, d) is semantically +// equivalent to parent. The returned [Context.Done] channel is closed when +// the deadline expires, when the returned cancel function is called, +// or when the parent context's Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this [Context] complete. +func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) { + return context.WithDeadline(parent, d) +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this [Context] complete: +// +// func slowOperationWithTimeout(ctx context.Context) (Result, error) { +// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) +// defer cancel() // releases resources if slowOperation completes before timeout elapses +// return slowOperation(ctx) +// } +func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { + return context.WithTimeout(parent, timeout) +} + +// WithValue returns a derived context that points to the parent Context. +// In the derived context, the value associated with key is val. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +// +// The provided key must be comparable and should not be of type +// string or any other built-in type to avoid collisions between +// packages using context. Users of WithValue should define their own +// types for keys. To avoid allocating when assigning to an +// interface{}, context keys often have concrete type +// struct{}. Alternatively, exported context key variables' static +// type should be a pointer or interface. +func WithValue(parent Context, key, val interface{}) Context { + return context.WithValue(parent, key, val) +} diff --git a/context/context_test.go b/context/context_test.go deleted file mode 100644 index 2cb54edb89..0000000000 --- a/context/context_test.go +++ /dev/null @@ -1,583 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !go1.7 - -package context - -import ( - "fmt" - "math/rand" - "runtime" - "strings" - "sync" - "testing" - "time" -) - -// otherContext is a Context that's not one of the types defined in context.go. -// This lets us test code paths that differ based on the underlying type of the -// Context. -type otherContext struct { - Context -} - -func TestBackground(t *testing.T) { - c := Background() - if c == nil { - t.Fatalf("Background returned nil") - } - select { - case x := <-c.Done(): - t.Errorf("<-c.Done() == %v want nothing (it should block)", x) - default: - } - if got, want := fmt.Sprint(c), "context.Background"; got != want { - t.Errorf("Background().String() = %q want %q", got, want) - } -} - -func TestTODO(t *testing.T) { - c := TODO() - if c == nil { - t.Fatalf("TODO returned nil") - } - select { - case x := <-c.Done(): - t.Errorf("<-c.Done() == %v want nothing (it should block)", x) - default: - } - if got, want := fmt.Sprint(c), "context.TODO"; got != want { - t.Errorf("TODO().String() = %q want %q", got, want) - } -} - -func TestWithCancel(t *testing.T) { - c1, cancel := WithCancel(Background()) - - if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want { - t.Errorf("c1.String() = %q want %q", got, want) - } - - o := otherContext{c1} - c2, _ := WithCancel(o) - contexts := []Context{c1, o, c2} - - for i, c := range contexts { - if d := c.Done(); d == nil { - t.Errorf("c[%d].Done() == %v want non-nil", i, d) - } - if e := c.Err(); e != nil { - t.Errorf("c[%d].Err() == %v want nil", i, e) - } - - select { - case x := <-c.Done(): - t.Errorf("<-c.Done() == %v want nothing (it should block)", x) - default: - } - } - - cancel() - time.Sleep(100 * time.Millisecond) // let cancelation propagate - - for i, c := range contexts { - select { - case <-c.Done(): - default: - t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i) - } - if e := c.Err(); e != Canceled { - t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled) - } - } -} - -func TestParentFinishesChild(t *testing.T) { - // Context tree: - // parent -> cancelChild - // parent -> valueChild -> timerChild - parent, cancel := WithCancel(Background()) - cancelChild, stop := WithCancel(parent) - defer stop() - valueChild := WithValue(parent, "key", "value") - timerChild, stop := WithTimeout(valueChild, 10000*time.Hour) - defer stop() - - select { - case x := <-parent.Done(): - t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) - case x := <-cancelChild.Done(): - t.Errorf("<-cancelChild.Done() == %v want nothing (it should block)", x) - case x := <-timerChild.Done(): - t.Errorf("<-timerChild.Done() == %v want nothing (it should block)", x) - case x := <-valueChild.Done(): - t.Errorf("<-valueChild.Done() == %v want nothing (it should block)", x) - default: - } - - // The parent's children should contain the two cancelable children. - pc := parent.(*cancelCtx) - cc := cancelChild.(*cancelCtx) - tc := timerChild.(*timerCtx) - pc.mu.Lock() - if len(pc.children) != 2 || !pc.children[cc] || !pc.children[tc] { - t.Errorf("bad linkage: pc.children = %v, want %v and %v", - pc.children, cc, tc) - } - pc.mu.Unlock() - - if p, ok := parentCancelCtx(cc.Context); !ok || p != pc { - t.Errorf("bad linkage: parentCancelCtx(cancelChild.Context) = %v, %v want %v, true", p, ok, pc) - } - if p, ok := parentCancelCtx(tc.Context); !ok || p != pc { - t.Errorf("bad linkage: parentCancelCtx(timerChild.Context) = %v, %v want %v, true", p, ok, pc) - } - - cancel() - - pc.mu.Lock() - if len(pc.children) != 0 { - t.Errorf("pc.cancel didn't clear pc.children = %v", pc.children) - } - pc.mu.Unlock() - - // parent and children should all be finished. - check := func(ctx Context, name string) { - select { - case <-ctx.Done(): - default: - t.Errorf("<-%s.Done() blocked, but shouldn't have", name) - } - if e := ctx.Err(); e != Canceled { - t.Errorf("%s.Err() == %v want %v", name, e, Canceled) - } - } - check(parent, "parent") - check(cancelChild, "cancelChild") - check(valueChild, "valueChild") - check(timerChild, "timerChild") - - // WithCancel should return a canceled context on a canceled parent. - precanceledChild := WithValue(parent, "key", "value") - select { - case <-precanceledChild.Done(): - default: - t.Errorf("<-precanceledChild.Done() blocked, but shouldn't have") - } - if e := precanceledChild.Err(); e != Canceled { - t.Errorf("precanceledChild.Err() == %v want %v", e, Canceled) - } -} - -func TestChildFinishesFirst(t *testing.T) { - cancelable, stop := WithCancel(Background()) - defer stop() - for _, parent := range []Context{Background(), cancelable} { - child, cancel := WithCancel(parent) - - select { - case x := <-parent.Done(): - t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) - case x := <-child.Done(): - t.Errorf("<-child.Done() == %v want nothing (it should block)", x) - default: - } - - cc := child.(*cancelCtx) - pc, pcok := parent.(*cancelCtx) // pcok == false when parent == Background() - if p, ok := parentCancelCtx(cc.Context); ok != pcok || (ok && pc != p) { - t.Errorf("bad linkage: parentCancelCtx(cc.Context) = %v, %v want %v, %v", p, ok, pc, pcok) - } - - if pcok { - pc.mu.Lock() - if len(pc.children) != 1 || !pc.children[cc] { - t.Errorf("bad linkage: pc.children = %v, cc = %v", pc.children, cc) - } - pc.mu.Unlock() - } - - cancel() - - if pcok { - pc.mu.Lock() - if len(pc.children) != 0 { - t.Errorf("child's cancel didn't remove self from pc.children = %v", pc.children) - } - pc.mu.Unlock() - } - - // child should be finished. - select { - case <-child.Done(): - default: - t.Errorf("<-child.Done() blocked, but shouldn't have") - } - if e := child.Err(); e != Canceled { - t.Errorf("child.Err() == %v want %v", e, Canceled) - } - - // parent should not be finished. - select { - case x := <-parent.Done(): - t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) - default: - } - if e := parent.Err(); e != nil { - t.Errorf("parent.Err() == %v want nil", e) - } - } -} - -func testDeadline(c Context, wait time.Duration, t *testing.T) { - select { - case <-time.After(wait): - t.Fatalf("context should have timed out") - case <-c.Done(): - } - if e := c.Err(); e != DeadlineExceeded { - t.Errorf("c.Err() == %v want %v", e, DeadlineExceeded) - } -} - -func TestDeadline(t *testing.T) { - t.Parallel() - const timeUnit = 500 * time.Millisecond - c, _ := WithDeadline(Background(), time.Now().Add(1*timeUnit)) - if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { - t.Errorf("c.String() = %q want prefix %q", got, prefix) - } - testDeadline(c, 2*timeUnit, t) - - c, _ = WithDeadline(Background(), time.Now().Add(1*timeUnit)) - o := otherContext{c} - testDeadline(o, 2*timeUnit, t) - - c, _ = WithDeadline(Background(), time.Now().Add(1*timeUnit)) - o = otherContext{c} - c, _ = WithDeadline(o, time.Now().Add(3*timeUnit)) - testDeadline(c, 2*timeUnit, t) -} - -func TestTimeout(t *testing.T) { - t.Parallel() - const timeUnit = 500 * time.Millisecond - c, _ := WithTimeout(Background(), 1*timeUnit) - if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { - t.Errorf("c.String() = %q want prefix %q", got, prefix) - } - testDeadline(c, 2*timeUnit, t) - - c, _ = WithTimeout(Background(), 1*timeUnit) - o := otherContext{c} - testDeadline(o, 2*timeUnit, t) - - c, _ = WithTimeout(Background(), 1*timeUnit) - o = otherContext{c} - c, _ = WithTimeout(o, 3*timeUnit) - testDeadline(c, 2*timeUnit, t) -} - -func TestCanceledTimeout(t *testing.T) { - t.Parallel() - const timeUnit = 500 * time.Millisecond - c, _ := WithTimeout(Background(), 2*timeUnit) - o := otherContext{c} - c, cancel := WithTimeout(o, 4*timeUnit) - cancel() - time.Sleep(1 * timeUnit) // let cancelation propagate - select { - case <-c.Done(): - default: - t.Errorf("<-c.Done() blocked, but shouldn't have") - } - if e := c.Err(); e != Canceled { - t.Errorf("c.Err() == %v want %v", e, Canceled) - } -} - -type key1 int -type key2 int - -var k1 = key1(1) -var k2 = key2(1) // same int as k1, different type -var k3 = key2(3) // same type as k2, different int - -func TestValues(t *testing.T) { - check := func(c Context, nm, v1, v2, v3 string) { - if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 { - t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0) - } - if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 { - t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0) - } - if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 { - t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0) - } - } - - c0 := Background() - check(c0, "c0", "", "", "") - - c1 := WithValue(Background(), k1, "c1k1") - check(c1, "c1", "c1k1", "", "") - - if got, want := fmt.Sprint(c1), `context.Background.WithValue(1, "c1k1")`; got != want { - t.Errorf("c.String() = %q want %q", got, want) - } - - c2 := WithValue(c1, k2, "c2k2") - check(c2, "c2", "c1k1", "c2k2", "") - - c3 := WithValue(c2, k3, "c3k3") - check(c3, "c2", "c1k1", "c2k2", "c3k3") - - c4 := WithValue(c3, k1, nil) - check(c4, "c4", "", "c2k2", "c3k3") - - o0 := otherContext{Background()} - check(o0, "o0", "", "", "") - - o1 := otherContext{WithValue(Background(), k1, "c1k1")} - check(o1, "o1", "c1k1", "", "") - - o2 := WithValue(o1, k2, "o2k2") - check(o2, "o2", "c1k1", "o2k2", "") - - o3 := otherContext{c4} - check(o3, "o3", "", "c2k2", "c3k3") - - o4 := WithValue(o3, k3, nil) - check(o4, "o4", "", "c2k2", "") -} - -func TestAllocs(t *testing.T) { - bg := Background() - for _, test := range []struct { - desc string - f func() - limit float64 - gccgoLimit float64 - }{ - { - desc: "Background()", - f: func() { Background() }, - limit: 0, - gccgoLimit: 0, - }, - { - desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1), - f: func() { - c := WithValue(bg, k1, nil) - c.Value(k1) - }, - limit: 3, - gccgoLimit: 3, - }, - { - desc: "WithTimeout(bg, 15*time.Millisecond)", - f: func() { - c, _ := WithTimeout(bg, 15*time.Millisecond) - <-c.Done() - }, - limit: 8, - gccgoLimit: 16, - }, - { - desc: "WithCancel(bg)", - f: func() { - c, cancel := WithCancel(bg) - cancel() - <-c.Done() - }, - limit: 5, - gccgoLimit: 8, - }, - { - desc: "WithTimeout(bg, 100*time.Millisecond)", - f: func() { - c, cancel := WithTimeout(bg, 100*time.Millisecond) - cancel() - <-c.Done() - }, - limit: 8, - gccgoLimit: 25, - }, - } { - limit := test.limit - if runtime.Compiler == "gccgo" { - // gccgo does not yet do escape analysis. - // TODO(iant): Remove this when gccgo does do escape analysis. - limit = test.gccgoLimit - } - if n := testing.AllocsPerRun(100, test.f); n > limit { - t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit)) - } - } -} - -func TestSimultaneousCancels(t *testing.T) { - root, cancel := WithCancel(Background()) - m := map[Context]CancelFunc{root: cancel} - q := []Context{root} - // Create a tree of contexts. - for len(q) != 0 && len(m) < 100 { - parent := q[0] - q = q[1:] - for i := 0; i < 4; i++ { - ctx, cancel := WithCancel(parent) - m[ctx] = cancel - q = append(q, ctx) - } - } - // Start all the cancels in a random order. - var wg sync.WaitGroup - wg.Add(len(m)) - for _, cancel := range m { - go func(cancel CancelFunc) { - cancel() - wg.Done() - }(cancel) - } - // Wait on all the contexts in a random order. - for ctx := range m { - select { - case <-ctx.Done(): - case <-time.After(1 * time.Second): - buf := make([]byte, 10<<10) - n := runtime.Stack(buf, true) - t.Fatalf("timed out waiting for <-ctx.Done(); stacks:\n%s", buf[:n]) - } - } - // Wait for all the cancel functions to return. - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - select { - case <-done: - case <-time.After(1 * time.Second): - buf := make([]byte, 10<<10) - n := runtime.Stack(buf, true) - t.Fatalf("timed out waiting for cancel functions; stacks:\n%s", buf[:n]) - } -} - -func TestInterlockedCancels(t *testing.T) { - parent, cancelParent := WithCancel(Background()) - child, cancelChild := WithCancel(parent) - go func() { - parent.Done() - cancelChild() - }() - cancelParent() - select { - case <-child.Done(): - case <-time.After(1 * time.Second): - buf := make([]byte, 10<<10) - n := runtime.Stack(buf, true) - t.Fatalf("timed out waiting for child.Done(); stacks:\n%s", buf[:n]) - } -} - -func TestLayersCancel(t *testing.T) { - testLayers(t, time.Now().UnixNano(), false) -} - -func TestLayersTimeout(t *testing.T) { - testLayers(t, time.Now().UnixNano(), true) -} - -func testLayers(t *testing.T, seed int64, testTimeout bool) { - rand.Seed(seed) - errorf := func(format string, a ...interface{}) { - t.Errorf(fmt.Sprintf("seed=%d: %s", seed, format), a...) - } - const ( - timeout = 200 * time.Millisecond - minLayers = 30 - ) - type value int - var ( - vals []*value - cancels []CancelFunc - numTimers int - ctx = Background() - ) - for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ { - switch rand.Intn(3) { - case 0: - v := new(value) - ctx = WithValue(ctx, v, v) - vals = append(vals, v) - case 1: - var cancel CancelFunc - ctx, cancel = WithCancel(ctx) - cancels = append(cancels, cancel) - case 2: - var cancel CancelFunc - ctx, cancel = WithTimeout(ctx, timeout) - cancels = append(cancels, cancel) - numTimers++ - } - } - checkValues := func(when string) { - for _, key := range vals { - if val := ctx.Value(key).(*value); key != val { - errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key) - } - } - } - select { - case <-ctx.Done(): - errorf("ctx should not be canceled yet") - default: - } - if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) { - t.Errorf("ctx.String() = %q want prefix %q", s, prefix) - } - t.Log(ctx) - checkValues("before cancel") - if testTimeout { - select { - case <-ctx.Done(): - case <-time.After(timeout + 100*time.Millisecond): - errorf("ctx should have timed out") - } - checkValues("after timeout") - } else { - cancel := cancels[rand.Intn(len(cancels))] - cancel() - select { - case <-ctx.Done(): - default: - errorf("ctx should be canceled") - } - checkValues("after cancel") - } -} - -func TestCancelRemoves(t *testing.T) { - checkChildren := func(when string, ctx Context, want int) { - if got := len(ctx.(*cancelCtx).children); got != want { - t.Errorf("%s: context has %d children, want %d", when, got, want) - } - } - - ctx, _ := WithCancel(Background()) - checkChildren("after creation", ctx, 0) - _, cancel := WithCancel(ctx) - checkChildren("with WithCancel child ", ctx, 1) - cancel() - checkChildren("after cancelling WithCancel child", ctx, 0) - - ctx, _ = WithCancel(Background()) - checkChildren("after creation", ctx, 0) - _, cancel = WithTimeout(ctx, 60*time.Minute) - checkChildren("with WithTimeout child ", ctx, 1) - cancel() - checkChildren("after cancelling WithTimeout child", ctx, 0) -} diff --git a/context/ctxhttp/ctxhttp.go b/context/ctxhttp/ctxhttp.go index 37dc0cfdb5..e0df203cea 100644 --- a/context/ctxhttp/ctxhttp.go +++ b/context/ctxhttp/ctxhttp.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. // Package ctxhttp provides helper functions for performing context-aware HTTP requests. -package ctxhttp // import "golang.org/x/net/context/ctxhttp" +package ctxhttp import ( "context" diff --git a/context/go17.go b/context/go17.go deleted file mode 100644 index 0c1b867937..0000000000 --- a/context/go17.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2016 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build go1.7 - -package context - -import ( - "context" // standard library's context, as of Go 1.7 - "time" -) - -var ( - todo = context.TODO() - background = context.Background() -) - -// Canceled is the error returned by Context.Err when the context is canceled. -var Canceled = context.Canceled - -// DeadlineExceeded is the error returned by Context.Err when the context's -// deadline passes. -var DeadlineExceeded = context.DeadlineExceeded - -// WithCancel returns a copy of parent with a new Done channel. The returned -// context's Done channel is closed when the returned cancel function is called -// or when the parent context's Done channel is closed, whichever happens first. -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete. -func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { - ctx, f := context.WithCancel(parent) - return ctx, f -} - -// WithDeadline returns a copy of the parent context with the deadline adjusted -// to be no later than d. If the parent's deadline is already earlier than d, -// WithDeadline(parent, d) is semantically equivalent to parent. The returned -// context's Done channel is closed when the deadline expires, when the returned -// cancel function is called, or when the parent context's Done channel is -// closed, whichever happens first. -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete. -func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { - ctx, f := context.WithDeadline(parent, deadline) - return ctx, f -} - -// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete: -// -// func slowOperationWithTimeout(ctx context.Context) (Result, error) { -// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) -// defer cancel() // releases resources if slowOperation completes before timeout elapses -// return slowOperation(ctx) -// } -func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { - return WithDeadline(parent, time.Now().Add(timeout)) -} - -// WithValue returns a copy of parent in which the value associated with key is -// val. -// -// Use context Values only for request-scoped data that transits processes and -// APIs, not for passing optional parameters to functions. -func WithValue(parent Context, key interface{}, val interface{}) Context { - return context.WithValue(parent, key, val) -} diff --git a/context/go19.go b/context/go19.go deleted file mode 100644 index e31e35a904..0000000000 --- a/context/go19.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build go1.9 - -package context - -import "context" // standard library's context, as of Go 1.7 - -// A Context carries a deadline, a cancelation signal, and other values across -// API boundaries. -// -// Context's methods may be called by multiple goroutines simultaneously. -type Context = context.Context - -// A CancelFunc tells an operation to abandon its work. -// A CancelFunc does not wait for the work to stop. -// After the first call, subsequent calls to a CancelFunc do nothing. -type CancelFunc = context.CancelFunc diff --git a/context/pre_go17.go b/context/pre_go17.go deleted file mode 100644 index 065ff3dfa5..0000000000 --- a/context/pre_go17.go +++ /dev/null @@ -1,300 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !go1.7 - -package context - -import ( - "errors" - "fmt" - "sync" - "time" -) - -// An emptyCtx is never canceled, has no values, and has no deadline. It is not -// struct{}, since vars of this type must have distinct addresses. -type emptyCtx int - -func (*emptyCtx) Deadline() (deadline time.Time, ok bool) { - return -} - -func (*emptyCtx) Done() <-chan struct{} { - return nil -} - -func (*emptyCtx) Err() error { - return nil -} - -func (*emptyCtx) Value(key interface{}) interface{} { - return nil -} - -func (e *emptyCtx) String() string { - switch e { - case background: - return "context.Background" - case todo: - return "context.TODO" - } - return "unknown empty Context" -} - -var ( - background = new(emptyCtx) - todo = new(emptyCtx) -) - -// Canceled is the error returned by Context.Err when the context is canceled. -var Canceled = errors.New("context canceled") - -// DeadlineExceeded is the error returned by Context.Err when the context's -// deadline passes. -var DeadlineExceeded = errors.New("context deadline exceeded") - -// WithCancel returns a copy of parent with a new Done channel. The returned -// context's Done channel is closed when the returned cancel function is called -// or when the parent context's Done channel is closed, whichever happens first. -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete. -func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { - c := newCancelCtx(parent) - propagateCancel(parent, c) - return c, func() { c.cancel(true, Canceled) } -} - -// newCancelCtx returns an initialized cancelCtx. -func newCancelCtx(parent Context) *cancelCtx { - return &cancelCtx{ - Context: parent, - done: make(chan struct{}), - } -} - -// propagateCancel arranges for child to be canceled when parent is. -func propagateCancel(parent Context, child canceler) { - if parent.Done() == nil { - return // parent is never canceled - } - if p, ok := parentCancelCtx(parent); ok { - p.mu.Lock() - if p.err != nil { - // parent has already been canceled - child.cancel(false, p.err) - } else { - if p.children == nil { - p.children = make(map[canceler]bool) - } - p.children[child] = true - } - p.mu.Unlock() - } else { - go func() { - select { - case <-parent.Done(): - child.cancel(false, parent.Err()) - case <-child.Done(): - } - }() - } -} - -// parentCancelCtx follows a chain of parent references until it finds a -// *cancelCtx. This function understands how each of the concrete types in this -// package represents its parent. -func parentCancelCtx(parent Context) (*cancelCtx, bool) { - for { - switch c := parent.(type) { - case *cancelCtx: - return c, true - case *timerCtx: - return c.cancelCtx, true - case *valueCtx: - parent = c.Context - default: - return nil, false - } - } -} - -// removeChild removes a context from its parent. -func removeChild(parent Context, child canceler) { - p, ok := parentCancelCtx(parent) - if !ok { - return - } - p.mu.Lock() - if p.children != nil { - delete(p.children, child) - } - p.mu.Unlock() -} - -// A canceler is a context type that can be canceled directly. The -// implementations are *cancelCtx and *timerCtx. -type canceler interface { - cancel(removeFromParent bool, err error) - Done() <-chan struct{} -} - -// A cancelCtx can be canceled. When canceled, it also cancels any children -// that implement canceler. -type cancelCtx struct { - Context - - done chan struct{} // closed by the first cancel call. - - mu sync.Mutex - children map[canceler]bool // set to nil by the first cancel call - err error // set to non-nil by the first cancel call -} - -func (c *cancelCtx) Done() <-chan struct{} { - return c.done -} - -func (c *cancelCtx) Err() error { - c.mu.Lock() - defer c.mu.Unlock() - return c.err -} - -func (c *cancelCtx) String() string { - return fmt.Sprintf("%v.WithCancel", c.Context) -} - -// cancel closes c.done, cancels each of c's children, and, if -// removeFromParent is true, removes c from its parent's children. -func (c *cancelCtx) cancel(removeFromParent bool, err error) { - if err == nil { - panic("context: internal error: missing cancel error") - } - c.mu.Lock() - if c.err != nil { - c.mu.Unlock() - return // already canceled - } - c.err = err - close(c.done) - for child := range c.children { - // NOTE: acquiring the child's lock while holding parent's lock. - child.cancel(false, err) - } - c.children = nil - c.mu.Unlock() - - if removeFromParent { - removeChild(c.Context, c) - } -} - -// WithDeadline returns a copy of the parent context with the deadline adjusted -// to be no later than d. If the parent's deadline is already earlier than d, -// WithDeadline(parent, d) is semantically equivalent to parent. The returned -// context's Done channel is closed when the deadline expires, when the returned -// cancel function is called, or when the parent context's Done channel is -// closed, whichever happens first. -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete. -func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { - if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { - // The current deadline is already sooner than the new one. - return WithCancel(parent) - } - c := &timerCtx{ - cancelCtx: newCancelCtx(parent), - deadline: deadline, - } - propagateCancel(parent, c) - d := deadline.Sub(time.Now()) - if d <= 0 { - c.cancel(true, DeadlineExceeded) // deadline has already passed - return c, func() { c.cancel(true, Canceled) } - } - c.mu.Lock() - defer c.mu.Unlock() - if c.err == nil { - c.timer = time.AfterFunc(d, func() { - c.cancel(true, DeadlineExceeded) - }) - } - return c, func() { c.cancel(true, Canceled) } -} - -// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to -// implement Done and Err. It implements cancel by stopping its timer then -// delegating to cancelCtx.cancel. -type timerCtx struct { - *cancelCtx - timer *time.Timer // Under cancelCtx.mu. - - deadline time.Time -} - -func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { - return c.deadline, true -} - -func (c *timerCtx) String() string { - return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now())) -} - -func (c *timerCtx) cancel(removeFromParent bool, err error) { - c.cancelCtx.cancel(false, err) - if removeFromParent { - // Remove this timerCtx from its parent cancelCtx's children. - removeChild(c.cancelCtx.Context, c) - } - c.mu.Lock() - if c.timer != nil { - c.timer.Stop() - c.timer = nil - } - c.mu.Unlock() -} - -// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete: -// -// func slowOperationWithTimeout(ctx context.Context) (Result, error) { -// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) -// defer cancel() // releases resources if slowOperation completes before timeout elapses -// return slowOperation(ctx) -// } -func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { - return WithDeadline(parent, time.Now().Add(timeout)) -} - -// WithValue returns a copy of parent in which the value associated with key is -// val. -// -// Use context Values only for request-scoped data that transits processes and -// APIs, not for passing optional parameters to functions. -func WithValue(parent Context, key interface{}, val interface{}) Context { - return &valueCtx{parent, key, val} -} - -// A valueCtx carries a key-value pair. It implements Value for that key and -// delegates all other calls to the embedded Context. -type valueCtx struct { - Context - key, val interface{} -} - -func (c *valueCtx) String() string { - return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val) -} - -func (c *valueCtx) Value(key interface{}) interface{} { - if c.key == key { - return c.val - } - return c.Context.Value(key) -} diff --git a/context/pre_go19.go b/context/pre_go19.go deleted file mode 100644 index ec5a638033..0000000000 --- a/context/pre_go19.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !go1.9 - -package context - -import "time" - -// A Context carries a deadline, a cancelation signal, and other values across -// API boundaries. -// -// Context's methods may be called by multiple goroutines simultaneously. -type Context interface { - // Deadline returns the time when work done on behalf of this context - // should be canceled. Deadline returns ok==false when no deadline is - // set. Successive calls to Deadline return the same results. - Deadline() (deadline time.Time, ok bool) - - // Done returns a channel that's closed when work done on behalf of this - // context should be canceled. Done may return nil if this context can - // never be canceled. Successive calls to Done return the same value. - // - // WithCancel arranges for Done to be closed when cancel is called; - // WithDeadline arranges for Done to be closed when the deadline - // expires; WithTimeout arranges for Done to be closed when the timeout - // elapses. - // - // Done is provided for use in select statements: - // - // // Stream generates values with DoSomething and sends them to out - // // until DoSomething returns an error or ctx.Done is closed. - // func Stream(ctx context.Context, out chan<- Value) error { - // for { - // v, err := DoSomething(ctx) - // if err != nil { - // return err - // } - // select { - // case <-ctx.Done(): - // return ctx.Err() - // case out <- v: - // } - // } - // } - // - // See http://blog.golang.org/pipelines for more examples of how to use - // a Done channel for cancelation. - Done() <-chan struct{} - - // Err returns a non-nil error value after Done is closed. Err returns - // Canceled if the context was canceled or DeadlineExceeded if the - // context's deadline passed. No other values for Err are defined. - // After Done is closed, successive calls to Err return the same value. - Err() error - - // Value returns the value associated with this context for key, or nil - // if no value is associated with key. Successive calls to Value with - // the same key returns the same result. - // - // Use context values only for request-scoped data that transits - // processes and API boundaries, not for passing optional parameters to - // functions. - // - // A key identifies a specific value in a Context. Functions that wish - // to store values in Context typically allocate a key in a global - // variable then use that key as the argument to context.WithValue and - // Context.Value. A key can be any type that supports equality; - // packages should define keys as an unexported type to avoid - // collisions. - // - // Packages that define a Context key should provide type-safe accessors - // for the values stores using that key: - // - // // Package user defines a User type that's stored in Contexts. - // package user - // - // import "golang.org/x/net/context" - // - // // User is the type of value stored in the Contexts. - // type User struct {...} - // - // // key is an unexported type for keys defined in this package. - // // This prevents collisions with keys defined in other packages. - // type key int - // - // // userKey is the key for user.User values in Contexts. It is - // // unexported; clients use user.NewContext and user.FromContext - // // instead of using this key directly. - // var userKey key = 0 - // - // // NewContext returns a new Context that carries value u. - // func NewContext(ctx context.Context, u *User) context.Context { - // return context.WithValue(ctx, userKey, u) - // } - // - // // FromContext returns the User value stored in ctx, if any. - // func FromContext(ctx context.Context) (*User, bool) { - // u, ok := ctx.Value(userKey).(*User) - // return u, ok - // } - Value(key interface{}) interface{} -} - -// A CancelFunc tells an operation to abandon its work. -// A CancelFunc does not wait for the work to stop. -// After the first call, subsequent calls to a CancelFunc do nothing. -type CancelFunc func() diff --git a/context/withtimeout_test.go b/context/withtimeout_test.go deleted file mode 100644 index e6f56691d1..0000000000 --- a/context/withtimeout_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package context_test - -import ( - "fmt" - "time" - - "golang.org/x/net/context" -) - -// This example passes a context with a timeout to tell a blocking function that -// it should abandon its work after the timeout elapses. -func ExampleWithTimeout() { - // Pass a context with a timeout to tell a blocking function that it - // should abandon its work after the timeout elapses. - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - - select { - case <-time.After(1 * time.Second): - fmt.Println("overslept") - case <-ctx.Done(): - fmt.Println(ctx.Err()) // prints "context deadline exceeded" - } - - // Output: - // context deadline exceeded -} diff --git a/go.mod b/go.mod index 8de393204b..37aac27a62 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,9 @@ module golang.org/x/net -go 1.18 +go 1.23.0 require ( - golang.org/x/crypto v0.33.0 + golang.org/x/crypto v0.35.0 golang.org/x/sys v0.30.0 golang.org/x/term v0.29.0 golang.org/x/text v0.22.0 diff --git a/go.sum b/go.sum index 553516bb0f..5f95431dfa 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= -golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= +golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU= diff --git a/http/httpproxy/proxy.go b/http/httpproxy/proxy.go index 6404aaf157..d89c257ae7 100644 --- a/http/httpproxy/proxy.go +++ b/http/httpproxy/proxy.go @@ -14,6 +14,7 @@ import ( "errors" "fmt" "net" + "net/netip" "net/url" "os" "strings" @@ -177,8 +178,10 @@ func (cfg *config) useProxy(addr string) bool { if host == "localhost" { return false } - ip := net.ParseIP(host) - if ip != nil { + nip, err := netip.ParseAddr(host) + var ip net.IP + if err == nil { + ip = net.IP(nip.AsSlice()) if ip.IsLoopback() { return false } @@ -360,6 +363,9 @@ type domainMatch struct { } func (m domainMatch) match(host, port string, ip net.IP) bool { + if ip != nil { + return false + } if strings.HasSuffix(host, m.host) || (m.matchHost && host == m.host[1:]) { return m.port == "" || m.port == port } diff --git a/http/httpproxy/proxy_test.go b/http/httpproxy/proxy_test.go index 790afdab77..a1dd2e83fd 100644 --- a/http/httpproxy/proxy_test.go +++ b/http/httpproxy/proxy_test.go @@ -211,6 +211,13 @@ var proxyForURLTests = []proxyForURLTest{{ }, req: "http://www.xn--fsq092h.com", want: "<nil>", +}, { + cfg: httpproxy.Config{ + NoProxy: "example.com", + HTTPProxy: "proxy", + }, + req: "http://[1000::%25.example.com]:123", + want: "http://proxy", }, } diff --git a/http2/server.go b/http2/server.go index 7434b87843..b640deb0e0 100644 --- a/http2/server.go +++ b/http2/server.go @@ -2233,25 +2233,25 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) { sc.serveG.check() - rp := requestParam{ - method: f.PseudoValue("method"), - scheme: f.PseudoValue("scheme"), - authority: f.PseudoValue("authority"), - path: f.PseudoValue("path"), - protocol: f.PseudoValue("protocol"), + rp := httpcommon.ServerRequestParam{ + Method: f.PseudoValue("method"), + Scheme: f.PseudoValue("scheme"), + Authority: f.PseudoValue("authority"), + Path: f.PseudoValue("path"), + Protocol: f.PseudoValue("protocol"), } // extended connect is disabled, so we should not see :protocol - if disableExtendedConnectProtocol && rp.protocol != "" { + if disableExtendedConnectProtocol && rp.Protocol != "" { return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol)) } - isConnect := rp.method == "CONNECT" + isConnect := rp.Method == "CONNECT" if isConnect { - if rp.protocol == "" && (rp.path != "" || rp.scheme != "" || rp.authority == "") { + if rp.Protocol == "" && (rp.Path != "" || rp.Scheme != "" || rp.Authority == "") { return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol)) } - } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { + } else if rp.Method == "" || rp.Path == "" || (rp.Scheme != "https" && rp.Scheme != "http") { // See 8.1.2.6 Malformed Requests and Responses: // // Malformed requests or responses that are detected @@ -2265,15 +2265,16 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol)) } - rp.header = make(http.Header) + header := make(http.Header) + rp.Header = header for _, hf := range f.RegularFields() { - rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value) + header.Add(sc.canonicalHeader(hf.Name), hf.Value) } - if rp.authority == "" { - rp.authority = rp.header.Get("Host") + if rp.Authority == "" { + rp.Authority = header.Get("Host") } - if rp.protocol != "" { - rp.header.Set(":protocol", rp.protocol) + if rp.Protocol != "" { + header.Set(":protocol", rp.Protocol) } rw, req, err := sc.newWriterAndRequestNoBody(st, rp) @@ -2282,7 +2283,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res } bodyOpen := !f.StreamEnded() if bodyOpen { - if vv, ok := rp.header["Content-Length"]; ok { + if vv, ok := rp.Header["Content-Length"]; ok { if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil { req.ContentLength = int64(cl) } else { @@ -2298,84 +2299,38 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res return rw, req, nil } -type requestParam struct { - method string - scheme, authority, path string - protocol string - header http.Header -} - -func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*responseWriter, *http.Request, error) { +func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp httpcommon.ServerRequestParam) (*responseWriter, *http.Request, error) { sc.serveG.check() var tlsState *tls.ConnectionState // nil if not scheme https - if rp.scheme == "https" { + if rp.Scheme == "https" { tlsState = sc.tlsState } - needsContinue := httpguts.HeaderValuesContainsToken(rp.header["Expect"], "100-continue") - if needsContinue { - rp.header.Del("Expect") - } - // Merge Cookie headers into one "; "-delimited value. - if cookies := rp.header["Cookie"]; len(cookies) > 1 { - rp.header.Set("Cookie", strings.Join(cookies, "; ")) - } - - // Setup Trailers - var trailer http.Header - for _, v := range rp.header["Trailer"] { - for _, key := range strings.Split(v, ",") { - key = http.CanonicalHeaderKey(textproto.TrimString(key)) - switch key { - case "Transfer-Encoding", "Trailer", "Content-Length": - // Bogus. (copy of http1 rules) - // Ignore. - default: - if trailer == nil { - trailer = make(http.Header) - } - trailer[key] = nil - } - } - } - delete(rp.header, "Trailer") - - var url_ *url.URL - var requestURI string - if rp.method == "CONNECT" && rp.protocol == "" { - url_ = &url.URL{Host: rp.authority} - requestURI = rp.authority // mimic HTTP/1 server behavior - } else { - var err error - url_, err = url.ParseRequestURI(rp.path) - if err != nil { - return nil, nil, sc.countError("bad_path", streamError(st.id, ErrCodeProtocol)) - } - requestURI = rp.path + res := httpcommon.NewServerRequest(rp) + if res.InvalidReason != "" { + return nil, nil, sc.countError(res.InvalidReason, streamError(st.id, ErrCodeProtocol)) } body := &requestBody{ conn: sc, stream: st, - needsContinue: needsContinue, + needsContinue: res.NeedsContinue, } - req := &http.Request{ - Method: rp.method, - URL: url_, + req := (&http.Request{ + Method: rp.Method, + URL: res.URL, RemoteAddr: sc.remoteAddrStr, - Header: rp.header, - RequestURI: requestURI, + Header: rp.Header, + RequestURI: res.RequestURI, Proto: "HTTP/2.0", ProtoMajor: 2, ProtoMinor: 0, TLS: tlsState, - Host: rp.authority, + Host: rp.Authority, Body: body, - Trailer: trailer, - } - req = req.WithContext(st.ctx) - + Trailer: res.Trailer, + }).WithContext(st.ctx) rw := sc.newResponseWriter(st, req) return rw, req, nil } @@ -3270,12 +3225,12 @@ func (sc *serverConn) startPush(msg *startPushRequest) { // we start in "half closed (remote)" for simplicity. // See further comments at the definition of stateHalfClosedRemote. promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote) - rw, req, err := sc.newWriterAndRequestNoBody(promised, requestParam{ - method: msg.method, - scheme: msg.url.Scheme, - authority: msg.url.Host, - path: msg.url.RequestURI(), - header: cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE + rw, req, err := sc.newWriterAndRequestNoBody(promised, httpcommon.ServerRequestParam{ + Method: msg.method, + Scheme: msg.url.Scheme, + Authority: msg.url.Host, + Path: msg.url.RequestURI(), + Header: cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE }) if err != nil { // Should not happen, since we've already validated msg.url. diff --git a/http2/server_test.go b/http2/server_test.go index 08f2dd3b21..b27a127a5e 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -1032,6 +1032,26 @@ func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) { }) } +func TestServer_Request_Reject_Authority_Userinfo(t *testing.T) { + // "':authority' MUST NOT include the deprecated userinfo subcomponent + // for "http" or "https" schemed URIs." + // https://www.rfc-editor.org/rfc/rfc9113.html#section-8.3.1-2.3.8 + testRejectRequest(t, func(st *serverTester) { + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "userinfo@example.tld"}) + enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"}) + enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/"}) + enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"}) + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: buf.Bytes(), + EndStream: true, + EndHeaders: true, + }) + }) +} + func testRejectRequest(t *testing.T, send func(*serverTester)) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { t.Error("server request made it to handler; should've been rejected") @@ -2794,6 +2814,8 @@ func testServerWritesTrailers(t *testing.T, withFlush bool) { w.Header().Set("Trailer", "should not be included; Forbidden by RFC 7230 4.1.2") return nil }, func(st *serverTester) { + // Ignore errors from writing invalid trailers. + st.h1server.ErrorLog = log.New(io.Discard, "", 0) getSlash(st) st.wantHeaders(wantHeader{ streamID: 1, diff --git a/http2/sync_test.go b/http2/sync_test.go index aeddbd6f3c..6687202d2c 100644 --- a/http2/sync_test.go +++ b/http2/sync_test.go @@ -24,9 +24,10 @@ type synctestGroup struct { } type goroutine struct { - id int - parent int - state string + id int + parent int + state string + syscall bool } // newSynctest creates a new group with the synthetic clock set the provided time. @@ -76,6 +77,14 @@ func (g *synctestGroup) Wait() { return } runtime.Gosched() + if runtime.GOOS == "js" { + // When GOOS=js, we appear to need to time.Sleep to make progress + // on some syscalls. In particular, without this sleep + // writing to stdout (including via t.Log) can block forever. + for range 10 { + time.Sleep(1) + } + } } } @@ -87,6 +96,9 @@ func (g *synctestGroup) idle() bool { if !g.gids[gr.id] && !g.gids[gr.parent] { continue } + if gr.syscall { + return false + } // From runtime/runtime2.go. switch gr.state { case "IO wait": @@ -97,9 +109,6 @@ func (g *synctestGroup) idle() bool { case "chan receive": case "chan send": case "sync.Cond.Wait": - case "sync.Mutex.Lock": - case "sync.RWMutex.RLock": - case "sync.RWMutex.Lock": default: return false } @@ -138,6 +147,10 @@ func stacks(all bool) []goroutine { panic(fmt.Errorf("3 unparsable goroutine stack:\n%s", gs)) } state, rest, ok := strings.Cut(rest, "]") + isSyscall := false + if strings.Contains(rest, "\nsyscall.") { + isSyscall = true + } var parent int _, rest, ok = strings.Cut(rest, "\ncreated by ") if ok && strings.Contains(rest, " in goroutine ") { @@ -155,9 +168,10 @@ func stacks(all bool) []goroutine { } } goroutines = append(goroutines, goroutine{ - id: id, - parent: parent, - state: state, + id: id, + parent: parent, + state: state, + syscall: isSyscall, }) } return goroutines @@ -291,3 +305,25 @@ func (tm *fakeTimer) Stop() bool { delete(tm.g.timers, tm) return stopped } + +// TestSynctestLogs verifies that t.Log works, +// in particular that the GOOS=js workaround in synctestGroup.Wait is working. +// (When GOOS=js, writing to stdout can hang indefinitely if some goroutine loops +// calling runtime.Gosched; see Wait for the workaround.) +func TestSynctestLogs(t *testing.T) { + g := newSynctest(time.Now()) + donec := make(chan struct{}) + go func() { + g.Join() + for range 100 { + t.Logf("logging a long line") + } + close(donec) + }() + g.Wait() + select { + case <-donec: + default: + panic("done") + } +} diff --git a/http2/transport.go b/http2/transport.go index f2c166b615..f26356b9cd 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -1286,6 +1286,19 @@ func (cc *ClientConn) responseHeaderTimeout() time.Duration { return 0 } +// actualContentLength returns a sanitized version of +// req.ContentLength, where 0 actually means zero (not unknown) and -1 +// means unknown. +func actualContentLength(req *http.Request) int64 { + if req.Body == nil || req.Body == http.NoBody { + return 0 + } + if req.ContentLength != 0 { + return req.ContentLength + } + return -1 +} + func (cc *ClientConn) decrStreamReservations() { cc.mu.Lock() defer cc.mu.Unlock() @@ -1310,7 +1323,7 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) reqCancel: req.Cancel, isHead: req.Method == "HEAD", reqBody: req.Body, - reqBodyContentLength: httpcommon.ActualContentLength(req), + reqBodyContentLength: actualContentLength(req), trace: httptrace.ContextClientTrace(ctx), peerClosed: make(chan struct{}), abort: make(chan struct{}), @@ -1318,7 +1331,7 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) donec: make(chan struct{}), } - cs.requestedGzip = httpcommon.IsRequestGzip(req, cc.t.disableCompression()) + cs.requestedGzip = httpcommon.IsRequestGzip(req.Method, req.Header, cc.t.disableCompression()) go cs.doRequest(req, streamf) @@ -1349,7 +1362,7 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) } res.Request = req res.TLS = cc.tlsState - if res.Body == noBody && httpcommon.ActualContentLength(req) == 0 { + if res.Body == noBody && actualContentLength(req) == 0 { // If there isn't a request or response body still being // written, then wait for the stream to be closed before // RoundTrip returns. @@ -1596,12 +1609,7 @@ func (cs *clientStream) encodeAndWriteHeaders(req *http.Request) error { // sent by writeRequestBody below, along with any Trailers, // again in form HEADERS{1}, CONTINUATION{0,}) cc.hbuf.Reset() - res, err := httpcommon.EncodeHeaders(httpcommon.EncodeHeadersParam{ - Request: req, - AddGzipHeader: cs.requestedGzip, - PeerMaxHeaderListSize: cc.peerMaxHeaderListSize, - DefaultUserAgent: defaultUserAgent, - }, func(name, value string) { + res, err := encodeRequestHeaders(req, cs.requestedGzip, cc.peerMaxHeaderListSize, func(name, value string) { cc.writeHeader(name, value) }) if err != nil { @@ -1617,6 +1625,22 @@ func (cs *clientStream) encodeAndWriteHeaders(req *http.Request) error { return err } +func encodeRequestHeaders(req *http.Request, addGzipHeader bool, peerMaxHeaderListSize uint64, headerf func(name, value string)) (httpcommon.EncodeHeadersResult, error) { + return httpcommon.EncodeHeaders(req.Context(), httpcommon.EncodeHeadersParam{ + Request: httpcommon.Request{ + Header: req.Header, + Trailer: req.Trailer, + URL: req.URL, + Host: req.Host, + Method: req.Method, + ActualContentLength: actualContentLength(req), + }, + AddGzipHeader: addGzipHeader, + PeerMaxHeaderListSize: peerMaxHeaderListSize, + DefaultUserAgent: defaultUserAgent, + }, headerf) +} + // cleanupWriteRequest performs post-request tasks. // // If err (the result of writeRequest) is non-nil and the stream is not closed, @@ -2186,6 +2210,13 @@ func (rl *clientConnReadLoop) cleanup() { } cc.cond.Broadcast() cc.mu.Unlock() + + if !cc.seenSettings { + // If we have a pending request that wants extended CONNECT, + // let it continue and fail with the connection error. + cc.extendedConnectAllowed = true + close(cc.seenSettingsChan) + } } // countReadFrameError calls Transport.CountError with a string @@ -2278,9 +2309,6 @@ func (rl *clientConnReadLoop) run() error { if VerboseLogs { cc.vlogf("http2: Transport conn %p received error from processing frame %v: %v", cc, summarizeFrame(f), err) } - if !cc.seenSettings { - close(cc.seenSettingsChan) - } return err } } diff --git a/http2/transport_test.go b/http2/transport_test.go index 47eac2fa80..1eeb76e06e 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -36,7 +36,6 @@ import ( "time" "golang.org/x/net/http2/hpack" - "golang.org/x/net/internal/httpcommon" ) var ( @@ -571,6 +570,45 @@ func randString(n int) string { return string(b) } +type panicReader struct{} + +func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") } +func (panicReader) Close() error { panic("unexpected Close") } + +func TestActualContentLength(t *testing.T) { + tests := []struct { + req *http.Request + want int64 + }{ + // Verify we don't read from Body: + 0: { + req: &http.Request{Body: panicReader{}}, + want: -1, + }, + // nil Body means 0, regardless of ContentLength: + 1: { + req: &http.Request{Body: nil, ContentLength: 5}, + want: 0, + }, + // ContentLength is used if set. + 2: { + req: &http.Request{Body: panicReader{}, ContentLength: 5}, + want: 5, + }, + // http.NoBody means 0, not -1. + 3: { + req: &http.Request{Body: http.NoBody}, + want: 0, + }, + } + for i, tt := range tests { + got := actualContentLength(tt.req) + if got != tt.want { + t.Errorf("test[%d]: got %d; want %d", i, got, tt.want) + } + } +} + func TestTransportBody(t *testing.T) { bodyTests := []struct { body string @@ -1405,12 +1443,9 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { } } headerListSizeForRequest := func(req *http.Request) (size uint64) { - _, err := httpcommon.EncodeHeaders(httpcommon.EncodeHeadersParam{ - Request: req, - AddGzipHeader: true, - PeerMaxHeaderListSize: 0xffffffffffffffff, - DefaultUserAgent: defaultUserAgent, - }, func(name, value string) { + const addGzipHeader = true + const peerMaxHeaderListSize = 0xffffffffffffffff + _, err := encodeRequestHeaders(req, addGzipHeader, peerMaxHeaderListSize, func(name, value string) { hf := hpack.HeaderField{Name: name, Value: value} size += uint64(hf.Size()) }) @@ -2808,11 +2843,10 @@ func TestTransportRequestPathPseudo(t *testing.T) { for i, tt := range tests { hbuf := &bytes.Buffer{} henc := hpack.NewEncoder(hbuf) - _, err := httpcommon.EncodeHeaders(httpcommon.EncodeHeadersParam{ - Request: tt.req, - AddGzipHeader: false, - PeerMaxHeaderListSize: 0xffffffffffffffff, - }, func(name, value string) { + + const addGzipHeader = false + const peerMaxHeaderListSize = 0xffffffffffffffff + _, err := encodeRequestHeaders(tt.req, addGzipHeader, peerMaxHeaderListSize, func(name, value string) { henc.WriteField(hpack.HeaderField{Name: name, Value: value}) }) hdrs := hbuf.Bytes() @@ -5880,3 +5914,24 @@ func TestExtendedConnectClientWithoutServerSupport(t *testing.T) { t.Fatalf("expected error errExtendedConnectNotSupported, got: %v", err) } } + +// Issue #70658: Make sure extended CONNECT requests don't get stuck if a +// connection fails early in its lifetime. +func TestExtendedConnectReadFrameError(t *testing.T) { + tc := newTestClientConn(t) + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + + req, _ := http.NewRequest("CONNECT", "https://dummy.tld/", nil) + req.Header.Set(":protocol", "extended-connect") + rt := tc.roundTrip(req) + tc.wantIdle() // waiting for SETTINGS response + + tc.closeWrite() // connection breaks without sending SETTINGS + if !rt.done() { + t.Fatalf("after connection closed: RoundTrip still running; want done") + } + if rt.err() == nil { + t.Fatalf("after connection closed: RoundTrip succeeded; want error") + } +} diff --git a/internal/http3/body_test.go b/internal/http3/body_test.go new file mode 100644 index 0000000000..599e0df816 --- /dev/null +++ b/internal/http3/body_test.go @@ -0,0 +1,276 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.24 && goexperiment.synctest + +package http3 + +import ( + "bytes" + "fmt" + "io" + "net/http" + "testing" +) + +// TestReadData tests servers reading request bodies, and clients reading response bodies. +func TestReadData(t *testing.T) { + // These tests consist of a series of steps, + // where each step is either something arriving on the stream + // or the client/server reading from the body. + type ( + // HEADERS frame arrives (headers). + receiveHeaders struct { + contentLength int64 // -1 for no content-length + } + // DATA frame header arrives. + receiveDataHeader struct { + size int64 + } + // DATA frame content arrives. + receiveData struct { + size int64 + } + // HEADERS frame arrives (trailers). + receiveTrailers struct{} + // Some other frame arrives. + receiveFrame struct { + ftype frameType + data []byte + } + // Stream closed, ending the body. + receiveEOF struct{} + // Server reads from Request.Body, or client reads from Response.Body. + wantBody struct { + size int64 + eof bool + } + wantError struct{} + ) + for _, test := range []struct { + name string + respHeader http.Header + steps []any + wantError bool + }{{ + name: "no content length", + steps: []any{ + receiveHeaders{contentLength: -1}, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + receiveEOF{}, + wantBody{size: 10, eof: true}, + }, + }, { + name: "valid content length", + steps: []any{ + receiveHeaders{contentLength: 10}, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + receiveEOF{}, + wantBody{size: 10, eof: true}, + }, + }, { + name: "data frame exceeds content length", + steps: []any{ + receiveHeaders{contentLength: 5}, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + wantError{}, + }, + }, { + name: "data frame after all content read", + steps: []any{ + receiveHeaders{contentLength: 5}, + receiveDataHeader{size: 5}, + receiveData{size: 5}, + wantBody{size: 5}, + receiveDataHeader{size: 1}, + receiveData{size: 1}, + wantError{}, + }, + }, { + name: "content length too long", + steps: []any{ + receiveHeaders{contentLength: 10}, + receiveDataHeader{size: 5}, + receiveData{size: 5}, + receiveEOF{}, + wantBody{size: 5}, + wantError{}, + }, + }, { + name: "stream ended by trailers", + steps: []any{ + receiveHeaders{contentLength: -1}, + receiveDataHeader{size: 5}, + receiveData{size: 5}, + receiveTrailers{}, + wantBody{size: 5, eof: true}, + }, + }, { + name: "trailers and content length too long", + steps: []any{ + receiveHeaders{contentLength: 10}, + receiveDataHeader{size: 5}, + receiveData{size: 5}, + wantBody{size: 5}, + receiveTrailers{}, + wantError{}, + }, + }, { + name: "unknown frame before headers", + steps: []any{ + receiveFrame{ + ftype: 0x1f + 0x21, // reserved frame type + data: []byte{1, 2, 3, 4}, + }, + receiveHeaders{contentLength: -1}, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + wantBody{size: 10}, + }, + }, { + name: "unknown frame after headers", + steps: []any{ + receiveHeaders{contentLength: -1}, + receiveFrame{ + ftype: 0x1f + 0x21, // reserved frame type + data: []byte{1, 2, 3, 4}, + }, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + wantBody{size: 10}, + }, + }, { + name: "invalid frame", + steps: []any{ + receiveHeaders{contentLength: -1}, + receiveFrame{ + ftype: frameTypeSettings, // not a valid frame on this stream + data: []byte{1, 2, 3, 4}, + }, + wantError{}, + }, + }, { + name: "data frame consumed by several reads", + steps: []any{ + receiveHeaders{contentLength: -1}, + receiveDataHeader{size: 16}, + receiveData{size: 16}, + wantBody{size: 2}, + wantBody{size: 4}, + wantBody{size: 8}, + wantBody{size: 2}, + }, + }, { + name: "read multiple frames", + steps: []any{ + receiveHeaders{contentLength: -1}, + receiveDataHeader{size: 2}, + receiveData{size: 2}, + receiveDataHeader{size: 4}, + receiveData{size: 4}, + receiveDataHeader{size: 8}, + receiveData{size: 8}, + wantBody{size: 2}, + wantBody{size: 4}, + wantBody{size: 8}, + }, + }} { + + runTest := func(t testing.TB, h http.Header, st *testQUICStream, body func() io.ReadCloser) { + var ( + bytesSent int + bytesReceived int + ) + for _, step := range test.steps { + switch step := step.(type) { + case receiveHeaders: + header := h.Clone() + if step.contentLength != -1 { + header["content-length"] = []string{ + fmt.Sprint(step.contentLength), + } + } + st.writeHeaders(header) + case receiveDataHeader: + t.Logf("receive DATA frame header: size=%v", step.size) + st.writeVarint(int64(frameTypeData)) + st.writeVarint(step.size) + st.Flush() + case receiveData: + t.Logf("receive DATA frame content: size=%v", step.size) + for range step.size { + st.stream.stream.WriteByte(byte(bytesSent)) + bytesSent++ + } + st.Flush() + case receiveTrailers: + st.writeHeaders(http.Header{ + "x-trailer": []string{"trailer"}, + }) + case receiveFrame: + st.writeVarint(int64(step.ftype)) + st.writeVarint(int64(len(step.data))) + st.Write(step.data) + st.Flush() + case receiveEOF: + t.Logf("receive EOF on request stream") + st.stream.stream.CloseWrite() + case wantBody: + t.Logf("read %v bytes from response body", step.size) + want := make([]byte, step.size) + for i := range want { + want[i] = byte(bytesReceived) + bytesReceived++ + } + got := make([]byte, step.size) + n, err := body().Read(got) + got = got[:n] + if !bytes.Equal(got, want) { + t.Errorf("resp.Body.Read:") + t.Errorf(" got: {%x}", got) + t.Fatalf(" want: {%x}", want) + } + if err != nil { + if step.eof && err == io.EOF { + continue + } + t.Fatalf("resp.Body.Read: unexpected error %v", err) + } + if step.eof { + if n, err := body().Read([]byte{0}); n != 0 || err != io.EOF { + t.Fatalf("resp.Body.Read() = %v, %v; want io.EOF", n, err) + } + } + case wantError: + if n, err := body().Read([]byte{0}); n != 0 || err == nil || err == io.EOF { + t.Fatalf("resp.Body.Read() = %v, %v; want error", n, err) + } + default: + t.Fatalf("unknown test step %T", step) + } + } + + } + + runSynctestSubtest(t, test.name+"/client", func(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://example.tld/", nil) + rt := tc.roundTrip(req) + st := tc.wantStream(streamTypeRequest) + st.wantHeaders(nil) + + header := http.Header{ + ":status": []string{"200"}, + } + runTest(t, header, st, func() io.ReadCloser { + return rt.response().Body + }) + }) + } +} diff --git a/internal/http3/conn.go b/internal/http3/conn.go index e9a58471ea..5eb803115e 100644 --- a/internal/http3/conn.go +++ b/internal/http3/conn.go @@ -19,7 +19,7 @@ type streamHandler interface { handlePushStream(*stream) error handleEncoderStream(*stream) error handleDecoderStream(*stream) error - handleRequestStream(*stream) + handleRequestStream(*stream) error abort(error) } @@ -43,7 +43,7 @@ func (c *genericConn) acceptStreams(qconn *quic.Conn, h streamHandler) { if st.IsReadOnly() { go c.handleUnidirectionalStream(newStream(st), h) } else { - go h.handleRequestStream(newStream(st)) + go c.handleRequestStream(newStream(st), h) } } } @@ -81,7 +81,6 @@ func (c *genericConn) handleUnidirectionalStream(st *stream, h streamHandler) { // but the quic package currently doesn't allow setting error codes // for STOP_SENDING frames. // TODO: Should CloseRead take an error code? - st.stream.CloseRead() err = nil } if err == io.EOF { @@ -90,8 +89,26 @@ func (c *genericConn) handleUnidirectionalStream(st *stream, h streamHandler) { message: streamType(stype).String() + " stream closed", } } - if err != nil { + c.handleStreamError(st, h, err) +} + +func (c *genericConn) handleRequestStream(st *stream, h streamHandler) { + c.handleStreamError(st, h, h.handleRequestStream(st)) +} + +func (c *genericConn) handleStreamError(st *stream, h streamHandler, err error) { + switch err := err.(type) { + case *connectionError: h.abort(err) + case nil: + st.stream.CloseRead() + st.stream.CloseWrite() + case *streamError: + st.stream.CloseRead() + st.stream.Reset(uint64(err.code)) + default: + st.stream.CloseRead() + st.stream.Reset(uint64(errH3InternalError)) } } diff --git a/internal/http3/conn_test.go b/internal/http3/conn_test.go index e9b5b4189d..a9afb1f9e9 100644 --- a/internal/http3/conn_test.go +++ b/internal/http3/conn_test.go @@ -146,4 +146,9 @@ func runConnTest(t *testing.T, f func(testing.TB, *testQUICConn)) { tc := newTestClientConn(t) f(t, tc.testQUICConn) }) + runSynctestSubtest(t, "server", func(t testing.TB) { + ts := newTestServer(t) + tc := ts.connect() + f(t, tc.testQUICConn) + }) } diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index b24a303081..bf55a13159 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -82,10 +82,19 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) st.stream.SetReadContext(req.Context()) st.stream.SetWriteContext(req.Context()) + contentLength := actualContentLength(req) + var encr httpcommon.EncodeHeadersResult headers := cc.enc.encode(func(yield func(itype indexType, name, value string)) { - encr, err = httpcommon.EncodeHeaders(httpcommon.EncodeHeadersParam{ - Request: req, + encr, err = httpcommon.EncodeHeaders(req.Context(), httpcommon.EncodeHeadersParam{ + Request: httpcommon.Request{ + URL: req.URL, + Method: req.Method, + Host: req.Host, + Header: req.Header, + Trailer: req.Trailer, + ActualContentLength: contentLength, + }, AddGzipHeader: false, // TODO: add when appropriate PeerMaxHeaderListSize: 0, DefaultUserAgent: "Go-http-client/3", @@ -110,7 +119,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) // TODO: Defer sending the request body when "Expect: 100-continue" is set. rt.reqBody = req.Body rt.reqBodyWriter.st = st - rt.reqBodyWriter.remain = httpcommon.ActualContentLength(req) + rt.reqBodyWriter.remain = contentLength rt.reqBodyWriter.flush = true rt.reqBodyWriter.name = "request" go copyRequestBody(rt) @@ -165,6 +174,18 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) } } +// actualContentLength returns a sanitized version of req.ContentLength, +// where 0 actually means zero (not unknown) and -1 means unknown. +func actualContentLength(req *http.Request) int64 { + if req.Body == nil || req.Body == http.NoBody { + return 0 + } + if req.ContentLength != 0 { + return req.ContentLength + } + return -1 +} + func copyRequestBody(rt *roundTripState) { defer rt.closeReqBody() _, err := io.Copy(&rt.reqBodyWriter, rt.reqBody) diff --git a/internal/http3/roundtrip_test.go b/internal/http3/roundtrip_test.go index 533b750a55..acd8613d0e 100644 --- a/internal/http3/roundtrip_test.go +++ b/internal/http3/roundtrip_test.go @@ -237,274 +237,6 @@ func TestRoundTripCrumbledCookiesInResponse(t *testing.T) { }) } -func TestRoundTripResponseBody(t *testing.T) { - // These tests consist of a series of steps, - // where each step is either something arriving on the response stream - // or the client reading from the request body. - type ( - // HEADERS frame arrives on the response stream (headers or trailers). - receiveHeaders http.Header - // DATA frame header arrives on the response stream. - receiveDataHeader struct { - size int64 - } - // DATA frame content arrives on the response stream. - receiveData struct { - size int64 - } - // Some other frame arrives on the response stream. - receiveFrame struct { - ftype frameType - data []byte - } - // Response stream closed, ending the body. - receiveEOF struct{} - // Client reads from Response.Body. - wantBody struct { - size int64 - eof bool - } - wantError struct{} - ) - for _, test := range []struct { - name string - respHeader http.Header - steps []any - wantError bool - }{{ - name: "no content length", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - }, - receiveDataHeader{size: 10}, - receiveData{size: 10}, - receiveEOF{}, - wantBody{size: 10, eof: true}, - }, - }, { - name: "valid content length", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - "content-length": []string{"10"}, - }, - receiveDataHeader{size: 10}, - receiveData{size: 10}, - receiveEOF{}, - wantBody{size: 10, eof: true}, - }, - }, { - name: "data frame exceeds content length", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - "content-length": []string{"5"}, - }, - receiveDataHeader{size: 10}, - receiveData{size: 10}, - wantError{}, - }, - }, { - name: "data frame after all content read", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - "content-length": []string{"5"}, - }, - receiveDataHeader{size: 5}, - receiveData{size: 5}, - wantBody{size: 5}, - receiveDataHeader{size: 1}, - receiveData{size: 1}, - wantError{}, - }, - }, { - name: "content length too long", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - "content-length": []string{"10"}, - }, - receiveDataHeader{size: 5}, - receiveData{size: 5}, - receiveEOF{}, - wantBody{size: 5}, - wantError{}, - }, - }, { - name: "stream ended by trailers", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - }, - receiveDataHeader{size: 5}, - receiveData{size: 5}, - receiveHeaders{ - "x-trailer": []string{"value"}, - }, - wantBody{size: 5, eof: true}, - }, - }, { - name: "trailers and content length too long", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - "content-length": []string{"10"}, - }, - receiveDataHeader{size: 5}, - receiveData{size: 5}, - wantBody{size: 5}, - receiveHeaders{ - "x-trailer": []string{"value"}, - }, - wantError{}, - }, - }, { - name: "unknown frame before headers", - steps: []any{ - receiveFrame{ - ftype: 0x1f + 0x21, // reserved frame type - data: []byte{1, 2, 3, 4}, - }, - receiveHeaders{ - ":status": []string{"200"}, - }, - receiveDataHeader{size: 10}, - receiveData{size: 10}, - wantBody{size: 10}, - }, - }, { - name: "unknown frame after headers", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - }, - receiveFrame{ - ftype: 0x1f + 0x21, // reserved frame type - data: []byte{1, 2, 3, 4}, - }, - receiveDataHeader{size: 10}, - receiveData{size: 10}, - wantBody{size: 10}, - }, - }, { - name: "invalid frame", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - }, - receiveFrame{ - ftype: frameTypeSettings, // not a valid frame on this stream - data: []byte{1, 2, 3, 4}, - }, - wantError{}, - }, - }, { - name: "data frame consumed by several reads", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - }, - receiveDataHeader{size: 16}, - receiveData{size: 16}, - wantBody{size: 2}, - wantBody{size: 4}, - wantBody{size: 8}, - wantBody{size: 2}, - }, - }, { - name: "read multiple frames", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - }, - receiveDataHeader{size: 2}, - receiveData{size: 2}, - receiveDataHeader{size: 4}, - receiveData{size: 4}, - receiveDataHeader{size: 8}, - receiveData{size: 8}, - wantBody{size: 2}, - wantBody{size: 4}, - wantBody{size: 8}, - }, - }} { - runSynctestSubtest(t, test.name, func(t testing.TB) { - tc := newTestClientConn(t) - tc.greet() - - req, _ := http.NewRequest("GET", "https://example.tld/", nil) - rt := tc.roundTrip(req) - st := tc.wantStream(streamTypeRequest) - st.wantHeaders(nil) - - var ( - bytesSent int - bytesReceived int - ) - for _, step := range test.steps { - switch step := step.(type) { - case receiveHeaders: - st.writeHeaders(http.Header(step)) - case receiveDataHeader: - t.Logf("receive DATA frame header: size=%v", step.size) - st.writeVarint(int64(frameTypeData)) - st.writeVarint(step.size) - st.Flush() - case receiveData: - t.Logf("receive DATA frame content: size=%v", step.size) - for range step.size { - st.stream.stream.WriteByte(byte(bytesSent)) - bytesSent++ - } - st.Flush() - case receiveFrame: - st.writeVarint(int64(step.ftype)) - st.writeVarint(int64(len(step.data))) - st.Write(step.data) - st.Flush() - case receiveEOF: - t.Logf("receive EOF on request stream") - st.stream.stream.CloseWrite() - case wantBody: - t.Logf("read %v bytes from response body", step.size) - want := make([]byte, step.size) - for i := range want { - want[i] = byte(bytesReceived) - bytesReceived++ - } - got := make([]byte, step.size) - n, err := rt.response().Body.Read(got) - got = got[:n] - if !bytes.Equal(got, want) { - t.Errorf("resp.Body.Read:") - t.Errorf(" got: {%x}", got) - t.Fatalf(" want: {%x}", want) - } - if err != nil { - if step.eof && err == io.EOF { - continue - } - t.Fatalf("resp.Body.Read: unexpected error %v", err) - } - if step.eof { - if n, err := rt.response().Body.Read([]byte{0}); n != 0 || err != io.EOF { - t.Fatalf("resp.Body.Read() = %v, %v; want io.EOF", n, err) - } - } - case wantError: - if n, err := rt.response().Body.Read([]byte{0}); n != 0 || err == nil || err == io.EOF { - t.Fatalf("resp.Body.Read() = %v, %v; want error", n, err) - } - default: - t.Fatalf("unknown test step %T", step) - } - } - }) - } -} - func TestRoundTripRequestBodySent(t *testing.T) { runSynctest(t, func(t testing.TB) { tc := newTestClientConn(t) diff --git a/internal/http3/server.go b/internal/http3/server.go new file mode 100644 index 0000000000..ca93c5298a --- /dev/null +++ b/internal/http3/server.go @@ -0,0 +1,172 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.24 + +package http3 + +import ( + "context" + "net/http" + "sync" + + "golang.org/x/net/quic" +) + +// A Server is an HTTP/3 server. +// The zero value for Server is a valid server. +type Server struct { + // Handler to invoke for requests, http.DefaultServeMux if nil. + Handler http.Handler + + // Config is the QUIC configuration used by the server. + // The Config may be nil. + // + // ListenAndServe may clone and modify the Config. + // The Config must not be modified after calling ListenAndServe. + Config *quic.Config + + initOnce sync.Once +} + +func (s *Server) init() { + s.initOnce.Do(func() { + s.Config = initConfig(s.Config) + if s.Handler == nil { + s.Handler = http.DefaultServeMux + } + }) +} + +// ListenAndServe listens on the UDP network address addr +// and then calls Serve to handle requests on incoming connections. +func (s *Server) ListenAndServe(addr string) error { + s.init() + e, err := quic.Listen("udp", addr, s.Config) + if err != nil { + return err + } + return s.Serve(e) +} + +// Serve accepts incoming connections on the QUIC endpoint e, +// and handles requests from those connections. +func (s *Server) Serve(e *quic.Endpoint) error { + s.init() + for { + qconn, err := e.Accept(context.Background()) + if err != nil { + return err + } + go newServerConn(qconn) + } +} + +type serverConn struct { + qconn *quic.Conn + + genericConn // for handleUnidirectionalStream + enc qpackEncoder + dec qpackDecoder +} + +func newServerConn(qconn *quic.Conn) { + sc := &serverConn{ + qconn: qconn, + } + sc.enc.init() + + // Create control stream and send SETTINGS frame. + // TODO: Time out on creating stream. + controlStream, err := newConnStream(context.Background(), sc.qconn, streamTypeControl) + if err != nil { + return + } + controlStream.writeSettings() + controlStream.Flush() + + sc.acceptStreams(sc.qconn, sc) +} + +func (sc *serverConn) handleControlStream(st *stream) error { + // "A SETTINGS frame MUST be sent as the first frame of each control stream [...]" + // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4-2 + if err := st.readSettings(func(settingsType, settingsValue int64) error { + switch settingsType { + case settingsMaxFieldSectionSize: + _ = settingsValue // TODO + case settingsQPACKMaxTableCapacity: + _ = settingsValue // TODO + case settingsQPACKBlockedStreams: + _ = settingsValue // TODO + default: + // Unknown settings types are ignored. + } + return nil + }); err != nil { + return err + } + + for { + ftype, err := st.readFrameHeader() + if err != nil { + return err + } + switch ftype { + case frameTypeCancelPush: + // "If a server receives a CANCEL_PUSH frame for a push ID + // that has not yet been mentioned by a PUSH_PROMISE frame, + // this MUST be treated as a connection error of type H3_ID_ERROR." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.3-8 + return &connectionError{ + code: errH3IDError, + message: "CANCEL_PUSH for unsent push ID", + } + case frameTypeGoaway: + return errH3NoError + default: + // Unknown frames are ignored. + if err := st.discardUnknownFrame(ftype); err != nil { + return err + } + } + } +} + +func (sc *serverConn) handleEncoderStream(*stream) error { + // TODO + return nil +} + +func (sc *serverConn) handleDecoderStream(*stream) error { + // TODO + return nil +} + +func (sc *serverConn) handlePushStream(*stream) error { + // "[...] if a server receives a client-initiated push stream, + // this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2.2-3 + return &connectionError{ + code: errH3StreamCreationError, + message: "client created push stream", + } +} + +func (sc *serverConn) handleRequestStream(st *stream) error { + // TODO + return nil +} + +// abort closes the connection with an error. +func (sc *serverConn) abort(err error) { + if e, ok := err.(*connectionError); ok { + sc.qconn.Abort(&quic.ApplicationError{ + Code: uint64(e.code), + Reason: e.message, + }) + } else { + sc.qconn.Abort(err) + } +} diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go new file mode 100644 index 0000000000..8e727d2512 --- /dev/null +++ b/internal/http3/server_test.go @@ -0,0 +1,110 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.24 && goexperiment.synctest + +package http3 + +import ( + "net/netip" + "testing" + "testing/synctest" + + "golang.org/x/net/internal/quic/quicwire" + "golang.org/x/net/quic" +) + +func TestServerReceivePushStream(t *testing.T) { + // "[...] if a server receives a client-initiated push stream, + // this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2.2-3 + runSynctest(t, func(t testing.TB) { + ts := newTestServer(t) + tc := ts.connect() + tc.newStream(streamTypePush) + tc.wantClosed("invalid client-created push stream", errH3StreamCreationError) + }) +} + +func TestServerCancelPushForUnsentPromise(t *testing.T) { + runSynctest(t, func(t testing.TB) { + ts := newTestServer(t) + tc := ts.connect() + tc.greet() + + const pushID = 100 + tc.control.writeVarint(int64(frameTypeCancelPush)) + tc.control.writeVarint(int64(quicwire.SizeVarint(pushID))) + tc.control.writeVarint(pushID) + tc.control.Flush() + + tc.wantClosed("client canceled never-sent push ID", errH3IDError) + }) +} + +type testServer struct { + t testing.TB + s *Server + tn testNet + *testQUICEndpoint + + addr netip.AddrPort +} + +type testQUICEndpoint struct { + t testing.TB + e *quic.Endpoint +} + +func (te *testQUICEndpoint) dial() { +} + +type testServerConn struct { + ts *testServer + + *testQUICConn + control *testQUICStream +} + +func newTestServer(t testing.TB) *testServer { + t.Helper() + ts := &testServer{ + t: t, + s: &Server{ + Config: &quic.Config{ + TLSConfig: testTLSConfig, + }, + }, + } + e := ts.tn.newQUICEndpoint(t, ts.s.Config) + ts.addr = e.LocalAddr() + go ts.s.Serve(e) + return ts +} + +func (ts *testServer) connect() *testServerConn { + ts.t.Helper() + config := &quic.Config{TLSConfig: testTLSConfig} + e := ts.tn.newQUICEndpoint(ts.t, nil) + qconn, err := e.Dial(ts.t.Context(), "udp", ts.addr.String(), config) + if err != nil { + ts.t.Fatal(err) + } + tc := &testServerConn{ + ts: ts, + testQUICConn: newTestQUICConn(ts.t, qconn), + } + synctest.Wait() + return tc +} + +// greet performs initial connection handshaking with the server. +func (tc *testServerConn) greet() { + // Client creates a control stream. + tc.control = tc.newStream(streamTypeControl) + tc.control.writeVarint(int64(frameTypeSettings)) + tc.control.writeVarint(0) // size + tc.control.Flush() + synctest.Wait() +} diff --git a/internal/http3/settings.go b/internal/http3/settings.go index 45018aadd2..b5e562ecad 100644 --- a/internal/http3/settings.go +++ b/internal/http3/settings.go @@ -8,7 +8,6 @@ package http3 import ( "golang.org/x/net/internal/quic/quicwire" - "golang.org/x/net/quic" ) const ( @@ -39,9 +38,9 @@ func (st *stream) writeSettings(settings ...int64) { func (st *stream) readSettings(f func(settingType, value int64) error) error { frameType, err := st.readFrameHeader() if err != nil || frameType != frameTypeSettings { - return &quic.ApplicationError{ - Code: uint64(errH3MissingSettings), - Reason: "settings not sent on control stream", + return &connectionError{ + code: errH3MissingSettings, + message: "settings not sent on control stream", } } for st.lim > 0 { @@ -59,9 +58,9 @@ func (st *stream) readSettings(f func(settingType, value int64) error) error { // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4.1-5 switch settingsType { case 0x02, 0x03, 0x04, 0x05: - return &quic.ApplicationError{ - Code: uint64(errH3SettingsError), - Reason: "use of reserved setting", + return &connectionError{ + code: errH3SettingsError, + message: "use of reserved setting", } } diff --git a/internal/http3/transport.go b/internal/http3/transport.go index 83bc56c2bf..b26524cbda 100644 --- a/internal/http3/transport.go +++ b/internal/http3/transport.go @@ -167,14 +167,14 @@ func (cc *ClientConn) handlePushStream(*stream) error { } } -func (cc *ClientConn) handleRequestStream(st *stream) { +func (cc *ClientConn) handleRequestStream(st *stream) error { // "Clients MUST treat receipt of a server-initiated bidirectional // stream as a connection error of type H3_STREAM_CREATION_ERROR [...]" // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.1-3 - cc.abort(&connectionError{ + return &connectionError{ code: errH3StreamCreationError, message: "server created bidirectional stream", - }) + } } // abort closes the connection with an error. diff --git a/internal/httpcommon/headermap.go b/internal/httpcommon/headermap.go index ad3fbacd60..92483d8e41 100644 --- a/internal/httpcommon/headermap.go +++ b/internal/httpcommon/headermap.go @@ -5,7 +5,7 @@ package httpcommon import ( - "net/http" + "net/textproto" "sync" ) @@ -82,7 +82,7 @@ func buildCommonHeaderMaps() { commonLowerHeader = make(map[string]string, len(common)) commonCanonHeader = make(map[string]string, len(common)) for _, v := range common { - chk := http.CanonicalHeaderKey(v) + chk := textproto.CanonicalMIMEHeaderKey(v) commonLowerHeader[chk] = v commonCanonHeader[v] = chk } @@ -104,7 +104,7 @@ func CanonicalHeader(v string) string { if s, ok := commonCanonHeader[v]; ok { return s } - return http.CanonicalHeaderKey(v) + return textproto.CanonicalMIMEHeaderKey(v) } // CachedCanonicalHeader returns the canonical form of a well-known header name. diff --git a/internal/httpcommon/httpcommon_test.go b/internal/httpcommon/httpcommon_test.go new file mode 100644 index 0000000000..e725ec76cb --- /dev/null +++ b/internal/httpcommon/httpcommon_test.go @@ -0,0 +1,37 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httpcommon_test + +import ( + "bytes" + "os" + "path/filepath" + "strings" + "testing" +) + +// This package is imported by the net/http package, +// and therefore must not itself import net/http. +func TestNoNetHttp(t *testing.T) { + files, err := filepath.Glob("*.go") + if err != nil { + t.Fatal(err) + } + for _, file := range files { + if strings.HasSuffix(file, "_test.go") { + continue + } + // Could use something complex like go/build or x/tools/go/packages, + // but there's no reason for "net/http" to appear (in quotes) in the source + // otherwise, so just use a simple substring search. + data, err := os.ReadFile(file) + if err != nil { + t.Fatal(err) + } + if bytes.Contains(data, []byte(`"net/http"`)) { + t.Errorf(`%s: cannot import "net/http"`, file) + } + } +} diff --git a/internal/httpcommon/request.go b/internal/httpcommon/request.go index 3439147738..4b70553179 100644 --- a/internal/httpcommon/request.go +++ b/internal/httpcommon/request.go @@ -5,10 +5,12 @@ package httpcommon import ( + "context" "errors" "fmt" - "net/http" "net/http/httptrace" + "net/textproto" + "net/url" "sort" "strconv" "strings" @@ -21,9 +23,21 @@ var ( ErrRequestHeaderListSize = errors.New("request header list larger than peer's advertised limit") ) +// Request is a subset of http.Request. +// It'd be simpler to pass an *http.Request, of course, but we can't depend on net/http +// without creating a dependency cycle. +type Request struct { + URL *url.URL + Method string + Host string + Header map[string][]string + Trailer map[string][]string + ActualContentLength int64 // 0 means 0, -1 means unknown +} + // EncodeHeadersParam is parameters to EncodeHeaders. type EncodeHeadersParam struct { - Request *http.Request + Request Request // AddGzipHeader indicates that an "accept-encoding: gzip" header should be // added to the request. @@ -47,11 +61,11 @@ type EncodeHeadersResult struct { // It validates a request and calls headerf with each pseudo-header and header // for the request. // The headerf function is called with the validated, canonicalized header name. -func EncodeHeaders(param EncodeHeadersParam, headerf func(name, value string)) (res EncodeHeadersResult, _ error) { +func EncodeHeaders(ctx context.Context, param EncodeHeadersParam, headerf func(name, value string)) (res EncodeHeadersResult, _ error) { req := param.Request // Check for invalid connection-level headers. - if err := checkConnHeaders(req); err != nil { + if err := checkConnHeaders(req.Header); err != nil { return res, err } @@ -73,7 +87,10 @@ func EncodeHeaders(param EncodeHeadersParam, headerf func(name, value string)) ( // isNormalConnect is true if this is a non-extended CONNECT request. isNormalConnect := false - protocol := req.Header.Get(":protocol") + var protocol string + if vv := req.Header[":protocol"]; len(vv) > 0 { + protocol = vv[0] + } if req.Method == "CONNECT" && protocol == "" { isNormalConnect = true } else if protocol != "" && req.Method != "CONNECT" { @@ -107,9 +124,7 @@ func EncodeHeaders(param EncodeHeadersParam, headerf func(name, value string)) ( return res, fmt.Errorf("invalid HTTP trailer %s", err) } - contentLength := ActualContentLength(req) - - trailers, err := commaSeparatedTrailers(req) + trailers, err := commaSeparatedTrailers(req.Trailer) if err != nil { return res, err } @@ -123,7 +138,7 @@ func EncodeHeaders(param EncodeHeadersParam, headerf func(name, value string)) ( f(":authority", host) m := req.Method if m == "" { - m = http.MethodGet + m = "GET" } f(":method", m) if !isNormalConnect { @@ -198,8 +213,8 @@ func EncodeHeaders(param EncodeHeadersParam, headerf func(name, value string)) ( f(k, v) } } - if shouldSendReqContentLength(req.Method, contentLength) { - f("content-length", strconv.FormatInt(contentLength, 10)) + if shouldSendReqContentLength(req.Method, req.ActualContentLength) { + f("content-length", strconv.FormatInt(req.ActualContentLength, 10)) } if param.AddGzipHeader { f("accept-encoding", "gzip") @@ -225,7 +240,7 @@ func EncodeHeaders(param EncodeHeadersParam, headerf func(name, value string)) ( } } - trace := httptrace.ContextClientTrace(req.Context()) + trace := httptrace.ContextClientTrace(ctx) // Header list size is ok. Write the headers. enumerateHeaders(func(name, value string) { @@ -243,19 +258,19 @@ func EncodeHeaders(param EncodeHeadersParam, headerf func(name, value string)) ( } }) - res.HasBody = contentLength != 0 + res.HasBody = req.ActualContentLength != 0 res.HasTrailers = trailers != "" return res, nil } // IsRequestGzip reports whether we should add an Accept-Encoding: gzip header // for a request. -func IsRequestGzip(req *http.Request, disableCompression bool) bool { +func IsRequestGzip(method string, header map[string][]string, disableCompression bool) bool { // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? if !disableCompression && - req.Header.Get("Accept-Encoding") == "" && - req.Header.Get("Range") == "" && - req.Method != "HEAD" { + len(header["Accept-Encoding"]) == 0 && + len(header["Range"]) == 0 && + method != "HEAD" { // Request gzip only, not deflate. Deflate is ambiguous and // not as universally supported anyway. // See: https://zlib.net/zlib_faq.html#faq39 @@ -280,22 +295,22 @@ func IsRequestGzip(req *http.Request, disableCompression bool) bool { // // Certain headers are special-cased as okay but not transmitted later. // For example, we allow "Transfer-Encoding: chunked", but drop the header when encoding. -func checkConnHeaders(req *http.Request) error { - if v := req.Header.Get("Upgrade"); v != "" { - return fmt.Errorf("invalid Upgrade request header: %q", req.Header["Upgrade"]) +func checkConnHeaders(h map[string][]string) error { + if vv := h["Upgrade"]; len(vv) > 0 && (vv[0] != "" && vv[0] != "chunked") { + return fmt.Errorf("invalid Upgrade request header: %q", vv) } - if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { + if vv := h["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { return fmt.Errorf("invalid Transfer-Encoding request header: %q", vv) } - if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !asciiEqualFold(vv[0], "close") && !asciiEqualFold(vv[0], "keep-alive")) { + if vv := h["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !asciiEqualFold(vv[0], "close") && !asciiEqualFold(vv[0], "keep-alive")) { return fmt.Errorf("invalid Connection request header: %q", vv) } return nil } -func commaSeparatedTrailers(req *http.Request) (string, error) { - keys := make([]string, 0, len(req.Trailer)) - for k := range req.Trailer { +func commaSeparatedTrailers(trailer map[string][]string) (string, error) { + keys := make([]string, 0, len(trailer)) + for k := range trailer { k = CanonicalHeader(k) switch k { case "Transfer-Encoding", "Trailer", "Content-Length": @@ -310,19 +325,6 @@ func commaSeparatedTrailers(req *http.Request) (string, error) { return "", nil } -// ActualContentLength returns a sanitized version of -// req.ContentLength, where 0 actually means zero (not unknown) and -1 -// means unknown. -func ActualContentLength(req *http.Request) int64 { - if req.Body == nil || req.Body == http.NoBody { - return 0 - } - if req.ContentLength != 0 { - return req.ContentLength - } - return -1 -} - // validPseudoPath reports whether v is a valid :path pseudo-header // value. It must be either: // @@ -340,7 +342,7 @@ func validPseudoPath(v string) bool { return (len(v) > 0 && v[0] == '/') || v == "*" } -func validateHeaders(hdrs http.Header) string { +func validateHeaders(hdrs map[string][]string) string { for k, vv := range hdrs { if !httpguts.ValidHeaderFieldName(k) && k != ":protocol" { return fmt.Sprintf("name %q", k) @@ -377,3 +379,89 @@ func shouldSendReqContentLength(method string, contentLength int64) bool { return false } } + +// ServerRequestParam is parameters to NewServerRequest. +type ServerRequestParam struct { + Method string + Scheme, Authority, Path string + Protocol string + Header map[string][]string +} + +// ServerRequestResult is the result of NewServerRequest. +type ServerRequestResult struct { + // Various http.Request fields. + URL *url.URL + RequestURI string + Trailer map[string][]string + + NeedsContinue bool // client provided an "Expect: 100-continue" header + + // If the request should be rejected, this is a short string suitable for passing + // to the http2 package's CountError function. + // It might be a bit odd to return errors this way rather than returing an error, + // but this ensures we don't forget to include a CountError reason. + InvalidReason string +} + +func NewServerRequest(rp ServerRequestParam) ServerRequestResult { + needsContinue := httpguts.HeaderValuesContainsToken(rp.Header["Expect"], "100-continue") + if needsContinue { + delete(rp.Header, "Expect") + } + // Merge Cookie headers into one "; "-delimited value. + if cookies := rp.Header["Cookie"]; len(cookies) > 1 { + rp.Header["Cookie"] = []string{strings.Join(cookies, "; ")} + } + + // Setup Trailers + var trailer map[string][]string + for _, v := range rp.Header["Trailer"] { + for _, key := range strings.Split(v, ",") { + key = textproto.CanonicalMIMEHeaderKey(textproto.TrimString(key)) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + // Bogus. (copy of http1 rules) + // Ignore. + default: + if trailer == nil { + trailer = make(map[string][]string) + } + trailer[key] = nil + } + } + } + delete(rp.Header, "Trailer") + + // "':authority' MUST NOT include the deprecated userinfo subcomponent + // for "http" or "https" schemed URIs." + // https://www.rfc-editor.org/rfc/rfc9113.html#section-8.3.1-2.3.8 + if strings.IndexByte(rp.Authority, '@') != -1 && (rp.Scheme == "http" || rp.Scheme == "https") { + return ServerRequestResult{ + InvalidReason: "userinfo_in_authority", + } + } + + var url_ *url.URL + var requestURI string + if rp.Method == "CONNECT" && rp.Protocol == "" { + url_ = &url.URL{Host: rp.Authority} + requestURI = rp.Authority // mimic HTTP/1 server behavior + } else { + var err error + url_, err = url.ParseRequestURI(rp.Path) + if err != nil { + return ServerRequestResult{ + InvalidReason: "bad_path", + } + } + requestURI = rp.Path + } + + return ServerRequestResult{ + URL: url_, + NeedsContinue: needsContinue, + RequestURI: requestURI, + Trailer: trailer, + } +} diff --git a/internal/httpcommon/request_test.go b/internal/httpcommon/request_test.go index b453983e0d..b8792977c1 100644 --- a/internal/httpcommon/request_test.go +++ b/internal/httpcommon/request_test.go @@ -6,6 +6,7 @@ package httpcommon import ( "cmp" + "context" "io" "net/http" "slices" @@ -27,9 +28,9 @@ func TestEncodeHeaders(t *testing.T) { }{{ name: "simple request", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { return must(http.NewRequest("GET", "https://example.tld/", nil)) - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -47,12 +48,12 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "host set from URL", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Host = "" req.URL.Host = "example.tld" return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -70,11 +71,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "chunked transfer-encoding", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Transfer-Encoding", "chunked") // ignored return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -92,11 +93,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "connection close", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Connection", "close") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -114,11 +115,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "connection keep-alive", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Connection", "keep-alive") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -136,9 +137,9 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "normal connect", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { return must(http.NewRequest("CONNECT", "https://example.tld/", nil)) - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -154,11 +155,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "extended connect", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("CONNECT", "https://example.tld/", nil)) req.Header.Set(":protocol", "foo") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -177,13 +178,13 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "trailers", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Trailer = make(http.Header) req.Trailer.Set("a", "1") req.Trailer.Set("b", "2") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -202,11 +203,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "override user-agent", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("User-Agent", "GopherTron 9000") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -224,11 +225,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "disable user-agent", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header["User-Agent"] = nil return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -245,11 +246,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "ignore host header", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Host", "gophers.tld/") // ignored return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -267,11 +268,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "crumble cookie header", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Cookie", "a=b; b=c; c=d") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -293,9 +294,9 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "post with nil body", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { return must(http.NewRequest("POST", "https://example.tld/", nil)) - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -314,9 +315,9 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "post with NoBody", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { return must(http.NewRequest("POST", "https://example.tld/", http.NoBody)) - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -335,12 +336,12 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "post with Content-Length", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { type reader struct{ io.ReadCloser } req := must(http.NewRequest("POST", "https://example.tld/", reader{})) req.ContentLength = 10 return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -359,11 +360,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "post with unknown Content-Length", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { type reader struct{ io.ReadCloser } req := must(http.NewRequest("POST", "https://example.tld/", reader{})) return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -381,11 +382,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "explicit accept-encoding", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Accept-Encoding", "deflate") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -403,9 +404,9 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "head request", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { return must(http.NewRequest("HEAD", "https://example.tld/", nil)) - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -422,11 +423,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "range request", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("HEAD", "https://example.tld/", nil)) req.Header.Set("Range", "bytes=0-10") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -444,11 +445,11 @@ func TestEncodeHeaders(t *testing.T) { }} { t.Run(test.name, func(t *testing.T) { var gotHeaders []header - if IsRequestGzip(test.in.Request, test.disableCompression) { + if IsRequestGzip(test.in.Request.Method, test.in.Request.Header, test.disableCompression) { test.in.AddGzipHeader = true } - got, err := EncodeHeaders(test.in, func(name, value string) { + got, err := EncodeHeaders(context.Background(), test.in, func(name, value string) { gotHeaders = append(gotHeaders, header{name, value}) }) if err != nil { @@ -490,151 +491,151 @@ func TestEncodeHeaderErrors(t *testing.T) { }{{ name: "URL is nil", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.URL = nil return req - }(), + }), }, want: "URL is nil", }, { name: "upgrade header is set", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Upgrade", "foo") return req - }(), + }), }, want: "Upgrade", }, { name: "unsupported transfer-encoding header", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Transfer-Encoding", "identity") return req - }(), + }), }, want: "Transfer-Encoding", }, { name: "unsupported connection header", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Connection", "x") return req - }(), + }), }, want: "Connection", }, { name: "invalid host", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Host = "\x00.tld" return req - }(), + }), }, want: "Host", }, { name: "protocol header is set", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set(":protocol", "foo") return req - }(), + }), }, want: ":protocol", }, { name: "invalid path", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.URL.Path = "no_leading_slash" return req - }(), + }), }, want: "path", }, { name: "invalid header name", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("x\ny", "foo") return req - }(), + }), }, want: "header", }, { name: "invalid header value", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("x", "foo\nbar") return req - }(), + }), }, want: "header", }, { name: "invalid trailer", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Trailer = make(http.Header) req.Trailer.Set("x\ny", "foo") return req - }(), + }), }, want: "trailer", }, { name: "transfer-encoding trailer", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Trailer = make(http.Header) req.Trailer.Set("Transfer-Encoding", "chunked") return req - }(), + }), }, want: "Trailer", }, { name: "trailer trailer", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Trailer = make(http.Header) req.Trailer.Set("Trailer", "chunked") return req - }(), + }), }, want: "Trailer", }, { name: "content-length trailer", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Trailer = make(http.Header) req.Trailer.Set("Content-Length", "0") return req - }(), + }), }, want: "Trailer", }, { name: "too many headers", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("X-Foo", strings.Repeat("x", 1000)) return req - }(), + }), PeerMaxHeaderListSize: 1000, }, want: "limit", }} { t.Run(test.name, func(t *testing.T) { - _, err := EncodeHeaders(test.in, func(name, value string) {}) + _, err := EncodeHeaders(context.Background(), test.in, func(name, value string) {}) if err == nil { t.Fatalf("EncodeHeaders = nil, want %q", test.want) } @@ -645,48 +646,27 @@ func TestEncodeHeaderErrors(t *testing.T) { } } +func newReq(f func() *http.Request) Request { + req := f() + contentLength := req.ContentLength + if req.Body == nil || req.Body == http.NoBody { + contentLength = 0 + } else if contentLength == 0 { + contentLength = -1 + } + return Request{ + Header: req.Header, + Trailer: req.Trailer, + URL: req.URL, + Host: req.Host, + Method: req.Method, + ActualContentLength: contentLength, + } +} + func must[T any](v T, err error) T { if err != nil { panic(err) } return v } - -type panicReader struct{} - -func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") } -func (panicReader) Close() error { panic("unexpected Close") } - -func TestActualContentLength(t *testing.T) { - tests := []struct { - req *http.Request - want int64 - }{ - // Verify we don't read from Body: - 0: { - req: &http.Request{Body: panicReader{}}, - want: -1, - }, - // nil Body means 0, regardless of ContentLength: - 1: { - req: &http.Request{Body: nil, ContentLength: 5}, - want: 0, - }, - // ContentLength is used if set. - 2: { - req: &http.Request{Body: panicReader{}, ContentLength: 5}, - want: 5, - }, - // http.NoBody means 0, not -1. - 3: { - req: &http.Request{Body: http.NoBody}, - want: 0, - }, - } - for i, tt := range tests { - got := ActualContentLength(tt.req) - if got != tt.want { - t.Errorf("test[%d]: got %d; want %d", i, got, tt.want) - } - } -} diff --git a/proxy/per_host.go b/proxy/per_host.go index d7d4b8b6e3..32bdf435ec 100644 --- a/proxy/per_host.go +++ b/proxy/per_host.go @@ -7,6 +7,7 @@ package proxy import ( "context" "net" + "net/netip" "strings" ) @@ -57,7 +58,8 @@ func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net. } func (p *PerHost) dialerForRequest(host string) Dialer { - if ip := net.ParseIP(host); ip != nil { + if nip, err := netip.ParseAddr(host); err == nil { + ip := net.IP(nip.AsSlice()) for _, net := range p.bypassNetworks { if net.Contains(ip) { return p.bypass @@ -108,8 +110,8 @@ func (p *PerHost) AddFromString(s string) { } continue } - if ip := net.ParseIP(host); ip != nil { - p.AddIP(ip) + if nip, err := netip.ParseAddr(host); err == nil { + p.AddIP(net.IP(nip.AsSlice())) continue } if strings.HasPrefix(host, "*.") { diff --git a/proxy/per_host_test.go b/proxy/per_host_test.go index 0447eb427a..b7bcec8ae3 100644 --- a/proxy/per_host_test.go +++ b/proxy/per_host_test.go @@ -7,8 +7,9 @@ package proxy import ( "context" "errors" + "fmt" "net" - "reflect" + "slices" "testing" ) @@ -22,55 +23,118 @@ func (r *recordingProxy) Dial(network, addr string) (net.Conn, error) { } func TestPerHost(t *testing.T) { - expectedDef := []string{ - "example.com:123", - "1.2.3.4:123", - "[1001::]:123", - } - expectedBypass := []string{ - "localhost:123", - "zone:123", - "foo.zone:123", - "127.0.0.1:123", - "10.1.2.3:123", - "[1000::]:123", - } - - t.Run("Dial", func(t *testing.T) { - var def, bypass recordingProxy - perHost := NewPerHost(&def, &bypass) - perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16") - for _, addr := range expectedDef { - perHost.Dial("tcp", addr) + for _, test := range []struct { + config string // passed to PerHost.AddFromString + nomatch []string // addrs using the default dialer + match []string // addrs using the bypass dialer + }{{ + config: "localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16", + nomatch: []string{ + "example.com:123", + "1.2.3.4:123", + "[1001::]:123", + }, + match: []string{ + "localhost:123", + "zone:123", + "foo.zone:123", + "127.0.0.1:123", + "10.1.2.3:123", + "[1000::]:123", + "[1000::%25.example.com]:123", + }, + }, { + config: "localhost", + nomatch: []string{ + "127.0.0.1:80", + }, + match: []string{ + "localhost:80", + }, + }, { + config: "*.zone", + nomatch: []string{ + "foo.com:80", + }, + match: []string{ + "foo.zone:80", + "foo.bar.zone:80", + }, + }, { + config: "1.2.3.4", + nomatch: []string{ + "127.0.0.1:80", + "11.2.3.4:80", + }, + match: []string{ + "1.2.3.4:80", + }, + }, { + config: "10.0.0.0/24", + nomatch: []string{ + "10.0.1.1:80", + }, + match: []string{ + "10.0.0.1:80", + "10.0.0.255:80", + }, + }, { + config: "fe80::/10", + nomatch: []string{ + "[fec0::1]:80", + "[fec0::1%en0]:80", + }, + match: []string{ + "[fe80::1]:80", + "[fe80::1%en0]:80", + }, + }, { + // We don't allow zone IDs in network prefixes, + // so this config matches nothing. + config: "fe80::%en0/10", + nomatch: []string{ + "[fec0::1]:80", + "[fec0::1%en0]:80", + "[fe80::1]:80", + "[fe80::1%en0]:80", + "[fe80::1%en1]:80", + }, + }} { + for _, addr := range test.match { + testPerHost(t, test.config, addr, true) } - for _, addr := range expectedBypass { - perHost.Dial("tcp", addr) + for _, addr := range test.nomatch { + testPerHost(t, test.config, addr, false) } + } +} - if !reflect.DeepEqual(expectedDef, def.addrs) { - t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef) - } - if !reflect.DeepEqual(expectedBypass, bypass.addrs) { - t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass) - } - }) +func testPerHost(t *testing.T, config, addr string, wantMatch bool) { + name := fmt.Sprintf("config %q, dial %q", config, addr) - t.Run("DialContext", func(t *testing.T) { - var def, bypass recordingProxy - perHost := NewPerHost(&def, &bypass) - perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16") - for _, addr := range expectedDef { - perHost.DialContext(context.Background(), "tcp", addr) - } - for _, addr := range expectedBypass { - perHost.DialContext(context.Background(), "tcp", addr) - } + var def, bypass recordingProxy + perHost := NewPerHost(&def, &bypass) + perHost.AddFromString(config) + perHost.Dial("tcp", addr) - if !reflect.DeepEqual(expectedDef, def.addrs) { - t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef) - } - if !reflect.DeepEqual(expectedBypass, bypass.addrs) { - t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass) - } - }) + // Dial and DialContext should have the same results. + var defc, bypassc recordingProxy + perHostc := NewPerHost(&defc, &bypassc) + perHostc.AddFromString(config) + perHostc.DialContext(context.Background(), "tcp", addr) + if !slices.Equal(def.addrs, defc.addrs) { + t.Errorf("%v: Dial default=%v, bypass=%v; DialContext default=%v, bypass=%v", name, def.addrs, bypass.addrs, defc.addrs, bypass.addrs) + return + } + + if got, want := slices.Concat(def.addrs, bypass.addrs), []string{addr}; !slices.Equal(got, want) { + t.Errorf("%v: dialed %q, want %q", name, got, want) + return + } + + gotMatch := len(bypass.addrs) > 0 + if gotMatch != wantMatch { + t.Errorf("%v: matched=%v, want %v", name, gotMatch, wantMatch) + return + } } diff --git a/publicsuffix/gen.go b/publicsuffix/gen.go index 7f7d08dbc2..5f454e57e9 100644 --- a/publicsuffix/gen.go +++ b/publicsuffix/gen.go @@ -21,6 +21,7 @@ package main import ( "bufio" "bytes" + "cmp" "encoding/binary" "flag" "fmt" @@ -29,7 +30,7 @@ import ( "net/http" "os" "regexp" - "sort" + "slices" "strings" "golang.org/x/net/idna" @@ -62,20 +63,6 @@ var ( maxLo uint32 ) -func max(a, b int) int { - if a < b { - return b - } - return a -} - -func u32max(a, b uint32) uint32 { - if a < b { - return b - } - return a -} - const ( nodeTypeNormal = 0 nodeTypeException = 1 @@ -83,18 +70,6 @@ const ( numNodeType = 3 ) -func nodeTypeStr(n int) string { - switch n { - case nodeTypeNormal: - return "+" - case nodeTypeException: - return "!" - case nodeTypeParentOnly: - return "o" - } - panic("unreachable") -} - const ( defaultURL = "https://publicsuffix.org/list/effective_tld_names.dat" gitCommitURL = "https://api.github.com/repos/publicsuffix/list/commits?path=public_suffix_list.dat" @@ -251,7 +226,7 @@ func main1() error { for label := range labelsMap { labelsList = append(labelsList, label) } - sort.Strings(labelsList) + slices.Sort(labelsList) combinedText = combineText(labelsList) if combinedText == "" { @@ -509,15 +484,13 @@ func (n *node) child(label string) *node { icann: true, } n.children = append(n.children, c) - sort.Sort(byLabel(n.children)) + slices.SortFunc(n.children, byLabel) return c } -type byLabel []*node - -func (b byLabel) Len() int { return len(b) } -func (b byLabel) Swap(i, j int) { b[i], b[j] = b[j], b[i] } -func (b byLabel) Less(i, j int) bool { return b[i].label < b[j].label } +func byLabel(a, b *node) int { + return strings.Compare(a.label, b.label) +} var nextNodesIndex int @@ -557,7 +530,7 @@ func assignIndexes(n *node) error { n.childrenIndex = len(childrenEncoding) lo := uint32(n.firstChild) hi := lo + uint32(len(n.children)) - maxLo, maxHi = u32max(maxLo, lo), u32max(maxHi, hi) + maxLo, maxHi = max(maxLo, lo), max(maxHi, hi) if lo >= 1<<childrenBitsLo { return fmt.Errorf("children lo %d is too large, or childrenBitsLo is too small", lo) } @@ -586,20 +559,6 @@ func printNodeLabel(w io.Writer, n *node) error { return nil } -func icannStr(icann bool) string { - if icann { - return "I" - } - return " " -} - -func wildcardStr(wildcard bool) string { - if wildcard { - return "*" - } - return " " -} - // combineText combines all the strings in labelsList to form one giant string. // Overlapping strings will be merged: "arpa" and "parliament" could yield // "arparliament". @@ -616,18 +575,15 @@ func combineText(labelsList []string) string { return text } -type byLength []string - -func (s byLength) Len() int { return len(s) } -func (s byLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -func (s byLength) Less(i, j int) bool { return len(s[i]) < len(s[j]) } +func byLength(a, b string) int { + return cmp.Compare(len(a), len(b)) +} // removeSubstrings returns a copy of its input with any strings removed // that are substrings of other provided strings. func removeSubstrings(input []string) []string { - // Make a copy of input. - ss := append(make([]string, 0, len(input)), input...) - sort.Sort(byLength(ss)) + ss := slices.Clone(input) + slices.SortFunc(ss, byLength) for i, shortString := range ss { // For each string, only consider strings higher than it in sort order, i.e. @@ -641,7 +597,7 @@ func removeSubstrings(input []string) []string { } // Remove the empty strings. - sort.Strings(ss) + slices.Sort(ss) for len(ss) > 0 && ss[0] == "" { ss = ss[1:] } diff --git a/publicsuffix/list.go b/publicsuffix/list.go index d56e9e7624..56069d0429 100644 --- a/publicsuffix/list.go +++ b/publicsuffix/list.go @@ -88,7 +88,7 @@ func PublicSuffix(domain string) (publicSuffix string, icann bool) { s, suffix, icannNode, wildcard := domain, len(domain), false, false loop: for { - dot := strings.LastIndex(s, ".") + dot := strings.LastIndexByte(s, '.') if wildcard { icann = icannNode suffix = 1 + dot @@ -129,7 +129,7 @@ loop: } if suffix == len(domain) { // If no rules match, the prevailing rule is "*". - return domain[1+strings.LastIndex(domain, "."):], icann + return domain[1+strings.LastIndexByte(domain, '.'):], icann } return domain[suffix:], icann } @@ -178,26 +178,28 @@ func EffectiveTLDPlusOne(domain string) (string, error) { if domain[i] != '.' { return "", fmt.Errorf("publicsuffix: invalid public suffix %q for domain %q", suffix, domain) } - return domain[1+strings.LastIndex(domain[:i], "."):], nil + return domain[1+strings.LastIndexByte(domain[:i], '.'):], nil } type uint32String string func (u uint32String) get(i uint32) uint32 { off := i * 4 - return (uint32(u[off])<<24 | - uint32(u[off+1])<<16 | - uint32(u[off+2])<<8 | - uint32(u[off+3])) + u = u[off:] // help the compiler reduce bounds checks + return uint32(u[3]) | + uint32(u[2])<<8 | + uint32(u[1])<<16 | + uint32(u[0])<<24 } type uint40String string func (u uint40String) get(i uint32) uint64 { off := uint64(i * (nodesBits / 8)) - return uint64(u[off])<<32 | - uint64(u[off+1])<<24 | - uint64(u[off+2])<<16 | - uint64(u[off+3])<<8 | - uint64(u[off+4]) + u = u[off:] // help the compiler reduce bounds checks + return uint64(u[4]) | + uint64(u[3])<<8 | + uint64(u[2])<<16 | + uint64(u[1])<<24 | + uint64(u[0])<<32 } diff --git a/quic/conn.go b/quic/conn.go index bf54409bfe..1f1cfa6d0a 100644 --- a/quic/conn.go +++ b/quic/conn.go @@ -186,6 +186,11 @@ func (c *Conn) RemoteAddr() netip.AddrPort { return c.peerAddr } +// ConnectionState returns basic TLS details about the connection. +func (c *Conn) ConnectionState() tls.ConnectionState { + return c.tls.ConnectionState() +} + // confirmHandshake is called when the handshake is confirmed. // https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2 func (c *Conn) confirmHandshake(now time.Time) { diff --git a/route/address.go b/route/address.go index 279505b109..492838a7fe 100644 --- a/route/address.go +++ b/route/address.go @@ -396,13 +396,19 @@ func marshalAddrs(b []byte, as []Addr) (uint, error) { func parseAddrs(attrs uint, fn func(int, []byte) (int, Addr, error), b []byte) ([]Addr, error) { var as [syscall.RTAX_MAX]Addr af := int(syscall.AF_UNSPEC) + isInet := func(fam int) bool { + return fam == syscall.AF_INET || fam == syscall.AF_INET6 + } + isMask := func(addrType uint) bool { + return addrType == syscall.RTAX_NETMASK || addrType == syscall.RTAX_GENMASK + } for i := uint(0); i < syscall.RTAX_MAX && len(b) >= roundup(0); i++ { if attrs&(1<<i) == 0 { continue } if i <= syscall.RTAX_BRD { - switch b[1] { - case syscall.AF_LINK: + switch { + case b[1] == syscall.AF_LINK: a, err := parseLinkAddr(b) if err != nil { return nil, err @@ -413,8 +419,10 @@ func parseAddrs(attrs uint, fn func(int, []byte) (int, Addr, error), b []byte) ( return nil, errMessageTooShort } b = b[l:] - case syscall.AF_INET, syscall.AF_INET6: - af = int(b[1]) + case isInet(int(b[1])) || (isMask(i) && isInet(af)): + if isInet(int(b[1])) { + af = int(b[1]) + } a, err := parseInetAddr(af, b) if err != nil { return nil, err diff --git a/route/address_darwin_test.go b/route/address_darwin_test.go index 80f686e97d..e7e666ab30 100644 --- a/route/address_darwin_test.go +++ b/route/address_darwin_test.go @@ -29,12 +29,12 @@ var parseAddrsOnDarwinLittleEndianTests = []parseAddrsOnDarwinTest{ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0x7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x0, }, []Addr{ &Inet4Addr{IP: [4]byte{192, 168, 86, 0}}, &LinkAddr{Index: 4}, - &Inet4Addr{IP: [4]byte{255, 255, 255, 255}}, + &Inet4Addr{IP: [4]byte{255, 255, 255, 0}}, nil, nil, nil, diff --git a/route/example_darwin_test.go b/route/example_darwin_test.go new file mode 100644 index 0000000000..e442c3ecf7 --- /dev/null +++ b/route/example_darwin_test.go @@ -0,0 +1,70 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package route_test + +import ( + "fmt" + "net/netip" + "os" + "syscall" + + "golang.org/x/net/route" + "golang.org/x/sys/unix" +) + +// This example demonstrates how to parse a response to RTM_GET request. +func ExampleParseRIB() { + fd, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return + } + defer unix.Close(fd) + + // Create a RouteMessage with RTM_GET type + rtm := &route.RouteMessage{ + Version: syscall.RTM_VERSION, + Type: unix.RTM_GET, + ID: uintptr(os.Getpid()), + Seq: 0, + Addrs: []route.Addr{ + &route.Inet4Addr{IP: [4]byte{127, 0, 0, 0}}, + }, + } + + // Marshal the message into bytes + msgBytes, err := rtm.Marshal() + if err != nil { + return + } + + // Send the message over the routing socket + _, err = unix.Write(fd, msgBytes) + if err != nil { + return + } + + // Read the response from the routing socket + var buf [2 << 10]byte + n, err := unix.Read(fd, buf[:]) + if err != nil { + return + } + + // Parse the response messages + msgs, err := route.ParseRIB(route.RIBTypeRoute, buf[:n]) + if err != nil { + return + } + routeMsg, ok := msgs[0].(*route.RouteMessage) + if !ok { + return + } + netmask, ok := routeMsg.Addrs[2].(*route.Inet4Addr) + if !ok { + return + } + fmt.Println(netip.AddrFrom4(netmask.IP)) + // Output: 255.0.0.0 +}