diff --git a/Makefile b/Makefile index e6e808d4..dd36fff6 100644 --- a/Makefile +++ b/Makefile @@ -66,8 +66,8 @@ vet: # Generate code generate: controller-gen mockgen manifests - go generate ./... $(CONTROLLER_GEN) object paths="./..." + go generate ./... .PHONY: controller-gen controller-gen: $(CONTROLLER_GEN) diff --git a/api/v1/clusterwidenetworkpolicy_types.go b/api/v1/clusterwidenetworkpolicy_types.go index 7a7b32c5..f5362514 100644 --- a/api/v1/clusterwidenetworkpolicy_types.go +++ b/api/v1/clusterwidenetworkpolicy_types.go @@ -151,6 +151,14 @@ type FQDNSelector struct { MatchPattern string `json:"matchPattern,omitempty"` } +// IPSet stores set name association to IP addresses +// type IPSet struct { +// FQDN string `json:"fqdn,omitempty"` +// SetName string `json:"setName,omitempty"` +// IPs map[string]metav1.Time `json:"ips,omitempty"` +// Version IPVersion `json:"version,omitempty"` +// } + // IPSet stores set name association to IP addresses type IPSet struct { FQDN string `json:"fqdn,omitempty"` diff --git a/pkg/dns/dnscache.go b/pkg/dns/dnscache.go index aa5be334..bb4ba940 100644 --- a/pkg/dns/dnscache.go +++ b/pkg/dns/dnscache.go @@ -4,10 +4,7 @@ import ( "crypto/md5" //nolint:gosec "encoding/hex" "fmt" - "math" - "net" "regexp" - "sort" "strings" "sync" "time" @@ -16,7 +13,7 @@ import ( "github.com/go-logr/logr" "github.com/google/nftables" dnsgo "github.com/miekg/dns" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + // metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" firewallv1 "github.com/metal-stack/firewall-controller/v2/api/v1" ) @@ -42,31 +39,23 @@ type RenderIPSet struct { } type ipEntry struct { - ips []string - expirationTime time.Time - setName string + // ips is a map of the ip address and its expiration time which is the time of the DNS lookup + the TTL + ips map[string]time.Time + setName string } -func newIPEntry(setName string, expirationTime time.Time) *ipEntry { +func newIPEntry(setName string) *ipEntry { return &ipEntry{ - expirationTime: expirationTime, - setName: setName, + setName: setName, + ips: map[string]time.Time{}, } } -func (e *ipEntry) update(setName string, ips []net.IP, expirationTime time.Time, dtype nftables.SetDatatype) error { - newIPs, deletedIPs := e.getNewAndDeletedIPs(ips) - if !e.expirationTime.After(time.Now()) { - e.expirationTime = expirationTime - } +func (e *ipEntry) update(log logr.Logger, setName string, rrs []dnsgo.RR, lookupTime time.Time, dtype nftables.SetDatatype) error { + deletedIPs := e.expireIPs() + newIPs := e.addAndUpdateIPs(log, rrs, lookupTime) if newIPs != nil || deletedIPs != nil { - e.ips = make([]string, len(ips)) - for i, ip := range ips { - e.ips[i] = ip.String() - } - sort.Strings(e.ips) - if err := updateNftSet(newIPs, deletedIPs, setName, dtype); err != nil { return fmt.Errorf("failed to update nft set: %w", err) } @@ -75,27 +64,32 @@ func (e *ipEntry) update(setName string, ips []net.IP, expirationTime time.Time, return nil } -func (e *ipEntry) getNewAndDeletedIPs(ips []net.IP) (newIPs, deletedIPs []nftables.SetElement) { - currentIps := make(map[string]bool, len(e.ips)) - for _, ip := range e.ips { - currentIps[ip] = false - } - - for _, ip := range ips { - s := ip.String() - if _, ok := currentIps[s]; ok { - currentIps[s] = true - } else { - newIPs = append(newIPs, nftables.SetElement{Key: ip}) +func (e *ipEntry) expireIPs() (deletedIPs []nftables.SetElement) { + for ip, expirationTime := range e.ips { + if expirationTime.Before(time.Now()) { + deletedIPs = append(deletedIPs, nftables.SetElement{Key: []byte(ip)}) + delete(e.ips, ip) } } + return +} - for ip, exists := range currentIps { - if !exists { - deletedIPs = append(deletedIPs, nftables.SetElement{Key: net.ParseIP(ip)}) +func (e *ipEntry) addAndUpdateIPs(log logr.Logger, rrs []dnsgo.RR, lookupTime time.Time) (newIPs []nftables.SetElement) { + for _, rr := range rrs { + var s string + switch r := rr.(type) { + case *dnsgo.A: + s = r.A.String() + case *dnsgo.AAAA: + s = r.AAAA.String() } - } + if _, ok := e.ips[s]; !ok { + newIPs = append(newIPs, nftables.SetElement{Key: []byte(s)}) + } + log.WithValues("ip", s, "rr header ttl", rr.Header().Ttl, "expiration time", lookupTime.Add(time.Duration(rr.Header().Ttl)*time.Second)) + e.ips[s] = lookupTime.Add(time.Duration(rr.Header().Ttl) * time.Second) + } return } @@ -197,9 +191,17 @@ func (c *DNSCache) restoreSets(fqdnSets []firewallv1.IPSet) { } ipe := &ipEntry{ - ips: s.IPs, - expirationTime: s.ExpirationTime.Time, - setName: s.SetName, + setName: s.SetName, + } + for _, ip := range s.IPs { + ipa, _, _ := strings.Cut(ip, ",") + expirationTime := time.Now() + if _, ets, found := strings.Cut(ip, ": "); found { + if err := expirationTime.UnmarshalText([]byte(ets)); err != nil { + expirationTime = time.Now() + } + } + ipe.ips[ipa] = expirationTime } switch s.Version { case firewallv1.IPv4: @@ -311,10 +313,8 @@ func (c *DNSCache) Update(lookupTime time.Time, qname string, msg *dnsgo.Msg, fq return true, fmt.Errorf("too many hops, fqdn chain: %s", strings.Join(fqdns, ",")) } - ipv4 := []net.IP{} - ipv6 := []net.IP{} - minIPv4TTL := uint32(math.MaxUint32) - minIPv6TTL := uint32(math.MaxUint32) + ipv4 := []dnsgo.RR{} + ipv6 := []dnsgo.RR{} found := false for _, ans := range msg.Answer { @@ -326,17 +326,11 @@ func (c *DNSCache) Update(lookupTime time.Time, qname string, msg *dnsgo.Msg, fq switch rr := ans.(type) { case *dnsgo.A: - ipv4 = append(ipv4, rr.A) - if minIPv4TTL > rr.Hdr.Ttl { - minIPv4TTL = rr.Hdr.Ttl - } + ipv4 = append(ipv4, rr) found = true c.log.V(4).Info("DEBUG dnscache Update function A record found", "IPs", ipv4) case *dnsgo.AAAA: - ipv6 = append(ipv6, rr.AAAA) - if minIPv6TTL > rr.Hdr.Ttl { - minIPv6TTL = rr.Hdr.Ttl - } + ipv6 = append(ipv6, rr) found = true c.log.V(4).Info("DEBUG dnscache Update function AAAA record found", "IPs", ipv6) case *dnsgo.CNAME: @@ -362,12 +356,12 @@ func (c *DNSCache) Update(lookupTime time.Time, qname string, msg *dnsgo.Msg, fq for _, fqdn := range fqdns { c.log.V(4).Info("DEBUG dnscache Update function Updating DNS cache for", "fqdn", fqdn, "ipv4", ipv4, "ipv6", ipv6) if c.ipv4Enabled && len(ipv4) > 0 { - if err := c.updateIPEntry(fqdn, ipv4, lookupTime.Add(time.Duration(minIPv4TTL)), nftables.TypeIPAddr); err != nil { + if err := c.updateIPEntry(fqdn, ipv4, lookupTime, nftables.TypeIPAddr); err != nil { return false, fmt.Errorf("failed to update IPv4 addresses: %w", err) } } if c.ipv6Enabled && len(ipv6) > 0 { - if err := c.updateIPEntry(fqdn, ipv6, lookupTime.Add(time.Duration(minIPv6TTL)), nftables.TypeIP6Addr); err != nil { + if err := c.updateIPEntry(fqdn, ipv6, lookupTime, nftables.TypeIP6Addr); err != nil { return false, fmt.Errorf("failed to update IPv6 addresses: %w", err) } } @@ -376,10 +370,10 @@ func (c *DNSCache) Update(lookupTime time.Time, qname string, msg *dnsgo.Msg, fq return found, nil } -func (c *DNSCache) updateIPEntry(qname string, ips []net.IP, expirationTime time.Time, dtype nftables.SetDatatype) error { +func (c *DNSCache) updateIPEntry(qname string, rrs []dnsgo.RR, lookupTime time.Time, dtype nftables.SetDatatype) error { scopedLog := c.log.WithValues( "fqdn", qname, - "ip_len", len(ips), + "ip_len", len(rrs), "dtype", dtype.Name, ) @@ -396,21 +390,22 @@ func (c *DNSCache) updateIPEntry(qname string, ips []net.IP, expirationTime time case nftables.TypeIPAddr: if entry.ipv4 == nil { setName := c.createSetName(qname, dtype.Name, 0) - ipe = newIPEntry(setName, expirationTime) + ipe = newIPEntry(setName) entry.ipv4 = ipe } ipe = entry.ipv4 case nftables.TypeIP6Addr: if entry.ipv6 == nil { setName := c.createSetName(qname, dtype.Name, 0) - ipe = newIPEntry(setName, expirationTime) + ipe = newIPEntry(setName) entry.ipv6 = ipe } ipe = entry.ipv6 } setName := ipe.setName - if err := ipe.update(setName, ips, expirationTime, dtype); err != nil { + scopedLog.WithValues("set", setName, "lookupTime", lookupTime, "rrs", rrs).Info("updating ip entry") + if err := ipe.update(scopedLog, setName, rrs, lookupTime, dtype); err != nil { return fmt.Errorf("failed to update ipEntry: %w", err) } c.fqdnToEntry[qname] = entry @@ -478,19 +473,29 @@ func updateNftSet( } func createIPSetFromIPEntry(fqdn string, version firewallv1.IPVersion, entry *ipEntry) firewallv1.IPSet { - return firewallv1.IPSet{ - FQDN: fqdn, - SetName: entry.setName, - IPs: entry.ips, - ExpirationTime: metav1.Time{Time: entry.expirationTime}, - Version: version, + ips := firewallv1.IPSet{ + FQDN: fqdn, + SetName: entry.setName, + IPs: []string{}, + Version: version, } + for ip, expirationTime := range entry.ips { + if et, err := expirationTime.MarshalText(); err == nil { + ip = ip + ", expiration time: " + string(et) + } + ips.IPs = append(ips.IPs, ip) + } + return ips } func createRenderIPSetFromIPEntry(version IPVersion, entry *ipEntry) RenderIPSet { + var ips []string + for ip, _ := range entry.ips { + ips = append(ips, ip) + } return RenderIPSet{ SetName: entry.setName, - IPs: entry.ips, + IPs: ips, Version: version, } } diff --git a/pkg/nftables/networkpolicy.go b/pkg/nftables/networkpolicy.go index cb46d0d7..b29206bb 100644 --- a/pkg/nftables/networkpolicy.go +++ b/pkg/nftables/networkpolicy.go @@ -78,6 +78,7 @@ func clusterwideNetworkPolicyEgressRules( np firewallv1.ClusterwideNetworkPolicy, logAcceptedConnections bool, ) (rules nftablesRules, updated firewallv1.ClusterwideNetworkPolicy) { + var fqdnState firewallv1.FQDNState for _, e := range np.Spec.Egress { tcpPorts, udpPorts := calculatePorts(e.Ports) ruleBases := []ruleBase{} @@ -95,9 +96,9 @@ func clusterwideNetworkPolicyEgressRules( ruleBases = append(ruleBases, ruleBase{base: rb}) } else if len(e.ToFQDNs) > 0 && cache.IsInitialized() { // Generate allow rules based on DNS selectors - rbs, u := clusterwideNetworkPolicyEgressToFQDNRules(cache, e) - np.Status.FQDNState = u + rbs, u := clusterwideNetworkPolicyEgressToFQDNRules(cache, fqdnState, e) ruleBases = append(ruleBases, rbs...) + fqdnState = u } comment := fmt.Sprintf("accept traffic for np %s", np.ObjectMeta.Name) @@ -111,6 +112,7 @@ func clusterwideNetworkPolicyEgressRules( } } + np.Status.FQDNState = fqdnState return uniqueSorted(rules), np } @@ -125,9 +127,12 @@ func clusterwideNetworkPolicyEgressToRules(e firewallv1.EgressRule) (allow, exce func clusterwideNetworkPolicyEgressToFQDNRules( cache FQDNCache, + fqdnState firewallv1.FQDNState, e firewallv1.EgressRule, ) (rules []ruleBase, updatedState firewallv1.FQDNState) { - fqdnState := firewallv1.FQDNState{} + if fqdnState == nil { + fqdnState = firewallv1.FQDNState{} + } for _, fqdn := range e.ToFQDNs { fqdnName := fqdn.MatchName