diff --git a/net/formmailer/formmailer.go b/net/formmailer/formmailer.go new file mode 100644 index 0000000..ea390fd --- /dev/null +++ b/net/formmailer/formmailer.go @@ -0,0 +1,381 @@ +package formmailer + +import ( + "bytes" + "context" + "fmt" + "log" + "net" + "net/http" + "net/mail" + "net/netip" + "net/smtp" + "os" + "regexp" + "strings" + "sync" + "time" + + "golang.org/x/term" + "golang.org/x/time/rate" + + "github.com/phuslu/iploc" +) + +const ( + maxFormSize = 10 * 1024 // 10KB total form limit + maxMessageLength = 4000 + maxCompanyLength = 200 + maxNameLength = 100 + maxEmailLength = 254 + maxPhoneLength = 20 + + requestsPerMinute = 5 + burstSize = 3 +) + +var ErrInvalidEmail = fmt.Errorf("email address doesn't look like an email address") +var ErrInvalidMX = fmt.Errorf("email address isn't deliverable") +var ErrInvalidPhone = fmt.Errorf("phone number is not properly formatted") +var ErrContentTooLong = fmt.Errorf("one or more of the field values was too long") +var ErrInvalidNewlines = fmt.Errorf("invalid use of newlines or returns") + +var ( + phoneRe = regexp.MustCompile(`^[0-9+\-\(\) ]{7,20}$`) + + // Global per-IP limiter map + limiterMu sync.Mutex + limiters = make(map[string]*rate.Limiter) +) + +type FormMailer struct { + showVersion bool + listenAddr string + smtpHost string + smtpFrom string + smtpToList string + smtpUser string + smtpPass string + smtpSubject string + successFile string + errorFile string + responseType string + Blacklist *ipcohort.Cohort +} + +func Init() { + gitURL := "git@github.com:bitwire-it/ipblocklist.git" + blacklistPath := "/home/app/srv/ipblocklist/inbound.txt" + + cfg := &FormMailer{ + listenAddr: "localhost:3081", + smtpHost: os.Getenv("SMTP_HOST"), + smtpFrom: os.Getenv("SMTP_FROM"), + smtpToList: os.Getenv("SMTP_TO"), + smtpUser: os.Getenv("SMTP_USER"), + smtpPass: "", + smtpSubject: "Website contact request from {.Email}", + successFile: "success-file.html", + errorFile: "error-file.html", + responseType: "text/plain", + Blacklist: nil, + } + + if cfg.smtpHost == "" || cfg.smtpFrom == "" || cfg.smtpToList == "" { + return fmt.Errorf("missing required SMTP settings") + } + + if _, err := os.ReadFile(cfg.successFile); err != nil { + fmt.Fprintf(os.Stderr, "\nError: couldn't read success response file %q: %v\n\n", cfg.successFile, err) + os.Exit(1) + } + if _, err := os.ReadFile(cfg.errorFile); err != nil { + fmt.Fprintf(os.Stderr, "\nError: couldn't read error response file %q: %v\n\n", cfg.errorFile, err) + os.Exit(1) + } + + if cfg.smtpUser == "" { + cfg.smtpUser = cfg.smtpFrom + } + if cfg.smtpFrom == "" { + cfg.smtpFrom = cfg.smtpUser + } + + if pass, hasPass := os.LookupEnv("SMTP_PASS"); !hasPass { + fmt.Fprintf(os.Stderr, "SMTP_PASS not set → ") + pwBytes, err := term.ReadPassword(int(os.Stdin.Fd())) + if err != nil { + log.Fatalf("Failed to read password: %v", err) + } + fmt.Fprintln(os.Stderr) + cfg.smtpPass = strings.TrimSpace(string(pwBytes)) + } else { + cfg.smtpPass = pass + } + + if strings.HasSuffix(strings.ToLower(cfg.successFile), ".html") { + cfg.responseType = "text/html" + } else if strings.HasSuffix(strings.ToLower(cfg.successFile), ".json") { + cfg.responseType = "application/json" + } + + cfg.Blacklist = NewBlacklist(gitURL, blacklistPath) + fmt.Fprintf(os.Stderr, "Syncing git repo ...\n") + skipGCOnce := true + if n, err := cfg.Blacklist.Init(skipGCOnce); err != nil { + fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err) + } else if n > 0 { + fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n) + } + go func() { + cfg.Blacklist.Run(context.TODO()) + }() + + http.HandleFunc("POST /contact", cfg.submitHandler) + http.HandleFunc("POST /contact/", cfg.submitHandler) + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintln(w, "form2email server running. POST form data to /contact") + }) + + fmt.Printf("form2email listening on http://%s\n", cfg.listenAddr) + fmt.Printf("Forwarding submissions from %s → %s via %s\n", cfg.smtpFrom, cfg.smtpToList, cfg.smtpHost) + fmt.Printf("Rate limit: ~%d req/min per IP (burst %d)\n", requestsPerMinute, burstSize) + fmt.Println("CTRL+C to stop") + + log.Fatal(http.ListenAndServe(cfg.listenAddr, nil)) +} + +func (cfg *FormMailer) submitHandler(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", cfg.responseType) + + // Parse form early (needed for rate limit decision, but still protected by size limit) + err := r.ParseMultipartForm(maxFormSize) + if err != nil { + http.Error(w, "Form too large or invalid", http.StatusBadRequest) + log.Printf("ParseMultipartForm error: %v", err) + return + } + + // Rate limit FIRST (cheap check) + ipStr := getClientIP(r) + ip, err := netip.ParseAddr(ipStr) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + b := cfg.getErrorBody(fmt.Errorf("malformed proxy headers")) + _, _ = w.Write(b) + return + } + + if cfg.Blacklist.Contains(ipStr) { + w.WriteHeader(http.StatusBadRequest) + b := cfg.getErrorBotty(fmt.Errorf("bots are not allowed to submit contact requests")) + _, _ = w.Write(b) + return + } + + switch iploc.IPCountry(ip) { + case "", "US", "CA", "MX", "CR", "VI": + // North America, or unknown + default: + w.WriteHeader(http.StatusBadRequest) + b := cfg.getErrorBody(fmt.Errorf("it appears that you are contacting us from outside of the United States, please email us directly for international inquiries")) + _, _ = w.Write(b) + return + } + + if !validateRateLimit(ipStr) { + http.Error(w, "Rate limit exceeded (try again later)", http.StatusTooManyRequests) + return + } + + stuff := make(map[string]string) + // Extract & trim fields + email := strings.ToLower(strings.TrimSpace(r.FormValue("input_3"))) + stuff["name"] = strings.TrimSpace(r.FormValue("input_1")) + stuff["phone"] = strings.TrimSpace(r.FormValue("input_4")) + stuff["company"] = strings.TrimSpace(r.FormValue("input_5")) + stuff["message"] = strings.TrimSpace(r.FormValue("input_7")) + + // Validation chain + if err := validateLengths(stuff["name"], email, stuff["phone"], stuff["company"], stuff["message"]); err != nil { + w.WriteHeader(http.StatusBadRequest) + b := cfg.getErrorBody(err) + _, _ = w.Write(b) + return + } + + if err := validatePhone(stuff["phone"]); err != nil { + w.WriteHeader(http.StatusBadRequest) + b := cfg.getErrorBody(err) + _, _ = w.Write(b) + return + } + + if err := validateNoHeaderInjection(stuff["name"], email, stuff["company"]); err != nil { + w.WriteHeader(http.StatusBadRequest) + b := cfg.getErrorBody(err) + _, _ = w.Write(b) + return + } + + if err := validateEmailAndMX(email); err != nil { + w.WriteHeader(http.StatusBadRequest) + b := cfg.getErrorBody(err) + _, _ = w.Write(b) + return + } + + // Log submission + n := min(len(stuff["message"]), 100) + log.Printf("Submission from %s | Name=%q Email=%q Phone=%q Company=%q Message=%q", + ipStr, stuff["name"], email, stuff["phone"], stuff["company"], stuff["message"][:n]+"...") + + // TODO blacklist + + if strings.HasSuffix(email, ".ru") { + b, _ := os.ReadFile(cfg.successFile) + _, _ = w.Write(b) + return + } + + // Build email + body := fmt.Sprintf( + "New contact form submission:\n\n"+ + "Name: %s\n"+ + "Email: %s\n"+ + "Phone: %s\n"+ + "Company: %s\n"+ + "Message:\n%s\n", + stuff["name"], email, stuff["phone"], stuff["company"], stuff["message"], + ) + + msg := fmt.Appendf(nil, + "To: %s\r\n"+ + "From: %s\r\n"+ + "Reply-To: %s\r\n"+ + "Subject: %s\r\n"+ + "\r\n"+ + "%s\r\n", + cfg.smtpToList, cfg.smtpFrom, email, strings.ReplaceAll(cfg.smtpSubject, "{.Email}", email), body, + ) + + hostname := strings.Split(cfg.smtpHost, ":")[0] + auth := smtp.PlainAuth("", cfg.smtpUser, cfg.smtpPass, hostname) + + smtpTo := strings.Split(cfg.smtpToList, ",") + err = smtp.SendMail(cfg.smtpHost, auth, cfg.smtpFrom, smtpTo, msg) + if err != nil { + http.Error(w, "Failed to send email", http.StatusInternalServerError) + log.Printf("SMTP error: %v", err) + return + } + + b, _ := os.ReadFile(cfg.successFile) + _, _ = w.Write(b) +} + +func (cfg *FormMailer) getErrorBody(err error) []byte { + b, _ := os.ReadFile(cfg.errorFile) + b = bytes.ReplaceAll(b, []byte("{.Error}"), []byte(err.Error())) + b = bytes.ReplaceAll(b, []byte("{.SupportEmail}"), []byte(cfg.smtpFrom)) + return b +} + +func (cfg *FormMailer) getErrorBotty(err error) []byte { + b, _ := os.ReadFile(cfg.errorFile) + b = bytes.ReplaceAll(b, []byte("{.Error}"), []byte(err.Error())) + b = bytes.ReplaceAll(b, []byte("{.SupportEmail}"), []byte("[REDACTED]")) + return b +} + +// ──────────────────────────────────────────────────────────────────────────────── +// Validation functions +// ──────────────────────────────────────────────────────────────────────────────── + +func validateRateLimit(ipStr string) bool { + limiterMu.Lock() + lim, ok := limiters[ipStr] + if !ok { + lim = rate.NewLimiter(rate.Every(time.Minute/time.Duration(requestsPerMinute)), burstSize) + limiters[ipStr] = lim + } + limiterMu.Unlock() + + if !lim.Allow() { + log.Printf("Rate limited IP: %s", ipStr) + return false + } + return true +} + +func validateLengths(name, email, phone, company, message string) error { + if len(name) > maxNameLength || + len(email) > maxEmailLength || + len(phone) > maxPhoneLength || + len(company) > maxCompanyLength || + len(message) > maxMessageLength { + return ErrContentTooLong + } + + return nil +} + +func validateEmailAndMX(email string) error { + _, err := mail.ParseAddress(email) + if err != nil { + return ErrInvalidEmail + } + + parts := strings.Split(email, "@") + if len(parts) != 2 { + return ErrInvalidEmail + } + domain := parts[1] + + _, err = net.LookupMX(domain) + if err != nil { + return ErrInvalidMX + } + + return nil +} + +func validatePhone(phone string) error { + if phone == "" { + return nil + } + if !phoneRe.MatchString(phone) { + return ErrInvalidPhone + } + + return nil +} + +func validateNoHeaderInjection(name, email, company string) error { + combined := name + email + company + if strings.ContainsAny(combined, "\r\n") { + return ErrInvalidNewlines + } + + return nil +} + +// getClientIP prefers X-Forwarded-For (first value) over RemoteAddr +func getClientIP(r *http.Request) string { + xff := r.Header.Get("X-Forwarded-For") + if xff != "" { + // Take the first (original client) IP in case of multiple proxies + parts := strings.Split(xff, ",") + if len(parts) > 0 { + fmt.Println("Remote IP XFF:", xff) + return strings.TrimSpace(parts[0]) + } + } + // Fallback to RemoteAddr (strip port) + ip := r.RemoteAddr + if idx := strings.LastIndex(ip, ":"); idx > -1 { + ip = ip[:idx] + fmt.Println("Remote IP:", ip) + } + return ip +}