package websocket import ( "encoding/json" "fmt" "github.com/gorilla/websocket" "go.uber.org/zap" "io" "time" ) type writer struct { conn *websocket.Conn channel chan ServerMessage readNotifications <-chan time.Duration timer *time.Ticker logger *zap.Logger } func (w *writer) act() { defer w.gracefulShutdown() w.logger.Debug("Starting up") w.timer = time.NewTicker(PingDelay) for { select { case _, open := <-w.readNotifications: if open { w.logger.Debug("Received reader read, extending ping") w.timer.Reset(PingDelay) } else { w.logger.Debug("Received reader close, shutting down") w.readNotifications = nil // bye bye, we'll graceful shutdown because we deferred it return } case raw := <-w.channel: switch msg := raw.(type) { case SocketClosed: w.logger.Debug("Received close message, forwarding and shutting down", zap.Object("msg", msg)) w.sendClose(msg) // bye bye, we'll graceful shutdown because we deferred it return default: w.logger.Debug("Received message, forwarding", zap.Object("msg", msg)) w.send(msg) } case <-w.timer.C: w.sendPing() } w.logger.Debug("Awakening handled, resuming listening") } } func (w *writer) send(msg ServerMessage) { writer, err := w.conn.NextWriter(websocket.TextMessage) if err != nil { w.logger.Error("error while getting writer from connection", zap.Error(err)) return } defer func(writer io.WriteCloser) { err := writer.Close() if err != nil { w.logger.Error("error while closing writer to send message", zap.Error(err)) } }(writer) payload, err := json.Marshal(msg) if err != nil { w.logger.Error("error while rendering message payload to JSON", zap.Error(err)) return } if len(payload) == 2 { // This is an empty JSON message. We can leave it out. _, err = fmt.Fprintf(writer, "%s!", msg.ServerType()) if err != nil { w.logger.Error("error while writing command-only message", zap.Error(err)) } } else { // Because we need to send this, we put in a space instead of an exclamation mark. _, err = fmt.Fprintf(writer, "%s %s", msg.ServerType(), payload) if err != nil { w.logger.Error("error while writing command-only message", zap.Error(err)) } } } func (w *writer) sendClose(msg SocketClosed) { w.logger.Debug("Shutting down the writer channel") close(w.channel) w.channel = nil w.logger.Debug("Writing close message") err := w.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(int(msg.Code), msg.Text), time.Now().Add(ControlTimeLimit)) if err != nil { w.logger.Error("Error while sending close", zap.Error(err)) } } func (w *writer) sendPing() { w.logger.Debug("Sending ping") err := w.conn.WriteControl(websocket.PingMessage, []byte("are you still there?"), time.Now().Add(ControlTimeLimit)) if err != nil { w.logger.Error("Error while sending ping", zap.Error(err)) } } // gracefulShutdown causes the writer to wait for the close handshake to finish and then shut down. // It waits for the reader's readNotifications to close, indicating that it has also shut down, and for the channel to // receive a SocketClosed message indicating that the client has shut down. // During this time, the writer ignores all other messages from the channel and sends no pings. func (w *writer) gracefulShutdown() { defer w.finalShutdown() w.logger.Debug("Waiting for all channels to shut down") w.timer.Stop() for { if w.channel == nil && w.readNotifications == nil { w.logger.Debug("All channels closed, beginning final shutdown") // all done, we outta here, let the defer pick up the final shutdown return } select { case _, open := <-w.readNotifications: if !open { w.logger.Debug("Received reader close while shutting down") w.readNotifications = nil } case raw := <-w.channel: switch msg := raw.(type) { case SocketClosed: w.logger.Debug("Received close message from channel while shutting down, forwarding", zap.Object("msg", msg)) w.sendClose(msg) default: w.logger.Debug("Ignoring non-close message while shutting down", zap.Object("msg", msg)) } } } } func (w *writer) finalShutdown() { w.logger.Debug("Closing WebSocket connection") err := w.conn.Close() if err != nil { w.logger.Error("Received an error while closing", zap.Error(err)) } w.logger.Debug("Shut down") }