Skip to content

Commit

Permalink
Extend DialOptions to allow Host header override
Browse files Browse the repository at this point in the history
Commit from coder#336
  • Loading branch information
dejan-lokar committed Oct 12, 2023
1 parent 14fb98e commit 1a02d4f
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
websocket.test
.idea
8 changes: 8 additions & 0 deletions dial.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build !js
// +build !js

package websocket
Expand Down Expand Up @@ -30,6 +31,10 @@ type DialOptions struct {
// HTTPHeader specifies the HTTP headers included in the handshake request.
HTTPHeader http.Header

// Host optionally overrides the Host HTTP header to send. If empty, the value
// of URL.Host will be used.
Host string

// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
Subprotocols []string

Expand Down Expand Up @@ -158,6 +163,9 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
}

req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if len(opts.Host) > 0 {
req.Host = opts.Host
}
req.Header = opts.HTTPHeader.Clone()
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
Expand Down
51 changes: 51 additions & 0 deletions dial_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build !js
// +build !js

package websocket
Expand Down Expand Up @@ -108,6 +109,56 @@ func TestBadDials(t *testing.T) {
})
}

func Test_verifyHostOverride(t *testing.T) {
testCases := []struct {
name string
host string
exp string
}{
{
name: "noOverride",
host: "",
exp: "example.com",
},
{
name: "hostOverride",
host: "example.net",
exp: "example.net",
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

rt := func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "Host", tc.exp, r.Host)

h := http.Header{}
h.Set("Connection", "Upgrade")
h.Set("Upgrade", "websocket")
h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))

return &http.Response{
StatusCode: http.StatusSwitchingProtocols,
Header: h,
Body: ioutil.NopCloser(strings.NewReader("hi")),
}, nil
}

_, _, _ = Dial(ctx, "ws://example.com", &DialOptions{
HTTPClient: mockHTTPClient(rt),
Host: tc.host,
})
})
}

}

func Test_verifyServerHandshake(t *testing.T) {
t.Parallel()

Expand Down

0 comments on commit 1a02d4f

Please sign in to comment.