From b8cf135c4570ee6a693b5a1991a0ace198c25fd4 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Tue, 25 Aug 2020 01:29:10 -0600 Subject: [PATCH] add check from ip list --- .gitignore | 1 + cmd/iplist/iplist.go | 46 +++++++++++++++ iplist/iplist.go | 137 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 184 insertions(+) create mode 100644 cmd/iplist/iplist.go create mode 100644 iplist/iplist.go diff --git a/.gitignore b/.gitignore index 8e22740..a5d4630 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ telebit-client-macos telebit-client-windows-debug.exe telebit-client-windows.exe +/cmd/iplist/iplist /cmd/machineid/machineid /cmd/dnsclient/dnsclient /cmd/sqlstore/sqlstore diff --git a/cmd/iplist/iplist.go b/cmd/iplist/iplist.go new file mode 100644 index 0000000..3340088 --- /dev/null +++ b/cmd/iplist/iplist.go @@ -0,0 +1,46 @@ +package main + +import ( + "fmt" + "net" + "os" + + "git.rootprojects.org/root/telebit/iplist" +) + +func help() { + fmt.Fprintf(os.Stderr, "Usage: iplist domain.tld 123.45.6.78\n") + fmt.Fprintf(os.Stderr, "(`dig TXT +short domain.tld` should return a list like `v=spf1 ip4:123.45.6.78 ip4:123.45.6.1/24`\n") + os.Exit(1) +} + +func main() { + if 3 != len(os.Args) { + help() + return + } + + txtDomain := os.Args[1] + remoteIP := net.ParseIP(os.Args[2]) + if nil == remoteIP { + fmt.Fprintf(os.Stderr, "bad remote IP\n") + os.Exit(1) + return + } + + iplist.Init(txtDomain) + + allowed, err := iplist.IsAllowed(remoteIP) + if nil != err { + fmt.Fprintf(os.Stderr, err.Error()) + os.Exit(1) + return + } + if !allowed { + fmt.Fprintf(os.Stderr, "not allowed\n") + os.Exit(1) + return + } + + fmt.Println("allowed") +} diff --git a/iplist/iplist.go b/iplist/iplist.go new file mode 100644 index 0000000..c1d4a68 --- /dev/null +++ b/iplist/iplist.go @@ -0,0 +1,137 @@ +package iplist + +import ( + "errors" + "fmt" + "net" + "os" + "strings" + "time" +) + +var fields []string +var initialized bool + +// Init should be called with domain that has valid SPF-like records +// to populate the IP whitelist, or with an empty string "" to disable +func Init(txtDomain string) []string { + initialized = true + if "" == txtDomain { + return []string{} + } + + err := updateTxt(txtDomain) + if nil != err { + panic(err) + } + go func() { + for { + time.Sleep(5 * time.Minute) + if err := updateTxt(txtDomain); nil != err { + fmt.Fprintf(os.Stderr, "warn: could not update iplist: %s\n", err) + continue + } + } + }() + + for _, section := range fields { + parts := strings.Split(section, ":") + if 2 != len(parts) || !strings.HasPrefix(parts[0], "ip") { + // ignore unsupported bits + // (i.e. +mx +ip include:xxx) + continue + } + ip := parts[1] + + if strings.Contains(ip, "/") { + _, _, err := net.ParseCIDR(ip) + if nil != err { + panic(fmt.Errorf("invalid CIDR %q", ip)) + } + continue + } + + ipAddr := net.ParseIP(ip) + if nil == ipAddr { + panic(fmt.Errorf( + "IP %q from SPF record could not be parsed", + ipAddr.String(), + )) + } + } + return fields +} + +func updateTxt(txtDomain string) error { + var newFields []string + records, err := net.LookupTXT(txtDomain) + if nil != err { + return fmt.Errorf("bad spf-domain: %s", err) + } + for _, record := range records { + newFields, err = parseSpf(record) + if nil != err { + continue + } + if len(fields) > 0 { + break + } + } + + // TODO put a lock here? + fields = newFields + return nil +} + +// IsAllowed returns true if the given IP matches an IP or CIDR in +// the whitelist, or if the spf-domain is an empty string explicitly +func IsAllowed(remoteIP net.IP) (bool, error) { + if !initialized { + panic(fmt.Errorf("was not initialized")) + } + if 0 == len(fields) { + return true, nil + } + + for _, section := range fields { + parts := strings.Split(section, ":") + if 2 != len(parts) || !strings.HasPrefix(parts[0], "ip") { + // ignore unsupported bits + // (i.e. +mx +ip include:xxx) + continue + } + ip := parts[1] + + if strings.Contains(ip, "/") { + _, ipNet, err := net.ParseCIDR(ip) + if nil != err { + return false, fmt.Errorf("invalid CIDR %q", ip) + } + return ipNet.Contains(remoteIP), nil + } + + ipAddr := net.ParseIP(ip) + if nil == ipAddr { + return false, fmt.Errorf( + "IP %q from SPF record could not be parsed", + ipAddr.String(), + ) + } + if remoteIP.Equal(ipAddr) { + return true, nil + } + } + return false, nil +} + +func parseSpf(spf1 string) ([]string, error) { + fields := strings.Fields(spf1) + if len(fields) < 1 || + len(fields[0]) < 1 || + !strings.HasPrefix(fields[0], "v=") { + return nil, errors.New("missing v=") + } + fields = fields[1:] + + return fields, nil +}