Skip to content

Commit

Permalink
httpresolver: Always use the fallback resolver
Browse files Browse the repository at this point in the history
When the httpresolver needs to resolve the name of our upstream server
(e.g. `dns.google.`), it uses the system resolver.

Normally the system resover is either configured to work fine
independently, or will point back to dnss, in which case dnss will use
`-fallback_domains` to help identify these recursive requests and use
`-fallback_upstream` to get an answer.

However, in some cases, the system resolver won't work properly or send
some requests to dnss that don't fall within `-fallback_domains`. An
example of this was reported in
#9, where `systemd-resolved`
does some DNSSEC resolutions to `.` and `google.` causing it to
mis-behave.

This patch makes dnss always use the fallback resolver, without falling
back to the system resolver at all. This should result in more
predictible behaviour and simpler setups, as now dnss should be fully
independent from the system resolver.

Thanks to David Mandelberg ([email protected]) for finding and
helping debug this issue.
  • Loading branch information
albertito committed Mar 5, 2021
1 parent c98400d commit 5567591
Show file tree
Hide file tree
Showing 8 changed files with 368 additions and 122 deletions.
38 changes: 4 additions & 34 deletions dnss.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@ import (
"fmt"
"net/http"
"net/url"
"strings"
"sync"

"golang.org/x/net/http/httpproxy"

"blitiri.com.ar/go/dnss/internal/dnsserver"
"blitiri.com.ar/go/dnss/internal/httpresolver"
"blitiri.com.ar/go/dnss/internal/httpserver"
Expand All @@ -38,10 +35,8 @@ var (
"DNS server to forward unqualified requests to")

fallbackUpstream = flag.String("fallback_upstream", "8.8.8.8:53",
"DNS server to resolve domains in --fallback_domains")
fallbackDomains = flag.String("fallback_domains", "dns.google.",
"Domains we resolve via DNS, using --fallback_upstream"+
" (space-separated list)")
"DNS server used to resolve domains in -https_upstream"+
" (including proxy if needed)")

enableDNStoHTTPS = flag.Bool("enable_dns_to_https", false,
"enable DNS-to-HTTPS proxy")
Expand Down Expand Up @@ -74,6 +69,7 @@ var (
_ = flag.Duration("log_flush_every", 0, "deprecated, will be removed")
_ = flag.Bool("logtostderr", false, "deprecated, will be removed")
_ = flag.String("force_mode", "", "deprecated, will be removed")
_ = flag.String("fallback_domains", "", "deprecated, will be removed")
)

