From b23610fdf12ef8cbecbccaa5f80df24d6f123391 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Mon, 20 Apr 2026 19:32:40 -0600 Subject: [PATCH] refactor(formmailer): production-readiness + dataset.View compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Blacklist is now a CohortSource interface (Value() *ipcohort.Cohort). *dataset.View[ipcohort.Cohort] satisfies it directly; callers with an atomic.Pointer can wrap. Drops the atomic/sync import from the public API. - SMTP send now uses net.Dialer.DialContext with a bounded SMTPTimeout (default 15s) and conn deadline, so a slow/hung relay no longer holds the request goroutine for WriteTimeout. Opportunistic STARTTLS added. - MX lookup uses net.DefaultResolver.LookupMX with a bounded MXTimeout (default 3s), cancellable via r.Context(). - clientIP uses net.SplitHostPort (was LastIndex(":"), broken for IPv6). - Per-IP limiter map now has a 10-minute TTL with opportunistic sweep every 1024 requests — previously grew unbounded. - Sentinel errors switched to errors.New; fmt.Errorf was unused. --- net/formmailer/formmailer.go | 176 +++++++++++++++++++++++++++-------- 1 file changed, 139 insertions(+), 37 deletions(-) diff --git a/net/formmailer/formmailer.go b/net/formmailer/formmailer.go index 081dc6c..259001e 100644 --- a/net/formmailer/formmailer.go +++ b/net/formmailer/formmailer.go @@ -3,8 +3,9 @@ // // Typical setup: // -// var blacklist atomic.Pointer[ipcohort.Cohort] -// // ... caller loads blacklist and hot-swaps on a timer ... +// // Blacklist can be any snapshot source — e.g. *dataset.View[ipcohort.Cohort] +// // satisfies CohortSource directly via its Value() method. +// blacklist := dataset.Add(set, func() (*ipcohort.Cohort, error) { ... }) // // fm := &formmailer.FormMailer{ // SMTPHost: "smtp.example.com:587", @@ -15,7 +16,7 @@ // Subject: "Contact from {.Email}", // SuccessBody: successHTML, // ErrorBody: errorHTML, -// Blacklist: &blacklist, +// Blacklist: blacklist, // AllowedCountries: []string{"US", "CA", "MX"}, // } // http.Handle("POST /contact", fm) @@ -23,6 +24,8 @@ package formmailer import ( "bytes" + "context" + "errors" "fmt" "log" "net" @@ -34,7 +37,6 @@ import ( "slices" "strings" "sync" - "sync/atomic" "time" "github.com/phuslu/iploc" @@ -51,20 +53,31 @@ const ( maxEmailLength = 254 maxPhoneLength = 20 - defaultRPM = 5 - defaultBurst = 3 + defaultRPM = 5 + defaultBurst = 3 + defaultSMTPTimeout = 15 * time.Second + defaultMXTimeout = 3 * time.Second + + limiterTTL = 10 * time.Minute + limiterSweepEvery = 1024 // sweep once every N handler invocations ) var ( - ErrInvalidEmail = fmt.Errorf("email address doesn't look like an email address") - ErrInvalidMX = fmt.Errorf("email address isn't deliverable") - ErrInvalidPhone = fmt.Errorf("phone number is not properly formatted") - ErrContentTooLong = fmt.Errorf("one or more field values was too long") - ErrInvalidNewlines = fmt.Errorf("invalid use of newlines or carriage returns") + ErrInvalidEmail = errors.New("email address doesn't look like an email address") + ErrInvalidMX = errors.New("email address isn't deliverable") + ErrInvalidPhone = errors.New("phone number is not properly formatted") + ErrContentTooLong = errors.New("one or more field values was too long") + ErrInvalidNewlines = errors.New("invalid use of newlines or carriage returns") phoneRe = regexp.MustCompile(`^[0-9+\-\(\) ]{7,20}$`) ) +// CohortSource returns the current cohort snapshot, or nil if not yet loaded. +// *dataset.View[ipcohort.Cohort] satisfies this interface directly. +type CohortSource interface { + Value() *ipcohort.Cohort +} + // FormFields maps logical field names to the HTML form field names. // Zero values use GravityForms-compatible defaults (input_1, input_3, etc.). type FormFields struct { @@ -93,6 +106,11 @@ type FormMailer struct { SMTPPass string Subject string // may contain {.Email} + // SMTPTimeout bounds the entire connect+auth+send cycle. Zero uses 15s. + SMTPTimeout time.Duration + // MXTimeout bounds the per-submission MX lookup. Zero uses 3s. + MXTimeout time.Duration + // SuccessBody and ErrorBody are the response bodies sent to the client. // ErrorBody may contain {.Error} and {.SupportEmail} placeholders. // Load from files before use: fm.SuccessBody, _ = os.ReadFile("success.html") @@ -101,7 +119,8 @@ type FormMailer struct { ContentType string // inferred from SuccessBody if empty // Blacklist — if set, matching IPs are rejected before any other processing. - Blacklist *atomic.Pointer[ipcohort.Cohort] + // *dataset.View[ipcohort.Cohort] satisfies this interface. + Blacklist CohortSource // AllowedCountries — if non-nil, only requests from listed ISO codes are // accepted. Unknown country ("") is always allowed. @@ -117,11 +136,17 @@ type FormMailer struct { once sync.Once mu sync.Mutex - limiters map[string]*rate.Limiter + limiters map[string]*limiterEntry + reqCount uint64 +} + +type limiterEntry struct { + lim *rate.Limiter + lastUsed time.Time } func (fm *FormMailer) init() { - fm.limiters = make(map[string]*rate.Limiter) + fm.limiters = make(map[string]*limiterEntry) } func (fm *FormMailer) contentType() string { @@ -154,8 +179,8 @@ func (fm *FormMailer) ServeHTTP(w http.ResponseWriter, r *http.Request) { } // Blocklist check — fail before any other processing. - if bl := fm.Blacklist; bl != nil { - if c := bl.Load(); c != nil && c.ContainsAddr(ip) { + if fm.Blacklist != nil { + if c := fm.Blacklist.Value(); c != nil && c.ContainsAddr(ip) { fm.writeError(w, fmt.Errorf("automated requests are not accepted"), false) return } @@ -194,7 +219,7 @@ func (fm *FormMailer) ServeHTTP(w http.ResponseWriter, r *http.Request) { fm.writeError(w, err, true) return } - if err := validateEmailAndMX(email); err != nil { + if err := fm.validateEmailAndMX(r.Context(), email); err != nil { fm.writeError(w, err, true) return } @@ -213,16 +238,77 @@ func (fm *FormMailer) ServeHTTP(w http.ResponseWriter, r *http.Request) { strings.Join(fm.SMTPTo, ", "), fm.SMTPFrom, email, subject, body, ) - hostname := strings.Split(fm.SMTPHost, ":")[0] - auth := smtp.PlainAuth("", fm.SMTPUser, fm.SMTPPass, hostname) - if err := smtp.SendMail(fm.SMTPHost, auth, fm.SMTPFrom, fm.SMTPTo, msg); err != nil { + if err := fm.sendMail(r.Context(), msg); err != nil { log.Printf("contact form: smtp error: %v", err) http.Error(w, "failed to send — please try again later", http.StatusInternalServerError) return } w.Header().Set("Content-Type", fm.contentType()) - w.Write(fm.SuccessBody) + _, _ = w.Write(fm.SuccessBody) +} + +// sendMail dials SMTPHost with a bounded timeout and writes the message. +// Uses smtp.NewClient directly so the dial respects ctx; stdlib smtp.SendMail +// has no context plumbing. +func (fm *FormMailer) sendMail(ctx context.Context, msg []byte) error { + timeout := fm.SMTPTimeout + if timeout == 0 { + timeout = defaultSMTPTimeout + } + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + d := net.Dialer{} + conn, err := d.DialContext(ctx, "tcp", fm.SMTPHost) + if err != nil { + return fmt.Errorf("dial: %w", err) + } + if dl, ok := ctx.Deadline(); ok { + _ = conn.SetDeadline(dl) + } + hostname, _, err := net.SplitHostPort(fm.SMTPHost) + if err != nil { + hostname = fm.SMTPHost + } + c, err := smtp.NewClient(conn, hostname) + if err != nil { + _ = conn.Close() + return fmt.Errorf("smtp client: %w", err) + } + defer func() { _ = c.Close() }() + + if ok, _ := c.Extension("STARTTLS"); ok { + if err := c.StartTLS(nil); err != nil { + return fmt.Errorf("starttls: %w", err) + } + } + if fm.SMTPUser != "" { + auth := smtp.PlainAuth("", fm.SMTPUser, fm.SMTPPass, hostname) + if err := c.Auth(auth); err != nil { + return fmt.Errorf("auth: %w", err) + } + } + if err := c.Mail(fm.SMTPFrom); err != nil { + return fmt.Errorf("mail from: %w", err) + } + for _, to := range fm.SMTPTo { + if err := c.Rcpt(to); err != nil { + return fmt.Errorf("rcpt to %s: %w", to, err) + } + } + wc, err := c.Data() + if err != nil { + return fmt.Errorf("data: %w", err) + } + if _, err := wc.Write(msg); err != nil { + _ = wc.Close() + return fmt.Errorf("write: %w", err) + } + if err := wc.Close(); err != nil { + return fmt.Errorf("close data: %w", err) + } + return c.Quit() } func (fm *FormMailer) writeError(w http.ResponseWriter, err error, showSupport bool) { @@ -234,7 +320,7 @@ func (fm *FormMailer) writeError(w http.ResponseWriter, err error, showSupport b } b := bytes.ReplaceAll(fm.ErrorBody, []byte("{.Error}"), []byte(err.Error())) b = bytes.ReplaceAll(b, []byte("{.SupportEmail}"), []byte(support)) - w.Write(b) + _, _ = w.Write(b) } func (fm *FormMailer) allow(ipStr string) bool { @@ -247,12 +333,25 @@ func (fm *FormMailer) allow(ipStr string) bool { burst = defaultBurst } + now := time.Now() fm.mu.Lock() - lim, ok := fm.limiters[ipStr] + e, ok := fm.limiters[ipStr] if !ok { - lim = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), burst) - fm.limiters[ipStr] = lim + e = &limiterEntry{ + lim: rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), burst), + } + fm.limiters[ipStr] = e } + e.lastUsed = now + fm.reqCount++ + if fm.reqCount%limiterSweepEvery == 0 { + for k, v := range fm.limiters { + if now.Sub(v.lastUsed) > limiterTTL { + delete(fm.limiters, k) + } + } + } + lim := e.lim fm.mu.Unlock() if !lim.Allow() { @@ -271,15 +370,21 @@ func validateLengths(name, email, phone, company, message string) error { return nil } -func validateEmailAndMX(email string) error { +func (fm *FormMailer) validateEmailAndMX(ctx context.Context, email string) error { if _, err := mail.ParseAddress(email); err != nil { return ErrInvalidEmail } - parts := strings.Split(email, "@") - if len(parts) != 2 { + _, domain, ok := strings.Cut(email, "@") + if !ok { return ErrInvalidEmail } - if _, err := net.LookupMX(parts[1]); err != nil { + timeout := fm.MXTimeout + if timeout == 0 { + timeout = defaultMXTimeout + } + lookupCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + if _, err := net.DefaultResolver.LookupMX(lookupCtx, domain); err != nil { return ErrInvalidMX } return nil @@ -307,14 +412,11 @@ func validateNoHeaderInjection(fields ...string) error { // clientIP returns the originating IP, preferring X-Forwarded-For. func clientIP(r *http.Request) string { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - if parts := strings.SplitN(xff, ",", 2); len(parts) > 0 { - return strings.TrimSpace(parts[0]) - } + first, _, _ := strings.Cut(xff, ",") + return strings.TrimSpace(first) } - ip := r.RemoteAddr - if idx := strings.LastIndex(ip, ":"); idx > 0 { - return ip[:idx] + if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + return host } - return ip + return r.RemoteAddr } -