// Copyright 2015 The etcd Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package srv looks up DNS SRV records. package srv import ( "fmt" "net" "net/url" "strings" "go.etcd.io/etcd/pkg/types" ) var ( // indirection for testing lookupSRV = net.LookupSRV // net.DefaultResolver.LookupSRV when ctxs don't conflict resolveTCPAddr = net.ResolveTCPAddr ) // GetCluster gets the cluster information via DNS discovery. // Also sees each entry as a separate instance. func GetCluster(serviceScheme, service, name, dns string, apurls types.URLs) ([]string, error) { tempName := int(0) tcp2ap := make(map[string]url.URL) // First, resolve the apurls for _, url := range apurls { tcpAddr, err := resolveTCPAddr("tcp", url.Host) if err != nil { return nil, err } tcp2ap[tcpAddr.String()] = url } stringParts := []string{} updateNodeMap := func(service, scheme string) error { _, addrs, err := lookupSRV(service, "tcp", dns) if err != nil { return err } for _, srv := range addrs { port := fmt.Sprintf("%d", srv.Port) host := net.JoinHostPort(srv.Target, port) tcpAddr, terr := resolveTCPAddr("tcp", host) if terr != nil { err = terr continue } n := "" url, ok := tcp2ap[tcpAddr.String()] if ok { n = name } if n == "" { n = fmt.Sprintf("%d", tempName) tempName++ } // SRV records have a trailing dot but URL shouldn't. shortHost := strings.TrimSuffix(srv.Target, ".") urlHost := net.JoinHostPort(shortHost, port) if ok && url.Scheme != scheme { err = fmt.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String()) } else { stringParts = append(stringParts, fmt.Sprintf("%s=%s://%s", n, scheme, urlHost)) } } if len(stringParts) == 0 { return err } return nil } err := updateNodeMap(service, serviceScheme) if err != nil { return nil, fmt.Errorf("error querying DNS SRV records for _%s %s", service, err) } return stringParts, nil } type SRVClients struct { Endpoints []string SRVs []*net.SRV } // GetClient looks up the client endpoints for a service and domain. func GetClient(service, domain string, serviceName string) (*SRVClients, error) { var urls []*url.URL var srvs []*net.SRV updateURLs := func(service, scheme string) error { _, addrs, err := lookupSRV(service, "tcp", domain) if err != nil { return err } for _, srv := range addrs { urls = append(urls, &url.URL{ Scheme: scheme, Host: net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)), }) } srvs = append(srvs, addrs...) return nil } errHTTPS := updateURLs(GetSRVService(service, serviceName, "https"), "https") errHTTP := updateURLs(GetSRVService(service, serviceName, "http"), "http") if errHTTPS != nil && errHTTP != nil { return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP) } endpoints := make([]string, len(urls)) for i := range urls { endpoints[i] = urls[i].String() } return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil } // GetSRVService generates a SRV service including an optional suffix. func GetSRVService(service, serviceName string, scheme string) (SRVService string) { if scheme == "https" { service = fmt.Sprintf("%s-ssl", service) } if serviceName != "" { return fmt.Sprintf("%s-%s", service, serviceName) } return service }