mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 12:48:00 +00:00
feat(formmailer)!: replace FormFields struct with ordered []Field
Form inputs are now declared as an ordered slice with Kind-driven
validation (KindText, KindEmail, KindPhone, KindMessage). Arbitrary
input names are fine — callers pick the Label shown in the email
body and the FormName of the HTML input. Per-field MaxLen and
Required overrides supported; defaults come from Kind.
Exactly one KindEmail entry is required (used for Reply-To, Subject
{.Email} substitution, and the MX check); misconfiguration is
detected at first request and returns 500.
Email body, log line, and validation now iterate Fields in order, so
the email preserves the form's declared layout.
BREAKING: FormMailer.Fields is now []Field, not FormFields struct.
Callers must migrate to the slice form.
This commit is contained in:
parent
f972d6f117
commit
b77872623a
@ -1,23 +1,33 @@
|
||||
// Package formmailer provides an HTTP handler that validates, rate-limits,
|
||||
// and emails contact form submissions.
|
||||
//
|
||||
// Fields are declared as an ordered slice; each Field names the HTML input,
|
||||
// the label for the email body, and the validation Kind. Exactly one Kind
|
||||
// must be KindEmail — its value is used for Reply-To, Subject substitution,
|
||||
// and the MX check.
|
||||
//
|
||||
// Typical setup:
|
||||
//
|
||||
// // Blacklist can be any snapshot source — e.g. *dataset.View[ipcohort.Cohort]
|
||||
// // satisfies CohortSource directly via its Value() method.
|
||||
// blacklist := dataset.Add(set, func() (*ipcohort.Cohort, error) { ... })
|
||||
//
|
||||
// fm := &formmailer.FormMailer{
|
||||
// SMTPHost: "smtp.example.com:587",
|
||||
// SMTPFrom: "noreply@example.com",
|
||||
// SMTPTo: []string{"contact@example.com"},
|
||||
// SMTPUser: "noreply@example.com",
|
||||
// SMTPPass: os.Getenv("SMTP_PASS"),
|
||||
// Subject: "Contact from {.Email}",
|
||||
// SMTPHost: "smtp.example.com:587",
|
||||
// SMTPFrom: "noreply@example.com",
|
||||
// SMTPTo: []string{"contact@example.com"},
|
||||
// SMTPUser: "noreply@example.com",
|
||||
// SMTPPass: os.Getenv("SMTP_PASS"),
|
||||
// Subject: "Contact from {.Email}",
|
||||
// Fields: []formmailer.Field{
|
||||
// {Label: "Name", FormName: "input_1", Kind: formmailer.KindText},
|
||||
// {Label: "Email", FormName: "input_3", Kind: formmailer.KindEmail},
|
||||
// {Label: "Phone", FormName: "input_4", Kind: formmailer.KindPhone},
|
||||
// {Label: "Company", FormName: "input_5", Kind: formmailer.KindText},
|
||||
// {Label: "Budget", FormName: "input_8", Kind: formmailer.KindText},
|
||||
// {Label: "Message", FormName: "input_7", Kind: formmailer.KindMessage},
|
||||
// },
|
||||
// SuccessBody: successHTML,
|
||||
// ErrorBody: errorHTML,
|
||||
// Blacklist: blacklist,
|
||||
// AllowedCountries: []string{"US", "CA", "MX"},
|
||||
// }
|
||||
// http.Handle("POST /contact", fm)
|
||||
package formmailer
|
||||
@ -47,12 +57,13 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
maxFormSize = 10 * 1024
|
||||
maxMessageLength = 4000
|
||||
maxCompanyLength = 200
|
||||
maxNameLength = 100
|
||||
maxFormSize = 10 * 1024
|
||||
|
||||
// Default per-Kind length caps; override with Field.MaxLen.
|
||||
maxEmailLength = 254
|
||||
maxPhoneLength = 20
|
||||
maxTextLength = 200
|
||||
maxMessageLength = 4000
|
||||
|
||||
defaultRPM = 5
|
||||
defaultBurst = 3
|
||||
@ -69,26 +80,45 @@ var (
|
||||
ErrInvalidPhone = errors.New("phone number is not properly formatted")
|
||||
ErrContentTooLong = errors.New("one or more field values was too long")
|
||||
ErrInvalidNewlines = errors.New("invalid use of newlines or carriage returns")
|
||||
ErrMissingRequired = errors.New("required field was empty")
|
||||
ErrNoEmailField = errors.New("FormMailer.Fields must contain exactly one KindEmail field")
|
||||
|
||||
phoneRe = regexp.MustCompile(`^[0-9+\-\(\) ]{7,20}$`)
|
||||
)
|
||||
|
||||
// FormFields maps logical field names to the HTML form field names.
|
||||
// Zero values use GravityForms-compatible defaults (input_1, input_3, etc.).
|
||||
type FormFields struct {
|
||||
Name string // default "input_1"
|
||||
Email string // default "input_3"
|
||||
Phone string // default "input_4"
|
||||
Company string // default "input_5"
|
||||
Message string // default "input_7"
|
||||
// FieldKind picks validation rules and default length cap for a Field.
|
||||
type FieldKind int
|
||||
|
||||
const (
|
||||
KindText FieldKind = iota // default; length-capped text
|
||||
KindEmail // RFC 5321 parse + MX lookup
|
||||
KindPhone // phoneRe match
|
||||
KindMessage // long free text (body of the submission)
|
||||
)
|
||||
|
||||
// Field declares one form input. Order is preserved in the email body.
|
||||
type Field struct {
|
||||
Label string // shown in email body, e.g. "Name"
|
||||
FormName string // HTML form field name, e.g. "input_1"
|
||||
Kind FieldKind // validation rules + default MaxLen
|
||||
MaxLen int // 0 = default for Kind
|
||||
Required bool // if true, empty value is rejected
|
||||
}
|
||||
|
||||
func (f FormFields) get(r *http.Request, field, def string) string {
|
||||
key := field
|
||||
if key == "" {
|
||||
key = def
|
||||
func (f Field) maxLen() int {
|
||||
if f.MaxLen > 0 {
|
||||
return f.MaxLen
|
||||
}
|
||||
switch f.Kind {
|
||||
case KindEmail:
|
||||
return maxEmailLength
|
||||
case KindPhone:
|
||||
return maxPhoneLength
|
||||
case KindMessage:
|
||||
return maxMessageLength
|
||||
default:
|
||||
return maxTextLength
|
||||
}
|
||||
return strings.TrimSpace(r.FormValue(key))
|
||||
}
|
||||
|
||||
// FormMailer is an http.Handler that validates and emails contact form submissions.
|
||||
@ -101,35 +131,35 @@ type FormMailer struct {
|
||||
SMTPPass string
|
||||
Subject string // may contain {.Email}
|
||||
|
||||
// SMTPTimeout bounds the entire connect+auth+send cycle. Zero uses 15s.
|
||||
// SMTPTimeout bounds the entire connect+auth+send cycle. Zero uses 5s.
|
||||
SMTPTimeout time.Duration
|
||||
// MXTimeout bounds the per-submission MX lookup. Zero uses 3s.
|
||||
// MXTimeout bounds the per-submission MX lookup. Zero uses 2s.
|
||||
MXTimeout time.Duration
|
||||
|
||||
// SuccessBody and ErrorBody are the response bodies sent to the client.
|
||||
// ErrorBody may contain {.Error} and {.SupportEmail} placeholders.
|
||||
// Load from files before use: fm.SuccessBody, _ = os.ReadFile("success.html")
|
||||
SuccessBody []byte
|
||||
ErrorBody []byte
|
||||
ContentType string // inferred from SuccessBody if empty
|
||||
|
||||
// Blacklist — if set, matching IPs are rejected before any other processing.
|
||||
// Value() returns nil before the first successful load (no blocks applied).
|
||||
Blacklist *dataset.View[ipcohort.Cohort]
|
||||
|
||||
// AllowedCountries — if non-nil, only requests from listed ISO codes are
|
||||
// accepted. Unknown country ("") is always allowed.
|
||||
// Example: []string{"US", "CA", "MX"}
|
||||
AllowedCountries []string
|
||||
|
||||
// Fields maps logical names to HTML form field names.
|
||||
Fields FormFields
|
||||
// Fields declares the form inputs in display order. Exactly one entry
|
||||
// must have Kind == KindEmail.
|
||||
Fields []Field
|
||||
|
||||
// RPM and Burst control per-IP rate limiting. Zero uses defaults (5/3).
|
||||
RPM int
|
||||
Burst int
|
||||
|
||||
once sync.Once
|
||||
initErr error
|
||||
emailIdx int // index into Fields of the KindEmail entry
|
||||
mu sync.Mutex
|
||||
limiters map[string]*limiterEntry
|
||||
reqCount uint64
|
||||
@ -142,13 +172,25 @@ type limiterEntry struct {
|
||||
|
||||
func (fm *FormMailer) init() {
|
||||
fm.limiters = make(map[string]*limiterEntry)
|
||||
fm.emailIdx = -1
|
||||
for i, f := range fm.Fields {
|
||||
if f.Kind == KindEmail {
|
||||
if fm.emailIdx >= 0 {
|
||||
fm.initErr = ErrNoEmailField
|
||||
return
|
||||
}
|
||||
fm.emailIdx = i
|
||||
}
|
||||
}
|
||||
if fm.emailIdx < 0 {
|
||||
fm.initErr = ErrNoEmailField
|
||||
}
|
||||
}
|
||||
|
||||
func (fm *FormMailer) contentType() string {
|
||||
if fm.ContentType != "" {
|
||||
return fm.ContentType
|
||||
}
|
||||
// Infer from SuccessBody sniff or leave as plain text.
|
||||
if bytes.Contains(fm.SuccessBody[:min(512, len(fm.SuccessBody))], []byte("<html")) {
|
||||
return "text/html; charset=utf-8"
|
||||
}
|
||||
@ -160,6 +202,11 @@ func (fm *FormMailer) contentType() string {
|
||||
|
||||
func (fm *FormMailer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
fm.once.Do(fm.init)
|
||||
if fm.initErr != nil {
|
||||
log.Printf("contact form: misconfigured: %v", fm.initErr)
|
||||
http.Error(w, "contact form misconfigured", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := r.ParseMultipartForm(maxFormSize); err != nil {
|
||||
http.Error(w, "form too large or invalid", http.StatusBadRequest)
|
||||
@ -173,7 +220,6 @@ func (fm *FormMailer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Blocklist check — fail before any other processing.
|
||||
if fm.Blacklist != nil {
|
||||
if c := fm.Blacklist.Value(); c != nil && c.ContainsAddr(ip) {
|
||||
fm.writeError(w, fmt.Errorf("automated requests are not accepted"), false)
|
||||
@ -181,7 +227,6 @@ func (fm *FormMailer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// Geo check.
|
||||
if fm.AllowedCountries != nil {
|
||||
country := string(iploc.IPCountry(ip))
|
||||
if country != "" && !slices.Contains(fm.AllowedCountries, country) {
|
||||
@ -190,47 +235,75 @@ func (fm *FormMailer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limit.
|
||||
if !fm.allow(ipStr) {
|
||||
http.Error(w, "rate limit exceeded — please try again later", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
name := fm.Fields.get(r, fm.Fields.Name, "input_1")
|
||||
email := strings.ToLower(fm.Fields.get(r, fm.Fields.Email, "input_3"))
|
||||
phone := fm.Fields.get(r, fm.Fields.Phone, "input_4")
|
||||
company := fm.Fields.get(r, fm.Fields.Company, "input_5")
|
||||
message := fm.Fields.get(r, fm.Fields.Message, "input_7")
|
||||
|
||||
if err := validateLengths(name, email, phone, company, message); err != nil {
|
||||
fm.writeError(w, err, true)
|
||||
return
|
||||
}
|
||||
if err := validateNoHeaderInjection(name, email, company); err != nil {
|
||||
fm.writeError(w, err, true)
|
||||
return
|
||||
}
|
||||
if err := validatePhone(phone); err != nil {
|
||||
fm.writeError(w, err, true)
|
||||
return
|
||||
}
|
||||
if err := fm.validateEmailAndMX(r.Context(), email); err != nil {
|
||||
fm.writeError(w, err, true)
|
||||
return
|
||||
values := make([]string, len(fm.Fields))
|
||||
for i, f := range fm.Fields {
|
||||
v := strings.TrimSpace(r.FormValue(f.FormName))
|
||||
if f.Kind == KindEmail {
|
||||
v = strings.ToLower(v)
|
||||
}
|
||||
values[i] = v
|
||||
}
|
||||
|
||||
n := min(len(message), 100)
|
||||
log.Printf("contact form: ip=%s name=%q email=%q phone=%q company=%q message=%q",
|
||||
ipStr, name, email, phone, company, message[:n])
|
||||
for i, f := range fm.Fields {
|
||||
v := values[i]
|
||||
if f.Required && v == "" {
|
||||
fm.writeError(w, fmt.Errorf("%w: %s", ErrMissingRequired, f.Label), true)
|
||||
return
|
||||
}
|
||||
if len(v) > f.maxLen() {
|
||||
fm.writeError(w, ErrContentTooLong, true)
|
||||
return
|
||||
}
|
||||
// Header-injection check: all fields except free-form message bodies.
|
||||
if f.Kind != KindMessage && strings.ContainsAny(v, "\r\n") {
|
||||
fm.writeError(w, ErrInvalidNewlines, true)
|
||||
return
|
||||
}
|
||||
switch f.Kind {
|
||||
case KindPhone:
|
||||
if err := validatePhone(v); err != nil {
|
||||
fm.writeError(w, err, true)
|
||||
return
|
||||
}
|
||||
case KindEmail:
|
||||
if err := fm.validateEmailAndMX(r.Context(), v); err != nil {
|
||||
fm.writeError(w, err, true)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
email := values[fm.emailIdx]
|
||||
|
||||
var logBuf strings.Builder
|
||||
fmt.Fprintf(&logBuf, "contact form: ip=%s", ipStr)
|
||||
for i, f := range fm.Fields {
|
||||
v := values[i]
|
||||
if len(v) > 100 {
|
||||
v = v[:100]
|
||||
}
|
||||
fmt.Fprintf(&logBuf, " %s=%q", f.Label, v)
|
||||
}
|
||||
log.Print(logBuf.String())
|
||||
|
||||
subject := strings.ReplaceAll(fm.Subject, "{.Email}", email)
|
||||
body := fmt.Sprintf(
|
||||
"New contact form submission:\n\nName: %s\nEmail: %s\nPhone: %s\nCompany: %s\nMessage:\n%s\n",
|
||||
name, email, phone, company, message,
|
||||
)
|
||||
var body strings.Builder
|
||||
body.WriteString("New contact form submission:\n\n")
|
||||
for i, f := range fm.Fields {
|
||||
if f.Kind == KindMessage {
|
||||
fmt.Fprintf(&body, "%s:\n%s\n", f.Label, values[i])
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(&body, "%s: %s\n", f.Label, values[i])
|
||||
}
|
||||
msg := fmt.Appendf(nil,
|
||||
"To: %s\r\nFrom: %s\r\nReply-To: %s\r\nSubject: %s\r\n\r\n%s\r\n",
|
||||
strings.Join(fm.SMTPTo, ", "), fm.SMTPFrom, email, subject, body,
|
||||
strings.Join(fm.SMTPTo, ", "), fm.SMTPFrom, email, subject, body.String(),
|
||||
)
|
||||
|
||||
if err := fm.sendMail(r.Context(), msg); err != nil {
|
||||
@ -356,16 +429,10 @@ func (fm *FormMailer) allow(ipStr string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func validateLengths(name, email, phone, company, message string) error {
|
||||
if len(name) > maxNameLength || len(email) > maxEmailLength ||
|
||||
len(phone) > maxPhoneLength || len(company) > maxCompanyLength ||
|
||||
len(message) > maxMessageLength {
|
||||
return ErrContentTooLong
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fm *FormMailer) validateEmailAndMX(ctx context.Context, email string) error {
|
||||
if email == "" {
|
||||
return ErrInvalidEmail
|
||||
}
|
||||
if _, err := mail.ParseAddress(email); err != nil {
|
||||
return ErrInvalidEmail
|
||||
}
|
||||
@ -395,15 +462,6 @@ func validatePhone(phone string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateNoHeaderInjection(fields ...string) error {
|
||||
for _, f := range fields {
|
||||
if strings.ContainsAny(f, "\r\n") {
|
||||
return ErrInvalidNewlines
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// clientIP returns the originating IP, preferring X-Forwarded-For.
|
||||
func clientIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user