diff --git a/net/formmailer/formmailer.go b/net/formmailer/formmailer.go index c0667e0..be336be 100644 --- a/net/formmailer/formmailer.go +++ b/net/formmailer/formmailer.go @@ -1,23 +1,33 @@ // Package formmailer provides an HTTP handler that validates, rate-limits, // and emails contact form submissions. // +// Fields are declared as an ordered slice; each Field names the HTML input, +// the label for the email body, and the validation Kind. Exactly one Kind +// must be KindEmail — its value is used for Reply-To, Subject substitution, +// and the MX check. +// // Typical setup: // -// // Blacklist can be any snapshot source — e.g. *dataset.View[ipcohort.Cohort] -// // satisfies CohortSource directly via its Value() method. // blacklist := dataset.Add(set, func() (*ipcohort.Cohort, error) { ... }) // // fm := &formmailer.FormMailer{ -// SMTPHost: "smtp.example.com:587", -// SMTPFrom: "noreply@example.com", -// SMTPTo: []string{"contact@example.com"}, -// SMTPUser: "noreply@example.com", -// SMTPPass: os.Getenv("SMTP_PASS"), -// Subject: "Contact from {.Email}", +// SMTPHost: "smtp.example.com:587", +// SMTPFrom: "noreply@example.com", +// SMTPTo: []string{"contact@example.com"}, +// SMTPUser: "noreply@example.com", +// SMTPPass: os.Getenv("SMTP_PASS"), +// Subject: "Contact from {.Email}", +// Fields: []formmailer.Field{ +// {Label: "Name", FormName: "input_1", Kind: formmailer.KindText}, +// {Label: "Email", FormName: "input_3", Kind: formmailer.KindEmail}, +// {Label: "Phone", FormName: "input_4", Kind: formmailer.KindPhone}, +// {Label: "Company", FormName: "input_5", Kind: formmailer.KindText}, +// {Label: "Budget", FormName: "input_8", Kind: formmailer.KindText}, +// {Label: "Message", FormName: "input_7", Kind: formmailer.KindMessage}, +// }, // SuccessBody: successHTML, // ErrorBody: errorHTML, // Blacklist: blacklist, -// AllowedCountries: []string{"US", "CA", "MX"}, // } // http.Handle("POST /contact", fm) package formmailer @@ -47,12 +57,13 @@ import ( ) const ( - maxFormSize = 10 * 1024 - maxMessageLength = 4000 - maxCompanyLength = 200 - maxNameLength = 100 + maxFormSize = 10 * 1024 + + // Default per-Kind length caps; override with Field.MaxLen. maxEmailLength = 254 maxPhoneLength = 20 + maxTextLength = 200 + maxMessageLength = 4000 defaultRPM = 5 defaultBurst = 3 @@ -69,26 +80,45 @@ var ( ErrInvalidPhone = errors.New("phone number is not properly formatted") ErrContentTooLong = errors.New("one or more field values was too long") ErrInvalidNewlines = errors.New("invalid use of newlines or carriage returns") + ErrMissingRequired = errors.New("required field was empty") + ErrNoEmailField = errors.New("FormMailer.Fields must contain exactly one KindEmail field") phoneRe = regexp.MustCompile(`^[0-9+\-\(\) ]{7,20}$`) ) -// FormFields maps logical field names to the HTML form field names. -// Zero values use GravityForms-compatible defaults (input_1, input_3, etc.). -type FormFields struct { - Name string // default "input_1" - Email string // default "input_3" - Phone string // default "input_4" - Company string // default "input_5" - Message string // default "input_7" +// FieldKind picks validation rules and default length cap for a Field. +type FieldKind int + +const ( + KindText FieldKind = iota // default; length-capped text + KindEmail // RFC 5321 parse + MX lookup + KindPhone // phoneRe match + KindMessage // long free text (body of the submission) +) + +// Field declares one form input. Order is preserved in the email body. +type Field struct { + Label string // shown in email body, e.g. "Name" + FormName string // HTML form field name, e.g. "input_1" + Kind FieldKind // validation rules + default MaxLen + MaxLen int // 0 = default for Kind + Required bool // if true, empty value is rejected } -func (f FormFields) get(r *http.Request, field, def string) string { - key := field - if key == "" { - key = def +func (f Field) maxLen() int { + if f.MaxLen > 0 { + return f.MaxLen + } + switch f.Kind { + case KindEmail: + return maxEmailLength + case KindPhone: + return maxPhoneLength + case KindMessage: + return maxMessageLength + default: + return maxTextLength } - return strings.TrimSpace(r.FormValue(key)) } // FormMailer is an http.Handler that validates and emails contact form submissions. @@ -101,35 +131,35 @@ type FormMailer struct { SMTPPass string Subject string // may contain {.Email} - // SMTPTimeout bounds the entire connect+auth+send cycle. Zero uses 15s. + // SMTPTimeout bounds the entire connect+auth+send cycle. Zero uses 5s. SMTPTimeout time.Duration - // MXTimeout bounds the per-submission MX lookup. Zero uses 3s. + // MXTimeout bounds the per-submission MX lookup. Zero uses 2s. MXTimeout time.Duration // SuccessBody and ErrorBody are the response bodies sent to the client. // ErrorBody may contain {.Error} and {.SupportEmail} placeholders. - // Load from files before use: fm.SuccessBody, _ = os.ReadFile("success.html") SuccessBody []byte ErrorBody []byte ContentType string // inferred from SuccessBody if empty // Blacklist — if set, matching IPs are rejected before any other processing. - // Value() returns nil before the first successful load (no blocks applied). Blacklist *dataset.View[ipcohort.Cohort] // AllowedCountries — if non-nil, only requests from listed ISO codes are // accepted. Unknown country ("") is always allowed. - // Example: []string{"US", "CA", "MX"} AllowedCountries []string - // Fields maps logical names to HTML form field names. - Fields FormFields + // Fields declares the form inputs in display order. Exactly one entry + // must have Kind == KindEmail. + Fields []Field // RPM and Burst control per-IP rate limiting. Zero uses defaults (5/3). RPM int Burst int once sync.Once + initErr error + emailIdx int // index into Fields of the KindEmail entry mu sync.Mutex limiters map[string]*limiterEntry reqCount uint64 @@ -142,13 +172,25 @@ type limiterEntry struct { func (fm *FormMailer) init() { fm.limiters = make(map[string]*limiterEntry) + fm.emailIdx = -1 + for i, f := range fm.Fields { + if f.Kind == KindEmail { + if fm.emailIdx >= 0 { + fm.initErr = ErrNoEmailField + return + } + fm.emailIdx = i + } + } + if fm.emailIdx < 0 { + fm.initErr = ErrNoEmailField + } } func (fm *FormMailer) contentType() string { if fm.ContentType != "" { return fm.ContentType } - // Infer from SuccessBody sniff or leave as plain text. if bytes.Contains(fm.SuccessBody[:min(512, len(fm.SuccessBody))], []byte(" f.maxLen() { + fm.writeError(w, ErrContentTooLong, true) + return + } + // Header-injection check: all fields except free-form message bodies. + if f.Kind != KindMessage && strings.ContainsAny(v, "\r\n") { + fm.writeError(w, ErrInvalidNewlines, true) + return + } + switch f.Kind { + case KindPhone: + if err := validatePhone(v); err != nil { + fm.writeError(w, err, true) + return + } + case KindEmail: + if err := fm.validateEmailAndMX(r.Context(), v); err != nil { + fm.writeError(w, err, true) + return + } + } + } + + email := values[fm.emailIdx] + + var logBuf strings.Builder + fmt.Fprintf(&logBuf, "contact form: ip=%s", ipStr) + for i, f := range fm.Fields { + v := values[i] + if len(v) > 100 { + v = v[:100] + } + fmt.Fprintf(&logBuf, " %s=%q", f.Label, v) + } + log.Print(logBuf.String()) subject := strings.ReplaceAll(fm.Subject, "{.Email}", email) - body := fmt.Sprintf( - "New contact form submission:\n\nName: %s\nEmail: %s\nPhone: %s\nCompany: %s\nMessage:\n%s\n", - name, email, phone, company, message, - ) + var body strings.Builder + body.WriteString("New contact form submission:\n\n") + for i, f := range fm.Fields { + if f.Kind == KindMessage { + fmt.Fprintf(&body, "%s:\n%s\n", f.Label, values[i]) + continue + } + fmt.Fprintf(&body, "%s: %s\n", f.Label, values[i]) + } msg := fmt.Appendf(nil, "To: %s\r\nFrom: %s\r\nReply-To: %s\r\nSubject: %s\r\n\r\n%s\r\n", - strings.Join(fm.SMTPTo, ", "), fm.SMTPFrom, email, subject, body, + strings.Join(fm.SMTPTo, ", "), fm.SMTPFrom, email, subject, body.String(), ) if err := fm.sendMail(r.Context(), msg); err != nil { @@ -356,16 +429,10 @@ func (fm *FormMailer) allow(ipStr string) bool { return true } -func validateLengths(name, email, phone, company, message string) error { - if len(name) > maxNameLength || len(email) > maxEmailLength || - len(phone) > maxPhoneLength || len(company) > maxCompanyLength || - len(message) > maxMessageLength { - return ErrContentTooLong - } - return nil -} - func (fm *FormMailer) validateEmailAndMX(ctx context.Context, email string) error { + if email == "" { + return ErrInvalidEmail + } if _, err := mail.ParseAddress(email); err != nil { return ErrInvalidEmail } @@ -395,15 +462,6 @@ func validatePhone(phone string) error { return nil } -func validateNoHeaderInjection(fields ...string) error { - for _, f := range fields { - if strings.ContainsAny(f, "\r\n") { - return ErrInvalidNewlines - } - } - return nil -} - // clientIP returns the originating IP, preferring X-Forwarded-For. func clientIP(r *http.Request) string { if xff := r.Header.Get("X-Forwarded-For"); xff != "" {