mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 20:58:00 +00:00
refactor(formmailer): production-readiness + dataset.View compatibility
- Blacklist is now a CohortSource interface (Value() *ipcohort.Cohort).
*dataset.View[ipcohort.Cohort] satisfies it directly; callers with
an atomic.Pointer can wrap. Drops the atomic/sync import from the
public API.
- SMTP send now uses net.Dialer.DialContext with a bounded SMTPTimeout
(default 15s) and conn deadline, so a slow/hung relay no longer holds
the request goroutine for WriteTimeout. Opportunistic STARTTLS added.
- MX lookup uses net.DefaultResolver.LookupMX with a bounded MXTimeout
(default 3s), cancellable via r.Context().
- clientIP uses net.SplitHostPort (was LastIndex(":"), broken for IPv6).
- Per-IP limiter map now has a 10-minute TTL with opportunistic sweep
every 1024 requests — previously grew unbounded.
- Sentinel errors switched to errors.New; fmt.Errorf was unused.
This commit is contained in:
parent
b40abe0a06
commit
b23610fdf1
@ -3,8 +3,9 @@
|
|||||||
//
|
//
|
||||||
// Typical setup:
|
// Typical setup:
|
||||||
//
|
//
|
||||||
// var blacklist atomic.Pointer[ipcohort.Cohort]
|
// // Blacklist can be any snapshot source — e.g. *dataset.View[ipcohort.Cohort]
|
||||||
// // ... caller loads blacklist and hot-swaps on a timer ...
|
// // satisfies CohortSource directly via its Value() method.
|
||||||
|
// blacklist := dataset.Add(set, func() (*ipcohort.Cohort, error) { ... })
|
||||||
//
|
//
|
||||||
// fm := &formmailer.FormMailer{
|
// fm := &formmailer.FormMailer{
|
||||||
// SMTPHost: "smtp.example.com:587",
|
// SMTPHost: "smtp.example.com:587",
|
||||||
@ -15,7 +16,7 @@
|
|||||||
// Subject: "Contact from {.Email}",
|
// Subject: "Contact from {.Email}",
|
||||||
// SuccessBody: successHTML,
|
// SuccessBody: successHTML,
|
||||||
// ErrorBody: errorHTML,
|
// ErrorBody: errorHTML,
|
||||||
// Blacklist: &blacklist,
|
// Blacklist: blacklist,
|
||||||
// AllowedCountries: []string{"US", "CA", "MX"},
|
// AllowedCountries: []string{"US", "CA", "MX"},
|
||||||
// }
|
// }
|
||||||
// http.Handle("POST /contact", fm)
|
// http.Handle("POST /contact", fm)
|
||||||
@ -23,6 +24,8 @@ package formmailer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
@ -34,7 +37,6 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/phuslu/iploc"
|
"github.com/phuslu/iploc"
|
||||||
@ -53,18 +55,29 @@ const (
|
|||||||
|
|
||||||
defaultRPM = 5
|
defaultRPM = 5
|
||||||
defaultBurst = 3
|
defaultBurst = 3
|
||||||
|
defaultSMTPTimeout = 15 * time.Second
|
||||||
|
defaultMXTimeout = 3 * time.Second
|
||||||
|
|
||||||
|
limiterTTL = 10 * time.Minute
|
||||||
|
limiterSweepEvery = 1024 // sweep once every N handler invocations
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrInvalidEmail = fmt.Errorf("email address doesn't look like an email address")
|
ErrInvalidEmail = errors.New("email address doesn't look like an email address")
|
||||||
ErrInvalidMX = fmt.Errorf("email address isn't deliverable")
|
ErrInvalidMX = errors.New("email address isn't deliverable")
|
||||||
ErrInvalidPhone = fmt.Errorf("phone number is not properly formatted")
|
ErrInvalidPhone = errors.New("phone number is not properly formatted")
|
||||||
ErrContentTooLong = fmt.Errorf("one or more field values was too long")
|
ErrContentTooLong = errors.New("one or more field values was too long")
|
||||||
ErrInvalidNewlines = fmt.Errorf("invalid use of newlines or carriage returns")
|
ErrInvalidNewlines = errors.New("invalid use of newlines or carriage returns")
|
||||||
|
|
||||||
phoneRe = regexp.MustCompile(`^[0-9+\-\(\) ]{7,20}$`)
|
phoneRe = regexp.MustCompile(`^[0-9+\-\(\) ]{7,20}$`)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// CohortSource returns the current cohort snapshot, or nil if not yet loaded.
|
||||||
|
// *dataset.View[ipcohort.Cohort] satisfies this interface directly.
|
||||||
|
type CohortSource interface {
|
||||||
|
Value() *ipcohort.Cohort
|
||||||
|
}
|
||||||
|
|
||||||
// FormFields maps logical field names to the HTML form field names.
|
// FormFields maps logical field names to the HTML form field names.
|
||||||
// Zero values use GravityForms-compatible defaults (input_1, input_3, etc.).
|
// Zero values use GravityForms-compatible defaults (input_1, input_3, etc.).
|
||||||
type FormFields struct {
|
type FormFields struct {
|
||||||
@ -93,6 +106,11 @@ type FormMailer struct {
|
|||||||
SMTPPass string
|
SMTPPass string
|
||||||
Subject string // may contain {.Email}
|
Subject string // may contain {.Email}
|
||||||
|
|
||||||
|
// SMTPTimeout bounds the entire connect+auth+send cycle. Zero uses 15s.
|
||||||
|
SMTPTimeout time.Duration
|
||||||
|
// MXTimeout bounds the per-submission MX lookup. Zero uses 3s.
|
||||||
|
MXTimeout time.Duration
|
||||||
|
|
||||||
// SuccessBody and ErrorBody are the response bodies sent to the client.
|
// SuccessBody and ErrorBody are the response bodies sent to the client.
|
||||||
// ErrorBody may contain {.Error} and {.SupportEmail} placeholders.
|
// ErrorBody may contain {.Error} and {.SupportEmail} placeholders.
|
||||||
// Load from files before use: fm.SuccessBody, _ = os.ReadFile("success.html")
|
// Load from files before use: fm.SuccessBody, _ = os.ReadFile("success.html")
|
||||||
@ -101,7 +119,8 @@ type FormMailer struct {
|
|||||||
ContentType string // inferred from SuccessBody if empty
|
ContentType string // inferred from SuccessBody if empty
|
||||||
|
|
||||||
// Blacklist — if set, matching IPs are rejected before any other processing.
|
// Blacklist — if set, matching IPs are rejected before any other processing.
|
||||||
Blacklist *atomic.Pointer[ipcohort.Cohort]
|
// *dataset.View[ipcohort.Cohort] satisfies this interface.
|
||||||
|
Blacklist CohortSource
|
||||||
|
|
||||||
// AllowedCountries — if non-nil, only requests from listed ISO codes are
|
// AllowedCountries — if non-nil, only requests from listed ISO codes are
|
||||||
// accepted. Unknown country ("") is always allowed.
|
// accepted. Unknown country ("") is always allowed.
|
||||||
@ -117,11 +136,17 @@ type FormMailer struct {
|
|||||||
|
|
||||||
once sync.Once
|
once sync.Once
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
limiters map[string]*rate.Limiter
|
limiters map[string]*limiterEntry
|
||||||
|
reqCount uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
type limiterEntry struct {
|
||||||
|
lim *rate.Limiter
|
||||||
|
lastUsed time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fm *FormMailer) init() {
|
func (fm *FormMailer) init() {
|
||||||
fm.limiters = make(map[string]*rate.Limiter)
|
fm.limiters = make(map[string]*limiterEntry)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fm *FormMailer) contentType() string {
|
func (fm *FormMailer) contentType() string {
|
||||||
@ -154,8 +179,8 @@ func (fm *FormMailer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Blocklist check — fail before any other processing.
|
// Blocklist check — fail before any other processing.
|
||||||
if bl := fm.Blacklist; bl != nil {
|
if fm.Blacklist != nil {
|
||||||
if c := bl.Load(); c != nil && c.ContainsAddr(ip) {
|
if c := fm.Blacklist.Value(); c != nil && c.ContainsAddr(ip) {
|
||||||
fm.writeError(w, fmt.Errorf("automated requests are not accepted"), false)
|
fm.writeError(w, fmt.Errorf("automated requests are not accepted"), false)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -194,7 +219,7 @@ func (fm *FormMailer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
fm.writeError(w, err, true)
|
fm.writeError(w, err, true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := validateEmailAndMX(email); err != nil {
|
if err := fm.validateEmailAndMX(r.Context(), email); err != nil {
|
||||||
fm.writeError(w, err, true)
|
fm.writeError(w, err, true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -213,16 +238,77 @@ func (fm *FormMailer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||||||
strings.Join(fm.SMTPTo, ", "), fm.SMTPFrom, email, subject, body,
|
strings.Join(fm.SMTPTo, ", "), fm.SMTPFrom, email, subject, body,
|
||||||
)
|
)
|
||||||
|
|
||||||
hostname := strings.Split(fm.SMTPHost, ":")[0]
|
if err := fm.sendMail(r.Context(), msg); err != nil {
|
||||||
auth := smtp.PlainAuth("", fm.SMTPUser, fm.SMTPPass, hostname)
|
|
||||||
if err := smtp.SendMail(fm.SMTPHost, auth, fm.SMTPFrom, fm.SMTPTo, msg); err != nil {
|
|
||||||
log.Printf("contact form: smtp error: %v", err)
|
log.Printf("contact form: smtp error: %v", err)
|
||||||
http.Error(w, "failed to send — please try again later", http.StatusInternalServerError)
|
http.Error(w, "failed to send — please try again later", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Header().Set("Content-Type", fm.contentType())
|
w.Header().Set("Content-Type", fm.contentType())
|
||||||
w.Write(fm.SuccessBody)
|
_, _ = w.Write(fm.SuccessBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendMail dials SMTPHost with a bounded timeout and writes the message.
|
||||||
|
// Uses smtp.NewClient directly so the dial respects ctx; stdlib smtp.SendMail
|
||||||
|
// has no context plumbing.
|
||||||
|
func (fm *FormMailer) sendMail(ctx context.Context, msg []byte) error {
|
||||||
|
timeout := fm.SMTPTimeout
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = defaultSMTPTimeout
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
d := net.Dialer{}
|
||||||
|
conn, err := d.DialContext(ctx, "tcp", fm.SMTPHost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("dial: %w", err)
|
||||||
|
}
|
||||||
|
if dl, ok := ctx.Deadline(); ok {
|
||||||
|
_ = conn.SetDeadline(dl)
|
||||||
|
}
|
||||||
|
hostname, _, err := net.SplitHostPort(fm.SMTPHost)
|
||||||
|
if err != nil {
|
||||||
|
hostname = fm.SMTPHost
|
||||||
|
}
|
||||||
|
c, err := smtp.NewClient(conn, hostname)
|
||||||
|
if err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return fmt.Errorf("smtp client: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = c.Close() }()
|
||||||
|
|
||||||
|
if ok, _ := c.Extension("STARTTLS"); ok {
|
||||||
|
if err := c.StartTLS(nil); err != nil {
|
||||||
|
return fmt.Errorf("starttls: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if fm.SMTPUser != "" {
|
||||||
|
auth := smtp.PlainAuth("", fm.SMTPUser, fm.SMTPPass, hostname)
|
||||||
|
if err := c.Auth(auth); err != nil {
|
||||||
|
return fmt.Errorf("auth: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := c.Mail(fm.SMTPFrom); err != nil {
|
||||||
|
return fmt.Errorf("mail from: %w", err)
|
||||||
|
}
|
||||||
|
for _, to := range fm.SMTPTo {
|
||||||
|
if err := c.Rcpt(to); err != nil {
|
||||||
|
return fmt.Errorf("rcpt to %s: %w", to, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
wc, err := c.Data()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("data: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := wc.Write(msg); err != nil {
|
||||||
|
_ = wc.Close()
|
||||||
|
return fmt.Errorf("write: %w", err)
|
||||||
|
}
|
||||||
|
if err := wc.Close(); err != nil {
|
||||||
|
return fmt.Errorf("close data: %w", err)
|
||||||
|
}
|
||||||
|
return c.Quit()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fm *FormMailer) writeError(w http.ResponseWriter, err error, showSupport bool) {
|
func (fm *FormMailer) writeError(w http.ResponseWriter, err error, showSupport bool) {
|
||||||
@ -234,7 +320,7 @@ func (fm *FormMailer) writeError(w http.ResponseWriter, err error, showSupport b
|
|||||||
}
|
}
|
||||||
b := bytes.ReplaceAll(fm.ErrorBody, []byte("{.Error}"), []byte(err.Error()))
|
b := bytes.ReplaceAll(fm.ErrorBody, []byte("{.Error}"), []byte(err.Error()))
|
||||||
b = bytes.ReplaceAll(b, []byte("{.SupportEmail}"), []byte(support))
|
b = bytes.ReplaceAll(b, []byte("{.SupportEmail}"), []byte(support))
|
||||||
w.Write(b)
|
_, _ = w.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fm *FormMailer) allow(ipStr string) bool {
|
func (fm *FormMailer) allow(ipStr string) bool {
|
||||||
@ -247,12 +333,25 @@ func (fm *FormMailer) allow(ipStr string) bool {
|
|||||||
burst = defaultBurst
|
burst = defaultBurst
|
||||||
}
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
fm.mu.Lock()
|
fm.mu.Lock()
|
||||||
lim, ok := fm.limiters[ipStr]
|
e, ok := fm.limiters[ipStr]
|
||||||
if !ok {
|
if !ok {
|
||||||
lim = rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), burst)
|
e = &limiterEntry{
|
||||||
fm.limiters[ipStr] = lim
|
lim: rate.NewLimiter(rate.Every(time.Minute/time.Duration(rpm)), burst),
|
||||||
}
|
}
|
||||||
|
fm.limiters[ipStr] = e
|
||||||
|
}
|
||||||
|
e.lastUsed = now
|
||||||
|
fm.reqCount++
|
||||||
|
if fm.reqCount%limiterSweepEvery == 0 {
|
||||||
|
for k, v := range fm.limiters {
|
||||||
|
if now.Sub(v.lastUsed) > limiterTTL {
|
||||||
|
delete(fm.limiters, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lim := e.lim
|
||||||
fm.mu.Unlock()
|
fm.mu.Unlock()
|
||||||
|
|
||||||
if !lim.Allow() {
|
if !lim.Allow() {
|
||||||
@ -271,15 +370,21 @@ func validateLengths(name, email, phone, company, message string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateEmailAndMX(email string) error {
|
func (fm *FormMailer) validateEmailAndMX(ctx context.Context, email string) error {
|
||||||
if _, err := mail.ParseAddress(email); err != nil {
|
if _, err := mail.ParseAddress(email); err != nil {
|
||||||
return ErrInvalidEmail
|
return ErrInvalidEmail
|
||||||
}
|
}
|
||||||
parts := strings.Split(email, "@")
|
_, domain, ok := strings.Cut(email, "@")
|
||||||
if len(parts) != 2 {
|
if !ok {
|
||||||
return ErrInvalidEmail
|
return ErrInvalidEmail
|
||||||
}
|
}
|
||||||
if _, err := net.LookupMX(parts[1]); err != nil {
|
timeout := fm.MXTimeout
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = defaultMXTimeout
|
||||||
|
}
|
||||||
|
lookupCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||||
|
defer cancel()
|
||||||
|
if _, err := net.DefaultResolver.LookupMX(lookupCtx, domain); err != nil {
|
||||||
return ErrInvalidMX
|
return ErrInvalidMX
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -307,14 +412,11 @@ func validateNoHeaderInjection(fields ...string) error {
|
|||||||
// clientIP returns the originating IP, preferring X-Forwarded-For.
|
// clientIP returns the originating IP, preferring X-Forwarded-For.
|
||||||
func clientIP(r *http.Request) string {
|
func clientIP(r *http.Request) string {
|
||||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
if parts := strings.SplitN(xff, ",", 2); len(parts) > 0 {
|
first, _, _ := strings.Cut(xff, ",")
|
||||||
return strings.TrimSpace(parts[0])
|
return strings.TrimSpace(first)
|
||||||
}
|
}
|
||||||
|
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||||
|
return host
|
||||||
}
|
}
|
||||||
ip := r.RemoteAddr
|
return r.RemoteAddr
|
||||||
if idx := strings.LastIndex(ip, ":"); idx > 0 {
|
|
||||||
return ip[:idx]
|
|
||||||
}
|
}
|
||||||
return ip
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user