diff --git a/conn.go b/conn.go index ff43e0ed..df3870c2 100644 --- a/conn.go +++ b/conn.go @@ -23,10 +23,11 @@ import ( const errThreshold = 3 type Conn struct { - conn net.Conn - text *textproto.Conn - server *Server - helo string + conn net.Conn + text *textproto.Conn + server *Server + helo string + rejected bool // Number of errors witnessed on this connection errCount int @@ -105,6 +106,17 @@ func (c *Conn) handle(cmd string, arg string) { return } + //as per RFC5321 3.1 + if c.rejected { + if cmd == "QUIT" { + c.writeResponse(221, NoEnhancedCode, "OK") + c.Close() + } else { + c.protocolError(503, NoEnhancedCode, "bad sequence of commands") + } + return + } + cmd = strings.ToUpper(cmd) switch cmd { case "SEND", "SOML", "SAML", "EXPN", "HELP", "TURN": @@ -1276,7 +1288,13 @@ func (c *Conn) greet() { if c.server.LMTP { protocol = "LMTP" } - c.writeResponse(220, NoEnhancedCode, fmt.Sprintf("%v %s Service Ready", c.server.Domain, protocol)) + domain, err := c.server.GetDomain(c) + if err != nil { + c.writeResponse(554, NoEnhancedCode, "Error: "+err.Error()) + c.rejected = true + return + } + c.writeResponse(220, NoEnhancedCode, fmt.Sprintf("%v %s Service Ready", domain, protocol)) } func (c *Conn) writeResponse(code int, enhCode EnhancedCode, text ...string) { diff --git a/server.go b/server.go index e0e0acd0..8a32754c 100644 --- a/server.go +++ b/server.go @@ -40,6 +40,7 @@ type Server struct { ErrorLog Logger ReadTimeout time.Duration WriteTimeout time.Duration + domainFn func(*Conn) (string, error) // Advertise SMTPUTF8 (RFC 6531) capability. // Should be used only if backend supports it. @@ -148,6 +149,19 @@ func (s *Server) Serve(l net.Listener) error { } } +// Getter for the optional dynamic domain string generator function +func (s *Server) GetDomain(c *Conn) (string, error) { + if s.domainFn != nil { + return s.domainFn(c) + } + return s.Domain, nil +} + +// Setter for the dynamic domain string generator function +func (s *Server) SetDomainFunc(fn func(*Conn) (string, error)) { + s.domainFn = fn +} + func (s *Server) handleConn(c *Conn) error { s.locker.Lock() s.conns[c] = struct{}{} diff --git a/server_test.go b/server_test.go index 06db7468..1febebfe 100644 --- a/server_test.go +++ b/server_test.go @@ -1712,3 +1712,95 @@ func TestServerMTPRIORITY(t *testing.T) { t.Fatal("Incorrect MtPriority parameter value:", fmt.Sprintf("expected %d, got %d", expectedPriority, *priority)) } } + +func getDynamicDomainResponse(c *smtp.Conn) (string, error) { + return "dynamichost.local", nil +} + +func getNegativeDynamicDomainResponse(c *smtp.Conn) (string, error) { + return "", fmt.Errorf("no service") +} + +func testServerDynamicDomain(t *testing.T, fn ...serverConfigureFunc) (be *backend, s *smtp.Server, c net.Conn, scanner *bufio.Scanner) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + be = new(backend) + s = smtp.NewServer(be) + s.Domain = "localhost" + s.SetDomainFunc(getDynamicDomainResponse) + s.AllowInsecureAuth = true + for _, f := range fn { + f(s) + } + + go s.Serve(l) + + c, err = net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatal(err) + } + + scanner = bufio.NewScanner(c) + return +} + +func testServerNegativeDynamicDomain(t *testing.T, fn ...serverConfigureFunc) (be *backend, s *smtp.Server, c net.Conn, scanner *bufio.Scanner) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + + be = new(backend) + s = smtp.NewServer(be) + s.Domain = "localhost" + s.SetDomainFunc(getNegativeDynamicDomainResponse) + s.AllowInsecureAuth = true + for _, f := range fn { + f(s) + } + + go s.Serve(l) + + c, err = net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatal(err) + } + + scanner = bufio.NewScanner(c) + return +} + +func TestServerDynamicDomainGreeted(t *testing.T) { + _, _, _, scanner := testServerDynamicDomain(t) + + scanner.Scan() + if scanner.Text() != "220 dynamichost.local ESMTP Service Ready" { + t.Fatal("Invalid greeting:", scanner.Text()) + } +} + +func TestServerNegativeDynamicDomainGreeted(t *testing.T) { + _, _, c, scanner := testServerNegativeDynamicDomain(t) + + scanner.Scan() + if scanner.Text() != "554 Error: no service" { + t.Fatal("Invalid greeting:", scanner.Text()) + } + + //Now test for 503 error as per RFC521 3.1 + io.WriteString(c, "HELO localhost\r\n") + + scanner.Scan() + if !strings.HasPrefix(scanner.Text(), "503 bad sequence of commands") { + t.Fatal("Invalid HELO response:", scanner.Text()) + } + io.WriteString(c, "QUIT\r\n") + + scanner.Scan() + if !strings.HasPrefix(scanner.Text(), "221 OK") { + t.Fatal("Invalid HELO response:", scanner.Text()) + } +}