Commit 9b8d2d5e by Iwasaki Yudai

Reduce struct variables of server.Server

1 parent 21899e63
...@@ -8,7 +8,9 @@ import ( ...@@ -8,7 +8,9 @@ import (
"log" "log"
"net/http" "net/http"
"net/url" "net/url"
"sync"
"sync/atomic" "sync/atomic"
"time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/pkg/errors" "github.com/pkg/errors"
...@@ -16,46 +18,63 @@ import ( ...@@ -16,46 +18,63 @@ import (
"github.com/yudai/gotty/webtty" "github.com/yudai/gotty/webtty"
) )
func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc) http.HandlerFunc { func (server *Server) generateHandleWS(ctx context.Context, cancel context.CancelFunc, connections *int64, wg *sync.WaitGroup) http.HandlerFunc {
once := new(int64)
timer := time.NewTimer(time.Duration(server.options.Timeout) * time.Second)
if server.options.Timeout > 0 {
go func() {
select {
case <-timer.C:
cancel()
case <-ctx.Done():
}
}()
}
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if server.options.Once { if server.options.Once {
if atomic.LoadInt64(server.once) > 0 { success := atomic.CompareAndSwapInt64(once, 0, 1)
if !success {
http.Error(w, "Server is shutting down", http.StatusServiceUnavailable) http.Error(w, "Server is shutting down", http.StatusServiceUnavailable)
return return
} }
atomic.AddInt64(server.once, 1)
} }
connections := atomic.AddInt64(server.connections, 1)
server.wsWG.Add(1) if server.options.Timeout > 0 {
server.stopTimer() timer.Stop()
}
wg.Add(1)
num := atomic.AddInt64(connections, 1)
closeReason := "unknown reason" closeReason := "unknown reason"
defer func() { defer func() {
server.wsWG.Done() num := atomic.AddInt64(connections, -1)
if num == 0 && server.options.Timeout > 0 {
connections := atomic.AddInt64(server.connections, -1) timer.Reset(time.Duration(server.options.Timeout) * time.Second)
if connections == 0 {
server.resetTimer()
} }
log.Printf( log.Printf(
"Connection closed by %s: %s, connections: %d/%d", "Connection closed by %s: %s, connections: %d/%d",
closeReason, r.RemoteAddr, connections, server.options.MaxConnection, closeReason, r.RemoteAddr, num, server.options.MaxConnection,
) )
if server.options.Once { if server.options.Once {
cancel() cancel()
} }
wg.Done()
}() }()
log.Printf("New client connected: %s", r.RemoteAddr)
if int64(server.options.MaxConnection) != 0 { if int64(server.options.MaxConnection) != 0 {
if connections > int64(server.options.MaxConnection) { if num > int64(server.options.MaxConnection) {
closeReason = "exceeding max number of connections" closeReason = "exceeding max number of connections"
return return
} }
} }
log.Printf("New client connected: %s, connections: %d/%d", r.RemoteAddr, num, server.options.MaxConnection)
if r.Method != "GET" { if r.Method != "GET" {
http.Error(w, "Method not allowed", 405) http.Error(w, "Method not allowed", 405)
return return
......
...@@ -13,7 +13,6 @@ import ( ...@@ -13,7 +13,6 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
noesctmpl "text/template" noesctmpl "text/template"
"time"
"github.com/elazarl/go-bindata-assetfs" "github.com/elazarl/go-bindata-assetfs"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
...@@ -29,18 +28,9 @@ type Server struct { ...@@ -29,18 +28,9 @@ type Server struct {
factory Factory factory Factory
options *Options options *Options
srv *http.Server upgrader *websocket.Upgrader
upgrader *websocket.Upgrader
indexTemplate *template.Template indexTemplate *template.Template
titleTemplate *noesctmpl.Template titleTemplate *noesctmpl.Template
titleVars map[string]interface{}
timer *time.Timer
wsWG sync.WaitGroup
url *url.URL // use URL()
connections *int64 // Use atomic operations
once *int64 // use atomic operations
} }
// New creates a new instance of Server. // New creates a new instance of Server.
...@@ -51,7 +41,6 @@ func New(factory Factory, options *Options) (*Server, error) { ...@@ -51,7 +41,6 @@ func New(factory Factory, options *Options) (*Server, error) {
panic("index not found") // must be in bindata panic("index not found") // must be in bindata
} }
if options.IndexFile != "" { if options.IndexFile != "" {
log.Printf("Using index file at " + options.IndexFile)
path := homedir.Expand(options.IndexFile) path := homedir.Expand(options.IndexFile)
indexData, err = ioutil.ReadFile(path) indexData, err = ioutil.ReadFile(path)
if err != nil { if err != nil {
...@@ -68,9 +57,6 @@ func New(factory Factory, options *Options) (*Server, error) { ...@@ -68,9 +57,6 @@ func New(factory Factory, options *Options) (*Server, error) {
return nil, errors.Wrapf(err, "failed to parse window title format `%s`", options.TitleFormat) return nil, errors.Wrapf(err, "failed to parse window title format `%s`", options.TitleFormat)
} }
connections := int64(0)
once := int64(0)
return &Server{ return &Server{
factory: factory, factory: factory,
options: options, options: options,
...@@ -82,8 +68,6 @@ func New(factory Factory, options *Options) (*Server, error) { ...@@ -82,8 +68,6 @@ func New(factory Factory, options *Options) (*Server, error) {
}, },
indexTemplate: indexTemplate, indexTemplate: indexTemplate,
titleTemplate: titleTemplate, titleTemplate: titleTemplate,
connections: &connections,
once: &once,
}, nil }, nil
} }
...@@ -97,12 +81,18 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error { ...@@ -97,12 +81,18 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
opt(opts) opt(opts)
} }
handlers := server.setupHandlers(cctx, cancel) // wg and connections can be incosistent because they are handled nonatomically
srv, err := server.setupHTTPServer(handlers) wg := new(sync.WaitGroup) // to wait all connections to be closed
connections := new(int64) // number of active connections
url := server.setupURL()
handlers := server.setupHandlers(cctx, cancel, url, connections, wg)
srv, err := server.setupHTTPServer(handlers, url)
if err != nil { if err != nil {
return errors.Wrapf(err, "failed to setup an HTTP server") return errors.Wrapf(err, "failed to setup an HTTP server")
} }
log.Printf("URL: %s", url.String())
if server.options.PermitWrite { if server.options.PermitWrite {
log.Printf("Permitting clients to write input to the PTY.") log.Printf("Permitting clients to write input to the PTY.")
} }
...@@ -111,19 +101,6 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error { ...@@ -111,19 +101,6 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
log.Printf("Once option is provided, accepting only one client") log.Printf("Once option is provided, accepting only one client")
} }
server.srv = srv
if server.options.Timeout > 0 {
server.timer = time.NewTimer(time.Duration(server.options.Timeout) * time.Second)
go func() {
select {
case <-server.timer.C:
cancel()
case <-cctx.Done():
}
}()
}
listenErr := make(chan error, 1) listenErr := make(chan error, 1)
go func() { go func() {
if server.options.EnableTLS { if server.options.EnableTLS {
...@@ -161,21 +138,35 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error { ...@@ -161,21 +138,35 @@ func (server *Server) Run(ctx context.Context, options ...RunOption) error {
err = cctx.Err() err = cctx.Err()
} }
conn := atomic.LoadInt64(server.connections) conn := atomic.LoadInt64(connections)
if conn > 0 { if conn > 0 {
log.Printf("Waiting for %d connections to be closed", conn) log.Printf("Waiting for %d connections to be closed", conn)
} }
server.wsWG.Wait() wg.Wait()
return err return err
} }
func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc) http.Handler { func (server *Server) setupURL() *url.URL {
host := net.JoinHostPort(server.options.Address, server.options.Port)
scheme := "http"
path := "/"
if server.options.EnableRandomUrl {
path = "/" + randomstring.Generate(server.options.RandomUrlLength) + "/"
}
if server.options.EnableTLS {
scheme = "https"
}
return &url.URL{Scheme: scheme, Host: host, Path: path}
}
func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFunc, url *url.URL, connections *int64, wg *sync.WaitGroup) http.Handler {
staticFileHandler := http.FileServer( staticFileHandler := http.FileServer(
&assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "static"}, &assetfs.AssetFS{Asset: Asset, AssetDir: AssetDir, Prefix: "static"},
) )
url := server.URL()
var siteMux = http.NewServeMux() var siteMux = http.NewServeMux()
siteMux.HandleFunc(url.Path, server.handleIndex) siteMux.HandleFunc(url.Path, server.handleIndex)
siteMux.Handle(url.Path+"js/", http.StripPrefix(url.Path, staticFileHandler)) siteMux.Handle(url.Path+"js/", http.StripPrefix(url.Path, staticFileHandler))
...@@ -193,16 +184,13 @@ func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFu ...@@ -193,16 +184,13 @@ func (server *Server) setupHandlers(ctx context.Context, cancel context.CancelFu
wsMux := http.NewServeMux() wsMux := http.NewServeMux()
wsMux.Handle("/", siteHandler) wsMux.Handle("/", siteHandler)
wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel)) wsMux.HandleFunc(url.Path+"ws", server.generateHandleWS(ctx, cancel, connections, wg))
siteHandler = http.Handler(wsMux) siteHandler = http.Handler(wsMux)
return server.wrapLogger(siteHandler) return server.wrapLogger(siteHandler)
} }
func (server *Server) setupHTTPServer(handler http.Handler) (*http.Server, error) { func (server *Server) setupHTTPServer(handler http.Handler, url *url.URL) (*http.Server, error) {
url := server.URL()
log.Printf("URL: %s", url.String())
srv := &http.Server{ srv := &http.Server{
Addr: url.Host, Addr: url.Host,
Handler: handler, Handler: handler,
...@@ -219,22 +207,6 @@ func (server *Server) setupHTTPServer(handler http.Handler) (*http.Server, error ...@@ -219,22 +207,6 @@ func (server *Server) setupHTTPServer(handler http.Handler) (*http.Server, error
return srv, nil return srv, nil
} }
func (server *Server) URL() *url.URL {
if server.url == nil {
host := net.JoinHostPort(server.options.Address, server.options.Port)
path := ""
if server.options.EnableRandomUrl {
path += "/" + randomstring.Generate(server.options.RandomUrlLength)
}
scheme := "http"
if server.options.EnableTLS {
scheme = "https"
}
server.url = &url.URL{Scheme: scheme, Host: host, Path: path + "/"}
}
return server.url
}
func (server *Server) tlsConfig() (*tls.Config, error) { func (server *Server) tlsConfig() (*tls.Config, error) {
caFile := homedir.Expand(server.options.TLSCACrtFile) caFile := homedir.Expand(server.options.TLSCACrtFile)
caCert, err := ioutil.ReadFile(caFile) caCert, err := ioutil.ReadFile(caFile)
......
package server
import (
"time"
)
func (server *Server) stopTimer() {
if server.options.Timeout > 0 {
server.timer.Stop()
}
}
func (server *Server) resetTimer() {
if server.options.Timeout > 0 {
server.timer.Reset(time.Duration(server.options.Timeout) * time.Second)
}
}
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!