Commit 2a2a0347 by Iwasaki Yudai

Fix possible race condition on timeout

1 parent 9b8d2d5e
package server
import (
"sync"
"time"
)
type counter struct {
duration time.Duration
zeroTimer *time.Timer
wg sync.WaitGroup
connections int
mutex sync.Mutex
}
func newCounter(duration time.Duration) *counter {
zeroTimer := time.NewTimer(duration)
// when duration is 0, drain the expire event here
// so that user will never get the event.
if duration == 0 {
<-zeroTimer.C
}
return &counter{
duration: duration,
zeroTimer: zeroTimer,
}
}
func (counter *counter) add(n int) int {
counter.mutex.Lock()
defer counter.mutex.Unlock()
if counter.duration > 0 {
counter.zeroTimer.Stop()
}
counter.wg.Add(n)
counter.connections += n
return counter.connections
}
func (counter *counter) done() int {
counter.mutex.Lock()
defer counter.mutex.Unlock()
counter.connections--
counter.wg.Done()
if counter.connections == 0 && counter.duration > 0 {
counter.zeroTimer.Reset(counter.duration)
}
return counter.connections
}
func (counter *counter) count() int {
counter.mutex.Lock()
defer counter.mutex.Unlock()
return counter.connections
}
func (counter *counter) wait() {
counter.wg.Wait()
}
func (counter *counter) timer() *time.Timer {
return counter.zeroTimer
}
......@@ -8,9 +8,7 @@ import (
"log"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
......@@ -18,42 +16,32 @@ import (
"github.com/yudai/gotty/webtty"
)
func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc, connections *int64, wg *sync.WaitGroup) http.HandlerFunc {
func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc, counter *counter) http.HandlerFunc {
once := new(int64)
timer := time.NewTimer(time.Duration(server.options.Timeout) * time.Second)
if server.options.Timeout > 0 {
go func() {
select {
case <-timer.C:
cancel()
case <-ctx.Done():
}
}()
}
go func() {
select {
case <-counter.timer().C:
cancel()
case <-ctx.Done():
}
}()
return func(w http.ResponseWriter, r *http.Request) {
if server.options.Once {
success := atomic.CompareAndSwapInt64(once, 0, 1)
if !success {
http.Error(w, "Server is shutting down", http.StatusServiceUnavailable)
return
}
}
if server.options.Timeout > 0 {
timer.Stop()
}
wg.Add(1)
num := atomic.AddInt64(connections, 1)
num := counter.add(1)
closeReason := "unknown reason"
defer func() {
num := atomic.AddInt64(connections, -1)
if num == 0 && server.options.Timeout > 0 {
timer.Reset(time.Duration(server.options.Timeout) * time.Second)
}
num := counter.done()
log.Printf(
"Connection closed by %s: %s, connections: %d/%d",
closeReason, r.RemoteAddr, num, server.options.MaxConnection,
......@@ -62,12 +50,10 @@ func (server *Server) generateHandleWS(ctx context.Context, cancel context.Cance
if server.options.Once {
cancel()
}
wg.Done()
}()
if int64(server.options.MaxConnection) != 0 {
if num > int64(server.options.MaxConnection) {
if num > server.options.MaxConnection {
closeReason = "exceeding max number of connections"
return
}
......
......@@ -10,9 +10,8 @@ import (
"net"
"net/http"
"net/url"
"sync"
"sync/atomic"
noesctmpl "text/template"
"time"
"github.com/elazarl/go-bindata-assetfs"
"github.com/gorilla/websocket"
......@@ -81,12 +80,9 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
opt(opts)
}
// wg and connections can be incosistent because they are handled nonatomically
wg := new(sync.WaitGroup) // to wait all connections to be closed
connections := new(int64) // number of active connections
counter := newCounter(time.Duration(server.options.Timeout) * time.Second)
url := server.setupURL()
handlers := server.setupHandlers(cctx, cancel, url, connections, wg)
handlers := server.setupHandlers(cctx, cancel, url, counter)
srv, err := server.setupHTTPServer(handlers, url)
if err != nil {
return errors.Wrapf(err, "failed to setup an HTTP server")
......@@ -138,11 +134,11 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
err = cctx.Err()
}
conn := atomic.LoadInt64(connections)
conn := counter.count()
if conn > 0 {
log.Printf("Waiting for %d connections to be closed", conn)
}
wg.Wait()
counter.wait()
return err
}
......@@ -162,7 +158,7 @@ func (server *Server) setupURL() *url.URL {
return &url.URL{Scheme: scheme, Host: host, Path: path}
}
func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc, url *url.URL, connections *int64, wg *sync.WaitGroup) http.Handler {
func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc, url *url.URL, counter *counter) http.Handler {
staticFileHandler := http.FileServer(
&assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "static"},
)
......@@ -184,7 +180,7 @@ func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFu
wsMux := http.NewServeMux()
wsMux.Handle("/", siteHandler)
wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel, connections, wg))
wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel, counter))
siteHandler = http.Handler(wsMux)
return server.wrapLogger(siteHandler)
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!