From cd9d6616c0e47be812dd5ccc12c26ee21c7a3d36 Mon Sep 17 00:00:00 2001 From: Carlos Hernandez <carlos@hrndz.ca> Date: Mon, 10 Feb 2025 19:47:21 +0000 Subject: [PATCH 01/17] route: fix RTM_GET netmask parsing on Darwin On Darwin, the AF_FAMILY byte of a sockaddr for a netmask or genmask can be ignored if unreasonable. In such cases, it is the family of the DST address that should instead be used. Additionally, fixing faulty test data. 192.168.86.0 is a Class C network address, that should have a subnet mask of 255.255.255.0. What's more is the data can also be flag as incorrect considering structure padding rules alone. Further more, you can validate that `route get` will never actually return a netmask for a host query, even though it should be 255.255.255.255. You can run the following to check: route -n get -host 127.0.0.1 You will note the reply has no mention of netmask. Depends on CL 646556 - https://go.dev/cl/646556 Fixes golang/go#71578. Change-Id: Id95669b649a416a380d26c5cdba0e3d1c4bc1ffb GitHub-Last-Rev: 20064b27979b0244db23b4d1f38c5e40479df166 GitHub-Pull-Request: golang/net#232 Reviewed-on: https://go-review.googlesource.com/c/net/+/647176 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Ian Lance Taylor <iant@google.com> Auto-Submit: Ian Lance Taylor <iant@google.com> Commit-Queue: Ian Lance Taylor <iant@google.com> Reviewed-by: Damien Neil <dneil@google.com> --- route/address.go | 16 ++++++--- route/address_darwin_test.go | 4 +-- route/example_darwin_test.go | 70 ++++++++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+), 6 deletions(-) create mode 100644 route/example_darwin_test.go diff --git a/route/address.go b/route/address.go index 279505b10..492838a7f 100644 --- a/route/address.go +++ b/route/address.go @@ -396,13 +396,19 @@ func marshalAddrs(b []byte, as []Addr) (uint, error) { func parseAddrs(attrs uint, fn func(int, []byte) (int, Addr, error), b []byte) ([]Addr, error) { var as [syscall.RTAX_MAX]Addr af := int(syscall.AF_UNSPEC) + isInet := func(fam int) bool { + return fam == syscall.AF_INET || fam == syscall.AF_INET6 + } + isMask := func(addrType uint) bool { + return addrType == syscall.RTAX_NETMASK || addrType == syscall.RTAX_GENMASK + } for i := uint(0); i < syscall.RTAX_MAX && len(b) >= roundup(0); i++ { if attrs&(1<<i) == 0 { continue } if i <= syscall.RTAX_BRD { - switch b[1] { - case syscall.AF_LINK: + switch { + case b[1] == syscall.AF_LINK: a, err := parseLinkAddr(b) if err != nil { return nil, err @@ -413,8 +419,10 @@ func parseAddrs(attrs uint, fn func(int, []byte) (int, Addr, error), b []byte) ( return nil, errMessageTooShort } b = b[l:] - case syscall.AF_INET, syscall.AF_INET6: - af = int(b[1]) + case isInet(int(b[1])) || (isMask(i) && isInet(af)): + if isInet(int(b[1])) { + af = int(b[1]) + } a, err := parseInetAddr(af, b) if err != nil { return nil, err diff --git a/route/address_darwin_test.go b/route/address_darwin_test.go index 80f686e97..e7e666ab3 100644 --- a/route/address_darwin_test.go +++ b/route/address_darwin_test.go @@ -29,12 +29,12 @@ var parseAddrsOnDarwinLittleEndianTests = []parseAddrsOnDarwinTest{ 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, - 0x7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0x7, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x0, }, []Addr{ &Inet4Addr{IP: [4]byte{192, 168, 86, 0}}, &LinkAddr{Index: 4}, - &Inet4Addr{IP: [4]byte{255, 255, 255, 255}}, + &Inet4Addr{IP: [4]byte{255, 255, 255, 0}}, nil, nil, nil, diff --git a/route/example_darwin_test.go b/route/example_darwin_test.go new file mode 100644 index 000000000..e442c3ecf --- /dev/null +++ b/route/example_darwin_test.go @@ -0,0 +1,70 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package route_test + +import ( + "fmt" + "net/netip" + "os" + "syscall" + + "golang.org/x/net/route" + "golang.org/x/sys/unix" +) + +// This example demonstrates how to parse a response to RTM_GET request. +func ExampleParseRIB() { + fd, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC) + if err != nil { + return + } + defer unix.Close(fd) + + // Create a RouteMessage with RTM_GET type + rtm := &route.RouteMessage{ + Version: syscall.RTM_VERSION, + Type: unix.RTM_GET, + ID: uintptr(os.Getpid()), + Seq: 0, + Addrs: []route.Addr{ + &route.Inet4Addr{IP: [4]byte{127, 0, 0, 0}}, + }, + } + + // Marshal the message into bytes + msgBytes, err := rtm.Marshal() + if err != nil { + return + } + + // Send the message over the routing socket + _, err = unix.Write(fd, msgBytes) + if err != nil { + return + } + + // Read the response from the routing socket + var buf [2 << 10]byte + n, err := unix.Read(fd, buf[:]) + if err != nil { + return + } + + // Parse the response messages + msgs, err := route.ParseRIB(route.RIBTypeRoute, buf[:n]) + if err != nil { + return + } + routeMsg, ok := msgs[0].(*route.RouteMessage) + if !ok { + return + } + netmask, ok := routeMsg.Addrs[2].(*route.Inet4Addr) + if !ok { + return + } + fmt.Println(netip.AddrFrom4(netmask.IP)) + // Output: 255.0.0.0 +} From 884432780bfdc3f8033af387a3adb3bf4f59fbd3 Mon Sep 17 00:00:00 2001 From: Damien Neil <dneil@google.com> Date: Thu, 13 Feb 2025 12:39:30 -0800 Subject: [PATCH 02/17] internal/httpcommon: don't depend on net/http When the http2 package is bundled into net/http, it imports httpcommon, so httpcommon must not depend on net/http. Change-Id: I2aa34e913a0df757fa83deb56f650394a924933e Reviewed-on: https://go-review.googlesource.com/c/net/+/649415 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> --- http2/transport.go | 42 ++++-- http2/transport_test.go | 58 ++++++-- internal/http3/roundtrip.go | 27 +++- internal/httpcommon/headermap.go | 6 +- internal/httpcommon/httpcommon_test.go | 37 +++++ internal/httpcommon/request.go | 79 +++++----- internal/httpcommon/request_test.go | 196 +++++++++++-------------- 7 files changed, 271 insertions(+), 174 deletions(-) create mode 100644 internal/httpcommon/httpcommon_test.go diff --git a/http2/transport.go b/http2/transport.go index f2c166b61..94b397c69 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -1286,6 +1286,19 @@ func (cc *ClientConn) responseHeaderTimeout() time.Duration { return 0 } +// actualContentLength returns a sanitized version of +// req.ContentLength, where 0 actually means zero (not unknown) and -1 +// means unknown. +func actualContentLength(req *http.Request) int64 { + if req.Body == nil || req.Body == http.NoBody { + return 0 + } + if req.ContentLength != 0 { + return req.ContentLength + } + return -1 +} + func (cc *ClientConn) decrStreamReservations() { cc.mu.Lock() defer cc.mu.Unlock() @@ -1310,7 +1323,7 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) reqCancel: req.Cancel, isHead: req.Method == "HEAD", reqBody: req.Body, - reqBodyContentLength: httpcommon.ActualContentLength(req), + reqBodyContentLength: actualContentLength(req), trace: httptrace.ContextClientTrace(ctx), peerClosed: make(chan struct{}), abort: make(chan struct{}), @@ -1318,7 +1331,7 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) donec: make(chan struct{}), } - cs.requestedGzip = httpcommon.IsRequestGzip(req, cc.t.disableCompression()) + cs.requestedGzip = httpcommon.IsRequestGzip(req.Method, req.Header, cc.t.disableCompression()) go cs.doRequest(req, streamf) @@ -1349,7 +1362,7 @@ func (cc *ClientConn) roundTrip(req *http.Request, streamf func(*clientStream)) } res.Request = req res.TLS = cc.tlsState - if res.Body == noBody && httpcommon.ActualContentLength(req) == 0 { + if res.Body == noBody && actualContentLength(req) == 0 { // If there isn't a request or response body still being // written, then wait for the stream to be closed before // RoundTrip returns. @@ -1596,12 +1609,7 @@ func (cs *clientStream) encodeAndWriteHeaders(req *http.Request) error { // sent by writeRequestBody below, along with any Trailers, // again in form HEADERS{1}, CONTINUATION{0,}) cc.hbuf.Reset() - res, err := httpcommon.EncodeHeaders(httpcommon.EncodeHeadersParam{ - Request: req, - AddGzipHeader: cs.requestedGzip, - PeerMaxHeaderListSize: cc.peerMaxHeaderListSize, - DefaultUserAgent: defaultUserAgent, - }, func(name, value string) { + res, err := encodeRequestHeaders(req, cs.requestedGzip, cc.peerMaxHeaderListSize, func(name, value string) { cc.writeHeader(name, value) }) if err != nil { @@ -1617,6 +1625,22 @@ func (cs *clientStream) encodeAndWriteHeaders(req *http.Request) error { return err } +func encodeRequestHeaders(req *http.Request, addGzipHeader bool, peerMaxHeaderListSize uint64, headerf func(name, value string)) (httpcommon.EncodeHeadersResult, error) { + return httpcommon.EncodeHeaders(req.Context(), httpcommon.EncodeHeadersParam{ + Request: httpcommon.Request{ + Header: req.Header, + Trailer: req.Trailer, + URL: req.URL, + Host: req.Host, + Method: req.Method, + ActualContentLength: actualContentLength(req), + }, + AddGzipHeader: addGzipHeader, + PeerMaxHeaderListSize: peerMaxHeaderListSize, + DefaultUserAgent: defaultUserAgent, + }, headerf) +} + // cleanupWriteRequest performs post-request tasks. // // If err (the result of writeRequest) is non-nil and the stream is not closed, diff --git a/http2/transport_test.go b/http2/transport_test.go index 47eac2fa8..d1d27f8f9 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -36,7 +36,6 @@ import ( "time" "golang.org/x/net/http2/hpack" - "golang.org/x/net/internal/httpcommon" ) var ( @@ -571,6 +570,45 @@ func randString(n int) string { return string(b) } +type panicReader struct{} + +func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") } +func (panicReader) Close() error { panic("unexpected Close") } + +func TestActualContentLength(t *testing.T) { + tests := []struct { + req *http.Request + want int64 + }{ + // Verify we don't read from Body: + 0: { + req: &http.Request{Body: panicReader{}}, + want: -1, + }, + // nil Body means 0, regardless of ContentLength: + 1: { + req: &http.Request{Body: nil, ContentLength: 5}, + want: 0, + }, + // ContentLength is used if set. + 2: { + req: &http.Request{Body: panicReader{}, ContentLength: 5}, + want: 5, + }, + // http.NoBody means 0, not -1. + 3: { + req: &http.Request{Body: http.NoBody}, + want: 0, + }, + } + for i, tt := range tests { + got := actualContentLength(tt.req) + if got != tt.want { + t.Errorf("test[%d]: got %d; want %d", i, got, tt.want) + } + } +} + func TestTransportBody(t *testing.T) { bodyTests := []struct { body string @@ -1405,12 +1443,9 @@ func TestTransportChecksRequestHeaderListSize(t *testing.T) { } } headerListSizeForRequest := func(req *http.Request) (size uint64) { - _, err := httpcommon.EncodeHeaders(httpcommon.EncodeHeadersParam{ - Request: req, - AddGzipHeader: true, - PeerMaxHeaderListSize: 0xffffffffffffffff, - DefaultUserAgent: defaultUserAgent, - }, func(name, value string) { + const addGzipHeader = true + const peerMaxHeaderListSize = 0xffffffffffffffff + _, err := encodeRequestHeaders(req, addGzipHeader, peerMaxHeaderListSize, func(name, value string) { hf := hpack.HeaderField{Name: name, Value: value} size += uint64(hf.Size()) }) @@ -2808,11 +2843,10 @@ func TestTransportRequestPathPseudo(t *testing.T) { for i, tt := range tests { hbuf := &bytes.Buffer{} henc := hpack.NewEncoder(hbuf) - _, err := httpcommon.EncodeHeaders(httpcommon.EncodeHeadersParam{ - Request: tt.req, - AddGzipHeader: false, - PeerMaxHeaderListSize: 0xffffffffffffffff, - }, func(name, value string) { + + const addGzipHeader = false + const peerMaxHeaderListSize = 0xffffffffffffffff + _, err := encodeRequestHeaders(tt.req, addGzipHeader, peerMaxHeaderListSize, func(name, value string) { henc.WriteField(hpack.HeaderField{Name: name, Value: value}) }) hdrs := hbuf.Bytes() diff --git a/internal/http3/roundtrip.go b/internal/http3/roundtrip.go index b24a30308..bf55a1315 100644 --- a/internal/http3/roundtrip.go +++ b/internal/http3/roundtrip.go @@ -82,10 +82,19 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) st.stream.SetReadContext(req.Context()) st.stream.SetWriteContext(req.Context()) + contentLength := actualContentLength(req) + var encr httpcommon.EncodeHeadersResult headers := cc.enc.encode(func(yield func(itype indexType, name, value string)) { - encr, err = httpcommon.EncodeHeaders(httpcommon.EncodeHeadersParam{ - Request: req, + encr, err = httpcommon.EncodeHeaders(req.Context(), httpcommon.EncodeHeadersParam{ + Request: httpcommon.Request{ + URL: req.URL, + Method: req.Method, + Host: req.Host, + Header: req.Header, + Trailer: req.Trailer, + ActualContentLength: contentLength, + }, AddGzipHeader: false, // TODO: add when appropriate PeerMaxHeaderListSize: 0, DefaultUserAgent: "Go-http-client/3", @@ -110,7 +119,7 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) // TODO: Defer sending the request body when "Expect: 100-continue" is set. rt.reqBody = req.Body rt.reqBodyWriter.st = st - rt.reqBodyWriter.remain = httpcommon.ActualContentLength(req) + rt.reqBodyWriter.remain = contentLength rt.reqBodyWriter.flush = true rt.reqBodyWriter.name = "request" go copyRequestBody(rt) @@ -165,6 +174,18 @@ func (cc *ClientConn) RoundTrip(req *http.Request) (_ *http.Response, err error) } } +// actualContentLength returns a sanitized version of req.ContentLength, +// where 0 actually means zero (not unknown) and -1 means unknown. +func actualContentLength(req *http.Request) int64 { + if req.Body == nil || req.Body == http.NoBody { + return 0 + } + if req.ContentLength != 0 { + return req.ContentLength + } + return -1 +} + func copyRequestBody(rt *roundTripState) { defer rt.closeReqBody() _, err := io.Copy(&rt.reqBodyWriter, rt.reqBody) diff --git a/internal/httpcommon/headermap.go b/internal/httpcommon/headermap.go index ad3fbacd6..92483d8e4 100644 --- a/internal/httpcommon/headermap.go +++ b/internal/httpcommon/headermap.go @@ -5,7 +5,7 @@ package httpcommon import ( - "net/http" + "net/textproto" "sync" ) @@ -82,7 +82,7 @@ func buildCommonHeaderMaps() { commonLowerHeader = make(map[string]string, len(common)) commonCanonHeader = make(map[string]string, len(common)) for _, v := range common { - chk := http.CanonicalHeaderKey(v) + chk := textproto.CanonicalMIMEHeaderKey(v) commonLowerHeader[chk] = v commonCanonHeader[v] = chk } @@ -104,7 +104,7 @@ func CanonicalHeader(v string) string { if s, ok := commonCanonHeader[v]; ok { return s } - return http.CanonicalHeaderKey(v) + return textproto.CanonicalMIMEHeaderKey(v) } // CachedCanonicalHeader returns the canonical form of a well-known header name. diff --git a/internal/httpcommon/httpcommon_test.go b/internal/httpcommon/httpcommon_test.go new file mode 100644 index 000000000..e725ec76c --- /dev/null +++ b/internal/httpcommon/httpcommon_test.go @@ -0,0 +1,37 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httpcommon_test + +import ( + "bytes" + "os" + "path/filepath" + "strings" + "testing" +) + +// This package is imported by the net/http package, +// and therefore must not itself import net/http. +func TestNoNetHttp(t *testing.T) { + files, err := filepath.Glob("*.go") + if err != nil { + t.Fatal(err) + } + for _, file := range files { + if strings.HasSuffix(file, "_test.go") { + continue + } + // Could use something complex like go/build or x/tools/go/packages, + // but there's no reason for "net/http" to appear (in quotes) in the source + // otherwise, so just use a simple substring search. + data, err := os.ReadFile(file) + if err != nil { + t.Fatal(err) + } + if bytes.Contains(data, []byte(`"net/http"`)) { + t.Errorf(`%s: cannot import "net/http"`, file) + } + } +} diff --git a/internal/httpcommon/request.go b/internal/httpcommon/request.go index 343914773..bec16d0b9 100644 --- a/internal/httpcommon/request.go +++ b/internal/httpcommon/request.go @@ -5,10 +5,11 @@ package httpcommon import ( + "context" "errors" "fmt" - "net/http" "net/http/httptrace" + "net/url" "sort" "strconv" "strings" @@ -21,9 +22,21 @@ var ( ErrRequestHeaderListSize = errors.New("request header list larger than peer's advertised limit") ) +// Request is a subset of http.Request. +// It'd be simpler to pass an *http.Request, of course, but we can't depend on net/http +// without creating a dependency cycle. +type Request struct { + URL *url.URL + Method string + Host string + Header map[string][]string + Trailer map[string][]string + ActualContentLength int64 // 0 means 0, -1 means unknown +} + // EncodeHeadersParam is parameters to EncodeHeaders. type EncodeHeadersParam struct { - Request *http.Request + Request Request // AddGzipHeader indicates that an "accept-encoding: gzip" header should be // added to the request. @@ -47,11 +60,11 @@ type EncodeHeadersResult struct { // It validates a request and calls headerf with each pseudo-header and header // for the request. // The headerf function is called with the validated, canonicalized header name. -func EncodeHeaders(param EncodeHeadersParam, headerf func(name, value string)) (res EncodeHeadersResult, _ error) { +func EncodeHeaders(ctx context.Context, param EncodeHeadersParam, headerf func(name, value string)) (res EncodeHeadersResult, _ error) { req := param.Request // Check for invalid connection-level headers. - if err := checkConnHeaders(req); err != nil { + if err := checkConnHeaders(req.Header); err != nil { return res, err } @@ -73,7 +86,10 @@ func EncodeHeaders(param EncodeHeadersParam, headerf func(name, value string)) ( // isNormalConnect is true if this is a non-extended CONNECT request. isNormalConnect := false - protocol := req.Header.Get(":protocol") + var protocol string + if vv := req.Header[":protocol"]; len(vv) > 0 { + protocol = vv[0] + } if req.Method == "CONNECT" && protocol == "" { isNormalConnect = true } else if protocol != "" && req.Method != "CONNECT" { @@ -107,9 +123,7 @@ func EncodeHeaders(param EncodeHeadersParam, headerf func(name, value string)) ( return res, fmt.Errorf("invalid HTTP trailer %s", err) } - contentLength := ActualContentLength(req) - - trailers, err := commaSeparatedTrailers(req) + trailers, err := commaSeparatedTrailers(req.Trailer) if err != nil { return res, err } @@ -123,7 +137,7 @@ func EncodeHeaders(param EncodeHeadersParam, headerf func(name, value string)) ( f(":authority", host) m := req.Method if m == "" { - m = http.MethodGet + m = "GET" } f(":method", m) if !isNormalConnect { @@ -198,8 +212,8 @@ func EncodeHeaders(param EncodeHeadersParam, headerf func(name, value string)) ( f(k, v) } } - if shouldSendReqContentLength(req.Method, contentLength) { - f("content-length", strconv.FormatInt(contentLength, 10)) + if shouldSendReqContentLength(req.Method, req.ActualContentLength) { + f("content-length", strconv.FormatInt(req.ActualContentLength, 10)) } if param.AddGzipHeader { f("accept-encoding", "gzip") @@ -225,7 +239,7 @@ func EncodeHeaders(param EncodeHeadersParam, headerf func(name, value string)) ( } } - trace := httptrace.ContextClientTrace(req.Context()) + trace := httptrace.ContextClientTrace(ctx) // Header list size is ok. Write the headers. enumerateHeaders(func(name, value string) { @@ -243,19 +257,19 @@ func EncodeHeaders(param EncodeHeadersParam, headerf func(name, value string)) ( } }) - res.HasBody = contentLength != 0 + res.HasBody = req.ActualContentLength != 0 res.HasTrailers = trailers != "" return res, nil } // IsRequestGzip reports whether we should add an Accept-Encoding: gzip header // for a request. -func IsRequestGzip(req *http.Request, disableCompression bool) bool { +func IsRequestGzip(method string, header map[string][]string, disableCompression bool) bool { // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? if !disableCompression && - req.Header.Get("Accept-Encoding") == "" && - req.Header.Get("Range") == "" && - req.Method != "HEAD" { + len(header["Accept-Encoding"]) == 0 && + len(header["Range"]) == 0 && + method != "HEAD" { // Request gzip only, not deflate. Deflate is ambiguous and // not as universally supported anyway. // See: https://zlib.net/zlib_faq.html#faq39 @@ -280,22 +294,22 @@ func IsRequestGzip(req *http.Request, disableCompression bool) bool { // // Certain headers are special-cased as okay but not transmitted later. // For example, we allow "Transfer-Encoding: chunked", but drop the header when encoding. -func checkConnHeaders(req *http.Request) error { - if v := req.Header.Get("Upgrade"); v != "" { - return fmt.Errorf("invalid Upgrade request header: %q", req.Header["Upgrade"]) +func checkConnHeaders(h map[string][]string) error { + if vv := h["Upgrade"]; len(vv) > 0 && (vv[0] != "" && vv[0] != "chunked") { + return fmt.Errorf("invalid Upgrade request header: %q", vv) } - if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { + if vv := h["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { return fmt.Errorf("invalid Transfer-Encoding request header: %q", vv) } - if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !asciiEqualFold(vv[0], "close") && !asciiEqualFold(vv[0], "keep-alive")) { + if vv := h["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !asciiEqualFold(vv[0], "close") && !asciiEqualFold(vv[0], "keep-alive")) { return fmt.Errorf("invalid Connection request header: %q", vv) } return nil } -func commaSeparatedTrailers(req *http.Request) (string, error) { - keys := make([]string, 0, len(req.Trailer)) - for k := range req.Trailer { +func commaSeparatedTrailers(trailer map[string][]string) (string, error) { + keys := make([]string, 0, len(trailer)) + for k := range trailer { k = CanonicalHeader(k) switch k { case "Transfer-Encoding", "Trailer", "Content-Length": @@ -310,19 +324,6 @@ func commaSeparatedTrailers(req *http.Request) (string, error) { return "", nil } -// ActualContentLength returns a sanitized version of -// req.ContentLength, where 0 actually means zero (not unknown) and -1 -// means unknown. -func ActualContentLength(req *http.Request) int64 { - if req.Body == nil || req.Body == http.NoBody { - return 0 - } - if req.ContentLength != 0 { - return req.ContentLength - } - return -1 -} - // validPseudoPath reports whether v is a valid :path pseudo-header // value. It must be either: // @@ -340,7 +341,7 @@ func validPseudoPath(v string) bool { return (len(v) > 0 && v[0] == '/') || v == "*" } -func validateHeaders(hdrs http.Header) string { +func validateHeaders(hdrs map[string][]string) string { for k, vv := range hdrs { if !httpguts.ValidHeaderFieldName(k) && k != ":protocol" { return fmt.Sprintf("name %q", k) diff --git a/internal/httpcommon/request_test.go b/internal/httpcommon/request_test.go index b453983e0..b8792977c 100644 --- a/internal/httpcommon/request_test.go +++ b/internal/httpcommon/request_test.go @@ -6,6 +6,7 @@ package httpcommon import ( "cmp" + "context" "io" "net/http" "slices" @@ -27,9 +28,9 @@ func TestEncodeHeaders(t *testing.T) { }{{ name: "simple request", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { return must(http.NewRequest("GET", "https://example.tld/", nil)) - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -47,12 +48,12 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "host set from URL", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Host = "" req.URL.Host = "example.tld" return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -70,11 +71,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "chunked transfer-encoding", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Transfer-Encoding", "chunked") // ignored return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -92,11 +93,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "connection close", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Connection", "close") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -114,11 +115,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "connection keep-alive", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Connection", "keep-alive") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -136,9 +137,9 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "normal connect", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { return must(http.NewRequest("CONNECT", "https://example.tld/", nil)) - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -154,11 +155,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "extended connect", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("CONNECT", "https://example.tld/", nil)) req.Header.Set(":protocol", "foo") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -177,13 +178,13 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "trailers", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Trailer = make(http.Header) req.Trailer.Set("a", "1") req.Trailer.Set("b", "2") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -202,11 +203,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "override user-agent", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("User-Agent", "GopherTron 9000") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -224,11 +225,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "disable user-agent", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header["User-Agent"] = nil return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -245,11 +246,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "ignore host header", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Host", "gophers.tld/") // ignored return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -267,11 +268,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "crumble cookie header", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Cookie", "a=b; b=c; c=d") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -293,9 +294,9 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "post with nil body", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { return must(http.NewRequest("POST", "https://example.tld/", nil)) - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -314,9 +315,9 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "post with NoBody", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { return must(http.NewRequest("POST", "https://example.tld/", http.NoBody)) - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -335,12 +336,12 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "post with Content-Length", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { type reader struct{ io.ReadCloser } req := must(http.NewRequest("POST", "https://example.tld/", reader{})) req.ContentLength = 10 return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -359,11 +360,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "post with unknown Content-Length", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { type reader struct{ io.ReadCloser } req := must(http.NewRequest("POST", "https://example.tld/", reader{})) return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -381,11 +382,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "explicit accept-encoding", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Accept-Encoding", "deflate") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -403,9 +404,9 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "head request", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { return must(http.NewRequest("HEAD", "https://example.tld/", nil)) - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -422,11 +423,11 @@ func TestEncodeHeaders(t *testing.T) { }, { name: "range request", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("HEAD", "https://example.tld/", nil)) req.Header.Set("Range", "bytes=0-10") return req - }(), + }), DefaultUserAgent: "default-user-agent", }, want: EncodeHeadersResult{ @@ -444,11 +445,11 @@ func TestEncodeHeaders(t *testing.T) { }} { t.Run(test.name, func(t *testing.T) { var gotHeaders []header - if IsRequestGzip(test.in.Request, test.disableCompression) { + if IsRequestGzip(test.in.Request.Method, test.in.Request.Header, test.disableCompression) { test.in.AddGzipHeader = true } - got, err := EncodeHeaders(test.in, func(name, value string) { + got, err := EncodeHeaders(context.Background(), test.in, func(name, value string) { gotHeaders = append(gotHeaders, header{name, value}) }) if err != nil { @@ -490,151 +491,151 @@ func TestEncodeHeaderErrors(t *testing.T) { }{{ name: "URL is nil", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.URL = nil return req - }(), + }), }, want: "URL is nil", }, { name: "upgrade header is set", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Upgrade", "foo") return req - }(), + }), }, want: "Upgrade", }, { name: "unsupported transfer-encoding header", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Transfer-Encoding", "identity") return req - }(), + }), }, want: "Transfer-Encoding", }, { name: "unsupported connection header", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("Connection", "x") return req - }(), + }), }, want: "Connection", }, { name: "invalid host", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Host = "\x00.tld" return req - }(), + }), }, want: "Host", }, { name: "protocol header is set", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set(":protocol", "foo") return req - }(), + }), }, want: ":protocol", }, { name: "invalid path", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.URL.Path = "no_leading_slash" return req - }(), + }), }, want: "path", }, { name: "invalid header name", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("x\ny", "foo") return req - }(), + }), }, want: "header", }, { name: "invalid header value", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("x", "foo\nbar") return req - }(), + }), }, want: "header", }, { name: "invalid trailer", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Trailer = make(http.Header) req.Trailer.Set("x\ny", "foo") return req - }(), + }), }, want: "trailer", }, { name: "transfer-encoding trailer", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Trailer = make(http.Header) req.Trailer.Set("Transfer-Encoding", "chunked") return req - }(), + }), }, want: "Trailer", }, { name: "trailer trailer", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Trailer = make(http.Header) req.Trailer.Set("Trailer", "chunked") return req - }(), + }), }, want: "Trailer", }, { name: "content-length trailer", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Trailer = make(http.Header) req.Trailer.Set("Content-Length", "0") return req - }(), + }), }, want: "Trailer", }, { name: "too many headers", in: EncodeHeadersParam{ - Request: func() *http.Request { + Request: newReq(func() *http.Request { req := must(http.NewRequest("GET", "https://example.tld/", nil)) req.Header.Set("X-Foo", strings.Repeat("x", 1000)) return req - }(), + }), PeerMaxHeaderListSize: 1000, }, want: "limit", }} { t.Run(test.name, func(t *testing.T) { - _, err := EncodeHeaders(test.in, func(name, value string) {}) + _, err := EncodeHeaders(context.Background(), test.in, func(name, value string) {}) if err == nil { t.Fatalf("EncodeHeaders = nil, want %q", test.want) } @@ -645,48 +646,27 @@ func TestEncodeHeaderErrors(t *testing.T) { } } +func newReq(f func() *http.Request) Request { + req := f() + contentLength := req.ContentLength + if req.Body == nil || req.Body == http.NoBody { + contentLength = 0 + } else if contentLength == 0 { + contentLength = -1 + } + return Request{ + Header: req.Header, + Trailer: req.Trailer, + URL: req.URL, + Host: req.Host, + Method: req.Method, + ActualContentLength: contentLength, + } +} + func must[T any](v T, err error) T { if err != nil { panic(err) } return v } - -type panicReader struct{} - -func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") } -func (panicReader) Close() error { panic("unexpected Close") } - -func TestActualContentLength(t *testing.T) { - tests := []struct { - req *http.Request - want int64 - }{ - // Verify we don't read from Body: - 0: { - req: &http.Request{Body: panicReader{}}, - want: -1, - }, - // nil Body means 0, regardless of ContentLength: - 1: { - req: &http.Request{Body: nil, ContentLength: 5}, - want: 0, - }, - // ContentLength is used if set. - 2: { - req: &http.Request{Body: panicReader{}, ContentLength: 5}, - want: 5, - }, - // http.NoBody means 0, not -1. - 3: { - req: &http.Request{Body: http.NoBody}, - want: 0, - }, - } - for i, tt := range tests { - got := ActualContentLength(tt.req) - if got != tt.want { - t.Errorf("test[%d]: got %d; want %d", i, got, tt.want) - } - } -} From 5095d0cf1463414ad99ced9d5032eae6175f5ac5 Mon Sep 17 00:00:00 2001 From: Gopher Robot <gobot@golang.org> Date: Fri, 14 Feb 2025 21:11:59 +0000 Subject: [PATCH 03/17] all: upgrade go directive to at least 1.23.0 [generated] By now Go 1.24.0 has been released, and Go 1.22 is no longer supported per the Go Release Policy (https://go.dev/doc/devel/release#policy). For golang/go#69095. [git-generate] (cd . && go get go@1.23.0 && go mod tidy && go mod edit -toolchain=none) Change-Id: I7e0b4b38a9838b5489cb674cd20ae60233a304e6 Reviewed-on: https://go-review.googlesource.com/c/net/+/649775 Reviewed-by: Cherry Mui <cherryyz@google.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com> Auto-Submit: Gopher Robot <gobot@golang.org> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 8de393204..162f7073e 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module golang.org/x/net -go 1.18 +go 1.23.0 require ( golang.org/x/crypto v0.33.0 From 918d64e8e6a411fc6d2ff215a17198d6db0e9fd0 Mon Sep 17 00:00:00 2001 From: Dmitri Shuralyov <dmitshur@golang.org> Date: Sun, 16 Feb 2025 15:28:12 -0500 Subject: [PATCH 04/17] context: delete dead code, sync docs with upstream context package The go directive is now at 1.23.0, so the go1.7 and go1.9 build constraints are guaranteed to always be satisfied, and their inverse will never be satisfied. Delete all the dead code and merge everything that's left in a single context.go file. Also update docs to match the upstream context package. For golang/go#49506. Change-Id: I317550767838a93af2c2d3dbc7b61f2e37e6fe1c Reviewed-on: https://go-review.googlesource.com/c/net/+/650155 Reviewed-by: Ian Lance Taylor <iant@google.com> Reviewed-by: Damien Neil <dneil@google.com> Auto-Submit: Dmitri Shuralyov <dmitshur@golang.org> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com> --- context/context.go | 112 ++++++- context/context_test.go | 583 ------------------------------------- context/ctxhttp/ctxhttp.go | 2 +- context/go17.go | 72 ----- context/go19.go | 20 -- context/pre_go17.go | 300 ------------------- context/pre_go19.go | 109 ------- 7 files changed, 101 insertions(+), 1097 deletions(-) delete mode 100644 context/context_test.go delete mode 100644 context/go17.go delete mode 100644 context/go19.go delete mode 100644 context/pre_go17.go delete mode 100644 context/pre_go19.go diff --git a/context/context.go b/context/context.go index cf66309c4..db1c95fab 100644 --- a/context/context.go +++ b/context/context.go @@ -3,29 +3,31 @@ // license that can be found in the LICENSE file. // Package context defines the Context type, which carries deadlines, -// cancelation signals, and other request-scoped values across API boundaries +// cancellation signals, and other request-scoped values across API boundaries // and between processes. // As of Go 1.7 this package is available in the standard library under the -// name context. https://golang.org/pkg/context. +// name [context], and migrating to it can be done automatically with [go fix]. // -// Incoming requests to a server should create a Context, and outgoing calls to -// servers should accept a Context. The chain of function calls between must -// propagate the Context, optionally replacing it with a modified copy created -// using WithDeadline, WithTimeout, WithCancel, or WithValue. +// Incoming requests to a server should create a [Context], and outgoing +// calls to servers should accept a Context. The chain of function +// calls between them must propagate the Context, optionally replacing +// it with a derived Context created using [WithCancel], [WithDeadline], +// [WithTimeout], or [WithValue]. // // Programs that use Contexts should follow these rules to keep interfaces // consistent across packages and enable static analysis tools to check context // propagation: // // Do not store Contexts inside a struct type; instead, pass a Context -// explicitly to each function that needs it. The Context should be the first +// explicitly to each function that needs it. This is discussed further in +// https://go.dev/blog/context-and-structs. The Context should be the first // parameter, typically named ctx: // // func DoSomething(ctx context.Context, arg Arg) error { // // ... use ctx ... // } // -// Do not pass a nil Context, even if a function permits it. Pass context.TODO +// Do not pass a nil [Context], even if a function permits it. Pass [context.TODO] // if you are unsure about which Context to use. // // Use context Values only for request-scoped data that transits processes and @@ -34,9 +36,30 @@ // The same Context may be passed to functions running in different goroutines; // Contexts are safe for simultaneous use by multiple goroutines. // -// See http://blog.golang.org/context for example code for a server that uses +// See https://go.dev/blog/context for example code for a server that uses // Contexts. -package context // import "golang.org/x/net/context" +// +// [go fix]: https://go.dev/cmd/go#hdr-Update_packages_to_use_new_APIs +package context + +import ( + "context" // standard library's context, as of Go 1.7 + "time" +) + +// A Context carries a deadline, a cancellation signal, and other values across +// API boundaries. +// +// Context's methods may be called by multiple goroutines simultaneously. +type Context = context.Context + +// Canceled is the error returned by [Context.Err] when the context is canceled +// for some reason other than its deadline passing. +var Canceled = context.Canceled + +// DeadlineExceeded is the error returned by [Context.Err] when the context is canceled +// due to its deadline passing. +var DeadlineExceeded = context.DeadlineExceeded // Background returns a non-nil, empty Context. It is never canceled, has no // values, and has no deadline. It is typically used by the main function, @@ -49,8 +72,73 @@ func Background() Context { // TODO returns a non-nil, empty Context. Code should use context.TODO when // it's unclear which Context to use or it is not yet available (because the // surrounding function has not yet been extended to accept a Context -// parameter). TODO is recognized by static analysis tools that determine -// whether Contexts are propagated correctly in a program. +// parameter). func TODO() Context { return todo } + +var ( + background = context.Background() + todo = context.TODO() +) + +// A CancelFunc tells an operation to abandon its work. +// A CancelFunc does not wait for the work to stop. +// A CancelFunc may be called by multiple goroutines simultaneously. +// After the first call, subsequent calls to a CancelFunc do nothing. +type CancelFunc = context.CancelFunc + +// WithCancel returns a derived context that points to the parent context +// but has a new Done channel. The returned context's Done channel is closed +// when the returned cancel function is called or when the parent context's +// Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this [Context] complete. +func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + return context.WithCancel(parent) +} + +// WithDeadline returns a derived context that points to the parent context +// but has the deadline adjusted to be no later than d. If the parent's +// deadline is already earlier than d, WithDeadline(parent, d) is semantically +// equivalent to parent. The returned [Context.Done] channel is closed when +// the deadline expires, when the returned cancel function is called, +// or when the parent context's Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this [Context] complete. +func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) { + return context.WithDeadline(parent, d) +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this [Context] complete: +// +// func slowOperationWithTimeout(ctx context.Context) (Result, error) { +// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) +// defer cancel() // releases resources if slowOperation completes before timeout elapses +// return slowOperation(ctx) +// } +func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { + return context.WithTimeout(parent, timeout) +} + +// WithValue returns a derived context that points to the parent Context. +// In the derived context, the value associated with key is val. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +// +// The provided key must be comparable and should not be of type +// string or any other built-in type to avoid collisions between +// packages using context. Users of WithValue should define their own +// types for keys. To avoid allocating when assigning to an +// interface{}, context keys often have concrete type +// struct{}. Alternatively, exported context key variables' static +// type should be a pointer or interface. +func WithValue(parent Context, key, val interface{}) Context { + return context.WithValue(parent, key, val) +} diff --git a/context/context_test.go b/context/context_test.go deleted file mode 100644 index 2cb54edb8..000000000 --- a/context/context_test.go +++ /dev/null @@ -1,583 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !go1.7 - -package context - -import ( - "fmt" - "math/rand" - "runtime" - "strings" - "sync" - "testing" - "time" -) - -// otherContext is a Context that's not one of the types defined in context.go. -// This lets us test code paths that differ based on the underlying type of the -// Context. -type otherContext struct { - Context -} - -func TestBackground(t *testing.T) { - c := Background() - if c == nil { - t.Fatalf("Background returned nil") - } - select { - case x := <-c.Done(): - t.Errorf("<-c.Done() == %v want nothing (it should block)", x) - default: - } - if got, want := fmt.Sprint(c), "context.Background"; got != want { - t.Errorf("Background().String() = %q want %q", got, want) - } -} - -func TestTODO(t *testing.T) { - c := TODO() - if c == nil { - t.Fatalf("TODO returned nil") - } - select { - case x := <-c.Done(): - t.Errorf("<-c.Done() == %v want nothing (it should block)", x) - default: - } - if got, want := fmt.Sprint(c), "context.TODO"; got != want { - t.Errorf("TODO().String() = %q want %q", got, want) - } -} - -func TestWithCancel(t *testing.T) { - c1, cancel := WithCancel(Background()) - - if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want { - t.Errorf("c1.String() = %q want %q", got, want) - } - - o := otherContext{c1} - c2, _ := WithCancel(o) - contexts := []Context{c1, o, c2} - - for i, c := range contexts { - if d := c.Done(); d == nil { - t.Errorf("c[%d].Done() == %v want non-nil", i, d) - } - if e := c.Err(); e != nil { - t.Errorf("c[%d].Err() == %v want nil", i, e) - } - - select { - case x := <-c.Done(): - t.Errorf("<-c.Done() == %v want nothing (it should block)", x) - default: - } - } - - cancel() - time.Sleep(100 * time.Millisecond) // let cancelation propagate - - for i, c := range contexts { - select { - case <-c.Done(): - default: - t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i) - } - if e := c.Err(); e != Canceled { - t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled) - } - } -} - -func TestParentFinishesChild(t *testing.T) { - // Context tree: - // parent -> cancelChild - // parent -> valueChild -> timerChild - parent, cancel := WithCancel(Background()) - cancelChild, stop := WithCancel(parent) - defer stop() - valueChild := WithValue(parent, "key", "value") - timerChild, stop := WithTimeout(valueChild, 10000*time.Hour) - defer stop() - - select { - case x := <-parent.Done(): - t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) - case x := <-cancelChild.Done(): - t.Errorf("<-cancelChild.Done() == %v want nothing (it should block)", x) - case x := <-timerChild.Done(): - t.Errorf("<-timerChild.Done() == %v want nothing (it should block)", x) - case x := <-valueChild.Done(): - t.Errorf("<-valueChild.Done() == %v want nothing (it should block)", x) - default: - } - - // The parent's children should contain the two cancelable children. - pc := parent.(*cancelCtx) - cc := cancelChild.(*cancelCtx) - tc := timerChild.(*timerCtx) - pc.mu.Lock() - if len(pc.children) != 2 || !pc.children[cc] || !pc.children[tc] { - t.Errorf("bad linkage: pc.children = %v, want %v and %v", - pc.children, cc, tc) - } - pc.mu.Unlock() - - if p, ok := parentCancelCtx(cc.Context); !ok || p != pc { - t.Errorf("bad linkage: parentCancelCtx(cancelChild.Context) = %v, %v want %v, true", p, ok, pc) - } - if p, ok := parentCancelCtx(tc.Context); !ok || p != pc { - t.Errorf("bad linkage: parentCancelCtx(timerChild.Context) = %v, %v want %v, true", p, ok, pc) - } - - cancel() - - pc.mu.Lock() - if len(pc.children) != 0 { - t.Errorf("pc.cancel didn't clear pc.children = %v", pc.children) - } - pc.mu.Unlock() - - // parent and children should all be finished. - check := func(ctx Context, name string) { - select { - case <-ctx.Done(): - default: - t.Errorf("<-%s.Done() blocked, but shouldn't have", name) - } - if e := ctx.Err(); e != Canceled { - t.Errorf("%s.Err() == %v want %v", name, e, Canceled) - } - } - check(parent, "parent") - check(cancelChild, "cancelChild") - check(valueChild, "valueChild") - check(timerChild, "timerChild") - - // WithCancel should return a canceled context on a canceled parent. - precanceledChild := WithValue(parent, "key", "value") - select { - case <-precanceledChild.Done(): - default: - t.Errorf("<-precanceledChild.Done() blocked, but shouldn't have") - } - if e := precanceledChild.Err(); e != Canceled { - t.Errorf("precanceledChild.Err() == %v want %v", e, Canceled) - } -} - -func TestChildFinishesFirst(t *testing.T) { - cancelable, stop := WithCancel(Background()) - defer stop() - for _, parent := range []Context{Background(), cancelable} { - child, cancel := WithCancel(parent) - - select { - case x := <-parent.Done(): - t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) - case x := <-child.Done(): - t.Errorf("<-child.Done() == %v want nothing (it should block)", x) - default: - } - - cc := child.(*cancelCtx) - pc, pcok := parent.(*cancelCtx) // pcok == false when parent == Background() - if p, ok := parentCancelCtx(cc.Context); ok != pcok || (ok && pc != p) { - t.Errorf("bad linkage: parentCancelCtx(cc.Context) = %v, %v want %v, %v", p, ok, pc, pcok) - } - - if pcok { - pc.mu.Lock() - if len(pc.children) != 1 || !pc.children[cc] { - t.Errorf("bad linkage: pc.children = %v, cc = %v", pc.children, cc) - } - pc.mu.Unlock() - } - - cancel() - - if pcok { - pc.mu.Lock() - if len(pc.children) != 0 { - t.Errorf("child's cancel didn't remove self from pc.children = %v", pc.children) - } - pc.mu.Unlock() - } - - // child should be finished. - select { - case <-child.Done(): - default: - t.Errorf("<-child.Done() blocked, but shouldn't have") - } - if e := child.Err(); e != Canceled { - t.Errorf("child.Err() == %v want %v", e, Canceled) - } - - // parent should not be finished. - select { - case x := <-parent.Done(): - t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) - default: - } - if e := parent.Err(); e != nil { - t.Errorf("parent.Err() == %v want nil", e) - } - } -} - -func testDeadline(c Context, wait time.Duration, t *testing.T) { - select { - case <-time.After(wait): - t.Fatalf("context should have timed out") - case <-c.Done(): - } - if e := c.Err(); e != DeadlineExceeded { - t.Errorf("c.Err() == %v want %v", e, DeadlineExceeded) - } -} - -func TestDeadline(t *testing.T) { - t.Parallel() - const timeUnit = 500 * time.Millisecond - c, _ := WithDeadline(Background(), time.Now().Add(1*timeUnit)) - if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { - t.Errorf("c.String() = %q want prefix %q", got, prefix) - } - testDeadline(c, 2*timeUnit, t) - - c, _ = WithDeadline(Background(), time.Now().Add(1*timeUnit)) - o := otherContext{c} - testDeadline(o, 2*timeUnit, t) - - c, _ = WithDeadline(Background(), time.Now().Add(1*timeUnit)) - o = otherContext{c} - c, _ = WithDeadline(o, time.Now().Add(3*timeUnit)) - testDeadline(c, 2*timeUnit, t) -} - -func TestTimeout(t *testing.T) { - t.Parallel() - const timeUnit = 500 * time.Millisecond - c, _ := WithTimeout(Background(), 1*timeUnit) - if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { - t.Errorf("c.String() = %q want prefix %q", got, prefix) - } - testDeadline(c, 2*timeUnit, t) - - c, _ = WithTimeout(Background(), 1*timeUnit) - o := otherContext{c} - testDeadline(o, 2*timeUnit, t) - - c, _ = WithTimeout(Background(), 1*timeUnit) - o = otherContext{c} - c, _ = WithTimeout(o, 3*timeUnit) - testDeadline(c, 2*timeUnit, t) -} - -func TestCanceledTimeout(t *testing.T) { - t.Parallel() - const timeUnit = 500 * time.Millisecond - c, _ := WithTimeout(Background(), 2*timeUnit) - o := otherContext{c} - c, cancel := WithTimeout(o, 4*timeUnit) - cancel() - time.Sleep(1 * timeUnit) // let cancelation propagate - select { - case <-c.Done(): - default: - t.Errorf("<-c.Done() blocked, but shouldn't have") - } - if e := c.Err(); e != Canceled { - t.Errorf("c.Err() == %v want %v", e, Canceled) - } -} - -type key1 int -type key2 int - -var k1 = key1(1) -var k2 = key2(1) // same int as k1, different type -var k3 = key2(3) // same type as k2, different int - -func TestValues(t *testing.T) { - check := func(c Context, nm, v1, v2, v3 string) { - if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 { - t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0) - } - if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 { - t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0) - } - if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 { - t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0) - } - } - - c0 := Background() - check(c0, "c0", "", "", "") - - c1 := WithValue(Background(), k1, "c1k1") - check(c1, "c1", "c1k1", "", "") - - if got, want := fmt.Sprint(c1), `context.Background.WithValue(1, "c1k1")`; got != want { - t.Errorf("c.String() = %q want %q", got, want) - } - - c2 := WithValue(c1, k2, "c2k2") - check(c2, "c2", "c1k1", "c2k2", "") - - c3 := WithValue(c2, k3, "c3k3") - check(c3, "c2", "c1k1", "c2k2", "c3k3") - - c4 := WithValue(c3, k1, nil) - check(c4, "c4", "", "c2k2", "c3k3") - - o0 := otherContext{Background()} - check(o0, "o0", "", "", "") - - o1 := otherContext{WithValue(Background(), k1, "c1k1")} - check(o1, "o1", "c1k1", "", "") - - o2 := WithValue(o1, k2, "o2k2") - check(o2, "o2", "c1k1", "o2k2", "") - - o3 := otherContext{c4} - check(o3, "o3", "", "c2k2", "c3k3") - - o4 := WithValue(o3, k3, nil) - check(o4, "o4", "", "c2k2", "") -} - -func TestAllocs(t *testing.T) { - bg := Background() - for _, test := range []struct { - desc string - f func() - limit float64 - gccgoLimit float64 - }{ - { - desc: "Background()", - f: func() { Background() }, - limit: 0, - gccgoLimit: 0, - }, - { - desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1), - f: func() { - c := WithValue(bg, k1, nil) - c.Value(k1) - }, - limit: 3, - gccgoLimit: 3, - }, - { - desc: "WithTimeout(bg, 15*time.Millisecond)", - f: func() { - c, _ := WithTimeout(bg, 15*time.Millisecond) - <-c.Done() - }, - limit: 8, - gccgoLimit: 16, - }, - { - desc: "WithCancel(bg)", - f: func() { - c, cancel := WithCancel(bg) - cancel() - <-c.Done() - }, - limit: 5, - gccgoLimit: 8, - }, - { - desc: "WithTimeout(bg, 100*time.Millisecond)", - f: func() { - c, cancel := WithTimeout(bg, 100*time.Millisecond) - cancel() - <-c.Done() - }, - limit: 8, - gccgoLimit: 25, - }, - } { - limit := test.limit - if runtime.Compiler == "gccgo" { - // gccgo does not yet do escape analysis. - // TODO(iant): Remove this when gccgo does do escape analysis. - limit = test.gccgoLimit - } - if n := testing.AllocsPerRun(100, test.f); n > limit { - t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit)) - } - } -} - -func TestSimultaneousCancels(t *testing.T) { - root, cancel := WithCancel(Background()) - m := map[Context]CancelFunc{root: cancel} - q := []Context{root} - // Create a tree of contexts. - for len(q) != 0 && len(m) < 100 { - parent := q[0] - q = q[1:] - for i := 0; i < 4; i++ { - ctx, cancel := WithCancel(parent) - m[ctx] = cancel - q = append(q, ctx) - } - } - // Start all the cancels in a random order. - var wg sync.WaitGroup - wg.Add(len(m)) - for _, cancel := range m { - go func(cancel CancelFunc) { - cancel() - wg.Done() - }(cancel) - } - // Wait on all the contexts in a random order. - for ctx := range m { - select { - case <-ctx.Done(): - case <-time.After(1 * time.Second): - buf := make([]byte, 10<<10) - n := runtime.Stack(buf, true) - t.Fatalf("timed out waiting for <-ctx.Done(); stacks:\n%s", buf[:n]) - } - } - // Wait for all the cancel functions to return. - done := make(chan struct{}) - go func() { - wg.Wait() - close(done) - }() - select { - case <-done: - case <-time.After(1 * time.Second): - buf := make([]byte, 10<<10) - n := runtime.Stack(buf, true) - t.Fatalf("timed out waiting for cancel functions; stacks:\n%s", buf[:n]) - } -} - -func TestInterlockedCancels(t *testing.T) { - parent, cancelParent := WithCancel(Background()) - child, cancelChild := WithCancel(parent) - go func() { - parent.Done() - cancelChild() - }() - cancelParent() - select { - case <-child.Done(): - case <-time.After(1 * time.Second): - buf := make([]byte, 10<<10) - n := runtime.Stack(buf, true) - t.Fatalf("timed out waiting for child.Done(); stacks:\n%s", buf[:n]) - } -} - -func TestLayersCancel(t *testing.T) { - testLayers(t, time.Now().UnixNano(), false) -} - -func TestLayersTimeout(t *testing.T) { - testLayers(t, time.Now().UnixNano(), true) -} - -func testLayers(t *testing.T, seed int64, testTimeout bool) { - rand.Seed(seed) - errorf := func(format string, a ...interface{}) { - t.Errorf(fmt.Sprintf("seed=%d: %s", seed, format), a...) - } - const ( - timeout = 200 * time.Millisecond - minLayers = 30 - ) - type value int - var ( - vals []*value - cancels []CancelFunc - numTimers int - ctx = Background() - ) - for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ { - switch rand.Intn(3) { - case 0: - v := new(value) - ctx = WithValue(ctx, v, v) - vals = append(vals, v) - case 1: - var cancel CancelFunc - ctx, cancel = WithCancel(ctx) - cancels = append(cancels, cancel) - case 2: - var cancel CancelFunc - ctx, cancel = WithTimeout(ctx, timeout) - cancels = append(cancels, cancel) - numTimers++ - } - } - checkValues := func(when string) { - for _, key := range vals { - if val := ctx.Value(key).(*value); key != val { - errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key) - } - } - } - select { - case <-ctx.Done(): - errorf("ctx should not be canceled yet") - default: - } - if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) { - t.Errorf("ctx.String() = %q want prefix %q", s, prefix) - } - t.Log(ctx) - checkValues("before cancel") - if testTimeout { - select { - case <-ctx.Done(): - case <-time.After(timeout + 100*time.Millisecond): - errorf("ctx should have timed out") - } - checkValues("after timeout") - } else { - cancel := cancels[rand.Intn(len(cancels))] - cancel() - select { - case <-ctx.Done(): - default: - errorf("ctx should be canceled") - } - checkValues("after cancel") - } -} - -func TestCancelRemoves(t *testing.T) { - checkChildren := func(when string, ctx Context, want int) { - if got := len(ctx.(*cancelCtx).children); got != want { - t.Errorf("%s: context has %d children, want %d", when, got, want) - } - } - - ctx, _ := WithCancel(Background()) - checkChildren("after creation", ctx, 0) - _, cancel := WithCancel(ctx) - checkChildren("with WithCancel child ", ctx, 1) - cancel() - checkChildren("after cancelling WithCancel child", ctx, 0) - - ctx, _ = WithCancel(Background()) - checkChildren("after creation", ctx, 0) - _, cancel = WithTimeout(ctx, 60*time.Minute) - checkChildren("with WithTimeout child ", ctx, 1) - cancel() - checkChildren("after cancelling WithTimeout child", ctx, 0) -} diff --git a/context/ctxhttp/ctxhttp.go b/context/ctxhttp/ctxhttp.go index 37dc0cfdb..e0df203ce 100644 --- a/context/ctxhttp/ctxhttp.go +++ b/context/ctxhttp/ctxhttp.go @@ -3,7 +3,7 @@ // license that can be found in the LICENSE file. // Package ctxhttp provides helper functions for performing context-aware HTTP requests. -package ctxhttp // import "golang.org/x/net/context/ctxhttp" +package ctxhttp import ( "context" diff --git a/context/go17.go b/context/go17.go deleted file mode 100644 index 0c1b86793..000000000 --- a/context/go17.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2016 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build go1.7 - -package context - -import ( - "context" // standard library's context, as of Go 1.7 - "time" -) - -var ( - todo = context.TODO() - background = context.Background() -) - -// Canceled is the error returned by Context.Err when the context is canceled. -var Canceled = context.Canceled - -// DeadlineExceeded is the error returned by Context.Err when the context's -// deadline passes. -var DeadlineExceeded = context.DeadlineExceeded - -// WithCancel returns a copy of parent with a new Done channel. The returned -// context's Done channel is closed when the returned cancel function is called -// or when the parent context's Done channel is closed, whichever happens first. -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete. -func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { - ctx, f := context.WithCancel(parent) - return ctx, f -} - -// WithDeadline returns a copy of the parent context with the deadline adjusted -// to be no later than d. If the parent's deadline is already earlier than d, -// WithDeadline(parent, d) is semantically equivalent to parent. The returned -// context's Done channel is closed when the deadline expires, when the returned -// cancel function is called, or when the parent context's Done channel is -// closed, whichever happens first. -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete. -func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { - ctx, f := context.WithDeadline(parent, deadline) - return ctx, f -} - -// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete: -// -// func slowOperationWithTimeout(ctx context.Context) (Result, error) { -// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) -// defer cancel() // releases resources if slowOperation completes before timeout elapses -// return slowOperation(ctx) -// } -func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { - return WithDeadline(parent, time.Now().Add(timeout)) -} - -// WithValue returns a copy of parent in which the value associated with key is -// val. -// -// Use context Values only for request-scoped data that transits processes and -// APIs, not for passing optional parameters to functions. -func WithValue(parent Context, key interface{}, val interface{}) Context { - return context.WithValue(parent, key, val) -} diff --git a/context/go19.go b/context/go19.go deleted file mode 100644 index e31e35a90..000000000 --- a/context/go19.go +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2017 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build go1.9 - -package context - -import "context" // standard library's context, as of Go 1.7 - -// A Context carries a deadline, a cancelation signal, and other values across -// API boundaries. -// -// Context's methods may be called by multiple goroutines simultaneously. -type Context = context.Context - -// A CancelFunc tells an operation to abandon its work. -// A CancelFunc does not wait for the work to stop. -// After the first call, subsequent calls to a CancelFunc do nothing. -type CancelFunc = context.CancelFunc diff --git a/context/pre_go17.go b/context/pre_go17.go deleted file mode 100644 index 065ff3dfa..000000000 --- a/context/pre_go17.go +++ /dev/null @@ -1,300 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !go1.7 - -package context - -import ( - "errors" - "fmt" - "sync" - "time" -) - -// An emptyCtx is never canceled, has no values, and has no deadline. It is not -// struct{}, since vars of this type must have distinct addresses. -type emptyCtx int - -func (*emptyCtx) Deadline() (deadline time.Time, ok bool) { - return -} - -func (*emptyCtx) Done() <-chan struct{} { - return nil -} - -func (*emptyCtx) Err() error { - return nil -} - -func (*emptyCtx) Value(key interface{}) interface{} { - return nil -} - -func (e *emptyCtx) String() string { - switch e { - case background: - return "context.Background" - case todo: - return "context.TODO" - } - return "unknown empty Context" -} - -var ( - background = new(emptyCtx) - todo = new(emptyCtx) -) - -// Canceled is the error returned by Context.Err when the context is canceled. -var Canceled = errors.New("context canceled") - -// DeadlineExceeded is the error returned by Context.Err when the context's -// deadline passes. -var DeadlineExceeded = errors.New("context deadline exceeded") - -// WithCancel returns a copy of parent with a new Done channel. The returned -// context's Done channel is closed when the returned cancel function is called -// or when the parent context's Done channel is closed, whichever happens first. -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete. -func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { - c := newCancelCtx(parent) - propagateCancel(parent, c) - return c, func() { c.cancel(true, Canceled) } -} - -// newCancelCtx returns an initialized cancelCtx. -func newCancelCtx(parent Context) *cancelCtx { - return &cancelCtx{ - Context: parent, - done: make(chan struct{}), - } -} - -// propagateCancel arranges for child to be canceled when parent is. -func propagateCancel(parent Context, child canceler) { - if parent.Done() == nil { - return // parent is never canceled - } - if p, ok := parentCancelCtx(parent); ok { - p.mu.Lock() - if p.err != nil { - // parent has already been canceled - child.cancel(false, p.err) - } else { - if p.children == nil { - p.children = make(map[canceler]bool) - } - p.children[child] = true - } - p.mu.Unlock() - } else { - go func() { - select { - case <-parent.Done(): - child.cancel(false, parent.Err()) - case <-child.Done(): - } - }() - } -} - -// parentCancelCtx follows a chain of parent references until it finds a -// *cancelCtx. This function understands how each of the concrete types in this -// package represents its parent. -func parentCancelCtx(parent Context) (*cancelCtx, bool) { - for { - switch c := parent.(type) { - case *cancelCtx: - return c, true - case *timerCtx: - return c.cancelCtx, true - case *valueCtx: - parent = c.Context - default: - return nil, false - } - } -} - -// removeChild removes a context from its parent. -func removeChild(parent Context, child canceler) { - p, ok := parentCancelCtx(parent) - if !ok { - return - } - p.mu.Lock() - if p.children != nil { - delete(p.children, child) - } - p.mu.Unlock() -} - -// A canceler is a context type that can be canceled directly. The -// implementations are *cancelCtx and *timerCtx. -type canceler interface { - cancel(removeFromParent bool, err error) - Done() <-chan struct{} -} - -// A cancelCtx can be canceled. When canceled, it also cancels any children -// that implement canceler. -type cancelCtx struct { - Context - - done chan struct{} // closed by the first cancel call. - - mu sync.Mutex - children map[canceler]bool // set to nil by the first cancel call - err error // set to non-nil by the first cancel call -} - -func (c *cancelCtx) Done() <-chan struct{} { - return c.done -} - -func (c *cancelCtx) Err() error { - c.mu.Lock() - defer c.mu.Unlock() - return c.err -} - -func (c *cancelCtx) String() string { - return fmt.Sprintf("%v.WithCancel", c.Context) -} - -// cancel closes c.done, cancels each of c's children, and, if -// removeFromParent is true, removes c from its parent's children. -func (c *cancelCtx) cancel(removeFromParent bool, err error) { - if err == nil { - panic("context: internal error: missing cancel error") - } - c.mu.Lock() - if c.err != nil { - c.mu.Unlock() - return // already canceled - } - c.err = err - close(c.done) - for child := range c.children { - // NOTE: acquiring the child's lock while holding parent's lock. - child.cancel(false, err) - } - c.children = nil - c.mu.Unlock() - - if removeFromParent { - removeChild(c.Context, c) - } -} - -// WithDeadline returns a copy of the parent context with the deadline adjusted -// to be no later than d. If the parent's deadline is already earlier than d, -// WithDeadline(parent, d) is semantically equivalent to parent. The returned -// context's Done channel is closed when the deadline expires, when the returned -// cancel function is called, or when the parent context's Done channel is -// closed, whichever happens first. -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete. -func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { - if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { - // The current deadline is already sooner than the new one. - return WithCancel(parent) - } - c := &timerCtx{ - cancelCtx: newCancelCtx(parent), - deadline: deadline, - } - propagateCancel(parent, c) - d := deadline.Sub(time.Now()) - if d <= 0 { - c.cancel(true, DeadlineExceeded) // deadline has already passed - return c, func() { c.cancel(true, Canceled) } - } - c.mu.Lock() - defer c.mu.Unlock() - if c.err == nil { - c.timer = time.AfterFunc(d, func() { - c.cancel(true, DeadlineExceeded) - }) - } - return c, func() { c.cancel(true, Canceled) } -} - -// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to -// implement Done and Err. It implements cancel by stopping its timer then -// delegating to cancelCtx.cancel. -type timerCtx struct { - *cancelCtx - timer *time.Timer // Under cancelCtx.mu. - - deadline time.Time -} - -func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { - return c.deadline, true -} - -func (c *timerCtx) String() string { - return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now())) -} - -func (c *timerCtx) cancel(removeFromParent bool, err error) { - c.cancelCtx.cancel(false, err) - if removeFromParent { - // Remove this timerCtx from its parent cancelCtx's children. - removeChild(c.cancelCtx.Context, c) - } - c.mu.Lock() - if c.timer != nil { - c.timer.Stop() - c.timer = nil - } - c.mu.Unlock() -} - -// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). -// -// Canceling this context releases resources associated with it, so code should -// call cancel as soon as the operations running in this Context complete: -// -// func slowOperationWithTimeout(ctx context.Context) (Result, error) { -// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) -// defer cancel() // releases resources if slowOperation completes before timeout elapses -// return slowOperation(ctx) -// } -func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { - return WithDeadline(parent, time.Now().Add(timeout)) -} - -// WithValue returns a copy of parent in which the value associated with key is -// val. -// -// Use context Values only for request-scoped data that transits processes and -// APIs, not for passing optional parameters to functions. -func WithValue(parent Context, key interface{}, val interface{}) Context { - return &valueCtx{parent, key, val} -} - -// A valueCtx carries a key-value pair. It implements Value for that key and -// delegates all other calls to the embedded Context. -type valueCtx struct { - Context - key, val interface{} -} - -func (c *valueCtx) String() string { - return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val) -} - -func (c *valueCtx) Value(key interface{}) interface{} { - if c.key == key { - return c.val - } - return c.Context.Value(key) -} diff --git a/context/pre_go19.go b/context/pre_go19.go deleted file mode 100644 index ec5a63803..000000000 --- a/context/pre_go19.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -//go:build !go1.9 - -package context - -import "time" - -// A Context carries a deadline, a cancelation signal, and other values across -// API boundaries. -// -// Context's methods may be called by multiple goroutines simultaneously. -type Context interface { - // Deadline returns the time when work done on behalf of this context - // should be canceled. Deadline returns ok==false when no deadline is - // set. Successive calls to Deadline return the same results. - Deadline() (deadline time.Time, ok bool) - - // Done returns a channel that's closed when work done on behalf of this - // context should be canceled. Done may return nil if this context can - // never be canceled. Successive calls to Done return the same value. - // - // WithCancel arranges for Done to be closed when cancel is called; - // WithDeadline arranges for Done to be closed when the deadline - // expires; WithTimeout arranges for Done to be closed when the timeout - // elapses. - // - // Done is provided for use in select statements: - // - // // Stream generates values with DoSomething and sends them to out - // // until DoSomething returns an error or ctx.Done is closed. - // func Stream(ctx context.Context, out chan<- Value) error { - // for { - // v, err := DoSomething(ctx) - // if err != nil { - // return err - // } - // select { - // case <-ctx.Done(): - // return ctx.Err() - // case out <- v: - // } - // } - // } - // - // See http://blog.golang.org/pipelines for more examples of how to use - // a Done channel for cancelation. - Done() <-chan struct{} - - // Err returns a non-nil error value after Done is closed. Err returns - // Canceled if the context was canceled or DeadlineExceeded if the - // context's deadline passed. No other values for Err are defined. - // After Done is closed, successive calls to Err return the same value. - Err() error - - // Value returns the value associated with this context for key, or nil - // if no value is associated with key. Successive calls to Value with - // the same key returns the same result. - // - // Use context values only for request-scoped data that transits - // processes and API boundaries, not for passing optional parameters to - // functions. - // - // A key identifies a specific value in a Context. Functions that wish - // to store values in Context typically allocate a key in a global - // variable then use that key as the argument to context.WithValue and - // Context.Value. A key can be any type that supports equality; - // packages should define keys as an unexported type to avoid - // collisions. - // - // Packages that define a Context key should provide type-safe accessors - // for the values stores using that key: - // - // // Package user defines a User type that's stored in Contexts. - // package user - // - // import "golang.org/x/net/context" - // - // // User is the type of value stored in the Contexts. - // type User struct {...} - // - // // key is an unexported type for keys defined in this package. - // // This prevents collisions with keys defined in other packages. - // type key int - // - // // userKey is the key for user.User values in Contexts. It is - // // unexported; clients use user.NewContext and user.FromContext - // // instead of using this key directly. - // var userKey key = 0 - // - // // NewContext returns a new Context that carries value u. - // func NewContext(ctx context.Context, u *User) context.Context { - // return context.WithValue(ctx, userKey, u) - // } - // - // // FromContext returns the User value stored in ctx, if any. - // func FromContext(ctx context.Context) (*User, bool) { - // u, ok := ctx.Value(userKey).(*User) - // return u, ok - // } - Value(key interface{}) interface{} -} - -// A CancelFunc tells an operation to abandon its work. -// A CancelFunc does not wait for the work to stop. -// After the first call, subsequent calls to a CancelFunc do nothing. -type CancelFunc func() From 447f458ae023a20be3c2b5481591fb1fc920e464 Mon Sep 17 00:00:00 2001 From: Dmitri Shuralyov <dmitshur@golang.org> Date: Sun, 16 Feb 2025 15:43:41 -0500 Subject: [PATCH 05/17] context: delete lone example This is motivated by a few reasons. One, the upstream package has more examples, and no one should be looking at this old package to learn how to use it. Seeing an example might make it seem like the scope of the documentation here is to provide examples, and that there aren't many of them. Instead of trying to add more examples or maintain the current one by porting the de-flake enhancement from CL 460999, delete the only example here. Second, running 'go fix ./...' causes the 'context' fix to rewrite the import path of the example from "golang.org/x/net/context" to "context". That is a false positive in the fix, and I would've liked it fix the fix, but it only has the AST information at this time, not type info, so the import path isn't currently available to the check. That means it can't know when it's running on the golang.org/x/net/context package, which is the one place it should skip the rewrite. It seems simpler to just delete the example, and then it becomes possible to use 'go fix ./...' safely on the entire x/net module. For golang/go#49506. Change-Id: I97eba33ca2e1f2960aef8340d8b561639a26ee48 Reviewed-on: https://go-review.googlesource.com/c/net/+/650156 Reviewed-by: Ian Lance Taylor <iant@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Dmitri Shuralyov <dmitshur@google.com> Reviewed-by: Damien Neil <dneil@google.com> Auto-Submit: Dmitri Shuralyov <dmitshur@golang.org> --- context/withtimeout_test.go | 31 ------------------------------- 1 file changed, 31 deletions(-) delete mode 100644 context/withtimeout_test.go diff --git a/context/withtimeout_test.go b/context/withtimeout_test.go deleted file mode 100644 index e6f56691d..000000000 --- a/context/withtimeout_test.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2014 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package context_test - -import ( - "fmt" - "time" - - "golang.org/x/net/context" -) - -// This example passes a context with a timeout to tell a blocking function that -// it should abandon its work after the timeout elapses. -func ExampleWithTimeout() { - // Pass a context with a timeout to tell a blocking function that it - // should abandon its work after the timeout elapses. - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - - select { - case <-time.After(1 * time.Second): - fmt.Println("overslept") - case <-ctx.Done(): - fmt.Println(ctx.Err()) // prints "context deadline exceeded" - } - - // Output: - // context deadline exceeded -} From 163d83654d4d78be90251b9bf05aa502b6f7e79d Mon Sep 17 00:00:00 2001 From: Damien Neil <dneil@google.com> Date: Tue, 4 Feb 2025 13:39:37 -0800 Subject: [PATCH 06/17] internal/http3: add Server Add the general structure of an HTTP/3 server. The server currently accepts QUIC connections and establishes a control stream on them, but does not handle requests. For golang/go#70914 Change-Id: I28193ddacef028233248601979b0b45ad844205a Reviewed-on: https://go-review.googlesource.com/c/net/+/646617 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> --- internal/http3/conn_test.go | 5 + internal/http3/server.go | 183 ++++++++++++++++++++++++++++++++++ internal/http3/server_test.go | 110 ++++++++++++++++++++ 3 files changed, 298 insertions(+) create mode 100644 internal/http3/server.go create mode 100644 internal/http3/server_test.go diff --git a/internal/http3/conn_test.go b/internal/http3/conn_test.go index e9b5b4189..a9afb1f9e 100644 --- a/internal/http3/conn_test.go +++ b/internal/http3/conn_test.go @@ -146,4 +146,9 @@ func runConnTest(t *testing.T, f func(testing.TB, *testQUICConn)) { tc := newTestClientConn(t) f(t, tc.testQUICConn) }) + runSynctestSubtest(t, "server", func(t testing.TB) { + ts := newTestServer(t) + tc := ts.connect() + f(t, tc.testQUICConn) + }) } diff --git a/internal/http3/server.go b/internal/http3/server.go new file mode 100644 index 000000000..2d8d1df22 --- /dev/null +++ b/internal/http3/server.go @@ -0,0 +1,183 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.24 + +package http3 + +import ( + "context" + "net/http" + "sync" + + "golang.org/x/net/quic" +) + +// A Server is an HTTP/3 server. +// The zero value for Server is a valid server. +type Server struct { + // Handler to invoke for requests, http.DefaultServeMux if nil. + Handler http.Handler + + // Config is the QUIC configuration used by the server. + // The Config may be nil. + // + // ListenAndServe may clone and modify the Config. + // The Config must not be modified after calling ListenAndServe. + Config *quic.Config + + initOnce sync.Once +} + +func (s *Server) init() { + s.initOnce.Do(func() { + s.Config = initConfig(s.Config) + if s.Handler == nil { + s.Handler = http.DefaultServeMux + } + }) +} + +// ListenAndServe listens on the UDP network address addr +// and then calls Serve to handle requests on incoming connections. +func (s *Server) ListenAndServe(addr string) error { + s.init() + e, err := quic.Listen("udp", addr, s.Config) + if err != nil { + return err + } + return s.Serve(e) +} + +// Serve accepts incoming connections on the QUIC endpoint e, +// and handles requests from those connections. +func (s *Server) Serve(e *quic.Endpoint) error { + s.init() + for { + qconn, err := e.Accept(context.Background()) + if err != nil { + return err + } + go newServerConn(qconn) + } +} + +type serverConn struct { + qconn *quic.Conn + + genericConn // for handleUnidirectionalStream + enc qpackEncoder + dec qpackDecoder +} + +func newServerConn(qconn *quic.Conn) { + sc := &serverConn{ + qconn: qconn, + } + sc.enc.init() + + // Create control stream and send SETTINGS frame. + // TODO: Time out on creating stream. + controlStream, err := newConnStream(context.Background(), sc.qconn, streamTypeControl) + if err != nil { + return + } + controlStream.writeSettings() + controlStream.Flush() + + // Accept streams on the connection. + for { + st, err := sc.qconn.AcceptStream(context.Background()) + if err != nil { + return // connection closed + } + if st.IsReadOnly() { + go sc.handleUnidirectionalStream(newStream(st), sc) + } else { + go sc.handleRequestStream(newStream(st)) + } + } +} + +func (sc *serverConn) handleControlStream(st *stream) error { + // "A SETTINGS frame MUST be sent as the first frame of each control stream [...]" + // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4-2 + if err := st.readSettings(func(settingsType, settingsValue int64) error { + switch settingsType { + case settingsMaxFieldSectionSize: + _ = settingsValue // TODO + case settingsQPACKMaxTableCapacity: + _ = settingsValue // TODO + case settingsQPACKBlockedStreams: + _ = settingsValue // TODO + default: + // Unknown settings types are ignored. + } + return nil + }); err != nil { + return err + } + + for { + ftype, err := st.readFrameHeader() + if err != nil { + return err + } + switch ftype { + case frameTypeCancelPush: + // "If a server receives a CANCEL_PUSH frame for a push ID + // that has not yet been mentioned by a PUSH_PROMISE frame, + // this MUST be treated as a connection error of type H3_ID_ERROR." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.3-8 + return &connectionError{ + code: errH3IDError, + message: "CANCEL_PUSH for unsent push ID", + } + case frameTypeGoaway: + return errH3NoError + default: + // Unknown frames are ignored. + if err := st.discardUnknownFrame(ftype); err != nil { + return err + } + } + } +} + +func (sc *serverConn) handleEncoderStream(*stream) error { + // TODO + return nil +} + +func (sc *serverConn) handleDecoderStream(*stream) error { + // TODO + return nil +} + +func (sc *serverConn) handlePushStream(*stream) error { + // "[...] if a server receives a client-initiated push stream, + // this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2.2-3 + return &connectionError{ + code: errH3StreamCreationError, + message: "client created push stream", + } +} + +func (sc *serverConn) handleRequestStream(st *stream) { + // TODO + return +} + +// abort closes the connection with an error. +func (sc *serverConn) abort(err error) { + if e, ok := err.(*connectionError); ok { + sc.qconn.Abort(&quic.ApplicationError{ + Code: uint64(e.code), + Reason: e.message, + }) + } else { + sc.qconn.Abort(err) + } +} diff --git a/internal/http3/server_test.go b/internal/http3/server_test.go new file mode 100644 index 000000000..8e727d251 --- /dev/null +++ b/internal/http3/server_test.go @@ -0,0 +1,110 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.24 && goexperiment.synctest + +package http3 + +import ( + "net/netip" + "testing" + "testing/synctest" + + "golang.org/x/net/internal/quic/quicwire" + "golang.org/x/net/quic" +) + +func TestServerReceivePushStream(t *testing.T) { + // "[...] if a server receives a client-initiated push stream, + // this MUST be treated as a connection error of type H3_STREAM_CREATION_ERROR." + // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.2.2-3 + runSynctest(t, func(t testing.TB) { + ts := newTestServer(t) + tc := ts.connect() + tc.newStream(streamTypePush) + tc.wantClosed("invalid client-created push stream", errH3StreamCreationError) + }) +} + +func TestServerCancelPushForUnsentPromise(t *testing.T) { + runSynctest(t, func(t testing.TB) { + ts := newTestServer(t) + tc := ts.connect() + tc.greet() + + const pushID = 100 + tc.control.writeVarint(int64(frameTypeCancelPush)) + tc.control.writeVarint(int64(quicwire.SizeVarint(pushID))) + tc.control.writeVarint(pushID) + tc.control.Flush() + + tc.wantClosed("client canceled never-sent push ID", errH3IDError) + }) +} + +type testServer struct { + t testing.TB + s *Server + tn testNet + *testQUICEndpoint + + addr netip.AddrPort +} + +type testQUICEndpoint struct { + t testing.TB + e *quic.Endpoint +} + +func (te *testQUICEndpoint) dial() { +} + +type testServerConn struct { + ts *testServer + + *testQUICConn + control *testQUICStream +} + +func newTestServer(t testing.TB) *testServer { + t.Helper() + ts := &testServer{ + t: t, + s: &Server{ + Config: &quic.Config{ + TLSConfig: testTLSConfig, + }, + }, + } + e := ts.tn.newQUICEndpoint(t, ts.s.Config) + ts.addr = e.LocalAddr() + go ts.s.Serve(e) + return ts +} + +func (ts *testServer) connect() *testServerConn { + ts.t.Helper() + config := &quic.Config{TLSConfig: testTLSConfig} + e := ts.tn.newQUICEndpoint(ts.t, nil) + qconn, err := e.Dial(ts.t.Context(), "udp", ts.addr.String(), config) + if err != nil { + ts.t.Fatal(err) + } + tc := &testServerConn{ + ts: ts, + testQUICConn: newTestQUICConn(ts.t, qconn), + } + synctest.Wait() + return tc +} + +// greet performs initial connection handshaking with the server. +func (tc *testServerConn) greet() { + // Client creates a control stream. + tc.control = tc.newStream(streamTypeControl) + tc.control.writeVarint(int64(frameTypeSettings)) + tc.control.writeVarint(0) // size + tc.control.Flush() + synctest.Wait() +} From b4c86550a5be2d314b04727f13affd9bb07fcf46 Mon Sep 17 00:00:00 2001 From: Damien Neil <dneil@google.com> Date: Tue, 3 Dec 2024 10:39:37 -0800 Subject: [PATCH 07/17] http2: avoid extended CONNECT hang when connection breaks during startup Fixes golang/go#70658 Change-Id: Iaac5c7730a10afc8a8bb2e725746fa7387970582 Reviewed-on: https://go-review.googlesource.com/c/net/+/633277 Auto-Submit: Damien Neil <dneil@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Antonio Ojea <aojea@google.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> --- http2/transport.go | 10 +++++++--- http2/transport_test.go | 21 +++++++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/http2/transport.go b/http2/transport.go index 94b397c69..f26356b9c 100644 --- a/http2/transport.go +++ b/http2/transport.go @@ -2210,6 +2210,13 @@ func (rl *clientConnReadLoop) cleanup() { } cc.cond.Broadcast() cc.mu.Unlock() + + if !cc.seenSettings { + // If we have a pending request that wants extended CONNECT, + // let it continue and fail with the connection error. + cc.extendedConnectAllowed = true + close(cc.seenSettingsChan) + } } // countReadFrameError calls Transport.CountError with a string @@ -2302,9 +2309,6 @@ func (rl *clientConnReadLoop) run() error { if VerboseLogs { cc.vlogf("http2: Transport conn %p received error from processing frame %v: %v", cc, summarizeFrame(f), err) } - if !cc.seenSettings { - close(cc.seenSettingsChan) - } return err } } diff --git a/http2/transport_test.go b/http2/transport_test.go index d1d27f8f9..1eeb76e06 100644 --- a/http2/transport_test.go +++ b/http2/transport_test.go @@ -5914,3 +5914,24 @@ func TestExtendedConnectClientWithoutServerSupport(t *testing.T) { t.Fatalf("expected error errExtendedConnectNotSupported, got: %v", err) } } + +// Issue #70658: Make sure extended CONNECT requests don't get stuck if a +// connection fails early in its lifetime. +func TestExtendedConnectReadFrameError(t *testing.T) { + tc := newTestClientConn(t) + tc.wantFrameType(FrameSettings) + tc.wantFrameType(FrameWindowUpdate) + + req, _ := http.NewRequest("CONNECT", "https://dummy.tld/", nil) + req.Header.Set(":protocol", "extended-connect") + rt := tc.roundTrip(req) + tc.wantIdle() // waiting for SETTINGS response + + tc.closeWrite() // connection breaks without sending SETTINGS + if !rt.done() { + t.Fatalf("after connection closed: RoundTrip still running; want done") + } + if rt.err() == nil { + t.Fatalf("after connection closed: RoundTrip succeeded; want error") + } +} From 0d7dc54a591c12b4bd03bcd745024178d03d9218 Mon Sep 17 00:00:00 2001 From: Damien Neil <dneil@google.com> Date: Wed, 5 Feb 2025 16:56:42 -0800 Subject: [PATCH 08/17] quic: add Conn.ConnectionState Add a method that returns the tls.ConnectionState for a connection. Generally useful, and also required to let HTTP/3 expose the ConnectionState in Requests. Change-Id: Iba725e0f40c68020fc6ee45d49f5c609a2b6b493 Reviewed-on: https://go-review.googlesource.com/c/net/+/647075 Auto-Submit: Damien Neil <dneil@google.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> --- quic/conn.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/quic/conn.go b/quic/conn.go index bf54409bf..1f1cfa6d0 100644 --- a/quic/conn.go +++ b/quic/conn.go @@ -186,6 +186,11 @@ func (c *Conn) RemoteAddr() netip.AddrPort { return c.peerAddr } +// ConnectionState returns basic TLS details about the connection. +func (c *Conn) ConnectionState() tls.ConnectionState { + return c.tls.ConnectionState() +} + // confirmHandshake is called when the handshake is confirmed. // https://www.rfc-editor.org/rfc/rfc9001#section-4.1.2 func (c *Conn) confirmHandshake(now time.Time) { From 1d78a085008d9fedfe3f303591058325f99727d7 Mon Sep 17 00:00:00 2001 From: Damien Neil <dneil@google.com> Date: Wed, 5 Feb 2025 16:56:02 -0800 Subject: [PATCH 09/17] http2, internal/httpcommon: factor out server header logic for h2/h3 Move common elements of constructing a http.Request for a server handler into internal/httpcommon. For golang/go#70914 Change-Id: I5dcd902e189a0bb8daf47c0a815045d274346923 Reviewed-on: https://go-review.googlesource.com/c/net/+/652455 Auto-Submit: Damien Neil <dneil@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> --- http2/server.go | 121 +++++++++++---------------------- internal/httpcommon/request.go | 77 +++++++++++++++++++++ 2 files changed, 115 insertions(+), 83 deletions(-) diff --git a/http2/server.go b/http2/server.go index 7434b8784..b640deb0e 100644 --- a/http2/server.go +++ b/http2/server.go @@ -2233,25 +2233,25 @@ func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) { sc.serveG.check() - rp := requestParam{ - method: f.PseudoValue("method"), - scheme: f.PseudoValue("scheme"), - authority: f.PseudoValue("authority"), - path: f.PseudoValue("path"), - protocol: f.PseudoValue("protocol"), + rp := httpcommon.ServerRequestParam{ + Method: f.PseudoValue("method"), + Scheme: f.PseudoValue("scheme"), + Authority: f.PseudoValue("authority"), + Path: f.PseudoValue("path"), + Protocol: f.PseudoValue("protocol"), } // extended connect is disabled, so we should not see :protocol - if disableExtendedConnectProtocol && rp.protocol != "" { + if disableExtendedConnectProtocol && rp.Protocol != "" { return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol)) } - isConnect := rp.method == "CONNECT" + isConnect := rp.Method == "CONNECT" if isConnect { - if rp.protocol == "" && (rp.path != "" || rp.scheme != "" || rp.authority == "") { + if rp.Protocol == "" && (rp.Path != "" || rp.Scheme != "" || rp.Authority == "") { return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol)) } - } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { + } else if rp.Method == "" || rp.Path == "" || (rp.Scheme != "https" && rp.Scheme != "http") { // See 8.1.2.6 Malformed Requests and Responses: // // Malformed requests or responses that are detected @@ -2265,15 +2265,16 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol)) } - rp.header = make(http.Header) + header := make(http.Header) + rp.Header = header for _, hf := range f.RegularFields() { - rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value) + header.Add(sc.canonicalHeader(hf.Name), hf.Value) } - if rp.authority == "" { - rp.authority = rp.header.Get("Host") + if rp.Authority == "" { + rp.Authority = header.Get("Host") } - if rp.protocol != "" { - rp.header.Set(":protocol", rp.protocol) + if rp.Protocol != "" { + header.Set(":protocol", rp.Protocol) } rw, req, err := sc.newWriterAndRequestNoBody(st, rp) @@ -2282,7 +2283,7 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res } bodyOpen := !f.StreamEnded() if bodyOpen { - if vv, ok := rp.header["Content-Length"]; ok { + if vv, ok := rp.Header["Content-Length"]; ok { if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil { req.ContentLength = int64(cl) } else { @@ -2298,84 +2299,38 @@ func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*res return rw, req, nil } -type requestParam struct { - method string - scheme, authority, path string - protocol string - header http.Header -} - -func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*responseWriter, *http.Request, error) { +func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp httpcommon.ServerRequestParam) (*responseWriter, *http.Request, error) { sc.serveG.check() var tlsState *tls.ConnectionState // nil if not scheme https - if rp.scheme == "https" { + if rp.Scheme == "https" { tlsState = sc.tlsState } - needsContinue := httpguts.HeaderValuesContainsToken(rp.header["Expect"], "100-continue") - if needsContinue { - rp.header.Del("Expect") - } - // Merge Cookie headers into one "; "-delimited value. - if cookies := rp.header["Cookie"]; len(cookies) > 1 { - rp.header.Set("Cookie", strings.Join(cookies, "; ")) - } - - // Setup Trailers - var trailer http.Header - for _, v := range rp.header["Trailer"] { - for _, key := range strings.Split(v, ",") { - key = http.CanonicalHeaderKey(textproto.TrimString(key)) - switch key { - case "Transfer-Encoding", "Trailer", "Content-Length": - // Bogus. (copy of http1 rules) - // Ignore. - default: - if trailer == nil { - trailer = make(http.Header) - } - trailer[key] = nil - } - } - } - delete(rp.header, "Trailer") - - var url_ *url.URL - var requestURI string - if rp.method == "CONNECT" && rp.protocol == "" { - url_ = &url.URL{Host: rp.authority} - requestURI = rp.authority // mimic HTTP/1 server behavior - } else { - var err error - url_, err = url.ParseRequestURI(rp.path) - if err != nil { - return nil, nil, sc.countError("bad_path", streamError(st.id, ErrCodeProtocol)) - } - requestURI = rp.path + res := httpcommon.NewServerRequest(rp) + if res.InvalidReason != "" { + return nil, nil, sc.countError(res.InvalidReason, streamError(st.id, ErrCodeProtocol)) } body := &requestBody{ conn: sc, stream: st, - needsContinue: needsContinue, + needsContinue: res.NeedsContinue, } - req := &http.Request{ - Method: rp.method, - URL: url_, + req := (&http.Request{ + Method: rp.Method, + URL: res.URL, RemoteAddr: sc.remoteAddrStr, - Header: rp.header, - RequestURI: requestURI, + Header: rp.Header, + RequestURI: res.RequestURI, Proto: "HTTP/2.0", ProtoMajor: 2, ProtoMinor: 0, TLS: tlsState, - Host: rp.authority, + Host: rp.Authority, Body: body, - Trailer: trailer, - } - req = req.WithContext(st.ctx) - + Trailer: res.Trailer, + }).WithContext(st.ctx) rw := sc.newResponseWriter(st, req) return rw, req, nil } @@ -3270,12 +3225,12 @@ func (sc *serverConn) startPush(msg *startPushRequest) { // we start in "half closed (remote)" for simplicity. // See further comments at the definition of stateHalfClosedRemote. promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote) - rw, req, err := sc.newWriterAndRequestNoBody(promised, requestParam{ - method: msg.method, - scheme: msg.url.Scheme, - authority: msg.url.Host, - path: msg.url.RequestURI(), - header: cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE + rw, req, err := sc.newWriterAndRequestNoBody(promised, httpcommon.ServerRequestParam{ + Method: msg.method, + Scheme: msg.url.Scheme, + Authority: msg.url.Host, + Path: msg.url.RequestURI(), + Header: cloneHeader(msg.header), // clone since handler runs concurrently with writing the PUSH_PROMISE }) if err != nil { // Should not happen, since we've already validated msg.url. diff --git a/internal/httpcommon/request.go b/internal/httpcommon/request.go index bec16d0b9..a9629809c 100644 --- a/internal/httpcommon/request.go +++ b/internal/httpcommon/request.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "net/http/httptrace" + "net/textproto" "net/url" "sort" "strconv" @@ -378,3 +379,79 @@ func shouldSendReqContentLength(method string, contentLength int64) bool { return false } } + +// ServerRequestParam is parameters to NewServerRequest. +type ServerRequestParam struct { + Method string + Scheme, Authority, Path string + Protocol string + Header map[string][]string +} + +// ServerRequestResult is the result of NewServerRequest. +type ServerRequestResult struct { + // Various http.Request fields. + URL *url.URL + RequestURI string + Trailer map[string][]string + + NeedsContinue bool // client provided an "Expect: 100-continue" header + + // If the request should be rejected, this is a short string suitable for passing + // to the http2 package's CountError function. + // It might be a bit odd to return errors this way rather than returing an error, + // but this ensures we don't forget to include a CountError reason. + InvalidReason string +} + +func NewServerRequest(rp ServerRequestParam) ServerRequestResult { + needsContinue := httpguts.HeaderValuesContainsToken(rp.Header["Expect"], "100-continue") + if needsContinue { + delete(rp.Header, "Expect") + } + // Merge Cookie headers into one "; "-delimited value. + if cookies := rp.Header["Cookie"]; len(cookies) > 1 { + rp.Header["Cookie"] = []string{strings.Join(cookies, "; ")} + } + + // Setup Trailers + var trailer map[string][]string + for _, v := range rp.Header["Trailer"] { + for _, key := range strings.Split(v, ",") { + key = textproto.CanonicalMIMEHeaderKey(textproto.TrimString(key)) + switch key { + case "Transfer-Encoding", "Trailer", "Content-Length": + // Bogus. (copy of http1 rules) + // Ignore. + default: + if trailer == nil { + trailer = make(map[string][]string) + } + trailer[key] = nil + } + } + } + delete(rp.Header, "Trailer") + var url_ *url.URL + var requestURI string + if rp.Method == "CONNECT" && rp.Protocol == "" { + url_ = &url.URL{Host: rp.Authority} + requestURI = rp.Authority // mimic HTTP/1 server behavior + } else { + var err error + url_, err = url.ParseRequestURI(rp.Path) + if err != nil { + return ServerRequestResult{ + InvalidReason: "bad_path", + } + } + requestURI = rp.Path + } + + return ServerRequestResult{ + URL: url_, + NeedsContinue: needsContinue, + RequestURI: requestURI, + Trailer: trailer, + } +} From 43c2540165a4d1bc9a81e06a86eb1e22ece64145 Mon Sep 17 00:00:00 2001 From: Damien Neil <dneil@google.com> Date: Mon, 24 Feb 2025 16:51:59 -0800 Subject: [PATCH 10/17] http2, internal/httpcommon: reject userinfo in :authority RFC 9113, section 8.3.1: The :authority (host) in an HTTP request must not include a userinfo (e.g., user@host). Change-Id: I459a3da40b825c9662467778f582050c7358f8bb Reviewed-on: https://go-review.googlesource.com/c/net/+/652456 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> Auto-Submit: Damien Neil <dneil@google.com> --- http2/server_test.go | 20 ++++++++++++++++++++ internal/httpcommon/request.go | 10 ++++++++++ 2 files changed, 30 insertions(+) diff --git a/http2/server_test.go b/http2/server_test.go index 08f2dd3b2..376227e66 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -1032,6 +1032,26 @@ func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) { }) } +func TestServer_Request_Reject_Authority_Userinfo(t *testing.T) { + // "':authority' MUST NOT include the deprecated userinfo subcomponent + // for "http" or "https" schemed URIs." + // https://www.rfc-editor.org/rfc/rfc9113.html#section-8.3.1-2.3.8 + testRejectRequest(t, func(st *serverTester) { + var buf bytes.Buffer + enc := hpack.NewEncoder(&buf) + enc.WriteField(hpack.HeaderField{Name: ":authority", Value: "userinfo@example.tld"}) + enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"}) + enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/"}) + enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"}) + st.writeHeaders(HeadersFrameParam{ + StreamID: 1, // clients send odd numbers + BlockFragment: buf.Bytes(), + EndStream: true, + EndHeaders: true, + }) + }) +} + func testRejectRequest(t *testing.T, send func(*serverTester)) { st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) { t.Error("server request made it to handler; should've been rejected") diff --git a/internal/httpcommon/request.go b/internal/httpcommon/request.go index a9629809c..4b7055317 100644 --- a/internal/httpcommon/request.go +++ b/internal/httpcommon/request.go @@ -432,6 +432,16 @@ func NewServerRequest(rp ServerRequestParam) ServerRequestResult { } } delete(rp.Header, "Trailer") + + // "':authority' MUST NOT include the deprecated userinfo subcomponent + // for "http" or "https" schemed URIs." + // https://www.rfc-editor.org/rfc/rfc9113.html#section-8.3.1-2.3.8 + if strings.IndexByte(rp.Authority, '@') != -1 && (rp.Scheme == "http" || rp.Scheme == "https") { + return ServerRequestResult{ + InvalidReason: "userinfo_in_authority", + } + } + var url_ *url.URL var requestURI string if rp.Method == "CONNECT" && rp.Protocol == "" { From 5f45c776a9c4d415cbe67d6c22c06fd704f8c9f1 Mon Sep 17 00:00:00 2001 From: Damien Neil <dneil@google.com> Date: Mon, 24 Feb 2025 16:54:39 -0800 Subject: [PATCH 11/17] internal/http3: make read-data tests usable for server handlers A reading a transport response body behaves much the same as a server handler reading a request body. Move the transport test into body_test.go and rearrange it a bit so we can reuse it as a server test. For golang/go#70914 Change-Id: I24e10dd078ffab867c9b678e1d0b99172763b069 Reviewed-on: https://go-review.googlesource.com/c/net/+/652457 Auto-Submit: Damien Neil <dneil@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> --- internal/http3/body_test.go | 276 +++++++++++++++++++++++++++++++ internal/http3/roundtrip_test.go | 268 ------------------------------ 2 files changed, 276 insertions(+), 268 deletions(-) create mode 100644 internal/http3/body_test.go diff --git a/internal/http3/body_test.go b/internal/http3/body_test.go new file mode 100644 index 000000000..599e0df81 --- /dev/null +++ b/internal/http3/body_test.go @@ -0,0 +1,276 @@ +// Copyright 2025 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.24 && goexperiment.synctest + +package http3 + +import ( + "bytes" + "fmt" + "io" + "net/http" + "testing" +) + +// TestReadData tests servers reading request bodies, and clients reading response bodies. +func TestReadData(t *testing.T) { + // These tests consist of a series of steps, + // where each step is either something arriving on the stream + // or the client/server reading from the body. + type ( + // HEADERS frame arrives (headers). + receiveHeaders struct { + contentLength int64 // -1 for no content-length + } + // DATA frame header arrives. + receiveDataHeader struct { + size int64 + } + // DATA frame content arrives. + receiveData struct { + size int64 + } + // HEADERS frame arrives (trailers). + receiveTrailers struct{} + // Some other frame arrives. + receiveFrame struct { + ftype frameType + data []byte + } + // Stream closed, ending the body. + receiveEOF struct{} + // Server reads from Request.Body, or client reads from Response.Body. + wantBody struct { + size int64 + eof bool + } + wantError struct{} + ) + for _, test := range []struct { + name string + respHeader http.Header + steps []any + wantError bool + }{{ + name: "no content length", + steps: []any{ + receiveHeaders{contentLength: -1}, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + receiveEOF{}, + wantBody{size: 10, eof: true}, + }, + }, { + name: "valid content length", + steps: []any{ + receiveHeaders{contentLength: 10}, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + receiveEOF{}, + wantBody{size: 10, eof: true}, + }, + }, { + name: "data frame exceeds content length", + steps: []any{ + receiveHeaders{contentLength: 5}, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + wantError{}, + }, + }, { + name: "data frame after all content read", + steps: []any{ + receiveHeaders{contentLength: 5}, + receiveDataHeader{size: 5}, + receiveData{size: 5}, + wantBody{size: 5}, + receiveDataHeader{size: 1}, + receiveData{size: 1}, + wantError{}, + }, + }, { + name: "content length too long", + steps: []any{ + receiveHeaders{contentLength: 10}, + receiveDataHeader{size: 5}, + receiveData{size: 5}, + receiveEOF{}, + wantBody{size: 5}, + wantError{}, + }, + }, { + name: "stream ended by trailers", + steps: []any{ + receiveHeaders{contentLength: -1}, + receiveDataHeader{size: 5}, + receiveData{size: 5}, + receiveTrailers{}, + wantBody{size: 5, eof: true}, + }, + }, { + name: "trailers and content length too long", + steps: []any{ + receiveHeaders{contentLength: 10}, + receiveDataHeader{size: 5}, + receiveData{size: 5}, + wantBody{size: 5}, + receiveTrailers{}, + wantError{}, + }, + }, { + name: "unknown frame before headers", + steps: []any{ + receiveFrame{ + ftype: 0x1f + 0x21, // reserved frame type + data: []byte{1, 2, 3, 4}, + }, + receiveHeaders{contentLength: -1}, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + wantBody{size: 10}, + }, + }, { + name: "unknown frame after headers", + steps: []any{ + receiveHeaders{contentLength: -1}, + receiveFrame{ + ftype: 0x1f + 0x21, // reserved frame type + data: []byte{1, 2, 3, 4}, + }, + receiveDataHeader{size: 10}, + receiveData{size: 10}, + wantBody{size: 10}, + }, + }, { + name: "invalid frame", + steps: []any{ + receiveHeaders{contentLength: -1}, + receiveFrame{ + ftype: frameTypeSettings, // not a valid frame on this stream + data: []byte{1, 2, 3, 4}, + }, + wantError{}, + }, + }, { + name: "data frame consumed by several reads", + steps: []any{ + receiveHeaders{contentLength: -1}, + receiveDataHeader{size: 16}, + receiveData{size: 16}, + wantBody{size: 2}, + wantBody{size: 4}, + wantBody{size: 8}, + wantBody{size: 2}, + }, + }, { + name: "read multiple frames", + steps: []any{ + receiveHeaders{contentLength: -1}, + receiveDataHeader{size: 2}, + receiveData{size: 2}, + receiveDataHeader{size: 4}, + receiveData{size: 4}, + receiveDataHeader{size: 8}, + receiveData{size: 8}, + wantBody{size: 2}, + wantBody{size: 4}, + wantBody{size: 8}, + }, + }} { + + runTest := func(t testing.TB, h http.Header, st *testQUICStream, body func() io.ReadCloser) { + var ( + bytesSent int + bytesReceived int + ) + for _, step := range test.steps { + switch step := step.(type) { + case receiveHeaders: + header := h.Clone() + if step.contentLength != -1 { + header["content-length"] = []string{ + fmt.Sprint(step.contentLength), + } + } + st.writeHeaders(header) + case receiveDataHeader: + t.Logf("receive DATA frame header: size=%v", step.size) + st.writeVarint(int64(frameTypeData)) + st.writeVarint(step.size) + st.Flush() + case receiveData: + t.Logf("receive DATA frame content: size=%v", step.size) + for range step.size { + st.stream.stream.WriteByte(byte(bytesSent)) + bytesSent++ + } + st.Flush() + case receiveTrailers: + st.writeHeaders(http.Header{ + "x-trailer": []string{"trailer"}, + }) + case receiveFrame: + st.writeVarint(int64(step.ftype)) + st.writeVarint(int64(len(step.data))) + st.Write(step.data) + st.Flush() + case receiveEOF: + t.Logf("receive EOF on request stream") + st.stream.stream.CloseWrite() + case wantBody: + t.Logf("read %v bytes from response body", step.size) + want := make([]byte, step.size) + for i := range want { + want[i] = byte(bytesReceived) + bytesReceived++ + } + got := make([]byte, step.size) + n, err := body().Read(got) + got = got[:n] + if !bytes.Equal(got, want) { + t.Errorf("resp.Body.Read:") + t.Errorf(" got: {%x}", got) + t.Fatalf(" want: {%x}", want) + } + if err != nil { + if step.eof && err == io.EOF { + continue + } + t.Fatalf("resp.Body.Read: unexpected error %v", err) + } + if step.eof { + if n, err := body().Read([]byte{0}); n != 0 || err != io.EOF { + t.Fatalf("resp.Body.Read() = %v, %v; want io.EOF", n, err) + } + } + case wantError: + if n, err := body().Read([]byte{0}); n != 0 || err == nil || err == io.EOF { + t.Fatalf("resp.Body.Read() = %v, %v; want error", n, err) + } + default: + t.Fatalf("unknown test step %T", step) + } + } + + } + + runSynctestSubtest(t, test.name+"/client", func(t testing.TB) { + tc := newTestClientConn(t) + tc.greet() + + req, _ := http.NewRequest("GET", "https://example.tld/", nil) + rt := tc.roundTrip(req) + st := tc.wantStream(streamTypeRequest) + st.wantHeaders(nil) + + header := http.Header{ + ":status": []string{"200"}, + } + runTest(t, header, st, func() io.ReadCloser { + return rt.response().Body + }) + }) + } +} diff --git a/internal/http3/roundtrip_test.go b/internal/http3/roundtrip_test.go index 533b750a5..acd8613d0 100644 --- a/internal/http3/roundtrip_test.go +++ b/internal/http3/roundtrip_test.go @@ -237,274 +237,6 @@ func TestRoundTripCrumbledCookiesInResponse(t *testing.T) { }) } -func TestRoundTripResponseBody(t *testing.T) { - // These tests consist of a series of steps, - // where each step is either something arriving on the response stream - // or the client reading from the request body. - type ( - // HEADERS frame arrives on the response stream (headers or trailers). - receiveHeaders http.Header - // DATA frame header arrives on the response stream. - receiveDataHeader struct { - size int64 - } - // DATA frame content arrives on the response stream. - receiveData struct { - size int64 - } - // Some other frame arrives on the response stream. - receiveFrame struct { - ftype frameType - data []byte - } - // Response stream closed, ending the body. - receiveEOF struct{} - // Client reads from Response.Body. - wantBody struct { - size int64 - eof bool - } - wantError struct{} - ) - for _, test := range []struct { - name string - respHeader http.Header - steps []any - wantError bool - }{{ - name: "no content length", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - }, - receiveDataHeader{size: 10}, - receiveData{size: 10}, - receiveEOF{}, - wantBody{size: 10, eof: true}, - }, - }, { - name: "valid content length", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - "content-length": []string{"10"}, - }, - receiveDataHeader{size: 10}, - receiveData{size: 10}, - receiveEOF{}, - wantBody{size: 10, eof: true}, - }, - }, { - name: "data frame exceeds content length", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - "content-length": []string{"5"}, - }, - receiveDataHeader{size: 10}, - receiveData{size: 10}, - wantError{}, - }, - }, { - name: "data frame after all content read", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - "content-length": []string{"5"}, - }, - receiveDataHeader{size: 5}, - receiveData{size: 5}, - wantBody{size: 5}, - receiveDataHeader{size: 1}, - receiveData{size: 1}, - wantError{}, - }, - }, { - name: "content length too long", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - "content-length": []string{"10"}, - }, - receiveDataHeader{size: 5}, - receiveData{size: 5}, - receiveEOF{}, - wantBody{size: 5}, - wantError{}, - }, - }, { - name: "stream ended by trailers", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - }, - receiveDataHeader{size: 5}, - receiveData{size: 5}, - receiveHeaders{ - "x-trailer": []string{"value"}, - }, - wantBody{size: 5, eof: true}, - }, - }, { - name: "trailers and content length too long", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - "content-length": []string{"10"}, - }, - receiveDataHeader{size: 5}, - receiveData{size: 5}, - wantBody{size: 5}, - receiveHeaders{ - "x-trailer": []string{"value"}, - }, - wantError{}, - }, - }, { - name: "unknown frame before headers", - steps: []any{ - receiveFrame{ - ftype: 0x1f + 0x21, // reserved frame type - data: []byte{1, 2, 3, 4}, - }, - receiveHeaders{ - ":status": []string{"200"}, - }, - receiveDataHeader{size: 10}, - receiveData{size: 10}, - wantBody{size: 10}, - }, - }, { - name: "unknown frame after headers", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - }, - receiveFrame{ - ftype: 0x1f + 0x21, // reserved frame type - data: []byte{1, 2, 3, 4}, - }, - receiveDataHeader{size: 10}, - receiveData{size: 10}, - wantBody{size: 10}, - }, - }, { - name: "invalid frame", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - }, - receiveFrame{ - ftype: frameTypeSettings, // not a valid frame on this stream - data: []byte{1, 2, 3, 4}, - }, - wantError{}, - }, - }, { - name: "data frame consumed by several reads", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - }, - receiveDataHeader{size: 16}, - receiveData{size: 16}, - wantBody{size: 2}, - wantBody{size: 4}, - wantBody{size: 8}, - wantBody{size: 2}, - }, - }, { - name: "read multiple frames", - steps: []any{ - receiveHeaders{ - ":status": []string{"200"}, - }, - receiveDataHeader{size: 2}, - receiveData{size: 2}, - receiveDataHeader{size: 4}, - receiveData{size: 4}, - receiveDataHeader{size: 8}, - receiveData{size: 8}, - wantBody{size: 2}, - wantBody{size: 4}, - wantBody{size: 8}, - }, - }} { - runSynctestSubtest(t, test.name, func(t testing.TB) { - tc := newTestClientConn(t) - tc.greet() - - req, _ := http.NewRequest("GET", "https://example.tld/", nil) - rt := tc.roundTrip(req) - st := tc.wantStream(streamTypeRequest) - st.wantHeaders(nil) - - var ( - bytesSent int - bytesReceived int - ) - for _, step := range test.steps { - switch step := step.(type) { - case receiveHeaders: - st.writeHeaders(http.Header(step)) - case receiveDataHeader: - t.Logf("receive DATA frame header: size=%v", step.size) - st.writeVarint(int64(frameTypeData)) - st.writeVarint(step.size) - st.Flush() - case receiveData: - t.Logf("receive DATA frame content: size=%v", step.size) - for range step.size { - st.stream.stream.WriteByte(byte(bytesSent)) - bytesSent++ - } - st.Flush() - case receiveFrame: - st.writeVarint(int64(step.ftype)) - st.writeVarint(int64(len(step.data))) - st.Write(step.data) - st.Flush() - case receiveEOF: - t.Logf("receive EOF on request stream") - st.stream.stream.CloseWrite() - case wantBody: - t.Logf("read %v bytes from response body", step.size) - want := make([]byte, step.size) - for i := range want { - want[i] = byte(bytesReceived) - bytesReceived++ - } - got := make([]byte, step.size) - n, err := rt.response().Body.Read(got) - got = got[:n] - if !bytes.Equal(got, want) { - t.Errorf("resp.Body.Read:") - t.Errorf(" got: {%x}", got) - t.Fatalf(" want: {%x}", want) - } - if err != nil { - if step.eof && err == io.EOF { - continue - } - t.Fatalf("resp.Body.Read: unexpected error %v", err) - } - if step.eof { - if n, err := rt.response().Body.Read([]byte{0}); n != 0 || err != io.EOF { - t.Fatalf("resp.Body.Read() = %v, %v; want io.EOF", n, err) - } - } - case wantError: - if n, err := rt.response().Body.Read([]byte{0}); n != 0 || err == nil || err == io.EOF { - t.Fatalf("resp.Body.Read() = %v, %v; want error", n, err) - } - default: - t.Fatalf("unknown test step %T", step) - } - } - }) - } -} - func TestRoundTripRequestBodySent(t *testing.T) { runSynctest(t, func(t testing.TB) { tc := newTestClientConn(t) From b73e5746f64471c22097f07593643a743e7cfb0f Mon Sep 17 00:00:00 2001 From: Damien Neil <dneil@google.com> Date: Fri, 28 Feb 2025 10:57:15 -0800 Subject: [PATCH 12/17] http2: don't log expected errors from writing invalid trailers Change-Id: I1c8af5a1f7539a25d5602a7bc8e15756d3cafa56 Reviewed-on: https://go-review.googlesource.com/c/net/+/653695 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Auto-Submit: Damien Neil <dneil@google.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> --- http2/server_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/http2/server_test.go b/http2/server_test.go index 376227e66..b27a127a5 100644 --- a/http2/server_test.go +++ b/http2/server_test.go @@ -2814,6 +2814,8 @@ func testServerWritesTrailers(t *testing.T, withFlush bool) { w.Header().Set("Trailer", "should not be included; Forbidden by RFC 7230 4.1.2") return nil }, func(st *serverTester) { + // Ignore errors from writing invalid trailers. + st.h1server.ErrorLog = log.New(io.Discard, "", 0) getSlash(st) st.wantHeaders(wantHeader{ streamID: 1, From aad0180cad195ab7bcd14347e7ab51bece53f61d Mon Sep 17 00:00:00 2001 From: Damien Neil <dneil@google.com> Date: Fri, 28 Feb 2025 10:57:40 -0800 Subject: [PATCH 13/17] http2: fix flakiness from t.Log when GOOS=js The http2 package uses a precursor to the experimental testing/synctest package, parsing runtime.Stack output to determine when goroutines are idle. When GOOS=js, some tests which use t.Log are flaky. t.Log blocks in the syscall package writing to stdout. The GOOS=js implementation of the syscall leaves the goroutine blocked on a channel operation, which synctest interprets as the goroutine being "durably blocked". Fix the http2 synctest to treat any goroutine blocked in the syscall package as not being durably blocked. Making this fix reveals another bug when GOOS=js: Looping while calling runtime.Gosched does not appear to permit syscalls to make progress. Add a few time.Sleep(1) calls while waiting for idleness to work around the problem. While changing things in here, change http2's synctest to not treat goroutines blocked on mutex operations as durably blocked. This matches the behavior of testing/synctest. (This would all be simpler if we just used testing/synctest, but we don't want to make the http2 package depend on an experimental API.) Fixes golang/go#67693 Change-Id: I889834e97e4a33f4ef232278b1a78af00d52d261 Reviewed-on: https://go-review.googlesource.com/c/net/+/653696 Reviewed-by: Jonathan Amsterdam <jba@google.com> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Auto-Submit: Damien Neil <dneil@google.com> --- http2/sync_test.go | 54 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 45 insertions(+), 9 deletions(-) diff --git a/http2/sync_test.go b/http2/sync_test.go index aeddbd6f3..6687202d2 100644 --- a/http2/sync_test.go +++ b/http2/sync_test.go @@ -24,9 +24,10 @@ type synctestGroup struct { } type goroutine struct { - id int - parent int - state string + id int + parent int + state string + syscall bool } // newSynctest creates a new group with the synthetic clock set the provided time. @@ -76,6 +77,14 @@ func (g *synctestGroup) Wait() { return } runtime.Gosched() + if runtime.GOOS == "js" { + // When GOOS=js, we appear to need to time.Sleep to make progress + // on some syscalls. In particular, without this sleep + // writing to stdout (including via t.Log) can block forever. + for range 10 { + time.Sleep(1) + } + } } } @@ -87,6 +96,9 @@ func (g *synctestGroup) idle() bool { if !g.gids[gr.id] && !g.gids[gr.parent] { continue } + if gr.syscall { + return false + } // From runtime/runtime2.go. switch gr.state { case "IO wait": @@ -97,9 +109,6 @@ func (g *synctestGroup) idle() bool { case "chan receive": case "chan send": case "sync.Cond.Wait": - case "sync.Mutex.Lock": - case "sync.RWMutex.RLock": - case "sync.RWMutex.Lock": default: return false } @@ -138,6 +147,10 @@ func stacks(all bool) []goroutine { panic(fmt.Errorf("3 unparsable goroutine stack:\n%s", gs)) } state, rest, ok := strings.Cut(rest, "]") + isSyscall := false + if strings.Contains(rest, "\nsyscall.") { + isSyscall = true + } var parent int _, rest, ok = strings.Cut(rest, "\ncreated by ") if ok && strings.Contains(rest, " in goroutine ") { @@ -155,9 +168,10 @@ func stacks(all bool) []goroutine { } } goroutines = append(goroutines, goroutine{ - id: id, - parent: parent, - state: state, + id: id, + parent: parent, + state: state, + syscall: isSyscall, }) } return goroutines @@ -291,3 +305,25 @@ func (tm *fakeTimer) Stop() bool { delete(tm.g.timers, tm) return stopped } + +// TestSynctestLogs verifies that t.Log works, +// in particular that the GOOS=js workaround in synctestGroup.Wait is working. +// (When GOOS=js, writing to stdout can hang indefinitely if some goroutine loops +// calling runtime.Gosched; see Wait for the workaround.) +func TestSynctestLogs(t *testing.T) { + g := newSynctest(time.Now()) + donec := make(chan struct{}) + go func() { + g.Join() + for range 100 { + t.Logf("logging a long line") + } + close(donec) + }() + g.Wait() + select { + case <-donec: + default: + panic("done") + } +} From 459513d1f8abff01b4854c93ff0bff7e87985a0a Mon Sep 17 00:00:00 2001 From: Damien Neil <dneil@google.com> Date: Thu, 27 Feb 2025 10:40:55 -0800 Subject: [PATCH 14/17] internal/http3: move more common stream processing to genericConn Move the server stream-accept loop into genericConn. (Overlooked in a previous CL.) Be more consistent about having genericConn handle errors. For golang/go#70914 Change-Id: I872673482f16539e95a1a1381ada7d3e22affb82 Reviewed-on: https://go-review.googlesource.com/c/net/+/653395 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Auto-Submit: Damien Neil <dneil@google.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> --- internal/http3/conn.go | 25 +++++++++++++++++++++---- internal/http3/server.go | 17 +++-------------- internal/http3/settings.go | 13 ++++++------- internal/http3/transport.go | 6 +++--- 4 files changed, 33 insertions(+), 28 deletions(-) diff --git a/internal/http3/conn.go b/internal/http3/conn.go index e9a58471e..5eb803115 100644 --- a/internal/http3/conn.go +++ b/internal/http3/conn.go @@ -19,7 +19,7 @@ type streamHandler interface { handlePushStream(*stream) error handleEncoderStream(*stream) error handleDecoderStream(*stream) error - handleRequestStream(*stream) + handleRequestStream(*stream) error abort(error) } @@ -43,7 +43,7 @@ func (c *genericConn) acceptStreams(qconn *quic.Conn, h streamHandler) { if st.IsReadOnly() { go c.handleUnidirectionalStream(newStream(st), h) } else { - go h.handleRequestStream(newStream(st)) + go c.handleRequestStream(newStream(st), h) } } } @@ -81,7 +81,6 @@ func (c *genericConn) handleUnidirectionalStream(st *stream, h streamHandler) { // but the quic package currently doesn't allow setting error codes // for STOP_SENDING frames. // TODO: Should CloseRead take an error code? - st.stream.CloseRead() err = nil } if err == io.EOF { @@ -90,8 +89,26 @@ func (c *genericConn) handleUnidirectionalStream(st *stream, h streamHandler) { message: streamType(stype).String() + " stream closed", } } - if err != nil { + c.handleStreamError(st, h, err) +} + +func (c *genericConn) handleRequestStream(st *stream, h streamHandler) { + c.handleStreamError(st, h, h.handleRequestStream(st)) +} + +func (c *genericConn) handleStreamError(st *stream, h streamHandler, err error) { + switch err := err.(type) { + case *connectionError: h.abort(err) + case nil: + st.stream.CloseRead() + st.stream.CloseWrite() + case *streamError: + st.stream.CloseRead() + st.stream.Reset(uint64(err.code)) + default: + st.stream.CloseRead() + st.stream.Reset(uint64(errH3InternalError)) } } diff --git a/internal/http3/server.go b/internal/http3/server.go index 2d8d1df22..ca93c5298 100644 --- a/internal/http3/server.go +++ b/internal/http3/server.go @@ -86,18 +86,7 @@ func newServerConn(qconn *quic.Conn) { controlStream.writeSettings() controlStream.Flush() - // Accept streams on the connection. - for { - st, err := sc.qconn.AcceptStream(context.Background()) - if err != nil { - return // connection closed - } - if st.IsReadOnly() { - go sc.handleUnidirectionalStream(newStream(st), sc) - } else { - go sc.handleRequestStream(newStream(st)) - } - } + sc.acceptStreams(sc.qconn, sc) } func (sc *serverConn) handleControlStream(st *stream) error { @@ -165,9 +154,9 @@ func (sc *serverConn) handlePushStream(*stream) error { } } -func (sc *serverConn) handleRequestStream(st *stream) { +func (sc *serverConn) handleRequestStream(st *stream) error { // TODO - return + return nil } // abort closes the connection with an error. diff --git a/internal/http3/settings.go b/internal/http3/settings.go index 45018aadd..b5e562eca 100644 --- a/internal/http3/settings.go +++ b/internal/http3/settings.go @@ -8,7 +8,6 @@ package http3 import ( "golang.org/x/net/internal/quic/quicwire" - "golang.org/x/net/quic" ) const ( @@ -39,9 +38,9 @@ func (st *stream) writeSettings(settings ...int64) { func (st *stream) readSettings(f func(settingType, value int64) error) error { frameType, err := st.readFrameHeader() if err != nil || frameType != frameTypeSettings { - return &quic.ApplicationError{ - Code: uint64(errH3MissingSettings), - Reason: "settings not sent on control stream", + return &connectionError{ + code: errH3MissingSettings, + message: "settings not sent on control stream", } } for st.lim > 0 { @@ -59,9 +58,9 @@ func (st *stream) readSettings(f func(settingType, value int64) error) error { // https://www.rfc-editor.org/rfc/rfc9114.html#section-7.2.4.1-5 switch settingsType { case 0x02, 0x03, 0x04, 0x05: - return &quic.ApplicationError{ - Code: uint64(errH3SettingsError), - Reason: "use of reserved setting", + return &connectionError{ + code: errH3SettingsError, + message: "use of reserved setting", } } diff --git a/internal/http3/transport.go b/internal/http3/transport.go index 83bc56c2b..b26524cbd 100644 --- a/internal/http3/transport.go +++ b/internal/http3/transport.go @@ -167,14 +167,14 @@ func (cc *ClientConn) handlePushStream(*stream) error { } } -func (cc *ClientConn) handleRequestStream(st *stream) { +func (cc *ClientConn) handleRequestStream(st *stream) error { // "Clients MUST treat receipt of a server-initiated bidirectional // stream as a connection error of type H3_STREAM_CREATION_ERROR [...]" // https://www.rfc-editor.org/rfc/rfc9114.html#section-6.1-3 - cc.abort(&connectionError{ + return &connectionError{ code: errH3StreamCreationError, message: "server created bidirectional stream", - }) + } } // abort closes the connection with an error. From fe7f0391aa994a401c82d829183c1efab7a64df4 Mon Sep 17 00:00:00 2001 From: Julien Cretel <jub0bsinthecloud@gmail.com> Date: Tue, 25 Feb 2025 11:40:25 +0000 Subject: [PATCH 15/17] publicsuffix: spruce up code gen and speed up PublicSuffix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rely on functions from the slices package where convenient. Drop custom max functions in favor of max builtin. Remove unused non-exported functions. Reduce the number of bounds checks. Replace calls to strings.LastIndex by calls to strings.LastIndexByte. goos: darwin goarch: amd64 pkg: golang.org/x/net/publicsuffix cpu: Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz │ old │ new │ │ sec/op │ sec/op vs base │ PublicSuffix-8 13.46µ ± 0% 13.23µ ± 0% -1.67% (p=0.000 n=20) │ old │ new │ │ B/op │ B/op vs base │ PublicSuffix-8 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=20) ¹ ¹ all samples are equal │ old │ new │ │ allocs/op │ allocs/op vs base │ PublicSuffix-8 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=20) ¹ ¹ all samples are equal Change-Id: Id72967560884d98a5c0791ccea73dbb27d120c2c GitHub-Last-Rev: 87567e7cb5b80e0e50f2c90a8266f656e99577b5 GitHub-Pull-Request: golang/net#233 Reviewed-on: https://go-review.googlesource.com/c/net/+/652236 Reviewed-by: Damien Neil <dneil@google.com> Commit-Queue: Ian Lance Taylor <iant@golang.org> Auto-Submit: Ian Lance Taylor <iant@golang.org> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Michael Pratt <mpratt@google.com> --- publicsuffix/gen.go | 72 +++++++++----------------------------------- publicsuffix/list.go | 26 ++++++++-------- 2 files changed, 28 insertions(+), 70 deletions(-) diff --git a/publicsuffix/gen.go b/publicsuffix/gen.go index 7f7d08dbc..5f454e57e 100644 --- a/publicsuffix/gen.go +++ b/publicsuffix/gen.go @@ -21,6 +21,7 @@ package main import ( "bufio" "bytes" + "cmp" "encoding/binary" "flag" "fmt" @@ -29,7 +30,7 @@ import ( "net/http" "os" "regexp" - "sort" + "slices" "strings" "golang.org/x/net/idna" @@ -62,20 +63,6 @@ var ( maxLo uint32 ) -func max(a, b int) int { - if a < b { - return b - } - return a -} - -func u32max(a, b uint32) uint32 { - if a < b { - return b - } - return a -} - const ( nodeTypeNormal = 0 nodeTypeException = 1 @@ -83,18 +70,6 @@ const ( numNodeType = 3 ) -func nodeTypeStr(n int) string { - switch n { - case nodeTypeNormal: - return "+" - case nodeTypeException: - return "!" - case nodeTypeParentOnly: - return "o" - } - panic("unreachable") -} - const ( defaultURL = "https://publicsuffix.org/list/effective_tld_names.dat" gitCommitURL = "https://api.github.com/repos/publicsuffix/list/commits?path=public_suffix_list.dat" @@ -251,7 +226,7 @@ func main1() error { for label := range labelsMap { labelsList = append(labelsList, label) } - sort.Strings(labelsList) + slices.Sort(labelsList) combinedText = combineText(labelsList) if combinedText == "" { @@ -509,15 +484,13 @@ func (n *node) child(label string) *node { icann: true, } n.children = append(n.children, c) - sort.Sort(byLabel(n.children)) + slices.SortFunc(n.children, byLabel) return c } -type byLabel []*node - -func (b byLabel) Len() int { return len(b) } -func (b byLabel) Swap(i, j int) { b[i], b[j] = b[j], b[i] } -func (b byLabel) Less(i, j int) bool { return b[i].label < b[j].label } +func byLabel(a, b *node) int { + return strings.Compare(a.label, b.label) +} var nextNodesIndex int @@ -557,7 +530,7 @@ func assignIndexes(n *node) error { n.childrenIndex = len(childrenEncoding) lo := uint32(n.firstChild) hi := lo + uint32(len(n.children)) - maxLo, maxHi = u32max(maxLo, lo), u32max(maxHi, hi) + maxLo, maxHi = max(maxLo, lo), max(maxHi, hi) if lo >= 1<<childrenBitsLo { return fmt.Errorf("children lo %d is too large, or childrenBitsLo is too small", lo) } @@ -586,20 +559,6 @@ func printNodeLabel(w io.Writer, n *node) error { return nil } -func icannStr(icann bool) string { - if icann { - return "I" - } - return " " -} - -func wildcardStr(wildcard bool) string { - if wildcard { - return "*" - } - return " " -} - // combineText combines all the strings in labelsList to form one giant string. // Overlapping strings will be merged: "arpa" and "parliament" could yield // "arparliament". @@ -616,18 +575,15 @@ func combineText(labelsList []string) string { return text } -type byLength []string - -func (s byLength) Len() int { return len(s) } -func (s byLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -func (s byLength) Less(i, j int) bool { return len(s[i]) < len(s[j]) } +func byLength(a, b string) int { + return cmp.Compare(len(a), len(b)) +} // removeSubstrings returns a copy of its input with any strings removed // that are substrings of other provided strings. func removeSubstrings(input []string) []string { - // Make a copy of input. - ss := append(make([]string, 0, len(input)), input...) - sort.Sort(byLength(ss)) + ss := slices.Clone(input) + slices.SortFunc(ss, byLength) for i, shortString := range ss { // For each string, only consider strings higher than it in sort order, i.e. @@ -641,7 +597,7 @@ func removeSubstrings(input []string) []string { } // Remove the empty strings. - sort.Strings(ss) + slices.Sort(ss) for len(ss) > 0 && ss[0] == "" { ss = ss[1:] } diff --git a/publicsuffix/list.go b/publicsuffix/list.go index d56e9e762..56069d042 100644 --- a/publicsuffix/list.go +++ b/publicsuffix/list.go @@ -88,7 +88,7 @@ func PublicSuffix(domain string) (publicSuffix string, icann bool) { s, suffix, icannNode, wildcard := domain, len(domain), false, false loop: for { - dot := strings.LastIndex(s, ".") + dot := strings.LastIndexByte(s, '.') if wildcard { icann = icannNode suffix = 1 + dot @@ -129,7 +129,7 @@ loop: } if suffix == len(domain) { // If no rules match, the prevailing rule is "*". - return domain[1+strings.LastIndex(domain, "."):], icann + return domain[1+strings.LastIndexByte(domain, '.'):], icann } return domain[suffix:], icann } @@ -178,26 +178,28 @@ func EffectiveTLDPlusOne(domain string) (string, error) { if domain[i] != '.' { return "", fmt.Errorf("publicsuffix: invalid public suffix %q for domain %q", suffix, domain) } - return domain[1+strings.LastIndex(domain[:i], "."):], nil + return domain[1+strings.LastIndexByte(domain[:i], '.'):], nil } type uint32String string func (u uint32String) get(i uint32) uint32 { off := i * 4 - return (uint32(u[off])<<24 | - uint32(u[off+1])<<16 | - uint32(u[off+2])<<8 | - uint32(u[off+3])) + u = u[off:] // help the compiler reduce bounds checks + return uint32(u[3]) | + uint32(u[2])<<8 | + uint32(u[1])<<16 | + uint32(u[0])<<24 } type uint40String string func (u uint40String) get(i uint32) uint64 { off := uint64(i * (nodesBits / 8)) - return uint64(u[off])<<32 | - uint64(u[off+1])<<24 | - uint64(u[off+2])<<16 | - uint64(u[off+3])<<8 | - uint64(u[off+4]) + u = u[off:] // help the compiler reduce bounds checks + return uint64(u[4]) | + uint64(u[3])<<8 | + uint64(u[2])<<16 | + uint64(u[1])<<24 | + uint64(u[0])<<32 } From cde1dda944dcf6350753df966bb5bda87a544842 Mon Sep 17 00:00:00 2001 From: Damien Neil <dneil@google.com> Date: Tue, 21 Jan 2025 16:36:50 -0800 Subject: [PATCH 16/17] proxy, http/httpproxy: do not mismatch IPv6 zone ids against hosts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When matching against a host "example.com", don't match an IPv6 address like "[1000::1%25.example.com]:80". Thanks to Juho Forsén of Mattermost for reporting this issue. Fixes CVE-2025-22870 For #71984 Change-Id: I0c4fdf18765decc27e6ddf220ebe3a9bf4a6454d Reviewed-on: https://go-review.googlesource.com/c/net/+/654697 Auto-Submit: Roland Shoemaker <roland@golang.org> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Commit-Queue: Roland Shoemaker <roland@golang.org> Reviewed-by: Roland Shoemaker <roland@golang.org> Reviewed-by: Damien Neil <dneil@google.com> --- http/httpproxy/proxy.go | 10 ++- http/httpproxy/proxy_test.go | 7 ++ proxy/per_host.go | 8 +- proxy/per_host_test.go | 158 ++++++++++++++++++++++++----------- 4 files changed, 131 insertions(+), 52 deletions(-) diff --git a/http/httpproxy/proxy.go b/http/httpproxy/proxy.go index 6404aaf15..d89c257ae 100644 --- a/http/httpproxy/proxy.go +++ b/http/httpproxy/proxy.go @@ -14,6 +14,7 @@ import ( "errors" "fmt" "net" + "net/netip" "net/url" "os" "strings" @@ -177,8 +178,10 @@ func (cfg *config) useProxy(addr string) bool { if host == "localhost" { return false } - ip := net.ParseIP(host) - if ip != nil { + nip, err := netip.ParseAddr(host) + var ip net.IP + if err == nil { + ip = net.IP(nip.AsSlice()) if ip.IsLoopback() { return false } @@ -360,6 +363,9 @@ type domainMatch struct { } func (m domainMatch) match(host, port string, ip net.IP) bool { + if ip != nil { + return false + } if strings.HasSuffix(host, m.host) || (m.matchHost && host == m.host[1:]) { return m.port == "" || m.port == port } diff --git a/http/httpproxy/proxy_test.go b/http/httpproxy/proxy_test.go index 790afdab7..a1dd2e83f 100644 --- a/http/httpproxy/proxy_test.go +++ b/http/httpproxy/proxy_test.go @@ -211,6 +211,13 @@ var proxyForURLTests = []proxyForURLTest{{ }, req: "http://www.xn--fsq092h.com", want: "<nil>", +}, { + cfg: httpproxy.Config{ + NoProxy: "example.com", + HTTPProxy: "proxy", + }, + req: "http://[1000::%25.example.com]:123", + want: "http://proxy", }, } diff --git a/proxy/per_host.go b/proxy/per_host.go index d7d4b8b6e..32bdf435e 100644 --- a/proxy/per_host.go +++ b/proxy/per_host.go @@ -7,6 +7,7 @@ package proxy import ( "context" "net" + "net/netip" "strings" ) @@ -57,7 +58,8 @@ func (p *PerHost) DialContext(ctx context.Context, network, addr string) (c net. } func (p *PerHost) dialerForRequest(host string) Dialer { - if ip := net.ParseIP(host); ip != nil { + if nip, err := netip.ParseAddr(host); err == nil { + ip := net.IP(nip.AsSlice()) for _, net := range p.bypassNetworks { if net.Contains(ip) { return p.bypass @@ -108,8 +110,8 @@ func (p *PerHost) AddFromString(s string) { } continue } - if ip := net.ParseIP(host); ip != nil { - p.AddIP(ip) + if nip, err := netip.ParseAddr(host); err == nil { + p.AddIP(net.IP(nip.AsSlice())) continue } if strings.HasPrefix(host, "*.") { diff --git a/proxy/per_host_test.go b/proxy/per_host_test.go index 0447eb427..b7bcec8ae 100644 --- a/proxy/per_host_test.go +++ b/proxy/per_host_test.go @@ -7,8 +7,9 @@ package proxy import ( "context" "errors" + "fmt" "net" - "reflect" + "slices" "testing" ) @@ -22,55 +23,118 @@ func (r *recordingProxy) Dial(network, addr string) (net.Conn, error) { } func TestPerHost(t *testing.T) { - expectedDef := []string{ - "example.com:123", - "1.2.3.4:123", - "[1001::]:123", - } - expectedBypass := []string{ - "localhost:123", - "zone:123", - "foo.zone:123", - "127.0.0.1:123", - "10.1.2.3:123", - "[1000::]:123", - } - - t.Run("Dial", func(t *testing.T) { - var def, bypass recordingProxy - perHost := NewPerHost(&def, &bypass) - perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16") - for _, addr := range expectedDef { - perHost.Dial("tcp", addr) + for _, test := range []struct { + config string // passed to PerHost.AddFromString + nomatch []string // addrs using the default dialer + match []string // addrs using the bypass dialer + }{{ + config: "localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16", + nomatch: []string{ + "example.com:123", + "1.2.3.4:123", + "[1001::]:123", + }, + match: []string{ + "localhost:123", + "zone:123", + "foo.zone:123", + "127.0.0.1:123", + "10.1.2.3:123", + "[1000::]:123", + "[1000::%25.example.com]:123", + }, + }, { + config: "localhost", + nomatch: []string{ + "127.0.0.1:80", + }, + match: []string{ + "localhost:80", + }, + }, { + config: "*.zone", + nomatch: []string{ + "foo.com:80", + }, + match: []string{ + "foo.zone:80", + "foo.bar.zone:80", + }, + }, { + config: "1.2.3.4", + nomatch: []string{ + "127.0.0.1:80", + "11.2.3.4:80", + }, + match: []string{ + "1.2.3.4:80", + }, + }, { + config: "10.0.0.0/24", + nomatch: []string{ + "10.0.1.1:80", + }, + match: []string{ + "10.0.0.1:80", + "10.0.0.255:80", + }, + }, { + config: "fe80::/10", + nomatch: []string{ + "[fec0::1]:80", + "[fec0::1%en0]:80", + }, + match: []string{ + "[fe80::1]:80", + "[fe80::1%en0]:80", + }, + }, { + // We don't allow zone IDs in network prefixes, + // so this config matches nothing. + config: "fe80::%en0/10", + nomatch: []string{ + "[fec0::1]:80", + "[fec0::1%en0]:80", + "[fe80::1]:80", + "[fe80::1%en0]:80", + "[fe80::1%en1]:80", + }, + }} { + for _, addr := range test.match { + testPerHost(t, test.config, addr, true) } - for _, addr := range expectedBypass { - perHost.Dial("tcp", addr) + for _, addr := range test.nomatch { + testPerHost(t, test.config, addr, false) } + } +} - if !reflect.DeepEqual(expectedDef, def.addrs) { - t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef) - } - if !reflect.DeepEqual(expectedBypass, bypass.addrs) { - t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass) - } - }) +func testPerHost(t *testing.T, config, addr string, wantMatch bool) { + name := fmt.Sprintf("config %q, dial %q", config, addr) - t.Run("DialContext", func(t *testing.T) { - var def, bypass recordingProxy - perHost := NewPerHost(&def, &bypass) - perHost.AddFromString("localhost,*.zone,127.0.0.1,10.0.0.1/8,1000::/16") - for _, addr := range expectedDef { - perHost.DialContext(context.Background(), "tcp", addr) - } - for _, addr := range expectedBypass { - perHost.DialContext(context.Background(), "tcp", addr) - } + var def, bypass recordingProxy + perHost := NewPerHost(&def, &bypass) + perHost.AddFromString(config) + perHost.Dial("tcp", addr) - if !reflect.DeepEqual(expectedDef, def.addrs) { - t.Errorf("Hosts which went to the default proxy didn't match. Got %v, want %v", def.addrs, expectedDef) - } - if !reflect.DeepEqual(expectedBypass, bypass.addrs) { - t.Errorf("Hosts which went to the bypass proxy didn't match. Got %v, want %v", bypass.addrs, expectedBypass) - } - }) + // Dial and DialContext should have the same results. + var defc, bypassc recordingProxy + perHostc := NewPerHost(&defc, &bypassc) + perHostc.AddFromString(config) + perHostc.DialContext(context.Background(), "tcp", addr) + if !slices.Equal(def.addrs, defc.addrs) { + t.Errorf("%v: Dial default=%v, bypass=%v; DialContext default=%v, bypass=%v", name, def.addrs, bypass.addrs, defc.addrs, bypass.addrs) + return + } + + if got, want := slices.Concat(def.addrs, bypass.addrs), []string{addr}; !slices.Equal(got, want) { + t.Errorf("%v: dialed %q, want %q", name, got, want) + return + } + + gotMatch := len(bypass.addrs) > 0 + if gotMatch != wantMatch { + t.Errorf("%v: matched=%v, want %v", name, gotMatch, wantMatch) + return + } } From 85d1d54551b68719346cb9fec24b911da4e452a1 Mon Sep 17 00:00:00 2001 From: Gopher Robot <gobot@golang.org> Date: Tue, 4 Mar 2025 11:17:41 -0800 Subject: [PATCH 17/17] go.mod: update golang.org/x dependencies Update golang.org/x dependencies to their latest tagged versions. Change-Id: Id043663bf74d33d77fcea718ff308fa9461f242b Reviewed-on: https://go-review.googlesource.com/c/net/+/654320 Reviewed-by: Roland Shoemaker <roland@golang.org> LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Auto-Submit: Gopher Robot <gobot@golang.org> Auto-Submit: Junyang Shao <shaojunyang@google.com> Reviewed-by: Michael Pratt <mpratt@google.com> --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 162f7073e..37aac27a6 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module golang.org/x/net go 1.23.0 require ( - golang.org/x/crypto v0.33.0 + golang.org/x/crypto v0.35.0 golang.org/x/sys v0.30.0 golang.org/x/term v0.29.0 golang.org/x/text v0.22.0 diff --git a/go.sum b/go.sum index 553516bb0..5f95431df 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= -golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= +golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= +golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=