diff --git a/conn.go b/conn.go index 43c8df29f..fe0f6a7cf 100644 --- a/conn.go +++ b/conn.go @@ -132,6 +132,12 @@ type conn struct { // round-trip mode for non-prepared Query calls. binaryParameters bool + // Timeouts for read and write operations against the database server. + // A duration of 0 indicates no timeout. + // specified in milliseconds + readTimeout time.Duration + writeTimeout time.Duration + // If true this connection is in the middle of a COPY inCopy bool } @@ -151,6 +157,28 @@ func (cn *conn) handleDriverSettings(o values) (err error) { return nil } + timeSetting := func(key string, val *time.Duration) error { + if value := o[key]; value != "" { + timeout, err := strconv.Atoi(value) + if err != nil { + return err + } + // timeout is specified in milliseconds. + *val = time.Duration(timeout) * time.Millisecond + } else { + *val = time.Duration(0) * time.Millisecond + } + return nil + } + + err = timeSetting("read_timeout", &cn.readTimeout) + if err != nil { + return err + } + err = timeSetting("write_timeout", &cn.writeTimeout) + if err != nil { + return err + } err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult) if err != nil { return err @@ -911,15 +939,47 @@ func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err return r, err } +func (cn *conn) setWriteTimeout() { + // Set the write deadline if we have a write timeout set. + if cn.writeTimeout != 0 { + cn.c.SetWriteDeadline(time.Now().Add(cn.writeTimeout)) + } +} + +func (cn *conn) cleanWriteTimeout() { + if cn.writeTimeout != 0 { + // Clear the write deadline if we set one. + cn.c.SetWriteDeadline(time.Time{}) + } +} + +func (cn *conn) setReadTimeout() { + // Set the read deadline if we have a read timeout set. + if cn.readTimeout != 0 { + cn.c.SetReadDeadline(time.Now().Add(cn.readTimeout)) + } +} + +func (cn *conn) cleanReadTimeout() { + if cn.readTimeout != 0 { + // Clear the read deadline if we set one. + cn.c.SetReadDeadline(time.Time{}) + } +} + func (cn *conn) send(m *writeBuf) { + cn.setWriteTimeout() _, err := cn.c.Write(m.wrap()) if err != nil { panic(err) } + cn.cleanWriteTimeout() } func (cn *conn) sendStartupPacket(m *writeBuf) error { + cn.setWriteTimeout() _, err := cn.c.Write((m.wrap())[1:]) + cn.cleanWriteTimeout() return err } @@ -927,7 +987,9 @@ func (cn *conn) sendStartupPacket(m *writeBuf) error { // message should have no payload. This method does not use the scratch // buffer. func (cn *conn) sendSimpleMessage(typ byte) (err error) { + cn.setWriteTimeout() _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'}) + cn.cleanWriteTimeout() return err } @@ -957,9 +1019,11 @@ func (cn *conn) recvMessage(r *readBuf) (byte, error) { return t, nil } + cn.setReadTimeout() x := cn.scratch[:5] _, err := io.ReadFull(cn.buf, x) if err != nil { + cn.cleanReadTimeout() return 0, err } @@ -972,11 +1036,14 @@ func (cn *conn) recvMessage(r *readBuf) (byte, error) { } else { y = make([]byte, n) } + cn.setReadTimeout() _, err = io.ReadFull(cn.buf, y) if err != nil { + cn.cleanReadTimeout() return 0, err } *r = y + cn.cleanReadTimeout() return t, nil } @@ -1083,6 +1150,10 @@ func isDriverSetting(key string) bool { return true case "binary_parameters": return true + case "read_timeout": + return true + case "write_timeout": + return true default: return false