func main() {
Expand Down Expand Up @@ -101,7 +97,7 @@ func main() {
}

var resolver dnsserver.Resolver
resolver = httpresolver.NewDoH(upstream, *httpsClientCAFile)
resolver = httpresolver.NewDoH(upstream, *httpsClientCAFile, *fallbackUpstream)

if *enableCache {
cr := dnsserver.NewCachingResolver(resolver)
Expand All @@ -110,15 +106,6 @@ func main() {
}
dth := dnsserver.New(*dnsListenAddr, resolver, *dnsUnqualifiedUpstream)

// If we're using an HTTP proxy, add the name to the fallback domain
// so we don't have problems resolving it.
fallbackDoms := strings.Split(*fallbackDomains, " ")
if proxyDomain := proxyServerDomain(); proxyDomain != "" {
log.Infof("Adding proxy %q to fallback domains", proxyDomain)
fallbackDoms = append(fallbackDoms, proxyDomain)
}

dth.SetFallback(*fallbackUpstream, fallbackDoms)
wg.Add(1)
go func() {
defer wg.Done()
Expand Down Expand Up @@ -146,23 +133,6 @@ func main() {
wg.Wait()
}

// proxyServerDomain checks if we're using an HTTP proxy server, and if so
// returns its domain.
func proxyServerDomain() string {
url, err := url.Parse(*httpsUpstream)
if err != nil {
return ""
}

proxyFunc := httpproxy.FromEnvironment().ProxyFunc()
proxyURL, err := proxyFunc(url)
if err != nil || proxyURL == nil {
return ""
}

return proxyURL.Hostname()
}

func launchMonitoringServer(addr string) {
log.Infof("Monitoring HTTP server listening on %s", addr)

Expand Down
41 changes: 5 additions & 36 deletions dnss_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ func Setup(tb testing.TB) string {
tb.Fatalf("invalid URL: %v", err)
}

r := httpresolver.NewDoH(HTTPSToDNSURL, "")
// Create the DoH resolver and DNS server backed by it.
// Note that we use an invalid address as fallback resolver - since we use
// IP addresses directly in the http requests, the fallback resolver
// should not be needed.
r := httpresolver.NewDoH(HTTPSToDNSURL, "", "0.0.0.0:0")
dtoh := dnsserver.New(DNSToHTTPSAddr, r, "")
go dtoh.ListenAndServe()

Expand Down Expand Up @@ -185,41 +189,6 @@ func BenchmarkSimple(b *testing.B) {
/////////////////////////////////////////////////////////////////////
// Tests for main-specific helpers

func TestProxyServerDomain(t *testing.T) {
prevProxy, wasSet := os.LookupEnv("HTTPS_PROXY")

// Valid case, proxy set.
os.Setenv("HTTPS_PROXY", "http://proxy:1234/p")
*httpsUpstream = "https://montoto/xyz"
if got := proxyServerDomain(); got != "proxy" {
t.Errorf("got %q, expected 'proxy'", got)
}

// Valid case, proxy not set.
os.Unsetenv("HTTPS_PROXY")
*httpsUpstream = "https://montoto/xyz"
if got := proxyServerDomain(); got != "" {
t.Errorf("got %q, expected ''", got)
}

// Invalid upstream URL.
*httpsUpstream = "in%20valid:url"
if got := proxyServerDomain(); got != "" {
t.Errorf("got %q, expected ''", got)
}

// Invalid proxy.
os.Setenv("HTTPS_PROXY", "invalid value")
*httpsUpstream = "https://montoto/xyz"
if got := proxyServerDomain(); got != "" {
t.Errorf("got %q, expected ''", got)
}

if wasSet {
os.Setenv("HTTPS_PROXY", prevProxy)
}
}

func TestMonitoringServer(t *testing.T) {
addr := testutil.GetFreePort()
launchMonitoringServer(addr)
Expand Down
33 changes: 3 additions & 30 deletions internal/dnsserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,15 @@ type Server struct {
Addr string
unqUpstream string
resolver Resolver

fallbackDomains map[string]struct{}
fallbackUpstream string
}

// New *Server, which will listen on addr, use resolver as the backend
// resolver, and use unqUpstream to resolve unqualified queries.
func New(addr string, resolver Resolver, unqUpstream string) *Server {
return &Server{
Addr: addr,
resolver: resolver,
unqUpstream: unqUpstream,
fallbackDomains: map[string]struct{}{},
}
}

// SetFallback upstream server for the given domains.
func (s *Server) SetFallback(upstream string, domains []string) {
s.fallbackUpstream = upstream
for _, d := range domains {
s.fallbackDomains[d] = struct{}{}
Addr: addr,
resolver: resolver,
unqUpstream: unqUpstream,
}
}

Expand Down Expand Up @@ -109,21 +97,6 @@ func (s *Server) Handler(w dns.ResponseWriter, r *dns.Msg) {
return
}

// Forward to the fallback server if the domain is on our list.
if _, ok := s.fallbackDomains[r.Question[0].Name]; ok {
u, err := dns.Exchange(r, s.fallbackUpstream)
if err == nil {
tr.LazyPrintf("used fallback upstream (%s)", s.fallbackUpstream)
util.TraceAnswer(tr, u)
w.WriteMsg(u)
} else {
tr.LazyPrintf("fallback upstream error: %v", err)
dns.HandleFailed(w, r)
}

return
}

// Create our own IDs, in case different users pick the same id and we
// pass that upstream.
oldid := r.Id
Expand Down
11 changes: 0 additions & 11 deletions internal/dnsserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,12 @@ func TestServe(t *testing.T) {
go testutil.ServeTestDNSServer(unqUpstreamAddr,
testutil.MakeStaticHandler(t, "unq. A 2.2.2.2"))

fallbackAddr := testutil.GetFreePort()
go testutil.ServeTestDNSServer(fallbackAddr,
testutil.MakeStaticHandler(t, "fallback. A 3.3.3.3"))

srv := New(testutil.GetFreePort(), res, unqUpstreamAddr)
srv.SetFallback(fallbackAddr, []string{"one.fallback.", "two.fallback."})
go srv.ListenAndServe()
testutil.WaitForDNSServer(srv.Addr)

query(t, srv.Addr, "response.test.", "1.1.1.1")
query(t, srv.Addr, "unqualified.", "2.2.2.2")
query(t, srv.Addr, "one.fallback.", "3.3.3.3")
query(t, srv.Addr, "two.fallback.", "3.3.3.3")
}

func query(t *testing.T, srv, domain, expected string) {
Expand All @@ -56,17 +49,13 @@ func TestBadUpstreams(t *testing.T) {
// Get addresses but don't start the servers, so we get an error when
// trying to reach them.
unqUpstreamAddr := testutil.GetFreePort()
fallbackAddr := testutil.GetFreePort()

srv := New(testutil.GetFreePort(), res, unqUpstreamAddr)
srv.SetFallback(fallbackAddr, []string{"one.fallback.", "two.fallback."})
go srv.ListenAndServe()
testutil.WaitForDNSServer(srv.Addr)

queryFailure(t, srv.Addr, "response.test.")
queryFailure(t, srv.Addr, "unqualified.")
queryFailure(t, srv.Addr, "one.fallback.")
queryFailure(t, srv.Addr, "two.fallback.")
}

func queryFailure(t *testing.T, srv, domain string) {
Expand Down
26 changes: 24 additions & 2 deletions internal/httpresolver/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package httpresolver

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"fmt"
Expand All @@ -28,6 +29,10 @@ type httpsResolver struct {
CAFile string
tlsConfig *tls.Config

// net.Resolver that will contact the server at --fallback_upstream for
// DNS resolutions.
fallbackResolver *net.Resolver

mu sync.Mutex
client *http.Client
firstErr time.Time
Expand All @@ -51,11 +56,27 @@ func loadCertPool(caFile string) (*x509.CertPool, error) {

// NewDoH creates a new DoH resolver, which uses the given upstream
// URL to resolve queries.
func NewDoH(upstream *url.URL, caFile string) *httpsResolver {
return &httpsResolver{
func NewDoH(upstream *url.URL, caFile, fallback string) *httpsResolver {
r := &httpsResolver{
Upstream: upstream,
CAFile: caFile,
}

if fallback != "" {
// Dial function that will always use the fallback address to contact
// DNS.
dialer := net.Dialer{}
dialFallback := func(ctx context.Context, network, address string) (net.Conn, error) {
return dialer.DialContext(ctx, network, fallback)
}

r.fallbackResolver = &net.Resolver{
PreferGo: true, // Avoid the system resolver.
Dial: dialFallback,
}
}

return r
}

func (r *httpsResolver) Init() error {
Expand Down Expand Up @@ -101,6 +122,7 @@ func (r *httpsResolver) newClient() (*http.Client, error) {
Timeout: 10 * time.Second,
KeepAlive: 1 * time.Second,
DualStack: true,
Resolver: r.fallbackResolver,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 10,
Expand Down
33 changes: 24 additions & 9 deletions tests/external.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ function dnss() {
PID=$!
}

# Run minidns in the background (sets $MINIDNS_PID to its process id).
function minidns() {
go run tests/minidns.go \
-addr ":1953" \
-zones tests/testzones \
> .minidns.log 2>&1 &
MINIDNS_PID=$!
}

# Wait until there's something listening on the given port.
function wait_until_ready() {
PROTO=$1
Expand Down Expand Up @@ -88,11 +97,11 @@ function get() {
}

function generate_certs() {
mkdir -p .certs/localhost
mkdir -p .certs/$1
(
cd .certs/localhost
cd .certs/$1
go run ../../tests/generate_cert.go \
-ca -duration=1h --host=localhost
-ca -duration=1h --host=$1
)
}

Expand All @@ -105,6 +114,9 @@ if wait $PID; then
exit 1
fi

echo "## Launching minidns for testing"
minidns
wait_until_ready tcp 1953

echo "## Launching HTTPS server"
dnss -enable_https_to_dns \
Expand All @@ -126,7 +138,8 @@ fi

echo "## DoH against dnss"
dnss -enable_dns_to_https -dns_listen_addr "localhost:1053" \
-https_upstream "http://localhost:1999/dns-query"
-fallback_upstream "127.0.0.1:1953" \
-https_upstream "http://upstream:1999/dns-query"

# Exercise DoH via GET (dnss always uses POST).
get "http://localhost:1999/resolve?&dns=q80BAAABAAAAAAAAA3d3dwdleGFtcGxlA2NvbQAAAQAB"
Expand All @@ -148,23 +161,25 @@ kill $HTTP_PID


echo "## HTTPS with custom certificates"
generate_certs
generate_certs upstream
dnss -enable_https_to_dns \
-https_key .certs/localhost/privkey.pem \
-https_cert .certs/localhost/fullchain.pem \
-https_key .certs/upstream/privkey.pem \
-https_cert .certs/upstream/fullchain.pem \
-https_server_addr "localhost:1999"
HTTP_PID=$PID
mv .dnss.log .dnss.http.log
wait_until_ready tcp 1999

dnss -enable_dns_to_https -dns_listen_addr "localhost:1053" \
-https_client_cafile .certs/localhost/fullchain.pem \
-https_upstream "https://localhost:1999/dns-query"
-fallback_upstream "127.0.0.1:1953" \
-https_client_cafile .certs/upstream/fullchain.pem \
-https_upstream "https://upstream:1999/dns-query"

resolve

kill $PID
kill $HTTP_PID
kill $MINIDNS_PID


# DoH integration test against some publicly available servers.
Expand Down
Loading

0 comments on commit 5567591

Please sign in to comment.