Commit 2a2a0347 by Iwasaki Yudai

Fix possible race condition on timeout

1 parent 9b8d2d5e
package server
import (
"sync"
"time"
)
type counter struct {
duration time.Duration
zeroTimer *time.Timer
wg sync.WaitGroup
connections int
mutex sync.Mutex
}
func newCounter(duration time.Duration) *counter {
zeroTimer := time.NewTimer(duration)
// when duration is 0, drain the expire event here
// so that user will never get the event.
if duration == 0 {
<-zeroTimer.C
}
return &counter{
duration: duration,
zeroTimer: zeroTimer,
}
}
func (counter *counter) add(n int) int {
counter.mutex.Lock()
defer counter.mutex.Unlock()
if counter.duration > 0 {
counter.zeroTimer.Stop()
}
counter.wg.Add(n)
counter.connections += n
return counter.connections
}
func (counter *counter) done() int {
counter.mutex.Lock()
defer counter.mutex.Unlock()
counter.connections--
counter.wg.Done()
if counter.connections == 0 && counter.duration > 0 {
counter.zeroTimer.Reset(counter.duration)
}
return counter.connections
}
func (counter *counter) count() int {
counter.mutex.Lock()
defer counter.mutex.Unlock()
return counter.connections
}
func (counter *counter) wait() {
counter.wg.Wait()
}
func (counter *counter) timer() *time.Timer {
return counter.zeroTimer
}
...@@ -8,9 +8,7 @@ import ( ...@@ -8,9 +8,7 @@ import (
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
"sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/pkg/errors" "github.com/pkg/errors"
...@@ -18,42 +16,32 @@ import ( ...@@ -18,42 +16,32 @@ import (
"github.com/yudai/gotty/webtty" "github.com/yudai/gotty/webtty"
) )
func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc, connections *int64, wg *sync.WaitGroup) http.HandlerFunc { func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc, counter *counter) http.HandlerFunc {
once := new(int64) once := new(int64)
timer := time.NewTimer(time.Duration(server.options.Timeout) * time.Second)
if server.options.Timeout > 0 {
go func() { go func() {
select { select {
case <-timer.C: case <-counter.timer().C:
cancel() cancel()
case <-ctx.Done(): case <-ctx.Done():
} }
}() }()
}
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if server.options.Once { if server.options.Once {
success := atomic.CompareAndSwapInt64(once, 0, 1) success := atomic.CompareAndSwapInt64(once, 0, 1)
if !success { if !success {
http.Error(w, "Server is shutting down", http.StatusServiceUnavailable) http.Error(w, "Server is shutting down", http.StatusServiceUnavailable)
return return
} }
} }
if server.options.Timeout > 0 { num := counter.add(1)
timer.Stop()
}
wg.Add(1)
num := atomic.AddInt64(connections, 1)
closeReason := "unknown reason" closeReason := "unknown reason"
defer func() { defer func() {
num := atomic.AddInt64(connections, -1) num := counter.done()
if num == 0 && server.options.Timeout > 0 {
timer.Reset(time.Duration(server.options.Timeout) * time.Second)
}
log.Printf( log.Printf(
"Connection closed by %s: %s, connections: %d/%d", "Connection closed by %s: %s, connections: %d/%d",
closeReason, r.RemoteAddr, num, server.options.MaxConnection, closeReason, r.RemoteAddr, num, server.options.MaxConnection,
...@@ -62,12 +50,10 @@ func (server *Server) generateHandleWS(ctx context.Context, cancel context.Cance ...@@ -62,12 +50,10 @@ func (server *Server) generateHandleWS(ctx context.Context, cancel context.Cance
if server.options.Once { if server.options.Once {
cancel() cancel()
} }
wg.Done()
}() }()
if int64(server.options.MaxConnection) != 0 { if int64(server.options.MaxConnection) != 0 {
if num > int64(server.options.MaxConnection) { if num > server.options.MaxConnection {
closeReason = "exceeding max number of connections" closeReason = "exceeding max number of connections"
return return
} }
......
...@@ -10,9 +10,8 @@ import ( ...@@ -10,9 +10,8 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"sync"
"sync/atomic"
noesctmpl "text/template" noesctmpl "text/template"
"time"
"github.com/elazarl/go-bindata-assetfs" "github.com/elazarl/go-bindata-assetfs"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
...@@ -81,12 +80,9 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error { ...@@ -81,12 +80,9 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
opt(opts) opt(opts)
} }
// wg and connections can be incosistent because they are handled nonatomically counter := newCounter(time.Duration(server.options.Timeout) * time.Second)
wg := new(sync.WaitGroup) // to wait all connections to be closed
connections := new(int64) // number of active connections
url := server.setupURL() url := server.setupURL()
handlers := server.setupHandlers(cctx, cancel, url, connections, wg) handlers := server.setupHandlers(cctx, cancel, url, counter)
srv, err := server.setupHTTPServer(handlers, url) srv, err := server.setupHTTPServer(handlers, url)
if err != nil { if err != nil {
return errors.Wrapf(err, "failed to setup an HTTP server") return errors.Wrapf(err, "failed to setup an HTTP server")
...@@ -138,11 +134,11 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error { ...@@ -138,11 +134,11 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
err = cctx.Err() err = cctx.Err()
} }
conn := atomic.LoadInt64(connections) conn := counter.count()
if conn > 0 { if conn > 0 {
log.Printf("Waiting for %d connections to be closed", conn) log.Printf("Waiting for %d connections to be closed", conn)
} }
wg.Wait() counter.wait()
return err return err
} }
...@@ -162,7 +158,7 @@ func (server *Server) setupURL() *url.URL { ...@@ -162,7 +158,7 @@ func (server *Server) setupURL() *url.URL {
return &url.URL{Scheme: scheme, Host: host, Path: path} return &url.URL{Scheme: scheme, Host: host, Path: path}
} }
func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc, url *url.URL, connections *int64, wg *sync.WaitGroup) http.Handler { func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc, url *url.URL, counter *counter) http.Handler {
staticFileHandler := http.FileServer( staticFileHandler := http.FileServer(
&assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "static"}, &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "static"},
) )
...@@ -184,7 +180,7 @@ func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFu ...@@ -184,7 +180,7 @@ func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFu
wsMux := http.NewServeMux() wsMux := http.NewServeMux()
wsMux.Handle("/", siteHandler) wsMux.Handle("/", siteHandler)
wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel, connections, wg)) wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel, counter))
siteHandler = http.Handler(wsMux) siteHandler = http.Handler(wsMux)
return server.wrapLogger(siteHandler) return server.wrapLogger(siteHandler)
......
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!