Skip to content

Commit

Permalink
Sec web socket protocol (#11)
Browse files Browse the repository at this point in the history
* 回显子协议的值

* 更新

* 更新
  • Loading branch information
guonaihong authored Sep 2, 2023
1 parent 608a7f8 commit 8faa921
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 24 deletions.
30 changes: 17 additions & 13 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,35 +36,35 @@ var (
)

type DialOption struct {
Header http.Header
u *url.URL
tlsConfig *tls.Config
dialTimeout time.Duration
Header http.Header
u *url.URL
tlsConfig *tls.Config
dialTimeout time.Duration
bindClientHttpHeader *http.Header // 握手成功之后, 客户端获取http.Header,
Config
}

func ClientOptionToConf(opts ...ClientOption) *Config {
func ClientOptionToConf(opts ...ClientOption) *DialOption {
var dial DialOption
dial.defaultSetting()
for _, o := range opts {
o(&dial)
}
return &dial.Config
return &dial
}

func DialConf(rawUrl string, conf *Config) (*Conn, error) {
var dial DialOption
func DialConf(rawUrl string, conf *DialOption) (*Conn, error) {
u, err := url.Parse(rawUrl)
if err != nil {
return nil, err
}

dial.u = u
dial.dialTimeout = defaultTimeout
if dial.Header == nil {
dial.Header = make(http.Header)
conf.u = u
conf.dialTimeout = defaultTimeout
if conf.Header == nil {
conf.Header = make(http.Header)
}
return dial.Dial()
return conf.Dial()
}

// https://datatracker.ietf.org/doc/html/rfc6455#section-4.1
Expand Down Expand Up @@ -222,6 +222,10 @@ func (d *DialOption) Dial() (c *Conn, err error) {
return nil, err
}

if d.bindClientHttpHeader != nil {
*d.bindClientHttpHeader = rsp.Header.Clone()
}

cd := maybeCompressionDecompression(rsp.Header)
if d.decompression {
d.decompression = cd
Expand Down
50 changes: 50 additions & 0 deletions client_option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,54 @@ func Test_ClientOption(t *testing.T) {
t.Error("not run server:method fail")
}
})

t.Run("6.1 Dial: WithClientBindHTTPHeader", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := Upgrade(w, r)
if err != nil {
t.Error(err)
}
}))

defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
h := make(http.Header)
con, err := Dial(url, WithClientBindHTTPHeader(&h), WithClientHTTPHeader(http.Header{
"Sec-WebSocket-Protocol": []string{"token"},
}))
if err != nil {
t.Error(err)
}
defer con.Close()

if h["Sec-Websocket-Protocol"][0] != "token" {
t.Error("header fail")
}
})

t.Run("6.2 Dial: WithClientBindHTTPHeader", func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := Upgrade(w, r)
if err != nil {
t.Error(err)
}
}))

defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
h := make(http.Header)
con, err := DialConf(url, ClientOptionToConf(WithClientBindHTTPHeader(&h), WithClientHTTPHeader(http.Header{
"Sec-WebSocket-Protocol": []string{"token"},
})))
if err != nil {
t.Error(err)
}
defer con.Close()

if h["Sec-Websocket-Protocol"][0] != "token" {
t.Error("header fail")
}
})
}
17 changes: 12 additions & 5 deletions client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,38 +22,45 @@ import (

type ClientOption func(*DialOption)

// 配置tls.config
// 1.配置tls.config
func WithClientTLSConfig(tls *tls.Config) ClientOption {
return func(o *DialOption) {
o.tlsConfig = tls
}
}

// 配置http.Header
// 2.配置http.Header
func WithClientHTTPHeader(h http.Header) ClientOption {
return func(o *DialOption) {
o.Header = h
}
}

// 配置握手时的timeout
// 3.配置握手时的timeout
func WithClientDialTimeout(t time.Duration) ClientOption {
return func(o *DialOption) {
o.dialTimeout = t
}
}

// 配置压缩
// 4.配置压缩
func WithClientCompression() ClientOption {
return func(o *DialOption) {
o.compression = true
}
}

// 配置压缩和解压缩
// 5.配置压缩和解压缩
func WithClientDecompressAndCompress() ClientOption {
return func(o *DialOption) {
o.compression = true
o.decompression = true
}
}

// 6.获取http header
func WithClientBindHTTPHeader(h *http.Header) ClientOption {
return func(o *DialOption) {
o.bindClientHttpHeader = h
}
}
23 changes: 17 additions & 6 deletions server_handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ import (
)

var (
ErrNotFoundHijacker = errors.New("not found Hijacker")
bytesHeaderUpgrade = []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept:")
bytesHeaderExtensions = []byte("Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n")
bytesCRLF = []byte("\r\n")
bytesColon = []byte(": ")
ErrNotFoundHijacker = errors.New("not found Hijacker")
bytesHeaderUpgrade = []byte("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept:")
bytesHeaderExtensions = []byte("Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n")
bytesCRLF = []byte("\r\n")
strGetSecWebSocketProtocolKey = "Sec-WebSocket-Protocol"
bytesPutSecWebSocketProtocolKey = []byte("Sec-WebSocket-Protocol: ")
)

type ConnOption struct {
Expand Down Expand Up @@ -67,6 +68,17 @@ func prepareWriteResponse(r *http.Request, w io.Writer, cnf *Config) (err error)
}
}

v = r.Header.Get(strGetSecWebSocketProtocolKey)
if len(v) > 0 {
if _, err = w.Write(bytesPutSecWebSocketProtocolKey); err != nil {
return
}

if err = writeHeaderVal(w, StringToBytes(v)); err != nil {
return err
}
}

_, err = w.Write(bytesCRLF)
return err
}
Expand Down Expand Up @@ -111,7 +123,6 @@ func checkRequest(r *http.Request) (ecode int, err error) {
return http.StatusUpgradeRequired, ErrSecWebSocketVersion
}

// TODO Sec-WebSocket-Protocol
// TODO Sec-WebSocket-Extensions
return 0, nil
}

0 comments on commit 8faa921

Please sign in to comment.