Skip to content

Commit

Permalink
feat: support for setting callbacks when calling asynchronous APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
panjf2000 committed Jan 9, 2022
1 parent 8b84977 commit 139098f
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 35 deletions.
6 changes: 3 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ func (cli *Client) Start() error {

// Stop stops the client event-loop.
func (cli *Client) Stop() (err error) {
err = cli.el.poller.UrgentTrigger(func(_ interface{}) error { return gerrors.ErrEngineShutdown }, nil)
logging.Error(cli.el.poller.UrgentTrigger(func(_ interface{}) error { return gerrors.ErrEngineShutdown }, nil))
cli.el.engine.wg.Wait()
cli.el.poller.Close()
logging.Error(cli.el.poller.Close())
cli.el.eventHandler.OnShutdown(Engine{})
// Stop the ticker.
if cli.opts.Ticker {
Expand Down Expand Up @@ -199,7 +199,7 @@ func (cli *Client) Dial(network, address string) (Conn, error) {
}
err = cli.el.poller.UrgentTrigger(cli.el.register, gc)
if err != nil {
gc.Close()
gc.Close(nil)
return nil, err
}
return gc, nil
Expand Down
7 changes: 3 additions & 4 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,10 @@ func (s *testClientServer) OnTraffic(c Conn) (action Action) {
_ = c.InboundBuffered()
_ = c.OutboundBuffered()
_, _ = c.Discard(1)

}
_ = s.workerPool.Submit(
func() {
_ = c.AsyncWrite(buf.Bytes())
_ = c.AsyncWrite(buf.Bytes(), nil)
})
return
}
Expand Down Expand Up @@ -316,7 +315,7 @@ func startGnetClient(t *testing.T, cli *Client, ev *clientEvents, network, addr
rand.Seed(time.Now().UnixNano())
c, err := cli.Dial(network, addr)
require.NoError(t, err)
defer c.Close()
defer c.Close(nil)
var rspCh chan []byte
if network == "udp" {
rspCh = make(chan []byte, 1)
Expand Down Expand Up @@ -347,7 +346,7 @@ func startGnetClient(t *testing.T, cli *Client, ev *clientEvents, network, addr
}
_, err = rand.Read(reqData)
require.NoError(t, err)
err = c.AsyncWrite(reqData)
err = c.AsyncWrite(reqData, nil)
require.NoError(t, err)
respData := <-rspCh
require.NoError(t, err)
Expand Down
71 changes: 55 additions & 16 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,40 @@ func (c *conn) writev(bs [][]byte) (err error) {
return
}

func (c *conn) asyncWrite(itf interface{}) error {
type asyncWriteHook struct {
callback AsyncCallback
data []byte
}

func (c *conn) asyncWrite(itf interface{}) (err error) {
if !c.opened {
return nil
}

return c.write(itf.([]byte))
hook := itf.(*asyncWriteHook)
err = c.write(hook.data)
if hook.callback != nil {
_ = hook.callback(c)
}
return
}

func (c *conn) asyncWritev(itf interface{}) error {
type asyncWritevHook struct {
callback AsyncCallback
data [][]byte
}

func (c *conn) asyncWritev(itf interface{}) (err error) {
if !c.opened {
return nil
}

return c.writev(itf.([][]byte))
hook := itf.(*asyncWritevHook)
err = c.writev(hook.data)
if hook.callback != nil {
_ = hook.callback(c)
}
return
}

func (c *conn) sendTo(buf []byte) error {
Expand Down Expand Up @@ -401,29 +421,48 @@ func (c *conn) RemoteAddr() net.Addr { return c.remoteAddr }

// ==================================== Concurrency-safe API's ====================================

func (c *conn) AsyncWrite(buf []byte) error {
func (c *conn) AsyncWrite(buf []byte, callback AsyncCallback) (err error) {
if c.isDatagram {
return c.sendTo(buf)
err = c.sendTo(buf)
if callback != nil {
_ = callback(c)
}
return
}
return c.loop.poller.Trigger(c.asyncWrite, buf)
return c.loop.poller.Trigger(c.asyncWrite, &asyncWriteHook{callback, buf})
}

func (c *conn) AsyncWritev(bs [][]byte) error {
func (c *conn) AsyncWritev(bs [][]byte, callback AsyncCallback) (err error) {
if c.isDatagram {
for _, b := range bs {
if err := c.sendTo(b); err != nil {
return err
if err = c.sendTo(b); err != nil {
return
}
}
return nil
if callback != nil {
_ = callback(c)
}
return
}
return c.loop.poller.Trigger(c.asyncWritev, bs)
return c.loop.poller.Trigger(c.asyncWritev, &asyncWritevHook{callback, bs})
}

func (c *conn) Wake() error {
return c.loop.poller.UrgentTrigger(func(_ interface{}) error { return c.loop.wake(c) }, nil)
func (c *conn) Wake(callback AsyncCallback) error {
return c.loop.poller.UrgentTrigger(func(_ interface{}) (err error) {
err = c.loop.wake(c)
if callback != nil {
_ = callback(c)
}
return
}, nil)
}

func (c *conn) Close() error {
return c.loop.poller.Trigger(func(_ interface{}) error { return c.loop.closeConn(c, nil) }, nil)
func (c *conn) Close(callback AsyncCallback) error {
return c.loop.poller.Trigger(func(_ interface{}) (err error) {
err = c.loop.closeConn(c, nil)
if callback != nil {
_ = callback(c)
}
return
}, nil)
}
16 changes: 11 additions & 5 deletions gnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,18 @@ type Writer interface {

// AsyncWrite writes one byte slice to peer asynchronously, usually you would call it in individual goroutines
// instead of the event-loop goroutines.
AsyncWrite(buf []byte) (err error)
AsyncWrite(buf []byte, callback AsyncCallback) (err error)

// AsyncWritev writes multiple byte slices to peer asynchronously, usually you would call it in individual goroutines
// instead of the event-loop goroutines.
AsyncWritev(bs [][]byte) (err error)
AsyncWritev(bs [][]byte, callback AsyncCallback) (err error)
}

// AsyncCallback is a callback which will be invoked after the asynchronous functions has finished executing.
//
// Note that the parameter gnet.Conn is already released under UDP protocol, thus it's not allowed to be accessed.
type AsyncCallback func(c Conn) error

// Conn is an interface of underlying connection.
type Conn interface {
Reader
Expand Down Expand Up @@ -167,10 +172,11 @@ type Conn interface {
// ==================================== Concurrency-safe API's ====================================

// Wake triggers a OnTraffic event for the connection.
Wake() (err error)
Wake(callback AsyncCallback) (err error)

// Close closes the current connection.
Close() (err error)
// Close closes the current connection, usually you don't need to pass a non-nil callback
// because you should use OnClose() instead, the callback here is only for compatibility.
Close(callback AsyncCallback) (err error)
}

type (
Expand Down
23 changes: 16 additions & 7 deletions gnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,22 @@ func (s *testServer) OnTraffic(c Conn) (action Action) {
bs := make([][]byte, 2)
bs[0] = buf.B[:mid]
bs[1] = buf.B[mid:]
_ = c.AsyncWritev(bs)
_ = c.AsyncWritev(bs, func(c Conn) error {
logging.Debugf("conn=%s done writev", c.RemoteAddr().String())
return nil
})
} else {
_ = c.AsyncWrite(buf.Bytes())
_ = c.AsyncWrite(buf.Bytes(), func(c Conn) error {
logging.Debugf("conn=%s done write", c.RemoteAddr().String())
return nil
})
}
})
return
} else if s.network == "udp" {
_ = s.workerPool.Submit(
func() {
_ = c.AsyncWrite(buf.Bytes())
_ = c.AsyncWrite(buf.Bytes(), nil)
})
return
}
Expand Down Expand Up @@ -508,7 +514,10 @@ func (t *testWakeConnServer) OnTick() (delay time.Duration, action Action) {
return
}
t.c = <-t.conn
_ = t.c.Wake()
_ = t.c.Wake(func(c Conn) error {
logging.Debugf("conn=%s done wake", c.RemoteAddr().String())
return nil
})
delay = time.Millisecond * 100
return
}
Expand Down Expand Up @@ -819,7 +828,7 @@ func (t *testCloseConnectionServer) OnTraffic(c Conn) (action Action) {
_, _ = c.Discard(-1)
go func() {
time.Sleep(time.Second)
_ = c.Close()
_ = c.Close(nil)
}()
return
}
Expand Down Expand Up @@ -967,8 +976,8 @@ func (s *testClosedWakeUpServer) OnTraffic(c Conn) Action {
close(s.wakeup)
}

go func() { require.NoError(s.tester, c.Wake()) }()
go func() { require.NoError(s.tester, c.Close()) }()
go func() { require.NoError(s.tester, c.Wake(nil)) }()
go func() { require.NoError(s.tester, c.Close(nil)) }()

<-s.clientClosed

Expand Down

0 comments on commit 139098f

Please sign in to comment.