Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add required tests for go/netutil #15392

Merged
merged 4 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions go/netutil/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,17 @@ package netutil

import (
"net"
"strings"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func createSocketPair(t *testing.T) (net.Listener, net.Conn, net.Conn) {
// Create a listener.
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen failed: %v", err)
}
assert.NoError(t, err)
addr := listener.Addr().String()

// Dial a client, Accept a server.
Expand All @@ -38,9 +37,7 @@ func createSocketPair(t *testing.T) (net.Listener, net.Conn, net.Conn) {
defer wg.Done()
var err error
clientConn, err = net.Dial("tcp", addr)
if err != nil {
t.Errorf("Dial failed: %v", err)
}
assert.NoError(t, err)
}()

var serverConn net.Conn
Expand All @@ -49,9 +46,7 @@ func createSocketPair(t *testing.T) (net.Listener, net.Conn, net.Conn) {
defer wg.Done()
var err error
serverConn, err = listener.Accept()
if err != nil {
t.Errorf("Accept failed: %v", err)
}
assert.NoError(t, err)
}()

wg.Wait()
Expand All @@ -77,13 +72,7 @@ func TestReadTimeout(t *testing.T) {

select {
case err := <-c:
if err == nil {
t.Fatalf("Expected error, got nil")
}

if !strings.HasSuffix(err.Error(), "i/o timeout") {
t.Errorf("Expected error timeout, got %s", err)
}
assert.ErrorContains(t, err, "i/o timeout", "Expected error timeout")
case <-time.After(10 * time.Second):
t.Errorf("Timeout did not happen")
}
Expand Down Expand Up @@ -113,13 +102,7 @@ func TestWriteTimeout(t *testing.T) {

select {
case err := <-c:
if err == nil {
t.Fatalf("Expected error, got nil")
}

if !strings.HasSuffix(err.Error(), "i/o timeout") {
t.Errorf("Expected error timeout, got %s", err)
}
assert.ErrorContains(t, err, "i/o timeout", "Expected error timeout")
case <-time.After(10 * time.Second):
t.Errorf("Timeout did not happen")
}
Expand Down Expand Up @@ -167,3 +150,42 @@ func TestNoTimeouts(t *testing.T) {
// NOOP
}
}

func TestSetDeadline(t *testing.T) {
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
sConn.Close()
cConn.Close()
}()

cConnWithTimeout := NewConnWithTimeouts(cConn, 0, 24*time.Hour)

assert.Panics(t, func() { _ = cConnWithTimeout.SetDeadline(time.Now()) })
}

func TestSetReadDeadline(t *testing.T) {
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
sConn.Close()
cConn.Close()
}()

cConnWithTimeout := NewConnWithTimeouts(cConn, 0, 24*time.Hour)

assert.Panics(t, func() { _ = cConnWithTimeout.SetReadDeadline(time.Now()) })
}

func TestSetWriteDeadline(t *testing.T) {
listener, sConn, cConn := createSocketPair(t)
defer func() {
listener.Close()
sConn.Close()
cConn.Close()
}()

cConnWithTimeout := NewConnWithTimeouts(cConn, 0, 24*time.Hour)

assert.Panics(t, func() { _ = cConnWithTimeout.SetWriteDeadline(time.Now()) })
}
52 changes: 37 additions & 15 deletions go/netutil/netutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ limitations under the License.
package netutil

import (
"net"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSplitHostPort(t *testing.T) {
Expand All @@ -33,12 +36,9 @@ func TestSplitHostPort(t *testing.T) {
}
for input, want := range table {
gotHost, gotPort, err := SplitHostPort(input)
if err != nil {
t.Errorf("SplitHostPort error: %v", err)
}
if gotHost != want.host || gotPort != want.port {
t.Errorf("SplitHostPort(%#v) = (%v, %v), want (%v, %v)", input, gotHost, gotPort, want.host, want.port)
}
assert.NoError(t, err)
assert.Equal(t, want.host, gotHost)
assert.Equal(t, want.port, gotPort)
}
}

Expand All @@ -50,9 +50,7 @@ func TestSplitHostPortFail(t *testing.T) {
}
for _, input := range inputs {
_, _, err := SplitHostPort(input)
if err == nil {
t.Errorf("expected error from SplitHostPort(%q), but got none", input)
}
assert.Error(t, err)
}
}

Expand All @@ -66,9 +64,7 @@ func TestJoinHostPort(t *testing.T) {
"[::1]:321": {host: "::1", port: 321},
}
for want, input := range table {
if got := JoinHostPort(input.host, input.port); got != want {
t.Errorf("SplitHostPort(%v, %v) = %#v, want %#v", input.host, input.port, got, want)
}
assert.Equal(t, want, JoinHostPort(input.host, input.port))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

outside the scope of this PR, but should we delete this function in favor of the stdlib net.JoinHostPort? cc @mattlord, the signature is slightly different (we take an int32 port whereas stdlib takes string)

}
}

Expand All @@ -83,8 +79,34 @@ func TestNormalizeIP(t *testing.T) {
"127.": "127.",
}
for input, want := range table {
if got := NormalizeIP(input); got != want {
t.Errorf("NormalizeIP(%#v) = %#v, want %#v", input, got, want)
}
assert.Equal(t, want, NormalizeIP(input))
}
}

func TestDNSTracker(t *testing.T) {
refresh := DNSTracker("localhost")
_, err := refresh()
assert.NoError(t, err)

refresh = DNSTracker("")
val, err := refresh()
assert.NoError(t, err)
assert.False(t, val, "DNS name resolution should not have changed")
}

func TestAddrEqual(t *testing.T) {
addr1 := net.ParseIP("1.2.3.4")
addr2 := net.ParseIP("127.0.0.1")

addrSet1 := []net.IP{addr1, addr2}
addrSet2 := []net.IP{addr1}
addrSet3 := []net.IP{addr2}
ok := addrEqual(addrSet1, addrSet2)
assert.False(t, ok, "addresses %q and %q should not be equal", addrSet1, addrSet2)

ok = addrEqual(addrSet3, addrSet2)
assert.False(t, ok, "addresses %q and %q should not be equal", addrSet3, addrSet2)

ok = addrEqual(addrSet1, addrSet1)
assert.True(t, ok, "addresses %q and %q should be equal", addrSet1, addrSet1)
}
Loading