Commit 9b8d2d5e by Iwasaki Yudai

Reduce struct variables of server.Server

1 parent 21899e63
......@@ -8,7 +8,9 @@ import (
"log"
"net/http"
"net/url"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
......@@ -16,46 +18,63 @@ import (
"github.com/yudai/gotty/webtty"
)
func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc) http.HandlerFunc {
func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc, connections *int64, wg *sync.WaitGroup) http.HandlerFunc {
once := new(int64)
timer := time.NewTimer(time.Duration(server.options.Timeout) * time.Second)
if server.options.Timeout > 0 {
go func() {
select {
case <-timer.C:
cancel()
case <-ctx.Done():
}
}()
}
return func(w http.ResponseWriter, r *http.Request) {
if server.options.Once {
if atomic.LoadInt64(server.once) > 0 {
success := atomic.CompareAndSwapInt64(once, 0, 1)
if !success {
http.Error(w, "Server is shutting down", http.StatusServiceUnavailable)
return
}
atomic.AddInt64(server.once, 1)
}
connections := atomic.AddInt64(server.connections, 1)
server.wsWG.Add(1)
server.stopTimer()
if server.options.Timeout > 0 {
timer.Stop()
}
wg.Add(1)
num := atomic.AddInt64(connections, 1)
closeReason := "unknown reason"
defer func() {
server.wsWG.Done()
connections := atomic.AddInt64(server.connections, -1)
if connections == 0 {
server.resetTimer()
num := atomic.AddInt64(connections, -1)
if num == 0 && server.options.Timeout > 0 {
timer.Reset(time.Duration(server.options.Timeout) * time.Second)
}
log.Printf(
"Connection closed by %s: %s, connections: %d/%d",
closeReason, r.RemoteAddr, connections, server.options.MaxConnection,
closeReason, r.RemoteAddr, num, server.options.MaxConnection,
)
if server.options.Once {
cancel()
}
wg.Done()
}()
log.Printf("New client connected: %s", r.RemoteAddr)
if int64(server.options.MaxConnection) != 0 {
if connections > int64(server.options.MaxConnection) {
if num > int64(server.options.MaxConnection) {
closeReason = "exceeding max number of connections"
return
}
}
log.Printf("New client connected: %s, connections: %d/%d", r.RemoteAddr, num, server.options.MaxConnection)
if r.Method != "GET" {
http.Error(w, "Method not allowed", 405)
return
......
......@@ -13,7 +13,6 @@ import (
"sync"
"sync/atomic"
noesctmpl "text/template"
"time"
"github.com/elazarl/go-bindata-assetfs"
"github.com/gorilla/websocket"
......@@ -29,18 +28,9 @@ type Server struct {
factory Factory
options *Options
srv *http.Server
upgrader *websocket.Upgrader
indexTemplate *template.Template
titleTemplate *noesctmpl.Template
titleVars map[string]interface{}
timer *time.Timer
wsWG sync.WaitGroup
url *url.URL // use URL()
connections *int64 // Use atomic operations
once *int64 // use atomic operations
}
// New creates a new instance of Server.
......@@ -51,7 +41,6 @@ func New(factory Factory, options *Options) (*Server, error) {
panic("index not found") // must be in bindata
}
if options.IndexFile != "" {
log.Printf("Using index file at " + options.IndexFile)
path := homedir.Expand(options.IndexFile)
indexData, err = ioutil.ReadFile(path)
if err != nil {
......@@ -68,9 +57,6 @@ func New(factory Factory, options *Options) (*Server, error) {
return nil, errors.Wrapf(err, "failed to parse window title format `%s`", options.TitleFormat)
}
connections := int64(0)
once := int64(0)
return &Server{
factory: factory,
options: options,
......@@ -82,8 +68,6 @@ func New(factory Factory, options *Options) (*Server, error) {
},
indexTemplate: indexTemplate,
titleTemplate: titleTemplate,
connections: &connections,
once: &once,
}, nil
}
......@@ -97,12 +81,18 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
opt(opts)
}
handlers := server.setupHandlers(cctx, cancel)
srv, err := server.setupHTTPServer(handlers)
// wg and connections can be incosistent because they are handled nonatomically
wg := new(sync.WaitGroup) // to wait all connections to be closed
connections := new(int64) // number of active connections
url := server.setupURL()
handlers := server.setupHandlers(cctx, cancel, url, connections, wg)
srv, err := server.setupHTTPServer(handlers, url)
if err != nil {
return errors.Wrapf(err, "failed to setup an HTTP server")
}
log.Printf("URL: %s", url.String())
if server.options.PermitWrite {
log.Printf("Permitting clients to write input to the PTY.")
}
......@@ -111,19 +101,6 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
log.Printf("Once option is provided, accepting only one client")
}
server.srv = srv
if server.options.Timeout > 0 {
server.timer = time.NewTimer(time.Duration(server.options.Timeout) * time.Second)
go func() {
select {
case <-server.timer.C:
cancel()
case <-cctx.Done():
}
}()
}
listenErr := make(chan error, 1)
go func() {
if server.options.EnableTLS {
......@@ -161,21 +138,35 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
err = cctx.Err()
}
conn := atomic.LoadInt64(server.connections)
conn := atomic.LoadInt64(connections)
if conn > 0 {
log.Printf("Waiting for %d connections to be closed", conn)
}
server.wsWG.Wait()
wg.Wait()
return err
}
func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc) http.Handler {
func (server *Server) setupURL() *url.URL {
host := net.JoinHostPort(server.options.Address, server.options.Port)
scheme := "http"
path := "/"
if server.options.EnableRandomUrl {
path = "/" + randomstring.Generate(server.options.RandomUrlLength) + "/"
}
if server.options.EnableTLS {
scheme = "https"
}
return &url.URL{Scheme: scheme, Host: host, Path: path}
}
func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc, url *url.URL, connections *int64, wg *sync.WaitGroup) http.Handler {
staticFileHandler := http.FileServer(
&assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "static"},
)
url := server.URL()
var siteMux = http.NewServeMux()
siteMux.HandleFunc(url.Path, server.handleIndex)
siteMux.Handle(url.Path+"js/", http.StripPrefix(url.Path, staticFileHandler))
......@@ -193,16 +184,13 @@ func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFu
wsMux := http.NewServeMux()
wsMux.Handle("/", siteHandler)
wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel))
wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel, connections, wg))
siteHandler = http.Handler(wsMux)
return server.wrapLogger(siteHandler)
}
func (server *Server) setupHTTPServer(handler http.Handler) (*http.Server, error) {
url := server.URL()
log.Printf("URL: %s", url.String())
func (server *Server) setupHTTPServer(handler http.Handler, url *url.URL) (*http.Server, error) {
srv := &http.Server{
Addr: url.Host,
Handler: handler,
......@@ -219,22 +207,6 @@ func (server *Server) setupHTTPServer(handler http.Handler) (*http.Server, error
return srv, nil
}
func (server *Server) URL() *url.URL {
if server.url == nil {
host := net.JoinHostPort(server.options.Address, server.options.Port)
path := ""
if server.options.EnableRandomUrl {
path += "/" + randomstring.Generate(server.options.RandomUrlLength)
}
scheme := "http"
if server.options.EnableTLS {
scheme = "https"
}
server.url = &url.URL{Scheme: scheme, Host: host, Path: path + "/"}
}
return server.url
}
func (server *Server) tlsConfig() (*tls.Config, error) {
caFile := homedir.Expand(server.options.TLSCACrtFile)
caCert, err := ioutil.ReadFile(caFile)
......
package server
import (
"time"
)
func (server *Server) stopTimer() {
if server.options.Timeout > 0 {
server.timer.Stop()
}
}
func (server *Server) resetTimer() {
if server.options.Timeout > 0 {
server.timer.Reset(time.Duration(server.options.Timeout) * time.Second)
}
}
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!