add check from ip list

This commit is contained in:
AJ ONeal 2020-08-25 01:29:10 -06:00
parent e7db7b9bba
commit b8cf135c45
3 changed files with 184 additions and 0 deletions

1
.gitignore vendored
View File

@ -16,6 +16,7 @@ telebit-client-macos
telebit-client-windows-debug.exe
telebit-client-windows.exe
/cmd/iplist/iplist
/cmd/machineid/machineid
/cmd/dnsclient/dnsclient
/cmd/sqlstore/sqlstore

46
cmd/iplist/iplist.go Normal file
View File

@ -0,0 +1,46 @@
package main
import (
"fmt"
"net"
"os"
"git.rootprojects.org/root/telebit/iplist"
)
func help() {
fmt.Fprintf(os.Stderr, "Usage: iplist domain.tld 123.45.6.78\n")
fmt.Fprintf(os.Stderr, "(`dig TXT +short domain.tld` should return a list like `v=spf1 ip4:123.45.6.78 ip4:123.45.6.1/24`\n")
os.Exit(1)
}
func main() {
if 3 != len(os.Args) {
help()
return
}
txtDomain := os.Args[1]
remoteIP := net.ParseIP(os.Args[2])
if nil == remoteIP {
fmt.Fprintf(os.Stderr, "bad remote IP\n")
os.Exit(1)
return
}
iplist.Init(txtDomain)
allowed, err := iplist.IsAllowed(remoteIP)
if nil != err {
fmt.Fprintf(os.Stderr, err.Error())
os.Exit(1)
return
}
if !allowed {
fmt.Fprintf(os.Stderr, "not allowed\n")
os.Exit(1)
return
}
fmt.Println("allowed")
}

137
iplist/iplist.go Normal file
View File

@ -0,0 +1,137 @@
package iplist
import (
"errors"
"fmt"
"net"
"os"
"strings"
"time"
)
var fields []string
var initialized bool
// Init should be called with domain that has valid SPF-like records
// to populate the IP whitelist, or with an empty string "" to disable
func Init(txtDomain string) []string {
initialized = true
if "" == txtDomain {
return []string{}
}
err := updateTxt(txtDomain)
if nil != err {
panic(err)
}
go func() {
for {
time.Sleep(5 * time.Minute)
if err := updateTxt(txtDomain); nil != err {
fmt.Fprintf(os.Stderr, "warn: could not update iplist: %s\n", err)
continue
}
}
}()
for _, section := range fields {
parts := strings.Split(section, ":")
if 2 != len(parts) || !strings.HasPrefix(parts[0], "ip") {
// ignore unsupported bits
// (i.e. +mx +ip include:xxx)
continue
}
ip := parts[1]
if strings.Contains(ip, "/") {
_, _, err := net.ParseCIDR(ip)
if nil != err {
panic(fmt.Errorf("invalid CIDR %q", ip))
}
continue
}
ipAddr := net.ParseIP(ip)
if nil == ipAddr {
panic(fmt.Errorf(
"IP %q from SPF record could not be parsed",
ipAddr.String(),
))
}
}
return fields
}
func updateTxt(txtDomain string) error {
var newFields []string
records, err := net.LookupTXT(txtDomain)
if nil != err {
return fmt.Errorf("bad spf-domain: %s", err)
}
for _, record := range records {
newFields, err = parseSpf(record)
if nil != err {
continue
}
if len(fields) > 0 {
break
}
}
// TODO put a lock here?
fields = newFields
return nil
}
// IsAllowed returns true if the given IP matches an IP or CIDR in
// the whitelist, or if the spf-domain is an empty string explicitly
func IsAllowed(remoteIP net.IP) (bool, error) {
if !initialized {
panic(fmt.Errorf("was not initialized"))
}
if 0 == len(fields) {
return true, nil
}
for _, section := range fields {
parts := strings.Split(section, ":")
if 2 != len(parts) || !strings.HasPrefix(parts[0], "ip") {
// ignore unsupported bits
// (i.e. +mx +ip include:xxx)
continue
}
ip := parts[1]
if strings.Contains(ip, "/") {
_, ipNet, err := net.ParseCIDR(ip)
if nil != err {
return false, fmt.Errorf("invalid CIDR %q", ip)
}
return ipNet.Contains(remoteIP), nil
}
ipAddr := net.ParseIP(ip)
if nil == ipAddr {
return false, fmt.Errorf(
"IP %q from SPF record could not be parsed",
ipAddr.String(),
)
}
if remoteIP.Equal(ipAddr) {
return true, nil
}
}
return false, nil
}
func parseSpf(spf1 string) ([]string, error) {
fields := strings.Fields(spf1)
if len(fields) < 1 ||
len(fields[0]) < 1 ||
!strings.HasPrefix(fields[0], "v=") {
return nil, errors.New("missing v=")
}
fields = fields[1:]
return fields, nil
}